├── __init__.py ├── examples ├── __init__.py ├── classifier_compression │ ├── __init__.py │ ├── logging.conf │ └── inspect_ckpt.py ├── quantization │ ├── alexnet_bn_base_fp32.yaml │ ├── preact_resnet18_imagenet_base_fp32.yaml │ ├── alexnet_bn_dorefa.yaml │ ├── preact_resnet18_imagenet_dorefa.yaml │ ├── preact_resnet20_cifar_dorefa.yaml │ ├── preact_resnet20_cifar_base_fp32.yaml │ └── preact_resnet20_cifar_pact.yaml ├── sensitivity-pruning │ ├── alexnet.schedule_sensitivity_direct.yaml │ └── alexnet.schedule_sensitivity.yaml ├── network_slimming │ └── resnet20_slimming.yaml ├── pruning_filters_for_efficient_convnets │ └── vgg19.schedule_filter_rank.yaml ├── gss │ └── gss_channels-removal_training.yaml ├── baidu-rnn-pruning │ └── word_lang_model.schedule_baidu_rnn.yaml └── hybrid │ └── alexnet.schedule_sensitivity_2D-reg.yaml ├── imgs ├── banner1.png ├── ch_compute_stats.png ├── ch_sparsity_stats.png ├── simplenet_training.png ├── resnet18-sensitivity.png ├── resnet18_summary_png.png ├── ch_sparsity_stats_barchart.png ├── resnet18_summary_compute.png ├── resnet18_summary_modules.png ├── resnet18_summary_sparsity.png └── wiki │ └── word_lang_model_performance.png ├── pylintrc ├── .gitignore ├── OCS-CNN ├── example_train.sh ├── config │ ├── alexnet_bn_base_fp32.yaml │ ├── preact_resnet18_imagenet_base_fp32.yaml │ ├── mobilenet_cifar_base_fp32.yaml │ ├── resnet20_cifar_base_fp32.yaml │ ├── alexnet_bn_dorefa.yaml │ ├── preact_resnet18_imagenet_dorefa.yaml │ ├── preact_resnet20_cifar_dorefa.yaml │ ├── preact_resnet20_cifar_base_fp32.yaml │ └── preact_resnet20_cifar_pact.yaml ├── example.sh ├── logging.conf ├── scripts │ ├── clip_script.py │ ├── ocs_script.py │ ├── parse_ocs.py │ └── parse_clip.py └── inspect_ckpt.py ├── requirements.txt ├── distiller ├── data_loggers │ ├── __init__.py │ ├── tbbackend.py │ └── collector.py ├── regularization │ ├── __init__.py │ ├── regularizer.py │ └── l1_regularizer.py ├── quantization │ ├── __init__.py │ ├── q_utils.py │ └── ocs_impl.py ├── pruning │ ├── __init__.py │ ├── pruner.py │ ├── structure_pruner.py │ ├── level_pruner.py │ ├── magnitude_pruner.py │ ├── sensitivity_pruner.py │ ├── baidu_rnn_pruner.py │ ├── automated_gradual_pruner.py │ └── ranked_structures_pruner.py ├── learning_rate.py ├── __init__.py ├── directives.py └── thresholding.py ├── models ├── imagenet │ ├── __init__.py │ ├── resnet_earlyexit.py │ ├── mobilenet.py │ └── alexnet_batchnorm.py ├── cifar10 │ ├── __init__.py │ ├── simplenet_cifar.py │ ├── mobilenet_cifar.py │ ├── resnet_cifar_earlyexit.py │ ├── vgg_cifar.py │ └── resnet_cifar.py └── __init__.py ├── tests ├── full_flow_tests │ └── preact_resnet20_cifar_pact_test.yaml ├── test_thresholding.py ├── test_infra.py ├── common.py ├── test_basic.py ├── test_model_summary.py ├── test_learning_rate.py ├── test_loss.py ├── test_ranking.py └── test_summarygraph.py ├── apputils ├── __init__.py ├── dataset_summaries.py ├── execution_env.py └── checkpoint.py └── jupyter ├── compare_executions.ipynb ├── interactive_lr_scheduler.ipynb ├── parameter_histograms.ipynb └── performance.ipynb /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/classifier_compression/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/banner1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/banner1.png -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [FORMAT] 2 | 3 | # Maximum number of characters on a single line. 4 | max-line-length=120 5 | -------------------------------------------------------------------------------- /imgs/ch_compute_stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/ch_compute_stats.png -------------------------------------------------------------------------------- /imgs/ch_sparsity_stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/ch_sparsity_stats.png -------------------------------------------------------------------------------- /imgs/simplenet_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/simplenet_training.png -------------------------------------------------------------------------------- /imgs/resnet18-sensitivity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/resnet18-sensitivity.png -------------------------------------------------------------------------------- /imgs/resnet18_summary_png.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/resnet18_summary_png.png -------------------------------------------------------------------------------- /imgs/ch_sparsity_stats_barchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/ch_sparsity_stats_barchart.png -------------------------------------------------------------------------------- /imgs/resnet18_summary_compute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/resnet18_summary_compute.png -------------------------------------------------------------------------------- /imgs/resnet18_summary_modules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/resnet18_summary_modules.png -------------------------------------------------------------------------------- /imgs/resnet18_summary_sparsity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/resnet18_summary_sparsity.png -------------------------------------------------------------------------------- /imgs/wiki/word_lang_model_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-quant-ocs/HEAD/imgs/wiki/word_lang_model_performance.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.pyc 3 | __pycache__/ 4 | .pytest_cache 5 | .cache 6 | site/ 7 | env/ 8 | .env/ 9 | .idea/ 10 | logs/ 11 | .DS_Store 12 | -------------------------------------------------------------------------------- /OCS-CNN/example_train.sh: -------------------------------------------------------------------------------- 1 | python compress_classifier.py \ 2 | %DATA_DIR% \ 3 | -a resnet20_cifar \ 4 | --lr 0.1 -p 50 -b 128 -j 1 --epochs 200 \ 5 | --compress=./config/resnet20_cifar_base_fp32.yaml \ 6 | --out-dir="logs_resnet20_cifar_train/" \ 7 | --wd=0.0002 --vs=0 8 | -------------------------------------------------------------------------------- /OCS-CNN/config/alexnet_bn_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | lr_schedulers: 2 | training_lr: 3 | class: MultiStepLR 4 | milestones: [60, 75] 5 | gamma: 0.2 6 | 7 | policies: 8 | - lr_scheduler: 9 | instance_name: training_lr 10 | starting_epoch: 0 11 | ending_epoch: 200 12 | frequency: 1 13 | -------------------------------------------------------------------------------- /examples/quantization/alexnet_bn_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | lr_schedulers: 2 | training_lr: 3 | class: MultiStepLR 4 | milestones: [60, 75] 5 | gamma: 0.2 6 | 7 | policies: 8 | - lr_scheduler: 9 | instance_name: training_lr 10 | starting_epoch: 0 11 | ending_epoch: 200 12 | frequency: 1 13 | -------------------------------------------------------------------------------- /OCS-CNN/config/preact_resnet18_imagenet_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | lr_schedulers: 2 | training_lr: 3 | class: MultiStepLR 4 | milestones: [30, 60, 90, 100] 5 | gamma: 0.1 6 | 7 | policies: 8 | - lr_scheduler: 9 | instance_name: training_lr 10 | starting_epoch: 0 11 | ending_epoch: 200 12 | frequency: 1 13 | -------------------------------------------------------------------------------- /examples/quantization/preact_resnet18_imagenet_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | lr_schedulers: 2 | training_lr: 3 | class: MultiStepLR 4 | milestones: [30, 60, 90, 100] 5 | gamma: 0.1 6 | 7 | policies: 8 | - lr_scheduler: 9 | instance_name: training_lr 10 | starting_epoch: 0 11 | ending_epoch: 200 12 | frequency: 1 13 | -------------------------------------------------------------------------------- /OCS-CNN/config/mobilenet_cifar_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | lr_schedulers: 2 | training_lr: 3 | class: MultiStepMultiGammaLR 4 | milestones: [80, 120, 160] 5 | gammas: [0.1, 0.1, 0.2] 6 | 7 | policies: 8 | - lr_scheduler: 9 | instance_name: training_lr 10 | starting_epoch: 0 11 | ending_epoch: 200 12 | frequency: 1 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==0.4.0 2 | numpy==1.14.3 3 | torchvision==0.2.1 4 | scipy==1.1.0 5 | gitpython 6 | torchnet==0.0.4 7 | tensorflow==1.7.0 8 | pydot==1.2.4 9 | tabulate==0.8.2 10 | pandas==0.22.0 11 | jupyter==1.0.0 12 | matplotlib==2.2.2 13 | qgrid==1.0.2 14 | graphviz==0.8.2 15 | ipywidgets==7.1.2 16 | bqplot==0.10.5 17 | pyyaml 18 | pytest==3.5.1 19 | -------------------------------------------------------------------------------- /OCS-CNN/example.sh: -------------------------------------------------------------------------------- 1 | python compress_classifier.py \ 2 | %DATA_DIR% \ 3 | -a resnet50 \ 4 | -b 128 -j 1 --vs 0 \ 5 | --evaluate --pretrained \ 6 | --act-bits 8 --weight-bits 6 \ 7 | --quantize-method ocs \ 8 | --weight-expand-ratio 0.02 \ 9 | --weight-clip-threshold 1.0 \ 10 | --act-clip-threshold 1.0 \ 11 | --profile-batches 4 12 | -------------------------------------------------------------------------------- /OCS-CNN/config/resnet20_cifar_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | 2 | # python compress_classifier.py ./data.cifar10/ -a resnet20_cifar --lr 0.1 -p 50 -b 128 -j 1 3 | # --epochs 200 --compress=./config/resnet20_cifar_base_fp32.yaml --out-dir="logs/" --wd=0.0002 --vs=0 4 | 5 | lr_schedulers: 6 | training_lr: 7 | class: MultiStepMultiGammaLR 8 | milestones: [80, 120, 160] 9 | gammas: [0.1, 0.1, 0.2] 10 | 11 | policies: 12 | - lr_scheduler: 13 | instance_name: training_lr 14 | starting_epoch: 0 15 | ending_epoch: 200 16 | frequency: 1 17 | -------------------------------------------------------------------------------- /distiller/data_loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .collector import ActivationSparsityCollector 18 | from .logger import PythonLogger, TensorBoardLogger, CsvLogger 19 | 20 | del logger 21 | del collector 22 | -------------------------------------------------------------------------------- /distiller/regularization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .l1_regularizer import L1Regularizer 18 | from .group_regularizer import GroupLassoRegularizer, GroupVarianceRegularizer 19 | 20 | del l1_regularizer 21 | del group_regularizer 22 | -------------------------------------------------------------------------------- /models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains ImageNet image classification models not found in torchvision""" 18 | 19 | from .mobilenet import * 20 | from .preresnet_imagenet import * 21 | from .alexnet_batchnorm import * 22 | from .resnet_earlyexit import * 23 | -------------------------------------------------------------------------------- /OCS-CNN/config/alexnet_bn_dorefa.yaml: -------------------------------------------------------------------------------- 1 | quantizers: 2 | dorefa_quantizer: 3 | class: DorefaQuantizer 4 | bits_activations: 8 5 | bits_weights: 3 6 | bits_overrides: 7 | # Don't quantize first and last layer 8 | features.0: 9 | wts: null 10 | acts: null 11 | features.1: 12 | wts: null 13 | acts: null 14 | classifier.5: 15 | wts: null 16 | acts: null 17 | classifier.6: 18 | wts: null 19 | acts: null 20 | 21 | lr_schedulers: 22 | training_lr: 23 | class: MultiStepLR 24 | milestones: [60, 75] 25 | gamma: 0.2 26 | 27 | policies: 28 | - quantizer: 29 | instance_name: dorefa_quantizer 30 | starting_epoch: 0 31 | ending_epoch: 200 32 | frequency: 1 33 | 34 | - lr_scheduler: 35 | instance_name: training_lr 36 | starting_epoch: 0 37 | ending_epoch: 200 38 | frequency: 1 39 | -------------------------------------------------------------------------------- /OCS-CNN/config/preact_resnet18_imagenet_dorefa.yaml: -------------------------------------------------------------------------------- 1 | quantizers: 2 | dorefa_quantizer: 3 | class: DorefaQuantizer 4 | bits_activations: 8 5 | bits_weights: 3 6 | bits_overrides: 7 | # Don't quantize first and last layer 8 | conv1: 9 | wts: null 10 | acts: null 11 | relu1: 12 | wts: null 13 | acts: null 14 | final_relu: 15 | wts: null 16 | acts: null 17 | fc: 18 | wts: null 19 | acts: null 20 | 21 | lr_schedulers: 22 | training_lr: 23 | class: MultiStepLR 24 | milestones: [30, 60, 90, 100] 25 | gamma: 0.1 26 | 27 | policies: 28 | - quantizer: 29 | instance_name: dorefa_quantizer 30 | starting_epoch: 0 31 | ending_epoch: 200 32 | frequency: 1 33 | 34 | - lr_scheduler: 35 | instance_name: training_lr 36 | starting_epoch: 0 37 | ending_epoch: 200 38 | frequency: 1 39 | -------------------------------------------------------------------------------- /examples/quantization/alexnet_bn_dorefa.yaml: -------------------------------------------------------------------------------- 1 | quantizers: 2 | dorefa_quantizer: 3 | class: DorefaQuantizer 4 | bits_activations: 8 5 | bits_weights: 3 6 | bits_overrides: 7 | # Don't quantize first and last layer 8 | features.0: 9 | wts: null 10 | acts: null 11 | features.1: 12 | wts: null 13 | acts: null 14 | classifier.5: 15 | wts: null 16 | acts: null 17 | classifier.6: 18 | wts: null 19 | acts: null 20 | 21 | lr_schedulers: 22 | training_lr: 23 | class: MultiStepLR 24 | milestones: [60, 75] 25 | gamma: 0.2 26 | 27 | policies: 28 | - quantizer: 29 | instance_name: dorefa_quantizer 30 | starting_epoch: 0 31 | ending_epoch: 200 32 | frequency: 1 33 | 34 | - lr_scheduler: 35 | instance_name: training_lr 36 | starting_epoch: 0 37 | ending_epoch: 200 38 | frequency: 1 39 | -------------------------------------------------------------------------------- /examples/quantization/preact_resnet18_imagenet_dorefa.yaml: -------------------------------------------------------------------------------- 1 | quantizers: 2 | dorefa_quantizer: 3 | class: DorefaQuantizer 4 | bits_activations: 8 5 | bits_weights: 3 6 | bits_overrides: 7 | # Don't quantize first and last layer 8 | conv1: 9 | wts: null 10 | acts: null 11 | relu1: 12 | wts: null 13 | acts: null 14 | final_relu: 15 | wts: null 16 | acts: null 17 | fc: 18 | wts: null 19 | acts: null 20 | 21 | lr_schedulers: 22 | training_lr: 23 | class: MultiStepLR 24 | milestones: [30, 60, 90, 100] 25 | gamma: 0.1 26 | 27 | policies: 28 | - quantizer: 29 | instance_name: dorefa_quantizer 30 | starting_epoch: 0 31 | ending_epoch: 200 32 | frequency: 1 33 | 34 | - lr_scheduler: 35 | instance_name: training_lr 36 | starting_epoch: 0 37 | ending_epoch: 200 38 | frequency: 1 39 | -------------------------------------------------------------------------------- /models/cifar10/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains CIFAR image classification models for pytorch""" 18 | 19 | from .simplenet_cifar import * 20 | from .resnet_cifar import * 21 | from .preresnet_cifar import * 22 | from .vgg_cifar import * 23 | from .resnet_cifar_earlyexit import * 24 | from .mobilenet_cifar import * 25 | -------------------------------------------------------------------------------- /OCS-CNN/config/preact_resnet20_cifar_dorefa.yaml: -------------------------------------------------------------------------------- 1 | quantizers: 2 | dorefa_quantizer: 3 | class: DorefaQuantizer 4 | bits_activations: 8 5 | bits_weights: 3 6 | bits_overrides: 7 | # Don't quantize first and last layer 8 | conv1: 9 | wts: null 10 | acts: null 11 | layer1.0.pre_relu: 12 | wts: null 13 | acts: null 14 | final_relu: 15 | wts: null 16 | acts: null 17 | fc: 18 | wts: null 19 | acts: null 20 | 21 | lr_schedulers: 22 | training_lr: 23 | class: MultiStepMultiGammaLR 24 | milestones: [80, 120, 160] 25 | gammas: [0.1, 0.1, 0.2] 26 | 27 | policies: 28 | - quantizer: 29 | instance_name: dorefa_quantizer 30 | starting_epoch: 0 31 | ending_epoch: 200 32 | frequency: 1 33 | 34 | - lr_scheduler: 35 | instance_name: training_lr 36 | starting_epoch: 0 37 | ending_epoch: 161 38 | frequency: 1 39 | -------------------------------------------------------------------------------- /examples/quantization/preact_resnet20_cifar_dorefa.yaml: -------------------------------------------------------------------------------- 1 | quantizers: 2 | dorefa_quantizer: 3 | class: DorefaQuantizer 4 | bits_activations: 8 5 | bits_weights: 3 6 | bits_overrides: 7 | # Don't quantize first and last layer 8 | conv1: 9 | wts: null 10 | acts: null 11 | layer1.0.pre_relu: 12 | wts: null 13 | acts: null 14 | final_relu: 15 | wts: null 16 | acts: null 17 | fc: 18 | wts: null 19 | acts: null 20 | 21 | lr_schedulers: 22 | training_lr: 23 | class: MultiStepMultiGammaLR 24 | milestones: [80, 120, 160] 25 | gammas: [0.1, 0.1, 0.2] 26 | 27 | policies: 28 | - quantizer: 29 | instance_name: dorefa_quantizer 30 | starting_epoch: 0 31 | ending_epoch: 200 32 | frequency: 1 33 | 34 | - lr_scheduler: 35 | instance_name: training_lr 36 | starting_epoch: 0 37 | ending_epoch: 161 38 | frequency: 1 39 | -------------------------------------------------------------------------------- /tests/full_flow_tests/preact_resnet20_cifar_pact_test.yaml: -------------------------------------------------------------------------------- 1 | quantizers: 2 | pact_quantizer: 3 | class: PACTQuantizer 4 | act_clip_init_val: 8.0 5 | bits_activations: 4 6 | bits_weights: 3 7 | bits_overrides: 8 | # Don't quantize first and last layers 9 | conv1: 10 | wts: null 11 | acts: null 12 | layer1.0.pre_relu: 13 | wts: null 14 | acts: null 15 | final_relu: 16 | wts: null 17 | acts: null 18 | fc: 19 | wts: null 20 | acts: null 21 | 22 | lr_schedulers: 23 | training_lr: 24 | class: MultiStepLR 25 | milestones: [60, 120] 26 | gammas: 0.1 27 | 28 | policies: 29 | - quantizer: 30 | instance_name: pact_quantizer 31 | starting_epoch: 0 32 | ending_epoch: 200 33 | frequency: 1 34 | 35 | - lr_scheduler: 36 | instance_name: training_lr 37 | starting_epoch: 0 38 | ending_epoch: 121 39 | frequency: 1 40 | -------------------------------------------------------------------------------- /OCS-CNN/logging.conf: -------------------------------------------------------------------------------- 1 | [formatters] 2 | keys: simple, time_simple 3 | 4 | [handlers] 5 | keys: console, file 6 | 7 | [loggers] 8 | keys: root, app_cfg 9 | 10 | [formatter_simple] 11 | format: %(message)s 12 | 13 | [formatter_time_simple] 14 | format: %(asctime)s - %(message)s 15 | 16 | [handler_console] 17 | class: StreamHandler 18 | propagate: 0 19 | args: [] 20 | formatter: simple 21 | 22 | [handler_file] 23 | class: FileHandler 24 | mode: 'w' 25 | args=('%(logfilename)s', 'w') 26 | formatter: time_simple 27 | 28 | [logger_root] 29 | level: INFO 30 | propagate: 1 31 | handlers: console, file 32 | 33 | [logger_app_cfg] 34 | # Use this logger to log the application configuration and execution environment 35 | level: DEBUG 36 | qualname: app_cfg 37 | propagate: 0 38 | handlers: file 39 | 40 | # Example of adding a module-specific logger 41 | # Do not forget to add apputils.model_summaries to the list of keys in section [loggers] 42 | # [logger_apputils.model_summaries] 43 | # level: DEBUG 44 | # qualname: apputils.model_summaries 45 | # propagate: 0 46 | # handlers: console 47 | -------------------------------------------------------------------------------- /apputils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains Python code and classes that are meant to make your life easier, 18 | when working with distiller. 19 | 20 | """ 21 | from .data_loaders import * 22 | from .model_summaries import * 23 | from .checkpoint import * 24 | from .execution_env import * 25 | from .dataset_summaries import * 26 | 27 | del data_loaders 28 | del model_summaries 29 | del checkpoint 30 | del execution_env 31 | del dataset_summaries 32 | -------------------------------------------------------------------------------- /examples/classifier_compression/logging.conf: -------------------------------------------------------------------------------- 1 | [formatters] 2 | keys: simple, time_simple 3 | 4 | [handlers] 5 | keys: console, file 6 | 7 | [loggers] 8 | keys: root, app_cfg 9 | 10 | [formatter_simple] 11 | format: %(message)s 12 | 13 | [formatter_time_simple] 14 | format: %(asctime)s - %(message)s 15 | 16 | [handler_console] 17 | class: StreamHandler 18 | propagate: 0 19 | args: [] 20 | formatter: simple 21 | 22 | [handler_file] 23 | class: FileHandler 24 | mode: 'w' 25 | args=('%(logfilename)s', 'w') 26 | formatter: time_simple 27 | 28 | [logger_root] 29 | level: INFO 30 | propagate: 1 31 | handlers: console, file 32 | 33 | [logger_app_cfg] 34 | # Use this logger to log the application configuration and execution environment 35 | level: DEBUG 36 | qualname: app_cfg 37 | propagate: 0 38 | handlers: file 39 | 40 | # Example of adding a module-specific logger 41 | # Do not forget to add apputils.model_summaries to the list of keys in section [loggers] 42 | # [logger_apputils.model_summaries] 43 | # level: DEBUG 44 | # qualname: apputils.model_summaries 45 | # propagate: 0 46 | # handlers: console 47 | -------------------------------------------------------------------------------- /distiller/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .quantizer import Quantizer 18 | from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, SymmetricLinearQuantizer 19 | from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer 20 | from .ocs import OCSQuantizer, ocs_set_profile_mode 21 | 22 | del quantizer 23 | del range_linear 24 | del clipped_linear 25 | del ocs 26 | del ocs_impl 27 | del clip 28 | -------------------------------------------------------------------------------- /examples/sensitivity-pruning/alexnet.schedule_sensitivity_direct.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # This schedule is an example of "Direct Pruning" for Alexnet/Imagent, as 3 | # described in chapter 3 of Song Han's PhD dissertation: "EFFICIENT METHODS AND 4 | # HARDWARE FOR DEEP LEARNING" 5 | # 6 | # The pruning policy describes a single pruning phase. 7 | # The sensitivity values are the sensitivities from examples/sensitivity-pruning/alexnet.schedule_sensitivity.yaml 8 | # multiplied by 1.9. The multiplication raises the threshold to its final value. 9 | # The overall model sparsity is 90.36% 10 | # 11 | version: 1 12 | pruners: 13 | pruner1: 14 | class: 'SensitivityPruner' 15 | sensitivities: 16 | 'features.module.0.weight': 0.475 17 | 'features.module.3.weight': 0.665 18 | 'features.module.6.weight': 0.76 19 | 'features.module.8.weight': 0.855 20 | 'features.module.10.weight': 1.045 21 | 'classifier.1.weight': 1.663 22 | 'classifier.4.weight': 1.663 23 | 'classifier.6.weight': 1.188 24 | 25 | policies: 26 | - pruner: 27 | instance_name : 'pruner1' 28 | starting_epoch: 0 29 | ending_epoch: 1 30 | frequency: 1 31 | -------------------------------------------------------------------------------- /OCS-CNN/config/preact_resnet20_cifar_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | 2 | # time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1 3 | # --epochs 200 --compress=../quantization/preact_resnet20_cifar_base_fp32.yaml --out-dir="logs/" --wd=0.0002 --vs=0 4 | 5 | 6 | #2018-07-18 12:25:56,477 - --- validate (epoch=199)----------- 7 | #2018-07-18 12:25:56,477 - 10000 samples (128 per mini-batch) 8 | #2018-07-18 12:25:57,810 - Epoch: [199][ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625 9 | #2018-07-18 12:25:58,402 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307 10 | # 11 | #2018-07-18 12:25:58,404 - ==> Best validation Top1: 92.560 Epoch: 127 12 | #2018-07-18 12:25:58,404 - Saving checkpoint to: logs/checkpoint.pth.tar 13 | #2018-07-18 12:25:58,418 - --- test --------------------- 14 | #2018-07-18 12:25:58,418 - 10000 samples (128 per mini-batch) 15 | #2018-07-18 12:25:59,664 - Test: [ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625 16 | #2018-07-18 12:26:00,248 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307 17 | 18 | 19 | lr_schedulers: 20 | training_lr: 21 | class: MultiStepMultiGammaLR 22 | milestones: [80, 120, 160] 23 | gammas: [0.1, 0.1, 0.2] 24 | 25 | policies: 26 | - lr_scheduler: 27 | instance_name: training_lr 28 | starting_epoch: 0 29 | ending_epoch: 200 30 | frequency: 1 31 | -------------------------------------------------------------------------------- /examples/network_slimming/resnet20_slimming.yaml: -------------------------------------------------------------------------------- 1 | # Schedule based on: 2 | # Learning Efficient Convolutional Networks through Network Slimming 3 | # Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan, Changshui Zhang 4 | # arXiv:1708.06519v1 5 | # Aug 2017 6 | # 7 | # time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../network_slimming/resnet20_slimming.yaml -j=1 --deterministic 8 | 9 | lr_schedulers: 10 | training_lr: 11 | class: StepLR 12 | step_size: 45 13 | gamma: 0.10 14 | 15 | regularizers: 16 | Channels_groups_regularizer: 17 | class: L1Regularizer 18 | reg_regims: 19 | module.layer1.0.bn2.weight: 0.007 20 | module.layer1.1.bn2.weight: 0.007 21 | module.layer1.2.bn2.weight: 0.007 22 | module.layer2.0.bn2.weight: 0.007 23 | module.layer2.1.bn2.weight: 0.007 24 | module.layer2.2.bn2.weight: 0.007 25 | module.layer3.0.bn2.weight: 0.007 26 | module.layer3.1.bn2.weight: 0.007 27 | module.layer3.2.bn2.weight: 0.007 28 | 29 | policies: 30 | - lr_scheduler: 31 | instance_name: training_lr 32 | starting_epoch: 45 33 | ending_epoch: 300 34 | frequency: 1 35 | 36 | # - regularizer: 37 | # instance_name: Channels_groups_regularizer 38 | # starting_epoch: 0 39 | # ending_epoch: 180 40 | # frequency: 1 41 | -------------------------------------------------------------------------------- /distiller/pruning/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ 18 | :mod:`distiller.pruning` is a package implementing various pruning algorithms. 19 | """ 20 | 21 | from .magnitude_pruner import MagnitudeParameterPruner 22 | from .automated_gradual_pruner import AutomatedGradualPruner, StructuredAutomatedGradualPruner 23 | from .level_pruner import SparsityLevelParameterPruner 24 | from .sensitivity_pruner import SensitivityPruner 25 | from .structure_pruner import StructureParameterPruner 26 | from .ranked_structures_pruner import L1RankedStructureParameterPruner 27 | from .baidu_rnn_pruner import BaiduRNNPruner 28 | 29 | del magnitude_pruner 30 | del automated_gradual_pruner 31 | del level_pruner 32 | del sensitivity_pruner 33 | del structure_pruner 34 | del ranked_structures_pruner 35 | -------------------------------------------------------------------------------- /examples/quantization/preact_resnet20_cifar_base_fp32.yaml: -------------------------------------------------------------------------------- 1 | 2 | # time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1 3 | # --epochs 200 --compress=../quantization/preact_resnet20_cifar_base_fp32.yaml --out-dir="logs/" --wd=0.0002 --vs=0 4 | 5 | 6 | #2018-07-18 12:25:56,477 - --- validate (epoch=199)----------- 7 | #2018-07-18 12:25:56,477 - 10000 samples (128 per mini-batch) 8 | #2018-07-18 12:25:57,810 - Epoch: [199][ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625 9 | #2018-07-18 12:25:58,402 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307 10 | # 11 | #2018-07-18 12:25:58,404 - ==> Best validation Top1: 92.560 Epoch: 127 12 | #2018-07-18 12:25:58,404 - Saving checkpoint to: logs/checkpoint.pth.tar 13 | #2018-07-18 12:25:58,418 - --- test --------------------- 14 | #2018-07-18 12:25:58,418 - 10000 samples (128 per mini-batch) 15 | #2018-07-18 12:25:59,664 - Test: [ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625 16 | #2018-07-18 12:26:00,248 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307 17 | 18 | 19 | lr_schedulers: 20 | training_lr: 21 | class: MultiStepMultiGammaLR 22 | milestones: [80, 120, 160] 23 | gammas: [0.1, 0.1, 0.2] 24 | 25 | policies: 26 | - lr_scheduler: 27 | instance_name: training_lr 28 | starting_epoch: 0 29 | ending_epoch: 200 30 | frequency: 1 31 | -------------------------------------------------------------------------------- /OCS-CNN/scripts/clip_script.py: -------------------------------------------------------------------------------- 1 | import sys, os, re 2 | import subprocess as sp 3 | import numpy as np 4 | import timeit 5 | 6 | DATA_DIR = None 7 | MODELS = ['resnet50', 'densenet121', 'inception_v3'] 8 | RATIOS = [1.0, 0, -1, -2] 9 | WBITS = [8,7,6,5] 10 | ABITS = 8 11 | 12 | if __name__ == "__main__": 13 | if DATA_DIR is None: 14 | print('Add the ImageNet data dir to the script.') 15 | sys.exit(0) 16 | 17 | for model in MODELS: 18 | out_dir = 'logs_' + model 19 | 20 | for r in RATIOS: 21 | for wbits in WBITS: 22 | 23 | exp_name = 'clip-%da%dw-r%s' % (ABITS, wbits, r) 24 | 25 | args = [ "%s" % DATA_DIR, 26 | "--arch=%s" % model, 27 | "--evaluate", 28 | "--pretrained", 29 | "--act-bits=%d" % ABITS, 30 | "--weight-bits=%d" % wbits, 31 | "--quantize-method=%s" % "ocs", 32 | "--weight-expand-ratio=0.0", 33 | "--weight-clip-threshold=%5.3f" % r, 34 | "--act-clip-threshold=1.0", 35 | "--profile-batches=4", 36 | "-b 128", 37 | "-j 1", 38 | "--vs=0", 39 | "--out-dir=%s" % out_dir, 40 | "--name=%s" % exp_name] 41 | 42 | print("Args:") 43 | print(args) 44 | 45 | sp.call(["python", "compress_classifier.py"] + args) 46 | 47 | -------------------------------------------------------------------------------- /OCS-CNN/scripts/ocs_script.py: -------------------------------------------------------------------------------- 1 | import sys, os, re 2 | import subprocess as sp 3 | import numpy as np 4 | import timeit 5 | 6 | DATA_DIR = None 7 | MODELS = ['resnet50', 'densenet121', 'inception_v3'] 8 | RATIOS = [0, 0.01, 0.02, 0.05] 9 | WBITS = [8,7,6,5] 10 | ABITS = 8 11 | 12 | if __name__ == "__main__": 13 | if DATA_DIR is None: 14 | print('Add the ImageNet data dir to the script.') 15 | sys.exit(0) 16 | 17 | for model in MODELS: 18 | out_dir = 'logs_' + model 19 | 20 | for r in RATIOS: 21 | for wbits in WBITS: 22 | 23 | exp_name = 'ocs-%da%dw-r%s' % (ABITS, wbits, r) 24 | 25 | args = [ "%s" % DATA_DIR, 26 | "--arch=%s" % model, 27 | "--evaluate", 28 | "--pretrained", 29 | "--act-bits=%d" % ABITS, 30 | "--weight-bits=%d" % wbits, 31 | "--quantize-method=%s" % "ocs", 32 | "--weight-expand-ratio=%5.3f" % r, 33 | "--weight-clip-threshold=1.0", 34 | "--act-clip-threshold=1.0", 35 | "--profile-batches=4", 36 | "-b 128", 37 | "-j 1", 38 | "--vs=0", 39 | "--out-dir=%s" % out_dir, 40 | "--name=%s" % exp_name] 41 | 42 | print("Args:") 43 | print(args) 44 | 45 | sp.call(["python", "compress_classifier.py"] + args) 46 | 47 | -------------------------------------------------------------------------------- /distiller/regularization/regularizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | EPSILON = 1e-8 18 | 19 | class _Regularizer(object): 20 | def __init__(self, name, model, reg_regims, threshold_criteria): 21 | """Regularization base class. 22 | 23 | Args: 24 | reg_regims: regularization regiment. A dictionary of 25 | reg_regims[] = [ lambda, structure-type] 26 | """ 27 | self.name = name 28 | self.model = model 29 | self.reg_regims = reg_regims 30 | self.threshold_criteria = threshold_criteria 31 | 32 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): 33 | raise NotImplementedError 34 | 35 | def threshold(self, param, param_name, zeros_mask_dict): 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /distiller/pruning/pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import distiller 19 | 20 | class _ParameterPruner(object): 21 | """Base class for all pruners. 22 | 23 | Arguments: 24 | name: pruner name is used mainly for debugging. 25 | """ 26 | def __init__(self, name): 27 | self.name = name 28 | 29 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 30 | raise NotImplementedError 31 | 32 | def threshold_model(model, threshold): 33 | """Threshold an entire model using the provided threshold 34 | 35 | This function prunes weights only (biases are left untouched). 36 | """ 37 | for name, p in model.named_parameters(): 38 | if 'weight' in name: 39 | mask = distiller.threshold_mask(param.data, threshold) 40 | p.data = p.data.mul_(mask) 41 | -------------------------------------------------------------------------------- /tests/test_thresholding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import pytest 5 | module_path = os.path.abspath(os.path.join('..')) 6 | if module_path not in sys.path: 7 | sys.path.append(module_path) 8 | import distiller 9 | 10 | 11 | def get_test_tensor(): 12 | return torch.tensor([[1.0, 2.0, 3.0], 13 | [4.0, 5.0, 6.0], 14 | [7.0, 8.0, 9.0], 15 | [10., 11., 12.]]) 16 | 17 | 18 | def test_row_thresholding(): 19 | p = get_test_tensor().cuda() 20 | group_th = distiller.GroupThresholdMixin() 21 | mask = group_th.group_threshold_mask(p, 'Rows', 7, 'Max') 22 | assert torch.eq(mask, torch.tensor([[ 0., 0., 0.], 23 | [ 0., 0., 0.], 24 | [ 1., 1., 1.], 25 | [ 1., 1., 1.]], device=mask.device)).all() 26 | return mask 27 | 28 | 29 | def test_col_thresholding(): 30 | p = get_test_tensor().cuda() 31 | group_th = distiller.GroupThresholdMixin() 32 | mask = group_th.group_threshold_mask(p, 'Cols', 11, 'Max') 33 | assert torch.eq(mask, torch.tensor([[ 0., 0., 1.], 34 | [ 0., 0., 1.], 35 | [ 0., 0., 1.], 36 | [ 0., 0., 1.]], device=mask.device)).all() 37 | return mask 38 | 39 | if __name__ == '__main__': 40 | m = test_col_thresholding() 41 | print(m) 42 | -------------------------------------------------------------------------------- /apputils/dataset_summaries.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def dataset_summary(data_loader): 18 | """Create a histogram of class membership distribution within a dataset. 19 | 20 | It is important to examine our training, validation, and test 21 | datasets, to make sure that they are balanced. 22 | """ 23 | print("Analyzing dataset:") 24 | hist = {} 25 | for idx, (input, label_batch) in enumerate(data_loader): 26 | for label in label_batch: 27 | hist[label] = hist.get(label, 0) + 1 28 | if idx%50 == 0: 29 | print("idx: %d" % idx) 30 | 31 | nclasses = len(hist) 32 | from statistics import mean 33 | print('Dataset contains {} items'.format(len(data_loader.sampler))) 34 | print('Found {} classes'.format(nclasses)) 35 | for data_class, size in hist.iteritems(): 36 | print('\tClass {} = {}'.format(data_class, size)) 37 | 38 | print('mean: ', mean(list(hist.values()))) 39 | -------------------------------------------------------------------------------- /models/cifar10/simplenet_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | __all__ = ['simplenet_cifar'] 21 | 22 | class Simplenet(nn.Module): 23 | def __init__(self): 24 | super(Simplenet, self).__init__() 25 | self.conv1 = nn.Conv2d(3, 6, 5) 26 | self.pool = nn.MaxPool2d(2, 2) 27 | self.conv2 = nn.Conv2d(6, 16, 5) 28 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 29 | self.fc2 = nn.Linear(120, 84) 30 | self.fc3 = nn.Linear(84, 10) 31 | 32 | def forward(self, x): 33 | x = self.pool(F.relu(self.conv1(x))) 34 | x = self.pool(F.relu(self.conv2(x))) 35 | x = x.view(-1, 16 * 5 * 5) 36 | x = F.relu(self.fc1(x)) 37 | x = F.relu(self.fc2(x)) 38 | #x = nn.Threshold(0.2, 0.0)#ActivationZeroThreshold(x) 39 | x = self.fc3(x) 40 | return x 41 | 42 | def simplenet_cifar(): 43 | model = Simplenet() 44 | return model 45 | -------------------------------------------------------------------------------- /tests/test_infra.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import os 19 | import sys 20 | import pytest 21 | module_path = os.path.abspath(os.path.join('..')) 22 | if module_path not in sys.path: 23 | sys.path.append(module_path) 24 | 25 | from models import create_model 26 | from apputils import load_checkpoint 27 | 28 | def test_load(): 29 | logger = logging.getLogger('simple_example') 30 | logger.setLevel(logging.INFO) 31 | 32 | model = create_model(False, 'cifar10', 'resnet20_cifar') 33 | model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') 34 | assert compression_scheduler is not None 35 | assert start_epoch == 180 36 | 37 | def test_load_negative(): 38 | with pytest.raises(FileNotFoundError): 39 | model = create_model(False, 'cifar10', 'resnet20_cifar') 40 | model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar') 41 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import os 18 | import sys 19 | import torch 20 | module_path = os.path.abspath(os.path.join('..')) 21 | if module_path not in sys.path: 22 | sys.path.append(module_path) 23 | import distiller 24 | from models import create_model 25 | 26 | 27 | def setup_test(arch, dataset, parallel): 28 | model = create_model(False, dataset, arch, parallel=parallel) 29 | assert model is not None 30 | 31 | # Create the masks 32 | zeros_mask_dict = {} 33 | for name, param in model.named_parameters(): 34 | masker = distiller.ParameterMasker(name) 35 | zeros_mask_dict[name] = masker 36 | return model, zeros_mask_dict 37 | 38 | 39 | def find_module_by_name(model, module_to_find): 40 | for name, m in model.named_modules(): 41 | if name == module_to_find: 42 | return m 43 | return None 44 | 45 | 46 | def get_dummy_input(dataset): 47 | if dataset == "imagenet": 48 | return torch.randn(1, 3, 224, 224).cuda() 49 | elif dataset == "cifar10": 50 | return torch.randn(1, 3, 32, 32).cuda() 51 | raise ValueError("Trying to use an unknown dataset " + dataset) 52 | 53 | 54 | def almost_equal(a , b, max_diff=0.000001): 55 | return abs(a - b) <= max_diff 56 | -------------------------------------------------------------------------------- /OCS-CNN/scripts/parse_ocs.py: -------------------------------------------------------------------------------- 1 | import sys, os, re 2 | import subprocess as sp 3 | import numpy as np 4 | import timeit 5 | 6 | from os.path import isfile, join 7 | 8 | def parse_dir(dir_name): 9 | """ Get relevant info from directory name where the format 10 | is 8a4w-r0.4__stuff """ 11 | m = re.match('ocs-(\d+)a(\d+)w-r([\d|.]+)', dir_name) 12 | if m: 13 | abits = int(m.group(1)) 14 | wbits = int(m.group(2)) 15 | r = float(m.group(3)) 16 | else: 17 | #print('Cannot match directory %s' % dir_name) 18 | return None 19 | return abits, wbits, r 20 | 21 | def dir_key(dir_name): 22 | """ Used for sorting """ 23 | p = parse_dir(dir_name) 24 | if p is None: 25 | return 0 26 | abits, wbits, r = p 27 | return 1000*abits + 100*wbits + int(100*r) 28 | 29 | def find_acc(fname): 30 | """ Find the Top1 and Top5 accuracy """ 31 | with open(fname, 'r') as f: 32 | for l in f: 33 | m = re.search("==> Top1: ([\d|.]+)\W*Top5: ([\d|.]+)", l) 34 | if m: 35 | return float(m.group(1)), float(m.group(2)) 36 | return None 37 | 38 | if __name__ == "__main__": 39 | if len(sys.argv) < 2: 40 | print("Give directory name") 41 | sys.exit(0) 42 | 43 | top_dir = sys.argv[1] 44 | 45 | # Find each experiment dir 46 | dirs = sorted( os.listdir(top_dir), key=dir_key, reverse=True ) 47 | 48 | for d in dirs: 49 | # Get the bitwidths and ratio 50 | p = parse_dir(d) 51 | if p is None: 52 | continue 53 | abits, wbits, r = p 54 | 55 | # There should be exactly 1 logfile in each dir 56 | d = join(top_dir, d) 57 | logs = [f for f in os.listdir(d) if isfile(join(d,f)) and f[-4:]=='.log'] 58 | assert(len(logs) == 1) 59 | logfile = join(d, logs[0]) 60 | 61 | # Read the accuracy 62 | acc = find_acc(logfile) 63 | if acc: 64 | print("%d, %d, %4.2f, %6.3f, %6.3f" % (abits, wbits, r, acc[0], acc[1])) 65 | 66 | 67 | -------------------------------------------------------------------------------- /OCS-CNN/scripts/parse_clip.py: -------------------------------------------------------------------------------- 1 | import sys, os, re 2 | import subprocess as sp 3 | import numpy as np 4 | import timeit 5 | 6 | from os.path import isfile, join 7 | 8 | def parse_dir(dir_name): 9 | """ Get relevant info from directory name where the format 10 | is 8a4w-r0.4__stuff """ 11 | m = re.match('clip-(\d+)a(\d+)w-r([-|\d|.]+)', dir_name) 12 | if m: 13 | abits = int(m.group(1)) 14 | wbits = int(m.group(2)) 15 | r = float(m.group(3)) 16 | else: 17 | #print('Cannot match directory %s' % dir_name) 18 | return None 19 | return abits, wbits, r 20 | 21 | def dir_key(dir_name): 22 | """ Used for sorting """ 23 | p = parse_dir(dir_name) 24 | if p is None: 25 | return 0 26 | abits, wbits, r = p 27 | return 1000*abits + 100*wbits + int(100*r) 28 | 29 | def find_acc(fname): 30 | """ Find the Top1 and Top5 accuracy """ 31 | with open(fname, 'r') as f: 32 | for l in f: 33 | m = re.search("==> Top1: ([\d|.]+)\W*Top5: ([\d|.]+)", l) 34 | if m: 35 | return float(m.group(1)), float(m.group(2)) 36 | return None 37 | 38 | if __name__ == "__main__": 39 | if len(sys.argv) < 2: 40 | print("Give directory name") 41 | sys.exit(0) 42 | 43 | top_dir = sys.argv[1] 44 | 45 | # Find each experiment dir 46 | dirs = sorted( os.listdir(top_dir), key=dir_key, reverse=True ) 47 | 48 | for d in dirs: 49 | # Get the bitwidths and ratio 50 | p = parse_dir(d) 51 | if p is None: 52 | continue 53 | abits, wbits, r = p 54 | 55 | # There should be exactly 1 logfile in each dir 56 | d = join(top_dir, d) 57 | logs = [f for f in os.listdir(d) if isfile(join(d,f)) and f[-4:]=='.log'] 58 | assert(len(logs) == 1) 59 | logfile = join(d, logs[0]) 60 | 61 | # Read the accuracy 62 | acc = find_acc(logfile) 63 | if acc: 64 | print("%d, %d, %4.2f, %6.3f, %6.3f" % (abits, wbits, r, acc[0], acc[1])) 65 | 66 | 67 | -------------------------------------------------------------------------------- /OCS-CNN/config/preact_resnet20_cifar_pact.yaml: -------------------------------------------------------------------------------- 1 | 2 | # time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1 3 | # --epochs 200 --compress=../quantization/preact_resnet20_cifar_pact.yaml --out-dir="logs/" --wd=0.0002 --vs=0 4 | 5 | 6 | #2018-07-18 17:28:56,710 - --- validate (epoch=199)----------- 7 | #2018-07-18 17:28:56,710 - 10000 samples (128 per mini-batch) 8 | #2018-07-18 17:28:58,070 - Epoch: [199][ 50/ 78] Loss 0.349229 Top1 91.140625 Top5 99.671875 9 | #2018-07-18 17:28:58,670 - ==> Top1: 91.440 Top5: 99.680 Loss: 0.348 10 | # 11 | #2018-07-18 17:28:58,671 - ==> Best validation Top1: 91.860 Epoch: 147 12 | #2018-07-18 17:28:58,672 - Saving checkpoint to: logs/checkpoint.pth.tar 13 | #2018-07-18 17:28:58,687 - --- test --------------------- 14 | #2018-07-18 17:28:58,687 - 10000 samples (128 per mini-batch) 15 | #2018-07-18 17:29:00,006 - Test: [ 50/ 78] Loss 0.349229 Top1 91.140625 Top5 99.671875 16 | #2018-07-18 17:29:00,560 - ==> Top1: 91.440 Top5: 99.680 Loss: 0.348 17 | 18 | 19 | quantizers: 20 | pact_quantizer: 21 | class: PACTQuantizer 22 | act_clip_init_val: 8.0 23 | bits_activations: 4 24 | bits_weights: 3 25 | bits_overrides: 26 | # Don't quantize first and last layers 27 | conv1: 28 | wts: null 29 | acts: null 30 | layer1.0.pre_relu: 31 | wts: null 32 | acts: null 33 | final_relu: 34 | wts: null 35 | acts: null 36 | fc: 37 | wts: null 38 | acts: null 39 | 40 | lr_schedulers: 41 | training_lr: 42 | class: MultiStepLR 43 | milestones: [60, 120] 44 | gammas: 0.1 45 | 46 | policies: 47 | - quantizer: 48 | instance_name: pact_quantizer 49 | starting_epoch: 0 50 | ending_epoch: 200 51 | frequency: 1 52 | 53 | - lr_scheduler: 54 | instance_name: training_lr 55 | starting_epoch: 0 56 | ending_epoch: 121 57 | frequency: 1 58 | -------------------------------------------------------------------------------- /examples/quantization/preact_resnet20_cifar_pact.yaml: -------------------------------------------------------------------------------- 1 | 2 | # time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1 3 | # --epochs 200 --compress=../quantization/preact_resnet20_cifar_pact.yaml --out-dir="logs/" --wd=0.0002 --vs=0 4 | 5 | 6 | #2018-07-18 17:28:56,710 - --- validate (epoch=199)----------- 7 | #2018-07-18 17:28:56,710 - 10000 samples (128 per mini-batch) 8 | #2018-07-18 17:28:58,070 - Epoch: [199][ 50/ 78] Loss 0.349229 Top1 91.140625 Top5 99.671875 9 | #2018-07-18 17:28:58,670 - ==> Top1: 91.440 Top5: 99.680 Loss: 0.348 10 | # 11 | #2018-07-18 17:28:58,671 - ==> Best validation Top1: 91.860 Epoch: 147 12 | #2018-07-18 17:28:58,672 - Saving checkpoint to: logs/checkpoint.pth.tar 13 | #2018-07-18 17:28:58,687 - --- test --------------------- 14 | #2018-07-18 17:28:58,687 - 10000 samples (128 per mini-batch) 15 | #2018-07-18 17:29:00,006 - Test: [ 50/ 78] Loss 0.349229 Top1 91.140625 Top5 99.671875 16 | #2018-07-18 17:29:00,560 - ==> Top1: 91.440 Top5: 99.680 Loss: 0.348 17 | 18 | 19 | quantizers: 20 | pact_quantizer: 21 | class: PACTQuantizer 22 | act_clip_init_val: 8.0 23 | bits_activations: 4 24 | bits_weights: 3 25 | bits_overrides: 26 | # Don't quantize first and last layers 27 | conv1: 28 | wts: null 29 | acts: null 30 | layer1.0.pre_relu: 31 | wts: null 32 | acts: null 33 | final_relu: 34 | wts: null 35 | acts: null 36 | fc: 37 | wts: null 38 | acts: null 39 | 40 | lr_schedulers: 41 | training_lr: 42 | class: MultiStepLR 43 | milestones: [60, 120] 44 | gammas: 0.1 45 | 46 | policies: 47 | - quantizer: 48 | instance_name: pact_quantizer 49 | starting_epoch: 0 50 | ending_epoch: 200 51 | frequency: 1 52 | 53 | - lr_scheduler: 54 | instance_name: training_lr 55 | starting_epoch: 0 56 | ending_epoch: 121 57 | frequency: 1 58 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import os 19 | import sys 20 | module_path = os.path.abspath(os.path.join('..')) 21 | if module_path not in sys.path: 22 | sys.path.append(module_path) 23 | import distiller 24 | import models 25 | 26 | def test_sparsity(): 27 | zeros = torch.zeros(2,3,5,6) 28 | print(distiller.sparsity(zeros)) 29 | assert distiller.sparsity(zeros) == 1.0 30 | assert distiller.sparsity_3D(zeros) == 1.0 31 | assert distiller.density_3D(zeros) == 0.0 32 | 33 | 34 | ones = torch.zeros(12,43,4,6) 35 | ones.fill_(1) 36 | assert distiller.sparsity(ones) == 0.0 37 | 38 | def test_utils(): 39 | model = models.create_model(False, 'cifar10', 'resnet20_cifar', parallel=False) 40 | assert model is not None 41 | 42 | p = distiller.model_find_param(model, "") 43 | assert p is None 44 | 45 | # Search for a parameter by its "non-parallel" name 46 | p = distiller.model_find_param(model, "layer1.0.conv1.weight") 47 | assert p is not None 48 | 49 | # Search for a module name 50 | module_to_find = None 51 | for name, m in model.named_modules(): 52 | if name == "layer1.0.conv1": 53 | module_to_find = m 54 | break 55 | assert module_to_find is not None 56 | 57 | module_name = distiller.model_find_module_name(model, module_to_find) 58 | assert module_name == "layer1.0.conv1" 59 | -------------------------------------------------------------------------------- /distiller/quantization/q_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | 19 | 20 | def symmetric_linear_quantization_scale_factor(num_bits, saturation_val): 21 | # Leave one bit for sign 22 | n = 2 ** (num_bits - 1) - 1 23 | return n / saturation_val 24 | 25 | 26 | def asymmetric_linear_quantization_scale_factor(num_bits, saturation_min, saturation_max): 27 | n = 2 ** num_bits - 1 28 | return n / (saturation_max - saturation_min) 29 | 30 | 31 | def clamp(input, min, max, inplace=False): 32 | if inplace: 33 | input.clamp_(min, max) 34 | return input 35 | return torch.clamp(input, min, max) 36 | 37 | 38 | def linear_quantize(input, scale_factor, inplace=False): 39 | if inplace: 40 | input.mul_(scale_factor).round_() 41 | return input 42 | return torch.round(scale_factor * input) 43 | 44 | 45 | def linear_quantize_clamp(input, scale_factor, clamp_min, clamp_max, inplace=False): 46 | output = linear_quantize(input, scale_factor, inplace) 47 | return clamp(output, clamp_min, clamp_max, inplace) 48 | 49 | 50 | def linear_dequantize(input, scale_factor, inplace=False): 51 | if inplace: 52 | input.div_(scale_factor) 53 | return input 54 | return input / scale_factor 55 | 56 | 57 | def get_tensor_max_abs(tensor): 58 | return max(abs(tensor.max().item()), abs(tensor.min().item())) 59 | 60 | 61 | def get_quantized_range(num_bits, signed=True): 62 | if signed: 63 | n = 2 ** (num_bits - 1) 64 | return -n, n - 1 65 | return 0, 2 ** num_bits - 1 66 | -------------------------------------------------------------------------------- /distiller/pruning/structure_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | from .pruner import _ParameterPruner 19 | import distiller 20 | msglogger = logging.getLogger() 21 | 22 | class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner): 23 | """Prune parameter structures. 24 | 25 | Pruning criterion: average L1-norm. If the average L1-norm (absolute value) of the eleements 26 | in the structure is below threshold, then the structure is pruned. 27 | 28 | We use the average, instead of plain L1-norm, because we don't want the threshold to depend on 29 | the structure size. 30 | """ 31 | def __init__(self, name, model, reg_regims, threshold_criteria): 32 | super(StructureParameterPruner, self).__init__(name) 33 | self.name = name 34 | self.model = model 35 | self.reg_regims = reg_regims 36 | self.threshold_criteria = threshold_criteria 37 | assert threshold_criteria in ["Max", "Mean_Abs"] 38 | 39 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 40 | if param_name not in self.reg_regims.keys(): 41 | return 42 | 43 | group_type = self.reg_regims[param_name][1] 44 | threshold = self.reg_regims[param_name][0] 45 | zeros_mask_dict[param_name].mask = self.group_threshold_mask(param, 46 | group_type, 47 | threshold, 48 | self.threshold_criteria) 49 | -------------------------------------------------------------------------------- /distiller/regularization/l1_regularizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """L1-norm regularization""" 18 | 19 | import torch 20 | import math 21 | import numpy as np 22 | import distiller 23 | from .regularizer import _Regularizer, EPSILON 24 | 25 | class L1Regularizer(_Regularizer): 26 | def __init__(self, name, model, reg_regims, threshold_criteria=None): 27 | super(L1Regularizer, self).__init__(name, model, reg_regims, threshold_criteria) 28 | 29 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): 30 | if param_name in self.reg_regims: 31 | strength = self.reg_regims[param_name] 32 | regularizer_loss += L1Regularizer.__add_l1(param, strength) 33 | 34 | return regularizer_loss 35 | 36 | def threshold(self, param, param_name, zeros_mask_dict): 37 | """Soft threshold for L1-norm regularizer""" 38 | if self.threshold_criteria is None or param_name not in self.reg_regims: 39 | return 40 | 41 | strength = self.reg_regims[param_name] 42 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold=strength) 43 | zeros_mask_dict[param_name].is_regularization_mask = True 44 | 45 | @staticmethod 46 | def __add_l1(var, strength): 47 | return var.abs().sum() * strength 48 | 49 | @staticmethod 50 | def __add_l1_all(loss, model, reg_regims): 51 | for param_name, param in model.named_parameters(): 52 | if param_name in reg_regims.keys(): 53 | strength = reg_regims[param_name] 54 | loss += L1Regularizer.__add_l1(param, strength) 55 | -------------------------------------------------------------------------------- /models/cifar10/mobilenet_cifar.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | __all__ = ['mobilenet_cifar'] 11 | 12 | class Block(nn.Module): 13 | '''Depthwise conv + Pointwise conv''' 14 | def __init__(self, in_planes, out_planes, stride=1): 15 | super(Block, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 17 | self.bn1 = nn.BatchNorm2d(in_planes) 18 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn2 = nn.BatchNorm2d(out_planes) 20 | 21 | def forward(self, x): 22 | out = F.relu(self.bn1(self.conv1(x))) 23 | out = F.relu(self.bn2(self.conv2(out))) 24 | return out 25 | 26 | 27 | class MobileNet(nn.Module): 28 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 29 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 30 | 31 | def __init__(self, num_classes=10): 32 | super(MobileNet, self).__init__() 33 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(32) 35 | self.layers = self._make_layers(in_planes=32) 36 | self.linear = nn.Linear(1024, num_classes) 37 | 38 | def _make_layers(self, in_planes): 39 | layers = [] 40 | for x in self.cfg: 41 | out_planes = x if isinstance(x, int) else x[0] 42 | stride = 1 if isinstance(x, int) else x[1] 43 | layers.append(Block(in_planes, out_planes, stride)) 44 | in_planes = out_planes 45 | return nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.layers(out) 50 | out = F.avg_pool2d(out, 2) 51 | out = out.view(out.size(0), -1) 52 | out = self.linear(out) 53 | return out 54 | 55 | def mobilenet_cifar(**kwargs): 56 | model = MobileNet(**kwargs) 57 | return model 58 | 59 | -------------------------------------------------------------------------------- /tests/test_model_summary.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import torch 19 | import os 20 | import sys 21 | module_path = os.path.abspath(os.path.join('..')) 22 | if module_path not in sys.path: 23 | sys.path.append(module_path) 24 | import distiller 25 | import pytest 26 | import common # common test code 27 | import apputils 28 | 29 | # Logging configuration 30 | logging.basicConfig(level=logging.INFO) 31 | fh = logging.FileHandler('test.log') 32 | logger = logging.getLogger() 33 | logger.addHandler(fh) 34 | 35 | 36 | def test_png_generation(): 37 | DATASET = "cifar10" 38 | ARCH = "resnet20_cifar" 39 | model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True) 40 | # 2 different ways to create a PNG 41 | apputils.draw_img_classifier_to_file(model, 'model.png', DATASET, True) 42 | apputils.draw_img_classifier_to_file(model, 'model.png', DATASET, False) 43 | 44 | 45 | def test_negative(): 46 | DATASET = "cifar10" 47 | ARCH = "resnet20_cifar" 48 | model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True) 49 | 50 | with pytest.raises(ValueError): 51 | # png is not a supported summary type, so we expect this to fail with a ValueError 52 | distiller.model_summary(model, what='png', dataset=DATASET) 53 | 54 | 55 | def test_summary(): 56 | DATASET = "cifar10" 57 | ARCH = "resnet20_cifar" 58 | model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True) 59 | 60 | distiller.model_summary(model, what='sparsity', dataset=DATASET) 61 | distiller.model_summary(model, what='compute', dataset=DATASET) 62 | distiller.model_summary(model, what='model', dataset=DATASET) 63 | distiller.model_summary(model, what='modules', dataset=DATASET) 64 | -------------------------------------------------------------------------------- /distiller/pruning/level_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | from .pruner import _ParameterPruner 19 | import distiller 20 | 21 | class SparsityLevelParameterPruner(_ParameterPruner): 22 | """Prune to an exact pruning level specification. 23 | 24 | This pruner is very similar to MagnitudeParameterPruner, but instead of 25 | specifying an absolute threshold for pruning, you specify a target sparsity 26 | level (expressed as a fraction: 0.5 means 50% sparsity.) 27 | 28 | To find the correct threshold, we view the tensor as one large 1D vector, sort 29 | it using the absolute values of the elements, and then take topk elements. 30 | """ 31 | 32 | def __init__(self, name, levels, **kwargs): 33 | super(SparsityLevelParameterPruner, self).__init__(name) 34 | self.levels = levels 35 | assert self.levels 36 | 37 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 38 | # If there is a specific sparsity level specified for this module, then 39 | # use it. Otherwise try to use the default level ('*'). 40 | desired_sparsity = self.levels.get(param_name, self.levels.get('*', 0)) 41 | if desired_sparsity == 0: 42 | return 43 | 44 | self.prune_level(param, param_name, zeros_mask_dict, desired_sparsity) 45 | 46 | @staticmethod 47 | def prune_level(param, param_name, zeros_mask_dict, desired_sparsity): 48 | bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True) 49 | threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away 50 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold) 51 | -------------------------------------------------------------------------------- /distiller/pruning/magnitude_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | import distiller 19 | 20 | 21 | class MagnitudeParameterPruner(_ParameterPruner): 22 | """This is the most basic magnitude-based pruner. 23 | 24 | This pruner supports configuring a scalar threshold for each layer. 25 | A default threshold is mandatory and is used for layers without explicit 26 | threshold setting. 27 | 28 | """ 29 | def __init__(self, name, thresholds, **kwargs): 30 | """ 31 | Usually, a Pruner is constructed by the compression schedule parser 32 | found in distiller/config.py. 33 | The constructor is passed a dictionary of thresholds, as explained below. 34 | 35 | Args: 36 | name (string): the name of the pruner (used only for debug) 37 | thresholds (dict): a disctionary of thresholds, with the key being the 38 | parameter name. 39 | A special key, '*', represents the default threshold value. If 40 | set_param_mask is invoked on a parameter tensor that does not have 41 | an explicit entry in the 'thresholds' dictionary, then this default 42 | value is used. 43 | Currently it is mandatory to include a '*' key in 'thresholds'. 44 | """ 45 | super(MagnitudeParameterPruner, self).__init__(name) 46 | assert thresholds is not None 47 | # Make sure there is a default threshold to use 48 | assert '*' in thresholds 49 | self.thresholds = thresholds 50 | 51 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 52 | threshold = self.thresholds.get(param_name, self.thresholds['*']) 53 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold) 54 | -------------------------------------------------------------------------------- /examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # This schedule performs 3D (filter-wise) regularization of some of the convolution layers, together with 3 | # element-wise pruning using sensitivity-pruning. 4 | # 5 | # time python3 compress_classifier.py -a=vgg19 -p=50 ../../../data.imagenet --epochs=10 --lr=0.00001 --compress=../pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml --pretrained 6 | # 7 | 8 | 9 | version: 1 10 | pruners: 11 | vgg_manual: 12 | class: 'L1RankedStructureParameterPruner' 13 | reg_regims: 14 | # 'features.module.0.weight': [0.1, '3D'] 15 | 'features.module.2.weight': [0.1, '3D'] 16 | 'features.module.5.weight': [0.1, '3D'] 17 | 'features.module.7.weight': [0.1, '3D'] 18 | 'features.module.10.weight': [0.1, '3D'] 19 | 'features.module.12.weight': [0.1, '3D'] 20 | 'features.module.14.weight': [0.1, '3D'] 21 | 'features.module.16.weight': [0.1, '3D'] 22 | 'features.module.19.weight': [0.1, '3D'] 23 | 24 | vgg_manual2: 25 | class: 'L1RankedStructureParameterPruner' 26 | reg_regims: 27 | 'features.module.21.weight': [0.1, '3D'] 28 | 'features.module.23.weight': [0.1, '3D'] 29 | 'features.module.25.weight': [0.1, '3D'] 30 | 'features.module.28.weight': [0.1, '3D'] 31 | 'features.module.30.weight': [0.1, '3D'] 32 | 'features.module.32.weight': [0.1, '3D'] 33 | 'features.module.34.weight': [0.1, '3D'] 34 | 35 | extensions: 36 | net_thinner: 37 | class: 'FilterRemover' 38 | thinning_func_str: remove_filters 39 | arch: 'vgg19' 40 | dataset: 'imagenet' 41 | 42 | lr_schedulers: 43 | # Learning rate decay scheduler 44 | pruning_lr: 45 | class: StepLR 46 | step_size: 50 47 | gamma: 0.10 48 | 49 | 50 | policies: 51 | - pruner: 52 | instance_name: vgg_manual 53 | epochs: [0] 54 | 55 | - extension: 56 | instance_name: net_thinner 57 | epochs: [0] 58 | 59 | - pruner: 60 | instance_name: vgg_manual2 61 | epochs: [1] 62 | 63 | - extension: 64 | instance_name: net_thinner 65 | epochs: [1] 66 | 67 | - pruner: 68 | instance_name: vgg_manual 69 | epochs: [4] 70 | 71 | - extension: 72 | instance_name: net_thinner 73 | epochs: [4] 74 | 75 | - pruner: 76 | instance_name: vgg_manual2 77 | epochs: [5] 78 | 79 | - extension: 80 | instance_name: net_thinner 81 | epochs: [5] 82 | -------------------------------------------------------------------------------- /tests/test_learning_rate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import os 18 | import sys 19 | import pytest 20 | module_path = os.path.abspath(os.path.join('..')) 21 | if module_path not in sys.path: 22 | sys.path.append(module_path) 23 | 24 | import torch 25 | from torch.optim import Optimizer 26 | from distiller.learning_rate import MultiStepMultiGammaLR 27 | 28 | 29 | def test_multi_step_multi_gamma_lr(): 30 | dummy_tensor = torch.zeros(3, 3, 3, requires_grad=True) 31 | dummy_optimizer = Optimizer([dummy_tensor], {'lr': 0.1}) 32 | 33 | # Test input checks 34 | with pytest.raises(ValueError): 35 | lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[60, 30, 80], gammas=[0.1, 0.1, 0.2]) 36 | with pytest.raises(ValueError): 37 | lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60], gammas=[0.1, 0.1, 0.2]) 38 | with pytest.raises(ValueError): 39 | lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1]) 40 | 41 | # Test functionality 42 | lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1, 0.2]) 43 | expected_gammas = [1, 1 * 0.1, 1 * 0.1 * 0.1, 1 * 0.1 * 0.1 * 0.2] 44 | expected_lrs = [0.1 * gamma for gamma in expected_gammas] 45 | assert lr_sched.multiplicative_gammas == expected_gammas 46 | lr_sched.step(0) 47 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[0] 48 | lr_sched.step(15) 49 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[0] 50 | lr_sched.step(30) 51 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[1] 52 | lr_sched.step(33) 53 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[1] 54 | lr_sched.step(60) 55 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[2] 56 | lr_sched.step(79) 57 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[2] 58 | lr_sched.step(80) 59 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[3] 60 | lr_sched.step(100) 61 | assert dummy_optimizer.param_groups[0]['lr'] == expected_lrs[3] 62 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import os 19 | import sys 20 | import torch.nn as nn 21 | from copy import deepcopy 22 | import pytest 23 | 24 | module_path = os.path.abspath(os.path.join('..')) 25 | if module_path not in sys.path: 26 | sys.path.append(module_path) 27 | from distiller import ScheduledTrainingPolicy, CompressionScheduler 28 | from distiller.policy import PolicyLoss, LossComponent 29 | 30 | 31 | class DummyPolicy(ScheduledTrainingPolicy): 32 | def __init__(self, idx): 33 | super(DummyPolicy, self).__init__() 34 | self.loss_val = torch.randint(0, 10000, (1,)) 35 | self.idx = idx 36 | 37 | def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, 38 | zeros_mask_dict, optimizer=None): 39 | return PolicyLoss(loss + self.loss_val, [LossComponent('Dummy Loss ' + str(self.idx), self.loss_val)]) 40 | 41 | 42 | @pytest.mark.parametrize("check_loss_components", [False, True]) 43 | def test_multiple_policies_loss(check_loss_components): 44 | model = nn.Module() 45 | scheduler = CompressionScheduler(model, device=torch.device('cpu')) 46 | num_policies = 3 47 | expected_overall_loss = 0 48 | expected_policy_losses = [] 49 | for i in range(num_policies): 50 | policy = DummyPolicy(i) 51 | expected_overall_loss += policy.loss_val 52 | expected_policy_losses.append(policy.loss_val) 53 | scheduler.add_policy(policy, epochs=[0]) 54 | 55 | main_loss = torch.randint(0, 10000, (1,)) 56 | expected_overall_loss += main_loss 57 | main_loss_before = deepcopy(main_loss) 58 | 59 | policies_loss = scheduler.before_backward_pass(0, 0, 1, main_loss, return_loss_components=check_loss_components) 60 | 61 | assert main_loss_before == main_loss 62 | if check_loss_components: 63 | assert expected_overall_loss == policies_loss.overall_loss 64 | for idx, lc in enumerate(policies_loss.loss_components): 65 | assert lc.name == 'Dummy Loss ' + str(idx) 66 | assert expected_policy_losses[idx] == lc.value.item() 67 | else: 68 | assert expected_overall_loss == policies_loss 69 | -------------------------------------------------------------------------------- /distiller/pruning/sensitivity_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | import distiller 19 | import torch 20 | 21 | class SensitivityPruner(_ParameterPruner): 22 | """Use algorithm from "Learning both Weights and Connections for Efficient 23 | Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf 24 | 25 | I.e.: "The pruning threshold is chosen as a quality parameter multiplied 26 | by the standard deviation of a layers weights." 27 | In this code, the "quality parameter" is referred to as "sensitivity" and 28 | is based on the values learned from performing sensitivity analysis. 29 | 30 | Note that this implementation deviates slightly from the algorithm Song Han 31 | describes in his PhD dissertation, in that the threshold value is set only 32 | once. In his PhD dissertation, Song Han describes a growing threshold, at 33 | each iteration. This requires n+1 hyper-parameters (n being the number of 34 | pruning iterations we use): the threshold and the threshold increase (delta) 35 | at each pruning iteration. 36 | The implementation that follows, takes advantage of the fact that as pruning 37 | progresses, more weights are pulled toward zero, and therefore the threshold 38 | "traps" more weights. Thus, we can use less hyper-parameters and achieve the 39 | same results. 40 | """ 41 | 42 | def __init__(self, name, sensitivities, **kwargs): 43 | super(SensitivityPruner, self).__init__(name) 44 | self.sensitivities = sensitivities 45 | 46 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 47 | if not hasattr(param, 'stddev'): 48 | param.stddev = torch.std(param).item() 49 | 50 | if param_name not in self.sensitivities: 51 | if '*' not in self.sensitivities: 52 | return 53 | else: 54 | sensitivity = self.sensitivities['*'] 55 | else: 56 | sensitivity = self.sensitivities[param_name] 57 | 58 | threshold = param.stddev * sensitivity 59 | 60 | # After computing the threshold, we can create the mask 61 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold) 62 | -------------------------------------------------------------------------------- /distiller/quantization/ocs_impl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from .q_utils import * 6 | 7 | #------------------------------------------------------------------------- 8 | # For weights we can use numpy since we perform quantization once 9 | # at the beginning of testing 10 | #------------------------------------------------------------------------- 11 | def ocs_wts(weight_np, expand_ratio, axis=1, split_threshold=0.5, w_scale=None, grid_aware=False): 12 | """ Basic net2net and split the channels into 2 equal parts """ 13 | assert(axis == 1) 14 | assert(grid_aware == False or w_scale is not None) 15 | weight_np = weight_np.copy() 16 | 17 | # Identify channel to split 18 | # weight layout is O x I x h x w 19 | num_channels = weight_np.shape[axis] 20 | ocs_channels = int(np.ceil(expand_ratio * num_channels)) 21 | 22 | if ocs_channels == 0: 23 | return weight_np, [] 24 | 25 | # Which act channels to copy 26 | in_channels_to_copy = [] 27 | # Mapping from newly added channels to the orig channels they split from 28 | orig_idx_dict = {} 29 | 30 | for c in range(ocs_channels): 31 | # pick the channels with the largest max values 32 | axes = list(range(weight_np.ndim)) 33 | axes.remove(axis) 34 | max_per_channel = np.max(np.abs(weight_np), axis=tuple(axes)) 35 | # Sort and compute which channel to split 36 | idxs = np.flip(np.argsort(max_per_channel), axis=0) 37 | split_idx = idxs[0] 38 | 39 | # Split channel 40 | ch_slice = weight_np[:, split_idx:(split_idx+1), :, :].copy() 41 | 42 | ch_slice_half = ch_slice / 2. 43 | ch_slice_zero = np.zeros_like(ch_slice) 44 | split_value = np.max(ch_slice) * split_threshold 45 | 46 | if not grid_aware: 47 | ch_slice_1 = np.where(np.abs(ch_slice) > split_value, ch_slice_half, ch_slice) 48 | ch_slice_2 = np.where(np.abs(ch_slice) > split_value, ch_slice_half, ch_slice_zero) 49 | else: 50 | ch_slice_half *= w_scale 51 | ch_slice_1 = np.where(np.abs(ch_slice) > split_value, ch_slice_half-0.25, ch_slice*w_scale) / w_scale 52 | ch_slice_2 = np.where(np.abs(ch_slice) > split_value, ch_slice_half+0.25, ch_slice_zero) / w_scale 53 | 54 | weight_np[:, split_idx:(split_idx+1), :, :] = ch_slice_1 55 | weight_np = np.concatenate((weight_np, ch_slice_2), axis=axis) 56 | 57 | # Record which channel was split 58 | if split_idx < num_channels: 59 | in_channels_to_copy.append(split_idx) 60 | orig_idx_dict[num_channels+c] = split_idx 61 | else: 62 | idx_to_copy = orig_idx_dict[split_idx] 63 | in_channels_to_copy.append(idx_to_copy) 64 | orig_idx_dict[num_channels+c] = idx_to_copy 65 | 66 | return weight_np, in_channels_to_copy 67 | -------------------------------------------------------------------------------- /distiller/data_loggers/tbbackend.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | """ A TensorBoard backend. 17 | 18 | Writes logs to a file using a Google's TensorBoard protobuf format. 19 | See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto 20 | """ 21 | import tensorflow as tf 22 | import numpy as np 23 | 24 | class TBBackend(object): 25 | 26 | def __init__(self, log_dir): 27 | self.writer = tf.summary.FileWriter(log_dir) 28 | 29 | def scalar_summary(self, tag, scalar, step): 30 | """From TF documentation: 31 | tag: name for the data. Used by TensorBoard plugins to organize data. 32 | value: value associated with the tag (a float). 33 | """ 34 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)]) 35 | self.writer.add_summary(summary, step) 36 | 37 | def histogram_summary(self, tag, tensor, step): 38 | """ 39 | From the TF documentation: 40 | tf.summary.histogram takes an arbitrarily sized and shaped Tensor, and 41 | compresses it into a histogram data structure consisting of many bins with 42 | widths and counts. 43 | 44 | TensorFlow uses non-uniformly distributed bins, which is better than using 45 | numpy's uniform bins for activations and parameters which converge around zero, 46 | but we don't add that logic here. 47 | 48 | https://www.tensorflow.org/programmers_guide/tensorboard_histograms 49 | """ 50 | hist, edges = np.histogram(tensor, bins=200) 51 | tfhist = tf.HistogramProto( 52 | min = np.min(tensor), 53 | max = np.max(tensor), 54 | num = int(np.prod(tensor.shape)), 55 | sum = np.sum(tensor), 56 | sum_squares = np.sum(np.square(tensor))) 57 | 58 | # From the TF documentation: 59 | # Parallel arrays encoding the bucket boundaries and the bucket values. 60 | # bucket(i) is the count for the bucket i. The range for a bucket is: 61 | # i == 0: -DBL_MAX .. bucket_limit(0) 62 | # i != 0: bucket_limit(i-1) .. bucket_limit(i) 63 | tfhist.bucket_limit.extend(edges[1:]) 64 | tfhist.bucket.extend(hist) 65 | 66 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=tfhist)]) 67 | self.writer.add_summary(summary, step) 68 | 69 | def sync_to_file(self): 70 | self.writer.flush() 71 | -------------------------------------------------------------------------------- /jupyter/compare_executions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Compare experiment executions\n", 8 | "\n", 9 | "This notebook let's you qickly compare the training progress of your experiments.\n", 10 | "You will need to have the tfevents files (these are TensorBoard formatted log files that Distiller creates)." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "scrolled": false 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import tensorflow as tf\n", 22 | "import numpy as np\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "\n", 25 | "def get_performance_data(path_to_events_file, tag):\n", 26 | " \"\"\"Extract the performance history of data named 'tag'\n", 27 | "\n", 28 | " Based on sample code from TF:\n", 29 | " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/summary_iterator.py\n", 30 | " \"\"\"\n", 31 | " data = []\n", 32 | " steps = []\n", 33 | " for e in tf.train.summary_iterator(path_to_events_file):\n", 34 | " for v in e.summary.value:\n", 35 | " if v.tag == tag:\n", 36 | " data.append(v.simple_value)\n", 37 | " steps.append(e.step)\n", 38 | " return steps, data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "scrolled": false 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "# Here insert your own tfevents files to compare\n", 50 | "# WARNING: these files do not exist in the repositroy (too large) and will give you an error\n", 51 | "experiment_files = [('events.out.tfevents.1523290172.one-machine', 'experiment 1'),\n", 52 | " ('events.out.tfevents.1520430112.one-machine', 'experiment 2')]\n", 53 | "\n", 54 | "# Choose which performance indicators you wish to graph\n", 55 | "tags = ['Peformance/Validation/Top1', 'Peformance/Validation/Loss', \n", 56 | " 'sprasity/weights/total', 'Peformance/Training/Reg Loss']\n", 57 | "\n", 58 | "f, axs = plt.subplots(2, 2, figsize=(20,20))\n", 59 | "f.suptitle('Performance')\n", 60 | "\n", 61 | "for experiment in experiment_files:\n", 62 | " add_experiment(axs, tags, experiment[0], label=experiment[1])\n", 63 | "plt.tight_layout()\n", 64 | "\n", 65 | "plt.show()" 66 | ] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "display_name": "Python 3", 72 | "language": "python", 73 | "name": "python3" 74 | }, 75 | "language_info": { 76 | "codemirror_mode": { 77 | "name": "ipython", 78 | "version": 3 79 | }, 80 | "file_extension": ".py", 81 | "mimetype": "text/x-python", 82 | "name": "python", 83 | "nbconvert_exporter": "python", 84 | "pygments_lexer": "ipython3", 85 | "version": "3.5.2" 86 | } 87 | }, 88 | "nbformat": 4, 89 | "nbformat_minor": 2 90 | } 91 | -------------------------------------------------------------------------------- /distiller/learning_rate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from bisect import bisect_right 18 | from torch.optim.lr_scheduler import _LRScheduler 19 | 20 | 21 | class PolynomialLR(_LRScheduler): 22 | """Set the learning rate for each parameter group using a polynomial defined as: 23 | lr = base_lr * (1 - T_cur/T_max) ^ (power), where T_cur is the current epoch and T_max is the maximum number of 24 | epochs. 25 | 26 | Args: 27 | optimizer (Optimizer): Wrapped optimizer. 28 | T_max (int): Maximum number of epochs 29 | power (int): Degree of polynomial 30 | last_epoch (int): The index of last epoch. Default: -1. 31 | """ 32 | def __init__(self, optimizer, T_max, power, last_epoch=-1): 33 | self.T_max = T_max 34 | self.power = power 35 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | # base_lr * (1 - iter/max_iter) ^ (power) 39 | return [base_lr * (1 - self.last_epoch / self.T_max) ** self.power 40 | for base_lr in self.base_lrs] 41 | 42 | 43 | class MultiStepMultiGammaLR(_LRScheduler): 44 | """Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone. 45 | 46 | Args: 47 | optimizer (Optimizer): Wrapped optimizer. 48 | milestones (list): List of epoch indices. Must be increasing. 49 | gammas (list): List of gamma values. Must have same length as milestones. 50 | last_epoch (int): The index of last epoch. Default: -1. 51 | """ 52 | def __init__(self, optimizer, milestones, gammas, last_epoch=-1): 53 | if not list(milestones) == sorted(milestones): 54 | raise ValueError('Milestones should be a list of' 55 | ' increasing integers. Got {}', milestones) 56 | if len(milestones) != len(gammas): 57 | raise ValueError('Milestones and Gammas lists should be of same length.') 58 | 59 | self.milestones = milestones 60 | self.multiplicative_gammas = [1] 61 | for idx, gamma in enumerate(gammas): 62 | self.multiplicative_gammas.append(gamma * self.multiplicative_gammas[idx]) 63 | 64 | super(MultiStepMultiGammaLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | idx = bisect_right(self.milestones, self.last_epoch) 68 | return [base_lr * self.multiplicative_gammas[idx] for base_lr in self.base_lrs] 69 | -------------------------------------------------------------------------------- /examples/gss/gss_channels-removal_training.yaml: -------------------------------------------------------------------------------- 1 | # GSS (Guided Structured Sparsity). 2 | # "Attention-Based Guided Structured Sparsity of Deep Neural Networks", 3 | # Amirsina Torfi, Rouzbeh A. Shirvani, Sobhan Soleymani, Nasser M. Nasrabadi 4 | # ICLR 2018 5 | # https://arxiv.org/abs/1802.09902 6 | # 7 | # Add group variance regularization to SSL. 8 | # So far I haven't produced results better than SSL. The regularization strengh of the variance does not come into play 9 | # because it seems like the variance cost diminishes very quickly. 10 | # 11 | # time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training.yaml -j=1 --deterministic 12 | # 13 | 14 | 15 | lr_schedulers: 16 | training_lr: 17 | class: StepLR 18 | step_size: 45 19 | gamma: 0.10 20 | 21 | regularizers: 22 | Channels_l2_regularizer: 23 | class: GroupLassoRegularizer 24 | reg_regims: 25 | module.layer1.0.conv2.weight: [0.0028, Channels] 26 | module.layer1.1.conv2.weight: [0.0028, Channels] 27 | module.layer1.2.conv2.weight: [0.0024, Channels] 28 | module.layer2.0.conv2.weight: [0.0016, Channels] # sensitive 29 | module.layer2.1.conv2.weight: [0.0028, Channels] 30 | module.layer2.2.conv2.weight: [0.0028, Channels] 31 | module.layer3.0.conv2.weight: [0.0008, Channels] # sensitive 32 | module.layer3.1.conv2.weight: [0.0028, Channels] 33 | #module.layer3.2.conv2.weight: [0.0006, Channels] # very sensitive 34 | threshold_criteria: Mean_Abs 35 | 36 | Channels_variance_reguralizer: 37 | class: GroupVarianceRegularizer 38 | reg_regims: 39 | module.layer1.0.conv2.weight: [0.000008, Channels] 40 | module.layer1.1.conv2.weight: [0.000008, Channels] 41 | module.layer1.2.conv2.weight: [0.000008, Channels] 42 | module.layer2.0.conv2.weight: [0.000008, Channels] 43 | module.layer2.1.conv2.weight: [0.000008, Channels] 44 | module.layer2.2.conv2.weight: [0.000008, Channels] 45 | module.layer3.0.conv2.weight: [0.000008, Channels] 46 | module.layer3.1.conv2.weight: [0.000008, Channels] 47 | #module.layer3.2.conv2.weight: [0.000008, Channels] 48 | 49 | extensions: 50 | net_thinner: 51 | class: 'ChannelRemover' 52 | thinning_func_str: remove_channels 53 | arch: 'resnet20_cifar' 54 | dataset: 'cifar10' 55 | 56 | policies: 57 | - lr_scheduler: 58 | instance_name: training_lr 59 | starting_epoch: 45 60 | ending_epoch: 300 61 | frequency: 1 62 | 63 | # After completeing the regularization, we perform network thinning and exit. 64 | - extension: 65 | instance_name: net_thinner 66 | epochs: [179] 67 | 68 | - regularizer: 69 | instance_name: Channels_l2_regularizer 70 | args: 71 | keep_mask: True 72 | starting_epoch: 0 73 | ending_epoch: 180 74 | frequency: 1 75 | 76 | - regularizer: 77 | instance_name: Channels_variance_reguralizer 78 | args: 79 | keep_mask: True 80 | starting_epoch: 0 81 | ending_epoch: 180 82 | frequency: 1 83 | -------------------------------------------------------------------------------- /distiller/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .utils import * 18 | from .thresholding import GroupThresholdMixin, threshold_mask 19 | from .config import file_config, dict_config 20 | from .model_summaries import * 21 | from .scheduler import * 22 | from .sensitivity import * 23 | from .directives import * 24 | from .policy import * 25 | from .thinning import * 26 | from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights 27 | 28 | #del utils 29 | del dict_config 30 | del thinning 31 | #del model_summaries 32 | #del scheduler 33 | #del sensitivity 34 | #del directives 35 | #del thresholding 36 | #del policy 37 | 38 | # Distiller version 39 | __version__ = "0.3.0-pre" 40 | 41 | def model_find_param_name(model, param_to_find): 42 | """Look up the name of a model parameter. 43 | 44 | Arguments: 45 | model: the model to search 46 | param_to_find: the parameter whose name we want to look up 47 | 48 | Returns: 49 | The parameter name (string) or None, if the parameter was not found. 50 | """ 51 | for name, param in model.named_parameters(): 52 | if param is param_to_find: 53 | return name 54 | return None 55 | 56 | 57 | def model_find_module_name(model, module_to_find): 58 | """Look up the name of a module in a model. 59 | 60 | Arguments: 61 | model: the model to search 62 | module_to_find: the module whose name we want to look up 63 | 64 | Returns: 65 | The module name (string) or None, if the module was not found. 66 | """ 67 | for name, m in model.named_modules(): 68 | if m == module_to_find: 69 | return name 70 | return None 71 | 72 | def model_find_param(model, param_to_find_name): 73 | """Look a model parameter by its name 74 | 75 | Arguments: 76 | model: the model to search 77 | param_to_find_name: the name of the parameter that we are searching for 78 | 79 | Returns: 80 | The parameter or None, if the paramter name was not found. 81 | """ 82 | for name, param in model.named_parameters(): 83 | if name == param_to_find_name: 84 | return param 85 | return None 86 | 87 | 88 | def model_find_module(model, module_to_find): 89 | """Given a module name, find the module in the provided model. 90 | 91 | Arguments: 92 | model: the model to search 93 | module_to_find: the module whose name we want to look up 94 | 95 | Returns: 96 | The module or None, if the module was not found. 97 | """ 98 | for name, m in model.named_modules(): 99 | if name == module_to_find: 100 | return m 101 | return None 102 | -------------------------------------------------------------------------------- /models/imagenet/resnet_earlyexit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torchvision.models as models 5 | from torchvision.models.resnet import Bottleneck 6 | from torchvision.models.resnet import BasicBlock 7 | 8 | 9 | __all__ = ['resnet18_earlyexit', 'resnet34_earlyexit', 'resnet50_earlyexit', 'resnet101_earlyexit', 'resnet152_earlyexit'] 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class ResNetEarlyExit(models.ResNet): 19 | 20 | def __init__(self, block, layers, num_classes=1000): 21 | super(ResNetEarlyExit, self).__init__(block, layers, num_classes) 22 | 23 | # Define early exit layers 24 | self.conv1_exit0 = nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True) 25 | self.conv2_exit0 = nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True) 26 | self.conv1_exit1 = nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True) 27 | self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes) 28 | self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes) 29 | 30 | def forward(self, x): 31 | x = self.conv1(x) 32 | x = self.bn1(x) 33 | x = self.relu(x) 34 | x = self.maxpool(x) 35 | 36 | x = self.layer1(x) 37 | 38 | # Add early exit layers 39 | exit0 = self.avgpool(x) 40 | exit0 = self.conv1_exit0(exit0) 41 | exit0 = self.conv2_exit0(exit0) 42 | exit0 = self.avgpool(exit0) 43 | exit0 = exit0.view(exit0.size(0), -1) 44 | exit0 = self.fc_exit0(exit0) 45 | 46 | x = self.layer2(x) 47 | 48 | # Add early exit layers 49 | exit1 = self.conv1_exit1(x) 50 | exit1 = self.avgpool(exit1) 51 | exit1 = exit1.view(exit1.size(0), -1) 52 | exit1 = self.fc_exit1(exit1) 53 | 54 | x = self.layer3(x) 55 | x = self.layer4(x) 56 | 57 | x = self.avgpool(x) 58 | x = x.view(x.size(0), -1) 59 | x = self.fc(x) 60 | 61 | # return a list of probabilities 62 | output = [] 63 | output.append(exit0) 64 | output.append(exit1) 65 | output.append(x) 66 | return output 67 | 68 | 69 | def resnet18_earlyexit(**kwargs): 70 | """Constructs a ResNet-18 model. 71 | """ 72 | model = ResNetEarlyExit(BasicBlock, [2, 2, 2, 2], **kwargs) 73 | return model 74 | 75 | 76 | def resnet34_earlyexit(**kwargs): 77 | """Constructs a ResNet-34 model. 78 | """ 79 | model = ResNetEarlyExit(BasicBlock, [3, 4, 6, 3], **kwargs) 80 | return model 81 | 82 | 83 | def resnet50_earlyexit(**kwargs): 84 | """Constructs a ResNet-50 model. 85 | """ 86 | model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs) 87 | return model 88 | 89 | 90 | def resnet101_earlyexit(**kwargs): 91 | """Constructs a ResNet-101 model. 92 | """ 93 | model = ResNetEarlyExit(Bottleneck, [3, 4, 23, 3], **kwargs) 94 | return model 95 | 96 | 97 | def resnet152_earlyexit(**kwargs): 98 | """Constructs a ResNet-152 model. 99 | """ 100 | model = ResNetEarlyExit(Bottleneck, [3, 8, 36, 3], **kwargs) 101 | return model 102 | -------------------------------------------------------------------------------- /models/imagenet/mobilenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from math import floor 18 | import torch.nn as nn 19 | 20 | __all__ = ['mobilenet', 'mobilenet_025', 'mobilenet_050', 'mobilenet_075'] 21 | 22 | 23 | class MobileNet(nn.Module): 24 | def __init__(self, channel_multiplier=1.0, min_channels=8): 25 | super(MobileNet, self).__init__() 26 | 27 | if channel_multiplier <= 0: 28 | raise ValueError('channel_multiplier must be >= 0') 29 | 30 | def conv_bn_relu(n_ifm, n_ofm, kernel_size, stride=1, padding=0, groups=1): 31 | return [ 32 | nn.Conv2d(n_ifm, n_ofm, kernel_size, stride=stride, padding=padding, groups=groups, bias=False), 33 | nn.BatchNorm2d(n_ofm), 34 | nn.ReLU(inplace=True) 35 | ] 36 | 37 | def depthwise_conv(n_ifm, n_ofm, stride): 38 | return nn.Sequential( 39 | *conv_bn_relu(n_ifm, n_ifm, 3, stride=stride, padding=1, groups=n_ifm), 40 | *conv_bn_relu(n_ifm, n_ofm, 1, stride=1) 41 | ) 42 | 43 | base_channels = [32, 64, 128, 256, 512, 1024] 44 | self.channels = [max(floor(n * channel_multiplier), min_channels) for n in base_channels] 45 | 46 | self.model = nn.Sequential( 47 | nn.Sequential(*conv_bn_relu(3, self.channels[0], 3, stride=2, padding=1)), 48 | depthwise_conv(self.channels[0], self.channels[1], 1), 49 | depthwise_conv(self.channels[1], self.channels[2], 2), 50 | depthwise_conv(self.channels[2], self.channels[2], 1), 51 | depthwise_conv(self.channels[2], self.channels[3], 2), 52 | depthwise_conv(self.channels[3], self.channels[3], 1), 53 | depthwise_conv(self.channels[3], self.channels[4], 2), 54 | depthwise_conv(self.channels[4], self.channels[4], 1), 55 | depthwise_conv(self.channels[4], self.channels[4], 1), 56 | depthwise_conv(self.channels[4], self.channels[4], 1), 57 | depthwise_conv(self.channels[4], self.channels[4], 1), 58 | depthwise_conv(self.channels[4], self.channels[4], 1), 59 | depthwise_conv(self.channels[4], self.channels[5], 2), 60 | depthwise_conv(self.channels[5], self.channels[5], 1), 61 | nn.AvgPool2d(7), 62 | ) 63 | self.fc = nn.Linear(self.channels[5], 1000) 64 | 65 | def forward(self, x): 66 | x = self.model(x) 67 | x = x.view(-1, self.channels[-1]) 68 | x = self.fc(x) 69 | return x 70 | 71 | 72 | def mobilenet_025(): 73 | return MobileNet(channel_multiplier=0.25) 74 | 75 | 76 | def mobilenet_050(): 77 | return MobileNet(channel_multiplier=0.5) 78 | 79 | 80 | def mobilenet_075(): 81 | return MobileNet(channel_multiplier=0.75) 82 | 83 | 84 | def mobilenet(): 85 | return MobileNet() 86 | -------------------------------------------------------------------------------- /OCS-CNN/inspect_ckpt.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """A small utility to inspect the contents of checkpoint files. 18 | 19 | Sometimes it is useful to look at the contents of a checkpoint file, and this utility is meant to help 20 | with this. 21 | By default this utility will print just the names and types of the keys it finds in the checkpoint 22 | file. If the key type is simple (i.e. integer, float, or string), then the value is printed as well. 23 | 24 | You can also print the model keys (i.e. the names of the parameters tensors in the model), and the 25 | weight tensor masks in the schedule). 26 | 27 | $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule 28 | """ 29 | import torch 30 | import argparse 31 | from tabulate import tabulate 32 | import sys 33 | import os 34 | script_dir = os.path.dirname(__file__) 35 | module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) 36 | try: 37 | import distiller 38 | except ImportError: 39 | sys.path.append(module_path) 40 | import distiller 41 | 42 | 43 | def inspect_checkpoint(chkpt_file, args): 44 | def inspect_val(val): 45 | if isinstance(val, (int, float, str)): 46 | return val 47 | return None 48 | 49 | print("Inspecting checkpoint file: ", chkpt_file) 50 | checkpoint = torch.load(chkpt_file) 51 | 52 | chkpt_keys = [[k, type(checkpoint[k]).__name__, inspect_val(checkpoint[k])] for k in checkpoint.keys()] 53 | print(tabulate(chkpt_keys, headers=["Key", "Type", "Value"], tablefmt="fancy_grid")) 54 | 55 | if args.model and "state_dict" in checkpoint: 56 | print("\nModel keys (state_dict):\n{}".format(", ".join(list(checkpoint["state_dict"].keys())))) 57 | 58 | if args.schedule and "compression_sched" in checkpoint: 59 | compression_sched = checkpoint["compression_sched"] 60 | print("\nSchedule keys (compression_sched):\n{}\n".format("\n\t".join(list(compression_sched.keys())))) 61 | sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()] 62 | print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid")) 63 | if "masks_dict" in checkpoint["compression_sched"]: 64 | print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join( 65 | list(compression_sched["masks_dict"].keys())))) 66 | 67 | if args.thinning and "thinning_recipes" in checkpoint: 68 | for recipe in checkpoint["thinning_recipes"]: 69 | print(recipe) 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser(description='Distiller checkpoint inspection') 73 | parser.add_argument('chkpt_file', help='path to the checkpoint file') 74 | parser.add_argument('-m', '--model', action='store_true', help='print the model keys') 75 | parser.add_argument('-s', '--schedule', action='store_true', help='print the schedule keys') 76 | parser.add_argument('-t', '--thinning', action='store_true', help='print the thinning keys') 77 | args = parser.parse_args() 78 | inspect_checkpoint(args.chkpt_file, args) 79 | -------------------------------------------------------------------------------- /examples/baidu-rnn-pruning/word_lang_model.schedule_baidu_rnn.yaml: -------------------------------------------------------------------------------- 1 | # Fine grained (element-wise) pruning using RNN pruning scheduling for PyTorch's example Word Language model. 2 | # The pruning schedule is based on the following paper from ICLR 2017: 3 | # Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017). 4 | # Exploring Sparsity in Recurrent Neural Networks. 5 | # (https://arxiv.org/abs/1704.05119) 6 | # 7 | # The README of PyTorch's word language model example code, promises that this configuration will produce a Test perplexity 8 | # of 72.30, while I was only able to get 84.23, so I use that as the baseline for comparison. 9 | # 10 | # Baseline generation: 11 | # time python3 main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied 12 | # 13 | # Pruning: 14 | # python3 main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied --compress=../../examples/baidu-rnn-pruning/word_lang_model.schedule_baidu_rnn.yaml 15 | # 16 | # The Baidu pruner uses a value they refer to as 'q': 17 | # "In order to determine q in equation 1, we use an existing weight array from a previously trained 18 | # model. The weights are sorted using absolute values and we pick the weight corresponding to the 19 | # 90th percentile as q." 20 | # 21 | # To determine this 'q' value we first train the baseline network (saved in model.emsize1500.nhid1500.dropout065.tied.pt), 22 | # and then extract the statistics: 23 | # python3 main.py --cuda --resume=model.emsize1500.nhid1500.dropout065.tied.pt --summary=percentile 24 | # 25 | # parameter encoder.weight: q = 0.16 26 | # parameter rnn.weight_ih_l0: q = 0.17 27 | # parameter rnn.weight_hh_l0: q = 0.11 28 | # parameter rnn.weight_ih_l1: q = 0.18 29 | # parameter rnn.weight_hh_l1: q = 0.15 30 | # parameter decoder.weight: q = 0.16 31 | # 32 | # To save you time, you can download a pretrained model from here: 33 | # https://s3-us-west-1.amazonaws.com/nndistiller/agp-pruning/word_language_model/model.emsize1500.nhid1500.dropout065.tied.pt 34 | # 35 | 36 | version: 1 37 | pruners: 38 | ih_l0_rnn_pruner: 39 | class: BaiduRNNPruner 40 | q: 0.17 41 | ramp_epoch_offset: 3 42 | ramp_slope_mult: 2 43 | weights: [rnn.weight_ih_l0] 44 | 45 | hh_l0_rnn_pruner: 46 | class: BaiduRNNPruner 47 | q: 0.11 48 | ramp_epoch_offset: 3 49 | ramp_slope_mult: 2 50 | weights: [rnn.weight_hh_l0] 51 | 52 | ih_l1_rnn_pruner: 53 | class: BaiduRNNPruner 54 | q: 0.18 55 | ramp_epoch_offset: 3 56 | ramp_slope_mult: 2 57 | weights: [rnn.weight_ih_l1] 58 | 59 | hh_l1_rnn_pruner: 60 | class: BaiduRNNPruner 61 | q: 0.15 62 | ramp_epoch_offset: 3 63 | ramp_slope_mult: 2 64 | weights: [rnn.weight_hh_l1] 65 | 66 | embedding_pruner: 67 | class: BaiduRNNPruner 68 | q: 0.16 69 | ramp_epoch_offset: 3 70 | ramp_slope_mult: 2 71 | weights: [encoder.weight] 72 | 73 | policies: 74 | - pruner: 75 | instance_name : ih_l0_rnn_pruner 76 | starting_epoch: 4 77 | ending_epoch: 21 78 | frequency: 3 79 | 80 | - pruner: 81 | instance_name : hh_l0_rnn_pruner 82 | starting_epoch: 4 83 | ending_epoch: 21 84 | frequency: 3 85 | 86 | - pruner: 87 | instance_name : ih_l1_rnn_pruner 88 | starting_epoch: 5 89 | ending_epoch: 22 90 | frequency: 3 91 | 92 | - pruner: 93 | instance_name : hh_l1_rnn_pruner 94 | starting_epoch: 5 95 | ending_epoch: 22 96 | frequency: 3 97 | 98 | - pruner: 99 | instance_name : embedding_pruner 100 | starting_epoch: 6 101 | ending_epoch: 23 102 | frequency: 3 103 | -------------------------------------------------------------------------------- /examples/classifier_compression/inspect_ckpt.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """A small utility to inspect the contents of checkpoint files. 18 | 19 | Sometimes it is useful to look at the contents of a checkpoint file, and this utility is meant to help 20 | with this. 21 | By default this utility will print just the names and types of the keys it finds in the checkpoint 22 | file. If the key type is simple (i.e. integer, float, or string), then the value is printed as well. 23 | 24 | You can also print the model keys (i.e. the names of the parameters tensors in the model), and the 25 | weight tensor masks in the schedule). 26 | 27 | $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule 28 | """ 29 | import torch 30 | import argparse 31 | from tabulate import tabulate 32 | import sys 33 | import os 34 | script_dir = os.path.dirname(__file__) 35 | module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) 36 | try: 37 | import distiller 38 | except ImportError: 39 | sys.path.append(module_path) 40 | import distiller 41 | 42 | 43 | def inspect_checkpoint(chkpt_file, args): 44 | def inspect_val(val): 45 | if isinstance(val, (int, float, str)): 46 | return val 47 | return None 48 | 49 | print("Inspecting checkpoint file: ", chkpt_file) 50 | checkpoint = torch.load(chkpt_file) 51 | 52 | chkpt_keys = [[k, type(checkpoint[k]).__name__, inspect_val(checkpoint[k])] for k in checkpoint.keys()] 53 | print(tabulate(chkpt_keys, headers=["Key", "Type", "Value"], tablefmt="fancy_grid")) 54 | 55 | if args.model and "state_dict" in checkpoint: 56 | print("\nModel keys (state_dict):\n{}".format(", ".join(list(checkpoint["state_dict"].keys())))) 57 | 58 | if args.schedule and "compression_sched" in checkpoint: 59 | compression_sched = checkpoint["compression_sched"] 60 | print("\nSchedule keys (compression_sched):\n{}\n".format("\n\t".join(list(compression_sched.keys())))) 61 | sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()] 62 | print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid")) 63 | if "masks_dict" in checkpoint["compression_sched"]: 64 | print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join( 65 | list(compression_sched["masks_dict"].keys())))) 66 | 67 | if args.thinning and "thinning_recipes" in checkpoint: 68 | for recipe in checkpoint["thinning_recipes"]: 69 | print(recipe) 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser(description='Distiller checkpoint inspection') 73 | parser.add_argument('chkpt_file', help='path to the checkpoint file') 74 | parser.add_argument('-m', '--model', action='store_true', help='print the model keys') 75 | parser.add_argument('-s', '--schedule', action='store_true', help='print the schedule keys') 76 | parser.add_argument('-t', '--thinning', action='store_true', help='print the thinning keys') 77 | args = parser.parse_args() 78 | inspect_checkpoint(args.chkpt_file, args) 79 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains ImageNet and CIFAR image classification models for pytorch""" 18 | 19 | import torch 20 | import torchvision.models as torch_models 21 | import models.cifar10 as cifar10_models 22 | import models.imagenet as imagenet_extra_models 23 | 24 | import logging 25 | msglogger = logging.getLogger() 26 | 27 | IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__ 28 | if name.islower() and not name.startswith("__") 29 | and callable(torch_models.__dict__[name])) 30 | IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__ 31 | if name.islower() and not name.startswith("__") 32 | and callable(imagenet_extra_models.__dict__[name]))) 33 | 34 | CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__ 35 | if name.islower() and not name.startswith("__") 36 | and callable(cifar10_models.__dict__[name])) 37 | 38 | ALL_MODEL_NAMES = sorted(set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)) 39 | 40 | 41 | def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): 42 | """Create a pytorch model based on the model architecture and dataset 43 | 44 | Args: 45 | pretrained: True is you wish to load a pretrained model. Only torchvision models 46 | have a pretrained model. 47 | dataset: 48 | arch: 49 | parallel: 50 | """ 51 | msglogger.info('==> using %s dataset' % dataset) 52 | 53 | model = None 54 | if dataset == 'imagenet': 55 | str_pretrained = 'pretrained ' if pretrained else '' 56 | msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch)) 57 | assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \ 58 | "Model %s is not supported for dataset %s" % (arch, 'ImageNet') 59 | if arch in torch_models.__dict__: 60 | model = torch_models.__dict__[arch](pretrained=pretrained) 61 | else: 62 | assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch 63 | model = imagenet_extra_models.__dict__[arch]() 64 | elif dataset == 'cifar10': 65 | msglogger.info("=> creating %s model for CIFAR10" % arch) 66 | assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch 67 | assert not pretrained, "Model %s (CIFAR10) does not have a pretrained model" % arch 68 | model = cifar10_models.__dict__[arch]() 69 | else: 70 | print("FATAL ERROR: create_model does not support models for dataset %s" % dataset) 71 | exit() 72 | 73 | if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: 74 | model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) 75 | elif parallel: 76 | model = torch.nn.DataParallel(model, device_ids=device_ids) 77 | 78 | model.cuda() 79 | return model 80 | -------------------------------------------------------------------------------- /distiller/directives.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Scheduling directives 18 | 19 | Scheduling directives are instructions (directives) that the scheduler can 20 | execute as part of scheduling pruning activities. 21 | """ 22 | from __future__ import division 23 | import torch 24 | import numpy as np 25 | from collections import defaultdict 26 | import logging 27 | msglogger = logging.getLogger() 28 | 29 | from torchnet.meter import AverageValueMeter 30 | from distiller.utils import sparsity, density 31 | 32 | 33 | class FreezeTraining(object): 34 | def __init__(self, name): 35 | print("------FreezeTraining--------") 36 | self.name = name 37 | 38 | def freeze_training(model, which_params, freeze): 39 | """This function will freeze/defrost training for certain layers. 40 | 41 | Sometimes, when we prune and retrain a certain layer type, 42 | we'd like to freeze the training of the other layers. 43 | """ 44 | for param in model.parameters(): 45 | pname = model_find_param_name(model, param.data) 46 | if pname is None: 47 | continue 48 | for ptype in which_params: 49 | if ptype in pname: 50 | # see: http://pytorch.org/docs/master/notes/autograd.html?highlight=grad_fn 51 | param.requires_grad = not freeze 52 | if freeze: 53 | msglogger.info('Freezing: ' + pname) 54 | else: 55 | msglogger.info('Defrosting: ' + pname) 56 | 57 | 58 | def freeze_all(model, freeze): 59 | msglogger.info('{} all parameters'.format('Freezing' if freeze else 'Defrosting')) 60 | for param in model.parameters(): 61 | param.requires_grad = not freeze 62 | 63 | 64 | def adjust_dropout(module, new_probabilty): 65 | """Replace the dropout probability of dropout layers 66 | 67 | As explained in the paper "Learning both Weights and Connections for 68 | Efficient Neural Networks": 69 | Dropout is widely used to prevent over-fitting, and this also applies to retraining. 70 | During retraining, however, the dropout ratio must be adjusted to account for the 71 | change in model capacity. In dropout, each parameter is probabilistically dropped 72 | during training, but will come back during inference. In pruning, parameters are 73 | dropped forever after pruning and have no chance to come back during both training 74 | and inference. As the parameters get sparse, the classifier will select the most 75 | informative predictors and thus have much less prediction variance, which reduces 76 | over-fitting. As pruning already reduced model capacity, the retraining dropout ratio 77 | should be smaller. 78 | """ 79 | if type(module) in [torch.nn.Dropout, 80 | torch.nn.Dropout2d, 81 | torch.nn.Dropout3d, 82 | torch.nn.AlphaDropout]: 83 | msglogger.info("Adjusting dropout probability")# for {}".format(str(module))) 84 | module.p = new_probabilty 85 | else: 86 | for child in module.children(): 87 | adjust_dropout(child, new_probabilty) 88 | -------------------------------------------------------------------------------- /models/cifar10/resnet_cifar_earlyexit.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Resnet for CIFAR10 18 | 19 | Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition". 20 | This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate 21 | changes for the 10-class Cifar-10 dataset. 22 | This ResNet also has layer gates, to be able to dynamically remove layers. 23 | 24 | @inproceedings{DBLP:conf/cvpr/HeZRS16, 25 | author = {Kaiming He and 26 | Xiangyu Zhang and 27 | Shaoqing Ren and 28 | Jian Sun}, 29 | title = {Deep Residual Learning for Image Recognition}, 30 | booktitle = {{CVPR}}, 31 | pages = {770--778}, 32 | publisher = {{IEEE} Computer Society}, 33 | year = {2016} 34 | } 35 | 36 | """ 37 | import torch.nn as nn 38 | import math 39 | import torch.utils.model_zoo as model_zoo 40 | import torchvision.models as models 41 | from .resnet_cifar import BasicBlock 42 | from .resnet_cifar import ResNetCifar 43 | 44 | 45 | __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', 46 | 'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit'] 47 | 48 | NUM_CLASSES = 10 49 | 50 | def conv3x3(in_planes, out_planes, stride=1): 51 | """3x3 convolution with padding""" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | 55 | 56 | class ResNetCifarEarlyExit(ResNetCifar): 57 | 58 | def __init__(self, block, layers, num_classes=NUM_CLASSES): 59 | super(ResNetCifarEarlyExit, self).__init__(block, layers, num_classes) 60 | 61 | # Define early exit layers 62 | self.linear_exit0 = nn.Linear(1600, num_classes) 63 | 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.bn1(x) 68 | x = self.relu(x) 69 | 70 | x = self.layer1(x) 71 | 72 | # Add early exit layers 73 | exit0 = nn.functional.avg_pool2d(x, 3) 74 | exit0 = exit0.view(exit0.size(0), -1) 75 | exit0 = self.linear_exit0(exit0) 76 | 77 | x = self.layer2(x) 78 | x = self.layer3(x) 79 | 80 | x = self.avgpool(x) 81 | x = x.view(x.size(0), -1) 82 | x = self.fc(x) 83 | 84 | # return a list of probabilities 85 | output = [] 86 | output.append(exit0) 87 | output.append(x) 88 | return output 89 | 90 | 91 | def resnet20_cifar_earlyexit(**kwargs): 92 | model = ResNetCifarEarlyExit(BasicBlock, [3, 3, 3], **kwargs) 93 | return model 94 | 95 | def resnet32_cifar_earlyexit(**kwargs): 96 | model = ResNetCifarEarlyExit(BasicBlock, [5, 5, 5], **kwargs) 97 | return model 98 | 99 | def resnet44_cifar_earlyexit(**kwargs): 100 | model = ResNetCifarEarlyExit(BasicBlock, [7, 7, 7], **kwargs) 101 | return model 102 | 103 | def resnet56_cifar_earlyexit(**kwargs): 104 | model = ResNetCifarEarlyExit(BasicBlock, [9, 9, 9], **kwargs) 105 | return model 106 | 107 | def resnet110_cifar_earlyexit(**kwargs): 108 | model = ResNetCifarEarlyExit(BasicBlock, [18, 18, 18], **kwargs) 109 | return model 110 | 111 | def resnet1202_cifar_earlyexit(**kwargs): 112 | model = ResNetCifarEarlyExit(BasicBlock, [200, 200, 200], **kwargs) 113 | return model -------------------------------------------------------------------------------- /tests/test_ranking.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import torch 19 | import os 20 | import sys 21 | try: 22 | import distiller 23 | except ImportError: 24 | module_path = os.path.abspath(os.path.join('..')) 25 | if module_path not in sys.path: 26 | sys.path.append(module_path) 27 | import distiller 28 | import common # common test code 29 | 30 | # Logging configuration 31 | logging.basicConfig(level=logging.INFO) 32 | fh = logging.FileHandler('test.log') 33 | logger = logging.getLogger() 34 | logger.addHandler(fh) 35 | 36 | 37 | def test_ch_ranking(): 38 | # Tensor with shape [3, 2, 2, 2] -- 3 filters, 2 channels 39 | param = torch.tensor([[[[11., 12], 40 | [13, 14]], 41 | 42 | [[15., 16], 43 | [17, 18]]], 44 | # Filter #2 45 | [[[21., 22], 46 | [23, 24]], 47 | 48 | [[25., 26], 49 | [27, 28]]], 50 | # Filter #3 51 | [[[31., 32], 52 | [33, 34]], 53 | 54 | [[35., 36], 55 | [37, 38]]]]) 56 | 57 | fraction_to_prune = 0.5 58 | bottomk_channels, channel_mags = distiller.pruning.L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param) 59 | logger.info("bottom {}% channels: {}".format(fraction_to_prune*100, bottomk_channels)) 60 | assert bottomk_channels == torch.tensor([90.]) 61 | 62 | 63 | def test_ranked_channel_pruning(): 64 | model, zeros_mask_dict = common.setup_test("resnet20_cifar", "cifar10", parallel=False) 65 | 66 | # Test that we can access the weights tensor of the first convolution in layer 1 67 | conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight") 68 | assert conv1_p is not None 69 | 70 | # Test that there are no zero-channels 71 | assert distiller.sparsity_ch(conv1_p) == 0.0 72 | 73 | # # Create a channel-ranking pruner 74 | reg_regims = {"layer1.0.conv1.weight": [0.1, "Channels"]} 75 | pruner = distiller.pruning.L1RankedStructureParameterPruner("channel_pruner", reg_regims) 76 | pruner.set_param_mask(conv1_p, "layer1.0.conv1.weight", zeros_mask_dict, meta=None) 77 | 78 | conv1 = common.find_module_by_name(model, "layer1.0.conv1") 79 | assert conv1 is not None 80 | 81 | # Test that the mask has the correct fraction of channels pruned. 82 | # We asked for 10%, but there are only 16 channels, so we have to settle for 1/16 channels 83 | logger.info("layer1.0.conv1 = {}".format(conv1)) 84 | expected_pruning = int(0.1 * conv1.in_channels) / conv1.in_channels 85 | assert distiller.sparsity_ch(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning 86 | 87 | # Use the mask to prune 88 | assert distiller.sparsity_ch(conv1_p) == 0 89 | zeros_mask_dict["layer1.0.conv1.weight"].apply_mask(conv1_p) 90 | assert distiller.sparsity_ch(conv1_p) == expected_pruning 91 | 92 | # Remove channels (and filters) 93 | conv0 = common.find_module_by_name(model, "conv1") 94 | assert conv0 is not None 95 | assert conv0.out_channels == 16 96 | assert conv1.in_channels == 16 97 | 98 | # Test thinning 99 | distiller.remove_channels(model, zeros_mask_dict, "resnet20_cifar", "cifar10", optimizer=None) 100 | assert conv0.out_channels == 15 101 | assert conv1.in_channels == 15 102 | -------------------------------------------------------------------------------- /distiller/pruning/baidu_rnn_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | from .level_pruner import SparsityLevelParameterPruner 19 | from distiller.utils import * 20 | 21 | import distiller 22 | 23 | class BaiduRNNPruner(_ParameterPruner): 24 | """An element-wise pruner for RNN networks. 25 | 26 | Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017). 27 | Exploring Sparsity in Recurrent Neural Networks. 28 | (https://arxiv.org/abs/1704.05119) 29 | 30 | This implementation slightly differs from the algorithm original paper in that 31 | the algorithm changes the pruning rate at the training-step granularity, while 32 | Distiller controls the pruning rate at epoch granularity. 33 | 34 | Equation (1): 35 | 36 | 2 * q * freq 37 | start_slope = ------------------------------------------------------- 38 | 2 * (ramp_itr - start_itr ) + 3 * (end_itr - ramp_itr ) 39 | 40 | 41 | Pruning algorithm (1): 42 | 43 | if current itr < ramp itr then 44 | threshold = start_slope * (current_itr - start_itr + 1) / freq 45 | else 46 | threshold = (start_slope * (ramp_itr - start_itr + 1) + 47 | ramp_slope * (current_itr - ramp_itr + 1)) / freq 48 | end if 49 | 50 | mask = abs(param) < threshold 51 | """ 52 | 53 | def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights): 54 | # Initialize the pruner, using a configuration that originates from the 55 | # schedule YAML file. 56 | super(BaiduRNNPruner, self).__init__(name) 57 | self.params_names = weights 58 | assert self.params_names 59 | 60 | # This is the 'q' value that appears in equation (1) of the paper 61 | self.q = q 62 | # This is the number of epochs to wait after starting_epoch, before we 63 | # begin ramping up the pruning rate. 64 | # In other words, between epochs 'starting_epoch' and 'starting_epoch'+ 65 | # self.ramp_epoch_offset the pruning slope is 'self.start_slope'. After 66 | # that, the slope is 'self.ramp_slope' 67 | self.ramp_epoch_offset = ramp_epoch_offset 68 | self.ramp_slope_mult = ramp_slope_mult 69 | self.ramp_slope = None 70 | self.start_slope = None 71 | 72 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 73 | if param_name not in self.params_names: 74 | return 75 | 76 | starting_epoch = meta['starting_epoch'] 77 | current_epoch = meta['current_epoch'] 78 | ending_epoch = meta['ending_epoch'] 79 | freq = meta['frequency'] 80 | 81 | ramp_epoch = self.ramp_epoch_offset + starting_epoch 82 | 83 | # Calculate start slope 84 | if self.start_slope is None: 85 | # We want to calculate these values only once, and then cache them. 86 | self.start_slope = (2 * self.q * freq) / (2*(ramp_epoch - starting_epoch) + 3*(ending_epoch - ramp_epoch)) 87 | self.ramp_slope = self.start_slope * self.ramp_slope_mult 88 | 89 | if current_epoch < ramp_epoch: 90 | eps = self.start_slope * (current_epoch - starting_epoch + 1) / freq 91 | else: 92 | eps = (self.start_slope * (ramp_epoch - starting_epoch + 1) + 93 | self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq 94 | 95 | # After computing the threshold, we can create the mask 96 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps) 97 | -------------------------------------------------------------------------------- /apputils/execution_env.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Log information regarding the execution environment. 18 | 19 | This is helpful if you want to recreate an experiment at a later time, or if 20 | you want to understand the environment in which you execute the training. 21 | """ 22 | 23 | import sys 24 | import os 25 | import time 26 | import platform 27 | import logging 28 | import logging.config 29 | import numpy as np 30 | import torch 31 | from git import Repo 32 | HAVE_LSB = True 33 | try: 34 | import lsb_release 35 | except ImportError: 36 | HAVE_LSB = False 37 | 38 | logger = logging.getLogger("app_cfg") 39 | 40 | def log_execution_env_state(app_args, gitroot='.'): 41 | """Log information about the execution environment. 42 | 43 | It is recommeneded to log this information so it can be used for referencing 44 | at a later time. 45 | 46 | Args: 47 | app_args (dict): the command line arguments passed to the application 48 | git_root: the path to the .git root directory 49 | """ 50 | 51 | def log_git_state(): 52 | """Log the state of the git repository. 53 | 54 | It is useful to know what git tag we're using, and if we have outstanding code. 55 | """ 56 | repo = Repo(gitroot) 57 | assert not repo.bare 58 | 59 | if repo.is_dirty(): 60 | logger.debug("Git is dirty") 61 | try: 62 | branch_name = repo.active_branch.name 63 | except TypeError: 64 | branch_name = "None, Git is in 'detached HEAD' state" 65 | logger.debug("Active Git branch: %s", branch_name) 66 | logger.debug("Git commit: %s" % repo.head.commit.hexsha) 67 | 68 | logger.debug("Number of CPUs: %d", len(os.sched_getaffinity(0))) 69 | logger.debug("Number of GPUs: %d", torch.cuda.device_count()) 70 | logger.debug("CUDA version: %s", torch.version.cuda) 71 | logger.debug("CUDNN version: %s", torch.backends.cudnn.version()) 72 | logger.debug("Kernel: %s", platform.release()) 73 | if HAVE_LSB: 74 | logger.debug("OS: %s", lsb_release.get_lsb_information()['DESCRIPTION']) 75 | logger.debug("Python: %s", sys.version) 76 | logger.debug("PyTorch: %s", torch.__version__) 77 | logger.debug("Numpy: %s", np.__version__) 78 | log_git_state() 79 | logger.debug("App args: %s", app_args) 80 | 81 | 82 | def config_pylogger(log_cfg_file, experiment_name, output_dir='logs'): 83 | """Configure the Python logger. 84 | 85 | For each execution of the application, we'd like to create a unique log directory. 86 | By default this library is named using the date and time of day, to that directories 87 | can be sorted by recency. You can also name yor experiments and prefix the log 88 | directory with this name. This can be useful when accessing experiment data from 89 | TensorBoard, for example. 90 | """ 91 | timestr = time.strftime("%Y.%m.%d-%H%M%S") 92 | exp_full_name = timestr if experiment_name is None else experiment_name + '___' + timestr 93 | logdir = os.path.join(output_dir, exp_full_name) 94 | if not os.path.exists(logdir): 95 | os.makedirs(logdir) 96 | log_filename = os.path.join(logdir, exp_full_name + '.log') 97 | if os.path.isfile(log_cfg_file): 98 | logging.config.fileConfig(log_cfg_file, defaults={'logfilename': log_filename}) 99 | msglogger = logging.getLogger() 100 | msglogger.logdir = logdir 101 | msglogger.log_filename = log_filename 102 | msglogger.info('Log file for this run: ' + os.path.realpath(log_filename)) 103 | return msglogger 104 | -------------------------------------------------------------------------------- /distiller/pruning/automated_gradual_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | from .level_pruner import SparsityLevelParameterPruner 19 | from .ranked_structures_pruner import L1RankedStructureParameterPruner 20 | from distiller.utils import * 21 | # import logging 22 | # msglogger = logging.getLogger() 23 | 24 | 25 | class AutomatedGradualPruner(_ParameterPruner): 26 | """Prune to an exact pruning level specification. 27 | 28 | An automated gradual pruning algorithm that prunes the smallest magnitude 29 | weights to achieve a preset level of network sparsity. 30 | 31 | Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the 32 | efficacy of pruning for model compression", 2017 NIPS Workshop on Machine 33 | Learning of Phones and other Consumer Devices, 34 | (https://arxiv.org/pdf/1710.01878.pdf) 35 | """ 36 | 37 | def __init__(self, name, initial_sparsity, final_sparsity, weights, 38 | pruning_fn=None): 39 | super(AutomatedGradualPruner, self).__init__(name) 40 | self.initial_sparsity = initial_sparsity 41 | self.final_sparsity = final_sparsity 42 | assert final_sparsity > initial_sparsity 43 | self.params_names = weights 44 | assert self.params_names 45 | if pruning_fn is None: 46 | self.pruning_fn = self.prune_to_target_sparsity 47 | else: 48 | self.pruning_fn = pruning_fn 49 | 50 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 51 | if param_name not in self.params_names: 52 | return 53 | 54 | starting_epoch = meta['starting_epoch'] 55 | current_epoch = meta['current_epoch'] 56 | ending_epoch = meta['ending_epoch'] 57 | freq = meta['frequency'] 58 | span = ((ending_epoch - starting_epoch - 1) // freq) * freq 59 | assert span > 0 60 | 61 | target_sparsity = (self.final_sparsity + 62 | (self.initial_sparsity-self.final_sparsity) * 63 | (1.0 - ((current_epoch-starting_epoch)/span))**3) 64 | self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity) 65 | 66 | @staticmethod 67 | def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity): 68 | return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity) 69 | 70 | 71 | class StructuredAutomatedGradualPruner(AutomatedGradualPruner): 72 | def __init__(self, name, initial_sparsity, final_sparsity, reg_regims): 73 | self.reg_regims = reg_regims 74 | weights = [weight for weight in reg_regims.keys()] 75 | if not all([group in ['3D', 'Filters', 'Channels'] for group in reg_regims.values()]): 76 | raise ValueError("Currently only filter (3D) and channel pruning is supported") 77 | super(StructuredAutomatedGradualPruner, self).__init__(name, initial_sparsity, 78 | final_sparsity, weights, 79 | pruning_fn=self.prune_to_target_sparsity) 80 | 81 | def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity): 82 | if self.reg_regims[param_name] in ['3D', 'Filters']: 83 | L1RankedStructureParameterPruner.rank_prune_filters(target_sparsity, param, 84 | param_name, zeros_mask_dict) 85 | else: 86 | if self.reg_regims[param_name] == 'Channels': 87 | L1RankedStructureParameterPruner.rank_prune_channels(target_sparsity, param, 88 | param_name, zeros_mask_dict) 89 | -------------------------------------------------------------------------------- /models/imagenet/alexnet_batchnorm.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ 18 | AlexNet model with batch-norm layers. 19 | Model configuration based on the AlexNet DoReFa example in TensorPack: 20 | https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py 21 | 22 | Code based on the AlexNet PyTorch sample, with the required changes. 23 | """ 24 | 25 | import math 26 | import torch.nn as nn 27 | 28 | __all__ = ['AlexNetBN', 'alexnet_bn'] 29 | 30 | 31 | class AlexNetBN(nn.Module): 32 | 33 | def __init__(self, num_classes=1000): 34 | super(AlexNetBN, self).__init__() 35 | self.features = nn.Sequential( 36 | nn.Conv2d(3, 96, kernel_size=12, stride=4), # conv0 (224x224x3) -> (54x54x96) 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2, bias=False), # conv1 (54x54x96) -> (54x54x256) 39 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn1 (54x54x256) 40 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), # pool1 (54x54x256) -> (27x27x256) 41 | nn.ReLU(inplace=True), 42 | 43 | nn.Conv2d(256, 384, kernel_size=3, padding=1, bias=False), # conv2 (27x27x256) -> (27x27x384) 44 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn2 (27x27x384) 45 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # pool2 (27x27x384) -> (14x14x384) 46 | nn.ReLU(inplace=True), 47 | 48 | nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False), # conv3 (14x14x384) -> (14x14x384) 49 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn3 (14x14x384) 50 | nn.ReLU(inplace=True), 51 | 52 | nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False), # conv4 (14x14x384) -> (14x14x256) 53 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn4 (14x14x256) 54 | nn.MaxPool2d(kernel_size=3, stride=2), # pool4 (14x14x256) -> (6x6x256) 55 | nn.ReLU(inplace=True), 56 | ) 57 | self.classifier = nn.Sequential( 58 | nn.Linear(256 * 6 * 6, 4096, bias=False), # fc0 59 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc0 60 | nn.ReLU(inplace=True), 61 | nn.Linear(4096, 4096, bias=False), # fc1 62 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc1 63 | nn.ReLU(inplace=True), 64 | nn.Linear(4096, num_classes), # fct 65 | ) 66 | 67 | for m in self.modules(): 68 | if isinstance(m, (nn.Conv2d, nn.Linear)): 69 | fan_in, k_size = (m.in_channels, m.kernel_size[0] * m.kernel_size[1]) if isinstance(m, nn.Conv2d) \ 70 | else (m.in_features, 1) 71 | n = k_size * fan_in 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | if hasattr(m, 'bias') and m.bias is not None: 74 | m.bias.data.fill_(0) 75 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 76 | m.weight.data.fill_(1) 77 | m.bias.data.zero_() 78 | 79 | def forward(self, x): 80 | x = self.features(x) 81 | x = x.view(x.size(0), 256 * 6 * 6) 82 | x = self.classifier(x) 83 | return x 84 | 85 | 86 | def alexnet_bn(**kwargs): 87 | r"""AlexNet model with batch-norm layers. 88 | Model configuration based on the AlexNet DoReFa example in `TensorPack 89 | ` 90 | """ 91 | model = AlexNetBN(**kwargs) 92 | return model 93 | -------------------------------------------------------------------------------- /jupyter/interactive_lr_scheduler.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Learning rate decay scheduling\n", 8 | "\n", 9 | "This notebook is not specific to Distiller.\n", 10 | "\n", 11 | "When fine-tuning or training a model, you may want to try different LR-decay policies. This notebook shows how the different policies work." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from ipywidgets import widgets, interact\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import torch\n", 23 | "from torch.optim.lr_scheduler import *" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import torchvision\n", 33 | "model = torchvision.models.alexnet(pretrained=True).cuda()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "@interact(first_epoch=(0,100), last_epoch=(1,100), step_size=(1, 30, 1), gamma=(0, 1, 0.05), lr='0.001', T_max=(1,10),\n", 43 | " enable_steplr=True, \n", 44 | " enable_explr=True,\n", 45 | " enable_cosinelr=False,\n", 46 | " enable_multisteplr=True)\n", 47 | "def draw_schedules(first_epoch=0, last_epoch=50, step_size=3, gamma=0.9, lr=0.001, T_max=1, \n", 48 | " enable_steplr=True,\n", 49 | " enable_explr=True,\n", 50 | " enable_cosinelr=False,\n", 51 | " enable_multisteplr=True):\n", 52 | " lr = float(lr)\n", 53 | " optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=0.0001)\n", 54 | "\n", 55 | " schedulers = {}\n", 56 | " if enable_explr:\n", 57 | " schedulers['ExponentialLR'] = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)\n", 58 | " if enable_steplr:\n", 59 | " schedulers['StepLR'] = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma)\n", 60 | " if enable_cosinelr:\n", 61 | " schedulers['CosineAnnealingLR'] = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)\n", 62 | " if enable_multisteplr:\n", 63 | " schedulers['MultiStepLR'] = MultiStepLR(optimizer, milestones=[30,80], gamma=gamma)\n", 64 | " \n", 65 | " epochs = []\n", 66 | " lr_values = {}\n", 67 | " for name in schedulers.keys():\n", 68 | " lr_values[name] = []\n", 69 | "\n", 70 | " for epoch in range(first_epoch, last_epoch):\n", 71 | " epochs.append(epoch)\n", 72 | " for name, scheduler in schedulers.items():\n", 73 | " scheduler.step(epoch)\n", 74 | " lr = scheduler.get_lr()\n", 75 | " lr_values[name].append(lr) \n", 76 | "\n", 77 | " for name in schedulers.keys():\n", 78 | " plt.plot(epochs, lr_values[name])\n", 79 | " plt.ylabel('LR')\n", 80 | " plt.xlabel('epoch')\n", 81 | " plt.title('Learning Rate Schedulers')\n", 82 | " plt.show()\n" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "# References\n", 90 | "\n", 91 | " 1.
**http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html**\n", 92 | " 2.
**http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate** " 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "Python 3", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.5.2" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 2 124 | } 125 | -------------------------------------------------------------------------------- /distiller/data_loggers/collector.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | from distiller.utils import sparsity 19 | from torchnet.meter import AverageValueMeter 20 | import logging 21 | msglogger = logging.getLogger() 22 | 23 | __all__ = ['ActivationSparsityCollector'] 24 | 25 | class DataCollector(object): 26 | def __init__(self): 27 | pass 28 | 29 | 30 | class ActivationSparsityCollector(DataCollector): 31 | """Collect model activation sparsity information. 32 | 33 | CNN models with ReLU layers, exhibit sparse activations. 34 | ActivationSparsityCollector will collect information about this sparsity. 35 | Currently we only record the mean sparsity of the activations, but this can be expanded 36 | to collect std and other statistics. 37 | 38 | The current implementation activation sparsity collection has a few caveats: 39 | * It is slow 40 | * It can't access the activations of torch.Functions, only torch.Modules. 41 | * The layer names are mangled 42 | 43 | ActivationSparsityCollector uses the forward hook of modules in order to access the 44 | feature-maps. This is both slow and limits us to seeing only the outputs of torch.Modules. 45 | We can remove some of the slowness, by choosing to log only specific layers. By default, 46 | we only logs torch.nn.ReLU activations. 47 | 48 | The layer names are mangled, because torch.Modules don't have names and we need to invent 49 | a unique name per layer. 50 | """ 51 | def __init__(self, model, classes=[torch.nn.ReLU]): 52 | """Since only specific layers produce sparse feature-maps, the 53 | ActivationSparsityCollector constructor accepts an optional list of layers to log.""" 54 | 55 | super(ActivationSparsityCollector, self).__init__() 56 | self.model = model 57 | self.classes = classes 58 | self._init_activations_sparsity(model) 59 | 60 | def value(self): 61 | """Return a dictionary containing {layer_name: mean sparsity}""" 62 | activation_sparsity = {} 63 | _collect_activations_sparsity(self.model, activation_sparsity) 64 | return activation_sparsity 65 | 66 | 67 | def _init_activations_sparsity(self, module, name=''): 68 | def __activation_sparsity_cb(module, input, output): 69 | """Record the activation sparsity of 'module' 70 | 71 | This is a callback from the forward() of 'module'. 72 | """ 73 | module.sparsity.add(sparsity(output.data)) 74 | 75 | has_children = False 76 | for name, sub_module in module._modules.items(): 77 | self._init_activations_sparsity(sub_module, name) 78 | has_children = True 79 | if not has_children: 80 | if type(module) in self.classes: 81 | module.register_forward_hook(__activation_sparsity_cb) 82 | module.sparsity = AverageValueMeter() 83 | if hasattr(module, 'ref_name'): 84 | module.sparsity.name = 'sparsity_' + module.ref_name 85 | else: 86 | module.sparsity.name = 'sparsity_' + name + '_' + module.__class__.__name__ + '_' + str(id(module)) 87 | 88 | @staticmethod 89 | def _collect_activations_sparsity(model, activation_sparsity, name=''): 90 | for name, module in model._modules.items(): 91 | _collect_activations_sparsity(module, activation_sparsity, name) 92 | 93 | if hasattr(model, 'sparsity'): 94 | activation_sparsity[model.sparsity.name] = model.sparsity.mean 95 | 96 | 97 | class TrainingProgressCollector(DataCollector): 98 | def __init__(self, stats = {}): 99 | super(TrainingProgressCollector, self).__init__() 100 | object.__setattr__(self, '_stats', stats) 101 | 102 | def __setattr__(self, name, value): 103 | stats = self.__dict__.get('_stats') 104 | stats[name] = value 105 | 106 | def __getattr__(self, name): 107 | if name in self.__dict__['_stats']: 108 | return self.__dict__['_stats'][name] 109 | raise AttributeError("'{}' object has no attribute '{}'".format( 110 | type(self).__name__, name)) 111 | 112 | def value(self): 113 | return self._stats 114 | -------------------------------------------------------------------------------- /jupyter/parameter_histograms.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Parameter Histograms\n", 8 | "\n", 9 | "This notebook loads a model and draws the histograms of the parameters tensors." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import torchvision\n", 20 | "import torch.nn as nn\n", 21 | "from torch.autograd import Variable\n", 22 | "import scipy.stats as ss\n", 23 | "\n", 24 | "# Relative import of code from distiller, w/o installing the package\n", 25 | "import os\n", 26 | "import sys\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "module_path = os.path.abspath(os.path.join('..'))\n", 30 | "if module_path not in sys.path:\n", 31 | " sys.path.append(module_path)\n", 32 | "\n", 33 | "import distiller\n", 34 | "import models\n", 35 | "from apputils import *\n", 36 | "\n", 37 | "plt.style.use('seaborn') # pretty matplotlib plots" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Load your model" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "scrolled": false 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# It is interesting to compare the distribution of non-pretrained model (Normally-distributed)\n", 56 | "# vs. the distribution of the pretrained model.\n", 57 | "model = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=True)\n", 58 | "\n", 59 | "# Optionally load your compressed model \n", 60 | "# load_checkpoint(model, );" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "## Plot the distributions\n", 68 | "\n", 69 | "We plot the distributions of the weights of each convolution layer, and we also plot the fitted Gaussian and Laplacian distributions." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "scrolled": false 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "def flatten(weights):\n", 81 | " weights = weights.view(weights.numel())\n", 82 | " weights = weights.data.cpu().numpy()\n", 83 | " return weights\n", 84 | "\n", 85 | "REMOVE_ZEROS = False\n", 86 | "nbins = 500\n", 87 | "for name, weights in model.named_parameters():\n", 88 | " if weights.dim() == 4:\n", 89 | " size_str = \"x\".join([str(s) for s in weights.size()])\n", 90 | " weights = flatten(weights)\n", 91 | " \n", 92 | " if REMOVE_ZEROS:\n", 93 | " # Optionally remove zeros (lots of zeros will dominate the histogram and the \n", 94 | " # other data will be hard to see\n", 95 | " weights = weights[weights!=0]\n", 96 | " \n", 97 | " # Fit the data to the Normal distribution\n", 98 | " (mean_fitted, std_fitted) = ss.norm.fit(weights)\n", 99 | " x = np.linspace(min(weights), max(weights), nbins)\n", 100 | " weights_gauss_fitted = ss.norm.pdf(x, loc=mean_fitted, scale=std_fitted)\n", 101 | "\n", 102 | " # Fit the data to the Laplacian distribution\n", 103 | " (mean_fitted, std_fitted) = ss.laplace.fit(weights)\n", 104 | " weights_laplace_fitted = ss.laplace.pdf(x, loc=mean_fitted, scale=std_fitted)\n", 105 | "\n", 106 | " n, bins, patches = plt.hist(weights, histtype='stepfilled', \n", 107 | " cumulative=False, bins=nbins, normed=1)\n", 108 | " plt.plot(x, weights_gauss_fitted, label='gauss')\n", 109 | " plt.plot(x, weights_laplace_fitted, label='laplace')\n", 110 | " plt.title(name + \" - \" +size_str)\n", 111 | " #plt.figure(figsize=(10,5))\n", 112 | " plt.legend()\n", 113 | " plt.show()" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "Python 3", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.5.2" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 2 145 | } 146 | -------------------------------------------------------------------------------- /models/cifar10/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """VGG for CIFAR10 18 | 19 | VGG for CIFAR10, based on "Very Deep Convolutional Networks for Large-Scale 20 | Image Recognition". 21 | This is based on TorchVision's implementation of VGG for ImageNet, with 22 | appropriate changes for the 10-class Cifar-10 dataset. 23 | We replaced the three linear classifiers with a single one. 24 | """ 25 | 26 | import torch.nn as nn 27 | 28 | __all__ = [ 29 | 'VGGCifar', 'vgg11_cifar', 'vgg11_bn_cifar', 'vgg13_cifar', 'vgg13_bn_cifar', 'vgg16_cifar', 'vgg16_bn_cifar', 30 | 'vgg19_bn_cifar', 'vgg19_cifar', 31 | ] 32 | 33 | 34 | class VGGCifar(nn.Module): 35 | def __init__(self, features, num_classes=10, init_weights=True): 36 | super(VGGCifar, self).__init__() 37 | self.features = features 38 | self.classifier = nn.Linear(512, num_classes) 39 | if init_weights: 40 | self._initialize_weights() 41 | 42 | def forward(self, x): 43 | x = self.features(x) 44 | x = x.view(x.size(0), -1) 45 | x = self.classifier(x) 46 | return x 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | elif isinstance(m, nn.BatchNorm2d): 55 | nn.init.constant_(m.weight, 1) 56 | nn.init.constant_(m.bias, 0) 57 | elif isinstance(m, nn.Linear): 58 | nn.init.normal_(m.weight, 0, 0.01) 59 | nn.init.constant_(m.bias, 0) 60 | 61 | 62 | def make_layers(cfg, batch_norm=False): 63 | layers = [] 64 | in_channels = 3 65 | for v in cfg: 66 | if v == 'M': 67 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 68 | else: 69 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 70 | if batch_norm: 71 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 72 | else: 73 | layers += [conv2d, nn.ReLU(inplace=True)] 74 | in_channels = v 75 | return nn.Sequential(*layers) 76 | 77 | 78 | cfg = { 79 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 80 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 81 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 82 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 83 | } 84 | 85 | 86 | def vgg11_cifar(**kwargs): 87 | """VGG 11-layer model (configuration "A")""" 88 | model = VGGCifar(make_layers(cfg['A']), **kwargs) 89 | return model 90 | 91 | 92 | def vgg11_bn_cifar(**kwargs): 93 | """VGG 11-layer model (configuration "A") with batch normalization""" 94 | model = VGGCifar(make_layers(cfg['A'], batch_norm=True), **kwargs) 95 | return model 96 | 97 | 98 | def vgg13_cifar(**kwargs): 99 | """VGG 13-layer model (configuration "B")""" 100 | model = VGGCifar(make_layers(cfg['B']), **kwargs) 101 | return model 102 | 103 | 104 | def vgg13_bn_cifar(**kwargs): 105 | """VGG 13-layer model (configuration "B") with batch normalization""" 106 | model = VGGCifar(make_layers(cfg['B'], batch_norm=True), **kwargs) 107 | return model 108 | 109 | 110 | def vgg16_cifar(**kwargs): 111 | """VGG 16-layer model (configuration "D") 112 | """ 113 | model = VGGCifar(make_layers(cfg['D']), **kwargs) 114 | return model 115 | 116 | 117 | def vgg16_bn_cifar(**kwargs): 118 | """VGG 16-layer model (configuration "D") with batch normalization""" 119 | model = VGGCifar(make_layers(cfg['D'], batch_norm=True), **kwargs) 120 | return model 121 | 122 | 123 | def vgg19_cifar(**kwargs): 124 | """VGG 19-layer model (configuration "E") 125 | """ 126 | model = VGGCifar(make_layers(cfg['E']), **kwargs) 127 | return model 128 | 129 | 130 | def vgg19_bn_cifar(**kwargs): 131 | """VGG 19-layer model (configuration 'E') with batch normalization""" 132 | model = VGGCifar(make_layers(cfg['E'], batch_norm=True), **kwargs) 133 | return model 134 | -------------------------------------------------------------------------------- /distiller/pruning/ranked_structures_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import torch 19 | import distiller 20 | from .pruner import _ParameterPruner 21 | msglogger = logging.getLogger() 22 | 23 | 24 | # TODO: support different policies for ranking structures 25 | class L1RankedStructureParameterPruner(_ParameterPruner): 26 | """Uses mean L1-norm to rank structures and prune a specified percentage of structures 27 | """ 28 | def __init__(self, name, reg_regims): 29 | super(L1RankedStructureParameterPruner, self).__init__(name) 30 | self.name = name 31 | self.reg_regims = reg_regims 32 | 33 | 34 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 35 | if param_name not in self.reg_regims.keys(): 36 | return 37 | 38 | group_type = self.reg_regims[param_name][1] 39 | fraction_to_prune = self.reg_regims[param_name][0] 40 | if fraction_to_prune == 0: 41 | return 42 | 43 | if group_type in ['3D', 'Filters']: 44 | return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict) 45 | elif group_type == 'Channels': 46 | return self.rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict) 47 | else: 48 | raise ValueError("Currently only filter (3D) and channel ranking is supported") 49 | 50 | @staticmethod 51 | def rank_channels(fraction_to_prune, param): 52 | num_filters = param.size(0) 53 | num_channels = param.size(1) 54 | kernel_size = param.size(2) * param.size(3) 55 | 56 | # First, reshape the weights tensor such that each channel (kernel) in the original 57 | # tensor, is now a row in the 2D tensor. 58 | view_2d = param.view(-1, kernel_size) 59 | # Next, compute the sums of each kernel 60 | kernel_sums = view_2d.abs().sum(dim=1) 61 | # Now group by channels 62 | k_sums_mat = kernel_sums.view(num_filters, num_channels).t() 63 | channel_mags = k_sums_mat.mean(dim=1) 64 | k = int(fraction_to_prune * channel_mags.size(0)) 65 | if k == 0: 66 | msglogger.info("Too few channels (%d)- can't prune %.1f%% channels", 67 | num_channels, 100*fraction_to_prune) 68 | return None, None 69 | 70 | bottomk, _ = torch.topk(channel_mags, k, largest=False, sorted=True) 71 | return bottomk, channel_mags 72 | 73 | @staticmethod 74 | def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict): 75 | bottomk_channels, channel_mags = L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param) 76 | if bottomk_channels is None: 77 | # Empty list means that fraction_to_prune is too low to prune anything 78 | return 79 | 80 | num_filters = param.size(0) 81 | num_channels = param.size(1) 82 | 83 | threshold = bottomk_channels[-1] 84 | binary_map = channel_mags.gt(threshold).type(param.data.type()) 85 | a = binary_map.expand(num_filters, num_channels) 86 | c = a.unsqueeze(-1) 87 | d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous() 88 | zeros_mask_dict[param_name].mask = d.view(num_filters, num_channels, param.size(2), param.size(3)) 89 | 90 | msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, 91 | distiller.sparsity_ch(zeros_mask_dict[param_name].mask), 92 | fraction_to_prune, len(bottomk_channels), num_channels) 93 | 94 | @staticmethod 95 | def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict): 96 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 97 | view_filters = param.view(param.size(0), -1) 98 | filter_mags = view_filters.data.norm(1, dim=1) # same as view_filters.data.abs().sum(dim=1) 99 | topk_filters = int(fraction_to_prune * filter_mags.size(0)) 100 | if topk_filters == 0: 101 | msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) 102 | return 103 | 104 | bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) 105 | threshold = bottomk[-1] 106 | binary_map = filter_mags.gt(threshold).type(param.data.type()) 107 | expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() 108 | zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3)) 109 | msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, 110 | distiller.sparsity(zeros_mask_dict[param_name].mask), 111 | fraction_to_prune, topk_filters, filter_mags.size(0)) 112 | -------------------------------------------------------------------------------- /examples/sensitivity-pruning/alexnet.schedule_sensitivity.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # This schedule is an example of "Iterative Pruning" for Alexnet/Imagent, as 3 | # described in chapter 3 of Song Han's PhD dissertation: "EFFICIENT METHODS AND 4 | # HARDWARE FOR DEEP LEARNING" 5 | # 6 | # The pruning policy uses multiple pruning phases. Each pruning phase is 7 | # followed by a retraining phase. 8 | # In this particular policy, pruning is scheduled every 2 epochs. 9 | # After 38/2 pruning phases, pruning ends and the only retraining continues. 10 | # 11 | # time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j=44 --epochs=90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml 12 | # 13 | # Parameters: 14 | # 15 | # +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ 16 | # | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | 17 | # |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| 18 | # | 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13373 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.43716 | 0.14381 | -0.00002 | 0.08794 | 19 | # | 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115322 | 0.00000 | 0.00000 | 0.00000 | 2.04264 | 0.00000 | 62.46029 | 0.04702 | -0.00248 | 0.02286 | 20 | # | 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256454 | 0.00000 | 0.00000 | 0.00000 | 6.13742 | 0.00000 | 61.35133 | 0.03354 | -0.00184 | 0.01803 | 21 | # | 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315278 | 0.00000 | 0.00000 | 0.00000 | 7.02922 | 0.00000 | 64.36474 | 0.02647 | -0.00168 | 0.01423 | 22 | # | 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186861 | 0.00000 | 0.00000 | 0.00000 | 15.72266 | 0.00000 | 68.31919 | 0.02714 | -0.00245 | 0.01408 | 23 | # | 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3395124 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 91.00599 | 0.00589 | -0.00020 | 0.00168 | 24 | # | 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1783541 | 0.21973 | 3.49121 | 0.00000 | 3.49121 | 0.00000 | 89.36927 | 0.00849 | -0.00066 | 0.00263 | 25 | # | 7 | classifier.6.weight | (1000, 4096) | 4096000 | 993134 | 3.39355 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.75356 | 0.01718 | 0.00029 | 0.00777 | 26 | # | 8 | Total sparsity: | - | 61090496 | 7059087 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.44487 | 0.00000 | 0.00000 | 0.00000 | 27 | # +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ 28 | # Total sparsity: 88.44 29 | # 30 | # --- validate (epoch=89)----------- 31 | # 128116 samples (256 per mini-batch) 32 | # Epoch: [89][ 50/ 500] Loss 2.149753 Top1 51.976562 Top5 74.859375 33 | # Epoch: [89][ 100/ 500] Loss 2.154934 Top1 51.941406 Top5 74.550781 34 | # Epoch: [89][ 150/ 500] Loss 2.159868 Top1 51.880208 Top5 74.513021 35 | # Epoch: [89][ 200/ 500] Loss 2.158245 Top1 51.875000 Top5 74.597656 36 | # Epoch: [89][ 250/ 500] Loss 2.150266 Top1 51.920313 Top5 74.667187 37 | # Epoch: [89][ 300/ 500] Loss 2.152199 Top1 51.933594 Top5 74.682292 38 | # Epoch: [89][ 350/ 500] Loss 2.152126 Top1 51.952009 Top5 74.684152 39 | # Epoch: [89][ 400/ 500] Loss 2.153599 Top1 51.949219 Top5 74.648438 40 | # Epoch: [89][ 450/ 500] Loss 2.151281 Top1 52.046875 Top5 74.703993 41 | # Epoch: [89][ 500/ 500] Loss 2.149620 Top1 52.032031 Top5 74.765625 42 | # ==> Top1: 52.029 Top5: 74.767 Loss: 2.150 43 | # 44 | # Saving checkpoint 45 | # --- test --------------------- 46 | # 50000 samples (256 per mini-batch) 47 | # Test: [ 50/ 195] Loss 1.484814 Top1 63.328125 Top5 85.820312 48 | # Test: [ 100/ 195] Loss 1.636993 Top1 60.835938 Top5 83.617188 49 | # Test: [ 150/ 195] Loss 1.832027 Top1 57.713542 Top5 80.330729 50 | # ==> Top1: 56.762 Top5: 79.340 Loss: 1.892 51 | # 52 | # 53 | # Log file for this run: /data/home/cvds_lab/nzmora/private-distiller/examples/classifier_compression/logs/2018.04.08-154509/2018.04.08-154509.log 54 | # 55 | # real 646m54.061s 56 | # user 14899m29.068s 57 | # sys 1901m19.958s 58 | 59 | version: 1 60 | pruners: 61 | pruner1: 62 | class: 'SensitivityPruner' 63 | sensitivities: 64 | 'features.module.0.weight': 0.25 65 | 'features.module.3.weight': 0.35 66 | 'features.module.6.weight': 0.40 67 | 'features.module.8.weight': 0.45 68 | 'features.module.10.weight': 0.55 69 | 'classifier.1.weight': 0.875 70 | 'classifier.4.weight': 0.875 71 | 'classifier.6.weight': 0.625 72 | 73 | lr_schedulers: 74 | pruning_lr: 75 | class: ExponentialLR 76 | gamma: 0.9 77 | 78 | policies: 79 | - pruner: 80 | instance_name : 'pruner1' 81 | starting_epoch: 0 82 | ending_epoch: 38 83 | frequency: 2 84 | 85 | - lr_scheduler: 86 | instance_name: pruning_lr 87 | starting_epoch: 24 88 | ending_epoch: 200 89 | frequency: 1 90 | -------------------------------------------------------------------------------- /apputils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ Helper code for checkpointing models, with support for saving the pruning schedule. 18 | 19 | Adding the schedule information in the model checkpoint is helpful in resuming 20 | a pruning session, or for querying the pruning schedule of a sparse model. 21 | """ 22 | 23 | import os 24 | import shutil 25 | from errno import ENOENT 26 | import logging 27 | import torch 28 | import distiller 29 | msglogger = logging.getLogger() 30 | 31 | 32 | def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, 33 | best_top1=None, is_best=False, name=None, dir='.'): 34 | """Save a pytorch training checkpoint 35 | 36 | Args: 37 | epoch: current epoch 38 | arch: name of the network arechitecture/topology 39 | model: a pytorch model 40 | optimizer: the optimizer used in the training session 41 | scheduler: the CompressionScheduler instance used for training, if any 42 | best_top1: the best top1 score seen so far 43 | is_best: True if this is the best (top1 accuracy) model so far 44 | name: the name of the checkpoint file 45 | dir: directory in which to save the checkpoint 46 | """ 47 | if not os.path.isdir(dir): 48 | raise IOError(ENOENT, 'Checkpoint directory does not exist at', os.path.abspath(dir)) 49 | 50 | filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar' 51 | fullpath = os.path.join(dir, filename) 52 | msglogger.info("Saving checkpoint to: %s" % fullpath) 53 | filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar' 54 | fullpath_best = os.path.join(dir, filename_best) 55 | checkpoint = {} 56 | checkpoint['epoch'] = epoch 57 | checkpoint['arch'] = arch 58 | checkpoint['state_dict'] = model.state_dict() 59 | if best_top1 is not None: 60 | checkpoint['best_top1'] = best_top1 61 | if optimizer is not None: 62 | checkpoint['optimizer'] = optimizer.state_dict() 63 | if scheduler is not None: 64 | checkpoint['compression_sched'] = scheduler.state_dict() 65 | if hasattr(model, 'thinning_recipes'): 66 | checkpoint['thinning_recipes'] = model.thinning_recipes 67 | if hasattr(model, 'quantizer_metadata'): 68 | checkpoint['quantizer_metadata'] = model.quantizer_metadata 69 | 70 | torch.save(checkpoint, fullpath) 71 | if is_best: 72 | shutil.copyfile(fullpath, fullpath_best) 73 | 74 | 75 | def load_checkpoint(model, chkpt_file, optimizer=None): 76 | """Load a pytorch training checkpoint 77 | 78 | Args: 79 | model: the pytorch model to which we will load the parameters 80 | chkpt_file: the checkpoint file 81 | optimizer: the optimizer to which we will load the serialized state 82 | """ 83 | compression_scheduler = None 84 | start_epoch = 0 85 | 86 | if os.path.isfile(chkpt_file): 87 | msglogger.info("=> loading checkpoint %s", chkpt_file) 88 | checkpoint = torch.load(chkpt_file) 89 | msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys()))) 90 | start_epoch = checkpoint['epoch'] + 1 91 | best_top1 = checkpoint.get('best_top1', None) 92 | if best_top1 is not None: 93 | msglogger.info(" best top@1: %.3f", best_top1) 94 | 95 | if 'compression_sched' in checkpoint: 96 | compression_scheduler = distiller.CompressionScheduler(model) 97 | compression_scheduler.load_state_dict(checkpoint['compression_sched']) 98 | msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", 99 | checkpoint['epoch']) 100 | else: 101 | msglogger.info("Warning: compression schedule data does not exist in the checkpoint") 102 | 103 | if 'thinning_recipes' in checkpoint: 104 | if 'compression_sched' not in checkpoint: 105 | raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched") 106 | msglogger.info("Loaded a thinning recipe from the checkpoint") 107 | # Cache the recipes in case we need them later 108 | model.thinning_recipes = checkpoint['thinning_recipes'] 109 | distiller.execute_thinning_recipes_list(model, 110 | compression_scheduler.zeros_mask_dict, 111 | model.thinning_recipes) 112 | 113 | if 'quantizer_metadata' in checkpoint: 114 | msglogger.info('Loaded quantizer metadata from the checkpoint') 115 | qmd = checkpoint['quantizer_metadata'] 116 | quantizer = qmd['type'](model, **qmd['params']) 117 | quantizer.prepare_model() 118 | 119 | msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, checkpoint['epoch']) 120 | 121 | model.load_state_dict(checkpoint['state_dict']) 122 | return model, compression_scheduler, start_epoch 123 | else: 124 | raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) 125 | -------------------------------------------------------------------------------- /models/cifar10/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Resnet for CIFAR10 18 | 19 | Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition". 20 | This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate 21 | changes for the 10-class Cifar-10 dataset. 22 | This ResNet also has layer gates, to be able to dynamically remove layers. 23 | 24 | @inproceedings{DBLP:conf/cvpr/HeZRS16, 25 | author = {Kaiming He and 26 | Xiangyu Zhang and 27 | Shaoqing Ren and 28 | Jian Sun}, 29 | title = {Deep Residual Learning for Image Recognition}, 30 | booktitle = {{CVPR}}, 31 | pages = {770--778}, 32 | publisher = {{IEEE} Computer Society}, 33 | year = {2016} 34 | } 35 | 36 | """ 37 | import torch.nn as nn 38 | import math 39 | import torch.utils.model_zoo as model_zoo 40 | 41 | 42 | __all__ = ['resnet20_cifar', 'resnet32_cifar', 'resnet44_cifar', 'resnet56_cifar'] 43 | 44 | NUM_CLASSES = 10 45 | 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | """3x3 convolution with padding""" 48 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 49 | padding=1, bias=False) 50 | 51 | class BasicBlock(nn.Module): 52 | expansion = 1 53 | 54 | def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None): 55 | super(BasicBlock, self).__init__() 56 | self.block_gates = block_gates 57 | self.conv1 = conv3x3(inplanes, planes, stride) 58 | self.bn1 = nn.BatchNorm2d(planes) 59 | self.relu1 = nn.ReLU(inplace=False) # To enable layer removal inplace must be False 60 | self.conv2 = conv3x3(planes, planes) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.relu2 = nn.ReLU(inplace=False) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = out = x 68 | 69 | if self.block_gates[0]: 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu1(out) 73 | 74 | if self.block_gates[1]: 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu2(out) 83 | 84 | return out 85 | 86 | 87 | class ResNetCifar(nn.Module): 88 | 89 | def __init__(self, block, layers, num_classes=NUM_CLASSES): 90 | self.nlayers = 0 91 | # Each layer manages its own gates 92 | self.layer_gates = [] 93 | for layer in range(3): 94 | # For each of the 3 layers, create block gates: each block has two layers 95 | self.layer_gates.append([]) # [True, True] * layers[layer]) 96 | for blk in range(layers[layer]): 97 | self.layer_gates[layer].append([True, True]) 98 | 99 | self.inplanes = 16 # 64 100 | super(ResNetCifar, self).__init__() 101 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 102 | self.bn1 = nn.BatchNorm2d(self.inplanes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0]) 105 | self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2) 106 | self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2) 107 | self.avgpool = nn.AvgPool2d(8, stride=1) 108 | self.fc = nn.Linear(64 * block.expansion, num_classes) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, layer_gates, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion), 125 | ) 126 | 127 | layers = [] 128 | layers.append(block(layer_gates[0], self.inplanes, planes, stride, downsample)) 129 | self.inplanes = planes * block.expansion 130 | for i in range(1, blocks): 131 | layers.append(block(layer_gates[i], self.inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | x = self.conv1(x) 137 | x = self.bn1(x) 138 | x = self.relu(x) 139 | 140 | x = self.layer1(x) 141 | x = self.layer2(x) 142 | x = self.layer3(x) 143 | 144 | x = self.avgpool(x) 145 | x = x.view(x.size(0), -1) 146 | x = self.fc(x) 147 | 148 | return x 149 | 150 | 151 | def resnet20_cifar(**kwargs): 152 | model = ResNetCifar(BasicBlock, [3, 3, 3], **kwargs) 153 | return model 154 | 155 | def resnet32_cifar(**kwargs): 156 | model = ResNetCifar(BasicBlock, [5, 5, 5], **kwargs) 157 | return model 158 | 159 | def resnet44_cifar(**kwargs): 160 | model = ResNetCifar(BasicBlock, [7, 7, 7], **kwargs) 161 | return model 162 | 163 | def resnet56_cifar(**kwargs): 164 | model = ResNetCifar(BasicBlock, [9, 9, 9], **kwargs) 165 | return model 166 | -------------------------------------------------------------------------------- /examples/hybrid/alexnet.schedule_sensitivity_2D-reg.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # This schedule is an example of "Iterative Pruning" for Alexnet/Imagent, combined 3 | # with 2D structure regularization for the Convolution weights. 4 | # 5 | # time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 24 --epochs 90 --pretrained --compress=../hybrid/alexnet.schedule_sensitivity_2D-reg.yaml 6 | # time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 24 --epochs 90 --pretrained --compress=../hybrid/alexnet.schedule_sensitivity_2D-reg.yaml 7 | # Parameters: 8 | # 9 | # +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ 10 | # | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | 11 | # |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| 12 | # | 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13380 | 0.00000 | 0.00000 | 0.00000 | 0.52083 | 0.00000 | 42.40702 | 0.15288 | 0.00001 | 0.09403 | 13 | # | 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 102744 | 0.00000 | 0.00000 | 0.00000 | 9.17969 | 0.00000 | 66.55469 | 0.04458 | -0.00215 | 0.02018 | 14 | # | 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 176986 | 0.00000 | 0.00000 | 0.00000 | 29.33757 | 0.00000 | 73.32734 | 0.02720 | -0.00124 | 0.01197 | 15 | # | 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 199956 | 0.00000 | 0.00000 | 0.00000 | 35.29867 | 0.00000 | 77.39936 | 0.02040 | -0.00092 | 0.00869 | 16 | # | 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 131286 | 0.00000 | 0.00000 | 0.00000 | 43.33954 | 0.00000 | 77.74150 | 0.02280 | -0.00154 | 0.00987 | 17 | # | 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3643767 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.34731 | 0.00603 | -0.00019 | 0.00178 | 18 | # | 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1892052 | 0.21973 | 3.56445 | 0.00000 | 3.56445 | 0.00000 | 88.72249 | 0.00879 | -0.00067 | 0.00280 | 19 | # | 7 | classifier.6.weight | (1000, 4096) | 4096000 | 1022778 | 3.44238 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.02983 | 0.01783 | 0.00039 | 0.00816 | 20 | # | 8 | Total sparsity: | - | 61090496 | 7182950 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.24212 | 0.00000 | 0.00000 | 0.00000 | 21 | # +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ 22 | # Total sparsity: 88.24 23 | # 24 | # --- validate (epoch=89)----------- 25 | # 128116 samples (256 per mini-batch) 26 | # Epoch: [89][ 50/ 500] Loss 2.179626 Top1 51.585938 Top5 73.976562 27 | # Epoch: [89][ 100/ 500] Loss 2.188121 Top1 51.261719 Top5 74.054688 28 | # Epoch: [89][ 150/ 500] Loss 2.186677 Top1 51.302083 Top5 74.096354 29 | # Epoch: [89][ 200/ 500] Loss 2.188725 Top1 51.195312 Top5 74.007812 30 | # Epoch: [89][ 250/ 500] Loss 2.184323 Top1 51.342188 Top5 74.150000 31 | # Epoch: [89][ 300/ 500] Loss 2.181935 Top1 51.441406 Top5 74.194010 32 | # Epoch: [89][ 350/ 500] Loss 2.180590 Top1 51.477679 Top5 74.223214 33 | # Epoch: [89][ 400/ 500] Loss 2.177557 Top1 51.538086 Top5 74.300781 34 | # Epoch: [89][ 450/ 500] Loss 2.178948 Top1 51.572049 Top5 74.275174 35 | # Epoch: [89][ 500/ 500] Loss 2.178128 Top1 51.576563 Top5 74.308594 36 | # ==> Top1: 51.577 Top5: 74.305 Loss: 2.178 37 | # 38 | # Saving checkpoint 39 | # --- test --------------------- 40 | # 50000 samples (256 per mini-batch) 41 | # Test: [ 50/ 195] Loss 1.514649 Top1 62.546875 Top5 85.429688 42 | # Test: [ 100/ 195] Loss 1.659908 Top1 60.261719 Top5 83.367188 43 | # Test: [ 150/ 195] Loss 1.852519 Top1 57.171875 Top5 80.187500 44 | # ==> Top1: 56.240 Top5: 79.246 Loss: 1.911 45 | # 46 | # 47 | # Log file for this run: /data/home/cvds_lab/nzmora/pytorch_workspace/private-distiller/examples/classifier_compression/logs/2018.04.10-030052/2018.04.10-030052.log 48 | # 49 | # real 1031m47.668s 50 | # user 23604m55.588s 51 | # sys 1772m45.380s 52 | 53 | version: 1 54 | pruners: 55 | my_pruner: 56 | class: 'SensitivityPruner' 57 | sensitivities: 58 | 'features.module.0.weight': 0.25 59 | 'features.module.3.weight': 0.35 60 | 'features.module.6.weight': 0.40 61 | 'features.module.8.weight': 0.45 62 | 'features.module.10.weight': 0.55 63 | 'classifier.1.weight': 0.875 64 | 'classifier.4.weight': 0.875 65 | 'classifier.6.weight': 0.625 66 | 67 | regularizers: 68 | 2d_groups_regularizer: 69 | class: GroupLassoRegularizer 70 | reg_regims: 71 | 'features.module.0.weight': [0.000012, '2D'] 72 | 'features.module.3.weight': [0.000012, '2D'] 73 | 'features.module.6.weight': [0.000012, '2D'] 74 | 'features.module.8.weight': [0.000012, '2D'] 75 | 'features.module.10.weight': [0.000012, '2D'] 76 | #'classifier.1.weight': [0.000012, '2D'] 77 | #'classifier.4.weight': [0.000012, '2D'] 78 | #'classifier.6.weight': [0.000012, '2D'] 79 | 80 | 81 | lr_schedulers: 82 | # Learning rate decay scheduler 83 | pruning_lr: 84 | class: ExponentialLR 85 | gamma: 0.9 86 | 87 | policies: 88 | - pruner: 89 | instance_name : 'my_pruner' 90 | starting_epoch: 0 91 | ending_epoch: 38 92 | frequency: 2 93 | 94 | - regularizer: 95 | instance_name: '2d_groups_regularizer' 96 | starting_epoch: 0 97 | ending_epoch: 38 98 | frequency: 1 99 | 100 | - lr_scheduler: 101 | instance_name: pruning_lr 102 | starting_epoch: 24 103 | ending_epoch: 200 104 | frequency: 1 105 | -------------------------------------------------------------------------------- /distiller/thresholding.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Tensor thresholding. 18 | 19 | The code below supports fine-grained tensor thresholding and group-wise thresholding. 20 | """ 21 | import torch 22 | 23 | def threshold_mask(weights, threshold): 24 | """Create a threshold mask for the provided parameter tensor using 25 | magnitude thresholding. 26 | 27 | Arguments: 28 | weights: a parameter tensor which should be pruned. 29 | threshold: the pruning threshold. 30 | Returns: 31 | prune_mask: The pruning mask. 32 | """ 33 | return torch.gt(torch.abs(weights), threshold).type(weights.type()) 34 | 35 | class GroupThresholdMixin(object): 36 | """A mixin class to add group thresholding capabilities""" 37 | 38 | def group_threshold_mask(self, param, group_type, threshold, threshold_criteria): 39 | """Return a threshold mask for the provided parameter and group type. 40 | 41 | Args: 42 | param: The parameter to mask 43 | group_type: The elements grouping type (structure). 44 | One of:2D, 3D, 4D, Channels, Row, Cols 45 | threshold: The threshold 46 | threshold_criteria: The thresholding criteria. 47 | 'Mean_Abs' thresholds the entire element group using the mean of the 48 | absolute values of the tensor elements. 49 | 'Max' thresholds the entire group using the magnitude of the largest 50 | element in the group. 51 | """ 52 | if group_type == '2D': 53 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 54 | view_2d = param.view(-1, param.size(2) * param.size(3)) 55 | # 1. Determine if the kernel "value" is below the threshold, by creating a 1D 56 | # thresholds tensor with length = #IFMs * # OFMs 57 | thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda() 58 | # 2. Create a binary thresholds mask, where we use the mean of the abs values of the 59 | # elements in each channel as the threshold filter. 60 | # 3. Apply the threshold filter 61 | binary_map = self.threshold_policy(view_2d, thresholds, threshold_criteria) 62 | # 3. Finally, expand the thresholds and view as a 4D tensor 63 | a = binary_map.expand(param.size(2) * param.size(3), 64 | param.size(0) * param.size(1)).t() 65 | return a.view(param.size(0), param.size(1), param.size(2), param.size(3)) 66 | 67 | elif group_type == 'Rows': 68 | assert param.dim() == 2, "This regularization is only supported for 2D weights" 69 | thresholds = torch.Tensor([threshold] * param.size(0)).cuda() 70 | binary_map = self.threshold_policy(param, thresholds, threshold_criteria) 71 | return binary_map.expand(param.size(1), param.size(0)).t() 72 | 73 | elif group_type == 'Cols': 74 | assert param.dim() == 2, "This regularization is only supported for 2D weights" 75 | thresholds = torch.Tensor([threshold] * param.size(1)).cuda() 76 | binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0) 77 | return binary_map.expand(param.size(0), param.size(1)) 78 | 79 | elif group_type == '3D' or group_type == 'Filters': 80 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 81 | view_filters = param.view(param.size(0), -1) 82 | thresholds = torch.Tensor([threshold] * param.size(0)).cuda() 83 | binary_map = self.threshold_policy(view_filters, thresholds, threshold_criteria) 84 | a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t() 85 | return a.view(param.size(0), param.size(1), param.size(2), param.size(3)) 86 | 87 | elif group_type == '4D': 88 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 89 | if threshold_criteria == 'Mean_Abs': 90 | if param.data.abs().mean() > threshold: 91 | return None 92 | return torch.zeros_like(param.data) 93 | elif threshold_criteria == 'Max': 94 | if param.data.abs().max() > threshold: 95 | return None 96 | return torch.zeros_like(param.data) 97 | exit("Invalid threshold_criteria {}".format(self.threshold_criteria)) 98 | 99 | elif group_type == 'Channels': 100 | assert param.dim() == 4, "This thresholding is only supported for 4D weights" 101 | num_filters = param.size(0) 102 | num_kernels_per_filter = param.size(1) 103 | 104 | view_2d = param.view(-1, param.size(2) * param.size(3)) 105 | # Next, compute the sum of the squares (of the elements in each row/kernel) 106 | kernel_means = view_2d.abs().mean(dim=1) 107 | k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t() 108 | thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda() 109 | binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type()) 110 | 111 | # Now let's expand back up to a 4D mask 112 | a = binary_map.expand(num_filters, num_kernels_per_filter) 113 | c = a.unsqueeze(-1) 114 | d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous() 115 | return d.view(param.size(0), param.size(1), param.size(2), param.size(3)) 116 | 117 | 118 | def threshold_policy(self, weights, thresholds, threshold_criteria, dim=1): 119 | """ 120 | """ 121 | if threshold_criteria == 'Mean_Abs': 122 | return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type()) 123 | elif threshold_criteria == 'Max': 124 | maxv, _ = weights.data.abs().max(dim=dim) 125 | return maxv.gt(thresholds).type(weights.type()) 126 | exit("Invalid threshold_criteria {}".format(threshold_criteria)) 127 | -------------------------------------------------------------------------------- /tests/test_summarygraph.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import torch 19 | import os 20 | import sys 21 | import pytest 22 | module_path = os.path.abspath(os.path.join('..')) 23 | if module_path not in sys.path: 24 | sys.path.append(module_path) 25 | import distiller 26 | from models import ALL_MODEL_NAMES, create_model 27 | from apputils import * 28 | from distiller import normalize_module_name, denormalize_module_name 29 | 30 | # Logging configuration 31 | logging.basicConfig(level=logging.DEBUG) 32 | fh = logging.FileHandler('test.log') 33 | logger = logging.getLogger() 34 | logger.addHandler(fh) 35 | 36 | 37 | def get_input(dataset): 38 | if dataset == 'imagenet': 39 | return torch.randn((1, 3, 224, 224), requires_grad=False) 40 | elif dataset == 'cifar10': 41 | return torch.randn((1, 3, 32, 32)) 42 | return None 43 | 44 | 45 | def create_graph(dataset, arch): 46 | dummy_input = get_input(dataset) 47 | assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) 48 | 49 | model = create_model(False, dataset, arch, parallel=False) 50 | assert model is not None 51 | return SummaryGraph(model, dummy_input) 52 | 53 | 54 | def test_graph(): 55 | g = create_graph('cifar10', 'resnet20_cifar') 56 | assert g is not None 57 | 58 | 59 | def test_connectivity(): 60 | g = create_graph('cifar10', 'resnet20_cifar') 61 | assert g is not None 62 | 63 | op_names = [op['name'] for op in g.ops.values()] 64 | assert 73 == len(op_names) 65 | 66 | edges = g.edges 67 | assert edges[0].src == '0' and edges[0].dst == 'conv1' 68 | 69 | # Test two sequential calls to predecessors (this was a bug once) 70 | preds = g.predecessors(g.find_op('bn1'), 1) 71 | preds = g.predecessors(g.find_op('bn1'), 1) 72 | assert preds == ['108', '2', '3', '4', '5'] 73 | # Test successors 74 | succs = g.successors(g.find_op('bn1'), 2) 75 | assert succs == ['relu'] 76 | 77 | op = g.find_op('layer1.0') 78 | assert op is not None 79 | preds = g.predecessors(op, 2) 80 | assert preds == ['layer1.0.bn2', 'relu'] 81 | 82 | op = g.find_op('layer1.0.relu2') 83 | assert op is not None 84 | succs = g.successors(op, 4) 85 | assert succs == ['layer1.1.bn1', 'layer1.1.relu2'] 86 | 87 | preds = g.predecessors(g.find_op('bn1'), 10) 88 | assert preds == [] 89 | preds = g.predecessors(g.find_op('bn1'), 3) 90 | assert preds == ['0', '1'] 91 | 92 | 93 | def test_layer_search(): 94 | g = create_graph('cifar10', 'resnet20_cifar') 95 | assert g is not None 96 | 97 | op = g.find_op('layer1.0.conv1') 98 | assert op is not None 99 | 100 | succs = g.successors_f('layer1.0.conv1', 'Conv', [], logging) 101 | assert ['layer1.0.conv2'] == succs 102 | 103 | succs = g.successors_f('relu', 'Conv', [], logging) 104 | assert succs == ['layer1.0.conv1', 'layer1.1.conv1', 'layer1.2.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0'] 105 | 106 | succs = g.successors_f('relu', 'Gemm', [], logging) 107 | assert succs == ['fc'] 108 | 109 | succs = g.successors_f('layer3.2', 'Conv', [], logging) 110 | assert succs == [] 111 | #logging.debug(succs) 112 | 113 | preds = g.predecessors_f('conv1', 'Conv', [], logging) 114 | assert preds == [] 115 | 116 | preds = g.predecessors_f('layer1.0.conv2', 'Conv', [], logging) 117 | assert preds == ['layer1.0.conv1'] 118 | 119 | preds = g.predecessors_f('layer1.0.conv1', 'Conv', [], logging) 120 | assert preds == ['conv1'] 121 | 122 | preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging) 123 | assert preds == ['layer1.0.conv2', 'conv1'] 124 | 125 | 126 | def test_vgg(): 127 | g = create_graph('imagenet', 'vgg19') 128 | assert g is not None 129 | succs = g.successors_f('features.32', 'Conv') 130 | logging.debug(succs) 131 | succs = g.successors_f('features.34', 'Conv') 132 | 133 | 134 | def test_simplenet(): 135 | g = create_graph('cifar10', 'simplenet_cifar') 136 | assert g is not None 137 | preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv') 138 | logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds)) 139 | assert len(preds) == 0 140 | 141 | preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv') 142 | logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds)) 143 | assert len(preds) == 1 144 | 145 | 146 | def name_test(dataset, arch): 147 | model = create_model(False, dataset, arch, parallel=False) 148 | modelp = create_model(False, dataset, arch, parallel=True) 149 | assert model is not None and modelp is not None 150 | 151 | mod_names = [mod_name for mod_name, _ in model.named_modules()] 152 | mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()] 153 | assert mod_names is not None and mod_names_p is not None 154 | assert len(mod_names)+1 == len(mod_names_p) 155 | 156 | for i in range(len(mod_names)-1): 157 | assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2]) 158 | logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2]))) 159 | assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1]) 160 | 161 | 162 | def test_normalize_module_name(): 163 | assert "features.0" == normalize_module_name("features.module.0") 164 | assert "features.0" == normalize_module_name("module.features.0") 165 | assert "features" == normalize_module_name("features.module") 166 | name_test('imagenet', 'vgg19') 167 | name_test('cifar10', 'resnet20_cifar') 168 | name_test('imagenet', 'alexnet') 169 | 170 | 171 | def test_onnx_name_2_pytorch_name(): 172 | assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu') 173 | assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv') 174 | assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu') 175 | #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv') 176 | 177 | 178 | def test_connectivity_summary(): 179 | g = create_graph('cifar10', 'resnet20_cifar') 180 | assert g is not None 181 | 182 | summary = connectivity_summary(g) 183 | assert len(summary) == 73 184 | 185 | verbose_summary = connectivity_summary_verbose(g) 186 | assert len(verbose_summary ) == 73 187 | 188 | 189 | if __name__ == '__main__': 190 | test_connectivity_summary() 191 | -------------------------------------------------------------------------------- /jupyter/performance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torchvision\n", 11 | "import torch.nn as nn\n", 12 | "from torch.autograd import Variable\n", 13 | "\n", 14 | "# Relative import of code from distiller, w/o installing the package\n", 15 | "import os\n", 16 | "import sys\n", 17 | "module_path = os.path.abspath(os.path.join('..'))\n", 18 | "if module_path not in sys.path:\n", 19 | " sys.path.append(module_path)\n", 20 | "\n", 21 | "import pandas as pd\n", 22 | "import distiller\n", 23 | "import models\n", 24 | "from apputils import *" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Performance overview\n", 32 | "\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "model = models.create_model(pretrained=False, dataset='imagenet', arch='resnet50', parallel=False)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)\n", 51 | "df = distiller.model_performance_summary(model, dummy_input, batch_size=1)\n", 52 | "display(df)\n", 53 | "\n", 54 | "total_macs = df['MACs'].sum()\n", 55 | "\n", 56 | "print(\"Total MACs: \" + \"{:,}\".format(total_macs))" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "### Let's take a look at how our compute is distibuted:" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "print(\"MAC distribution:\")\n", 73 | "counts = df['MACs'].value_counts()\n", 74 | "print(counts)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Let's look at which convolutions kernel sizes we're using, and how many instances:" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "print(\"Convolution kernel size distribution:\")\n", 91 | "counts = df['Attrs'].value_counts()\n", 92 | "print(counts)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "### Let's look at how the MACs are distributed between the layers and the convolution kernel sizes" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "scrolled": false 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "def get_layer_color(layer_type, attrs):\n", 111 | " if layer_type == \"Conv2d\":\n", 112 | " if attrs == 'k=(1, 1)':\n", 113 | " return 'tomato'\n", 114 | " elif attrs == 'k=(3, 3)':\n", 115 | " return 'limegreen'\n", 116 | " else:\n", 117 | " return 'steelblue'\n", 118 | " return 'indigo'\n", 119 | "\n", 120 | "df_compute = df['MACs']\n", 121 | "ax = df_compute.plot.bar(figsize=[15,10], title=\"MACs\", \n", 122 | " color=[get_layer_color(layer_type, attrs) for layer_type,attrs in zip(df['Type'], df['Attrs'])])\n", 123 | "\n", 124 | "ax.set_xticklabels(df.Name, rotation=90);" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "### How do the Weights and Feature-maps footprints distribute across the layers:" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "df['FM volume'] = df['IFM volume'] + df['OFM volume']\n", 141 | "df_footprint = df[['FM volume', 'Weights volume']]\n", 142 | "ax = df_footprint.plot.bar(figsize=[15,10], title=\"Footprint\");\n", 143 | "ax.set_xticklabels(df.Name, rotation=90);" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "### How the Arithmetic Intensity distributes across the layers:" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "df_performance = df\n", 160 | "df_performance['raw traffic'] = df_footprint['FM volume'] + df_footprint['Weights volume']\n", 161 | "df_performance['arithmetic intensity'] = df['MACs'] / df_performance['raw traffic']\n", 162 | "df_performance2 = df_performance['arithmetic intensity']\n", 163 | "ax = df_performance2.plot.bar(figsize=[15,10], title=\"Arithmetic Intensity\");\n", 164 | "ax.set_xticklabels(df.Name, rotation=90);" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "## ResNet20 channel pruning using SSL\n", 172 | "\n", 173 | "Let's see how many MACs we saved by using SSL to prune filters from ResNet20:" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "resnet20_dense = models.create_model(pretrained=False, dataset='cifar10', arch='resnet20_cifar', parallel=True)\n", 183 | "resnet20_sparse = models.create_model(pretrained=False, dataset='cifar10', arch='resnet20_cifar', parallel=True)\n", 184 | "checkpoint_file = \"../examples/ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20_finetuned.pth.tar\" \n", 185 | "load_checkpoint(resnet20_sparse, checkpoint_file);" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "dummy_input = Variable(torch.randn(1, 3, 32, 32), requires_grad=False)\n", 195 | "df_dense = distiller.model_performance_summary(resnet20_dense, dummy_input, batch_size=1)\n", 196 | "df_sparse = distiller.model_performance_summary(resnet20_sparse, dummy_input, batch_size=1)\n", 197 | "\n", 198 | "dense_macs = df_dense['MACs'].sum()\n", 199 | "sparse_macs = df_sparse['MACs'].sum()\n", 200 | "\n", 201 | "print(\"Dense MACs: \" + \"{:,}\".format(int(dense_macs)))\n", 202 | "print(\"Sparse MACs: \" + \"{:,}\".format(int(sparse_macs)))\n", 203 | "print(\"Saved MACs: %.2f%%\" % ((1 - sparse_macs/dense_macs)*100))" 204 | ] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.5.2" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 2 228 | } 229 | --------------------------------------------------------------------------------