├── LICENSE ├── README.md ├── bash ├── set_env.sh ├── train_efficientnet.sh └── train_mobilenetv2.sh ├── display └── toy_regression.gif ├── main.py ├── models ├── __init__.py ├── efficientnet_lite_quantized.py ├── mobilenet_v2.py ├── mobilenet_v2_quantized.py └── resnet_quantized.py ├── quantization ├── __init__.py ├── autoquant_utils.py ├── base_quantized_classes.py ├── base_quantized_model.py ├── hijacker.py ├── quantization_manager.py ├── quantized_folded_bn.py ├── quantizers │ ├── __init__.py │ ├── base_quantizers.py │ ├── rounding_utils.py │ ├── uniform_quantizers.py │ └── utils.py ├── range_estimators.py └── utils.py ├── requirements.txt └── utils ├── __init__.py ├── click_options.py ├── imagenet_dataloaders.py ├── optimizer_utils.py ├── oscillation_tracking_utils.py ├── qat_utils.py ├── stopwatch.py ├── supervised_driver.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Qualcomm Technologies, Inc. 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer: 8 | 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overcoming Oscillations in Quantization-Aware Training 2 | This repository containes the implementation and experiments for the paper presented in 3 | 4 | **Markus Nagel\*1, Marios Fournarakis\*1, Yelysei Bondarenko1, 5 | Tijmen Blankevoort1 "Overcoming Oscillations in Quantization-Aware Training", ICML 6 | 2022.** [[ArXiv]](https://arxiv.org/abs/2203.11086) 7 | 8 | *Equal contribution 9 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.) 10 | 11 | You can use this code to recreate the results in the paper. 12 | ## Reference 13 | If you find our work useful, please cite 14 | ``` 15 | @InProceedings{pmlr-v162-nagel22a, 16 | title = {Overcoming Oscillations in Quantization-Aware Training}, 17 | author = {Nagel, Markus and Fournarakis, Marios and Bondarenko, Yelysei and Blankevoort, Tijmen}, 18 | booktitle = {Proceedings of the 39th International Conference on Machine Learning}, 19 | pages = {16318--16330}, 20 | year = {2022}, 21 | editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, 22 | volume = {162}, 23 | series = {Proceedings of Machine Learning Research}, 24 | month = {17--23 Jul}, 25 | publisher = {PMLR}, 26 | pdf = {https://proceedings.mlr.press/v162/nagel22a/nagel22a.pdf}, 27 | url = {https://proceedings.mlr.press/v162/nagel22a.html} 28 | } 29 | ``` 30 | 31 | ## Method and Results 32 | 33 | When training neural networks with simulated quantization, we observe that quantized weights can, 34 | rather unexpectedly, oscillate between two grid-points. This is an inherent issue problem caused 35 | by the straight-through-estimator (STE). In our paper, we delve deeper in this little understood 36 | phenomenon and show that oscillations harm accuracy by corrupting the EMA statistics of the 37 | batch-normalization layers and by preventing convergence to local mimima. 38 | 39 |

40 | 41 |

42 | 43 | We propose two novel methods to tackle oscillations at their source: **oscillations dampening** 44 | and **iterative state freezing** We demonstrate that our algorithms achieve state-of-the-art 45 | accuracy for low-bit (3 & 4 bits) weight and activation quantization of efficient architectures, 46 | such as MobileNetV2, MobileNetV3, and EfficentNet-lite on ImageNet. 47 | 48 | 49 | ## How to install 50 | Make sure to have Python ≥3.6 (tested with Python 3.6.8) and 51 | ensure the latest version of `pip` (**tested** with 21.3.1): 52 | ```bash 53 | source env/bin/activate 54 | pip install --upgrade --no-deps pip 55 | ``` 56 | 57 | Next, install PyTorch 1.9.1 with the appropriate CUDA version (tested with CUDA 10.0, CuDNN 7.6.3): 58 | ```bash 59 | pip install torch==1.9.1 torchvision==0.10.1 60 | ``` 61 | 62 | Finally, install the remaining dependencies using pip: 63 | ```bash 64 | pip install -r requirements.txt 65 | ``` 66 | 67 | ## Running experiments 68 | The main run file to reproduce all experiments is `main.py`. 69 | It contains commands for quantization-aware training (QAT) and validating quantized models. 70 | You can see the full list of options for each command using `python main.py [COMMAND] --help`. 71 | ```bash 72 | Usage: main.py [OPTIONS] COMMAND [ARGS]... 73 | 74 | Options: 75 | --help Show this message and exit. 76 | 77 | Commands: 78 | train-quantized 79 | ``` 80 | 81 | ## Quantization-Aware Training (QAT) 82 | All models are fine-tuned starting from pre-trained FP32 weights. Pretrained weights may be found here 83 | 84 | - [MobileNetV2](https://drive.google.com/open?id=1jlto6HRVD3ipNkAl1lNhDbkBp7HylaqR) 85 | - EfficientNet-Lite: pretrained weights from [repository](https://github.com/rwightman/pytorch-image-models/) (downloaded at runtime) 86 | 87 | ## MobileNetV2 88 | 89 | To train with **oscillations dampening** run: 90 | ```bash 91 | python main.py train-quantized --arhcitecture mobilenet_v2_quantized 92 | --images-dir path/to/raw_imagenet --act-quant-method MSE --weight-quant-method MSE 93 | --optimizer SGD --weight-decay 2.5e-05 --sep-quant-optimizer 94 | --quant-optimizer Adam --quant-learning-rate 1e-5 --quant-weight-decay 0.0 95 | --model-dir /path/to/mobilenet_v2.pth.tar --learning-rate-schedule cosine:0 96 | # Dampening loss configurations 97 | --oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1 98 | # 4-bit best learning rate 99 | --n-bits 4 --learning-rate 0.0033 100 | # 3-bits best learning rate 101 | --n-bits 3 --learning-rate 0.01 102 | ``` 103 | 104 | To train with **iterative weight freezing** run: 105 | ```bash 106 | python main.py train-quantized --arhcitecture mobilenet_v2_quantized 107 | --images-dir path/to/raw_imagenet --act-quant-method MSE --weight-quant-method MSE 108 | --optimizer SGD --sep-quant-optimizer 109 | --quant-optimizer Adam --quant-learning-rate 1e-5 --quant-weight-decay 0.0 110 | --model-dir /path/to/mobilenet_v2.pth.tar --learning-rate-schedule cosine:0 111 | # Iterative weight freezing configuration 112 | --oscillations-freeze-threshold 0.1 113 | # 4-bit best configuration 114 | --n-bits 4 --learning-rate 0.0033 --weight-decay 5e-05 --oscillations-freeze-threshold-final 0.01 115 | # 3-bit best configuration 116 | --n-bits 3 --learning-rate 0.01 --weight-decay 2.5e-05 --oscillations-freeze-threshold-final 0.011 117 | ``` 118 | 119 | For end user's convenience, bash scripts are provided under `/bash/` for reproducing our experiments. 120 | ```bash 121 | ./bash/train_mobilenetv2.sh --IMAGES_DIR path_to_raw_imagenet --MODEL_DIR path_to_pretrained_weights # QAT training of MobileNetV2 with defaults (method 'freeze' and 3 bits) 122 | ./bash/train_efficientnet.sh --IMAGES_DIR path_to_raw_imagenet --METHOD damp --N_BITS 4 123 | ``` 124 | -------------------------------------------------------------------------------- /bash/set_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | # Setting up the environment 6 | source env/bin/activate 7 | export LC_ALL=C.UTF-8 8 | export LANG=C.UTF-8 9 | export PYTHONPATH=${PYTHONPATH}:$(realpath "$PWD") 10 | -------------------------------------------------------------------------------- /bash/train_efficientnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | ######################################################################################################### 6 | 7 | # Bash script for running QAT EfficientNet-Lite training configuration. 8 | # IMAGES_DIR: path to local raw imagenet dataset and 9 | # MODEL_DIR: path to model's pretrained weights 10 | # should be at least specified. 11 | # 12 | # Example of using this script: 13 | # $ ./bash/train_efficientnet.sh --IMAGES_DIR path_to_imagenet_raw --MODEL_DIR path_to_weights 14 | # 15 | # For getting usage info: 16 | # $ ./bash/train_efficientnet.sh 17 | # 18 | # Other configurable params: 19 | # N_BITS: 3(default), 4 are currently supported 20 | # METHOD: freeze (default; iterative weight freezing), damp (oscillations dampening) 21 | # 22 | # The script may be extended with further input parameters (please refer to "/utils/click_options.py") 23 | 24 | ######################################################################################################### 25 | 26 | source bash/set_env.sh 27 | 28 | MODEL='efficientnet' 29 | N_BITS=3 30 | METHOD='freeze' 31 | 32 | for ARG in "$@" 33 | do 34 | key=$(echo $ARG | cut -f1 -d=) 35 | value=$(echo $ARG | cut -f2 -d=) 36 | 37 | if [[ $key == *"--"* ]]; then 38 | v="${key/--/}" 39 | declare $v="${value}" 40 | fi 41 | done 42 | 43 | if [[ -z $IMAGES_DIR ]] || [[ -z $MODEL_DIR ]]; then 44 | echo "Usage: $(basename "$0") 45 | --IMAGES_DIR=[path to imagenet_raw] 46 | --MODEL_DIR=[path to model's pretrained weights] 47 | --N_BITS=[3(default), 4] 48 | --METHOD=[freeze(default), damp]" 49 | exit 1 50 | fi 51 | 52 | if [ $N_BITS -ne 3 ] && [ $N_BITS -ne 4 ]; then 53 | echo "Only 3,4 bits configuration currently supported" 54 | exit 1 55 | fi 56 | 57 | if [ "$METHOD" != 'freeze' ] && [ "$METHOD" != 'damp' ]; then 58 | echo "Only methods 'damp' and 'freeze' are currently supported." 59 | exit 1 60 | fi 61 | 62 | CMD_ARGS='--architecture efficientnet_lite0_quantized 63 | --act-quant-method MSE 64 | --weight-quant-method MSE 65 | --optimizer SGD 66 | --max-epochs 50 67 | --learning-rate-schedule cosine:0 68 | --sep-quant-optimizer 69 | --quant-optimizer Adam 70 | --quant-learning-rate 1e-5 71 | --quant-weight-decay 0.0' 72 | 73 | # QAT methods 74 | if [ $METHOD == 'freeze' ]; then 75 | CMD_QAT='--oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1' 76 | if [ $N_BITS == 3 ]; then 77 | CMD_BITS='--n-bits 3 --learning-rate 0.01 --weight-decay 5e-05 --oscillations-freeze-threshold-final 0.005' 78 | else 79 | CMD_BITS='--n-bits 4 --learning-rate 0.0033 --weight-decay 1e-04 --oscillations-freeze-threshold-final 0.015' 80 | fi 81 | else 82 | CMD_QAT='--oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1' 83 | if [ $N_BITS == 3 ]; then 84 | CMD_BITS='--n-bits 3 --learning-rate 0.01 --weight-decay 5e-5' 85 | else 86 | CMD_BITS='--n-bits 4 --learning-rate 0.0033 --weight-decay 1e-4' 87 | fi 88 | fi 89 | 90 | CMD_ARGS="$CMD_ARGS $CMD_QAT $CMD_BITS" 91 | 92 | python main.py train-quantized \ 93 | --images-dir $IMAGES_DIR \ 94 | $CMD_ARGS 95 | -------------------------------------------------------------------------------- /bash/train_mobilenetv2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | ######################################################################################################### 6 | 7 | # Bash script for running QAT MobileNetV2 training configuration. 8 | # IMAGES_DIR: path to local raw imagenet dataset and 9 | # MODEL_DIR: path to model's pretrained weights 10 | # should be at least specified. 11 | # 12 | # Example of using this script: 13 | # $ ./bash/train_mobilenetv2.sh --IMAGES_DIR path_to_imagenet_raw --MODEL_DIR path_to_weights 14 | # 15 | # For getting usage info: 16 | # $ ./bash/train_mobilenetv2.sh 17 | # 18 | # Other configurable params: 19 | # N_BITS: 3(default), 4 are currently supported 20 | # METHOD: freeze (default; iterative weight freezing), damp (oscillations dampening) 21 | # 22 | # The script may be extended with further input parameters (please refer to "/utils/click_options.py") 23 | 24 | ######################################################################################################### 25 | 26 | source bash/set_env.sh 27 | 28 | MODEL='mobilenetV2' 29 | N_BITS=3 30 | METHOD='freeze' 31 | 32 | for ARG in "$@" 33 | do 34 | key=$(echo $ARG | cut -f1 -d=) 35 | value=$(echo $ARG | cut -f2 -d=) 36 | 37 | if [[ $key == *"--"* ]]; then 38 | v="${key/--/}" 39 | declare $v="${value}" 40 | fi 41 | done 42 | 43 | if [[ -z $IMAGES_DIR ]] || [[ -z $MODEL_DIR ]]; then 44 | echo "Usage: $(basename "$0") 45 | --IMAGES_DIR=[path to imagenet_raw] 46 | --MODEL_DIR=[path to model's pretrained weights] 47 | --N_BITS=[3(default), 4] 48 | --METHOD=[freeze(default), damp]" 49 | exit 1 50 | fi 51 | 52 | if [ $N_BITS -ne 3 ] && [ $N_BITS -ne 4 ]; then 53 | echo 'Only 3,4 bits configuration currently supported' 54 | exit 1 55 | fi 56 | 57 | if [ "$METHOD" != 'freeze' ] && [ "$METHOD" != 'damp' ]; then 58 | echo "Only methods 'damp' and 'freeze' are currently supported." 59 | exit 1 60 | fi 61 | 62 | CMD_ARGS='--architecture mobilenet_v2_quantized 63 | --act-quant-method MSE 64 | --weight-quant-method MSE 65 | --optimizer SGD 66 | --weight-decay 2.5e-05 67 | --sep-quant-optimizer 68 | --quant-optimizer Adam 69 | --quant-learning-rate 1e-5 70 | --quant-weight-decay 0.0 71 | --learning-rate-schedule cosine:0' 72 | 73 | # QAT methods 74 | if [ $METHOD == "freeze" ]; then 75 | CMD_QAT='--oscillations-freeze-threshold 0.1' 76 | if [ $N_BITS == 3 ]; then 77 | CMD_BITS='--n-bits 3 --learning-rate 0.01 --weight-decay 2.5e-05 --oscillations-freeze-threshold-final 0.011' 78 | else 79 | CMD_BITS='--n-bits 4 --learning-rate 0.0033 --weight-decay 5e-05 --oscillations-freeze-threshold-final 0.01' 80 | fi 81 | else 82 | CMD_QAT='--oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1 --weight-decay 2.5e-05' 83 | if [ $N_BITS == 3 ]; then 84 | CMD_BITS='--n-bits 3 --learning-rate 0.01' 85 | else 86 | CMD_BITS='--n-bits 4 --learning-rate 0.0033' 87 | fi 88 | fi 89 | 90 | CMD_ARGS="$CMD_ARGS $CMD_QAT $CMD_BITS" 91 | 92 | python main.py train-quantized \ 93 | --images-dir $IMAGES_DIR \ 94 | --model-dir $MODEL_DIR \ 95 | $CMD_ARGS -------------------------------------------------------------------------------- /display/toy_regression.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qualcomm-AI-research/oscillations-qat/9064d8540c1705242f08b864f06661247012ee4d/display/toy_regression.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import logging 4 | import os 5 | 6 | import click 7 | from ignite.contrib.handlers import ProgressBar 8 | from ignite.engine import Events, create_supervised_evaluator 9 | from ignite.metrics import Accuracy, TopKCategoricalAccuracy, Loss 10 | from torch.nn import CrossEntropyLoss 11 | 12 | from quantization.utils import ( 13 | pass_data_for_range_estimation, 14 | separate_quantized_model_params, 15 | set_range_estimators, 16 | ) 17 | from utils import DotDict, CosineTempDecay 18 | from utils.click_options import ( 19 | qat_options, 20 | quantization_options, 21 | quant_params_dict, 22 | base_options, 23 | multi_optimizer_options, 24 | ) 25 | from utils.optimizer_utils import optimizer_lr_factory 26 | from utils.oscillation_tracking_utils import add_oscillation_trackers 27 | from utils.qat_utils import ( 28 | get_dataloaders_and_model, 29 | MethodPropagator, 30 | DampeningLoss, 31 | CompositeLoss, 32 | UpdateDampeningLossWeighting, 33 | UpdateFreezingThreshold, 34 | ReestimateBNStats, 35 | ) 36 | from utils.supervised_driver import create_trainer_engine, setup_tensorboard_logger, log_metrics 37 | 38 | 39 | # setup stuff 40 | class Config(DotDict): 41 | pass 42 | 43 | 44 | @click.group() 45 | def oscillations(): 46 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) 47 | 48 | 49 | pass_config = click.make_pass_decorator(Config, ensure=True) 50 | 51 | 52 | @oscillations.command() 53 | @pass_config 54 | @base_options 55 | @multi_optimizer_options() 56 | @quantization_options 57 | @qat_options 58 | def train_quantized(config): 59 | """ 60 | Main QAT function 61 | """ 62 | 63 | print("Setting up network and data loaders") 64 | qparams = quant_params_dict(config) 65 | 66 | dataloaders, model = get_dataloaders_and_model(config, **qparams) 67 | 68 | # Estimate ranges using training data 69 | pass_data_for_range_estimation( 70 | loader=dataloaders.train_loader, 71 | model=model, 72 | act_quant=config.quant.act_quant, 73 | weight_quant=config.quant.weight_quant, 74 | max_num_batches=config.quant.num_est_batches, 75 | ) 76 | 77 | # Put quantizers in desirable state 78 | set_range_estimators(config, model) 79 | 80 | print("Loaded model:\n{}".format(model)) 81 | 82 | # Get all models parameters in subcategories 83 | quantizer_params, model_params, grad_params = separate_quantized_model_params(model) 84 | model_optimizer, quant_optimizer = None, None 85 | if config.qat.sep_quant_optimizer: 86 | # Separate optimizer for model and quantization parameters 87 | model_optimizer, model_lr_scheduler = optimizer_lr_factory( 88 | config.optimizer, model_params, config.base.max_epochs 89 | ) 90 | quant_optimizer, quant_lr_scheduler = optimizer_lr_factory( 91 | config.quant_optimizer, quantizer_params, config.base.max_epochs 92 | ) 93 | 94 | optimizer = MethodPropagator([model_optimizer, quant_optimizer]) 95 | lr_schedulers = [s for s in [model_lr_scheduler, quant_lr_scheduler] if s is not None] 96 | lr_scheduler = MethodPropagator(lr_schedulers) if len(lr_schedulers) else None 97 | else: 98 | optimizer, lr_scheduler = optimizer_lr_factory( 99 | config.optimizer, quantizer_params + model_params, config.base.max_epochs 100 | ) 101 | 102 | print("Optimizer:\n{}".format(optimizer)) 103 | print(f"LR scheduler\n{lr_scheduler}") 104 | 105 | # Define metrics for ingite engine 106 | metrics = {"top_1_accuracy": Accuracy(), "top_5_accuracy": TopKCategoricalAccuracy()} 107 | 108 | # Set-up losses 109 | task_loss_fn = CrossEntropyLoss() 110 | dampening_loss = None 111 | if config.osc_damp.weight is not None: 112 | # Add dampening loss to task loss 113 | dampening_loss = DampeningLoss(model, config.osc_damp.weight, config.osc_damp.aggregation) 114 | loss_dict = {"task_loss": task_loss_fn, "dampening_loss": dampening_loss} 115 | loss_func = CompositeLoss(loss_dict) 116 | loss_metrics = { 117 | "task_loss": Loss(task_loss_fn), 118 | "dampening_loss": Loss(dampening_loss), 119 | "loss": Loss(loss_func), 120 | } 121 | else: 122 | loss_func = task_loss_fn 123 | loss_metrics = {"loss": Loss(loss_func)} 124 | 125 | metrics.update(loss_metrics) 126 | 127 | # Set up ignite trainer and evaluator 128 | trainer, evaluator = create_trainer_engine( 129 | model=model, 130 | optimizer=optimizer, 131 | criterion=loss_func, 132 | data_loaders=dataloaders, 133 | metrics=metrics, 134 | lr_scheduler=lr_scheduler, 135 | save_checkpoint_dir=config.base.save_checkpoint_dir, 136 | device="cuda" if config.base.cuda else "cpu", 137 | ) 138 | 139 | if config.base.progress_bar: 140 | pbar = ProgressBar() 141 | pbar.attach(trainer) 142 | pbar.attach(evaluator) 143 | 144 | # Create TensorboardLogger 145 | if config.base.tb_logging_dir: 146 | if config.qat.sep_quant_optimizer: 147 | optimizers_dict = {"model": model_optimizer, "quant_params": quant_optimizer} 148 | else: 149 | optimizers_dict = optimizer 150 | tb_logger = setup_tensorboard_logger( 151 | trainer, evaluator, config.base.tb_logging_dir, optimizers_dict 152 | ) 153 | 154 | if config.osc_damp.weight_final: 155 | # Apply cosine annealing of dampening loss 156 | total_iterations = len(dataloaders.train_loader) * config.base.max_epochs 157 | annealing_schedule = CosineTempDecay( 158 | t_max=total_iterations, 159 | temp_range=(config.osc_damp.weight, config.osc_damp.weight_final), 160 | rel_decay_start=config.osc_damp.anneal_start, 161 | ) 162 | print(f"Weight gradient parameter cosine annealing schedule:\n{annealing_schedule}") 163 | trainer.add_event_handler( 164 | Events.ITERATION_STARTED, 165 | UpdateDampeningLossWeighting(dampening_loss, annealing_schedule), 166 | ) 167 | 168 | # Evaluate model 169 | print("Running evaluation before training") 170 | evaluator.run(dataloaders.val_loader) 171 | log_metrics(evaluator.state.metrics, "Evaluation", trainer.state.epoch) 172 | 173 | # BN Re-estimation 174 | if config.qat.reestimate_bn_stats: 175 | evaluator.add_event_handler( 176 | Events.EPOCH_STARTED, ReestimateBNStats(model, dataloaders.train_loader) 177 | ) 178 | 179 | # Add oscillation trackers to the model and set up oscillation freezing 180 | if config.osc_freeze.threshold: 181 | oscillation_tracker_dict = add_oscillation_trackers( 182 | model, 183 | max_bits=config.osc_freeze.max_bits, 184 | momentum=config.osc_freeze.ema_momentum, 185 | freeze_threshold=config.osc_freeze.threshold, 186 | use_ema_x_int=config.osc_freeze.use_ema, 187 | ) 188 | 189 | if config.osc_freeze.threshold_final: 190 | # Apply cosine annealing schedule to the freezing threshdold 191 | total_iterations = len(dataloaders.train_loader) * config.base.max_epochs 192 | annealing_schedule = CosineTempDecay( 193 | t_max=total_iterations, 194 | temp_range=(config.osc_freeze.threshold, config.osc_freeze.threshold_final), 195 | rel_decay_start=config.osc_freeze.anneal_start, 196 | ) 197 | print(f"Oscillation freezing annealing schedule:\n{annealing_schedule}") 198 | trainer.add_event_handler( 199 | Events.ITERATION_STARTED, 200 | UpdateFreezingThreshold(oscillation_tracker_dict, annealing_schedule), 201 | ) 202 | 203 | print("Starting training") 204 | 205 | trainer.run(dataloaders.train_loader, max_epochs=config.base.max_epochs) 206 | 207 | print("Finished training") 208 | 209 | 210 | @oscillations.command() 211 | @pass_config 212 | @base_options 213 | @quantization_options 214 | @click.option( 215 | "--load-type", 216 | type=click.Choice(["fp32", "quantized"]), 217 | default="quantized", 218 | help='Either "fp32", or "quantized". Specify weather to load a quantized or a FP ' "model.", 219 | ) 220 | def validate_quantized(config, load_type): 221 | """ 222 | function for running validation on pre-trained quantized models 223 | """ 224 | print("Setting up network and data loaders") 225 | qparams = quant_params_dict(config) 226 | 227 | dataloaders, model = get_dataloaders_and_model(config=config, load_type=load_type, **qparams) 228 | 229 | if load_type == "fp32": 230 | # Estimate ranges using training data 231 | pass_data_for_range_estimation( 232 | loader=dataloaders.train_loader, 233 | model=model, 234 | act_quant=config.quant.act_quant, 235 | weight_quant=config.quant.weight_quant, 236 | max_num_batches=config.quant.num_est_batches, 237 | ) 238 | # Ensure we have the desired quant state 239 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant) 240 | 241 | # Fix ranges 242 | model.fix_ranges() 243 | print("Loaded model:\n{}".format(model)) 244 | 245 | # Create evaluator 246 | loss_func = CrossEntropyLoss() 247 | metrics = { 248 | "top_1_accuracy": Accuracy(), 249 | "top_5_accuracy": TopKCategoricalAccuracy(), 250 | "loss": Loss(loss_func), 251 | } 252 | 253 | pbar = ProgressBar() 254 | evaluator = create_supervised_evaluator( 255 | model=model, metrics=metrics, device="cuda" if config.base.cuda else "cpu" 256 | ) 257 | pbar.attach(evaluator) 258 | print("Start quantized validation") 259 | evaluator.run(dataloaders.val_loader) 260 | final_metrics = evaluator.state.metrics 261 | print(final_metrics) 262 | 263 | 264 | if __name__ == "__main__": 265 | oscillations() 266 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from models.efficientnet_lite_quantized import efficientnet_lite0_quantized 6 | from models.mobilenet_v2_quantized import mobilenetv2_quantized 7 | from models.resnet_quantized import resnet18_quantized, resnet50_quantized 8 | from utils import ClassEnumOptions, MethodMap 9 | 10 | 11 | class QuantArchitectures(ClassEnumOptions): 12 | mobilenet_v2_quantized = MethodMap(mobilenetv2_quantized) 13 | resnet18_quantized = MethodMap(resnet18_quantized) 14 | resnet50_quantized = MethodMap(resnet50_quantized) 15 | efficientnet_lite0_quantized = MethodMap(efficientnet_lite0_quantized) 16 | 17 | -------------------------------------------------------------------------------- /models/efficientnet_lite_quantized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import torch 5 | from timm.models import create_model 6 | from timm.models.efficientnet_blocks import DepthwiseSeparableConv, InvertedResidual 7 | from torch import nn 8 | 9 | from quantization.autoquant_utils import quantize_sequential, quantize_model 10 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts 11 | from quantization.base_quantized_model import QuantizedModel 12 | 13 | 14 | class QuantizedInvertedResidual(QuantizedActivation): 15 | def __init__(self, inv_res_orig, **quant_params): 16 | super().__init__(**quant_params) 17 | 18 | assert inv_res_orig.drop_path_rate == 0.0 19 | assert isinstance(inv_res_orig.se, nn.Identity) 20 | 21 | self.has_residual = inv_res_orig.has_residual 22 | 23 | conv_pw = nn.Sequential(inv_res_orig.conv_pw, inv_res_orig.bn1, inv_res_orig.act1) 24 | self.conv_pw = quantize_sequential(conv_pw, **quant_params)[0] 25 | 26 | conv_dw = nn.Sequential(inv_res_orig.conv_dw, inv_res_orig.bn2, inv_res_orig.act2) 27 | self.conv_dw = quantize_sequential(conv_dw, **quant_params) # [0] 28 | 29 | conv_pwl = nn.Sequential(inv_res_orig.conv_pwl, inv_res_orig.bn3) 30 | self.conv_pwl = quantize_sequential(conv_pwl, **quant_params)[0] 31 | 32 | def forward(self, x): 33 | residual = x 34 | # Point-wise expansion 35 | x = self.conv_pw(x) 36 | # Depth-wise convolution 37 | x = self.conv_dw(x) 38 | # Point-wise linear projection 39 | x = self.conv_pwl(x) 40 | 41 | if self.has_residual: 42 | x += residual 43 | x = self.quantize_activations(x) 44 | return x 45 | 46 | 47 | class QuantizedDepthwiseSeparableConv(QuantizedActivation): 48 | def __init__(self, dws_orig, **quant_params): 49 | super().__init__(**quant_params) 50 | 51 | assert dws_orig.drop_path_rate == 0.0 52 | assert isinstance(dws_orig.se, nn.Identity) 53 | 54 | self.has_residual = dws_orig.has_residual 55 | 56 | conv_dw = nn.Sequential(dws_orig.conv_dw, dws_orig.bn1, dws_orig.act1) 57 | self.conv_dw = quantize_sequential(conv_dw, **quant_params)[0] 58 | 59 | conv_pw = nn.Sequential(dws_orig.conv_pw, dws_orig.bn2, dws_orig.act2) 60 | self.conv_pw = quantize_sequential(conv_pw, **quant_params)[0] 61 | 62 | def forward(self, x): 63 | residual = x 64 | # Depth-wise convolution 65 | x = self.conv_dw(x) 66 | # Point-wise projection 67 | x = self.conv_pw(x) 68 | if self.has_residual: 69 | x += residual 70 | x = self.quantize_activations(x) 71 | return x 72 | 73 | 74 | class QuantizedEfficientNetLite(QuantizedModel): 75 | def __init__(self, base_model, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params): 76 | super().__init__(input_size) 77 | 78 | specials = { 79 | InvertedResidual: QuantizedInvertedResidual, 80 | DepthwiseSeparableConv: QuantizedDepthwiseSeparableConv, 81 | } 82 | 83 | conv_stem = nn.Sequential(base_model.conv_stem, base_model.bn1, base_model.act1) 84 | self.conv_stem = quantize_model(conv_stem, specials=specials, **quant_params)[0] 85 | 86 | self.blocks = quantize_model(base_model.blocks, specials=specials, **quant_params) 87 | 88 | conv_head = nn.Sequential(base_model.conv_head, base_model.bn2, base_model.act2) 89 | self.conv_head = quantize_model(conv_head, specials=specials, **quant_params)[0] 90 | 91 | self.global_pool = base_model.global_pool 92 | 93 | base_model.classifier.__class__ = nn.Linear # Small hack to work with autoquant 94 | self.classifier = quantize_model(base_model.classifier, **quant_params) 95 | 96 | if quant_setup == "FP_logits": 97 | print("Do not quantize output of FC layer") 98 | self.classifier.activation_quantizer = FP32Acts() # no activation quantization of 99 | # logits 100 | elif quant_setup == "LSQ": 101 | print("Set quantization to LSQ (first+last layer in 8 bits)") 102 | # Weights of the first layer 103 | self.conv_stem.weight_quantizer.quantizer.n_bits = 8 104 | # The quantizer of the last conv_layer layer (input to global) 105 | self.conv_head.activation_quantizer.quantizer.n_bits = 8 106 | # Weights of the last layer 107 | self.classifier.weight_quantizer.quantizer.n_bits = 8 108 | # no activation quantization of logits 109 | self.classifier.activation_quantizer = FP32Acts() 110 | elif quant_setup == "LSQ_paper": 111 | # Weights of the first layer 112 | self.conv_stem.activation_quantizer = FP32Acts() 113 | self.conv_stem.weight_quantizer.quantizer.n_bits = 8 114 | # Weights of the last layer 115 | self.classifier.activation_quantizer.quantizer.n_bits = 8 116 | self.classifier.weight_quantizer.quantizer.n_bits = 8 117 | # Set all QuantizedActivations to FP32 118 | for layer in self.blocks.modules(): 119 | if isinstance(layer, QuantizedActivation): 120 | layer.activation_quantizer = FP32Acts() 121 | elif quant_setup is not None and quant_setup != "all": 122 | raise ValueError( 123 | "Quantization setup '{}' not supported for EfficientNet lite".format(quant_setup) 124 | ) 125 | 126 | def forward(self, x): 127 | # features 128 | x = self.conv_stem(x) 129 | x = self.blocks(x) 130 | x = self.conv_head(x) 131 | 132 | x = self.global_pool(x) 133 | x = x.flatten(1) 134 | return self.classifier(x) 135 | 136 | 137 | def efficientnet_lite0_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams): 138 | if load_type == "fp32": 139 | # Load model from pretrained FP32 weights 140 | fp_model = create_model("efficientnet_lite0", pretrained=pretrained) 141 | quant_model = QuantizedEfficientNetLite(fp_model, **qparams) 142 | elif load_type == "quantized": 143 | # Load pretrained QuantizedModel 144 | print(f"Loading pretrained quantized model from {model_dir}") 145 | state_dict = torch.load(model_dir) 146 | fp_model = create_model("efficientnet_lite0") 147 | quant_model = QuantizedEfficientNetLite(fp_model, **qparams) 148 | quant_model.load_state_dict(state_dict) 149 | else: 150 | raise ValueError("wrong load_type specified") 151 | return quant_model 152 | -------------------------------------------------------------------------------- /models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | # https://github.com/tonylins/pytorch-mobilenet-v2 6 | 7 | import math 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | __all__ = ["MobileNetV2"] 13 | 14 | 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def conv_1x1_bn(inp, oup): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True) 24 | ) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = round(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 40 | nn.BatchNorm2d(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 54 | nn.BatchNorm2d(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 58 | nn.BatchNorm2d(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | 68 | class MobileNetV2(nn.Module): 69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.0, dropout=0.0): 70 | super().__init__() 71 | block = InvertedResidual 72 | input_channel = 32 73 | last_channel = 1280 74 | inverted_residual_setting = [ 75 | # t, c, n, s 76 | [1, 16, 1, 1], 77 | [6, 24, 2, 2], 78 | [6, 32, 3, 2], 79 | [6, 64, 4, 2], 80 | [6, 96, 3, 1], 81 | [6, 160, 3, 2], 82 | [6, 320, 1, 1], 83 | ] 84 | 85 | # building first layer 86 | assert input_size % 32 == 0 87 | input_channel = int(input_channel * width_mult) 88 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 89 | features = [conv_bn(3, input_channel, 2)] 90 | # building inverted residual blocks 91 | for t, c, n, s in inverted_residual_setting: 92 | output_channel = int(c * width_mult) 93 | for i in range(n): 94 | if i == 0: 95 | features.append(block(input_channel, output_channel, s, expand_ratio=t)) 96 | else: 97 | features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 98 | input_channel = output_channel 99 | # building last several layers 100 | features.append(conv_1x1_bn(input_channel, self.last_channel)) 101 | features.append(nn.AvgPool2d(input_size // 32)) 102 | # make it nn.Sequential 103 | self.features = nn.Sequential(*features) 104 | 105 | # building classifier 106 | self.classifier = nn.Sequential( 107 | nn.Dropout(dropout), 108 | nn.Linear(self.last_channel, n_class), 109 | ) 110 | 111 | self._initialize_weights() 112 | 113 | def forward(self, x): 114 | x = self.features(x) 115 | x = F.adaptive_avg_pool2d(x, 1).squeeze() # type: ignore[arg-type] # accepted slang 116 | x = self.classifier(x) 117 | return x 118 | 119 | def _initialize_weights(self): 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.Linear): 130 | n = m.weight.size(1) 131 | m.weight.data.normal_(0, 0.01) 132 | m.bias.data.zero_() 133 | -------------------------------------------------------------------------------- /models/mobilenet_v2_quantized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import os 6 | import re 7 | import torch 8 | from collections import OrderedDict 9 | from models.mobilenet_v2 import MobileNetV2, InvertedResidual 10 | from quantization.autoquant_utils import quantize_sequential, Flattener, quantize_model, BNQConv 11 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts 12 | from quantization.base_quantized_model import QuantizedModel 13 | 14 | 15 | class QuantizedInvertedResidual(QuantizedActivation): 16 | def __init__(self, inv_res_orig, **quant_params): 17 | super().__init__(**quant_params) 18 | self.use_res_connect = inv_res_orig.use_res_connect 19 | self.conv = quantize_sequential(inv_res_orig.conv, **quant_params) 20 | 21 | def forward(self, x): 22 | if self.use_res_connect: 23 | x = x + self.conv(x) 24 | return self.quantize_activations(x) 25 | else: 26 | return self.conv(x) 27 | 28 | 29 | class QuantizedMobileNetV2(QuantizedModel): 30 | def __init__(self, model_fp, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params): 31 | super().__init__(input_size) 32 | specials = {InvertedResidual: QuantizedInvertedResidual} 33 | # quantize and copy parts from original model 34 | quantize_input = quant_setup and quant_setup == "LSQ_paper" 35 | self.features = quantize_sequential( 36 | model_fp.features, 37 | tie_activation_quantizers=not quantize_input, 38 | specials=specials, 39 | **quant_params, 40 | ) 41 | 42 | self.flattener = Flattener() 43 | self.classifier = quantize_model(model_fp.classifier, **quant_params) 44 | 45 | if quant_setup == "FP_logits": 46 | print("Do not quantize output of FC layer") 47 | self.classifier[1].activation_quantizer = FP32Acts() 48 | # self.classifier.activation_quantizer = FP32Acts() # no activation quantization of logits 49 | elif quant_setup == "fc4": 50 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 51 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4 52 | elif quant_setup == "fc4_dw8": 53 | print("\n\n### fc4_dw8 setup ###\n\n") 54 | # FC layer in 4 bits, depth-wise separable once in 8 bit 55 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 56 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4 57 | for name, module in self.named_modules(): 58 | if isinstance(module, BNQConv) and module.groups == module.in_channels: 59 | module.weight_quantizer.quantizer.n_bits = 8 60 | print(f"Set layer {name} to 8 bits") 61 | elif quant_setup == "LSQ": 62 | print("Set quantization to LSQ (first+last layer in 8 bits)") 63 | # Weights of the first layer 64 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 65 | # The quantizer of the last conv_layer layer (input to avgpool with tied quantizers) 66 | self.features[-2][0].activation_quantizer.quantizer.n_bits = 8 67 | # Weights of the last layer 68 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8 69 | # no activation quantization of logits 70 | self.classifier[1].activation_quantizer = FP32Acts() 71 | elif quant_setup == "LSQ_paper": 72 | # Weights of the first layer 73 | self.features[0][0].activation_quantizer = FP32Acts() 74 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 75 | # Weights of the last layer 76 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8 77 | self.classifier[1].activation_quantizer.quantizer.n_bits = 8 78 | # Set all QuantizedActivations to FP32 79 | for layer in self.features.modules(): 80 | if isinstance(layer, QuantizedActivation): 81 | layer.activation_quantizer = FP32Acts() 82 | elif quant_setup is not None and quant_setup != "all": 83 | raise ValueError( 84 | "Quantization setup '{}' not supported for MobilenetV2".format(quant_setup) 85 | ) 86 | 87 | def forward(self, x): 88 | x = self.features(x) 89 | x = self.flattener(x) 90 | x = self.classifier(x) 91 | 92 | return x 93 | 94 | 95 | def mobilenetv2_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams): 96 | fp_model = MobileNetV2() 97 | if pretrained and load_type == "fp32": 98 | # Load model from pretrained FP32 weights 99 | assert os.path.exists(model_dir) 100 | print(f"Loading pretrained weights from {model_dir}") 101 | state_dict = torch.load(model_dir) 102 | fp_model.load_state_dict(state_dict) 103 | quant_model = QuantizedMobileNetV2(fp_model, **qparams) 104 | elif load_type == "quantized": 105 | # Load pretrained QuantizedModel 106 | print(f"Loading pretrained quantized model from {model_dir}") 107 | state_dict = torch.load(model_dir) 108 | quant_model = QuantizedMobileNetV2(fp_model, **qparams) 109 | quant_model.load_state_dict(state_dict, strict=False) 110 | else: 111 | raise ValueError("wrong load_type specified") 112 | 113 | return quant_model 114 | -------------------------------------------------------------------------------- /models/resnet_quantized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import torch 5 | from torch import nn 6 | from torchvision.models.resnet import BasicBlock, Bottleneck 7 | from torchvision.models import resnet18, resnet50 8 | 9 | from quantization.autoquant_utils import quantize_model, Flattener, QuantizedActivationWrapper 10 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts 11 | from quantization.base_quantized_model import QuantizedModel 12 | 13 | 14 | class QuantizedBlock(QuantizedActivation): 15 | def __init__(self, block, **quant_params): 16 | super().__init__(**quant_params) 17 | 18 | if isinstance(block, Bottleneck): 19 | features = nn.Sequential( 20 | block.conv1, 21 | block.bn1, 22 | block.relu, 23 | block.conv2, 24 | block.bn2, 25 | block.relu, 26 | block.conv3, 27 | block.bn3, 28 | ) 29 | elif isinstance(block, BasicBlock): 30 | features = nn.Sequential(block.conv1, block.bn1, block.relu, block.conv2, block.bn2) 31 | 32 | self.features = quantize_model(features, **quant_params) 33 | self.downsample = ( 34 | quantize_model(block.downsample, **quant_params) if block.downsample else None 35 | ) 36 | 37 | self.relu = block.relu 38 | 39 | def forward(self, x): 40 | residual = x if self.downsample is None else self.downsample(x) 41 | out = self.features(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return self.quantize_activations(out) 47 | 48 | 49 | class QuantizedResNet(QuantizedModel): 50 | def __init__(self, resnet, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params): 51 | super().__init__(input_size) 52 | specials = {BasicBlock: QuantizedBlock, Bottleneck: QuantizedBlock} 53 | 54 | if hasattr(resnet, "maxpool"): 55 | # ImageNet ResNet case 56 | features = nn.Sequential( 57 | resnet.conv1, 58 | resnet.bn1, 59 | resnet.relu, 60 | resnet.maxpool, 61 | resnet.layer1, 62 | resnet.layer2, 63 | resnet.layer3, 64 | resnet.layer4, 65 | ) 66 | else: 67 | # Tiny ImageNet ResNet case 68 | features = nn.Sequential( 69 | resnet.conv1, 70 | resnet.bn1, 71 | resnet.relu, 72 | resnet.layer1, 73 | resnet.layer2, 74 | resnet.layer3, 75 | resnet.layer4, 76 | ) 77 | 78 | self.features = quantize_model(features, specials=specials, **quant_params) 79 | 80 | if quant_setup and quant_setup == "LSQ_paper": 81 | # Keep avgpool intact as we quantize the input the last layer 82 | self.avgpool = resnet.avgpool 83 | else: 84 | self.avgpool = QuantizedActivationWrapper( 85 | resnet.avgpool, 86 | tie_activation_quantizers=True, 87 | input_quantizer=self.features[-1][-1].activation_quantizer, 88 | **quant_params, 89 | ) 90 | self.flattener = Flattener() 91 | self.fc = quantize_model(resnet.fc, **quant_params) 92 | 93 | # Adapt to specific quantization setup 94 | if quant_setup == "LSQ": 95 | print("Set quantization to LSQ (first+last layer in 8 bits)") 96 | # Weights of the first layer 97 | self.features[0].weight_quantizer.quantizer.n_bits = 8 98 | # The quantizer of the residual (input to last layer) 99 | self.features[-1][-1].activation_quantizer.quantizer.n_bits = 8 100 | # Output of the last conv (input to last layer) 101 | self.features[-1][-1].features[-1].activation_quantizer.quantizer.n_bits = 8 102 | # Weights of the last layer 103 | self.fc.weight_quantizer.quantizer.n_bits = 8 104 | # no activation quantization of logits 105 | self.fc.activation_quantizer = FP32Acts() 106 | elif quant_setup == "LSQ_paper": 107 | # Weights of the first layer 108 | self.features[0].activation_quantizer = FP32Acts() 109 | self.features[0].weight_quantizer.quantizer.n_bits = 8 110 | # Weights of the last layer 111 | self.fc.activation_quantizer.quantizer.n_bits = 8 112 | self.fc.weight_quantizer.quantizer.n_bits = 8 113 | # Set all QuantizedActivations to FP32 114 | for layer in self.features.modules(): 115 | if isinstance(layer, QuantizedActivation): 116 | layer.activation_quantizer = FP32Acts() 117 | elif quant_setup == "FP_logits": 118 | print("Do not quantize output of FC layer") 119 | self.fc.activation_quantizer = FP32Acts() # no activation quantization of logits 120 | elif quant_setup == "fc4": 121 | self.features[0].weight_quantizer.quantizer.n_bits = 8 122 | self.fc.weight_quantizer.quantizer.n_bits = 4 123 | elif quant_setup is not None and quant_setup != "all": 124 | raise ValueError("Quantization setup '{}' not supported for Resnet".format(quant_setup)) 125 | 126 | def forward(self, x): 127 | x = self.features(x) 128 | x = self.avgpool(x) 129 | # x = x.view(x.size(0), -1) 130 | x = self.flattener(x) 131 | x = self.fc(x) 132 | 133 | return x 134 | 135 | 136 | def resnet18_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams): 137 | if load_type == "fp32": 138 | # Load model from pretrained FP32 weights 139 | fp_model = resnet18(pretrained=pretrained) 140 | quant_model = QuantizedResNet(fp_model, **qparams) 141 | elif load_type == "quantized": 142 | # Load pretrained QuantizedModel 143 | print(f"Loading pretrained quantized model from {model_dir}") 144 | state_dict = torch.load(model_dir) 145 | fp_model = resnet18() 146 | quant_model = QuantizedResNet(fp_model, **qparams) 147 | quant_model.load_state_dict(state_dict) 148 | else: 149 | raise ValueError("wrong load_type specified") 150 | return quant_model 151 | 152 | 153 | def resnet50_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams): 154 | if load_type == "fp32": 155 | # Load model from pretrained FP32 weights 156 | fp_model = resnet50(pretrained=pretrained) 157 | quant_model = QuantizedResNet(fp_model, **qparams) 158 | elif load_type == "quantized": 159 | # Load pretrained QuantizedModel 160 | print(f"Loading pretrained quantized model from {model_dir}") 161 | state_dict = torch.load(model_dir) 162 | fp_model = resnet50() 163 | quant_model = QuantizedResNet(fp_model, **qparams) 164 | quant_model.load_state_dict(state_dict) 165 | else: 166 | raise ValueError("wrong load_type specified") 167 | return quant_model 168 | -------------------------------------------------------------------------------- /quantization/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from . import utils, autoquant_utils, quantized_folded_bn 6 | -------------------------------------------------------------------------------- /quantization/autoquant_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import copy 6 | import warnings 7 | 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.nn.modules.conv import _ConvNd 11 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd 12 | 13 | 14 | from quantization.base_quantized_classes import QuantizedActivation, QuantizedModule 15 | from quantization.hijacker import QuantizationHijacker, activations_set 16 | from quantization.quantization_manager import QuantizationManager 17 | from quantization.quantized_folded_bn import BNFusedHijacker 18 | 19 | 20 | class QuantConv1d(QuantizationHijacker, nn.Conv1d): 21 | def run_forward(self, x, weight, bias, offsets=None): 22 | return F.conv1d( 23 | x.contiguous(), 24 | weight.contiguous(), 25 | bias=bias, 26 | stride=self.stride, 27 | padding=self.padding, 28 | dilation=self.dilation, 29 | groups=self.groups, 30 | ) 31 | 32 | 33 | class QuantConv(QuantizationHijacker, nn.Conv2d): 34 | def run_forward(self, x, weight, bias, offsets=None): 35 | return F.conv2d( 36 | x.contiguous(), 37 | weight.contiguous(), 38 | bias=bias, 39 | stride=self.stride, 40 | padding=self.padding, 41 | dilation=self.dilation, 42 | groups=self.groups, 43 | ) 44 | 45 | 46 | class QuantConvTransposeBase(QuantizationHijacker): 47 | def quantize_weights(self, weights): 48 | if self.per_channel_weights: 49 | # NOTE: ND tranpose conv weights are stored as (in_channels, out_channels, *) 50 | # instead of (out_channels, in_channels, *) for convs 51 | # and per-channel quantization should be applied to out channels 52 | # transposing before passing to quantizer is trick to avoid 53 | # changing logic in range estimators and quantizers 54 | weights = weights.transpose(1, 0).contiguous() 55 | weights = self.weight_quantizer(weights) 56 | if self.per_channel_weights: 57 | weights = weights.transpose(1, 0).contiguous() 58 | return weights 59 | 60 | 61 | class QuantConvTranspose1d(QuantConvTransposeBase, nn.ConvTranspose1d): 62 | def run_forward(self, x, weight, bias, offsets=None): 63 | return F.conv_transpose1d( 64 | x.contiguous(), 65 | weight.contiguous(), 66 | bias=bias, 67 | stride=self.stride, 68 | padding=self.padding, 69 | output_padding=self.output_padding, 70 | dilation=self.dilation, 71 | groups=self.groups, 72 | ) 73 | 74 | 75 | class QuantConvTranspose(QuantConvTransposeBase, nn.ConvTranspose2d): 76 | def run_forward(self, x, weight, bias, offsets=None): 77 | return F.conv_transpose2d( 78 | x.contiguous(), 79 | weight.contiguous(), 80 | bias=bias, 81 | stride=self.stride, 82 | padding=self.padding, 83 | output_padding=self.output_padding, 84 | dilation=self.dilation, 85 | groups=self.groups, 86 | ) 87 | 88 | 89 | class QuantLinear(QuantizationHijacker, nn.Linear): 90 | def run_forward(self, x, weight, bias, offsets=None): 91 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias) 92 | 93 | 94 | class BNQConv1d(BNFusedHijacker, nn.Conv1d): 95 | def run_forward(self, x, weight, bias, offsets=None): 96 | return F.conv1d( 97 | x.contiguous(), 98 | weight.contiguous(), 99 | bias=bias, 100 | stride=self.stride, 101 | padding=self.padding, 102 | dilation=self.dilation, 103 | groups=self.groups, 104 | ) 105 | 106 | 107 | class BNQConv(BNFusedHijacker, nn.Conv2d): 108 | def run_forward(self, x, weight, bias, offsets=None): 109 | return F.conv2d( 110 | x.contiguous(), 111 | weight.contiguous(), 112 | bias=bias, 113 | stride=self.stride, 114 | padding=self.padding, 115 | dilation=self.dilation, 116 | groups=self.groups, 117 | ) 118 | 119 | 120 | class BNQLinear(BNFusedHijacker, nn.Linear): 121 | def run_forward(self, x, weight, bias, offsets=None): 122 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias) 123 | 124 | 125 | class QuantizedActivationWrapper(QuantizedActivation): 126 | """ 127 | Wraps over a layer and quantized the activation. 128 | It also allow for tying the input and output quantizer which is helpful 129 | for layers such Average Pooling 130 | """ 131 | 132 | def __init__( 133 | self, 134 | layer, 135 | tie_activation_quantizers=False, 136 | input_quantizer: QuantizationManager = None, 137 | *args, 138 | **kwargs, 139 | ): 140 | super().__init__(*args, **kwargs) 141 | self.tie_activation_quantizers = tie_activation_quantizers 142 | if input_quantizer: 143 | assert isinstance(input_quantizer, QuantizationManager) 144 | self.activation_quantizer = input_quantizer 145 | self.layer = layer 146 | 147 | def quantize_activations_no_range_update(self, x): 148 | if self._quant_a: 149 | return self.activation_quantizer.quantizer(x) 150 | else: 151 | return x 152 | 153 | def forward(self, x): 154 | x = self.layer(x) 155 | if self.tie_activation_quantizers: 156 | # The input activation quantizer is used to quantize the activation 157 | # but without updating the quantization range 158 | return self.quantize_activations_no_range_update(x) 159 | else: 160 | return self.quantize_activations(x) 161 | 162 | def extra_repr(self): 163 | return f"tie_activation_quantizers={self.tie_activation_quantizers}" 164 | 165 | 166 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm): 167 | def run_forward(self, x, weight, bias, offsets=None): 168 | return F.layer_norm( 169 | input=x.contiguous(), 170 | normalized_shape=self.normalized_shape, 171 | weight=weight.contiguous(), 172 | bias=bias.contiguous(), 173 | eps=self.eps, 174 | ) 175 | 176 | 177 | class Flattener(nn.Module): 178 | def forward(self, x): 179 | return x.view(x.shape[0], -1) 180 | 181 | 182 | # Non BN Quant Modules Map 183 | non_bn_module_map = { 184 | nn.Conv1d: QuantConv1d, 185 | nn.Conv2d: QuantConv, 186 | nn.ConvTranspose1d: QuantConvTranspose1d, 187 | nn.ConvTranspose2d: QuantConvTranspose, 188 | nn.Linear: QuantLinear, 189 | nn.LayerNorm: QuantLayerNorm, 190 | } 191 | 192 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd) 193 | # BN Quant Modules Map 194 | bn_module_map = {nn.Conv1d: BNQConv1d, nn.Conv2d: BNQConv, nn.Linear: BNQLinear} 195 | 196 | quant_conv_modules = (QuantConv1d, QuantConv, BNQConv1d, BNQConv) 197 | 198 | 199 | def next_bn(module, i): 200 | return len(module) > i + 1 and isinstance(module[i + 1], (nn.BatchNorm2d, nn.BatchNorm1d)) 201 | 202 | 203 | def get_act(module, i): 204 | # Case 1: conv + act 205 | if len(module) - i > 1 and isinstance(module[i + 1], tuple(activations_set)): 206 | return module[i + 1], i + 1 207 | 208 | # Case 2: conv + bn + act 209 | if ( 210 | len(module) - i > 2 211 | and next_bn(module, i) 212 | and isinstance(module[i + 2], tuple(activations_set)) 213 | ): 214 | return module[i + 2], i + 2 215 | 216 | # Case 3: conv + bn + X -> return false 217 | # Case 4: conv + X -> return false 218 | return None, None 219 | 220 | 221 | def get_conv_args(module): 222 | args = dict( 223 | in_channels=module.in_channels, 224 | out_channels=module.out_channels, 225 | kernel_size=module.kernel_size, 226 | stride=module.stride, 227 | padding=module.padding, 228 | dilation=module.dilation, 229 | groups=module.groups, 230 | bias=module.bias is not None, 231 | ) 232 | if isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d)): 233 | args["output_padding"] = module.output_padding 234 | return args 235 | 236 | 237 | def get_linear_args(module): 238 | args = dict( 239 | in_features=module.in_features, 240 | out_features=module.out_features, 241 | bias=module.bias is not None, 242 | ) 243 | return args 244 | 245 | 246 | def get_layernorm_args(module): 247 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps) 248 | return args 249 | 250 | 251 | def get_module_args(mod, act): 252 | if isinstance(mod, _ConvNd): 253 | kwargs = get_conv_args(mod) 254 | elif isinstance(mod, nn.Linear): 255 | kwargs = get_linear_args(mod) 256 | elif isinstance(mod, nn.LayerNorm): 257 | kwargs = get_layernorm_args(mod) 258 | else: 259 | raise ValueError 260 | 261 | kwargs["activation"] = act 262 | 263 | return kwargs 264 | 265 | 266 | def fold_bn(module, i, **quant_params): 267 | bn = next_bn(module, i) 268 | act, act_idx = get_act(module, i) 269 | modmap = bn_module_map if bn else non_bn_module_map 270 | modtype = modmap[type(module[i])] 271 | 272 | kwargs = get_module_args(module[i], act) 273 | new_module = modtype(**kwargs, **quant_params) 274 | new_module.weight.data = module[i].weight.data.clone() 275 | 276 | if bn: 277 | new_module.gamma.data = module[i + 1].weight.data.clone() 278 | new_module.beta.data = module[i + 1].bias.data.clone() 279 | new_module.running_mean.data = module[i + 1].running_mean.data.clone() 280 | new_module.running_var.data = module[i + 1].running_var.data.clone() 281 | if module[i].bias is not None: 282 | new_module.running_mean.data -= module[i].bias.data 283 | print("Warning: bias in conv/linear before batch normalization.") 284 | new_module.epsilon = module[i + 1].eps 285 | 286 | elif module[i].bias is not None: 287 | new_module.bias.data = module[i].bias.data.clone() 288 | 289 | return new_module, i + int(bool(act)) + int(bn) + 1 290 | 291 | 292 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params): 293 | specials = specials or dict() 294 | 295 | i = 0 296 | quant_modules = [] 297 | while i < len(model): 298 | if isinstance(model[i], QuantizedModule): 299 | quant_modules.append(model[i]) 300 | elif type(model[i]) in non_bn_module_map: 301 | new_module, new_i = fold_bn(model, i, **quant_params) 302 | quant_modules.append(new_module) 303 | i = new_i 304 | continue 305 | 306 | elif type(model[i]) in specials: 307 | quant_modules.append(specials[type(model[i])](model[i], **quant_params)) 308 | 309 | elif isinstance(model[i], non_param_modules): 310 | # Check for last quantizer 311 | input_quantizer = None 312 | if quant_modules and isinstance(quant_modules[-1], QuantizedModule): 313 | last_layer = quant_modules[-1] 314 | input_quantizer = quant_modules[-1].activation_quantizer 315 | elif ( 316 | quant_modules 317 | and isinstance(quant_modules[-1], nn.Sequential) 318 | and isinstance(quant_modules[-1][-1], QuantizedModule) 319 | ): 320 | last_layer = quant_modules[-1][-1] 321 | input_quantizer = quant_modules[-1][-1].activation_quantizer 322 | 323 | if input_quantizer and tie_activation_quantizers: 324 | # If input quantizer is found the tie input/output act quantizers 325 | print( 326 | f"Tying input quantizer {i-1}^th layer of type {type(last_layer)} to the " 327 | f"quantized {type(model[i])} following it" 328 | ) 329 | quant_modules.append( 330 | QuantizedActivationWrapper( 331 | model[i], 332 | tie_activation_quantizers=tie_activation_quantizers, 333 | input_quantizer=input_quantizer, 334 | **quant_params, 335 | ) 336 | ) 337 | else: 338 | # Input quantizer not found 339 | quant_modules.append(QuantizedActivationWrapper(model[i], **quant_params)) 340 | if tie_activation_quantizers: 341 | warnings.warn("Input quantizer not found, so we do not tie quantizers") 342 | else: 343 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params)) 344 | i += 1 345 | return nn.Sequential(*quant_modules) 346 | 347 | 348 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params): 349 | specials = specials or dict() 350 | 351 | if isinstance(model, nn.Sequential): 352 | quant_model = quantize_sequential( 353 | model, specials, tie_activation_quantizers, **quant_params 354 | ) 355 | 356 | elif type(model) in specials: 357 | quant_model = specials[type(model)](model, **quant_params) 358 | 359 | elif isinstance(model, non_param_modules): 360 | quant_model = QuantizedActivationWrapper(model, **quant_params) 361 | 362 | elif type(model) in non_bn_module_map: 363 | # If we do isinstance() then we might run into issues with modules that inherit from 364 | # one of these classes, for whatever reason 365 | modtype = non_bn_module_map[type(model)] 366 | kwargs = get_module_args(model, None) 367 | quant_model = modtype(**kwargs, **quant_params) 368 | 369 | quant_model.weight.data = model.weight.data 370 | if getattr(model, "bias", None) is not None: 371 | quant_model.bias.data = model.bias.data 372 | 373 | else: 374 | # Unknown type, try to quantize all child modules 375 | quant_model = copy.deepcopy(model) 376 | for name, module in quant_model._modules.items(): 377 | new_model = quantize_model(module, specials=specials, **quant_params) 378 | if new_model is not None: 379 | setattr(quant_model, name, new_model) 380 | 381 | return quant_model 382 | -------------------------------------------------------------------------------- /quantization/base_quantized_classes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from quantization.quantization_manager import QuantizationManager 9 | from quantization.quantizers import QuantizerBase, AsymmetricUniformQuantizer 10 | from quantization.quantizers.rounding_utils import round_ste_func 11 | from quantization.range_estimators import ( 12 | RangeEstimatorBase, 13 | CurrentMinMaxEstimator, 14 | RunningMinMaxEstimator, 15 | ) 16 | 17 | 18 | def _set_layer_learn_ranges(layer): 19 | if isinstance(layer, QuantizationManager): 20 | if layer.quantizer.is_initialized: 21 | layer.learn_ranges() 22 | 23 | 24 | def _set_layer_fix_ranges(layer): 25 | if isinstance(layer, QuantizationManager): 26 | if layer.quantizer.is_initialized: 27 | layer.fix_ranges() 28 | 29 | 30 | def _set_layer_estimate_ranges(layer): 31 | if isinstance(layer, QuantizationManager): 32 | layer.estimate_ranges() 33 | 34 | 35 | def _set_layer_estimate_ranges_train(layer): 36 | if isinstance(layer, QuantizationManager): 37 | if layer.quantizer.is_initialized: 38 | layer.estimate_ranges_train() 39 | 40 | 41 | class QuantizedModule(nn.Module): 42 | """ 43 | Parent class for a quantized module. It adds the basic functionality of switching the module 44 | between quantized and full precision mode. It also defines the cached parameters and handles 45 | the reset of the cache properly. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | *args, 51 | method: QuantizerBase = AsymmetricUniformQuantizer, 52 | act_method=None, 53 | weight_range_method: RangeEstimatorBase = CurrentMinMaxEstimator, 54 | act_range_method: RangeEstimatorBase = RunningMinMaxEstimator, 55 | n_bits=8, 56 | n_bits_act=None, 57 | per_channel_weights=False, 58 | percentile=None, 59 | weight_range_options=None, 60 | act_range_options=None, 61 | scale_domain="linear", 62 | act_quant_kwargs={}, 63 | weight_quant_kwargs={}, 64 | weight_discretizer=round_ste_func, 65 | act_discretizer=round_ste_func, 66 | act_discretizer_args=tuple(), 67 | weight_discretizer_args=tuple(), 68 | quantize_input=False, 69 | **kwargs 70 | ): 71 | kwargs.pop("act_quant_dict", None) 72 | 73 | super().__init__(*args, **kwargs) 74 | 75 | self.method = method 76 | self.act_method = act_method or method 77 | self.n_bits = n_bits 78 | self.n_bits_act = n_bits_act or n_bits 79 | self.per_channel_weights = per_channel_weights 80 | self.percentile = percentile 81 | self.weight_range_method = weight_range_method 82 | self.weight_range_options = weight_range_options if weight_range_options else {} 83 | self.act_range_method = act_range_method 84 | self.act_range_options = act_range_options if act_range_options else {} 85 | self.scale_domain = scale_domain 86 | self.quantize_input = quantize_input 87 | 88 | self.quant_params = None 89 | self.register_buffer("_quant_w", torch.BoolTensor([False])) 90 | self.register_buffer("_quant_a", torch.BoolTensor([False])) 91 | 92 | self.act_qparams = dict( 93 | n_bits=self.n_bits_act, 94 | scale_domain=self.scale_domain, 95 | discretizer=act_discretizer, 96 | discretizer_args=act_discretizer_args, 97 | **act_quant_kwargs 98 | ) 99 | self.weight_qparams = dict( 100 | n_bits=self.n_bits, 101 | scale_domain=self.scale_domain, 102 | discretizer=weight_discretizer, 103 | discretizer_args=weight_discretizer_args, 104 | **weight_quant_kwargs 105 | ) 106 | 107 | def quantized_weights(self): 108 | self._quant_w = torch.BoolTensor([True]) 109 | 110 | def full_precision_weights(self): 111 | self._quant_w = torch.BoolTensor([False]) 112 | 113 | def quantized_acts(self): 114 | self._quant_a = torch.BoolTensor([True]) 115 | 116 | def full_precision_acts(self): 117 | self._quant_a = torch.BoolTensor([False]) 118 | 119 | def quantized(self): 120 | self.quantized_weights() 121 | self.quantized_acts() 122 | 123 | def full_precision(self): 124 | self.full_precision_weights() 125 | self.full_precision_acts() 126 | 127 | def get_quantizer_status(self): 128 | return dict(quant_a=self._quant_a.item(), quant_w=self._quant_w.item()) 129 | 130 | def set_quantizer_status(self, quantizer_status): 131 | if quantizer_status["quant_a"]: 132 | self.quantized_acts() 133 | else: 134 | self.full_precision_acts() 135 | 136 | if quantizer_status["quant_w"]: 137 | self.quantized_weights() 138 | else: 139 | self.full_precision_weights() 140 | 141 | def learn_ranges(self): 142 | self.apply(_set_layer_learn_ranges) 143 | 144 | def fix_ranges(self): 145 | self.apply(_set_layer_fix_ranges) 146 | 147 | def estimate_ranges(self): 148 | self.apply(_set_layer_estimate_ranges) 149 | 150 | def estimate_ranges_train(self): 151 | self.apply(_set_layer_estimate_ranges_train) 152 | 153 | def extra_repr(self): 154 | quant_state = "weight_quant={}, act_quant={}".format( 155 | self._quant_w.item(), self._quant_a.item() 156 | ) 157 | parent_repr = super().extra_repr() 158 | return "{},\n{}".format(parent_repr, quant_state) if parent_repr else quant_state 159 | 160 | 161 | class QuantizedActivation(QuantizedModule): 162 | def __init__(self, *args, **kwargs): 163 | super().__init__(*args, **kwargs) 164 | self.activation_quantizer = QuantizationManager( 165 | qmethod=self.act_method, 166 | qparams=self.act_qparams, 167 | init=self.act_range_method, 168 | range_estim_params=self.act_range_options, 169 | ) 170 | 171 | def quantize_activations(self, x): 172 | if self._quant_a: 173 | return self.activation_quantizer(x) 174 | else: 175 | return x 176 | 177 | def forward(self, x): 178 | return self.quantize_activations(x) 179 | 180 | 181 | class FP32Acts(nn.Module): 182 | def forward(self, x): 183 | return x 184 | 185 | def reset_ranges(self): 186 | pass 187 | -------------------------------------------------------------------------------- /quantization/base_quantized_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | from typing import Union, Dict 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | 9 | from quantization.base_quantized_classes import ( 10 | QuantizedModule, 11 | _set_layer_estimate_ranges, 12 | _set_layer_estimate_ranges_train, 13 | _set_layer_learn_ranges, 14 | _set_layer_fix_ranges, 15 | ) 16 | from quantization.quantizers import QuantizerBase 17 | 18 | 19 | class QuantizedModel(nn.Module): 20 | """ 21 | Parent class for a quantized model. This allows you to have convenience functions to put the 22 | whole model into quantization or full precision. 23 | """ 24 | 25 | def __init__(self, input_size=(1, 3, 224, 224)): 26 | """ 27 | Parameters 28 | ---------- 29 | input_size: Tuple with the input dimension for the model (including batch dimension) 30 | """ 31 | super().__init__() 32 | self.input_size = input_size 33 | 34 | def load_state_dict( 35 | self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True 36 | ): 37 | """ 38 | This function overwrites the load_state_dict of nn.Module to ensure that quantization 39 | parameters are loaded correctly for quantized model. 40 | 41 | """ 42 | quant_state_dict = { 43 | k: v for k, v in state_dict.items() if k.endswith("_quant_a") or k.endswith("_quant_w") 44 | } 45 | 46 | if quant_state_dict: 47 | super().load_state_dict(quant_state_dict, strict=False) 48 | else: 49 | raise ValueError( 50 | "The quantization states of activations or weights should be " 51 | "included in the state dict " 52 | ) 53 | # Pass dummy data through quantized model to ensure all quantization parameters are 54 | # initialized with the correct dimensions (None tensors will lead to issues in state dict 55 | # loading) 56 | device = next(self.parameters()).device 57 | dummy_input = torch.rand(*self.input_size, device=device) 58 | with torch.no_grad(): 59 | self.forward(dummy_input) 60 | 61 | # Load state dict 62 | super().load_state_dict(state_dict, strict) 63 | 64 | def quantized_weights(self): 65 | def _fn(layer): 66 | if isinstance(layer, QuantizedModule): 67 | layer.quantized_weights() 68 | 69 | self.apply(_fn) 70 | 71 | def full_precision_weights(self): 72 | def _fn(layer): 73 | if isinstance(layer, QuantizedModule): 74 | layer.full_precision_weights() 75 | 76 | self.apply(_fn) 77 | 78 | def quantized_acts(self): 79 | def _fn(layer): 80 | if isinstance(layer, QuantizedModule): 81 | layer.quantized_acts() 82 | 83 | self.apply(_fn) 84 | 85 | def full_precision_acts(self): 86 | def _fn(layer): 87 | if isinstance(layer, QuantizedModule): 88 | layer.full_precision_acts() 89 | 90 | self.apply(_fn) 91 | 92 | def quantized(self): 93 | def _fn(layer): 94 | if isinstance(layer, QuantizedModule): 95 | layer.quantized() 96 | 97 | self.apply(_fn) 98 | 99 | def full_precision(self): 100 | def _fn(layer): 101 | if isinstance(layer, QuantizedModule): 102 | layer.full_precision() 103 | 104 | self.apply(_fn) 105 | 106 | def estimate_ranges(self): 107 | self.apply(_set_layer_estimate_ranges) 108 | 109 | def estimate_ranges_train(self): 110 | self.apply(_set_layer_estimate_ranges_train) 111 | 112 | def set_quant_state(self, weight_quant, act_quant): 113 | if act_quant: 114 | self.quantized_acts() 115 | else: 116 | self.full_precision_acts() 117 | 118 | if weight_quant: 119 | self.quantized_weights() 120 | else: 121 | self.full_precision_weights() 122 | 123 | def grad_scaling(self, grad_scaling=True): 124 | def _fn(module): 125 | if isinstance(module, QuantizerBase): 126 | module.grad_scaling = grad_scaling 127 | 128 | self.apply(_fn) 129 | # Methods for switching quantizer quantization states 130 | 131 | def learn_ranges(self): 132 | self.apply(_set_layer_learn_ranges) 133 | 134 | def fix_ranges(self): 135 | self.apply(_set_layer_fix_ranges) 136 | -------------------------------------------------------------------------------- /quantization/hijacker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import copy 6 | 7 | from timm.models.layers.activations import Swish, HardSwish, HardSigmoid 8 | from timm.models.layers.activations_me import SwishMe, HardSwishMe, HardSigmoidMe 9 | from torch import nn 10 | 11 | from quantization.base_quantized_classes import QuantizedModule 12 | from quantization.quantization_manager import QuantizationManager 13 | from quantization.range_estimators import RangeEstimators 14 | 15 | activations_set = [ 16 | nn.ReLU, 17 | nn.ReLU6, 18 | nn.Hardtanh, 19 | nn.Sigmoid, 20 | nn.Tanh, 21 | nn.GELU, 22 | nn.PReLU, 23 | Swish, 24 | SwishMe, 25 | HardSwish, 26 | HardSwishMe, 27 | HardSigmoid, 28 | HardSigmoidMe, 29 | ] 30 | 31 | 32 | class QuantizationHijacker(QuantizedModule): 33 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and 34 | dequantization on the weights and output distributions. 35 | 36 | Usage: 37 | To make a quantized nn.Linear layer: 38 | class HijackedLinear(QuantizationHijacker, nn.Linear): 39 | pass 40 | """ 41 | 42 | def __init__(self, *args, activation: nn.Module = None, **kwargs): 43 | 44 | super().__init__(*args, **kwargs) 45 | if activation: 46 | assert isinstance(activation, tuple(activations_set)), str(activation) 47 | 48 | self.activation_function = copy.deepcopy(activation) if activation else None 49 | 50 | self.activation_quantizer = QuantizationManager( 51 | qmethod=self.act_method, 52 | init=self.act_range_method, 53 | qparams=self.act_qparams, 54 | range_estim_params=self.act_range_options, 55 | ) 56 | 57 | if self.weight_range_method == RangeEstimators.current_minmax: 58 | weight_init_params = dict(percentile=self.percentile) 59 | else: 60 | weight_init_params = self.weight_range_options 61 | 62 | self.weight_quantizer = QuantizationManager( 63 | qmethod=self.method, 64 | init=self.weight_range_method, 65 | per_channel=self.per_channel_weights, 66 | qparams=self.weight_qparams, 67 | range_estim_params=weight_init_params, 68 | ) 69 | 70 | def forward(self, x, offsets=None): 71 | # Quantize input 72 | if self.quantize_input and self._quant_a: 73 | x = self.activation_quantizer(x) 74 | 75 | # Get quantized weight 76 | weight, bias = self.get_params() 77 | res = self.run_forward(x, weight, bias, offsets=offsets) 78 | 79 | # Apply fused activation function 80 | if self.activation_function is not None: 81 | res = self.activation_function(res) 82 | 83 | # Quantize output 84 | if not self.quantize_input and self._quant_a: 85 | res = self.activation_quantizer(res) 86 | return res 87 | 88 | def get_params(self): 89 | 90 | weight, bias = self.get_weight_bias() 91 | 92 | if self._quant_w: 93 | weight = self.quantize_weights(weight) 94 | 95 | return weight, bias 96 | 97 | def quantize_weights(self, weights): 98 | return self.weight_quantizer(weights) 99 | 100 | def get_weight_bias(self): 101 | bias = None 102 | if hasattr(self, "bias"): 103 | bias = self.bias 104 | return self.weight, bias 105 | 106 | def run_forward(self, x, weight, bias, offsets=None): 107 | # Performs the actual linear operation of the layer 108 | raise NotImplementedError() 109 | 110 | def extra_repr(self): 111 | activation = "input" if self.quantize_input else "output" 112 | return f"{super().extra_repr()}-{activation}" 113 | -------------------------------------------------------------------------------- /quantization/quantization_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from enum import auto 6 | 7 | from torch import nn 8 | from quantization.quantizers import QMethods, QuantizerBase 9 | from quantization.quantizers.utils import QuantizerNotInitializedError 10 | from quantization.range_estimators import RangeEstimators, RangeEstimatorBase 11 | from utils import BaseEnumOptions 12 | 13 | 14 | class QuantizationManager(nn.Module): 15 | """Implementation of Quantization and Quantization Range Estimation 16 | 17 | Parameters 18 | ---------- 19 | n_bits: int 20 | Number of bits for the quantization. 21 | qmethod: QMethods member (Enum) 22 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform, 23 | qmn_uniform etc. 24 | init: RangeEstimators member (Enum) 25 | Initialization method for the grid from 26 | per_channel: bool 27 | If true, will use a separate quantization grid for each kernel/channel. 28 | x_min: float or PyTorch Tensor 29 | The minimum value which needs to be represented. 30 | x_max: float or PyTorch Tensor 31 | The maximum value which needs to be represented. 32 | qparams: kwargs 33 | dictionary of quantization parameters to passed to the quantizer instantiation 34 | range_estim_params: kwargs 35 | dictionary of parameters to passed to the range estimator instantiation 36 | """ 37 | 38 | def __init__( 39 | self, 40 | qmethod: QuantizerBase = QMethods.symmetric_uniform.cls, 41 | init: RangeEstimatorBase = RangeEstimators.current_minmax.cls, 42 | per_channel=False, 43 | x_min=None, 44 | x_max=None, 45 | qparams=None, 46 | range_estim_params=None, 47 | ): 48 | super().__init__() 49 | self.state = Qstates.estimate_ranges 50 | self.qmethod = qmethod 51 | self.init = init 52 | self.per_channel = per_channel 53 | self.qparams = qparams if qparams else {} 54 | self.range_estim_params = range_estim_params if range_estim_params else {} 55 | self.range_estimator = None 56 | 57 | # define quantizer 58 | self.quantizer = self.qmethod(per_channel=self.per_channel, **qparams) 59 | self.quantizer.state = self.state 60 | 61 | # define range estimation method for quantizer initialisation 62 | if x_min is not None and x_max is not None: 63 | self.set_quant_range(x_min, x_max) 64 | self.fix_ranges() 65 | else: 66 | # set up the collector function to set the ranges 67 | self.range_estimator = self.init( 68 | per_channel=self.per_channel, quantizer=self.quantizer, **self.range_estim_params 69 | ) 70 | 71 | @property 72 | def n_bits(self): 73 | return self.quantizer.n_bits 74 | 75 | def estimate_ranges(self): 76 | self.state = Qstates.estimate_ranges 77 | self.quantizer.state = self.state 78 | 79 | def fix_ranges(self): 80 | if self.quantizer.is_initialized: 81 | self.state = Qstates.fix_ranges 82 | self.quantizer.state = self.state 83 | else: 84 | raise QuantizerNotInitializedError() 85 | 86 | def learn_ranges(self): 87 | self.quantizer.make_range_trainable() 88 | self.state = Qstates.learn_ranges 89 | self.quantizer.state = self.state 90 | 91 | def estimate_ranges_train(self): 92 | self.state = Qstates.estimate_ranges_train 93 | self.quantizer.state = self.state 94 | 95 | def reset_ranges(self): 96 | self.range_estimator.reset() 97 | self.quantizer.reset() 98 | self.estimate_ranges() 99 | 100 | def forward(self, x): 101 | if self.state == Qstates.estimate_ranges or ( 102 | self.state == Qstates.estimate_ranges_train and self.training 103 | ): 104 | # Note this can be per tensor or per channel 105 | cur_xmin, cur_xmax = self.range_estimator(x) 106 | self.set_quant_range(cur_xmin, cur_xmax) 107 | 108 | return self.quantizer(x) 109 | 110 | def set_quant_range(self, x_min, x_max): 111 | self.quantizer.set_quant_range(x_min, x_max) 112 | 113 | def extra_repr(self): 114 | return "state={}".format(self.state.name) 115 | 116 | 117 | class Qstates(BaseEnumOptions): 118 | estimate_ranges = auto() # ranges are updated in eval and train mode 119 | fix_ranges = auto() # quantization ranges are fixed for train and eval 120 | learn_ranges = auto() # quantization params are nn.Parameters 121 | estimate_ranges_train = auto() # quantization ranges are updated during train and fixed for 122 | # eval 123 | -------------------------------------------------------------------------------- /quantization/quantized_folded_bn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.nn.modules.conv import _ConvNd 8 | 9 | from quantization.hijacker import QuantizationHijacker 10 | 11 | 12 | class BNFusedHijacker(QuantizationHijacker): 13 | """Extension to the QuantizationHijacker that fuses batch normalization (BN) after a weight 14 | layer into a joined module. The parameters and the statistics of the BN layer remain in 15 | full-precision. 16 | """ 17 | 18 | def __init__(self, *args, **kwargs): 19 | kwargs.pop("bias", None) # Bias will be learned by BN params 20 | super().__init__(*args, **kwargs, bias=False) 21 | bn_dim = self.get_bn_dim() 22 | self.register_buffer("running_mean", torch.zeros(bn_dim)) 23 | self.register_buffer("running_var", torch.ones(bn_dim)) 24 | self.momentum = kwargs.pop("momentum", 0.1) 25 | self.gamma = nn.Parameter(torch.ones(bn_dim)) 26 | self.beta = nn.Parameter(torch.zeros(bn_dim)) 27 | self.epsilon = kwargs.get("eps", 1e-5) 28 | self.bias = None 29 | 30 | def forward(self, x): 31 | # Quantize input 32 | if self.quantize_input and self._quant_a: 33 | x = self.activation_quantizer(x) 34 | 35 | # Get quantized weight 36 | weight, bias = self.get_params() 37 | res = self.run_forward(x, weight, bias) 38 | 39 | res = F.batch_norm( 40 | res, 41 | self.running_mean, 42 | self.running_var, 43 | self.gamma, 44 | self.beta, 45 | self.training, 46 | self.momentum, 47 | self.epsilon, 48 | ) 49 | # Apply fused activation function 50 | if self.activation_function is not None: 51 | res = self.activation_function(res) 52 | 53 | # Quantize output 54 | if not self.quantize_input and self._quant_a: 55 | res = self.activation_quantizer(res) 56 | return res 57 | 58 | def get_bn_dim(self): 59 | if isinstance(self, nn.Linear): 60 | return self.out_features 61 | elif isinstance(self, _ConvNd): 62 | return self.out_channels 63 | else: 64 | msg = ( 65 | f"Unsupported type used: {self}. Must be a linear or (transpose)-convolutional " 66 | f"nn.Module" 67 | ) 68 | raise NotImplementedError(msg) 69 | -------------------------------------------------------------------------------- /quantization/quantizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from quantization.quantizers.base_quantizers import QuantizerBase 6 | from quantization.quantizers.uniform_quantizers import ( 7 | SymmetricUniformQuantizer, 8 | AsymmetricUniformQuantizer, 9 | ) 10 | from utils import ClassEnumOptions, MethodMap 11 | 12 | 13 | class QMethods(ClassEnumOptions): 14 | symmetric_uniform = MethodMap(SymmetricUniformQuantizer) 15 | asymmetric_uniform = MethodMap(AsymmetricUniformQuantizer) 16 | -------------------------------------------------------------------------------- /quantization/quantizers/base_quantizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from torch import nn 6 | 7 | 8 | class QuantizerBase(nn.Module): 9 | def __init__(self, n_bits, per_channel=False, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.n_bits = n_bits 12 | self.per_channel = per_channel 13 | self.state = None 14 | self.x_min_fp32 = self.x_max_fp32 = None 15 | 16 | @property 17 | def is_initialized(self): 18 | raise NotImplementedError() 19 | 20 | @property 21 | def x_max(self): 22 | raise NotImplementedError() 23 | 24 | @property 25 | def symmetric(self): 26 | raise NotImplementedError() 27 | 28 | @property 29 | def x_min(self): 30 | raise NotImplementedError() 31 | 32 | def forward(self, x_float): 33 | raise NotImplementedError() 34 | 35 | def _adjust_params_per_channel(self, x): 36 | raise NotImplementedError() 37 | 38 | def set_quant_range(self, x_min, x_max): 39 | raise NotImplementedError() 40 | 41 | def extra_repr(self): 42 | return "n_bits={}, per_channel={}, is_initalized={}".format( 43 | self.n_bits, self.per_channel, self.is_initialized 44 | ) 45 | 46 | def reset(self): 47 | self._delta = None 48 | -------------------------------------------------------------------------------- /quantization/quantizers/rounding_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from torch import nn 5 | import torch 6 | from torch.autograd import Function 7 | 8 | # Functional 9 | from utils import MethodMap, ClassEnumOptions 10 | 11 | 12 | class RoundStraightThrough(Function): 13 | @staticmethod 14 | def forward(ctx, x): 15 | return torch.round(x) 16 | 17 | @staticmethod 18 | def backward(ctx, output_grad): 19 | return output_grad 20 | 21 | 22 | class StochasticRoundSTE(Function): 23 | @staticmethod 24 | def forward(ctx, x): 25 | # Sample noise between [0, 1) 26 | noise = torch.rand_like(x) 27 | return torch.floor(x + noise) 28 | 29 | @staticmethod 30 | def backward(ctx, output_grad): 31 | return output_grad 32 | 33 | 34 | class ScaleGradient(Function): 35 | @staticmethod 36 | def forward(ctx, x, scale): 37 | ctx.scale = scale 38 | return x 39 | 40 | @staticmethod 41 | def backward(ctx, output_grad): 42 | return output_grad * ctx.scale, None 43 | 44 | 45 | class EWGSFunctional(Function): 46 | """ 47 | x_in: float input 48 | scaling_factor: backward scaling factor 49 | x_out: discretized version of x_in within the range of [0,1] 50 | """ 51 | 52 | @staticmethod 53 | def forward(ctx, x_in, scaling_factor): 54 | x_int = torch.round(x_in) 55 | ctx._scaling_factor = scaling_factor 56 | ctx.save_for_backward(x_in - x_int) 57 | return x_int 58 | 59 | @staticmethod 60 | def backward(ctx, g): 61 | diff = ctx.saved_tensors[0] 62 | delta = ctx._scaling_factor 63 | scale = 1 + delta * torch.sign(g) * diff 64 | return g * scale, None, None 65 | 66 | 67 | class StackSigmoidFunctional(Function): 68 | @staticmethod 69 | def forward(ctx, x, alpha): 70 | # Apply round to nearest in the forward pass 71 | ctx.save_for_backward(x, alpha) 72 | return torch.round(x) 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | x, alpha = ctx.saved_tensors 77 | sig_min = torch.sigmoid(alpha / 2) 78 | sig_scale = 1 - 2 * sig_min 79 | x_base = torch.floor(x).detach() 80 | x_rest = x - x_base - 0.5 81 | stacked_sigmoid_grad = ( 82 | torch.sigmoid(x_rest * -alpha) 83 | * (1 - torch.sigmoid(x_rest * -alpha)) 84 | * -alpha 85 | / sig_scale 86 | ) 87 | return stacked_sigmoid_grad * grad_output, None 88 | 89 | 90 | # Parametrized modules 91 | class ParametrizedGradEstimatorBase(nn.Module): 92 | def __init__(self, *args, **kwargs): 93 | super().__init__() 94 | self._trainable = False 95 | 96 | def make_grad_params_trainable(self): 97 | self._trainable = True 98 | for name, buf in self.named_buffers(recurse=False): 99 | setattr(self, name, torch.nn.Parameter(buf)) 100 | 101 | def make_grad_params_tensor(self): 102 | self._trainable = False 103 | for name, param in self.named_parameters(recurse=False): 104 | cur_value = param.data 105 | delattr(self, name) 106 | self.register_buffer(name, cur_value) 107 | 108 | def forward(self, x): 109 | raise NotImplementedError() 110 | 111 | 112 | class StackedSigmoid(ParametrizedGradEstimatorBase): 113 | """ 114 | Stacked sigmoid estimator based on a simulated sigmoid forward pass 115 | """ 116 | 117 | def __init__(self, alpha=1.0): 118 | super().__init__() 119 | self.register_buffer("alpha", torch.tensor(alpha)) 120 | 121 | def forward(self, x): 122 | return stacked_sigmoid_func(x, self.alpha) 123 | 124 | def extra_repr(self): 125 | return f"alpha={self.alpha.item()}" 126 | 127 | 128 | class EWGSDiscretizer(ParametrizedGradEstimatorBase): 129 | def __init__(self, scaling_factor=0.2): 130 | super().__init__() 131 | self.register_buffer("scaling_factor", torch.tensor(scaling_factor)) 132 | 133 | def forward(self, x): 134 | return ewgs_func(x, self.scaling_factor) 135 | 136 | def extra_repr(self): 137 | return f"scaling_factor={self.scaling_factor.item()}" 138 | 139 | 140 | class StochasticRounding(nn.Module): 141 | def __init__(self): 142 | super().__init__() 143 | 144 | def forward(self, x): 145 | if self.training: 146 | return stochastic_round_ste_func(x) 147 | else: 148 | return round_ste_func(x) 149 | 150 | 151 | round_ste_func = RoundStraightThrough.apply 152 | stacked_sigmoid_func = StackSigmoidFunctional.apply 153 | scale_grad_func = ScaleGradient.apply 154 | stochastic_round_ste_func = StochasticRoundSTE.apply 155 | ewgs_func = EWGSFunctional.apply 156 | 157 | 158 | class GradientEstimator(ClassEnumOptions): 159 | ste = MethodMap(round_ste_func) 160 | stoch_round = MethodMap(StochasticRounding) 161 | ewgs = MethodMap(EWGSDiscretizer) 162 | stacked_sigmoid = MethodMap(StackedSigmoid) 163 | -------------------------------------------------------------------------------- /quantization/quantizers/uniform_quantizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import inspect 6 | import torch 7 | 8 | from quantization.quantizers.rounding_utils import scale_grad_func, round_ste_func 9 | from .utils import QuantizerNotInitializedError 10 | from .base_quantizers import QuantizerBase 11 | 12 | 13 | class AsymmetricUniformQuantizer(QuantizerBase): 14 | """ 15 | PyTorch Module that implements Asymmetric Uniform Quantization using STE. 16 | Quantizes its argument in the forward pass, passes the gradient 'straight 17 | through' on the backward pass, ignoring the quantization that occurred. 18 | 19 | Parameters 20 | ---------- 21 | n_bits: int 22 | Number of bits for quantization. 23 | scale_domain: str ('log', 'linear) with default='linear' 24 | Domain of scale factor 25 | per_channel: bool 26 | If True: allows for per-channel quantization 27 | """ 28 | 29 | def __init__( 30 | self, 31 | n_bits, 32 | scale_domain="linear", 33 | discretizer=round_ste_func, 34 | discretizer_args=tuple(), 35 | grad_scaling=False, 36 | eps=1e-8, 37 | **kwargs 38 | ): 39 | super().__init__(n_bits=n_bits, **kwargs) 40 | 41 | assert scale_domain in ("linear", "log") 42 | self.register_buffer("_delta", None) 43 | self.register_buffer("_zero_float", None) 44 | 45 | if inspect.isclass(discretizer): 46 | self.discretizer = discretizer(*discretizer_args) 47 | else: 48 | self.discretizer = discretizer 49 | 50 | self.scale_domain = scale_domain 51 | self.grad_scaling = grad_scaling 52 | self.eps = eps 53 | 54 | # A few useful properties 55 | @property 56 | def delta(self): 57 | if self._delta is not None: 58 | return self._delta 59 | else: 60 | raise QuantizerNotInitializedError() 61 | 62 | @property 63 | def zero_float(self): 64 | if self._zero_float is not None: 65 | return self._zero_float 66 | else: 67 | raise QuantizerNotInitializedError() 68 | 69 | @property 70 | def is_initialized(self): 71 | return self._delta is not None 72 | 73 | @property 74 | def symmetric(self): 75 | return False 76 | 77 | @property 78 | def int_min(self): 79 | # integer grid minimum 80 | return 0.0 81 | 82 | @property 83 | def int_max(self): 84 | # integer grid maximum 85 | return 2.0**self.n_bits - 1 86 | 87 | @property 88 | def scale(self): 89 | if self.scale_domain == "linear": 90 | return torch.clamp(self.delta, min=self.eps) 91 | elif self.scale_domain == "log": 92 | return torch.exp(self.delta) 93 | 94 | @property 95 | def zero_point(self): 96 | zero_point = self.discretizer(self.zero_float) 97 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max) 98 | return zero_point 99 | 100 | @property 101 | def x_max(self): 102 | return self.scale * (self.int_max - self.zero_point) 103 | 104 | @property 105 | def x_min(self): 106 | return self.scale * (self.int_min - self.zero_point) 107 | 108 | def to_integer_forward(self, x_float, *args, **kwargs): 109 | """ 110 | Qunatized input to its integer representation 111 | Parameters 112 | ---------- 113 | x_float: PyTorch Float Tensor 114 | Full-precision Tensor 115 | 116 | Returns 117 | ------- 118 | x_int: PyTorch Float Tensor of integers 119 | """ 120 | if self.grad_scaling: 121 | grad_scale = self.calculate_grad_scale(x_float) 122 | scale = scale_grad_func(self.scale, grad_scale) 123 | zero_point = ( 124 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale) 125 | ) 126 | else: 127 | scale = self.scale 128 | zero_point = self.zero_point 129 | 130 | x_int = self.discretizer(x_float / scale) + zero_point 131 | x_int = torch.clamp(x_int, self.int_min, self.int_max) 132 | 133 | return x_int 134 | 135 | def forward(self, x_float, *args, **kwargs): 136 | """ 137 | Quantizes (quantized to integer and the scales back to original domain) 138 | Parameters 139 | ---------- 140 | x_float: PyTorch Float Tensor 141 | Full-precision Tensor 142 | 143 | Returns 144 | ------- 145 | x_quant: PyTorch Float Tensor 146 | Quantized-Dequantized Tensor 147 | """ 148 | if self.per_channel: 149 | self._adjust_params_per_channel(x_float) 150 | 151 | if self.grad_scaling: 152 | grad_scale = self.calculate_grad_scale(x_float) 153 | scale = scale_grad_func(self.scale, grad_scale) 154 | zero_point = ( 155 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale) 156 | ) 157 | else: 158 | scale = self.scale 159 | zero_point = self.zero_point 160 | 161 | x_int = self.to_integer_forward(x_float, *args, **kwargs) 162 | x_quant = scale * (x_int - zero_point) 163 | 164 | return x_quant 165 | 166 | def calculate_grad_scale(self, quant_tensor): 167 | num_pos_levels = self.int_max # Qp in LSQ paper 168 | num_elements = quant_tensor.numel() # nfeatures or nweights in LSQ paper 169 | if self.per_channel: 170 | # In the per tensor case we do not sum the gradients over the output channel dimension 171 | num_elements /= quant_tensor.shape[0] 172 | 173 | return (num_pos_levels * num_elements) ** -0.5 # 1 / sqrt (Qn * nfeatures) 174 | 175 | def _adjust_params_per_channel(self, x): 176 | """ 177 | Adjusts the quantization parameter tensors (delta, zero_float) 178 | to the input tensor shape if they don't match 179 | Parameters 180 | ---------- 181 | x: input tensor 182 | """ 183 | if x.ndim != self.delta.ndim: 184 | new_shape = [-1] + [1] * (len(x.shape) - 1) 185 | self._delta = self.delta.view(new_shape) 186 | if self._zero_float is not None: 187 | self._zero_float = self._zero_float.view(new_shape) 188 | 189 | def _tensorize_min_max(self, x_min, x_max): 190 | """ 191 | Converts provided min max range into tensors 192 | Parameters 193 | ---------- 194 | x_min: float or PyTorch 1D tensor 195 | x_max: float or PyTorch 1D tensor 196 | 197 | Returns 198 | ------- 199 | x_min: PyTorch Tensor 0 or 1-D 200 | x_max: PyTorch Tensor 0 or 1-D 201 | """ 202 | # Ensure a torch tensor 203 | if not torch.is_tensor(x_min): 204 | x_min = torch.tensor(x_min).float() 205 | x_max = torch.tensor(x_max).float() 206 | 207 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel: 208 | print(x_min) 209 | print(self.per_channel) 210 | raise ValueError( 211 | "x_min and x_max must be a float or 1-D Tensor" 212 | " for per-tensor quantization (per_channel=False)" 213 | ) 214 | # Ensure we always use zero and avoid division by zero 215 | x_min = torch.min(x_min, torch.zeros_like(x_min)) 216 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps) 217 | 218 | return x_min, x_max 219 | 220 | def set_quant_range(self, x_min, x_max): 221 | """ 222 | Instantiates the quantization parameters based on the provided 223 | min and max range 224 | 225 | Parameters 226 | ---------- 227 | x_min: tensor or float 228 | Quantization range minimum limit 229 | x_max: tensor of float 230 | Quantization range minimum limit 231 | """ 232 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max 233 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 234 | self._delta = (x_max - x_min) / self.int_max 235 | self._zero_float = (-x_min / self.delta).detach() 236 | 237 | if self.scale_domain == "log": 238 | self._delta = torch.log(self.delta) 239 | 240 | self._delta = self._delta.detach() 241 | 242 | def make_range_trainable(self): 243 | # Converts trainable parameters to nn.Parameters 244 | if self.delta not in self.parameters(): 245 | self._delta = torch.nn.Parameter(self._delta) 246 | self._zero_float = torch.nn.Parameter(self._zero_float) 247 | 248 | def fix_ranges(self): 249 | # Removes trainable quantization params from nn.Parameters 250 | if self.delta in self.parameters(): 251 | _delta = self._delta.data 252 | _zero_float = self._zero_float.data 253 | del self._delta # delete the parameter 254 | del self._zero_float 255 | self.register_buffer("_delta", _delta) 256 | self.register_buffer("_zero_float", _zero_float) 257 | 258 | 259 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer): 260 | """ 261 | PyTorch Module that implements Symmetric Uniform Quantization using STE. 262 | Quantizes its argument in the forward pass, passes the gradient 'straight 263 | through' on the backward pass, ignoring the quantization that occurred. 264 | 265 | Parameters 266 | ---------- 267 | n_bits: int 268 | Number of bits for quantization. 269 | scale_domain: str ('log', 'linear) with default='linear' 270 | Domain of scale factor 271 | per_channel: bool 272 | If True: allows for per-channel quantization 273 | """ 274 | 275 | def __init__(self, *args, **kwargs): 276 | super().__init__(*args, **kwargs) 277 | self.register_buffer("_signed", None) 278 | 279 | @property 280 | def signed(self): 281 | if self._signed is not None: 282 | return self._signed.item() 283 | else: 284 | raise QuantizerNotInitializedError() 285 | 286 | @property 287 | def symmetric(self): 288 | return True 289 | 290 | @property 291 | def int_min(self): 292 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0 293 | 294 | @property 295 | def int_max(self): 296 | pos_n_bits = self.n_bits - self.signed 297 | return 2.0**pos_n_bits - 1 298 | 299 | @property 300 | def zero_point(self): 301 | return 0.0 302 | 303 | def set_quant_range(self, x_min, x_max): 304 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max 305 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 306 | self._signed = x_min.min() < 0 307 | 308 | x_absmax = torch.max(x_min.abs(), x_max) 309 | self._delta = x_absmax / self.int_max 310 | 311 | if self.scale_domain == "log": 312 | self._delta = torch.log(self._delta) 313 | 314 | self._delta = self._delta.detach() 315 | 316 | def make_range_trainable(self): 317 | # Converts trainable parameters to nn.Parameters 318 | if self.delta not in self.parameters(): 319 | self._delta = torch.nn.Parameter(self._delta) 320 | 321 | def fix_ranges(self): 322 | # Removes trainable quantization params from nn.Parameters 323 | if self.delta in self.parameters(): 324 | _delta = self._delta.data 325 | del self._delta # delete the parameter 326 | self.register_buffer("_delta", _delta) 327 | -------------------------------------------------------------------------------- /quantization/quantizers/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | 6 | class QuantizerNotInitializedError(Exception): 7 | """Raised when a quantizer has not been initialized""" 8 | 9 | def __init__(self): 10 | super(QuantizerNotInitializedError, self).__init__( 11 | "Quantizer has not been initialized yet" 12 | ) 13 | -------------------------------------------------------------------------------- /quantization/range_estimators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import copy 5 | from enum import auto 6 | 7 | import numpy as np 8 | import torch 9 | from scipy.optimize import minimize_scalar 10 | from torch import nn 11 | 12 | from utils import to_numpy, BaseEnumOptions, MethodMap, ClassEnumOptions 13 | 14 | 15 | class RangeEstimatorBase(nn.Module): 16 | def __init__(self, per_channel=False, quantizer=None, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | self.register_buffer("current_xmin", None) 19 | self.register_buffer("current_xmax", None) 20 | self.per_channel = per_channel 21 | self.quantizer = quantizer 22 | 23 | def forward(self, x): 24 | """ 25 | Accepts an input tensor, updates the current estimates of x_min and x_max 26 | and returns them. 27 | Parameters 28 | ---------- 29 | x: Input tensor 30 | 31 | Returns 32 | ------- 33 | self.current_xmin: tensor 34 | 35 | self.current_xmax: tensor 36 | 37 | """ 38 | raise NotImplementedError() 39 | 40 | def reset(self): 41 | """ 42 | Reset the range estimator. 43 | """ 44 | self.current_xmin = None 45 | self.current_xmax = None 46 | 47 | def __repr__(self): 48 | # We overwrite this from nn.Module as we do not want to have submodules such as 49 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module. 50 | lines = self.extra_repr().split("\n") 51 | extra_str = lines[0] if len(lines) == 1 else "\n " + "\n ".join(lines) + "\n" 52 | 53 | return self._get_name() + "(" + extra_str + ")" 54 | 55 | 56 | class CurrentMinMaxEstimator(RangeEstimatorBase): 57 | def __init__(self, percentile=None, *args, **kwargs): 58 | self.percentile = percentile 59 | super().__init__(*args, **kwargs) 60 | 61 | def forward(self, x): 62 | if self.per_channel: 63 | x = x.view(x.shape[0], -1) 64 | if self.percentile: 65 | axis = -1 if self.per_channel else None 66 | data_np = to_numpy(x) 67 | x_min, x_max = np.percentile( 68 | data_np, (self.percentile, 100 - self.percentile), axis=axis 69 | ) 70 | self.current_xmin = torch.tensor(x_min).to(x.device) 71 | self.current_xmax = torch.tensor(x_max).to(x.device) 72 | else: 73 | self.current_xmin = x.min(-1)[0].detach() if self.per_channel else x.min().detach() 74 | self.current_xmax = x.max(-1)[0].detach() if self.per_channel else x.max().detach() 75 | 76 | return self.current_xmin, self.current_xmax 77 | 78 | 79 | class RunningMinMaxEstimator(RangeEstimatorBase): 80 | def __init__(self, momentum=0.9, *args, **kwargs): 81 | self.momentum = momentum 82 | super().__init__(*args, **kwargs) 83 | 84 | def forward(self, x): 85 | if self.per_channel: 86 | # Along 1st dim 87 | x_flattened = x.view(x.shape[0], -1) 88 | x_min = x_flattened.min(-1)[0].detach() 89 | x_max = x_flattened.max(-1)[0].detach() 90 | else: 91 | x_min = torch.min(x).detach() 92 | x_max = torch.max(x).detach() 93 | 94 | if self.current_xmin is None: 95 | self.current_xmin = x_min 96 | self.current_xmax = x_max 97 | else: 98 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin 99 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax 100 | 101 | return self.current_xmin, self.current_xmax 102 | 103 | 104 | class OptMethod(BaseEnumOptions): 105 | grid = auto() 106 | golden_section = auto() 107 | 108 | 109 | class MSE_Estimator(RangeEstimatorBase): 110 | def __init__( 111 | self, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, *args, **kwargs 112 | ): 113 | 114 | super().__init__(*args, **kwargs) 115 | assert opt_method in OptMethod 116 | self.opt_method = opt_method 117 | self.num_candidates = num_candidates 118 | self.loss_array = None 119 | self.max_pos_thr = None 120 | self.max_neg_thr = None 121 | self.max_search_range = None 122 | self.one_sided_dist = None 123 | self.range_margin = range_margin 124 | if self.quantizer is None: 125 | raise NotImplementedError( 126 | "A Quantizer must be given as an argument to the MSE Range Estimator" 127 | ) 128 | self.max_int_skew = (2**self.quantizer.n_bits) // 4 # For asymmetric quantization 129 | 130 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False): 131 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr) 132 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1) 133 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel 134 | # grid search 135 | if per_channel_loss: 136 | return to_numpy(temp_sum) 137 | else: 138 | return to_numpy(torch.sum(temp_sum)) 139 | 140 | @property 141 | def step_size(self): 142 | if self.one_sided_dist is None: 143 | raise NoDataPassedError() 144 | 145 | return self.max_search_range / self.num_candidates 146 | 147 | @property 148 | def optimization_method(self): 149 | if self.one_sided_dist is None: 150 | raise NoDataPassedError() 151 | 152 | if self.opt_method == OptMethod.grid: 153 | # Grid search method 154 | if self.one_sided_dist or self.quantizer.symmetric: 155 | # 1-D grid search 156 | return self._perform_1D_search 157 | else: 158 | # 2-D grid_search 159 | return self._perform_2D_search 160 | elif self.opt_method == OptMethod.golden_section: 161 | # Golden section method 162 | if self.one_sided_dist or self.quantizer.symmetric: 163 | return self._golden_section_symmetric 164 | else: 165 | return self._golden_section_asymmetric 166 | else: 167 | raise NotImplementedError("Optimization Method not Implemented") 168 | 169 | def quantize(self, x_float, x_min=None, x_max=None): 170 | temp_q = copy.deepcopy(self.quantizer) 171 | # In the current implementation no optimization procedure requires temp quantizer for 172 | # loss_fx to be per-channel 173 | temp_q.per_channel = False 174 | if x_min or x_max: 175 | temp_q.set_quant_range(x_min, x_max) 176 | return temp_q(x_float) 177 | 178 | def golden_sym_loss(self, range, data): 179 | """ 180 | Loss function passed to the golden section optimizer from scipy in case of symmetric 181 | quantization 182 | """ 183 | neg_thr = 0 if self.one_sided_dist else -range 184 | pos_thr = range 185 | return self.loss_fx(data, neg_thr, pos_thr) 186 | 187 | def golden_asym_shift_loss(self, shift, range, data): 188 | """ 189 | Inner Loss function (shift) passed to the golden section optimizer from scipy 190 | in case of asymmetric quantization 191 | """ 192 | pos_thr = range + shift 193 | neg_thr = -range + shift 194 | return self.loss_fx(data, neg_thr, pos_thr) 195 | 196 | def golden_asym_range_loss(self, range, data): 197 | """ 198 | Outer Loss function (range) passed to the golden section optimizer from scipy in case of 199 | asymmetric quantization 200 | """ 201 | temp_delta = 2 * range / (2**self.quantizer.n_bits - 1) 202 | max_shift = temp_delta * self.max_int_skew 203 | result = minimize_scalar( 204 | self.golden_asym_shift_loss, 205 | args=(range, data), 206 | bounds=(-max_shift, max_shift), 207 | method="Bounded", 208 | ) 209 | return result.fun 210 | 211 | def _define_search_range(self, data): 212 | self.channel_groups = len(data) if self.per_channel else 1 213 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device) 214 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device) 215 | 216 | if self.one_sided_dist or self.quantizer.symmetric: 217 | # 1D search space 218 | self.loss_array = np.zeros( 219 | (self.channel_groups, self.num_candidates + 1) 220 | ) # 1D search space 221 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish 222 | # Defining the search range for clipping thresholds 223 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin 224 | self.max_neg_thr = -self.max_pos_thr 225 | self.max_search_range = self.max_pos_thr 226 | else: 227 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth 228 | # index represents whether the skew is positive (0) or negative (1)) 229 | self.loss_array = np.zeros( 230 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2] 231 | ) # 2D search space 232 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish 233 | # Define the search range for clipping thresholds in asymmetric case 234 | self.max_pos_thr = float(data.max()) + self.range_margin 235 | self.max_neg_thr = float(data.min()) - self.range_margin 236 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr)) 237 | 238 | def _perform_1D_search(self, data): 239 | """ 240 | Grid search through all candidate quantizers in 1D to find the best 241 | The loss is accumulated over all batches without any momentum 242 | :param data: input tensor 243 | """ 244 | for cand_index in range(1, self.num_candidates + 1): 245 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index 246 | pos_thr = self.step_size * cand_index 247 | 248 | self.loss_array[:, cand_index] += self.loss_fx( 249 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel 250 | ) 251 | # find the best clipping thresholds 252 | min_cand = self.loss_array.argmin(axis=1) 253 | xmin = ( 254 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand 255 | ).astype(np.single) 256 | xmax = (self.step_size * min_cand).astype(np.single) 257 | self.current_xmax = torch.tensor(xmax).to(device=data.device) 258 | self.current_xmin = torch.tensor(xmin).to(device=data.device) 259 | 260 | def _perform_2D_search(self, data): 261 | """ 262 | Grid search through all candidate quantizers in 1D to find the best 263 | The loss is accumulated over all batches without any momentum 264 | Parameters 265 | ---------- 266 | data: PyTorch Tensor 267 | Returns 268 | ------- 269 | 270 | """ 271 | for cand_index in range(1, self.num_candidates + 1): 272 | # defining the symmetric quantization range 273 | temp_start = -self.step_size * cand_index 274 | temp_finish = self.step_size * cand_index 275 | temp_delta = float(temp_finish - temp_start) / (2**self.quantizer.n_bits - 1) 276 | for shift in range(self.max_int_skew): 277 | for reverse in range(2): 278 | # introducing asymmetry in the quantization range 279 | skew = ((-1) ** reverse) * shift * temp_delta 280 | neg_thr = max(temp_start + skew, self.max_neg_thr) 281 | pos_thr = min(temp_finish + skew, self.max_pos_thr) 282 | 283 | self.loss_array[:, cand_index, shift, reverse] += self.loss_fx( 284 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel 285 | ) 286 | 287 | for channel_index in range(self.channel_groups): 288 | min_cand, min_shift, min_reverse = np.unravel_index( 289 | np.argmin(self.loss_array[channel_index], axis=None), 290 | self.loss_array[channel_index].shape, 291 | ) 292 | min_interval_start = -self.step_size * min_cand 293 | min_interval_finish = self.step_size * min_cand 294 | min_delta = float(min_interval_finish - min_interval_start) / ( 295 | 2**self.quantizer.n_bits - 1 296 | ) 297 | min_skew = ((-1) ** min_reverse) * min_shift * min_delta 298 | xmin = max(min_interval_start + min_skew, self.max_neg_thr) 299 | xmax = min(min_interval_finish + min_skew, self.max_pos_thr) 300 | 301 | self.current_xmin[channel_index] = torch.tensor(xmin).to(device=data.device) 302 | self.current_xmax[channel_index] = torch.tensor(xmax).to(device=data.device) 303 | 304 | def _golden_section_symmetric(self, data): 305 | for channel_index in range(self.channel_groups): 306 | if channel_index == 0 and not self.per_channel: 307 | data_segment = data 308 | else: 309 | data_segment = data[channel_index] 310 | 311 | self.result = minimize_scalar( 312 | self.golden_sym_loss, 313 | args=data_segment, 314 | bounds=(0.01 * self.max_search_range, self.max_search_range), 315 | method="Bounded", 316 | ) 317 | self.current_xmax[channel_index] = torch.tensor(self.result.x).to(device=data.device) 318 | self.current_xmin[channel_index] = ( 319 | torch.tensor(0.0).to(device=data.device) 320 | if self.one_sided_dist 321 | else -self.current_xmax[channel_index] 322 | ) 323 | 324 | def _golden_section_asymmetric(self, data): 325 | for channel_index in range(self.channel_groups): 326 | if channel_index == 0 and not self.per_channel: 327 | data_segment = data 328 | else: 329 | data_segment = data[channel_index] 330 | 331 | self.result = minimize_scalar( 332 | self.golden_asym_range_loss, 333 | args=data_segment, 334 | bounds=(0.01 * self.max_search_range, self.max_search_range), 335 | method="Bounded", 336 | ) 337 | self.final_range = self.result.x 338 | temp_delta = 2 * self.final_range / (2**self.quantizer.n_bits - 1) 339 | max_shift = temp_delta * self.max_int_skew 340 | self.subresult = minimize_scalar( 341 | self.golden_asym_shift_loss, 342 | args=(self.final_range, data_segment), 343 | bounds=(-max_shift, max_shift), 344 | method="Bounded", 345 | ) 346 | self.final_shift = self.subresult.x 347 | self.current_xmax[channel_index] = torch.tensor(self.final_range + self.final_shift).to( 348 | device=data.device 349 | ) 350 | self.current_xmin[channel_index] = torch.tensor( 351 | -self.final_range + self.final_shift 352 | ).to(device=data.device) 353 | 354 | def forward(self, data): 355 | if self.loss_array is None: 356 | # Initialize search range on first batch, and accumulate losses with subsequent calls 357 | 358 | # Decide whether input distribution is one-sided 359 | if self.one_sided_dist is None: 360 | self.one_sided_dist = bool((data.min() >= 0).item()) 361 | 362 | # Define search 363 | self._define_search_range(data) 364 | 365 | # Perform Search/Optimization for Quantization Ranges 366 | self.optimization_method(data) 367 | 368 | return self.current_xmin, self.current_xmax 369 | 370 | def reset(self): 371 | super().reset() 372 | self.loss_array = None 373 | 374 | def extra_repr(self): 375 | repr = "opt_method={}".format(self.opt_method.name) 376 | if self.opt_method == OptMethod.grid: 377 | repr += " ,num_candidates={}".format(self.num_candidates) 378 | return repr 379 | 380 | 381 | class NoDataPassedError(Exception): 382 | """Raised data has been passed into the Range Estimator""" 383 | 384 | def __init__(self): 385 | super().__init__("Data must be pass through the range estimator to be initialized") 386 | 387 | 388 | class RangeEstimators(ClassEnumOptions): 389 | current_minmax = MethodMap(CurrentMinMaxEstimator) 390 | running_minmax = MethodMap(RunningMinMaxEstimator) 391 | MSE = MethodMap(MSE_Estimator) 392 | -------------------------------------------------------------------------------- /quantization/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | 6 | import torch 7 | import torch.serialization 8 | 9 | from quantization.quantizers import QuantizerBase 10 | from quantization.quantizers.rounding_utils import ParametrizedGradEstimatorBase 11 | from quantization.range_estimators import RangeEstimators 12 | from utils import StopForwardException, get_layer_by_name 13 | 14 | 15 | def separate_quantized_model_params(quant_model): 16 | """ 17 | This method separates the parameters of the quantized model to 4 categories. 18 | Parameters 19 | ---------- 20 | quant_model: (QuantizedModel) 21 | 22 | Returns 23 | ------- 24 | quant_params: (list) 25 | Quantization parameters, e.g. delta and zero_float 26 | model_params: (list) 27 | The model parameters of the base model without any quantization operations 28 | grad_params: (list) 29 | Parameters found in the gradient estimators (ParametrizedGradEstimatorBase) 30 | ------- 31 | 32 | """ 33 | quant_params, grad_params = [], [] 34 | quant_params_names, grad_params_names = [], [] 35 | for mod_name, module in quant_model.named_modules(): 36 | if isinstance(module, QuantizerBase): 37 | for name, param in module.named_parameters(recurse=False): 38 | quant_params.append(param) 39 | quant_params_names.append(".".join((mod_name, name))) 40 | if isinstance(module, ParametrizedGradEstimatorBase): 41 | # gradient estimator params 42 | for name, param in module.named_parameters(recurse=False): 43 | grad_params.append(param) 44 | grad_params_names.append(".".join((mod_name, name))) 45 | 46 | def tensor_in_list(tensor, lst): 47 | return any([e is tensor for e in lst]) 48 | 49 | found_params = quant_params + grad_params 50 | 51 | model_params = [p for p in quant_model.parameters() if not tensor_in_list(p, found_params)] 52 | model_param_names = [ 53 | n for n, p in quant_model.named_parameters() if not tensor_in_list(p, found_params) 54 | ] 55 | 56 | print("Quantization parameters ({}):".format(len(quant_params_names))) 57 | print(quant_params_names) 58 | 59 | print("Gradient estimator parameters ({}):".format(len(grad_params_names))) 60 | print(grad_params_names) 61 | 62 | print("Other model parameters ({}):".format(len(model_param_names))) 63 | print(model_param_names) 64 | 65 | assert len(model_params + quant_params + grad_params) == len( 66 | list(quant_model.parameters()) 67 | ), "{}; {}; {} -- {}".format( 68 | len(model_params), len(quant_params), len(grad_params), len(list(quant_model.parameters())) 69 | ) 70 | 71 | return quant_params, model_params, grad_params 72 | 73 | 74 | def pass_data_for_range_estimation( 75 | loader, model, act_quant, weight_quant, max_num_batches=20, cross_entropy_layer=None, inp_idx=0 76 | ): 77 | print("\nEstimate quantization ranges on training data") 78 | model.set_quant_state(weight_quant, act_quant) 79 | # Put model in eval such that BN EMA does not get updated 80 | model.eval() 81 | 82 | if cross_entropy_layer is not None: 83 | layer_xent = get_layer_by_name(model, cross_entropy_layer) 84 | if layer_xent: 85 | print('Set cross entropy estimator for layer "{}"'.format(cross_entropy_layer)) 86 | act_quant_mgr = layer_xent.activation_quantizer 87 | act_quant_mgr.range_estimator = RangeEstimators.cross_entropy.cls( 88 | per_channel=act_quant_mgr.per_channel, 89 | quantizer=act_quant_mgr.quantizer, 90 | **act_quant_mgr.range_estim_params, 91 | ) 92 | else: 93 | raise ValueError("Cross-entropy layer not found") 94 | 95 | batches = [] 96 | device = next(model.parameters()).device 97 | 98 | with torch.no_grad(): 99 | for i, data in enumerate(loader): 100 | try: 101 | if isinstance(data, (tuple, list)): 102 | x = data[inp_idx].to(device=device) 103 | batches.append(x.data.cpu().numpy()) 104 | model(x) 105 | print(f"proccesed step={i}") 106 | else: 107 | x = {k: v.to(device=device) for k, v in data.items()} 108 | model(**x) 109 | print(f"proccesed step={i}") 110 | 111 | if i >= max_num_batches - 1 or not act_quant: 112 | break 113 | except StopForwardException: 114 | pass 115 | return batches 116 | 117 | 118 | def set_range_estimators(config, model): 119 | print("Make quantizers learnable") 120 | model.learn_ranges() 121 | 122 | if config.qat.grad_scaling: 123 | print("Activate gradient scaling") 124 | model.grad_scaling(True) 125 | 126 | # Ensure we have the desired quant state 127 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant) 128 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click>=7.0 2 | pytorch-ignite~=0.4.9 3 | tensorboard>=2.5 4 | scipy==1.3.1 5 | numpy==1.19.5 6 | pillow==6.2.1 7 | timm~=0.4.12 8 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from .stopwatch import Stopwatch 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /utils/click_options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import click 6 | 7 | from models import QuantArchitectures 8 | from functools import wraps, partial 9 | from quantization.quantizers import QMethods 10 | from quantization.range_estimators import RangeEstimators, OptMethod 11 | from utils import split_dict, DotDict, ClickEnumOption, seed_all 12 | from utils.imagenet_dataloaders import ImageInterpolation 13 | 14 | click.option = partial(click.option, show_default=True) 15 | 16 | _HELP_MSG = ( 17 | "Enforce determinism also on the GPU by disabling CUDNN and setting " 18 | "`torch.set_deterministic(True)`. In many cases this comes at the cost of efficiency " 19 | "and performance." 20 | ) 21 | 22 | 23 | def base_options(func): 24 | @click.option( 25 | "--images-dir", type=click.Path(exists=True), help="Root directory of images", required=True 26 | ) 27 | @click.option("--max-epochs", default=90, type=int, help="Maximum number of training epochs.") 28 | @click.option( 29 | "--interpolation", 30 | type=ClickEnumOption(ImageInterpolation), 31 | default=ImageInterpolation.bilinear.name, 32 | help="Desired interpolation to use for resizing.", 33 | ) 34 | @click.option( 35 | "--save-checkpoint-dir", 36 | type=click.Path(exists=False), 37 | default=None, 38 | help="Directory where to save checkpoints (model, optimizer, lr_scheduler).", 39 | ) 40 | @click.option( 41 | "--tb-logging-dir", default=None, type=str, help="The logging directory " "for tensorboard" 42 | ) 43 | @click.option("--cuda/--no-cuda", is_flag=True, default=True, help="Use GPU") 44 | @click.option("--batch-size", default=128, type=int, help="Mini-batch size") 45 | @click.option("--num-workers", default=16, type=int, help="Number of workers for data loading") 46 | @click.option("--seed", default=None, type=int, help="Random number generator seed to set") 47 | @click.option("--deterministic/--nondeterministic", default=False, help=_HELP_MSG) 48 | # Architecture related options 49 | @click.option( 50 | "--architecture", 51 | type=ClickEnumOption(QuantArchitectures), 52 | required=True, 53 | help="Quantized architecture", 54 | ) 55 | @click.option( 56 | "--model-dir", 57 | type=click.Path(exists=True), 58 | default=None, 59 | help="Path for model directory. If the model does not exist it will downloaded " 60 | "from a URL", 61 | ) 62 | @click.option( 63 | "--pretrained/--no-pretrained", 64 | is_flag=True, 65 | default=True, 66 | help="Use pretrained model weights", 67 | ) 68 | @click.option( 69 | "--progress-bar/--no-progress-bar", is_flag=True, default=False, help="Show progress bar" 70 | ) 71 | @wraps(func) 72 | def func_wrapper(config, *args, **kwargs): 73 | config.base, remaining_kwargs = split_dict( 74 | kwargs, 75 | [ 76 | "images_dir", 77 | "max_epochs", 78 | "interpolation", 79 | "save_checkpoint_dir", 80 | "tb_logging_dir", 81 | "cuda", 82 | "batch_size", 83 | "num_workers", 84 | "seed", 85 | "model_dir", 86 | "architecture", 87 | "pretrained", 88 | "deterministic", 89 | "progress_bar", 90 | ], 91 | ) 92 | 93 | seed, deterministic = config.base.seed, config.base.deterministic 94 | 95 | if seed is None: 96 | if deterministic is True: 97 | raise ValueError("Enforcing determinism without providing a seed is not supported") 98 | else: 99 | seed_all(seed=seed, deterministic=deterministic) 100 | 101 | return func(config, *args, **remaining_kwargs) 102 | 103 | return func_wrapper 104 | 105 | 106 | class multi_optimizer_options: 107 | """ 108 | An instance of this class is a callable object to serve as a decorator; 109 | hence the lower case class name. 110 | 111 | Among the CLI options defined in the decorator, `--{prefix-}optimizer-type` 112 | requires special attention. Default value for that variable for 113 | {prefix-}optimizer is the value in use by the main optimizer. 114 | 115 | Examples: 116 | @multi_optimizer_options('quant') 117 | @pass_config 118 | def command(config): 119 | ... 120 | """ 121 | 122 | def __init__(self, prefix: str = ""): 123 | self.optimizer_name = prefix + "_optimizer" if prefix else "optimizer" 124 | self.prefix_option = prefix + "-" if prefix else "" 125 | self.prefix_attribute = prefix + "_" if prefix else "" 126 | 127 | def __call__(self, func): 128 | prefix_option = self.prefix_option 129 | prefix_attribute = self.prefix_attribute 130 | 131 | @click.option( 132 | f"--{prefix_option}optimizer", 133 | default="SGD", 134 | type=click.Choice(["SGD", "Adam"], case_sensitive=False), 135 | help=f"Class name of torch Optimizer to be used.", 136 | ) 137 | @click.option( 138 | f"--{prefix_option}learning-rate", 139 | default=None, 140 | type=float, 141 | help="Initial learning rate.", 142 | ) 143 | @click.option( 144 | f"--{prefix_option}momentum", default=0.9, type=float, help=f"Optimizer momentum." 145 | ) 146 | @click.option( 147 | f"--{prefix_option}weight-decay", 148 | default=None, 149 | type=float, 150 | help="Weight decay for the network.", 151 | ) 152 | @click.option( 153 | f"--{prefix_option}learning-rate-schedule", 154 | default=None, 155 | type=str, 156 | help="Learning rate scheduler, 'MultiStepLR:10:20:40' or " 157 | "'cosine:1e-4' for cosine decay", 158 | ) 159 | @wraps(func) 160 | def func_wrapper(config, *args, **kwargs): 161 | base_arg_names = [ 162 | "optimizer", 163 | "learning_rate", 164 | "momentum", 165 | "weight_decay", 166 | "learning_rate_schedule", 167 | ] 168 | 169 | optimizer_opt = DotDict() 170 | 171 | # Collect basic arguments 172 | for arg in base_arg_names: 173 | option_name = prefix_attribute + arg 174 | optimizer_opt[arg] = kwargs.pop(option_name) 175 | 176 | # config.{prefix_attribute}optimizer = optimizer_opt 177 | setattr(config, prefix_attribute + "optimizer", optimizer_opt) 178 | 179 | return func(config, *args, **kwargs) 180 | 181 | return func_wrapper 182 | 183 | 184 | def qat_options(func): 185 | @click.option( 186 | "--reestimate-bn-stats/--no-reestimate-bn-stats", 187 | is_flag=True, 188 | default=True, 189 | help="Reestimates the BN stats before every evaluation.", 190 | ) 191 | @click.option( 192 | "--grad-scaling/--no-grad-scaling", 193 | is_flag=True, 194 | default=False, 195 | help="Do gradient scaling as in LSQ paper.", 196 | ) 197 | @click.option( 198 | "--sep-quant-optimizer/--no-sep-quant-optimizer", 199 | is_flag=True, 200 | default=False, 201 | help="Use a separate optimizer for the quantizers.", 202 | ) 203 | @multi_optimizer_options("quant") 204 | @oscillations_dampen_options 205 | @oscillations_freeze_options 206 | @wraps(func) 207 | def func_wrapper(config, *args, **kwargs): 208 | config.qat, remainder_kwargs = split_dict( 209 | kwargs, ["reestimate_bn_stats", "grad_scaling", "sep_quant_optimizer"] 210 | ) 211 | return func(config, *args, **remainder_kwargs) 212 | 213 | return func_wrapper 214 | 215 | 216 | def oscillations_dampen_options(func): 217 | @click.option( 218 | "--oscillations-dampen-weight", 219 | default=None, 220 | type=float, 221 | help="If given, adds a oscillations dampening to the loss with given " "weighting.", 222 | ) 223 | @click.option( 224 | "--oscillations-dampen-aggregation", 225 | type=click.Choice(["sum", "mean", "kernel_mean"]), 226 | default="kernel_mean", 227 | help="Aggregation type for bin regularization loss.", 228 | ) 229 | @click.option( 230 | "--oscillations-dampen-weight-final", 231 | type=float, 232 | default=None, 233 | help="Dampening regularization final value for annealing schedule.", 234 | ) 235 | @click.option( 236 | "--oscillations-dampen-anneal-start", 237 | default=0.25, 238 | type=float, 239 | help="Start of annealing (relative to total number of iterations).", 240 | ) 241 | @wraps(func) 242 | def func_wrapper(config, *args, **kwargs): 243 | config.osc_damp, remainder_kwargs = split_dict( 244 | kwargs, 245 | [ 246 | "oscillations_dampen_weight", 247 | "oscillations_dampen_aggregation", 248 | "oscillations_dampen_weight_final", 249 | "oscillations_dampen_anneal_start", 250 | ], 251 | "oscillations_dampen", 252 | ) 253 | 254 | return func(config, *args, **remainder_kwargs) 255 | 256 | return func_wrapper 257 | 258 | 259 | def oscillations_freeze_options(func): 260 | @click.option( 261 | "--oscillations-freeze-threshold", 262 | default=0.0, 263 | type=float, 264 | help="If greater than 0, we will freeze oscillations which frequency (EMA) is " 265 | "higher than the given threshold. Frequency is defined as 1/period length.", 266 | ) 267 | @click.option( 268 | "--oscillations-freeze-ema-momentum", 269 | default=0.001, 270 | type=float, 271 | help="The momentum to calculate the EMA frequency of the oscillation. In case" 272 | "freezing is used, this should be at least 2-3 times lower than the " 273 | "freeze threshold.", 274 | ) 275 | @click.option( 276 | "--oscillations-freeze-use-ema/--no-oscillation-freeze-use-ema", 277 | is_flag=True, 278 | default=True, 279 | help="Uses an EMA of past x_int to find the correct freezing int value.", 280 | ) 281 | @click.option( 282 | "--oscillations-freeze-max-bits", 283 | default=4, 284 | type=int, 285 | help="Max bit-width for oscillation tracking and freezing. If layers weight is in" 286 | "higher bits we do not track or freeze oscillations.", 287 | ) 288 | @click.option( 289 | "--oscillations-freeze-threshold-final", 290 | type=float, 291 | default=None, 292 | help="Oscillation freezing final value for annealing schedule.", 293 | ) 294 | @click.option( 295 | "--oscillations-freeze-anneal-start", 296 | default=0.25, 297 | type=float, 298 | help="Start of annealing (relative to total number of iterations).", 299 | ) 300 | @wraps(func) 301 | def func_wrapper(config, *args, **kwargs): 302 | config.osc_freeze, remainder_kwargs = split_dict( 303 | kwargs, 304 | [ 305 | "oscillations_freeze_threshold", 306 | "oscillations_freeze_ema_momentum", 307 | "oscillations_freeze_use_ema", 308 | "oscillations_freeze_max_bits", 309 | "oscillations_freeze_threshold_final", 310 | "oscillations_freeze_anneal_start", 311 | ], 312 | "oscillations_freeze", 313 | ) 314 | 315 | return func(config, *args, **remainder_kwargs) 316 | 317 | return func_wrapper 318 | 319 | 320 | def quantization_options(func): 321 | # Weight quantization options 322 | @click.option( 323 | "--weight-quant/--no-weight-quant", 324 | is_flag=True, 325 | default=True, 326 | help="Run evaluation weight quantization or use FP32 weights", 327 | ) 328 | @click.option( 329 | "--qmethod", 330 | type=ClickEnumOption(QMethods), 331 | default=QMethods.symmetric_uniform.name, 332 | help="Quantization scheme to use.", 333 | ) 334 | @click.option( 335 | "--weight-quant-method", 336 | default=RangeEstimators.current_minmax.name, 337 | type=ClickEnumOption(RangeEstimators), 338 | help="Method to determine weight quantization clipping thresholds.", 339 | ) 340 | @click.option( 341 | "--weight-opt-method", 342 | default=OptMethod.grid.name, 343 | type=ClickEnumOption(OptMethod), 344 | help="Optimization procedure for activation quantization clipping thresholds", 345 | ) 346 | @click.option( 347 | "--num-candidates", 348 | type=int, 349 | default=None, 350 | help="Number of grid points for grid search in MSE range method.", 351 | ) 352 | @click.option("--n-bits", default=8, type=int, help="Default number of quantization bits.") 353 | @click.option( 354 | "--per-channel/--no-per-channel", 355 | is_flag=True, 356 | default=False, 357 | help="If given, quantize each channel separately.", 358 | ) 359 | # Activation quantization options 360 | @click.option( 361 | "--act-quant/--no-act-quant", 362 | is_flag=True, 363 | default=True, 364 | help="Run evaluation with activation quantization or use FP32 activations", 365 | ) 366 | @click.option( 367 | "--qmethod-act", 368 | type=ClickEnumOption(QMethods), 369 | default=None, 370 | help="Quantization scheme for activation to use. If not specified `--qmethod` " "is used.", 371 | ) 372 | @click.option( 373 | "--n-bits-act", default=None, type=int, help="Number of quantization bits for activations." 374 | ) 375 | @click.option( 376 | "--act-quant-method", 377 | default=RangeEstimators.running_minmax.name, 378 | type=ClickEnumOption(RangeEstimators), 379 | help="Method to determine activation quantization clipping thresholds", 380 | ) 381 | @click.option( 382 | "--act-opt-method", 383 | default=OptMethod.grid.name, 384 | type=ClickEnumOption(OptMethod), 385 | help="Optimization procedure for activation quantization clipping thresholds", 386 | ) 387 | @click.option( 388 | "--act-num-candidates", 389 | type=int, 390 | default=None, 391 | help="Number of grid points for grid search in MSE/SQNR/Cross-entropy", 392 | ) 393 | @click.option( 394 | "--act-momentum", 395 | type=float, 396 | default=None, 397 | help="Exponential averaging factor for running_minmax", 398 | ) 399 | @click.option( 400 | "--num-est-batches", 401 | type=int, 402 | default=1, 403 | help="Number of training batches to be used for activation range estimation", 404 | ) 405 | # Other options 406 | @click.option( 407 | "--quant-setup", 408 | default="LSQ_paper", 409 | type=click.Choice(["all", "LSQ", "FP_logits", "fc4", "fc4_dw8", "LSQ_paper"]), 410 | help="Method to quantize the network.", 411 | ) 412 | @wraps(func) 413 | def func_wrapper(config, *args, **kwargs): 414 | config.quant, remainder_kwargs = split_dict( 415 | kwargs, 416 | [ 417 | "qmethod", 418 | "qmethod_act", 419 | "weight_quant_method", 420 | "weight_opt_method", 421 | "num_candidates", 422 | "n_bits", 423 | "n_bits_act", 424 | "per_channel", 425 | "act_quant", 426 | "weight_quant", 427 | "quant_setup", 428 | "num_est_batches", 429 | "act_momentum", 430 | "act_num_candidates", 431 | "act_opt_method", 432 | "act_quant_method", 433 | ], 434 | ) 435 | 436 | config.quant.qmethod_act = config.quant.qmethod_act or config.quant.qmethod 437 | 438 | return func(config, *args, **remainder_kwargs) 439 | 440 | return func_wrapper 441 | 442 | 443 | def quant_params_dict(config): 444 | weight_range_options = {} 445 | if config.quant.weight_quant_method == RangeEstimators.MSE: 446 | weight_range_options = dict(opt_method=config.quant.weight_opt_method) 447 | if config.quant.num_candidates is not None: 448 | weight_range_options["num_candidates"] = config.quant.num_candidates 449 | 450 | act_range_options = {} 451 | if config.quant.act_quant_method == RangeEstimators.MSE: 452 | act_range_options = dict(opt_method=config.quant.act_opt_method) 453 | if config.quant.act_num_candidates is not None: 454 | act_range_options["num_candidates"] = config.quant.num_candidates 455 | 456 | qparams = { 457 | "method": config.quant.qmethod.cls, 458 | "n_bits": config.quant.n_bits, 459 | "n_bits_act": config.quant.n_bits_act, 460 | "act_method": config.quant.qmethod_act.cls, 461 | "per_channel_weights": config.quant.per_channel, 462 | "quant_setup": config.quant.quant_setup, 463 | "weight_range_method": config.quant.weight_quant_method.cls, 464 | "weight_range_options": weight_range_options, 465 | "act_range_method": config.quant.act_quant_method.cls, 466 | "act_range_options": act_range_options, 467 | "quantize_input": True if config.quant.quant_setup == "LSQ_paper" else False, 468 | } 469 | 470 | return qparams 471 | -------------------------------------------------------------------------------- /utils/imagenet_dataloaders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import os 6 | 7 | import torchvision 8 | import torch.utils.data as torch_data 9 | from torchvision import transforms 10 | from utils import BaseEnumOptions 11 | 12 | 13 | class ImageInterpolation(BaseEnumOptions): 14 | nearest = transforms.InterpolationMode.NEAREST 15 | box = transforms.InterpolationMode.BOX 16 | bilinear = transforms.InterpolationMode.BILINEAR 17 | hamming = transforms.InterpolationMode.HAMMING 18 | bicubic = transforms.InterpolationMode.BICUBIC 19 | lanczos = transforms.InterpolationMode.LANCZOS 20 | 21 | 22 | class ImageNetDataLoaders(object): 23 | """ 24 | Data loader provider for ImageNet images, providing a train and a validation loader. 25 | It assumes that the structure of the images is 26 | images_dir 27 | - train 28 | - label1 29 | - label2 30 | - ... 31 | - val 32 | - label1 33 | - label2 34 | - ... 35 | """ 36 | 37 | def __init__( 38 | self, 39 | images_dir: str, 40 | image_size: int, 41 | batch_size: int, 42 | num_workers: int, 43 | interpolation: transforms.InterpolationMode, 44 | ): 45 | """ 46 | Parameters 47 | ---------- 48 | images_dir: str 49 | Root image directory 50 | image_size: int 51 | Number of pixels the image will be re-sized to (square) 52 | batch_size: int 53 | Batch size of both the training and validation loaders 54 | num_workers 55 | Number of parallel workers loading the images 56 | interpolation: transforms.InterpolationMode 57 | Desired interpolation to use for resizing. 58 | """ 59 | 60 | self.images_dir = images_dir 61 | self.batch_size = batch_size 62 | self.num_workers = num_workers 63 | 64 | # For normalization, mean and std dev values are calculated per channel 65 | # and can be found on the web. 66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 67 | 68 | self.train_transforms = transforms.Compose( 69 | [ 70 | transforms.RandomResizedCrop(image_size, interpolation=interpolation.value), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | normalize, 74 | ] 75 | ) 76 | 77 | self.val_transforms = transforms.Compose( 78 | [ 79 | transforms.Resize(image_size + 24, interpolation=interpolation.value), 80 | transforms.CenterCrop(image_size), 81 | transforms.ToTensor(), 82 | normalize, 83 | ] 84 | ) 85 | 86 | self._train_loader = None 87 | self._val_loader = None 88 | 89 | @property 90 | def train_loader(self) -> torch_data.DataLoader: 91 | if not self._train_loader: 92 | root = os.path.join(self.images_dir, "train") 93 | train_set = torchvision.datasets.ImageFolder(root, transform=self.train_transforms) 94 | self._train_loader = torch_data.DataLoader( 95 | train_set, 96 | batch_size=self.batch_size, 97 | shuffle=True, 98 | num_workers=self.num_workers, 99 | pin_memory=True, 100 | ) 101 | return self._train_loader 102 | 103 | @property 104 | def val_loader(self) -> torch_data.DataLoader: 105 | if not self._val_loader: 106 | root = os.path.join(self.images_dir, "val") 107 | val_set = torchvision.datasets.ImageFolder(root, transform=self.val_transforms) 108 | self._val_loader = torch_data.DataLoader( 109 | val_set, 110 | batch_size=self.batch_size, 111 | shuffle=False, 112 | num_workers=self.num_workers, 113 | pin_memory=True, 114 | ) 115 | return self._val_loader 116 | -------------------------------------------------------------------------------- /utils/optimizer_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import torch 6 | 7 | 8 | def get_lr_scheduler(optimizer, lr_schedule, epochs): 9 | scheduler = None 10 | if lr_schedule: 11 | if lr_schedule.startswith("multistep"): 12 | epochs = [int(s) for s in lr_schedule.split(":")[1:]] 13 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epochs) 14 | elif lr_schedule.startswith("cosine"): 15 | eta_min = float(lr_schedule.split(":")[1]) 16 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 17 | optimizer, epochs, eta_min=eta_min 18 | ) 19 | return scheduler 20 | 21 | 22 | def optimizer_lr_factory(config_optim, params, epochs): 23 | if config_optim.optimizer.lower() == "sgd": 24 | optimizer = torch.optim.SGD( 25 | params, 26 | lr=config_optim.learning_rate, 27 | momentum=config_optim.momentum, 28 | weight_decay=config_optim.weight_decay, 29 | ) 30 | elif config_optim.optimizer.lower() == "adam": 31 | optimizer = torch.optim.Adam( 32 | params, lr=config_optim.learning_rate, weight_decay=config_optim.weight_decay 33 | ) 34 | else: 35 | raise ValueError() 36 | 37 | lr_scheduler = get_lr_scheduler(optimizer, config_optim.learning_rate_schedule, epochs) 38 | 39 | return optimizer, lr_scheduler 40 | -------------------------------------------------------------------------------- /utils/oscillation_tracking_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import torch 5 | 6 | from quantization.hijacker import QuantizationHijacker 7 | 8 | 9 | def add_oscillation_trackers(model, max_bits=4, *args, **kwarks): 10 | tracker_dict = {} 11 | # Add oscillation trackers to all weight quantizers 12 | for name, module in model.named_modules(): 13 | if isinstance(module, QuantizationHijacker): 14 | q = module.weight_quantizer.quantizer 15 | if q.n_bits > max_bits: 16 | print( 17 | f"Skip tracking/freezing for {name}, too high bit {q.n_bits} (max {max_bits})" 18 | ) 19 | continue 20 | int_fwd_wrapper = TrackOscillation(int_fwd=q.to_integer_forward, *args, **kwarks) 21 | q.to_integer_forward = int_fwd_wrapper 22 | tracker_dict[name + ".weight_quantizer"] = int_fwd_wrapper 23 | return tracker_dict 24 | 25 | 26 | class TrackOscillation: 27 | """ 28 | This is a wrapper of the int_forward function of a quantizer. 29 | It tracks the oscillations in integer domain. 30 | """ 31 | 32 | def __init__(self, int_fwd, momentum=0.01, freeze_threshold=0, use_ema_x_int=True): 33 | self.int_fwd = int_fwd 34 | self.momentum = momentum 35 | 36 | self.prev_x_int = None 37 | self.prev_switch_dir = None 38 | 39 | # Statistics to log 40 | self.ema_oscillation = None 41 | self.oscillated_sum = None 42 | self.total_oscillation = None 43 | self.iters_since_reset = 0 44 | 45 | # Extra variables for weight freezing 46 | self.freeze_threshold = freeze_threshold # This should be at least 2-3x the momentum value. 47 | self.use_ema_x_int = use_ema_x_int 48 | self.frozen = None 49 | self.frozen_x_int = None 50 | self.ema_x_int = None 51 | 52 | def __call__(self, x_float, skip_tracking=False, *args, **kwargs): 53 | x_int = self.int_fwd(x_float, *args, **kwargs) 54 | 55 | # Apply weight freezing 56 | if self.frozen is not None: 57 | x_int = ~self.frozen * x_int + self.frozen * self.frozen_x_int 58 | 59 | if skip_tracking: 60 | return x_int 61 | 62 | with torch.no_grad(): 63 | # Check if everything is correctly initialized, otherwise do so 64 | self.check_init(x_int) 65 | 66 | # detect difference in x_int NB we round to avoid int inaccuracies 67 | delta_x_int = torch.round(self.prev_x_int - x_int).detach() # should be {-1, 0, 1} 68 | switch_dir = torch.sign(delta_x_int) # This is {-1, 0, 1} as sign(0) is mapped to 0 69 | # binary mask for switching 70 | switched = delta_x_int != 0 71 | 72 | oscillated = (self.prev_switch_dir * switch_dir) == -1 73 | self.ema_oscillation = ( 74 | self.momentum * oscillated + (1 - self.momentum) * self.ema_oscillation 75 | ) 76 | 77 | # Update prev_switch_dir for the switch variables 78 | self.prev_switch_dir[switched] = switch_dir[switched] 79 | self.prev_x_int = x_int 80 | self.oscillated_sum = oscillated.sum() 81 | self.total_oscillation += oscillated 82 | self.iters_since_reset += 1 83 | 84 | # Freeze some weights 85 | if self.freeze_threshold > 0: 86 | freeze_weights = self.ema_oscillation > self.freeze_threshold 87 | self.frozen[freeze_weights] = True # Set them to frozen 88 | if self.use_ema_x_int: 89 | self.frozen_x_int[freeze_weights] = torch.round(self.ema_x_int[freeze_weights]) 90 | # Update x_int EMA which can be used for freezing 91 | self.ema_x_int = self.momentum * x_int + (1 - self.momentum) * self.ema_x_int 92 | else: 93 | self.frozen_x_int[freeze_weights] = x_int[freeze_weights] 94 | 95 | return x_int 96 | 97 | def check_init(self, x_int): 98 | if self.prev_x_int is None: 99 | # Init prev switch dir to 0 100 | self.prev_switch_dir = torch.zeros_like(x_int) 101 | self.prev_x_int = x_int.detach() # Not sure if needed, don't think so 102 | self.ema_oscillation = torch.zeros_like(x_int) 103 | self.oscillated_sum = 0 104 | self.total_oscillation = torch.zeros_like(x_int) 105 | print("Init tracking", x_int.shape) 106 | else: 107 | assert ( 108 | self.prev_x_int.shape == x_int.shape 109 | ), "Tracking shape does not match current tensor shape." 110 | 111 | # For weight freezing 112 | if self.frozen is None and self.freeze_threshold > 0: 113 | self.frozen = torch.zeros_like(x_int, dtype=torch.bool) 114 | self.frozen_x_int = torch.zeros_like(x_int) 115 | if self.use_ema_x_int: 116 | self.ema_x_int = x_int.detach().clone() 117 | print("Init freezing", x_int.shape) 118 | -------------------------------------------------------------------------------- /utils/qat_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import copy 6 | 7 | import torch 8 | 9 | from quantization.hijacker import QuantizationHijacker 10 | from quantization.quantized_folded_bn import BNFusedHijacker 11 | from utils.imagenet_dataloaders import ImageNetDataLoaders 12 | 13 | 14 | class MethodPropagator: 15 | """convenience class to allow multiple optimizers or LR schedulers to be used as if it 16 | were one optimizer/scheduler.""" 17 | 18 | def __init__(self, propagatables): 19 | self.propagatables = propagatables 20 | 21 | def __getattr__(self, item): 22 | if callable(getattr(self.propagatables[0], item)): 23 | 24 | def propagate_call(*args, **kwargs): 25 | for prop in self.propagatables: 26 | getattr(prop, item)(*args, **kwargs) 27 | 28 | return propagate_call 29 | else: 30 | return getattr(self.propagatables[0], item) 31 | 32 | def __str__(self): 33 | result = "" 34 | for prop in self.propagatables: 35 | result += str(prop) + "\n" 36 | return result 37 | 38 | def __iter__(self): 39 | for i in self.propagatables: 40 | yield i 41 | 42 | def __contains__(self, item): 43 | return item in self.propagatables 44 | 45 | 46 | def get_dataloaders_and_model(config, load_type="fp32", **qparams): 47 | dataloaders = ImageNetDataLoaders( 48 | config.base.images_dir, 49 | 224, 50 | config.base.batch_size, 51 | config.base.num_workers, 52 | config.base.interpolation, 53 | ) 54 | 55 | model = config.base.architecture( 56 | pretrained=config.base.pretrained, 57 | load_type=load_type, 58 | model_dir=config.base.model_dir, 59 | **qparams, 60 | ) 61 | if config.base.cuda: 62 | model = model.cuda() 63 | 64 | return dataloaders, model 65 | 66 | 67 | class CompositeLoss: 68 | def __init__(self, loss_dict): 69 | """ 70 | Composite loss of N separate loss functions. All functions are summed up. 71 | 72 | Note, each loss function gets as argument (prediction, target), even though if it might not 73 | need it. Other data independent instances need to be provided directly to the loss function 74 | (e.g. the model/weights in case of a regularization term. 75 | 76 | """ 77 | self.loss_dict = loss_dict 78 | 79 | def __call__(self, prediction, target, *args, **kwargs): 80 | total_loss = 0 81 | for loss_func in self.loss_dict.values(): 82 | total_loss += loss_func(prediction, target, *args, **kwargs) 83 | return total_loss 84 | 85 | 86 | class UpdateFreezingThreshold: 87 | def __init__(self, tracker_dict, decay_schedule): 88 | self.tracker_dict = tracker_dict 89 | self.decay_schedule = decay_schedule 90 | 91 | def __call__(self, engine): 92 | if engine.state.iteration < self.decay_schedule.decay_start: 93 | # Put it always to 0 for real warm-start 94 | new_threshold = 0 95 | else: 96 | new_threshold = self.decay_schedule(engine.state.iteration) 97 | 98 | # Update trackers with new threshold 99 | for name, tracker in self.tracker_dict.items(): 100 | tracker.freeze_threshold = new_threshold 101 | # print('Set new freezing threshold', new_threshold) 102 | 103 | 104 | class UpdateDampeningLossWeighting: 105 | def __init__(self, bin_reg_loss, decay_schedule): 106 | self.dampen_loss = bin_reg_loss 107 | self.decay_schedule = decay_schedule 108 | 109 | def __call__(self, engine): 110 | new_weighting = self.decay_schedule(engine.state.iteration) 111 | self.dampen_loss.weighting = new_weighting 112 | # print('Set new bin reg weighting', new_weighting) 113 | 114 | 115 | class DampeningLoss: 116 | def __init__(self, model, weighting=1.0, aggregation="sum"): 117 | """ 118 | Calculates the dampening loss for all weights in a given quantized model. It is 119 | expected that all quantized weights are in a Hijacker module. 120 | 121 | """ 122 | self.model = model 123 | self.weighting = weighting 124 | self.aggregation = aggregation 125 | 126 | def __call__(self, *args, **kwargs): 127 | total_bin_loss = 0 128 | for name, module in self.model.named_modules(): 129 | if isinstance(module, QuantizationHijacker): 130 | # FP32 weight tensor, potential folded but before quantization 131 | weight, _ = module.get_weight_bias() 132 | # The matching weight quantizer (not manager, direct quantizer class) 133 | quantizer = module.weight_quantizer.quantizer 134 | total_bin_loss += dampening_loss(weight, quantizer, self.aggregation) 135 | return total_bin_loss * self.weighting 136 | 137 | 138 | def dampening_loss(w_fp, quantizer, aggregation="sum"): 139 | # L &= (s*w_{int} - w)^2 140 | # We also need to add clipping for both cases, we can do so by using the forward 141 | w_q = quantizer(w_fp, skip_tracking=True).detach() # this is also clipped and our target 142 | # clamp w in FP32 domain to not change range learning (min(max) is needed for per-channel) 143 | w_fp_clip = torch.min(torch.max(w_fp, quantizer.x_min), quantizer.x_max) 144 | loss = (w_q - w_fp_clip) ** 2 145 | if aggregation == "sum": 146 | return loss.sum() 147 | elif aggregation == "mean": 148 | return loss.mean() 149 | elif aggregation == "kernel_mean": 150 | return loss.sum(0).mean() 151 | else: 152 | raise ValueError(f"Aggregation method '{aggregation}' not implemented.") 153 | 154 | 155 | class ReestimateBNStats: 156 | def __init__(self, model, data_loader, num_batches=50): 157 | super().__init__() 158 | self.model = model 159 | self.data_loader = data_loader 160 | self.num_batches = num_batches 161 | 162 | def __call__(self, engine): 163 | print("-- Reestimate current BN statistics --") 164 | reestimate_BN_stats(self.model, self.data_loader, self.num_batches) 165 | 166 | 167 | def reestimate_BN_stats(model, data_loader, num_batches=50, store_ema_stats=False): 168 | # We set BN momentum to 1 an use train mode 169 | # -> the running mean/var have the current batch statistics 170 | model.eval() 171 | org_momentum = {} 172 | for name, module in model.named_modules(): 173 | if isinstance(module, BNFusedHijacker): 174 | org_momentum[name] = module.momentum 175 | module.momentum = 1.0 176 | module.running_mean_sum = torch.zeros_like(module.running_mean) 177 | module.running_var_sum = torch.zeros_like(module.running_var) 178 | # Set all BNFusedHijacker modules to train mode for but not its children 179 | module.training = True 180 | 181 | if store_ema_stats: 182 | # Save the original EMA, make sure they are in buffers so they end in the state dict 183 | if not hasattr(module, "running_mean_ema"): 184 | module.register_buffer("running_mean_ema", copy.deepcopy(module.running_mean)) 185 | module.register_buffer("running_var_ema", copy.deepcopy(module.running_var)) 186 | else: 187 | module.running_mean_ema = copy.deepcopy(module.running_mean) 188 | module.running_var_ema = copy.deepcopy(module.running_var) 189 | 190 | # Run data for estimation 191 | device = next(model.parameters()).device 192 | batch_count = 0 193 | with torch.no_grad(): 194 | for x, y in data_loader: 195 | model(x.to(device)) 196 | # We save the running mean/var to a buffer 197 | for name, module in model.named_modules(): 198 | if isinstance(module, BNFusedHijacker): 199 | module.running_mean_sum += module.running_mean 200 | module.running_var_sum += module.running_var 201 | 202 | batch_count += 1 203 | if batch_count == num_batches: 204 | break 205 | # At the end we normalize the buffer and write it into the running mean/var 206 | for name, module in model.named_modules(): 207 | if isinstance(module, BNFusedHijacker): 208 | module.running_mean = module.running_mean_sum / batch_count 209 | module.running_var = module.running_var_sum / batch_count 210 | # We reset the momentum in case it would be used anywhere else 211 | module.momentum = org_momentum[name] 212 | model.eval() 213 | -------------------------------------------------------------------------------- /utils/stopwatch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import sys 6 | import time 7 | 8 | 9 | class Stopwatch: 10 | """ 11 | A simple cross-platform context-manager stopwatch. 12 | 13 | Examples 14 | -------- 15 | >>> import time 16 | >>> with Stopwatch(verbose=True) as st: 17 | ... time.sleep(0.101) #doctest: +ELLIPSIS 18 | Elapsed time: 0.10... sec 19 | """ 20 | 21 | def __init__(self, name=None, verbose=False): 22 | self._name = name 23 | self._verbose = verbose 24 | 25 | self._start_time_point = 0.0 26 | self._total_duration = 0.0 27 | self._is_running = False 28 | 29 | if sys.platform == "win32": 30 | # on Windows, the best timer is time.clock() 31 | self._timer_fn = time.clock 32 | else: 33 | # on most other platforms, the best timer is time.time() 34 | self._timer_fn = time.time 35 | 36 | def __enter__(self, verbose=False): 37 | return self.start() 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | self.stop() 41 | if self._verbose: 42 | self.print() 43 | 44 | def start(self): 45 | if not self._is_running: 46 | self._start_time_point = self._timer_fn() 47 | self._is_running = True 48 | return self 49 | 50 | def stop(self): 51 | if self._is_running: 52 | self._total_duration += self._timer_fn() - self._start_time_point 53 | self._is_running = False 54 | return self 55 | 56 | def reset(self): 57 | self._start_time_point = 0.0 58 | self._total_duration = 0.0 59 | self._is_running = False 60 | return self 61 | 62 | def _update_state(self): 63 | now = self._timer_fn() 64 | self._total_duration += now - self._start_time_point 65 | self._start_time_point = now 66 | 67 | def _format(self): 68 | prefix = f"[{self._name}]" if self._name is not None else "Elapsed time" 69 | info = f"{prefix}: {self._total_duration:.3f} sec" 70 | return info 71 | 72 | def format(self): 73 | if self._is_running: 74 | self._update_state() 75 | return self._format() 76 | 77 | def print(self): 78 | print(self.format()) 79 | 80 | def get_total_duration(self): 81 | if self._is_running: 82 | self._update_state() 83 | return self._total_duration 84 | -------------------------------------------------------------------------------- /utils/supervised_driver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from ignite.contrib.handlers import TensorboardLogger 6 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator 7 | from ignite.handlers import Checkpoint, global_step_from_engine 8 | from torch.optim import Optimizer 9 | 10 | 11 | def create_trainer_engine( 12 | model, 13 | optimizer, 14 | criterion, 15 | metrics, 16 | data_loaders, 17 | lr_scheduler=None, 18 | save_checkpoint_dir=None, 19 | device="cuda", 20 | ): 21 | # Create trainer 22 | trainer = create_supervised_trainer( 23 | model=model, 24 | optimizer=optimizer, 25 | loss_fn=criterion, 26 | device=device, 27 | output_transform=custom_output_transform, 28 | ) 29 | 30 | for name, metric in metrics.items(): 31 | metric.attach(trainer, name) 32 | 33 | # Add lr_scheduler 34 | if lr_scheduler: 35 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step()) 36 | 37 | # Create evaluator 38 | evaluator = create_supervised_evaluator(model=model, metrics=metrics, device=device) 39 | 40 | # Save model checkpoint 41 | if save_checkpoint_dir: 42 | to_save = {"model": model, "optimizer": optimizer} 43 | if lr_scheduler: 44 | to_save["lr_scheduler"] = lr_scheduler 45 | checkpoint = Checkpoint( 46 | to_save, 47 | save_checkpoint_dir, 48 | n_saved=1, 49 | global_step_transform=global_step_from_engine(trainer), 50 | ) 51 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint) 52 | 53 | # Add hooks for logging metrics 54 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results, optimizer) 55 | 56 | trainer.add_event_handler( 57 | Events.EPOCH_COMPLETED, run_evaluation_for_training, evaluator, data_loaders.val_loader 58 | ) 59 | 60 | return trainer, evaluator 61 | 62 | 63 | def custom_output_transform(x, y, y_pred, loss): 64 | return y_pred, y 65 | 66 | 67 | def log_training_results(trainer, optimizer): 68 | learning_rate = optimizer.param_groups[0]["lr"] 69 | log_metrics(trainer.state.metrics, "Training", trainer.state.epoch, learning_rate) 70 | 71 | 72 | def run_evaluation_for_training(trainer, evaluator, val_loader): 73 | evaluator.run(val_loader) 74 | log_metrics(evaluator.state.metrics, "Evaluation", trainer.state.epoch) 75 | 76 | 77 | def log_metrics(metrics, stage: str = "", training_epoch=None, learning_rate=None): 78 | log_text = " {}".format(metrics) if metrics else "" 79 | if training_epoch is not None: 80 | log_text = "Epoch: {}".format(training_epoch) + log_text 81 | if learning_rate and learning_rate > 0.0: 82 | log_text += " Learning rate: {:.2E}".format(learning_rate) 83 | log_text = "Results - " + log_text 84 | if stage: 85 | log_text = "{} ".format(stage) + log_text 86 | print(log_text, flush=True) 87 | 88 | 89 | def setup_tensorboard_logger(trainer, evaluator, output_path, optimizers=None): 90 | logger = TensorboardLogger(logdir=output_path) 91 | 92 | # Attach the logger to log loss and accuracy for both training and validation 93 | for tag, cur_evaluator in [("train", trainer), ("validation", evaluator)]: 94 | logger.attach_output_handler( 95 | cur_evaluator, 96 | event_name=Events.EPOCH_COMPLETED, 97 | tag=tag, 98 | metric_names="all", 99 | global_step_transform=global_step_from_engine(trainer), 100 | ) 101 | 102 | # Log optimizer parameters 103 | if isinstance(optimizers, Optimizer): 104 | optimizers = {None: optimizers} 105 | 106 | for k, optimizer in optimizers.items(): 107 | logger.attach_opt_params_handler( 108 | trainer, Events.EPOCH_COMPLETED, optimizer, param_name="lr", tag=k 109 | ) 110 | 111 | return logger 112 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import collections 6 | import os 7 | import random 8 | from collections import namedtuple 9 | from enum import Flag, auto 10 | from functools import partial 11 | 12 | import click 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class DotDict(dict): 19 | """ 20 | A dictionary that allows attribute-style access. 21 | Examples 22 | -------- 23 | >>> config = DotDict(a=None) 24 | >>> config.a = 42 25 | >>> config.b = 'egg' 26 | >>> config # can be used as dict 27 | {'a': 42, 'b': 'egg'} 28 | """ 29 | 30 | def __setattr__(self, key, value): 31 | self.__setitem__(key, value) 32 | 33 | def __delattr__(self, key): 34 | self.__delitem__(key) 35 | 36 | def __getattr__(self, key): 37 | if key in self: 38 | return self.__getitem__(key) 39 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})") 40 | 41 | 42 | def relu(x): 43 | x = np.array(x) 44 | return x * (x > 0) 45 | 46 | 47 | def get_all_layer_names(model, subtypes=None): 48 | if subtypes is None: 49 | return [name for name, module in model.named_modules()][1:] 50 | return [name for name, module in model.named_modules() if isinstance(module, subtypes)] 51 | 52 | 53 | def get_layer_name_to_module_dict(model): 54 | return {name: module for name, module in model.named_modules() if name} 55 | 56 | 57 | def get_module_to_layer_name_dict(model): 58 | modules_to_names = collections.OrderedDict() 59 | for name, module in model.named_modules(): 60 | modules_to_names[module] = name 61 | return modules_to_names 62 | 63 | 64 | def get_layer_name(model, layer): 65 | for name, module in model.named_modules(): 66 | if module == layer: 67 | return name 68 | return None 69 | 70 | 71 | def get_layer_by_name(model, layer_name): 72 | for name, module in model.named_modules(): 73 | if name == layer_name: 74 | return module 75 | return None 76 | 77 | 78 | def create_conv_layer_list(cls, model: nn.Module) -> list: 79 | """ 80 | Function finds all prunable layers in the provided model 81 | 82 | Parameters 83 | ---------- 84 | cls: SVD class 85 | model : torch.nn.Module 86 | A pytorch model. 87 | 88 | Returns 89 | ------- 90 | conv_layer_list : list 91 | List of all prunable layers in the given model. 92 | 93 | """ 94 | conv_layer_list = [] 95 | 96 | def fill_list(mod): 97 | if isinstance(mod, tuple(cls.supported_layer_types)): 98 | conv_layer_list.append(mod) 99 | 100 | model.apply(fill_list) 101 | return conv_layer_list 102 | 103 | 104 | def create_linear_layer_list(cls, model: nn.Module) -> list: 105 | """ 106 | Function finds all prunable layers in the provided model 107 | 108 | Parameters 109 | ---------- 110 | model : torch.nn.Module 111 | A pytorch model. 112 | 113 | Returns 114 | ------- 115 | conv_layer_list : list 116 | List of all prunable layers in the given model. 117 | 118 | """ 119 | conv_layer_list = [] 120 | 121 | def fill_list(mod): 122 | if isinstance(mod, tuple(cls.supported_layer_types)): 123 | conv_layer_list.append(mod) 124 | 125 | model.apply(fill_list) 126 | return conv_layer_list 127 | 128 | 129 | def to_numpy(tensor): 130 | """ 131 | Helper function that turns the given tensor into a numpy array 132 | 133 | Parameters 134 | ---------- 135 | tensor : torch.Tensor 136 | 137 | Returns 138 | ------- 139 | tensor : float or np.array 140 | 141 | """ 142 | if isinstance(tensor, np.ndarray): 143 | return tensor 144 | if hasattr(tensor, "is_cuda"): 145 | if tensor.is_cuda: 146 | return tensor.cpu().detach().numpy() 147 | if hasattr(tensor, "detach"): 148 | return tensor.detach().numpy() 149 | if hasattr(tensor, "numpy"): 150 | return tensor.numpy() 151 | 152 | return np.array(tensor) 153 | 154 | 155 | def set_module_attr(model, layer_name, value): 156 | split = layer_name.split(".") 157 | 158 | this_module = model 159 | for mod_name in split[:-1]: 160 | if mod_name.isdigit(): 161 | this_module = this_module[int(mod_name)] 162 | else: 163 | this_module = getattr(this_module, mod_name) 164 | 165 | last_mod_name = split[-1] 166 | if last_mod_name.isdigit(): 167 | this_module[int(last_mod_name)] = value 168 | else: 169 | setattr(this_module, last_mod_name, value) 170 | 171 | 172 | def search_for_zero_planes(model: torch.nn.Module): 173 | """If list of modules to winnow is empty to start with, search through all modules to check 174 | if any 175 | planes have been zeroed out. Update self._list_of_modules_to_winnow with any findings. 176 | :param model: torch model to search through modules for zeroed parameters 177 | """ 178 | 179 | list_of_modules_to_winnow = [] 180 | for _, module in model.named_modules(): 181 | if isinstance(module, (torch.nn.Linear, torch.nn.modules.conv.Conv2d)): 182 | in_channels_to_winnow = _assess_weight_and_bias(module.weight, module.bias) 183 | if in_channels_to_winnow: 184 | list_of_modules_to_winnow.append((module, in_channels_to_winnow)) 185 | return list_of_modules_to_winnow 186 | 187 | 188 | def _assess_weight_and_bias(weight: torch.nn.Parameter, _bias: torch.nn.Parameter): 189 | """4-dim weights [CH-out, CH-in, H, W] and 1-dim bias [CH-out]""" 190 | if len(weight.shape) > 2: 191 | input_channels_to_ignore = (weight.sum((0, 2, 3)) == 0).nonzero().squeeze().tolist() 192 | else: 193 | input_channels_to_ignore = (weight.sum(0) == 0).nonzero().squeeze().tolist() 194 | 195 | if type(input_channels_to_ignore) != list: 196 | input_channels_to_ignore = [input_channels_to_ignore] 197 | 198 | return input_channels_to_ignore 199 | 200 | 201 | def seed_all(seed: int = 1029, deterministic: bool = False): 202 | """ 203 | This is our attempt to make experiments reproducible by seeding all known RNGs and setting 204 | appropriate torch directives. 205 | For a general discussion of reproducibility in Pytorch and CUDA and a documentation of the 206 | options we are using see, e.g., 207 | https://pytorch.org/docs/1.7.1/notes/randomness.html 208 | https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility 209 | 210 | As of today (July 2021), even after seeding and setting some directives, 211 | there remain unfortunate contradictions: 212 | 1. CUDNN 213 | - having CUDNN enabled leads to 214 | - non-determinism in Pytorch when using the GPU, cf. MORPH-10999. 215 | - having CUDNN disabled leads to 216 | - most regression tests in Qrunchy failing, cf. MORPH-11103 217 | - significantly increased execution time in some cases 218 | - performance degradation in some cases 219 | 2. torch.set_deterministic(d) 220 | - setting d = True leads to errors for Pytorch algorithms that do not (yet) have a deterministic 221 | counterpart, e.g., the layer `adaptive_avg_pool2d_backward_cuda` in vgg16__torchvision. 222 | 223 | Thus, we leave the choice of enforcing determinism by disabling CUDNN and non-deterministic 224 | algorithms to the user. To keep it simple, we only have one switch for both. 225 | This situation could be re-evaluated upon updates of Pytorch, CUDA, CUDNN. 226 | """ 227 | 228 | assert isinstance(seed, int), f"RNG seed must be an integer ({seed})" 229 | assert seed >= 0, f"RNG seed must be a positive integer ({seed})" 230 | 231 | # Builtin RNGs 232 | random.seed(seed) 233 | os.environ["PYTHONHASHSEED"] = str(seed) 234 | 235 | # Numpy RNG 236 | np.random.seed(seed) 237 | 238 | # CUDNN determinism (setting those has not lead to errors so far) 239 | torch.backends.cudnn.benchmark = False 240 | torch.backends.cudnn.deterministic = True 241 | 242 | # Torch RNGs 243 | torch.manual_seed(seed) 244 | torch.cuda.manual_seed(seed) 245 | torch.cuda.manual_seed_all(seed) 246 | 247 | # Problematic settings, see docstring. Precaution: We do not mutate unless asked to do so 248 | if deterministic is True: 249 | torch.backends.cudnn.enabled = False 250 | 251 | torch.set_deterministic(True) # Use torch.use_deterministic_algorithms(True) in torch 1.8.1 252 | # When using torch.set_deterministic(True), it is advised by Pytorch to set the 253 | # CUBLAS_WORKSPACE_CONFIG variable as follows, see 254 | # https://pytorch.org/docs/1.7.1/notes/randomness.html#avoiding-nondeterministic-algorithms 255 | # and the link to the CUDA homepage on that website. 256 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 257 | 258 | 259 | def assert_allclose(actual, desired, *args, **kwargs): 260 | """A more beautiful version of torch.all_close.""" 261 | np.testing.assert_allclose(to_numpy(actual), to_numpy(desired), *args, **kwargs) 262 | 263 | 264 | def count_params(module): 265 | return len(nn.utils.parameters_to_vector(module.parameters())) 266 | 267 | 268 | class StopForwardException(Exception): 269 | """Used to throw and catch an exception to stop traversing the graph.""" 270 | 271 | pass 272 | 273 | 274 | class StopForwardHook: 275 | def __call__(self, module, *args): 276 | raise StopForwardException 277 | 278 | 279 | def sigmoid(x): 280 | return 1.0 / (1.0 + np.exp(-x)) 281 | 282 | 283 | class CosineTempDecay: 284 | def __init__(self, t_max, temp_range=(20.0, 2.0), rel_decay_start=0): 285 | self.t_max = t_max 286 | self.start_temp, self.end_temp = temp_range 287 | self.decay_start = rel_decay_start * t_max 288 | 289 | def __call__(self, t): 290 | if t < self.decay_start: 291 | return self.start_temp 292 | 293 | rel_t = (t - self.decay_start) / (self.t_max - self.decay_start) 294 | return self.end_temp + 0.5 * (self.start_temp - self.end_temp) * (1 + np.cos(rel_t * np.pi)) 295 | 296 | 297 | class BaseEnumOptions(Flag): 298 | def __str__(self): 299 | return self.name 300 | 301 | @classmethod 302 | def list_names(cls): 303 | return [m.name for m in cls] 304 | 305 | 306 | class ClassEnumOptions(BaseEnumOptions): 307 | @property 308 | def cls(self): 309 | return self.value.cls 310 | 311 | def __call__(self, *args, **kwargs): 312 | return self.value.cls(*args, **kwargs) 313 | 314 | 315 | MethodMap = partial(namedtuple("MethodMap", ["value", "cls"]), auto()) 316 | 317 | 318 | def split_dict(src: dict, include=(), remove_prefix: str = ""): 319 | """ 320 | Splits dictionary into a DotDict and a remainder. 321 | The arguments to be placed in the first DotDict are those listed in `include`. 322 | Parameters 323 | ---------- 324 | src: dict 325 | The source dictionary. 326 | include: 327 | List of keys to be returned in the first DotDict. 328 | remove_suffix: 329 | remove prefix from key 330 | """ 331 | result = DotDict() 332 | 333 | for arg in include: 334 | if remove_prefix: 335 | key = arg.replace(f"{remove_prefix}_", "", 1) 336 | else: 337 | key = arg 338 | result[key] = src[arg] 339 | remainder = {key: val for key, val in src.items() if key not in include} 340 | return result, remainder 341 | 342 | 343 | class ClickEnumOption(click.Choice): 344 | """ 345 | Adjusted click.Choice type for BaseOption which is based on Enum 346 | """ 347 | 348 | def __init__(self, enum_options, case_sensitive=True): 349 | assert issubclass(enum_options, BaseEnumOptions) 350 | self.base_option = enum_options 351 | super().__init__(self.base_option.list_names(), case_sensitive) 352 | 353 | def convert(self, value, param, ctx): 354 | # Exact match 355 | if value in self.choices: 356 | return self.base_option[value] 357 | 358 | # Match through normalization and case sensitivity 359 | # first do token_normalize_func, then lowercase 360 | # preserve original `value` to produce an accurate message in 361 | # `self.fail` 362 | normed_value = value 363 | normed_choices = self.choices 364 | 365 | if ctx is not None and ctx.token_normalize_func is not None: 366 | normed_value = ctx.token_normalize_func(value) 367 | normed_choices = [ctx.token_normalize_func(choice) for choice in self.choices] 368 | 369 | if not self.case_sensitive: 370 | normed_value = normed_value.lower() 371 | normed_choices = [choice.lower() for choice in normed_choices] 372 | 373 | if normed_value in normed_choices: 374 | return self.base_option[normed_value] 375 | 376 | self.fail( 377 | "invalid choice: %s. (choose from %s)" % (value, ", ".join(self.choices)), param, ctx 378 | ) 379 | --------------------------------------------------------------------------------