├── LICENSE
├── README.md
├── __init__.py
├── build_lookup_table.py
├── common.py
├── constants.py
├── data
└── __init__.py
├── docs
├── alexnet
│ └── README.md
└── mobilenet
│ └── README.md
├── eval.py
├── fig
├── netadapt_algo.png
└── netadapt_fig.png
├── functions.py
├── latency_lut
└── __init__.py
├── master.py
├── models
├── alexnet
│ └── __init__.py
├── helloworld
│ ├── lut.pkl
│ └── model_0.pth.tar
└── mobilenet
│ └── __init__.py
├── nets
├── __init__.py
├── alexnet.py
├── helloworld.py
└── mobilenet.py
├── network_utils
├── __init__.py
├── network_utils_abstract.py
├── network_utils_alexnet.py
├── network_utils_helloworld.py
└── network_utils_mobilenet.py
├── requirements.txt
├── scripts
├── netadapt_alexnet-0.5latency.sh
├── netadapt_alexnet-0.5mac.sh
├── netadapt_helloworld.sh
├── netadapt_mobilenet-0.5latency.sh
├── netadapt_mobilenet-0.5mac.sh
├── unittest.sh
├── unittest_alexnet.sh
├── unittest_helloworld.sh
└── unittest_mobilenet.sh
├── train.py
├── unittest
├── unittest_master_helloworld.py
├── unittest_network_utils_alexnet.py
├── unittest_network_utils_helloworld.py
├── unittest_network_utils_mobilenet.py
└── unittest_worker_helloworld.py
└── worker.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 denru01
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications
2 | ============================
3 | This repo contains the official Pytorch reimplementation of the paper "NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications"
4 | [[paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Tien-Ju_Yang_NetAdapt_Platform-Aware_Neural_ECCV_2018_paper.pdf)]
5 | [[project](https://netadapt.mit.edu)]. The results in the paper were generated by the Tensorflow implementation from Google AI.
6 |
7 |
8 |
9 |
10 | ## Summary ##
11 | 0. [Requirements](#requirements)
12 | 0. [Usage](#usage)
13 | 0. [Example](#example)
14 | 0. [Customization](#customization)
15 | 0. [Citation](#citation)
16 |
17 |
29 |
30 | ## Requirements ##
31 | The code base is tested with the following setting:
32 |
33 | 1. Python 3.7.0
34 | 2. CUDA 10.0
35 | 3. Pytorch 1.2.0
36 | 4. torchvision 0.4.0
37 | 5. numpy 1.17.0
38 | 6. scipy 1.3.1
39 |
40 | First clone the repo in the directory you want to work:
41 |
42 | git clone https://github.com/denru01/netadapt.git
43 | cd netadapt
44 |
45 | In the following context, we assume you are at the repo root.
46 |
47 | If the versions of Python and CUDA are the same as yours, you can download the python packages using:
48 |
49 | pip install -r requirements.txt
50 |
51 | To verify the downloaded code base is correct, please run either
52 |
53 | sh scripts/unittest.sh
54 |
55 | or
56 |
57 | sh scripts/unittest_helloworld.sh
58 | sh scripts/unittest_alexnet.sh
59 | sh scripts/unittest_mobilenet.sh
60 |
61 | If it is correct, you should not see any FAIL.
62 |
63 | ## Usage ##
64 | In order to apply NetAdapt, run:
65 |
66 | python master.py [-h] [-gp GPUS [GPUS ...]] [-re] [-im INIT_MODEL_PATH]
67 | [-mi MAX_ITERS] [-lr FINETUNE_LR] [-bu BUDGET]
68 | [-bur BUDGET_RATIO] [-rt RESOURCE_TYPE]
69 | [-ir INIT_RESOURCE_REDUCTION]
70 | [-irr INIT_RESOURCE_REDUCTION_RATIO]
71 | [-rd RESOURCE_REDUCTION_DECAY]
72 | [-st SHORT_TERM_FINE_TUNE_ITERATION] [-lt LOOKUP_TABLE_PATH]
73 | [-dp DATASET_PATH] [-a ARCH] [-si SAVE_INTERVAL]
74 | working_folder input_data_shape input_data_shape
75 | input_data_shape
76 |
77 | - `working_folder`: Root folder where models, related files and history information are saved. You can see how models are pruned progressively in `working_folder/master/history.txt`.
78 |
79 | - `input_data_shape`: Input data shape (C, H, W) (default: 3 224 224). If you want to apply NetAdapt to different tasks, you might need to change data shape.
80 |
81 | - `-h, --help`: Show this help message and exit.
82 |
83 | - `-gp GPUS [GPUS ...], --gpus GPUS [GPUS ...]`: Indices of available gpus (default: 0).
84 |
85 | - `-re, --resume`: Resume from previous iteration. In order to resume, specify `--resume` and specify `working_folder` as the one you want to resume.
86 | The resumed arguments will overwrite the arguments provided here.
87 | For example, if you want to simplify a model by pruning and finetuning for 30 iterations (under `working_folder`), however, your program terminated after 20 iterations.
88 | Then you can use `--resume` to restore and continue for the last 10 iterations.
89 |
90 | - `-im INIT_MODEL_PATH, --init_model_path INIT_MODEL_PATH`: Path to pretrained model.
91 |
92 | - `-mi MAX_ITERS, --max_iters MAX_ITERS`: Maximum iteration of removing filters and short-term fine-tune (default: 10).
93 |
94 | - `-lr FINETUNE_LR, --finetune_lr FINETUNE_LR`: Short-term fine-tune learning rate (default: 0.001).
95 |
96 | - `-bu BUDGET, --budget BUDGET`: Resource constraint. If resource < `budget`, the process is terminated.
97 |
98 | - `-bur BUDGET_RATIO, --budget_ratio BUDGET_RATIO`: If `--budget` is not specified, `buget` = `budget_ratio`\*(pretrained model resource) (default: 0.25).
99 |
100 | - `-rt RESOURCE_TYPE, --resource_type RESOURCE_TYPE`: Resource constraint type (default: FLOPS). We currently support `FLOPS`, `WEIGHTS`, and `LATENCY` (device `cuda:0`). If you want to add other resource
101 | types, please modify `def compute_resource(...)` in `network_util` python files (e.g. `network_utils/network_utils_alexnet`).
102 |
103 | - `-ir INIT_RESOURCE_REDUCTION, --init_resource_reduction INIT_RESOURCE_REDUCTION`: For each iteration, target resource = current resource - `init_resource_reduction`\*(`resource_reduction_decay`\*\*(iteration-1)).
104 |
105 | - `-irr INIT_RESOURCE_REDUCTION_RATIO, --init_resource_reduction_ratio INIT_RESOURCE_REDUCTION_RATIO`: If `--init_resource_reduction` is not specified,
106 | `init_resource_reduction` = `init_resource_reduction_ratio`\*(pretrained model resource) (default: 0.025).
107 |
108 | - `-rd RESOURCE_REDUCTION_DECAY, --resource_reduction_decay RESOURCE_REDUCTION_DECAY`: For each iteration, target resource = current resource - `init_resource_reduction`\*(`resource_reduction_decay`\*\*(iteration-1)) (default: 0.96).
109 |
110 | - `-st SHORT_TERM_FINE_TUNE_ITERATION, --short_term_fine_tune_iteration SHORT_TERM_FINE_TUNE_ITERATION`: Short-term fine-tune iteration (default: 10).
111 |
112 | - `-lt LOOKUP_TABLE_PATH, --lookup_table_path LOOKUP_TABLE_PATH`: Path to lookup table.
113 |
114 | - `-dp DATASET_PATH, --dataset_path DATASET_PATH`: Path to dataset.
115 |
116 | - `-a ARCH, --arch ARCH network_utils`: Defines how networks are pruned, fine-tuned, and evaluated. If you want to use
117 | your own method, please see [**Customization**](#customization) and specify here. (default: alexnet)
118 |
119 | - `-si SAVE_INTERVAL, --save_interval SAVE_INTERVAL`: Interval of iterations that all pruned models at the same iteration will be saved.
120 | Use `-1` to save only the best model at each iteration.
121 | Use `1` to save all models at each iteration. (default: -1).
122 |
123 |
124 |
125 | ## Example ##
126 |
187 |
188 | We provide a simple example of applying **NetAdapt** to a very small [network](nets/helloworld.py):
189 |
190 | sh scripts/netadapt_helloworld.sh
191 |
192 | Detailed examples of applying **NetAdapt** to **AlexNet**/**MobileNet** on **CIFAR-10** are shown [**here (AlexNet)**](docs/alexnet/README.md) and [**here (MobileNet)**](docs/mobilenet/README.md).
193 |
194 |
195 |
196 |
197 |
198 | If you want to apply the algorithm to different networks or even different tasks,
199 | please see the following [**Customization**](#customization) section.
200 |
201 |
202 | ## Customization ##
203 |
204 | To apply NetAdapt to differenct networks or different tasks, please follow the instructions:
205 |
206 | 1. Create your own `network_utils` python file (said `network_utils_yourNetworkOrTask.py`) and place it under `network_utils`.
207 |
208 | 2. Implement functions described in [`network_utils_abstract.py`](network_utils/network_utils_abstract.py).
209 |
210 | 3. As we provide an example of applying NetAdapt to AlexNet, you can also build your `network_utils` based on [`network_utils_alexnet.py`](network_utils/network_utils_alexnet.py):
211 |
212 | ```bash
213 | cd network_utils
214 | cp network_utils_alexnet.py ./network_utils_yourNetworkOrTask.py
215 | ```
216 |
217 | 4. Add `from .network_utils_yourNetworkOrTask import *` to `__init__.py`, which is under [the same directory](network_utils/__init__.py).
218 |
219 | 5. Modify `class networkUtils_alexnet(...)` in [line 44](network_utils/network_utils_alexnet.py#L44) in `network_utils_yourNetworkOrTask.py` to `class networkUtils_yourNetworkOrTask(...)`.
220 |
221 | 6. Modify `def alexnet(...)` in [line 325-326](network_utils/network_utils_alexnet.py#L325-L326) to:
222 | ```bash
223 | def yourNetworkOrTask(model, input_data_shape, dataset_path, finetune_lr=1e-3):
224 | return networkUtils_yourNetworkOrTask(model, input_data_shape, dataset_path, finetune_lr)
225 | ```
226 |
227 | 7. Specify training/validation data loader, loss functions, optimizers, network architecture, training method, and evaluation method in `network_utils_yourNetworkOrTask.py` if there is any difference from the AlexNet example:
228 |
229 | - Modify data loader and loss functionsin function `def __init__(...):` in [line 52](network_utils/network_utils_alexnet.py#L52-L125).
230 |
231 | - Specify additive skip connections if there is any and modify `def simplify_network_def_based_on_constraint(...)` in `network_utils_yourNetworkOrTask.py`.
232 | You can see how our implementation uses additive skip connections [here](functions.py#L543-L549).
233 |
234 | - Modify training method (short-term finetune) in function `def fine_tune(...):` in [line 245](network_utils/network_utils_alexnet.py#L245-L288).
235 |
236 | - Modify evaluation method in function `def evaluate(...):` in [line 291](network_utils/network_utils_alexnet.py#L291-L322).
237 |
238 | You can see how these methods are utilized by the framework [here](worker.py#L39-L62).
239 |
240 | 8. Our current code base supports pruning `Conv2d`, `ConvTranspose2d`, and `Linear` with additive skip connection.
241 | If your network architecture is not supported, please modify [this](network_utils/network_utils_alexnet.py#L142-L199).
242 | If you want to use other metrics (resource type) to prune networks, please modify [this](network_utils/network_utils_alexnet.py#L234-L238).
243 |
244 | 9. We can apply NetAdapt to different networks or tasks by using `--arch yourNetworkOrTask` in `scripts/netadapt_alexnet-0.5mac.sh`.
245 | As for the values of other arguments, please see [**Usage**](#usage).
246 | Generally, if you want to apply NetAdapt to a different task, you might change `input_data_shape`.
247 | If your network architecture is very different from that of MobileNet, you would have to modify the values of `--init_resource_reduction_ratio` and `--resource_reduction_decay` to get a different resource reduction schedule.
248 |
249 |
250 |
251 | ## Citation
252 | If you use our code or method in your work, please consider citing the following:
253 |
254 | ```
255 | @InProceedings{eccv_2018_yang_netadapt,
256 | author = {Yang, Tien-Ju and Howard, Andrew and Chen, Bo and Zhang, Xiao and Go, Alec and Sandler, Mark and Sze, Vivienne and Adam, Hartwig},
257 | title = {NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications},
258 | booktitle = {The European Conference on Computer Vision (ECCV)},
259 | month = {September},
260 | year = {2018}
261 | }
262 | ```
263 |
264 | Please direct any questions to the authors: Tien-Ju Yang (tjy@mit.edu) and Yi-Lun Liao (ylliao@mit.edu).
265 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/__init__.py
--------------------------------------------------------------------------------
/build_lookup_table.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import nets as models
4 | import functions as fns
5 | from argparse import ArgumentParser
6 |
7 | model_names = sorted(name for name in models.__dict__
8 | if name.islower() and not name.startswith("__")
9 | and callable(models.__dict__[name]))
10 |
11 | NUM_CLASSES = 10
12 |
13 | INPUT_DATA_SHAPE = (3, 224, 224)
14 |
15 |
16 | '''
17 | `MIN_CONV_FEATURE_SIZE`: The sampled size of feature maps of layers (conv layer)
18 | along channel dimmension are multiples of 'MIN_CONV_FEATURE_SIZE'.
19 |
20 | `MIN_FC_FEATURE_SIZE`: The sampled size of features of FC layers are
21 | multiples of 'MIN_FC_FEATURE_SIZE'.
22 | '''
23 | MIN_CONV_FEATURE_SIZE = 8
24 | MIN_FC_FEATRE_SIZE = 64
25 |
26 | '''
27 | `MEASURE_LATENCY_BATCH_SIZE`: the batch size of input data
28 | when running forward functions to measure latency.
29 | `MEASURE_LATENCY_SAMPLE_TIMES`: the number of times to run the forward function of
30 | a layer in order to get its latency.
31 | '''
32 | MEASURE_LATENCY_BATCH_SIZE = 128
33 | MEASURE_LATENCY_SAMPLE_TIMES = 500
34 |
35 |
36 | arg_parser = ArgumentParser(description='Build latency lookup table')
37 | arg_parser.add_argument('--dir', metavar='DIR', default='latency_lut/lut_alexnet.pkl',
38 | help='path to saving lookup table')
39 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet',
40 | choices=model_names,
41 | help='model architecture: ' +
42 | ' | '.join(model_names) +
43 | ' (default: alexnet)')
44 |
45 |
46 | if __name__ == '__main__':
47 |
48 | args = arg_parser.parse_args()
49 | print(args)
50 |
51 | build_lookup_table = True
52 | lookup_table_path = args.dir
53 | model_arch = args.arch
54 |
55 | print('Load', model_arch)
56 | print('--------------------------------------')
57 | model = models.__dict__[model_arch](num_classes=NUM_CLASSES)
58 | network_def = fns.get_network_def_from_model(model, INPUT_DATA_SHAPE)
59 | for layer_name, layer_properties in network_def.items():
60 | print(layer_name)
61 | print(' ', layer_properties, '\n')
62 | print('-------------------------------------------')
63 |
64 | num_w = fns.compute_resource(network_def, 'WEIGHTS')
65 | flops = fns.compute_resource(network_def, 'FLOPS')
66 | num_param = fns.compute_resource(network_def, 'WEIGHTS')
67 | print('Number of FLOPs: ', flops)
68 | print('Number of weights: ', num_w)
69 | print('Number of parameters: ', num_param)
70 | print('-------------------------------------------')
71 |
72 | model = model.cuda()
73 |
74 | print('Building latency lookup table for',
75 | torch.cuda.get_device_name())
76 | if build_lookup_table:
77 | fns.build_latency_lookup_table(network_def, lookup_table_path=lookup_table_path,
78 | min_fc_feature_size=MIN_FC_FEATRE_SIZE,
79 | min_conv_feature_size=MIN_CONV_FEATURE_SIZE,
80 | measure_latency_batch_size=MEASURE_LATENCY_BATCH_SIZE,
81 | measure_latency_sample_times=MEASURE_LATENCY_SAMPLE_TIMES,
82 | verbose=True)
83 | print('-------------------------------------------')
84 | print('Finish building latency lookup table.')
85 | print(' Device:', torch.cuda.get_device_name())
86 | print(' Model: ', model_arch)
87 | print('-------------------------------------------')
88 |
89 | latency = fns.compute_resource(network_def, 'LATENCY', lookup_table_path)
90 | print('Computed latency: ', latency)
91 | latency = fns.measure_latency(model,
92 | [MEASURE_LATENCY_BATCH_SIZE, *INPUT_DATA_SHAPE])
93 | print('Exact latency: ', latency)
94 |
95 |
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | # Master-related filenames.
2 | MASTER_MODEL_FILENAME_TEMPLATE = 'iter_{}_best_model.pth.tar'
3 |
4 | # Worker-related filenames.
5 | WORKER_MODEL_FILENAME_TEMPLATE = 'iter_{}_block_{}_model.pth.tar'
6 | WORKER_ACCURACY_FILENAME_TEMPLATE = 'iter_{}_block_{}_accuracy.txt'
7 | WORKER_RESOURCE_FILENAME_TEMPLATE = 'iter_{}_block_{}_resource.txt'
8 | WORKER_LOG_FILENAME_TEMPLATE = 'iter_{}_block_{}_log.txt'
9 | WORKER_FINISH_FILENAME_TEMPLATE = 'iter_{}_block_{}_finish.signal'
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | # Define constants.
2 | STRING_SEPARATOR = '.'
3 |
4 | # Define layer types.
5 | CONV_LAYER_TYPES = ['Conv2d', 'ConvTranspose2d']
6 | FC_LAYER_TYPES = ['Linear']
7 | BNORM_LAYER_TYPES = ['BatchNorm2d']
8 |
9 | # Define data types.
10 | WEIGHTSTRING = 'weight'
11 | BIASSTRING = 'bias'
12 | RUNNING_MEANSTRING = 'running_mean'
13 | RUNNING_VARSTRING = 'running_var'
14 | NUM_BATCHES_TRACKED = 'num_batches_tracked'
15 |
16 | # Define keys.
17 | KEY_LAYER_TYPE_STR = 'layer_type_str' #(e.g. Linear, Conv2d)
18 | KEY_IS_DEPTHWISE = 'is_depthwise'
19 | KEY_NUM_IN_CHANNELS = 'num_in_channels'
20 | KEY_NUM_OUT_CHANNELS = 'num_out_channels'
21 | KEY_KERNEL_SIZE = 'kernel_size'
22 | KEY_STRIDE = 'stride'
23 | KEY_PADDING = 'padding'
24 | KEY_GROUPS = 'groups'
25 | KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR = 'before_squared_pixel_shuffle_factor'
26 | KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR = 'after_squared_pixel_shuffle_factor'
27 | KEY_INPUT_FEATURE_MAP_SIZE = 'input_feature_map_size'
28 | KEY_OUTPUT_FEATURE_MAP_SIZE = 'output_feature_map_size'
29 | KEY_MODEL = 'model'
30 | KEY_LATENCY = 'latency'
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/alexnet/README.md:
--------------------------------------------------------------------------------
1 | 1. **Training AlexNet on CIFAR-10.**
2 |
3 | Training:
4 | ```bash
5 | python train.py data/ --dir models/alexnet/model.pth.tar --arch alexnet
6 | ```
7 |
8 | Evaluation:
9 | ```bash
10 | python eval.py data/ --dir models/alexnet/model.pth.tar --arch alexnet
11 | ```
12 |
13 | One trained model can be found [here](https://drive.google.com/file/d/1GwugqlSl5ogRvJ2F2-4uz1fXBzt36ahB/view?usp=sharing).
14 |
15 | 2. **Measuring Latency**
16 |
17 | Here we build the latency lookup table for `cuda:0` device:
18 | ```bash
19 | python build_lookup_table.py --dir latency_lut/lut_alexnet.pkl --arch alexnet
20 | ```
21 | It measures latency of different layers contained in the network (i.e. **AlexNet** here).
22 | For conv layers, the sampled numbers of feature channels are multiples of `MIN_CONV_FEATURE_SIZE`.
23 | For fc layers, the sampled numbers of features are multiples of `MIN_FC_FEATURE_SIZE`.
24 |
25 | 3. **Applying NetAdapt**
26 |
27 | Modify which GPUs will be utilized (`-gp`) in `netadapt_alexnet-0.5mac.sh` and run the script to apply NetAdapt to a pretrained model:
28 | ```bash
29 | sh scripts/netadapt_alexnet-0.5mac.sh
30 | ```
31 |
32 | You can see how the model is simplified at each iteration in `models/alexnet/prune-by-mac/master/history.txt` and
33 | select the one that satisfies the constraints to run long-term fine-tune.
34 |
35 | After obtaining the adapted model, we need to finetune the model (here we select the one after 18 iterations):
36 | ```bash
37 | python train.py data/ --arch alexnet --resume models/alexnet/prune-by-mac/master/iter_18_best_model.pth.tar --dir models/alexnet/prune-by-mac/master/finetune_model.pth.tar --lr 0.001
38 | ```
39 |
40 |
41 |
42 |
43 |
44 |
45 | If you want to get a model with 50% latency, please run:
46 | ```bash
47 | sh scripts/netadapt_alexnet-0.5latency.sh
48 | ```
49 |
50 | 4. **Evaluation Using Adapted Models**
51 |
52 | After applying NetAdapt to a pretrained model, we can evaluate this adapted model using:
53 | ```bash
54 | python eval.py data/ --dir models/alexnet/prune-by-mac/master/finetune_model.pth.tar --arch alexnet
55 | ```
56 |
57 | The adapted model can be restored **without modifying the orignal python file**.
58 |
59 | We provide one adapted model [here](https://drive.google.com/file/d/1VH9c2orF2W0P21gD8NrdTdvwP_uJJgYA/view?usp=sharing).
--------------------------------------------------------------------------------
/docs/mobilenet/README.md:
--------------------------------------------------------------------------------
1 | 1. **Training MobileNet on CIFAR-10.**
2 |
3 | Training:
4 | ```bash
5 | python train.py data/ --dir models/mobilenet/model.pth.tar --arch mobilenet
6 | ```
7 |
8 | Evaluation:
9 | ```bash
10 | python eval.py data/ --dir models/mobilenet/model.pth.tar --arch mobilenet
11 | ```
12 |
13 | One trained model can be found [here](https://drive.google.com/file/d/1jtRZOHK1daRTKD4jYu4US84lf8YjqEdJ/view?usp=sharing).
14 |
15 |
16 | 2. **Measuring Latency**
17 |
18 | Here we build the latency lookup table for `cuda:0` device:
19 | ```bash
20 | python build_lookup_table.py --dir latency_lut/lut_mobilenet.pkl --arch mobilenet
21 | ```
22 | It measures latency of different layers contained in the network (i.e. **MobileNet** here).
23 | For conv layers, the sampled numbers of feature channels are multiples of `MIN_CONV_FEATURE_SIZE`.
24 | For fc layers, the sampled numbers of features are multiples of `MIN_FC_FEATURE_SIZE`.
25 |
26 | 3. **Applying NetAdapt**
27 |
28 | Modify which GPUs will be utilized (`-gp`) in `netadapt_mobilenet-0.5mac.sh` and run the script to apply NetAdapt to a pretrained model:
29 | ```bash
30 | sh scripts/netadapt_mobilenet-0.5mac.sh
31 | ```
32 |
33 | You can see how the model is simplified at each iteration in `models/mobilenet/prune-by-mac/master/history.txt` and
34 | select the one that satisfies the constraints to run long-term fine-tune.
35 |
36 | After obtaining the adapted model, we need to finetune the model (here we select the one after 28 iterations):
37 | ```bash
38 | python train.py data/ --arch mobilenet --resume models/mobilenet/prune-by-mac/master/iter_28_best_model.pth.tar --dir models/mobilenet/prune-by-mac/master/finetune_model.pth.tar --lr 0.001
39 | ```
40 |
41 |
42 |
43 |
44 |
45 |
46 | If you want to get a model with 50% latency, please run:
47 | ```bash
48 | sh scripts/netadapt_mobilenet-0.5latency.sh
49 | ```
50 |
51 |
52 | 4. **Evaluation Using Adapted Models**
53 |
54 | After applying NetAdapt to a pretrained model, we can evaluate this adapted model using:
55 | ```bash
56 | python eval.py data/ --dir models/mobilenet/prune-by-mac/master/finetune_model.pth.tar --arch mobilenet
57 | ```
58 |
59 | The adapted model can be restored **without modifying the orignal python file**.
60 |
61 | We provide one adapted model [here](https://drive.google.com/file/d/1wkQPolgv34ESb0gyeYdygbuNuNYhIqyx/view?usp=sharing).
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | import time
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.transforms as transforms
8 | import torchvision.datasets as datasets
9 | import torch.backends.cudnn as cudnn
10 | import pickle
11 |
12 | import nets as models
13 | import functions as fns
14 |
15 | _NUM_CLASSES = 10
16 |
17 | model_names = sorted(name for name in models.__dict__
18 | if name.islower() and not name.startswith("__")
19 | and callable(models.__dict__[name]))
20 |
21 |
22 | def compute_topk_accuracy(output, target, topk=(1,)):
23 | """Computes the accuracy over the k top predictions for the specified values of k"""
24 | with torch.no_grad():
25 | maxk = max(topk)
26 | batch_size = target.size(0)
27 |
28 | _, pred = output.topk(maxk, 1, True, True)
29 | pred = pred.t()
30 | correct = pred.eq(target.view(1, -1).expand_as(pred))
31 |
32 | res = []
33 | for k in topk:
34 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
35 | res.append(correct_k.mul_(100.0 / batch_size))
36 | return res
37 |
38 |
39 | def compute_accuracy(output, target):
40 | output = output.argmax(dim=1)
41 | acc = 0.0
42 | acc = torch.sum(target == output).item()
43 | acc = acc/output.size(0)*100
44 | return acc
45 |
46 |
47 | class AverageMeter(object):
48 | """Computes and stores the average and current value"""
49 | def __init__(self):
50 | self.reset()
51 |
52 | def reset(self):
53 | self.val = 0
54 | self.avg = 0
55 | self.sum = 0
56 | self.count = 0
57 |
58 | def get_avg(self):
59 | return self.avg
60 |
61 | def update(self, val, n=1):
62 | self.val = val
63 | self.sum += val * n
64 | self.count += n
65 | self.avg = self.sum / self.count
66 |
67 |
68 | def eval(test_loader, model, args):
69 | batch_time = AverageMeter()
70 | acc = AverageMeter()
71 |
72 | # switch to eval mode
73 | model.eval()
74 |
75 | end = time.time()
76 | for i, (images, target) in enumerate(test_loader):
77 | if not args.no_cuda:
78 | images = images.cuda()
79 | target = target.cuda()
80 | output = model(images)
81 | batch_acc = compute_accuracy(output, target)
82 | acc.update(batch_acc, images.size(0))
83 | batch_time.update(time.time() - end)
84 | end = time.time()
85 |
86 | # Update statistics
87 | estimated_time_remained = batch_time.get_avg()*(len(test_loader)-i-1)
88 | fns.update_progress(i, len(test_loader),
89 | ESA='{:8.2f}'.format(estimated_time_remained)+'s',
90 | acc='{:4.2f}'.format(float(batch_acc))
91 | )
92 | print()
93 | print('Test accuracy: {:4.2f}% (time = {:8.2f}s)'.format(
94 | float(acc.get_avg()), batch_time.get_avg()*len(test_loader)))
95 | print('===================================================================')
96 | return float(acc.get_avg())
97 |
98 |
99 | if __name__ == '__main__':
100 | # Parse the input arguments.
101 | arg_parser = ArgumentParser()
102 | arg_parser.add_argument('data', metavar='DIR', help='path to dataset')
103 | arg_parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
104 | help='number of data loading workers (default: 4)')
105 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet',
106 | choices=model_names,
107 | help='model architecture: ' +
108 | ' | '.join(model_names) +
109 | ' (default: alexnet)')
110 | arg_parser.add_argument('-b', '--batch-size', default=128, type=int,
111 | metavar='N',
112 | help='batch size (default: 128)')
113 | arg_parser.add_argument('--dir', type=str, default='models/', dest='save_dir',
114 | help='path to save models (default: models/')
115 | arg_parser.add_argument('--no-cuda', action='store_true', default=False, dest='no_cuda',
116 | help='disables training on GPU')
117 |
118 | args = arg_parser.parse_args()
119 | print(args)
120 |
121 | # Data loader
122 | test_dataset = datasets.CIFAR10(root=args.data, train=False, download=True,
123 | transform=transforms.Compose([
124 | transforms.Resize(224),
125 | transforms.ToTensor(),
126 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
127 | ]))
128 | test_loader = torch.utils.data.DataLoader(
129 | test_dataset, batch_size=args.batch_size, shuffle=False,
130 | num_workers=args.workers, pin_memory=True)
131 |
132 | # Network
133 | model_arch = args.arch
134 | cudnn.benchmark = True
135 | num_classes = _NUM_CLASSES
136 | model = models.__dict__[model_arch](num_classes=num_classes)
137 |
138 | if not args.no_cuda:
139 | model = model.cuda()
140 |
141 | # Evaluation
142 | filename = os.path.join(args.save_dir)
143 |
144 | model = torch.load(filename)
145 | print(model)
146 |
147 | best_acc = eval(test_loader, model, args)
148 | print('Testing accuracy:', best_acc)
--------------------------------------------------------------------------------
/fig/netadapt_algo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/fig/netadapt_algo.png
--------------------------------------------------------------------------------
/fig/netadapt_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/fig/netadapt_fig.png
--------------------------------------------------------------------------------
/latency_lut/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/master.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | import pickle
4 | import time
5 | import torch
6 | from shutil import copyfile
7 | import subprocess
8 | import sys
9 | import warnings
10 | import common
11 | import network_utils as networkUtils
12 |
13 | '''
14 | The main file of NetAdapt.
15 |
16 | Launch workers to simplify and finetune pretrained models.
17 | '''
18 |
19 | # Define constants.
20 | _MASTER_FOLDER_FILENAME = 'master'
21 | _WORKER_FOLDER_FILENAME = 'worker'
22 | _WORKER_PY_FILENAME = 'worker.py'
23 | _HISTORY_PICKLE_FILENAME = 'history.pickle'
24 | _HISTORY_TEXT_FILENAME = 'history.txt'
25 | _SLEEP_TIME = 1
26 |
27 | # Define keys.
28 | _KEY_MASTER_ARGS = 'master_args'
29 | _KEY_HISTORY = 'history'
30 | _KEY_RESOURCE = 'resource'
31 | _KEY_ACCURACY = 'accuracy'
32 | _KEY_SOURCE_MODEL_PATH = 'source_model_path'
33 | _KEY_BLOCK = 'block'
34 | _KEY_ITERATION = 'iteration'
35 | _KEY_GPU = 'gpu'
36 | _KEY_MODEL = 'model'
37 | _KEY_NETWORK_DEF = 'network_def'
38 | _KEY_NUM_OUT_CHANNELS = 'num_out_channels'
39 |
40 | # Supported network_utils
41 | network_utils_all = sorted(name for name in networkUtils.__dict__
42 | if name.islower() and not name.startswith("__")
43 | and callable(networkUtils.__dict__[name]))
44 |
45 |
46 | def _launch_worker(worker_folder, model_path, block, resource_type, constraint, netadapt_iteration,
47 | short_term_fine_tune_iteration, input_data_shape, job_list, available_gpus,
48 | lookup_table_path, dataset_path, model_arch):
49 | '''
50 | `master.py` launches several `worker.py`.
51 | Each `worker.py` prunes one specific block and fine-tune it.
52 | This function launches one worker to run on one gpu.
53 |
54 | Input:
55 | `worker_folder`: (string) directory where `worker.py` will save models.
56 | `model_path`:(string) path to model which `worker.py` will load as pretrained model.
57 | `block`: (int) index of block to be simplified.
58 | `resource_type`: (string) (e.g. `WEIGHTS`, `FLOPS`, and `LATENCY`).
59 | `constraint`: (float) the value of constraints (e.g. 10**6 (weights)).
60 | `netadapt_iteration`: (int) indicates the current iteration of NetAdapt.
61 | `short_term_fine_tune_iteration`: (int) short-term fine-tune iteration.
62 | `input_data_shape`: (list) input data shape (C, H, W).
63 | `job_list`: (list of dict) list of current jobs. Each job is a dict, showing current iteration, block and gpu idx.
64 | `available_gpus`: (list) list of available gpu idx.
65 | `lookup_table_path`: (string) path to lookup table.
66 | `dataset_path`: (string) path to dataset.
67 | `model_arch`: (string) specifies which network_utils will be used.
68 |
69 | Output:
70 | updated_job_list: (list of dict)
71 | updated_available_gpus: (list)
72 | '''
73 | updated_job_list = job_list.copy()
74 | updated_available_gpus = available_gpus.copy()
75 | gpu = updated_available_gpus[0]
76 |
77 | if lookup_table_path == None:
78 | lookup_table_path = ''
79 |
80 | print(' Launch a worker for block {}'.format(block))
81 | with open(os.path.join(worker_folder,
82 | common.WORKER_LOG_FILENAME_TEMPLATE.format(netadapt_iteration, block)), 'w') as file_id:
83 | command_list = [sys.executable, _WORKER_PY_FILENAME, worker_folder, model_path, str(block), resource_type,
84 | str(constraint), str(netadapt_iteration), str(short_term_fine_tune_iteration), str(gpu),
85 | lookup_table_path, dataset_path] + [str(e) for e in input_data_shape] + [model_arch] + [str(args.finetune_lr)]
86 |
87 | print(command_list)
88 |
89 | subprocess.Popen(command_list, stdout=file_id, stderr=file_id)
90 |
91 | updated_job_list.append({_KEY_ITERATION: netadapt_iteration, _KEY_BLOCK: block, _KEY_GPU: gpu})
92 | del updated_available_gpus[0]
93 |
94 | return updated_job_list, updated_available_gpus
95 |
96 |
97 | def _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus):
98 | '''
99 | update job list and available gpu list based on whether a worker finishes pruning and fine-tuning.
100 |
101 | Input:
102 | `worker_folder`: (string) directory where `worker.py` will save models.
103 | `job_list`: (list of dict) list of current jobs. Each job is a dict, showing current iteration, block and gpu idx.
104 | `available_gpus`: (list) list of available gpu idx.
105 |
106 | Output:
107 | `updated_job_list`: (list of dict) if a worker finishes its job, the job will be removed from this list.
108 | `updated_available_gpus`: (list) if a worker finishes its job, the gpu will be available.
109 | '''
110 | updated_job_list = []
111 | updated_available_gpus = available_gpus.copy()
112 | for job in job_list:
113 | if os.path.exists(os.path.join(worker_folder, common.WORKER_FINISH_FILENAME_TEMPLATE.format(job[_KEY_ITERATION],
114 | job[_KEY_BLOCK]))):
115 | # Find corresponding finish file of worker
116 | updated_available_gpus.append(job[_KEY_GPU])
117 | else:
118 | updated_job_list.append(job)
119 |
120 | return updated_job_list, updated_available_gpus
121 |
122 |
123 | def _find_best_model(worker_folder, iteration, num_blocks, starting_accuracy, starting_resource):
124 | '''
125 | After all workers finish jobs, select the model with best accuracy-to-resource ratio
126 |
127 | Input:
128 | `worker_folder`: (string) directory where `worker.py` will save models.
129 | `iteration`: (int) NetAdapt iteration.
130 | `num_blocks`: (int) num of simplifiable blocks at each iteration.
131 | `starting_accuracy`: (float) initial accuracy before pruning and fine-tuning.
132 | `start_resource`: (float) initial resource sonsumption.
133 |
134 | Output:
135 | `best_accuracy`: (float) accuracy of the best pruned model.
136 | `best_model_path`: (string) path to the best model.
137 | `best_resource`: (float) resource consumption of the best model.
138 | `best_block`: (int) block index of the best model.
139 | '''
140 |
141 | best_ratio = float('Inf')
142 | best_accuracy = 0.0
143 | best_model_path = None
144 | best_resource = None
145 | best_block = None
146 | for block_idx in range(num_blocks):
147 | with open(os.path.join(worker_folder, common.WORKER_ACCURACY_FILENAME_TEMPLATE.format(iteration, block_idx)),
148 | 'r') as file_id:
149 | accuracy = float(file_id.read())
150 | with open(os.path.join(worker_folder, common.WORKER_RESOURCE_FILENAME_TEMPLATE.format(iteration, block_idx)),
151 | 'r') as file_id:
152 | resource = float(file_id.read())
153 | #ratio_resource_accuracy = (starting_accuracy - accuracy) / (starting_resource - resource + 1e-5)
154 | ratio_resource_accuracy = (starting_accuracy - accuracy + 1e-6) / (starting_resource - resource + 1e-5)
155 |
156 | print('Block id {}: resource {}, accuracy {}'.format(block_idx, resource, accuracy))
157 | if resource < starting_resource and ratio_resource_accuracy < best_ratio:
158 | #if resource < starting_resource and accuracy > best_accuracy:
159 | best_ratio = ratio_resource_accuracy
160 | best_accuracy = accuracy
161 | best_model_path = os.path.join(worker_folder,
162 | common.WORKER_MODEL_FILENAME_TEMPLATE.format(iteration, block_idx))
163 | best_resource = resource
164 | best_block = block_idx
165 | print('Best block id: {}\n'.format(best_block))
166 |
167 | return best_accuracy, best_model_path, best_resource, best_block
168 |
169 |
170 | def _save_and_print_history(network_utils, history, pickle_file_path, text_file_path):
171 | '''
172 | save history info (log: history.txt, history file: history.pickle)
173 |
174 | Input:
175 | `network_utils`: (defined in network_utils/network_utils_*) use the .extra_history_info()
176 | to get the num of output channels.
177 | `history`: (dict) records accuracy, resource, block idx, model path for each iteration and
178 | input arguments.
179 | `pickle_file_path`: (string) path to save history dict.
180 | `text_file_path`: (string) path to save history log.
181 | '''
182 | with open(pickle_file_path, 'wb') as file_id:
183 | pickle.dump(history, file_id)
184 | with open(text_file_path, 'w') as file_id:
185 | file_id.write('Iteration,Accuracy,Resource,Block,Source Model\n')
186 | for iter in range(len(history[_KEY_HISTORY])):
187 |
188 | # assume the extra hisotry info is the # of output channels per layer
189 | num_filters_str = network_utils.extra_history_info(history[_KEY_HISTORY][iter][_KEY_NETWORK_DEF])
190 | file_id.write('{},{},{},{},{},{}\n'.format(iter, history[_KEY_HISTORY][iter][_KEY_ACCURACY],
191 | history[_KEY_HISTORY][iter][_KEY_RESOURCE],
192 | history[_KEY_HISTORY][iter][_KEY_BLOCK],
193 | history[_KEY_HISTORY][iter][_KEY_SOURCE_MODEL_PATH],
194 | num_filters_str))
195 |
196 |
197 | def master(args):
198 | """
199 | The main function of the master.
200 |
201 | Note: iteration 0 means the initial model.
202 |
203 | Input:
204 | args: input arguments
205 |
206 | raise:
207 | ValueError: when:
208 | (1) no available gpus (i.e. len(args.gpus) == 0)
209 | (2) resume from previous iteration and required to use lookup table but no loookup table found
210 | (3) files exist under working_folder/master or working_folder/worker and not use `--resume`
211 | (4) target resource is not achievable (i.e. the resource consumption at a certain iteration is the same as that at previous iteration)
212 | """
213 |
214 | # Set the important paths.
215 | master_folder = os.path.join(args.working_folder, _MASTER_FOLDER_FILENAME)
216 | worker_folder = os.path.join(args.working_folder, _WORKER_FOLDER_FILENAME)
217 | history_pickle_file = os.path.join(master_folder, _HISTORY_PICKLE_FILENAME)
218 | history_text_file = os.path.join(master_folder, _HISTORY_TEXT_FILENAME)
219 |
220 | # Get available GPUs.
221 | available_gpus = args.gpus
222 | if len(available_gpus) == 0:
223 | raise ValueError('At least one gpu must be specified.')
224 |
225 | # Resume or do iteration 0.
226 | if args.resume:
227 | with open(history_pickle_file, 'rb') as file_id:
228 | history = pickle.load(file_id)
229 | args = history[_KEY_MASTER_ARGS]
230 |
231 | # Initialize variables.
232 | current_iter = len(history[_KEY_HISTORY]) - 1
233 | current_resource = history[_KEY_HISTORY][-1][_KEY_RESOURCE]
234 | current_model_path = os.path.join(master_folder,
235 | common.MASTER_MODEL_FILENAME_TEMPLATE.format(current_iter))
236 | current_accuracy = history[_KEY_HISTORY][-1][_KEY_ACCURACY]
237 |
238 | # Get the network utils.
239 | model = torch.load(current_model_path, map_location=lambda storage, loc: storage)
240 |
241 | # Select network_utils.
242 | model_arch = args.arch
243 | network_utils = networkUtils.__dict__[model_arch](model, args.input_data_shape, args.dataset_path)
244 |
245 | if args.lookup_table_path != None and not os.path.exists(args.lookup_table_path):
246 | errMsg = 'Resume from a previous task but the {} lookup table is not found.'.format(args.resource_type)
247 | raise ValueError(errMsg)
248 | del model
249 |
250 | # Print the message.
251 | print(('Resume from iteration {:>3}: current_accuracy = {:>8.3f}, '
252 | 'current_resource = {:>8.3f}').format(current_iter, current_accuracy, current_resource))
253 | print('arguments:', args)
254 |
255 | else:
256 | # Initialize the iteration.
257 | current_iter = 0
258 |
259 | # Create the folder structure.
260 | if not os.path.exists(args.working_folder):
261 | os.makedirs(args.working_folder)
262 | print('Create directory', args.working_folder)
263 | if not os.path.exists(master_folder):
264 | os.mkdir(master_folder)
265 | print('Create directory', master_folder)
266 | elif os.listdir(master_folder):
267 | errMsg = 'Find previous files in the master directory {}. Please use `--resume` or delete those files'.format(master_folder)
268 | raise ValueError(errMsg)
269 |
270 | if not os.path.exists(worker_folder):
271 | os.mkdir(worker_folder)
272 | print('Create directory', worker_folder)
273 | elif os.listdir(worker_folder):
274 | errMsg = 'Find previous files in the worker directory {}. Please use `--resume` or delete those files'.format(worker_folder)
275 | raise ValueError(errMsg)
276 |
277 | # Backup the initial model.
278 | current_model_path = os.path.join(master_folder,
279 | common.MASTER_MODEL_FILENAME_TEMPLATE.format(current_iter))
280 | copyfile(args.init_model_path, current_model_path)
281 |
282 | # Initialize variables.
283 | model = torch.load(current_model_path)
284 |
285 | # Select network_utils.
286 | model_arch = args.arch
287 | network_utils = networkUtils.__dict__[model_arch](model, args.input_data_shape, args.dataset_path)
288 |
289 | network_def = network_utils.get_network_def_from_model(model)
290 | if args.lookup_table_path != None and not os.path.exists(args.lookup_table_path):
291 | warnMsg = 'The {} lookup table is not found and going to be built.'.format(args.resource_type)
292 | warnings.warn(warnMsg)
293 | network_utils.build_lookup_table(network_def, args.resource_type, args.lookup_table_path)
294 | current_resource = network_utils.compute_resource(network_def, args.resource_type, args.lookup_table_path)
295 |
296 | current_accuracy = network_utils.evaluate(model)
297 | current_block = None
298 |
299 | if args.init_resource_reduction == None:
300 | args.init_resource_reduction = args.init_resource_reduction_ratio*current_resource
301 | print('`--init_resource_reduction` is not specified')
302 | print('Use `--init_resource_reduction_ratio` ({}) to get `init_resource_reduction` ({})\n'.format(
303 | args.init_resource_reduction_ratio, args.init_resource_reduction))
304 | if args.budget == None:
305 | args.budget = args.budget_ratio*current_resource
306 | print('`--budget` is not specified')
307 | print('Use `--budget_ratio` ({}) to get `budget` ({})\n'.format(
308 | args.budget_ratio, args.budget))
309 |
310 | # Create and save the history.
311 | history = {_KEY_MASTER_ARGS: args, _KEY_HISTORY: []}
312 | history[_KEY_HISTORY].append({_KEY_RESOURCE: current_resource,
313 | _KEY_SOURCE_MODEL_PATH: args.init_model_path,
314 | _KEY_ACCURACY: current_accuracy,
315 | _KEY_BLOCK: current_block,
316 | _KEY_NETWORK_DEF: network_def})
317 | _save_and_print_history(network_utils, history, history_pickle_file, history_text_file)
318 | del model, network_def
319 |
320 | # Print the message.
321 | print(('Start from iteration {:>3}: current_accuracy = {:>8.3f}, '
322 | 'current_resource = {:>8.3f}').format(current_iter, current_accuracy, current_resource))
323 |
324 |
325 | current_iter += 1
326 |
327 | # Start adaptation.
328 | while current_iter <= args.max_iters and current_resource > args.budget:
329 |
330 | start_time = time.time()
331 |
332 | # Set the target resource.
333 | target_resource = current_resource - args.init_resource_reduction * (
334 | args.resource_reduction_decay ** (current_iter - 1))
335 |
336 | # Print the message.
337 | print('===================================================================')
338 | print(
339 | ('Process iteration {:>3}: current_accuracy = {:>8.3f}, '
340 | 'current_resource = {:>8.3f}, target_resource = {:>8.3f}').format(
341 | current_iter, current_accuracy, current_resource, target_resource))
342 |
343 | # Launch the workers.
344 | job_list = []
345 |
346 | # Launch worker for each block
347 | for block_idx in range(network_utils.get_num_simplifiable_blocks()):
348 | # Check and update the gpu availability.
349 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus)
350 | while not available_gpus:
351 | # print(' Wait for the next available gpu...')
352 | time.sleep(_SLEEP_TIME)
353 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus)
354 |
355 | # Launch a worker.
356 | job_list, available_gpus = _launch_worker(worker_folder, current_model_path, block_idx, args.resource_type,
357 | target_resource, current_iter,
358 | args.short_term_fine_tune_iteration, args.input_data_shape,
359 | job_list, available_gpus, args.lookup_table_path,
360 | args.dataset_path, args.arch)
361 | print('Update job list: ', job_list)
362 | print('Update available gpu:', available_gpus, '\n')
363 |
364 | # Wait until all the workers finish.
365 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus)
366 | while job_list:
367 | time.sleep(_SLEEP_TIME)
368 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus)
369 |
370 | # Find the best model.
371 | best_accuracy, best_model_path, best_resource, best_block = (
372 | _find_best_model(worker_folder, current_iter, network_utils.get_num_simplifiable_blocks(), current_accuracy,
373 | current_resource))
374 |
375 | # Check whether the target_resource is achieved.
376 | if not best_model_path:
377 | raise ValueError('target_resource {} is not achievable in iter {}.'.format(target_resource, current_iter))
378 | if best_resource > target_resource:
379 | warnMsg = "Iteration {}: target resource {} is not achieved. Current best resource is {}".format(current_iter, target_resource, best_resource)
380 | warnings.warn(warnMsg)
381 |
382 | # Update the variables.
383 | current_model_path = os.path.join(master_folder,
384 | common.MASTER_MODEL_FILENAME_TEMPLATE.format(current_iter))
385 | copyfile(best_model_path, current_model_path)
386 | current_accuracy = best_accuracy
387 | current_resource = best_resource
388 | current_block = best_block
389 |
390 | if args.save_interval == -1 or (current_iter % args.save_interval != 0):
391 | for block_idx in range(network_utils.get_num_simplifiable_blocks()):
392 | temp_model_path = os.path.join(worker_folder, common.WORKER_MODEL_FILENAME_TEMPLATE.format(current_iter, block_idx))
393 | os.remove(temp_model_path)
394 | print('Remove', temp_model_path)
395 | print(' ')
396 |
397 | # Save and print the history.
398 | model = torch.load(current_model_path)
399 | if type(model) is dict:
400 | model = model[_KEY_MODEL]
401 | network_def = network_utils.get_network_def_from_model(model)
402 | history[_KEY_HISTORY].append({_KEY_RESOURCE: current_resource,
403 | _KEY_SOURCE_MODEL_PATH: best_model_path,
404 | _KEY_ACCURACY: current_accuracy,
405 | _KEY_BLOCK: current_block,
406 | _KEY_NETWORK_DEF: network_def})
407 | _save_and_print_history(network_utils, history, history_pickle_file, history_text_file)
408 | del model, network_def
409 |
410 | current_iter += 1
411 |
412 | print('Finish iteration {}: time {}'.format(current_iter-1, time.time()-start_time))
413 |
414 |
415 | if __name__ == '__main__':
416 | # Parse the input arguments.
417 | arg_parser = ArgumentParser()
418 | arg_parser.add_argument('working_folder', type=str,
419 | help='Root folder where models, related files and history information are saved.')
420 | arg_parser.add_argument('input_data_shape', nargs=3, default=[3, 224, 224], type=int,
421 | help='Input data shape (C, H, W) (default: 3 224 224).')
422 | arg_parser.add_argument('-gp', '--gpus', nargs='+', default=[0], type=int,
423 | help='Indices of available gpus (default: 0).')
424 | arg_parser.add_argument('-re', '--resume', action='store_true',
425 | help='Resume from previous iteration. In order to resume, specify `--resume` and specify `working_folder` as the one you want to resume.')
426 | arg_parser.add_argument('-im', '--init_model_path',
427 | help='Path to pretrained model.')
428 | arg_parser.add_argument('-mi', '--max_iters', type=int, default=10,
429 | help='Maximum iteration of removing filters and short-term fine-tune (default: 10).')
430 | arg_parser.add_argument('-lr', '--finetune_lr', type=float, default=0.001,
431 | help='Short-term fine-tune learning rate (default: 0.001).')
432 |
433 | arg_parser.add_argument('-bu', '--budget', type=float, default=None,
434 | help='Resource constraint. If resource < `budget`, the process is terminated.')
435 | arg_parser.add_argument('-bur', '--budget_ratio', type=float, default=0.25,
436 | help='If `--budget` is not specified, `buget` = `budget_ratio`*(pretrained model resource) (default: 0.25).')
437 |
438 | arg_parser.add_argument('-rt', '--resource_type', type=str, default='FLOPS',
439 | help='Resource constraint type (default: FLOPS). We currently support `FLOPS`, `WEIGHTS`, and `LATENCY` (device cuda:0). If you want to add other resource types, please modify network_util.')
440 |
441 | arg_parser.add_argument('-ir', '--init_resource_reduction', type=float, default=None,
442 | help='For each iteration, target resource = current resource - `init_resource_reduction`*(`resource_reduction_decay`**(iteration-1)).')
443 | arg_parser.add_argument('-irr', '--init_resource_reduction_ratio', type=float, default=0.025,
444 | help='If `--init_resource_reduction` is not specified, `init_resource_reduction` = `init_resource_reduction_ratio`*(pretrained model resource) (default: 0.025).')
445 |
446 |
447 | arg_parser.add_argument('-rd', '--resource_reduction_decay', type=float, default=0.96,
448 | help='For each iteration, target resource = current resource - `init_resource_reduction`*(`resource_reduction_decay`**(iteration-1)) (default: 0.96).')
449 | arg_parser.add_argument('-st', '--short_term_fine_tune_iteration', type=int, default=10,
450 | help='Short-term fine-tune iteration (default: 10).')
451 |
452 | arg_parser.add_argument('-lt', '--lookup_table_path', type=str, default=None,
453 | help='Path to lookup table.')
454 | arg_parser.add_argument('-dp', '--dataset_path', type=str, default='',
455 | help='Path to dataset.')
456 |
457 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet',
458 | choices=network_utils_all,
459 | help='network_utils: ' +
460 | ' | '.join(network_utils_all) +
461 | ' (default: alexnet). Defines how networks are pruned, fine-tuned, and evaluated. If you want to use your own method, please specify here.')
462 |
463 | arg_parser.add_argument('-si', '--save_interval', type=int, default=-1,
464 | help='Interval of iterations that all pruned models at the same iteration will be saved. Use `-1` to save only the best model at each iteration. Use `1` to save all models at each iteration. (default: -1).')
465 |
466 | print(network_utils_all)
467 |
468 | args = arg_parser.parse_args()
469 |
470 | # Launch the master.
471 | print(args)
472 | master(args)
473 |
--------------------------------------------------------------------------------
/models/alexnet/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/helloworld/lut.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/models/helloworld/lut.pkl
--------------------------------------------------------------------------------
/models/helloworld/model_0.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/models/helloworld/model_0.pth.tar
--------------------------------------------------------------------------------
/models/mobilenet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/models/mobilenet/__init__.py
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .alexnet import *
2 | from .mobilenet import *
3 | from .helloworld import *
--------------------------------------------------------------------------------
/nets/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 |
5 |
6 | __all__ = ['AlexNet', 'alexnet']
7 |
8 |
9 | model_urls = {
10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
11 | }
12 |
13 |
14 | class AlexNet(nn.Module):
15 |
16 | def __init__(self):
17 | super(AlexNet, self).__init__()
18 | self.features = nn.Sequential(
19 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
20 | nn.ReLU(inplace=True),
21 | nn.MaxPool2d(kernel_size=3, stride=2),
22 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
23 | nn.ReLU(inplace=True),
24 | nn.MaxPool2d(kernel_size=3, stride=2),
25 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
26 | nn.ReLU(inplace=True),
27 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
30 | nn.ReLU(inplace=True),
31 | nn.MaxPool2d(kernel_size=3, stride=2),
32 | )
33 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
34 | self.classifier = nn.Sequential(
35 | nn.Dropout(),
36 | nn.Linear(256 * 6 * 6, 4096),
37 | nn.ReLU(inplace=True),
38 | nn.Dropout(),
39 | nn.Linear(4096, 4096),
40 | nn.ReLU(inplace=True),
41 | nn.Linear(4096, 1000),
42 | )
43 |
44 | def forward(self, x):
45 | x = self.features(x)
46 | x = self.avgpool(x)
47 | x = torch.flatten(x, 1)
48 | x = self.classifier(x)
49 | return x
50 |
51 |
52 | def alexnet(pretrained=False, progress=True, num_classes=1000):
53 | r"""AlexNet model architecture from the
54 | `"One weird trick..." `_ paper.
55 | Args:
56 | pretrained (bool): If True, returns a model pre-trained on ImageNet
57 | progress (bool): If True, displays a progress bar of the download to stderr
58 | """
59 | model = AlexNet()
60 | if pretrained:
61 | state_dict = model_zoo.load_url(model_urls['alexnet'], progress=progress)
62 | model.load_state_dict(state_dict)
63 | if num_classes != 1000:
64 | num_in_feature = model.classifier[6].in_features
65 | model.classifier[6] = nn.Linear(num_in_feature, num_classes)
66 | return model
--------------------------------------------------------------------------------
/nets/helloworld.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | __all__ = ['HelloWorld', 'helloworld']
6 |
7 | class HelloWorld(nn.Module):
8 |
9 | def __init__(self, num_classes=10):
10 | super(HelloWorld, self).__init__()
11 | self.features = nn.Sequential(
12 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
13 | nn.ReLU(inplace=True),
14 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1, bias=False)
19 | )
20 | self.avgpool = nn.AvgPool2d(32, 32)
21 |
22 | def forward(self, x):
23 | x = self.features(x)
24 | x = self.avgpool(x)
25 | x = x.view(x.shape[0], -1)
26 | return x
27 |
28 | def helloworld(num_classes=10):
29 | return HelloWorld()
30 |
--------------------------------------------------------------------------------
/nets/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MobileNet(nn.Module):
5 | def __init__(self, relu6=False):
6 | super(MobileNet, self).__init__()
7 |
8 | def relu(relu6):
9 | if relu6:
10 | return nn.ReLU6(inplace=True)
11 | else:
12 | return nn.ReLU(inplace=True)
13 |
14 | def conv_bn(inp, oup, stride, relu6):
15 | return nn.Sequential(
16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
17 | nn.BatchNorm2d(oup),
18 | relu(relu6),
19 | )
20 |
21 | def conv_dw(inp, oup, stride, relu6):
22 | return nn.Sequential(
23 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
24 | nn.BatchNorm2d(inp),
25 | relu(relu6),
26 |
27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
28 | nn.BatchNorm2d(oup),
29 | relu(relu6),
30 | )
31 |
32 | self.model = nn.Sequential(
33 | conv_bn( 3, 32, 2, relu6),
34 | conv_dw( 32, 64, 1, relu6),
35 | conv_dw( 64, 128, 2, relu6),
36 | conv_dw(128, 128, 1, relu6),
37 | conv_dw(128, 256, 2, relu6),
38 | conv_dw(256, 256, 1, relu6),
39 | conv_dw(256, 512, 2, relu6),
40 | conv_dw(512, 512, 1, relu6),
41 | conv_dw(512, 512, 1, relu6),
42 | conv_dw(512, 512, 1, relu6),
43 | conv_dw(512, 512, 1, relu6),
44 | conv_dw(512, 512, 1, relu6),
45 | conv_dw(512, 1024, 2, relu6),
46 | conv_dw(1024, 1024, 1, relu6),
47 | nn.AvgPool2d(7),
48 | )
49 | self.fc = nn.Linear(1024, 1000)
50 |
51 | def forward(self, x):
52 | x = self.model(x)
53 | x = x.view(x.shape[0], -1)
54 | x = self.fc(x)
55 | return x
56 |
57 | def mobilenet(pretrained=False, progress=False, num_classes=1000):
58 | model = MobileNet()
59 | if pretrained:
60 | print('Cannot download pretrained model')
61 | if num_classes != 1000:
62 | num_in_feature = model.fc.in_features
63 | model.fc = nn.Linear(num_in_feature, num_classes)
64 | return model
--------------------------------------------------------------------------------
/network_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .network_utils_abstract import *
2 | from .network_utils_alexnet import *
3 | from .network_utils_mobilenet import *
4 | from .network_utils_helloworld import *
--------------------------------------------------------------------------------
/network_utils/network_utils_abstract.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class NetworkUtilsAbstract(ABC):
5 |
6 | def __init__(self):
7 | super().__init__()
8 |
9 |
10 | @abstractmethod
11 | def get_network_def_from_model(self, model):
12 | '''
13 | network_def contains information about each layer within a model
14 |
15 | Input:
16 | `model`: pytorch model (e.g. nn.Sequential())
17 |
18 | Output:
19 | `network_def`: network_def contains layerwise information (e.g. num of output/input channels).
20 | network_def will be used to compute resoure and guide pruning models.
21 |
22 | please refer to def get_network_def_from_model() in functions.py
23 | to see one implementation.
24 | '''
25 | pass
26 |
27 |
28 | @abstractmethod
29 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type,
30 | lookup_table_path):
31 | '''
32 | Derive how much a certain block of layers ('block') should be simplified
33 | based on resource constraints.
34 |
35 | Input:
36 | `network_def`: defined in get_network_def_from_model()
37 | `constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy
38 | `resource_type`: (string) `FLOPs`, `WEIGHTS`, or `LATENCY`
39 | `lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY'
40 |
41 | Output:
42 | `simplified_network_def`: simplified network definition. Indicates how much the network should
43 | be simplified/pruned.
44 | `simplified_resource`: (float) the estimated resource consumption of simplified network_def.
45 |
46 | please refer to def simplify_network_def_based_on_constraint(...) in functions.py
47 | to see one implementation.
48 | '''
49 | pass
50 |
51 |
52 | @abstractmethod
53 | def simplify_model_based_on_network_def(self, simplified_network_def, model):
54 | '''
55 | Choose which filters to perserve
56 |
57 | Input:
58 | `simplified_network_def`: network_def shows how a model will be pruned.
59 | defined in get_network_def_from_model().
60 | Get simplified_network_def from the output `simplified_network_def` of
61 | self.simplify_network_def_based_on_constraint()
62 |
63 | `model`: model to be simplified.
64 |
65 | Output:
66 | `simplified_model`: simplified model.
67 |
68 | please refer to def simplify_model_based_on_network_def(...) in functions.py
69 | to see one implementation
70 | '''
71 | pass
72 |
73 |
74 | @abstractmethod
75 | def extra_history_info(self, network_def):
76 | '''
77 | return # of output channels per layer
78 |
79 | Input:
80 | `network_def`: defined in get_network_def_from_model()
81 |
82 | Output:
83 | `num_filters_str`: (string) show the num of output channels for each layer.
84 | Or you can define your own log
85 | '''
86 | pass
87 |
88 |
89 | @abstractmethod
90 | def build_lookup_table(self, network_def, resource_type, lookup_table_path):
91 | '''
92 | Build lookup table for layers defined by `network_def`.
93 |
94 | Input:
95 | `network_def`: defined in get_network_def_from_model()
96 | `resource_type`: (string) resource type (e.g. 'LATENCY')
97 | `lookup_table_path`: (string) path to save the file of lookup table
98 | '''
99 | pass
100 |
101 |
102 | @abstractmethod
103 | def compute_resource(self, network_def, resource_type, lookup_table_path):
104 | '''
105 | compute resource based on resource type
106 |
107 | Input:
108 | `network_def`: defined in get_network_def_from_model()
109 | `resource_type`: (string) resource type (e.g. 'WEIGHTS'/'LATENCY'/'FLOPS')
110 | `lookup_table_path`: (string) path to lookup table
111 |
112 | Output:
113 | `resource`: (float)
114 | '''
115 | pass
116 |
117 |
118 | @abstractmethod
119 | def get_num_simplifiable_blocks(self):
120 | '''
121 | Output:
122 | `num_splifiable_blocks`: (int) num of blocks whose num of output channels can be reduced.
123 | Note that simplifiable blocks do not include output layer
124 | '''
125 | pass
126 |
127 |
128 | @abstractmethod
129 | def fine_tune(self, model, iterations):
130 | '''
131 | short-term fine-tune a simplified model
132 |
133 | Input:
134 | `model`: model to be fine-tuned
135 | `iterations`: (int) num of short-term fine-tune iterations
136 |
137 | Output:
138 | `model`: fine-tuned model
139 | '''
140 | pass
141 |
142 |
143 | @abstractmethod
144 | def evaluate(self, model):
145 | '''
146 | Evaluate the accuracy of the model
147 |
148 | Input:
149 | `model`: model to be evaluated
150 |
151 | Output:
152 | `accuracy`: (float) (0~100)
153 | '''
154 | pass
--------------------------------------------------------------------------------
/network_utils/network_utils_alexnet.py:
--------------------------------------------------------------------------------
1 | from .network_utils_abstract import NetworkUtilsAbstract
2 | from collections import OrderedDict
3 | import os
4 | import sys
5 | import copy
6 | import time
7 | import torch
8 | import pickle
9 | import warnings
10 | import torch.nn as nn
11 | import torchvision.transforms as transforms
12 | import torchvision.datasets as datasets
13 | import torch.utils.data.sampler as sampler
14 |
15 | sys.path.append(os.path.abspath('../'))
16 |
17 | from constants import *
18 | import functions as fns
19 |
20 | '''
21 | This is an example of NetAdapt applied to AlexNet.
22 | We measure the latency on GPU.
23 | '''
24 |
25 | '''
26 | The size of feature maps of simplified layers along channel dimmension
27 | are multiples of '_MIN_FEATURE_SIZE'.
28 | The reason is that on mobile devices, the computation of (B, 7, H, W) tensors
29 | would take longer time than that of (B, 8, H, W) tensors.
30 | '''
31 | _MIN_CONV_FEATURE_SIZE = 8
32 | _MIN_FC_FEATURE_SIZE = 64
33 |
34 | '''
35 | How many times to run the forward function of a layer in order to get its latency.
36 | '''
37 | _MEASURE_LATENCY_SAMPLE_TIMES = 500
38 |
39 | '''
40 | The batch size of input data when running forward functions to measure latency.
41 | '''
42 | _MEASURE_LATENCY_BATCH_SIZE = 128
43 |
44 | class networkUtils_alexnet(NetworkUtilsAbstract):
45 | num_simplifiable_blocks = None
46 | input_data_shape = None
47 | train_loader = None
48 | holdout_loader = None
49 | val_loader = None
50 | optimizer = None
51 |
52 | def __init__(self, model, input_data_shape, dataset_path, finetune_lr=1e-3):
53 | '''
54 | Initialize:
55 | (1) network definition 'network_def'
56 | (2) num of simplifiable blocks 'num_simplifiable_blocks'.
57 | (3) loss function 'criterion'
58 | (4) data loader for training/validation set 'train_loader' and 'holdout_loader',
59 |
60 | Need to be implemented:
61 | (1) finetune/evaluation data loader
62 | (2) loss function
63 | (3) optimizer
64 |
65 | Input:
66 | `model`: model from which we will get network_def.
67 | `input_data_shape`: (list) [C, H, W].
68 | `dataset_path`: (string) path to dataset.
69 | `finetune_lr`: (float) short-term fine-tune learning rate.
70 | '''
71 |
72 | super().__init__()
73 |
74 | # Set the shape of the input data.
75 | self.input_data_shape = input_data_shape
76 | # Set network definition (conv & fc)
77 | network_def = self.get_network_def_from_model(model)
78 | # Set num_simplifiable_blocks.
79 | self.num_simplifiable_blocks = 0
80 | for layer_name, layer_properties in network_def.items():
81 | if not layer_properties[KEY_IS_DEPTHWISE]:
82 | self.num_simplifiable_blocks += 1
83 | # We cannot reduce the number of filters in the output layer (1).
84 | # also not consider simplifying the last two FC layer
85 | self.num_simplifiable_blocks -= 1
86 |
87 | '''
88 | The following variables need to be defined depending on tasks:
89 | (1) finetune/evaluation data loader
90 | (2) loss function
91 | (3) optimizer
92 | '''
93 | # Data loaders for fine tuning and evaluation.
94 | self.batch_size = 128
95 | self.num_workers = 4
96 | self.momentum = 0.9
97 | self.weight_decay = 1e-4
98 | self.finetune_lr = finetune_lr
99 |
100 | train_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True,
101 | transform=transforms.Compose([
102 | transforms.RandomCrop(32, padding=4),
103 | transforms.Resize(224),
104 | transforms.RandomHorizontalFlip(),
105 | transforms.ToTensor(),
106 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
107 | ]))
108 |
109 | train_loader = torch.utils.data.DataLoader(
110 | train_dataset, batch_size=self.batch_size,
111 | num_workers=self.num_workers, pin_memory=True, shuffle=True)#, sampler=train_sampler)
112 | self.train_loader = train_loader
113 |
114 | val_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True,
115 | transform=transforms.Compose([
116 | transforms.Resize(224),
117 | transforms.ToTensor(),
118 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
119 | ]))
120 | val_loader = torch.utils.data.DataLoader(
121 | val_dataset, batch_size=self.batch_size, shuffle=False,
122 | num_workers=self.num_workers, pin_memory=True) #, sampler=valid_sampler)
123 | self.val_loader = val_loader
124 |
125 | self.criterion = torch.nn.BCEWithLogitsLoss()
126 |
127 |
128 | def _get_layer_by_param_name(self, model, param_name):
129 | '''
130 | please refer to def get_layer_by_param_name(...) in functions.py
131 | '''
132 | return fns.get_layer_by_param_name(model, param_name)
133 |
134 |
135 | def _get_keys_from_ordered_dict(self, ordered_dict):
136 | '''
137 | please refer to def get_keys_from_ordered_dict(...) in functions.py
138 | '''
139 | return fns.get_keys_from_ordered_dict(ordered_dict)
140 |
141 |
142 | def get_network_def_from_model(self, model):
143 | '''
144 | please refer to get_network_def_from_model(...) in functions.py
145 | '''
146 | return fns.get_network_def_from_model(model, self.input_data_shape)
147 |
148 |
149 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type,
150 | lookup_table_path=None):
151 | '''
152 | Derive how much a certain block of layers ('block') should be simplified
153 | based on resource constraints.
154 |
155 | Here we treat one block as one layer although a block can contain several layers.
156 |
157 | Input:
158 | `network_def`: simplifiable network definition (conv & fc). Get network def from self.get_network_def_from_model(...)
159 | `block`: (int) index of block to simplify
160 | `constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy
161 | `resource_type`: `FLOPs`, `WEIGHTS`, or `LATENCY`
162 | `lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY'
163 |
164 | Output:
165 | `simplified_network_def`: simplified network definition. Indicates how much the network should
166 | be simplified/pruned.
167 | `simplified_resource`: (float) the estimated resource consumption of simplified models.
168 | '''
169 |
170 | return fns.simplify_network_def_based_on_constraint(network_def, block, constraint,
171 | resource_type, lookup_table_path)
172 |
173 |
174 | def simplify_model_based_on_network_def(self, simplified_network_def, model):
175 | '''
176 | Choose which filters to perserve
177 |
178 | Here filters with largest L2 magnitude will be kept
179 |
180 | please refer to def simplify_model_based_on_network_def(...) in functions.py
181 | '''
182 |
183 | return fns.simplify_model_based_on_network_def(simplified_network_def, model)
184 |
185 |
186 | def extra_history_info(self, network_def):
187 | '''
188 | return # of output channels per layer
189 |
190 | Input:
191 | `network_def`: (dict)
192 |
193 | Output:
194 | `num_filters_str`: (string) show the num of output channels for each layer
195 | '''
196 | num_filters_str = [str(layer_properties[KEY_NUM_OUT_CHANNELS]) for _, layer_properties in
197 | network_def.items()]
198 | num_filters_str = ' '.join(num_filters_str)
199 | return num_filters_str
200 |
201 |
202 | def _compute_weights_and_flops(self, network_def):
203 | '''
204 | please refer to def compute_weights_and_macs(...) in functions.py
205 | '''
206 | return fns.compute_weights_and_macs(network_def)
207 |
208 |
209 | def _compute_latency_from_lookup_table(self, network_def, lookup_table_path):
210 | '''
211 | please refer to def compute_latency_from_lookup_table(...) in functions.py
212 | '''
213 | return fns.compute_latency_from_lookup_table(network_def, lookup_table_path)
214 |
215 |
216 | def build_lookup_table(self, network_def_full, resource_type, lookup_table_path,
217 | min_conv_feature_size=_MIN_CONV_FEATURE_SIZE,
218 | min_fc_feature_size=_MIN_FC_FEATURE_SIZE,
219 | measure_latency_batch_size=_MEASURE_LATENCY_BATCH_SIZE,
220 | measure_latency_sample_times=_MEASURE_LATENCY_SAMPLE_TIMES,
221 | verbose=True):
222 | # Build lookup table for latency
223 | '''
224 | please refer to def build_latency_lookup_table(...) in functions.py
225 | '''
226 | return fns.build_latency_lookup_table(network_def_full, lookup_table_path,
227 | min_conv_feature_size=min_conv_feature_size,
228 | min_fc_feature_size=min_fc_feature_size,
229 | measure_latency_batch_size=measure_latency_batch_size,
230 | measure_latency_sample_times=measure_latency_sample_times,
231 | verbose=verbose)
232 |
233 |
234 | def compute_resource(self, network_def, resource_type, lookup_table_path=None):
235 | '''
236 | please refer to def compute_resource(...) in functions.py
237 | '''
238 | return fns.compute_resource(network_def, resource_type, lookup_table_path)
239 |
240 |
241 | def get_num_simplifiable_blocks(self):
242 | return self.num_simplifiable_blocks
243 |
244 |
245 | def fine_tune(self, model, iterations, print_frequency=100):
246 | '''
247 | short-term fine-tune a simplified model
248 |
249 | Input:
250 | `model`: model to be fine-tuned.
251 | `iterations`: (int) num of short-term fine-tune iterations.
252 | `print_frequency`: (int) how often to print fine-tune info.
253 |
254 | Output:
255 | `model`: fine-tuned model.
256 | '''
257 |
258 | _NUM_CLASSES = 10
259 | optimizer = torch.optim.SGD(model.parameters(), self.finetune_lr,
260 | momentum=self.momentum, weight_decay=self.weight_decay)
261 | model = model.cuda()
262 | model.train()
263 | dataloader_iter = iter(self.train_loader)
264 | for i in range(iterations):
265 | try:
266 | (input, target) = next(dataloader_iter)
267 | except:
268 | dataloader_iter = iter(self.train_loader)
269 | (input, target) = next(dataloader_iter)
270 |
271 | if i % print_frequency == 0:
272 | print('Fine-tuning iteration {}'.format(i))
273 | sys.stdout.flush()
274 |
275 | target.unsqueeze_(1)
276 | target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES)
277 | target_onehot.zero_()
278 | target_onehot.scatter_(1, target, 1)
279 | target.squeeze_(1)
280 | input, target = input.cuda(), target.cuda()
281 | target_onehot = target_onehot.cuda()
282 |
283 | pred = model(input)
284 | loss = self.criterion(pred, target_onehot)
285 | optimizer.zero_grad()
286 | loss.backward() # compute gradient and do SGD step
287 | optimizer.step()
288 | return model
289 |
290 |
291 | def evaluate(self, model, print_frequency=10):
292 | '''
293 | Evaluate the accuracy of the model
294 |
295 | Input:
296 | `model`: model to be evaluated.
297 | `print_frequency`: how often to print evaluation info.
298 |
299 | Output:
300 | accuracy: (float) (0~100)
301 | '''
302 |
303 | model = model.cuda()
304 | model.eval()
305 | acc = .0
306 | num_samples = .0
307 | with torch.no_grad():
308 | for i, (input, target) in enumerate(self.val_loader):
309 | input, target = input.cuda(), target.cuda()
310 | pred = model(input)
311 | pred = pred.argmax(dim=1)
312 | batch_acc = torch.sum(target == pred)
313 | acc += batch_acc.item()
314 | num_samples += pred.shape[0]
315 |
316 | if i % print_frequency == 0:
317 | fns.update_progress(i, len(self.val_loader))
318 | print(' ')
319 | print(' ')
320 | print('Test accuracy: {:4.2f}% '.format(float(acc/num_samples*100)))
321 | print('===================================================================')
322 | return acc/num_samples*100
323 |
324 |
325 | def alexnet(model, input_data_shape, dataset_path, finetune_lr=1e-3):
326 | return networkUtils_alexnet(model, input_data_shape, dataset_path, finetune_lr)
--------------------------------------------------------------------------------
/network_utils/network_utils_helloworld.py:
--------------------------------------------------------------------------------
1 | from .network_utils_abstract import NetworkUtilsAbstract
2 | import os
3 | import sys
4 | import torch
5 | import copy
6 | import pickle
7 | import warnings
8 | sys.path.append(os.path.abspath('../'))
9 | from constants import *
10 |
11 |
12 | class networkUtils_helloworld(NetworkUtilsAbstract):
13 |
14 | def __init__(self, model, input_data_shape, dataset_path=None, finetune_lr=1e-3):
15 | super(networkUtils_helloworld).__init__()
16 | '''
17 | 4 conv layers:
18 | conv1: 3, 16
19 | conv2: 16, 32
20 | conv3: 32, 64
21 | conv4: 64, 10
22 |
23 | Input:
24 | `model`: model from which we will get network_def.
25 | `input_data_shape`: (list) [C, H, W].
26 | `dataset_path`: (string) path to dataset.
27 | `finetune_lr`: (float) short-term fine-tune learning rate.
28 | '''
29 |
30 | self.input_data_shape = input_data_shape
31 | self.lookup_table = None
32 |
33 |
34 | def get_network_def_from_model(self, model):
35 | '''
36 | return network def (list) of the input model containing layerwise info
37 |
38 | Input:
39 | `model`: model we will get network_def from
40 |
41 | Output:
42 | `network_def`: (list) each element corresponds to one layer and
43 | is a tuple (num_input_channels, num_output_channels)
44 | '''
45 | network_def = list()
46 | for idx in range(4):
47 | layer = getattr(model.features, str(idx * 2))
48 | network_def.append((layer.in_channels, layer.out_channels))
49 | return network_def
50 |
51 |
52 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type,
53 | lookup_table_path=None):
54 | '''
55 | Derive how much a certain block of layers ('block') should be simplified
56 | based on resource constraints.
57 |
58 | Input:
59 | `network_def`: (list) simplifiable network definition (conv).
60 | defined in self.get_network_def_from_model(...)
61 | `block`: (int) index of block to simplify
62 | `constraint`: (float) representing the FLOPs/weights constraint the simplied model should satisfy
63 | `resource_type`: (string) `FLOPS`, `WEIGHTS`
64 | `lookup_table_path`: (string) path to lookup table. Here we construct lookup table for FLOPS and it is needed only when resource_type == 'FLOPS'
65 |
66 | Output:
67 | `simplified_network_def`: (list) simplified network_def whose resource is `simplified_resource`
68 | `simplified_resource`: (float) resource comsumption of `simplified_network_def`
69 | '''
70 |
71 | assert block < self.get_num_simplifiable_blocks()
72 |
73 | # Determine the number of filters and the resource consumption.
74 | simplified_network_def = copy.deepcopy(network_def)
75 | simplified_resource = None
76 | return_with_constraint_satisfied = False
77 | num_out_channels_try = list(range(network_def[block][1], 0, -1))
78 |
79 | for current_num_out_channels in num_out_channels_try:
80 | simplified_network_def[block] = (simplified_network_def[block][0], current_num_out_channels)
81 | simplified_network_def[block+1] = (current_num_out_channels, simplified_network_def[block+1][1])
82 | simplified_resource = self.compute_resource(simplified_network_def, resource_type, lookup_table_path)
83 | if simplified_resource <= constraint:
84 | return_with_constraint_satisfied = True
85 | break
86 | if not return_with_constraint_satisfied:
87 | warnings.warn('Constraint not satisfied: constraint = {}, simplified_resource = {}'.format(constraint,
88 | simplified_resource))
89 | return simplified_network_def, simplified_resource
90 |
91 |
92 | def simplify_model_based_on_network_def(self, simplified_network_def, model):
93 | '''
94 | Choose which filters to perserve (Here only the first `num_filters` filters will be perserved)
95 |
96 | Input:
97 | `simplified_network_def`: (list) network_def shows how each layer should be simplified.
98 | `model`: model to be simplified.
99 |
100 | Output:
101 | `simplified_model`: simplified model
102 | '''
103 | simplified_model = copy.deepcopy(model)
104 |
105 | for idx in range(self.get_num_simplifiable_blocks()):
106 | layer = getattr(simplified_model.features, str(idx * 2))
107 | num_filters = simplified_network_def[idx][1]
108 | # Here we keep the first `num_filters` weights
109 | # Not based on magnitude
110 |
111 | # update output channel weight
112 | setattr(layer, WEIGHTSTRING, torch.nn.Parameter(getattr(layer, WEIGHTSTRING)[0:num_filters, :, :, :]))
113 | layer.out_channels = num_filters
114 | # update input channel weight (next layer)
115 | layer = getattr(simplified_model.features, str((idx+1)*2))
116 | setattr(layer, WEIGHTSTRING, torch.nn.Parameter(getattr(layer, WEIGHTSTRING)[:, 0:num_filters, :, :]))
117 | layer.in_channels = num_filters
118 | return simplified_model
119 |
120 |
121 | def extra_history_info(self, network_def):
122 | '''
123 | Output num of channels layerwise
124 |
125 | Input:
126 | `network_def`: (list) defined in self.get_network_def_from_model()
127 |
128 | Output:
129 | `num_filters_str`: (string) representing num of output channels
130 | '''
131 | num_filters_str = []
132 | for layer_idx in range(len(network_def)):
133 | num_filters_str.append(str(network_def[layer_idx][1]))
134 | num_filters_str = ' '.join(num_filters_str)
135 | return num_filters_str
136 |
137 |
138 | def _compute_weights(self, network_def):
139 | '''
140 | Compute the number of parameters of a whole network.
141 | (considering only weights)
142 |
143 | Input:
144 | `network_def`: (list) defined in get_network_def_from_model()
145 |
146 | Output:
147 | `total_num_weights`: (float) num of weights
148 | '''
149 | total_num_weights = 0.0
150 | for layer_idx, layer_properties in enumerate(network_def):
151 | layer_num_weights = network_def[layer_idx][0] * network_def[layer_idx][1] * 3 * 3
152 | total_num_weights += layer_num_weights
153 | return total_num_weights
154 |
155 |
156 | def _compute_flops_from_lookup_table(self, network_def, lookup_table_path):
157 | # Note that it return FLOPs
158 | '''
159 | Compute FLOPs from a lookup table.
160 |
161 | Although num of FLOPs can be calculated,
162 | we use lookup table here to show how NetAdapt framework uses lookup tables for resource estimation.
163 |
164 | Input:
165 | `network_def`: (list) defined in get_network_def_from_model()
166 | `lookup_table_path`: (string) path to lookup table
167 |
168 | Output:
169 | `resource`: (float) num of flops
170 | '''
171 | resource = 0
172 | if self.lookup_table == None:
173 | with open(lookup_table_path, 'rb') as file_id:
174 | self.lookup_table = pickle.load(file_id)
175 | for layer_idx in range(len(network_def)):
176 | if (network_def[layer_idx][0], network_def[layer_idx][1]) in self.lookup_table[layer_idx].keys():
177 | resource += self.lookup_table[layer_idx][(network_def[layer_idx][0], network_def[layer_idx][1])]
178 | return resource
179 |
180 |
181 | def build_lookup_table(self, network_def_full, resource_type, lookup_table_path):
182 | '''
183 | Build lookup table
184 | Here we only build a lookup table for FLOPs
185 |
186 | Input:
187 | `network_def_full`: (list) defined in get_network_def_from_model()
188 | `resource_type`: not used here as we build 'FLOPS' here
189 | `lookup_table_path`: (string) path to save lookup table
190 | '''
191 |
192 | lookup_table = []
193 | print("Building lookup table.")
194 | for i in range(4):
195 | feature_map_resource = dict()
196 | for num_in_channels in range(network_def_full[i][0], 0, -1):
197 | for num_out_channels in range(network_def_full[i][1], 0, -1):
198 | feature_map_resource[(num_in_channels, num_out_channels)] = num_in_channels*num_out_channels*32*32*3*3
199 | lookup_table.append(feature_map_resource)
200 | with open(lookup_table_path, 'wb') as file_id:
201 | pickle.dump(lookup_table, file_id)
202 | return
203 |
204 |
205 | def compute_resource(self, network_def, resource_type, lookup_table_path=None):
206 | '''
207 | Input:
208 | `network_def`: (list) defined in get_network_def_from_model()
209 | `resource_type`: (string) 'FLOPS'/'WEIGHTS'
210 | `lookup_table_path`: (string) path to lookup table
211 |
212 | Output:
213 | resource: (float) num of flops or weights
214 | '''
215 | if resource_type == 'FLOPS':
216 | return self._compute_flops_from_lookup_table(network_def, lookup_table_path)
217 | else:
218 | return self._compute_weights(network_def)
219 |
220 |
221 | def get_num_simplifiable_blocks(self):
222 | '''
223 | 4 conv layers
224 |
225 | the # of output channel of the last layer is not reducible
226 | '''
227 | return 3
228 |
229 |
230 | def fine_tune(self, model, iterations, print_frequency=100):
231 | '''
232 | do not finetune in this example
233 |
234 | please specify data loader, loss function, optimizer to customize finetuning in your case
235 |
236 | Input:
237 | `model`: model whose weights will be modified
238 | `iterations`: (int) num of iteration to change model weights
239 | `print_frequency`: (int) how often to print log info
240 |
241 | Output:
242 | `finetune_model`: model whose weights have been modified
243 | '''
244 |
245 | finetune_model = copy.deepcopy(model)
246 | for i in range(iterations):
247 | for idx in range(4):
248 | layer = getattr(finetune_model.features, str(idx * 2))
249 | layer.weight.data = layer.weight.data + idx
250 | return finetune_model
251 |
252 |
253 | def evaluate(self, model):
254 | '''
255 | for simplicity, we return a value determined by the network architecture
256 |
257 | please specify evaluate function in your case
258 |
259 | Input:
260 | `model`: model whose architecture will determine the output value
261 |
262 | Output:
263 | `acc`: (int) value depends on the input model architecture
264 | '''
265 |
266 | network_def = self.get_network_def_from_model(model)
267 | acc = 0
268 | if network_def[0][1] != 16 and network_def[1][1] == 32 and network_def[2][1] == 64:
269 | acc = 1
270 | elif network_def[0][1] == 16 and network_def[1][1] != 32 and network_def[2][1] == 64:
271 | acc = 5
272 | elif network_def[0][1] == 16 and network_def[1][1] == 32 and network_def[2][1] != 64:
273 | acc = 80
274 | elif network_def[0][1] != 16 and network_def[1][1] == 32 and network_def[2][1] != 64:
275 | acc = 85
276 | elif network_def[0][1] == 16 and network_def[1][1] != 32 and network_def[2][1] != 64:
277 | acc = 10
278 | elif network_def[0][1] != 16 and network_def[1][1] == 32 and network_def[2][1] != 64:
279 | acc = 12
280 | elif network_def[0][1] != 16 and network_def[1][1] != 32 and network_def[2][1] != 64:
281 | acc = 90
282 | else:
283 | acc = 95
284 |
285 | return acc
286 |
287 |
288 | def helloworld(model, input_data_shape, dataset_path=None, finetune_lr=1e-3):
289 | return networkUtils_helloworld(model, input_data_shape, dataset_path, finetune_lr)
--------------------------------------------------------------------------------
/network_utils/network_utils_mobilenet.py:
--------------------------------------------------------------------------------
1 | from .network_utils_abstract import NetworkUtilsAbstract
2 | from collections import OrderedDict
3 | import os
4 | import sys
5 | import copy
6 | import time
7 | import torch
8 | import pickle
9 | import warnings
10 | import torch.nn as nn
11 | import torchvision.transforms as transforms
12 | import torchvision.datasets as datasets
13 | import torch.utils.data.sampler as sampler
14 |
15 | sys.path.append(os.path.abspath('../'))
16 |
17 | from constants import *
18 | import functions as fns
19 |
20 | '''
21 | This is an example of NetAdapt applied to MobileNet.
22 | We measure the latency on GPU.
23 | '''
24 |
25 | '''
26 | The size of feature maps of simplified layers along channel dimmension
27 | are multiples of '_MIN_FEATURE_SIZE'.
28 | The reason is that on mobile devices, the computation of (B, 7, H, W) tensors
29 | would take longer time than that of (B, 8, H, W) tensors.
30 | '''
31 | _MIN_CONV_FEATURE_SIZE = 8
32 | _MIN_FC_FEATURE_SIZE = 64
33 |
34 | '''
35 | How many times to run the forward function of a layer in order to get its latency.
36 | '''
37 | _MEASURE_LATENCY_SAMPLE_TIMES = 500
38 |
39 | '''
40 | The batch size of input data when running forward functions to measure latency.
41 | '''
42 | _MEASURE_LATENCY_BATCH_SIZE = 128
43 |
44 | class networkUtils_mobilenet(NetworkUtilsAbstract):
45 | num_simplifiable_blocks = None
46 | input_data_shape = None
47 | train_loader = None
48 | holdout_loader = None
49 | val_loader = None
50 | optimizer = None
51 |
52 | def __init__(self, model, input_data_shape, dataset_path, finetune_lr=1e-3):
53 | '''
54 | Initialize:
55 | (1) network definition 'network_def'
56 | (2) num of simplifiable blocks 'num_simplifiable_blocks'.
57 | (3) loss function 'criterion'
58 | (4) data loader for training/validation set 'train_loader' and 'holdout_loader',
59 | (5) optimizer 'optimizer'
60 |
61 | Need to be implemented:
62 | (1) finetune/evaluation data loader
63 | (2) loss function
64 | (3) optimizer
65 |
66 | Input:
67 | `model`: model from which we will get network_def.
68 | `input_data_shape`: (list) [C, H, W].
69 | `dataset_path`: (string) path to dataset.
70 | `finetune_lr`: (float) short-term fine-tune learning rate.
71 | '''
72 |
73 | super().__init__()
74 |
75 | # Set the shape of the input data.
76 | self.input_data_shape = input_data_shape
77 | # Set network definition (conv & fc)
78 | network_def = self.get_network_def_from_model(model)
79 | # Set num_simplifiable_blocks.
80 | self.num_simplifiable_blocks = 0
81 | for layer_name, layer_properties in network_def.items():
82 | if not layer_properties[KEY_IS_DEPTHWISE]:
83 | self.num_simplifiable_blocks += 1
84 | # We cannot reduce the number of filters in the output layer (1).
85 | self.num_simplifiable_blocks -= 1
86 |
87 | '''
88 | The following variables need to be defined depending on tasks:
89 | (1) finetune/evaluation data loader
90 | (2) loss function
91 | (3) optimizer
92 | '''
93 | # Data loaders for fine tuning and evaluation.
94 | self.batch_size = 128
95 | self.num_workers = 4
96 | self.momentum = 0.9
97 | self.weight_decay = 1e-4
98 | self.finetune_lr = finetune_lr
99 |
100 | train_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True,
101 | transform=transforms.Compose([
102 | transforms.RandomCrop(32, padding=4),
103 | transforms.Resize(224),
104 | transforms.RandomHorizontalFlip(),
105 | transforms.ToTensor(),
106 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
107 | ]))
108 |
109 | train_loader = torch.utils.data.DataLoader(
110 | train_dataset, batch_size=self.batch_size,
111 | num_workers=self.num_workers, pin_memory=True, shuffle=True)
112 | self.train_loader = train_loader
113 |
114 | val_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True,
115 | transform=transforms.Compose([
116 | transforms.Resize(224),
117 | transforms.ToTensor(),
118 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
119 | ]))
120 | val_loader = torch.utils.data.DataLoader(
121 | val_dataset, batch_size=self.batch_size, shuffle=False,
122 | num_workers=self.num_workers, pin_memory=True)
123 | self.val_loader = val_loader
124 |
125 | self.criterion = torch.nn.BCEWithLogitsLoss()
126 | self.optimizer = torch.optim.SGD(model.parameters(),
127 | finetune_lr, momentum=self.momentum, weight_decay=self.weight_decay)
128 |
129 |
130 | def _get_layer_by_param_name(self, model, param_name):
131 | '''
132 | please refer to def get_layer_by_param_name(...) in functions.py
133 | '''
134 | return fns.get_layer_by_param_name(model, param_name)
135 |
136 |
137 | def _get_keys_from_ordered_dict(self, ordered_dict):
138 | '''
139 | please refer to def get_keys_from_ordered_dict(...) in functions.py
140 | '''
141 | return fns.get_keys_from_ordered_dict(ordered_dict)
142 |
143 |
144 | def get_network_def_from_model(self, model):
145 | '''
146 | please refer to get_network_def_from_model(...) in functions.py
147 | '''
148 | return fns.get_network_def_from_model(model, self.input_data_shape)
149 |
150 |
151 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type,
152 | lookup_table_path=None):
153 | '''
154 | Derive how much a certain block of layers ('block') should be simplified
155 | based on resource constraints.
156 |
157 | Here we treat one block as one layer although a block can contain several layers.
158 |
159 | Input:
160 | `network_def`: simplifiable network definition (conv & fc). Get network def from self.get_network_def_from_model(...)
161 | `block`: (int) index of block to simplify
162 | `constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy
163 | `resource_type`: `FLOPs`, `WEIGHTS`, or `LATENCY`
164 | `lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY'
165 |
166 | Output:
167 | `simplified_network_def`: simplified network definition. Indicates how much the network should
168 | be simplified/pruned.
169 | `simplified_resource`: (float) the estimated resource consumption of simplified models.
170 | '''
171 | return fns.simplify_network_def_based_on_constraint(network_def, block, constraint,
172 | resource_type, lookup_table_path)
173 |
174 |
175 | def simplify_model_based_on_network_def(self, simplified_network_def, model):
176 | '''
177 | Choose which filters to perserve
178 |
179 | Here filters with largest L2 magnitude will be kept
180 |
181 | please refer to def simplify_model_based_on_network_def(...) in functions.py
182 | '''
183 | return fns.simplify_model_based_on_network_def(simplified_network_def, model)
184 |
185 |
186 | def extra_history_info(self, network_def):
187 | '''
188 | return # of output channels per layer
189 |
190 | Input:
191 | `network_def`: (dict)
192 |
193 | Output:
194 | `num_filters_str`: (string) show the num of output channels for each layer
195 | '''
196 | num_filters_str = [str(layer_properties[KEY_NUM_OUT_CHANNELS]) for _, layer_properties in
197 | network_def.items()]
198 | num_filters_str = ' '.join(num_filters_str)
199 | return num_filters_str
200 |
201 |
202 | def _compute_weights_and_flops(self, network_def):
203 | '''
204 | please refer to def compute_weights_and_macs(...) in functions.py
205 | '''
206 | return fns.compute_weights_and_macs(network_def)
207 |
208 |
209 | def _compute_latency_from_lookup_table(self, network_def, lookup_table_path):
210 | '''
211 | please refer to def compute_latency_from_lookup_table(...) in functions.py
212 | '''
213 | return fns.compute_latency_from_lookup_table(network_def, lookup_table_path)
214 |
215 |
216 | def build_lookup_table(self, network_def_full, resource_type, lookup_table_path,
217 | min_conv_feature_size=_MIN_CONV_FEATURE_SIZE,
218 | min_fc_feature_size=_MIN_FC_FEATURE_SIZE,
219 | measure_latency_batch_size=_MEASURE_LATENCY_BATCH_SIZE,
220 | measure_latency_sample_times=_MEASURE_LATENCY_SAMPLE_TIMES,
221 | verbose=True):
222 | # Build lookup table for latency
223 | '''
224 | please refer to def build_latency_lookup_table(...) in functions.py
225 | '''
226 | return fns.build_latency_lookup_table(network_def_full, lookup_table_path,
227 | min_conv_feature_size=min_conv_feature_size,
228 | min_fc_feature_size=min_fc_feature_size,
229 | measure_latency_batch_size=measure_latency_batch_size,
230 | measure_latency_sample_times=measure_latency_sample_times,
231 | verbose=verbose)
232 |
233 |
234 | def compute_resource(self, network_def, resource_type, lookup_table_path=None):
235 | '''
236 | please refer to def compute_resource(...) in functions.py
237 | '''
238 | return fns.compute_resource(network_def, resource_type, lookup_table_path)
239 |
240 |
241 | def get_num_simplifiable_blocks(self):
242 | return self.num_simplifiable_blocks
243 |
244 |
245 | def fine_tune(self, model, iterations, print_frequency=100):
246 | '''
247 | short-term fine-tune a simplified model
248 |
249 | Input:
250 | `model`: model to be fine-tuned.
251 | `iterations`: (int) num of short-term fine-tune iterations.
252 | `print_frequency`: (int) how often to print fine-tune info.
253 |
254 | Output:
255 | `model`: fine-tuned model.
256 | '''
257 |
258 | _NUM_CLASSES = 10
259 | optimizer = torch.optim.SGD(model.parameters(), self.finetune_lr,
260 | momentum=self.momentum, weight_decay=self.weight_decay)
261 | model = model.cuda()
262 | model.train()
263 | dataloader_iter = iter(self.train_loader)
264 | for i in range(iterations):
265 | try:
266 | (input, target) = next(dataloader_iter)
267 | except:
268 | dataloader_iter = iter(self.train_loader)
269 | (input, target) = next(dataloader_iter)
270 |
271 | if i % print_frequency == 0:
272 | print('Fine-tuning iteration {}'.format(i))
273 | sys.stdout.flush()
274 |
275 | target.unsqueeze_(1)
276 | target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES)
277 | target_onehot.zero_()
278 | target_onehot.scatter_(1, target, 1)
279 | target.squeeze_(1)
280 | input, target = input.cuda(), target.cuda()
281 | target_onehot = target_onehot.cuda()
282 |
283 | pred = model(input)
284 | loss = self.criterion(pred, target_onehot)
285 | optimizer.zero_grad()
286 | loss.backward() # compute gradient and do SGD step
287 | optimizer.step()
288 | return model
289 |
290 |
291 | def evaluate(self, model, print_frequency=10):
292 | '''
293 | Evaluate the accuracy of the model
294 |
295 | Input:
296 | `model`: model to be evaluated.
297 | `print_frequency`: how often to print evaluation info.
298 |
299 | Output:
300 | accuracy: (float) (0~100)
301 | '''
302 |
303 | model = model.cuda()
304 | model.eval()
305 | acc = .0
306 | num_samples = .0
307 | with torch.no_grad():
308 | for i, (input, target) in enumerate(self.val_loader):
309 | input, target = input.cuda(), target.cuda()
310 | pred = model(input)
311 | pred = pred.argmax(dim=1)
312 | batch_acc = torch.sum(target == pred)
313 | acc += batch_acc.item()
314 | num_samples += pred.shape[0]
315 |
316 | if i % print_frequency == 0:
317 | fns.update_progress(i, len(self.val_loader))
318 | print(' ')
319 | print(' ')
320 | print('Test accuracy: {:4.2f}% '.format(float(acc/num_samples*100)))
321 | print('===================================================================')
322 | return acc/num_samples*100
323 |
324 |
325 | def mobilenet(model, input_data_shape, dataset_path, finetune_lr=1e-3):
326 | return networkUtils_mobilenet(model, input_data_shape, dataset_path, finetune_lr)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch == 1.2.0
2 | torchvision == 0.4.0
3 | numpy == 1.17.0
4 | scipy == 1.3.1
--------------------------------------------------------------------------------
/scripts/netadapt_alexnet-0.5latency.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/alexnet/prune-by-latency 3 224 224 \
2 | -im models/alexnet/model.pth.tar -gp 0 1 2 3 4 5 6 \
3 | -mi 30 -bur 0.25 -rt LATENCY -irr 0.025 -rd 0.96 \
4 | -lr 0.001 -st 500 -lt latency_lut/lut_alexnet.pkl \
5 | -dp data/ --arch alexnet
--------------------------------------------------------------------------------
/scripts/netadapt_alexnet-0.5mac.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/alexnet/prune-by-mac 3 224 224 \
2 | -im models/alexnet/model.pth.tar -gp 0 1 2 3 4 5 6 \
3 | -mi 30 -bur 0.25 -rt FLOPS -irr 0.025 -rd 0.96 \
4 | -lr 0.001 -st 500 \
5 | -dp data/ --arch alexnet
--------------------------------------------------------------------------------
/scripts/netadapt_helloworld.sh:
--------------------------------------------------------------------------------
1 | python master.py models/helloworld/netadapt 3 32 32 \
2 | -gp 0 1 2 -mi 3 -bur 0.25 -rt FLOPS \
3 | -irr 0.025 -rd 1.0 -lr 0.001 -st 5 \
4 | -im models/helloworld/model_0.pth.tar \
5 | -lt models/helloworld/lut.pkl -dp data/ \
6 | --arch helloworld -si 1
--------------------------------------------------------------------------------
/scripts/netadapt_mobilenet-0.5latency.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/mobilenet/prune-by-latency 3 224 224 \
2 | -im models/mobilenet/model.pth.tar -gp 0 1 2 3 4 5 6 \
3 | -mi 30 -bur 0.25 -rt LATENCY -irr 0.025 -rd 0.96 \
4 | -lr 0.001 -st 500 -lt latency_lut/lut_mobilenet.pkl \
5 | -dp data/ --arch mobilenet
--------------------------------------------------------------------------------
/scripts/netadapt_mobilenet-0.5mac.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/mobilenet/prune-by-mac 3 224 224 \
2 | -im models/mobilenet/model.pth.tar -gp 0 1 2 3 4 5 6 \
3 | -mi 30 -bur 0.25 -rt FLOPS -irr 0.025 -rd 0.96 \
4 | -lr 0.001 -st 500 \
5 | -dp data/ --arch mobilenet
--------------------------------------------------------------------------------
/scripts/unittest.sh:
--------------------------------------------------------------------------------
1 | cd unittest
2 | python unittest_network_utils_helloworld.py
3 | python unittest_network_utils_alexnet.py
4 | python unittest_network_utils_mobilenet.py
5 | python unittest_worker_helloworld.py
6 | cp unittest_master_helloworld.py ../unittest_master_helloworld.py
7 | cd ..
8 | python unittest_master_helloworld.py
9 | rm unittest_master_helloworld.py
--------------------------------------------------------------------------------
/scripts/unittest_alexnet.sh:
--------------------------------------------------------------------------------
1 | cd unittest
2 | python unittest_network_utils_alexnet.py
3 |
--------------------------------------------------------------------------------
/scripts/unittest_helloworld.sh:
--------------------------------------------------------------------------------
1 | cd unittest
2 | python unittest_network_utils_helloworld.py
3 | python unittest_worker_helloworld.py
4 | cp unittest_master_helloworld.py ../unittest_master_helloworld.py
5 | cd ..
6 | python unittest_master_helloworld.py
7 | rm unittest_master_helloworld.py
--------------------------------------------------------------------------------
/scripts/unittest_mobilenet.sh:
--------------------------------------------------------------------------------
1 | cd unittest
2 | python unittest_network_utils_mobilenet.py
3 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | import time
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.transforms as transforms
8 | import torchvision.datasets as datasets
9 | import torch.backends.cudnn as cudnn
10 |
11 | import nets as models
12 | import functions as fns
13 |
14 | _NUM_CLASSES = 10
15 |
16 | model_names = sorted(name for name in models.__dict__
17 | if name.islower() and not name.startswith("__")
18 | and callable(models.__dict__[name]))
19 |
20 |
21 | def adjust_learning_rate(optimizer, epoch, args):
22 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
23 | lr = args.lr * (0.1 ** (epoch // 50))
24 | for param_group in optimizer.param_groups:
25 | param_group['lr'] = lr
26 |
27 |
28 | class AverageMeter(object):
29 | """Computes and stores the average and current value"""
30 | def __init__(self):
31 | self.reset()
32 |
33 | def reset(self):
34 | self.val = 0
35 | self.avg = 0
36 | self.sum = 0
37 | self.count = 0
38 |
39 | def get_avg(self):
40 | return self.avg
41 |
42 | def update(self, val, n=1):
43 | self.val = val
44 | self.sum += val * n
45 | self.count += n
46 | self.avg = self.sum / self.count
47 |
48 |
49 | def compute_accuracy(output, target):
50 | output = output.argmax(dim=1)
51 | acc = 0.0
52 | acc = torch.sum(target == output).item()
53 | acc = acc/output.size(0)*100
54 | return acc
55 |
56 |
57 | def train(train_loader, model, criterion, optimizer, epoch, args):
58 | batch_time = AverageMeter()
59 | losses = AverageMeter()
60 | acc = AverageMeter()
61 |
62 | # switch to train mode
63 | model.train()
64 |
65 | print('===================================================================')
66 | end = time.time()
67 |
68 | for i, (images, target) in enumerate(train_loader):
69 | target.unsqueeze_(1)
70 | target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES)
71 | target_onehot.zero_()
72 | target_onehot.scatter_(1, target, 1)
73 | target.squeeze_(1)
74 |
75 | if not args.no_cuda:
76 | images = images.cuda()
77 | target_onehot = target_onehot.cuda()
78 | target = target.cuda()
79 |
80 | # compute output and loss
81 | output = model(images)
82 | loss = criterion(output, target_onehot)
83 |
84 | # measure accuracy and record loss
85 | batch_acc = compute_accuracy(output, target)
86 |
87 | losses.update(loss.item(), images.size(0))
88 | acc.update(batch_acc, images.size(0))
89 |
90 | # compute gradient and do SGD step
91 | optimizer.zero_grad()
92 | loss.backward()
93 | optimizer.step()
94 |
95 | # measure elapsed time
96 | batch_time.update(time.time() - end)
97 | end = time.time()
98 |
99 | # Update statistics
100 | estimated_time_remained = batch_time.get_avg()*(len(train_loader)-i-1)
101 | fns.update_progress(i, len(train_loader),
102 | ESA='{:8.2f}'.format(estimated_time_remained)+'s',
103 | loss='{:4.2f}'.format(loss.item()),
104 | acc='{:4.2f}%'.format(float(batch_acc))
105 | )
106 |
107 | print()
108 | print('Finish epoch {}: time = {:8.2f}s, loss = {:4.2f}, acc = {:4.2f}%'.format(
109 | epoch+1, batch_time.get_avg()*len(train_loader),
110 | float(losses.get_avg()), float(acc.get_avg())))
111 | print('===================================================================')
112 | return
113 |
114 |
115 | def eval(test_loader, model, args):
116 | batch_time = AverageMeter()
117 | acc = AverageMeter()
118 |
119 | # switch to eval mode
120 | model.eval()
121 |
122 | end = time.time()
123 | for i, (images, target) in enumerate(test_loader):
124 | if not args.no_cuda:
125 | images = images.cuda()
126 | target = target.cuda()
127 | output = model(images)
128 | batch_acc = compute_accuracy(output, target)
129 | acc.update(batch_acc, images.size(0))
130 | batch_time.update(time.time() - end)
131 | end = time.time()
132 |
133 | # Update statistics
134 | estimated_time_remained = batch_time.get_avg()*(len(test_loader)-i-1)
135 | fns.update_progress(i, len(test_loader),
136 | ESA='{:8.2f}'.format(estimated_time_remained)+'s',
137 | acc='{:4.2f}'.format(float(batch_acc))
138 | )
139 | print()
140 | print('Test accuracy: {:4.2f}% (time = {:8.2f}s)'.format(
141 | float(acc.get_avg()), batch_time.get_avg()*len(test_loader)))
142 | print('===================================================================')
143 | return float(acc.get_avg())
144 |
145 |
146 | if __name__ == '__main__':
147 | # Parse the input arguments.
148 | arg_parser = ArgumentParser()
149 | arg_parser.add_argument('data', metavar='DIR', help='path to dataset')
150 | arg_parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
151 | help='number of data loading workers (default: 4)')
152 | arg_parser.add_argument('--epochs', default=150, type=int, metavar='N',
153 | help='number of total epochs to run (default: 150)')
154 | arg_parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
155 | help='manual epoch number (useful on restarts)')
156 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet',
157 | choices=model_names,
158 | help='model architecture: ' +
159 | ' | '.join(model_names) +
160 | ' (default: alexnet)')
161 | arg_parser.add_argument('-b', '--batch-size', default=128, type=int,
162 | metavar='N',
163 | help='batch size (default: 128)')
164 | arg_parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
165 | metavar='LR', help='initial learning rate (defult: 0.1)', dest='lr')
166 | arg_parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
167 | help='momentum (default: 0.9)')
168 | arg_parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,
169 | metavar='W', help='weight decay (default: 5e-4)',
170 | dest='weight_decay')
171 | arg_parser.add_argument('--resume', default='', type=str, metavar='PATH',
172 | help='path to latest checkpoint (default: none)')
173 | arg_parser.add_argument('--dir', type=str, default='models/', dest='save_dir',
174 | help='path to save models (default: models/')
175 | arg_parser.add_argument('--no-cuda', action='store_true', default=False, dest='no_cuda',
176 | help='disables training on GPU')
177 | args = arg_parser.parse_args()
178 | print(args)
179 |
180 | path = os.path.dirname(args.save_dir)
181 | if not os.path.exists(path):
182 | os.makedirs(path)
183 | print('Create new directory `{}`'.format(path))
184 |
185 | # Data loader
186 | train_dataset = datasets.CIFAR10(root=args.data, train=True, download=True,
187 | transform=transforms.Compose([
188 | transforms.RandomCrop(32, padding=4),
189 | transforms.Resize(224),
190 | transforms.RandomHorizontalFlip(),
191 | transforms.ToTensor(),
192 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
193 | ]))
194 | train_loader = torch.utils.data.DataLoader(
195 | train_dataset, batch_size=args.batch_size, shuffle=True,
196 | num_workers=args.workers, pin_memory=True)
197 | test_dataset = datasets.CIFAR10(root=args.data, train=False, download=True,
198 | transform=transforms.Compose([
199 | transforms.Resize(224),
200 | transforms.ToTensor(),
201 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
202 | ]))
203 | test_loader = torch.utils.data.DataLoader(
204 | test_dataset, batch_size=args.batch_size, shuffle=False,
205 | num_workers=args.workers, pin_memory=True)
206 |
207 | # Network
208 | cudnn.benchmark = True
209 | num_classes = _NUM_CLASSES
210 | model_arch = args.arch
211 | model = models.__dict__[model_arch](num_classes=num_classes)
212 | criterion = nn.BCEWithLogitsLoss()
213 | if not args.no_cuda:
214 | model = model.cuda()
215 | criterion = criterion.cuda()
216 |
217 | # optionally resume from a checkpoint
218 | if args.resume:
219 | if os.path.isfile(args.resume):
220 | print("Loading checkpoint '{}'".format(args.resume))
221 | model = torch.load(args.resume)
222 |
223 | else:
224 | print("No checkpoint found at '{}'".format(args.resume))
225 |
226 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
227 | momentum=args.momentum,
228 | weight_decay=args.weight_decay)
229 |
230 | # Train & evaluation
231 | best_acc = 0
232 | filename = os.path.join(args.save_dir)
233 |
234 | for epoch in range(args.start_epoch, args.epochs):
235 | print('Epoch [{}/{}]'.format(epoch+1, args.epochs - args.start_epoch))
236 | adjust_learning_rate(optimizer, epoch, args)
237 | # train for one epoch
238 | train(train_loader, model, criterion, optimizer, epoch, args)
239 | acc = eval(test_loader, model, args)
240 |
241 | if acc > best_acc:
242 | torch.save(model, filename)
243 | best_acc = acc
244 | print('Save model: ' + filename)
245 | print(' ')
246 | print('Best accuracy:', best_acc)
247 |
248 | model = torch.load(filename)
249 | print(model)
250 |
251 | best_acc = eval(test_loader, model, args)
252 | print('Best accuracy:', best_acc)
--------------------------------------------------------------------------------
/unittest/unittest_master_helloworld.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import network_utils as networkUtils
4 | import nets as models
5 | import unittest
6 | import pickle
7 | from constants import *
8 | import subprocess
9 | import common
10 | import sys
11 | import shutil
12 |
13 |
14 | MODEL_ARCH = 'helloworld'
15 | INPUT_DATA_SHAPE = (3, 32, 32)
16 |
17 | FLOPS_LOOKUP_TABLE_PATH = os.path.join('models', MODEL_ARCH, 'lut.pkl')
18 |
19 | MODEL_PATH = os.path.join('models', MODEL_ARCH, 'model_0.pth.tar')
20 |
21 | model = models.__dict__[MODEL_ARCH]()
22 | for i in range(4):
23 | layer = getattr(model.features, str(i*2))
24 | layer.weight.data = torch.zeros_like(layer.weight.data)
25 | torch.save(model, MODEL_PATH)
26 |
27 | DATASET_PATH = './'
28 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH)
29 |
30 | SHORT_TERM_FINE_TUNE_ITERATION = 5
31 | MAX_ITERS = 3
32 | BUDGET_RATIO = 0.8
33 | INIT_REDUCTION_RATIO = 0.025
34 | REDUCTION_DECAY = 1.0
35 | FINETUNE_LR = 0.001
36 | SAVE_INTERVAL = 1
37 |
38 |
39 | def run_master(working_folder, resource_type='WEIGHT',
40 | budget_ratio=BUDGET_RATIO,
41 | budget=None,
42 | init_reduction_ratio=INIT_REDUCTION_RATIO,
43 | init_reduction=None,
44 | reduction_decay=REDUCTION_DECAY,
45 | finetune_lr=FINETUNE_LR,
46 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
47 | max_iters=MAX_ITERS,
48 | lookup_table_path=None,
49 | resume=False):
50 | if not os.path.exists(working_folder):
51 | os.mkdir(working_folder)
52 | print('Create directory', working_folder)
53 | with open(os.path.join(working_folder, 'master_log.txt'), 'w') as file_id:
54 | command_list = [sys.executable, 'master.py', working_folder, str(3), str(32), str(32),
55 | '-im', MODEL_PATH,
56 | '-gp', str(0), str(1), str(2),
57 | '-mi', str(max_iters),
58 | '-bur', str(budget_ratio),
59 | '-rt', resource_type,
60 | '-irr', str(init_reduction_ratio),
61 | '-rd', str(reduction_decay),
62 | '-lr', str(finetune_lr),
63 | '-st', str(short_term_fine_tune_iteration),
64 | '-dp', DATASET_PATH,
65 | '--arch', MODEL_ARCH,
66 | '-si', str(SAVE_INTERVAL)]
67 | if lookup_table_path != None:
68 | command_list = command_list + ['-lt', lookup_table_path]
69 | if resume:
70 | command_list = command_list + ['--resume']
71 | if init_reduction != None:
72 | command_list = command_list + ['-ir', str(init_reduction)]
73 | if budget != None:
74 | command_list = command_list + ['-bu', str(budget)]
75 | print(command_list)
76 | return subprocess.call(command_list, stdout=file_id, stderr=file_id)
77 |
78 |
79 | class NormalUsage(unittest.TestCase):
80 | '''
81 | No ValueError
82 | '''
83 | def __init__(self, *args, **kwargs):
84 | super(NormalUsage, self).__init__(*args, **kwargs)
85 |
86 |
87 | def check_master_results(self, working_folder, acc_gt, res_gt, output_feature_gt, resource_type,
88 | short_term_fine_tune_iteration, max_iters, lookup_table_path):
89 | history_path = os.path.join(working_folder, 'master', 'history.txt')
90 | with open(history_path) as f:
91 | content = f.readlines()
92 | content = [x.strip() for x in content]
93 | self.assertEqual(len(content)-1, max_iters+1, "master/history.txt length error")
94 |
95 | for i in range(1, len(content)):
96 | print('Check iteration {}'.format(i-1))
97 | tokens = content[i].split(',')
98 | print(tokens)
99 |
100 | # check accuracy
101 | load_model_path = tokens[4]
102 | saved_model = torch.load(load_model_path)
103 | acc = network_utils.evaluate(saved_model)
104 | saved_acc = float(tokens[1])
105 | self.assertEqual(acc, saved_acc, "The accuracy of saved model is not equal to that in history.txt")
106 | self.assertEqual(acc_gt[i-1], acc, "The accuracy of saved model is incorrect")
107 |
108 | # check resource
109 | saved_network_def = network_utils.get_network_def_from_model(saved_model)
110 | resource = network_utils.compute_resource(saved_network_def, resource_type, lookup_table_path)
111 | saved_resource = float(tokens[2])
112 | self.assertEqual(resource, saved_resource, "The resource of saved model is not equal to that in history txt")
113 | self.assertEqual(res_gt[i-1], resource, "The resource of saved model is incorrect.")
114 |
115 | # check simplified block idx
116 | if i != 1:
117 | tokens_pre = content[i-1].split(',')
118 | output_features = tokens[5].split(' ')
119 | output_features_pre = tokens_pre[5].split(' ')
120 | find_simplified_block = False
121 | for output_idx in range(len(output_features)):
122 | if output_features[output_idx] != output_features_pre[output_idx]:
123 | if not find_simplified_block:
124 | self.assertEqual(output_idx, int(tokens[3]), "Not simplify the block as described in master/history.txt")
125 | self.assertEqual(output_features[output_idx], str(output_feature_gt[i-1][output_idx]), "Simplified block has incorrect # of output channels")
126 | find_simplified_block = True
127 | else:
128 | self.assertEqual(1, 0, "Simplified block index error")
129 | else:
130 | self.assertTrue(output_idx != int(tokens[3]), "Simplify the incorrect block")
131 |
132 | # check network_def
133 | for idx in range(4):
134 | if idx == 0:
135 | self.assertEqual(saved_network_def[idx], (3, output_feature_gt[i-1][idx]), "network_def of simplified model error")
136 | else:
137 | self.assertEqual(saved_network_def[idx], (output_feature_gt[i-1][idx-1], output_feature_gt[i-1][idx]), "network_def of simplified model error")
138 |
139 | # check model weights
140 | for idx in range(4):
141 | layer = getattr(saved_model.features, str(idx*2))
142 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data) + idx*short_term_fine_tune_iteration*(i-1)))
143 | temp = torch.min(temp)
144 | temp = temp.item()
145 | self.assertTrue(temp, "Model weights after short-term fine-tune are incorrect")
146 |
147 |
148 | def test_master_weights(self):
149 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights')
150 | res_gt = [29232, 28476, 27531, 26181]
151 | acc_gt = [95, 80, 85, 90]
152 | output_feature_gt = [[16, 32, 64, 10],
153 | [16, 32, 62, 10],
154 | [13, 32, 62, 10],
155 | [13, 30, 62, 10]
156 | ]
157 | run_master(working_folder, resource_type='WEIGHTS',
158 | budget_ratio=BUDGET_RATIO,
159 | init_reduction_ratio=INIT_REDUCTION_RATIO,
160 | reduction_decay=REDUCTION_DECAY,
161 | finetune_lr=FINETUNE_LR,
162 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
163 | max_iters=MAX_ITERS,
164 | lookup_table_path=None,
165 | resume=False)
166 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS',
167 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
168 | max_iters=MAX_ITERS, lookup_table_path=None)
169 | shutil.rmtree(working_folder)
170 |
171 |
172 | def test_master_flops_with_built_lookup_table(self):
173 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_with_built_lookup_table')
174 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'lut.pkl')
175 | res_gt = [29232*32*32, 28476*32*32, 27531*32*32, 26181*32*32]
176 | acc_gt = [95, 80, 85, 90]
177 | output_feature_gt = [[16, 32, 64, 10],
178 | [16, 32, 62, 10],
179 | [13, 32, 62, 10],
180 | [13, 30, 62, 10]
181 | ]
182 | run_master(working_folder, resource_type='FLOPS',
183 | budget_ratio=BUDGET_RATIO,
184 | init_reduction_ratio=INIT_REDUCTION_RATIO,
185 | reduction_decay=REDUCTION_DECAY,
186 | finetune_lr=FINETUNE_LR,
187 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
188 | max_iters=MAX_ITERS,
189 | lookup_table_path=lookup_table_path,
190 | resume=False)
191 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS',
192 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
193 | max_iters=MAX_ITERS, lookup_table_path=lookup_table_path)
194 | shutil.rmtree(working_folder)
195 |
196 |
197 | def test_master_flops_without_built_lookup_table(self):
198 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_without_built_lookup_table')
199 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'not_built_flops_lut.pkl')
200 | res_gt = [29232*32*32, 28476*32*32, 27531*32*32, 26181*32*32]
201 | acc_gt = [95, 80, 85, 90]
202 | output_feature_gt = [[16, 32, 64, 10],
203 | [16, 32, 62, 10],
204 | [13, 32, 62, 10],
205 | [13, 30, 62, 10]
206 | ]
207 | run_master(working_folder, resource_type='FLOPS',
208 | budget_ratio=BUDGET_RATIO,
209 | init_reduction_ratio=INIT_REDUCTION_RATIO,
210 | reduction_decay=REDUCTION_DECAY,
211 | finetune_lr=FINETUNE_LR,
212 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
213 | lookup_table_path=lookup_table_path,
214 | resume=False)
215 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS',
216 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
217 | max_iters=MAX_ITERS, lookup_table_path=lookup_table_path)
218 | os.remove(lookup_table_path)
219 | shutil.rmtree(working_folder)
220 |
221 |
222 | def master_flops_without_built_lookup_table_resume(self, working_folder, max_iters_before_resume):
223 | # run `max_iters_before_resume` iteration
224 | # modify history
225 | # resume
226 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'not_built_flops_lut.pkl')
227 | res_gt = [29232*32*32, 28476*32*32, 27531*32*32, 26181*32*32]
228 | acc_gt = [95, 80, 85, 90]
229 | output_feature_gt = [[16, 32, 64, 10],
230 | [16, 32, 62, 10],
231 | [13, 32, 62, 10],
232 | [13, 30, 62, 10]
233 | ]
234 | run_master(working_folder, resource_type='FLOPS',
235 | budget_ratio=BUDGET_RATIO,
236 | init_reduction_ratio=INIT_REDUCTION_RATIO,
237 | reduction_decay=REDUCTION_DECAY,
238 | finetune_lr=FINETUNE_LR,
239 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
240 | max_iters=max_iters_before_resume,
241 | lookup_table_path=lookup_table_path,
242 | resume=False)
243 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS',
244 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
245 | max_iters=max_iters_before_resume, lookup_table_path=lookup_table_path)
246 |
247 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'rb') as file_id:
248 | history_pkl = pickle.load(file_id)
249 | his_args = history_pkl['master_args']
250 | his_args.max_iters = MAX_ITERS
251 | history_pkl['master_args'] = his_args
252 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'wb') as file_id:
253 | pickle.dump(history_pkl, file_id)
254 |
255 | run_master(working_folder, resource_type='FLOPS',
256 | budget_ratio=0,
257 | init_reduction_ratio=0,
258 | reduction_decay=0,
259 | finetune_lr=FINETUNE_LR,
260 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
261 | max_iters=max_iters_before_resume,
262 | lookup_table_path=' ',
263 | resume=True)
264 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS',
265 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
266 | max_iters=MAX_ITERS, lookup_table_path=lookup_table_path)
267 |
268 | os.remove(lookup_table_path)
269 |
270 |
271 | def test_master_weights_budget_met(self):
272 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights_budget_met')
273 | res_gt = [29232, 28476]
274 | acc_gt = [95, 80]
275 | output_feature_gt = [[16, 32, 64, 10],
276 | [16, 32, 62, 10]
277 | ]
278 | run_master(working_folder, resource_type='WEIGHTS',
279 | budget_ratio=0.9999,
280 | init_reduction_ratio=INIT_REDUCTION_RATIO,
281 | reduction_decay=REDUCTION_DECAY,
282 | finetune_lr=FINETUNE_LR,
283 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
284 | lookup_table_path=None,
285 | resume=False)
286 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS',
287 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
288 | max_iters=1, lookup_table_path=None)
289 | history_path = os.path.join(working_folder, 'master', 'history.txt')
290 | with open(history_path) as f:
291 | content = f.readlines()
292 | content = [x.strip() for x in content]
293 | self.assertEqual(len(content)-1, 2, "Master does not terminate when budget is met")
294 | shutil.rmtree(working_folder)
295 |
296 |
297 | def test_master_weights_use_init_reduction_not_ratio(self):
298 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights_use_init_reduction_not_ratio')
299 | res_gt = [29232, 28476, 27531, 26181]
300 | acc_gt = [95, 80, 85, 90]
301 | output_feature_gt = [[16, 32, 64, 10],
302 | [16, 32, 62, 10],
303 | [13, 32, 62, 10],
304 | [13, 30, 62, 10]
305 | ]
306 | run_master(working_folder, resource_type='WEIGHTS',
307 | budget_ratio=BUDGET_RATIO,
308 | init_reduction_ratio=0,
309 | init_reduction=INIT_REDUCTION_RATIO*29232,
310 | reduction_decay=REDUCTION_DECAY,
311 | finetune_lr=FINETUNE_LR,
312 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
313 | lookup_table_path=None,
314 | resume=False)
315 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS',
316 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
317 | max_iters=MAX_ITERS, lookup_table_path=None)
318 | shutil.rmtree(working_folder)
319 |
320 |
321 | def test_master_weights_use_budget_not_ratio(self):
322 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights_use_budget_not_ratio')
323 | res_gt = [29232, 28476, 27531, 26181]
324 | acc_gt = [95, 80, 85, 90]
325 | output_feature_gt = [[16, 32, 64, 10],
326 | [16, 32, 62, 10],
327 | [13, 32, 62, 10],
328 | [13, 30, 62, 10]
329 | ]
330 | run_master(working_folder, resource_type='WEIGHTS',
331 | budget_ratio=0,
332 | budget=29232*BUDGET_RATIO,
333 | init_reduction_ratio=INIT_REDUCTION_RATIO,
334 | reduction_decay=REDUCTION_DECAY,
335 | finetune_lr=FINETUNE_LR,
336 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
337 | lookup_table_path=None,
338 | resume=False)
339 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS',
340 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
341 | max_iters=MAX_ITERS, lookup_table_path=None)
342 | shutil.rmtree(working_folder)
343 |
344 |
345 | def test_master_delete_previous_files_and_not_resume(self):
346 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_delete_previous_files_and_not_resume')
347 | res_gt = [29232, 28476, 27531, 26181]
348 | acc_gt = [95, 80, 85, 90]
349 | output_feature_gt = [[16, 32, 64, 10],
350 | [16, 32, 62, 10],
351 | [13, 32, 62, 10],
352 | [13, 30, 62, 10]
353 | ]
354 | run_master(working_folder, resource_type='WEIGHTS',
355 | budget_ratio=BUDGET_RATIO,
356 | init_reduction_ratio=INIT_REDUCTION_RATIO,
357 | reduction_decay=REDUCTION_DECAY,
358 | finetune_lr=FINETUNE_LR,
359 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
360 | lookup_table_path=None,
361 | resume=False)
362 | for file in os.listdir(os.path.join(working_folder, 'master')):
363 | os.remove(os.path.join(working_folder, 'master', file))
364 | for file in os.listdir(os.path.join(working_folder, 'worker')):
365 | os.remove(os.path.join(working_folder, 'worker', file))
366 | returncode = run_master(working_folder, resource_type='WEIGHTS',
367 | budget_ratio=BUDGET_RATIO,
368 | init_reduction_ratio=INIT_REDUCTION_RATIO,
369 | reduction_decay=REDUCTION_DECAY,
370 | finetune_lr=FINETUNE_LR,
371 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
372 | lookup_table_path=None,
373 | resume=False)
374 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS',
375 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
376 | max_iters=MAX_ITERS, lookup_table_path=None)
377 | self.assertEqual(returncode, 0, "Master function error when all previous files are deleted.")
378 | shutil.rmtree(working_folder)
379 |
380 |
381 | def test_master_resume(self):
382 | max_iters_before_resume = 1
383 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_without_built_lookup_table-resume_' + str(max_iters_before_resume))
384 | self.master_flops_without_built_lookup_table_resume(working_folder, max_iters_before_resume)
385 | shutil.rmtree(working_folder)
386 |
387 | max_iters_before_resume = 2
388 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_without_built_lookup_table-resume_' + str(max_iters_before_resume))
389 | self.master_flops_without_built_lookup_table_resume(working_folder, max_iters_before_resume)
390 | shutil.rmtree(working_folder)
391 |
392 | def test_constraint_tight(self):
393 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_constraint_tight')
394 | res_gt = [29232, 5418, 693]
395 | acc_gt = [95, 80, 85]
396 | output_feature_gt = [[16, 32, 64, 10],
397 | [16, 32, 1, 10],
398 | [1, 32, 1, 10]
399 | ]
400 | run_master(working_folder, resource_type='WEIGHTS',
401 | budget_ratio=0,
402 | init_reduction_ratio=1,
403 | reduction_decay=REDUCTION_DECAY,
404 | finetune_lr=FINETUNE_LR,
405 | max_iters=MAX_ITERS-1,
406 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
407 | lookup_table_path=None,
408 | resume=False)
409 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS',
410 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
411 | max_iters=MAX_ITERS-1, lookup_table_path=None)
412 |
413 | master_log_path = os.path.join(working_folder, 'master_log.txt')
414 | with open(master_log_path) as f:
415 | content = f.readlines()
416 | content = [x.strip() for x in content]
417 |
418 | warning_counter = 0
419 | for line in content:
420 | if 'UserWarning' in line:
421 | warning_counter += 1
422 | self.assertEqual(warning_counter, 2, "Target resource warning by master error")
423 | shutil.rmtree(working_folder)
424 |
425 |
426 | # normal usage
427 | # however, if constraint is too tight or not achievable,
428 | # it will raise ValueError
429 | def test_constraint_not_achievable(self):
430 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_constraint_not_achievable')
431 | res_gt = [29232, 5418, 693, 135]
432 | acc_gt = [95, 80, 85, 90]
433 | output_feature_gt = [[16, 32, 64, 10],
434 | [16, 32, 1, 10],
435 | [1, 32, 1, 10],
436 | [1, 1, 1, 10]
437 | ]
438 | returncode = run_master(working_folder, resource_type='WEIGHTS',
439 | budget_ratio=0,
440 | init_reduction_ratio=1,
441 | reduction_decay=REDUCTION_DECAY,
442 | finetune_lr=FINETUNE_LR,
443 | max_iters=MAX_ITERS+1,
444 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
445 | lookup_table_path=None,
446 | resume=False)
447 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS',
448 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
449 | max_iters=MAX_ITERS, lookup_table_path=None)
450 | self.assertEqual(returncode, 1, "Master resource constraint not achievable error")
451 |
452 | master_log_path = os.path.join(working_folder, 'master_log.txt')
453 | with open(master_log_path) as f:
454 | content = f.readlines()
455 | content = [x.strip() for x in content]
456 |
457 | warning_counter = 0
458 | for line in content:
459 | if 'UserWarning' in line:
460 | warning_counter += 1
461 | self.assertEqual(warning_counter, 3, "Target resource warning by master error")
462 | shutil.rmtree(working_folder)
463 |
464 |
465 | class ValueErrCase(unittest.TestCase):
466 | def __init__(self, *args, **kwargs):
467 | super(ValueErrCase, self).__init__(*args, **kwargs)
468 |
469 |
470 | def test_not_resume_and_previous_files_exist(self):
471 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_not_resume')
472 | returncode = run_master(working_folder, resource_type='WEIGHTS',
473 | budget_ratio=0,
474 | init_reduction_ratio=1,
475 | reduction_decay=REDUCTION_DECAY,
476 | finetune_lr=FINETUNE_LR,
477 | max_iters=1,
478 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
479 | lookup_table_path=None,
480 | resume=False)
481 | self.assertEqual(returncode, 0, "Normal master execution error")
482 | returncode = run_master(working_folder, resource_type='WEIGHTS',
483 | budget_ratio=0,
484 | init_reduction_ratio=1,
485 | reduction_decay=REDUCTION_DECAY,
486 | finetune_lr=FINETUNE_LR,
487 | max_iters=1,
488 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
489 | lookup_table_path=None,
490 | resume=False)
491 | self.assertEqual(returncode, 1, "Master does not detect the error incurred when previous files exist and `--resume` is not specified")
492 | shutil.rmtree(working_folder)
493 |
494 |
495 | def test_resume_no_lookup_table(self):
496 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_resume_no_lookuptable')
497 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'not_built_flops_lut.pkl')
498 | run_master(working_folder, resource_type='FLOPS',
499 | budget_ratio=BUDGET_RATIO,
500 | init_reduction_ratio=INIT_REDUCTION_RATIO,
501 | reduction_decay=REDUCTION_DECAY,
502 | max_iters=1,
503 | finetune_lr=FINETUNE_LR,
504 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
505 | lookup_table_path=lookup_table_path,
506 | resume=False)
507 | os.remove(lookup_table_path)
508 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'rb') as file_id:
509 | history_pkl = pickle.load(file_id)
510 | his_args = history_pkl['master_args']
511 | his_args.max_iters = MAX_ITERS
512 | history_pkl['master_args'] = his_args
513 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'wb') as file_id:
514 | pickle.dump(history_pkl, file_id)
515 | returncode = run_master(working_folder, resource_type='FLOPS',
516 | budget_ratio=BUDGET_RATIO,
517 | init_reduction_ratio=INIT_REDUCTION_RATIO,
518 | reduction_decay=REDUCTION_DECAY,
519 | max_iters=1,
520 | finetune_lr=FINETUNE_LR,
521 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION,
522 | lookup_table_path=lookup_table_path,
523 | resume=True)
524 | self.assertEqual(returncode, 1, "Master does not detect the error incurred when resuming from previous iterations but lookup table not found")
525 | shutil.rmtree(working_folder)
526 |
527 |
528 |
529 | if __name__ == '__main__':
530 |
531 | unittest.main()
532 |
--------------------------------------------------------------------------------
/unittest/unittest_network_utils_alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import sys
4 | sys.path.append(os.path.abspath('../'))
5 | import pickle
6 | import network_utils as networkUtils
7 | import unittest
8 | import nets as models
9 | from constants import *
10 | import copy
11 |
12 | MODEL_ARCH = 'alexnet'
13 | INPUT_DATA_SHAPE = (3, 224, 224)
14 | LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl')
15 | DATASET_PATH = '../data/'
16 |
17 | model = models.__dict__[MODEL_ARCH](num_classes=10)
18 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH)
19 |
20 | class TestNetworkUtils_alexnet(unittest.TestCase):
21 | def __init__(self, *args, **kwargs):
22 | super(TestNetworkUtils_alexnet, self).__init__(*args, **kwargs)
23 |
24 |
25 | def check_network_def(self, network_def, input_channels, output_channels, only_num_channels=False):
26 | self.assertEqual(len(network_def), 8, "network_def length error")
27 | layer_idx = 0
28 |
29 | kernel_size_gt = [(11, 11), (5, 5), (3, 3), (3, 3), (3, 3), (1, 1), (1, 1), (1, 1)]
30 | stride_gt = [(4, 4), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
31 | padding_gt = [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1), (0, 0), (0, 0), (0, 0)]
32 | for layer_name, layer_properties in network_def.items():
33 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], input_channels[layer_idx], "network_def num of input channels error")
34 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], output_channels[layer_idx], "network_def num of output channels error")
35 | self.assertFalse(layer_properties[KEY_IS_DEPTHWISE], "network_def is_depthwise error")
36 | self.assertEqual(layer_properties[KEY_GROUPS], 1, "network_def group error")
37 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], kernel_size_gt[layer_idx], "network_def kernel size error")
38 | self.assertEqual(layer_properties[KEY_PADDING], padding_gt[layer_idx], "network_def padding error")
39 | self.assertEqual(layer_properties[KEY_STRIDE], stride_gt[layer_idx], "network_def stride error")
40 |
41 | if layer_idx < 5:
42 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Conv2d', "network_def layer type string error")
43 | else:
44 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Linear', "network_def layer type string error")
45 |
46 | input_feature_map_spatial_size = [224, 27, 13, 13, 13, 1, 1, 1]
47 | output_feature_map_spatial_size = [55, 27, 13, 13, 13, 1, 1, 1]
48 | if not only_num_channels:
49 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], [1, input_channels[layer_idx],
50 | input_feature_map_spatial_size[layer_idx], input_feature_map_spatial_size[layer_idx]],
51 | "network_def input feature map size error")
52 | self.assertEqual(layer_properties[KEY_OUTPUT_FEATURE_MAP_SIZE], [1, output_channels[layer_idx],
53 | output_feature_map_spatial_size[layer_idx], output_feature_map_spatial_size[layer_idx]],
54 | "network_def output feature map size error")
55 | #print(layer_idx)
56 | layer_idx += 1
57 |
58 |
59 | def gen_layer_weight(self, tensor):
60 | gen_tensor = torch.zeros_like(tensor)
61 | for i in range(gen_tensor.shape[0]):
62 | gen_tensor[i, ::] += i
63 | return gen_tensor
64 |
65 |
66 | def test_network_def(self):
67 | network_def = network_utils.get_network_def_from_model(model)
68 | #print(network_def)
69 |
70 | input_channels = [3, 64, 192, 384, 256, 9216, 4096, 4096]
71 | output_channels = [64, 192, 384, 256, 256, 4096, 4096, 10]
72 | self.check_network_def(network_def, input_channels, output_channels)
73 | self.assertEqual(network_utils.get_num_simplifiable_blocks(), 7, "Num of simplifiable blocks error")
74 |
75 |
76 | def test_compute_resource(self):
77 | network_def = network_utils.get_network_def_from_model(model)
78 | num_w = network_utils.compute_resource(network_def, 'WEIGHTS')
79 | num_mac = network_utils.compute_resource(network_def, 'FLOPS')
80 |
81 | self.assertEqual(num_w, 57035456, "Num of weights error")
82 | self.assertEqual(num_mac, 710133440, "Num of MACs error")
83 |
84 |
85 | def test_extra_history_info(self):
86 | network_def = network_utils.get_network_def_from_model(model)
87 | output_feature_info = network_utils.extra_history_info(network_def)
88 | output_channels = [64, 192, 384, 256, 256, 4096, 4096, 10]
89 | output_channels_str = [str(x) for x in output_channels]
90 | output_feature_info_gt = ' '.join(output_channels_str)
91 | self.assertEqual(output_feature_info, output_feature_info_gt, "extra_history_info error")
92 |
93 |
94 | def delta_to_layer_num_channels(self, delta, simp_block_idx):
95 | input_channels_gt = [3, 64, 192, 384, 256, 9216, 4096, 4096]
96 | output_channels_gt = [64, 192, 384, 256, 256, 4096, 4096, 10]
97 |
98 | output_channels_gt[simp_block_idx] = output_channels_gt[simp_block_idx] - delta
99 | if simp_block_idx == 4:
100 | input_channels_gt[simp_block_idx+1] = input_channels_gt[simp_block_idx+1] - delta*36
101 | else:
102 | input_channels_gt[simp_block_idx+1] = input_channels_gt[simp_block_idx+1] - delta
103 |
104 | return input_channels_gt, output_channels_gt
105 |
106 |
107 | def run_simplify_network_def_and_check_for_one_resource_type(self, constraint, resource_type, simp_block_indices, delta, res_gt):
108 | network_def = network_utils.get_network_def_from_model(model)
109 |
110 | for i in range(len(simp_block_indices)):
111 | simp_block_idx = simp_block_indices[i]
112 | simp_network_def, simp_resource = network_utils.simplify_network_def_based_on_constraint(network_def, simp_block_idx, constraint, resource_type)
113 | self.assertEqual(simp_resource, res_gt[i], "Simplified network resource {} error ({})".format(resource_type, simp_block_idx))
114 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta[i], simp_block_idx)
115 | self.check_network_def(simp_network_def, input_channels_gt, output_channels_gt, only_num_channels=True)
116 | print(i)
117 |
118 | def test_simplify_network_def_based_on_constraint(self):
119 | total_num_w = 57035456
120 | total_num_mac = 710133440
121 | constraint_num_w = total_num_w*0.975
122 | constraint_num_mac = total_num_mac*0.975
123 |
124 | simp_block_indices = [0, 1, 4, 6]
125 | delta_w = [56, 184, 16, 352]
126 | delta_mac = [8, 16, 40, 4088]
127 |
128 | num_w_gt = [56746328, 56105152, 54639296, 55590144]
129 | num_mac_gt = [673355240, 682126016, 688660160, 693348112]
130 |
131 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_w,
132 | resource_type="WEIGHTS", simp_block_indices=simp_block_indices,
133 | delta=delta_w, res_gt=num_w_gt)
134 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_mac,
135 | resource_type="FLOPS", simp_block_indices=simp_block_indices,
136 | delta=delta_mac, res_gt=num_mac_gt)
137 |
138 |
139 | def test_simplify_model_based_on_network_def(self):
140 | network_def = network_utils.get_network_def_from_model(model)
141 | total_num_w = 57035456
142 | constraint_num_w = total_num_w*0.975
143 | simp_block_indices = [0, 1, 4, 6]
144 | delta_w = [56, 184, 16, 352]
145 | topk_w = [8, 8, 240, 3744]
146 |
147 | conv_idx = [0, 3, 6, 8, 10]
148 | fc_idx = [1, 4, 6]
149 |
150 | for i in range(len(simp_block_indices)):
151 | simp_block_idx = simp_block_indices[i]
152 | simp_network_def, _ = network_utils.simplify_network_def_based_on_constraint(network_def,
153 | simp_block_idx, constraint_num_w, "WEIGHTS")
154 | simp_model = network_utils.simplify_model_based_on_network_def(simp_network_def, model)
155 | updated_network_def = network_utils.get_network_def_from_model(simp_model)
156 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta_w[i], simp_block_idx)
157 | self.check_network_def(updated_network_def, input_channels_gt, output_channels_gt)
158 |
159 | for block_idx in range(7):
160 | if block_idx < 5: # conv
161 | layer = getattr(model, 'features')
162 | layer = getattr(layer, str(conv_idx[block_idx]))
163 | simp_layer = getattr(simp_model, 'features')
164 | simp_layer = getattr(simp_layer, str(conv_idx[block_idx]))
165 | else:
166 | layer = getattr(model, 'classifier')
167 | layer = getattr(layer, str(fc_idx[block_idx - 5]))
168 | simp_layer = getattr(simp_model, 'classifier')
169 | simp_layer = getattr(simp_layer, str(fc_idx[block_idx - 5]))
170 |
171 | if block_idx != simp_block_idx and block_idx != simp_block_idx + 1:
172 | equal_weight = (layer.weight.data == simp_layer.weight.data)
173 | equal_bias = (layer.bias.data == simp_layer.bias.data)
174 | self.assertTrue(equal_weight.min(), "simplify_model_based_on_network_def modify unrelated layers (weights)")
175 | self.assertTrue(equal_bias.min(), "simplify_model_based_on_network_def modify unrelated layers (biases)")
176 | elif block_idx == simp_block_idx:
177 | layer_weight = layer.weight.data
178 | layer_weight = layer_weight.view(layer_weight.shape[0], -1)
179 | layer_weight_norm = layer_weight*layer_weight
180 | layer_weight_norm = layer_weight_norm.sum(1)
181 | _, kept_filter_idx = torch.topk(layer_weight_norm, topk_w[i], sorted=False)
182 | kept_filter_idx, _ = torch.sort(kept_filter_idx)
183 |
184 | equal_prune_weights = (layer.weight.data[kept_filter_idx, ::] == simp_layer.weight.data)
185 | self.assertTrue(equal_prune_weights.min(), "Output channels of the pruned layer error")
186 |
187 | equal_prune_biases = (layer.bias.data[kept_filter_idx] == simp_layer.bias.data)
188 | self.assertTrue(equal_prune_biases.min(), "Output channels of the pruned layer error")
189 |
190 | # check the input features of the next layer
191 | if (block_idx + 1) < 5: # conv
192 | next_layer = getattr(model, 'features')
193 | next_layer = getattr(next_layer, str(conv_idx[block_idx + 1]))
194 | simp_next_layer = getattr(simp_model, 'features')
195 | simp_next_layer = getattr(simp_next_layer, str(conv_idx[block_idx + 1]))
196 | else:
197 | next_layer = getattr(model, 'classifier')
198 | next_layer = getattr(next_layer, str(fc_idx[(block_idx + 1) - 5]))
199 | simp_next_layer = getattr(simp_model, 'classifier')
200 | simp_next_layer = getattr(simp_next_layer, str(fc_idx[(block_idx + 1) - 5]))
201 |
202 | if block_idx != 4:
203 | if block_idx < 5:
204 | equal_weights = (next_layer.weight.data[:, kept_filter_idx, ::] == simp_next_layer.weight.data)
205 | else:
206 | equal_weights = (next_layer.weight.data[:, kept_filter_idx] == simp_next_layer.weight.data)
207 | self.assertTrue(equal_weights.min(), "Input channels of the layer after the pruned layer error")
208 | else: # conv -> FC
209 | kept_filter_idx_fc = []
210 | for filter_idx in kept_filter_idx:
211 | for i in range(36):
212 | kept_filter_idx_fc.append(filter_idx*36 + i)
213 | equal_weights = (next_layer.weight.data[:, kept_filter_idx_fc] == simp_next_layer.weight.data)
214 | self.assertTrue(equal_weights.min(), "Input channels of the FC layer after the pruned conv layer error")
215 |
216 |
217 |
218 | def test_build_latency_lookup_table(self):
219 | network_def = network_utils.get_network_def_from_model(model)
220 | lookup_table_path = './unittest_lookup_table.plk'
221 | min_conv_feature_size = 32
222 | min_fc_feature_size = 1024
223 | measure_latency_batch_size = 1
224 | measure_latency_sample_times = 1
225 |
226 | network_utils.build_lookup_table(network_def, 'LATENCY', lookup_table_path, min_conv_feature_size,
227 | min_fc_feature_size, measure_latency_batch_size, measure_latency_sample_times)
228 |
229 | with open(lookup_table_path, 'rb') as file_id:
230 | lookup_table = pickle.load(file_id)
231 | self.assertEqual(len(lookup_table), 8, "Lookup table length error")
232 |
233 | input_channels_gt = [3, 64, 192, 384, 256, 9216, 4096, 4096]
234 | output_channels_gt = [64, 192, 384, 256, 256, 4096, 4096, 10]
235 | layer_idx = 0
236 |
237 | for layer_name, layer_properties in lookup_table.items():
238 | self.assertEqual(layer_properties[KEY_IS_DEPTHWISE], network_def[layer_name][KEY_IS_DEPTHWISE], "lookup table layer properties error (is_depthwise)")
239 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], network_def[layer_name][KEY_NUM_IN_CHANNELS], "lookup table layer properties error (num_in_channels)")
240 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], network_def[layer_name][KEY_NUM_OUT_CHANNELS], "lookup table layer properties error (num_out_channels)")
241 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], network_def[layer_name][KEY_KERNEL_SIZE], "lookup table layer properties error (kernel_size)")
242 | self.assertEqual(layer_properties[KEY_STRIDE], network_def[layer_name][KEY_STRIDE], "lookup table layer properties error (stride)")
243 | self.assertEqual(layer_properties[KEY_PADDING], network_def[layer_name][KEY_PADDING], "lookup table layer properties error (padding)")
244 | self.assertEqual(layer_properties[KEY_GROUPS], network_def[layer_name][KEY_GROUPS], "lookup table layer properties error (groups)")
245 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], network_def[layer_name][KEY_LAYER_TYPE_STR], "lookup table layer properties error (layer_type_str)")
246 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], network_def[layer_name][KEY_INPUT_FEATURE_MAP_SIZE], "lookup table layer properties error (input_feature_size)")
247 |
248 | layer_latency_table = layer_properties[KEY_LATENCY]
249 |
250 | num_in_samples = input_channels_gt[layer_idx]
251 | num_output_samples = output_channels_gt[layer_idx]
252 | if layer_idx < 5:
253 | if num_in_samples < min_conv_feature_size:
254 | num_in_samples = 1
255 | else:
256 | num_in_samples = num_in_samples/min_conv_feature_size
257 | num_output_samples = num_output_samples/min_conv_feature_size
258 | else:
259 | num_in_samples = num_in_samples/min_fc_feature_size
260 | if num_output_samples < min_fc_feature_size:
261 | num_output_samples = 1
262 | else:
263 | num_output_samples = num_output_samples/min_fc_feature_size
264 |
265 | self.assertEqual(len(layer_latency_table), num_in_samples*num_output_samples, "Layerwise latency dict length error (layer index: {})".format(layer_idx))
266 | layer_idx += 1
267 |
268 | os.remove(lookup_table_path)
269 |
270 | if __name__ == '__main__':
271 | unittest.main()
272 |
--------------------------------------------------------------------------------
/unittest/unittest_network_utils_helloworld.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import sys
4 | import pickle
5 | sys.path.append(os.path.abspath('../'))
6 |
7 | import network_utils as networkUtils
8 | import unittest
9 | import nets as models
10 | from constants import *
11 | import copy
12 |
13 | MODEL_ARCH = 'helloworld'
14 | INPUT_DATA_SHAPE = (3, 32, 32)
15 | LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl')
16 |
17 | model = models.__dict__[MODEL_ARCH]()
18 | for i in range(4):
19 | layer = getattr(model.features, str(i*2))
20 | layer.weight.data = torch.zeros_like(layer.weight.data)
21 | model = model.cuda()
22 |
23 | network_utils = networkUtils.helloworld(model, INPUT_DATA_SHAPE)
24 |
25 | class TestNetworkUtils_helloworld(unittest.TestCase):
26 | def __init__(self, *args, **kwargs):
27 | super(TestNetworkUtils_helloworld, self).__init__(*args, **kwargs)
28 |
29 |
30 | def check_network_def(self, network_def, input_channels, output_channels):
31 | for layer_idx in range(len(network_def)):
32 | self.assertEqual(network_def[layer_idx][0], input_channels[layer_idx], "Number of input channels in network_def not match with that of the model")
33 | self.assertEqual(network_def[layer_idx][1], output_channels[layer_idx], "Number of output channels in network_def not match with that of the model")
34 | self.assertEqual(len(network_def), 4, "Number of layers in network_def is not equal to that in the original model")
35 |
36 |
37 | def test_num_sumplifiable_blocks(self):
38 | self.assertEqual(network_utils.get_num_simplifiable_blocks(), 3, "Number of simplifiabe blocks error")
39 |
40 |
41 | def test_network_def(self):
42 | network_def = network_utils.get_network_def_from_model(model)
43 | output_channels = [16, 32, 64, 10]
44 | input_channels = [3, 16, 32, 64]
45 | for layer_idx in range(len(network_def)):
46 | self.assertEqual(network_def[layer_idx][0], input_channels[layer_idx], "Number of input channels in network_def not match with that of the model")
47 | self.assertEqual(network_def[layer_idx][1], output_channels[layer_idx], "Number of output channels in network_def not match with that of the model")
48 | self.assertEqual(len(network_def), 4, "Number of layers in network_def is not equal to that in the original model")
49 |
50 |
51 | def test_compute_weights(self):
52 | network_def = network_utils.get_network_def_from_model(model)
53 | num_w = network_utils.compute_resource(network_def, 'WEIGHTS')
54 | num_w_gt = 3*3*(3*16 + 16*32 + 32*64 + 64*10)
55 | self.assertEqual(num_w, num_w_gt, "Number of weights error")
56 |
57 |
58 | def test_compute_flops(self):
59 | network_def = network_utils.get_network_def_from_model(model)
60 | network_utils.build_lookup_table(network_def, resource_type='FLOPS', lookup_table_path=LOOKUP_TABLE_PATH)
61 | with open(LOOKUP_TABLE_PATH, 'rb') as file_id:
62 | lookup_table = pickle.load(file_id)
63 | self.assertEqual(len(lookup_table), 4, "Lookup table has wrong number of layers")
64 | for layer_idx in range(4):
65 | feature_flops_dict = lookup_table[layer_idx]
66 | for feature_points in feature_flops_dict.keys():
67 | entry = feature_flops_dict[feature_points]
68 | self.assertEqual(32*32*9*feature_points[0]*feature_points[1], entry, "Lookup table entry error")
69 | flops = network_utils.compute_resource(network_def, 'FLOPS', lookup_table_path=LOOKUP_TABLE_PATH)
70 | flops_gt = 3*3*(3*16 + 16*32 + 32*64 + 64*10)*32*32
71 | self.assertEqual(flops, flops_gt, "FLOPS estimation error")
72 |
73 |
74 | def test_extra_history_info(self):
75 | network_def = network_utils.get_network_def_from_model(model)
76 | network_def[0] = (3, 8)
77 | network_def[1] = (8, 32)
78 | num_filters_str = network_utils.extra_history_info(network_def)
79 | self.assertEqual(num_filters_str, "8 32 64 10", "Network architecture in extra_history_info is wrong")
80 |
81 |
82 | def test_simplift_network_def(self):
83 | network_def = network_utils.get_network_def_from_model(model)
84 |
85 | constraint_1 = 29232 - 3*3*(8*3 + 8*32)
86 | simp_network_def_1, simp_resource_1 = network_utils.simplify_network_def_based_on_constraint(network_def, 0, constraint_1, "WEIGHTS")
87 | self.check_network_def(simp_network_def_1, [3, 8, 32, 64], [8, 32, 64, 10])
88 |
89 | constraint_2 = 29232 - 3*3*(16*16 + 16*64)
90 | simp_network_def_2, simp_resource_2 = network_utils.simplify_network_def_based_on_constraint(network_def, 1, constraint_2, "WEIGHTS")
91 | self.check_network_def(simp_network_def_2, [3, 16, 16, 64], [16, 16, 64, 10])
92 |
93 | constraint_3 = 29232 - 3*3*(16*32 + 16*10)
94 | simp_network_def_3, simp_resource_3 = network_utils.simplify_network_def_based_on_constraint(network_def, 2, constraint_3, "WEIGHTS")
95 | self.check_network_def(simp_network_def_3, [3, 16, 32, 48], [16, 32, 48, 10])
96 |
97 |
98 | def test_simplift_network_def_constraint_too_tight(self):
99 | network_def = network_utils.get_network_def_from_model(model)
100 |
101 | constraint_1 = -1
102 | simp_network_def_1, simp_resource_1 = network_utils.simplify_network_def_based_on_constraint(network_def, 0, constraint_1, "WEIGHTS")
103 | self.check_network_def(simp_network_def_1, [3, 1, 32, 64], [1, 32, 64, 10])
104 |
105 | constraint_2 = 0
106 | simp_network_def_2, simp_resource_2 = network_utils.simplify_network_def_based_on_constraint(network_def, 1, constraint_2, "WEIGHTS")
107 | self.check_network_def(simp_network_def_2, [3, 16, 1, 64], [16, 1, 64, 10])
108 |
109 | constraint_3 = 1
110 | simp_network_def_3, simp_resource_3 = network_utils.simplify_network_def_based_on_constraint(network_def, 2, constraint_3, "WEIGHTS")
111 | self.check_network_def(simp_network_def_3, [3, 16, 32, 1], [16, 32, 1, 10])
112 |
113 |
114 | def test_simplify_model(self):
115 | network_def = network_utils.get_network_def_from_model(model)
116 |
117 | constraint_1 = 29232 - 3*3*(9*3 + 9*32)
118 | simp_network_def_1, simp_resource_1 = network_utils.simplify_network_def_based_on_constraint(network_def, 0, constraint_1, "WEIGHTS")
119 | self.check_network_def(simp_network_def_1, [3, 7, 32, 64], [7, 32, 64, 10])
120 | simp_model_1 = network_utils.simplify_model_based_on_network_def(simp_network_def_1, model)
121 | update_network_def_1 = network_utils.get_network_def_from_model(simp_model_1)
122 | self.check_network_def(update_network_def_1, [3, 7, 32, 64], [7, 32, 64, 10])
123 |
124 | constraint_2 = 29232 - 3*3*(19*16 + 19*64)
125 | simp_network_def_2, simp_resource_2 = network_utils.simplify_network_def_based_on_constraint(network_def, 1, constraint_2, "WEIGHTS")
126 | self.check_network_def(simp_network_def_2, [3, 16, 13, 64], [16, 13, 64, 10])
127 | simp_model_2 = network_utils.simplify_model_based_on_network_def(simp_network_def_2, model)
128 | update_network_def_2 = network_utils.get_network_def_from_model(simp_model_2)
129 | self.check_network_def(update_network_def_2, [3, 16, 13, 64], [16, 13, 64, 10])
130 |
131 | constraint_3 = 29232 - 3*3*(15*32 + 15*10)
132 | simp_network_def_3, simp_resource_3 = network_utils.simplify_network_def_based_on_constraint(network_def, 2, constraint_3, "WEIGHTS")
133 | self.check_network_def(simp_network_def_3, [3, 16, 32, 49], [16, 32, 49, 10])
134 | simp_model_3 = network_utils.simplify_model_based_on_network_def(simp_network_def_3, model)
135 | update_network_def_3 = network_utils.get_network_def_from_model(simp_model_3)
136 | self.check_network_def(update_network_def_3, [3, 16, 32, 49], [16, 32, 49, 10])
137 |
138 |
139 | acc_1 = network_utils.evaluate(simp_model_1)
140 | self.assertEqual(acc_1, 1, "Evaluation function error")
141 |
142 | acc_2 = network_utils.evaluate(simp_model_2)
143 | self.assertEqual(acc_2, 5, "Evaluation function error")
144 |
145 | acc_3 = network_utils.evaluate(simp_model_3)
146 | self.assertEqual(acc_3, 80, "Evaluation function error")
147 |
148 | input = torch.randn((4, 3, 32, 32)).cuda()
149 | _ = simp_model_1(input)
150 | _ = simp_model_2(input)
151 | _ = simp_model_3(input)
152 |
153 |
154 | def test_finetune(self):
155 | model_0 = copy.deepcopy(model)
156 |
157 | for layer_idx in range(4):
158 | layer = getattr(model_0.features, str(layer_idx*2))
159 | layer.weight.data = torch.zeros_like(layer.weight.data)
160 |
161 | model_finetune = network_utils.fine_tune(model_0, iterations=1)
162 | for layer_idx in range(4):
163 | layer = getattr(model_finetune.features, str(layer_idx*2))
164 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data).cuda() + layer_idx))
165 | temp = torch.min(temp)
166 | temp = temp.item()
167 | self.assertTrue(temp, "Fintune function error when iteration = 1")
168 |
169 | model_finetune_0 = network_utils.fine_tune(model_0, iterations=0)
170 | for layer_idx in range(4):
171 | layer = getattr(model_finetune_0.features, str(layer_idx*2))
172 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data).cuda()))
173 | temp = torch.min(temp)
174 | temp = temp.item()
175 | self.assertTrue(temp, "Finetune function error when iteration = 0")
176 |
177 |
178 | if __name__ == '__main__':
179 | unittest.main()
180 |
--------------------------------------------------------------------------------
/unittest/unittest_network_utils_mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import sys
4 | import pickle
5 | sys.path.append(os.path.abspath('../'))
6 |
7 | import network_utils as networkUtils
8 | import unittest
9 | import nets as models
10 | from constants import *
11 | import copy
12 |
13 | MODEL_ARCH = 'mobilenet'
14 | INPUT_DATA_SHAPE = (3, 224, 224)
15 | LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl')
16 | DATASET_PATH = '../data/'
17 |
18 | model = models.__dict__[MODEL_ARCH](num_classes=10)
19 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH)
20 |
21 | class TestNetworkUtils_mobilenet(unittest.TestCase):
22 | def __init__(self, *args, **kwargs):
23 | super(TestNetworkUtils_mobilenet, self).__init__(*args, **kwargs)
24 |
25 |
26 | def check_network_def(self, network_def, input_channels, output_channels, only_num_channels=False):
27 | self.assertEqual(len(network_def), 28, "network_def length error")
28 | layer_idx = 0
29 | for layer_name, layer_properties in network_def.items():
30 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], input_channels[layer_idx], "network_def num of input channels error")
31 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], output_channels[layer_idx], "network_def num of output channels error")
32 |
33 | if layer_idx % 2 == 1 and layer_idx != 27:
34 | self.assertTrue(layer_properties[KEY_IS_DEPTHWISE], "network_def is_depthwise error")
35 | self.assertEqual(layer_properties[KEY_GROUPS], layer_properties[KEY_NUM_IN_CHANNELS], "network_def group error")
36 | else:
37 | self.assertFalse(layer_properties[KEY_IS_DEPTHWISE], "network_def is_depthwise error")
38 | self.assertEqual(layer_properties[KEY_GROUPS], 1, "network_def group error")
39 | if layer_idx == 27 or (layer_idx % 2 == 0 and layer_idx != 0):
40 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], (1, 1), "network_def kernel size error")
41 | self.assertEqual(layer_properties[KEY_PADDING], (0, 0), "network_def padding error")
42 | else:
43 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], (3, 3), "network_def kernel size error")
44 | self.assertEqual(layer_properties[KEY_PADDING], (1, 1), "network_def padding error")
45 | if layer_idx != 27:
46 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Conv2d', "network_def layer type string error")
47 | else:
48 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Linear', "network_def layer type string error")
49 | if layer_idx in [0, 3, 7, 11, 23]:
50 | self.assertEqual(layer_properties[KEY_STRIDE], (2, 2), "network_def stride error")
51 | else:
52 | self.assertEqual(layer_properties[KEY_STRIDE], (1, 1), "network_def stride error")
53 | input_feature_map_spatial_size = [224, 112, 112, 112, 56, 56, 56, 56, 28, 28, 28, 28, 14,
54 | 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 7, 7, 7, 1]
55 | output_feature_map_spatial_size = [112, 112, 112, 56, 56, 56, 56, 28, 28, 28, 28, 14,
56 | 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 7, 7, 7, 7, 1]
57 | if not only_num_channels:
58 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], [1, input_channels[layer_idx],
59 | input_feature_map_spatial_size[layer_idx], input_feature_map_spatial_size[layer_idx]],
60 | "network_def input feature map size error")
61 | self.assertEqual(layer_properties[KEY_OUTPUT_FEATURE_MAP_SIZE], [1, output_channels[layer_idx],
62 | output_feature_map_spatial_size[layer_idx], output_feature_map_spatial_size[layer_idx]],
63 | "network_def output feature map size error")
64 | #print(layer_idx)
65 | layer_idx += 1
66 |
67 |
68 | def gen_layer_weight(self, tensor):
69 | gen_tensor = torch.zeros_like(tensor)
70 | for i in range(gen_tensor.shape[0]):
71 | gen_tensor[i, ::] += i
72 | return gen_tensor
73 |
74 |
75 | def test_network_def(self):
76 | network_def = network_utils.get_network_def_from_model(model)
77 | #print(network_def)
78 | #print(len(network_def))
79 | input_channels = [3, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512,
80 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024]
81 | output_channels = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512,
82 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10]
83 | self.check_network_def(network_def, input_channels, output_channels)
84 | self.assertEqual(network_utils.get_num_simplifiable_blocks(), 14, "Num of simplifiable blocks error")
85 |
86 |
87 | def test_compute_resource(self):
88 | network_def = network_utils.get_network_def_from_model(model)
89 | num_w = network_utils.compute_resource(network_def, 'WEIGHTS')
90 | num_mac = network_utils.compute_resource(network_def, 'FLOPS')
91 | self.assertEqual(num_w, 3195328, "Num of weights error")
92 | self.assertEqual(num_mac, 567726592, "Num of MACs error")
93 |
94 |
95 | def test_extra_history_info(self):
96 | network_def = network_utils.get_network_def_from_model(model)
97 | output_feature_info = network_utils.extra_history_info(network_def)
98 | output_channels = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512,
99 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10]
100 | output_channels_str = [str(x) for x in output_channels]
101 | output_feature_info_gt = ' '.join(output_channels_str)
102 | self.assertEqual(output_feature_info, output_feature_info_gt, "extra_history_info error")
103 |
104 |
105 | def delta_to_layer_num_channels(self, delta, simp_block_idx):
106 | input_channels_gt = [3, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512,
107 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024]
108 | output_channels_gt = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512,
109 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10]
110 |
111 | if simp_block_idx == 0:
112 | input_channels_gt[simp_block_idx + 1] = input_channels_gt[simp_block_idx + 1] - delta
113 | input_channels_gt[simp_block_idx + 2] = input_channels_gt[simp_block_idx + 2] - delta
114 | output_channels_gt[simp_block_idx] = output_channels_gt[simp_block_idx] - delta
115 | output_channels_gt[simp_block_idx+1] = output_channels_gt[simp_block_idx+1] - delta
116 | elif simp_block_idx != 13:
117 | print(input_channels_gt)
118 | print(output_channels_gt)
119 | input_channels_gt[2*simp_block_idx+1] = input_channels_gt[2*simp_block_idx+1] - delta
120 | input_channels_gt[2*simp_block_idx+2] = input_channels_gt[2*simp_block_idx+2] - delta
121 | output_channels_gt[2*simp_block_idx] = output_channels_gt[2*simp_block_idx] - delta
122 | output_channels_gt[2*simp_block_idx+1]= output_channels_gt[2*simp_block_idx+1] - delta
123 | else:
124 | output_channels_gt[2*simp_block_idx] = output_channels_gt[2*simp_block_idx] - delta
125 | input_channels_gt[2*simp_block_idx+1] = input_channels_gt[2*simp_block_idx+1] - delta
126 | return input_channels_gt, output_channels_gt
127 |
128 |
129 | def run_simplify_network_def_and_check_for_one_resource_type(self, constraint, resource_type, simp_block_indices, delta, res_gt):
130 | network_def = network_utils.get_network_def_from_model(model)
131 |
132 | for i in range(len(simp_block_indices)):
133 | simp_block_idx = simp_block_indices[i]
134 | simp_network_def, simp_resource = network_utils.simplify_network_def_based_on_constraint(network_def, simp_block_idx, constraint, resource_type)
135 | self.assertEqual(simp_resource, res_gt[i], "Simplified network resource {} error".format(resource_type))
136 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta[i], simp_block_idx)
137 | self.check_network_def(simp_network_def, input_channels_gt, output_channels_gt, only_num_channels=True)
138 |
139 |
140 | def test_simplify_network_def_based_on_constraint(self):
141 | total_num_w = 3195328
142 | total_num_mac = 567726592
143 | constraint_num_w = total_num_w*0.975
144 | constraint_num_mac = total_num_mac*0.975
145 |
146 | simp_block_indices = [0, 1, 5, 7, 9, 11, 13]
147 | delta_w = [24, 56, 104, 80, 80, 56, 80]
148 | delta_mac = [16, 24, 48, 72, 72, 96, 288]
149 |
150 | num_w_gt = [3192928, 3185864, 3114520, 3112688, 3112688, 3108808, 3112608]
151 | num_mac_gt = [547656192, 547781632, 553191232, 553148896, 553148896, 553233568, 553273024]
152 |
153 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_w,
154 | resource_type="WEIGHTS", simp_block_indices=simp_block_indices,
155 | delta=delta_w, res_gt=num_w_gt)
156 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_mac,
157 | resource_type="FLOPS", simp_block_indices=simp_block_indices,
158 | delta=delta_mac, res_gt=num_mac_gt)
159 |
160 |
161 | def test_simplify_model_based_on_network_def(self):
162 | network_def = network_utils.get_network_def_from_model(model)
163 | total_num_w = 3195328
164 | constraint_num_w = total_num_w*0.975
165 | simp_block_indices = [0, 1, 5, 7, 9, 11, 13]
166 | delta_w = [24, 56, 104, 80, 80, 56, 80]
167 | topk_w = [8, 8, 152, 432, 432, 456, 944]
168 |
169 | for i in range(len(simp_block_indices)):
170 | simp_block_idx = simp_block_indices[i]
171 | simp_network_def, _ = network_utils.simplify_network_def_based_on_constraint(network_def,
172 | simp_block_idx, constraint_num_w, "WEIGHTS")
173 | simp_model = network_utils.simplify_model_based_on_network_def(simp_network_def, model)
174 | updated_network_def = network_utils.get_network_def_from_model(simp_model)
175 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta_w[i], simp_block_idx)
176 | self.check_network_def(updated_network_def, input_channels_gt, output_channels_gt)
177 |
178 | conv_layers = getattr(model, 'model')
179 | simp_conv_layers = getattr(simp_model, 'model')
180 | for block_idx in range(14):
181 | module = getattr(conv_layers, str(block_idx))
182 | simp_module = getattr(simp_conv_layers, str(block_idx))
183 | if block_idx != simp_block_idx and block_idx != simp_block_idx + 1:
184 | if block_idx != 0:
185 | for layer_idx in ['0', '1', '3', '4']:
186 | layer = getattr(module, layer_idx)
187 | simp_layer = getattr(simp_module, layer_idx)
188 | if layer_idx in ['0', '3']:
189 | equal = (simp_layer.weight.data == layer.weight.data)
190 | self.assertTrue(equal.min(), "simplify_model_based_on_network_def modify unrelated conv layers")
191 | else:
192 | equal_weight = (simp_layer.weight.data == layer.weight.data)
193 | equal_bias = (simp_layer.bias.data == layer.bias.data)
194 | equal_num_features = (simp_layer.num_features == layer.num_features)
195 | self.assertTrue(equal_weight.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (weight)")
196 | self.assertTrue(equal_bias.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (bias)")
197 | self.assertTrue(equal_num_features, "simplify_model_based_on_network_def modify unrelated batchnorm layers (num_features)")
198 | else:
199 | layer = getattr(module, '0')
200 | simp_layer = getattr(simp_module, '0')
201 | equal = (simp_layer.weight.data == layer.weight.data)
202 | self.assertTrue(equal.min(), "simplify_model_based_on_network_def modify unrelated conv layers")
203 |
204 | layer = getattr(module, '1')
205 | simp_layer = getattr(simp_module, '1')
206 | equal_weight = (simp_layer.weight.data == layer.weight.data)
207 | equal_bias = (simp_layer.bias.data == layer.bias.data)
208 | equal_num_features = (simp_layer.num_features == layer.num_features)
209 | self.assertTrue(equal_weight.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (weight)")
210 | self.assertTrue(equal_bias.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (bias)")
211 | self.assertTrue(equal_num_features, "simplify_model_based_on_network_def modify unrelated batchnorm layers (num_features)")
212 |
213 | elif block_idx == simp_block_idx:
214 | # check (regular/pointwise layer output channels and input channels of the next depthwise layer)
215 | # or check (pointwise layer output channels and nput features of the next FC layer)
216 | if block_idx == 0:
217 | layer = getattr(module, '0')
218 | simp_layer = getattr(simp_module, '0')
219 | else: # pointwise
220 | # first check depthwise layer within the same block
221 | layer = getattr(module, '0')
222 | simp_layer = getattr(module, '0')
223 | equal_dep = (layer.weight.data == simp_layer.weight.data)
224 | self.assertTrue(equal_dep.min(), "Depthwise layer within the target block error")
225 |
226 | layer = getattr(module, '3')
227 | simp_layer = getattr(simp_module, '3')
228 |
229 | layer_weight = layer.weight.data
230 | weight_vector = layer_weight.view(layer_weight.shape[0], -1)
231 | weight_norm = weight_vector*weight_vector
232 | weight_norm = torch.sum(weight_norm, dim=1)
233 | _, kept_filter_idx = torch.topk(weight_norm, topk_w[i], sorted=False)
234 | kept_filter_idx, _ = kept_filter_idx.sort()
235 |
236 | weight_gt = layer_weight[kept_filter_idx, :, :, :]
237 | weight_simp = simp_layer.weight.data
238 | equal_weight = (weight_gt == weight_simp)
239 |
240 | self.assertTrue(equal_weight.min(), "Output channels of the pruned layer error")
241 |
242 | # modify input channels of the next few layers
243 | if block_idx != 13: # depthwise -> batchnorm -> pointwise
244 | next_module = getattr(conv_layers, str(block_idx+1))
245 | simp_next_module = getattr(simp_conv_layers, str(block_idx+1))
246 |
247 | dep_layer = getattr(next_module, '0')
248 | simp_dep_layer = getattr(simp_next_module, '0')
249 | dep_layer_weight = dep_layer.weight.data[kept_filter_idx, :, :, :]
250 | equal_dep_weights = (dep_layer_weight == simp_dep_layer.weight.data)
251 | self.assertTrue(equal_dep_weights.min(), "Input channels of the depthwise layer after pruned layers error")
252 |
253 | batchnorm_layer = getattr(next_module, '1')
254 | simp_batchnorm_layer = getattr(simp_next_module, '1')
255 | batchnorm_layer_weight = batchnorm_layer.weight.data[kept_filter_idx]
256 | equal_batchnorm_weights = (batchnorm_layer_weight == simp_batchnorm_layer.weight.data)
257 | self.assertTrue(equal_batchnorm_weights.min(), "Weights of the batchnorm layer after pruned layers error")
258 |
259 | batchnorm_layer_bias = batchnorm_layer.bias.data[kept_filter_idx]
260 | equal_batchnorm_bias = (batchnorm_layer_bias == simp_batchnorm_layer.bias.data)
261 | self.assertTrue(equal_batchnorm_bias.min(), "Biases of the batchnorm layer after pruned layers error")
262 |
263 | equal_batchnorm_num_features = (len(kept_filter_idx) == simp_batchnorm_layer.num_features)
264 | self.assertTrue(equal_batchnorm_num_features, "Number of features of the batchnorm layer after pruned layers error")
265 |
266 | pt_layer = getattr(next_module, '3')
267 | simp_pt_layer = getattr(simp_next_module, '3')
268 | pt_layer_weight = pt_layer.weight.data[:, kept_filter_idx, :, :]
269 | equal_pt_weights = (pt_layer_weight == simp_pt_layer.weight.data)
270 | self.assertTrue(equal_pt_weights.min(), "Input channels of the pointwise layer after pruned layers error")
271 |
272 | else: # FC
273 | fc_layer = getattr(model, 'fc')
274 | simp_fc_layer = getattr(simp_model, 'fc')
275 | fc_layer_weight = fc_layer.weight.data
276 | fc_layer_weight = fc_layer_weight[:, kept_filter_idx]
277 | equal_fc_weights = (fc_layer_weight == simp_fc_layer.weight.data)
278 | self.assertTrue(equal_fc_weights.min(), "Input features of FC layer error")
279 |
280 | def test_simplify_model_based_on_network_def_check_weights(self):
281 | # make sure we prune the correct filters by checking the weights of a pruned model
282 | # the weights of the original model are initialized to certain values
283 |
284 | total_num_w = 3195328
285 | constraint_num_w = total_num_w*0.975
286 | simp_block_indices = [0, 1, 5, 7, 9, 11, 13]
287 | delta_w = [24, 56, 104, 80, 80, 56, 80]
288 | topk_w = [8, 8, 152, 432, 432, 456, 944]
289 |
290 | # initialze model weights
291 | model_init = copy.deepcopy(model)
292 | conv_layers = getattr(model_init, 'model')
293 | for block_idx in range(14):
294 | module = getattr(conv_layers, str(block_idx))
295 |
296 | # regular/depthwise
297 | layer = getattr(module, '0')
298 | layer.weight.data = self.gen_layer_weight(layer.weight.data)
299 |
300 | if block_idx != 0: # pointwise
301 | layer = getattr(module, '3')
302 | layer.weight.data = self.gen_layer_weight(layer.weight.data)
303 | model_init.fc.weight.data = self.gen_layer_weight(model_init.fc.weight.data)
304 |
305 |
306 | network_def = network_utils.get_network_def_from_model(model_init)
307 |
308 | for i in range(len(simp_block_indices)):
309 | simp_block_idx = simp_block_indices[i]
310 | simp_network_def, _ = network_utils.simplify_network_def_based_on_constraint(network_def,
311 | simp_block_idx, constraint_num_w, "WEIGHTS")
312 | simp_model = network_utils.simplify_model_based_on_network_def(simp_network_def, model_init)
313 | updated_network_def = network_utils.get_network_def_from_model(simp_model)
314 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta_w[i], simp_block_idx)
315 | self.check_network_def(updated_network_def, input_channels_gt, output_channels_gt)
316 |
317 | simp_conv_layers = getattr(simp_model, 'model')
318 | for block_idx in range(14):
319 | if block_idx == simp_block_idx:
320 | simp_module = getattr(simp_conv_layers, str(block_idx))
321 |
322 | if block_idx == 0:
323 | simp_layer = getattr(simp_module, '0')
324 | else: # pointwise
325 | simp_layer = getattr(simp_module, '3')
326 |
327 | for weight_idx in range(topk_w[i]):
328 | equal_weights = (simp_layer.weight.data[weight_idx, ::] == delta_w[i] + weight_idx)
329 | self.assertTrue(equal_weights.min(), "Weights of the pruned layers error")
330 |
331 | if simp_block_idx != 13:
332 | # check the next depthwise layer
333 | simp_module = getattr(simp_conv_layers, str(block_idx+1))
334 | simp_layer = getattr(simp_module, '0')
335 | for weight_idx in range(topk_w[i]):
336 | equal_weights = (simp_layer.weight.data[weight_idx, ::] == delta_w[i] + weight_idx)
337 | self.assertTrue(equal_weights.min(), "Weights of the pruned layers error")
338 |
339 |
340 | def test_build_latency_lookup_table(self):
341 | network_def = network_utils.get_network_def_from_model(model)
342 | lookup_table_path = './unittest_lookup_table.plk'
343 | min_conv_feature_size = 32
344 | min_fc_feature_size = 128
345 | measure_latency_batch_size = 1
346 | measure_latency_sample_times = 1
347 |
348 | network_utils.build_lookup_table(network_def, 'LATENCY', lookup_table_path, min_conv_feature_size,
349 | min_fc_feature_size, measure_latency_batch_size, measure_latency_sample_times)
350 |
351 | with open(lookup_table_path, 'rb') as file_id:
352 | lookup_table = pickle.load(file_id)
353 | self.assertEqual(len(lookup_table), 28, "Lookup table length error")
354 |
355 | input_channels_gt = [3, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512,
356 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024]
357 | output_channels_gt = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512,
358 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10]
359 | layer_idx = 0
360 |
361 | dep_layer_latency_dict_list = []
362 | pt_layer_latency_dict_list = []
363 |
364 | for layer_name, layer_properties in lookup_table.items():
365 | self.assertEqual(layer_properties[KEY_IS_DEPTHWISE], network_def[layer_name][KEY_IS_DEPTHWISE], "lookup table layer properties error (is_depthwise)")
366 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], network_def[layer_name][KEY_NUM_IN_CHANNELS], "lookup table layer properties error (num_in_channels)")
367 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], network_def[layer_name][KEY_NUM_OUT_CHANNELS], "lookup table layer properties error (num_out_channels)")
368 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], network_def[layer_name][KEY_KERNEL_SIZE], "lookup table layer properties error (kernel_size)")
369 | self.assertEqual(layer_properties[KEY_STRIDE], network_def[layer_name][KEY_STRIDE], "lookup table layer properties error (stride)")
370 | self.assertEqual(layer_properties[KEY_PADDING], network_def[layer_name][KEY_PADDING], "lookup table layer properties error (padding)")
371 | self.assertEqual(layer_properties[KEY_GROUPS], network_def[layer_name][KEY_GROUPS], "lookup table layer properties error (groups)")
372 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], network_def[layer_name][KEY_LAYER_TYPE_STR], "lookup table layer properties error (layer_type_str)")
373 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], network_def[layer_name][KEY_INPUT_FEATURE_MAP_SIZE], "lookup table layer properties error (input_feature_size)")
374 |
375 | layer_latency_table = layer_properties[KEY_LATENCY]
376 |
377 | num_in_samples = input_channels_gt[layer_idx]
378 | num_output_samples = output_channels_gt[layer_idx]
379 | if layer_idx != 27:
380 | if num_in_samples < min_conv_feature_size:
381 | num_in_samples = 1
382 | else:
383 | num_in_samples = num_in_samples/min_conv_feature_size
384 | num_output_samples = num_output_samples/min_conv_feature_size
385 | else:
386 | num_in_samples = num_in_samples/min_fc_feature_size
387 | if num_output_samples < min_fc_feature_size:
388 | num_output_samples = 1
389 | if layer_idx != 27 and layer_idx % 2 == 1:
390 | self.assertEqual(len(layer_latency_table), num_in_samples, "Layerwise latency dict length error (layer index: {})".format(layer_idx))
391 | else:
392 | self.assertEqual(len(layer_latency_table), num_in_samples*num_output_samples, "Layerwise latency dict length error (layer index: {})".format(layer_idx))
393 |
394 | if layer_idx >= 13 and layer_idx <= 22:
395 | if layer_idx % 2 == 0: # pointwise layer
396 | pt_layer_latency_dict_list.append(layer_latency_table)
397 | else: # depthwise layer
398 | dep_layer_latency_dict_list.append(layer_latency_table)
399 |
400 | layer_idx += 1
401 |
402 | # check whether same layers have the same results
403 | for i in range(1, len(dep_layer_latency_dict_list)):
404 | latency_dict_gt = dep_layer_latency_dict_list[0]
405 | latency_dict = dep_layer_latency_dict_list[i]
406 | for key, value in latency_dict_gt.items():
407 | self.assertEqual(latency_dict_gt[key], latency_dict[key], "Lookup talbe of same depthwise layers ({}) error".format(i))
408 |
409 | for i in range(1, len(pt_layer_latency_dict_list)):
410 | latency_dict_gt = pt_layer_latency_dict_list[0]
411 | latency_dict = pt_layer_latency_dict_list[i]
412 | for key, value in latency_dict_gt.items():
413 | self.assertEqual(latency_dict_gt[key], latency_dict[key], "Lookup talbe of same pointwise layers ({}) error".format(i))
414 |
415 | os.remove(lookup_table_path)
416 |
417 | if __name__ == '__main__':
418 | unittest.main()
419 |
--------------------------------------------------------------------------------
/unittest/unittest_worker_helloworld.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import sys
4 | sys.path.append(os.path.abspath('../'))
5 |
6 | import network_utils as networkUtils
7 | import unittest
8 | from constants import *
9 | import subprocess
10 | import time
11 | import common
12 | import sys
13 | import nets as models
14 | import shutil
15 |
16 |
17 | MODEL_ARCH = 'helloworld'
18 | INPUT_DATA_SHAPE = (3, 32, 32)
19 |
20 | FLOPS_LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl')
21 |
22 | WORKER_FOLDER = os.path.join('../models', MODEL_ARCH, 'unittest_worker')
23 | if not os.path.exists(WORKER_FOLDER):
24 | os.mkdir(WORKER_FOLDER)
25 | print('Create directory', WORKER_FOLDER)
26 | _WORKER_PY_FILENAME = '../worker.py'
27 |
28 | MODEL_PATH = os.path.join('../models', MODEL_ARCH, 'model_0.pth.tar')
29 |
30 | model = models.__dict__[MODEL_ARCH]()
31 | for i in range(4):
32 | layer = getattr(model.features, str(i*2))
33 | layer.weight.data = torch.zeros_like(layer.weight.data)
34 | torch.save(model, MODEL_PATH)
35 |
36 | DATASET_PATH = './'
37 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH)
38 |
39 |
40 | class TestWorker_helloworld(unittest.TestCase):
41 | def __init__(self, *args, **kwargs):
42 | super(TestWorker_helloworld, self).__init__(*args, **kwargs)
43 |
44 |
45 | def run_worker(self, constraint, netadapt_iteration, block, resource_type, short_term_fine_tune_iteration, finetune_lr=0.001, lookup_table_path=''):
46 | gpu = 0
47 | with open(os.path.join(WORKER_FOLDER, common.WORKER_LOG_FILENAME_TEMPLATE.format(netadapt_iteration, block)), 'w') as file_id:
48 | command_list = [sys.executable, _WORKER_PY_FILENAME, WORKER_FOLDER, \
49 | MODEL_PATH, str(block), resource_type, str(constraint), \
50 | str(netadapt_iteration), str(short_term_fine_tune_iteration), str(gpu), \
51 | lookup_table_path, DATASET_PATH] + [str(e) for e in INPUT_DATA_SHAPE] + [MODEL_ARCH] + [str(finetune_lr)]
52 | print(command_list)
53 | return subprocess.call(command_list, stdout=file_id, stderr=file_id)
54 | #return os.system(' '.join(command_list))
55 |
56 |
57 | def check_worker_simplify_and_finetune(self, constraint, netadapt_iteration, block, resource_type,
58 | short_term_fine_tune_iteration, resource_gt, acc_gt, network_def_gt,
59 | finetune_lr=0.001, lookup_table_path=''):
60 | t = time.time()
61 | returncode = self.run_worker(constraint, netadapt_iteration, block,
62 | resource_type, short_term_fine_tune_iteration,
63 | finetune_lr=0.001, lookup_table_path=lookup_table_path)
64 | print('Worker finish time: {}s'.format(time.time() - t))
65 |
66 | # Check return code
67 | self.assertEqual(returncode, 0, "Normal worker return value error")
68 | saved_model = torch.load(os.path.join(WORKER_FOLDER,
69 | common.WORKER_MODEL_FILENAME_TEMPLATE.format(netadapt_iteration, block)))
70 | acc = network_utils.evaluate(saved_model)
71 | network_def_saved_model = network_utils.get_network_def_from_model(saved_model)
72 | res = network_utils.compute_resource(network_def_saved_model, resource_type=resource_type, lookup_table_path=lookup_table_path)
73 |
74 | with open(os.path.join(WORKER_FOLDER, common.WORKER_ACCURACY_FILENAME_TEMPLATE.format(netadapt_iteration, block)),
75 | 'r') as file_id:
76 | saved_acc = float(file_id.read())
77 | with open(os.path.join(WORKER_FOLDER, common.WORKER_RESOURCE_FILENAME_TEMPLATE.format(netadapt_iteration, block)),
78 | 'r') as file_id:
79 | saved_res = float(file_id.read())
80 |
81 | self.assertEqual(acc, acc_gt, "Evaluation of simplified model error")
82 | self.assertEqual(acc, saved_acc, "The value in accuracy file is not equal to accuracy of somplified model")
83 | self.assertEqual(res, resource_gt, "Resource of simplified model error")
84 | self.assertEqual(res, saved_res, "The value in resource file is not equal to resource of somplified model")
85 |
86 | for idx in range(4):
87 | layer = getattr(saved_model.features, str(idx*2))
88 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data) + idx*short_term_fine_tune_iteration))
89 | temp = torch.min(temp)
90 | temp = temp.item()
91 | self.assertTrue(temp, "Model weights after short-term fine-tune are incorrect")
92 |
93 | self.assertEqual(network_def_saved_model[idx], network_def_gt[idx], "network_def of simplified model is incorrect")
94 | return
95 |
96 |
97 | def test_worker_simplify_and_finetune_weights(self):
98 | '''
99 | Check simplifying and finetuning block 0~2
100 | '''
101 |
102 | netadapt_iteration = 2
103 | all_resource_weights = 29232
104 |
105 | constraint_weights = [all_resource_weights - 3*3*(8*(3 + 32)),
106 | all_resource_weights - 3*3*(7*(16 + 64)),
107 | all_resource_weights - 3*3*(31*(32 + 10))]
108 | resource_weights_gt = [all_resource_weights - 3*3*(8*(3 + 32)),
109 | all_resource_weights - 3*3*(7*(16 + 64)),
110 | all_resource_weights - 3*3*(31*(32 + 10))]
111 | eval_acc_gt = [1, 5, 80]
112 |
113 | network_def_gt = [
114 | [(3, 8), (8, 32), (32, 64), (64, 10)],
115 | [(3, 16), (16, 25), (25, 64), (64, 10)],
116 | [(3, 16), (16, 32), (32, 33), (33, 10)]
117 | ]
118 |
119 | for block in range(3):
120 | self.check_worker_simplify_and_finetune(constraint=constraint_weights[block],
121 | netadapt_iteration=netadapt_iteration, block=block,
122 | resource_type="WEIGHTS", short_term_fine_tune_iteration=block, finetune_lr=0.001,
123 | lookup_table_path='', resource_gt=resource_weights_gt[block], acc_gt=eval_acc_gt[block],
124 | network_def_gt=network_def_gt[block])
125 | return
126 |
127 |
128 | def test_worker_simplify_and_finetune_flops(self):
129 | '''
130 | Check simplifying and finetuning block 0~2
131 | '''
132 |
133 | netadapt_iteration = -1
134 | all_resource_flops = 29232*32*32
135 |
136 | constraint_flops = [all_resource_flops - 3*3*(8*(3 + 32))*32*32,
137 | all_resource_flops - 3*3*(7*(16 + 64))*32*32,
138 | all_resource_flops - 3*3*(31*(32 + 10))*32*32]
139 | resource_flops_gt = [all_resource_flops - 3*3*(8*(3 + 32))*32*32,
140 | all_resource_flops - 3*3*(7*(16 + 64))*32*32,
141 | all_resource_flops - 3*3*(31*(32 + 10))*32*32]
142 |
143 |
144 | eval_acc_gt = [1, 5, 80]
145 |
146 | network_def_gt = [
147 | [(3, 8), (8, 32), (32, 64), (64, 10)],
148 | [(3, 16), (16, 25), (25, 64), (64, 10)],
149 | [(3, 16), (16, 32), (32, 33), (33, 10)]
150 | ]
151 |
152 | for block in range(3):
153 | self.check_worker_simplify_and_finetune(constraint=constraint_flops[block],
154 | netadapt_iteration=netadapt_iteration, block=block,
155 | resource_type="FLOPS", short_term_fine_tune_iteration=block, finetune_lr=0.001,
156 | lookup_table_path=FLOPS_LOOKUP_TABLE_PATH, resource_gt=resource_flops_gt[block], acc_gt=eval_acc_gt[block],
157 | network_def_gt=network_def_gt[block])
158 | return
159 |
160 |
161 | def test_worker_weights_tight_constraint(self):
162 | '''
163 | Check simplifying and finetuning block 0~2
164 | '''
165 |
166 | netadapt_iteration = 2
167 | all_resource_weights = 29232
168 |
169 | constraint_weights = [0, 1, -1]
170 | resource_weights_gt = [all_resource_weights - 3*3*(15*(3 + 32)),
171 | all_resource_weights - 3*3*(31*(16 + 64)),
172 | all_resource_weights - 3*3*(63*(32 + 10))]
173 | eval_acc_gt = [1, 5, 80]
174 |
175 | network_def_gt = [
176 | [(3, 1), (1, 32), (32, 64), (64, 10)],
177 | [(3, 16), (16, 1), (1, 64), (64, 10)],
178 | [(3, 16), (16, 32), (32, 1), (1, 10)]
179 | ]
180 |
181 | for block in range(3):
182 | self.check_worker_simplify_and_finetune(constraint=constraint_weights[block],
183 | netadapt_iteration=netadapt_iteration, block=block,
184 | resource_type="WEIGHTS", short_term_fine_tune_iteration=block, finetune_lr=0.001,
185 | lookup_table_path='', resource_gt=resource_weights_gt[block], acc_gt=eval_acc_gt[block],
186 | network_def_gt=network_def_gt[block])
187 | return
188 |
189 |
190 | def test_worker_block_out_of_bound(self):
191 | netadapt_iteration = -5
192 | all_resource_flops = 29232*32*32
193 | returncode = self.run_worker(constraint=all_resource_flops,
194 | netadapt_iteration=netadapt_iteration, block=5,
195 | resource_type="FLOPS", short_term_fine_tune_iteration=5, finetune_lr=0.001)
196 |
197 | self.assertEqual(returncode, 1, "Abnormal worker not detected")
198 |
199 |
200 | if __name__ == '__main__':
201 | unittest.main()
202 | shutil.rmtree(WORKER_FOLDER)
--------------------------------------------------------------------------------
/worker.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import torch
3 | import os
4 | import common
5 | import constants
6 | import network_utils as networkUtils
7 |
8 | '''
9 | Launched by `master.py`
10 |
11 | Simplify a certain block of models and then finetune for several iterations.
12 | '''
13 |
14 |
15 | # Supported network_utils
16 | network_utils_all = sorted(name for name in networkUtils.__dict__
17 | if name.islower() and not name.startswith("__")
18 | and callable(networkUtils.__dict__[name]))
19 |
20 |
21 | def worker(args):
22 | """
23 | The main function of the worker.
24 | `worker.py` loads a pretrained model, simplify it (one specific block), and short-term fine-tune the pruned model.
25 | Then, the accuracy and resource consumption of the simplified model will be recorded.
26 | `worker.py` finished with a finish file, which is utilized by `master.py`.
27 |
28 | Input:
29 | args: command-line arguments
30 |
31 | raise:
32 | ValueError: when the num of block index >= simplifiable blocks (i.e. simplify nonexistent block or output layer)
33 | """
34 |
35 | # Set the GPU.
36 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
37 |
38 | # Get the network utils.
39 | model = torch.load(args.model_path)
40 | network_utils = networkUtils.__dict__[args.arch](model, args.input_data_shape, args.dataset_path, args.finetune_lr)
41 |
42 | if network_utils.get_num_simplifiable_blocks() <= args.block:
43 | raise ValueError("Block index >= number of simplifiable blocks")
44 |
45 | network_def = network_utils.get_network_def_from_model(model)
46 | simplified_network_def, simplified_resource = (
47 | network_utils.simplify_network_def_based_on_constraint(network_def,
48 | args.block,
49 | args.constraint,
50 | args.resource_type,
51 | args.lookup_table_path))
52 | # Choose the filters.
53 | simplified_model = network_utils.simplify_model_based_on_network_def(simplified_network_def, model)
54 |
55 | print('Original model:')
56 | print(model)
57 | print('')
58 | print('Simplified model:')
59 | print(simplified_model)
60 |
61 | fine_tuned_model = network_utils.fine_tune(simplified_model, args.short_term_fine_tune_iteration)
62 | fine_tuned_accuracy = network_utils.evaluate(fine_tuned_model)
63 | print('Accuracy after finetune:', fine_tuned_accuracy)
64 |
65 | # Save the results.
66 | torch.save(fine_tuned_model,
67 | os.path.join(args.worker_folder,
68 | common.WORKER_MODEL_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block)))
69 | with open(os.path.join(args.worker_folder,
70 | common.WORKER_ACCURACY_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block)),
71 | 'w') as file_id:
72 | file_id.write(str(fine_tuned_accuracy))
73 | with open(os.path.join(args.worker_folder,
74 | common.WORKER_RESOURCE_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block)),
75 | 'w') as file_id:
76 | file_id.write(str(simplified_resource))
77 | with open(os.path.join(args.worker_folder,
78 | common.WORKER_FINISH_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block)),
79 | 'w') as file_id:
80 | file_id.write('finished.')
81 |
82 | # release GPU memory
83 | del simplified_model, fine_tuned_model
84 | return
85 |
86 | if __name__ == '__main__':
87 | # Parse the input arguments.
88 | arg_parser = ArgumentParser()
89 | arg_parser.add_argument('worker_folder', type=str,
90 | help='directory where model and logging information will be saved')
91 | arg_parser.add_argument('model_path', type=str, help='path to model which is to be simplified')
92 | arg_parser.add_argument('block', type=int, help='index of block to be simplified')
93 | arg_parser.add_argument('resource_type', type=str, help='FLOPS/WEIGHTS/LATENCY')
94 | arg_parser.add_argument('constraint', type=float, help='floating value specifying resource constraint')
95 | arg_parser.add_argument('netadapt_iteration', type=int, help='netadapt iteration')
96 | arg_parser.add_argument('short_term_fine_tune_iteration', type=int, help='number of iterations of fine-tuning after simplification')
97 | arg_parser.add_argument('gpu', type=str, help='index of gpu to run short-term fine-tuning')
98 | arg_parser.add_argument('lookup_table_path', type=str, default='', help='path to lookup table')
99 | arg_parser.add_argument('dataset_path', type=str, default='', help='path to dataset')
100 | arg_parser.add_argument('input_data_shape', nargs=3, default=[], type=int, help='input shape (for ImageNet: `3 224 224`)')
101 | arg_parser.add_argument('arch', default='alexnet',
102 | choices=network_utils_all,
103 | help='network_utils: ' +
104 | ' | '.join(network_utils_all) +
105 | ' (default: alexnet)')
106 | arg_parser.add_argument('finetune_lr', type=float, default=0.001, help='short-term fine-tune learning rate')
107 | args = arg_parser.parse_args()
108 |
109 | # Launch a worker.
110 | worker(args)
111 |
--------------------------------------------------------------------------------