├── .gitignore ├── LICENSE ├── README.md ├── latency.py ├── misc.py ├── model.py ├── mxnet_mobilenet.py ├── train_final.py ├── train_nas.py └── viterbi.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project-specific 2 | output 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # PyCharm 104 | .idea 105 | 106 | # PyTorch weights 107 | *.tar 108 | *.pth 109 | *.gz 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Maxim Berman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AOWS 2 | **AOWS: Adaptive and optimal network width search with latency constraints**, Maxim Berman, Leonid Pishchulin, Ning Xu, Matthew B. Blaschko, Gérard Medioni, _NAS workshop @ ICLR 2020_ and _CVPR 2020 (oral)_. 3 | 4 | 5 | AOWS-teaser 6 | 7 | ## Usage 8 | 9 | ### Latency model 10 | _main file: `latency.py`_ 11 | 12 | _depends on: [PyTorch](http://pytorch.org/), [CVXPY](https://www.cvxpy.org/), matplotlib, numpy, [numba](http://numba.pydata.org/), scipy, [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt/tree/e22844a449a880123435fce7e6444f1516ebbe60)_ 13 | 14 | 15 | * Generate training and validation samples 16 | ```Shell 17 | python latency.py generate --device trt --dtype fp16 \ 18 | --biased --count 8000 output/samples_trt16.jsonl 19 | python latency.py generate --device trt --dtype fp16 \ 20 | --count 200 output/val_trt16.jsonl 21 | ``` 22 | * Fit the model (`K` controls the amount of regularization and should be set by validation) 23 | ```Shell 24 | python latency.py fit output/samples_trt16.jsonl \ 25 | output/model_trt16_K100.0.jsonl -K 100.0 26 | ``` 27 | * Validate the model (produces a plot) 28 | ```Shell 29 | python latency.py validate output/val_trt16.jsonl \ 30 | output/model_trt16_K100.0.jsonl output/correlation_plot.png 31 | ``` 32 | 33 | Additionally, one can benchark a single configuration with this script using e.g. 34 | ```Shell 35 | python latency.py benchmark --device trt --dtype fp16 \ 36 | "(16, 32, 64, 112, 360, 48, 464, 664, 152, 664, 256, 208, 816, 304)" 37 | ``` 38 | 39 | ### Network width search 40 | _main file: `train_nas.py`_ 41 | 42 | _depends on: [PyTorch](http://pytorch.org/), numpy, [numba](http://numba.pydata.org/)_ 43 | * **Train a slimmable network and select a configuration with OWS.** See `-h` for optimization options. 44 | ```Shell 45 | python train_nas.py --data /imagenet --latency-target 0.04 \ 46 | --latency-model output/model_trt16_K100.0.jsonl \ 47 | --expname output/ows-trt16-0.04 --resume-last 48 | ``` 49 | In OWS, the latency target `--latency-target` can be changed during or after training. Using the parameter `--resume-last` allows to resume the last checkpoint without having to retrain, allowing for varying the latency target. 50 | 51 | _Implementation detail: for ease of implementation we here use a fixed moving average with a window of `--window=100000` samples for each unary weight, while in the article we used the statistics available over one full last epoch._ 52 | 53 | * **Train a slimmable network with AOWS.** See `-h` for optimization options. The outputs, including best configuration for each epochs, are put in the directory corresponding to the parameter `--expname`. 54 | ```Shell 55 | python train_nas.py --data /imagenet --latency-target 0.04 \ 56 | --latency-model output/model_trt16_K100.0.jsonl \ 57 | --AOWS --expname output/aows-trt16-0.04 --resume-last 58 | ``` 59 | In AOWS, the latency target `--latency-target` should be set at the beginning of the training, since it impacts the training. 60 | 61 | 62 | ### Training the final model 63 | _main file: `train_final.py`_ 64 | 65 | _depends on: [mxnet](https://mxnet.apache.org/)_ 66 | 67 | modified version of gluon-cv's [train_imagenet.py](https://github.com/dmlc/gluon-cv/blob/18f8ab526ffb97660e6e5661f991064c20e2699d/scripts/classification/imagenet/train_imagenet.py) for training mobilenet-v1 with varying channel numbers. Refer to gluon-cv's documentation for detailed usage. 68 | 69 | Example command: 70 | ``` 71 | python train_final.py \ 72 | --rec-train /imagenet/imagenet_train.rec \ 73 | --rec-train-idx /imagenet/imagenet_train.idx \ 74 | --rec-val /ramdisk/imagenet_val.rec \ 75 | --rec-val-idx /ramdisk/imagenet_val.idx \ 76 | --use-rec --mode hybrid --lr 0.4 --lr-mode cosine \ 77 | --num-epochs 200 --batch-size 256 -j 32 --num-gpus 4 \ 78 | --dtype float16 --warmup-epochs 5 --no-wd \ 79 | --label-smoothing --mixup \ 80 | --save-dir params_mymobilenet --logging-file mymobilenet.log \ 81 | --configuration "(16, 32, 64, 112, 360, 48, 464, 664, 152, 664, 256, 208, 816, 304)" 82 | ``` 83 | 84 | ## Citation 85 | ```BibTeX 86 | @InProceedings{Berman2020AOWS, 87 | author = {Berman, Maxim and Pishchulin, Leonid and Xu, Ning and Blaschko, Matthew B. and Medioni, Gerard}, 88 | title = {{AOWS}: adaptive and optimal network width search with latency constraints}, 89 | booktitle = {Proceedings of the {IEEE} Computer Society Conference on Computer Vision and Pattern Recognition}, 90 | month = jun, 91 | year = {2020}, 92 | } 93 | ``` 94 | 95 | ## Disclaimer 96 | The code was re-implemented and is not fully tested at this point. 97 | 98 | -------------------------------------------------------------------------------- /latency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import os.path as osp 6 | import random 7 | import sys 8 | import time 9 | from collections import Counter 10 | from collections import defaultdict 11 | from collections import namedtuple 12 | 13 | import cvxpy as cp 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import torch 17 | from scipy.sparse import lil_matrix 18 | # lazy: from torch2trt import torch2trt 19 | 20 | from misc import DelayedKeyboardInterrupt 21 | from misc import tuplify 22 | from model import SlimMobilenet 23 | from model import LayerType 24 | from viterbi import complete 25 | from viterbi import maxsum 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | Vartype = namedtuple("Vartype", LayerType._fields + ('in_channels', 'out_channels')) 30 | torch.backends.cudnn.benchmark = True 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser( 35 | description="Generate samples and fit a latency model.") 36 | 37 | subparsers = parser.add_subparsers(dest='mode') 38 | subparsers.required = True 39 | 40 | parser_bench = subparsers.add_parser('benchmark', 41 | help="Benchmark a single channel configuration", 42 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 43 | parser_bench.add_argument("configuration", 44 | help="configuration to test (comma-separated channels or MOBILENET)") 45 | 46 | parser_gen = subparsers.add_parser('generate', 47 | help="Generate latency samples", 48 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | for subparser in (parser_bench, parser_gen): 50 | subparser.add_argument("-D", "--device", choices=["cpu", "gpu", "trt"], 51 | default="gpu", help="Use GPU, CPU or TensorRT latency") 52 | subparser.add_argument("--dtype", choices=["fp32", "fp16"], 53 | default="fp16", help="Datatype for network") 54 | subparser.add_argument("-B", "--batch-size", type=int, default=64, 55 | help="Batch size used for profiling") 56 | subparser.add_argument("-I", "--iterations", type=int, default=60, 57 | help="Profiling iterations") 58 | subparser.add_argument("-W", "--warmup", type=int, default=10, 59 | help="Warmup iterations") 60 | subparser.add_argument("--reduction", choices=['mean', 'min'], default='mean', 61 | help="Reduce timings by their mean or by their minimum (minimum can reduce variance)") 62 | parser_gen.add_argument("--biased", action="store_true", 63 | help="Bias sampling towards missing configurations") 64 | parser_gen.add_argument("-N", "--count", type=int, default=8000, 65 | help="Minimum number of samples to generate") 66 | parser_gen.add_argument("-R", "--repetitions", type=int, default=0, 67 | help="Minimum number of samples per choice") 68 | parser_gen.add_argument("--save-every", type=int, default=1, 69 | help="Number of inferences before saving intermediate output") 70 | parser_gen.add_argument("samples_file", help="Output samples file") 71 | 72 | parser_fit = subparsers.add_parser('fit', help="Fit a latency model") 73 | parser_fit.add_argument("-K", "--regularize", type=float, default=0.0, 74 | help="Amount of monotonicity regularization (Equation 7)") 75 | parser_fit.add_argument("samples_file", help="Training samples") 76 | parser_fit.add_argument("model_file", help="Output model file") 77 | 78 | parser_val = subparsers.add_parser('validate', help="Validate a latency model") 79 | parser_val.add_argument("samples_file", help="Validation samples") 80 | parser_val.add_argument("model_file", help="Model file") 81 | parser_val.add_argument("plot_file", help="Plot file") 82 | 83 | args = parser.parse_args() 84 | 85 | if 'configuration' in args: 86 | defaults = {'MOBILENET': "32,64,128,128,256,256,512,512,512,512,512,512,1024,1024"} 87 | if args.configuration in defaults: 88 | args.configuration = defaults[args.configuration] 89 | args.configuration = [int(''.join(ci for ci in c if ci.isdigit())) for c in args.configuration.split(',')] 90 | 91 | return args 92 | 93 | 94 | def get_model(min_width=0.2, max_width=1.5, levels=14): 95 | return SlimMobilenet(min_width=min_width, max_width=max_width, levels=levels) 96 | 97 | 98 | def benchmark(device, dtype, batch_size, iterations, warmup, reduction, configuration, silent=False): 99 | if device == 'cpu': 100 | dev = torch.device('cpu') 101 | elif device in ['gpu', 'trt']: 102 | dev = torch.device('cuda') 103 | fp = dict(fp16=torch.float16, fp32=torch.float32).get(dtype) 104 | net = SlimMobilenet.reduce(configuration).to(dev).type(fp).eval() 105 | x = torch.ones((batch_size, 3, 224, 224)).to(dev).type(fp) 106 | if device == 'trt': 107 | from torch2trt import torch2trt 108 | net = torch2trt(net, [x], fp16_mode=(dtype == 'fp16'), max_batch_size=batch_size) 109 | 110 | for i in range(warmup): 111 | outputs = net(x) 112 | torch.cuda.current_stream().synchronize() 113 | 114 | timings = [] 115 | t0 = time.time() 116 | for i in range(iterations): 117 | outputs = net(x) 118 | torch.cuda.current_stream().synchronize() 119 | t1 = time.time() 120 | timings.append(t1 - t0) 121 | t0 = t1 122 | 123 | ms = 1000.0 * getattr(np, reduction)(timings) / batch_size 124 | if not silent: 125 | print(f"{configuration}: {ms}ms") 126 | 127 | return ms 128 | 129 | 130 | def gen_configuration_biased(net, repetitions): 131 | M = min(repetitions.values()) 132 | unary = [] 133 | pairwise = [] 134 | for i, L in enumerate(net.components): 135 | input_choices = [net.in_channels] if i == 0 else net.configurations[i - 1] 136 | output_choices = ([net.out_channels] if i == len(net.components) - 1 137 | else net.configurations[i]) 138 | U = np.zeros(len(input_choices)) 139 | P = np.zeros((len(input_choices), len(output_choices))) 140 | for i1, I in enumerate(input_choices): 141 | for i2, O in enumerate(output_choices): 142 | var = Vartype(**L._asdict(), in_channels=I, out_channels=O) 143 | P[i1, i2] = float(repetitions[var] == M) 144 | unary.append(U) 145 | pairwise.append(P) 146 | unary.append(np.zeros(len(output_choices))) 147 | un, pair, states = complete(unary, pairwise) 148 | iconfig = maxsum(un, pair, states)[1] 149 | configuration = [C[i] for (C, i) in zip(net.configurations, iconfig[1:-1])] 150 | return configuration 151 | 152 | 153 | def gen_configuration(net, repetitions, biased=False): 154 | if biased: 155 | return gen_configuration_biased(net, repetitions) 156 | return [random.choice(conf) for conf in net.configurations] 157 | 158 | 159 | def collect_repetitions(net, configuration=None): 160 | if configuration is None: 161 | configuration = net.configurations 162 | if isinstance(configuration[0], (int, np.integer)): # single configuration 163 | configuration = [[c] for c in configuration] 164 | layertypes = Counter() 165 | for i, L in enumerate(net.components): 166 | input_choices = [net.in_channels] if i == 0 else configuration[i - 1] 167 | output_choices = ([net.out_channels] if i == len(net.components) - 1 168 | else configuration[i]) 169 | for I in input_choices: 170 | for O in output_choices: 171 | var = Vartype(**L._asdict(), in_channels=I, out_channels=O) 172 | layertypes[var] += 1 173 | return layertypes 174 | 175 | 176 | def sample_file_iterator(samples_file): 177 | with open(samples_file, 'r') as f: 178 | for line in f: 179 | yield tuplify(json.loads(line)) 180 | 181 | 182 | def generate(device, dtype, batch_size, iterations, warmup, reduction, biased, 183 | count, repetitions, samples_file=os.devnull, save_every=10): 184 | os.makedirs(osp.dirname(samples_file), exist_ok=True) 185 | 186 | net = get_model() 187 | combinations = collect_repetitions(net) 188 | logger.info(f"{len(net.configurations)} modulers") 189 | logger.debug(f"search space: {net.configurations}") 190 | logger.debug(f"components: {net.components}") 191 | logger.info(f"Latency model has {len(combinations)} parameters") 192 | 193 | repeats = Counter() 194 | for c in combinations: 195 | repeats[c] = 0 196 | 197 | samples = [] 198 | if osp.isfile(samples_file): 199 | for sample in sample_file_iterator(samples_file): 200 | samples.append(sample) 201 | repeats.update(collect_repetitions(net, sample[0])) 202 | logger.info(f"Loaded {samples_file}, " 203 | f"min_repetition={min(repeats.values())} " 204 | f"count={len(samples)} ") 205 | logger.info(f"Writing new samples to {samples_file}") 206 | new_samples = [] 207 | while (len(samples) + len(new_samples) < count 208 | or min(repeats.values()) < repetitions): 209 | configuration = gen_configuration(net, repeats, biased=biased) 210 | ms = benchmark(device, dtype, batch_size, iterations, warmup, reduction, configuration, silent=True) 211 | repeats.update(collect_repetitions(net, configuration)) 212 | logger.info(f"{configuration}: {ms:.04f}ms, " 213 | f"min_repetition={min(repeats.values())} " 214 | f"count={len(samples) + len(new_samples)} ") 215 | new_samples.append([[int(d) for d in configuration], ms]) 216 | if (len(new_samples) % save_every) == 0: 217 | with open(samples_file, 'a') as f: 218 | for sample in new_samples: 219 | dump = json.dumps(sample) + '\n' 220 | with DelayedKeyboardInterrupt(): 221 | f.write(dump) 222 | samples.extend(new_samples) 223 | new_samples = [] 224 | 225 | samples.extend(new_samples) 226 | return samples 227 | 228 | 229 | def build_equation(samples): 230 | """ 231 | Samples can be iterator 232 | """ 233 | net = get_model() 234 | variables = {} 235 | ivariables = {} 236 | Mcoord = [] 237 | y = [] 238 | for (i, sample) in enumerate(samples): 239 | y.append(sample[1]) 240 | local_repeats = collect_repetitions(net, sample[0]) 241 | for (L, r) in local_repeats.items(): 242 | if L not in variables: 243 | j = len(variables) 244 | variables[L] = j 245 | ivariables[j] = L 246 | Mcoord.append((i, variables[L], r)) 247 | y = np.array(y) 248 | M = lil_matrix((len(y), len(variables))) 249 | for (i, j, r) in Mcoord: 250 | M[i, j] = r 251 | return M, y, variables, ivariables 252 | 253 | 254 | def solve_lsq(M, y, regularize=0.0, K=None): 255 | n = M.shape[1] 256 | x = cp.Variable(n) 257 | t = cp.Variable(K.shape[0]) 258 | M_cp = cp.Constant(M) 259 | obj = cp.sum_squares(M_cp @ x - y) 260 | constraints = [x >= 0] 261 | if regularize: 262 | K_cp = cp.Constant(K) 263 | obj += regularize * cp.sum_squares(t) 264 | constraints += [t >= 0, K_cp @ x <= t] 265 | objective = cp.Minimize(obj) 266 | prob = cp.Problem(objective, constraints) 267 | prob.solve(cp.SCS, verbose=True) 268 | return x.value 269 | 270 | 271 | def get_inequalities(variables): 272 | def other(L, *args): 273 | props = L._asdict() 274 | for k in args: 275 | del props[k] 276 | return tuple(props.values()) 277 | buckets = defaultdict(list) 278 | for order in ['in_channels', 'out_channels', 'in_size']: 279 | for V in variables: 280 | buckets[other(V, order)].append(V) 281 | inequalities = [] 282 | for bucket in buckets.values(): 283 | bucket = sorted(bucket) 284 | for i in range(len(bucket) - 1): 285 | inequalities.append((bucket[i], bucket[i + 1])) 286 | K = lil_matrix((len(inequalities), len(variables))) 287 | for i, (C1, C2) in enumerate(inequalities): 288 | K[i, variables[C1]] = 1 289 | K[i, variables[C2]] = -1 290 | return K 291 | 292 | 293 | def fit_model(samples, regularize=0.0): 294 | M, y, variables, ivariables = build_equation(samples) 295 | K = get_inequalities(variables) 296 | x = solve_lsq(M, y, regularize, K) 297 | model = [] 298 | for i, ms in enumerate(x): 299 | model.append((ivariables[i], ms)) 300 | return model 301 | 302 | 303 | def dump_model(model, model_file): 304 | with open(model_file, 'w') as f: 305 | for m in model: 306 | var, ms = m 307 | dump = json.dumps([var._asdict(), ms]) + '\n' 308 | f.write(dump) 309 | 310 | 311 | def load_model(model_file): 312 | with open(model_file, 'r') as f: 313 | for line in f: 314 | var, ms = tuplify(json.loads(line)) 315 | var = Vartype(**var) 316 | yield (var, ms) 317 | 318 | 319 | def fit(samples_file, model_file, regularize=0.0): 320 | os.makedirs(osp.dirname(model_file), exist_ok=True) 321 | samples = sample_file_iterator(samples_file) 322 | model = fit_model(samples, regularize) 323 | dump_model(model, model_file) 324 | return model 325 | 326 | 327 | def validate(samples_file, model_file, plot_file): 328 | os.makedirs(osp.dirname(plot_file), exist_ok=True) 329 | model = load_model(model_file) 330 | model_dict = dict(model) 331 | samples = sample_file_iterator(samples_file) 332 | M, y, variables, ivariables = build_equation(samples) 333 | x = [model_dict[ivariables[i]] for i in range(len(variables))] 334 | yhat = M @ x 335 | rmse = np.sqrt(((y - yhat) ** 2).mean()) 336 | title = f"RMSE {rmse:.04f}, NRMSE {100 * rmse / y.mean():.02f}%" 337 | print(title) 338 | plt.plot(y, yhat, 'o') 339 | plt.xlabel("ground truth (ms)") 340 | plt.ylabel("predicted (ms)") 341 | plt.title(title) 342 | plt.savefig(plot_file) 343 | 344 | 345 | if __name__ == "__main__": 346 | logger = logging.getLogger(__file__) 347 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG, 348 | format='%(name)s: %(message)s') 349 | args = parse_args().__dict__ 350 | 351 | globals()[args.pop('mode')](**args) 352 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import signal 2 | import logging 3 | from collections import defaultdict 4 | from collections import deque 5 | import numpy as np 6 | import torch 7 | 8 | 9 | # https://stackoverflow.com/a/21919644/805502 10 | class DelayedKeyboardInterrupt(object): 11 | def __enter__(self): 12 | self.signal_received = False 13 | self.old_handler = signal.signal(signal.SIGINT, self.handler) 14 | 15 | def handler(self, sig, frame): 16 | self.signal_received = (sig, frame) 17 | logging.debug('SIGINT received. Delaying KeyboardInterrupt.') 18 | 19 | def __exit__(self, type, value, traceback): 20 | signal.signal(signal.SIGINT, self.old_handler) 21 | if self.signal_received: 22 | self.old_handler(*self.signal_received) 23 | 24 | 25 | # https://stackoverflow.com/a/25294767/805502 26 | def tuplify(listything): 27 | if isinstance(listything, list): return tuple(map(tuplify, listything)) 28 | if isinstance(listything, dict): return {k:tuplify(v) for k,v in listything.items()} 29 | return listything 30 | 31 | 32 | class SWDict(dict): 33 | """ 34 | Single-write dict. Useful for making sure no inference is computed twice. 35 | """ 36 | def __setitem__(self, key, value): 37 | if key in self: 38 | raise ValueError('key', key, 'already set') 39 | super().__setitem__(key, value) 40 | 41 | 42 | class SWDefaultDict(defaultdict): 43 | """ 44 | Single-write defaultdict. 45 | """ 46 | def __setitem__(self, key, value): 47 | if key in self: 48 | raise ValueError('key', key, 'already set') 49 | super().__setitem__(key, value) 50 | 51 | 52 | class MovingAverageMeter(object): 53 | def __init__(self, window): 54 | self.window = window 55 | self.reset() 56 | 57 | def reset(self): 58 | self.history = deque() 59 | self.avg = 0 60 | self.sum = None 61 | self.val = None 62 | 63 | @property 64 | def count(self): 65 | return len(self.history) 66 | 67 | @property 68 | def isfull(self): 69 | return len(self.history) == self.window 70 | 71 | def __getstate__(self): 72 | state = self.__dict__.copy() 73 | state['history'] = np.array(state['history']) 74 | return state 75 | 76 | def __setstate__(self, state): 77 | state['history'] = deque(state['history']) 78 | self.__dict__.update(state) 79 | 80 | def update(self, val, epoch, iteration): 81 | self.history.append(val) 82 | if self.sum is None: 83 | self.sum = val 84 | else: 85 | self.sum += val 86 | if len(self.history) > self.window: 87 | self.sum -= self.history.popleft() 88 | self.val = val 89 | self.avg = self.sum / self.count 90 | 91 | def __repr__(self): 92 | return "".format( 93 | self.window, self.count, self.val, self.avg) 94 | 95 | 96 | class AverageMeter(object): 97 | """Computes and stores the average and current value""" 98 | 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | 115 | def accuracy(output, target, topk=(1,), return_correct_k=False): 116 | """Computes the precision@k for the specified values of k""" 117 | maxk = max(topk) 118 | batch_size = target.size(0) 119 | 120 | _, pred = output.topk(maxk, 1, True, True) 121 | pred = pred.t() 122 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 123 | 124 | res = [] 125 | correct_ks = [] 126 | 127 | for k in topk: 128 | correct_k = correct[:k].float().sum(0) 129 | res.append(correct_k.sum().mul_(100.0 / batch_size)) 130 | correct_ks.append(correct_k) 131 | if return_correct_k: 132 | return res, correct_ks 133 | return res 134 | 135 | 136 | def soft_cross_entropy(output, target): 137 | """ 138 | For knowledge distillation in self-distillation 139 | """ 140 | output_log_prob = torch.nn.functional.log_softmax(output, dim=1) 141 | target = target.unsqueeze(1) 142 | output_log_prob = output_log_prob.unsqueeze(2) 143 | cross_entropy_loss = -torch.bmm(target, output_log_prob).view(output.size(0)) 144 | return cross_entropy_loss -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model defining the search space. 3 | Implemented: Mobilenet-v1 family. 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | from collections import namedtuple 10 | import logging 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | LayerType = namedtuple('LayerType', ['in_size', 'kernel_size', 'stride', 'dw', 'bias']) 15 | 16 | 17 | class Moduler(nn.Module): 18 | """ 19 | Dynamically trims input channels (randomly or based on argument) 20 | """ 21 | 22 | def __init__(self, configurations): 23 | super().__init__() 24 | self.configurations = configurations 25 | self.base_channels = 8 26 | self.probability = None # for biased sampling 27 | 28 | def forward(self, data, channels=None, record=True): 29 | if not isinstance(data, dict): 30 | data = dict(x=data) 31 | x = data['x'] 32 | 33 | if channels is None: 34 | idx = np.random.choice(np.arange(len(self.configurations)), 35 | size=x.size(0), 36 | p=self.probability) 37 | confs = self.configurations[idx] 38 | else: 39 | confs = channels * np.ones((x.size(0),), int) 40 | 41 | mask = x.new_zeros((x.size(0), x.size(1) + 1)) 42 | mask[np.arange(len(confs)), confs] = 1.0 43 | mask = 1 - mask[:, :x.size(1)].cumsum(1) 44 | x = x * mask.unsqueeze(2).unsqueeze(3) 45 | 46 | data['x'] = x 47 | if record: # record chosen channels 48 | if 'decision' not in data: data['decision'] = [] 49 | data['decision'].append(confs) 50 | return data 51 | 52 | def __repr__(self): 53 | return "Moduler({})".format(self.configurations) 54 | 55 | 56 | class Flatten(nn.Module): 57 | def forward(self, x): 58 | return x.view(x.size(0), -1) 59 | 60 | 61 | class SlimMobilenet(nn.Module): 62 | 63 | in_channels = 3 64 | out_channels = 1000 65 | 66 | @staticmethod 67 | def gen_conv(inp, oup, stride, dw=False, bn=True): 68 | mod = [] 69 | if dw: 70 | mod = [ 71 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 72 | nn.BatchNorm2d(inp), 73 | nn.ReLU(inplace=True), 74 | 75 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 76 | nn.BatchNorm2d(oup), 77 | nn.ReLU(inplace=True), 78 | ] 79 | else: 80 | mod = [ 81 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 82 | nn.BatchNorm2d(oup), 83 | nn.ReLU(inplace=True), 84 | ] 85 | if not bn: 86 | mod = [m for m in mod if not isinstance(m, nn.BatchNorm2d)] 87 | return nn.Sequential(*mod) 88 | 89 | @staticmethod 90 | def strides_channels(): 91 | """ 92 | follows Mobilenet-v1 definition 93 | """ 94 | blocks = [[32, 64], # special first stem block 95 | [128, 128], 96 | [256, 256], 97 | [512, 512, 512, 512, 512, 512], 98 | [1024, 1024], 99 | ] 100 | strides = [] 101 | for block in blocks: 102 | strides.extend([2] + [1] * (len(block) - 1)) 103 | base_channels = np.array([c for block in blocks for c in block]) 104 | 105 | return strides, base_channels 106 | 107 | def __init__(self, min_width=0.2, max_width=1.5, levels=14, fc_dropout=0.0, in_size=(224, 224)): 108 | super().__init__() 109 | 110 | def divise8(i): 111 | return (np.maximum(np.round(i / 8), 1) * 8).astype(int) 112 | 113 | strides, base_channels = self.strides_channels() 114 | depthwise = [0] + [1] * (len(base_channels) - 1) 115 | 116 | self.configurations = divise8(base_channels.reshape(-1, 1) * np.linspace(min_width, max_width, levels).reshape(1, -1)) 117 | self.configurations = [np.unique(c) for c in self.configurations] 118 | 119 | self.components = [] 120 | 121 | channels = [self.in_channels] + [int(c[-1]) for c in self.configurations] 122 | inp = iter(channels) 123 | oup = iter(channels[1:]) 124 | 125 | self.model = nn.ModuleList() 126 | for dw, strid in zip(depthwise, strides): 127 | I = next(inp) 128 | O = next(oup) 129 | mod = self.gen_conv(I, O, strid, dw) 130 | component = LayerType(in_size=in_size, kernel_size=3, stride=strid, dw=bool(dw), bias=False) 131 | in_size = (in_size[0] // strid, in_size[1] // strid) 132 | self.model.append(mod) 133 | self.components.append(component) 134 | 135 | self.filters = nn.ModuleList() 136 | for conf, base_chan in zip(self.configurations, base_channels): 137 | F = Moduler(conf) 138 | F.base_channels = base_chan 139 | self.filters.append(F) 140 | 141 | self.pool = nn.AvgPool2d(7) 142 | self.fc_dropout = None if not fc_dropout else nn.Dropout(fc_dropout) 143 | in_size = (in_size[0] // 7, in_size[1] // 7) 144 | 145 | I = next(inp) 146 | self.fc = nn.Linear(I, self.out_channels) 147 | self.components.append(LayerType(in_size=in_size, kernel_size=1, stride=1, dw=False, bias=True)) 148 | 149 | def forward(self, data, configuration=None): 150 | if not isinstance(data, dict): 151 | data = dict(x=data) 152 | for i, (conv, filter) in enumerate(zip(self.model, self.filters)): 153 | data['x'] = conv(data['x']) 154 | data = filter(data, 155 | channels=(configuration[i] if configuration is not None else None)) 156 | data['x'] = self.pool(data['x']) 157 | data['x'] = data['x'].view(data['x'].size(0), -1) 158 | if self.fc_dropout is not None: 159 | data['x'] = self.fc_dropout(data['x']) 160 | data['x'] = self.fc(data['x']) 161 | data['decision'] = torch.tensor(np.array(data['decision']).T, device=data['x'].device) 162 | return data 163 | 164 | @classmethod 165 | def reduce(cls, C=(32, 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), 166 | bn=True): 167 | """ 168 | Remove all modulers and reduce according to a single channel configuration 169 | """ 170 | modules = [] 171 | I = cls.in_channels 172 | depthwise = [False] + [True] * (len(C) - 1) 173 | strides, base_channels = cls.strides_channels() 174 | assert len(strides) == len(C) 175 | 176 | for O, stride, dw in zip(C, strides, depthwise): 177 | modules.append(cls.gen_conv(I, O, stride, dw, bn=bn)) 178 | I = O 179 | modules += [nn.AvgPool2d(7), Flatten(), nn.Linear(I, cls.out_channels)] 180 | reduced = nn.Sequential(*modules) 181 | return reduced 182 | 183 | 184 | -------------------------------------------------------------------------------- /mxnet_mobilenet.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | # coding: utf-8 19 | # pylint: disable= arguments-differ,unused-argument,missing-docstring,too-many-function-args 20 | """ 21 | MobileNet and MobileNetV2, implemented in Gluon. 22 | File based on Gluon-cv mobilenet.py at https://github.com/dmlc/gluon-cv/blob/3c4150a964c776e4f7da0eb30b55ab05b7554c8d/gluoncv/model_zoo/mobilenet.py 23 | """ 24 | 25 | from mxnet.gluon import nn 26 | from mxnet.gluon.nn import BatchNorm 27 | from mxnet.context import cpu 28 | from mxnet.gluon.block import HybridBlock 29 | from gluoncv.nn import ReLU6 30 | import logging 31 | 32 | logger = logging.getLogger(__file__) 33 | 34 | 35 | __all__ = ['get_mobilenet'] 36 | 37 | 38 | # pylint: disable= too-many-arguments 39 | def _add_conv(out, channels=1, kernel=1, stride=1, pad=0, 40 | num_group=1, active=True, relu6=False, norm_layer=BatchNorm, norm_kwargs=None): 41 | out.add(nn.Conv2D(channels, kernel, stride, pad, groups=num_group, use_bias=False)) 42 | out.add(norm_layer(scale=True, **({} if norm_kwargs is None else norm_kwargs))) 43 | if active: 44 | out.add(ReLU6() if relu6 else nn.Activation('relu')) 45 | 46 | 47 | def _add_conv_dw(out, dw_channels, channels, stride, relu6=False, 48 | norm_layer=BatchNorm, norm_kwargs=None): 49 | _add_conv(out, channels=dw_channels, kernel=3, stride=stride, 50 | pad=1, num_group=dw_channels, relu6=relu6, 51 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 52 | _add_conv(out, channels=channels, relu6=relu6, 53 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 54 | 55 | 56 | class LinearBottleneck(nn.HybridBlock): 57 | r"""LinearBottleneck used in MobileNetV2 model from the 58 | `"Inverted Residuals and Linear Bottlenecks: 59 | Mobile Networks for Classification, Detection and Segmentation" 60 | `_ paper. 61 | 62 | Parameters 63 | ---------- 64 | in_channels : int 65 | Number of input channels. 66 | channels : int 67 | Number of output channels. 68 | t : int 69 | Layer expansion ratio. 70 | stride : int 71 | stride 72 | norm_layer : object 73 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 74 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 75 | norm_kwargs : dict 76 | Additional `norm_layer` arguments, for example `num_devices=4` 77 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 78 | """ 79 | 80 | def __init__(self, in_channels, channels, t, stride, 81 | norm_layer=BatchNorm, norm_kwargs=None, **kwargs): 82 | super(LinearBottleneck, self).__init__(**kwargs) 83 | self.use_shortcut = stride == 1 and in_channels == channels 84 | with self.name_scope(): 85 | self.out = nn.HybridSequential() 86 | 87 | if t != 1: 88 | _add_conv(self.out, 89 | in_channels * t, 90 | relu6=True, 91 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 92 | _add_conv(self.out, 93 | in_channels * t, 94 | kernel=3, 95 | stride=stride, 96 | pad=1, 97 | num_group=in_channels * t, 98 | relu6=True, 99 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 100 | _add_conv(self.out, 101 | channels, 102 | active=False, 103 | relu6=True, 104 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 105 | 106 | def hybrid_forward(self, F, x): 107 | out = self.out(x) 108 | if self.use_shortcut: 109 | out = F.elemwise_add(out, x) 110 | return out 111 | 112 | 113 | # Net 114 | class MobileNet(HybridBlock): 115 | r"""MobileNet model from the 116 | `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 117 | `_ paper. 118 | 119 | Parameters 120 | ---------- 121 | multiplier : float, default 1.0 122 | The width multiplier for controlling the model size. Only multipliers that are no 123 | less than 0.25 are supported. The actual number of channels is equal to the original 124 | channel size multiplied by this multiplier. 125 | classes : int, default 1000 126 | Number of classes for the output layer. 127 | norm_layer : object 128 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 129 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 130 | norm_kwargs : dict 131 | Additional `norm_layer` arguments, for example `num_devices=4` 132 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 133 | """ 134 | 135 | def __init__(self, multiplier=1.0, classes=1000, 136 | norm_layer=BatchNorm, norm_kwargs=None, configuration=None, **kwargs): 137 | if configuration == None or configuration == 'MOBILENET': 138 | configuration = (32, 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024) 139 | else: 140 | configuration = [int(''.join(ci for ci in c if ci.isdigit())) for c in configuration.split(',')] 141 | configuration = tuple(map(int, ','.split(configuration))) 142 | 143 | logger.info('Mobilenet with channel configuration: {}'.format(configuration)) 144 | 145 | super(MobileNet, self).__init__(**kwargs) 146 | with self.name_scope(): 147 | self.features = nn.HybridSequential(prefix='') 148 | with self.features.name_scope(): 149 | _add_conv(self.features, channels=int(configuration[0] * multiplier), kernel=3, pad=1, stride=2, 150 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 151 | dw_channels = [int(x * multiplier) for x in configuration[:-1]] 152 | configuration = [int(x * multiplier) for x in configuration[1:]] 153 | strides = [1, 2] * 3 + [1] * 5 + [2, 1] 154 | for dwc, c, s in zip(dw_channels, configuration, strides): 155 | _add_conv_dw(self.features, dw_channels=dwc, channels=c, stride=s, 156 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 157 | self.features.add(nn.GlobalAvgPool2D()) 158 | self.features.add(nn.Flatten()) 159 | 160 | self.output = nn.Dense(classes) 161 | 162 | def hybrid_forward(self, F, x): 163 | x = self.features(x) 164 | x = self.output(x) 165 | return x 166 | 167 | 168 | class MobileNetV2(nn.HybridBlock): 169 | r"""MobileNetV2 model from the 170 | `"Inverted Residuals and Linear Bottlenecks: 171 | Mobile Networks for Classification, Detection and Segmentation" 172 | `_ paper. 173 | Parameters 174 | ---------- 175 | multiplier : float, default 1.0 176 | The width multiplier for controlling the model size. The actual number of channels 177 | is equal to the original channel size multiplied by this multiplier. 178 | classes : int, default 1000 179 | Number of classes for the output layer. 180 | norm_layer : object 181 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 182 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 183 | norm_kwargs : dict 184 | Additional `norm_layer` arguments, for example `num_devices=4` 185 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 186 | """ 187 | 188 | def __init__(self, multiplier=1.0, classes=1000, 189 | norm_layer=BatchNorm, norm_kwargs=None, **kwargs): 190 | super(MobileNetV2, self).__init__(**kwargs) 191 | with self.name_scope(): 192 | self.features = nn.HybridSequential(prefix='features_') 193 | with self.features.name_scope(): 194 | _add_conv(self.features, int(32 * multiplier), kernel=3, 195 | stride=2, pad=1, relu6=True, 196 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 197 | 198 | in_channels_group = [int(x * multiplier) for x in [32] + [16] + [24] * 2 199 | + [32] * 3 + [64] * 4 + [96] * 3 + [160] * 3] 200 | channels_group = [int(x * multiplier) for x in [16] + [24] * 2 + [32] * 3 201 | + [64] * 4 + [96] * 3 + [160] * 3 + [320]] 202 | ts = [1] + [6] * 16 203 | strides = [1, 2] * 2 + [1, 1, 2] + [1] * 6 + [2] + [1] * 3 204 | 205 | for in_c, c, t, s in zip(in_channels_group, channels_group, ts, strides): 206 | self.features.add(LinearBottleneck(in_channels=in_c, 207 | channels=c, 208 | t=t, 209 | stride=s, 210 | norm_layer=norm_layer, 211 | norm_kwargs=norm_kwargs)) 212 | 213 | last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 214 | _add_conv(self.features, 215 | last_channels, 216 | relu6=True, 217 | norm_layer=norm_layer, norm_kwargs=norm_kwargs) 218 | 219 | self.features.add(nn.GlobalAvgPool2D()) 220 | 221 | self.output = nn.HybridSequential(prefix='output_') 222 | with self.output.name_scope(): 223 | self.output.add( 224 | nn.Conv2D(classes, 1, use_bias=False, prefix='pred_'), 225 | nn.Flatten()) 226 | 227 | def hybrid_forward(self, F, x): 228 | x = self.features(x) 229 | x = self.output(x) 230 | return x 231 | 232 | 233 | # Constructor 234 | def get_mobilenet(multiplier=1.0, pretrained=False, ctx=cpu(), 235 | root='~/.mxnet/models', norm_layer=BatchNorm, norm_kwargs=None, **kwargs): 236 | r"""MobileNet model from the 237 | `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 238 | `_ paper. 239 | 240 | Parameters 241 | ---------- 242 | multiplier : float 243 | The width multiplier for controlling the model size. Only multipliers that are no 244 | less than 0.25 are supported. The actual number of channels is equal to the original 245 | channel size multiplied by this multiplier. 246 | pretrained : bool or str 247 | Boolean value controls whether to load the default pretrained weights for model. 248 | String value represents the hashtag for a certain version of pretrained weights. 249 | ctx : Context, default CPU 250 | The context in which to load the pretrained weights. 251 | root : str, default $MXNET_HOME/models 252 | Location for keeping the model parameters. 253 | norm_layer : object 254 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 255 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 256 | norm_kwargs : dict 257 | Additional `norm_layer` arguments, for example `num_devices=4` 258 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 259 | """ 260 | if pretrained: 261 | raise NotImplementedError("No pretrained weights in this modified mobilenet script") 262 | net = MobileNet(multiplier, norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs) 263 | return net 264 | -------------------------------------------------------------------------------- /train_final.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a mobilenet-variant with varying channel numbers. 3 | File based on Gluon-cv train_imagenet.py at https://github.com/dmlc/gluon-cv/blob/18f8ab526ffb97660e6e5661f991064c20e2699d/scripts/classification/imagenet/train_imagenet.py 4 | """ 5 | 6 | import argparse, time, logging, os, math 7 | 8 | import numpy as np 9 | import mxnet as mx 10 | import gluoncv as gcv 11 | from mxnet import gluon, nd 12 | from mxnet import autograd as ag 13 | from mxnet.gluon.data.vision import transforms 14 | 15 | from gluoncv.data import imagenet 16 | from gluoncv.model_zoo import get_model 17 | from gluoncv.utils import makedirs, LRSequential, LRScheduler 18 | 19 | import mxnet_mobilenet as mobilenet 20 | 21 | 22 | _models = { 23 | 'mobilenet': mobilenet.get_mobilenet, 24 | } 25 | 26 | 27 | # CLI 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='Train a model for image classification.') 30 | parser.add_argument('--data-dir', type=str, default='~/.mxnet/datasets/imagenet', 31 | help='training and validation pictures to use.') 32 | parser.add_argument('--rec-train', type=str, default='~/.mxnet/datasets/imagenet/rec/train.rec', 33 | help='the training data') 34 | parser.add_argument('--rec-train-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/train.idx', 35 | help='the index of training data') 36 | parser.add_argument('--rec-val', type=str, default='~/.mxnet/datasets/imagenet/rec/val.rec', 37 | help='the validation data') 38 | parser.add_argument('--rec-val-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/val.idx', 39 | help='the index of validation data') 40 | parser.add_argument('--use-rec', action='store_true', 41 | help='use image record iter for data input. default is false.') 42 | parser.add_argument('--batch-size', type=int, default=32, 43 | help='training batch size per device (CPU/GPU).') 44 | parser.add_argument('--dtype', type=str, default='float32', 45 | help='data type for training. default is float32') 46 | parser.add_argument('--num-gpus', type=int, default=0, 47 | help='number of gpus to use.') 48 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 49 | help='number of preprocessing workers') 50 | parser.add_argument('--num-epochs', type=int, default=3, 51 | help='number of training epochs.') 52 | parser.add_argument('--lr', type=float, default=0.1, 53 | help='learning rate. default is 0.1.') 54 | parser.add_argument('--momentum', type=float, default=0.9, 55 | help='momentum value for optimizer, default is 0.9.') 56 | parser.add_argument('--wd', type=float, default=0.0001, 57 | help='weight decay rate. default is 0.0001.') 58 | parser.add_argument('--lr-mode', type=str, default='step', 59 | help='learning rate scheduler mode. options are step, poly and cosine.') 60 | parser.add_argument('--lr-decay', type=float, default=0.1, 61 | help='decay rate of learning rate. default is 0.1.') 62 | parser.add_argument('--lr-decay-period', type=int, default=0, 63 | help='interval for periodic learning rate decays. default is 0 to disable.') 64 | parser.add_argument('--lr-decay-epoch', type=str, default='40,60', 65 | help='epochs at which learning rate decays. default is 40,60.') 66 | parser.add_argument('--warmup-lr', type=float, default=0.0, 67 | help='starting warmup learning rate. default is 0.0.') 68 | parser.add_argument('--warmup-epochs', type=int, default=0, 69 | help='number of warmup epochs.') 70 | parser.add_argument('--last-gamma', action='store_true', 71 | help='whether to init gamma of the last BN layer in each bottleneck to 0.') 72 | parser.add_argument('--mode', type=str, 73 | help='mode in which to train the model. options are symbolic, imperative, hybrid') 74 | parser.add_argument('--configuration', type=str, default=None, 75 | help=("The custom mobilenet configuration to train. " 76 | "Can be MOBILENET or comma-separated channel numbers.")) 77 | parser.add_argument('--model', type=str, default="mobilenet", 78 | help='type of model to use. set to mobilenet in this modified version.') 79 | parser.add_argument('--input-size', type=int, default=224, 80 | help='size of the input image size. default is 224') 81 | parser.add_argument('--crop-ratio', type=float, default=0.875, 82 | help='Crop ratio during validation. default is 0.875') 83 | parser.add_argument('--use-pretrained', action='store_true', 84 | help='enable using pretrained model from gluon.') 85 | parser.add_argument('--use_se', action='store_true', 86 | help='use SE layers or not in resnext. default is false.') 87 | parser.add_argument('--mixup', action='store_true', 88 | help='whether train the model with mix-up. default is false.') 89 | parser.add_argument('--mixup-alpha', type=float, default=0.2, 90 | help='beta distribution parameter for mixup sampling, default is 0.2.') 91 | parser.add_argument('--mixup-off-epoch', type=int, default=0, 92 | help='how many last epochs to train without mixup, default is 0.') 93 | parser.add_argument('--label-smoothing', action='store_true', 94 | help='use label smoothing or not in training. default is false.') 95 | parser.add_argument('--no-wd', action='store_true', 96 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') 97 | parser.add_argument('--teacher', type=str, default=None, 98 | help='teacher model for distillation training') 99 | parser.add_argument('--temperature', type=float, default=20, 100 | help='temperature parameter for distillation teacher model') 101 | parser.add_argument('--hard-weight', type=float, default=0.5, 102 | help='weight for the loss of one-hot label for distillation training') 103 | parser.add_argument('--batch-norm', action='store_true', 104 | help='enable batch normalization or not in vgg. default is false.') 105 | parser.add_argument('--save-frequency', type=int, default=10, 106 | help='frequency of model saving.') 107 | parser.add_argument('--save-dir', type=str, default='params', 108 | help='directory of saved models') 109 | parser.add_argument('--resume-epoch', type=int, default=0, 110 | help='epoch to resume training from.') 111 | parser.add_argument('--resume-params', type=str, default='', 112 | help='path of parameters to load from.') 113 | parser.add_argument('--resume-states', type=str, default='', 114 | help='path of trainer state to load from.') 115 | parser.add_argument('--log-interval', type=int, default=50, 116 | help='Number of batches to wait before logging.') 117 | parser.add_argument('--logging-file', type=str, default='train_imagenet.log', 118 | help='name of training log file') 119 | parser.add_argument('--use-gn', action='store_true', 120 | help='whether to use group norm.') 121 | opt = parser.parse_args() 122 | return opt 123 | 124 | 125 | def main(): 126 | opt = parse_args() 127 | 128 | filehandler = logging.FileHandler(opt.logging_file) 129 | streamhandler = logging.StreamHandler() 130 | 131 | logger = logging.getLogger('') 132 | logger.setLevel(logging.INFO) 133 | logger.addHandler(filehandler) 134 | logger.addHandler(streamhandler) 135 | 136 | logger.info(opt) 137 | 138 | batch_size = opt.batch_size 139 | classes = 1000 140 | num_training_samples = 1281167 141 | 142 | num_gpus = opt.num_gpus 143 | batch_size *= max(1, num_gpus) 144 | context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] 145 | num_workers = opt.num_workers 146 | 147 | lr_decay = opt.lr_decay 148 | lr_decay_period = opt.lr_decay_period 149 | if opt.lr_decay_period > 0: 150 | lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period)) 151 | else: 152 | lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] 153 | lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch] 154 | num_batches = num_training_samples // batch_size 155 | 156 | lr_scheduler = LRSequential([ 157 | LRScheduler('linear', base_lr=0, target_lr=opt.lr, 158 | nepochs=opt.warmup_epochs, iters_per_epoch=num_batches), 159 | LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0, 160 | nepochs=opt.num_epochs - opt.warmup_epochs, 161 | iters_per_epoch=num_batches, 162 | step_epoch=lr_decay_epoch, 163 | step_factor=lr_decay, power=2) 164 | ]) 165 | 166 | model_name = opt.model 167 | 168 | kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes} 169 | if opt.use_gn: 170 | from gluoncv.nn import GroupNorm 171 | kwargs['norm_layer'] = GroupNorm 172 | if model_name.startswith('vgg'): 173 | kwargs['batch_norm'] = opt.batch_norm 174 | elif model_name.startswith('resnext'): 175 | kwargs['use_se'] = opt.use_se 176 | 177 | if opt.last_gamma: 178 | kwargs['last_gamma'] = True 179 | if opt.configuration is not None: 180 | kwargs['configuration'] = opt.configuration 181 | 182 | optimizer = 'nag' 183 | optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler} 184 | if opt.dtype != 'float32': 185 | optimizer_params['multi_precision'] = True 186 | 187 | if model_name in _models: 188 | net = _models[model_name](**kwargs) 189 | else: 190 | net = get_model(model_name, **kwargs) 191 | net.cast(opt.dtype) 192 | if opt.resume_params is not '': 193 | net.load_parameters(opt.resume_params, ctx = context) 194 | 195 | # teacher model for distillation training 196 | if opt.teacher is not None and opt.hard_weight < 1.0: 197 | teacher_name = opt.teacher 198 | teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context) 199 | teacher.cast(opt.dtype) 200 | distillation = True 201 | else: 202 | distillation = False 203 | 204 | # Two functions for reading data from record file or raw images 205 | def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers): 206 | rec_train = os.path.expanduser(rec_train) 207 | rec_train_idx = os.path.expanduser(rec_train_idx) 208 | rec_val = os.path.expanduser(rec_val) 209 | rec_val_idx = os.path.expanduser(rec_val_idx) 210 | jitter_param = 0.4 211 | lighting_param = 0.1 212 | input_size = opt.input_size 213 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 214 | resize = int(math.ceil(input_size / crop_ratio)) 215 | mean_rgb = [123.68, 116.779, 103.939] 216 | std_rgb = [58.393, 57.12, 57.375] 217 | 218 | def batch_fn(batch, ctx): 219 | data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) 220 | label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) 221 | return data, label 222 | 223 | train_data = mx.io.ImageRecordIter( 224 | path_imgrec = rec_train, 225 | path_imgidx = rec_train_idx, 226 | preprocess_threads = num_workers, 227 | shuffle = True, 228 | batch_size = batch_size, 229 | 230 | data_shape = (3, input_size, input_size), 231 | mean_r = mean_rgb[0], 232 | mean_g = mean_rgb[1], 233 | mean_b = mean_rgb[2], 234 | std_r = std_rgb[0], 235 | std_g = std_rgb[1], 236 | std_b = std_rgb[2], 237 | rand_mirror = True, 238 | random_resized_crop = True, 239 | max_aspect_ratio = 4. / 3., 240 | min_aspect_ratio = 3. / 4., 241 | max_random_area = 1, 242 | min_random_area = 0.08, 243 | brightness = jitter_param, 244 | saturation = jitter_param, 245 | contrast = jitter_param, 246 | pca_noise = lighting_param, 247 | ) 248 | val_data = mx.io.ImageRecordIter( 249 | path_imgrec = rec_val, 250 | path_imgidx = rec_val_idx, 251 | preprocess_threads = num_workers, 252 | shuffle = False, 253 | batch_size = batch_size, 254 | 255 | resize = resize, 256 | data_shape = (3, input_size, input_size), 257 | mean_r = mean_rgb[0], 258 | mean_g = mean_rgb[1], 259 | mean_b = mean_rgb[2], 260 | std_r = std_rgb[0], 261 | std_g = std_rgb[1], 262 | std_b = std_rgb[2], 263 | ) 264 | return train_data, val_data, batch_fn 265 | 266 | def get_data_loader(data_dir, batch_size, num_workers): 267 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 268 | jitter_param = 0.4 269 | lighting_param = 0.1 270 | input_size = opt.input_size 271 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 272 | resize = int(math.ceil(input_size / crop_ratio)) 273 | 274 | def batch_fn(batch, ctx): 275 | data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) 276 | label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) 277 | return data, label 278 | 279 | transform_train = transforms.Compose([ 280 | transforms.RandomResizedCrop(input_size), 281 | transforms.RandomFlipLeftRight(), 282 | transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param, 283 | saturation=jitter_param), 284 | transforms.RandomLighting(lighting_param), 285 | transforms.ToTensor(), 286 | normalize 287 | ]) 288 | transform_test = transforms.Compose([ 289 | transforms.Resize(resize, keep_ratio=True), 290 | transforms.CenterCrop(input_size), 291 | transforms.ToTensor(), 292 | normalize 293 | ]) 294 | 295 | train_data = gluon.data.DataLoader( 296 | imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train), 297 | batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) 298 | val_data = gluon.data.DataLoader( 299 | imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test), 300 | batch_size=batch_size, shuffle=False, num_workers=num_workers) 301 | 302 | return train_data, val_data, batch_fn 303 | 304 | if opt.use_rec: 305 | train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx, 306 | opt.rec_val, opt.rec_val_idx, 307 | batch_size, num_workers) 308 | else: 309 | train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers) 310 | 311 | if opt.mixup: 312 | train_metric = mx.metric.RMSE() 313 | else: 314 | train_metric = mx.metric.Accuracy() 315 | acc_top1 = mx.metric.Accuracy() 316 | acc_top5 = mx.metric.TopKAccuracy(5) 317 | 318 | save_frequency = opt.save_frequency 319 | if opt.save_dir and save_frequency: 320 | save_dir = opt.save_dir 321 | makedirs(save_dir) 322 | else: 323 | save_dir = '' 324 | save_frequency = 0 325 | 326 | def mixup_transform(label, classes, lam=1, eta=0.0): 327 | if isinstance(label, nd.NDArray): 328 | label = [label] 329 | res = [] 330 | for l in label: 331 | y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes) 332 | y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes) 333 | res.append(lam*y1 + (1-lam)*y2) 334 | return res 335 | 336 | def smooth(label, classes, eta=0.1): 337 | if isinstance(label, nd.NDArray): 338 | label = [label] 339 | smoothed = [] 340 | for l in label: 341 | res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes) 342 | smoothed.append(res) 343 | return smoothed 344 | 345 | def test(ctx, val_data): 346 | if opt.use_rec: 347 | val_data.reset() 348 | acc_top1.reset() 349 | acc_top5.reset() 350 | for i, batch in enumerate(val_data): 351 | data, label = batch_fn(batch, ctx) 352 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] 353 | acc_top1.update(label, outputs) 354 | acc_top5.update(label, outputs) 355 | 356 | _, top1 = acc_top1.get() 357 | _, top5 = acc_top5.get() 358 | return (1-top1, 1-top5) 359 | 360 | def train(ctx): 361 | if isinstance(ctx, mx.Context): 362 | ctx = [ctx] 363 | if opt.resume_params is '': 364 | net.initialize(mx.init.MSRAPrelu(), ctx=ctx) 365 | 366 | if opt.no_wd: 367 | for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): 368 | v.wd_mult = 0.0 369 | 370 | trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) 371 | if opt.resume_states is not '': 372 | trainer.load_states(opt.resume_states) 373 | 374 | if opt.label_smoothing or opt.mixup: 375 | sparse_label_loss = False 376 | else: 377 | sparse_label_loss = True 378 | if distillation: 379 | L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature, 380 | hard_weight=opt.hard_weight, 381 | sparse_label=sparse_label_loss) 382 | else: 383 | L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss) 384 | 385 | best_val_score = 1 386 | 387 | for epoch in range(opt.resume_epoch, opt.num_epochs): 388 | tic = time.time() 389 | if opt.use_rec: 390 | train_data.reset() 391 | train_metric.reset() 392 | btic = time.time() 393 | 394 | for i, batch in enumerate(train_data): 395 | data, label = batch_fn(batch, ctx) 396 | 397 | if opt.mixup: 398 | lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha) 399 | if epoch >= opt.num_epochs - opt.mixup_off_epoch: 400 | lam = 1 401 | data = [lam*X + (1-lam)*X[::-1] for X in data] 402 | 403 | if opt.label_smoothing: 404 | eta = 0.1 405 | else: 406 | eta = 0.0 407 | label = mixup_transform(label, classes, lam, eta) 408 | 409 | elif opt.label_smoothing: 410 | hard_label = label 411 | label = smooth(label, classes) 412 | 413 | if distillation: 414 | teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \ 415 | for X in data] 416 | 417 | with ag.record(): 418 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] 419 | if distillation: 420 | loss = [L(yhat.astype('float32', copy=False), 421 | y.astype('float32', copy=False), 422 | p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)] 423 | else: 424 | loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] 425 | for l in loss: 426 | l.backward() 427 | trainer.step(batch_size) 428 | 429 | if opt.mixup: 430 | output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \ 431 | for out in outputs] 432 | train_metric.update(label, output_softmax) 433 | else: 434 | if opt.label_smoothing: 435 | train_metric.update(hard_label, outputs) 436 | else: 437 | train_metric.update(label, outputs) 438 | 439 | if opt.log_interval and not (i+1)%opt.log_interval: 440 | train_metric_name, train_metric_score = train_metric.get() 441 | logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%( 442 | epoch, i, batch_size*opt.log_interval/(time.time()-btic), 443 | train_metric_name, train_metric_score, trainer.learning_rate)) 444 | btic = time.time() 445 | 446 | train_metric_name, train_metric_score = train_metric.get() 447 | throughput = int(batch_size * i /(time.time() - tic)) 448 | 449 | err_top1_val, err_top5_val = test(ctx, val_data) 450 | 451 | logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score)) 452 | logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic)) 453 | logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val)) 454 | 455 | if err_top1_val < best_val_score: 456 | best_val_score = err_top1_val 457 | net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) 458 | trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch)) 459 | 460 | if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: 461 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch)) 462 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch)) 463 | if save_frequency and save_dir: 464 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1)) 465 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1)) 466 | 467 | if opt.mode == 'hybrid': 468 | net.hybridize(static_alloc=True, static_shape=True) 469 | if distillation: 470 | teacher.hybridize(static_alloc=True, static_shape=True) 471 | train(context) 472 | 473 | if __name__ == '__main__': 474 | main() 475 | -------------------------------------------------------------------------------- /train_nas.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | import time 6 | import logging 7 | import warnings 8 | from collections import defaultdict 9 | from glob import glob 10 | 11 | import torch 12 | import numpy as np 13 | import pickle 14 | 15 | from torch import nn 16 | from torch.utils import data 17 | from viterbi import maxsum, sumprod_log, complete, score 18 | from torchvision import datasets 19 | from torchvision import transforms 20 | from types import SimpleNamespace 21 | 22 | import latency 23 | import misc 24 | from model import SlimMobilenet 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 29 | 30 | 31 | parser = argparse.ArgumentParser(description='Train mobilenet slimmable/AOWS model') 32 | parser.add_argument('--data', metavar='DIR', default='/imagenet', 33 | help='path to dataset') 34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('-b', '--batch-size', default=512, type=int, 39 | metavar='N', help='mini-batch size (default: 512)') 40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 41 | metavar='LR', help='initial learning rate') 42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 43 | help='momentum') 44 | parser.add_argument('--no-cuda', action="store_true") 45 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 46 | metavar='W', help='weight decay (default: 1e-5)') 47 | parser.add_argument('--print-freq', '-p', default=10, type=int, 48 | metavar='N', help='print frequency (default: 10)') 49 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 50 | help='path to latest checkpoint (default: none)') 51 | parser.add_argument('--resume-last', action='store_true', 52 | help='resume to last checkpoint if found (supersedes resume)') 53 | parser.add_argument("--debug", nargs="*", choices=["mini"], default="none", 54 | help="do a mini epoch to check everything works.") 55 | parser.add_argument('--max-width', type=float, default=1.5) 56 | parser.add_argument('--min-width', type=float, default=0.2) 57 | parser.add_argument('--levels', type=int, default=14) 58 | parser.add_argument('--latency-target', type=float, default=0.04, help="latency target in objective") 59 | parser.add_argument('--window', type=int, default=100000, 60 | help="size of window over which the moving average of losses is computed in OWS and AOWS.") 61 | parser.add_argument('--latency-model', type=str, default='output/model_trt16_K100.0.jsonl', help="latency model") 62 | parser.add_argument('--gamma-iter', type=int, default=12, help="Number of Viterbi iterations to set gamma.") 63 | parser.add_argument('--AOWS', action="store_true", help="use AOWS") 64 | parser.add_argument('--AOWS-warmup', type=int, default=5, help="AOWS warmup epochs") 65 | parser.add_argument('--AOWS-min-temp', type=float, default=0.0005, help="minimum (final) temperature") 66 | parser.add_argument('--expname', default='output/nas_output') 67 | 68 | 69 | def main(): 70 | args = parser.parse_args() 71 | 72 | logger.info("=> creating model") 73 | model = SlimMobilenet(min_width=args.min_width, max_width=args.max_width, levels=args.levels) 74 | logger.info(model) 75 | if not args.no_cuda: 76 | model = nn.DataParallel(model).cuda() 77 | 78 | criterion = nn.CrossEntropyLoss(reduction='none') 79 | if not args.no_cuda: 80 | criterion = criterion.cuda() 81 | 82 | args.lr = args.lr * args.batch_size / 256 83 | logger.info("learning rate scaling: using lr={}".format(args.lr)) 84 | 85 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 86 | momentum=args.momentum, 87 | weight_decay=args.weight_decay) 88 | 89 | torch.backends.cudnn.benchmark = True 90 | 91 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 92 | augment = [transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), ] 93 | 94 | dataset_train = datasets.ImageFolder(os.path.join(args.data, 'train'), 95 | transforms.Compose(augment + [transforms.ToTensor(), normalize])) 96 | train_loader = data.DataLoader(dataset_train, 97 | batch_size=args.batch_size, 98 | num_workers=args.workers, pin_memory=True) 99 | 100 | start_epoch = 0 101 | ows_state = SimpleNamespace() 102 | 103 | filters = model.filters if hasattr(model, 'filters') else model.module.filters 104 | ows_state.histories = [{c: misc.MovingAverageMeter(args.window) for c in F.configurations} for F in filters] 105 | ows_state.latency = dict(latency.load_model(args.latency_model)) 106 | 107 | if args.resume_last: 108 | avail = glob(osp.join(args.expname, 'checkpoint*.pth')) 109 | avail = [(int(f[-len('.pth') - 3:-len('.pth')]), f) for f in avail] 110 | avail = sorted(avail) 111 | if avail: 112 | args.resume = avail[-1][1] 113 | if args.resume: 114 | if os.path.isfile(args.resume): 115 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 116 | checkpoint = torch.load(args.resume) 117 | start_epoch = checkpoint['epoch'] 118 | state_dict = checkpoint['state_dict'] 119 | key = next(iter(state_dict.keys())) 120 | if key.startswith('module.') and args.no_cuda: 121 | state_dict = {k[len('module.'):]: v for (k, v) in state_dict.items()} 122 | model.load_state_dict(state_dict) 123 | optimizer.load_state_dict(checkpoint['optimizer']) 124 | if 'ows_state' in checkpoint: 125 | ows_state.histories = checkpoint['ows_state'].histories 126 | logger.info("=> loaded checkpoint '{}' (epoch {})" 127 | .format(args.resume, checkpoint['epoch'])) 128 | else: 129 | logger.info(f"=> no checkpoint found at '{args.resume}'") 130 | args.resume = '' 131 | 132 | if args.resume: 133 | # Solve OWS one first time in order to allow re-evaluation on the last epoch with varying latency target 134 | best_path, _, _, _, timing = solve_ows(model, start_epoch, len(train_loader), -1, ows_state, args, eval_only=True) 135 | logger.info('Evaluation from resumed checkpoint...') 136 | best_path_str = (f"Best configuration: {best_path}, " 137 | f"predicted latency: {timing}") 138 | logger.info(best_path_str) 139 | 140 | for epoch in range(start_epoch, args.epochs): 141 | history = train(train_loader, model, criterion, optimizer, epoch, ows_state, args) 142 | 143 | logger.info(f"=> saving decision history for epoch {format(epoch + 1)}") 144 | decision_target = 'decision{:03d}.pkl'.format(epoch + 1) 145 | if args.expname: 146 | os.makedirs(args.expname, exist_ok=True) 147 | decision_target = osp.join(args.expname, decision_target) 148 | with open(decision_target, 'wb') as f: 149 | pickle.dump(history, f, protocol=4) 150 | 151 | logger.info(f"=> saving checkpoint for epoch {epoch + 1}") 152 | 153 | current_state = { 154 | 'epoch': epoch + 1, 155 | 'state_dict': model.state_dict(), 156 | 'optimizer': optimizer.state_dict(), 157 | } 158 | 159 | current_state['ows_state'] = ows_state 160 | best_path_str = (f"Best configuration: {history['OWS'][-1]['best_path']}, " 161 | f"predicted latency: {history['OWS'][-1]['pred_latency']}") 162 | logger.info(best_path_str) 163 | with open(osp.join(args.expname, f"ows_result_{epoch + 1:03d}.txt"), 'w') as f: 164 | f.write(best_path_str + '\n') 165 | filename = save_checkpoint(current_state, args.expname) 166 | logger.info(f"checkpoint saved to {filename}.") 167 | 168 | 169 | 170 | def train(train_loader, model, criterion, optimizer, epoch, ows_state, args): 171 | meters = defaultdict(misc.AverageMeter) 172 | 173 | model.train() 174 | 175 | filters = model.filters if hasattr(model, 'filters') else model.module.filters 176 | history = defaultdict(list) 177 | 178 | end = time.time() 179 | for iteration, (input, target) in enumerate(train_loader): 180 | if "mini" in args.debug and iteration > 20: break 181 | 182 | best_path, temperature, gamma_max, best_perf, timing = solve_ows( 183 | model, epoch, len(train_loader), iteration, ows_state, args) 184 | 185 | # measure data loading time 186 | meters["data_time"].update(time.time() - end) 187 | 188 | if not args.no_cuda: 189 | target = target.cuda(non_blocking=True) 190 | 191 | compute_results = misc.SWDefaultDict(misc.SWDict) 192 | 193 | minconf = [F.configurations[0] for F in filters] 194 | maxconf = [F.configurations[-1] for F in filters] 195 | 196 | optimizer.zero_grad() 197 | 198 | # sandwich rule: train maximum configuration 199 | outp = model(input, configuration=maxconf) 200 | loss = criterion(outp['x'], target) 201 | loss.mean().backward() 202 | compute_results['max']['x'] = outp['x'].detach() 203 | compute_results['max']['loss_numpy'] = loss.detach().cpu().numpy() 204 | compute_results['max']['prob'] = torch.nn.functional.softmax(compute_results['max']['x'], dim=1) 205 | 206 | # sandwich rule: train minimum and random configuration with self-distillation 207 | for kind in ('min', 'rand'): 208 | conf = None if kind == 'rand' else minconf 209 | outp = model(input, configuration=conf) 210 | 211 | loss = misc.soft_cross_entropy(outp['x'], compute_results['max']['prob'].detach()) 212 | compute_results[kind]['soft_loss_numpy'] = loss.detach().cpu().numpy() 213 | with torch.no_grad(): 214 | hard_loss_numpy = criterion(outp['x'], target).detach().cpu().numpy() 215 | compute_results[kind]['loss_numpy'] = hard_loss_numpy 216 | 217 | compute_results[kind]['x'] = outp['x'].detach() 218 | if kind == 'rand': 219 | compute_results['rand']['decision'] = outp['decision'].cpu().numpy() 220 | loss.mean().backward() 221 | 222 | for path, image_loss, image_refloss in zip(compute_results['rand']['decision'], 223 | compute_results['rand']['loss_numpy'], 224 | compute_results['max']['loss_numpy']): 225 | for i, pi in enumerate(path): 226 | ows_state.histories[i][pi].update(-(image_loss - image_refloss) / len(path), epoch, iteration) 227 | 228 | for refname in ('min', 'max', 'rand'): 229 | meters['loss_' + kind].update(compute_results[kind]['loss_numpy'].mean(), input.size(0)) 230 | refloss = compute_results[refname]['loss_numpy'] 231 | (prec1, prec5), refcorrect_ks = misc.accuracy(compute_results[refname]['x'].data, 232 | target, topk=(1, 5), return_correct_k=True) 233 | refcorrect1, refcorrect5 = [a.cpu().numpy().astype(bool) for a in refcorrect_ks] 234 | history['loss_' + refname].append(refloss) 235 | history['top1_' + refname].append(refcorrect1) 236 | history['top5_' + refname].append(refcorrect5) 237 | meters['top1_' + refname].update(prec1.item(), input.size(0)) 238 | meters['top5_' + refname].update(prec5.item(), input.size(0)) 239 | if 'soft_loss_numpy' in compute_results[refname]: 240 | meters['loss_soft_' + kind].update(compute_results[kind]['soft_loss_numpy'].mean(), input.size(0)) 241 | history['loss_soft_' + refname].append(compute_results[refname]['soft_loss_numpy']) 242 | 243 | history['configuration'].append(compute_results['rand']['decision']) 244 | history['configuration'].append(compute_results['rand']['loss_numpy']) 245 | 246 | optimizer.step() 247 | 248 | # measure elapsed time 249 | meters["batch_time"].update(time.time() - end) 250 | end = time.time() 251 | 252 | if iteration % args.print_freq == 0: 253 | toprint = f"Epoch: [{epoch}][{iteration}/{len(train_loader)}]\t" 254 | toprint += ('Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 255 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 256 | 'Prec@1 {top1_rand.val:.3f} ({top1_rand.avg:.3f})\t' 257 | 'Prec@5 {top5_rand.val:.3f} ({top5_rand.avg:.3f})\t'.format(**meters)) 258 | 259 | for key, meter in meters.items(): 260 | if key.startswith('loss'): 261 | toprint += f'{key} {meter.val:.4f} ({meter.avg:.4f})\t' 262 | logger.info(toprint) 263 | 264 | # prints a string summarizing the sampling probabilities for each filter 265 | probas_str = "" 266 | for i, F in enumerate(filters): 267 | if F.probability is not None: 268 | probas_str += '|{} '.format(i) 269 | for p in F.probability: 270 | probas_str += str(int(100 * p)) + ' ' 271 | probas_log = None 272 | if any(F.probability is not None for F in filters): 273 | probas_log = tuple(F.probability for F in filters), 274 | history['OWS'].append(dict(best_path=best_path, temperature=temperature, gamma_max=gamma_max, 275 | best_pref=best_perf, pred_latency=timing, probas_log=probas_log)) 276 | if probas_str: 277 | probas_str = '\n' + probas_str 278 | ows_str = f"predicted latency: {timing}, perf: {best_perf}, T: {temperature}, gamma: {gamma_max}" 279 | logger.info('best_path: ' + ','.join(map(str, best_path)) + ows_str + probas_str) 280 | 281 | 282 | return history 283 | 284 | 285 | def aows_temp(epoch, epoch_len, iteration, args): 286 | schedule = [(0, 1.0), (args.AOWS_warmup, 1.0), 287 | (args.AOWS_warmup + 1, 0.01), 288 | (10, 0.001), (args.epochs, args.AOWS_min_temp)] 289 | cur_phase = 0 290 | for iphase, (phase, _) in enumerate(schedule): 291 | if epoch >= phase: 292 | cur_phase = iphase 293 | phase, start_temp = schedule[cur_phase] 294 | if cur_phase == len(schedule) - 1: 295 | return start_temp 296 | end_phase, end_temp = schedule[cur_phase + 1] 297 | max_iter = epoch_len * (end_phase - phase) 298 | cur_iter = epoch_len * (epoch - phase) + iteration 299 | ratio = cur_iter / max_iter 300 | log_T = (1.0 - ratio) * np.log10(start_temp) + ratio * np.log10(end_temp) 301 | return 10 ** log_T 302 | 303 | 304 | def solve_ows(model, epoch, len_epoch, iteration, ows_state, args, eval_only=False): 305 | """ 306 | Solves OWS equation and sets AOWS probabilities when AOWS is activated. 307 | """ 308 | if hasattr(model, 'module'): model = model.module 309 | 310 | unaries = [[0.0]] + [[M.avg for M in C.values()] for C in ows_state.histories] + [[0.0]] 311 | 312 | if not hasattr(ows_state, 'pairwise'): 313 | pairwise = [] 314 | 315 | possible_in_channels = [3] 316 | possible_outputs = iter([F.configurations for F in model.filters] + [[1000]]) 317 | for L in model.components: 318 | possible_out = next(possible_outputs) 319 | pair = np.zeros((len(possible_in_channels), len(possible_out))) 320 | for incoming, p in enumerate(possible_in_channels): 321 | for outgoing, l in enumerate(possible_out): 322 | var = latency.Vartype(**L._asdict(), in_channels=p, out_channels=l) 323 | pair[incoming, outgoing] = ows_state.latency[var] 324 | pairwise.append(pair) 325 | possible_in_channels = possible_out 326 | ows_state.pairwise = pairwise 327 | 328 | unaries, pairwise, states = complete(unaries, ows_state.pairwise) 329 | 330 | def solve(gamma): 331 | _, ipath = maxsum(unaries, -gamma * pairwise, states) 332 | perf, timing = score(ipath, unaries, pairwise, detail=True) 333 | return ipath, perf, timing 334 | 335 | gamma_min = 0.0 336 | gamma_max = 10.0 337 | timing_max = solve(gamma_max)[2] 338 | 339 | expanding_iterations = 0 340 | while timing_max > args.latency_target: 341 | expanding_iterations += 1 342 | if expanding_iterations > 2: 343 | logging.warning("Too many expanding loops for gamma, try adjusting gamma_max in the code") 344 | gamma_max *= 2 345 | timing_max = solve(gamma_max)[2] 346 | 347 | for _ in range(args.gamma_iter): 348 | mid_gamma = 0.5 * (gamma_min + gamma_max) 349 | timing_middle = solve(mid_gamma)[2] 350 | if timing_middle > args.latency_target: 351 | gamma_min = mid_gamma 352 | else: 353 | gamma_max = mid_gamma 354 | ipath, perf, timing = solve(gamma_max) 355 | 356 | T = np.inf 357 | if args.AOWS and epoch >= args.AOWS_warmup and not eval_only: 358 | T = aows_temp(epoch, len_epoch, iteration, args) 359 | marginals = sumprod_log(unaries / T, -gamma_max * pairwise / T, states) 360 | assert marginals.shape[0] == len(model.filters) + 2, "{} {}".format(marginals.shape[0], len(model.filters)) 361 | for F, marginal in zip(model.filters, marginals[1:-1]): 362 | F.probability = marginal[:len(F.configurations)] 363 | 364 | best_path = tuple(F.configurations[i] for (F, i) in zip(model.filters, ipath[1:-1])) 365 | return best_path, T, gamma_max, perf, timing 366 | 367 | 368 | def save_checkpoint(state, expname=''): 369 | filename = f"checkpoint{state['epoch']:03d}.pth" 370 | if expname: 371 | os.makedirs(expname, exist_ok=True) 372 | filename = osp.join(expname, filename) 373 | torch.save(state, filename) 374 | return filename 375 | 376 | 377 | if __name__ == '__main__': 378 | logger = logging.getLogger(__file__) 379 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG, 380 | format='%(name)s: %(message)s') 381 | main() 382 | -------------------------------------------------------------------------------- /viterbi.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from numba import jit 4 | import sys, logging 5 | logging.getLogger('numba').setLevel(logging.WARNING) 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def complete(unary, pairwise, fill=-np.inf): 12 | """ 13 | Convert lists of unaries and pairwises into tensors by filling blanks 14 | """ 15 | N = len(unary) 16 | states = np.array(list(map(len, unary))) 17 | max_S = max(states) 18 | un = np.full((N, max_S), fill) 19 | pair = np.full((N - 1, max_S, max_S), fill) 20 | for i in range(N): 21 | un[i, :len(unary[i])] = unary[i] 22 | if i < N - 1: 23 | pair[i, :len(unary[i]), :len(unary[i + 1])] = pairwise[i] 24 | 25 | return un, pair, states 26 | 27 | 28 | @jit(nopython=True) 29 | def maxsum(unary, pairwise, states): 30 | N, S = unary.shape 31 | partial = unary[0] 32 | selected = np.zeros((N - 1, S), np.int64) 33 | for s in range(N - 1): 34 | new_partial = np.full((S,), -np.inf) 35 | for j in range(states[s + 1]): 36 | best_ = -np.inf 37 | best_i = 0 38 | for i in range(states[s]): 39 | candidate = partial[i] + pairwise[s, i, j] 40 | if candidate > best_: 41 | best_ = candidate 42 | best_i = i 43 | selected[s, j] = best_i 44 | new_partial[j] = unary[s + 1, j] + best_ 45 | partial = new_partial 46 | 47 | path = np.zeros((N,), np.int64) 48 | score = -np.inf 49 | best_j = 0 50 | for j in range(states[N - 1]): 51 | candidate = partial[j] 52 | if candidate > score: 53 | score = candidate 54 | best_j = j 55 | path[N - 1] = best_j 56 | for i in range(N - 2, -1, -1): 57 | best_j = selected[i, best_j] 58 | path[i] = best_j 59 | return score, path 60 | 61 | 62 | def score(path, unary, pairwise, detail=False): 63 | if not len(path): return 0.0 64 | S = unary[0][path[0]] 65 | Sp = 0.0 66 | prev = path[0] 67 | for i, p in enumerate(path[1:]): 68 | Sp += pairwise[i][prev, p] 69 | S += unary[i + 1][p] 70 | prev = p 71 | if detail: 72 | return S, Sp 73 | return S + Sp 74 | 75 | 76 | def maxsum_brute(unary, pairwise, states, K=3): 77 | """ 78 | Brute-force max-sum (for debugging) 79 | """ 80 | best_path = None 81 | best_score = -float('inf') 82 | for path in itertools.product(*[range(s) for s in states]): 83 | sc = score(path, unary, pairwise) 84 | if sc > best_score: 85 | best_path = path 86 | best_score = sc 87 | return best_score, best_path 88 | 89 | 90 | @jit(nopython=True) 91 | def sumprod_log(unary, pairwise, states, logspace=False): 92 | N, max_s = unary.shape 93 | alpha = np.zeros((N, max_s), dtype=unary.dtype) 94 | beta = np.zeros((N, max_s), dtype=unary.dtype) 95 | alpha[0, :states[0]] = unary[0][:states[0]] 96 | alpha[0, states[0]:] = -np.inf 97 | for s in range(N - 1): 98 | for k2 in range(states[s + 1]): 99 | M = -np.inf 100 | for k1 in range(states[s]): 101 | C = alpha[s, k1] + pairwise[s, k1, k2] 102 | if C > M: 103 | M = C 104 | for k1 in range(states[s]): 105 | alpha[s + 1, k2] += np.exp(alpha[s, k1] + pairwise[s, k1, k2] - M) 106 | alpha[s + 1, k2] = unary[s + 1, k2] + np.log(alpha[s + 1, k2]) + M 107 | for k2 in range(states[s + 1], max_s): 108 | alpha[s + 1, k2] = -np.inf 109 | M = alpha[s + 1, :states[s + 1]].max() 110 | Z = np.log(np.exp(alpha[s + 1, :states[s + 1]] - M).sum()) + M 111 | alpha[s + 1, :states[s + 1]] = alpha[s + 1, :states[s + 1]] - Z 112 | for s in range(N - 2, -1, -1): 113 | for k1 in range(states[s]): 114 | M = -np.inf 115 | for k2 in range(states[s + 1]): 116 | C = beta[s + 1, k2] + pairwise[s, k1, k2] + unary[s + 1, k2] 117 | if C > M: 118 | M = C 119 | for k2 in range(states[s + 1]): 120 | beta[s, k1] += np.exp(beta[s + 1, k2] + pairwise[s, k1, k2] + unary[s + 1, k2] - M) 121 | beta[s, k1] = np.log(beta[s, k1]) + M 122 | for k1 in range(states[s], max_s): 123 | beta[s, k1] = -np.inf 124 | M = beta[s, :states[s]].max() 125 | Z = np.log(np.exp(beta[s, :states[s]] - M).sum()) + M 126 | beta[s, :states[s]] = beta[s, :states[s]] - Z 127 | marg = alpha + beta 128 | for s in range(N): 129 | marg[s] = marg[s] - marg[s].max() 130 | if not logspace: 131 | marg = np.exp(marg) 132 | marg = marg / marg.sum(1).reshape(-1, 1) 133 | return marg 134 | 135 | 136 | --------------------------------------------------------------------------------