├── .gitignore
├── LICENSE
├── README.md
├── latency.py
├── misc.py
├── model.py
├── mxnet_mobilenet.py
├── train_final.py
├── train_nas.py
└── viterbi.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project-specific
2 | output
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # Environments
88 | .env
89 | .venv
90 | env/
91 | venv/
92 | ENV/
93 | env.bak/
94 | venv.bak/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # PyCharm
104 | .idea
105 |
106 | # PyTorch weights
107 | *.tar
108 | *.pth
109 | *.gz
110 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Maxim Berman
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AOWS
2 | **AOWS: Adaptive and optimal network width search with latency constraints**, Maxim Berman, Leonid Pishchulin, Ning Xu, Matthew B. Blaschko, Gérard Medioni, _NAS workshop @ ICLR 2020_ and _CVPR 2020 (oral)_.
3 |
4 |
5 |
6 |
7 | ## Usage
8 |
9 | ### Latency model
10 | _main file: `latency.py`_
11 |
12 | _depends on: [PyTorch](http://pytorch.org/), [CVXPY](https://www.cvxpy.org/), matplotlib, numpy, [numba](http://numba.pydata.org/), scipy, [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt/tree/e22844a449a880123435fce7e6444f1516ebbe60)_
13 |
14 |
15 | * Generate training and validation samples
16 | ```Shell
17 | python latency.py generate --device trt --dtype fp16 \
18 | --biased --count 8000 output/samples_trt16.jsonl
19 | python latency.py generate --device trt --dtype fp16 \
20 | --count 200 output/val_trt16.jsonl
21 | ```
22 | * Fit the model (`K` controls the amount of regularization and should be set by validation)
23 | ```Shell
24 | python latency.py fit output/samples_trt16.jsonl \
25 | output/model_trt16_K100.0.jsonl -K 100.0
26 | ```
27 | * Validate the model (produces a plot)
28 | ```Shell
29 | python latency.py validate output/val_trt16.jsonl \
30 | output/model_trt16_K100.0.jsonl output/correlation_plot.png
31 | ```
32 |
33 | Additionally, one can benchark a single configuration with this script using e.g.
34 | ```Shell
35 | python latency.py benchmark --device trt --dtype fp16 \
36 | "(16, 32, 64, 112, 360, 48, 464, 664, 152, 664, 256, 208, 816, 304)"
37 | ```
38 |
39 | ### Network width search
40 | _main file: `train_nas.py`_
41 |
42 | _depends on: [PyTorch](http://pytorch.org/), numpy, [numba](http://numba.pydata.org/)_
43 | * **Train a slimmable network and select a configuration with OWS.** See `-h` for optimization options.
44 | ```Shell
45 | python train_nas.py --data /imagenet --latency-target 0.04 \
46 | --latency-model output/model_trt16_K100.0.jsonl \
47 | --expname output/ows-trt16-0.04 --resume-last
48 | ```
49 | In OWS, the latency target `--latency-target` can be changed during or after training. Using the parameter `--resume-last` allows to resume the last checkpoint without having to retrain, allowing for varying the latency target.
50 |
51 | _Implementation detail: for ease of implementation we here use a fixed moving average with a window of `--window=100000` samples for each unary weight, while in the article we used the statistics available over one full last epoch._
52 |
53 | * **Train a slimmable network with AOWS.** See `-h` for optimization options. The outputs, including best configuration for each epochs, are put in the directory corresponding to the parameter `--expname`.
54 | ```Shell
55 | python train_nas.py --data /imagenet --latency-target 0.04 \
56 | --latency-model output/model_trt16_K100.0.jsonl \
57 | --AOWS --expname output/aows-trt16-0.04 --resume-last
58 | ```
59 | In AOWS, the latency target `--latency-target` should be set at the beginning of the training, since it impacts the training.
60 |
61 |
62 | ### Training the final model
63 | _main file: `train_final.py`_
64 |
65 | _depends on: [mxnet](https://mxnet.apache.org/)_
66 |
67 | modified version of gluon-cv's [train_imagenet.py](https://github.com/dmlc/gluon-cv/blob/18f8ab526ffb97660e6e5661f991064c20e2699d/scripts/classification/imagenet/train_imagenet.py) for training mobilenet-v1 with varying channel numbers. Refer to gluon-cv's documentation for detailed usage.
68 |
69 | Example command:
70 | ```
71 | python train_final.py \
72 | --rec-train /imagenet/imagenet_train.rec \
73 | --rec-train-idx /imagenet/imagenet_train.idx \
74 | --rec-val /ramdisk/imagenet_val.rec \
75 | --rec-val-idx /ramdisk/imagenet_val.idx \
76 | --use-rec --mode hybrid --lr 0.4 --lr-mode cosine \
77 | --num-epochs 200 --batch-size 256 -j 32 --num-gpus 4 \
78 | --dtype float16 --warmup-epochs 5 --no-wd \
79 | --label-smoothing --mixup \
80 | --save-dir params_mymobilenet --logging-file mymobilenet.log \
81 | --configuration "(16, 32, 64, 112, 360, 48, 464, 664, 152, 664, 256, 208, 816, 304)"
82 | ```
83 |
84 | ## Citation
85 | ```BibTeX
86 | @InProceedings{Berman2020AOWS,
87 | author = {Berman, Maxim and Pishchulin, Leonid and Xu, Ning and Blaschko, Matthew B. and Medioni, Gerard},
88 | title = {{AOWS}: adaptive and optimal network width search with latency constraints},
89 | booktitle = {Proceedings of the {IEEE} Computer Society Conference on Computer Vision and Pattern Recognition},
90 | month = jun,
91 | year = {2020},
92 | }
93 | ```
94 |
95 | ## Disclaimer
96 | The code was re-implemented and is not fully tested at this point.
97 |
98 |
--------------------------------------------------------------------------------
/latency.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import os.path as osp
6 | import random
7 | import sys
8 | import time
9 | from collections import Counter
10 | from collections import defaultdict
11 | from collections import namedtuple
12 |
13 | import cvxpy as cp
14 | import matplotlib.pyplot as plt
15 | import numpy as np
16 | import torch
17 | from scipy.sparse import lil_matrix
18 | # lazy: from torch2trt import torch2trt
19 |
20 | from misc import DelayedKeyboardInterrupt
21 | from misc import tuplify
22 | from model import SlimMobilenet
23 | from model import LayerType
24 | from viterbi import complete
25 | from viterbi import maxsum
26 |
27 |
28 | logger = logging.getLogger(__name__)
29 | Vartype = namedtuple("Vartype", LayerType._fields + ('in_channels', 'out_channels'))
30 | torch.backends.cudnn.benchmark = True
31 |
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser(
35 | description="Generate samples and fit a latency model.")
36 |
37 | subparsers = parser.add_subparsers(dest='mode')
38 | subparsers.required = True
39 |
40 | parser_bench = subparsers.add_parser('benchmark',
41 | help="Benchmark a single channel configuration",
42 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
43 | parser_bench.add_argument("configuration",
44 | help="configuration to test (comma-separated channels or MOBILENET)")
45 |
46 | parser_gen = subparsers.add_parser('generate',
47 | help="Generate latency samples",
48 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
49 | for subparser in (parser_bench, parser_gen):
50 | subparser.add_argument("-D", "--device", choices=["cpu", "gpu", "trt"],
51 | default="gpu", help="Use GPU, CPU or TensorRT latency")
52 | subparser.add_argument("--dtype", choices=["fp32", "fp16"],
53 | default="fp16", help="Datatype for network")
54 | subparser.add_argument("-B", "--batch-size", type=int, default=64,
55 | help="Batch size used for profiling")
56 | subparser.add_argument("-I", "--iterations", type=int, default=60,
57 | help="Profiling iterations")
58 | subparser.add_argument("-W", "--warmup", type=int, default=10,
59 | help="Warmup iterations")
60 | subparser.add_argument("--reduction", choices=['mean', 'min'], default='mean',
61 | help="Reduce timings by their mean or by their minimum (minimum can reduce variance)")
62 | parser_gen.add_argument("--biased", action="store_true",
63 | help="Bias sampling towards missing configurations")
64 | parser_gen.add_argument("-N", "--count", type=int, default=8000,
65 | help="Minimum number of samples to generate")
66 | parser_gen.add_argument("-R", "--repetitions", type=int, default=0,
67 | help="Minimum number of samples per choice")
68 | parser_gen.add_argument("--save-every", type=int, default=1,
69 | help="Number of inferences before saving intermediate output")
70 | parser_gen.add_argument("samples_file", help="Output samples file")
71 |
72 | parser_fit = subparsers.add_parser('fit', help="Fit a latency model")
73 | parser_fit.add_argument("-K", "--regularize", type=float, default=0.0,
74 | help="Amount of monotonicity regularization (Equation 7)")
75 | parser_fit.add_argument("samples_file", help="Training samples")
76 | parser_fit.add_argument("model_file", help="Output model file")
77 |
78 | parser_val = subparsers.add_parser('validate', help="Validate a latency model")
79 | parser_val.add_argument("samples_file", help="Validation samples")
80 | parser_val.add_argument("model_file", help="Model file")
81 | parser_val.add_argument("plot_file", help="Plot file")
82 |
83 | args = parser.parse_args()
84 |
85 | if 'configuration' in args:
86 | defaults = {'MOBILENET': "32,64,128,128,256,256,512,512,512,512,512,512,1024,1024"}
87 | if args.configuration in defaults:
88 | args.configuration = defaults[args.configuration]
89 | args.configuration = [int(''.join(ci for ci in c if ci.isdigit())) for c in args.configuration.split(',')]
90 |
91 | return args
92 |
93 |
94 | def get_model(min_width=0.2, max_width=1.5, levels=14):
95 | return SlimMobilenet(min_width=min_width, max_width=max_width, levels=levels)
96 |
97 |
98 | def benchmark(device, dtype, batch_size, iterations, warmup, reduction, configuration, silent=False):
99 | if device == 'cpu':
100 | dev = torch.device('cpu')
101 | elif device in ['gpu', 'trt']:
102 | dev = torch.device('cuda')
103 | fp = dict(fp16=torch.float16, fp32=torch.float32).get(dtype)
104 | net = SlimMobilenet.reduce(configuration).to(dev).type(fp).eval()
105 | x = torch.ones((batch_size, 3, 224, 224)).to(dev).type(fp)
106 | if device == 'trt':
107 | from torch2trt import torch2trt
108 | net = torch2trt(net, [x], fp16_mode=(dtype == 'fp16'), max_batch_size=batch_size)
109 |
110 | for i in range(warmup):
111 | outputs = net(x)
112 | torch.cuda.current_stream().synchronize()
113 |
114 | timings = []
115 | t0 = time.time()
116 | for i in range(iterations):
117 | outputs = net(x)
118 | torch.cuda.current_stream().synchronize()
119 | t1 = time.time()
120 | timings.append(t1 - t0)
121 | t0 = t1
122 |
123 | ms = 1000.0 * getattr(np, reduction)(timings) / batch_size
124 | if not silent:
125 | print(f"{configuration}: {ms}ms")
126 |
127 | return ms
128 |
129 |
130 | def gen_configuration_biased(net, repetitions):
131 | M = min(repetitions.values())
132 | unary = []
133 | pairwise = []
134 | for i, L in enumerate(net.components):
135 | input_choices = [net.in_channels] if i == 0 else net.configurations[i - 1]
136 | output_choices = ([net.out_channels] if i == len(net.components) - 1
137 | else net.configurations[i])
138 | U = np.zeros(len(input_choices))
139 | P = np.zeros((len(input_choices), len(output_choices)))
140 | for i1, I in enumerate(input_choices):
141 | for i2, O in enumerate(output_choices):
142 | var = Vartype(**L._asdict(), in_channels=I, out_channels=O)
143 | P[i1, i2] = float(repetitions[var] == M)
144 | unary.append(U)
145 | pairwise.append(P)
146 | unary.append(np.zeros(len(output_choices)))
147 | un, pair, states = complete(unary, pairwise)
148 | iconfig = maxsum(un, pair, states)[1]
149 | configuration = [C[i] for (C, i) in zip(net.configurations, iconfig[1:-1])]
150 | return configuration
151 |
152 |
153 | def gen_configuration(net, repetitions, biased=False):
154 | if biased:
155 | return gen_configuration_biased(net, repetitions)
156 | return [random.choice(conf) for conf in net.configurations]
157 |
158 |
159 | def collect_repetitions(net, configuration=None):
160 | if configuration is None:
161 | configuration = net.configurations
162 | if isinstance(configuration[0], (int, np.integer)): # single configuration
163 | configuration = [[c] for c in configuration]
164 | layertypes = Counter()
165 | for i, L in enumerate(net.components):
166 | input_choices = [net.in_channels] if i == 0 else configuration[i - 1]
167 | output_choices = ([net.out_channels] if i == len(net.components) - 1
168 | else configuration[i])
169 | for I in input_choices:
170 | for O in output_choices:
171 | var = Vartype(**L._asdict(), in_channels=I, out_channels=O)
172 | layertypes[var] += 1
173 | return layertypes
174 |
175 |
176 | def sample_file_iterator(samples_file):
177 | with open(samples_file, 'r') as f:
178 | for line in f:
179 | yield tuplify(json.loads(line))
180 |
181 |
182 | def generate(device, dtype, batch_size, iterations, warmup, reduction, biased,
183 | count, repetitions, samples_file=os.devnull, save_every=10):
184 | os.makedirs(osp.dirname(samples_file), exist_ok=True)
185 |
186 | net = get_model()
187 | combinations = collect_repetitions(net)
188 | logger.info(f"{len(net.configurations)} modulers")
189 | logger.debug(f"search space: {net.configurations}")
190 | logger.debug(f"components: {net.components}")
191 | logger.info(f"Latency model has {len(combinations)} parameters")
192 |
193 | repeats = Counter()
194 | for c in combinations:
195 | repeats[c] = 0
196 |
197 | samples = []
198 | if osp.isfile(samples_file):
199 | for sample in sample_file_iterator(samples_file):
200 | samples.append(sample)
201 | repeats.update(collect_repetitions(net, sample[0]))
202 | logger.info(f"Loaded {samples_file}, "
203 | f"min_repetition={min(repeats.values())} "
204 | f"count={len(samples)} ")
205 | logger.info(f"Writing new samples to {samples_file}")
206 | new_samples = []
207 | while (len(samples) + len(new_samples) < count
208 | or min(repeats.values()) < repetitions):
209 | configuration = gen_configuration(net, repeats, biased=biased)
210 | ms = benchmark(device, dtype, batch_size, iterations, warmup, reduction, configuration, silent=True)
211 | repeats.update(collect_repetitions(net, configuration))
212 | logger.info(f"{configuration}: {ms:.04f}ms, "
213 | f"min_repetition={min(repeats.values())} "
214 | f"count={len(samples) + len(new_samples)} ")
215 | new_samples.append([[int(d) for d in configuration], ms])
216 | if (len(new_samples) % save_every) == 0:
217 | with open(samples_file, 'a') as f:
218 | for sample in new_samples:
219 | dump = json.dumps(sample) + '\n'
220 | with DelayedKeyboardInterrupt():
221 | f.write(dump)
222 | samples.extend(new_samples)
223 | new_samples = []
224 |
225 | samples.extend(new_samples)
226 | return samples
227 |
228 |
229 | def build_equation(samples):
230 | """
231 | Samples can be iterator
232 | """
233 | net = get_model()
234 | variables = {}
235 | ivariables = {}
236 | Mcoord = []
237 | y = []
238 | for (i, sample) in enumerate(samples):
239 | y.append(sample[1])
240 | local_repeats = collect_repetitions(net, sample[0])
241 | for (L, r) in local_repeats.items():
242 | if L not in variables:
243 | j = len(variables)
244 | variables[L] = j
245 | ivariables[j] = L
246 | Mcoord.append((i, variables[L], r))
247 | y = np.array(y)
248 | M = lil_matrix((len(y), len(variables)))
249 | for (i, j, r) in Mcoord:
250 | M[i, j] = r
251 | return M, y, variables, ivariables
252 |
253 |
254 | def solve_lsq(M, y, regularize=0.0, K=None):
255 | n = M.shape[1]
256 | x = cp.Variable(n)
257 | t = cp.Variable(K.shape[0])
258 | M_cp = cp.Constant(M)
259 | obj = cp.sum_squares(M_cp @ x - y)
260 | constraints = [x >= 0]
261 | if regularize:
262 | K_cp = cp.Constant(K)
263 | obj += regularize * cp.sum_squares(t)
264 | constraints += [t >= 0, K_cp @ x <= t]
265 | objective = cp.Minimize(obj)
266 | prob = cp.Problem(objective, constraints)
267 | prob.solve(cp.SCS, verbose=True)
268 | return x.value
269 |
270 |
271 | def get_inequalities(variables):
272 | def other(L, *args):
273 | props = L._asdict()
274 | for k in args:
275 | del props[k]
276 | return tuple(props.values())
277 | buckets = defaultdict(list)
278 | for order in ['in_channels', 'out_channels', 'in_size']:
279 | for V in variables:
280 | buckets[other(V, order)].append(V)
281 | inequalities = []
282 | for bucket in buckets.values():
283 | bucket = sorted(bucket)
284 | for i in range(len(bucket) - 1):
285 | inequalities.append((bucket[i], bucket[i + 1]))
286 | K = lil_matrix((len(inequalities), len(variables)))
287 | for i, (C1, C2) in enumerate(inequalities):
288 | K[i, variables[C1]] = 1
289 | K[i, variables[C2]] = -1
290 | return K
291 |
292 |
293 | def fit_model(samples, regularize=0.0):
294 | M, y, variables, ivariables = build_equation(samples)
295 | K = get_inequalities(variables)
296 | x = solve_lsq(M, y, regularize, K)
297 | model = []
298 | for i, ms in enumerate(x):
299 | model.append((ivariables[i], ms))
300 | return model
301 |
302 |
303 | def dump_model(model, model_file):
304 | with open(model_file, 'w') as f:
305 | for m in model:
306 | var, ms = m
307 | dump = json.dumps([var._asdict(), ms]) + '\n'
308 | f.write(dump)
309 |
310 |
311 | def load_model(model_file):
312 | with open(model_file, 'r') as f:
313 | for line in f:
314 | var, ms = tuplify(json.loads(line))
315 | var = Vartype(**var)
316 | yield (var, ms)
317 |
318 |
319 | def fit(samples_file, model_file, regularize=0.0):
320 | os.makedirs(osp.dirname(model_file), exist_ok=True)
321 | samples = sample_file_iterator(samples_file)
322 | model = fit_model(samples, regularize)
323 | dump_model(model, model_file)
324 | return model
325 |
326 |
327 | def validate(samples_file, model_file, plot_file):
328 | os.makedirs(osp.dirname(plot_file), exist_ok=True)
329 | model = load_model(model_file)
330 | model_dict = dict(model)
331 | samples = sample_file_iterator(samples_file)
332 | M, y, variables, ivariables = build_equation(samples)
333 | x = [model_dict[ivariables[i]] for i in range(len(variables))]
334 | yhat = M @ x
335 | rmse = np.sqrt(((y - yhat) ** 2).mean())
336 | title = f"RMSE {rmse:.04f}, NRMSE {100 * rmse / y.mean():.02f}%"
337 | print(title)
338 | plt.plot(y, yhat, 'o')
339 | plt.xlabel("ground truth (ms)")
340 | plt.ylabel("predicted (ms)")
341 | plt.title(title)
342 | plt.savefig(plot_file)
343 |
344 |
345 | if __name__ == "__main__":
346 | logger = logging.getLogger(__file__)
347 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG,
348 | format='%(name)s: %(message)s')
349 | args = parse_args().__dict__
350 |
351 | globals()[args.pop('mode')](**args)
352 |
--------------------------------------------------------------------------------
/misc.py:
--------------------------------------------------------------------------------
1 | import signal
2 | import logging
3 | from collections import defaultdict
4 | from collections import deque
5 | import numpy as np
6 | import torch
7 |
8 |
9 | # https://stackoverflow.com/a/21919644/805502
10 | class DelayedKeyboardInterrupt(object):
11 | def __enter__(self):
12 | self.signal_received = False
13 | self.old_handler = signal.signal(signal.SIGINT, self.handler)
14 |
15 | def handler(self, sig, frame):
16 | self.signal_received = (sig, frame)
17 | logging.debug('SIGINT received. Delaying KeyboardInterrupt.')
18 |
19 | def __exit__(self, type, value, traceback):
20 | signal.signal(signal.SIGINT, self.old_handler)
21 | if self.signal_received:
22 | self.old_handler(*self.signal_received)
23 |
24 |
25 | # https://stackoverflow.com/a/25294767/805502
26 | def tuplify(listything):
27 | if isinstance(listything, list): return tuple(map(tuplify, listything))
28 | if isinstance(listything, dict): return {k:tuplify(v) for k,v in listything.items()}
29 | return listything
30 |
31 |
32 | class SWDict(dict):
33 | """
34 | Single-write dict. Useful for making sure no inference is computed twice.
35 | """
36 | def __setitem__(self, key, value):
37 | if key in self:
38 | raise ValueError('key', key, 'already set')
39 | super().__setitem__(key, value)
40 |
41 |
42 | class SWDefaultDict(defaultdict):
43 | """
44 | Single-write defaultdict.
45 | """
46 | def __setitem__(self, key, value):
47 | if key in self:
48 | raise ValueError('key', key, 'already set')
49 | super().__setitem__(key, value)
50 |
51 |
52 | class MovingAverageMeter(object):
53 | def __init__(self, window):
54 | self.window = window
55 | self.reset()
56 |
57 | def reset(self):
58 | self.history = deque()
59 | self.avg = 0
60 | self.sum = None
61 | self.val = None
62 |
63 | @property
64 | def count(self):
65 | return len(self.history)
66 |
67 | @property
68 | def isfull(self):
69 | return len(self.history) == self.window
70 |
71 | def __getstate__(self):
72 | state = self.__dict__.copy()
73 | state['history'] = np.array(state['history'])
74 | return state
75 |
76 | def __setstate__(self, state):
77 | state['history'] = deque(state['history'])
78 | self.__dict__.update(state)
79 |
80 | def update(self, val, epoch, iteration):
81 | self.history.append(val)
82 | if self.sum is None:
83 | self.sum = val
84 | else:
85 | self.sum += val
86 | if len(self.history) > self.window:
87 | self.sum -= self.history.popleft()
88 | self.val = val
89 | self.avg = self.sum / self.count
90 |
91 | def __repr__(self):
92 | return "".format(
93 | self.window, self.count, self.val, self.avg)
94 |
95 |
96 | class AverageMeter(object):
97 | """Computes and stores the average and current value"""
98 |
99 | def __init__(self):
100 | self.reset()
101 |
102 | def reset(self):
103 | self.val = 0
104 | self.avg = 0
105 | self.sum = 0
106 | self.count = 0
107 |
108 | def update(self, val, n=1):
109 | self.val = val
110 | self.sum += val * n
111 | self.count += n
112 | self.avg = self.sum / self.count
113 |
114 |
115 | def accuracy(output, target, topk=(1,), return_correct_k=False):
116 | """Computes the precision@k for the specified values of k"""
117 | maxk = max(topk)
118 | batch_size = target.size(0)
119 |
120 | _, pred = output.topk(maxk, 1, True, True)
121 | pred = pred.t()
122 | correct = pred.eq(target.view(1, -1).expand_as(pred))
123 |
124 | res = []
125 | correct_ks = []
126 |
127 | for k in topk:
128 | correct_k = correct[:k].float().sum(0)
129 | res.append(correct_k.sum().mul_(100.0 / batch_size))
130 | correct_ks.append(correct_k)
131 | if return_correct_k:
132 | return res, correct_ks
133 | return res
134 |
135 |
136 | def soft_cross_entropy(output, target):
137 | """
138 | For knowledge distillation in self-distillation
139 | """
140 | output_log_prob = torch.nn.functional.log_softmax(output, dim=1)
141 | target = target.unsqueeze(1)
142 | output_log_prob = output_log_prob.unsqueeze(2)
143 | cross_entropy_loss = -torch.bmm(target, output_log_prob).view(output.size(0))
144 | return cross_entropy_loss
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Model defining the search space.
3 | Implemented: Mobilenet-v1 family.
4 | """
5 |
6 | import torch
7 | from torch import nn
8 | import numpy as np
9 | from collections import namedtuple
10 | import logging
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 | LayerType = namedtuple('LayerType', ['in_size', 'kernel_size', 'stride', 'dw', 'bias'])
15 |
16 |
17 | class Moduler(nn.Module):
18 | """
19 | Dynamically trims input channels (randomly or based on argument)
20 | """
21 |
22 | def __init__(self, configurations):
23 | super().__init__()
24 | self.configurations = configurations
25 | self.base_channels = 8
26 | self.probability = None # for biased sampling
27 |
28 | def forward(self, data, channels=None, record=True):
29 | if not isinstance(data, dict):
30 | data = dict(x=data)
31 | x = data['x']
32 |
33 | if channels is None:
34 | idx = np.random.choice(np.arange(len(self.configurations)),
35 | size=x.size(0),
36 | p=self.probability)
37 | confs = self.configurations[idx]
38 | else:
39 | confs = channels * np.ones((x.size(0),), int)
40 |
41 | mask = x.new_zeros((x.size(0), x.size(1) + 1))
42 | mask[np.arange(len(confs)), confs] = 1.0
43 | mask = 1 - mask[:, :x.size(1)].cumsum(1)
44 | x = x * mask.unsqueeze(2).unsqueeze(3)
45 |
46 | data['x'] = x
47 | if record: # record chosen channels
48 | if 'decision' not in data: data['decision'] = []
49 | data['decision'].append(confs)
50 | return data
51 |
52 | def __repr__(self):
53 | return "Moduler({})".format(self.configurations)
54 |
55 |
56 | class Flatten(nn.Module):
57 | def forward(self, x):
58 | return x.view(x.size(0), -1)
59 |
60 |
61 | class SlimMobilenet(nn.Module):
62 |
63 | in_channels = 3
64 | out_channels = 1000
65 |
66 | @staticmethod
67 | def gen_conv(inp, oup, stride, dw=False, bn=True):
68 | mod = []
69 | if dw:
70 | mod = [
71 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
72 | nn.BatchNorm2d(inp),
73 | nn.ReLU(inplace=True),
74 |
75 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
76 | nn.BatchNorm2d(oup),
77 | nn.ReLU(inplace=True),
78 | ]
79 | else:
80 | mod = [
81 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
82 | nn.BatchNorm2d(oup),
83 | nn.ReLU(inplace=True),
84 | ]
85 | if not bn:
86 | mod = [m for m in mod if not isinstance(m, nn.BatchNorm2d)]
87 | return nn.Sequential(*mod)
88 |
89 | @staticmethod
90 | def strides_channels():
91 | """
92 | follows Mobilenet-v1 definition
93 | """
94 | blocks = [[32, 64], # special first stem block
95 | [128, 128],
96 | [256, 256],
97 | [512, 512, 512, 512, 512, 512],
98 | [1024, 1024],
99 | ]
100 | strides = []
101 | for block in blocks:
102 | strides.extend([2] + [1] * (len(block) - 1))
103 | base_channels = np.array([c for block in blocks for c in block])
104 |
105 | return strides, base_channels
106 |
107 | def __init__(self, min_width=0.2, max_width=1.5, levels=14, fc_dropout=0.0, in_size=(224, 224)):
108 | super().__init__()
109 |
110 | def divise8(i):
111 | return (np.maximum(np.round(i / 8), 1) * 8).astype(int)
112 |
113 | strides, base_channels = self.strides_channels()
114 | depthwise = [0] + [1] * (len(base_channels) - 1)
115 |
116 | self.configurations = divise8(base_channels.reshape(-1, 1) * np.linspace(min_width, max_width, levels).reshape(1, -1))
117 | self.configurations = [np.unique(c) for c in self.configurations]
118 |
119 | self.components = []
120 |
121 | channels = [self.in_channels] + [int(c[-1]) for c in self.configurations]
122 | inp = iter(channels)
123 | oup = iter(channels[1:])
124 |
125 | self.model = nn.ModuleList()
126 | for dw, strid in zip(depthwise, strides):
127 | I = next(inp)
128 | O = next(oup)
129 | mod = self.gen_conv(I, O, strid, dw)
130 | component = LayerType(in_size=in_size, kernel_size=3, stride=strid, dw=bool(dw), bias=False)
131 | in_size = (in_size[0] // strid, in_size[1] // strid)
132 | self.model.append(mod)
133 | self.components.append(component)
134 |
135 | self.filters = nn.ModuleList()
136 | for conf, base_chan in zip(self.configurations, base_channels):
137 | F = Moduler(conf)
138 | F.base_channels = base_chan
139 | self.filters.append(F)
140 |
141 | self.pool = nn.AvgPool2d(7)
142 | self.fc_dropout = None if not fc_dropout else nn.Dropout(fc_dropout)
143 | in_size = (in_size[0] // 7, in_size[1] // 7)
144 |
145 | I = next(inp)
146 | self.fc = nn.Linear(I, self.out_channels)
147 | self.components.append(LayerType(in_size=in_size, kernel_size=1, stride=1, dw=False, bias=True))
148 |
149 | def forward(self, data, configuration=None):
150 | if not isinstance(data, dict):
151 | data = dict(x=data)
152 | for i, (conv, filter) in enumerate(zip(self.model, self.filters)):
153 | data['x'] = conv(data['x'])
154 | data = filter(data,
155 | channels=(configuration[i] if configuration is not None else None))
156 | data['x'] = self.pool(data['x'])
157 | data['x'] = data['x'].view(data['x'].size(0), -1)
158 | if self.fc_dropout is not None:
159 | data['x'] = self.fc_dropout(data['x'])
160 | data['x'] = self.fc(data['x'])
161 | data['decision'] = torch.tensor(np.array(data['decision']).T, device=data['x'].device)
162 | return data
163 |
164 | @classmethod
165 | def reduce(cls, C=(32, 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024),
166 | bn=True):
167 | """
168 | Remove all modulers and reduce according to a single channel configuration
169 | """
170 | modules = []
171 | I = cls.in_channels
172 | depthwise = [False] + [True] * (len(C) - 1)
173 | strides, base_channels = cls.strides_channels()
174 | assert len(strides) == len(C)
175 |
176 | for O, stride, dw in zip(C, strides, depthwise):
177 | modules.append(cls.gen_conv(I, O, stride, dw, bn=bn))
178 | I = O
179 | modules += [nn.AvgPool2d(7), Flatten(), nn.Linear(I, cls.out_channels)]
180 | reduced = nn.Sequential(*modules)
181 | return reduced
182 |
183 |
184 |
--------------------------------------------------------------------------------
/mxnet_mobilenet.py:
--------------------------------------------------------------------------------
1 | # Licensed to the Apache Software Foundation (ASF) under one
2 | # or more contributor license agreements. See the NOTICE file
3 | # distributed with this work for additional information
4 | # regarding copyright ownership. The ASF licenses this file
5 | # to you under the Apache License, Version 2.0 (the
6 | # "License"); you may not use this file except in compliance
7 | # with the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing,
12 | # software distributed under the License is distributed on an
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 | # KIND, either express or implied. See the License for the
15 | # specific language governing permissions and limitations
16 | # under the License.
17 |
18 | # coding: utf-8
19 | # pylint: disable= arguments-differ,unused-argument,missing-docstring,too-many-function-args
20 | """
21 | MobileNet and MobileNetV2, implemented in Gluon.
22 | File based on Gluon-cv mobilenet.py at https://github.com/dmlc/gluon-cv/blob/3c4150a964c776e4f7da0eb30b55ab05b7554c8d/gluoncv/model_zoo/mobilenet.py
23 | """
24 |
25 | from mxnet.gluon import nn
26 | from mxnet.gluon.nn import BatchNorm
27 | from mxnet.context import cpu
28 | from mxnet.gluon.block import HybridBlock
29 | from gluoncv.nn import ReLU6
30 | import logging
31 |
32 | logger = logging.getLogger(__file__)
33 |
34 |
35 | __all__ = ['get_mobilenet']
36 |
37 |
38 | # pylint: disable= too-many-arguments
39 | def _add_conv(out, channels=1, kernel=1, stride=1, pad=0,
40 | num_group=1, active=True, relu6=False, norm_layer=BatchNorm, norm_kwargs=None):
41 | out.add(nn.Conv2D(channels, kernel, stride, pad, groups=num_group, use_bias=False))
42 | out.add(norm_layer(scale=True, **({} if norm_kwargs is None else norm_kwargs)))
43 | if active:
44 | out.add(ReLU6() if relu6 else nn.Activation('relu'))
45 |
46 |
47 | def _add_conv_dw(out, dw_channels, channels, stride, relu6=False,
48 | norm_layer=BatchNorm, norm_kwargs=None):
49 | _add_conv(out, channels=dw_channels, kernel=3, stride=stride,
50 | pad=1, num_group=dw_channels, relu6=relu6,
51 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
52 | _add_conv(out, channels=channels, relu6=relu6,
53 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
54 |
55 |
56 | class LinearBottleneck(nn.HybridBlock):
57 | r"""LinearBottleneck used in MobileNetV2 model from the
58 | `"Inverted Residuals and Linear Bottlenecks:
59 | Mobile Networks for Classification, Detection and Segmentation"
60 | `_ paper.
61 |
62 | Parameters
63 | ----------
64 | in_channels : int
65 | Number of input channels.
66 | channels : int
67 | Number of output channels.
68 | t : int
69 | Layer expansion ratio.
70 | stride : int
71 | stride
72 | norm_layer : object
73 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
74 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
75 | norm_kwargs : dict
76 | Additional `norm_layer` arguments, for example `num_devices=4`
77 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
78 | """
79 |
80 | def __init__(self, in_channels, channels, t, stride,
81 | norm_layer=BatchNorm, norm_kwargs=None, **kwargs):
82 | super(LinearBottleneck, self).__init__(**kwargs)
83 | self.use_shortcut = stride == 1 and in_channels == channels
84 | with self.name_scope():
85 | self.out = nn.HybridSequential()
86 |
87 | if t != 1:
88 | _add_conv(self.out,
89 | in_channels * t,
90 | relu6=True,
91 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
92 | _add_conv(self.out,
93 | in_channels * t,
94 | kernel=3,
95 | stride=stride,
96 | pad=1,
97 | num_group=in_channels * t,
98 | relu6=True,
99 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
100 | _add_conv(self.out,
101 | channels,
102 | active=False,
103 | relu6=True,
104 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
105 |
106 | def hybrid_forward(self, F, x):
107 | out = self.out(x)
108 | if self.use_shortcut:
109 | out = F.elemwise_add(out, x)
110 | return out
111 |
112 |
113 | # Net
114 | class MobileNet(HybridBlock):
115 | r"""MobileNet model from the
116 | `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
117 | `_ paper.
118 |
119 | Parameters
120 | ----------
121 | multiplier : float, default 1.0
122 | The width multiplier for controlling the model size. Only multipliers that are no
123 | less than 0.25 are supported. The actual number of channels is equal to the original
124 | channel size multiplied by this multiplier.
125 | classes : int, default 1000
126 | Number of classes for the output layer.
127 | norm_layer : object
128 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
129 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
130 | norm_kwargs : dict
131 | Additional `norm_layer` arguments, for example `num_devices=4`
132 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
133 | """
134 |
135 | def __init__(self, multiplier=1.0, classes=1000,
136 | norm_layer=BatchNorm, norm_kwargs=None, configuration=None, **kwargs):
137 | if configuration == None or configuration == 'MOBILENET':
138 | configuration = (32, 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024)
139 | else:
140 | configuration = [int(''.join(ci for ci in c if ci.isdigit())) for c in configuration.split(',')]
141 | configuration = tuple(map(int, ','.split(configuration)))
142 |
143 | logger.info('Mobilenet with channel configuration: {}'.format(configuration))
144 |
145 | super(MobileNet, self).__init__(**kwargs)
146 | with self.name_scope():
147 | self.features = nn.HybridSequential(prefix='')
148 | with self.features.name_scope():
149 | _add_conv(self.features, channels=int(configuration[0] * multiplier), kernel=3, pad=1, stride=2,
150 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
151 | dw_channels = [int(x * multiplier) for x in configuration[:-1]]
152 | configuration = [int(x * multiplier) for x in configuration[1:]]
153 | strides = [1, 2] * 3 + [1] * 5 + [2, 1]
154 | for dwc, c, s in zip(dw_channels, configuration, strides):
155 | _add_conv_dw(self.features, dw_channels=dwc, channels=c, stride=s,
156 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
157 | self.features.add(nn.GlobalAvgPool2D())
158 | self.features.add(nn.Flatten())
159 |
160 | self.output = nn.Dense(classes)
161 |
162 | def hybrid_forward(self, F, x):
163 | x = self.features(x)
164 | x = self.output(x)
165 | return x
166 |
167 |
168 | class MobileNetV2(nn.HybridBlock):
169 | r"""MobileNetV2 model from the
170 | `"Inverted Residuals and Linear Bottlenecks:
171 | Mobile Networks for Classification, Detection and Segmentation"
172 | `_ paper.
173 | Parameters
174 | ----------
175 | multiplier : float, default 1.0
176 | The width multiplier for controlling the model size. The actual number of channels
177 | is equal to the original channel size multiplied by this multiplier.
178 | classes : int, default 1000
179 | Number of classes for the output layer.
180 | norm_layer : object
181 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
182 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
183 | norm_kwargs : dict
184 | Additional `norm_layer` arguments, for example `num_devices=4`
185 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
186 | """
187 |
188 | def __init__(self, multiplier=1.0, classes=1000,
189 | norm_layer=BatchNorm, norm_kwargs=None, **kwargs):
190 | super(MobileNetV2, self).__init__(**kwargs)
191 | with self.name_scope():
192 | self.features = nn.HybridSequential(prefix='features_')
193 | with self.features.name_scope():
194 | _add_conv(self.features, int(32 * multiplier), kernel=3,
195 | stride=2, pad=1, relu6=True,
196 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
197 |
198 | in_channels_group = [int(x * multiplier) for x in [32] + [16] + [24] * 2
199 | + [32] * 3 + [64] * 4 + [96] * 3 + [160] * 3]
200 | channels_group = [int(x * multiplier) for x in [16] + [24] * 2 + [32] * 3
201 | + [64] * 4 + [96] * 3 + [160] * 3 + [320]]
202 | ts = [1] + [6] * 16
203 | strides = [1, 2] * 2 + [1, 1, 2] + [1] * 6 + [2] + [1] * 3
204 |
205 | for in_c, c, t, s in zip(in_channels_group, channels_group, ts, strides):
206 | self.features.add(LinearBottleneck(in_channels=in_c,
207 | channels=c,
208 | t=t,
209 | stride=s,
210 | norm_layer=norm_layer,
211 | norm_kwargs=norm_kwargs))
212 |
213 | last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
214 | _add_conv(self.features,
215 | last_channels,
216 | relu6=True,
217 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)
218 |
219 | self.features.add(nn.GlobalAvgPool2D())
220 |
221 | self.output = nn.HybridSequential(prefix='output_')
222 | with self.output.name_scope():
223 | self.output.add(
224 | nn.Conv2D(classes, 1, use_bias=False, prefix='pred_'),
225 | nn.Flatten())
226 |
227 | def hybrid_forward(self, F, x):
228 | x = self.features(x)
229 | x = self.output(x)
230 | return x
231 |
232 |
233 | # Constructor
234 | def get_mobilenet(multiplier=1.0, pretrained=False, ctx=cpu(),
235 | root='~/.mxnet/models', norm_layer=BatchNorm, norm_kwargs=None, **kwargs):
236 | r"""MobileNet model from the
237 | `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
238 | `_ paper.
239 |
240 | Parameters
241 | ----------
242 | multiplier : float
243 | The width multiplier for controlling the model size. Only multipliers that are no
244 | less than 0.25 are supported. The actual number of channels is equal to the original
245 | channel size multiplied by this multiplier.
246 | pretrained : bool or str
247 | Boolean value controls whether to load the default pretrained weights for model.
248 | String value represents the hashtag for a certain version of pretrained weights.
249 | ctx : Context, default CPU
250 | The context in which to load the pretrained weights.
251 | root : str, default $MXNET_HOME/models
252 | Location for keeping the model parameters.
253 | norm_layer : object
254 | Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
255 | Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
256 | norm_kwargs : dict
257 | Additional `norm_layer` arguments, for example `num_devices=4`
258 | for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
259 | """
260 | if pretrained:
261 | raise NotImplementedError("No pretrained weights in this modified mobilenet script")
262 | net = MobileNet(multiplier, norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
263 | return net
264 |
--------------------------------------------------------------------------------
/train_final.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a mobilenet-variant with varying channel numbers.
3 | File based on Gluon-cv train_imagenet.py at https://github.com/dmlc/gluon-cv/blob/18f8ab526ffb97660e6e5661f991064c20e2699d/scripts/classification/imagenet/train_imagenet.py
4 | """
5 |
6 | import argparse, time, logging, os, math
7 |
8 | import numpy as np
9 | import mxnet as mx
10 | import gluoncv as gcv
11 | from mxnet import gluon, nd
12 | from mxnet import autograd as ag
13 | from mxnet.gluon.data.vision import transforms
14 |
15 | from gluoncv.data import imagenet
16 | from gluoncv.model_zoo import get_model
17 | from gluoncv.utils import makedirs, LRSequential, LRScheduler
18 |
19 | import mxnet_mobilenet as mobilenet
20 |
21 |
22 | _models = {
23 | 'mobilenet': mobilenet.get_mobilenet,
24 | }
25 |
26 |
27 | # CLI
28 | def parse_args():
29 | parser = argparse.ArgumentParser(description='Train a model for image classification.')
30 | parser.add_argument('--data-dir', type=str, default='~/.mxnet/datasets/imagenet',
31 | help='training and validation pictures to use.')
32 | parser.add_argument('--rec-train', type=str, default='~/.mxnet/datasets/imagenet/rec/train.rec',
33 | help='the training data')
34 | parser.add_argument('--rec-train-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/train.idx',
35 | help='the index of training data')
36 | parser.add_argument('--rec-val', type=str, default='~/.mxnet/datasets/imagenet/rec/val.rec',
37 | help='the validation data')
38 | parser.add_argument('--rec-val-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/val.idx',
39 | help='the index of validation data')
40 | parser.add_argument('--use-rec', action='store_true',
41 | help='use image record iter for data input. default is false.')
42 | parser.add_argument('--batch-size', type=int, default=32,
43 | help='training batch size per device (CPU/GPU).')
44 | parser.add_argument('--dtype', type=str, default='float32',
45 | help='data type for training. default is float32')
46 | parser.add_argument('--num-gpus', type=int, default=0,
47 | help='number of gpus to use.')
48 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int,
49 | help='number of preprocessing workers')
50 | parser.add_argument('--num-epochs', type=int, default=3,
51 | help='number of training epochs.')
52 | parser.add_argument('--lr', type=float, default=0.1,
53 | help='learning rate. default is 0.1.')
54 | parser.add_argument('--momentum', type=float, default=0.9,
55 | help='momentum value for optimizer, default is 0.9.')
56 | parser.add_argument('--wd', type=float, default=0.0001,
57 | help='weight decay rate. default is 0.0001.')
58 | parser.add_argument('--lr-mode', type=str, default='step',
59 | help='learning rate scheduler mode. options are step, poly and cosine.')
60 | parser.add_argument('--lr-decay', type=float, default=0.1,
61 | help='decay rate of learning rate. default is 0.1.')
62 | parser.add_argument('--lr-decay-period', type=int, default=0,
63 | help='interval for periodic learning rate decays. default is 0 to disable.')
64 | parser.add_argument('--lr-decay-epoch', type=str, default='40,60',
65 | help='epochs at which learning rate decays. default is 40,60.')
66 | parser.add_argument('--warmup-lr', type=float, default=0.0,
67 | help='starting warmup learning rate. default is 0.0.')
68 | parser.add_argument('--warmup-epochs', type=int, default=0,
69 | help='number of warmup epochs.')
70 | parser.add_argument('--last-gamma', action='store_true',
71 | help='whether to init gamma of the last BN layer in each bottleneck to 0.')
72 | parser.add_argument('--mode', type=str,
73 | help='mode in which to train the model. options are symbolic, imperative, hybrid')
74 | parser.add_argument('--configuration', type=str, default=None,
75 | help=("The custom mobilenet configuration to train. "
76 | "Can be MOBILENET or comma-separated channel numbers."))
77 | parser.add_argument('--model', type=str, default="mobilenet",
78 | help='type of model to use. set to mobilenet in this modified version.')
79 | parser.add_argument('--input-size', type=int, default=224,
80 | help='size of the input image size. default is 224')
81 | parser.add_argument('--crop-ratio', type=float, default=0.875,
82 | help='Crop ratio during validation. default is 0.875')
83 | parser.add_argument('--use-pretrained', action='store_true',
84 | help='enable using pretrained model from gluon.')
85 | parser.add_argument('--use_se', action='store_true',
86 | help='use SE layers or not in resnext. default is false.')
87 | parser.add_argument('--mixup', action='store_true',
88 | help='whether train the model with mix-up. default is false.')
89 | parser.add_argument('--mixup-alpha', type=float, default=0.2,
90 | help='beta distribution parameter for mixup sampling, default is 0.2.')
91 | parser.add_argument('--mixup-off-epoch', type=int, default=0,
92 | help='how many last epochs to train without mixup, default is 0.')
93 | parser.add_argument('--label-smoothing', action='store_true',
94 | help='use label smoothing or not in training. default is false.')
95 | parser.add_argument('--no-wd', action='store_true',
96 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.')
97 | parser.add_argument('--teacher', type=str, default=None,
98 | help='teacher model for distillation training')
99 | parser.add_argument('--temperature', type=float, default=20,
100 | help='temperature parameter for distillation teacher model')
101 | parser.add_argument('--hard-weight', type=float, default=0.5,
102 | help='weight for the loss of one-hot label for distillation training')
103 | parser.add_argument('--batch-norm', action='store_true',
104 | help='enable batch normalization or not in vgg. default is false.')
105 | parser.add_argument('--save-frequency', type=int, default=10,
106 | help='frequency of model saving.')
107 | parser.add_argument('--save-dir', type=str, default='params',
108 | help='directory of saved models')
109 | parser.add_argument('--resume-epoch', type=int, default=0,
110 | help='epoch to resume training from.')
111 | parser.add_argument('--resume-params', type=str, default='',
112 | help='path of parameters to load from.')
113 | parser.add_argument('--resume-states', type=str, default='',
114 | help='path of trainer state to load from.')
115 | parser.add_argument('--log-interval', type=int, default=50,
116 | help='Number of batches to wait before logging.')
117 | parser.add_argument('--logging-file', type=str, default='train_imagenet.log',
118 | help='name of training log file')
119 | parser.add_argument('--use-gn', action='store_true',
120 | help='whether to use group norm.')
121 | opt = parser.parse_args()
122 | return opt
123 |
124 |
125 | def main():
126 | opt = parse_args()
127 |
128 | filehandler = logging.FileHandler(opt.logging_file)
129 | streamhandler = logging.StreamHandler()
130 |
131 | logger = logging.getLogger('')
132 | logger.setLevel(logging.INFO)
133 | logger.addHandler(filehandler)
134 | logger.addHandler(streamhandler)
135 |
136 | logger.info(opt)
137 |
138 | batch_size = opt.batch_size
139 | classes = 1000
140 | num_training_samples = 1281167
141 |
142 | num_gpus = opt.num_gpus
143 | batch_size *= max(1, num_gpus)
144 | context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
145 | num_workers = opt.num_workers
146 |
147 | lr_decay = opt.lr_decay
148 | lr_decay_period = opt.lr_decay_period
149 | if opt.lr_decay_period > 0:
150 | lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
151 | else:
152 | lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
153 | lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
154 | num_batches = num_training_samples // batch_size
155 |
156 | lr_scheduler = LRSequential([
157 | LRScheduler('linear', base_lr=0, target_lr=opt.lr,
158 | nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
159 | LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
160 | nepochs=opt.num_epochs - opt.warmup_epochs,
161 | iters_per_epoch=num_batches,
162 | step_epoch=lr_decay_epoch,
163 | step_factor=lr_decay, power=2)
164 | ])
165 |
166 | model_name = opt.model
167 |
168 | kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
169 | if opt.use_gn:
170 | from gluoncv.nn import GroupNorm
171 | kwargs['norm_layer'] = GroupNorm
172 | if model_name.startswith('vgg'):
173 | kwargs['batch_norm'] = opt.batch_norm
174 | elif model_name.startswith('resnext'):
175 | kwargs['use_se'] = opt.use_se
176 |
177 | if opt.last_gamma:
178 | kwargs['last_gamma'] = True
179 | if opt.configuration is not None:
180 | kwargs['configuration'] = opt.configuration
181 |
182 | optimizer = 'nag'
183 | optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
184 | if opt.dtype != 'float32':
185 | optimizer_params['multi_precision'] = True
186 |
187 | if model_name in _models:
188 | net = _models[model_name](**kwargs)
189 | else:
190 | net = get_model(model_name, **kwargs)
191 | net.cast(opt.dtype)
192 | if opt.resume_params is not '':
193 | net.load_parameters(opt.resume_params, ctx = context)
194 |
195 | # teacher model for distillation training
196 | if opt.teacher is not None and opt.hard_weight < 1.0:
197 | teacher_name = opt.teacher
198 | teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context)
199 | teacher.cast(opt.dtype)
200 | distillation = True
201 | else:
202 | distillation = False
203 |
204 | # Two functions for reading data from record file or raw images
205 | def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
206 | rec_train = os.path.expanduser(rec_train)
207 | rec_train_idx = os.path.expanduser(rec_train_idx)
208 | rec_val = os.path.expanduser(rec_val)
209 | rec_val_idx = os.path.expanduser(rec_val_idx)
210 | jitter_param = 0.4
211 | lighting_param = 0.1
212 | input_size = opt.input_size
213 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
214 | resize = int(math.ceil(input_size / crop_ratio))
215 | mean_rgb = [123.68, 116.779, 103.939]
216 | std_rgb = [58.393, 57.12, 57.375]
217 |
218 | def batch_fn(batch, ctx):
219 | data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
220 | label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
221 | return data, label
222 |
223 | train_data = mx.io.ImageRecordIter(
224 | path_imgrec = rec_train,
225 | path_imgidx = rec_train_idx,
226 | preprocess_threads = num_workers,
227 | shuffle = True,
228 | batch_size = batch_size,
229 |
230 | data_shape = (3, input_size, input_size),
231 | mean_r = mean_rgb[0],
232 | mean_g = mean_rgb[1],
233 | mean_b = mean_rgb[2],
234 | std_r = std_rgb[0],
235 | std_g = std_rgb[1],
236 | std_b = std_rgb[2],
237 | rand_mirror = True,
238 | random_resized_crop = True,
239 | max_aspect_ratio = 4. / 3.,
240 | min_aspect_ratio = 3. / 4.,
241 | max_random_area = 1,
242 | min_random_area = 0.08,
243 | brightness = jitter_param,
244 | saturation = jitter_param,
245 | contrast = jitter_param,
246 | pca_noise = lighting_param,
247 | )
248 | val_data = mx.io.ImageRecordIter(
249 | path_imgrec = rec_val,
250 | path_imgidx = rec_val_idx,
251 | preprocess_threads = num_workers,
252 | shuffle = False,
253 | batch_size = batch_size,
254 |
255 | resize = resize,
256 | data_shape = (3, input_size, input_size),
257 | mean_r = mean_rgb[0],
258 | mean_g = mean_rgb[1],
259 | mean_b = mean_rgb[2],
260 | std_r = std_rgb[0],
261 | std_g = std_rgb[1],
262 | std_b = std_rgb[2],
263 | )
264 | return train_data, val_data, batch_fn
265 |
266 | def get_data_loader(data_dir, batch_size, num_workers):
267 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
268 | jitter_param = 0.4
269 | lighting_param = 0.1
270 | input_size = opt.input_size
271 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
272 | resize = int(math.ceil(input_size / crop_ratio))
273 |
274 | def batch_fn(batch, ctx):
275 | data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
276 | label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
277 | return data, label
278 |
279 | transform_train = transforms.Compose([
280 | transforms.RandomResizedCrop(input_size),
281 | transforms.RandomFlipLeftRight(),
282 | transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
283 | saturation=jitter_param),
284 | transforms.RandomLighting(lighting_param),
285 | transforms.ToTensor(),
286 | normalize
287 | ])
288 | transform_test = transforms.Compose([
289 | transforms.Resize(resize, keep_ratio=True),
290 | transforms.CenterCrop(input_size),
291 | transforms.ToTensor(),
292 | normalize
293 | ])
294 |
295 | train_data = gluon.data.DataLoader(
296 | imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
297 | batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
298 | val_data = gluon.data.DataLoader(
299 | imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
300 | batch_size=batch_size, shuffle=False, num_workers=num_workers)
301 |
302 | return train_data, val_data, batch_fn
303 |
304 | if opt.use_rec:
305 | train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
306 | opt.rec_val, opt.rec_val_idx,
307 | batch_size, num_workers)
308 | else:
309 | train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)
310 |
311 | if opt.mixup:
312 | train_metric = mx.metric.RMSE()
313 | else:
314 | train_metric = mx.metric.Accuracy()
315 | acc_top1 = mx.metric.Accuracy()
316 | acc_top5 = mx.metric.TopKAccuracy(5)
317 |
318 | save_frequency = opt.save_frequency
319 | if opt.save_dir and save_frequency:
320 | save_dir = opt.save_dir
321 | makedirs(save_dir)
322 | else:
323 | save_dir = ''
324 | save_frequency = 0
325 |
326 | def mixup_transform(label, classes, lam=1, eta=0.0):
327 | if isinstance(label, nd.NDArray):
328 | label = [label]
329 | res = []
330 | for l in label:
331 | y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
332 | y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
333 | res.append(lam*y1 + (1-lam)*y2)
334 | return res
335 |
336 | def smooth(label, classes, eta=0.1):
337 | if isinstance(label, nd.NDArray):
338 | label = [label]
339 | smoothed = []
340 | for l in label:
341 | res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
342 | smoothed.append(res)
343 | return smoothed
344 |
345 | def test(ctx, val_data):
346 | if opt.use_rec:
347 | val_data.reset()
348 | acc_top1.reset()
349 | acc_top5.reset()
350 | for i, batch in enumerate(val_data):
351 | data, label = batch_fn(batch, ctx)
352 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
353 | acc_top1.update(label, outputs)
354 | acc_top5.update(label, outputs)
355 |
356 | _, top1 = acc_top1.get()
357 | _, top5 = acc_top5.get()
358 | return (1-top1, 1-top5)
359 |
360 | def train(ctx):
361 | if isinstance(ctx, mx.Context):
362 | ctx = [ctx]
363 | if opt.resume_params is '':
364 | net.initialize(mx.init.MSRAPrelu(), ctx=ctx)
365 |
366 | if opt.no_wd:
367 | for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
368 | v.wd_mult = 0.0
369 |
370 | trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
371 | if opt.resume_states is not '':
372 | trainer.load_states(opt.resume_states)
373 |
374 | if opt.label_smoothing or opt.mixup:
375 | sparse_label_loss = False
376 | else:
377 | sparse_label_loss = True
378 | if distillation:
379 | L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
380 | hard_weight=opt.hard_weight,
381 | sparse_label=sparse_label_loss)
382 | else:
383 | L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
384 |
385 | best_val_score = 1
386 |
387 | for epoch in range(opt.resume_epoch, opt.num_epochs):
388 | tic = time.time()
389 | if opt.use_rec:
390 | train_data.reset()
391 | train_metric.reset()
392 | btic = time.time()
393 |
394 | for i, batch in enumerate(train_data):
395 | data, label = batch_fn(batch, ctx)
396 |
397 | if opt.mixup:
398 | lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
399 | if epoch >= opt.num_epochs - opt.mixup_off_epoch:
400 | lam = 1
401 | data = [lam*X + (1-lam)*X[::-1] for X in data]
402 |
403 | if opt.label_smoothing:
404 | eta = 0.1
405 | else:
406 | eta = 0.0
407 | label = mixup_transform(label, classes, lam, eta)
408 |
409 | elif opt.label_smoothing:
410 | hard_label = label
411 | label = smooth(label, classes)
412 |
413 | if distillation:
414 | teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
415 | for X in data]
416 |
417 | with ag.record():
418 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
419 | if distillation:
420 | loss = [L(yhat.astype('float32', copy=False),
421 | y.astype('float32', copy=False),
422 | p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
423 | else:
424 | loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
425 | for l in loss:
426 | l.backward()
427 | trainer.step(batch_size)
428 |
429 | if opt.mixup:
430 | output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
431 | for out in outputs]
432 | train_metric.update(label, output_softmax)
433 | else:
434 | if opt.label_smoothing:
435 | train_metric.update(hard_label, outputs)
436 | else:
437 | train_metric.update(label, outputs)
438 |
439 | if opt.log_interval and not (i+1)%opt.log_interval:
440 | train_metric_name, train_metric_score = train_metric.get()
441 | logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
442 | epoch, i, batch_size*opt.log_interval/(time.time()-btic),
443 | train_metric_name, train_metric_score, trainer.learning_rate))
444 | btic = time.time()
445 |
446 | train_metric_name, train_metric_score = train_metric.get()
447 | throughput = int(batch_size * i /(time.time() - tic))
448 |
449 | err_top1_val, err_top5_val = test(ctx, val_data)
450 |
451 | logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
452 | logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
453 | logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))
454 |
455 | if err_top1_val < best_val_score:
456 | best_val_score = err_top1_val
457 | net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
458 | trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))
459 |
460 | if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
461 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
462 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))
463 | if save_frequency and save_dir:
464 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
465 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))
466 |
467 | if opt.mode == 'hybrid':
468 | net.hybridize(static_alloc=True, static_shape=True)
469 | if distillation:
470 | teacher.hybridize(static_alloc=True, static_shape=True)
471 | train(context)
472 |
473 | if __name__ == '__main__':
474 | main()
475 |
--------------------------------------------------------------------------------
/train_nas.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import sys
5 | import time
6 | import logging
7 | import warnings
8 | from collections import defaultdict
9 | from glob import glob
10 |
11 | import torch
12 | import numpy as np
13 | import pickle
14 |
15 | from torch import nn
16 | from torch.utils import data
17 | from viterbi import maxsum, sumprod_log, complete, score
18 | from torchvision import datasets
19 | from torchvision import transforms
20 | from types import SimpleNamespace
21 |
22 | import latency
23 | import misc
24 | from model import SlimMobilenet
25 |
26 |
27 | logger = logging.getLogger(__name__)
28 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
29 |
30 |
31 | parser = argparse.ArgumentParser(description='Train mobilenet slimmable/AOWS model')
32 | parser.add_argument('--data', metavar='DIR', default='/imagenet',
33 | help='path to dataset')
34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | parser.add_argument('--epochs', default=20, type=int, metavar='N',
37 | help='number of total epochs to run')
38 | parser.add_argument('-b', '--batch-size', default=512, type=int,
39 | metavar='N', help='mini-batch size (default: 512)')
40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
41 | metavar='LR', help='initial learning rate')
42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
43 | help='momentum')
44 | parser.add_argument('--no-cuda', action="store_true")
45 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
46 | metavar='W', help='weight decay (default: 1e-5)')
47 | parser.add_argument('--print-freq', '-p', default=10, type=int,
48 | metavar='N', help='print frequency (default: 10)')
49 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
50 | help='path to latest checkpoint (default: none)')
51 | parser.add_argument('--resume-last', action='store_true',
52 | help='resume to last checkpoint if found (supersedes resume)')
53 | parser.add_argument("--debug", nargs="*", choices=["mini"], default="none",
54 | help="do a mini epoch to check everything works.")
55 | parser.add_argument('--max-width', type=float, default=1.5)
56 | parser.add_argument('--min-width', type=float, default=0.2)
57 | parser.add_argument('--levels', type=int, default=14)
58 | parser.add_argument('--latency-target', type=float, default=0.04, help="latency target in objective")
59 | parser.add_argument('--window', type=int, default=100000,
60 | help="size of window over which the moving average of losses is computed in OWS and AOWS.")
61 | parser.add_argument('--latency-model', type=str, default='output/model_trt16_K100.0.jsonl', help="latency model")
62 | parser.add_argument('--gamma-iter', type=int, default=12, help="Number of Viterbi iterations to set gamma.")
63 | parser.add_argument('--AOWS', action="store_true", help="use AOWS")
64 | parser.add_argument('--AOWS-warmup', type=int, default=5, help="AOWS warmup epochs")
65 | parser.add_argument('--AOWS-min-temp', type=float, default=0.0005, help="minimum (final) temperature")
66 | parser.add_argument('--expname', default='output/nas_output')
67 |
68 |
69 | def main():
70 | args = parser.parse_args()
71 |
72 | logger.info("=> creating model")
73 | model = SlimMobilenet(min_width=args.min_width, max_width=args.max_width, levels=args.levels)
74 | logger.info(model)
75 | if not args.no_cuda:
76 | model = nn.DataParallel(model).cuda()
77 |
78 | criterion = nn.CrossEntropyLoss(reduction='none')
79 | if not args.no_cuda:
80 | criterion = criterion.cuda()
81 |
82 | args.lr = args.lr * args.batch_size / 256
83 | logger.info("learning rate scaling: using lr={}".format(args.lr))
84 |
85 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
86 | momentum=args.momentum,
87 | weight_decay=args.weight_decay)
88 |
89 | torch.backends.cudnn.benchmark = True
90 |
91 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
92 | augment = [transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), ]
93 |
94 | dataset_train = datasets.ImageFolder(os.path.join(args.data, 'train'),
95 | transforms.Compose(augment + [transforms.ToTensor(), normalize]))
96 | train_loader = data.DataLoader(dataset_train,
97 | batch_size=args.batch_size,
98 | num_workers=args.workers, pin_memory=True)
99 |
100 | start_epoch = 0
101 | ows_state = SimpleNamespace()
102 |
103 | filters = model.filters if hasattr(model, 'filters') else model.module.filters
104 | ows_state.histories = [{c: misc.MovingAverageMeter(args.window) for c in F.configurations} for F in filters]
105 | ows_state.latency = dict(latency.load_model(args.latency_model))
106 |
107 | if args.resume_last:
108 | avail = glob(osp.join(args.expname, 'checkpoint*.pth'))
109 | avail = [(int(f[-len('.pth') - 3:-len('.pth')]), f) for f in avail]
110 | avail = sorted(avail)
111 | if avail:
112 | args.resume = avail[-1][1]
113 | if args.resume:
114 | if os.path.isfile(args.resume):
115 | logger.info("=> loading checkpoint '{}'".format(args.resume))
116 | checkpoint = torch.load(args.resume)
117 | start_epoch = checkpoint['epoch']
118 | state_dict = checkpoint['state_dict']
119 | key = next(iter(state_dict.keys()))
120 | if key.startswith('module.') and args.no_cuda:
121 | state_dict = {k[len('module.'):]: v for (k, v) in state_dict.items()}
122 | model.load_state_dict(state_dict)
123 | optimizer.load_state_dict(checkpoint['optimizer'])
124 | if 'ows_state' in checkpoint:
125 | ows_state.histories = checkpoint['ows_state'].histories
126 | logger.info("=> loaded checkpoint '{}' (epoch {})"
127 | .format(args.resume, checkpoint['epoch']))
128 | else:
129 | logger.info(f"=> no checkpoint found at '{args.resume}'")
130 | args.resume = ''
131 |
132 | if args.resume:
133 | # Solve OWS one first time in order to allow re-evaluation on the last epoch with varying latency target
134 | best_path, _, _, _, timing = solve_ows(model, start_epoch, len(train_loader), -1, ows_state, args, eval_only=True)
135 | logger.info('Evaluation from resumed checkpoint...')
136 | best_path_str = (f"Best configuration: {best_path}, "
137 | f"predicted latency: {timing}")
138 | logger.info(best_path_str)
139 |
140 | for epoch in range(start_epoch, args.epochs):
141 | history = train(train_loader, model, criterion, optimizer, epoch, ows_state, args)
142 |
143 | logger.info(f"=> saving decision history for epoch {format(epoch + 1)}")
144 | decision_target = 'decision{:03d}.pkl'.format(epoch + 1)
145 | if args.expname:
146 | os.makedirs(args.expname, exist_ok=True)
147 | decision_target = osp.join(args.expname, decision_target)
148 | with open(decision_target, 'wb') as f:
149 | pickle.dump(history, f, protocol=4)
150 |
151 | logger.info(f"=> saving checkpoint for epoch {epoch + 1}")
152 |
153 | current_state = {
154 | 'epoch': epoch + 1,
155 | 'state_dict': model.state_dict(),
156 | 'optimizer': optimizer.state_dict(),
157 | }
158 |
159 | current_state['ows_state'] = ows_state
160 | best_path_str = (f"Best configuration: {history['OWS'][-1]['best_path']}, "
161 | f"predicted latency: {history['OWS'][-1]['pred_latency']}")
162 | logger.info(best_path_str)
163 | with open(osp.join(args.expname, f"ows_result_{epoch + 1:03d}.txt"), 'w') as f:
164 | f.write(best_path_str + '\n')
165 | filename = save_checkpoint(current_state, args.expname)
166 | logger.info(f"checkpoint saved to {filename}.")
167 |
168 |
169 |
170 | def train(train_loader, model, criterion, optimizer, epoch, ows_state, args):
171 | meters = defaultdict(misc.AverageMeter)
172 |
173 | model.train()
174 |
175 | filters = model.filters if hasattr(model, 'filters') else model.module.filters
176 | history = defaultdict(list)
177 |
178 | end = time.time()
179 | for iteration, (input, target) in enumerate(train_loader):
180 | if "mini" in args.debug and iteration > 20: break
181 |
182 | best_path, temperature, gamma_max, best_perf, timing = solve_ows(
183 | model, epoch, len(train_loader), iteration, ows_state, args)
184 |
185 | # measure data loading time
186 | meters["data_time"].update(time.time() - end)
187 |
188 | if not args.no_cuda:
189 | target = target.cuda(non_blocking=True)
190 |
191 | compute_results = misc.SWDefaultDict(misc.SWDict)
192 |
193 | minconf = [F.configurations[0] for F in filters]
194 | maxconf = [F.configurations[-1] for F in filters]
195 |
196 | optimizer.zero_grad()
197 |
198 | # sandwich rule: train maximum configuration
199 | outp = model(input, configuration=maxconf)
200 | loss = criterion(outp['x'], target)
201 | loss.mean().backward()
202 | compute_results['max']['x'] = outp['x'].detach()
203 | compute_results['max']['loss_numpy'] = loss.detach().cpu().numpy()
204 | compute_results['max']['prob'] = torch.nn.functional.softmax(compute_results['max']['x'], dim=1)
205 |
206 | # sandwich rule: train minimum and random configuration with self-distillation
207 | for kind in ('min', 'rand'):
208 | conf = None if kind == 'rand' else minconf
209 | outp = model(input, configuration=conf)
210 |
211 | loss = misc.soft_cross_entropy(outp['x'], compute_results['max']['prob'].detach())
212 | compute_results[kind]['soft_loss_numpy'] = loss.detach().cpu().numpy()
213 | with torch.no_grad():
214 | hard_loss_numpy = criterion(outp['x'], target).detach().cpu().numpy()
215 | compute_results[kind]['loss_numpy'] = hard_loss_numpy
216 |
217 | compute_results[kind]['x'] = outp['x'].detach()
218 | if kind == 'rand':
219 | compute_results['rand']['decision'] = outp['decision'].cpu().numpy()
220 | loss.mean().backward()
221 |
222 | for path, image_loss, image_refloss in zip(compute_results['rand']['decision'],
223 | compute_results['rand']['loss_numpy'],
224 | compute_results['max']['loss_numpy']):
225 | for i, pi in enumerate(path):
226 | ows_state.histories[i][pi].update(-(image_loss - image_refloss) / len(path), epoch, iteration)
227 |
228 | for refname in ('min', 'max', 'rand'):
229 | meters['loss_' + kind].update(compute_results[kind]['loss_numpy'].mean(), input.size(0))
230 | refloss = compute_results[refname]['loss_numpy']
231 | (prec1, prec5), refcorrect_ks = misc.accuracy(compute_results[refname]['x'].data,
232 | target, topk=(1, 5), return_correct_k=True)
233 | refcorrect1, refcorrect5 = [a.cpu().numpy().astype(bool) for a in refcorrect_ks]
234 | history['loss_' + refname].append(refloss)
235 | history['top1_' + refname].append(refcorrect1)
236 | history['top5_' + refname].append(refcorrect5)
237 | meters['top1_' + refname].update(prec1.item(), input.size(0))
238 | meters['top5_' + refname].update(prec5.item(), input.size(0))
239 | if 'soft_loss_numpy' in compute_results[refname]:
240 | meters['loss_soft_' + kind].update(compute_results[kind]['soft_loss_numpy'].mean(), input.size(0))
241 | history['loss_soft_' + refname].append(compute_results[refname]['soft_loss_numpy'])
242 |
243 | history['configuration'].append(compute_results['rand']['decision'])
244 | history['configuration'].append(compute_results['rand']['loss_numpy'])
245 |
246 | optimizer.step()
247 |
248 | # measure elapsed time
249 | meters["batch_time"].update(time.time() - end)
250 | end = time.time()
251 |
252 | if iteration % args.print_freq == 0:
253 | toprint = f"Epoch: [{epoch}][{iteration}/{len(train_loader)}]\t"
254 | toprint += ('Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
255 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
256 | 'Prec@1 {top1_rand.val:.3f} ({top1_rand.avg:.3f})\t'
257 | 'Prec@5 {top5_rand.val:.3f} ({top5_rand.avg:.3f})\t'.format(**meters))
258 |
259 | for key, meter in meters.items():
260 | if key.startswith('loss'):
261 | toprint += f'{key} {meter.val:.4f} ({meter.avg:.4f})\t'
262 | logger.info(toprint)
263 |
264 | # prints a string summarizing the sampling probabilities for each filter
265 | probas_str = ""
266 | for i, F in enumerate(filters):
267 | if F.probability is not None:
268 | probas_str += '|{} '.format(i)
269 | for p in F.probability:
270 | probas_str += str(int(100 * p)) + ' '
271 | probas_log = None
272 | if any(F.probability is not None for F in filters):
273 | probas_log = tuple(F.probability for F in filters),
274 | history['OWS'].append(dict(best_path=best_path, temperature=temperature, gamma_max=gamma_max,
275 | best_pref=best_perf, pred_latency=timing, probas_log=probas_log))
276 | if probas_str:
277 | probas_str = '\n' + probas_str
278 | ows_str = f"predicted latency: {timing}, perf: {best_perf}, T: {temperature}, gamma: {gamma_max}"
279 | logger.info('best_path: ' + ','.join(map(str, best_path)) + ows_str + probas_str)
280 |
281 |
282 | return history
283 |
284 |
285 | def aows_temp(epoch, epoch_len, iteration, args):
286 | schedule = [(0, 1.0), (args.AOWS_warmup, 1.0),
287 | (args.AOWS_warmup + 1, 0.01),
288 | (10, 0.001), (args.epochs, args.AOWS_min_temp)]
289 | cur_phase = 0
290 | for iphase, (phase, _) in enumerate(schedule):
291 | if epoch >= phase:
292 | cur_phase = iphase
293 | phase, start_temp = schedule[cur_phase]
294 | if cur_phase == len(schedule) - 1:
295 | return start_temp
296 | end_phase, end_temp = schedule[cur_phase + 1]
297 | max_iter = epoch_len * (end_phase - phase)
298 | cur_iter = epoch_len * (epoch - phase) + iteration
299 | ratio = cur_iter / max_iter
300 | log_T = (1.0 - ratio) * np.log10(start_temp) + ratio * np.log10(end_temp)
301 | return 10 ** log_T
302 |
303 |
304 | def solve_ows(model, epoch, len_epoch, iteration, ows_state, args, eval_only=False):
305 | """
306 | Solves OWS equation and sets AOWS probabilities when AOWS is activated.
307 | """
308 | if hasattr(model, 'module'): model = model.module
309 |
310 | unaries = [[0.0]] + [[M.avg for M in C.values()] for C in ows_state.histories] + [[0.0]]
311 |
312 | if not hasattr(ows_state, 'pairwise'):
313 | pairwise = []
314 |
315 | possible_in_channels = [3]
316 | possible_outputs = iter([F.configurations for F in model.filters] + [[1000]])
317 | for L in model.components:
318 | possible_out = next(possible_outputs)
319 | pair = np.zeros((len(possible_in_channels), len(possible_out)))
320 | for incoming, p in enumerate(possible_in_channels):
321 | for outgoing, l in enumerate(possible_out):
322 | var = latency.Vartype(**L._asdict(), in_channels=p, out_channels=l)
323 | pair[incoming, outgoing] = ows_state.latency[var]
324 | pairwise.append(pair)
325 | possible_in_channels = possible_out
326 | ows_state.pairwise = pairwise
327 |
328 | unaries, pairwise, states = complete(unaries, ows_state.pairwise)
329 |
330 | def solve(gamma):
331 | _, ipath = maxsum(unaries, -gamma * pairwise, states)
332 | perf, timing = score(ipath, unaries, pairwise, detail=True)
333 | return ipath, perf, timing
334 |
335 | gamma_min = 0.0
336 | gamma_max = 10.0
337 | timing_max = solve(gamma_max)[2]
338 |
339 | expanding_iterations = 0
340 | while timing_max > args.latency_target:
341 | expanding_iterations += 1
342 | if expanding_iterations > 2:
343 | logging.warning("Too many expanding loops for gamma, try adjusting gamma_max in the code")
344 | gamma_max *= 2
345 | timing_max = solve(gamma_max)[2]
346 |
347 | for _ in range(args.gamma_iter):
348 | mid_gamma = 0.5 * (gamma_min + gamma_max)
349 | timing_middle = solve(mid_gamma)[2]
350 | if timing_middle > args.latency_target:
351 | gamma_min = mid_gamma
352 | else:
353 | gamma_max = mid_gamma
354 | ipath, perf, timing = solve(gamma_max)
355 |
356 | T = np.inf
357 | if args.AOWS and epoch >= args.AOWS_warmup and not eval_only:
358 | T = aows_temp(epoch, len_epoch, iteration, args)
359 | marginals = sumprod_log(unaries / T, -gamma_max * pairwise / T, states)
360 | assert marginals.shape[0] == len(model.filters) + 2, "{} {}".format(marginals.shape[0], len(model.filters))
361 | for F, marginal in zip(model.filters, marginals[1:-1]):
362 | F.probability = marginal[:len(F.configurations)]
363 |
364 | best_path = tuple(F.configurations[i] for (F, i) in zip(model.filters, ipath[1:-1]))
365 | return best_path, T, gamma_max, perf, timing
366 |
367 |
368 | def save_checkpoint(state, expname=''):
369 | filename = f"checkpoint{state['epoch']:03d}.pth"
370 | if expname:
371 | os.makedirs(expname, exist_ok=True)
372 | filename = osp.join(expname, filename)
373 | torch.save(state, filename)
374 | return filename
375 |
376 |
377 | if __name__ == '__main__':
378 | logger = logging.getLogger(__file__)
379 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG,
380 | format='%(name)s: %(message)s')
381 | main()
382 |
--------------------------------------------------------------------------------
/viterbi.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import numpy as np
3 | from numba import jit
4 | import sys, logging
5 | logging.getLogger('numba').setLevel(logging.WARNING)
6 |
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def complete(unary, pairwise, fill=-np.inf):
12 | """
13 | Convert lists of unaries and pairwises into tensors by filling blanks
14 | """
15 | N = len(unary)
16 | states = np.array(list(map(len, unary)))
17 | max_S = max(states)
18 | un = np.full((N, max_S), fill)
19 | pair = np.full((N - 1, max_S, max_S), fill)
20 | for i in range(N):
21 | un[i, :len(unary[i])] = unary[i]
22 | if i < N - 1:
23 | pair[i, :len(unary[i]), :len(unary[i + 1])] = pairwise[i]
24 |
25 | return un, pair, states
26 |
27 |
28 | @jit(nopython=True)
29 | def maxsum(unary, pairwise, states):
30 | N, S = unary.shape
31 | partial = unary[0]
32 | selected = np.zeros((N - 1, S), np.int64)
33 | for s in range(N - 1):
34 | new_partial = np.full((S,), -np.inf)
35 | for j in range(states[s + 1]):
36 | best_ = -np.inf
37 | best_i = 0
38 | for i in range(states[s]):
39 | candidate = partial[i] + pairwise[s, i, j]
40 | if candidate > best_:
41 | best_ = candidate
42 | best_i = i
43 | selected[s, j] = best_i
44 | new_partial[j] = unary[s + 1, j] + best_
45 | partial = new_partial
46 |
47 | path = np.zeros((N,), np.int64)
48 | score = -np.inf
49 | best_j = 0
50 | for j in range(states[N - 1]):
51 | candidate = partial[j]
52 | if candidate > score:
53 | score = candidate
54 | best_j = j
55 | path[N - 1] = best_j
56 | for i in range(N - 2, -1, -1):
57 | best_j = selected[i, best_j]
58 | path[i] = best_j
59 | return score, path
60 |
61 |
62 | def score(path, unary, pairwise, detail=False):
63 | if not len(path): return 0.0
64 | S = unary[0][path[0]]
65 | Sp = 0.0
66 | prev = path[0]
67 | for i, p in enumerate(path[1:]):
68 | Sp += pairwise[i][prev, p]
69 | S += unary[i + 1][p]
70 | prev = p
71 | if detail:
72 | return S, Sp
73 | return S + Sp
74 |
75 |
76 | def maxsum_brute(unary, pairwise, states, K=3):
77 | """
78 | Brute-force max-sum (for debugging)
79 | """
80 | best_path = None
81 | best_score = -float('inf')
82 | for path in itertools.product(*[range(s) for s in states]):
83 | sc = score(path, unary, pairwise)
84 | if sc > best_score:
85 | best_path = path
86 | best_score = sc
87 | return best_score, best_path
88 |
89 |
90 | @jit(nopython=True)
91 | def sumprod_log(unary, pairwise, states, logspace=False):
92 | N, max_s = unary.shape
93 | alpha = np.zeros((N, max_s), dtype=unary.dtype)
94 | beta = np.zeros((N, max_s), dtype=unary.dtype)
95 | alpha[0, :states[0]] = unary[0][:states[0]]
96 | alpha[0, states[0]:] = -np.inf
97 | for s in range(N - 1):
98 | for k2 in range(states[s + 1]):
99 | M = -np.inf
100 | for k1 in range(states[s]):
101 | C = alpha[s, k1] + pairwise[s, k1, k2]
102 | if C > M:
103 | M = C
104 | for k1 in range(states[s]):
105 | alpha[s + 1, k2] += np.exp(alpha[s, k1] + pairwise[s, k1, k2] - M)
106 | alpha[s + 1, k2] = unary[s + 1, k2] + np.log(alpha[s + 1, k2]) + M
107 | for k2 in range(states[s + 1], max_s):
108 | alpha[s + 1, k2] = -np.inf
109 | M = alpha[s + 1, :states[s + 1]].max()
110 | Z = np.log(np.exp(alpha[s + 1, :states[s + 1]] - M).sum()) + M
111 | alpha[s + 1, :states[s + 1]] = alpha[s + 1, :states[s + 1]] - Z
112 | for s in range(N - 2, -1, -1):
113 | for k1 in range(states[s]):
114 | M = -np.inf
115 | for k2 in range(states[s + 1]):
116 | C = beta[s + 1, k2] + pairwise[s, k1, k2] + unary[s + 1, k2]
117 | if C > M:
118 | M = C
119 | for k2 in range(states[s + 1]):
120 | beta[s, k1] += np.exp(beta[s + 1, k2] + pairwise[s, k1, k2] + unary[s + 1, k2] - M)
121 | beta[s, k1] = np.log(beta[s, k1]) + M
122 | for k1 in range(states[s], max_s):
123 | beta[s, k1] = -np.inf
124 | M = beta[s, :states[s]].max()
125 | Z = np.log(np.exp(beta[s, :states[s]] - M).sum()) + M
126 | beta[s, :states[s]] = beta[s, :states[s]] - Z
127 | marg = alpha + beta
128 | for s in range(N):
129 | marg[s] = marg[s] - marg[s].max()
130 | if not logspace:
131 | marg = np.exp(marg)
132 | marg = marg / marg.sum(1).reshape(-1, 1)
133 | return marg
134 |
135 |
136 |
--------------------------------------------------------------------------------