├── .DS_Store ├── .gitignore ├── README.md ├── choose_strategy.py ├── compress_config ├── mbv1_imagenet.yaml └── res50_imagenet.yaml ├── data ├── __init__.py ├── imagenet.py ├── imagenet_train_val_split.py └── imagenet_train_val_split_idx.pickle ├── distiller ├── __init__.py ├── apputils │ ├── __init__.py │ ├── checkpoint.py │ ├── data_loaders.py │ ├── dataset_summaries.py │ └── execution_env.py ├── config.py ├── data_loggers │ ├── __init__.py │ ├── collector.py │ ├── logger.py │ └── tbbackend.py ├── directives.py ├── knowledge_distillation.py ├── learning_rate.py ├── model_summaries.py ├── models │ ├── __init__.py │ └── imagenet │ │ ├── __init__.py │ │ ├── alexnet_batchnorm.py │ │ ├── mobilenet.py │ │ ├── preresnet_imagenet.py │ │ ├── resnet.py │ │ └── resnet_earlyexit.py ├── modules │ ├── __init__.py │ ├── eltwise.py │ ├── grouping.py │ └── rnn.py ├── policy.py ├── pruning │ ├── __init__.py │ ├── automated_gradual_pruner.py │ ├── baidu_rnn_pruner.py │ ├── greedy_filter_pruning.py │ ├── level_pruner.py │ ├── magnitude_pruner.py │ ├── pruner.py │ ├── ranked_structures_pruner.py │ ├── sensitivity_pruner.py │ ├── splicing_pruner.py │ └── structure_pruner.py ├── quantization │ ├── __init__.py │ ├── clipped_linear.py │ ├── q_utils.py │ ├── quantizer.py │ └── range_linear.py ├── regularization │ ├── __init__.py │ ├── drop_filter.py │ ├── group_regularizer.py │ ├── l1_regularizer.py │ └── regularizer.py ├── scheduler.py ├── sensitivity.py ├── summary_graph.py ├── thinning.py ├── thresholding.py └── utils.py ├── docker ├── Dockerfile └── requirements.txt ├── fig ├── cor_fix_flops.png ├── eye.png └── pipeline.png ├── finetune.py ├── inference.py ├── models ├── __init__.py ├── mobilenet.py ├── resnet.py └── wrapper.py ├── options ├── __init__.py └── base_options.py ├── report └── __init__.py ├── requirements.txt ├── scripts ├── mbv1_50flops.sh ├── res50_25flops.sh ├── res50_50flops.sh └── res50_75flops.sh ├── search.py ├── search_results ├── best_strategy_mbv1_50flops.txt ├── best_strategy_res50_25flops.txt ├── best_strategy_res50_50flops.txt └── best_strategy_res50_75flops.txt └── thinning └── __init__.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous47823493/EagleEye/ba312d99587e4c1b9ffeb4f56ab81d909f2021a0/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | # Edit at https://www.gitignore.io/?templates=python 4 | 5 | models/ckpt/ 6 | ### Python ### 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # pipenv 76 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 77 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 78 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 79 | # install all needed dependencies. 80 | #Pipfile.lock 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # Mr Developer 96 | .mr.developer.cfg 97 | .project 98 | .pydevproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .dmypy.json 106 | dmypy.json 107 | 108 | # Pyre type checker 109 | .pyre/ 110 | 111 | # End of https://www.gitignore.io/api/python 112 | logs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning 2 | 3 | ![Python version support](https://img.shields.io/badge/python-3.6-blue.svg) 4 | ![PyTorch version support](https://img.shields.io/badge/pytorch-1.1.0-red.svg) 5 | 6 | PyTorch implementation for *[EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning](https://arxiv.org/abs/2007.02491)* 7 | 8 | [Bailin Li,](https://github.com/bezorro) [Bowen Wu](https://github.com/Bowenwu1), Jiang Su, [Guangrun Wang](https://wanggrun.github.io/projects/zw), [Liang Lin](http://www.linliang.net/) 9 | 10 | Presented at [ECCV 2020 (Oral)](https://eccv2020.eu/accepted-papers/) 11 | 12 | Check [slides](https://dmmo.dm-ai.cn/eagle_eye/dmai_eagleeye_jiqizhixin202008.pdf) about EagleEye: “High-performance AI on the Edge: from perspectives of model compression and hardware architecture design“, DMAI HiPerAIR, Aug. 2020. 13 | 14 | ![pipeline](fig/eye.png) 15 | 16 | ## Citation 17 | 18 | If you use EagleEye in your research, please consider citing: 19 | 20 | ``` 21 | @misc{li2020eagleeye, 22 | title={EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning}, 23 | author={Bailin Li and Bowen Wu and Jiang Su and Guangrun Wang and Liang Lin}, 24 | year={2020}, 25 | eprint={2007.02491}, 26 | archivePrefix={arXiv}, 27 | primaryClass={cs.CV} 28 | } 29 | ``` 30 | 31 | ## Update 32 | 33 | * 2021-11-03 We uploaded `Dockerfile` for the convenience of setup. 34 | 35 | * 2021-03-03: We updated the pretrained baseline ResNet50 of ImageNet in [Google Drive](). Before that, incorrect pretrained model cause lower experimental results. 36 | 37 | ## Adaptive-BN-based Candidate Evaluation 38 | 39 | For the ease of your own implementation, here we present the key code for proposed Adaptive-BN-based Candidate Evaluation. The official implementation will be released soon. 40 | 41 | ```python 42 | def eval_pruning_strategy(model, pruning_strategy, dataloader_train): 43 | # Apply filter pruning to trained model 44 | pruned_model = prune(model, pruning_strategy) 45 | 46 | # Adaptive-BN 47 | pruned_model.train() 48 | max_iter = 100 49 | with torch.no_grad(): 50 | for iter_in_epoch, sample in enumerate(dataloader_train): 51 | pruned_model.forward(sample) 52 | if iter_in_epoch > max_iter: 53 | break 54 | 55 | # Eval top-1 accuracy for pruned model 56 | acc = pruned_model.get_val_acc() 57 | return acc 58 | ``` 59 | 60 | ## Baseline Model Training 61 | 62 | The code used for training baseline models(MobileNetV1, ResNet50) will be released at [CNNResearchToolkit](https://github.com/Bowenwu1/CNNResearchToolkit). Welcome everyone to follow! 63 | 64 | ## Setup 65 | 66 | 1. **Prepare Data** 67 | 68 | Download `ILSVRC2012` dataset from http://image-net.org/challenges/LSVRC/2012/index#introduction 69 | 70 | 2. **Download Pretrained Models** 71 | 72 | We provide pretrained baseline models and reported pruned models in [Google Drive](). Please put the downloaded models in the dir of `models/ckpt/`. 73 | 74 | 3. **Prepare Runtime Environment** 75 | 76 | **Via pip/conda** 77 | ```shell 78 | pip install -r requirements.txt 79 | ``` 80 | 81 | **Via Docker** 82 | ```shell 83 | # Build Image 84 | docker build docker/ -t eagleeye:[tag] 85 | 86 | # launch docker container 87 | docker run -it --rm \ 88 | -v [PATH-TO-EAGLEEYE]:/workspace/EagleEye \ 89 | -v [PATH-TO-IMAGENET]:/data/imagenet \ 90 | --ipc=host \ 91 | eagleeye:[tag] 92 | ``` 93 | 94 | ## Usage 95 | 96 | Our proposed EagleEye contains 3 steps: 97 | 98 | 1. Adaptive-BN-based Searching for Pruning Strategy 99 | 2. Candidate Selection 100 | 3. Fine-tuning of Pruned Model 101 | 102 | ### 1. Adaptive-BN-based Searching for Pruning Strategy 103 | 104 | On this step, pruning strategies are randomly generated. Then, Adaptive-BN-based evaluation are performed among these pruning strategies. Pruning strategies and their eval scores will be saved to `search_results/pruning_strategies.txt`. 105 | 106 | If you do not want to perform searching by yourself, the provided search result could be found in `search_results/`. 107 | 108 | Parameters involved in this steps: 109 | 110 | |Name|Description| 111 | |----|-----------| 112 | |`--flops_target`|The remaining ratio of FLOPs of pruned model| 113 | |`--max_rate`
`--min_rate`|Define the search space. The search space is [min_rate, max_rate]| 114 | |`--output_file`|File stores the searching results.| 115 | 116 | Sample scripts could refer to `1. Search` of `scripts/mbv1_50flops.sh`. 117 | 118 | **Searching space for different models** 119 | 120 | |Model|Pruned FLOPs|[min_rate, max_rate]| 121 | |-----|-----|--------------------| 122 | |MobileNetV1|-50%|[0, 0.7]| 123 | |ResNet50|-25%|[0, 0.4]| 124 | |ResNet50|-50%|[0, 0.7]| 125 | |ResNet50|-75%|[0, 0.8]| 126 | 127 | ### 2. Candidate Selection 128 | 129 | On this step, best pruning strategy is picked from `output_file` generated on step1. 130 | 131 | The output looks like as following: 132 | ``` 133 | ########## pruning_strategies.txt ########## 134 | strategy index:84, score:0.143 135 | strategy index:985, score:0.123 136 | ``` 137 | 138 | Sample scripts could refer to `2. Selection` of `scripts/mbv1_50flops.sh`. 139 | 140 | ### 3. Fine-tuning of Pruned Model 141 | 142 | This step take strategy index as input and perform fine-tuning on it. 143 | 144 | Parameters involved in this steps: 145 | 146 | |Name|Description| 147 | |----|-----------| 148 | |`--search_result`|Searching results| 149 | |`--strategy_id`|Index of best pruning strategy from step2| 150 | |`--lr`|Learning rate for fine-tuning| 151 | |`--weight_decay`|Weight decay while fine-tuning| 152 | |`--epoch`|Number of fine-tuning epoch| 153 | 154 | Sample scripts could refer to `3. Fine-tuning` of `scripts/mbv1_50flops.sh`. 155 | 156 | 157 | 158 | ## Inference of Pruned Model 159 | 160 | **For ResNet50:** 161 | 162 | ```shell 163 | python3 inference.py \ 164 | --model_name resnet50 \ 165 | --num_classes 1000 \ 166 | --checkpoint models/ckpt/{resnet50_25flops.pth|resnet50_50flops.pth|resnet50_72flops.pth} \ 167 | --gpu_ids 4 \ 168 | --batch_size 512 \ 169 | --dataset_path {PATH_TO_IMAGENET} \ 170 | --dataset_name imagenet \ 171 | --num_workers 20 172 | ``` 173 | 174 | **For MobileNetV1:** 175 | 176 | ```shell 177 | python3 inference.py \ 178 | --model_name mobilenetv1 \ 179 | --num_classes 1000 \ 180 | --checkpoint models/ckpt/mobilenetv1_50flops.pth \ 181 | --gpu_ids 4 \ 182 | --batch_size 512 \ 183 | --dataset_path {PATH_TO_IMAGENET} \ 184 | --dataset_name imagenet \ 185 | --num_workers 20 186 | ``` 187 | 188 | After running above program, the output looks like below: 189 | 190 | ``` 191 | ######### Report ######### 192 | Model:resnet50 193 | Checkpoint:models/ckpt/resnet50_50flops_7637.pth 194 | FLOPs of Original Model:4.089G;Params of Original Model:25.50M 195 | FLOPs of Pruned Model:2.057G;Params of Pruned Model:14.37M 196 | Top-1 Acc of Pruned Model on imagenet:0.76366 197 | ########################## 198 | ``` 199 | 200 | 201 | ## Results 202 | 203 | ### Quantitative analysis of correlation 204 | 205 | Correlation between evaluation and fine-tuning accuracy with different pruning ratios (MobileNet V1 on ImageNet classification Top-1 results) 206 | 207 | ![corr](fig/cor_fix_flops.png) 208 | 209 | ### Results on ImageNet 210 | 211 | | Model | FLOPs | Top-1 Acc | Top-5 Acc | Checkpoint | 212 | | --- | ---- | ------- | -------- | ---------------- | 213 | | ResNet-50 | 3G
2G
1G | 77.1%
76.4%
74.2%| 93.37%
92.89%
91.77% | [resnet50_75flops.pth](https://drive.google.com/file/d/1oPQOZJdKwZPXPSLykLkruHAFhxFmdzHp/view?usp=sharing)
[resnet50_50flops.pth](https://drive.google.com/file/d/19eOUO0LTzrQ-9izO4OzXcg83XAGwxf7u/view?usp=sharing)
[resnet50_25flops.pth](https://drive.google.com/file/d/1ppBLtajt5xcwa5xoonwn1T98MTB0V9DU/view?usp=sharing) | 214 | | MobileNetV1 | 284M | 70.9% | 89.62% | [mobilenetv1_50flops.pth](https://drive.google.com/file/d/1LZGqk_oPXNYcGa5Gk93fmHxRgPdfzf9p/view?usp=sharing) | 215 | 216 | ### Results on CIFAR-10 217 | 218 | | Model | FLOPs | Top-1 Acc | 219 | | --- | ---- | ----- | 220 | | ResNet-50 | 62.23M | 94.66% | 221 | | MobileNetV1 | 26.5M
12.1M
3.3M | 91.89%
91.44%
88.01% | 222 | 223 | -------------------------------------------------------------------------------- /choose_strategy.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 14 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import sys 8 | 9 | data_path = sys.argv[1] 10 | 11 | score = [] 12 | 13 | print("#" * 10, data_path, "#" * 10) 14 | with open(data_path) as data: 15 | line = data.readlines() 16 | for l in line: 17 | d = l.split(" ") 18 | score.append(float(d[0])) 19 | 20 | score_sorted_index = sorted(range(len(score)), key=lambda k: score[k], reverse=True) 21 | 22 | for i in range(5): 23 | print( 24 | "strategy index:{}, score:{}".format( 25 | score_sorted_index[i], score[score_sorted_index[i]] 26 | ) 27 | ) 28 | print("\n") 29 | -------------------------------------------------------------------------------- /compress_config/mbv1_imagenet.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | pruners: 3 | 4 | pruner_base_0: 5 | class: 'L1RankedStructureParameterPruner' 6 | group_type: Filters 7 | desired_sparsity: 0 8 | weights: [ 9 | conv1.weight 10 | ] 11 | 12 | pruner_base_1: 13 | class: 'L1RankedStructureParameterPruner' 14 | group_type: Filters 15 | desired_sparsity: 0 16 | weights: [ 17 | layers.0.conv2.weight 18 | ] 19 | pruner_base_2: 20 | class: 'L1RankedStructureParameterPruner' 21 | group_type: Filters 22 | desired_sparsity: 0 23 | weights: [ 24 | layers.1.conv2.weight 25 | ] 26 | pruner_base_3: 27 | class: 'L1RankedStructureParameterPruner' 28 | group_type: Filters 29 | desired_sparsity: 0.1 30 | weights: [ 31 | layers.2.conv2.weight 32 | ] 33 | pruner_base_4: 34 | class: 'L1RankedStructureParameterPruner' 35 | group_type: Filters 36 | desired_sparsity: 0.3 37 | weights: [ 38 | layers.3.conv2.weight 39 | ] 40 | pruner_base_5: 41 | class: 'L1RankedStructureParameterPruner' 42 | group_type: Filters 43 | desired_sparsity: 0.3 44 | weights: [ 45 | layers.4.conv2.weight 46 | ] 47 | pruner_base_6: 48 | class: 'L1RankedStructureParameterPruner' 49 | group_type: Filters 50 | desired_sparsity: 0.2 51 | weights: [ 52 | layers.5.conv2.weight 53 | ] 54 | pruner_base_7: 55 | class: 'L1RankedStructureParameterPruner' 56 | group_type: Filters 57 | desired_sparsity: 0.2 58 | weights: [ 59 | layers.6.conv2.weight 60 | ] 61 | pruner_base_8: 62 | class: 'L1RankedStructureParameterPruner' 63 | group_type: Filters 64 | desired_sparsity: 0.3 65 | weights: [ 66 | layers.7.conv2.weight 67 | ] 68 | pruner_base_9: 69 | class: 'L1RankedStructureParameterPruner' 70 | group_type: Filters 71 | desired_sparsity: 0.2 72 | weights: [ 73 | layers.8.conv2.weight 74 | ] 75 | pruner_base_10: 76 | class: 'L1RankedStructureParameterPruner' 77 | group_type: Filters 78 | desired_sparsity: 0.1 79 | weights: [ 80 | layers.9.conv2.weight 81 | ] 82 | pruner_base_11: 83 | class: 'L1RankedStructureParameterPruner' 84 | group_type: Filters 85 | desired_sparsity: 0.1 86 | weights: [ 87 | layers.10.conv2.weight 88 | ] 89 | pruner_base_12: 90 | class: 'L1RankedStructureParameterPruner' 91 | group_type: Filters 92 | desired_sparsity: 0.3 93 | weights: [ 94 | layers.11.conv2.weight 95 | ] 96 | pruner_base_13: 97 | class: 'L1RankedStructureParameterPruner' 98 | group_type: Filters 99 | desired_sparsity: 0.2 100 | weights: [ 101 | layers.12.conv2.weight 102 | ] 103 | policies: 104 | - pruner: 105 | instance_name : pruner_base_0 106 | epochs: [1] 107 | 108 | - pruner: 109 | instance_name : pruner_base_1 110 | epochs: [1] 111 | 112 | - pruner: 113 | instance_name : pruner_base_2 114 | epochs: [1] 115 | 116 | - pruner: 117 | instance_name : pruner_base_3 118 | epochs: [1] 119 | 120 | - pruner: 121 | instance_name : pruner_base_4 122 | epochs: [1] 123 | 124 | - pruner: 125 | instance_name : pruner_base_5 126 | epochs: [1] 127 | 128 | - pruner: 129 | instance_name : pruner_base_6 130 | epochs: [1] 131 | 132 | - pruner: 133 | instance_name : pruner_base_7 134 | epochs: [1] 135 | 136 | - pruner: 137 | instance_name : pruner_base_8 138 | epochs: [1] 139 | 140 | - pruner: 141 | instance_name : pruner_base_9 142 | epochs: [1] 143 | 144 | - pruner: 145 | instance_name : pruner_base_10 146 | epochs: [1] 147 | 148 | - pruner: 149 | instance_name : pruner_base_11 150 | epochs: [1] 151 | 152 | - pruner: 153 | instance_name : pruner_base_12 154 | epochs: [1] 155 | 156 | - pruner: 157 | instance_name : pruner_base_13 158 | epochs: [1] 159 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import os 8 | import torch 9 | import importlib 10 | 11 | 12 | def custom_get_dataloaders(opt): 13 | dataset_filename = "data." + opt.dataset_name 14 | datasetlib = importlib.import_module(dataset_filename) 15 | # find method named `get_dataloaders` 16 | for name, method in datasetlib.__dict__.items(): 17 | if name.lower() == "get_dataloaders": 18 | get_data_func = method 19 | return get_data_func(opt.batch_size, opt.num_workers, path=opt.dataset_path) 20 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import torchvision.datasets as datasets 6 | 7 | 8 | def get_dataloaders(batch_size, n_workers, path=""): 9 | normalize = transforms.Normalize( 10 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 11 | ) 12 | train_dataset = datasets.ImageFolder( 13 | osp.join(path, "train"), 14 | transforms.Compose( 15 | [ 16 | transforms.RandomResizedCrop(224), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.ToTensor(), 19 | normalize, 20 | ] 21 | ), 22 | ) 23 | 24 | test_dataset = datasets.ImageFolder( 25 | osp.join(path, "val"), 26 | transforms.Compose( 27 | [ 28 | transforms.Resize(256), 29 | transforms.CenterCrop(224), 30 | transforms.ToTensor(), 31 | normalize, 32 | ] 33 | ), 34 | ) 35 | 36 | dataloader_train = torch.utils.data.DataLoader( 37 | train_dataset, 38 | batch_size=batch_size, 39 | shuffle=True, 40 | num_workers=n_workers, 41 | pin_memory=True, 42 | ) 43 | dataloader_test = torch.utils.data.DataLoader( 44 | test_dataset, 45 | batch_size=batch_size, 46 | shuffle=False, 47 | num_workers=n_workers, 48 | pin_memory=True, 49 | ) 50 | return dataloader_train, dataloader_test 51 | -------------------------------------------------------------------------------- /data/imagenet_train_val_split.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import os.path as osp 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import torchvision.datasets as datasets 12 | import numpy as np 13 | 14 | np.random.seed(2019) 15 | import random 16 | 17 | random.seed(2019) 18 | from torch.utils.data.sampler import SubsetRandomSampler 19 | import pickle 20 | 21 | 22 | def get_dataloaders(batch_size, n_workers, path=""): 23 | print("USE PART OF TRAIN SET WITH UNIFORM SPLIT") 24 | normalize = transforms.Normalize( 25 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 26 | ) 27 | train_dataset = datasets.ImageFolder( 28 | osp.join(path, "train"), 29 | transforms.Compose( 30 | [ 31 | transforms.RandomResizedCrop(224), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | normalize, 35 | ] 36 | ), 37 | ) 38 | 39 | val_dataset = datasets.ImageFolder( 40 | osp.join(path, "val"), 41 | transforms.Compose( 42 | [ 43 | transforms.Resize(256), 44 | transforms.CenterCrop(224), 45 | transforms.ToTensor(), 46 | normalize, 47 | ] 48 | ), 49 | ) 50 | val_idx_filename = "data/imagenet_train_val_split_idx.pickle" 51 | print("len(train_dataset)", len(train_dataset)) 52 | if not osp.exists(val_idx_filename): 53 | val_size = 10000 54 | val_idx = [] 55 | cls_start, cls_end = 0, 0 56 | for c_id in range(1000): 57 | for i in range(cls_start, len(train_dataset)): 58 | if train_dataset[i][1] == c_id: 59 | cls_end = i + 1 60 | else: 61 | break 62 | c_list = list(range(cls_start, cls_end)) 63 | print("cid:{}, c_start:{}, c_end:{}".format(c_id, cls_start, cls_end)) 64 | print(int(val_size / 1000)) 65 | c_sample = random.sample(c_list, int(val_size / 1000)) 66 | val_idx += c_sample 67 | cls_start = cls_end 68 | print("len of val_size:{}".format(len(val_idx))) 69 | pickle.dump(val_idx, open(val_idx_filename, "wb")) 70 | else: 71 | val_idx = pickle.load(open(val_idx_filename, "rb")) 72 | val_sampler = SubsetRandomSampler(val_idx) 73 | dataloader_train = torch.utils.data.DataLoader( 74 | train_dataset, 75 | batch_size=batch_size, 76 | shuffle=True, 77 | num_workers=n_workers, 78 | pin_memory=True, 79 | ) 80 | dataloader_test = torch.utils.data.DataLoader( 81 | train_dataset, 82 | batch_size=batch_size, 83 | shuffle=False, 84 | num_workers=n_workers, 85 | pin_memory=True, 86 | sampler=val_sampler, 87 | ) 88 | return dataloader_train, dataloader_test 89 | -------------------------------------------------------------------------------- /data/imagenet_train_val_split_idx.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous47823493/EagleEye/ba312d99587e4c1b9ffeb4f56ab81d909f2021a0/data/imagenet_train_val_split_idx.pickle -------------------------------------------------------------------------------- /distiller/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .utils import * 18 | from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask 19 | from .config import file_config, dict_config, config_component_from_file_by_class 20 | from .model_summaries import * 21 | from .scheduler import * 22 | from .sensitivity import * 23 | from .directives import * 24 | from .policy import * 25 | from .thinning import * 26 | from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights 27 | from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name 28 | 29 | 30 | del dict_config 31 | del thinning 32 | 33 | # Distiller version 34 | __version__ = "0.4.0-pre" 35 | 36 | 37 | def model_find_param_name(model, param_to_find): 38 | """Look up the name of a model parameter. 39 | 40 | Arguments: 41 | model: the model to search 42 | param_to_find: the parameter whose name we want to look up 43 | 44 | Returns: 45 | The parameter name (string) or None, if the parameter was not found. 46 | """ 47 | for name, param in model.named_parameters(): 48 | if param is param_to_find: 49 | return name 50 | return None 51 | 52 | 53 | def model_find_module_name(model, module_to_find): 54 | """Look up the name of a module in a model. 55 | 56 | Arguments: 57 | model: the model to search 58 | module_to_find: the module whose name we want to look up 59 | 60 | Returns: 61 | The module name (string) or None, if the module was not found. 62 | """ 63 | for name, m in model.named_modules(): 64 | if m == module_to_find: 65 | return name 66 | return None 67 | 68 | 69 | def model_find_param(model, param_to_find_name): 70 | """Look a model parameter by its name 71 | 72 | Arguments: 73 | model: the model to search 74 | param_to_find_name: the name of the parameter that we are searching for 75 | 76 | Returns: 77 | The parameter or None, if the paramter name was not found. 78 | """ 79 | for name, param in model.named_parameters(): 80 | if name == param_to_find_name: 81 | return param 82 | return None 83 | 84 | 85 | def model_find_module(model, module_to_find): 86 | """Given a module name, find the module in the provided model. 87 | 88 | Arguments: 89 | model: the model to search 90 | module_to_find: the module whose name we want to look up 91 | 92 | Returns: 93 | The module or None, if the module was not found. 94 | """ 95 | for name, m in model.named_modules(): 96 | if name == module_to_find: 97 | return m 98 | return None 99 | -------------------------------------------------------------------------------- /distiller/apputils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains Python code and classes that are meant to make your life easier, 18 | when working with distiller. 19 | 20 | """ 21 | from .data_loaders import * 22 | from .checkpoint import * 23 | from .execution_env import * 24 | from .dataset_summaries import * 25 | 26 | del data_loaders 27 | del checkpoint 28 | del execution_env 29 | del dataset_summaries 30 | -------------------------------------------------------------------------------- /distiller/apputils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ Helper code for checkpointing models, with support for saving the pruning schedule. 18 | 19 | Adding the schedule information in the model checkpoint is helpful in resuming 20 | a pruning session, or for querying the pruning schedule of a sparse model. 21 | """ 22 | 23 | import os 24 | import shutil 25 | from errno import ENOENT 26 | import logging 27 | from numbers import Number 28 | from tabulate import tabulate 29 | import torch 30 | import distiller 31 | from distiller.utils import normalize_module_name 32 | 33 | msglogger = logging.getLogger() 34 | 35 | 36 | def save_checkpoint( 37 | epoch, 38 | arch, 39 | model, 40 | optimizer=None, 41 | scheduler=None, 42 | extras=None, 43 | is_best=False, 44 | name=None, 45 | dir=".", 46 | ): 47 | """Save a pytorch training checkpoint 48 | 49 | Args: 50 | epoch: current epoch number 51 | arch: name of the network architecture/topology 52 | model: a pytorch model 53 | optimizer: the optimizer used in the training session 54 | scheduler: the CompressionScheduler instance used for training, if any 55 | extras: optional dict with additional user-defined data to be saved in the checkpoint. 56 | Will be saved under the key 'extras' 57 | is_best: If true, will save a copy of the checkpoint with the suffix 'best' 58 | name: the name of the checkpoint file 59 | dir: directory in which to save the checkpoint 60 | """ 61 | if not os.path.isdir(dir): 62 | raise IOError( 63 | ENOENT, "Checkpoint directory does not exist at", os.path.abspath(dir) 64 | ) 65 | 66 | if extras is None: 67 | extras = {} 68 | if not isinstance(extras, dict): 69 | raise TypeError("extras must be either a dict or None") 70 | 71 | filename = "checkpoint.pth.tar" if name is None else name + "_checkpoint.pth.tar" 72 | fullpath = os.path.join(dir, filename) 73 | model_fullpath = fullpath.replace(".pth.tar", ".pth").replace( 74 | "checkpoint", "prunned_model" 75 | ) 76 | msglogger.info("Saving checkpoint to: %s" % fullpath) 77 | filename_best = "best.pth.tar" if name is None else name + "_best.pth.tar" 78 | fullpath_best = os.path.join(dir, filename_best) 79 | 80 | checkpoint = {} 81 | checkpoint["epoch"] = epoch 82 | checkpoint["arch"] = arch 83 | checkpoint["state_dict"] = model.state_dict() 84 | if optimizer is not None: 85 | checkpoint["optimizer_state_dict"] = optimizer.state_dict() 86 | checkpoint["optimizer_type"] = type(optimizer) 87 | if scheduler is not None: 88 | checkpoint["compression_sched"] = scheduler.state_dict() 89 | if hasattr(model, "thinning_recipes"): 90 | checkpoint["thinning_recipes"] = model.thinning_recipes 91 | if hasattr(model, "quantizer_metadata"): 92 | checkpoint["quantizer_metadata"] = model.quantizer_metadata 93 | 94 | checkpoint["extras"] = extras 95 | 96 | torch.save(checkpoint, fullpath) 97 | torch.save(model, model_fullpath) 98 | if is_best: 99 | shutil.copyfile(fullpath, fullpath_best) 100 | 101 | 102 | def load_lean_checkpoint(model, chkpt_file, model_device=None): 103 | return load_checkpoint( 104 | model, chkpt_file, model_device=model_device, lean_checkpoint=True 105 | )[0] 106 | 107 | 108 | def get_contents_table(d): 109 | def inspect_val(val): 110 | if isinstance(val, (Number, str)): 111 | return val 112 | elif isinstance(val, type): 113 | return val.__name__ 114 | return None 115 | 116 | contents = [[k, type(d[k]).__name__, inspect_val(d[k])] for k in d.keys()] 117 | contents = sorted(contents, key=lambda entry: entry[0]) 118 | return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="fancy_grid") 119 | 120 | 121 | def load_checkpoint( 122 | model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False 123 | ): 124 | """Load a pytorch training checkpoint. 125 | 126 | Args: 127 | model: the pytorch model to which we will load the parameters 128 | chkpt_file: the checkpoint file 129 | lean_checkpoint: if set, read into model only 'state_dict' field 130 | optimizer: [deprecated argument] 131 | model_device [str]: if set, call model.to($model_device) 132 | This should be set to either 'cpu' or 'cuda'. 133 | :returns: updated model, compression_scheduler, optimizer, start_epoch 134 | """ 135 | if not os.path.isfile(chkpt_file): 136 | raise IOError(ENOENT, "Could not find a checkpoint file at", chkpt_file) 137 | 138 | msglogger.info("=> loading checkpoint %s", chkpt_file) 139 | checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage) 140 | 141 | msglogger.info( 142 | "=> Checkpoint contents:\n{}\n".format(get_contents_table(checkpoint)) 143 | ) 144 | if "extras" in checkpoint: 145 | msglogger.info( 146 | "=> Checkpoint['extras'] contents:\n{}\n".format( 147 | get_contents_table(checkpoint["extras"]) 148 | ) 149 | ) 150 | 151 | if "state_dict" not in checkpoint: 152 | raise ValueError( 153 | "Checkpoint must contain the model parameters under the key 'state_dict'" 154 | ) 155 | 156 | checkpoint_epoch = checkpoint.get("epoch", None) 157 | start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0 158 | 159 | compression_scheduler = None 160 | normalize_dataparallel_keys = False 161 | if "compression_sched" in checkpoint: 162 | compression_scheduler = distiller.CompressionScheduler(model) 163 | try: 164 | compression_scheduler.load_state_dict( 165 | checkpoint["compression_sched"], normalize_dataparallel_keys 166 | ) 167 | except KeyError as e: 168 | # A very common source of this KeyError is loading a GPU model on the CPU. 169 | # We rename all of the DataParallel keys because DataParallel does not execute on the CPU. 170 | normalize_dataparallel_keys = True 171 | compression_scheduler.load_state_dict( 172 | checkpoint["compression_sched"], normalize_dataparallel_keys 173 | ) 174 | msglogger.info( 175 | "Loaded compression schedule from checkpoint (epoch {})".format( 176 | checkpoint_epoch 177 | ) 178 | ) 179 | else: 180 | msglogger.info( 181 | "Warning: compression schedule data does not exist in the checkpoint" 182 | ) 183 | 184 | if "thinning_recipes" in checkpoint: 185 | if "compression_sched" not in checkpoint: 186 | raise KeyError( 187 | "Found thinning_recipes key, but missing mandatory key compression_sched" 188 | ) 189 | msglogger.info("Loaded a thinning recipe from the checkpoint") 190 | # Cache the recipes in case we need them later 191 | model.thinning_recipes = checkpoint["thinning_recipes"] 192 | if normalize_dataparallel_keys: 193 | model.thinning_recipes = [ 194 | distiller.get_normalized_recipe(recipe) 195 | for recipe in model.thinning_recipes 196 | ] 197 | distiller.execute_thinning_recipes_list( 198 | model, compression_scheduler.zeros_mask_dict, model.thinning_recipes 199 | ) 200 | 201 | if "quantizer_metadata" in checkpoint: 202 | msglogger.info("Loaded quantizer metadata from the checkpoint") 203 | qmd = checkpoint["quantizer_metadata"] 204 | quantizer = qmd["type"](model, **qmd["params"]) 205 | quantizer.prepare_model() 206 | 207 | if normalize_dataparallel_keys: 208 | checkpoint["state_dict"] = { 209 | normalize_module_name(k): v for k, v in checkpoint["state_dict"].items() 210 | } 211 | model.load_state_dict(checkpoint["state_dict"]) 212 | if model_device is not None: 213 | model.to(model_device) 214 | 215 | if lean_checkpoint: 216 | msglogger.info( 217 | "=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file)) 218 | ) 219 | return (model, None, None, 0) 220 | 221 | def _load_optimizer(cls, src_state_dict, model): 222 | """Initiate optimizer with model parameters and load src_state_dict""" 223 | # initiate the dest_optimizer with a dummy learning rate, 224 | # this is required to support SGD.__init__() 225 | dest_optimizer = cls(model.parameters(), lr=1) 226 | dest_optimizer.load_state_dict(src_state_dict) 227 | return dest_optimizer 228 | 229 | try: 230 | optimizer = _load_optimizer( 231 | checkpoint["optimizer_type"], checkpoint["optimizer_state_dict"], model 232 | ) 233 | except KeyError: 234 | # Older checkpoints do support optimizer loading: They either had an 'optimizer' field 235 | # (different name) which was not used during the load, or they didn't even checkpoint 236 | # the optimizer. 237 | optimizer = None 238 | 239 | if optimizer is not None: 240 | msglogger.info( 241 | "Optimizer of type {type} was loaded from checkpoint".format( 242 | type=type(optimizer) 243 | ) 244 | ) 245 | msglogger.info( 246 | "Optimizer Args: {}".format( 247 | dict( 248 | (k, v) 249 | for k, v in optimizer.state_dict()["param_groups"][0].items() 250 | if k != "params" 251 | ) 252 | ) 253 | ) 254 | else: 255 | msglogger.warning("Optimizer could not be loaded from checkpoint.") 256 | 257 | msglogger.info( 258 | "=> loaded checkpoint '{f}' (epoch {e})".format( 259 | f=str(chkpt_file), e=checkpoint_epoch 260 | ) 261 | ) 262 | return (model, compression_scheduler, optimizer, start_epoch) 263 | -------------------------------------------------------------------------------- /distiller/apputils/data_loaders.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Helper code for data loading. 18 | 19 | This code will help with the image classification datasets: ImageNet and CIFAR10 20 | 21 | """ 22 | import os 23 | import torch 24 | import torchvision.transforms as transforms 25 | import torchvision.datasets as datasets 26 | from torch.utils.data.sampler import Sampler 27 | import numpy as np 28 | 29 | import distiller 30 | 31 | DATASETS_NAMES = ["imagenet", "cifar10"] 32 | 33 | 34 | def load_data( 35 | dataset, 36 | data_dir, 37 | batch_size, 38 | workers, 39 | validation_split=0.1, 40 | deterministic=False, 41 | effective_train_size=1.0, 42 | effective_valid_size=1.0, 43 | effective_test_size=1.0, 44 | fixed_subset=False, 45 | ): 46 | """Load a dataset. 47 | 48 | Args: 49 | dataset: a string with the name of the dataset to load (cifar10/imagenet) 50 | data_dir: the directory where the datset resides 51 | batch_size: the batch size 52 | workers: the number of worker threads to use for loading the data 53 | validation_split: portion of training dataset to set aside for validation 54 | deterministic: set to True if you want the data loading process to be deterministic. 55 | Note that deterministic data loading suffers from poor performance. 56 | effective_train/valid/test_size: portion of the datasets to load on each epoch. 57 | The subset is chosen randomly each time. For the training and validation sets, this is applied AFTER 58 | the split to those sets according to the validation_split parameter 59 | fixed_subset: set to True to keep the same subset of data throughout the run (the size of the subset 60 | is still determined according to the effective_train/valid/test_size args) 61 | """ 62 | if dataset not in DATASETS_NAMES: 63 | raise ValueError('load_data does not support dataset %s" % dataset') 64 | datasets_fn = ( 65 | cifar10_get_datasets if dataset == "cifar10" else imagenet_get_datasets 66 | ) 67 | return get_data_loaders( 68 | datasets_fn, 69 | data_dir, 70 | batch_size, 71 | workers, 72 | validation_split=validation_split, 73 | deterministic=deterministic, 74 | effective_train_size=effective_train_size, 75 | effective_valid_size=effective_valid_size, 76 | effective_test_size=effective_test_size, 77 | fixed_subset=fixed_subset, 78 | ) 79 | 80 | 81 | def cifar10_get_datasets(data_dir): 82 | """Load the CIFAR10 dataset. 83 | 84 | The original training dataset is split into training and validation sets (code is 85 | inspired by https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb). 86 | By default we use a 90:10 (45K:5K) training:validation split. 87 | 88 | The output of torchvision datasets are PIL Image images of range [0, 1]. 89 | We transform them to Tensors of normalized range [-1, 1] 90 | https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py 91 | 92 | Data augmentation: 4 pixels are padded on each side, and a 32x32 crop is randomly sampled 93 | from the padded image or its horizontal flip. 94 | This is similar to [1] and some other work that use CIFAR10. 95 | 96 | [1] C.-Y. Lee, S. Xie, P. Gallagher, Z. Zhang, and Z. Tu. Deeply Supervised Nets. 97 | arXiv:1409.5185, 2014 98 | """ 99 | train_transform = transforms.Compose( 100 | [ 101 | transforms.RandomCrop(32, padding=4), 102 | transforms.RandomHorizontalFlip(), 103 | transforms.ToTensor(), 104 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 105 | ] 106 | ) 107 | 108 | train_dataset = datasets.CIFAR10( 109 | root=data_dir, train=True, download=True, transform=train_transform 110 | ) 111 | 112 | test_transform = transforms.Compose( 113 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 114 | ) 115 | 116 | test_dataset = datasets.CIFAR10( 117 | root=data_dir, train=False, download=True, transform=test_transform 118 | ) 119 | 120 | return train_dataset, test_dataset 121 | 122 | 123 | def imagenet_get_datasets(data_dir): 124 | """ 125 | Load the ImageNet dataset. 126 | """ 127 | train_dir = os.path.join(data_dir, "train") 128 | test_dir = os.path.join(data_dir, "val") 129 | normalize = transforms.Normalize( 130 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 131 | ) 132 | 133 | train_transform = transforms.Compose( 134 | [ 135 | transforms.RandomResizedCrop(224), 136 | transforms.RandomHorizontalFlip(), 137 | transforms.ToTensor(), 138 | normalize, 139 | ] 140 | ) 141 | 142 | train_dataset = datasets.ImageFolder(train_dir, train_transform) 143 | 144 | test_transform = transforms.Compose( 145 | [ 146 | transforms.Resize(256), 147 | transforms.CenterCrop(224), 148 | transforms.ToTensor(), 149 | normalize, 150 | ] 151 | ) 152 | 153 | test_dataset = datasets.ImageFolder(test_dir, test_transform) 154 | 155 | return train_dataset, test_dataset 156 | 157 | 158 | def __image_size(dataset): 159 | # un-squeeze is used here to add the batch dimension (value=1), which is missing 160 | return dataset[0][0].unsqueeze(0).size() 161 | 162 | 163 | def __deterministic_worker_init_fn(worker_id, seed=0): 164 | import random 165 | import numpy 166 | 167 | random.seed(seed) 168 | numpy.random.seed(seed) 169 | torch.manual_seed(seed) 170 | 171 | 172 | def __split_list(l, ratio): 173 | split_idx = int(np.floor(ratio * len(l))) 174 | return l[:split_idx], l[split_idx:] 175 | 176 | 177 | class SwitchingSubsetRandomSampler(Sampler): 178 | """Samples a random subset of elements from a data source, without replacement. 179 | 180 | The subset of elements is re-chosen randomly each time the sampler is enumerated 181 | 182 | Args: 183 | data_source (Dataset): dataset to sample from 184 | subset_size (float): value in (0..1], representing the portion of dataset to sample at each enumeration. 185 | """ 186 | 187 | def __init__(self, data_source, effective_size): 188 | self.data_source = data_source 189 | self.subset_length = _get_subset_length(data_source, effective_size) 190 | 191 | def __iter__(self): 192 | # Randomizing in the same way as in torch.utils.data.sampler.SubsetRandomSampler to maintain 193 | # reproducibility with the previous data loaders implementation 194 | indices = torch.randperm(len(self.data_source)) 195 | subset_indices = indices[: self.subset_length] 196 | return (self.data_source[i] for i in subset_indices) 197 | 198 | def __len__(self): 199 | return self.subset_length 200 | 201 | 202 | def _get_subset_length(data_source, effective_size): 203 | if effective_size <= 0 or effective_size > 1: 204 | raise ValueError("effective_size must be in (0..1]") 205 | return int(np.floor(len(data_source) * effective_size)) 206 | 207 | 208 | def _get_sampler(data_source, effective_size, fixed_subset=False): 209 | if fixed_subset: 210 | subset_length = _get_subset_length(data_source, effective_size) 211 | indices = np.random.permutation(len(data_source)) 212 | subset_indices = indices[:subset_length] 213 | return torch.utils.data.SubsetRandomSampler(subset_indices) 214 | return SwitchingSubsetRandomSampler(data_source, effective_size) 215 | 216 | 217 | def get_data_loaders( 218 | datasets_fn, 219 | data_dir, 220 | batch_size, 221 | num_workers, 222 | validation_split=0.1, 223 | deterministic=False, 224 | effective_train_size=1.0, 225 | effective_valid_size=1.0, 226 | effective_test_size=1.0, 227 | fixed_subset=False, 228 | ): 229 | train_dataset, test_dataset = datasets_fn(data_dir) 230 | 231 | worker_init_fn = None 232 | if deterministic: 233 | distiller.set_deterministic() 234 | worker_init_fn = __deterministic_worker_init_fn 235 | 236 | num_train = len(train_dataset) 237 | indices = list(range(num_train)) 238 | 239 | # TODO: Switch to torch.utils.data.datasets.random_split() 240 | 241 | # We shuffle indices here in case the data is arranged by class, in which case we'd would get mutually 242 | # exclusive datasets if we didn't shuffle 243 | np.random.shuffle(indices) 244 | 245 | valid_indices, train_indices = __split_list(indices, validation_split) 246 | 247 | train_sampler = _get_sampler(train_indices, effective_train_size, fixed_subset) 248 | train_loader = torch.utils.data.DataLoader( 249 | train_dataset, 250 | batch_size=batch_size, 251 | sampler=train_sampler, 252 | num_workers=num_workers, 253 | pin_memory=True, 254 | worker_init_fn=worker_init_fn, 255 | ) 256 | 257 | valid_loader = None 258 | if valid_indices: 259 | valid_sampler = _get_sampler(valid_indices, effective_valid_size, fixed_subset) 260 | valid_loader = torch.utils.data.DataLoader( 261 | train_dataset, 262 | batch_size=batch_size, 263 | sampler=valid_sampler, 264 | num_workers=num_workers, 265 | pin_memory=True, 266 | worker_init_fn=worker_init_fn, 267 | ) 268 | 269 | test_indices = list(range(len(test_dataset))) 270 | test_sampler = _get_sampler(test_indices, effective_test_size, fixed_subset) 271 | test_loader = torch.utils.data.DataLoader( 272 | test_dataset, 273 | batch_size=batch_size, 274 | sampler=test_sampler, 275 | num_workers=num_workers, 276 | pin_memory=True, 277 | ) 278 | 279 | input_shape = __image_size(train_dataset) 280 | 281 | # If validation split was 0 we use the test set as the validation set 282 | return train_loader, valid_loader or test_loader, test_loader, input_shape 283 | -------------------------------------------------------------------------------- /distiller/apputils/dataset_summaries.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | 18 | def dataset_summary(data_loader): 19 | """Create a histogram of class membership distribution within a dataset. 20 | 21 | It is important to examine our training, validation, and test 22 | datasets, to make sure that they are balanced. 23 | """ 24 | print("Analyzing dataset:") 25 | hist = {} 26 | for idx, (input, label_batch) in enumerate(data_loader): 27 | for label in label_batch: 28 | hist[label] = hist.get(label, 0) + 1 29 | if idx % 50 == 0: 30 | print("idx: %d" % idx) 31 | 32 | nclasses = len(hist) 33 | from statistics import mean 34 | 35 | print("Dataset contains {} items".format(len(data_loader.sampler))) 36 | print("Found {} classes".format(nclasses)) 37 | for data_class, size in hist.items(): 38 | print("\tClass {} = {}".format(data_class, size)) 39 | 40 | print("mean: ", mean(list(hist.values()))) 41 | -------------------------------------------------------------------------------- /distiller/apputils/execution_env.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Log information regarding the execution environment. 18 | 19 | This is helpful if you want to recreate an experiment at a later time, or if 20 | you want to understand the environment in which you execute the training. 21 | """ 22 | 23 | import logging 24 | import logging.config 25 | import operator 26 | import os 27 | import platform 28 | import shutil 29 | import sys 30 | import time 31 | import pkg_resources 32 | 33 | from git import Repo, InvalidGitRepositoryError 34 | import numpy as np 35 | import torch 36 | 37 | try: 38 | import lsb_release 39 | 40 | HAVE_LSB = True 41 | except ImportError: 42 | HAVE_LSB = False 43 | 44 | logger = logging.getLogger("app_cfg") 45 | 46 | 47 | def log_execution_env_state(config_paths=None, logdir=None, gitroot="."): 48 | """Log information about the execution environment. 49 | 50 | Files in 'config_paths' will be copied to directory 'logdir'. A common use-case 51 | is passing the path to a (compression) schedule YAML file. Storing a copy 52 | of the schedule file, with the experiment logs, is useful in order to 53 | reproduce experiments. 54 | 55 | Args: 56 | config_paths: path(s) to config file(s), used only when logdir is set 57 | logdir: log directory 58 | git_root: the path to the .git root directory 59 | """ 60 | 61 | def log_git_state(): 62 | """Log the state of the git repository. 63 | 64 | It is useful to know what git tag we're using, and if we have outstanding code. 65 | """ 66 | try: 67 | repo = Repo(gitroot) 68 | assert not repo.bare 69 | except InvalidGitRepositoryError: 70 | logger.debug( 71 | "Cannot find a Git repository. You probably downloaded an archive of Distiller." 72 | ) 73 | return 74 | 75 | if repo.is_dirty(): 76 | logger.debug("Git is dirty") 77 | try: 78 | branch_name = repo.active_branch.name 79 | except TypeError: 80 | branch_name = "None, Git is in 'detached HEAD' state" 81 | logger.debug("Active Git branch: %s", branch_name) 82 | logger.debug("Git commit: %s" % repo.head.commit.hexsha) 83 | 84 | logger.debug("Number of CPUs: %d", len(os.sched_getaffinity(0))) 85 | logger.debug("Number of GPUs: %d", torch.cuda.device_count()) 86 | logger.debug("CUDA version: %s", torch.version.cuda) 87 | logger.debug("CUDNN version: %s", torch.backends.cudnn.version()) 88 | logger.debug("Kernel: %s", platform.release()) 89 | if HAVE_LSB: 90 | logger.debug("OS: %s", lsb_release.get_lsb_information()["DESCRIPTION"]) 91 | logger.debug("Python: %s", sys.version) 92 | 93 | def _pip_freeze(): 94 | return { 95 | x.key: x.version 96 | for x in sorted(pkg_resources.working_set, key=operator.attrgetter("key")) 97 | } 98 | 99 | logger.debug("pip freeze: {}".format(_pip_freeze())) 100 | log_git_state() 101 | logger.debug("Command line: %s", " ".join(sys.argv)) 102 | 103 | if (logdir is None) or (config_paths is None): 104 | return 105 | 106 | # clone configuration files to output directory 107 | configs_dest = os.path.join(logdir, "configs") 108 | 109 | if isinstance(config_paths, str) or not hasattr(config_paths, "__iter__"): 110 | config_paths = [config_paths] 111 | for cpath in config_paths: 112 | os.makedirs(configs_dest, exist_ok=True) 113 | 114 | if os.path.exists(os.path.join(configs_dest, os.path.basename(cpath))): 115 | logger.debug( 116 | "{} already exists in logdir".format(os.path.basename(cpath) or cpath) 117 | ) 118 | else: 119 | try: 120 | shutil.copy(cpath, configs_dest) 121 | except OSError as e: 122 | logger.debug("Failed to copy of config file: {}".format(str(e))) 123 | 124 | 125 | def config_pylogger(log_cfg_file, experiment_name, output_dir="logs"): 126 | """Configure the Python logger. 127 | 128 | For each execution of the application, we'd like to create a unique log directory. 129 | By default this directory is named using the date and time of day, so that directories 130 | can be sorted by recency. You can also name your experiments and prefix the log 131 | directory with this name. This can be useful when accessing experiment data from 132 | TensorBoard, for example. 133 | """ 134 | timestr = time.strftime("%Y.%m.%d-%H%M%S") 135 | exp_full_name = ( 136 | timestr if experiment_name is None else experiment_name + "___" + timestr 137 | ) 138 | logdir = os.path.join(output_dir, exp_full_name) 139 | if not os.path.exists(logdir): 140 | os.makedirs(logdir) 141 | log_filename = os.path.join(logdir, exp_full_name + ".log") 142 | if os.path.isfile(log_cfg_file): 143 | logging.config.fileConfig(log_cfg_file, defaults={"logfilename": log_filename}) 144 | msglogger = logging.getLogger() 145 | msglogger.logdir = logdir 146 | msglogger.log_filename = log_filename 147 | msglogger.info("Log file for this run: " + os.path.realpath(log_filename)) 148 | 149 | # Create a symbollic link to the last log file created (for easier access) 150 | try: 151 | os.unlink("latest_log_file") 152 | except FileNotFoundError: 153 | pass 154 | try: 155 | os.unlink("latest_log_dir") 156 | except FileNotFoundError: 157 | pass 158 | try: 159 | os.symlink(logdir, "latest_log_dir") 160 | os.symlink(log_filename, "latest_log_file") 161 | except OSError: 162 | msglogger.debug("Failed to create symlinks to latest logs") 163 | return msglogger 164 | -------------------------------------------------------------------------------- /distiller/data_loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .collector import * 18 | from .logger import PythonLogger, TensorBoardLogger, CsvLogger 19 | 20 | del logger 21 | del collector 22 | -------------------------------------------------------------------------------- /distiller/data_loggers/tbbackend.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | """ A TensorBoard backend. 17 | 18 | Writes logs to a file using a Google's TensorBoard protobuf format. 19 | See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto 20 | """ 21 | import os 22 | import tensorflow as tf 23 | import numpy as np 24 | 25 | 26 | class TBBackend(object): 27 | def __init__(self, log_dir): 28 | self.writers = [] 29 | self.log_dir = log_dir 30 | self.writers.append(tf.summary.FileWriter(log_dir)) 31 | 32 | def scalar_summary(self, tag, scalar, step): 33 | """From TF documentation: 34 | tag: name for the data. Used by TensorBoard plugins to organize data. 35 | value: value associated with the tag (a float). 36 | """ 37 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)]) 38 | self.writers[0].add_summary(summary, step) 39 | 40 | def list_summary(self, tag, list, step, multi_graphs): 41 | """Log a relatively small list of scalars. 42 | 43 | We want to track the progress of multiple scalar parameters in a single graph. 44 | The list provides a single value for each of the parameters we are tracking. 45 | 46 | NOTE: There are two ways to log multiple values in TB and neither one is optimal. 47 | 1. Use a single writer: in this case all of the parameters use the same color, and 48 | distinguishing between them is difficult. 49 | 2. Use multiple writers: in this case each parameter has its own color which helps 50 | to visually separate the parameters. However, each writer logs to a different 51 | file and this creates a lot of files which slow down the TB load. 52 | """ 53 | for i, scalar in enumerate(list): 54 | if multi_graphs and (i + 1 > len(self.writers)): 55 | self.writers.append( 56 | tf.summary.FileWriter(os.path.join(self.log_dir, str(i))) 57 | ) 58 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)]) 59 | self.writers[0 if not multi_graphs else i].add_summary(summary, step) 60 | 61 | def histogram_summary(self, tag, tensor, step): 62 | """ 63 | From the TF documentation: 64 | tf.summary.histogram takes an arbitrarily sized and shaped Tensor, and 65 | compresses it into a histogram data structure consisting of many bins with 66 | widths and counts. 67 | 68 | TensorFlow uses non-uniformly distributed bins, which is better than using 69 | numpy's uniform bins for activations and parameters which converge around zero, 70 | but we don't add that logic here. 71 | 72 | https://www.tensorflow.org/programmers_guide/tensorboard_histograms 73 | """ 74 | hist, edges = np.histogram(tensor, bins=200) 75 | tfhist = tf.HistogramProto( 76 | min=np.min(tensor), 77 | max=np.max(tensor), 78 | num=int(np.prod(tensor.shape)), 79 | sum=np.sum(tensor), 80 | sum_squares=np.sum(np.square(tensor)), 81 | ) 82 | 83 | # From the TF documentation: 84 | # Parallel arrays encoding the bucket boundaries and the bucket values. 85 | # bucket(i) is the count for the bucket i. The range for a bucket is: 86 | # i == 0: -DBL_MAX .. bucket_limit(0) 87 | # i != 0: bucket_limit(i-1) .. bucket_limit(i) 88 | tfhist.bucket_limit.extend(edges[1:]) 89 | tfhist.bucket.extend(hist) 90 | 91 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=tfhist)]) 92 | self.writers[0].add_summary(summary, step) 93 | 94 | def sync_to_file(self): 95 | for writer in self.writers: 96 | writer.flush() 97 | -------------------------------------------------------------------------------- /distiller/directives.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Scheduling directives 18 | 19 | Scheduling directives are instructions (directives) that the scheduler can 20 | execute as part of scheduling pruning activities. 21 | """ 22 | from __future__ import division 23 | import torch 24 | import numpy as np 25 | from collections import defaultdict 26 | import logging 27 | 28 | msglogger = logging.getLogger() 29 | 30 | from torchnet.meter import AverageValueMeter 31 | from distiller.utils import sparsity, density 32 | 33 | 34 | class FreezeTraining(object): 35 | def __init__(self, name): 36 | print("------FreezeTraining--------") 37 | self.name = name 38 | 39 | 40 | def freeze_training(model, which_params, freeze): 41 | """This function will freeze/defrost training for certain layers. 42 | 43 | Sometimes, when we prune and retrain a certain layer type, 44 | we'd like to freeze the training of the other layers. 45 | """ 46 | for param in model.parameters(): 47 | pname = model_find_param_name(model, param.data) 48 | if pname is None: 49 | continue 50 | for ptype in which_params: 51 | if ptype in pname: 52 | # see: http://pytorch.org/docs/master/notes/autograd.html?highlight=grad_fn 53 | param.requires_grad = not freeze 54 | if freeze: 55 | msglogger.info("Freezing: " + pname) 56 | else: 57 | msglogger.info("Defrosting: " + pname) 58 | 59 | 60 | def freeze_all(model, freeze): 61 | msglogger.info("{} all parameters".format("Freezing" if freeze else "Defrosting")) 62 | for param in model.parameters(): 63 | param.requires_grad = not freeze 64 | 65 | 66 | def adjust_dropout(module, new_probabilty): 67 | """Replace the dropout probability of dropout layers 68 | 69 | As explained in the paper "Learning both Weights and Connections for 70 | Efficient Neural Networks": 71 | Dropout is widely used to prevent over-fitting, and this also applies to retraining. 72 | During retraining, however, the dropout ratio must be adjusted to account for the 73 | change in model capacity. In dropout, each parameter is probabilistically dropped 74 | during training, but will come back during inference. In pruning, parameters are 75 | dropped forever after pruning and have no chance to come back during both training 76 | and inference. As the parameters get sparse, the classifier will select the most 77 | informative predictors and thus have much less prediction variance, which reduces 78 | over-fitting. As pruning already reduced model capacity, the retraining dropout ratio 79 | should be smaller. 80 | """ 81 | if type(module) in [ 82 | torch.nn.Dropout, 83 | torch.nn.Dropout2d, 84 | torch.nn.Dropout3d, 85 | torch.nn.AlphaDropout, 86 | ]: 87 | msglogger.info("Adjusting dropout probability") # for {}".format(str(module))) 88 | module.p = new_probabilty 89 | else: 90 | for child in module.children(): 91 | adjust_dropout(child, new_probabilty) 92 | -------------------------------------------------------------------------------- /distiller/knowledge_distillation.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from collections import namedtuple 20 | 21 | from .policy import ScheduledTrainingPolicy, PolicyLoss, LossComponent 22 | 23 | DistillationLossWeights = namedtuple( 24 | "DistillationLossWeights", ["distill", "student", "teacher"] 25 | ) 26 | 27 | 28 | def add_distillation_args(argparser, arch_choices=None, enable_pretrained=False): 29 | """ 30 | Helper function to make it easier to add command line arguments for knowledge distillation to any script 31 | 32 | Arguments: 33 | argparser (argparse.ArgumentParser): Existing parser to which to add the arguments 34 | arch_choices: Optional list of choices to be enforced by the parser for model selection 35 | enable_pretrained (bool): Flag to enable/disable argument for "pre-trained" models. 36 | """ 37 | group = argparser.add_argument_group("Knowledge Distillation Training Arguments") 38 | group.add_argument( 39 | "--kd-teacher", 40 | choices=arch_choices, 41 | metavar="ARCH", 42 | help="Model architecture for teacher model", 43 | ) 44 | if enable_pretrained: 45 | group.add_argument( 46 | "--kd-pretrained", 47 | action="store_true", 48 | help="Use pre-trained model for teacher", 49 | ) 50 | group.add_argument( 51 | "--kd-resume", 52 | type=str, 53 | default="", 54 | metavar="PATH", 55 | help="Path to checkpoint from which to load teacher weights", 56 | ) 57 | group.add_argument( 58 | "--kd-temperature", 59 | "--kd-temp", 60 | dest="kd_temp", 61 | type=float, 62 | default=1.0, 63 | metavar="TEMP", 64 | help="Knowledge distillation softmax temperature", 65 | ) 66 | group.add_argument( 67 | "--kd-distill-wt", 68 | "--kd-dw", 69 | type=float, 70 | default=0.5, 71 | metavar="WEIGHT", 72 | help="Weight for distillation loss (student vs. teacher soft targets)", 73 | ) 74 | group.add_argument( 75 | "--kd-student-wt", 76 | "--kd-sw", 77 | type=float, 78 | default=0.5, 79 | metavar="WEIGHT", 80 | help="Weight for student vs. labels loss", 81 | ) 82 | group.add_argument( 83 | "--kd-teacher-wt", 84 | "--kd-tw", 85 | type=float, 86 | default=0.0, 87 | metavar="WEIGHT", 88 | help="Weight for teacher vs. labels loss", 89 | ) 90 | group.add_argument( 91 | "--kd-start-epoch", 92 | type=int, 93 | default=0, 94 | metavar="EPOCH_NUM", 95 | help="Epoch from which to enable distillation", 96 | ) 97 | 98 | 99 | class KnowledgeDistillationPolicy(ScheduledTrainingPolicy): 100 | """ 101 | Policy which enables knowledge distillation from a teacher model to a student model, as presented in [1]. 102 | 103 | Notes: 104 | 1. In addition to the standard policy callbacks, this class also provides a 'forward' function that must 105 | be called instead of calling the student model directly as is usually done. This is needed to facilitate 106 | running the teacher model in addition to the student, and for caching the logits for loss calculation. 107 | 2. [TO BE ENABLED IN THE NEAR FUTURE] Option to train the teacher model in parallel with the student model, 108 | described as "scheme A" in [2]. This can be achieved by passing teacher loss weight > 0. 109 | 3. [1] proposes a weighted average between the different losses. We allow arbitrary weights to be assigned 110 | to each loss. 111 | 112 | Arguments: 113 | student_model (nn.Module): The student model, that is - the main model being trained. If only initialized with 114 | random weights, this matches "scheme B" in [2]. If it has been bootstrapped with trained FP32 weights, 115 | this matches "scheme C". 116 | teacher_model (nn.Module): The teacher model from which soft targets are generated for knowledge distillation. 117 | Usually this is a pre-trained model, however in the future it will be possible to train this model as well 118 | (see Note 1 above) 119 | temperature (float): Temperature value used when calculating soft targets and logits (see [1]). 120 | loss_weights (DistillationLossWeights): Named tuple with 3 loss weights 121 | (a) 'distill' for student predictions (default: 0.5) vs. teacher soft-targets 122 | (b) 'student' for student predictions vs. true labels (default: 0.5) 123 | (c) 'teacher' for teacher predictions vs. true labels (default: 0). Currently this is just a placeholder, 124 | and cannot be set to a non-zero value. 125 | 126 | [1] Hinton et al., Distilling the Knowledge in a Neural Network (https://arxiv.org/abs/1503.02531) 127 | [2] Mishra and Marr, Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy 128 | (https://arxiv.org/abs/1711.05852) 129 | 130 | """ 131 | 132 | def __init__( 133 | self, 134 | student_model, 135 | teacher_model, 136 | temperature=1.0, 137 | loss_weights=DistillationLossWeights(0.5, 0.5, 0), 138 | ): 139 | super(KnowledgeDistillationPolicy, self).__init__() 140 | 141 | if loss_weights.teacher != 0: 142 | raise NotImplementedError( 143 | "Using teacher vs. labels loss is not supported yet, " 144 | "for now teacher loss weight must be set to 0" 145 | ) 146 | 147 | self.active = False 148 | 149 | self.student = student_model 150 | self.teacher = teacher_model 151 | self.temperature = temperature 152 | self.loss_wts = loss_weights 153 | 154 | self.last_students_logits = None 155 | self.last_teacher_logits = None 156 | 157 | def forward(self, *inputs): 158 | """ 159 | Performs forward propagation through both student and teached models and caches the logits. 160 | This function MUST be used instead of calling the student model directly. 161 | 162 | Returns: 163 | The student model's returned output, to be consistent with what a script using this would expect 164 | """ 165 | if not self.active: 166 | return self.student(*inputs) 167 | 168 | if self.loss_wts.teacher == 0: 169 | with torch.no_grad(): 170 | self.last_teacher_logits = self.teacher(*inputs) 171 | else: 172 | self.last_teacher_logits = self.teacher(*inputs) 173 | 174 | out = self.student(*inputs) 175 | self.last_students_logits = out.new_tensor(out, requires_grad=True) 176 | 177 | return out 178 | 179 | # Since the "forward" function isn't a policy callback, we use the epoch callbacks to toggle the 180 | # activation of distillation according the schedule defined by the user 181 | def on_epoch_begin(self, model, zeros_mask_dict, meta): 182 | self.active = True 183 | 184 | def on_epoch_end(self, model, zeros_mask_dict, meta): 185 | self.active = False 186 | 187 | def before_backward_pass( 188 | self, 189 | model, 190 | epoch, 191 | minibatch_id, 192 | minibatches_per_epoch, 193 | loss, 194 | zeros_mask_dict, 195 | optimizer=None, 196 | ): 197 | # TODO: Consider adding 'labels' as an argument to this callback, so we can support teacher vs. labels loss 198 | # (Otherwise we can't do it with a sub-class of ScheduledTrainingPolicy) 199 | 200 | if not self.active: 201 | return None 202 | 203 | if self.last_teacher_logits is None or self.last_students_logits is None: 204 | raise RuntimeError( 205 | "KnowledgeDistillationPolicy: Student and or teacher logits were not cached. " 206 | "Make sure to call KnowledgeDistillationPolicy.forward() in your script instead of " 207 | "calling the model directly." 208 | ) 209 | 210 | # Calculate distillation loss 211 | soft_log_probs = F.log_softmax( 212 | self.last_students_logits / self.temperature, dim=1 213 | ) 214 | # soft_targets = F.softmax(self.cached_teacher_logits[minibatch_id] / self.temperature) 215 | soft_targets = F.softmax(self.last_teacher_logits / self.temperature, dim=1) 216 | 217 | # The averaging used in PyTorch KL Div implementation is wrong, so we work around as suggested in 218 | # https://pytorch.org/docs/stable/nn.html#kldivloss 219 | # (Also see https://github.com/pytorch/pytorch/issues/6622, https://github.com/pytorch/pytorch/issues/2259) 220 | distillation_loss = ( 221 | F.kl_div(soft_log_probs, soft_targets.detach(), size_average=False) 222 | / soft_targets.shape[0] 223 | ) 224 | 225 | # The loss passed to the callback is the student's loss vs. the true labels, so we can use it directly, no 226 | # need to calculate again 227 | 228 | overall_loss = ( 229 | self.loss_wts.student * loss + self.loss_wts.distill * distillation_loss 230 | ) 231 | return PolicyLoss( 232 | overall_loss, [LossComponent("Distill Loss", distillation_loss)] 233 | ) 234 | -------------------------------------------------------------------------------- /distiller/learning_rate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from bisect import bisect_right 18 | from torch.optim.lr_scheduler import _LRScheduler 19 | 20 | 21 | class PolynomialLR(_LRScheduler): 22 | """Set the learning rate for each parameter group using a polynomial defined as: 23 | lr = base_lr * (1 - T_cur/T_max) ^ (power), where T_cur is the current epoch and T_max is the maximum number of 24 | epochs. 25 | 26 | Args: 27 | optimizer (Optimizer): Wrapped optimizer. 28 | T_max (int): Maximum number of epochs 29 | power (int): Degree of polynomial 30 | last_epoch (int): The index of last epoch. Default: -1. 31 | """ 32 | 33 | def __init__(self, optimizer, T_max, power, last_epoch=-1): 34 | self.T_max = T_max 35 | self.power = power 36 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | # base_lr * (1 - iter/max_iter) ^ (power) 40 | return [ 41 | base_lr * (1 - self.last_epoch / self.T_max) ** self.power 42 | for base_lr in self.base_lrs 43 | ] 44 | 45 | 46 | class MultiStepMultiGammaLR(_LRScheduler): 47 | """Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone. 48 | 49 | Args: 50 | optimizer (Optimizer): Wrapped optimizer. 51 | milestones (list): List of epoch indices. Must be increasing. 52 | gammas (list): List of gamma values. Must have same length as milestones. 53 | last_epoch (int): The index of last epoch. Default: -1. 54 | """ 55 | 56 | def __init__(self, optimizer, milestones, gammas, last_epoch=-1): 57 | if not list(milestones) == sorted(milestones): 58 | raise ValueError( 59 | "Milestones should be a list of" " increasing integers. Got {}", 60 | milestones, 61 | ) 62 | if len(milestones) != len(gammas): 63 | raise ValueError("Milestones and Gammas lists should be of same length.") 64 | 65 | self.milestones = milestones 66 | self.multiplicative_gammas = [1] 67 | for idx, gamma in enumerate(gammas): 68 | self.multiplicative_gammas.append(gamma * self.multiplicative_gammas[idx]) 69 | 70 | super(MultiStepMultiGammaLR, self).__init__(optimizer, last_epoch) 71 | 72 | def get_lr(self): 73 | idx = bisect_right(self.milestones, self.last_epoch) 74 | return [base_lr * self.multiplicative_gammas[idx] for base_lr in self.base_lrs] 75 | -------------------------------------------------------------------------------- /distiller/models/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains ImageNet and CIFAR image classification models for pytorch""" 18 | 19 | import torch 20 | import torchvision.models as torch_models 21 | from . import cifar10 as cifar10_models 22 | from . import imagenet as imagenet_extra_models 23 | import pretrainedmodels 24 | 25 | import logging 26 | 27 | msglogger = logging.getLogger() 28 | 29 | # ResNet special treatment: we have our own version of ResNet, so we need to over-ride 30 | # TorchVision's version. 31 | RESNET_SYMS = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] 32 | 33 | IMAGENET_MODEL_NAMES = sorted( 34 | name 35 | for name in torch_models.__dict__ 36 | if name.islower() 37 | and not name.startswith("__") 38 | and callable(torch_models.__dict__[name]) 39 | ) 40 | IMAGENET_MODEL_NAMES.extend( 41 | sorted( 42 | name 43 | for name in imagenet_extra_models.__dict__ 44 | if name.islower() 45 | and not name.startswith("__") 46 | and callable(imagenet_extra_models.__dict__[name]) 47 | ) 48 | ) 49 | IMAGENET_MODEL_NAMES.extend(pretrainedmodels.model_names) 50 | 51 | CIFAR10_MODEL_NAMES = sorted( 52 | name 53 | for name in cifar10_models.__dict__ 54 | if name.islower() 55 | and not name.startswith("__") 56 | and callable(cifar10_models.__dict__[name]) 57 | ) 58 | 59 | ALL_MODEL_NAMES = sorted( 60 | map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)) 61 | ) 62 | 63 | 64 | def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): 65 | """Create a pytorch model based on the model architecture and dataset 66 | 67 | Args: 68 | pretrained [boolean]: True is you wish to load a pretrained model. 69 | Some models do not have a pretrained version. 70 | dataset: dataset name (only 'imagenet' and 'cifar10' are supported) 71 | arch: architecture name 72 | parallel [boolean]: if set, use torch.nn.DataParallel 73 | device_ids: Devices on which model should be created - 74 | None - GPU if available, otherwise CPU 75 | -1 - CPU 76 | >=0 - GPU device IDs 77 | """ 78 | model = None 79 | dataset = dataset.lower() 80 | if dataset == "imagenet": 81 | if arch in RESNET_SYMS: 82 | model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) 83 | elif arch in torch_models.__dict__: 84 | model = torch_models.__dict__[arch](pretrained=pretrained) 85 | elif (arch in imagenet_extra_models.__dict__) and not pretrained: 86 | model = imagenet_extra_models.__dict__[arch]() 87 | elif arch in pretrainedmodels.model_names: 88 | model = pretrainedmodels.__dict__[arch]( 89 | num_classes=1000, pretrained=(dataset if pretrained else None) 90 | ) 91 | else: 92 | error_message = "" 93 | if arch not in IMAGENET_MODEL_NAMES: 94 | error_message = "Model {} is not supported for dataset ImageNet".format( 95 | arch 96 | ) 97 | elif pretrained: 98 | error_message = "Model {} (ImageNet) does not have a pretrained model".format( 99 | arch 100 | ) 101 | raise ValueError(error_message or "Failed to find model {}".format(arch)) 102 | 103 | msglogger.info( 104 | "=> using {p}{a} model for ImageNet".format( 105 | a=arch, p=("pretrained " if pretrained else "") 106 | ) 107 | ) 108 | elif dataset == "cifar10": 109 | if pretrained: 110 | raise ValueError( 111 | "Model {} (CIFAR10) does not have a pretrained model".format(arch) 112 | ) 113 | try: 114 | model = cifar10_models.__dict__[arch]() 115 | except KeyError: 116 | raise ValueError( 117 | "Model {} is not supported for dataset CIFAR10".format(arch) 118 | ) 119 | msglogger.info("=> creating %s model for CIFAR10" % arch) 120 | else: 121 | raise ValueError("Could not recognize dataset {}".format(dataset)) 122 | 123 | if torch.cuda.is_available() and device_ids != -1: 124 | device = "cuda" 125 | if (arch.startswith("alexnet") or arch.startswith("vgg")) and parallel: 126 | model.features = torch.nn.DataParallel( 127 | model.features, device_ids=device_ids 128 | ) 129 | elif parallel: 130 | model = torch.nn.DataParallel(model, device_ids=device_ids) 131 | else: 132 | device = "cpu" 133 | 134 | return model.to(device) 135 | -------------------------------------------------------------------------------- /distiller/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains ImageNet image classification models not found in torchvision""" 18 | 19 | from .mobilenet import * 20 | from .preresnet_imagenet import * 21 | from .alexnet_batchnorm import * 22 | from .resnet_earlyexit import * 23 | from .resnet import * 24 | -------------------------------------------------------------------------------- /distiller/models/imagenet/alexnet_batchnorm.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ 18 | AlexNet model with batch-norm layers. 19 | Model configuration based on the AlexNet DoReFa example in TensorPack: 20 | https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py 21 | 22 | Code based on the AlexNet PyTorch sample, with the required changes. 23 | """ 24 | 25 | import math 26 | import torch.nn as nn 27 | 28 | __all__ = ["AlexNetBN", "alexnet_bn"] 29 | 30 | 31 | class AlexNetBN(nn.Module): 32 | def __init__(self, num_classes=1000): 33 | super(AlexNetBN, self).__init__() 34 | self.features = nn.Sequential( 35 | nn.Conv2d( 36 | 3, 96, kernel_size=12, stride=4 37 | ), # conv0 (224x224x3) -> (54x54x96) 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d( 40 | 96, 256, kernel_size=5, padding=2, groups=2, bias=False 41 | ), # conv1 (54x54x96) -> (54x54x256) 42 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn1 (54x54x256) 43 | nn.MaxPool2d( 44 | kernel_size=3, stride=2, ceil_mode=True 45 | ), # pool1 (54x54x256) -> (27x27x256) 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d( 48 | 256, 384, kernel_size=3, padding=1, bias=False 49 | ), # conv2 (27x27x256) -> (27x27x384) 50 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn2 (27x27x384) 51 | nn.MaxPool2d( 52 | kernel_size=3, stride=2, padding=1 53 | ), # pool2 (27x27x384) -> (14x14x384) 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d( 56 | 384, 384, kernel_size=3, padding=1, groups=2, bias=False 57 | ), # conv3 (14x14x384) -> (14x14x384) 58 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn3 (14x14x384) 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d( 61 | 384, 256, kernel_size=3, padding=1, groups=2, bias=False 62 | ), # conv4 (14x14x384) -> (14x14x256) 63 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn4 (14x14x256) 64 | nn.MaxPool2d(kernel_size=3, stride=2), # pool4 (14x14x256) -> (6x6x256) 65 | nn.ReLU(inplace=True), 66 | ) 67 | self.classifier = nn.Sequential( 68 | nn.Linear(256 * 6 * 6, 4096, bias=False), # fc0 69 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc0 70 | nn.ReLU(inplace=True), 71 | nn.Linear(4096, 4096, bias=False), # fc1 72 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc1 73 | nn.ReLU(inplace=True), 74 | nn.Linear(4096, num_classes), # fct 75 | ) 76 | 77 | for m in self.modules(): 78 | if isinstance(m, (nn.Conv2d, nn.Linear)): 79 | fan_in, k_size = ( 80 | (m.in_channels, m.kernel_size[0] * m.kernel_size[1]) 81 | if isinstance(m, nn.Conv2d) 82 | else (m.in_features, 1) 83 | ) 84 | n = k_size * fan_in 85 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 86 | if hasattr(m, "bias") and m.bias is not None: 87 | m.bias.data.fill_(0) 88 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 89 | m.weight.data.fill_(1) 90 | m.bias.data.zero_() 91 | 92 | def forward(self, x): 93 | x = self.features(x) 94 | x = x.view(x.size(0), 256 * 6 * 6) 95 | x = self.classifier(x) 96 | return x 97 | 98 | 99 | def alexnet_bn(**kwargs): 100 | r"""AlexNet model with batch-norm layers. 101 | Model configuration based on the AlexNet DoReFa example in `TensorPack 102 | ` 103 | """ 104 | model = AlexNetBN(**kwargs) 105 | return model 106 | -------------------------------------------------------------------------------- /distiller/models/imagenet/mobilenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from math import floor 18 | import torch.nn as nn 19 | 20 | __all__ = ["mobilenet", "mobilenet_025", "mobilenet_050", "mobilenet_075"] 21 | 22 | 23 | class MobileNet(nn.Module): 24 | def __init__(self, channel_multiplier=1.0, min_channels=8): 25 | super(MobileNet, self).__init__() 26 | 27 | if channel_multiplier <= 0: 28 | raise ValueError("channel_multiplier must be >= 0") 29 | 30 | def conv_bn_relu(n_ifm, n_ofm, kernel_size, stride=1, padding=0, groups=1): 31 | return [ 32 | nn.Conv2d( 33 | n_ifm, 34 | n_ofm, 35 | kernel_size, 36 | stride=stride, 37 | padding=padding, 38 | groups=groups, 39 | bias=False, 40 | ), 41 | nn.BatchNorm2d(n_ofm), 42 | nn.ReLU(inplace=True), 43 | ] 44 | 45 | def depthwise_conv(n_ifm, n_ofm, stride): 46 | return nn.Sequential( 47 | *conv_bn_relu(n_ifm, n_ifm, 3, stride=stride, padding=1, groups=n_ifm), 48 | *conv_bn_relu(n_ifm, n_ofm, 1, stride=1) 49 | ) 50 | 51 | base_channels = [32, 64, 128, 256, 512, 1024] 52 | self.channels = [ 53 | max(floor(n * channel_multiplier), min_channels) for n in base_channels 54 | ] 55 | 56 | self.model = nn.Sequential( 57 | nn.Sequential(*conv_bn_relu(3, self.channels[0], 3, stride=2, padding=1)), 58 | depthwise_conv(self.channels[0], self.channels[1], 1), 59 | depthwise_conv(self.channels[1], self.channels[2], 2), 60 | depthwise_conv(self.channels[2], self.channels[2], 1), 61 | depthwise_conv(self.channels[2], self.channels[3], 2), 62 | depthwise_conv(self.channels[3], self.channels[3], 1), 63 | depthwise_conv(self.channels[3], self.channels[4], 2), 64 | depthwise_conv(self.channels[4], self.channels[4], 1), 65 | depthwise_conv(self.channels[4], self.channels[4], 1), 66 | depthwise_conv(self.channels[4], self.channels[4], 1), 67 | depthwise_conv(self.channels[4], self.channels[4], 1), 68 | depthwise_conv(self.channels[4], self.channels[4], 1), 69 | depthwise_conv(self.channels[4], self.channels[5], 2), 70 | depthwise_conv(self.channels[5], self.channels[5], 1), 71 | nn.AvgPool2d(7), 72 | ) 73 | self.fc = nn.Linear(self.channels[5], 1000) 74 | 75 | def forward(self, x): 76 | x = self.model(x) 77 | x = x.view(-1, x.size(1)) 78 | x = self.fc(x) 79 | return x 80 | 81 | 82 | def mobilenet_025(): 83 | return MobileNet(channel_multiplier=0.25) 84 | 85 | 86 | def mobilenet_050(): 87 | return MobileNet(channel_multiplier=0.5) 88 | 89 | 90 | def mobilenet_075(): 91 | return MobileNet(channel_multiplier=0.75) 92 | 93 | 94 | def mobilenet(): 95 | return MobileNet() 96 | -------------------------------------------------------------------------------- /distiller/models/imagenet/preresnet_imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Pre-Activation ResNet for ImageNet 18 | 19 | Pre-Activation ResNet for ImageNet, based on "Identity Mappings in Deep Residual Networks". 20 | This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate changes for pre-activation. 21 | 22 | @article{ 23 | He2016, 24 | author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, 25 | title = {Identity Mappings in Deep Residual Networks}, 26 | journal = {arXiv preprint arXiv:1603.05027}, 27 | year = {2016} 28 | } 29 | """ 30 | 31 | import torch.nn as nn 32 | import math 33 | 34 | 35 | __all__ = [ 36 | "PreactResNet", 37 | "preact_resnet18", 38 | "preact_resnet34", 39 | "preact_resnet50", 40 | "preact_resnet101", 41 | "preact_resnet152", 42 | ] 43 | 44 | 45 | def conv3x3(in_planes, out_planes, stride=1): 46 | """3x3 convolution with padding""" 47 | return nn.Conv2d( 48 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 49 | ) 50 | 51 | 52 | class PreactBasicBlock(nn.Module): 53 | expansion = 1 54 | 55 | def __init__(self, inplanes, planes, stride=1, downsample=None, preactivate=True): 56 | super(PreactBasicBlock, self).__init__() 57 | self.pre_bn = self.pre_relu = None 58 | if preactivate: 59 | self.pre_bn = nn.BatchNorm2d(inplanes) 60 | self.pre_relu = nn.ReLU(inplace=True) 61 | self.conv1 = conv3x3(inplanes, planes, stride) 62 | self.bn1_2 = nn.BatchNorm2d(planes) 63 | self.relu1_2 = nn.ReLU(inplace=True) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.downsample = downsample 66 | self.stride = stride 67 | self.preactivate = preactivate 68 | 69 | def forward(self, x): 70 | if self.preactivate: 71 | preact = self.pre_bn(x) 72 | preact = self.pre_relu(preact) 73 | else: 74 | preact = x 75 | 76 | out = self.conv1(preact) 77 | out = self.bn1_2(out) 78 | out = self.relu1_2(out) 79 | out = self.conv2(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(preact) 83 | else: 84 | residual = x 85 | 86 | out += residual 87 | 88 | return out 89 | 90 | 91 | class PreactBottleneck(nn.Module): 92 | expansion = 4 93 | 94 | def __init__(self, inplanes, planes, stride=1, downsample=None, preactivate=True): 95 | super(PreactBottleneck, self).__init__() 96 | self.pre_bn = self.pre_relu = None 97 | if preactivate: 98 | self.pre_bn = nn.BatchNorm2d(inplanes) 99 | self.pre_relu = nn.ReLU(inplace=True) 100 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 101 | self.bn1_2 = nn.BatchNorm2d(planes) 102 | self.relu1_2 = nn.ReLU(inplace=True) 103 | self.conv2 = nn.Conv2d( 104 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 105 | ) 106 | self.bn2_3 = nn.BatchNorm2d(planes) 107 | self.relu2_3 = nn.ReLU(inplace=True) 108 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 109 | self.downsample = downsample 110 | self.stride = stride 111 | self.preactivate = preactivate 112 | 113 | def forward(self, x): 114 | if self.preactivate: 115 | preact = self.pre_bn(x) 116 | preact = self.pre_relu(preact) 117 | else: 118 | preact = x 119 | 120 | out = self.conv1(preact) 121 | out = self.bn1_2(out) 122 | out = self.relu1_2(out) 123 | 124 | out = self.conv2(out) 125 | out = self.bn2_3(out) 126 | out = self.relu2_3(out) 127 | 128 | out = self.conv3(out) 129 | 130 | if self.downsample is not None: 131 | residual = self.downsample(preact) 132 | else: 133 | residual = x 134 | 135 | out += residual 136 | 137 | return out 138 | 139 | 140 | class PreactResNet(nn.Module): 141 | def __init__(self, block, layers, num_classes=1000): 142 | self.inplanes = 64 143 | super(PreactResNet, self).__init__() 144 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 145 | self.bn1 = nn.BatchNorm2d(64) 146 | self.relu1 = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 152 | self.final_bn = nn.BatchNorm2d(512 * block.expansion) 153 | self.final_relu = nn.ReLU(inplace=True) 154 | self.avgpool = nn.AvgPool2d(7, stride=1) 155 | self.fc = nn.Linear(512 * block.expansion, num_classes) 156 | 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 160 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 161 | elif isinstance(m, nn.BatchNorm2d): 162 | m.weight.data.fill_(1) 163 | m.bias.data.zero_() 164 | 165 | def _make_layer(self, block, planes, blocks, stride=1): 166 | downsample = None 167 | if stride != 1 or self.inplanes != planes * block.expansion: 168 | downsample = nn.Sequential( 169 | nn.Conv2d( 170 | self.inplanes, 171 | planes * block.expansion, 172 | kernel_size=1, 173 | stride=stride, 174 | bias=False, 175 | ), 176 | ) 177 | 178 | # On the first residual block in the first residual layer we don't pre-activate, 179 | # because we take care of that (+ maxpool) after the initial conv layer 180 | preactivate_first = stride != 1 181 | 182 | layers = [] 183 | layers.append( 184 | block(self.inplanes, planes, stride, downsample, preactivate_first) 185 | ) 186 | self.inplanes = planes * block.expansion 187 | for i in range(1, blocks): 188 | layers.append(block(self.inplanes, planes)) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def forward(self, x): 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu1(x) 196 | x = self.maxpool(x) 197 | 198 | x = self.layer1(x) 199 | x = self.layer2(x) 200 | x = self.layer3(x) 201 | x = self.layer4(x) 202 | 203 | x = self.final_bn(x) 204 | x = self.final_relu(x) 205 | x = self.avgpool(x) 206 | x = x.view(x.size(0), -1) 207 | x = self.fc(x) 208 | 209 | return x 210 | 211 | 212 | def preact_resnet18(**kwargs): 213 | """Constructs a ResNet-18 model. 214 | """ 215 | model = PreactResNet(PreactBasicBlock, [2, 2, 2, 2], **kwargs) 216 | return model 217 | 218 | 219 | def preact_resnet34(**kwargs): 220 | """Constructs a ResNet-34 model. 221 | """ 222 | model = PreactResNet(PreactBasicBlock, [3, 4, 6, 3], **kwargs) 223 | return model 224 | 225 | 226 | def preact_resnet50(**kwargs): 227 | """Constructs a ResNet-50 model. 228 | """ 229 | model = PreactResNet(PreactBottleneck, [3, 4, 6, 3], **kwargs) 230 | return model 231 | 232 | 233 | def preact_resnet101(**kwargs): 234 | """Constructs a ResNet-101 model. 235 | """ 236 | model = PreactResNet(PreactBottleneck, [3, 4, 23, 3], **kwargs) 237 | return model 238 | 239 | 240 | def preact_resnet152(**kwargs): 241 | """Constructs a ResNet-152 model. 242 | """ 243 | model = PreactResNet(PreactBottleneck, [3, 8, 36, 3], **kwargs) 244 | return model 245 | -------------------------------------------------------------------------------- /distiller/models/imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # This is the same code as in https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 18 | # However, it contains one type of change: whenever a ReLU module is used, we make sure to use a different 19 | # instance. This is necessary when we want to collect activation statistics. 20 | 21 | import torch.nn as nn 22 | import math 23 | import torch.utils.model_zoo as model_zoo 24 | 25 | from distiller.modules import EltwiseAdd 26 | 27 | 28 | __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] 29 | 30 | 31 | model_urls = { 32 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 33 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 34 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 35 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 36 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 37 | } 38 | 39 | 40 | def conv3x3(in_planes, out_planes, stride=1): 41 | """3x3 convolution with padding""" 42 | return nn.Conv2d( 43 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 44 | ) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu1 = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.relu2 = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | # Replace '+=' operator with inplace module 62 | self.add = EltwiseAdd(inplace=True) 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu1(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(x) 76 | 77 | # out += residual 78 | out = self.add(out, residual) 79 | out = self.relu2(out) 80 | 81 | return out 82 | 83 | 84 | class Bottleneck(nn.Module): 85 | expansion = 4 86 | 87 | def __init__(self, inplanes, planes, stride=1, downsample=None): 88 | super(Bottleneck, self).__init__() 89 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 90 | self.bn1 = nn.BatchNorm2d(planes) 91 | self.relu1 = nn.ReLU(inplace=True) 92 | self.conv2 = nn.Conv2d( 93 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 94 | ) 95 | self.bn2 = nn.BatchNorm2d(planes) 96 | self.relu2 = nn.ReLU(inplace=True) 97 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 98 | self.bn3 = nn.BatchNorm2d(planes * 4) 99 | self.relu3 = nn.ReLU(inplace=True) 100 | self.downsample = downsample 101 | self.stride = stride 102 | 103 | # Replace '+=' operator with inplace module 104 | self.add = EltwiseAdd(inplace=True) 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | out = self.conv1(x) 110 | out = self.bn1(out) 111 | out = self.relu1(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2(out) 115 | out = self.relu2(out) 116 | 117 | out = self.conv3(out) 118 | out = self.bn3(out) 119 | 120 | if self.downsample is not None: 121 | residual = self.downsample(x) 122 | 123 | # out += residual 124 | out = self.add(out, residual) 125 | out = self.relu3(out) 126 | 127 | return out 128 | 129 | 130 | class ResNet(nn.Module): 131 | def __init__(self, block, layers, num_classes=1000): 132 | self.inplanes = 64 133 | super(ResNet, self).__init__() 134 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 135 | self.bn1 = nn.BatchNorm2d(64) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, 64, layers[0]) 139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 140 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 141 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 142 | self.avgpool = nn.AvgPool2d(7, stride=1) 143 | self.fc = nn.Linear(512 * block.expansion, num_classes) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 148 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 149 | elif isinstance(m, nn.BatchNorm2d): 150 | m.weight.data.fill_(1) 151 | m.bias.data.zero_() 152 | 153 | def _make_layer(self, block, planes, blocks, stride=1): 154 | downsample = None 155 | if stride != 1 or self.inplanes != planes * block.expansion: 156 | downsample = nn.Sequential( 157 | nn.Conv2d( 158 | self.inplanes, 159 | planes * block.expansion, 160 | kernel_size=1, 161 | stride=stride, 162 | bias=False, 163 | ), 164 | nn.BatchNorm2d(planes * block.expansion), 165 | ) 166 | 167 | layers = [] 168 | layers.append(block(self.inplanes, planes, stride, downsample)) 169 | self.inplanes = planes * block.expansion 170 | for i in range(1, blocks): 171 | layers.append(block(self.inplanes, planes)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self, x): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | x = self.layer4(x) 185 | 186 | x = self.avgpool(x) 187 | x = x.view(x.size(0), -1) 188 | x = self.fc(x) 189 | 190 | return x 191 | 192 | 193 | def resnet18(pretrained=False, **kwargs): 194 | """Constructs a ResNet-18 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"])) 202 | return model 203 | 204 | 205 | def resnet34(pretrained=False, **kwargs): 206 | """Constructs a ResNet-34 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls["resnet34"])) 214 | return model 215 | 216 | 217 | def resnet50(pretrained=False, **kwargs): 218 | """Constructs a ResNet-50 model. 219 | 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 224 | if pretrained: 225 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"])) 226 | return model 227 | 228 | 229 | def resnet101(pretrained=False, **kwargs): 230 | """Constructs a ResNet-101 model. 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"])) 238 | return model 239 | 240 | 241 | def resnet152(pretrained=False, **kwargs): 242 | """Constructs a ResNet-152 model. 243 | 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | """ 247 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 248 | if pretrained: 249 | model.load_state_dict(model_zoo.load_url(model_urls["resnet152"])) 250 | return model 251 | -------------------------------------------------------------------------------- /distiller/models/imagenet/resnet_earlyexit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torchvision.models as models 5 | from torchvision.models.resnet import Bottleneck 6 | from torchvision.models.resnet import BasicBlock 7 | 8 | 9 | __all__ = ["resnet50_earlyexit"] 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d( 15 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 16 | ) 17 | 18 | 19 | class ResNetEarlyExit(models.ResNet): 20 | def __init__(self, block, layers, num_classes=1000): 21 | super(ResNetEarlyExit, self).__init__(block, layers, num_classes) 22 | 23 | # Define early exit layers 24 | self.conv1_exit0 = nn.Conv2d( 25 | 256, 50, kernel_size=7, stride=2, padding=3, bias=True 26 | ) 27 | self.conv2_exit0 = nn.Conv2d( 28 | 50, 12, kernel_size=7, stride=2, padding=3, bias=True 29 | ) 30 | self.conv1_exit1 = nn.Conv2d( 31 | 512, 12, kernel_size=7, stride=2, padding=3, bias=True 32 | ) 33 | self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes) 34 | self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes) 35 | 36 | def forward(self, x): 37 | x = self.conv1(x) 38 | x = self.bn1(x) 39 | x = self.relu(x) 40 | x = self.maxpool(x) 41 | 42 | x = self.layer1(x) 43 | 44 | # Add early exit layers 45 | exit0 = self.avgpool(x) 46 | exit0 = self.conv1_exit0(exit0) 47 | exit0 = self.conv2_exit0(exit0) 48 | exit0 = self.avgpool(exit0) 49 | exit0 = exit0.view(exit0.size(0), -1) 50 | exit0 = self.fc_exit0(exit0) 51 | 52 | x = self.layer2(x) 53 | 54 | # Add early exit layers 55 | exit1 = self.conv1_exit1(x) 56 | exit1 = self.avgpool(exit1) 57 | exit1 = exit1.view(exit1.size(0), -1) 58 | exit1 = self.fc_exit1(exit1) 59 | 60 | x = self.layer3(x) 61 | x = self.layer4(x) 62 | 63 | x = self.avgpool(x) 64 | x = x.view(x.size(0), -1) 65 | x = self.fc(x) 66 | 67 | # return a list of probabilities 68 | output = [] 69 | output.append(exit0) 70 | output.append(exit1) 71 | output.append(x) 72 | return output 73 | 74 | 75 | def resnet50_earlyexit(pretrained=False, **kwargs): 76 | """Constructs a ResNet-50 model. 77 | """ 78 | model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs) 79 | return model 80 | -------------------------------------------------------------------------------- /distiller/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .eltwise import EltwiseAdd, EltwiseMult 18 | from .grouping import * 19 | from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm 20 | 21 | __all__ = [ 22 | "EltwiseAdd", 23 | "EltwiseMult", 24 | "Concat", 25 | "Chunk", 26 | "Split", 27 | "Stack", 28 | "DistillerLSTMCell", 29 | "DistillerLSTM", 30 | "convert_model_to_distiller_lstm", 31 | ] 32 | -------------------------------------------------------------------------------- /distiller/modules/eltwise.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch.nn as nn 18 | 19 | 20 | class EltwiseAdd(nn.Module): 21 | def __init__(self, inplace=False): 22 | super(EltwiseAdd, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, *input): 26 | res = input[0] 27 | if self.inplace: 28 | for t in input[1:]: 29 | res += t 30 | else: 31 | for t in input[1:]: 32 | res = res + t 33 | return res 34 | 35 | 36 | class EltwiseMult(nn.Module): 37 | def __init__(self, inplace=False): 38 | super(EltwiseMult, self).__init__() 39 | self.inplace = inplace 40 | 41 | def forward(self, *input): 42 | res = input[0] 43 | if self.inplace: 44 | for t in input[1:]: 45 | res *= t 46 | else: 47 | for t in input[1:]: 48 | res = res * t 49 | return res 50 | -------------------------------------------------------------------------------- /distiller/modules/grouping.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class Concat(nn.Module): 22 | def __init__(self, dim=0): 23 | super(Concat, self).__init__() 24 | self.dim = dim 25 | 26 | def forward(self, *seq): 27 | return torch.cat(seq, dim=self.dim) 28 | 29 | 30 | class Chunk(nn.Module): 31 | def __init__(self, chunks, dim=0): 32 | super(Chunk, self).__init__() 33 | self.chunks = chunks 34 | self.dim = dim 35 | 36 | def forward(self, tensor): 37 | return tensor.chunk(self.chunks, dim=self.dim) 38 | 39 | 40 | class Split(nn.Module): 41 | def __init__(self, split_size_or_sections, dim=0): 42 | super(Split, self).__init__() 43 | self.split_size_or_sections = split_size_or_sections 44 | self.dim = dim 45 | 46 | def forward(self, tensor): 47 | return torch.split(tensor, self.split_size_or_sections, dim=self.dim) 48 | 49 | 50 | class Stack(nn.Module): 51 | def __init__(self, dim=0): 52 | super(Stack, self).__init__() 53 | self.dim = dim 54 | 55 | def forward(self, seq): 56 | return torch.stack(seq, dim=self.dim) 57 | -------------------------------------------------------------------------------- /distiller/pruning/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ 18 | :mod:`distiller.pruning` is a package implementing various pruning algorithms. 19 | """ 20 | 21 | from .magnitude_pruner import MagnitudeParameterPruner 22 | from .automated_gradual_pruner import ( 23 | AutomatedGradualPruner, 24 | L1RankedStructureParameterPruner_AGP, 25 | L2RankedStructureParameterPruner_AGP, 26 | ActivationAPoZRankedFilterPruner_AGP, 27 | ActivationMeanRankedFilterPruner_AGP, 28 | GradientRankedFilterPruner_AGP, 29 | RandomRankedFilterPruner_AGP, 30 | BernoulliFilterPruner_AGP, 31 | ) 32 | from .level_pruner import SparsityLevelParameterPruner 33 | from .sensitivity_pruner import SensitivityPruner 34 | from .splicing_pruner import SplicingPruner 35 | from .structure_pruner import StructureParameterPruner 36 | from .ranked_structures_pruner import ( 37 | L1RankedStructureParameterPruner, 38 | L2RankedStructureParameterPruner, 39 | ActivationAPoZRankedFilterPruner, 40 | ActivationMeanRankedFilterPruner, 41 | GradientRankedFilterPruner, 42 | RandomRankedFilterPruner, 43 | RandomLevelStructureParameterPruner, 44 | BernoulliFilterPruner, 45 | ) 46 | from .baidu_rnn_pruner import BaiduRNNPruner 47 | from .greedy_filter_pruning import greedy_pruner 48 | 49 | del magnitude_pruner 50 | del automated_gradual_pruner 51 | del level_pruner 52 | del sensitivity_pruner 53 | del structure_pruner 54 | del ranked_structures_pruner 55 | -------------------------------------------------------------------------------- /distiller/pruning/automated_gradual_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | from .level_pruner import SparsityLevelParameterPruner 19 | from .ranked_structures_pruner import * 20 | from distiller.utils import * 21 | from functools import partial 22 | 23 | 24 | class AutomatedGradualPrunerBase(_ParameterPruner): 25 | """Prune to an exact sparsity level specification using a prescribed sparsity 26 | level schedule formula. 27 | 28 | An automated gradual pruning algorithm that prunes the smallest magnitude 29 | weights to achieve a preset level of network sparsity. 30 | 31 | Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the 32 | efficacy of pruning for model compression", 2017 NIPS Workshop on Machine 33 | Learning of Phones and other Consumer Devices, 34 | (https://arxiv.org/pdf/1710.01878.pdf) 35 | """ 36 | 37 | def __init__(self, name, initial_sparsity, final_sparsity): 38 | super().__init__(name) 39 | self.initial_sparsity = initial_sparsity 40 | self.final_sparsity = final_sparsity 41 | assert final_sparsity > initial_sparsity 42 | 43 | def compute_target_sparsity(self, meta): 44 | starting_epoch = meta["starting_epoch"] 45 | current_epoch = meta["current_epoch"] 46 | ending_epoch = meta["ending_epoch"] 47 | freq = meta["frequency"] 48 | span = ((ending_epoch - starting_epoch - 1) // freq) * freq 49 | assert span > 0 50 | 51 | target_sparsity = ( 52 | self.final_sparsity 53 | + (self.initial_sparsity - self.final_sparsity) 54 | * (1.0 - ((current_epoch - starting_epoch) / span)) ** 3 55 | ) 56 | 57 | return target_sparsity 58 | 59 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 60 | target_sparsity = self.compute_target_sparsity(meta) 61 | self.prune_to_target_sparsity( 62 | param, param_name, zeros_mask_dict, target_sparsity, meta["model"] 63 | ) 64 | 65 | def prune_to_target_sparsity( 66 | self, param, param_name, zeros_mask_dict, target_sparsity, model=None 67 | ): 68 | raise NotImplementedError 69 | 70 | 71 | class AutomatedGradualPruner(AutomatedGradualPrunerBase): 72 | """Fine-grained pruning with an AGP sparsity schedule. 73 | 74 | An automated gradual pruning algorithm that prunes the smallest magnitude 75 | weights to achieve a preset level of network sparsity. 76 | """ 77 | 78 | def __init__(self, name, initial_sparsity, final_sparsity, weights): 79 | super().__init__(name, initial_sparsity, final_sparsity) 80 | self.params_names = weights 81 | assert self.params_names 82 | 83 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 84 | if param_name not in self.params_names: 85 | return 86 | super().set_param_mask(param, param_name, zeros_mask_dict, meta) 87 | 88 | def prune_to_target_sparsity( 89 | self, param, param_name, zeros_mask_dict, target_sparsity, model=None 90 | ): 91 | return SparsityLevelParameterPruner.prune_level( 92 | param, param_name, zeros_mask_dict, target_sparsity 93 | ) 94 | 95 | 96 | class StructuredAGP(AutomatedGradualPrunerBase): 97 | """Structured pruning with an AGP sparsity schedule. 98 | 99 | This is a base-class for structured pruning with an AGP schedule. It is an 100 | extension of the AGP concept introduced by Zhu et. al. 101 | """ 102 | 103 | def __init__(self, name, initial_sparsity, final_sparsity): 104 | super().__init__(name, initial_sparsity, final_sparsity) 105 | self.pruner = None 106 | 107 | def prune_to_target_sparsity( 108 | self, param, param_name, zeros_mask_dict, target_sparsity, model 109 | ): 110 | self.pruner.prune_to_target_sparsity( 111 | param, param_name, zeros_mask_dict, target_sparsity, model 112 | ) 113 | 114 | 115 | # TODO: this class parameterization is cumbersome: the ranking functions (per structure) 116 | # should come from the YAML schedule 117 | class L1RankedStructureParameterPruner_AGP(StructuredAGP): 118 | def __init__( 119 | self, 120 | name, 121 | initial_sparsity, 122 | final_sparsity, 123 | group_type, 124 | weights, 125 | group_dependency=None, 126 | kwargs=None, 127 | ): 128 | super().__init__(name, initial_sparsity, final_sparsity) 129 | self.pruner = L1RankedStructureParameterPruner( 130 | name, 131 | group_type, 132 | desired_sparsity=0, 133 | weights=weights, 134 | group_dependency=group_dependency, 135 | kwargs=kwargs, 136 | ) 137 | 138 | 139 | class L2RankedStructureParameterPruner_AGP(StructuredAGP): 140 | def __init__( 141 | self, 142 | name, 143 | initial_sparsity, 144 | final_sparsity, 145 | group_type, 146 | weights, 147 | group_dependency=None, 148 | kwargs=None, 149 | ): 150 | super().__init__(name, initial_sparsity, final_sparsity) 151 | self.pruner = L2RankedStructureParameterPruner( 152 | name, 153 | group_type, 154 | desired_sparsity=0, 155 | weights=weights, 156 | group_dependency=group_dependency, 157 | kwargs=kwargs, 158 | ) 159 | 160 | 161 | class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP): 162 | def __init__( 163 | self, 164 | name, 165 | initial_sparsity, 166 | final_sparsity, 167 | group_type, 168 | weights, 169 | group_dependency=None, 170 | ): 171 | assert group_type in ["3D", "Filters"] 172 | super().__init__(name, initial_sparsity, final_sparsity) 173 | self.pruner = ActivationAPoZRankedFilterPruner( 174 | name, 175 | group_type, 176 | desired_sparsity=0, 177 | weights=weights, 178 | group_dependency=group_dependency, 179 | ) 180 | 181 | 182 | class ActivationMeanRankedFilterPruner_AGP(StructuredAGP): 183 | def __init__( 184 | self, 185 | name, 186 | initial_sparsity, 187 | final_sparsity, 188 | group_type, 189 | weights, 190 | group_dependency=None, 191 | ): 192 | assert group_type in ["3D", "Filters"] 193 | super().__init__(name, initial_sparsity, final_sparsity) 194 | self.pruner = ActivationMeanRankedFilterPruner( 195 | name, 196 | group_type, 197 | desired_sparsity=0, 198 | weights=weights, 199 | group_dependency=group_dependency, 200 | ) 201 | 202 | 203 | class GradientRankedFilterPruner_AGP(StructuredAGP): 204 | def __init__( 205 | self, 206 | name, 207 | initial_sparsity, 208 | final_sparsity, 209 | group_type, 210 | weights, 211 | group_dependency=None, 212 | ): 213 | assert group_type in ["3D", "Filters"] 214 | super().__init__(name, initial_sparsity, final_sparsity) 215 | self.pruner = GradientRankedFilterPruner( 216 | name, 217 | group_type, 218 | desired_sparsity=0, 219 | weights=weights, 220 | group_dependency=group_dependency, 221 | ) 222 | 223 | 224 | class RandomRankedFilterPruner_AGP(StructuredAGP): 225 | def __init__( 226 | self, 227 | name, 228 | initial_sparsity, 229 | final_sparsity, 230 | group_type, 231 | weights, 232 | group_dependency=None, 233 | ): 234 | assert group_type in ["3D", "Filters"] 235 | super().__init__(name, initial_sparsity, final_sparsity) 236 | self.pruner = RandomRankedFilterPruner( 237 | name, 238 | group_type, 239 | desired_sparsity=0, 240 | weights=weights, 241 | group_dependency=group_dependency, 242 | ) 243 | 244 | 245 | class BernoulliFilterPruner_AGP(StructuredAGP): 246 | def __init__( 247 | self, 248 | name, 249 | initial_sparsity, 250 | final_sparsity, 251 | group_type, 252 | weights, 253 | group_dependency=None, 254 | ): 255 | assert group_type in ["3D", "Filters"] 256 | super().__init__(name, initial_sparsity, final_sparsity) 257 | self.pruner = BernoulliFilterPruner( 258 | name, 259 | group_type, 260 | desired_sparsity=0, 261 | weights=weights, 262 | group_dependency=group_dependency, 263 | ) 264 | -------------------------------------------------------------------------------- /distiller/pruning/baidu_rnn_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | from .level_pruner import SparsityLevelParameterPruner 19 | from distiller.utils import * 20 | 21 | import distiller 22 | 23 | 24 | class BaiduRNNPruner(_ParameterPruner): 25 | """An element-wise pruner for RNN networks. 26 | 27 | Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017). 28 | Exploring Sparsity in Recurrent Neural Networks. 29 | (https://arxiv.org/abs/1704.05119) 30 | 31 | This implementation slightly differs from the algorithm original paper in that 32 | the algorithm changes the pruning rate at the training-step granularity, while 33 | Distiller controls the pruning rate at epoch granularity. 34 | 35 | Equation (1): 36 | 37 | 2 * q * freq 38 | start_slope = ------------------------------------------------------- 39 | 2 * (ramp_itr - start_itr ) + 3 * (end_itr - ramp_itr ) 40 | 41 | 42 | Pruning algorithm (1): 43 | 44 | if current itr < ramp itr then 45 | threshold = start_slope * (current_itr - start_itr + 1) / freq 46 | else 47 | threshold = (start_slope * (ramp_itr - start_itr + 1) + 48 | ramp_slope * (current_itr - ramp_itr + 1)) / freq 49 | end if 50 | 51 | mask = abs(param) < threshold 52 | """ 53 | 54 | def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights): 55 | # Initialize the pruner, using a configuration that originates from the 56 | # schedule YAML file. 57 | super(BaiduRNNPruner, self).__init__(name) 58 | self.params_names = weights 59 | assert self.params_names 60 | 61 | # This is the 'q' value that appears in equation (1) of the paper 62 | self.q = q 63 | # This is the number of epochs to wait after starting_epoch, before we 64 | # begin ramping up the pruning rate. 65 | # In other words, between epochs 'starting_epoch' and 'starting_epoch'+ 66 | # self.ramp_epoch_offset the pruning slope is 'self.start_slope'. After 67 | # that, the slope is 'self.ramp_slope' 68 | self.ramp_epoch_offset = ramp_epoch_offset 69 | self.ramp_slope_mult = ramp_slope_mult 70 | self.ramp_slope = None 71 | self.start_slope = None 72 | 73 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 74 | if param_name not in self.params_names: 75 | return 76 | 77 | starting_epoch = meta["starting_epoch"] 78 | current_epoch = meta["current_epoch"] 79 | ending_epoch = meta["ending_epoch"] 80 | freq = meta["frequency"] 81 | 82 | ramp_epoch = self.ramp_epoch_offset + starting_epoch 83 | 84 | # Calculate start slope 85 | if self.start_slope is None: 86 | # We want to calculate these values only once, and then cache them. 87 | self.start_slope = (2 * self.q * freq) / ( 88 | 2 * (ramp_epoch - starting_epoch) + 3 * (ending_epoch - ramp_epoch) 89 | ) 90 | self.ramp_slope = self.start_slope * self.ramp_slope_mult 91 | 92 | if current_epoch < ramp_epoch: 93 | eps = self.start_slope * (current_epoch - starting_epoch + 1) / freq 94 | else: 95 | eps = ( 96 | self.start_slope * (ramp_epoch - starting_epoch + 1) 97 | + self.ramp_slope * (current_epoch - ramp_epoch + 1) 98 | ) / freq 99 | 100 | # After computing the threshold, we can create the mask 101 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps) 102 | -------------------------------------------------------------------------------- /distiller/pruning/level_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | from .pruner import _ParameterPruner 19 | import distiller 20 | 21 | 22 | class SparsityLevelParameterPruner(_ParameterPruner): 23 | """Prune to an exact pruning level specification. 24 | 25 | This pruner is very similar to MagnitudeParameterPruner, but instead of 26 | specifying an absolute threshold for pruning, you specify a target sparsity 27 | level (expressed as a fraction: 0.5 means 50% sparsity.) 28 | 29 | To find the correct threshold, we view the tensor as one large 1D vector, sort 30 | it using the absolute values of the elements, and then take topk elements. 31 | """ 32 | 33 | def __init__(self, name, levels, **kwargs): 34 | super(SparsityLevelParameterPruner, self).__init__(name) 35 | self.levels = levels 36 | assert self.levels 37 | 38 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 39 | # If there is a specific sparsity level specified for this module, then 40 | # use it. Otherwise try to use the default level ('*'). 41 | desired_sparsity = self.levels.get(param_name, self.levels.get("*", 0)) 42 | if desired_sparsity == 0: 43 | return 44 | 45 | self.prune_level(param, param_name, zeros_mask_dict, desired_sparsity) 46 | 47 | @staticmethod 48 | def prune_level(param, param_name, zeros_mask_dict, desired_sparsity): 49 | bottomk, _ = torch.topk( 50 | param.abs().view(-1), 51 | int(desired_sparsity * param.numel()), 52 | largest=False, 53 | sorted=True, 54 | ) 55 | threshold = bottomk.data[ 56 | -1 57 | ] # This is the largest element from the group of elements that we prune away 58 | zeros_mask_dict[param_name].mask = distiller.threshold_mask( 59 | param.data, threshold 60 | ) 61 | -------------------------------------------------------------------------------- /distiller/pruning/magnitude_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | import distiller 19 | 20 | 21 | class MagnitudeParameterPruner(_ParameterPruner): 22 | """This is the most basic magnitude-based pruner. 23 | 24 | This pruner supports configuring a scalar threshold for each layer. 25 | A default threshold is mandatory and is used for layers without explicit 26 | threshold setting. 27 | 28 | """ 29 | 30 | def __init__(self, name, thresholds, **kwargs): 31 | """ 32 | Usually, a Pruner is constructed by the compression schedule parser 33 | found in distiller/config.py. 34 | The constructor is passed a dictionary of thresholds, as explained below. 35 | 36 | Args: 37 | name (string): the name of the pruner (used only for debug) 38 | thresholds (dict): a disctionary of thresholds, with the key being the 39 | parameter name. 40 | A special key, '*', represents the default threshold value. If 41 | set_param_mask is invoked on a parameter tensor that does not have 42 | an explicit entry in the 'thresholds' dictionary, then this default 43 | value is used. 44 | Currently it is mandatory to include a '*' key in 'thresholds'. 45 | """ 46 | super(MagnitudeParameterPruner, self).__init__(name) 47 | assert thresholds is not None 48 | # Make sure there is a default threshold to use 49 | assert "*" in thresholds 50 | self.thresholds = thresholds 51 | 52 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 53 | threshold = self.thresholds.get(param_name, self.thresholds["*"]) 54 | zeros_mask_dict[param_name].mask = distiller.threshold_mask( 55 | param.data, threshold 56 | ) 57 | -------------------------------------------------------------------------------- /distiller/pruning/pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import distiller 19 | 20 | 21 | class _ParameterPruner(object): 22 | """Base class for all pruners. 23 | 24 | Arguments: 25 | name: pruner name is used mainly for debugging. 26 | """ 27 | 28 | def __init__(self, name): 29 | self.name = name 30 | 31 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 32 | raise NotImplementedError 33 | 34 | 35 | def threshold_model(model, threshold): 36 | """Threshold an entire model using the provided threshold 37 | 38 | This function prunes weights only (biases are left untouched). 39 | """ 40 | for name, p in model.named_parameters(): 41 | if "weight" in name: 42 | mask = distiller.threshold_mask(p.data, threshold) 43 | p.data = p.data.mul_(mask) 44 | -------------------------------------------------------------------------------- /distiller/pruning/sensitivity_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | import distiller 19 | import torch 20 | 21 | 22 | class SensitivityPruner(_ParameterPruner): 23 | """Use algorithm from "Learning both Weights and Connections for Efficient 24 | Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf 25 | 26 | I.e.: "The pruning threshold is chosen as a quality parameter multiplied 27 | by the standard deviation of a layers weights." 28 | In this code, the "quality parameter" is referred to as "sensitivity" and 29 | is based on the values learned from performing sensitivity analysis. 30 | 31 | Note that this implementation deviates slightly from the algorithm Song Han 32 | describes in his PhD dissertation, in that the threshold value is set only 33 | once. In his PhD dissertation, Song Han describes a growing threshold, at 34 | each iteration. This requires n+1 hyper-parameters (n being the number of 35 | pruning iterations we use): the threshold and the threshold increase (delta) 36 | at each pruning iteration. 37 | The implementation that follows, takes advantage of the fact that as pruning 38 | progresses, more weights are pulled toward zero, and therefore the threshold 39 | "traps" more weights. Thus, we can use less hyper-parameters and achieve the 40 | same results. 41 | """ 42 | 43 | def __init__(self, name, sensitivities, **kwargs): 44 | super(SensitivityPruner, self).__init__(name) 45 | self.sensitivities = sensitivities 46 | 47 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 48 | if not hasattr(param, "stddev"): 49 | param.stddev = torch.std(param).item() 50 | 51 | if param_name not in self.sensitivities: 52 | if "*" not in self.sensitivities: 53 | return 54 | else: 55 | sensitivity = self.sensitivities["*"] 56 | else: 57 | sensitivity = self.sensitivities[param_name] 58 | 59 | threshold = param.stddev * sensitivity 60 | 61 | # After computing the threshold, we can create the mask 62 | zeros_mask_dict[param_name].mask = distiller.threshold_mask( 63 | param.data, threshold 64 | ) 65 | -------------------------------------------------------------------------------- /distiller/pruning/splicing_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | 18 | from .pruner import _ParameterPruner 19 | import torch 20 | import logging 21 | 22 | msglogger = logging.getLogger() 23 | 24 | 25 | class SplicingPruner(_ParameterPruner): 26 | """A pruner that both prunes and splices connections. 27 | 28 | The idea of pruning and splicing working in tandem was first proposed in the following 29 | NIPS paper from Intel Labs China in 2016: 30 | Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen. 31 | NIPS 2016, https://arxiv.org/abs/1608.04493. 32 | 33 | A SplicingPruner works best with a Dynamic Network Surgery schedule. 34 | The original Caffe code from the authors of the paper is available here: 35 | https://github.com/yiwenguo/Dynamic-Network-Surgery/blob/master/src/caffe/layers/compress_conv_layer.cpp 36 | """ 37 | 38 | def __init__( 39 | self, 40 | name, 41 | sensitivities, 42 | low_thresh_mult, 43 | hi_thresh_mult, 44 | sensitivity_multiplier=0, 45 | ): 46 | """Arguments: 47 | """ 48 | super(SplicingPruner, self).__init__(name) 49 | self.sensitivities = sensitivities 50 | self.low_thresh_mult = low_thresh_mult 51 | self.hi_thresh_mult = hi_thresh_mult 52 | self.sensitivity_multiplier = sensitivity_multiplier 53 | 54 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 55 | if param_name not in self.sensitivities: 56 | if "*" not in self.sensitivities: 57 | return 58 | else: 59 | sensitivity = self.sensitivities["*"] 60 | else: 61 | sensitivity = self.sensitivities[param_name] 62 | 63 | if not hasattr(param, "_std"): 64 | # Compute the mean and standard-deviation once, and cache them. 65 | param._std = torch.std(param.abs()).item() 66 | param._mean = torch.mean(param.abs()).item() 67 | 68 | if self.sensitivity_multiplier > 0: 69 | # Linearly growing sensitivity - for now this is hard-coded 70 | starting_epoch = meta["starting_epoch"] 71 | current_epoch = meta["current_epoch"] 72 | sensitivity *= ( 73 | current_epoch - starting_epoch 74 | ) * self.sensitivity_multiplier + 1 75 | 76 | threshold_low = (param._mean + param._std * sensitivity) * self.low_thresh_mult 77 | threshold_hi = (param._mean + param._std * sensitivity) * self.hi_thresh_mult 78 | 79 | if zeros_mask_dict[param_name].mask is None: 80 | zeros_mask_dict[param_name].mask = torch.ones_like(param) 81 | 82 | # This code performs the code in equation (3) of the "Dynamic Network Surgery" paper: 83 | # 84 | # 0 if a > |W| 85 | # h(W) = mask if a <= |W| < b 86 | # 1 if b <= |W| 87 | # 88 | # h(W) is the so-called "network surgery function". 89 | # mask is the mask used in the previous iteration. 90 | # a and b are the low and high thresholds, respectively. 91 | # We followed the example implementation from Yiwen Guo in Caffe, and used the 92 | # weight tensor's starting mean and std. 93 | # This is very similar to the initialization performed by distiller.SensitivityPruner. 94 | 95 | mask = zeros_mask_dict[param_name].mask 96 | zeros, ones = ( 97 | torch.tensor([0]).type(mask.type()), 98 | torch.tensor([1]).type(mask.type()), 99 | ) 100 | weights_abs = param.abs() 101 | new_mask = torch.where(threshold_low > weights_abs, zeros, mask) 102 | new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask) 103 | zeros_mask_dict[param_name].mask = new_mask 104 | -------------------------------------------------------------------------------- /distiller/pruning/structure_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | from .pruner import _ParameterPruner 19 | import distiller 20 | 21 | msglogger = logging.getLogger() 22 | 23 | 24 | class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner): 25 | """Prune parameter structures. 26 | 27 | Pruning criterion: average L1-norm. If the average L1-norm (absolute value) of the eleements 28 | in the structure is below threshold, then the structure is pruned. 29 | 30 | We use the average, instead of plain L1-norm, because we don't want the threshold to depend on 31 | the structure size. 32 | """ 33 | 34 | def __init__(self, name, model, reg_regims, threshold_criteria): 35 | super(StructureParameterPruner, self).__init__(name) 36 | self.name = name 37 | self.model = model 38 | self.reg_regims = reg_regims 39 | self.threshold_criteria = threshold_criteria 40 | assert threshold_criteria in ["Max", "Mean_Abs"] 41 | 42 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 43 | if param_name not in self.reg_regims.keys(): 44 | return 45 | 46 | group_type = self.reg_regims[param_name][1] 47 | threshold = self.reg_regims[param_name][0] 48 | zeros_mask_dict[param_name].mask = self.group_threshold_mask( 49 | param, group_type, threshold, self.threshold_criteria 50 | ) 51 | -------------------------------------------------------------------------------- /distiller/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .quantizer import Quantizer 18 | from .range_linear import ( 19 | RangeLinearQuantWrapper, 20 | RangeLinearQuantParamLayerWrapper, 21 | PostTrainLinearQuantizer, 22 | LinearQuantMode, 23 | QuantAwareTrainRangeLinearQuantizer, 24 | add_post_train_quant_args, 25 | RangeLinearQuantConcatWrapper, 26 | RangeLinearQuantEltwiseAddWrapper, 27 | RangeLinearQuantEltwiseMultWrapper, 28 | ClipMode, 29 | ) 30 | from .clipped_linear import ( 31 | LinearQuantizeSTE, 32 | ClippedLinearQuantization, 33 | WRPNQuantizer, 34 | DorefaQuantizer, 35 | PACTQuantizer, 36 | ) 37 | 38 | del quantizer 39 | del range_linear 40 | del clipped_linear 41 | -------------------------------------------------------------------------------- /distiller/quantization/q_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | 19 | 20 | def _prep_saturation_val_tensor(sat_val): 21 | is_scalar = not isinstance(sat_val, torch.Tensor) 22 | out = torch.tensor(sat_val) if is_scalar else sat_val.clone().detach() 23 | if not out.is_floating_point(): 24 | out = out.to(torch.float32) 25 | if out.dim() == 0: 26 | out = out.unsqueeze(0) 27 | return is_scalar, out 28 | 29 | 30 | def symmetric_linear_quantization_params(num_bits, saturation_val): 31 | is_scalar, sat_val = _prep_saturation_val_tensor(saturation_val) 32 | 33 | if any(sat_val < 0): 34 | raise ValueError("Saturation value must be >= 0") 35 | 36 | # Leave one bit for sign 37 | n = 2 ** (num_bits - 1) - 1 38 | 39 | # If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation 40 | # value to 'n', so the scale becomes 1 41 | sat_val[sat_val == 0] = n 42 | scale = n / sat_val 43 | zero_point = torch.zeros_like(scale) 44 | 45 | if is_scalar: 46 | # If input was scalar, return scalars 47 | return scale.item(), zero_point.item() 48 | return scale, zero_point 49 | 50 | 51 | def asymmetric_linear_quantization_params( 52 | num_bits, saturation_min, saturation_max, integral_zero_point=True, signed=False 53 | ): 54 | scalar_min, sat_min = _prep_saturation_val_tensor(saturation_min) 55 | scalar_max, sat_max = _prep_saturation_val_tensor(saturation_max) 56 | is_scalar = scalar_min and scalar_max 57 | 58 | if scalar_max and not scalar_min: 59 | sat_max = sat_max.to(sat_min.device) 60 | elif scalar_min and not scalar_max: 61 | sat_min = sat_min.to(sat_max.device) 62 | 63 | if any(sat_min > sat_max): 64 | raise ValueError("saturation_min must be smaller than saturation_max") 65 | 66 | n = 2 ** num_bits - 1 67 | 68 | # Make sure 0 is in the range 69 | sat_min = torch.min(sat_min, torch.zeros_like(sat_min)) 70 | sat_max = torch.max(sat_max, torch.zeros_like(sat_max)) 71 | 72 | diff = sat_max - sat_min 73 | # If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation 74 | # value to 'n', so the scale becomes 1 75 | diff[diff == 0] = n 76 | 77 | scale = n / diff 78 | zero_point = scale * sat_min 79 | if integral_zero_point: 80 | zero_point = zero_point.round() 81 | if signed: 82 | zero_point += 2 ** (num_bits - 1) 83 | if is_scalar: 84 | return scale.item(), zero_point.item() 85 | return scale, zero_point 86 | 87 | 88 | def clamp(input, min, max, inplace=False): 89 | if inplace: 90 | input.clamp_(min, max) 91 | return input 92 | return torch.clamp(input, min, max) 93 | 94 | 95 | def linear_quantize(input, scale, zero_point, inplace=False): 96 | if inplace: 97 | input.mul_(scale).sub_(zero_point).round_() 98 | return input 99 | return torch.round(scale * input - zero_point) 100 | 101 | 102 | def linear_quantize_clamp( 103 | input, scale, zero_point, clamp_min, clamp_max, inplace=False 104 | ): 105 | output = linear_quantize(input, scale, zero_point, inplace) 106 | return clamp(output, clamp_min, clamp_max, inplace) 107 | 108 | 109 | def linear_dequantize(input, scale, zero_point, inplace=False): 110 | if inplace: 111 | input.add_(zero_point).div_(scale) 112 | return input 113 | return (input + zero_point) / scale 114 | 115 | 116 | def get_tensor_min_max(t, per_dim=None): 117 | if per_dim is None: 118 | return t.min(), t.max() 119 | if per_dim >= t.dim(): 120 | raise ValueError( 121 | "Got per_dim={0}, but tensor only has {1} dimensions", per_dim, t.dim() 122 | ) 123 | view_dims = [t.shape[i] for i in range(per_dim + 1)] + [-1] 124 | tv = t.view(*view_dims) 125 | return tv.min(dim=-1)[0], tv.max(dim=-1)[0] 126 | 127 | 128 | def get_tensor_avg_min_max(t, across_dim=None): 129 | min_per_dim, max_per_dim = get_tensor_min_max(t, per_dim=across_dim) 130 | return min_per_dim.mean(), max_per_dim.mean() 131 | 132 | 133 | def get_tensor_max_abs(t, per_dim=None): 134 | min_val, max_val = get_tensor_min_max(t, per_dim=per_dim) 135 | return torch.max(min_val.abs_(), max_val.abs_()) 136 | 137 | 138 | def get_tensor_avg_max_abs(t, across_dim=None): 139 | avg_min, avg_max = get_tensor_avg_min_max(t, across_dim=across_dim) 140 | return torch.max(avg_min.abs_(), avg_max.abs_()) 141 | 142 | 143 | def get_tensor_mean_n_stds_min_max(t, dim=None, n_stds=1): 144 | if dim is not None: 145 | raise NotImplementedError("Setting dim != None not supported yet") 146 | if n_stds <= 0: 147 | raise ValueError("n_stds must be > 0, got {}".format(n_stds)) 148 | mean = t.mean() 149 | std = t.std() 150 | min_val, max_val = get_tensor_min_max(t) 151 | min_val = torch.max(min_val, mean - n_stds * std) 152 | max_val = torch.min(max_val, mean + n_stds * std) 153 | return min_val, max_val 154 | 155 | 156 | def get_tensor_mean_n_stds_max_abs(t, dim=None, n_stds=1): 157 | min_val, max_val = get_tensor_mean_n_stds_min_max(t, dim, n_stds) 158 | return torch.max(min_val.abs_(), max_val.abs_()) 159 | 160 | 161 | def get_scale_approximation_shift_bits(fp32_scale, mult_bits, limit=False): 162 | shift_bits = torch.log2((2 ** mult_bits - 1) / fp32_scale).floor() 163 | if limit: 164 | shift_bits = min(mult_bits, shift_bits) 165 | return shift_bits 166 | 167 | 168 | def get_scale_approximation_mult(fp32_scale, shift_bits): 169 | return (fp32_scale * (2 ** shift_bits)).floor() 170 | 171 | 172 | def get_scale_approximation_params(fp32_scale, mult_bits, limit=False): 173 | shift_bits = get_scale_approximation_shift_bits(fp32_scale, mult_bits, limit=limit) 174 | multiplier = get_scale_approximation_mult(fp32_scale, shift_bits) 175 | return multiplier, shift_bits 176 | 177 | 178 | def approx_scale_as_mult_and_shift(fp32_scale, mult_bits, limit=False): 179 | multiplier, shift_bits = get_scale_approximation_params( 180 | fp32_scale, mult_bits, limit=limit 181 | ) 182 | return multiplier / (2 ** shift_bits) 183 | 184 | 185 | def get_quantized_range(num_bits, signed=True): 186 | if signed: 187 | n = 2 ** (num_bits - 1) 188 | return -n, n - 1 189 | return 0, 2 ** num_bits - 1 190 | 191 | 192 | class LinearQuantizeSTE(torch.autograd.Function): 193 | @staticmethod 194 | def forward(ctx, input, scale, zero_point, dequantize, inplace): 195 | if inplace: 196 | ctx.mark_dirty(input) 197 | output = linear_quantize(input, scale, zero_point, inplace) 198 | if dequantize: 199 | output = linear_dequantize(output, scale, zero_point, inplace) 200 | return output 201 | 202 | @staticmethod 203 | def backward(ctx, grad_output): 204 | # Straight-through estimator 205 | return grad_output, None, None, None, None 206 | -------------------------------------------------------------------------------- /distiller/regularization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .l1_regularizer import L1Regularizer 18 | from .group_regularizer import GroupLassoRegularizer, GroupVarianceRegularizer 19 | 20 | del l1_regularizer 21 | del group_regularizer 22 | -------------------------------------------------------------------------------- /distiller/regularization/drop_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .regularizer import _Regularizer 6 | 7 | 8 | class Conv2dWithMask(nn.Conv2d): 9 | def __init__( 10 | self, 11 | in_channels, 12 | out_channels, 13 | kernel_size, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | bias=True, 19 | ): 20 | 21 | super(Conv2dWithMask, self).__init__( 22 | in_channels=in_channels, 23 | out_channels=out_channels, 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | dilation=dilation, 28 | groups=groups, 29 | bias=bias, 30 | ) 31 | 32 | self.test_mask = None 33 | self.p_mask = 1.0 34 | self.frequency = 16 35 | 36 | def forward(self, input): 37 | if self.training: 38 | self.frequency -= 1 39 | if self.frequency == 0: 40 | sample = np.random.binomial(n=1, p=self.p_mask, size=self.out_channels) 41 | param = self.weight 42 | l1norm = param.detach().view(param.size(0), -1).norm(p=1, dim=1) 43 | mask = torch.tensor(sample) 44 | mask = ( 45 | mask.expand( 46 | param.size(1) * param.size(2) * param.size(3), param.size(0) 47 | ) 48 | .t() 49 | .contiguous() 50 | ) 51 | mask = mask.view(self.weight.shape).to(param.device) 52 | mask = mask.type(param.type()) 53 | masked_weights = self.weight * mask 54 | masked_l1norm = ( 55 | masked_weights.detach().view(param.size(0), -1).norm(p=1, dim=1) 56 | ) 57 | pruning_factor = (masked_l1norm.sum() / l1norm.sum()).item() 58 | pruning_factor = max(0.2, pruning_factor) 59 | weight = masked_weights / pruning_factor 60 | self.frequency = 16 61 | else: 62 | weight = self.weight 63 | else: 64 | weight = self.weight 65 | return F.conv2d( 66 | input, 67 | weight, 68 | self.bias, 69 | self.stride, 70 | self.padding, 71 | self.dilation, 72 | self.groups, 73 | ) 74 | 75 | 76 | # replaces all conv2d layers in target`s model with 'Conv2dWithMask' 77 | def replace_conv2d(container): 78 | for name, module in container.named_children(): 79 | if isinstance(module, nn.Conv2d): 80 | print("replacing: ", name) 81 | new_module = Conv2dWithMask( 82 | in_channels=module.in_channels, 83 | out_channels=module.out_channels, 84 | kernel_size=module.kernel_size, 85 | padding=module.padding, 86 | stride=module.stride, 87 | bias=module.bias, 88 | ) 89 | setattr(container, name, new_module) 90 | replace_conv2d(module) 91 | 92 | 93 | class DropFilterRegularizer(_Regularizer): 94 | def __init__(self, name, model, reg_regims, threshold_criteria=None): 95 | super().__init__(name, model, reg_regims, threshold_criteria) 96 | replace_conv2d(model) 97 | -------------------------------------------------------------------------------- /distiller/regularization/l1_regularizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """L1-norm regularization""" 18 | 19 | import torch 20 | import math 21 | import numpy as np 22 | import distiller 23 | from .regularizer import _Regularizer, EPSILON 24 | 25 | 26 | class L1Regularizer(_Regularizer): 27 | def __init__(self, name, model, reg_regims, threshold_criteria=None): 28 | super(L1Regularizer, self).__init__(name, model, reg_regims, threshold_criteria) 29 | 30 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): 31 | if param_name in self.reg_regims: 32 | strength = self.reg_regims[param_name] 33 | regularizer_loss += L1Regularizer.__add_l1(param, strength) 34 | 35 | return regularizer_loss 36 | 37 | def threshold(self, param, param_name, zeros_mask_dict): 38 | """Soft threshold for L1-norm regularizer""" 39 | if self.threshold_criteria is None or param_name not in self.reg_regims: 40 | return 41 | 42 | strength = self.reg_regims[param_name] 43 | zeros_mask_dict[param_name].mask = distiller.threshold_mask( 44 | param.data, threshold=strength 45 | ) 46 | zeros_mask_dict[param_name].is_regularization_mask = True 47 | 48 | @staticmethod 49 | def __add_l1(var, strength): 50 | return var.abs().sum() * strength 51 | 52 | @staticmethod 53 | def __add_l1_all(loss, model, reg_regims): 54 | for param_name, param in model.named_parameters(): 55 | if param_name in reg_regims.keys(): 56 | strength = reg_regims[param_name] 57 | loss += L1Regularizer.__add_l1(param, strength) 58 | -------------------------------------------------------------------------------- /distiller/regularization/regularizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | EPSILON = 1e-8 18 | 19 | 20 | class _Regularizer(object): 21 | def __init__(self, name, model, reg_regims, threshold_criteria): 22 | """Regularization base class. 23 | 24 | Args: 25 | reg_regims: regularization regiment. A dictionary of 26 | reg_regims[] = [ lambda, structure-type] 27 | """ 28 | self.name = name 29 | self.model = model 30 | self.reg_regims = reg_regims 31 | self.threshold_criteria = threshold_criteria 32 | 33 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): 34 | raise NotImplementedError 35 | 36 | def threshold(self, param, param_name, zeros_mask_dict): 37 | raise NotImplementedError 38 | -------------------------------------------------------------------------------- /distiller/sensitivity.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Perform sensitivity tests on layers and whole networks. 18 | 19 | Construct a schedule for experimenting with network and layer sensitivity 20 | to pruning. 21 | 22 | The idea is to set the pruning level (percentage) of specific layers (or the 23 | entire network), and then to prune once, run an evaluation on the test dataset, 24 | and exit. This should teach us about the "sensitivity" of the network/layers 25 | to pruning. 26 | 27 | This concept is discussed in "Learning both Weights and Connections for 28 | Efficient Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf 29 | """ 30 | 31 | from copy import deepcopy 32 | from collections import OrderedDict 33 | import logging 34 | import csv 35 | import distiller 36 | from .scheduler import CompressionScheduler 37 | 38 | msglogger = logging.getLogger() 39 | 40 | 41 | def perform_sensitivity_analysis(model, net_params, sparsities, test_func, group): 42 | """Perform a sensitivity test for a model's weights parameters. 43 | 44 | The model should be trained to maximum accuracy, because we aim to understand 45 | the behavior of the model's performance in relation to pruning of a specific 46 | weights tensor. 47 | 48 | By default this function will test all of the model's parameters. 49 | 50 | The return value is a complex sensitivities dictionary: the dictionary's 51 | key is the name (string) of the weights tensor. The value is another dictionary, 52 | where the tested sparsity-level is the key, and a (top1, top5, loss) tuple 53 | is the value. 54 | Below is an example of such a dictionary: 55 | 56 | .. code-block:: python 57 | {'features.module.6.weight': {0.0: (56.518, 79.07, 1.9159), 58 | 0.05: (56.492, 79.1, 1.9161), 59 | 0.10: (56.212, 78.854, 1.9315), 60 | 0.15: (35.424, 60.3, 3.0866)}, 61 | 'classifier.module.1.weight': {0.0: (56.518, 79.07, 1.9159), 62 | 0.05: (56.514, 79.07, 1.9159), 63 | 0.10: (56.434, 79.074, 1.9138), 64 | 0.15: (54.454, 77.854, 2.3127)} } 65 | 66 | The test_func is expected to execute the model on a test/validation dataset, 67 | and return the results for top1 and top5 accuracies, and the loss value. 68 | """ 69 | if group not in ["element", "filter", "channel"]: 70 | raise ValueError("group parameter contains an illegal value: {}".format(group)) 71 | sensitivities = OrderedDict() 72 | 73 | for param_name in net_params: 74 | if model.state_dict()[param_name].dim() not in [2, 4]: 75 | continue 76 | 77 | # Make a copy of the model, because when we apply the zeros mask (i.e. 78 | # perform pruning), the model's weights are altered 79 | model_cpy = deepcopy(model) 80 | 81 | sensitivity = OrderedDict() 82 | for sparsity_level in sparsities: 83 | sparsity_level = float(sparsity_level) 84 | msglogger.info( 85 | "Testing sensitivity of %s [%0.1f%% sparsity]" 86 | % (param_name, sparsity_level * 100) 87 | ) 88 | # Create the pruner (a level pruner), the pruning policy and the 89 | # pruning schedule. 90 | if group == "element": 91 | # Element-wise sparasity 92 | sparsity_levels = {param_name: sparsity_level} 93 | pruner = distiller.pruning.SparsityLevelParameterPruner( 94 | name="sensitivity", levels=sparsity_levels 95 | ) 96 | elif group == "filter": 97 | # Filter ranking 98 | if model.state_dict()[param_name].dim() != 4: 99 | continue 100 | pruner = distiller.pruning.L1RankedStructureParameterPruner( 101 | "sensitivity", 102 | group_type="Filters", 103 | desired_sparsity=sparsity_level, 104 | weights=param_name, 105 | ) 106 | elif group == "channel": 107 | # Filter ranking 108 | if model.state_dict()[param_name].dim() != 4: 109 | continue 110 | pruner = distiller.pruning.L1RankedStructureParameterPruner( 111 | "sensitivity", 112 | group_type="Channels", 113 | desired_sparsity=sparsity_level, 114 | weights=param_name, 115 | ) 116 | 117 | policy = distiller.PruningPolicy(pruner, pruner_args=None) 118 | scheduler = CompressionScheduler(model_cpy) 119 | scheduler.add_policy(policy, epochs=[0]) 120 | 121 | # Compute the pruning mask per the pruner and apply the mask on the weights 122 | scheduler.on_epoch_begin(0) 123 | scheduler.mask_all_weights() 124 | 125 | # Test and record the performance of the pruned model 126 | prec1, prec5, loss = test_func(model=model_cpy) 127 | sensitivity[sparsity_level] = (prec1, prec5, loss) 128 | sensitivities[param_name] = sensitivity 129 | return sensitivities 130 | 131 | 132 | def sensitivities_to_png(sensitivities, fname): 133 | """Create a mulitplot of the sensitivities. 134 | 135 | The 'sensitivities' argument is expected to have the dict-of-dict structure 136 | described in the documentation of perform_sensitivity_test. 137 | """ 138 | try: 139 | # sudo apt-get install python3-tk 140 | import matplotlib 141 | 142 | matplotlib.use("Agg") 143 | import matplotlib.pyplot as plt 144 | except ImportError: 145 | print( 146 | "WARNING: Function plot_sensitivity requires package matplotlib which" 147 | "is not installed in your execution environment.\n" 148 | "Skipping the PNG file generation" 149 | ) 150 | return 151 | 152 | msglogger.info("Generating sensitivity graph") 153 | 154 | for param_name, sensitivity in sensitivities.items(): 155 | sense = [values[1] for sparsity, values in sensitivity.items()] 156 | sparsities = [sparsity for sparsity, values in sensitivity.items()] 157 | plt.plot(sparsities, sense, label=param_name) 158 | 159 | plt.ylabel("top5") 160 | plt.xlabel("sparsity") 161 | plt.title("Pruning Sensitivity") 162 | plt.legend(loc="lower center", ncol=2, mode="expand", borderaxespad=0.0) 163 | plt.savefig(fname, format="png") 164 | 165 | 166 | def sensitivities_to_csv(sensitivities, fname): 167 | """Create a CSV file listing from the sensitivities dictionary. 168 | 169 | The 'sensitivities' argument is expected to have the dict-of-dict structure 170 | described in the documentation of perform_sensitivity_test. 171 | """ 172 | with open(fname, "w") as csv_file: 173 | writer = csv.writer(csv_file) 174 | # write the header 175 | writer.writerow(["parameter", "sparsity", "top1", "top5", "loss"]) 176 | for param_name, sensitivity in sensitivities.items(): 177 | for sparsity, values in sensitivity.items(): 178 | writer.writerow([param_name] + [sparsity] + list(values)) 179 | -------------------------------------------------------------------------------- /distiller/thresholding.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Tensor thresholding. 18 | 19 | The code below supports fine-grained tensor thresholding and group-wise thresholding. 20 | """ 21 | import torch 22 | 23 | 24 | def threshold_mask(weights, threshold): 25 | """Create a threshold mask for the provided parameter tensor using 26 | magnitude thresholding. 27 | 28 | Arguments: 29 | weights: a parameter tensor which should be pruned. 30 | threshold: the pruning threshold. 31 | Returns: 32 | prune_mask: The pruning mask. 33 | """ 34 | return torch.gt(torch.abs(weights), threshold).type(weights.type()) 35 | 36 | 37 | class GroupThresholdMixin(object): 38 | """A mixin class to add group thresholding capabilities 39 | 40 | TODO: this does not need to be a mixin - it should be made a simple function. We keep this until we refactor 41 | """ 42 | 43 | def group_threshold_mask(self, param, group_type, threshold, threshold_criteria): 44 | ret = group_threshold_mask(param, group_type, threshold, threshold_criteria) 45 | if isinstance(ret, tuple): 46 | return ret[0] 47 | return ret 48 | 49 | 50 | def group_threshold_binary_map(param, group_type, threshold, threshold_criteria): 51 | """Return a threshold mask for the provided parameter and group type. 52 | 53 | Args: 54 | param: The parameter to mask 55 | group_type: The elements grouping type (structure). 56 | One of:2D, 3D, 4D, Channels, Row, Cols 57 | threshold: The threshold 58 | threshold_criteria: The thresholding criteria. 59 | 'Mean_Abs' thresholds the entire element group using the mean of the 60 | absolute values of the tensor elements. 61 | 'Max' thresholds the entire group using the magnitude of the largest 62 | element in the group. 63 | """ 64 | if group_type == "2D": 65 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 66 | view_2d = param.view(-1, param.size(2) * param.size(3)) 67 | # 1. Determine if the kernel "value" is below the threshold, by creating a 1D 68 | # thresholds tensor with length = #IFMs * # OFMs 69 | thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).to( 70 | param.device 71 | ) 72 | # 2. Create a binary thresholds mask, where we use the mean of the abs values of the 73 | # elements in each channel as the threshold filter. 74 | # 3. Apply the threshold filter 75 | binary_map = threshold_policy(view_2d, thresholds, threshold_criteria) 76 | return binary_map 77 | 78 | elif group_type == "Rows": 79 | assert param.dim() == 2, "This regularization is only supported for 2D weights" 80 | thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device) 81 | binary_map = threshold_policy(param, thresholds, threshold_criteria) 82 | return binary_map 83 | 84 | elif group_type == "Cols": 85 | assert param.dim() == 2, "This regularization is only supported for 2D weights" 86 | thresholds = torch.Tensor([threshold] * param.size(1)).to(param.device) 87 | binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0) 88 | return binary_map 89 | 90 | elif group_type == "3D" or group_type == "Filters": 91 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 92 | view_filters = param.view(param.size(0), -1) 93 | thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device) 94 | binary_map = threshold_policy(view_filters, thresholds, threshold_criteria) 95 | return binary_map 96 | 97 | elif group_type == "4D": 98 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 99 | if threshold_criteria == "Mean_Abs": 100 | if param.data.abs().mean() > threshold: 101 | return None 102 | return torch.zeros_like(param.data) 103 | elif threshold_criteria == "Max": 104 | if param.data.abs().max() > threshold: 105 | return None 106 | return torch.zeros_like(param.data) 107 | raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria)) 108 | 109 | elif group_type == "Channels": 110 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 111 | num_filters = param.size(0) 112 | num_kernels_per_filter = param.size(1) 113 | 114 | view_2d = param.view(-1, param.size(2) * param.size(3)) 115 | # Next, compute the sum of the squares (of the elements in each row/kernel) 116 | kernel_means = view_2d.abs().mean(dim=1) 117 | k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t() 118 | thresholds = torch.Tensor([threshold] * num_kernels_per_filter).to(param.device) 119 | binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type()) 120 | return binary_map 121 | 122 | 123 | def group_threshold_mask( 124 | param, group_type, threshold, threshold_criteria, binary_map=None 125 | ): 126 | """Return a threshold mask for the provided parameter and group type. 127 | 128 | Args: 129 | param: The parameter to mask 130 | group_type: The elements grouping type (structure). 131 | One of:2D, 3D, 4D, Channels, Row, Cols 132 | threshold: The threshold 133 | threshold_criteria: The thresholding criteria. 134 | 'Mean_Abs' thresholds the entire element group using the mean of the 135 | absolute values of the tensor elements. 136 | 'Max' thresholds the entire group using the magnitude of the largest 137 | element in the group. 138 | """ 139 | if group_type == "2D": 140 | if binary_map is None: 141 | binary_map = group_threshold_binary_map( 142 | param, group_type, threshold, threshold_criteria 143 | ) 144 | 145 | # 3. Finally, expand the thresholds and view as a 4D tensor 146 | a = binary_map.expand( 147 | param.size(2) * param.size(3), param.size(0) * param.size(1) 148 | ).t() 149 | return ( 150 | a.view(param.size(0), param.size(1), param.size(2), param.size(3)), 151 | binary_map, 152 | ) 153 | 154 | elif group_type == "Rows": 155 | if binary_map is None: 156 | binary_map = group_threshold_binary_map( 157 | param, group_type, threshold, threshold_criteria 158 | ) 159 | return binary_map.expand(param.size(1), param.size(0)).t(), binary_map 160 | 161 | elif group_type == "Cols": 162 | if binary_map is None: 163 | binary_map = group_threshold_binary_map( 164 | param, group_type, threshold, threshold_criteria 165 | ) 166 | return binary_map.expand(param.size(0), param.size(1)), binary_map 167 | 168 | elif group_type == "3D" or group_type == "Filters": 169 | if binary_map is None: 170 | binary_map = group_threshold_binary_map( 171 | param, group_type, threshold, threshold_criteria 172 | ) 173 | a = binary_map.expand( 174 | param.size(1) * param.size(2) * param.size(3), param.size(0) 175 | ).t() 176 | return a.view(*param.shape), binary_map 177 | 178 | elif group_type == "4D": 179 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 180 | if threshold_criteria == "Mean_Abs": 181 | if param.data.abs().mean() > threshold: 182 | return None 183 | return torch.zeros_like(param.data) 184 | elif threshold_criteria == "Max": 185 | if param.data.abs().max() > threshold: 186 | return None 187 | return torch.zeros_like(param.data) 188 | raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria)) 189 | 190 | elif group_type == "Channels": 191 | if binary_map is None: 192 | binary_map = group_threshold_binary_map( 193 | param, group_type, threshold, threshold_criteria 194 | ) 195 | num_filters = param.size(0) 196 | num_kernels_per_filter = param.size(1) 197 | 198 | # Now let's expand back up to a 4D mask 199 | a = binary_map.expand(num_filters, num_kernels_per_filter) 200 | c = a.unsqueeze(-1) 201 | d = c.expand( 202 | num_filters, num_kernels_per_filter, param.size(2) * param.size(3) 203 | ).contiguous() 204 | return ( 205 | d.view(param.size(0), param.size(1), param.size(2), param.size(3)), 206 | binary_map, 207 | ) 208 | 209 | 210 | def threshold_policy(weights, thresholds, threshold_criteria, dim=1): 211 | """ 212 | """ 213 | if threshold_criteria in ["Mean_Abs", "Mean_L1"]: 214 | return ( 215 | weights.data.norm(p=1, dim=dim) 216 | .div(weights.size(dim)) 217 | .gt(thresholds) 218 | .type(weights.type()) 219 | ) 220 | if threshold_criteria == "Mean_L2": 221 | return ( 222 | weights.data.norm(p=2, dim=dim) 223 | .div(weights.size(dim)) 224 | .gt(thresholds) 225 | .type(weights.type()) 226 | ) 227 | elif threshold_criteria == "L1": 228 | return weights.data.norm(p=1, dim=dim).gt(thresholds).type(weights.type()) 229 | elif threshold_criteria == "L2": 230 | return weights.data.norm(p=2, dim=dim).gt(thresholds).type(weights.type()) 231 | elif threshold_criteria == "Max": 232 | maxv, _ = weights.data.abs().max(dim=dim) 233 | return maxv.gt(thresholds).type(weights.type()) 234 | raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria)) 235 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.06-py3 2 | WORKDIR /workspace/EagleEye 3 | COPY requirements_v3.txt requirements.txt 4 | RUN pip install pip -U 5 | RUN pip install --upgrade setuptools 6 | RUN pip install -U --ignore-installed wrapt==1.11.1 enum34 simplejson netaddr pyyaml msgpack==0.5.6 7 | RUN pip install -r requirements.txt 8 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | astor==0.8.0 3 | astunparse==1.6.3 4 | cachetools==4.1.1 5 | chardet==3.0.4 6 | cycler==0.10.0 7 | gast==0.3.3 8 | gitdb2==2.0.5 9 | GitPython==2.1.11 10 | h5py==2.10.0 11 | idna==2.8 12 | importlib-metadata==1.7.0 13 | Keras-Applications==1.0.8 14 | Keras-Preprocessing==1.1.0 15 | kiwisolver==1.1.0 16 | Markdown==3.2.2 17 | matplotlib==3.1.1 18 | numpy==1.16.4 19 | oauthlib==3.1.0 20 | opt-einsum==3.2.1 21 | pandas==0.25.0 22 | Pillow==6.1.0 23 | protobuf==3.9.1 24 | pyasn1==0.4.8 25 | pyasn1-modules==0.2.8 26 | pydot==1.2.4 27 | pyparsing==2.4.2 28 | python-dateutil==2.8.0 29 | pytz==2019.2 30 | PyYAML==5.1.1 31 | pyzmq==18.1.0 32 | requests==2.22.0 33 | requests-oauthlib==1.3.0 34 | rsa==4.6 35 | scipy==1.4.1 36 | six==1.12.0 37 | smmap2==2.0.5 38 | tabulate==0.8.2 39 | tensorboard==2.2.2 40 | tensorboard-plugin-wit==1.7.0 41 | tensorflow==2.2.0 42 | tensorflow-estimator==2.2.0 43 | termcolor==1.1.0 44 | torchnet==0.0.4 45 | tornado==6.0.3 46 | tqdm==4.35.0 47 | urllib3==1.25.3 48 | visdom==0.1.8.8 49 | websocket-client==0.56.0 50 | Werkzeug==0.15.5 51 | XlsxWriter==1.1.8 52 | zipp==3.1.0 53 | tensorboardX -------------------------------------------------------------------------------- /fig/cor_fix_flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous47823493/EagleEye/ba312d99587e4c1b9ffeb4f56ab81d909f2021a0/fig/cor_fix_flops.png -------------------------------------------------------------------------------- /fig/eye.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous47823493/EagleEye/ba312d99587e4c1b9ffeb4f56ab81d909f2021a0/fig/eye.png -------------------------------------------------------------------------------- /fig/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous47823493/EagleEye/ba312d99587e4c1b9ffeb4f56ab81d909f2021a0/fig/pipeline.png -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import os 8 | import torch 9 | import torch.optim as optim 10 | from options.base_options import BaseOptions 11 | from models.wrapper import ModelWrapper 12 | from report import model_summary, Reporter 13 | from data import custom_get_dataloaders 14 | import torch.nn as nn 15 | from tqdm import tqdm 16 | import random 17 | import numpy as np 18 | import distiller 19 | from thinning import thinning 20 | 21 | 22 | def random_compression_scheduler(compression_scheduler, channel_configuration): 23 | for i, item in enumerate(channel_configuration): 24 | compression_scheduler.policies[1][i].pruner.desired_sparsity = item 25 | return compression_scheduler 26 | 27 | 28 | def get_channel_config(path, line_num): 29 | # line_num starts from 0 30 | with open(path) as data: 31 | lines = data.readlines() 32 | i = 0 33 | for l in lines: 34 | if i == line_num: 35 | d = l.strip().split(" ") 36 | channel_config = [] 37 | print("=" * 20, " read config") 38 | for i in range(0, 2): 39 | print("{} ".format(d[i]), end="") 40 | for i in range(2, len(d)): 41 | channel_config.append(float(d[i])) 42 | break 43 | i += 1 44 | return channel_config 45 | 46 | 47 | def train_epoch(model_wrapper, dataloader_train, optimizer): 48 | optimizer.zero_grad() 49 | model_wrapper._net.train() 50 | 51 | loss_total = 0 52 | total = 0 53 | 54 | for iter_in_epoch, sample in enumerate(tqdm(dataloader_train, leave=False)): 55 | loss = model_wrapper.get_loss(sample) 56 | 57 | loss_total += loss.item() 58 | total += 1 59 | 60 | loss.backward() 61 | 62 | optimizer.step() 63 | optimizer.zero_grad() 64 | 65 | return loss_total / total 66 | 67 | 68 | def main(opt): 69 | # basic settings 70 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_ids)[1:-1] 71 | 72 | if torch.cuda.is_available(): 73 | device = "cuda" 74 | torch.backends.cudnn.benchmark = True 75 | else: 76 | device = "cpu" 77 | ##################### Get Dataloader #################### 78 | dataloader_train, dataloader_val = custom_get_dataloaders(opt) 79 | # dummy_input is sample input of dataloaders 80 | if hasattr(dataloader_val, "dataset"): 81 | dummy_input = dataloader_val.dataset.__getitem__(0) 82 | dummy_input = dummy_input[0] 83 | dummy_input = dummy_input.unsqueeze(0) 84 | else: 85 | # for imagenet dali loader 86 | dummy_input = torch.rand(1, 3, 224, 224) 87 | 88 | ##################### Create Baseline Model #################### 89 | net = ModelWrapper(opt) 90 | net.load_checkpoint(opt.checkpoint) 91 | flops_before, params_before = model_summary(net.get_compress_part(), dummy_input) 92 | 93 | ##################### Load Pruning Strategy ############### 94 | compression_scheduler = distiller.file_config( 95 | net.get_compress_part(), net.optimizer, opt.compress_schedule_path 96 | ) 97 | 98 | channel_config = get_channel_config( 99 | opt.search_result, opt.strategy_id 100 | ) # pruning strategy 101 | 102 | compression_scheduler = random_compression_scheduler( 103 | compression_scheduler, channel_config 104 | ) 105 | 106 | ###### Adaptive-BN-based Candidate Evaluation of Pruning Strategy ### 107 | thinning(net, compression_scheduler, input_tensor=dummy_input) 108 | 109 | flops_after, params_after = model_summary(net.get_compress_part(), dummy_input) 110 | ratio = flops_after / flops_before 111 | print("FLOPs ratio:", ratio) 112 | net = net.to(device) 113 | net.parallel(opt.gpu_ids) 114 | net.get_compress_part().train() 115 | with torch.no_grad(): 116 | for index, sample in enumerate(tqdm(dataloader_train, leave=False)): 117 | _ = net.get_loss(sample) 118 | if index > 100: 119 | break 120 | 121 | strategy_score = net.get_eval_scores(dataloader_val)["accuracy"] 122 | 123 | print( 124 | "Result file:{}, Strategy ID:{}, Evaluation score:{}".format( 125 | opt.search_result, opt.strategy_id, strategy_score 126 | ) 127 | ) 128 | 129 | ##################### Fine-tuning ######################### 130 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(net.optimizer, opt.epoch) 131 | reporter = Reporter(opt) 132 | best_acc = 0 133 | net._net.train() 134 | for epoch in range(1, opt.epoch + 1): 135 | reporter.log_metric("lr", net.optimizer.param_groups[0]["lr"], epoch) 136 | train_loss = train_epoch(net, dataloader_train, net.optimizer,) 137 | reporter.log_metric("train_loss", train_loss, epoch) 138 | 139 | lr_scheduler.step() 140 | 141 | scores = net.get_eval_scores(dataloader_val) 142 | print("==> Evaluation: Epoch={} Acc={}".format(epoch, str(scores))) 143 | 144 | reporter.log_metric("eval_acc", scores["accuracy"], epoch) 145 | 146 | if scores["accuracy"] > best_acc: 147 | best_acc = scores["accuracy"] 148 | reporter.log_metric("best_acc", best_acc, epoch) 149 | 150 | save_checkpoints( 151 | scores["accuracy"], net._net, reporter, opt.exp_name, epoch, 152 | ) 153 | 154 | print("==> Training epoch %d" % epoch) 155 | 156 | 157 | def save_checkpoints(acc, model, reporter, exp_name, epoch): 158 | if not hasattr(save_checkpoints, "best_acc"): 159 | save_checkpoints.best_acc = 0 160 | 161 | state_dict = model.state_dict() 162 | reporter.save_checkpoint(state_dict, "{}_latest.pth".format(exp_name), epoch) 163 | if acc > save_checkpoints.best_acc: 164 | reporter.save_checkpoint(state_dict, "{}_best.pth".format(exp_name), epoch) 165 | save_checkpoints.best_acc = acc 166 | reporter.save_checkpoint(state_dict, "{}_{}.pth".format(exp_name, epoch), epoch) 167 | 168 | 169 | if __name__ == "__main__": 170 | # get options 171 | opt = BaseOptions().parse() 172 | main(opt) 173 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import os 8 | import torch 9 | from options.base_options import BaseOptions 10 | from models.wrapper import ModelWrapper 11 | from report import model_summary 12 | from data import custom_get_dataloaders 13 | import torch.nn as nn 14 | from tqdm import tqdm 15 | import random 16 | import numpy as np 17 | 18 | 19 | def main(): 20 | # get options 21 | opt = BaseOptions().parse() 22 | # basic settings 23 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_ids)[1:-1] 24 | 25 | if torch.cuda.is_available(): 26 | device = "cuda" 27 | torch.backends.cudnn.benchmark = True 28 | else: 29 | device = "cpu" 30 | ##################### Get Dataloader #################### 31 | _, dataloader_test = custom_get_dataloaders(opt) 32 | # dummy_input is sample input of dataloaders 33 | if hasattr(dataloader_test, "dataset"): 34 | dummy_input = dataloader_test.dataset.__getitem__(0) 35 | dummy_input = dummy_input[0] 36 | dummy_input = dummy_input.unsqueeze(0) 37 | else: 38 | # for imagenet dali loader 39 | dummy_input = torch.rand(1, 3, 224, 224) 40 | ##################### Evaluate Baseline Model #################### 41 | net = ModelWrapper(opt) 42 | net = net.to(device) 43 | net.parallel(opt.gpu_ids) 44 | flops_before, params_before = model_summary(net.get_compress_part(), dummy_input) 45 | 46 | del net 47 | ##################### Evaluate Pruned Model #################### 48 | net = ModelWrapper(opt) 49 | net.load_checkpoint(opt.checkpoint) 50 | net = net.to(device) 51 | flops_after, params_after = model_summary(net.get_compress_part(), dummy_input) 52 | net.parallel(opt.gpu_ids) 53 | 54 | acc_after = net.get_eval_scores(dataloader_test) 55 | 56 | #################### Report ##################### 57 | print("######### Report #########") 58 | print("Model:{}".format(opt.model_name)) 59 | print("Checkpoint:{}".format(opt.checkpoint)) 60 | print( 61 | "FLOPs of Original Model:{:.3f}G;Params of Original Model:{:.2f}M".format( 62 | flops_before / 1e9, params_before / 1e6 63 | ) 64 | ) 65 | print( 66 | "FLOPs of Pruned Model:{:.3f}G;Params of Pruned Model:{:.2f}M".format( 67 | flops_after / 1e9, params_after / 1e6 68 | ) 69 | ) 70 | print( 71 | "Top-1 Acc of Pruned Model on {}:{}".format( 72 | opt.dataset_name, acc_after["accuracy"] 73 | ) 74 | ) 75 | print("##########################") 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous47823493/EagleEye/ba312d99587e4c1b9ffeb4f56ab81d909f2021a0/models/__init__.py -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | """Depthwise conv + Pointwise conv""" 13 | 14 | def __init__(self, in_planes, out_planes, stride=1): 15 | super(Block, self).__init__() 16 | self.conv1 = nn.Conv2d( 17 | in_planes, 18 | in_planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=1, 22 | groups=in_planes, 23 | bias=False, 24 | ) 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv2 = nn.Conv2d( 27 | in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False 28 | ) 29 | self.bn2 = nn.BatchNorm2d(out_planes) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | return out 35 | 36 | 37 | class MobileNet(nn.Module): 38 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 39 | cfg = [ 40 | 64, 41 | (128, 2), 42 | 128, 43 | (256, 2), 44 | 256, 45 | (512, 2), 46 | 512, 47 | 512, 48 | 512, 49 | 512, 50 | 512, 51 | (1024, 2), 52 | 1024, 53 | ] 54 | 55 | def __init__(self, num_classes=10): 56 | super(MobileNet, self).__init__() 57 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False) 58 | self.bn1 = nn.BatchNorm2d(32) 59 | self.layers = self._make_layers(in_planes=32) 60 | self.linear = nn.Linear(1024, num_classes) 61 | self.relu = nn.ReLU() 62 | 63 | def _make_layers(self, in_planes): 64 | layers = [] 65 | for x in self.cfg: 66 | out_planes = x if isinstance(x, int) else x[0] 67 | stride = 1 if isinstance(x, int) else x[1] 68 | layers.append(Block(in_planes, out_planes, stride)) 69 | in_planes = out_planes 70 | return nn.Sequential(*layers) 71 | 72 | def forward(self, x): 73 | out = self.relu(self.bn1(self.conv1(x))) 74 | out = self.layers(out) 75 | out = F.avg_pool2d(out, 7) 76 | out = out.view(out.size(0), -1) 77 | out = self.linear(out) 78 | return out 79 | 80 | 81 | def test(): 82 | net = MobileNet() 83 | x = torch.randn(1, 3, 32, 32) 84 | y = net(x) 85 | print(y.size()) 86 | 87 | 88 | # test() 89 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | __all__ = ["resnet50"] 6 | 7 | 8 | def resnet50(**kwargs): 9 | import torchvision 10 | 11 | model = torchvision.models.resnet50(**kwargs) 12 | return model 13 | 14 | 15 | if __name__ == "__main__": 16 | net = resnet50() 17 | image = torch.randn(2, 3, 224, 224) 18 | print(net) 19 | print(net.layer1[1].conv2) 20 | out = net(image) 21 | print(out.size()) 22 | 23 | # print(distiller.weights_sparsity_summary(net)) 24 | -------------------------------------------------------------------------------- /models/wrapper.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import os 8 | import sys 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | import shutil 15 | import torch.optim as optim 16 | import numpy as np 17 | 18 | 19 | class ModelWrapper(nn.Module): 20 | def __init__(self, opt): 21 | super(ModelWrapper, self).__init__() 22 | if opt.model_name == "mobilenetv1": 23 | from .mobilenet import MobileNet 24 | 25 | self._net = MobileNet(num_classes=opt.num_classes) 26 | elif opt.model_name == "resnet50": 27 | from .resnet import resnet50 28 | 29 | self._net = resnet50(num_classes=opt.num_classes) 30 | 31 | self.optimizer = optim.SGD( 32 | self._net.parameters(), 33 | lr=opt.lr, 34 | momentum=opt.momentum, 35 | weight_decay=opt.weight_decay, 36 | ) 37 | self._criterion = nn.CrossEntropyLoss() 38 | 39 | def forward(self, x): # test forward 40 | x, _ = x 41 | 42 | self._net.eval() 43 | device = next(self.parameters()).device 44 | x = x.to(device) 45 | out = self._net(x) 46 | 47 | return out 48 | 49 | def get_compress_part(self): 50 | return self._net 51 | 52 | def parallel(self, gpu_ids): 53 | if len(gpu_ids) > 1: 54 | self._net = nn.DataParallel(self._net) 55 | 56 | def get_loss(self, inputs): 57 | device = next(self.parameters()).device 58 | 59 | self._net.train() 60 | images, targets = inputs 61 | images, targets = images.to(device), targets.to(device) 62 | out = self._net(images) 63 | loss = self._criterion(out, targets) 64 | 65 | return loss 66 | 67 | def get_eval_scores(self, dataloader_test): 68 | from tqdm import tqdm 69 | 70 | device = next(self.parameters()).device 71 | to_cuda = next(self.parameters()).device.type == "cuda" 72 | 73 | total = 0 74 | correct = 0 75 | 76 | top_5_acc = 0 77 | 78 | self._net.eval() 79 | # print('==> evaluating accuracy') 80 | with torch.no_grad(): 81 | for i, sample in enumerate( 82 | tqdm(dataloader_test, leave=False, desc="evaluating accuracy") 83 | ): 84 | outputs = self.forward(sample) 85 | _, predicted = outputs.max(1) 86 | targets = sample[1].to(device) 87 | 88 | prec5 = accuracy(outputs.data, targets.data, topk=(5,)) 89 | prec5 = prec5[0] 90 | top_5_acc += prec5.item() * targets.size(0) 91 | 92 | correct += predicted.eq(targets).sum().item() 93 | total += targets.size(0) 94 | acc = correct / total 95 | top_5_acc /= total 96 | scores = {"accuracy": round(acc, 3)} 97 | 98 | return scores 99 | 100 | def load_checkpoint(self, checkpoint_file): 101 | """ 102 | Function to load pruned model or normal model checkpoint. 103 | :param str checkpoint_file: path to checkpoint file, such as `models/ckpt/mobilenet.pth` 104 | """ 105 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 106 | net = self.get_compress_part() 107 | #### load pruned model #### 108 | for key, module in net.named_modules(): 109 | # torch.nn.BatchNorm2d 110 | if isinstance(module, nn.BatchNorm2d): 111 | module.weight = torch.nn.Parameter(checkpoint[key + ".weight"]) 112 | module.bias = torch.nn.Parameter(checkpoint[key + ".bias"]) 113 | module.num_features = module.weight.size(0) 114 | module.running_mean = module.running_mean[0 : module.num_features] 115 | module.running_var = module.running_var[0 : module.num_features] 116 | # torch.nn.Conv2d 117 | elif isinstance(module, nn.Conv2d): 118 | # for conv2d layer, bias and groups should be consider 119 | module.weight = torch.nn.Parameter(checkpoint[key + ".weight"]) 120 | module.out_channels = module.weight.size(0) 121 | module.in_channels = module.weight.size(1) 122 | if module.groups is not 1: 123 | # group convolution case 124 | # only support for MobileNet, pointwise conv 125 | module.in_channels = module.weight.size(0) 126 | module.groups = module.in_channels 127 | if key + ".bias" in checkpoint: 128 | module.bias = torch.nn.Parameter(checkpoint[key + ".bias"]) 129 | # torch.nn.Linear 130 | elif isinstance(module, nn.Linear): 131 | module.weight = torch.nn.Parameter(checkpoint[key + ".weight"]) 132 | if key + ".bias" in checkpoint: 133 | module.bias = torch.nn.Parameter(checkpoint[key + ".bias"]) 134 | module.out_features = module.weight.size(0) 135 | module.in_features = module.weight.size(1) 136 | 137 | net.load_state_dict(checkpoint) 138 | 139 | 140 | def accuracy(output, target, topk=(1,)): 141 | """Computes the precision@k for the specified values of k""" 142 | batch_size = target.size(0) 143 | num = output.size(1) 144 | target_topk = [] 145 | appendices = [] 146 | for k in topk: 147 | if k <= num: 148 | target_topk.append(k) 149 | else: 150 | appendices.append([0.0]) 151 | topk = target_topk 152 | maxk = max(topk) 153 | _, pred = output.topk(maxk, 1, True, True) 154 | pred = pred.t() 155 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 156 | 157 | res = [] 158 | for k in topk: 159 | correct_k = correct[:k].view(-1).float().sum(0) 160 | res.append(correct_k.mul_(100.0 / batch_size)) 161 | return res + appendices 162 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous47823493/EagleEye/ba312d99587e4c1b9ffeb4f56ab81d909f2021a0/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import argparse 8 | import os 9 | 10 | 11 | class BaseOptions: 12 | def __init__(self): 13 | self.parser = argparse.ArgumentParser() 14 | self.initialized = False 15 | 16 | def initialize(self): 17 | # model params 18 | self.parser.add_argument( 19 | "--model_name", 20 | type=str, 21 | default="mobilenetv1", 22 | help="what kind of model you are using. Only support `resnet50`, `mobilenetv1` and `mobilenetv1_imagenet`", 23 | ) 24 | self.parser.add_argument( 25 | "--num_classes", type=int, default=1000, help="num of class label" 26 | ) 27 | self.parser.add_argument( 28 | "--checkpoint", type=str, default="", help="path to model state dict" 29 | ) 30 | 31 | # env params 32 | self.parser.add_argument( 33 | "--gpu_ids", type=int, default=[0], nargs="+", help="GPU ids." 34 | ) 35 | 36 | # fine-tune params 37 | self.parser.add_argument( 38 | "--batch_size", type=int, default=64, help="batch size while fine-tuning" 39 | ) 40 | self.parser.add_argument( 41 | "--epoch", type=int, default=120, help="epoch while fine-tuning" 42 | ) 43 | self.parser.add_argument( 44 | "--dataset_path", type=str, default="./cifar10", help="path to dataset" 45 | ) 46 | self.parser.add_argument( 47 | "--dataset_name", 48 | type=str, 49 | default="cifar10_224", 50 | help="filename of the file contains your own `get_dataloaders` function", 51 | ) 52 | self.parser.add_argument( 53 | "--num_workers", 54 | type=int, 55 | default=16, 56 | help="Number of workers used in dataloading", 57 | ) 58 | self.parser.add_argument( 59 | "--lr", type=float, default=0.01, help="learning rate while fine-tuning" 60 | ) 61 | self.parser.add_argument( 62 | "--weight_decay", 63 | type=float, 64 | default=5e-4, 65 | help="weight decay while fine-tuning", 66 | ) 67 | self.parser.add_argument( 68 | "--momentum", type=float, default=0.9, help="momentum while fine-tuning" 69 | ) 70 | 71 | self.parser.add_argument( 72 | "--search_result", 73 | type=str, 74 | default="mbv1.txt", 75 | help="path to search result", 76 | ) 77 | self.parser.add_argument( 78 | "--strategy_id", type=int, default=0, help="line num in search result file" 79 | ) 80 | 81 | self.parser.add_argument("--log_dir", type=str, default="logs/", help="log dir") 82 | self.parser.add_argument( 83 | "--exp_name", type=str, default="mbv1_50flops", help="experiment name" 84 | ) 85 | 86 | # search params 87 | self.parser.add_argument( 88 | "--max_rate", type=float, default=0.7, help="define search space" 89 | ) 90 | self.parser.add_argument( 91 | "--min_rate", type=float, default=0, help="define search space" 92 | ) 93 | self.parser.add_argument( 94 | "--compress_schedule_path", 95 | type=str, 96 | default="compress_config/mbv1_imagenet.yaml", 97 | help="path to compression schedule", 98 | ) 99 | self.parser.add_argument( 100 | "--flops_target", 101 | type=float, 102 | default=0.5, 103 | help="flops constraints for pruning", 104 | ) 105 | self.parser.add_argument( 106 | "--output_file", type=str, default="mbv1.txt", help="path to search result" 107 | ) 108 | 109 | self.initialized = True 110 | 111 | def parse(self, save=True): 112 | if not self.initialized: 113 | self.initialize() 114 | self.opt = self.parser.parse_args() 115 | return self.opt 116 | -------------------------------------------------------------------------------- /report/__init__.py: -------------------------------------------------------------------------------- 1 | import distiller 2 | import torch 3 | import os 4 | import os.path as osp 5 | import pandas as pd 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | from datetime import datetime 9 | 10 | 11 | def weights_sparsity_summary(model, opt=None): 12 | try: 13 | df = distiller.weights_sparsity_summary( 14 | model.module, return_total_sparsity=True 15 | ) 16 | except AttributeError: 17 | df = distiller.weights_sparsity_summary(model, return_total_sparsity=True) 18 | return df[0]["NNZ (dense)"].sum() // 2 19 | 20 | 21 | def performance_summary(model, dummy_input, opt=None, prefix=""): 22 | try: 23 | df = distiller.model_performance_summary(model.module, dummy_input) 24 | except AttributeError: 25 | df = distiller.model_performance_summary(model, dummy_input) 26 | new_entry = { 27 | "Name": ["Total"], 28 | "MACs": [df["MACs"].sum()], 29 | } 30 | MAC_total = df["MACs"].sum() 31 | return MAC_total 32 | 33 | 34 | def model_summary(model, dummy_input, opt=None): 35 | return ( 36 | performance_summary(model, dummy_input, opt), 37 | weights_sparsity_summary(model, opt), 38 | ) 39 | 40 | 41 | def _check_mk_path(path): 42 | if not osp.exists(path): 43 | os.makedirs(path) 44 | 45 | 46 | class Reporter: 47 | def __init__(self, opt, use_time=True): 48 | now = datetime.now().strftime("-%Y-%m-%d-%H:%M:%S") 49 | 50 | if use_time: 51 | self.log_dir = osp.join(opt.log_dir, opt.exp_name + now) 52 | else: 53 | self.log_dir = osp.join(opt.log_dir, opt.exp_name) 54 | 55 | _check_mk_path(self.log_dir) 56 | 57 | self.writer = SummaryWriter(self.log_dir) 58 | 59 | self.ckpt_log_dir = osp.join(self.log_dir, "checkpoints") 60 | _check_mk_path(self.ckpt_log_dir) 61 | 62 | self.config_log_dir = osp.join(self.log_dir, "config") 63 | _check_mk_path(self.config_log_dir) 64 | 65 | def log_config(self, path): 66 | target = osp.join(self.config_log_dir, path.split("/")[-1]) 67 | shutil.copyfile(path, target) 68 | 69 | def get_writer(self): 70 | return self.writer 71 | 72 | def log_metric(self, key, value, step): 73 | self.writer.add_scalar("data/" + key, value, step) 74 | 75 | def log_text(self, msg): 76 | print(msg) 77 | 78 | def save_checkpoint(self, state_dict, ckpt_name, epoch=0): 79 | checkpoint = {"state_dict": state_dict, "epoch": epoch} 80 | torch.save(checkpoint, osp.join(self.ckpt_log_dir, ckpt_name)) 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | astor==0.8.0 3 | astunparse==1.6.3 4 | cachetools==4.1.1 5 | certifi==2019.6.16 6 | chardet==3.0.4 7 | cycler==0.10.0 8 | gast==0.3.3 9 | gitdb2==2.0.5 10 | GitPython==2.1.11 11 | google-auth==1.18.0 12 | google-auth-oauthlib==0.4.1 13 | google-pasta==0.2.0 14 | grpcio==1.23.0 15 | h5py==2.10.0 16 | idna==2.8 17 | importlib-metadata==1.7.0 18 | Keras-Applications==1.0.8 19 | Keras-Preprocessing==1.1.0 20 | kiwisolver==1.1.0 21 | Markdown==3.2.2 22 | matplotlib==3.1.1 23 | numpy==1.16.4 24 | oauthlib==3.1.0 25 | opt-einsum==3.2.1 26 | pandas==0.25.0 27 | Pillow==6.1.0 28 | protobuf==3.9.1 29 | pyasn1==0.4.8 30 | pyasn1-modules==0.2.8 31 | pydot==1.2.4 32 | pyparsing==2.4.2 33 | python-dateutil==2.8.0 34 | pytz==2019.2 35 | PyYAML==5.1.1 36 | pyzmq==18.1.0 37 | requests==2.22.0 38 | requests-oauthlib==1.3.0 39 | rsa==4.6 40 | scipy==1.4.1 41 | six==1.12.0 42 | smmap2==2.0.5 43 | tabulate==0.8.2 44 | tensorboard==2.2.2 45 | tensorboard-plugin-wit==1.7.0 46 | tensorflow==2.2.0 47 | tensorflow-estimator==2.2.0 48 | termcolor==1.1.0 49 | torch==1.1.0 50 | torchnet==0.0.4 51 | torchvision==0.3.0 52 | tornado==6.0.3 53 | tqdm==4.35.0 54 | urllib3==1.25.3 55 | visdom==0.1.8.8 56 | websocket-client==0.56.0 57 | Werkzeug==0.15.5 58 | wrapt==1.11.1 59 | XlsxWriter==1.1.8 60 | zipp==3.1.0 61 | tensorboardX -------------------------------------------------------------------------------- /scripts/mbv1_50flops.sh: -------------------------------------------------------------------------------- 1 | ###### 1. Search ###### 2 | python3 search.py \ 3 | --model_name mobilenetv1 \ 4 | --num_classes 1000 \ 5 | --checkpoint models/ckpt/imagenet_mobilenet_full_model.pth \ 6 | --gpu_ids 0 \ 7 | --batch_size 128 \ 8 | --dataset_path /data/imagenet \ 9 | --dataset_name imagenet_train_val_split \ 10 | --num_workers 4 \ 11 | --flops_target 0.5 \ 12 | --max_rate 0.7 \ 13 | --affine 0 \ 14 | --output_file search_results/mbv1_strategies.txt \ 15 | --compress_schedule_path compress_config/mbv1_imagenet.yaml 16 | 17 | ##### 2. Selection ####### 18 | python choose_strategy.py search_results/mbv1_strategies.txt 19 | 20 | ##### 3. Fine-tuning ####### 21 | python3 finetune.py \ 22 | --model_name mobilenetv1 \ 23 | --num_classes 1000 \ 24 | --checkpoint models/ckpt/imagenet_mobilenet_full_model.pth \ 25 | --gpu_ids [GPU_IDS] \ 26 | --batch_size 512 \ 27 | --dataset_path /data/imagenet \ 28 | --dataset_name imagenet \ 29 | --exp_name mbv1_50flops \ 30 | --search_result search_results/mbv1_strategies.txt \ 31 | --strategy_id 0 \ 32 | --epoch 120 \ 33 | --lr 1e-2 \ 34 | --weight_decay 1e-4 35 | -------------------------------------------------------------------------------- /scripts/res50_25flops.sh: -------------------------------------------------------------------------------- 1 | ###### 1. Search ###### 2 | python3 search.py \ 3 | --model_name resnet50 \ 4 | --num_classes 1000 \ 5 | --checkpoint models/ckpt/imagenet_resnet50_full_model.pth \ 6 | --gpu_ids 0 \ 7 | --batch_size 128 \ 8 | --dataset_path /data/imagenet \ 9 | --dataset_name imagenet_train_val_split \ 10 | --num_workers 4 \ 11 | --flops_target 0.25 \ 12 | --max_rate 0.8 \ 13 | --affine 0 \ 14 | --output_file search_results/res50_25flops_strategies.txt \ 15 | --compress_schedule_path compress_config/res50_imagenet.yaml 16 | 17 | ##### 2. Selection ####### 18 | python choose_strategy.py search_results/res50_25flops_strategies.txt 19 | 20 | ##### 3. Fine-tuning ####### 21 | python3 finetune.py \ 22 | --model_name resnet50 \ 23 | --num_classes 1000 \ 24 | --checkpoint models/ckpt/imagenet_resnet50_full_model.pth \ 25 | --gpu_ids [GPU_IDS] \ 26 | --batch_size 128 \ 27 | --dataset_path /data/imagenet \ 28 | --dataset_name imagenet \ 29 | --exp_name resnet50_25flops \ 30 | --search_result search_results/res50_25flops_strategies.txt \ 31 | --strategy_id 0 \ 32 | --epoch 120 \ 33 | --lr 1e-2 \ 34 | --weight_decay 1e-4 35 | -------------------------------------------------------------------------------- /scripts/res50_50flops.sh: -------------------------------------------------------------------------------- 1 | ###### 1. Search ###### 2 | python3 search.py \ 3 | --model_name resnet50 \ 4 | --num_classes 1000 \ 5 | --checkpoint models/ckpt/imagenet_resnet50_full_model.pth \ 6 | --gpu_ids 0 \ 7 | --batch_size 128 \ 8 | --dataset_path /data/imagenet \ 9 | --dataset_name imagenet_train_val_split \ 10 | --num_workers 4 \ 11 | --flops_target 0.5 \ 12 | --max_rate 0.7 \ 13 | --affine 0 \ 14 | --output_file search_results/res50_50flops_strategies.txt \ 15 | --compress_schedule_path compress_config/res50_imagenet.yaml 16 | 17 | ##### 2. Selection ####### 18 | python choose_strategy.py search_results/res50_50flops_strategies.txt 19 | 20 | ##### 3. Fine-tuning ####### 21 | python3 finetune.py \ 22 | --model_name resnet50 \ 23 | --num_classes 1000 \ 24 | --checkpoint models/ckpt/imagenet_resnet50_full_model.pth \ 25 | --gpu_ids [GPU_IDS] \ 26 | --batch_size 128 \ 27 | --dataset_path /data/imagenet \ 28 | --dataset_name imagenet \ 29 | --exp_name resnet50_50flops \ 30 | --search_result search_results/res50_50flops_strategies.txt \ 31 | --strategy_id 0 \ 32 | --epoch 120 \ 33 | --lr 1e-2 \ 34 | --weight_decay 1e-4 35 | -------------------------------------------------------------------------------- /scripts/res50_75flops.sh: -------------------------------------------------------------------------------- 1 | ###### 1. Search ###### 2 | python3 search.py \ 3 | --model_name resnet50 \ 4 | --num_classes 1000 \ 5 | --checkpoint models/ckpt/imagenet_resnet50_full_model.pth \ 6 | --gpu_ids 0 \ 7 | --batch_size 128 \ 8 | --dataset_path /data/imagenet \ 9 | --dataset_name imagenet_train_val_split \ 10 | --num_workers 4 \ 11 | --flops_target 0.75 \ 12 | --max_rate 0.4 \ 13 | --affine 0 \ 14 | --output_file search_results/res50_75flops_strategies.txt \ 15 | --compress_schedule_path compress_config/res50_imagenet.yaml 16 | 17 | ##### 2. Selection ####### 18 | python choose_strategy.py search_results/res50_75flops_strategies.txt 19 | 20 | ##### 3. Fine-tuning ####### 21 | python3 finetune.py \ 22 | --model_name resnet50 \ 23 | --num_classes 1000 \ 24 | --checkpoint models/ckpt/imagenet_resnet50_full_model.pth \ 25 | --gpu_ids [GPU_IDS] \ 26 | --batch_size 128 \ 27 | --dataset_path /data/imagenet \ 28 | --dataset_name imagenet \ 29 | --exp_name resnet50_75flops \ 30 | --search_result search_results/res50_75flops_strategies.txt \ 31 | --strategy_id 0 \ 32 | --epoch 120 \ 33 | --lr 1e-2 \ 34 | --weight_decay 1e-4 35 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import os 8 | import torch 9 | from options.base_options import BaseOptions 10 | from models.wrapper import ModelWrapper 11 | from report import model_summary 12 | from data import custom_get_dataloaders 13 | import torch.nn as nn 14 | from tqdm import tqdm 15 | import random 16 | import numpy as np 17 | import distiller 18 | from thinning import thinning 19 | 20 | 21 | def random_compression_scheduler(compression_scheduler, channel_configuration): 22 | for i, item in enumerate(channel_configuration): 23 | compression_scheduler.policies[1][i].pruner.desired_sparsity = item 24 | return compression_scheduler 25 | 26 | 27 | def get_pruning_strategy(opt, num_layer): 28 | channel_config = np.random.rand(num_layer) 29 | channel_config = channel_config * opt.max_rate 30 | channel_config = channel_config + opt.min_rate 31 | return channel_config 32 | 33 | 34 | def main(opt): 35 | # basic settings 36 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_ids)[1:-1] 37 | 38 | if torch.cuda.is_available(): 39 | device = "cuda" 40 | torch.backends.cudnn.benchmark = True 41 | else: 42 | device = "cpu" 43 | ##################### Get Dataloader #################### 44 | dataloader_train, dataloader_val = custom_get_dataloaders(opt) 45 | # dummy_input is sample input of dataloaders 46 | if hasattr(dataloader_val, "dataset"): 47 | dummy_input = dataloader_val.dataset.__getitem__(0) 48 | dummy_input = dummy_input[0] 49 | dummy_input = dummy_input.unsqueeze(0) 50 | else: 51 | # for imagenet dali loader 52 | dummy_input = torch.rand(1, 3, 224, 224) 53 | 54 | ##################### Create Baseline Model #################### 55 | net = ModelWrapper(opt) 56 | net.load_checkpoint(opt.checkpoint) 57 | flops_before, params_before = model_summary(net.get_compress_part(), dummy_input) 58 | 59 | ##################### Pruning Strategy Generation ############### 60 | compression_scheduler = distiller.file_config( 61 | net.get_compress_part(), net.optimizer, opt.compress_schedule_path 62 | ) 63 | num_layer = len(compression_scheduler.policies[1]) 64 | 65 | channel_config = get_pruning_strategy(opt, num_layer) # pruning strategy 66 | 67 | compression_scheduler = random_compression_scheduler( 68 | compression_scheduler, channel_config 69 | ) 70 | 71 | ###### Adaptive-BN-based Candidate Evaluation of Pruning Strategy ### 72 | try: 73 | thinning(net, compression_scheduler, input_tensor=dummy_input) 74 | except Exception as e: 75 | print('[WARNING] This pruning strategy is invalid for distiller thinning module, pass it.') 76 | print(e) 77 | return 78 | 79 | flops_after, params_after = model_summary(net.get_compress_part(), dummy_input) 80 | ratio = flops_after / flops_before 81 | print("FLOPs ratio:", ratio) 82 | if ratio < opt.flops_target - 0.005 or ratio > opt.flops_target + 0.005: 83 | # illegal pruning strategy 84 | return 85 | net = net.to(device) 86 | net.parallel(opt.gpu_ids) 87 | net.get_compress_part().train() 88 | with torch.no_grad(): 89 | for index, sample in enumerate(tqdm(dataloader_train, leave=False)): 90 | _ = net.get_loss(sample) 91 | if index > 100: 92 | break 93 | 94 | strategy_score = net.get_eval_scores(dataloader_val)["accuracy"] 95 | 96 | #################### Save Pruning Strategy and Score ######### 97 | log_file = open(opt.output_file, "a+") 98 | log_file.write("{} {} ".format(strategy_score, ratio)) 99 | 100 | for item in channel_config: 101 | log_file.write("{} ".format(str(item))) 102 | log_file.write("\n") 103 | log_file.close() 104 | print("Eval Score:{}".format(strategy_score)) 105 | 106 | 107 | if __name__ == "__main__": 108 | # get options 109 | opt = BaseOptions().parse() 110 | while True: 111 | main(opt) 112 | -------------------------------------------------------------------------------- /search_results/best_strategy_mbv1_50flops.txt: -------------------------------------------------------------------------------- 1 | 0.1411 0.50805 0.34522834 0.25616112 0.34354134 0.3639849 0.35014121 0.34809587 0.1421153 0.28915166 0.30550963 0.39298936 0.32038612 0.15806762 0.13184949 0.39613 -------------------------------------------------------------------------------- /search_results/best_strategy_res50_25flops.txt: -------------------------------------------------------------------------------- 1 | 0.00526 0.25 0.1939672 0.21797366 0.46062495 0.65938095 0.71535156 0.79856479 0.28947926 0.6437196 0.13621386 0.3310043 0.7447816 0.57806314 0.71184711 0.38868329 0.76312789 0.55029199 0.46928421 0.10625824 0.28823312 0.55573736 0.53709531 0.23434097 0.78005569 0.46100279 0.49465864 0.78342664 0.62700501 0.22615819 0.69553466 0.68402133 0.69224432 0.18955793 0.72702437 0.52134052 0.17007695 0.29444271 0.53916772 2 | -------------------------------------------------------------------------------- /search_results/best_strategy_res50_50flops.txt: -------------------------------------------------------------------------------- 1 | 0.12644 0.50 0.10272261 0.2827283 0.46222381 0.46675718 0.39910313 0.46953703 0.45295396 0.21956719 0.13951918 0.39655863 0.04569472 0.31119637 0.30863393 0.4472352 0.38888706 0.38405677 0.29581274 0.1234234 0.23327343 0.1541599 0.37519254 0.19114498 0.30318672 0.32446214 0.37455404 0.04177907 0.29925152 0.40318202 0.41735264 0.44579575 0.20660986 0.39657288 0.25469524 0.41795762 0.30585931 0.14133776 0.04950175 2 | -------------------------------------------------------------------------------- /search_results/best_strategy_res50_75flops.txt: -------------------------------------------------------------------------------- 1 | 0.60264 0.75 0.03805596 0.11375178 0.12906994 0.29345854 0.29473912 0.23196617 0.07760712 0.18177539 0.03299144 0.03357534 0.24132784 0.18612661 0.1085735 0.0887458 0.26656344 0.2919708 0.08160947 0.07458963 0.21765288 0.06074208 0.15334014 0.13102724 0.12475932 0.19935437 0.006374 0.11714018 0.16346813 0.17526527 0.27989531 0.10789936 0.07112449 0.19612177 0.20006662 0.24498201 0.20368413 0.04279182 0.03836601 -------------------------------------------------------------------------------- /thinning/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------ 2 | # Author: Bowen Wu 3 | # Email: wubw6@mail2.sysu.edu.cn 4 | # Affiliation: Sun Yat-sen University, Guangzhou 5 | # Date: 13 JULY 2020 6 | # ------------------------------------------------------------------ 7 | import torch 8 | import distiller 9 | 10 | 11 | def thinning(net, scheduler, input_tensor=None): 12 | scheduler.on_epoch_begin(1) 13 | scheduler.mask_all_weights() 14 | 15 | def create_graph(model): 16 | if input_tensor is not None: 17 | dummy_input = input_tensor 18 | else: 19 | dummy_input = torch.randn(16, 3, 32, 32) 20 | return distiller.SummaryGraph(model, dummy_input) 21 | 22 | sgraph = create_graph(net._net) 23 | from distiller.thinning import create_thinning_recipe_filters, apply_and_save_recipe 24 | 25 | thinning_recipe = create_thinning_recipe_filters( 26 | sgraph, net._net, scheduler.zeros_mask_dict 27 | ) 28 | apply_and_save_recipe( 29 | net._net, scheduler.zeros_mask_dict, thinning_recipe, net.optimizer 30 | ) 31 | net.optimizer.param_groups[0]['params'] = list(net._net.parameters()) 32 | return net 33 | --------------------------------------------------------------------------------