├── LICENSE ├── README.md ├── __init__.py ├── build_lookup_table.py ├── common.py ├── constants.py ├── data └── __init__.py ├── docs ├── alexnet │ └── README.md └── mobilenet │ └── README.md ├── eval.py ├── fig ├── netadapt_algo.png └── netadapt_fig.png ├── functions.py ├── latency_lut └── __init__.py ├── master.py ├── models ├── alexnet │ └── __init__.py ├── helloworld │ ├── lut.pkl │ └── model_0.pth.tar └── mobilenet │ └── __init__.py ├── nets ├── __init__.py ├── alexnet.py ├── helloworld.py └── mobilenet.py ├── network_utils ├── __init__.py ├── network_utils_abstract.py ├── network_utils_alexnet.py ├── network_utils_helloworld.py └── network_utils_mobilenet.py ├── requirements.txt ├── scripts ├── netadapt_alexnet-0.5latency.sh ├── netadapt_alexnet-0.5mac.sh ├── netadapt_helloworld.sh ├── netadapt_mobilenet-0.5latency.sh ├── netadapt_mobilenet-0.5mac.sh ├── unittest.sh ├── unittest_alexnet.sh ├── unittest_helloworld.sh └── unittest_mobilenet.sh ├── train.py ├── unittest ├── unittest_master_helloworld.py ├── unittest_network_utils_alexnet.py ├── unittest_network_utils_helloworld.py ├── unittest_network_utils_mobilenet.py └── unittest_worker_helloworld.py └── worker.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 denru01 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications 2 | ============================ 3 | This repo contains the official Pytorch reimplementation of the paper "NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications" 4 | [[paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Tien-Ju_Yang_NetAdapt_Platform-Aware_Neural_ECCV_2018_paper.pdf)] 5 | [[project](https://netadapt.mit.edu)]. The results in the paper were generated by the Tensorflow implementation from Google AI. 6 |

7 | photo not available 8 |

9 | 10 | ## Summary ## 11 | 0. [Requirements](#requirements) 12 | 0. [Usage](#usage) 13 | 0. [Example](#example) 14 | 0. [Customization](#customization) 15 | 0. [Citation](#citation) 16 | 17 | 29 | 30 | ## Requirements ## 31 | The code base is tested with the following setting: 32 | 33 | 1. Python 3.7.0 34 | 2. CUDA 10.0 35 | 3. Pytorch 1.2.0 36 | 4. torchvision 0.4.0 37 | 5. numpy 1.17.0 38 | 6. scipy 1.3.1 39 | 40 | First clone the repo in the directory you want to work: 41 | 42 | git clone https://github.com/denru01/netadapt.git 43 | cd netadapt 44 | 45 | In the following context, we assume you are at the repo root. 46 | 47 | If the versions of Python and CUDA are the same as yours, you can download the python packages using: 48 | 49 | pip install -r requirements.txt 50 | 51 | To verify the downloaded code base is correct, please run either 52 | 53 | sh scripts/unittest.sh 54 | 55 | or 56 | 57 | sh scripts/unittest_helloworld.sh 58 | sh scripts/unittest_alexnet.sh 59 | sh scripts/unittest_mobilenet.sh 60 | 61 | If it is correct, you should not see any FAIL. 62 | 63 | ## Usage ## 64 | In order to apply NetAdapt, run: 65 | 66 | python master.py [-h] [-gp GPUS [GPUS ...]] [-re] [-im INIT_MODEL_PATH] 67 | [-mi MAX_ITERS] [-lr FINETUNE_LR] [-bu BUDGET] 68 | [-bur BUDGET_RATIO] [-rt RESOURCE_TYPE] 69 | [-ir INIT_RESOURCE_REDUCTION] 70 | [-irr INIT_RESOURCE_REDUCTION_RATIO] 71 | [-rd RESOURCE_REDUCTION_DECAY] 72 | [-st SHORT_TERM_FINE_TUNE_ITERATION] [-lt LOOKUP_TABLE_PATH] 73 | [-dp DATASET_PATH] [-a ARCH] [-si SAVE_INTERVAL] 74 | working_folder input_data_shape input_data_shape 75 | input_data_shape 76 | 77 | - `working_folder`: Root folder where models, related files and history information are saved. You can see how models are pruned progressively in `working_folder/master/history.txt`. 78 | 79 | - `input_data_shape`: Input data shape (C, H, W) (default: 3 224 224). If you want to apply NetAdapt to different tasks, you might need to change data shape. 80 | 81 | - `-h, --help`: Show this help message and exit. 82 | 83 | - `-gp GPUS [GPUS ...], --gpus GPUS [GPUS ...]`: Indices of available gpus (default: 0). 84 | 85 | - `-re, --resume`: Resume from previous iteration. In order to resume, specify `--resume` and specify `working_folder` as the one you want to resume. 86 | The resumed arguments will overwrite the arguments provided here. 87 | For example, if you want to simplify a model by pruning and finetuning for 30 iterations (under `working_folder`), however, your program terminated after 20 iterations. 88 | Then you can use `--resume` to restore and continue for the last 10 iterations. 89 | 90 | - `-im INIT_MODEL_PATH, --init_model_path INIT_MODEL_PATH`: Path to pretrained model. 91 | 92 | - `-mi MAX_ITERS, --max_iters MAX_ITERS`: Maximum iteration of removing filters and short-term fine-tune (default: 10). 93 | 94 | - `-lr FINETUNE_LR, --finetune_lr FINETUNE_LR`: Short-term fine-tune learning rate (default: 0.001). 95 | 96 | - `-bu BUDGET, --budget BUDGET`: Resource constraint. If resource < `budget`, the process is terminated. 97 | 98 | - `-bur BUDGET_RATIO, --budget_ratio BUDGET_RATIO`: If `--budget` is not specified, `buget` = `budget_ratio`\*(pretrained model resource) (default: 0.25). 99 | 100 | - `-rt RESOURCE_TYPE, --resource_type RESOURCE_TYPE`: Resource constraint type (default: FLOPS). We currently support `FLOPS`, `WEIGHTS`, and `LATENCY` (device `cuda:0`). If you want to add other resource 101 | types, please modify `def compute_resource(...)` in `network_util` python files (e.g. `network_utils/network_utils_alexnet`). 102 | 103 | - `-ir INIT_RESOURCE_REDUCTION, --init_resource_reduction INIT_RESOURCE_REDUCTION`: For each iteration, target resource = current resource - `init_resource_reduction`\*(`resource_reduction_decay`\*\*(iteration-1)). 104 | 105 | - `-irr INIT_RESOURCE_REDUCTION_RATIO, --init_resource_reduction_ratio INIT_RESOURCE_REDUCTION_RATIO`: If `--init_resource_reduction` is not specified, 106 | `init_resource_reduction` = `init_resource_reduction_ratio`\*(pretrained model resource) (default: 0.025). 107 | 108 | - `-rd RESOURCE_REDUCTION_DECAY, --resource_reduction_decay RESOURCE_REDUCTION_DECAY`: For each iteration, target resource = current resource - `init_resource_reduction`\*(`resource_reduction_decay`\*\*(iteration-1)) (default: 0.96). 109 | 110 | - `-st SHORT_TERM_FINE_TUNE_ITERATION, --short_term_fine_tune_iteration SHORT_TERM_FINE_TUNE_ITERATION`: Short-term fine-tune iteration (default: 10). 111 | 112 | - `-lt LOOKUP_TABLE_PATH, --lookup_table_path LOOKUP_TABLE_PATH`: Path to lookup table. 113 | 114 | - `-dp DATASET_PATH, --dataset_path DATASET_PATH`: Path to dataset. 115 | 116 | - `-a ARCH, --arch ARCH network_utils`: Defines how networks are pruned, fine-tuned, and evaluated. If you want to use 117 | your own method, please see [**Customization**](#customization) and specify here. (default: alexnet) 118 | 119 | - `-si SAVE_INTERVAL, --save_interval SAVE_INTERVAL`: Interval of iterations that all pruned models at the same iteration will be saved. 120 | Use `-1` to save only the best model at each iteration. 121 | Use `1` to save all models at each iteration. (default: -1). 122 | 123 | 124 | 125 | ## Example ## 126 | 187 | 188 | We provide a simple example of applying **NetAdapt** to a very small [network](nets/helloworld.py): 189 | 190 | sh scripts/netadapt_helloworld.sh 191 | 192 | Detailed examples of applying **NetAdapt** to **AlexNet**/**MobileNet** on **CIFAR-10** are shown [**here (AlexNet)**](docs/alexnet/README.md) and [**here (MobileNet)**](docs/mobilenet/README.md). 193 | 194 |

195 | photo not available 196 |

197 | 198 | If you want to apply the algorithm to different networks or even different tasks, 199 | please see the following [**Customization**](#customization) section. 200 | 201 | 202 | ## Customization ## 203 | 204 | To apply NetAdapt to differenct networks or different tasks, please follow the instructions: 205 | 206 | 1. Create your own `network_utils` python file (said `network_utils_yourNetworkOrTask.py`) and place it under `network_utils`. 207 | 208 | 2. Implement functions described in [`network_utils_abstract.py`](network_utils/network_utils_abstract.py). 209 | 210 | 3. As we provide an example of applying NetAdapt to AlexNet, you can also build your `network_utils` based on [`network_utils_alexnet.py`](network_utils/network_utils_alexnet.py): 211 | 212 | ```bash 213 | cd network_utils 214 | cp network_utils_alexnet.py ./network_utils_yourNetworkOrTask.py 215 | ``` 216 | 217 | 4. Add `from .network_utils_yourNetworkOrTask import *` to `__init__.py`, which is under [the same directory](network_utils/__init__.py). 218 | 219 | 5. Modify `class networkUtils_alexnet(...)` in [line 44](network_utils/network_utils_alexnet.py#L44) in `network_utils_yourNetworkOrTask.py` to `class networkUtils_yourNetworkOrTask(...)`. 220 | 221 | 6. Modify `def alexnet(...)` in [line 325-326](network_utils/network_utils_alexnet.py#L325-L326) to: 222 | ```bash 223 | def yourNetworkOrTask(model, input_data_shape, dataset_path, finetune_lr=1e-3): 224 | return networkUtils_yourNetworkOrTask(model, input_data_shape, dataset_path, finetune_lr) 225 | ``` 226 | 227 | 7. Specify training/validation data loader, loss functions, optimizers, network architecture, training method, and evaluation method in `network_utils_yourNetworkOrTask.py` if there is any difference from the AlexNet example: 228 | 229 | - Modify data loader and loss functionsin function `def __init__(...):` in [line 52](network_utils/network_utils_alexnet.py#L52-L125). 230 | 231 | - Specify additive skip connections if there is any and modify `def simplify_network_def_based_on_constraint(...)` in `network_utils_yourNetworkOrTask.py`. 232 | You can see how our implementation uses additive skip connections [here](functions.py#L543-L549). 233 | 234 | - Modify training method (short-term finetune) in function `def fine_tune(...):` in [line 245](network_utils/network_utils_alexnet.py#L245-L288). 235 | 236 | - Modify evaluation method in function `def evaluate(...):` in [line 291](network_utils/network_utils_alexnet.py#L291-L322). 237 | 238 | You can see how these methods are utilized by the framework [here](worker.py#L39-L62). 239 | 240 | 8. Our current code base supports pruning `Conv2d`, `ConvTranspose2d`, and `Linear` with additive skip connection. 241 | If your network architecture is not supported, please modify [this](network_utils/network_utils_alexnet.py#L142-L199). 242 | If you want to use other metrics (resource type) to prune networks, please modify [this](network_utils/network_utils_alexnet.py#L234-L238). 243 | 244 | 9. We can apply NetAdapt to different networks or tasks by using `--arch yourNetworkOrTask` in `scripts/netadapt_alexnet-0.5mac.sh`. 245 | As for the values of other arguments, please see [**Usage**](#usage). 246 | Generally, if you want to apply NetAdapt to a different task, you might change `input_data_shape`. 247 | If your network architecture is very different from that of MobileNet, you would have to modify the values of `--init_resource_reduction_ratio` and `--resource_reduction_decay` to get a different resource reduction schedule. 248 | 249 | 250 | 251 | ## Citation 252 | If you use our code or method in your work, please consider citing the following: 253 | 254 | ``` 255 | @InProceedings{eccv_2018_yang_netadapt, 256 | author = {Yang, Tien-Ju and Howard, Andrew and Chen, Bo and Zhang, Xiao and Go, Alec and Sandler, Mark and Sze, Vivienne and Adam, Hartwig}, 257 | title = {NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications}, 258 | booktitle = {The European Conference on Computer Vision (ECCV)}, 259 | month = {September}, 260 | year = {2018} 261 | } 262 | ``` 263 | 264 | Please direct any questions to the authors: Tien-Ju Yang (tjy@mit.edu) and Yi-Lun Liao (ylliao@mit.edu). 265 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/__init__.py -------------------------------------------------------------------------------- /build_lookup_table.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import nets as models 4 | import functions as fns 5 | from argparse import ArgumentParser 6 | 7 | model_names = sorted(name for name in models.__dict__ 8 | if name.islower() and not name.startswith("__") 9 | and callable(models.__dict__[name])) 10 | 11 | NUM_CLASSES = 10 12 | 13 | INPUT_DATA_SHAPE = (3, 224, 224) 14 | 15 | 16 | ''' 17 | `MIN_CONV_FEATURE_SIZE`: The sampled size of feature maps of layers (conv layer) 18 | along channel dimmension are multiples of 'MIN_CONV_FEATURE_SIZE'. 19 | 20 | `MIN_FC_FEATURE_SIZE`: The sampled size of features of FC layers are 21 | multiples of 'MIN_FC_FEATURE_SIZE'. 22 | ''' 23 | MIN_CONV_FEATURE_SIZE = 8 24 | MIN_FC_FEATRE_SIZE = 64 25 | 26 | ''' 27 | `MEASURE_LATENCY_BATCH_SIZE`: the batch size of input data 28 | when running forward functions to measure latency. 29 | `MEASURE_LATENCY_SAMPLE_TIMES`: the number of times to run the forward function of 30 | a layer in order to get its latency. 31 | ''' 32 | MEASURE_LATENCY_BATCH_SIZE = 128 33 | MEASURE_LATENCY_SAMPLE_TIMES = 500 34 | 35 | 36 | arg_parser = ArgumentParser(description='Build latency lookup table') 37 | arg_parser.add_argument('--dir', metavar='DIR', default='latency_lut/lut_alexnet.pkl', 38 | help='path to saving lookup table') 39 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet', 40 | choices=model_names, 41 | help='model architecture: ' + 42 | ' | '.join(model_names) + 43 | ' (default: alexnet)') 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | args = arg_parser.parse_args() 49 | print(args) 50 | 51 | build_lookup_table = True 52 | lookup_table_path = args.dir 53 | model_arch = args.arch 54 | 55 | print('Load', model_arch) 56 | print('--------------------------------------') 57 | model = models.__dict__[model_arch](num_classes=NUM_CLASSES) 58 | network_def = fns.get_network_def_from_model(model, INPUT_DATA_SHAPE) 59 | for layer_name, layer_properties in network_def.items(): 60 | print(layer_name) 61 | print(' ', layer_properties, '\n') 62 | print('-------------------------------------------') 63 | 64 | num_w = fns.compute_resource(network_def, 'WEIGHTS') 65 | flops = fns.compute_resource(network_def, 'FLOPS') 66 | num_param = fns.compute_resource(network_def, 'WEIGHTS') 67 | print('Number of FLOPs: ', flops) 68 | print('Number of weights: ', num_w) 69 | print('Number of parameters: ', num_param) 70 | print('-------------------------------------------') 71 | 72 | model = model.cuda() 73 | 74 | print('Building latency lookup table for', 75 | torch.cuda.get_device_name()) 76 | if build_lookup_table: 77 | fns.build_latency_lookup_table(network_def, lookup_table_path=lookup_table_path, 78 | min_fc_feature_size=MIN_FC_FEATRE_SIZE, 79 | min_conv_feature_size=MIN_CONV_FEATURE_SIZE, 80 | measure_latency_batch_size=MEASURE_LATENCY_BATCH_SIZE, 81 | measure_latency_sample_times=MEASURE_LATENCY_SAMPLE_TIMES, 82 | verbose=True) 83 | print('-------------------------------------------') 84 | print('Finish building latency lookup table.') 85 | print(' Device:', torch.cuda.get_device_name()) 86 | print(' Model: ', model_arch) 87 | print('-------------------------------------------') 88 | 89 | latency = fns.compute_resource(network_def, 'LATENCY', lookup_table_path) 90 | print('Computed latency: ', latency) 91 | latency = fns.measure_latency(model, 92 | [MEASURE_LATENCY_BATCH_SIZE, *INPUT_DATA_SHAPE]) 93 | print('Exact latency: ', latency) 94 | 95 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | # Master-related filenames. 2 | MASTER_MODEL_FILENAME_TEMPLATE = 'iter_{}_best_model.pth.tar' 3 | 4 | # Worker-related filenames. 5 | WORKER_MODEL_FILENAME_TEMPLATE = 'iter_{}_block_{}_model.pth.tar' 6 | WORKER_ACCURACY_FILENAME_TEMPLATE = 'iter_{}_block_{}_accuracy.txt' 7 | WORKER_RESOURCE_FILENAME_TEMPLATE = 'iter_{}_block_{}_resource.txt' 8 | WORKER_LOG_FILENAME_TEMPLATE = 'iter_{}_block_{}_log.txt' 9 | WORKER_FINISH_FILENAME_TEMPLATE = 'iter_{}_block_{}_finish.signal' -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # Define constants. 2 | STRING_SEPARATOR = '.' 3 | 4 | # Define layer types. 5 | CONV_LAYER_TYPES = ['Conv2d', 'ConvTranspose2d'] 6 | FC_LAYER_TYPES = ['Linear'] 7 | BNORM_LAYER_TYPES = ['BatchNorm2d'] 8 | 9 | # Define data types. 10 | WEIGHTSTRING = 'weight' 11 | BIASSTRING = 'bias' 12 | RUNNING_MEANSTRING = 'running_mean' 13 | RUNNING_VARSTRING = 'running_var' 14 | NUM_BATCHES_TRACKED = 'num_batches_tracked' 15 | 16 | # Define keys. 17 | KEY_LAYER_TYPE_STR = 'layer_type_str' #(e.g. Linear, Conv2d) 18 | KEY_IS_DEPTHWISE = 'is_depthwise' 19 | KEY_NUM_IN_CHANNELS = 'num_in_channels' 20 | KEY_NUM_OUT_CHANNELS = 'num_out_channels' 21 | KEY_KERNEL_SIZE = 'kernel_size' 22 | KEY_STRIDE = 'stride' 23 | KEY_PADDING = 'padding' 24 | KEY_GROUPS = 'groups' 25 | KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR = 'before_squared_pixel_shuffle_factor' 26 | KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR = 'after_squared_pixel_shuffle_factor' 27 | KEY_INPUT_FEATURE_MAP_SIZE = 'input_feature_map_size' 28 | KEY_OUTPUT_FEATURE_MAP_SIZE = 'output_feature_map_size' 29 | KEY_MODEL = 'model' 30 | KEY_LATENCY = 'latency' -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/alexnet/README.md: -------------------------------------------------------------------------------- 1 | 1. **Training AlexNet on CIFAR-10.** 2 | 3 | Training: 4 | ```bash 5 | python train.py data/ --dir models/alexnet/model.pth.tar --arch alexnet 6 | ``` 7 | 8 | Evaluation: 9 | ```bash 10 | python eval.py data/ --dir models/alexnet/model.pth.tar --arch alexnet 11 | ``` 12 | 13 | One trained model can be found [here](https://drive.google.com/file/d/1GwugqlSl5ogRvJ2F2-4uz1fXBzt36ahB/view?usp=sharing). 14 | 15 | 2. **Measuring Latency** 16 | 17 | Here we build the latency lookup table for `cuda:0` device: 18 | ```bash 19 | python build_lookup_table.py --dir latency_lut/lut_alexnet.pkl --arch alexnet 20 | ``` 21 | It measures latency of different layers contained in the network (i.e. **AlexNet** here). 22 | For conv layers, the sampled numbers of feature channels are multiples of `MIN_CONV_FEATURE_SIZE`. 23 | For fc layers, the sampled numbers of features are multiples of `MIN_FC_FEATURE_SIZE`. 24 | 25 | 3. **Applying NetAdapt** 26 | 27 | Modify which GPUs will be utilized (`-gp`) in `netadapt_alexnet-0.5mac.sh` and run the script to apply NetAdapt to a pretrained model: 28 | ```bash 29 | sh scripts/netadapt_alexnet-0.5mac.sh 30 | ``` 31 | 32 | You can see how the model is simplified at each iteration in `models/alexnet/prune-by-mac/master/history.txt` and 33 | select the one that satisfies the constraints to run long-term fine-tune. 34 | 35 | After obtaining the adapted model, we need to finetune the model (here we select the one after 18 iterations): 36 | ```bash 37 | python train.py data/ --arch alexnet --resume models/alexnet/prune-by-mac/master/iter_18_best_model.pth.tar --dir models/alexnet/prune-by-mac/master/finetune_model.pth.tar --lr 0.001 38 | ``` 39 | 40 |

41 | photo not available 42 |

43 | 44 | 45 | If you want to get a model with 50% latency, please run: 46 | ```bash 47 | sh scripts/netadapt_alexnet-0.5latency.sh 48 | ``` 49 | 50 | 4. **Evaluation Using Adapted Models** 51 | 52 | After applying NetAdapt to a pretrained model, we can evaluate this adapted model using: 53 | ```bash 54 | python eval.py data/ --dir models/alexnet/prune-by-mac/master/finetune_model.pth.tar --arch alexnet 55 | ``` 56 | 57 | The adapted model can be restored **without modifying the orignal python file**. 58 | 59 | We provide one adapted model [here](https://drive.google.com/file/d/1VH9c2orF2W0P21gD8NrdTdvwP_uJJgYA/view?usp=sharing). -------------------------------------------------------------------------------- /docs/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | 1. **Training MobileNet on CIFAR-10.** 2 | 3 | Training: 4 | ```bash 5 | python train.py data/ --dir models/mobilenet/model.pth.tar --arch mobilenet 6 | ``` 7 | 8 | Evaluation: 9 | ```bash 10 | python eval.py data/ --dir models/mobilenet/model.pth.tar --arch mobilenet 11 | ``` 12 | 13 | One trained model can be found [here](https://drive.google.com/file/d/1jtRZOHK1daRTKD4jYu4US84lf8YjqEdJ/view?usp=sharing). 14 | 15 | 16 | 2. **Measuring Latency** 17 | 18 | Here we build the latency lookup table for `cuda:0` device: 19 | ```bash 20 | python build_lookup_table.py --dir latency_lut/lut_mobilenet.pkl --arch mobilenet 21 | ``` 22 | It measures latency of different layers contained in the network (i.e. **MobileNet** here). 23 | For conv layers, the sampled numbers of feature channels are multiples of `MIN_CONV_FEATURE_SIZE`. 24 | For fc layers, the sampled numbers of features are multiples of `MIN_FC_FEATURE_SIZE`. 25 | 26 | 3. **Applying NetAdapt** 27 | 28 | Modify which GPUs will be utilized (`-gp`) in `netadapt_mobilenet-0.5mac.sh` and run the script to apply NetAdapt to a pretrained model: 29 | ```bash 30 | sh scripts/netadapt_mobilenet-0.5mac.sh 31 | ``` 32 | 33 | You can see how the model is simplified at each iteration in `models/mobilenet/prune-by-mac/master/history.txt` and 34 | select the one that satisfies the constraints to run long-term fine-tune. 35 | 36 | After obtaining the adapted model, we need to finetune the model (here we select the one after 28 iterations): 37 | ```bash 38 | python train.py data/ --arch mobilenet --resume models/mobilenet/prune-by-mac/master/iter_28_best_model.pth.tar --dir models/mobilenet/prune-by-mac/master/finetune_model.pth.tar --lr 0.001 39 | ``` 40 | 41 |

42 | photo not available 43 |

44 | 45 | 46 | If you want to get a model with 50% latency, please run: 47 | ```bash 48 | sh scripts/netadapt_mobilenet-0.5latency.sh 49 | ``` 50 | 51 | 52 | 4. **Evaluation Using Adapted Models** 53 | 54 | After applying NetAdapt to a pretrained model, we can evaluate this adapted model using: 55 | ```bash 56 | python eval.py data/ --dir models/mobilenet/prune-by-mac/master/finetune_model.pth.tar --arch mobilenet 57 | ``` 58 | 59 | The adapted model can be restored **without modifying the orignal python file**. 60 | 61 | We provide one adapted model [here](https://drive.google.com/file/d/1wkQPolgv34ESb0gyeYdygbuNuNYhIqyx/view?usp=sharing). -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import time 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | import torch.backends.cudnn as cudnn 10 | import pickle 11 | 12 | import nets as models 13 | import functions as fns 14 | 15 | _NUM_CLASSES = 10 16 | 17 | model_names = sorted(name for name in models.__dict__ 18 | if name.islower() and not name.startswith("__") 19 | and callable(models.__dict__[name])) 20 | 21 | 22 | def compute_topk_accuracy(output, target, topk=(1,)): 23 | """Computes the accuracy over the k top predictions for the specified values of k""" 24 | with torch.no_grad(): 25 | maxk = max(topk) 26 | batch_size = target.size(0) 27 | 28 | _, pred = output.topk(maxk, 1, True, True) 29 | pred = pred.t() 30 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 31 | 32 | res = [] 33 | for k in topk: 34 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 35 | res.append(correct_k.mul_(100.0 / batch_size)) 36 | return res 37 | 38 | 39 | def compute_accuracy(output, target): 40 | output = output.argmax(dim=1) 41 | acc = 0.0 42 | acc = torch.sum(target == output).item() 43 | acc = acc/output.size(0)*100 44 | return acc 45 | 46 | 47 | class AverageMeter(object): 48 | """Computes and stores the average and current value""" 49 | def __init__(self): 50 | self.reset() 51 | 52 | def reset(self): 53 | self.val = 0 54 | self.avg = 0 55 | self.sum = 0 56 | self.count = 0 57 | 58 | def get_avg(self): 59 | return self.avg 60 | 61 | def update(self, val, n=1): 62 | self.val = val 63 | self.sum += val * n 64 | self.count += n 65 | self.avg = self.sum / self.count 66 | 67 | 68 | def eval(test_loader, model, args): 69 | batch_time = AverageMeter() 70 | acc = AverageMeter() 71 | 72 | # switch to eval mode 73 | model.eval() 74 | 75 | end = time.time() 76 | for i, (images, target) in enumerate(test_loader): 77 | if not args.no_cuda: 78 | images = images.cuda() 79 | target = target.cuda() 80 | output = model(images) 81 | batch_acc = compute_accuracy(output, target) 82 | acc.update(batch_acc, images.size(0)) 83 | batch_time.update(time.time() - end) 84 | end = time.time() 85 | 86 | # Update statistics 87 | estimated_time_remained = batch_time.get_avg()*(len(test_loader)-i-1) 88 | fns.update_progress(i, len(test_loader), 89 | ESA='{:8.2f}'.format(estimated_time_remained)+'s', 90 | acc='{:4.2f}'.format(float(batch_acc)) 91 | ) 92 | print() 93 | print('Test accuracy: {:4.2f}% (time = {:8.2f}s)'.format( 94 | float(acc.get_avg()), batch_time.get_avg()*len(test_loader))) 95 | print('===================================================================') 96 | return float(acc.get_avg()) 97 | 98 | 99 | if __name__ == '__main__': 100 | # Parse the input arguments. 101 | arg_parser = ArgumentParser() 102 | arg_parser.add_argument('data', metavar='DIR', help='path to dataset') 103 | arg_parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 104 | help='number of data loading workers (default: 4)') 105 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet', 106 | choices=model_names, 107 | help='model architecture: ' + 108 | ' | '.join(model_names) + 109 | ' (default: alexnet)') 110 | arg_parser.add_argument('-b', '--batch-size', default=128, type=int, 111 | metavar='N', 112 | help='batch size (default: 128)') 113 | arg_parser.add_argument('--dir', type=str, default='models/', dest='save_dir', 114 | help='path to save models (default: models/') 115 | arg_parser.add_argument('--no-cuda', action='store_true', default=False, dest='no_cuda', 116 | help='disables training on GPU') 117 | 118 | args = arg_parser.parse_args() 119 | print(args) 120 | 121 | # Data loader 122 | test_dataset = datasets.CIFAR10(root=args.data, train=False, download=True, 123 | transform=transforms.Compose([ 124 | transforms.Resize(224), 125 | transforms.ToTensor(), 126 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 127 | ])) 128 | test_loader = torch.utils.data.DataLoader( 129 | test_dataset, batch_size=args.batch_size, shuffle=False, 130 | num_workers=args.workers, pin_memory=True) 131 | 132 | # Network 133 | model_arch = args.arch 134 | cudnn.benchmark = True 135 | num_classes = _NUM_CLASSES 136 | model = models.__dict__[model_arch](num_classes=num_classes) 137 | 138 | if not args.no_cuda: 139 | model = model.cuda() 140 | 141 | # Evaluation 142 | filename = os.path.join(args.save_dir) 143 | 144 | model = torch.load(filename) 145 | print(model) 146 | 147 | best_acc = eval(test_loader, model, args) 148 | print('Testing accuracy:', best_acc) -------------------------------------------------------------------------------- /fig/netadapt_algo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/fig/netadapt_algo.png -------------------------------------------------------------------------------- /fig/netadapt_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/fig/netadapt_fig.png -------------------------------------------------------------------------------- /latency_lut/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /master.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import pickle 4 | import time 5 | import torch 6 | from shutil import copyfile 7 | import subprocess 8 | import sys 9 | import warnings 10 | import common 11 | import network_utils as networkUtils 12 | 13 | ''' 14 | The main file of NetAdapt. 15 | 16 | Launch workers to simplify and finetune pretrained models. 17 | ''' 18 | 19 | # Define constants. 20 | _MASTER_FOLDER_FILENAME = 'master' 21 | _WORKER_FOLDER_FILENAME = 'worker' 22 | _WORKER_PY_FILENAME = 'worker.py' 23 | _HISTORY_PICKLE_FILENAME = 'history.pickle' 24 | _HISTORY_TEXT_FILENAME = 'history.txt' 25 | _SLEEP_TIME = 1 26 | 27 | # Define keys. 28 | _KEY_MASTER_ARGS = 'master_args' 29 | _KEY_HISTORY = 'history' 30 | _KEY_RESOURCE = 'resource' 31 | _KEY_ACCURACY = 'accuracy' 32 | _KEY_SOURCE_MODEL_PATH = 'source_model_path' 33 | _KEY_BLOCK = 'block' 34 | _KEY_ITERATION = 'iteration' 35 | _KEY_GPU = 'gpu' 36 | _KEY_MODEL = 'model' 37 | _KEY_NETWORK_DEF = 'network_def' 38 | _KEY_NUM_OUT_CHANNELS = 'num_out_channels' 39 | 40 | # Supported network_utils 41 | network_utils_all = sorted(name for name in networkUtils.__dict__ 42 | if name.islower() and not name.startswith("__") 43 | and callable(networkUtils.__dict__[name])) 44 | 45 | 46 | def _launch_worker(worker_folder, model_path, block, resource_type, constraint, netadapt_iteration, 47 | short_term_fine_tune_iteration, input_data_shape, job_list, available_gpus, 48 | lookup_table_path, dataset_path, model_arch): 49 | ''' 50 | `master.py` launches several `worker.py`. 51 | Each `worker.py` prunes one specific block and fine-tune it. 52 | This function launches one worker to run on one gpu. 53 | 54 | Input: 55 | `worker_folder`: (string) directory where `worker.py` will save models. 56 | `model_path`:(string) path to model which `worker.py` will load as pretrained model. 57 | `block`: (int) index of block to be simplified. 58 | `resource_type`: (string) (e.g. `WEIGHTS`, `FLOPS`, and `LATENCY`). 59 | `constraint`: (float) the value of constraints (e.g. 10**6 (weights)). 60 | `netadapt_iteration`: (int) indicates the current iteration of NetAdapt. 61 | `short_term_fine_tune_iteration`: (int) short-term fine-tune iteration. 62 | `input_data_shape`: (list) input data shape (C, H, W). 63 | `job_list`: (list of dict) list of current jobs. Each job is a dict, showing current iteration, block and gpu idx. 64 | `available_gpus`: (list) list of available gpu idx. 65 | `lookup_table_path`: (string) path to lookup table. 66 | `dataset_path`: (string) path to dataset. 67 | `model_arch`: (string) specifies which network_utils will be used. 68 | 69 | Output: 70 | updated_job_list: (list of dict) 71 | updated_available_gpus: (list) 72 | ''' 73 | updated_job_list = job_list.copy() 74 | updated_available_gpus = available_gpus.copy() 75 | gpu = updated_available_gpus[0] 76 | 77 | if lookup_table_path == None: 78 | lookup_table_path = '' 79 | 80 | print(' Launch a worker for block {}'.format(block)) 81 | with open(os.path.join(worker_folder, 82 | common.WORKER_LOG_FILENAME_TEMPLATE.format(netadapt_iteration, block)), 'w') as file_id: 83 | command_list = [sys.executable, _WORKER_PY_FILENAME, worker_folder, model_path, str(block), resource_type, 84 | str(constraint), str(netadapt_iteration), str(short_term_fine_tune_iteration), str(gpu), 85 | lookup_table_path, dataset_path] + [str(e) for e in input_data_shape] + [model_arch] + [str(args.finetune_lr)] 86 | 87 | print(command_list) 88 | 89 | subprocess.Popen(command_list, stdout=file_id, stderr=file_id) 90 | 91 | updated_job_list.append({_KEY_ITERATION: netadapt_iteration, _KEY_BLOCK: block, _KEY_GPU: gpu}) 92 | del updated_available_gpus[0] 93 | 94 | return updated_job_list, updated_available_gpus 95 | 96 | 97 | def _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus): 98 | ''' 99 | update job list and available gpu list based on whether a worker finishes pruning and fine-tuning. 100 | 101 | Input: 102 | `worker_folder`: (string) directory where `worker.py` will save models. 103 | `job_list`: (list of dict) list of current jobs. Each job is a dict, showing current iteration, block and gpu idx. 104 | `available_gpus`: (list) list of available gpu idx. 105 | 106 | Output: 107 | `updated_job_list`: (list of dict) if a worker finishes its job, the job will be removed from this list. 108 | `updated_available_gpus`: (list) if a worker finishes its job, the gpu will be available. 109 | ''' 110 | updated_job_list = [] 111 | updated_available_gpus = available_gpus.copy() 112 | for job in job_list: 113 | if os.path.exists(os.path.join(worker_folder, common.WORKER_FINISH_FILENAME_TEMPLATE.format(job[_KEY_ITERATION], 114 | job[_KEY_BLOCK]))): 115 | # Find corresponding finish file of worker 116 | updated_available_gpus.append(job[_KEY_GPU]) 117 | else: 118 | updated_job_list.append(job) 119 | 120 | return updated_job_list, updated_available_gpus 121 | 122 | 123 | def _find_best_model(worker_folder, iteration, num_blocks, starting_accuracy, starting_resource): 124 | ''' 125 | After all workers finish jobs, select the model with best accuracy-to-resource ratio 126 | 127 | Input: 128 | `worker_folder`: (string) directory where `worker.py` will save models. 129 | `iteration`: (int) NetAdapt iteration. 130 | `num_blocks`: (int) num of simplifiable blocks at each iteration. 131 | `starting_accuracy`: (float) initial accuracy before pruning and fine-tuning. 132 | `start_resource`: (float) initial resource sonsumption. 133 | 134 | Output: 135 | `best_accuracy`: (float) accuracy of the best pruned model. 136 | `best_model_path`: (string) path to the best model. 137 | `best_resource`: (float) resource consumption of the best model. 138 | `best_block`: (int) block index of the best model. 139 | ''' 140 | 141 | best_ratio = float('Inf') 142 | best_accuracy = 0.0 143 | best_model_path = None 144 | best_resource = None 145 | best_block = None 146 | for block_idx in range(num_blocks): 147 | with open(os.path.join(worker_folder, common.WORKER_ACCURACY_FILENAME_TEMPLATE.format(iteration, block_idx)), 148 | 'r') as file_id: 149 | accuracy = float(file_id.read()) 150 | with open(os.path.join(worker_folder, common.WORKER_RESOURCE_FILENAME_TEMPLATE.format(iteration, block_idx)), 151 | 'r') as file_id: 152 | resource = float(file_id.read()) 153 | #ratio_resource_accuracy = (starting_accuracy - accuracy) / (starting_resource - resource + 1e-5) 154 | ratio_resource_accuracy = (starting_accuracy - accuracy + 1e-6) / (starting_resource - resource + 1e-5) 155 | 156 | print('Block id {}: resource {}, accuracy {}'.format(block_idx, resource, accuracy)) 157 | if resource < starting_resource and ratio_resource_accuracy < best_ratio: 158 | #if resource < starting_resource and accuracy > best_accuracy: 159 | best_ratio = ratio_resource_accuracy 160 | best_accuracy = accuracy 161 | best_model_path = os.path.join(worker_folder, 162 | common.WORKER_MODEL_FILENAME_TEMPLATE.format(iteration, block_idx)) 163 | best_resource = resource 164 | best_block = block_idx 165 | print('Best block id: {}\n'.format(best_block)) 166 | 167 | return best_accuracy, best_model_path, best_resource, best_block 168 | 169 | 170 | def _save_and_print_history(network_utils, history, pickle_file_path, text_file_path): 171 | ''' 172 | save history info (log: history.txt, history file: history.pickle) 173 | 174 | Input: 175 | `network_utils`: (defined in network_utils/network_utils_*) use the .extra_history_info() 176 | to get the num of output channels. 177 | `history`: (dict) records accuracy, resource, block idx, model path for each iteration and 178 | input arguments. 179 | `pickle_file_path`: (string) path to save history dict. 180 | `text_file_path`: (string) path to save history log. 181 | ''' 182 | with open(pickle_file_path, 'wb') as file_id: 183 | pickle.dump(history, file_id) 184 | with open(text_file_path, 'w') as file_id: 185 | file_id.write('Iteration,Accuracy,Resource,Block,Source Model\n') 186 | for iter in range(len(history[_KEY_HISTORY])): 187 | 188 | # assume the extra hisotry info is the # of output channels per layer 189 | num_filters_str = network_utils.extra_history_info(history[_KEY_HISTORY][iter][_KEY_NETWORK_DEF]) 190 | file_id.write('{},{},{},{},{},{}\n'.format(iter, history[_KEY_HISTORY][iter][_KEY_ACCURACY], 191 | history[_KEY_HISTORY][iter][_KEY_RESOURCE], 192 | history[_KEY_HISTORY][iter][_KEY_BLOCK], 193 | history[_KEY_HISTORY][iter][_KEY_SOURCE_MODEL_PATH], 194 | num_filters_str)) 195 | 196 | 197 | def master(args): 198 | """ 199 | The main function of the master. 200 | 201 | Note: iteration 0 means the initial model. 202 | 203 | Input: 204 | args: input arguments 205 | 206 | raise: 207 | ValueError: when: 208 | (1) no available gpus (i.e. len(args.gpus) == 0) 209 | (2) resume from previous iteration and required to use lookup table but no loookup table found 210 | (3) files exist under working_folder/master or working_folder/worker and not use `--resume` 211 | (4) target resource is not achievable (i.e. the resource consumption at a certain iteration is the same as that at previous iteration) 212 | """ 213 | 214 | # Set the important paths. 215 | master_folder = os.path.join(args.working_folder, _MASTER_FOLDER_FILENAME) 216 | worker_folder = os.path.join(args.working_folder, _WORKER_FOLDER_FILENAME) 217 | history_pickle_file = os.path.join(master_folder, _HISTORY_PICKLE_FILENAME) 218 | history_text_file = os.path.join(master_folder, _HISTORY_TEXT_FILENAME) 219 | 220 | # Get available GPUs. 221 | available_gpus = args.gpus 222 | if len(available_gpus) == 0: 223 | raise ValueError('At least one gpu must be specified.') 224 | 225 | # Resume or do iteration 0. 226 | if args.resume: 227 | with open(history_pickle_file, 'rb') as file_id: 228 | history = pickle.load(file_id) 229 | args = history[_KEY_MASTER_ARGS] 230 | 231 | # Initialize variables. 232 | current_iter = len(history[_KEY_HISTORY]) - 1 233 | current_resource = history[_KEY_HISTORY][-1][_KEY_RESOURCE] 234 | current_model_path = os.path.join(master_folder, 235 | common.MASTER_MODEL_FILENAME_TEMPLATE.format(current_iter)) 236 | current_accuracy = history[_KEY_HISTORY][-1][_KEY_ACCURACY] 237 | 238 | # Get the network utils. 239 | model = torch.load(current_model_path, map_location=lambda storage, loc: storage) 240 | 241 | # Select network_utils. 242 | model_arch = args.arch 243 | network_utils = networkUtils.__dict__[model_arch](model, args.input_data_shape, args.dataset_path) 244 | 245 | if args.lookup_table_path != None and not os.path.exists(args.lookup_table_path): 246 | errMsg = 'Resume from a previous task but the {} lookup table is not found.'.format(args.resource_type) 247 | raise ValueError(errMsg) 248 | del model 249 | 250 | # Print the message. 251 | print(('Resume from iteration {:>3}: current_accuracy = {:>8.3f}, ' 252 | 'current_resource = {:>8.3f}').format(current_iter, current_accuracy, current_resource)) 253 | print('arguments:', args) 254 | 255 | else: 256 | # Initialize the iteration. 257 | current_iter = 0 258 | 259 | # Create the folder structure. 260 | if not os.path.exists(args.working_folder): 261 | os.makedirs(args.working_folder) 262 | print('Create directory', args.working_folder) 263 | if not os.path.exists(master_folder): 264 | os.mkdir(master_folder) 265 | print('Create directory', master_folder) 266 | elif os.listdir(master_folder): 267 | errMsg = 'Find previous files in the master directory {}. Please use `--resume` or delete those files'.format(master_folder) 268 | raise ValueError(errMsg) 269 | 270 | if not os.path.exists(worker_folder): 271 | os.mkdir(worker_folder) 272 | print('Create directory', worker_folder) 273 | elif os.listdir(worker_folder): 274 | errMsg = 'Find previous files in the worker directory {}. Please use `--resume` or delete those files'.format(worker_folder) 275 | raise ValueError(errMsg) 276 | 277 | # Backup the initial model. 278 | current_model_path = os.path.join(master_folder, 279 | common.MASTER_MODEL_FILENAME_TEMPLATE.format(current_iter)) 280 | copyfile(args.init_model_path, current_model_path) 281 | 282 | # Initialize variables. 283 | model = torch.load(current_model_path) 284 | 285 | # Select network_utils. 286 | model_arch = args.arch 287 | network_utils = networkUtils.__dict__[model_arch](model, args.input_data_shape, args.dataset_path) 288 | 289 | network_def = network_utils.get_network_def_from_model(model) 290 | if args.lookup_table_path != None and not os.path.exists(args.lookup_table_path): 291 | warnMsg = 'The {} lookup table is not found and going to be built.'.format(args.resource_type) 292 | warnings.warn(warnMsg) 293 | network_utils.build_lookup_table(network_def, args.resource_type, args.lookup_table_path) 294 | current_resource = network_utils.compute_resource(network_def, args.resource_type, args.lookup_table_path) 295 | 296 | current_accuracy = network_utils.evaluate(model) 297 | current_block = None 298 | 299 | if args.init_resource_reduction == None: 300 | args.init_resource_reduction = args.init_resource_reduction_ratio*current_resource 301 | print('`--init_resource_reduction` is not specified') 302 | print('Use `--init_resource_reduction_ratio` ({}) to get `init_resource_reduction` ({})\n'.format( 303 | args.init_resource_reduction_ratio, args.init_resource_reduction)) 304 | if args.budget == None: 305 | args.budget = args.budget_ratio*current_resource 306 | print('`--budget` is not specified') 307 | print('Use `--budget_ratio` ({}) to get `budget` ({})\n'.format( 308 | args.budget_ratio, args.budget)) 309 | 310 | # Create and save the history. 311 | history = {_KEY_MASTER_ARGS: args, _KEY_HISTORY: []} 312 | history[_KEY_HISTORY].append({_KEY_RESOURCE: current_resource, 313 | _KEY_SOURCE_MODEL_PATH: args.init_model_path, 314 | _KEY_ACCURACY: current_accuracy, 315 | _KEY_BLOCK: current_block, 316 | _KEY_NETWORK_DEF: network_def}) 317 | _save_and_print_history(network_utils, history, history_pickle_file, history_text_file) 318 | del model, network_def 319 | 320 | # Print the message. 321 | print(('Start from iteration {:>3}: current_accuracy = {:>8.3f}, ' 322 | 'current_resource = {:>8.3f}').format(current_iter, current_accuracy, current_resource)) 323 | 324 | 325 | current_iter += 1 326 | 327 | # Start adaptation. 328 | while current_iter <= args.max_iters and current_resource > args.budget: 329 | 330 | start_time = time.time() 331 | 332 | # Set the target resource. 333 | target_resource = current_resource - args.init_resource_reduction * ( 334 | args.resource_reduction_decay ** (current_iter - 1)) 335 | 336 | # Print the message. 337 | print('===================================================================') 338 | print( 339 | ('Process iteration {:>3}: current_accuracy = {:>8.3f}, ' 340 | 'current_resource = {:>8.3f}, target_resource = {:>8.3f}').format( 341 | current_iter, current_accuracy, current_resource, target_resource)) 342 | 343 | # Launch the workers. 344 | job_list = [] 345 | 346 | # Launch worker for each block 347 | for block_idx in range(network_utils.get_num_simplifiable_blocks()): 348 | # Check and update the gpu availability. 349 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus) 350 | while not available_gpus: 351 | # print(' Wait for the next available gpu...') 352 | time.sleep(_SLEEP_TIME) 353 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus) 354 | 355 | # Launch a worker. 356 | job_list, available_gpus = _launch_worker(worker_folder, current_model_path, block_idx, args.resource_type, 357 | target_resource, current_iter, 358 | args.short_term_fine_tune_iteration, args.input_data_shape, 359 | job_list, available_gpus, args.lookup_table_path, 360 | args.dataset_path, args.arch) 361 | print('Update job list: ', job_list) 362 | print('Update available gpu:', available_gpus, '\n') 363 | 364 | # Wait until all the workers finish. 365 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus) 366 | while job_list: 367 | time.sleep(_SLEEP_TIME) 368 | job_list, available_gpus = _update_job_list_and_available_gpus(worker_folder, job_list, available_gpus) 369 | 370 | # Find the best model. 371 | best_accuracy, best_model_path, best_resource, best_block = ( 372 | _find_best_model(worker_folder, current_iter, network_utils.get_num_simplifiable_blocks(), current_accuracy, 373 | current_resource)) 374 | 375 | # Check whether the target_resource is achieved. 376 | if not best_model_path: 377 | raise ValueError('target_resource {} is not achievable in iter {}.'.format(target_resource, current_iter)) 378 | if best_resource > target_resource: 379 | warnMsg = "Iteration {}: target resource {} is not achieved. Current best resource is {}".format(current_iter, target_resource, best_resource) 380 | warnings.warn(warnMsg) 381 | 382 | # Update the variables. 383 | current_model_path = os.path.join(master_folder, 384 | common.MASTER_MODEL_FILENAME_TEMPLATE.format(current_iter)) 385 | copyfile(best_model_path, current_model_path) 386 | current_accuracy = best_accuracy 387 | current_resource = best_resource 388 | current_block = best_block 389 | 390 | if args.save_interval == -1 or (current_iter % args.save_interval != 0): 391 | for block_idx in range(network_utils.get_num_simplifiable_blocks()): 392 | temp_model_path = os.path.join(worker_folder, common.WORKER_MODEL_FILENAME_TEMPLATE.format(current_iter, block_idx)) 393 | os.remove(temp_model_path) 394 | print('Remove', temp_model_path) 395 | print(' ') 396 | 397 | # Save and print the history. 398 | model = torch.load(current_model_path) 399 | if type(model) is dict: 400 | model = model[_KEY_MODEL] 401 | network_def = network_utils.get_network_def_from_model(model) 402 | history[_KEY_HISTORY].append({_KEY_RESOURCE: current_resource, 403 | _KEY_SOURCE_MODEL_PATH: best_model_path, 404 | _KEY_ACCURACY: current_accuracy, 405 | _KEY_BLOCK: current_block, 406 | _KEY_NETWORK_DEF: network_def}) 407 | _save_and_print_history(network_utils, history, history_pickle_file, history_text_file) 408 | del model, network_def 409 | 410 | current_iter += 1 411 | 412 | print('Finish iteration {}: time {}'.format(current_iter-1, time.time()-start_time)) 413 | 414 | 415 | if __name__ == '__main__': 416 | # Parse the input arguments. 417 | arg_parser = ArgumentParser() 418 | arg_parser.add_argument('working_folder', type=str, 419 | help='Root folder where models, related files and history information are saved.') 420 | arg_parser.add_argument('input_data_shape', nargs=3, default=[3, 224, 224], type=int, 421 | help='Input data shape (C, H, W) (default: 3 224 224).') 422 | arg_parser.add_argument('-gp', '--gpus', nargs='+', default=[0], type=int, 423 | help='Indices of available gpus (default: 0).') 424 | arg_parser.add_argument('-re', '--resume', action='store_true', 425 | help='Resume from previous iteration. In order to resume, specify `--resume` and specify `working_folder` as the one you want to resume.') 426 | arg_parser.add_argument('-im', '--init_model_path', 427 | help='Path to pretrained model.') 428 | arg_parser.add_argument('-mi', '--max_iters', type=int, default=10, 429 | help='Maximum iteration of removing filters and short-term fine-tune (default: 10).') 430 | arg_parser.add_argument('-lr', '--finetune_lr', type=float, default=0.001, 431 | help='Short-term fine-tune learning rate (default: 0.001).') 432 | 433 | arg_parser.add_argument('-bu', '--budget', type=float, default=None, 434 | help='Resource constraint. If resource < `budget`, the process is terminated.') 435 | arg_parser.add_argument('-bur', '--budget_ratio', type=float, default=0.25, 436 | help='If `--budget` is not specified, `buget` = `budget_ratio`*(pretrained model resource) (default: 0.25).') 437 | 438 | arg_parser.add_argument('-rt', '--resource_type', type=str, default='FLOPS', 439 | help='Resource constraint type (default: FLOPS). We currently support `FLOPS`, `WEIGHTS`, and `LATENCY` (device cuda:0). If you want to add other resource types, please modify network_util.') 440 | 441 | arg_parser.add_argument('-ir', '--init_resource_reduction', type=float, default=None, 442 | help='For each iteration, target resource = current resource - `init_resource_reduction`*(`resource_reduction_decay`**(iteration-1)).') 443 | arg_parser.add_argument('-irr', '--init_resource_reduction_ratio', type=float, default=0.025, 444 | help='If `--init_resource_reduction` is not specified, `init_resource_reduction` = `init_resource_reduction_ratio`*(pretrained model resource) (default: 0.025).') 445 | 446 | 447 | arg_parser.add_argument('-rd', '--resource_reduction_decay', type=float, default=0.96, 448 | help='For each iteration, target resource = current resource - `init_resource_reduction`*(`resource_reduction_decay`**(iteration-1)) (default: 0.96).') 449 | arg_parser.add_argument('-st', '--short_term_fine_tune_iteration', type=int, default=10, 450 | help='Short-term fine-tune iteration (default: 10).') 451 | 452 | arg_parser.add_argument('-lt', '--lookup_table_path', type=str, default=None, 453 | help='Path to lookup table.') 454 | arg_parser.add_argument('-dp', '--dataset_path', type=str, default='', 455 | help='Path to dataset.') 456 | 457 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet', 458 | choices=network_utils_all, 459 | help='network_utils: ' + 460 | ' | '.join(network_utils_all) + 461 | ' (default: alexnet). Defines how networks are pruned, fine-tuned, and evaluated. If you want to use your own method, please specify here.') 462 | 463 | arg_parser.add_argument('-si', '--save_interval', type=int, default=-1, 464 | help='Interval of iterations that all pruned models at the same iteration will be saved. Use `-1` to save only the best model at each iteration. Use `1` to save all models at each iteration. (default: -1).') 465 | 466 | print(network_utils_all) 467 | 468 | args = arg_parser.parse_args() 469 | 470 | # Launch the master. 471 | print(args) 472 | master(args) 473 | -------------------------------------------------------------------------------- /models/alexnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/helloworld/lut.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/models/helloworld/lut.pkl -------------------------------------------------------------------------------- /models/helloworld/model_0.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/models/helloworld/model_0.pth.tar -------------------------------------------------------------------------------- /models/mobilenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denru01/netadapt/248ae1c607899c5b8c6742ea0c0ba9b58ffcafdf/models/mobilenet/__init__.py -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .mobilenet import * 3 | from .helloworld import * -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['AlexNet', 'alexnet'] 7 | 8 | 9 | model_urls = { 10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 11 | } 12 | 13 | 14 | class AlexNet(nn.Module): 15 | 16 | def __init__(self): 17 | super(AlexNet, self).__init__() 18 | self.features = nn.Sequential( 19 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=3, stride=2), 22 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=3, stride=2), 25 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.MaxPool2d(kernel_size=3, stride=2), 32 | ) 33 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 34 | self.classifier = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(256 * 6 * 6, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(4096, 1000), 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.features(x) 46 | x = self.avgpool(x) 47 | x = torch.flatten(x, 1) 48 | x = self.classifier(x) 49 | return x 50 | 51 | 52 | def alexnet(pretrained=False, progress=True, num_classes=1000): 53 | r"""AlexNet model architecture from the 54 | `"One weird trick..." `_ paper. 55 | Args: 56 | pretrained (bool): If True, returns a model pre-trained on ImageNet 57 | progress (bool): If True, displays a progress bar of the download to stderr 58 | """ 59 | model = AlexNet() 60 | if pretrained: 61 | state_dict = model_zoo.load_url(model_urls['alexnet'], progress=progress) 62 | model.load_state_dict(state_dict) 63 | if num_classes != 1000: 64 | num_in_feature = model.classifier[6].in_features 65 | model.classifier[6] = nn.Linear(num_in_feature, num_classes) 66 | return model -------------------------------------------------------------------------------- /nets/helloworld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['HelloWorld', 'helloworld'] 6 | 7 | class HelloWorld(nn.Module): 8 | 9 | def __init__(self, num_classes=10): 10 | super(HelloWorld, self).__init__() 11 | self.features = nn.Sequential( 12 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1, bias=False) 19 | ) 20 | self.avgpool = nn.AvgPool2d(32, 32) 21 | 22 | def forward(self, x): 23 | x = self.features(x) 24 | x = self.avgpool(x) 25 | x = x.view(x.shape[0], -1) 26 | return x 27 | 28 | def helloworld(num_classes=10): 29 | return HelloWorld() 30 | -------------------------------------------------------------------------------- /nets/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MobileNet(nn.Module): 5 | def __init__(self, relu6=False): 6 | super(MobileNet, self).__init__() 7 | 8 | def relu(relu6): 9 | if relu6: 10 | return nn.ReLU6(inplace=True) 11 | else: 12 | return nn.ReLU(inplace=True) 13 | 14 | def conv_bn(inp, oup, stride, relu6): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 17 | nn.BatchNorm2d(oup), 18 | relu(relu6), 19 | ) 20 | 21 | def conv_dw(inp, oup, stride, relu6): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 24 | nn.BatchNorm2d(inp), 25 | relu(relu6), 26 | 27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 28 | nn.BatchNorm2d(oup), 29 | relu(relu6), 30 | ) 31 | 32 | self.model = nn.Sequential( 33 | conv_bn( 3, 32, 2, relu6), 34 | conv_dw( 32, 64, 1, relu6), 35 | conv_dw( 64, 128, 2, relu6), 36 | conv_dw(128, 128, 1, relu6), 37 | conv_dw(128, 256, 2, relu6), 38 | conv_dw(256, 256, 1, relu6), 39 | conv_dw(256, 512, 2, relu6), 40 | conv_dw(512, 512, 1, relu6), 41 | conv_dw(512, 512, 1, relu6), 42 | conv_dw(512, 512, 1, relu6), 43 | conv_dw(512, 512, 1, relu6), 44 | conv_dw(512, 512, 1, relu6), 45 | conv_dw(512, 1024, 2, relu6), 46 | conv_dw(1024, 1024, 1, relu6), 47 | nn.AvgPool2d(7), 48 | ) 49 | self.fc = nn.Linear(1024, 1000) 50 | 51 | def forward(self, x): 52 | x = self.model(x) 53 | x = x.view(x.shape[0], -1) 54 | x = self.fc(x) 55 | return x 56 | 57 | def mobilenet(pretrained=False, progress=False, num_classes=1000): 58 | model = MobileNet() 59 | if pretrained: 60 | print('Cannot download pretrained model') 61 | if num_classes != 1000: 62 | num_in_feature = model.fc.in_features 63 | model.fc = nn.Linear(num_in_feature, num_classes) 64 | return model -------------------------------------------------------------------------------- /network_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .network_utils_abstract import * 2 | from .network_utils_alexnet import * 3 | from .network_utils_mobilenet import * 4 | from .network_utils_helloworld import * -------------------------------------------------------------------------------- /network_utils/network_utils_abstract.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class NetworkUtilsAbstract(ABC): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | 9 | 10 | @abstractmethod 11 | def get_network_def_from_model(self, model): 12 | ''' 13 | network_def contains information about each layer within a model 14 | 15 | Input: 16 | `model`: pytorch model (e.g. nn.Sequential()) 17 | 18 | Output: 19 | `network_def`: network_def contains layerwise information (e.g. num of output/input channels). 20 | network_def will be used to compute resoure and guide pruning models. 21 | 22 | please refer to def get_network_def_from_model() in functions.py 23 | to see one implementation. 24 | ''' 25 | pass 26 | 27 | 28 | @abstractmethod 29 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type, 30 | lookup_table_path): 31 | ''' 32 | Derive how much a certain block of layers ('block') should be simplified 33 | based on resource constraints. 34 | 35 | Input: 36 | `network_def`: defined in get_network_def_from_model() 37 | `constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy 38 | `resource_type`: (string) `FLOPs`, `WEIGHTS`, or `LATENCY` 39 | `lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY' 40 | 41 | Output: 42 | `simplified_network_def`: simplified network definition. Indicates how much the network should 43 | be simplified/pruned. 44 | `simplified_resource`: (float) the estimated resource consumption of simplified network_def. 45 | 46 | please refer to def simplify_network_def_based_on_constraint(...) in functions.py 47 | to see one implementation. 48 | ''' 49 | pass 50 | 51 | 52 | @abstractmethod 53 | def simplify_model_based_on_network_def(self, simplified_network_def, model): 54 | ''' 55 | Choose which filters to perserve 56 | 57 | Input: 58 | `simplified_network_def`: network_def shows how a model will be pruned. 59 | defined in get_network_def_from_model(). 60 | Get simplified_network_def from the output `simplified_network_def` of 61 | self.simplify_network_def_based_on_constraint() 62 | 63 | `model`: model to be simplified. 64 | 65 | Output: 66 | `simplified_model`: simplified model. 67 | 68 | please refer to def simplify_model_based_on_network_def(...) in functions.py 69 | to see one implementation 70 | ''' 71 | pass 72 | 73 | 74 | @abstractmethod 75 | def extra_history_info(self, network_def): 76 | ''' 77 | return # of output channels per layer 78 | 79 | Input: 80 | `network_def`: defined in get_network_def_from_model() 81 | 82 | Output: 83 | `num_filters_str`: (string) show the num of output channels for each layer. 84 | Or you can define your own log 85 | ''' 86 | pass 87 | 88 | 89 | @abstractmethod 90 | def build_lookup_table(self, network_def, resource_type, lookup_table_path): 91 | ''' 92 | Build lookup table for layers defined by `network_def`. 93 | 94 | Input: 95 | `network_def`: defined in get_network_def_from_model() 96 | `resource_type`: (string) resource type (e.g. 'LATENCY') 97 | `lookup_table_path`: (string) path to save the file of lookup table 98 | ''' 99 | pass 100 | 101 | 102 | @abstractmethod 103 | def compute_resource(self, network_def, resource_type, lookup_table_path): 104 | ''' 105 | compute resource based on resource type 106 | 107 | Input: 108 | `network_def`: defined in get_network_def_from_model() 109 | `resource_type`: (string) resource type (e.g. 'WEIGHTS'/'LATENCY'/'FLOPS') 110 | `lookup_table_path`: (string) path to lookup table 111 | 112 | Output: 113 | `resource`: (float) 114 | ''' 115 | pass 116 | 117 | 118 | @abstractmethod 119 | def get_num_simplifiable_blocks(self): 120 | ''' 121 | Output: 122 | `num_splifiable_blocks`: (int) num of blocks whose num of output channels can be reduced. 123 | Note that simplifiable blocks do not include output layer 124 | ''' 125 | pass 126 | 127 | 128 | @abstractmethod 129 | def fine_tune(self, model, iterations): 130 | ''' 131 | short-term fine-tune a simplified model 132 | 133 | Input: 134 | `model`: model to be fine-tuned 135 | `iterations`: (int) num of short-term fine-tune iterations 136 | 137 | Output: 138 | `model`: fine-tuned model 139 | ''' 140 | pass 141 | 142 | 143 | @abstractmethod 144 | def evaluate(self, model): 145 | ''' 146 | Evaluate the accuracy of the model 147 | 148 | Input: 149 | `model`: model to be evaluated 150 | 151 | Output: 152 | `accuracy`: (float) (0~100) 153 | ''' 154 | pass -------------------------------------------------------------------------------- /network_utils/network_utils_alexnet.py: -------------------------------------------------------------------------------- 1 | from .network_utils_abstract import NetworkUtilsAbstract 2 | from collections import OrderedDict 3 | import os 4 | import sys 5 | import copy 6 | import time 7 | import torch 8 | import pickle 9 | import warnings 10 | import torch.nn as nn 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | import torch.utils.data.sampler as sampler 14 | 15 | sys.path.append(os.path.abspath('../')) 16 | 17 | from constants import * 18 | import functions as fns 19 | 20 | ''' 21 | This is an example of NetAdapt applied to AlexNet. 22 | We measure the latency on GPU. 23 | ''' 24 | 25 | ''' 26 | The size of feature maps of simplified layers along channel dimmension 27 | are multiples of '_MIN_FEATURE_SIZE'. 28 | The reason is that on mobile devices, the computation of (B, 7, H, W) tensors 29 | would take longer time than that of (B, 8, H, W) tensors. 30 | ''' 31 | _MIN_CONV_FEATURE_SIZE = 8 32 | _MIN_FC_FEATURE_SIZE = 64 33 | 34 | ''' 35 | How many times to run the forward function of a layer in order to get its latency. 36 | ''' 37 | _MEASURE_LATENCY_SAMPLE_TIMES = 500 38 | 39 | ''' 40 | The batch size of input data when running forward functions to measure latency. 41 | ''' 42 | _MEASURE_LATENCY_BATCH_SIZE = 128 43 | 44 | class networkUtils_alexnet(NetworkUtilsAbstract): 45 | num_simplifiable_blocks = None 46 | input_data_shape = None 47 | train_loader = None 48 | holdout_loader = None 49 | val_loader = None 50 | optimizer = None 51 | 52 | def __init__(self, model, input_data_shape, dataset_path, finetune_lr=1e-3): 53 | ''' 54 | Initialize: 55 | (1) network definition 'network_def' 56 | (2) num of simplifiable blocks 'num_simplifiable_blocks'. 57 | (3) loss function 'criterion' 58 | (4) data loader for training/validation set 'train_loader' and 'holdout_loader', 59 | 60 | Need to be implemented: 61 | (1) finetune/evaluation data loader 62 | (2) loss function 63 | (3) optimizer 64 | 65 | Input: 66 | `model`: model from which we will get network_def. 67 | `input_data_shape`: (list) [C, H, W]. 68 | `dataset_path`: (string) path to dataset. 69 | `finetune_lr`: (float) short-term fine-tune learning rate. 70 | ''' 71 | 72 | super().__init__() 73 | 74 | # Set the shape of the input data. 75 | self.input_data_shape = input_data_shape 76 | # Set network definition (conv & fc) 77 | network_def = self.get_network_def_from_model(model) 78 | # Set num_simplifiable_blocks. 79 | self.num_simplifiable_blocks = 0 80 | for layer_name, layer_properties in network_def.items(): 81 | if not layer_properties[KEY_IS_DEPTHWISE]: 82 | self.num_simplifiable_blocks += 1 83 | # We cannot reduce the number of filters in the output layer (1). 84 | # also not consider simplifying the last two FC layer 85 | self.num_simplifiable_blocks -= 1 86 | 87 | ''' 88 | The following variables need to be defined depending on tasks: 89 | (1) finetune/evaluation data loader 90 | (2) loss function 91 | (3) optimizer 92 | ''' 93 | # Data loaders for fine tuning and evaluation. 94 | self.batch_size = 128 95 | self.num_workers = 4 96 | self.momentum = 0.9 97 | self.weight_decay = 1e-4 98 | self.finetune_lr = finetune_lr 99 | 100 | train_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True, 101 | transform=transforms.Compose([ 102 | transforms.RandomCrop(32, padding=4), 103 | transforms.Resize(224), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 107 | ])) 108 | 109 | train_loader = torch.utils.data.DataLoader( 110 | train_dataset, batch_size=self.batch_size, 111 | num_workers=self.num_workers, pin_memory=True, shuffle=True)#, sampler=train_sampler) 112 | self.train_loader = train_loader 113 | 114 | val_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True, 115 | transform=transforms.Compose([ 116 | transforms.Resize(224), 117 | transforms.ToTensor(), 118 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 119 | ])) 120 | val_loader = torch.utils.data.DataLoader( 121 | val_dataset, batch_size=self.batch_size, shuffle=False, 122 | num_workers=self.num_workers, pin_memory=True) #, sampler=valid_sampler) 123 | self.val_loader = val_loader 124 | 125 | self.criterion = torch.nn.BCEWithLogitsLoss() 126 | 127 | 128 | def _get_layer_by_param_name(self, model, param_name): 129 | ''' 130 | please refer to def get_layer_by_param_name(...) in functions.py 131 | ''' 132 | return fns.get_layer_by_param_name(model, param_name) 133 | 134 | 135 | def _get_keys_from_ordered_dict(self, ordered_dict): 136 | ''' 137 | please refer to def get_keys_from_ordered_dict(...) in functions.py 138 | ''' 139 | return fns.get_keys_from_ordered_dict(ordered_dict) 140 | 141 | 142 | def get_network_def_from_model(self, model): 143 | ''' 144 | please refer to get_network_def_from_model(...) in functions.py 145 | ''' 146 | return fns.get_network_def_from_model(model, self.input_data_shape) 147 | 148 | 149 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type, 150 | lookup_table_path=None): 151 | ''' 152 | Derive how much a certain block of layers ('block') should be simplified 153 | based on resource constraints. 154 | 155 | Here we treat one block as one layer although a block can contain several layers. 156 | 157 | Input: 158 | `network_def`: simplifiable network definition (conv & fc). Get network def from self.get_network_def_from_model(...) 159 | `block`: (int) index of block to simplify 160 | `constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy 161 | `resource_type`: `FLOPs`, `WEIGHTS`, or `LATENCY` 162 | `lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY' 163 | 164 | Output: 165 | `simplified_network_def`: simplified network definition. Indicates how much the network should 166 | be simplified/pruned. 167 | `simplified_resource`: (float) the estimated resource consumption of simplified models. 168 | ''' 169 | 170 | return fns.simplify_network_def_based_on_constraint(network_def, block, constraint, 171 | resource_type, lookup_table_path) 172 | 173 | 174 | def simplify_model_based_on_network_def(self, simplified_network_def, model): 175 | ''' 176 | Choose which filters to perserve 177 | 178 | Here filters with largest L2 magnitude will be kept 179 | 180 | please refer to def simplify_model_based_on_network_def(...) in functions.py 181 | ''' 182 | 183 | return fns.simplify_model_based_on_network_def(simplified_network_def, model) 184 | 185 | 186 | def extra_history_info(self, network_def): 187 | ''' 188 | return # of output channels per layer 189 | 190 | Input: 191 | `network_def`: (dict) 192 | 193 | Output: 194 | `num_filters_str`: (string) show the num of output channels for each layer 195 | ''' 196 | num_filters_str = [str(layer_properties[KEY_NUM_OUT_CHANNELS]) for _, layer_properties in 197 | network_def.items()] 198 | num_filters_str = ' '.join(num_filters_str) 199 | return num_filters_str 200 | 201 | 202 | def _compute_weights_and_flops(self, network_def): 203 | ''' 204 | please refer to def compute_weights_and_macs(...) in functions.py 205 | ''' 206 | return fns.compute_weights_and_macs(network_def) 207 | 208 | 209 | def _compute_latency_from_lookup_table(self, network_def, lookup_table_path): 210 | ''' 211 | please refer to def compute_latency_from_lookup_table(...) in functions.py 212 | ''' 213 | return fns.compute_latency_from_lookup_table(network_def, lookup_table_path) 214 | 215 | 216 | def build_lookup_table(self, network_def_full, resource_type, lookup_table_path, 217 | min_conv_feature_size=_MIN_CONV_FEATURE_SIZE, 218 | min_fc_feature_size=_MIN_FC_FEATURE_SIZE, 219 | measure_latency_batch_size=_MEASURE_LATENCY_BATCH_SIZE, 220 | measure_latency_sample_times=_MEASURE_LATENCY_SAMPLE_TIMES, 221 | verbose=True): 222 | # Build lookup table for latency 223 | ''' 224 | please refer to def build_latency_lookup_table(...) in functions.py 225 | ''' 226 | return fns.build_latency_lookup_table(network_def_full, lookup_table_path, 227 | min_conv_feature_size=min_conv_feature_size, 228 | min_fc_feature_size=min_fc_feature_size, 229 | measure_latency_batch_size=measure_latency_batch_size, 230 | measure_latency_sample_times=measure_latency_sample_times, 231 | verbose=verbose) 232 | 233 | 234 | def compute_resource(self, network_def, resource_type, lookup_table_path=None): 235 | ''' 236 | please refer to def compute_resource(...) in functions.py 237 | ''' 238 | return fns.compute_resource(network_def, resource_type, lookup_table_path) 239 | 240 | 241 | def get_num_simplifiable_blocks(self): 242 | return self.num_simplifiable_blocks 243 | 244 | 245 | def fine_tune(self, model, iterations, print_frequency=100): 246 | ''' 247 | short-term fine-tune a simplified model 248 | 249 | Input: 250 | `model`: model to be fine-tuned. 251 | `iterations`: (int) num of short-term fine-tune iterations. 252 | `print_frequency`: (int) how often to print fine-tune info. 253 | 254 | Output: 255 | `model`: fine-tuned model. 256 | ''' 257 | 258 | _NUM_CLASSES = 10 259 | optimizer = torch.optim.SGD(model.parameters(), self.finetune_lr, 260 | momentum=self.momentum, weight_decay=self.weight_decay) 261 | model = model.cuda() 262 | model.train() 263 | dataloader_iter = iter(self.train_loader) 264 | for i in range(iterations): 265 | try: 266 | (input, target) = next(dataloader_iter) 267 | except: 268 | dataloader_iter = iter(self.train_loader) 269 | (input, target) = next(dataloader_iter) 270 | 271 | if i % print_frequency == 0: 272 | print('Fine-tuning iteration {}'.format(i)) 273 | sys.stdout.flush() 274 | 275 | target.unsqueeze_(1) 276 | target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES) 277 | target_onehot.zero_() 278 | target_onehot.scatter_(1, target, 1) 279 | target.squeeze_(1) 280 | input, target = input.cuda(), target.cuda() 281 | target_onehot = target_onehot.cuda() 282 | 283 | pred = model(input) 284 | loss = self.criterion(pred, target_onehot) 285 | optimizer.zero_grad() 286 | loss.backward() # compute gradient and do SGD step 287 | optimizer.step() 288 | return model 289 | 290 | 291 | def evaluate(self, model, print_frequency=10): 292 | ''' 293 | Evaluate the accuracy of the model 294 | 295 | Input: 296 | `model`: model to be evaluated. 297 | `print_frequency`: how often to print evaluation info. 298 | 299 | Output: 300 | accuracy: (float) (0~100) 301 | ''' 302 | 303 | model = model.cuda() 304 | model.eval() 305 | acc = .0 306 | num_samples = .0 307 | with torch.no_grad(): 308 | for i, (input, target) in enumerate(self.val_loader): 309 | input, target = input.cuda(), target.cuda() 310 | pred = model(input) 311 | pred = pred.argmax(dim=1) 312 | batch_acc = torch.sum(target == pred) 313 | acc += batch_acc.item() 314 | num_samples += pred.shape[0] 315 | 316 | if i % print_frequency == 0: 317 | fns.update_progress(i, len(self.val_loader)) 318 | print(' ') 319 | print(' ') 320 | print('Test accuracy: {:4.2f}% '.format(float(acc/num_samples*100))) 321 | print('===================================================================') 322 | return acc/num_samples*100 323 | 324 | 325 | def alexnet(model, input_data_shape, dataset_path, finetune_lr=1e-3): 326 | return networkUtils_alexnet(model, input_data_shape, dataset_path, finetune_lr) -------------------------------------------------------------------------------- /network_utils/network_utils_helloworld.py: -------------------------------------------------------------------------------- 1 | from .network_utils_abstract import NetworkUtilsAbstract 2 | import os 3 | import sys 4 | import torch 5 | import copy 6 | import pickle 7 | import warnings 8 | sys.path.append(os.path.abspath('../')) 9 | from constants import * 10 | 11 | 12 | class networkUtils_helloworld(NetworkUtilsAbstract): 13 | 14 | def __init__(self, model, input_data_shape, dataset_path=None, finetune_lr=1e-3): 15 | super(networkUtils_helloworld).__init__() 16 | ''' 17 | 4 conv layers: 18 | conv1: 3, 16 19 | conv2: 16, 32 20 | conv3: 32, 64 21 | conv4: 64, 10 22 | 23 | Input: 24 | `model`: model from which we will get network_def. 25 | `input_data_shape`: (list) [C, H, W]. 26 | `dataset_path`: (string) path to dataset. 27 | `finetune_lr`: (float) short-term fine-tune learning rate. 28 | ''' 29 | 30 | self.input_data_shape = input_data_shape 31 | self.lookup_table = None 32 | 33 | 34 | def get_network_def_from_model(self, model): 35 | ''' 36 | return network def (list) of the input model containing layerwise info 37 | 38 | Input: 39 | `model`: model we will get network_def from 40 | 41 | Output: 42 | `network_def`: (list) each element corresponds to one layer and 43 | is a tuple (num_input_channels, num_output_channels) 44 | ''' 45 | network_def = list() 46 | for idx in range(4): 47 | layer = getattr(model.features, str(idx * 2)) 48 | network_def.append((layer.in_channels, layer.out_channels)) 49 | return network_def 50 | 51 | 52 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type, 53 | lookup_table_path=None): 54 | ''' 55 | Derive how much a certain block of layers ('block') should be simplified 56 | based on resource constraints. 57 | 58 | Input: 59 | `network_def`: (list) simplifiable network definition (conv). 60 | defined in self.get_network_def_from_model(...) 61 | `block`: (int) index of block to simplify 62 | `constraint`: (float) representing the FLOPs/weights constraint the simplied model should satisfy 63 | `resource_type`: (string) `FLOPS`, `WEIGHTS` 64 | `lookup_table_path`: (string) path to lookup table. Here we construct lookup table for FLOPS and it is needed only when resource_type == 'FLOPS' 65 | 66 | Output: 67 | `simplified_network_def`: (list) simplified network_def whose resource is `simplified_resource` 68 | `simplified_resource`: (float) resource comsumption of `simplified_network_def` 69 | ''' 70 | 71 | assert block < self.get_num_simplifiable_blocks() 72 | 73 | # Determine the number of filters and the resource consumption. 74 | simplified_network_def = copy.deepcopy(network_def) 75 | simplified_resource = None 76 | return_with_constraint_satisfied = False 77 | num_out_channels_try = list(range(network_def[block][1], 0, -1)) 78 | 79 | for current_num_out_channels in num_out_channels_try: 80 | simplified_network_def[block] = (simplified_network_def[block][0], current_num_out_channels) 81 | simplified_network_def[block+1] = (current_num_out_channels, simplified_network_def[block+1][1]) 82 | simplified_resource = self.compute_resource(simplified_network_def, resource_type, lookup_table_path) 83 | if simplified_resource <= constraint: 84 | return_with_constraint_satisfied = True 85 | break 86 | if not return_with_constraint_satisfied: 87 | warnings.warn('Constraint not satisfied: constraint = {}, simplified_resource = {}'.format(constraint, 88 | simplified_resource)) 89 | return simplified_network_def, simplified_resource 90 | 91 | 92 | def simplify_model_based_on_network_def(self, simplified_network_def, model): 93 | ''' 94 | Choose which filters to perserve (Here only the first `num_filters` filters will be perserved) 95 | 96 | Input: 97 | `simplified_network_def`: (list) network_def shows how each layer should be simplified. 98 | `model`: model to be simplified. 99 | 100 | Output: 101 | `simplified_model`: simplified model 102 | ''' 103 | simplified_model = copy.deepcopy(model) 104 | 105 | for idx in range(self.get_num_simplifiable_blocks()): 106 | layer = getattr(simplified_model.features, str(idx * 2)) 107 | num_filters = simplified_network_def[idx][1] 108 | # Here we keep the first `num_filters` weights 109 | # Not based on magnitude 110 | 111 | # update output channel weight 112 | setattr(layer, WEIGHTSTRING, torch.nn.Parameter(getattr(layer, WEIGHTSTRING)[0:num_filters, :, :, :])) 113 | layer.out_channels = num_filters 114 | # update input channel weight (next layer) 115 | layer = getattr(simplified_model.features, str((idx+1)*2)) 116 | setattr(layer, WEIGHTSTRING, torch.nn.Parameter(getattr(layer, WEIGHTSTRING)[:, 0:num_filters, :, :])) 117 | layer.in_channels = num_filters 118 | return simplified_model 119 | 120 | 121 | def extra_history_info(self, network_def): 122 | ''' 123 | Output num of channels layerwise 124 | 125 | Input: 126 | `network_def`: (list) defined in self.get_network_def_from_model() 127 | 128 | Output: 129 | `num_filters_str`: (string) representing num of output channels 130 | ''' 131 | num_filters_str = [] 132 | for layer_idx in range(len(network_def)): 133 | num_filters_str.append(str(network_def[layer_idx][1])) 134 | num_filters_str = ' '.join(num_filters_str) 135 | return num_filters_str 136 | 137 | 138 | def _compute_weights(self, network_def): 139 | ''' 140 | Compute the number of parameters of a whole network. 141 | (considering only weights) 142 | 143 | Input: 144 | `network_def`: (list) defined in get_network_def_from_model() 145 | 146 | Output: 147 | `total_num_weights`: (float) num of weights 148 | ''' 149 | total_num_weights = 0.0 150 | for layer_idx, layer_properties in enumerate(network_def): 151 | layer_num_weights = network_def[layer_idx][0] * network_def[layer_idx][1] * 3 * 3 152 | total_num_weights += layer_num_weights 153 | return total_num_weights 154 | 155 | 156 | def _compute_flops_from_lookup_table(self, network_def, lookup_table_path): 157 | # Note that it return FLOPs 158 | ''' 159 | Compute FLOPs from a lookup table. 160 | 161 | Although num of FLOPs can be calculated, 162 | we use lookup table here to show how NetAdapt framework uses lookup tables for resource estimation. 163 | 164 | Input: 165 | `network_def`: (list) defined in get_network_def_from_model() 166 | `lookup_table_path`: (string) path to lookup table 167 | 168 | Output: 169 | `resource`: (float) num of flops 170 | ''' 171 | resource = 0 172 | if self.lookup_table == None: 173 | with open(lookup_table_path, 'rb') as file_id: 174 | self.lookup_table = pickle.load(file_id) 175 | for layer_idx in range(len(network_def)): 176 | if (network_def[layer_idx][0], network_def[layer_idx][1]) in self.lookup_table[layer_idx].keys(): 177 | resource += self.lookup_table[layer_idx][(network_def[layer_idx][0], network_def[layer_idx][1])] 178 | return resource 179 | 180 | 181 | def build_lookup_table(self, network_def_full, resource_type, lookup_table_path): 182 | ''' 183 | Build lookup table 184 | Here we only build a lookup table for FLOPs 185 | 186 | Input: 187 | `network_def_full`: (list) defined in get_network_def_from_model() 188 | `resource_type`: not used here as we build 'FLOPS' here 189 | `lookup_table_path`: (string) path to save lookup table 190 | ''' 191 | 192 | lookup_table = [] 193 | print("Building lookup table.") 194 | for i in range(4): 195 | feature_map_resource = dict() 196 | for num_in_channels in range(network_def_full[i][0], 0, -1): 197 | for num_out_channels in range(network_def_full[i][1], 0, -1): 198 | feature_map_resource[(num_in_channels, num_out_channels)] = num_in_channels*num_out_channels*32*32*3*3 199 | lookup_table.append(feature_map_resource) 200 | with open(lookup_table_path, 'wb') as file_id: 201 | pickle.dump(lookup_table, file_id) 202 | return 203 | 204 | 205 | def compute_resource(self, network_def, resource_type, lookup_table_path=None): 206 | ''' 207 | Input: 208 | `network_def`: (list) defined in get_network_def_from_model() 209 | `resource_type`: (string) 'FLOPS'/'WEIGHTS' 210 | `lookup_table_path`: (string) path to lookup table 211 | 212 | Output: 213 | resource: (float) num of flops or weights 214 | ''' 215 | if resource_type == 'FLOPS': 216 | return self._compute_flops_from_lookup_table(network_def, lookup_table_path) 217 | else: 218 | return self._compute_weights(network_def) 219 | 220 | 221 | def get_num_simplifiable_blocks(self): 222 | ''' 223 | 4 conv layers 224 | 225 | the # of output channel of the last layer is not reducible 226 | ''' 227 | return 3 228 | 229 | 230 | def fine_tune(self, model, iterations, print_frequency=100): 231 | ''' 232 | do not finetune in this example 233 | 234 | please specify data loader, loss function, optimizer to customize finetuning in your case 235 | 236 | Input: 237 | `model`: model whose weights will be modified 238 | `iterations`: (int) num of iteration to change model weights 239 | `print_frequency`: (int) how often to print log info 240 | 241 | Output: 242 | `finetune_model`: model whose weights have been modified 243 | ''' 244 | 245 | finetune_model = copy.deepcopy(model) 246 | for i in range(iterations): 247 | for idx in range(4): 248 | layer = getattr(finetune_model.features, str(idx * 2)) 249 | layer.weight.data = layer.weight.data + idx 250 | return finetune_model 251 | 252 | 253 | def evaluate(self, model): 254 | ''' 255 | for simplicity, we return a value determined by the network architecture 256 | 257 | please specify evaluate function in your case 258 | 259 | Input: 260 | `model`: model whose architecture will determine the output value 261 | 262 | Output: 263 | `acc`: (int) value depends on the input model architecture 264 | ''' 265 | 266 | network_def = self.get_network_def_from_model(model) 267 | acc = 0 268 | if network_def[0][1] != 16 and network_def[1][1] == 32 and network_def[2][1] == 64: 269 | acc = 1 270 | elif network_def[0][1] == 16 and network_def[1][1] != 32 and network_def[2][1] == 64: 271 | acc = 5 272 | elif network_def[0][1] == 16 and network_def[1][1] == 32 and network_def[2][1] != 64: 273 | acc = 80 274 | elif network_def[0][1] != 16 and network_def[1][1] == 32 and network_def[2][1] != 64: 275 | acc = 85 276 | elif network_def[0][1] == 16 and network_def[1][1] != 32 and network_def[2][1] != 64: 277 | acc = 10 278 | elif network_def[0][1] != 16 and network_def[1][1] == 32 and network_def[2][1] != 64: 279 | acc = 12 280 | elif network_def[0][1] != 16 and network_def[1][1] != 32 and network_def[2][1] != 64: 281 | acc = 90 282 | else: 283 | acc = 95 284 | 285 | return acc 286 | 287 | 288 | def helloworld(model, input_data_shape, dataset_path=None, finetune_lr=1e-3): 289 | return networkUtils_helloworld(model, input_data_shape, dataset_path, finetune_lr) -------------------------------------------------------------------------------- /network_utils/network_utils_mobilenet.py: -------------------------------------------------------------------------------- 1 | from .network_utils_abstract import NetworkUtilsAbstract 2 | from collections import OrderedDict 3 | import os 4 | import sys 5 | import copy 6 | import time 7 | import torch 8 | import pickle 9 | import warnings 10 | import torch.nn as nn 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | import torch.utils.data.sampler as sampler 14 | 15 | sys.path.append(os.path.abspath('../')) 16 | 17 | from constants import * 18 | import functions as fns 19 | 20 | ''' 21 | This is an example of NetAdapt applied to MobileNet. 22 | We measure the latency on GPU. 23 | ''' 24 | 25 | ''' 26 | The size of feature maps of simplified layers along channel dimmension 27 | are multiples of '_MIN_FEATURE_SIZE'. 28 | The reason is that on mobile devices, the computation of (B, 7, H, W) tensors 29 | would take longer time than that of (B, 8, H, W) tensors. 30 | ''' 31 | _MIN_CONV_FEATURE_SIZE = 8 32 | _MIN_FC_FEATURE_SIZE = 64 33 | 34 | ''' 35 | How many times to run the forward function of a layer in order to get its latency. 36 | ''' 37 | _MEASURE_LATENCY_SAMPLE_TIMES = 500 38 | 39 | ''' 40 | The batch size of input data when running forward functions to measure latency. 41 | ''' 42 | _MEASURE_LATENCY_BATCH_SIZE = 128 43 | 44 | class networkUtils_mobilenet(NetworkUtilsAbstract): 45 | num_simplifiable_blocks = None 46 | input_data_shape = None 47 | train_loader = None 48 | holdout_loader = None 49 | val_loader = None 50 | optimizer = None 51 | 52 | def __init__(self, model, input_data_shape, dataset_path, finetune_lr=1e-3): 53 | ''' 54 | Initialize: 55 | (1) network definition 'network_def' 56 | (2) num of simplifiable blocks 'num_simplifiable_blocks'. 57 | (3) loss function 'criterion' 58 | (4) data loader for training/validation set 'train_loader' and 'holdout_loader', 59 | (5) optimizer 'optimizer' 60 | 61 | Need to be implemented: 62 | (1) finetune/evaluation data loader 63 | (2) loss function 64 | (3) optimizer 65 | 66 | Input: 67 | `model`: model from which we will get network_def. 68 | `input_data_shape`: (list) [C, H, W]. 69 | `dataset_path`: (string) path to dataset. 70 | `finetune_lr`: (float) short-term fine-tune learning rate. 71 | ''' 72 | 73 | super().__init__() 74 | 75 | # Set the shape of the input data. 76 | self.input_data_shape = input_data_shape 77 | # Set network definition (conv & fc) 78 | network_def = self.get_network_def_from_model(model) 79 | # Set num_simplifiable_blocks. 80 | self.num_simplifiable_blocks = 0 81 | for layer_name, layer_properties in network_def.items(): 82 | if not layer_properties[KEY_IS_DEPTHWISE]: 83 | self.num_simplifiable_blocks += 1 84 | # We cannot reduce the number of filters in the output layer (1). 85 | self.num_simplifiable_blocks -= 1 86 | 87 | ''' 88 | The following variables need to be defined depending on tasks: 89 | (1) finetune/evaluation data loader 90 | (2) loss function 91 | (3) optimizer 92 | ''' 93 | # Data loaders for fine tuning and evaluation. 94 | self.batch_size = 128 95 | self.num_workers = 4 96 | self.momentum = 0.9 97 | self.weight_decay = 1e-4 98 | self.finetune_lr = finetune_lr 99 | 100 | train_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True, 101 | transform=transforms.Compose([ 102 | transforms.RandomCrop(32, padding=4), 103 | transforms.Resize(224), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 107 | ])) 108 | 109 | train_loader = torch.utils.data.DataLoader( 110 | train_dataset, batch_size=self.batch_size, 111 | num_workers=self.num_workers, pin_memory=True, shuffle=True) 112 | self.train_loader = train_loader 113 | 114 | val_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True, 115 | transform=transforms.Compose([ 116 | transforms.Resize(224), 117 | transforms.ToTensor(), 118 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 119 | ])) 120 | val_loader = torch.utils.data.DataLoader( 121 | val_dataset, batch_size=self.batch_size, shuffle=False, 122 | num_workers=self.num_workers, pin_memory=True) 123 | self.val_loader = val_loader 124 | 125 | self.criterion = torch.nn.BCEWithLogitsLoss() 126 | self.optimizer = torch.optim.SGD(model.parameters(), 127 | finetune_lr, momentum=self.momentum, weight_decay=self.weight_decay) 128 | 129 | 130 | def _get_layer_by_param_name(self, model, param_name): 131 | ''' 132 | please refer to def get_layer_by_param_name(...) in functions.py 133 | ''' 134 | return fns.get_layer_by_param_name(model, param_name) 135 | 136 | 137 | def _get_keys_from_ordered_dict(self, ordered_dict): 138 | ''' 139 | please refer to def get_keys_from_ordered_dict(...) in functions.py 140 | ''' 141 | return fns.get_keys_from_ordered_dict(ordered_dict) 142 | 143 | 144 | def get_network_def_from_model(self, model): 145 | ''' 146 | please refer to get_network_def_from_model(...) in functions.py 147 | ''' 148 | return fns.get_network_def_from_model(model, self.input_data_shape) 149 | 150 | 151 | def simplify_network_def_based_on_constraint(self, network_def, block, constraint, resource_type, 152 | lookup_table_path=None): 153 | ''' 154 | Derive how much a certain block of layers ('block') should be simplified 155 | based on resource constraints. 156 | 157 | Here we treat one block as one layer although a block can contain several layers. 158 | 159 | Input: 160 | `network_def`: simplifiable network definition (conv & fc). Get network def from self.get_network_def_from_model(...) 161 | `block`: (int) index of block to simplify 162 | `constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy 163 | `resource_type`: `FLOPs`, `WEIGHTS`, or `LATENCY` 164 | `lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY' 165 | 166 | Output: 167 | `simplified_network_def`: simplified network definition. Indicates how much the network should 168 | be simplified/pruned. 169 | `simplified_resource`: (float) the estimated resource consumption of simplified models. 170 | ''' 171 | return fns.simplify_network_def_based_on_constraint(network_def, block, constraint, 172 | resource_type, lookup_table_path) 173 | 174 | 175 | def simplify_model_based_on_network_def(self, simplified_network_def, model): 176 | ''' 177 | Choose which filters to perserve 178 | 179 | Here filters with largest L2 magnitude will be kept 180 | 181 | please refer to def simplify_model_based_on_network_def(...) in functions.py 182 | ''' 183 | return fns.simplify_model_based_on_network_def(simplified_network_def, model) 184 | 185 | 186 | def extra_history_info(self, network_def): 187 | ''' 188 | return # of output channels per layer 189 | 190 | Input: 191 | `network_def`: (dict) 192 | 193 | Output: 194 | `num_filters_str`: (string) show the num of output channels for each layer 195 | ''' 196 | num_filters_str = [str(layer_properties[KEY_NUM_OUT_CHANNELS]) for _, layer_properties in 197 | network_def.items()] 198 | num_filters_str = ' '.join(num_filters_str) 199 | return num_filters_str 200 | 201 | 202 | def _compute_weights_and_flops(self, network_def): 203 | ''' 204 | please refer to def compute_weights_and_macs(...) in functions.py 205 | ''' 206 | return fns.compute_weights_and_macs(network_def) 207 | 208 | 209 | def _compute_latency_from_lookup_table(self, network_def, lookup_table_path): 210 | ''' 211 | please refer to def compute_latency_from_lookup_table(...) in functions.py 212 | ''' 213 | return fns.compute_latency_from_lookup_table(network_def, lookup_table_path) 214 | 215 | 216 | def build_lookup_table(self, network_def_full, resource_type, lookup_table_path, 217 | min_conv_feature_size=_MIN_CONV_FEATURE_SIZE, 218 | min_fc_feature_size=_MIN_FC_FEATURE_SIZE, 219 | measure_latency_batch_size=_MEASURE_LATENCY_BATCH_SIZE, 220 | measure_latency_sample_times=_MEASURE_LATENCY_SAMPLE_TIMES, 221 | verbose=True): 222 | # Build lookup table for latency 223 | ''' 224 | please refer to def build_latency_lookup_table(...) in functions.py 225 | ''' 226 | return fns.build_latency_lookup_table(network_def_full, lookup_table_path, 227 | min_conv_feature_size=min_conv_feature_size, 228 | min_fc_feature_size=min_fc_feature_size, 229 | measure_latency_batch_size=measure_latency_batch_size, 230 | measure_latency_sample_times=measure_latency_sample_times, 231 | verbose=verbose) 232 | 233 | 234 | def compute_resource(self, network_def, resource_type, lookup_table_path=None): 235 | ''' 236 | please refer to def compute_resource(...) in functions.py 237 | ''' 238 | return fns.compute_resource(network_def, resource_type, lookup_table_path) 239 | 240 | 241 | def get_num_simplifiable_blocks(self): 242 | return self.num_simplifiable_blocks 243 | 244 | 245 | def fine_tune(self, model, iterations, print_frequency=100): 246 | ''' 247 | short-term fine-tune a simplified model 248 | 249 | Input: 250 | `model`: model to be fine-tuned. 251 | `iterations`: (int) num of short-term fine-tune iterations. 252 | `print_frequency`: (int) how often to print fine-tune info. 253 | 254 | Output: 255 | `model`: fine-tuned model. 256 | ''' 257 | 258 | _NUM_CLASSES = 10 259 | optimizer = torch.optim.SGD(model.parameters(), self.finetune_lr, 260 | momentum=self.momentum, weight_decay=self.weight_decay) 261 | model = model.cuda() 262 | model.train() 263 | dataloader_iter = iter(self.train_loader) 264 | for i in range(iterations): 265 | try: 266 | (input, target) = next(dataloader_iter) 267 | except: 268 | dataloader_iter = iter(self.train_loader) 269 | (input, target) = next(dataloader_iter) 270 | 271 | if i % print_frequency == 0: 272 | print('Fine-tuning iteration {}'.format(i)) 273 | sys.stdout.flush() 274 | 275 | target.unsqueeze_(1) 276 | target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES) 277 | target_onehot.zero_() 278 | target_onehot.scatter_(1, target, 1) 279 | target.squeeze_(1) 280 | input, target = input.cuda(), target.cuda() 281 | target_onehot = target_onehot.cuda() 282 | 283 | pred = model(input) 284 | loss = self.criterion(pred, target_onehot) 285 | optimizer.zero_grad() 286 | loss.backward() # compute gradient and do SGD step 287 | optimizer.step() 288 | return model 289 | 290 | 291 | def evaluate(self, model, print_frequency=10): 292 | ''' 293 | Evaluate the accuracy of the model 294 | 295 | Input: 296 | `model`: model to be evaluated. 297 | `print_frequency`: how often to print evaluation info. 298 | 299 | Output: 300 | accuracy: (float) (0~100) 301 | ''' 302 | 303 | model = model.cuda() 304 | model.eval() 305 | acc = .0 306 | num_samples = .0 307 | with torch.no_grad(): 308 | for i, (input, target) in enumerate(self.val_loader): 309 | input, target = input.cuda(), target.cuda() 310 | pred = model(input) 311 | pred = pred.argmax(dim=1) 312 | batch_acc = torch.sum(target == pred) 313 | acc += batch_acc.item() 314 | num_samples += pred.shape[0] 315 | 316 | if i % print_frequency == 0: 317 | fns.update_progress(i, len(self.val_loader)) 318 | print(' ') 319 | print(' ') 320 | print('Test accuracy: {:4.2f}% '.format(float(acc/num_samples*100))) 321 | print('===================================================================') 322 | return acc/num_samples*100 323 | 324 | 325 | def mobilenet(model, input_data_shape, dataset_path, finetune_lr=1e-3): 326 | return networkUtils_mobilenet(model, input_data_shape, dataset_path, finetune_lr) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch == 1.2.0 2 | torchvision == 0.4.0 3 | numpy == 1.17.0 4 | scipy == 1.3.1 -------------------------------------------------------------------------------- /scripts/netadapt_alexnet-0.5latency.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/alexnet/prune-by-latency 3 224 224 \ 2 | -im models/alexnet/model.pth.tar -gp 0 1 2 3 4 5 6 \ 3 | -mi 30 -bur 0.25 -rt LATENCY -irr 0.025 -rd 0.96 \ 4 | -lr 0.001 -st 500 -lt latency_lut/lut_alexnet.pkl \ 5 | -dp data/ --arch alexnet -------------------------------------------------------------------------------- /scripts/netadapt_alexnet-0.5mac.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/alexnet/prune-by-mac 3 224 224 \ 2 | -im models/alexnet/model.pth.tar -gp 0 1 2 3 4 5 6 \ 3 | -mi 30 -bur 0.25 -rt FLOPS -irr 0.025 -rd 0.96 \ 4 | -lr 0.001 -st 500 \ 5 | -dp data/ --arch alexnet -------------------------------------------------------------------------------- /scripts/netadapt_helloworld.sh: -------------------------------------------------------------------------------- 1 | python master.py models/helloworld/netadapt 3 32 32 \ 2 | -gp 0 1 2 -mi 3 -bur 0.25 -rt FLOPS \ 3 | -irr 0.025 -rd 1.0 -lr 0.001 -st 5 \ 4 | -im models/helloworld/model_0.pth.tar \ 5 | -lt models/helloworld/lut.pkl -dp data/ \ 6 | --arch helloworld -si 1 -------------------------------------------------------------------------------- /scripts/netadapt_mobilenet-0.5latency.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/mobilenet/prune-by-latency 3 224 224 \ 2 | -im models/mobilenet/model.pth.tar -gp 0 1 2 3 4 5 6 \ 3 | -mi 30 -bur 0.25 -rt LATENCY -irr 0.025 -rd 0.96 \ 4 | -lr 0.001 -st 500 -lt latency_lut/lut_mobilenet.pkl \ 5 | -dp data/ --arch mobilenet -------------------------------------------------------------------------------- /scripts/netadapt_mobilenet-0.5mac.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python master.py models/mobilenet/prune-by-mac 3 224 224 \ 2 | -im models/mobilenet/model.pth.tar -gp 0 1 2 3 4 5 6 \ 3 | -mi 30 -bur 0.25 -rt FLOPS -irr 0.025 -rd 0.96 \ 4 | -lr 0.001 -st 500 \ 5 | -dp data/ --arch mobilenet -------------------------------------------------------------------------------- /scripts/unittest.sh: -------------------------------------------------------------------------------- 1 | cd unittest 2 | python unittest_network_utils_helloworld.py 3 | python unittest_network_utils_alexnet.py 4 | python unittest_network_utils_mobilenet.py 5 | python unittest_worker_helloworld.py 6 | cp unittest_master_helloworld.py ../unittest_master_helloworld.py 7 | cd .. 8 | python unittest_master_helloworld.py 9 | rm unittest_master_helloworld.py -------------------------------------------------------------------------------- /scripts/unittest_alexnet.sh: -------------------------------------------------------------------------------- 1 | cd unittest 2 | python unittest_network_utils_alexnet.py 3 | -------------------------------------------------------------------------------- /scripts/unittest_helloworld.sh: -------------------------------------------------------------------------------- 1 | cd unittest 2 | python unittest_network_utils_helloworld.py 3 | python unittest_worker_helloworld.py 4 | cp unittest_master_helloworld.py ../unittest_master_helloworld.py 5 | cd .. 6 | python unittest_master_helloworld.py 7 | rm unittest_master_helloworld.py -------------------------------------------------------------------------------- /scripts/unittest_mobilenet.sh: -------------------------------------------------------------------------------- 1 | cd unittest 2 | python unittest_network_utils_mobilenet.py 3 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import time 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | import torch.backends.cudnn as cudnn 10 | 11 | import nets as models 12 | import functions as fns 13 | 14 | _NUM_CLASSES = 10 15 | 16 | model_names = sorted(name for name in models.__dict__ 17 | if name.islower() and not name.startswith("__") 18 | and callable(models.__dict__[name])) 19 | 20 | 21 | def adjust_learning_rate(optimizer, epoch, args): 22 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 23 | lr = args.lr * (0.1 ** (epoch // 50)) 24 | for param_group in optimizer.param_groups: 25 | param_group['lr'] = lr 26 | 27 | 28 | class AverageMeter(object): 29 | """Computes and stores the average and current value""" 30 | def __init__(self): 31 | self.reset() 32 | 33 | def reset(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | 39 | def get_avg(self): 40 | return self.avg 41 | 42 | def update(self, val, n=1): 43 | self.val = val 44 | self.sum += val * n 45 | self.count += n 46 | self.avg = self.sum / self.count 47 | 48 | 49 | def compute_accuracy(output, target): 50 | output = output.argmax(dim=1) 51 | acc = 0.0 52 | acc = torch.sum(target == output).item() 53 | acc = acc/output.size(0)*100 54 | return acc 55 | 56 | 57 | def train(train_loader, model, criterion, optimizer, epoch, args): 58 | batch_time = AverageMeter() 59 | losses = AverageMeter() 60 | acc = AverageMeter() 61 | 62 | # switch to train mode 63 | model.train() 64 | 65 | print('===================================================================') 66 | end = time.time() 67 | 68 | for i, (images, target) in enumerate(train_loader): 69 | target.unsqueeze_(1) 70 | target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES) 71 | target_onehot.zero_() 72 | target_onehot.scatter_(1, target, 1) 73 | target.squeeze_(1) 74 | 75 | if not args.no_cuda: 76 | images = images.cuda() 77 | target_onehot = target_onehot.cuda() 78 | target = target.cuda() 79 | 80 | # compute output and loss 81 | output = model(images) 82 | loss = criterion(output, target_onehot) 83 | 84 | # measure accuracy and record loss 85 | batch_acc = compute_accuracy(output, target) 86 | 87 | losses.update(loss.item(), images.size(0)) 88 | acc.update(batch_acc, images.size(0)) 89 | 90 | # compute gradient and do SGD step 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | # measure elapsed time 96 | batch_time.update(time.time() - end) 97 | end = time.time() 98 | 99 | # Update statistics 100 | estimated_time_remained = batch_time.get_avg()*(len(train_loader)-i-1) 101 | fns.update_progress(i, len(train_loader), 102 | ESA='{:8.2f}'.format(estimated_time_remained)+'s', 103 | loss='{:4.2f}'.format(loss.item()), 104 | acc='{:4.2f}%'.format(float(batch_acc)) 105 | ) 106 | 107 | print() 108 | print('Finish epoch {}: time = {:8.2f}s, loss = {:4.2f}, acc = {:4.2f}%'.format( 109 | epoch+1, batch_time.get_avg()*len(train_loader), 110 | float(losses.get_avg()), float(acc.get_avg()))) 111 | print('===================================================================') 112 | return 113 | 114 | 115 | def eval(test_loader, model, args): 116 | batch_time = AverageMeter() 117 | acc = AverageMeter() 118 | 119 | # switch to eval mode 120 | model.eval() 121 | 122 | end = time.time() 123 | for i, (images, target) in enumerate(test_loader): 124 | if not args.no_cuda: 125 | images = images.cuda() 126 | target = target.cuda() 127 | output = model(images) 128 | batch_acc = compute_accuracy(output, target) 129 | acc.update(batch_acc, images.size(0)) 130 | batch_time.update(time.time() - end) 131 | end = time.time() 132 | 133 | # Update statistics 134 | estimated_time_remained = batch_time.get_avg()*(len(test_loader)-i-1) 135 | fns.update_progress(i, len(test_loader), 136 | ESA='{:8.2f}'.format(estimated_time_remained)+'s', 137 | acc='{:4.2f}'.format(float(batch_acc)) 138 | ) 139 | print() 140 | print('Test accuracy: {:4.2f}% (time = {:8.2f}s)'.format( 141 | float(acc.get_avg()), batch_time.get_avg()*len(test_loader))) 142 | print('===================================================================') 143 | return float(acc.get_avg()) 144 | 145 | 146 | if __name__ == '__main__': 147 | # Parse the input arguments. 148 | arg_parser = ArgumentParser() 149 | arg_parser.add_argument('data', metavar='DIR', help='path to dataset') 150 | arg_parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 151 | help='number of data loading workers (default: 4)') 152 | arg_parser.add_argument('--epochs', default=150, type=int, metavar='N', 153 | help='number of total epochs to run (default: 150)') 154 | arg_parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 155 | help='manual epoch number (useful on restarts)') 156 | arg_parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet', 157 | choices=model_names, 158 | help='model architecture: ' + 159 | ' | '.join(model_names) + 160 | ' (default: alexnet)') 161 | arg_parser.add_argument('-b', '--batch-size', default=128, type=int, 162 | metavar='N', 163 | help='batch size (default: 128)') 164 | arg_parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 165 | metavar='LR', help='initial learning rate (defult: 0.1)', dest='lr') 166 | arg_parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 167 | help='momentum (default: 0.9)') 168 | arg_parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, 169 | metavar='W', help='weight decay (default: 5e-4)', 170 | dest='weight_decay') 171 | arg_parser.add_argument('--resume', default='', type=str, metavar='PATH', 172 | help='path to latest checkpoint (default: none)') 173 | arg_parser.add_argument('--dir', type=str, default='models/', dest='save_dir', 174 | help='path to save models (default: models/') 175 | arg_parser.add_argument('--no-cuda', action='store_true', default=False, dest='no_cuda', 176 | help='disables training on GPU') 177 | args = arg_parser.parse_args() 178 | print(args) 179 | 180 | path = os.path.dirname(args.save_dir) 181 | if not os.path.exists(path): 182 | os.makedirs(path) 183 | print('Create new directory `{}`'.format(path)) 184 | 185 | # Data loader 186 | train_dataset = datasets.CIFAR10(root=args.data, train=True, download=True, 187 | transform=transforms.Compose([ 188 | transforms.RandomCrop(32, padding=4), 189 | transforms.Resize(224), 190 | transforms.RandomHorizontalFlip(), 191 | transforms.ToTensor(), 192 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 193 | ])) 194 | train_loader = torch.utils.data.DataLoader( 195 | train_dataset, batch_size=args.batch_size, shuffle=True, 196 | num_workers=args.workers, pin_memory=True) 197 | test_dataset = datasets.CIFAR10(root=args.data, train=False, download=True, 198 | transform=transforms.Compose([ 199 | transforms.Resize(224), 200 | transforms.ToTensor(), 201 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 202 | ])) 203 | test_loader = torch.utils.data.DataLoader( 204 | test_dataset, batch_size=args.batch_size, shuffle=False, 205 | num_workers=args.workers, pin_memory=True) 206 | 207 | # Network 208 | cudnn.benchmark = True 209 | num_classes = _NUM_CLASSES 210 | model_arch = args.arch 211 | model = models.__dict__[model_arch](num_classes=num_classes) 212 | criterion = nn.BCEWithLogitsLoss() 213 | if not args.no_cuda: 214 | model = model.cuda() 215 | criterion = criterion.cuda() 216 | 217 | # optionally resume from a checkpoint 218 | if args.resume: 219 | if os.path.isfile(args.resume): 220 | print("Loading checkpoint '{}'".format(args.resume)) 221 | model = torch.load(args.resume) 222 | 223 | else: 224 | print("No checkpoint found at '{}'".format(args.resume)) 225 | 226 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 227 | momentum=args.momentum, 228 | weight_decay=args.weight_decay) 229 | 230 | # Train & evaluation 231 | best_acc = 0 232 | filename = os.path.join(args.save_dir) 233 | 234 | for epoch in range(args.start_epoch, args.epochs): 235 | print('Epoch [{}/{}]'.format(epoch+1, args.epochs - args.start_epoch)) 236 | adjust_learning_rate(optimizer, epoch, args) 237 | # train for one epoch 238 | train(train_loader, model, criterion, optimizer, epoch, args) 239 | acc = eval(test_loader, model, args) 240 | 241 | if acc > best_acc: 242 | torch.save(model, filename) 243 | best_acc = acc 244 | print('Save model: ' + filename) 245 | print(' ') 246 | print('Best accuracy:', best_acc) 247 | 248 | model = torch.load(filename) 249 | print(model) 250 | 251 | best_acc = eval(test_loader, model, args) 252 | print('Best accuracy:', best_acc) -------------------------------------------------------------------------------- /unittest/unittest_master_helloworld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import network_utils as networkUtils 4 | import nets as models 5 | import unittest 6 | import pickle 7 | from constants import * 8 | import subprocess 9 | import common 10 | import sys 11 | import shutil 12 | 13 | 14 | MODEL_ARCH = 'helloworld' 15 | INPUT_DATA_SHAPE = (3, 32, 32) 16 | 17 | FLOPS_LOOKUP_TABLE_PATH = os.path.join('models', MODEL_ARCH, 'lut.pkl') 18 | 19 | MODEL_PATH = os.path.join('models', MODEL_ARCH, 'model_0.pth.tar') 20 | 21 | model = models.__dict__[MODEL_ARCH]() 22 | for i in range(4): 23 | layer = getattr(model.features, str(i*2)) 24 | layer.weight.data = torch.zeros_like(layer.weight.data) 25 | torch.save(model, MODEL_PATH) 26 | 27 | DATASET_PATH = './' 28 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH) 29 | 30 | SHORT_TERM_FINE_TUNE_ITERATION = 5 31 | MAX_ITERS = 3 32 | BUDGET_RATIO = 0.8 33 | INIT_REDUCTION_RATIO = 0.025 34 | REDUCTION_DECAY = 1.0 35 | FINETUNE_LR = 0.001 36 | SAVE_INTERVAL = 1 37 | 38 | 39 | def run_master(working_folder, resource_type='WEIGHT', 40 | budget_ratio=BUDGET_RATIO, 41 | budget=None, 42 | init_reduction_ratio=INIT_REDUCTION_RATIO, 43 | init_reduction=None, 44 | reduction_decay=REDUCTION_DECAY, 45 | finetune_lr=FINETUNE_LR, 46 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 47 | max_iters=MAX_ITERS, 48 | lookup_table_path=None, 49 | resume=False): 50 | if not os.path.exists(working_folder): 51 | os.mkdir(working_folder) 52 | print('Create directory', working_folder) 53 | with open(os.path.join(working_folder, 'master_log.txt'), 'w') as file_id: 54 | command_list = [sys.executable, 'master.py', working_folder, str(3), str(32), str(32), 55 | '-im', MODEL_PATH, 56 | '-gp', str(0), str(1), str(2), 57 | '-mi', str(max_iters), 58 | '-bur', str(budget_ratio), 59 | '-rt', resource_type, 60 | '-irr', str(init_reduction_ratio), 61 | '-rd', str(reduction_decay), 62 | '-lr', str(finetune_lr), 63 | '-st', str(short_term_fine_tune_iteration), 64 | '-dp', DATASET_PATH, 65 | '--arch', MODEL_ARCH, 66 | '-si', str(SAVE_INTERVAL)] 67 | if lookup_table_path != None: 68 | command_list = command_list + ['-lt', lookup_table_path] 69 | if resume: 70 | command_list = command_list + ['--resume'] 71 | if init_reduction != None: 72 | command_list = command_list + ['-ir', str(init_reduction)] 73 | if budget != None: 74 | command_list = command_list + ['-bu', str(budget)] 75 | print(command_list) 76 | return subprocess.call(command_list, stdout=file_id, stderr=file_id) 77 | 78 | 79 | class NormalUsage(unittest.TestCase): 80 | ''' 81 | No ValueError 82 | ''' 83 | def __init__(self, *args, **kwargs): 84 | super(NormalUsage, self).__init__(*args, **kwargs) 85 | 86 | 87 | def check_master_results(self, working_folder, acc_gt, res_gt, output_feature_gt, resource_type, 88 | short_term_fine_tune_iteration, max_iters, lookup_table_path): 89 | history_path = os.path.join(working_folder, 'master', 'history.txt') 90 | with open(history_path) as f: 91 | content = f.readlines() 92 | content = [x.strip() for x in content] 93 | self.assertEqual(len(content)-1, max_iters+1, "master/history.txt length error") 94 | 95 | for i in range(1, len(content)): 96 | print('Check iteration {}'.format(i-1)) 97 | tokens = content[i].split(',') 98 | print(tokens) 99 | 100 | # check accuracy 101 | load_model_path = tokens[4] 102 | saved_model = torch.load(load_model_path) 103 | acc = network_utils.evaluate(saved_model) 104 | saved_acc = float(tokens[1]) 105 | self.assertEqual(acc, saved_acc, "The accuracy of saved model is not equal to that in history.txt") 106 | self.assertEqual(acc_gt[i-1], acc, "The accuracy of saved model is incorrect") 107 | 108 | # check resource 109 | saved_network_def = network_utils.get_network_def_from_model(saved_model) 110 | resource = network_utils.compute_resource(saved_network_def, resource_type, lookup_table_path) 111 | saved_resource = float(tokens[2]) 112 | self.assertEqual(resource, saved_resource, "The resource of saved model is not equal to that in history txt") 113 | self.assertEqual(res_gt[i-1], resource, "The resource of saved model is incorrect.") 114 | 115 | # check simplified block idx 116 | if i != 1: 117 | tokens_pre = content[i-1].split(',') 118 | output_features = tokens[5].split(' ') 119 | output_features_pre = tokens_pre[5].split(' ') 120 | find_simplified_block = False 121 | for output_idx in range(len(output_features)): 122 | if output_features[output_idx] != output_features_pre[output_idx]: 123 | if not find_simplified_block: 124 | self.assertEqual(output_idx, int(tokens[3]), "Not simplify the block as described in master/history.txt") 125 | self.assertEqual(output_features[output_idx], str(output_feature_gt[i-1][output_idx]), "Simplified block has incorrect # of output channels") 126 | find_simplified_block = True 127 | else: 128 | self.assertEqual(1, 0, "Simplified block index error") 129 | else: 130 | self.assertTrue(output_idx != int(tokens[3]), "Simplify the incorrect block") 131 | 132 | # check network_def 133 | for idx in range(4): 134 | if idx == 0: 135 | self.assertEqual(saved_network_def[idx], (3, output_feature_gt[i-1][idx]), "network_def of simplified model error") 136 | else: 137 | self.assertEqual(saved_network_def[idx], (output_feature_gt[i-1][idx-1], output_feature_gt[i-1][idx]), "network_def of simplified model error") 138 | 139 | # check model weights 140 | for idx in range(4): 141 | layer = getattr(saved_model.features, str(idx*2)) 142 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data) + idx*short_term_fine_tune_iteration*(i-1))) 143 | temp = torch.min(temp) 144 | temp = temp.item() 145 | self.assertTrue(temp, "Model weights after short-term fine-tune are incorrect") 146 | 147 | 148 | def test_master_weights(self): 149 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights') 150 | res_gt = [29232, 28476, 27531, 26181] 151 | acc_gt = [95, 80, 85, 90] 152 | output_feature_gt = [[16, 32, 64, 10], 153 | [16, 32, 62, 10], 154 | [13, 32, 62, 10], 155 | [13, 30, 62, 10] 156 | ] 157 | run_master(working_folder, resource_type='WEIGHTS', 158 | budget_ratio=BUDGET_RATIO, 159 | init_reduction_ratio=INIT_REDUCTION_RATIO, 160 | reduction_decay=REDUCTION_DECAY, 161 | finetune_lr=FINETUNE_LR, 162 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 163 | max_iters=MAX_ITERS, 164 | lookup_table_path=None, 165 | resume=False) 166 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS', 167 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 168 | max_iters=MAX_ITERS, lookup_table_path=None) 169 | shutil.rmtree(working_folder) 170 | 171 | 172 | def test_master_flops_with_built_lookup_table(self): 173 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_with_built_lookup_table') 174 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'lut.pkl') 175 | res_gt = [29232*32*32, 28476*32*32, 27531*32*32, 26181*32*32] 176 | acc_gt = [95, 80, 85, 90] 177 | output_feature_gt = [[16, 32, 64, 10], 178 | [16, 32, 62, 10], 179 | [13, 32, 62, 10], 180 | [13, 30, 62, 10] 181 | ] 182 | run_master(working_folder, resource_type='FLOPS', 183 | budget_ratio=BUDGET_RATIO, 184 | init_reduction_ratio=INIT_REDUCTION_RATIO, 185 | reduction_decay=REDUCTION_DECAY, 186 | finetune_lr=FINETUNE_LR, 187 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 188 | max_iters=MAX_ITERS, 189 | lookup_table_path=lookup_table_path, 190 | resume=False) 191 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS', 192 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 193 | max_iters=MAX_ITERS, lookup_table_path=lookup_table_path) 194 | shutil.rmtree(working_folder) 195 | 196 | 197 | def test_master_flops_without_built_lookup_table(self): 198 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_without_built_lookup_table') 199 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'not_built_flops_lut.pkl') 200 | res_gt = [29232*32*32, 28476*32*32, 27531*32*32, 26181*32*32] 201 | acc_gt = [95, 80, 85, 90] 202 | output_feature_gt = [[16, 32, 64, 10], 203 | [16, 32, 62, 10], 204 | [13, 32, 62, 10], 205 | [13, 30, 62, 10] 206 | ] 207 | run_master(working_folder, resource_type='FLOPS', 208 | budget_ratio=BUDGET_RATIO, 209 | init_reduction_ratio=INIT_REDUCTION_RATIO, 210 | reduction_decay=REDUCTION_DECAY, 211 | finetune_lr=FINETUNE_LR, 212 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 213 | lookup_table_path=lookup_table_path, 214 | resume=False) 215 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS', 216 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 217 | max_iters=MAX_ITERS, lookup_table_path=lookup_table_path) 218 | os.remove(lookup_table_path) 219 | shutil.rmtree(working_folder) 220 | 221 | 222 | def master_flops_without_built_lookup_table_resume(self, working_folder, max_iters_before_resume): 223 | # run `max_iters_before_resume` iteration 224 | # modify history 225 | # resume 226 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'not_built_flops_lut.pkl') 227 | res_gt = [29232*32*32, 28476*32*32, 27531*32*32, 26181*32*32] 228 | acc_gt = [95, 80, 85, 90] 229 | output_feature_gt = [[16, 32, 64, 10], 230 | [16, 32, 62, 10], 231 | [13, 32, 62, 10], 232 | [13, 30, 62, 10] 233 | ] 234 | run_master(working_folder, resource_type='FLOPS', 235 | budget_ratio=BUDGET_RATIO, 236 | init_reduction_ratio=INIT_REDUCTION_RATIO, 237 | reduction_decay=REDUCTION_DECAY, 238 | finetune_lr=FINETUNE_LR, 239 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 240 | max_iters=max_iters_before_resume, 241 | lookup_table_path=lookup_table_path, 242 | resume=False) 243 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS', 244 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 245 | max_iters=max_iters_before_resume, lookup_table_path=lookup_table_path) 246 | 247 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'rb') as file_id: 248 | history_pkl = pickle.load(file_id) 249 | his_args = history_pkl['master_args'] 250 | his_args.max_iters = MAX_ITERS 251 | history_pkl['master_args'] = his_args 252 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'wb') as file_id: 253 | pickle.dump(history_pkl, file_id) 254 | 255 | run_master(working_folder, resource_type='FLOPS', 256 | budget_ratio=0, 257 | init_reduction_ratio=0, 258 | reduction_decay=0, 259 | finetune_lr=FINETUNE_LR, 260 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 261 | max_iters=max_iters_before_resume, 262 | lookup_table_path=' ', 263 | resume=True) 264 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='FLOPS', 265 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 266 | max_iters=MAX_ITERS, lookup_table_path=lookup_table_path) 267 | 268 | os.remove(lookup_table_path) 269 | 270 | 271 | def test_master_weights_budget_met(self): 272 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights_budget_met') 273 | res_gt = [29232, 28476] 274 | acc_gt = [95, 80] 275 | output_feature_gt = [[16, 32, 64, 10], 276 | [16, 32, 62, 10] 277 | ] 278 | run_master(working_folder, resource_type='WEIGHTS', 279 | budget_ratio=0.9999, 280 | init_reduction_ratio=INIT_REDUCTION_RATIO, 281 | reduction_decay=REDUCTION_DECAY, 282 | finetune_lr=FINETUNE_LR, 283 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 284 | lookup_table_path=None, 285 | resume=False) 286 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS', 287 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 288 | max_iters=1, lookup_table_path=None) 289 | history_path = os.path.join(working_folder, 'master', 'history.txt') 290 | with open(history_path) as f: 291 | content = f.readlines() 292 | content = [x.strip() for x in content] 293 | self.assertEqual(len(content)-1, 2, "Master does not terminate when budget is met") 294 | shutil.rmtree(working_folder) 295 | 296 | 297 | def test_master_weights_use_init_reduction_not_ratio(self): 298 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights_use_init_reduction_not_ratio') 299 | res_gt = [29232, 28476, 27531, 26181] 300 | acc_gt = [95, 80, 85, 90] 301 | output_feature_gt = [[16, 32, 64, 10], 302 | [16, 32, 62, 10], 303 | [13, 32, 62, 10], 304 | [13, 30, 62, 10] 305 | ] 306 | run_master(working_folder, resource_type='WEIGHTS', 307 | budget_ratio=BUDGET_RATIO, 308 | init_reduction_ratio=0, 309 | init_reduction=INIT_REDUCTION_RATIO*29232, 310 | reduction_decay=REDUCTION_DECAY, 311 | finetune_lr=FINETUNE_LR, 312 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 313 | lookup_table_path=None, 314 | resume=False) 315 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS', 316 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 317 | max_iters=MAX_ITERS, lookup_table_path=None) 318 | shutil.rmtree(working_folder) 319 | 320 | 321 | def test_master_weights_use_budget_not_ratio(self): 322 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_weights_use_budget_not_ratio') 323 | res_gt = [29232, 28476, 27531, 26181] 324 | acc_gt = [95, 80, 85, 90] 325 | output_feature_gt = [[16, 32, 64, 10], 326 | [16, 32, 62, 10], 327 | [13, 32, 62, 10], 328 | [13, 30, 62, 10] 329 | ] 330 | run_master(working_folder, resource_type='WEIGHTS', 331 | budget_ratio=0, 332 | budget=29232*BUDGET_RATIO, 333 | init_reduction_ratio=INIT_REDUCTION_RATIO, 334 | reduction_decay=REDUCTION_DECAY, 335 | finetune_lr=FINETUNE_LR, 336 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 337 | lookup_table_path=None, 338 | resume=False) 339 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS', 340 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 341 | max_iters=MAX_ITERS, lookup_table_path=None) 342 | shutil.rmtree(working_folder) 343 | 344 | 345 | def test_master_delete_previous_files_and_not_resume(self): 346 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_delete_previous_files_and_not_resume') 347 | res_gt = [29232, 28476, 27531, 26181] 348 | acc_gt = [95, 80, 85, 90] 349 | output_feature_gt = [[16, 32, 64, 10], 350 | [16, 32, 62, 10], 351 | [13, 32, 62, 10], 352 | [13, 30, 62, 10] 353 | ] 354 | run_master(working_folder, resource_type='WEIGHTS', 355 | budget_ratio=BUDGET_RATIO, 356 | init_reduction_ratio=INIT_REDUCTION_RATIO, 357 | reduction_decay=REDUCTION_DECAY, 358 | finetune_lr=FINETUNE_LR, 359 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 360 | lookup_table_path=None, 361 | resume=False) 362 | for file in os.listdir(os.path.join(working_folder, 'master')): 363 | os.remove(os.path.join(working_folder, 'master', file)) 364 | for file in os.listdir(os.path.join(working_folder, 'worker')): 365 | os.remove(os.path.join(working_folder, 'worker', file)) 366 | returncode = run_master(working_folder, resource_type='WEIGHTS', 367 | budget_ratio=BUDGET_RATIO, 368 | init_reduction_ratio=INIT_REDUCTION_RATIO, 369 | reduction_decay=REDUCTION_DECAY, 370 | finetune_lr=FINETUNE_LR, 371 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 372 | lookup_table_path=None, 373 | resume=False) 374 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS', 375 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 376 | max_iters=MAX_ITERS, lookup_table_path=None) 377 | self.assertEqual(returncode, 0, "Master function error when all previous files are deleted.") 378 | shutil.rmtree(working_folder) 379 | 380 | 381 | def test_master_resume(self): 382 | max_iters_before_resume = 1 383 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_without_built_lookup_table-resume_' + str(max_iters_before_resume)) 384 | self.master_flops_without_built_lookup_table_resume(working_folder, max_iters_before_resume) 385 | shutil.rmtree(working_folder) 386 | 387 | max_iters_before_resume = 2 388 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_flops_without_built_lookup_table-resume_' + str(max_iters_before_resume)) 389 | self.master_flops_without_built_lookup_table_resume(working_folder, max_iters_before_resume) 390 | shutil.rmtree(working_folder) 391 | 392 | def test_constraint_tight(self): 393 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_constraint_tight') 394 | res_gt = [29232, 5418, 693] 395 | acc_gt = [95, 80, 85] 396 | output_feature_gt = [[16, 32, 64, 10], 397 | [16, 32, 1, 10], 398 | [1, 32, 1, 10] 399 | ] 400 | run_master(working_folder, resource_type='WEIGHTS', 401 | budget_ratio=0, 402 | init_reduction_ratio=1, 403 | reduction_decay=REDUCTION_DECAY, 404 | finetune_lr=FINETUNE_LR, 405 | max_iters=MAX_ITERS-1, 406 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 407 | lookup_table_path=None, 408 | resume=False) 409 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS', 410 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 411 | max_iters=MAX_ITERS-1, lookup_table_path=None) 412 | 413 | master_log_path = os.path.join(working_folder, 'master_log.txt') 414 | with open(master_log_path) as f: 415 | content = f.readlines() 416 | content = [x.strip() for x in content] 417 | 418 | warning_counter = 0 419 | for line in content: 420 | if 'UserWarning' in line: 421 | warning_counter += 1 422 | self.assertEqual(warning_counter, 2, "Target resource warning by master error") 423 | shutil.rmtree(working_folder) 424 | 425 | 426 | # normal usage 427 | # however, if constraint is too tight or not achievable, 428 | # it will raise ValueError 429 | def test_constraint_not_achievable(self): 430 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_constraint_not_achievable') 431 | res_gt = [29232, 5418, 693, 135] 432 | acc_gt = [95, 80, 85, 90] 433 | output_feature_gt = [[16, 32, 64, 10], 434 | [16, 32, 1, 10], 435 | [1, 32, 1, 10], 436 | [1, 1, 1, 10] 437 | ] 438 | returncode = run_master(working_folder, resource_type='WEIGHTS', 439 | budget_ratio=0, 440 | init_reduction_ratio=1, 441 | reduction_decay=REDUCTION_DECAY, 442 | finetune_lr=FINETUNE_LR, 443 | max_iters=MAX_ITERS+1, 444 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 445 | lookup_table_path=None, 446 | resume=False) 447 | self.check_master_results(working_folder, acc_gt, res_gt, output_feature_gt, resource_type='WEIGHTS', 448 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 449 | max_iters=MAX_ITERS, lookup_table_path=None) 450 | self.assertEqual(returncode, 1, "Master resource constraint not achievable error") 451 | 452 | master_log_path = os.path.join(working_folder, 'master_log.txt') 453 | with open(master_log_path) as f: 454 | content = f.readlines() 455 | content = [x.strip() for x in content] 456 | 457 | warning_counter = 0 458 | for line in content: 459 | if 'UserWarning' in line: 460 | warning_counter += 1 461 | self.assertEqual(warning_counter, 3, "Target resource warning by master error") 462 | shutil.rmtree(working_folder) 463 | 464 | 465 | class ValueErrCase(unittest.TestCase): 466 | def __init__(self, *args, **kwargs): 467 | super(ValueErrCase, self).__init__(*args, **kwargs) 468 | 469 | 470 | def test_not_resume_and_previous_files_exist(self): 471 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_not_resume') 472 | returncode = run_master(working_folder, resource_type='WEIGHTS', 473 | budget_ratio=0, 474 | init_reduction_ratio=1, 475 | reduction_decay=REDUCTION_DECAY, 476 | finetune_lr=FINETUNE_LR, 477 | max_iters=1, 478 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 479 | lookup_table_path=None, 480 | resume=False) 481 | self.assertEqual(returncode, 0, "Normal master execution error") 482 | returncode = run_master(working_folder, resource_type='WEIGHTS', 483 | budget_ratio=0, 484 | init_reduction_ratio=1, 485 | reduction_decay=REDUCTION_DECAY, 486 | finetune_lr=FINETUNE_LR, 487 | max_iters=1, 488 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 489 | lookup_table_path=None, 490 | resume=False) 491 | self.assertEqual(returncode, 1, "Master does not detect the error incurred when previous files exist and `--resume` is not specified") 492 | shutil.rmtree(working_folder) 493 | 494 | 495 | def test_resume_no_lookup_table(self): 496 | working_folder = os.path.join('models', MODEL_ARCH, 'unittest_master_resume_no_lookuptable') 497 | lookup_table_path = os.path.join('models', MODEL_ARCH, 'not_built_flops_lut.pkl') 498 | run_master(working_folder, resource_type='FLOPS', 499 | budget_ratio=BUDGET_RATIO, 500 | init_reduction_ratio=INIT_REDUCTION_RATIO, 501 | reduction_decay=REDUCTION_DECAY, 502 | max_iters=1, 503 | finetune_lr=FINETUNE_LR, 504 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 505 | lookup_table_path=lookup_table_path, 506 | resume=False) 507 | os.remove(lookup_table_path) 508 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'rb') as file_id: 509 | history_pkl = pickle.load(file_id) 510 | his_args = history_pkl['master_args'] 511 | his_args.max_iters = MAX_ITERS 512 | history_pkl['master_args'] = his_args 513 | with open(os.path.join(working_folder, 'master', 'history.pickle'), 'wb') as file_id: 514 | pickle.dump(history_pkl, file_id) 515 | returncode = run_master(working_folder, resource_type='FLOPS', 516 | budget_ratio=BUDGET_RATIO, 517 | init_reduction_ratio=INIT_REDUCTION_RATIO, 518 | reduction_decay=REDUCTION_DECAY, 519 | max_iters=1, 520 | finetune_lr=FINETUNE_LR, 521 | short_term_fine_tune_iteration=SHORT_TERM_FINE_TUNE_ITERATION, 522 | lookup_table_path=lookup_table_path, 523 | resume=True) 524 | self.assertEqual(returncode, 1, "Master does not detect the error incurred when resuming from previous iterations but lookup table not found") 525 | shutil.rmtree(working_folder) 526 | 527 | 528 | 529 | if __name__ == '__main__': 530 | 531 | unittest.main() 532 | -------------------------------------------------------------------------------- /unittest/unittest_network_utils_alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | sys.path.append(os.path.abspath('../')) 5 | import pickle 6 | import network_utils as networkUtils 7 | import unittest 8 | import nets as models 9 | from constants import * 10 | import copy 11 | 12 | MODEL_ARCH = 'alexnet' 13 | INPUT_DATA_SHAPE = (3, 224, 224) 14 | LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl') 15 | DATASET_PATH = '../data/' 16 | 17 | model = models.__dict__[MODEL_ARCH](num_classes=10) 18 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH) 19 | 20 | class TestNetworkUtils_alexnet(unittest.TestCase): 21 | def __init__(self, *args, **kwargs): 22 | super(TestNetworkUtils_alexnet, self).__init__(*args, **kwargs) 23 | 24 | 25 | def check_network_def(self, network_def, input_channels, output_channels, only_num_channels=False): 26 | self.assertEqual(len(network_def), 8, "network_def length error") 27 | layer_idx = 0 28 | 29 | kernel_size_gt = [(11, 11), (5, 5), (3, 3), (3, 3), (3, 3), (1, 1), (1, 1), (1, 1)] 30 | stride_gt = [(4, 4), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)] 31 | padding_gt = [(2, 2), (2, 2), (1, 1), (1, 1), (1, 1), (0, 0), (0, 0), (0, 0)] 32 | for layer_name, layer_properties in network_def.items(): 33 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], input_channels[layer_idx], "network_def num of input channels error") 34 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], output_channels[layer_idx], "network_def num of output channels error") 35 | self.assertFalse(layer_properties[KEY_IS_DEPTHWISE], "network_def is_depthwise error") 36 | self.assertEqual(layer_properties[KEY_GROUPS], 1, "network_def group error") 37 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], kernel_size_gt[layer_idx], "network_def kernel size error") 38 | self.assertEqual(layer_properties[KEY_PADDING], padding_gt[layer_idx], "network_def padding error") 39 | self.assertEqual(layer_properties[KEY_STRIDE], stride_gt[layer_idx], "network_def stride error") 40 | 41 | if layer_idx < 5: 42 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Conv2d', "network_def layer type string error") 43 | else: 44 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Linear', "network_def layer type string error") 45 | 46 | input_feature_map_spatial_size = [224, 27, 13, 13, 13, 1, 1, 1] 47 | output_feature_map_spatial_size = [55, 27, 13, 13, 13, 1, 1, 1] 48 | if not only_num_channels: 49 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], [1, input_channels[layer_idx], 50 | input_feature_map_spatial_size[layer_idx], input_feature_map_spatial_size[layer_idx]], 51 | "network_def input feature map size error") 52 | self.assertEqual(layer_properties[KEY_OUTPUT_FEATURE_MAP_SIZE], [1, output_channels[layer_idx], 53 | output_feature_map_spatial_size[layer_idx], output_feature_map_spatial_size[layer_idx]], 54 | "network_def output feature map size error") 55 | #print(layer_idx) 56 | layer_idx += 1 57 | 58 | 59 | def gen_layer_weight(self, tensor): 60 | gen_tensor = torch.zeros_like(tensor) 61 | for i in range(gen_tensor.shape[0]): 62 | gen_tensor[i, ::] += i 63 | return gen_tensor 64 | 65 | 66 | def test_network_def(self): 67 | network_def = network_utils.get_network_def_from_model(model) 68 | #print(network_def) 69 | 70 | input_channels = [3, 64, 192, 384, 256, 9216, 4096, 4096] 71 | output_channels = [64, 192, 384, 256, 256, 4096, 4096, 10] 72 | self.check_network_def(network_def, input_channels, output_channels) 73 | self.assertEqual(network_utils.get_num_simplifiable_blocks(), 7, "Num of simplifiable blocks error") 74 | 75 | 76 | def test_compute_resource(self): 77 | network_def = network_utils.get_network_def_from_model(model) 78 | num_w = network_utils.compute_resource(network_def, 'WEIGHTS') 79 | num_mac = network_utils.compute_resource(network_def, 'FLOPS') 80 | 81 | self.assertEqual(num_w, 57035456, "Num of weights error") 82 | self.assertEqual(num_mac, 710133440, "Num of MACs error") 83 | 84 | 85 | def test_extra_history_info(self): 86 | network_def = network_utils.get_network_def_from_model(model) 87 | output_feature_info = network_utils.extra_history_info(network_def) 88 | output_channels = [64, 192, 384, 256, 256, 4096, 4096, 10] 89 | output_channels_str = [str(x) for x in output_channels] 90 | output_feature_info_gt = ' '.join(output_channels_str) 91 | self.assertEqual(output_feature_info, output_feature_info_gt, "extra_history_info error") 92 | 93 | 94 | def delta_to_layer_num_channels(self, delta, simp_block_idx): 95 | input_channels_gt = [3, 64, 192, 384, 256, 9216, 4096, 4096] 96 | output_channels_gt = [64, 192, 384, 256, 256, 4096, 4096, 10] 97 | 98 | output_channels_gt[simp_block_idx] = output_channels_gt[simp_block_idx] - delta 99 | if simp_block_idx == 4: 100 | input_channels_gt[simp_block_idx+1] = input_channels_gt[simp_block_idx+1] - delta*36 101 | else: 102 | input_channels_gt[simp_block_idx+1] = input_channels_gt[simp_block_idx+1] - delta 103 | 104 | return input_channels_gt, output_channels_gt 105 | 106 | 107 | def run_simplify_network_def_and_check_for_one_resource_type(self, constraint, resource_type, simp_block_indices, delta, res_gt): 108 | network_def = network_utils.get_network_def_from_model(model) 109 | 110 | for i in range(len(simp_block_indices)): 111 | simp_block_idx = simp_block_indices[i] 112 | simp_network_def, simp_resource = network_utils.simplify_network_def_based_on_constraint(network_def, simp_block_idx, constraint, resource_type) 113 | self.assertEqual(simp_resource, res_gt[i], "Simplified network resource {} error ({})".format(resource_type, simp_block_idx)) 114 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta[i], simp_block_idx) 115 | self.check_network_def(simp_network_def, input_channels_gt, output_channels_gt, only_num_channels=True) 116 | print(i) 117 | 118 | def test_simplify_network_def_based_on_constraint(self): 119 | total_num_w = 57035456 120 | total_num_mac = 710133440 121 | constraint_num_w = total_num_w*0.975 122 | constraint_num_mac = total_num_mac*0.975 123 | 124 | simp_block_indices = [0, 1, 4, 6] 125 | delta_w = [56, 184, 16, 352] 126 | delta_mac = [8, 16, 40, 4088] 127 | 128 | num_w_gt = [56746328, 56105152, 54639296, 55590144] 129 | num_mac_gt = [673355240, 682126016, 688660160, 693348112] 130 | 131 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_w, 132 | resource_type="WEIGHTS", simp_block_indices=simp_block_indices, 133 | delta=delta_w, res_gt=num_w_gt) 134 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_mac, 135 | resource_type="FLOPS", simp_block_indices=simp_block_indices, 136 | delta=delta_mac, res_gt=num_mac_gt) 137 | 138 | 139 | def test_simplify_model_based_on_network_def(self): 140 | network_def = network_utils.get_network_def_from_model(model) 141 | total_num_w = 57035456 142 | constraint_num_w = total_num_w*0.975 143 | simp_block_indices = [0, 1, 4, 6] 144 | delta_w = [56, 184, 16, 352] 145 | topk_w = [8, 8, 240, 3744] 146 | 147 | conv_idx = [0, 3, 6, 8, 10] 148 | fc_idx = [1, 4, 6] 149 | 150 | for i in range(len(simp_block_indices)): 151 | simp_block_idx = simp_block_indices[i] 152 | simp_network_def, _ = network_utils.simplify_network_def_based_on_constraint(network_def, 153 | simp_block_idx, constraint_num_w, "WEIGHTS") 154 | simp_model = network_utils.simplify_model_based_on_network_def(simp_network_def, model) 155 | updated_network_def = network_utils.get_network_def_from_model(simp_model) 156 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta_w[i], simp_block_idx) 157 | self.check_network_def(updated_network_def, input_channels_gt, output_channels_gt) 158 | 159 | for block_idx in range(7): 160 | if block_idx < 5: # conv 161 | layer = getattr(model, 'features') 162 | layer = getattr(layer, str(conv_idx[block_idx])) 163 | simp_layer = getattr(simp_model, 'features') 164 | simp_layer = getattr(simp_layer, str(conv_idx[block_idx])) 165 | else: 166 | layer = getattr(model, 'classifier') 167 | layer = getattr(layer, str(fc_idx[block_idx - 5])) 168 | simp_layer = getattr(simp_model, 'classifier') 169 | simp_layer = getattr(simp_layer, str(fc_idx[block_idx - 5])) 170 | 171 | if block_idx != simp_block_idx and block_idx != simp_block_idx + 1: 172 | equal_weight = (layer.weight.data == simp_layer.weight.data) 173 | equal_bias = (layer.bias.data == simp_layer.bias.data) 174 | self.assertTrue(equal_weight.min(), "simplify_model_based_on_network_def modify unrelated layers (weights)") 175 | self.assertTrue(equal_bias.min(), "simplify_model_based_on_network_def modify unrelated layers (biases)") 176 | elif block_idx == simp_block_idx: 177 | layer_weight = layer.weight.data 178 | layer_weight = layer_weight.view(layer_weight.shape[0], -1) 179 | layer_weight_norm = layer_weight*layer_weight 180 | layer_weight_norm = layer_weight_norm.sum(1) 181 | _, kept_filter_idx = torch.topk(layer_weight_norm, topk_w[i], sorted=False) 182 | kept_filter_idx, _ = torch.sort(kept_filter_idx) 183 | 184 | equal_prune_weights = (layer.weight.data[kept_filter_idx, ::] == simp_layer.weight.data) 185 | self.assertTrue(equal_prune_weights.min(), "Output channels of the pruned layer error") 186 | 187 | equal_prune_biases = (layer.bias.data[kept_filter_idx] == simp_layer.bias.data) 188 | self.assertTrue(equal_prune_biases.min(), "Output channels of the pruned layer error") 189 | 190 | # check the input features of the next layer 191 | if (block_idx + 1) < 5: # conv 192 | next_layer = getattr(model, 'features') 193 | next_layer = getattr(next_layer, str(conv_idx[block_idx + 1])) 194 | simp_next_layer = getattr(simp_model, 'features') 195 | simp_next_layer = getattr(simp_next_layer, str(conv_idx[block_idx + 1])) 196 | else: 197 | next_layer = getattr(model, 'classifier') 198 | next_layer = getattr(next_layer, str(fc_idx[(block_idx + 1) - 5])) 199 | simp_next_layer = getattr(simp_model, 'classifier') 200 | simp_next_layer = getattr(simp_next_layer, str(fc_idx[(block_idx + 1) - 5])) 201 | 202 | if block_idx != 4: 203 | if block_idx < 5: 204 | equal_weights = (next_layer.weight.data[:, kept_filter_idx, ::] == simp_next_layer.weight.data) 205 | else: 206 | equal_weights = (next_layer.weight.data[:, kept_filter_idx] == simp_next_layer.weight.data) 207 | self.assertTrue(equal_weights.min(), "Input channels of the layer after the pruned layer error") 208 | else: # conv -> FC 209 | kept_filter_idx_fc = [] 210 | for filter_idx in kept_filter_idx: 211 | for i in range(36): 212 | kept_filter_idx_fc.append(filter_idx*36 + i) 213 | equal_weights = (next_layer.weight.data[:, kept_filter_idx_fc] == simp_next_layer.weight.data) 214 | self.assertTrue(equal_weights.min(), "Input channels of the FC layer after the pruned conv layer error") 215 | 216 | 217 | 218 | def test_build_latency_lookup_table(self): 219 | network_def = network_utils.get_network_def_from_model(model) 220 | lookup_table_path = './unittest_lookup_table.plk' 221 | min_conv_feature_size = 32 222 | min_fc_feature_size = 1024 223 | measure_latency_batch_size = 1 224 | measure_latency_sample_times = 1 225 | 226 | network_utils.build_lookup_table(network_def, 'LATENCY', lookup_table_path, min_conv_feature_size, 227 | min_fc_feature_size, measure_latency_batch_size, measure_latency_sample_times) 228 | 229 | with open(lookup_table_path, 'rb') as file_id: 230 | lookup_table = pickle.load(file_id) 231 | self.assertEqual(len(lookup_table), 8, "Lookup table length error") 232 | 233 | input_channels_gt = [3, 64, 192, 384, 256, 9216, 4096, 4096] 234 | output_channels_gt = [64, 192, 384, 256, 256, 4096, 4096, 10] 235 | layer_idx = 0 236 | 237 | for layer_name, layer_properties in lookup_table.items(): 238 | self.assertEqual(layer_properties[KEY_IS_DEPTHWISE], network_def[layer_name][KEY_IS_DEPTHWISE], "lookup table layer properties error (is_depthwise)") 239 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], network_def[layer_name][KEY_NUM_IN_CHANNELS], "lookup table layer properties error (num_in_channels)") 240 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], network_def[layer_name][KEY_NUM_OUT_CHANNELS], "lookup table layer properties error (num_out_channels)") 241 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], network_def[layer_name][KEY_KERNEL_SIZE], "lookup table layer properties error (kernel_size)") 242 | self.assertEqual(layer_properties[KEY_STRIDE], network_def[layer_name][KEY_STRIDE], "lookup table layer properties error (stride)") 243 | self.assertEqual(layer_properties[KEY_PADDING], network_def[layer_name][KEY_PADDING], "lookup table layer properties error (padding)") 244 | self.assertEqual(layer_properties[KEY_GROUPS], network_def[layer_name][KEY_GROUPS], "lookup table layer properties error (groups)") 245 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], network_def[layer_name][KEY_LAYER_TYPE_STR], "lookup table layer properties error (layer_type_str)") 246 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], network_def[layer_name][KEY_INPUT_FEATURE_MAP_SIZE], "lookup table layer properties error (input_feature_size)") 247 | 248 | layer_latency_table = layer_properties[KEY_LATENCY] 249 | 250 | num_in_samples = input_channels_gt[layer_idx] 251 | num_output_samples = output_channels_gt[layer_idx] 252 | if layer_idx < 5: 253 | if num_in_samples < min_conv_feature_size: 254 | num_in_samples = 1 255 | else: 256 | num_in_samples = num_in_samples/min_conv_feature_size 257 | num_output_samples = num_output_samples/min_conv_feature_size 258 | else: 259 | num_in_samples = num_in_samples/min_fc_feature_size 260 | if num_output_samples < min_fc_feature_size: 261 | num_output_samples = 1 262 | else: 263 | num_output_samples = num_output_samples/min_fc_feature_size 264 | 265 | self.assertEqual(len(layer_latency_table), num_in_samples*num_output_samples, "Layerwise latency dict length error (layer index: {})".format(layer_idx)) 266 | layer_idx += 1 267 | 268 | os.remove(lookup_table_path) 269 | 270 | if __name__ == '__main__': 271 | unittest.main() 272 | -------------------------------------------------------------------------------- /unittest/unittest_network_utils_helloworld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import pickle 5 | sys.path.append(os.path.abspath('../')) 6 | 7 | import network_utils as networkUtils 8 | import unittest 9 | import nets as models 10 | from constants import * 11 | import copy 12 | 13 | MODEL_ARCH = 'helloworld' 14 | INPUT_DATA_SHAPE = (3, 32, 32) 15 | LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl') 16 | 17 | model = models.__dict__[MODEL_ARCH]() 18 | for i in range(4): 19 | layer = getattr(model.features, str(i*2)) 20 | layer.weight.data = torch.zeros_like(layer.weight.data) 21 | model = model.cuda() 22 | 23 | network_utils = networkUtils.helloworld(model, INPUT_DATA_SHAPE) 24 | 25 | class TestNetworkUtils_helloworld(unittest.TestCase): 26 | def __init__(self, *args, **kwargs): 27 | super(TestNetworkUtils_helloworld, self).__init__(*args, **kwargs) 28 | 29 | 30 | def check_network_def(self, network_def, input_channels, output_channels): 31 | for layer_idx in range(len(network_def)): 32 | self.assertEqual(network_def[layer_idx][0], input_channels[layer_idx], "Number of input channels in network_def not match with that of the model") 33 | self.assertEqual(network_def[layer_idx][1], output_channels[layer_idx], "Number of output channels in network_def not match with that of the model") 34 | self.assertEqual(len(network_def), 4, "Number of layers in network_def is not equal to that in the original model") 35 | 36 | 37 | def test_num_sumplifiable_blocks(self): 38 | self.assertEqual(network_utils.get_num_simplifiable_blocks(), 3, "Number of simplifiabe blocks error") 39 | 40 | 41 | def test_network_def(self): 42 | network_def = network_utils.get_network_def_from_model(model) 43 | output_channels = [16, 32, 64, 10] 44 | input_channels = [3, 16, 32, 64] 45 | for layer_idx in range(len(network_def)): 46 | self.assertEqual(network_def[layer_idx][0], input_channels[layer_idx], "Number of input channels in network_def not match with that of the model") 47 | self.assertEqual(network_def[layer_idx][1], output_channels[layer_idx], "Number of output channels in network_def not match with that of the model") 48 | self.assertEqual(len(network_def), 4, "Number of layers in network_def is not equal to that in the original model") 49 | 50 | 51 | def test_compute_weights(self): 52 | network_def = network_utils.get_network_def_from_model(model) 53 | num_w = network_utils.compute_resource(network_def, 'WEIGHTS') 54 | num_w_gt = 3*3*(3*16 + 16*32 + 32*64 + 64*10) 55 | self.assertEqual(num_w, num_w_gt, "Number of weights error") 56 | 57 | 58 | def test_compute_flops(self): 59 | network_def = network_utils.get_network_def_from_model(model) 60 | network_utils.build_lookup_table(network_def, resource_type='FLOPS', lookup_table_path=LOOKUP_TABLE_PATH) 61 | with open(LOOKUP_TABLE_PATH, 'rb') as file_id: 62 | lookup_table = pickle.load(file_id) 63 | self.assertEqual(len(lookup_table), 4, "Lookup table has wrong number of layers") 64 | for layer_idx in range(4): 65 | feature_flops_dict = lookup_table[layer_idx] 66 | for feature_points in feature_flops_dict.keys(): 67 | entry = feature_flops_dict[feature_points] 68 | self.assertEqual(32*32*9*feature_points[0]*feature_points[1], entry, "Lookup table entry error") 69 | flops = network_utils.compute_resource(network_def, 'FLOPS', lookup_table_path=LOOKUP_TABLE_PATH) 70 | flops_gt = 3*3*(3*16 + 16*32 + 32*64 + 64*10)*32*32 71 | self.assertEqual(flops, flops_gt, "FLOPS estimation error") 72 | 73 | 74 | def test_extra_history_info(self): 75 | network_def = network_utils.get_network_def_from_model(model) 76 | network_def[0] = (3, 8) 77 | network_def[1] = (8, 32) 78 | num_filters_str = network_utils.extra_history_info(network_def) 79 | self.assertEqual(num_filters_str, "8 32 64 10", "Network architecture in extra_history_info is wrong") 80 | 81 | 82 | def test_simplift_network_def(self): 83 | network_def = network_utils.get_network_def_from_model(model) 84 | 85 | constraint_1 = 29232 - 3*3*(8*3 + 8*32) 86 | simp_network_def_1, simp_resource_1 = network_utils.simplify_network_def_based_on_constraint(network_def, 0, constraint_1, "WEIGHTS") 87 | self.check_network_def(simp_network_def_1, [3, 8, 32, 64], [8, 32, 64, 10]) 88 | 89 | constraint_2 = 29232 - 3*3*(16*16 + 16*64) 90 | simp_network_def_2, simp_resource_2 = network_utils.simplify_network_def_based_on_constraint(network_def, 1, constraint_2, "WEIGHTS") 91 | self.check_network_def(simp_network_def_2, [3, 16, 16, 64], [16, 16, 64, 10]) 92 | 93 | constraint_3 = 29232 - 3*3*(16*32 + 16*10) 94 | simp_network_def_3, simp_resource_3 = network_utils.simplify_network_def_based_on_constraint(network_def, 2, constraint_3, "WEIGHTS") 95 | self.check_network_def(simp_network_def_3, [3, 16, 32, 48], [16, 32, 48, 10]) 96 | 97 | 98 | def test_simplift_network_def_constraint_too_tight(self): 99 | network_def = network_utils.get_network_def_from_model(model) 100 | 101 | constraint_1 = -1 102 | simp_network_def_1, simp_resource_1 = network_utils.simplify_network_def_based_on_constraint(network_def, 0, constraint_1, "WEIGHTS") 103 | self.check_network_def(simp_network_def_1, [3, 1, 32, 64], [1, 32, 64, 10]) 104 | 105 | constraint_2 = 0 106 | simp_network_def_2, simp_resource_2 = network_utils.simplify_network_def_based_on_constraint(network_def, 1, constraint_2, "WEIGHTS") 107 | self.check_network_def(simp_network_def_2, [3, 16, 1, 64], [16, 1, 64, 10]) 108 | 109 | constraint_3 = 1 110 | simp_network_def_3, simp_resource_3 = network_utils.simplify_network_def_based_on_constraint(network_def, 2, constraint_3, "WEIGHTS") 111 | self.check_network_def(simp_network_def_3, [3, 16, 32, 1], [16, 32, 1, 10]) 112 | 113 | 114 | def test_simplify_model(self): 115 | network_def = network_utils.get_network_def_from_model(model) 116 | 117 | constraint_1 = 29232 - 3*3*(9*3 + 9*32) 118 | simp_network_def_1, simp_resource_1 = network_utils.simplify_network_def_based_on_constraint(network_def, 0, constraint_1, "WEIGHTS") 119 | self.check_network_def(simp_network_def_1, [3, 7, 32, 64], [7, 32, 64, 10]) 120 | simp_model_1 = network_utils.simplify_model_based_on_network_def(simp_network_def_1, model) 121 | update_network_def_1 = network_utils.get_network_def_from_model(simp_model_1) 122 | self.check_network_def(update_network_def_1, [3, 7, 32, 64], [7, 32, 64, 10]) 123 | 124 | constraint_2 = 29232 - 3*3*(19*16 + 19*64) 125 | simp_network_def_2, simp_resource_2 = network_utils.simplify_network_def_based_on_constraint(network_def, 1, constraint_2, "WEIGHTS") 126 | self.check_network_def(simp_network_def_2, [3, 16, 13, 64], [16, 13, 64, 10]) 127 | simp_model_2 = network_utils.simplify_model_based_on_network_def(simp_network_def_2, model) 128 | update_network_def_2 = network_utils.get_network_def_from_model(simp_model_2) 129 | self.check_network_def(update_network_def_2, [3, 16, 13, 64], [16, 13, 64, 10]) 130 | 131 | constraint_3 = 29232 - 3*3*(15*32 + 15*10) 132 | simp_network_def_3, simp_resource_3 = network_utils.simplify_network_def_based_on_constraint(network_def, 2, constraint_3, "WEIGHTS") 133 | self.check_network_def(simp_network_def_3, [3, 16, 32, 49], [16, 32, 49, 10]) 134 | simp_model_3 = network_utils.simplify_model_based_on_network_def(simp_network_def_3, model) 135 | update_network_def_3 = network_utils.get_network_def_from_model(simp_model_3) 136 | self.check_network_def(update_network_def_3, [3, 16, 32, 49], [16, 32, 49, 10]) 137 | 138 | 139 | acc_1 = network_utils.evaluate(simp_model_1) 140 | self.assertEqual(acc_1, 1, "Evaluation function error") 141 | 142 | acc_2 = network_utils.evaluate(simp_model_2) 143 | self.assertEqual(acc_2, 5, "Evaluation function error") 144 | 145 | acc_3 = network_utils.evaluate(simp_model_3) 146 | self.assertEqual(acc_3, 80, "Evaluation function error") 147 | 148 | input = torch.randn((4, 3, 32, 32)).cuda() 149 | _ = simp_model_1(input) 150 | _ = simp_model_2(input) 151 | _ = simp_model_3(input) 152 | 153 | 154 | def test_finetune(self): 155 | model_0 = copy.deepcopy(model) 156 | 157 | for layer_idx in range(4): 158 | layer = getattr(model_0.features, str(layer_idx*2)) 159 | layer.weight.data = torch.zeros_like(layer.weight.data) 160 | 161 | model_finetune = network_utils.fine_tune(model_0, iterations=1) 162 | for layer_idx in range(4): 163 | layer = getattr(model_finetune.features, str(layer_idx*2)) 164 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data).cuda() + layer_idx)) 165 | temp = torch.min(temp) 166 | temp = temp.item() 167 | self.assertTrue(temp, "Fintune function error when iteration = 1") 168 | 169 | model_finetune_0 = network_utils.fine_tune(model_0, iterations=0) 170 | for layer_idx in range(4): 171 | layer = getattr(model_finetune_0.features, str(layer_idx*2)) 172 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data).cuda())) 173 | temp = torch.min(temp) 174 | temp = temp.item() 175 | self.assertTrue(temp, "Finetune function error when iteration = 0") 176 | 177 | 178 | if __name__ == '__main__': 179 | unittest.main() 180 | -------------------------------------------------------------------------------- /unittest/unittest_network_utils_mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import pickle 5 | sys.path.append(os.path.abspath('../')) 6 | 7 | import network_utils as networkUtils 8 | import unittest 9 | import nets as models 10 | from constants import * 11 | import copy 12 | 13 | MODEL_ARCH = 'mobilenet' 14 | INPUT_DATA_SHAPE = (3, 224, 224) 15 | LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl') 16 | DATASET_PATH = '../data/' 17 | 18 | model = models.__dict__[MODEL_ARCH](num_classes=10) 19 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH) 20 | 21 | class TestNetworkUtils_mobilenet(unittest.TestCase): 22 | def __init__(self, *args, **kwargs): 23 | super(TestNetworkUtils_mobilenet, self).__init__(*args, **kwargs) 24 | 25 | 26 | def check_network_def(self, network_def, input_channels, output_channels, only_num_channels=False): 27 | self.assertEqual(len(network_def), 28, "network_def length error") 28 | layer_idx = 0 29 | for layer_name, layer_properties in network_def.items(): 30 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], input_channels[layer_idx], "network_def num of input channels error") 31 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], output_channels[layer_idx], "network_def num of output channels error") 32 | 33 | if layer_idx % 2 == 1 and layer_idx != 27: 34 | self.assertTrue(layer_properties[KEY_IS_DEPTHWISE], "network_def is_depthwise error") 35 | self.assertEqual(layer_properties[KEY_GROUPS], layer_properties[KEY_NUM_IN_CHANNELS], "network_def group error") 36 | else: 37 | self.assertFalse(layer_properties[KEY_IS_DEPTHWISE], "network_def is_depthwise error") 38 | self.assertEqual(layer_properties[KEY_GROUPS], 1, "network_def group error") 39 | if layer_idx == 27 or (layer_idx % 2 == 0 and layer_idx != 0): 40 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], (1, 1), "network_def kernel size error") 41 | self.assertEqual(layer_properties[KEY_PADDING], (0, 0), "network_def padding error") 42 | else: 43 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], (3, 3), "network_def kernel size error") 44 | self.assertEqual(layer_properties[KEY_PADDING], (1, 1), "network_def padding error") 45 | if layer_idx != 27: 46 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Conv2d', "network_def layer type string error") 47 | else: 48 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], 'Linear', "network_def layer type string error") 49 | if layer_idx in [0, 3, 7, 11, 23]: 50 | self.assertEqual(layer_properties[KEY_STRIDE], (2, 2), "network_def stride error") 51 | else: 52 | self.assertEqual(layer_properties[KEY_STRIDE], (1, 1), "network_def stride error") 53 | input_feature_map_spatial_size = [224, 112, 112, 112, 56, 56, 56, 56, 28, 28, 28, 28, 14, 54 | 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 7, 7, 7, 1] 55 | output_feature_map_spatial_size = [112, 112, 112, 56, 56, 56, 56, 28, 28, 28, 28, 14, 56 | 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 7, 7, 7, 7, 1] 57 | if not only_num_channels: 58 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], [1, input_channels[layer_idx], 59 | input_feature_map_spatial_size[layer_idx], input_feature_map_spatial_size[layer_idx]], 60 | "network_def input feature map size error") 61 | self.assertEqual(layer_properties[KEY_OUTPUT_FEATURE_MAP_SIZE], [1, output_channels[layer_idx], 62 | output_feature_map_spatial_size[layer_idx], output_feature_map_spatial_size[layer_idx]], 63 | "network_def output feature map size error") 64 | #print(layer_idx) 65 | layer_idx += 1 66 | 67 | 68 | def gen_layer_weight(self, tensor): 69 | gen_tensor = torch.zeros_like(tensor) 70 | for i in range(gen_tensor.shape[0]): 71 | gen_tensor[i, ::] += i 72 | return gen_tensor 73 | 74 | 75 | def test_network_def(self): 76 | network_def = network_utils.get_network_def_from_model(model) 77 | #print(network_def) 78 | #print(len(network_def)) 79 | input_channels = [3, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 80 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024] 81 | output_channels = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 82 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10] 83 | self.check_network_def(network_def, input_channels, output_channels) 84 | self.assertEqual(network_utils.get_num_simplifiable_blocks(), 14, "Num of simplifiable blocks error") 85 | 86 | 87 | def test_compute_resource(self): 88 | network_def = network_utils.get_network_def_from_model(model) 89 | num_w = network_utils.compute_resource(network_def, 'WEIGHTS') 90 | num_mac = network_utils.compute_resource(network_def, 'FLOPS') 91 | self.assertEqual(num_w, 3195328, "Num of weights error") 92 | self.assertEqual(num_mac, 567726592, "Num of MACs error") 93 | 94 | 95 | def test_extra_history_info(self): 96 | network_def = network_utils.get_network_def_from_model(model) 97 | output_feature_info = network_utils.extra_history_info(network_def) 98 | output_channels = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 99 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10] 100 | output_channels_str = [str(x) for x in output_channels] 101 | output_feature_info_gt = ' '.join(output_channels_str) 102 | self.assertEqual(output_feature_info, output_feature_info_gt, "extra_history_info error") 103 | 104 | 105 | def delta_to_layer_num_channels(self, delta, simp_block_idx): 106 | input_channels_gt = [3, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 107 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024] 108 | output_channels_gt = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 109 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10] 110 | 111 | if simp_block_idx == 0: 112 | input_channels_gt[simp_block_idx + 1] = input_channels_gt[simp_block_idx + 1] - delta 113 | input_channels_gt[simp_block_idx + 2] = input_channels_gt[simp_block_idx + 2] - delta 114 | output_channels_gt[simp_block_idx] = output_channels_gt[simp_block_idx] - delta 115 | output_channels_gt[simp_block_idx+1] = output_channels_gt[simp_block_idx+1] - delta 116 | elif simp_block_idx != 13: 117 | print(input_channels_gt) 118 | print(output_channels_gt) 119 | input_channels_gt[2*simp_block_idx+1] = input_channels_gt[2*simp_block_idx+1] - delta 120 | input_channels_gt[2*simp_block_idx+2] = input_channels_gt[2*simp_block_idx+2] - delta 121 | output_channels_gt[2*simp_block_idx] = output_channels_gt[2*simp_block_idx] - delta 122 | output_channels_gt[2*simp_block_idx+1]= output_channels_gt[2*simp_block_idx+1] - delta 123 | else: 124 | output_channels_gt[2*simp_block_idx] = output_channels_gt[2*simp_block_idx] - delta 125 | input_channels_gt[2*simp_block_idx+1] = input_channels_gt[2*simp_block_idx+1] - delta 126 | return input_channels_gt, output_channels_gt 127 | 128 | 129 | def run_simplify_network_def_and_check_for_one_resource_type(self, constraint, resource_type, simp_block_indices, delta, res_gt): 130 | network_def = network_utils.get_network_def_from_model(model) 131 | 132 | for i in range(len(simp_block_indices)): 133 | simp_block_idx = simp_block_indices[i] 134 | simp_network_def, simp_resource = network_utils.simplify_network_def_based_on_constraint(network_def, simp_block_idx, constraint, resource_type) 135 | self.assertEqual(simp_resource, res_gt[i], "Simplified network resource {} error".format(resource_type)) 136 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta[i], simp_block_idx) 137 | self.check_network_def(simp_network_def, input_channels_gt, output_channels_gt, only_num_channels=True) 138 | 139 | 140 | def test_simplify_network_def_based_on_constraint(self): 141 | total_num_w = 3195328 142 | total_num_mac = 567726592 143 | constraint_num_w = total_num_w*0.975 144 | constraint_num_mac = total_num_mac*0.975 145 | 146 | simp_block_indices = [0, 1, 5, 7, 9, 11, 13] 147 | delta_w = [24, 56, 104, 80, 80, 56, 80] 148 | delta_mac = [16, 24, 48, 72, 72, 96, 288] 149 | 150 | num_w_gt = [3192928, 3185864, 3114520, 3112688, 3112688, 3108808, 3112608] 151 | num_mac_gt = [547656192, 547781632, 553191232, 553148896, 553148896, 553233568, 553273024] 152 | 153 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_w, 154 | resource_type="WEIGHTS", simp_block_indices=simp_block_indices, 155 | delta=delta_w, res_gt=num_w_gt) 156 | self.run_simplify_network_def_and_check_for_one_resource_type(constraint=constraint_num_mac, 157 | resource_type="FLOPS", simp_block_indices=simp_block_indices, 158 | delta=delta_mac, res_gt=num_mac_gt) 159 | 160 | 161 | def test_simplify_model_based_on_network_def(self): 162 | network_def = network_utils.get_network_def_from_model(model) 163 | total_num_w = 3195328 164 | constraint_num_w = total_num_w*0.975 165 | simp_block_indices = [0, 1, 5, 7, 9, 11, 13] 166 | delta_w = [24, 56, 104, 80, 80, 56, 80] 167 | topk_w = [8, 8, 152, 432, 432, 456, 944] 168 | 169 | for i in range(len(simp_block_indices)): 170 | simp_block_idx = simp_block_indices[i] 171 | simp_network_def, _ = network_utils.simplify_network_def_based_on_constraint(network_def, 172 | simp_block_idx, constraint_num_w, "WEIGHTS") 173 | simp_model = network_utils.simplify_model_based_on_network_def(simp_network_def, model) 174 | updated_network_def = network_utils.get_network_def_from_model(simp_model) 175 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta_w[i], simp_block_idx) 176 | self.check_network_def(updated_network_def, input_channels_gt, output_channels_gt) 177 | 178 | conv_layers = getattr(model, 'model') 179 | simp_conv_layers = getattr(simp_model, 'model') 180 | for block_idx in range(14): 181 | module = getattr(conv_layers, str(block_idx)) 182 | simp_module = getattr(simp_conv_layers, str(block_idx)) 183 | if block_idx != simp_block_idx and block_idx != simp_block_idx + 1: 184 | if block_idx != 0: 185 | for layer_idx in ['0', '1', '3', '4']: 186 | layer = getattr(module, layer_idx) 187 | simp_layer = getattr(simp_module, layer_idx) 188 | if layer_idx in ['0', '3']: 189 | equal = (simp_layer.weight.data == layer.weight.data) 190 | self.assertTrue(equal.min(), "simplify_model_based_on_network_def modify unrelated conv layers") 191 | else: 192 | equal_weight = (simp_layer.weight.data == layer.weight.data) 193 | equal_bias = (simp_layer.bias.data == layer.bias.data) 194 | equal_num_features = (simp_layer.num_features == layer.num_features) 195 | self.assertTrue(equal_weight.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (weight)") 196 | self.assertTrue(equal_bias.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (bias)") 197 | self.assertTrue(equal_num_features, "simplify_model_based_on_network_def modify unrelated batchnorm layers (num_features)") 198 | else: 199 | layer = getattr(module, '0') 200 | simp_layer = getattr(simp_module, '0') 201 | equal = (simp_layer.weight.data == layer.weight.data) 202 | self.assertTrue(equal.min(), "simplify_model_based_on_network_def modify unrelated conv layers") 203 | 204 | layer = getattr(module, '1') 205 | simp_layer = getattr(simp_module, '1') 206 | equal_weight = (simp_layer.weight.data == layer.weight.data) 207 | equal_bias = (simp_layer.bias.data == layer.bias.data) 208 | equal_num_features = (simp_layer.num_features == layer.num_features) 209 | self.assertTrue(equal_weight.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (weight)") 210 | self.assertTrue(equal_bias.min(), "simplify_model_based_on_network_def modify unrelated batchnorm layers (bias)") 211 | self.assertTrue(equal_num_features, "simplify_model_based_on_network_def modify unrelated batchnorm layers (num_features)") 212 | 213 | elif block_idx == simp_block_idx: 214 | # check (regular/pointwise layer output channels and input channels of the next depthwise layer) 215 | # or check (pointwise layer output channels and nput features of the next FC layer) 216 | if block_idx == 0: 217 | layer = getattr(module, '0') 218 | simp_layer = getattr(simp_module, '0') 219 | else: # pointwise 220 | # first check depthwise layer within the same block 221 | layer = getattr(module, '0') 222 | simp_layer = getattr(module, '0') 223 | equal_dep = (layer.weight.data == simp_layer.weight.data) 224 | self.assertTrue(equal_dep.min(), "Depthwise layer within the target block error") 225 | 226 | layer = getattr(module, '3') 227 | simp_layer = getattr(simp_module, '3') 228 | 229 | layer_weight = layer.weight.data 230 | weight_vector = layer_weight.view(layer_weight.shape[0], -1) 231 | weight_norm = weight_vector*weight_vector 232 | weight_norm = torch.sum(weight_norm, dim=1) 233 | _, kept_filter_idx = torch.topk(weight_norm, topk_w[i], sorted=False) 234 | kept_filter_idx, _ = kept_filter_idx.sort() 235 | 236 | weight_gt = layer_weight[kept_filter_idx, :, :, :] 237 | weight_simp = simp_layer.weight.data 238 | equal_weight = (weight_gt == weight_simp) 239 | 240 | self.assertTrue(equal_weight.min(), "Output channels of the pruned layer error") 241 | 242 | # modify input channels of the next few layers 243 | if block_idx != 13: # depthwise -> batchnorm -> pointwise 244 | next_module = getattr(conv_layers, str(block_idx+1)) 245 | simp_next_module = getattr(simp_conv_layers, str(block_idx+1)) 246 | 247 | dep_layer = getattr(next_module, '0') 248 | simp_dep_layer = getattr(simp_next_module, '0') 249 | dep_layer_weight = dep_layer.weight.data[kept_filter_idx, :, :, :] 250 | equal_dep_weights = (dep_layer_weight == simp_dep_layer.weight.data) 251 | self.assertTrue(equal_dep_weights.min(), "Input channels of the depthwise layer after pruned layers error") 252 | 253 | batchnorm_layer = getattr(next_module, '1') 254 | simp_batchnorm_layer = getattr(simp_next_module, '1') 255 | batchnorm_layer_weight = batchnorm_layer.weight.data[kept_filter_idx] 256 | equal_batchnorm_weights = (batchnorm_layer_weight == simp_batchnorm_layer.weight.data) 257 | self.assertTrue(equal_batchnorm_weights.min(), "Weights of the batchnorm layer after pruned layers error") 258 | 259 | batchnorm_layer_bias = batchnorm_layer.bias.data[kept_filter_idx] 260 | equal_batchnorm_bias = (batchnorm_layer_bias == simp_batchnorm_layer.bias.data) 261 | self.assertTrue(equal_batchnorm_bias.min(), "Biases of the batchnorm layer after pruned layers error") 262 | 263 | equal_batchnorm_num_features = (len(kept_filter_idx) == simp_batchnorm_layer.num_features) 264 | self.assertTrue(equal_batchnorm_num_features, "Number of features of the batchnorm layer after pruned layers error") 265 | 266 | pt_layer = getattr(next_module, '3') 267 | simp_pt_layer = getattr(simp_next_module, '3') 268 | pt_layer_weight = pt_layer.weight.data[:, kept_filter_idx, :, :] 269 | equal_pt_weights = (pt_layer_weight == simp_pt_layer.weight.data) 270 | self.assertTrue(equal_pt_weights.min(), "Input channels of the pointwise layer after pruned layers error") 271 | 272 | else: # FC 273 | fc_layer = getattr(model, 'fc') 274 | simp_fc_layer = getattr(simp_model, 'fc') 275 | fc_layer_weight = fc_layer.weight.data 276 | fc_layer_weight = fc_layer_weight[:, kept_filter_idx] 277 | equal_fc_weights = (fc_layer_weight == simp_fc_layer.weight.data) 278 | self.assertTrue(equal_fc_weights.min(), "Input features of FC layer error") 279 | 280 | def test_simplify_model_based_on_network_def_check_weights(self): 281 | # make sure we prune the correct filters by checking the weights of a pruned model 282 | # the weights of the original model are initialized to certain values 283 | 284 | total_num_w = 3195328 285 | constraint_num_w = total_num_w*0.975 286 | simp_block_indices = [0, 1, 5, 7, 9, 11, 13] 287 | delta_w = [24, 56, 104, 80, 80, 56, 80] 288 | topk_w = [8, 8, 152, 432, 432, 456, 944] 289 | 290 | # initialze model weights 291 | model_init = copy.deepcopy(model) 292 | conv_layers = getattr(model_init, 'model') 293 | for block_idx in range(14): 294 | module = getattr(conv_layers, str(block_idx)) 295 | 296 | # regular/depthwise 297 | layer = getattr(module, '0') 298 | layer.weight.data = self.gen_layer_weight(layer.weight.data) 299 | 300 | if block_idx != 0: # pointwise 301 | layer = getattr(module, '3') 302 | layer.weight.data = self.gen_layer_weight(layer.weight.data) 303 | model_init.fc.weight.data = self.gen_layer_weight(model_init.fc.weight.data) 304 | 305 | 306 | network_def = network_utils.get_network_def_from_model(model_init) 307 | 308 | for i in range(len(simp_block_indices)): 309 | simp_block_idx = simp_block_indices[i] 310 | simp_network_def, _ = network_utils.simplify_network_def_based_on_constraint(network_def, 311 | simp_block_idx, constraint_num_w, "WEIGHTS") 312 | simp_model = network_utils.simplify_model_based_on_network_def(simp_network_def, model_init) 313 | updated_network_def = network_utils.get_network_def_from_model(simp_model) 314 | input_channels_gt, output_channels_gt = self.delta_to_layer_num_channels(delta_w[i], simp_block_idx) 315 | self.check_network_def(updated_network_def, input_channels_gt, output_channels_gt) 316 | 317 | simp_conv_layers = getattr(simp_model, 'model') 318 | for block_idx in range(14): 319 | if block_idx == simp_block_idx: 320 | simp_module = getattr(simp_conv_layers, str(block_idx)) 321 | 322 | if block_idx == 0: 323 | simp_layer = getattr(simp_module, '0') 324 | else: # pointwise 325 | simp_layer = getattr(simp_module, '3') 326 | 327 | for weight_idx in range(topk_w[i]): 328 | equal_weights = (simp_layer.weight.data[weight_idx, ::] == delta_w[i] + weight_idx) 329 | self.assertTrue(equal_weights.min(), "Weights of the pruned layers error") 330 | 331 | if simp_block_idx != 13: 332 | # check the next depthwise layer 333 | simp_module = getattr(simp_conv_layers, str(block_idx+1)) 334 | simp_layer = getattr(simp_module, '0') 335 | for weight_idx in range(topk_w[i]): 336 | equal_weights = (simp_layer.weight.data[weight_idx, ::] == delta_w[i] + weight_idx) 337 | self.assertTrue(equal_weights.min(), "Weights of the pruned layers error") 338 | 339 | 340 | def test_build_latency_lookup_table(self): 341 | network_def = network_utils.get_network_def_from_model(model) 342 | lookup_table_path = './unittest_lookup_table.plk' 343 | min_conv_feature_size = 32 344 | min_fc_feature_size = 128 345 | measure_latency_batch_size = 1 346 | measure_latency_sample_times = 1 347 | 348 | network_utils.build_lookup_table(network_def, 'LATENCY', lookup_table_path, min_conv_feature_size, 349 | min_fc_feature_size, measure_latency_batch_size, measure_latency_sample_times) 350 | 351 | with open(lookup_table_path, 'rb') as file_id: 352 | lookup_table = pickle.load(file_id) 353 | self.assertEqual(len(lookup_table), 28, "Lookup table length error") 354 | 355 | input_channels_gt = [3, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 356 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024] 357 | output_channels_gt = [32, 32, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 358 | 512, 512, 512, 512, 512, 512, 1024, 1024, 1024, 10] 359 | layer_idx = 0 360 | 361 | dep_layer_latency_dict_list = [] 362 | pt_layer_latency_dict_list = [] 363 | 364 | for layer_name, layer_properties in lookup_table.items(): 365 | self.assertEqual(layer_properties[KEY_IS_DEPTHWISE], network_def[layer_name][KEY_IS_DEPTHWISE], "lookup table layer properties error (is_depthwise)") 366 | self.assertEqual(layer_properties[KEY_NUM_IN_CHANNELS], network_def[layer_name][KEY_NUM_IN_CHANNELS], "lookup table layer properties error (num_in_channels)") 367 | self.assertEqual(layer_properties[KEY_NUM_OUT_CHANNELS], network_def[layer_name][KEY_NUM_OUT_CHANNELS], "lookup table layer properties error (num_out_channels)") 368 | self.assertEqual(layer_properties[KEY_KERNEL_SIZE], network_def[layer_name][KEY_KERNEL_SIZE], "lookup table layer properties error (kernel_size)") 369 | self.assertEqual(layer_properties[KEY_STRIDE], network_def[layer_name][KEY_STRIDE], "lookup table layer properties error (stride)") 370 | self.assertEqual(layer_properties[KEY_PADDING], network_def[layer_name][KEY_PADDING], "lookup table layer properties error (padding)") 371 | self.assertEqual(layer_properties[KEY_GROUPS], network_def[layer_name][KEY_GROUPS], "lookup table layer properties error (groups)") 372 | self.assertEqual(layer_properties[KEY_LAYER_TYPE_STR], network_def[layer_name][KEY_LAYER_TYPE_STR], "lookup table layer properties error (layer_type_str)") 373 | self.assertEqual(layer_properties[KEY_INPUT_FEATURE_MAP_SIZE], network_def[layer_name][KEY_INPUT_FEATURE_MAP_SIZE], "lookup table layer properties error (input_feature_size)") 374 | 375 | layer_latency_table = layer_properties[KEY_LATENCY] 376 | 377 | num_in_samples = input_channels_gt[layer_idx] 378 | num_output_samples = output_channels_gt[layer_idx] 379 | if layer_idx != 27: 380 | if num_in_samples < min_conv_feature_size: 381 | num_in_samples = 1 382 | else: 383 | num_in_samples = num_in_samples/min_conv_feature_size 384 | num_output_samples = num_output_samples/min_conv_feature_size 385 | else: 386 | num_in_samples = num_in_samples/min_fc_feature_size 387 | if num_output_samples < min_fc_feature_size: 388 | num_output_samples = 1 389 | if layer_idx != 27 and layer_idx % 2 == 1: 390 | self.assertEqual(len(layer_latency_table), num_in_samples, "Layerwise latency dict length error (layer index: {})".format(layer_idx)) 391 | else: 392 | self.assertEqual(len(layer_latency_table), num_in_samples*num_output_samples, "Layerwise latency dict length error (layer index: {})".format(layer_idx)) 393 | 394 | if layer_idx >= 13 and layer_idx <= 22: 395 | if layer_idx % 2 == 0: # pointwise layer 396 | pt_layer_latency_dict_list.append(layer_latency_table) 397 | else: # depthwise layer 398 | dep_layer_latency_dict_list.append(layer_latency_table) 399 | 400 | layer_idx += 1 401 | 402 | # check whether same layers have the same results 403 | for i in range(1, len(dep_layer_latency_dict_list)): 404 | latency_dict_gt = dep_layer_latency_dict_list[0] 405 | latency_dict = dep_layer_latency_dict_list[i] 406 | for key, value in latency_dict_gt.items(): 407 | self.assertEqual(latency_dict_gt[key], latency_dict[key], "Lookup talbe of same depthwise layers ({}) error".format(i)) 408 | 409 | for i in range(1, len(pt_layer_latency_dict_list)): 410 | latency_dict_gt = pt_layer_latency_dict_list[0] 411 | latency_dict = pt_layer_latency_dict_list[i] 412 | for key, value in latency_dict_gt.items(): 413 | self.assertEqual(latency_dict_gt[key], latency_dict[key], "Lookup talbe of same pointwise layers ({}) error".format(i)) 414 | 415 | os.remove(lookup_table_path) 416 | 417 | if __name__ == '__main__': 418 | unittest.main() 419 | -------------------------------------------------------------------------------- /unittest/unittest_worker_helloworld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | sys.path.append(os.path.abspath('../')) 5 | 6 | import network_utils as networkUtils 7 | import unittest 8 | from constants import * 9 | import subprocess 10 | import time 11 | import common 12 | import sys 13 | import nets as models 14 | import shutil 15 | 16 | 17 | MODEL_ARCH = 'helloworld' 18 | INPUT_DATA_SHAPE = (3, 32, 32) 19 | 20 | FLOPS_LOOKUP_TABLE_PATH = os.path.join('../models', MODEL_ARCH, 'lut.pkl') 21 | 22 | WORKER_FOLDER = os.path.join('../models', MODEL_ARCH, 'unittest_worker') 23 | if not os.path.exists(WORKER_FOLDER): 24 | os.mkdir(WORKER_FOLDER) 25 | print('Create directory', WORKER_FOLDER) 26 | _WORKER_PY_FILENAME = '../worker.py' 27 | 28 | MODEL_PATH = os.path.join('../models', MODEL_ARCH, 'model_0.pth.tar') 29 | 30 | model = models.__dict__[MODEL_ARCH]() 31 | for i in range(4): 32 | layer = getattr(model.features, str(i*2)) 33 | layer.weight.data = torch.zeros_like(layer.weight.data) 34 | torch.save(model, MODEL_PATH) 35 | 36 | DATASET_PATH = './' 37 | network_utils = networkUtils.__dict__[MODEL_ARCH](model, INPUT_DATA_SHAPE, DATASET_PATH) 38 | 39 | 40 | class TestWorker_helloworld(unittest.TestCase): 41 | def __init__(self, *args, **kwargs): 42 | super(TestWorker_helloworld, self).__init__(*args, **kwargs) 43 | 44 | 45 | def run_worker(self, constraint, netadapt_iteration, block, resource_type, short_term_fine_tune_iteration, finetune_lr=0.001, lookup_table_path=''): 46 | gpu = 0 47 | with open(os.path.join(WORKER_FOLDER, common.WORKER_LOG_FILENAME_TEMPLATE.format(netadapt_iteration, block)), 'w') as file_id: 48 | command_list = [sys.executable, _WORKER_PY_FILENAME, WORKER_FOLDER, \ 49 | MODEL_PATH, str(block), resource_type, str(constraint), \ 50 | str(netadapt_iteration), str(short_term_fine_tune_iteration), str(gpu), \ 51 | lookup_table_path, DATASET_PATH] + [str(e) for e in INPUT_DATA_SHAPE] + [MODEL_ARCH] + [str(finetune_lr)] 52 | print(command_list) 53 | return subprocess.call(command_list, stdout=file_id, stderr=file_id) 54 | #return os.system(' '.join(command_list)) 55 | 56 | 57 | def check_worker_simplify_and_finetune(self, constraint, netadapt_iteration, block, resource_type, 58 | short_term_fine_tune_iteration, resource_gt, acc_gt, network_def_gt, 59 | finetune_lr=0.001, lookup_table_path=''): 60 | t = time.time() 61 | returncode = self.run_worker(constraint, netadapt_iteration, block, 62 | resource_type, short_term_fine_tune_iteration, 63 | finetune_lr=0.001, lookup_table_path=lookup_table_path) 64 | print('Worker finish time: {}s'.format(time.time() - t)) 65 | 66 | # Check return code 67 | self.assertEqual(returncode, 0, "Normal worker return value error") 68 | saved_model = torch.load(os.path.join(WORKER_FOLDER, 69 | common.WORKER_MODEL_FILENAME_TEMPLATE.format(netadapt_iteration, block))) 70 | acc = network_utils.evaluate(saved_model) 71 | network_def_saved_model = network_utils.get_network_def_from_model(saved_model) 72 | res = network_utils.compute_resource(network_def_saved_model, resource_type=resource_type, lookup_table_path=lookup_table_path) 73 | 74 | with open(os.path.join(WORKER_FOLDER, common.WORKER_ACCURACY_FILENAME_TEMPLATE.format(netadapt_iteration, block)), 75 | 'r') as file_id: 76 | saved_acc = float(file_id.read()) 77 | with open(os.path.join(WORKER_FOLDER, common.WORKER_RESOURCE_FILENAME_TEMPLATE.format(netadapt_iteration, block)), 78 | 'r') as file_id: 79 | saved_res = float(file_id.read()) 80 | 81 | self.assertEqual(acc, acc_gt, "Evaluation of simplified model error") 82 | self.assertEqual(acc, saved_acc, "The value in accuracy file is not equal to accuracy of somplified model") 83 | self.assertEqual(res, resource_gt, "Resource of simplified model error") 84 | self.assertEqual(res, saved_res, "The value in resource file is not equal to resource of somplified model") 85 | 86 | for idx in range(4): 87 | layer = getattr(saved_model.features, str(idx*2)) 88 | temp = (layer.weight.data == (torch.zeros_like(layer.weight.data) + idx*short_term_fine_tune_iteration)) 89 | temp = torch.min(temp) 90 | temp = temp.item() 91 | self.assertTrue(temp, "Model weights after short-term fine-tune are incorrect") 92 | 93 | self.assertEqual(network_def_saved_model[idx], network_def_gt[idx], "network_def of simplified model is incorrect") 94 | return 95 | 96 | 97 | def test_worker_simplify_and_finetune_weights(self): 98 | ''' 99 | Check simplifying and finetuning block 0~2 100 | ''' 101 | 102 | netadapt_iteration = 2 103 | all_resource_weights = 29232 104 | 105 | constraint_weights = [all_resource_weights - 3*3*(8*(3 + 32)), 106 | all_resource_weights - 3*3*(7*(16 + 64)), 107 | all_resource_weights - 3*3*(31*(32 + 10))] 108 | resource_weights_gt = [all_resource_weights - 3*3*(8*(3 + 32)), 109 | all_resource_weights - 3*3*(7*(16 + 64)), 110 | all_resource_weights - 3*3*(31*(32 + 10))] 111 | eval_acc_gt = [1, 5, 80] 112 | 113 | network_def_gt = [ 114 | [(3, 8), (8, 32), (32, 64), (64, 10)], 115 | [(3, 16), (16, 25), (25, 64), (64, 10)], 116 | [(3, 16), (16, 32), (32, 33), (33, 10)] 117 | ] 118 | 119 | for block in range(3): 120 | self.check_worker_simplify_and_finetune(constraint=constraint_weights[block], 121 | netadapt_iteration=netadapt_iteration, block=block, 122 | resource_type="WEIGHTS", short_term_fine_tune_iteration=block, finetune_lr=0.001, 123 | lookup_table_path='', resource_gt=resource_weights_gt[block], acc_gt=eval_acc_gt[block], 124 | network_def_gt=network_def_gt[block]) 125 | return 126 | 127 | 128 | def test_worker_simplify_and_finetune_flops(self): 129 | ''' 130 | Check simplifying and finetuning block 0~2 131 | ''' 132 | 133 | netadapt_iteration = -1 134 | all_resource_flops = 29232*32*32 135 | 136 | constraint_flops = [all_resource_flops - 3*3*(8*(3 + 32))*32*32, 137 | all_resource_flops - 3*3*(7*(16 + 64))*32*32, 138 | all_resource_flops - 3*3*(31*(32 + 10))*32*32] 139 | resource_flops_gt = [all_resource_flops - 3*3*(8*(3 + 32))*32*32, 140 | all_resource_flops - 3*3*(7*(16 + 64))*32*32, 141 | all_resource_flops - 3*3*(31*(32 + 10))*32*32] 142 | 143 | 144 | eval_acc_gt = [1, 5, 80] 145 | 146 | network_def_gt = [ 147 | [(3, 8), (8, 32), (32, 64), (64, 10)], 148 | [(3, 16), (16, 25), (25, 64), (64, 10)], 149 | [(3, 16), (16, 32), (32, 33), (33, 10)] 150 | ] 151 | 152 | for block in range(3): 153 | self.check_worker_simplify_and_finetune(constraint=constraint_flops[block], 154 | netadapt_iteration=netadapt_iteration, block=block, 155 | resource_type="FLOPS", short_term_fine_tune_iteration=block, finetune_lr=0.001, 156 | lookup_table_path=FLOPS_LOOKUP_TABLE_PATH, resource_gt=resource_flops_gt[block], acc_gt=eval_acc_gt[block], 157 | network_def_gt=network_def_gt[block]) 158 | return 159 | 160 | 161 | def test_worker_weights_tight_constraint(self): 162 | ''' 163 | Check simplifying and finetuning block 0~2 164 | ''' 165 | 166 | netadapt_iteration = 2 167 | all_resource_weights = 29232 168 | 169 | constraint_weights = [0, 1, -1] 170 | resource_weights_gt = [all_resource_weights - 3*3*(15*(3 + 32)), 171 | all_resource_weights - 3*3*(31*(16 + 64)), 172 | all_resource_weights - 3*3*(63*(32 + 10))] 173 | eval_acc_gt = [1, 5, 80] 174 | 175 | network_def_gt = [ 176 | [(3, 1), (1, 32), (32, 64), (64, 10)], 177 | [(3, 16), (16, 1), (1, 64), (64, 10)], 178 | [(3, 16), (16, 32), (32, 1), (1, 10)] 179 | ] 180 | 181 | for block in range(3): 182 | self.check_worker_simplify_and_finetune(constraint=constraint_weights[block], 183 | netadapt_iteration=netadapt_iteration, block=block, 184 | resource_type="WEIGHTS", short_term_fine_tune_iteration=block, finetune_lr=0.001, 185 | lookup_table_path='', resource_gt=resource_weights_gt[block], acc_gt=eval_acc_gt[block], 186 | network_def_gt=network_def_gt[block]) 187 | return 188 | 189 | 190 | def test_worker_block_out_of_bound(self): 191 | netadapt_iteration = -5 192 | all_resource_flops = 29232*32*32 193 | returncode = self.run_worker(constraint=all_resource_flops, 194 | netadapt_iteration=netadapt_iteration, block=5, 195 | resource_type="FLOPS", short_term_fine_tune_iteration=5, finetune_lr=0.001) 196 | 197 | self.assertEqual(returncode, 1, "Abnormal worker not detected") 198 | 199 | 200 | if __name__ == '__main__': 201 | unittest.main() 202 | shutil.rmtree(WORKER_FOLDER) -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | import os 4 | import common 5 | import constants 6 | import network_utils as networkUtils 7 | 8 | ''' 9 | Launched by `master.py` 10 | 11 | Simplify a certain block of models and then finetune for several iterations. 12 | ''' 13 | 14 | 15 | # Supported network_utils 16 | network_utils_all = sorted(name for name in networkUtils.__dict__ 17 | if name.islower() and not name.startswith("__") 18 | and callable(networkUtils.__dict__[name])) 19 | 20 | 21 | def worker(args): 22 | """ 23 | The main function of the worker. 24 | `worker.py` loads a pretrained model, simplify it (one specific block), and short-term fine-tune the pruned model. 25 | Then, the accuracy and resource consumption of the simplified model will be recorded. 26 | `worker.py` finished with a finish file, which is utilized by `master.py`. 27 | 28 | Input: 29 | args: command-line arguments 30 | 31 | raise: 32 | ValueError: when the num of block index >= simplifiable blocks (i.e. simplify nonexistent block or output layer) 33 | """ 34 | 35 | # Set the GPU. 36 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 37 | 38 | # Get the network utils. 39 | model = torch.load(args.model_path) 40 | network_utils = networkUtils.__dict__[args.arch](model, args.input_data_shape, args.dataset_path, args.finetune_lr) 41 | 42 | if network_utils.get_num_simplifiable_blocks() <= args.block: 43 | raise ValueError("Block index >= number of simplifiable blocks") 44 | 45 | network_def = network_utils.get_network_def_from_model(model) 46 | simplified_network_def, simplified_resource = ( 47 | network_utils.simplify_network_def_based_on_constraint(network_def, 48 | args.block, 49 | args.constraint, 50 | args.resource_type, 51 | args.lookup_table_path)) 52 | # Choose the filters. 53 | simplified_model = network_utils.simplify_model_based_on_network_def(simplified_network_def, model) 54 | 55 | print('Original model:') 56 | print(model) 57 | print('') 58 | print('Simplified model:') 59 | print(simplified_model) 60 | 61 | fine_tuned_model = network_utils.fine_tune(simplified_model, args.short_term_fine_tune_iteration) 62 | fine_tuned_accuracy = network_utils.evaluate(fine_tuned_model) 63 | print('Accuracy after finetune:', fine_tuned_accuracy) 64 | 65 | # Save the results. 66 | torch.save(fine_tuned_model, 67 | os.path.join(args.worker_folder, 68 | common.WORKER_MODEL_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block))) 69 | with open(os.path.join(args.worker_folder, 70 | common.WORKER_ACCURACY_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block)), 71 | 'w') as file_id: 72 | file_id.write(str(fine_tuned_accuracy)) 73 | with open(os.path.join(args.worker_folder, 74 | common.WORKER_RESOURCE_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block)), 75 | 'w') as file_id: 76 | file_id.write(str(simplified_resource)) 77 | with open(os.path.join(args.worker_folder, 78 | common.WORKER_FINISH_FILENAME_TEMPLATE.format(args.netadapt_iteration, args.block)), 79 | 'w') as file_id: 80 | file_id.write('finished.') 81 | 82 | # release GPU memory 83 | del simplified_model, fine_tuned_model 84 | return 85 | 86 | if __name__ == '__main__': 87 | # Parse the input arguments. 88 | arg_parser = ArgumentParser() 89 | arg_parser.add_argument('worker_folder', type=str, 90 | help='directory where model and logging information will be saved') 91 | arg_parser.add_argument('model_path', type=str, help='path to model which is to be simplified') 92 | arg_parser.add_argument('block', type=int, help='index of block to be simplified') 93 | arg_parser.add_argument('resource_type', type=str, help='FLOPS/WEIGHTS/LATENCY') 94 | arg_parser.add_argument('constraint', type=float, help='floating value specifying resource constraint') 95 | arg_parser.add_argument('netadapt_iteration', type=int, help='netadapt iteration') 96 | arg_parser.add_argument('short_term_fine_tune_iteration', type=int, help='number of iterations of fine-tuning after simplification') 97 | arg_parser.add_argument('gpu', type=str, help='index of gpu to run short-term fine-tuning') 98 | arg_parser.add_argument('lookup_table_path', type=str, default='', help='path to lookup table') 99 | arg_parser.add_argument('dataset_path', type=str, default='', help='path to dataset') 100 | arg_parser.add_argument('input_data_shape', nargs=3, default=[], type=int, help='input shape (for ImageNet: `3 224 224`)') 101 | arg_parser.add_argument('arch', default='alexnet', 102 | choices=network_utils_all, 103 | help='network_utils: ' + 104 | ' | '.join(network_utils_all) + 105 | ' (default: alexnet)') 106 | arg_parser.add_argument('finetune_lr', type=float, default=0.001, help='short-term fine-tune learning rate') 107 | args = arg_parser.parse_args() 108 | 109 | # Launch a worker. 110 | worker(args) 111 | --------------------------------------------------------------------------------