├── .gitignore ├── LICENSE ├── README.md ├── compute_quant_error.py ├── image_net.py ├── models ├── __init__.py ├── mobilenet_v2.py ├── mobilenet_v2_quantized.py └── resnet_quantized.py ├── quantization ├── __init__.py ├── autoquant_utils.py ├── base_quantized_classes.py ├── base_quantized_model.py ├── hijacker.py ├── quant_error_estimator.py ├── quantization_manager.py ├── quantized_folded_bn.py ├── quantizers │ ├── __init__.py │ ├── base_quantizers.py │ ├── fp8_quantizer.py │ ├── rounding_utils.py │ ├── uniform_quantizers.py │ └── utils.py ├── range_estimators.py └── utils.py ├── requirements.txt └── utils ├── __init__.py ├── click_options.py ├── distributions.py ├── grid.py ├── imagenet_dataloaders.py ├── optimizer_utils.py ├── qat_utils.py ├── stopwatch.py ├── supervised_driver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | __pycache__/ 4 | experiments/ 5 | Makefile 6 | resources/ 7 | docker/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Qualcomm Technologies, Inc. 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer: 8 | 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FP8 Quantization: The Power of the Exponent 2 | This repository contains the implementation and experiments for the paper presented in 3 | 4 | **Andrey Kuzmin\*1, Mart van Baalen\*1, Yuwei Ren1, 5 | Markus Nagel1, Jorn Peters1, Tijmen Blankevoort1 "FP8 Quantization: The Power of the Exponent", NeurIPS 6 | 2022.** [[ArXiv]](https://arxiv.org/abs/2208.09225) 7 | 8 | *Equal contribution 9 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.) 10 | 11 | You can use this code to recreate the results in the paper. 12 | 13 | ## Method and Results 14 | 15 | In this repository we share the code to reproduce analytical and experimental results on performance of FP8 format with different mantissa/exponent division versus INT8. The first part of the repository allows the user to reproduce 16 | analytical computations of SQNR for uniform, Gaussian, and Student's-t distibutions. Varying the mantissa/exponent bit-width division changes the trade-off between accurate representation of the data around mean of the distribution, 17 | and the ability to capture its tails. The more outliers are present in the data, the more exponent bits is useful to allocate for the best results. In the second part we provide the code to reproduce the post-training quantization (PTQ) 18 | results for MobileNetV2, and Resnet-18 pre-trained on ImageNet. 19 | 20 | 21 | 22 | ## How to install 23 | Make sure to have Python ≥3.8 (tested with Python 3.8.10) and 24 | ensure the latest version of `pip` (tested with 21.3.1): 25 | ```bash 26 | python3 -m venv env 27 | source env/bin/activate 28 | pip install --upgrade --no-deps pip 29 | ``` 30 | 31 | Next, install PyTorch 1.11.0 with the appropriate CUDA version (tested with CUDA 10.0): 32 | ```bash 33 | pip install torch==1.11.0 torchvision==0.12.0 34 | ``` 35 | 36 | Finally, install the remaining dependencies using pip: 37 | ```bash 38 | pip install -r requirements.txt 39 | ``` 40 | ## Running experiments 41 | ### Analytical expected SQNR computations 42 | The main run file to compute the expected SQNR for different distributions using different formats is 43 | `compute_quant_error.py`. The script takes no input arguments and computes the SQNR for different distributions and formats: 44 | ```bash 45 | python compute_quant_error.py 46 | ``` 47 | ### ImageNet experiments 48 | The main run file to reproduce the ImageNet experiments is `image_net.py`. 49 | It contains commands for validating models quantized with post-training quantization. 50 | You can see the full list of options for each command using `python image_net.py [COMMAND] --help`. 51 | ```bash 52 | Usage: image_net.py [OPTIONS] COMMAND [ARGS]... 53 | 54 | Options: 55 | --help Show this message and exit. 56 | 57 | Commands: 58 | validate-quantized 59 | ``` 60 | 61 | To reproduce the experiments run: 62 | ```bash 63 | python image_net.py validate-quantized --images-dir 64 | --architecture --batch-size 64 --seed 10 65 | --model-dir # only needed for MobileNet-V2 66 | --n-bits 8 --cuda --load-type fp32 --quant-setup all --qmethod fp_quantizer --per-channel 67 | --fp8-mantissa-bits=5 --fp8-set-maxval --no-fp8-mse-include-mantissa-bits 68 | --weight-quant-method=current_minmax --act-quant-method=allminmax --num-est-batches=1 69 | ``` 70 | 71 | where can be mobilenet_v2_quantized or resnet18_quantized. 72 | Please note that only MobileNet-V2 requires pre-trained weights that can be downloaded here (the tar file is used as it is without a need to untar): 73 | - [MobileNetV2](https://drive.google.com/open?id=1jlto6HRVD3ipNkAl1lNhDbkBp7HylaqR) 74 | 75 | ## Reference 76 | If you find our work useful, please cite 77 | ``` 78 | @article{kuzmin2022fp8, 79 | title={FP8 Quantization: The Power of the Exponent}, 80 | author={Kuzmin, Andrey and Van Baalen, Mart and Ren, Yuwei and Nagel, Markus and Peters, Jorn and Blankevoort, Tijmen}, 81 | journal={arXiv preprint arXiv:2208.09225}, 82 | year={2022} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /compute_quant_error.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from utils.distributions import ClippedGaussDistr, UniformDistr, ClippedStudentTDistr 8 | from quantization.quant_error_estimator import ( 9 | compute_expected_quant_mse, 10 | compute_expected_dot_prod_mse, 11 | ) 12 | from quantization.quantizers.fp8_quantizer import FPQuantizer 13 | from quantization.range_estimators import estimate_range_line_search 14 | from quantization.quantizers.uniform_quantizers import SymmetricUniformQuantizer 15 | from utils import seed_all 16 | 17 | 18 | def compute_quant_error(distr, n_bits=8, n_samples=5000000, seed=10): 19 | seed_all(seed) 20 | distr_sample = torch.tensor(distr.sample((n_samples,))) 21 | for exp_bits in [5, 4, 3, 2, 0]: 22 | mantissa_bits = n_bits - 1 - exp_bits 23 | if exp_bits > 0: 24 | quant = FPQuantizer(n_bits=8, mantissa_bits=mantissa_bits, set_maxval=True) 25 | elif exp_bits == 0: 26 | quant = SymmetricUniformQuantizer(n_bits=n_bits) 27 | 28 | (quant_range_min, quant_range_max) = estimate_range_line_search(distr_sample, quant) 29 | quant_expected_mse = compute_expected_quant_mse( 30 | distr, quant, quant_range_min, quant_range_max, n_samples 31 | ) 32 | quant_sqnr = -10.0 * np.log10(quant_expected_mse) 33 | 34 | dot_prod_expected_mse = compute_expected_dot_prod_mse( 35 | distr, 36 | distr, 37 | quant, 38 | quant, 39 | quant_range_min, 40 | quant_range_max, 41 | quant_range_min, 42 | quant_range_max, 43 | ) 44 | 45 | dot_prod_sqnr = -10.0 * np.log10(dot_prod_expected_mse) 46 | 47 | print( 48 | "FP8 {} E {} M Quantization: expected MSE {:.2e}".format( 49 | exp_bits, mantissa_bits, quant_expected_mse 50 | ), 51 | " SQNR ", 52 | "{:.2e}\n".format(quant_sqnr), 53 | "Dot product:".rjust(23), 54 | " expected MSE {:.2e}".format(dot_prod_expected_mse), 55 | " SQNR ", 56 | "{:.2e}".format(dot_prod_sqnr), 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | distr_list = [ 62 | UniformDistr(range_min=-1.0, range_max=1.0, params_dict={}), 63 | ClippedGaussDistr(params_dict={"mu": 0.0, "sigma": 1.0}, range_min=-10.0, range_max=10.0), 64 | ClippedStudentTDistr(params_dict={"nu": 8.0}, range_min=-100.0, range_max=100.0), 65 | ] 66 | 67 | for distr in distr_list: 68 | print("*" * 80) 69 | distr.print() 70 | compute_quant_error(distr) 71 | -------------------------------------------------------------------------------- /image_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import logging 4 | import os 5 | 6 | import click 7 | from ignite.contrib.handlers import ProgressBar 8 | from ignite.engine import create_supervised_evaluator 9 | from ignite.metrics import Accuracy, TopKCategoricalAccuracy, Loss 10 | from torch.nn import CrossEntropyLoss 11 | 12 | from quantization.utils import pass_data_for_range_estimation 13 | from utils import DotDict 14 | from utils.click_options import ( 15 | qat_options, 16 | quantization_options, 17 | fp8_options, 18 | quant_params_dict, 19 | base_options, 20 | ) 21 | from utils.qat_utils import get_dataloaders_and_model, ReestimateBNStats 22 | 23 | 24 | class Config(DotDict): 25 | pass 26 | 27 | 28 | @click.group() 29 | def fp8_cmd_group(): 30 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) 31 | 32 | 33 | pass_config = click.make_pass_decorator(Config, ensure=True) 34 | 35 | 36 | @fp8_cmd_group.command() 37 | @pass_config 38 | @base_options 39 | @fp8_options 40 | @quantization_options 41 | @qat_options 42 | @click.option( 43 | "--load-type", 44 | type=click.Choice(["fp32", "quantized"]), 45 | default="quantized", 46 | help='Either "fp32", or "quantized". Specify weather to load a quantized or a FP ' "model.", 47 | ) 48 | def validate_quantized(config, load_type): 49 | """ 50 | function for running validation on pre-trained quantized models 51 | """ 52 | print("Setting up network and data loaders") 53 | qparams = quant_params_dict(config) 54 | 55 | dataloaders, model = get_dataloaders_and_model(config=config, load_type=load_type, **qparams) 56 | 57 | if load_type == "fp32": 58 | # Estimate ranges using training data 59 | pass_data_for_range_estimation( 60 | loader=dataloaders.train_loader, 61 | model=model, 62 | act_quant=config.quant.act_quant, 63 | weight_quant=config.quant.weight_quant, 64 | max_num_batches=config.quant.num_est_batches, 65 | ) 66 | # Ensure we have the desired quant state 67 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant) 68 | 69 | # Fix ranges 70 | model.fix_ranges() 71 | 72 | # Create evaluator 73 | loss_func = CrossEntropyLoss() 74 | metrics = { 75 | "top_1_accuracy": Accuracy(), 76 | "top_5_accuracy": TopKCategoricalAccuracy(), 77 | "loss": Loss(loss_func), 78 | } 79 | 80 | pbar = ProgressBar() 81 | evaluator = create_supervised_evaluator( 82 | model=model, metrics=metrics, device="cuda" if config.base.cuda else "cpu" 83 | ) 84 | pbar.attach(evaluator) 85 | print("Model with the ranges estimated:\n{}".format(model)) 86 | 87 | # BN Re-estimation 88 | if config.qat.reestimate_bn_stats: 89 | ReestimateBNStats( 90 | model, dataloaders.train_loader, num_batches=int(0.02 * len(dataloaders.train_loader)) 91 | )(None) 92 | 93 | print("Start quantized validation") 94 | evaluator.run(dataloaders.val_loader) 95 | final_metrics = evaluator.state.metrics 96 | print(final_metrics) 97 | 98 | 99 | if __name__ == "__main__": 100 | fp8_cmd_group() 101 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from models.mobilenet_v2_quantized import mobilenetv2_quantized 6 | from models.resnet_quantized import resnet18_quantized, resnet50_quantized 7 | from utils import ClassEnumOptions, MethodMap 8 | 9 | 10 | class QuantArchitectures(ClassEnumOptions): 11 | mobilenet_v2_quantized = MethodMap(mobilenetv2_quantized) 12 | resnet18_quantized = MethodMap(resnet18_quantized) 13 | resnet50_quantized = MethodMap(resnet50_quantized) 14 | -------------------------------------------------------------------------------- /models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | # https://github.com/tonylins/pytorch-mobilenet-v2 6 | 7 | import math 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | __all__ = ["MobileNetV2"] 13 | 14 | 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def conv_1x1_bn(inp, oup): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True) 24 | ) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = round(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 40 | nn.BatchNorm2d(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 54 | nn.BatchNorm2d(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 58 | nn.BatchNorm2d(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | 68 | class MobileNetV2(nn.Module): 69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.0, dropout=0.0): 70 | super().__init__() 71 | block = InvertedResidual 72 | input_channel = 32 73 | last_channel = 1280 74 | inverted_residual_setting = [ 75 | # t, c, n, s 76 | [1, 16, 1, 1], 77 | [6, 24, 2, 2], 78 | [6, 32, 3, 2], 79 | [6, 64, 4, 2], 80 | [6, 96, 3, 1], 81 | [6, 160, 3, 2], 82 | [6, 320, 1, 1], 83 | ] 84 | 85 | # building first layer 86 | assert input_size % 32 == 0 87 | input_channel = int(input_channel * width_mult) 88 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 89 | features = [conv_bn(3, input_channel, 2)] 90 | # building inverted residual blocks 91 | for t, c, n, s in inverted_residual_setting: 92 | output_channel = int(c * width_mult) 93 | for i in range(n): 94 | if i == 0: 95 | features.append(block(input_channel, output_channel, s, expand_ratio=t)) 96 | else: 97 | features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 98 | input_channel = output_channel 99 | # building last several layers 100 | features.append(conv_1x1_bn(input_channel, self.last_channel)) 101 | features.append(nn.AvgPool2d(input_size // 32)) 102 | # make it nn.Sequential 103 | self.features = nn.Sequential(*features) 104 | 105 | # building classifier 106 | self.classifier = nn.Sequential( 107 | nn.Dropout(dropout), 108 | nn.Linear(self.last_channel, n_class), 109 | ) 110 | 111 | self._initialize_weights() 112 | 113 | def forward(self, x): 114 | x = self.features(x) 115 | x = F.adaptive_avg_pool2d(x, 1).squeeze() # type: ignore[arg-type] # accepted slang 116 | x = self.classifier(x) 117 | return x 118 | 119 | def _initialize_weights(self): 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.Linear): 130 | n = m.weight.size(1) 131 | m.weight.data.normal_(0, 0.01) 132 | m.bias.data.zero_() 133 | -------------------------------------------------------------------------------- /models/mobilenet_v2_quantized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import os 6 | import re 7 | import torch 8 | from collections import OrderedDict 9 | from models.mobilenet_v2 import MobileNetV2, InvertedResidual 10 | from quantization.autoquant_utils import quantize_sequential, Flattener, quantize_model, BNQConv 11 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts 12 | from quantization.base_quantized_model import QuantizedModel 13 | 14 | 15 | class QuantizedInvertedResidual(QuantizedActivation): 16 | def __init__(self, inv_res_orig, **quant_params): 17 | super().__init__(**quant_params) 18 | self.use_res_connect = inv_res_orig.use_res_connect 19 | self.conv = quantize_sequential(inv_res_orig.conv, **quant_params) 20 | 21 | def forward(self, x): 22 | if self.use_res_connect: 23 | x = x + self.conv(x) 24 | return self.quantize_activations(x) 25 | else: 26 | return self.conv(x) 27 | 28 | 29 | class QuantizedMobileNetV2(QuantizedModel): 30 | def __init__(self, model_fp, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params): 31 | super().__init__(input_size) 32 | specials = {InvertedResidual: QuantizedInvertedResidual} 33 | # quantize and copy parts from original model 34 | quantize_input = quant_setup and quant_setup == "LSQ_paper" 35 | self.features = quantize_sequential( 36 | model_fp.features, 37 | tie_activation_quantizers=not quantize_input, 38 | specials=specials, 39 | **quant_params, 40 | ) 41 | 42 | self.flattener = Flattener() 43 | self.classifier = quantize_model(model_fp.classifier, **quant_params) 44 | 45 | if quant_setup == "FP_logits": 46 | print("Do not quantize output of FC layer") 47 | self.classifier[1].activation_quantizer = FP32Acts() 48 | # self.classifier.activation_quantizer = FP32Acts() # no activation quantization of logits 49 | elif quant_setup == "fc4": 50 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 51 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4 52 | elif quant_setup == "fc4_dw8": 53 | print("\n\n### fc4_dw8 setup ###\n\n") 54 | # FC layer in 4 bits, depth-wise separable once in 8 bit 55 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 56 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4 57 | for name, module in self.named_modules(): 58 | if isinstance(module, BNQConv) and module.groups == module.in_channels: 59 | module.weight_quantizer.quantizer.n_bits = 8 60 | print(f"Set layer {name} to 8 bits") 61 | elif quant_setup == "LSQ": 62 | print("Set quantization to LSQ (first+last layer in 8 bits)") 63 | # Weights of the first layer 64 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 65 | # The quantizer of the last conv_layer layer (input to avgpool with tied quantizers) 66 | self.features[-2][0].activation_quantizer.quantizer.n_bits = 8 67 | # Weights of the last layer 68 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8 69 | # no activation quantization of logits 70 | self.classifier[1].activation_quantizer = FP32Acts() 71 | elif quant_setup == "LSQ_paper": 72 | # Weights of the first layer 73 | self.features[0][0].activation_quantizer = FP32Acts() 74 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8 75 | # Weights of the last layer 76 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8 77 | self.classifier[1].activation_quantizer.quantizer.n_bits = 8 78 | # Set all QuantizedActivations to FP32 79 | for layer in self.features.modules(): 80 | if isinstance(layer, QuantizedActivation): 81 | layer.activation_quantizer = FP32Acts() 82 | elif quant_setup is not None and quant_setup != "all": 83 | raise ValueError( 84 | "Quantization setup '{}' not supported for MobilenetV2".format(quant_setup) 85 | ) 86 | 87 | def forward(self, x): 88 | x = self.features(x) 89 | x = self.flattener(x) 90 | x = self.classifier(x) 91 | 92 | return x 93 | 94 | 95 | def mobilenetv2_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams): 96 | fp_model = MobileNetV2() 97 | if pretrained and load_type == "fp32": 98 | # Load model from pretrained FP32 weights 99 | assert os.path.exists(model_dir) 100 | print(f"Loading pretrained weights from {model_dir}") 101 | state_dict = torch.load(model_dir) 102 | fp_model.load_state_dict(state_dict) 103 | quant_model = QuantizedMobileNetV2(fp_model, **qparams) 104 | elif load_type == "quantized": 105 | # Load pretrained QuantizedModel 106 | print(f"Loading pretrained quantized model from {model_dir}") 107 | state_dict = torch.load(model_dir) 108 | quant_model = QuantizedMobileNetV2(fp_model, **qparams) 109 | quant_model.load_state_dict(state_dict, strict=False) 110 | else: 111 | raise ValueError("wrong load_type specified") 112 | 113 | return quant_model 114 | -------------------------------------------------------------------------------- /models/resnet_quantized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import torch 5 | from torch import nn 6 | from torchvision.models.resnet import BasicBlock, Bottleneck 7 | from torchvision.models import resnet18, resnet50 8 | 9 | from quantization.autoquant_utils import quantize_model, Flattener, QuantizedActivationWrapper 10 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts 11 | from quantization.base_quantized_model import QuantizedModel 12 | 13 | 14 | class QuantizedBlock(QuantizedActivation): 15 | def __init__(self, block, **quant_params): 16 | super().__init__(**quant_params) 17 | 18 | if isinstance(block, Bottleneck): 19 | features = nn.Sequential( 20 | block.conv1, 21 | block.bn1, 22 | block.relu, 23 | block.conv2, 24 | block.bn2, 25 | block.relu, 26 | block.conv3, 27 | block.bn3, 28 | ) 29 | elif isinstance(block, BasicBlock): 30 | features = nn.Sequential(block.conv1, block.bn1, block.relu, block.conv2, block.bn2) 31 | 32 | self.features = quantize_model(features, **quant_params) 33 | self.downsample = ( 34 | quantize_model(block.downsample, **quant_params) if block.downsample else None 35 | ) 36 | 37 | self.relu = block.relu 38 | 39 | def forward(self, x): 40 | residual = x if self.downsample is None else self.downsample(x) 41 | out = self.features(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return self.quantize_activations(out) 47 | 48 | 49 | class QuantizedResNet(QuantizedModel): 50 | def __init__(self, resnet, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params): 51 | super().__init__(input_size) 52 | specials = {BasicBlock: QuantizedBlock, Bottleneck: QuantizedBlock} 53 | 54 | if hasattr(resnet, "maxpool"): 55 | # ImageNet ResNet case 56 | features = nn.Sequential( 57 | resnet.conv1, 58 | resnet.bn1, 59 | resnet.relu, 60 | resnet.maxpool, 61 | resnet.layer1, 62 | resnet.layer2, 63 | resnet.layer3, 64 | resnet.layer4, 65 | ) 66 | else: 67 | # Tiny ImageNet ResNet case 68 | features = nn.Sequential( 69 | resnet.conv1, 70 | resnet.bn1, 71 | resnet.relu, 72 | resnet.layer1, 73 | resnet.layer2, 74 | resnet.layer3, 75 | resnet.layer4, 76 | ) 77 | 78 | self.features = quantize_model(features, specials=specials, **quant_params) 79 | 80 | if quant_setup and quant_setup == "LSQ_paper": 81 | # Keep avgpool intact as we quantize the input the last layer 82 | self.avgpool = resnet.avgpool 83 | else: 84 | self.avgpool = QuantizedActivationWrapper( 85 | resnet.avgpool, 86 | tie_activation_quantizers=True, 87 | input_quantizer=self.features[-1][-1].activation_quantizer, 88 | **quant_params, 89 | ) 90 | self.flattener = Flattener() 91 | self.fc = quantize_model(resnet.fc, **quant_params) 92 | 93 | # Adapt to specific quantization setup 94 | if quant_setup == "LSQ": 95 | print("Set quantization to LSQ (first+last layer in 8 bits)") 96 | # Weights of the first layer 97 | self.features[0].weight_quantizer.quantizer.n_bits = 8 98 | # The quantizer of the residual (input to last layer) 99 | self.features[-1][-1].activation_quantizer.quantizer.n_bits = 8 100 | # Output of the last conv (input to last layer) 101 | self.features[-1][-1].features[-1].activation_quantizer.quantizer.n_bits = 8 102 | # Weights of the last layer 103 | self.fc.weight_quantizer.quantizer.n_bits = 8 104 | # no activation quantization of logits 105 | self.fc.activation_quantizer = FP32Acts() 106 | elif quant_setup == "LSQ_paper": 107 | # Weights of the first layer 108 | self.features[0].activation_quantizer = FP32Acts() 109 | self.features[0].weight_quantizer.quantizer.n_bits = 8 110 | # Weights of the last layer 111 | self.fc.activation_quantizer.quantizer.n_bits = 8 112 | self.fc.weight_quantizer.quantizer.n_bits = 8 113 | # Set all QuantizedActivations to FP32 114 | for layer in self.features.modules(): 115 | if isinstance(layer, QuantizedActivation): 116 | layer.activation_quantizer = FP32Acts() 117 | elif quant_setup == "FP_logits": 118 | print("Do not quantize output of FC layer") 119 | self.fc.activation_quantizer = FP32Acts() # no activation quantization of logits 120 | elif quant_setup == "fc4": 121 | self.features[0].weight_quantizer.quantizer.n_bits = 8 122 | self.fc.weight_quantizer.quantizer.n_bits = 4 123 | elif quant_setup is not None and quant_setup != "all": 124 | raise ValueError("Quantization setup '{}' not supported for Resnet".format(quant_setup)) 125 | 126 | def forward(self, x): 127 | x = self.features(x) 128 | x = self.avgpool(x) 129 | # x = x.view(x.size(0), -1) 130 | x = self.flattener(x) 131 | x = self.fc(x) 132 | 133 | return x 134 | 135 | 136 | def resnet18_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams): 137 | if load_type == "fp32": 138 | # Load model from pretrained FP32 weights 139 | fp_model = resnet18(pretrained=pretrained) 140 | quant_model = QuantizedResNet(fp_model, **qparams) 141 | elif load_type == "quantized": 142 | # Load pretrained QuantizedModel 143 | print(f"Loading pretrained quantized model from {model_dir}") 144 | state_dict = torch.load(model_dir) 145 | fp_model = resnet18() 146 | quant_model = QuantizedResNet(fp_model, **qparams) 147 | quant_model.load_state_dict(state_dict) 148 | else: 149 | raise ValueError("wrong load_type specified") 150 | return quant_model 151 | 152 | 153 | def resnet50_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams): 154 | if load_type == "fp32": 155 | # Load model from pretrained FP32 weights 156 | fp_model = resnet50(pretrained=pretrained) 157 | quant_model = QuantizedResNet(fp_model, **qparams) 158 | elif load_type == "quantized": 159 | # Load pretrained QuantizedModel 160 | print(f"Loading pretrained quantized model from {model_dir}") 161 | state_dict = torch.load(model_dir) 162 | fp_model = resnet50() 163 | quant_model = QuantizedResNet(fp_model, **qparams) 164 | quant_model.load_state_dict(state_dict) 165 | else: 166 | raise ValueError("wrong load_type specified") 167 | return quant_model 168 | -------------------------------------------------------------------------------- /quantization/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from . import utils, autoquant_utils, quantized_folded_bn 6 | -------------------------------------------------------------------------------- /quantization/autoquant_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import copy 6 | import warnings 7 | 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.nn.modules.conv import _ConvNd 11 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd 12 | 13 | 14 | from quantization.base_quantized_classes import QuantizedActivation, QuantizedModule 15 | from quantization.hijacker import QuantizationHijacker, activations_set 16 | from quantization.quantization_manager import QuantizationManager 17 | from quantization.quantized_folded_bn import BNFusedHijacker 18 | 19 | 20 | class QuantConv1d(QuantizationHijacker, nn.Conv1d): 21 | def run_forward(self, x, weight, bias, offsets=None): 22 | return F.conv1d( 23 | x.contiguous(), 24 | weight.contiguous(), 25 | bias=bias, 26 | stride=self.stride, 27 | padding=self.padding, 28 | dilation=self.dilation, 29 | groups=self.groups, 30 | ) 31 | 32 | 33 | class QuantConv(QuantizationHijacker, nn.Conv2d): 34 | def run_forward(self, x, weight, bias, offsets=None): 35 | return F.conv2d( 36 | x.contiguous(), 37 | weight.contiguous(), 38 | bias=bias, 39 | stride=self.stride, 40 | padding=self.padding, 41 | dilation=self.dilation, 42 | groups=self.groups, 43 | ) 44 | 45 | 46 | class QuantConvTransposeBase(QuantizationHijacker): 47 | def quantize_weights(self, weights): 48 | if self.per_channel_weights: 49 | # NOTE: ND tranpose conv weights are stored as (in_channels, out_channels, *) 50 | # instead of (out_channels, in_channels, *) for convs 51 | # and per-channel quantization should be applied to out channels 52 | # transposing before passing to quantizer is trick to avoid 53 | # changing logic in range estimators and quantizers 54 | weights = weights.transpose(1, 0).contiguous() 55 | weights = self.weight_quantizer(weights) 56 | if self.per_channel_weights: 57 | weights = weights.transpose(1, 0).contiguous() 58 | return weights 59 | 60 | 61 | class QuantConvTranspose1d(QuantConvTransposeBase, nn.ConvTranspose1d): 62 | def run_forward(self, x, weight, bias, offsets=None): 63 | return F.conv_transpose1d( 64 | x.contiguous(), 65 | weight.contiguous(), 66 | bias=bias, 67 | stride=self.stride, 68 | padding=self.padding, 69 | output_padding=self.output_padding, 70 | dilation=self.dilation, 71 | groups=self.groups, 72 | ) 73 | 74 | 75 | class QuantConvTranspose(QuantConvTransposeBase, nn.ConvTranspose2d): 76 | def run_forward(self, x, weight, bias, offsets=None): 77 | return F.conv_transpose2d( 78 | x.contiguous(), 79 | weight.contiguous(), 80 | bias=bias, 81 | stride=self.stride, 82 | padding=self.padding, 83 | output_padding=self.output_padding, 84 | dilation=self.dilation, 85 | groups=self.groups, 86 | ) 87 | 88 | 89 | class QuantLinear(QuantizationHijacker, nn.Linear): 90 | def run_forward(self, x, weight, bias, offsets=None): 91 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias) 92 | 93 | 94 | class BNQConv1d(BNFusedHijacker, nn.Conv1d): 95 | def run_forward(self, x, weight, bias, offsets=None): 96 | return F.conv1d( 97 | x.contiguous(), 98 | weight.contiguous(), 99 | bias=bias, 100 | stride=self.stride, 101 | padding=self.padding, 102 | dilation=self.dilation, 103 | groups=self.groups, 104 | ) 105 | 106 | 107 | class BNQConv(BNFusedHijacker, nn.Conv2d): 108 | def run_forward(self, x, weight, bias, offsets=None): 109 | return F.conv2d( 110 | x.contiguous(), 111 | weight.contiguous(), 112 | bias=bias, 113 | stride=self.stride, 114 | padding=self.padding, 115 | dilation=self.dilation, 116 | groups=self.groups, 117 | ) 118 | 119 | 120 | class BNQLinear(BNFusedHijacker, nn.Linear): 121 | def run_forward(self, x, weight, bias, offsets=None): 122 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias) 123 | 124 | 125 | class QuantizedActivationWrapper(QuantizedActivation): 126 | """ 127 | Wraps over a layer and quantized the activation. 128 | It also allow for tying the input and output quantizer which is helpful 129 | for layers such Average Pooling 130 | """ 131 | 132 | def __init__( 133 | self, 134 | layer, 135 | tie_activation_quantizers=False, 136 | input_quantizer: QuantizationManager = None, 137 | *args, 138 | **kwargs, 139 | ): 140 | super().__init__(*args, **kwargs) 141 | self.tie_activation_quantizers = tie_activation_quantizers 142 | if input_quantizer: 143 | assert isinstance(input_quantizer, QuantizationManager) 144 | self.activation_quantizer = input_quantizer 145 | self.layer = layer 146 | 147 | def quantize_activations_no_range_update(self, x): 148 | if self._quant_a: 149 | return self.activation_quantizer.quantizer(x) 150 | else: 151 | return x 152 | 153 | def forward(self, x): 154 | x = self.layer(x) 155 | if self.tie_activation_quantizers: 156 | # The input activation quantizer is used to quantize the activation 157 | # but without updating the quantization range 158 | return self.quantize_activations_no_range_update(x) 159 | else: 160 | return self.quantize_activations(x) 161 | 162 | def extra_repr(self): 163 | return f"tie_activation_quantizers={self.tie_activation_quantizers}" 164 | 165 | 166 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm): 167 | def run_forward(self, x, weight, bias, offsets=None): 168 | return F.layer_norm( 169 | input=x.contiguous(), 170 | normalized_shape=self.normalized_shape, 171 | weight=weight.contiguous(), 172 | bias=bias.contiguous(), 173 | eps=self.eps, 174 | ) 175 | 176 | 177 | class Flattener(nn.Module): 178 | def forward(self, x): 179 | return x.view(x.shape[0], -1) 180 | 181 | 182 | # Non BN Quant Modules Map 183 | non_bn_module_map = { 184 | nn.Conv1d: QuantConv1d, 185 | nn.Conv2d: QuantConv, 186 | nn.ConvTranspose1d: QuantConvTranspose1d, 187 | nn.ConvTranspose2d: QuantConvTranspose, 188 | nn.Linear: QuantLinear, 189 | nn.LayerNorm: QuantLayerNorm, 190 | } 191 | 192 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd) 193 | # BN Quant Modules Map 194 | bn_module_map = {nn.Conv1d: BNQConv1d, nn.Conv2d: BNQConv, nn.Linear: BNQLinear} 195 | 196 | quant_conv_modules = (QuantConv1d, QuantConv, BNQConv1d, BNQConv) 197 | 198 | 199 | def next_bn(module, i): 200 | return len(module) > i + 1 and isinstance(module[i + 1], (nn.BatchNorm2d, nn.BatchNorm1d)) 201 | 202 | 203 | def get_act(module, i): 204 | # Case 1: conv + act 205 | if len(module) - i > 1 and isinstance(module[i + 1], tuple(activations_set)): 206 | return module[i + 1], i + 1 207 | 208 | # Case 2: conv + bn + act 209 | if ( 210 | len(module) - i > 2 211 | and next_bn(module, i) 212 | and isinstance(module[i + 2], tuple(activations_set)) 213 | ): 214 | return module[i + 2], i + 2 215 | 216 | # Case 3: conv + bn + X -> return false 217 | # Case 4: conv + X -> return false 218 | return None, None 219 | 220 | 221 | def get_conv_args(module): 222 | args = dict( 223 | in_channels=module.in_channels, 224 | out_channels=module.out_channels, 225 | kernel_size=module.kernel_size, 226 | stride=module.stride, 227 | padding=module.padding, 228 | dilation=module.dilation, 229 | groups=module.groups, 230 | bias=module.bias is not None, 231 | ) 232 | if isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d)): 233 | args["output_padding"] = module.output_padding 234 | return args 235 | 236 | 237 | def get_linear_args(module): 238 | args = dict( 239 | in_features=module.in_features, 240 | out_features=module.out_features, 241 | bias=module.bias is not None, 242 | ) 243 | return args 244 | 245 | 246 | def get_layernorm_args(module): 247 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps) 248 | return args 249 | 250 | 251 | def get_module_args(mod, act): 252 | if isinstance(mod, _ConvNd): 253 | kwargs = get_conv_args(mod) 254 | elif isinstance(mod, nn.Linear): 255 | kwargs = get_linear_args(mod) 256 | elif isinstance(mod, nn.LayerNorm): 257 | kwargs = get_layernorm_args(mod) 258 | else: 259 | raise ValueError 260 | 261 | kwargs["activation"] = act 262 | 263 | return kwargs 264 | 265 | 266 | def fold_bn(module, i, **quant_params): 267 | bn = next_bn(module, i) 268 | act, act_idx = get_act(module, i) 269 | modmap = bn_module_map if bn else non_bn_module_map 270 | modtype = modmap[type(module[i])] 271 | 272 | kwargs = get_module_args(module[i], act) 273 | new_module = modtype(**kwargs, **quant_params) 274 | new_module.weight.data = module[i].weight.data.clone() 275 | 276 | if bn: 277 | new_module.gamma.data = module[i + 1].weight.data.clone() 278 | new_module.beta.data = module[i + 1].bias.data.clone() 279 | new_module.running_mean.data = module[i + 1].running_mean.data.clone() 280 | new_module.running_var.data = module[i + 1].running_var.data.clone() 281 | if module[i].bias is not None: 282 | new_module.running_mean.data -= module[i].bias.data 283 | print("Warning: bias in conv/linear before batch normalization.") 284 | new_module.epsilon = module[i + 1].eps 285 | 286 | elif module[i].bias is not None: 287 | new_module.bias.data = module[i].bias.data.clone() 288 | 289 | return new_module, i + int(bool(act)) + int(bn) + 1 290 | 291 | 292 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params): 293 | specials = specials or dict() 294 | 295 | i = 0 296 | quant_modules = [] 297 | while i < len(model): 298 | if isinstance(model[i], QuantizedModule): 299 | quant_modules.append(model[i]) 300 | elif type(model[i]) in non_bn_module_map: 301 | new_module, new_i = fold_bn(model, i, **quant_params) 302 | quant_modules.append(new_module) 303 | i = new_i 304 | continue 305 | 306 | elif type(model[i]) in specials: 307 | quant_modules.append(specials[type(model[i])](model[i], **quant_params)) 308 | 309 | elif isinstance(model[i], non_param_modules): 310 | # Check for last quantizer 311 | input_quantizer = None 312 | if quant_modules and isinstance(quant_modules[-1], QuantizedModule): 313 | last_layer = quant_modules[-1] 314 | input_quantizer = quant_modules[-1].activation_quantizer 315 | elif ( 316 | quant_modules 317 | and isinstance(quant_modules[-1], nn.Sequential) 318 | and isinstance(quant_modules[-1][-1], QuantizedModule) 319 | ): 320 | last_layer = quant_modules[-1][-1] 321 | input_quantizer = quant_modules[-1][-1].activation_quantizer 322 | 323 | if input_quantizer and tie_activation_quantizers: 324 | # If input quantizer is found the tie input/output act quantizers 325 | print( 326 | f"Tying input quantizer {i-1}^th layer of type {type(last_layer)} to the " 327 | f"quantized {type(model[i])} following it" 328 | ) 329 | quant_modules.append( 330 | QuantizedActivationWrapper( 331 | model[i], 332 | tie_activation_quantizers=tie_activation_quantizers, 333 | input_quantizer=input_quantizer, 334 | **quant_params, 335 | ) 336 | ) 337 | else: 338 | # Input quantizer not found 339 | quant_modules.append(QuantizedActivationWrapper(model[i], **quant_params)) 340 | if tie_activation_quantizers: 341 | warnings.warn("Input quantizer not found, so we do not tie quantizers") 342 | else: 343 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params)) 344 | i += 1 345 | return nn.Sequential(*quant_modules) 346 | 347 | 348 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params): 349 | specials = specials or dict() 350 | 351 | if isinstance(model, nn.Sequential): 352 | quant_model = quantize_sequential( 353 | model, specials, tie_activation_quantizers, **quant_params 354 | ) 355 | 356 | elif type(model) in specials: 357 | quant_model = specials[type(model)](model, **quant_params) 358 | 359 | elif isinstance(model, non_param_modules): 360 | quant_model = QuantizedActivationWrapper(model, **quant_params) 361 | 362 | elif type(model) in non_bn_module_map: 363 | # If we do isinstance() then we might run into issues with modules that inherit from 364 | # one of these classes, for whatever reason 365 | modtype = non_bn_module_map[type(model)] 366 | kwargs = get_module_args(model, None) 367 | quant_model = modtype(**kwargs, **quant_params) 368 | 369 | quant_model.weight.data = model.weight.data 370 | if getattr(model, "bias", None) is not None: 371 | quant_model.bias.data = model.bias.data 372 | 373 | else: 374 | # Unknown type, try to quantize all child modules 375 | quant_model = copy.deepcopy(model) 376 | for name, module in quant_model._modules.items(): 377 | new_model = quantize_model(module, specials=specials, **quant_params) 378 | if new_model is not None: 379 | setattr(quant_model, name, new_model) 380 | 381 | return quant_model 382 | -------------------------------------------------------------------------------- /quantization/base_quantized_classes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from quantization.quantization_manager import QuantizationManager 9 | from quantization.quantizers import QuantizerBase, AsymmetricUniformQuantizer 10 | from quantization.range_estimators import ( 11 | RangeEstimatorBase, 12 | CurrentMinMaxEstimator, 13 | RunningMinMaxEstimator, 14 | ) 15 | 16 | 17 | def _set_layer_learn_ranges(layer): 18 | if isinstance(layer, QuantizationManager): 19 | if layer.quantizer.is_initialized: 20 | layer.learn_ranges() 21 | 22 | 23 | def _set_layer_fix_ranges(layer): 24 | if isinstance(layer, QuantizationManager): 25 | if layer.quantizer.is_initialized: 26 | layer.fix_ranges() 27 | 28 | 29 | def _set_layer_estimate_ranges(layer): 30 | if isinstance(layer, QuantizationManager): 31 | layer.estimate_ranges() 32 | 33 | 34 | def _set_layer_estimate_ranges_train(layer): 35 | if isinstance(layer, QuantizationManager): 36 | if layer.quantizer.is_initialized: 37 | layer.estimate_ranges_train() 38 | 39 | 40 | class QuantizedModule(nn.Module): 41 | """ 42 | Parent class for a quantized module. It adds the basic functionality of switching the module 43 | between quantized and full precision mode. It also defines the cached parameters and handles 44 | the reset of the cache properly. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | *args, 50 | method: QuantizerBase = AsymmetricUniformQuantizer, 51 | act_method=None, 52 | weight_range_method: RangeEstimatorBase = CurrentMinMaxEstimator, 53 | act_range_method: RangeEstimatorBase = RunningMinMaxEstimator, 54 | n_bits=8, 55 | n_bits_act=None, 56 | per_channel_weights=False, 57 | percentile=None, 58 | weight_range_options=None, 59 | act_range_options=None, 60 | scale_domain="linear", 61 | act_quant_kwargs={}, 62 | weight_quant_kwargs={}, 63 | quantize_input=False, 64 | fp8_kwargs=None, 65 | **kwargs 66 | ): 67 | kwargs.pop("act_quant_dict", None) 68 | 69 | super().__init__(*args, **kwargs) 70 | 71 | self.method = method 72 | self.act_method = act_method or method 73 | self.n_bits = n_bits 74 | self.n_bits_act = n_bits_act or n_bits 75 | self.per_channel_weights = per_channel_weights 76 | self.percentile = percentile 77 | self.weight_range_method = weight_range_method 78 | self.weight_range_options = weight_range_options if weight_range_options else {} 79 | self.act_range_method = act_range_method 80 | self.act_range_options = act_range_options if act_range_options else {} 81 | self.scale_domain = scale_domain 82 | self.quantize_input = quantize_input 83 | self.fp8_kwargs = fp8_kwargs or {} 84 | 85 | self.quant_params = None 86 | self.register_buffer("_quant_w", torch.BoolTensor([False])) 87 | self.register_buffer("_quant_a", torch.BoolTensor([False])) 88 | 89 | self.act_qparams = dict( 90 | n_bits=self.n_bits_act, 91 | scale_domain=self.scale_domain, 92 | **act_quant_kwargs, 93 | **self.fp8_kwargs 94 | ) 95 | self.weight_qparams = dict( 96 | n_bits=self.n_bits, 97 | scale_domain=self.scale_domain, 98 | **weight_quant_kwargs, 99 | **self.fp8_kwargs 100 | ) 101 | 102 | def quantized_weights(self): 103 | self._quant_w = torch.BoolTensor([True]) 104 | 105 | def full_precision_weights(self): 106 | self._quant_w = torch.BoolTensor([False]) 107 | 108 | def quantized_acts(self): 109 | self._quant_a = torch.BoolTensor([True]) 110 | 111 | def full_precision_acts(self): 112 | self._quant_a = torch.BoolTensor([False]) 113 | 114 | def quantized(self): 115 | self.quantized_weights() 116 | self.quantized_acts() 117 | 118 | def full_precision(self): 119 | self.full_precision_weights() 120 | self.full_precision_acts() 121 | 122 | def get_quantizer_status(self): 123 | return dict(quant_a=self._quant_a.item(), quant_w=self._quant_w.item()) 124 | 125 | def set_quantizer_status(self, quantizer_status): 126 | if quantizer_status["quant_a"]: 127 | self.quantized_acts() 128 | else: 129 | self.full_precision_acts() 130 | 131 | if quantizer_status["quant_w"]: 132 | self.quantized_weights() 133 | else: 134 | self.full_precision_weights() 135 | 136 | def learn_ranges(self): 137 | self.apply(_set_layer_learn_ranges) 138 | 139 | def fix_ranges(self): 140 | self.apply(_set_layer_fix_ranges) 141 | 142 | def estimate_ranges(self): 143 | self.apply(_set_layer_estimate_ranges) 144 | 145 | def estimate_ranges_train(self): 146 | self.apply(_set_layer_estimate_ranges_train) 147 | 148 | def extra_repr(self): 149 | quant_state = "weight_quant={}, act_quant={}".format( 150 | self._quant_w.item(), self._quant_a.item() 151 | ) 152 | parent_repr = super().extra_repr() 153 | return "{},\n{}".format(parent_repr, quant_state) if parent_repr else quant_state 154 | 155 | 156 | class QuantizedActivation(QuantizedModule): 157 | def __init__(self, *args, **kwargs): 158 | super().__init__(*args, **kwargs) 159 | self.activation_quantizer = QuantizationManager( 160 | qmethod=self.act_method, 161 | qparams=self.act_qparams, 162 | init=self.act_range_method, 163 | range_estim_params=self.act_range_options, 164 | ) 165 | 166 | def quantize_activations(self, x): 167 | if self._quant_a: 168 | return self.activation_quantizer(x) 169 | else: 170 | return x 171 | 172 | def forward(self, x): 173 | return self.quantize_activations(x) 174 | 175 | 176 | class FP32Acts(nn.Module): 177 | def forward(self, x): 178 | return x 179 | 180 | def reset_ranges(self): 181 | pass 182 | -------------------------------------------------------------------------------- /quantization/base_quantized_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | from typing import Union, Dict 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | 9 | from quantization.base_quantized_classes import ( 10 | QuantizedModule, 11 | _set_layer_estimate_ranges, 12 | _set_layer_estimate_ranges_train, 13 | _set_layer_learn_ranges, 14 | _set_layer_fix_ranges, 15 | ) 16 | from quantization.quantizers import QuantizerBase 17 | 18 | 19 | class QuantizedModel(nn.Module): 20 | """ 21 | Parent class for a quantized model. This allows you to have convenience functions to put the 22 | whole model into quantization or full precision. 23 | """ 24 | 25 | def __init__(self, input_size=(1, 3, 224, 224)): 26 | """ 27 | Parameters 28 | ---------- 29 | input_size: Tuple with the input dimension for the model (including batch dimension) 30 | """ 31 | super().__init__() 32 | self.input_size = input_size 33 | 34 | def load_state_dict( 35 | self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True 36 | ): 37 | """ 38 | This function overwrites the load_state_dict of nn.Module to ensure that quantization 39 | parameters are loaded correctly for quantized model. 40 | 41 | """ 42 | quant_state_dict = { 43 | k: v for k, v in state_dict.items() if k.endswith("_quant_a") or k.endswith("_quant_w") 44 | } 45 | 46 | if quant_state_dict: 47 | super().load_state_dict(quant_state_dict, strict=False) 48 | else: 49 | raise ValueError( 50 | "The quantization states of activations or weights should be " 51 | "included in the state dict " 52 | ) 53 | # Pass dummy data through quantized model to ensure all quantization parameters are 54 | # initialized with the correct dimensions (None tensors will lead to issues in state dict 55 | # loading) 56 | device = next(self.parameters()).device 57 | dummy_input = torch.rand(*self.input_size, device=device) 58 | with torch.no_grad(): 59 | self.forward(dummy_input) 60 | 61 | # Load state dict 62 | super().load_state_dict(state_dict, strict) 63 | 64 | def quantized_weights(self): 65 | def _fn(layer): 66 | if isinstance(layer, QuantizedModule): 67 | layer.quantized_weights() 68 | 69 | self.apply(_fn) 70 | 71 | def full_precision_weights(self): 72 | def _fn(layer): 73 | if isinstance(layer, QuantizedModule): 74 | layer.full_precision_weights() 75 | 76 | self.apply(_fn) 77 | 78 | def quantized_acts(self): 79 | def _fn(layer): 80 | if isinstance(layer, QuantizedModule): 81 | layer.quantized_acts() 82 | 83 | self.apply(_fn) 84 | 85 | def full_precision_acts(self): 86 | def _fn(layer): 87 | if isinstance(layer, QuantizedModule): 88 | layer.full_precision_acts() 89 | 90 | self.apply(_fn) 91 | 92 | def quantized(self): 93 | def _fn(layer): 94 | if isinstance(layer, QuantizedModule): 95 | layer.quantized() 96 | 97 | self.apply(_fn) 98 | 99 | def full_precision(self): 100 | def _fn(layer): 101 | if isinstance(layer, QuantizedModule): 102 | layer.full_precision() 103 | 104 | self.apply(_fn) 105 | 106 | def estimate_ranges(self): 107 | self.apply(_set_layer_estimate_ranges) 108 | 109 | def estimate_ranges_train(self): 110 | self.apply(_set_layer_estimate_ranges_train) 111 | 112 | def set_quant_state(self, weight_quant, act_quant): 113 | if act_quant: 114 | self.quantized_acts() 115 | else: 116 | self.full_precision_acts() 117 | 118 | if weight_quant: 119 | self.quantized_weights() 120 | else: 121 | self.full_precision_weights() 122 | 123 | def grad_scaling(self, grad_scaling=True): 124 | def _fn(module): 125 | if isinstance(module, QuantizerBase): 126 | module.grad_scaling = grad_scaling 127 | 128 | self.apply(_fn) 129 | # Methods for switching quantizer quantization states 130 | 131 | def learn_ranges(self): 132 | self.apply(_set_layer_learn_ranges) 133 | 134 | def fix_ranges(self): 135 | self.apply(_set_layer_fix_ranges) 136 | -------------------------------------------------------------------------------- /quantization/hijacker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import copy 6 | 7 | from timm.models.layers.activations import Swish, HardSwish, HardSigmoid 8 | from timm.models.layers.activations_me import SwishMe, HardSwishMe, HardSigmoidMe 9 | from torch import nn 10 | 11 | from quantization.base_quantized_classes import QuantizedModule 12 | from quantization.quantization_manager import QuantizationManager 13 | from quantization.range_estimators import RangeEstimators 14 | 15 | activations_set = [ 16 | nn.ReLU, 17 | nn.ReLU6, 18 | nn.Hardtanh, 19 | nn.Sigmoid, 20 | nn.Tanh, 21 | nn.GELU, 22 | nn.PReLU, 23 | Swish, 24 | SwishMe, 25 | HardSwish, 26 | HardSwishMe, 27 | HardSigmoid, 28 | HardSigmoidMe, 29 | ] 30 | 31 | 32 | class QuantizationHijacker(QuantizedModule): 33 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and 34 | dequantization on the weights and output distributions. 35 | 36 | Usage: 37 | To make a quantized nn.Linear layer: 38 | class HijackedLinear(QuantizationHijacker, nn.Linear): 39 | pass 40 | """ 41 | 42 | def __init__(self, *args, activation: nn.Module = None, **kwargs): 43 | 44 | super().__init__(*args, **kwargs) 45 | if activation: 46 | assert isinstance(activation, tuple(activations_set)), str(activation) 47 | 48 | self.activation_function = copy.deepcopy(activation) if activation else None 49 | 50 | self.activation_quantizer = QuantizationManager( 51 | qmethod=self.act_method, 52 | init=self.act_range_method, 53 | qparams=self.act_qparams, 54 | range_estim_params=self.act_range_options, 55 | ) 56 | 57 | if self.weight_range_method == RangeEstimators.current_minmax: 58 | weight_init_params = dict(percentile=self.percentile) 59 | else: 60 | weight_init_params = self.weight_range_options 61 | 62 | self.weight_quantizer = QuantizationManager( 63 | qmethod=self.method, 64 | init=self.weight_range_method, 65 | per_channel=self.per_channel_weights, 66 | qparams=self.weight_qparams, 67 | range_estim_params=weight_init_params, 68 | ) 69 | 70 | def forward(self, x, offsets=None): 71 | # Quantize input 72 | if self.quantize_input and self._quant_a: 73 | x = self.activation_quantizer(x) 74 | 75 | # Get quantized weight 76 | weight, bias = self.get_params() 77 | res = self.run_forward(x, weight, bias, offsets=offsets) 78 | 79 | # Apply fused activation function 80 | if self.activation_function is not None: 81 | res = self.activation_function(res) 82 | 83 | # Quantize output 84 | if not self.quantize_input and self._quant_a: 85 | res = self.activation_quantizer(res) 86 | return res 87 | 88 | def get_params(self): 89 | 90 | weight, bias = self.get_weight_bias() 91 | 92 | if self._quant_w: 93 | weight = self.quantize_weights(weight) 94 | 95 | return weight, bias 96 | 97 | def quantize_weights(self, weights): 98 | return self.weight_quantizer(weights) 99 | 100 | def get_weight_bias(self): 101 | bias = None 102 | if hasattr(self, "bias"): 103 | bias = self.bias 104 | return self.weight, bias 105 | 106 | def run_forward(self, x, weight, bias, offsets=None): 107 | # Performs the actual linear operation of the layer 108 | raise NotImplementedError() 109 | 110 | def extra_repr(self): 111 | activation = "input" if self.quantize_input else "output" 112 | return f"{super().extra_repr()}-{activation}" 113 | -------------------------------------------------------------------------------- /quantization/quant_error_estimator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import torch 6 | import numpy as np 7 | from utils.grid import integrate_pdf_grid_func_analyt 8 | from quantization.quantizers.fp8_quantizer import FPQuantizer, generate_all_float_values_scaled 9 | 10 | 11 | def generate_integr_grid_piecewise(integr_discontin, num_intervals_smallest_bin): 12 | bin_widths = np.diff(integr_discontin) 13 | min_bin_width = np.min(bin_widths[bin_widths > 0.0]) 14 | integr_min_step = min_bin_width / num_intervals_smallest_bin 15 | grid_list = [] 16 | for i in range(len(integr_discontin) - 1): 17 | curr_interv_min = integr_discontin[i] 18 | curr_interv_max = integr_discontin[i + 1] 19 | curr_interv_width = curr_interv_max - curr_interv_min 20 | 21 | if curr_interv_width == 0.0: 22 | continue 23 | assert curr_interv_min < curr_interv_max 24 | curr_interv_n_subintervals = np.ceil(curr_interv_width / integr_min_step).astype("int") 25 | curr_interv_n_pts = curr_interv_n_subintervals + 1 26 | lspace = torch.linspace(curr_interv_min, curr_interv_max, curr_interv_n_pts) 27 | grid_list.append(lspace) 28 | 29 | grid_all = torch.cat(grid_list) 30 | grid_all_no_dup = torch.unique(grid_all) 31 | 32 | return grid_all_no_dup 33 | 34 | 35 | def estimate_rounding_error_analyt(distr, grid): 36 | err = integrate_pdf_grid_func_analyt(distr, grid, "integr_interv_p_sqr_r") 37 | return err 38 | 39 | 40 | def estimate_dot_prod_error_analyt(distr_x, grid_x, distr_y, grid_y): 41 | rounding_err_x = integrate_pdf_grid_func_analyt(distr_x, grid_x, "integr_interv_p_sqr_r") 42 | rounding_err_y = integrate_pdf_grid_func_analyt(distr_y, grid_y, "integr_interv_p_sqr_r") 43 | second_moment_x = distr_x.eval_non_central_second_moment() 44 | second_moment_y = distr_y.eval_non_central_second_moment() 45 | y_p_y_R_y_signed = integrate_pdf_grid_func_analyt(distr_y, grid_y, "integr_interv_x_p_signed_r") 46 | x_p_x_R_x_signed = integrate_pdf_grid_func_analyt(distr_x, grid_x, "integr_interv_x_p_signed_r") 47 | 48 | term_rounding_err_x_moment_y = rounding_err_x * second_moment_y 49 | term_rounding_err_y_moment_x = rounding_err_y * second_moment_x 50 | term_mixed_rounding_err = rounding_err_x * rounding_err_y 51 | term_mixed_R_signed = 2.0 * y_p_y_R_y_signed * x_p_x_R_x_signed 52 | term_rounding_err_x_R_y_signed = 2.0 * rounding_err_x * y_p_y_R_y_signed 53 | term_rounding_err_y_R_x_signed = 2.0 * rounding_err_y * x_p_x_R_x_signed 54 | 55 | total_sum = ( 56 | term_rounding_err_x_moment_y 57 | + term_rounding_err_y_moment_x 58 | + term_mixed_R_signed 59 | + term_mixed_rounding_err 60 | + term_rounding_err_x_R_y_signed 61 | + term_rounding_err_y_R_x_signed 62 | ) 63 | 64 | return total_sum 65 | 66 | 67 | def estimate_rounding_error_empirical(W, quantizer, range_min, range_max): 68 | quantizer.set_quant_range(range_min, range_max) 69 | W_int_quant = quantizer.forward(W) 70 | 71 | round_err_sqr_int_quant_emp = (W_int_quant - W) ** 2 72 | res = torch.mean(round_err_sqr_int_quant_emp.flatten()).item() 73 | return res 74 | 75 | 76 | def estimate_dot_prod_error_empirical( 77 | x, y, quantizer_x, quantizer_y, x_range_min, x_range_max, y_range_min, y_range_max 78 | ): 79 | quantizer_x.set_quant_range(x_range_min, x_range_max) 80 | quantizer_y.set_quant_range(y_range_min, y_range_max) 81 | x_quant = quantizer_x.forward(x) 82 | y_quant = quantizer_y.forward(y) 83 | 84 | scalar_prod_err_emp = (torch.mul(x, y) - torch.mul(x_quant, y_quant)) ** 2 85 | res = torch.mean(scalar_prod_err_emp.flatten()).item() 86 | return res 87 | 88 | 89 | def compute_expected_dot_prod_mse( 90 | distr_x, 91 | distr_y, 92 | quant_x, 93 | quant_y, 94 | quant_x_range_min, 95 | quant_x_range_max, 96 | quant_y_range_min, 97 | quant_y_range_max, 98 | num_samples=2000000, 99 | ): 100 | 101 | quant_x.set_quant_range(quant_x_range_min, quant_x_range_max) 102 | quant_y.set_quant_range(quant_y_range_min, quant_y_range_max) 103 | if isinstance(quant_x, FPQuantizer): 104 | grid_x = generate_all_float_values_scaled( 105 | quant_x.n_bits, quant_x.ebits, quant_x.default_bias, quant_x_range_max 106 | ) 107 | else: 108 | grid_x = quant_x.generate_grid().numpy() 109 | 110 | if isinstance(quant_y, FPQuantizer): 111 | grid_y = generate_all_float_values_scaled( 112 | quant_y.n_bits, quant_y.ebits, quant_y.default_bias, quant_y_range_max 113 | ) 114 | else: 115 | grid_y = quant_x.generate_grid().numpy() 116 | 117 | err_analyt = estimate_dot_prod_error_analyt(distr_x, grid_x, distr_y, grid_y) 118 | distr_x_sample = torch.tensor(distr_x.sample((num_samples,))) 119 | distr_y_sample = torch.tensor(distr_x.sample((num_samples,))) 120 | err_emp = estimate_dot_prod_error_empirical( 121 | distr_x_sample, 122 | distr_y_sample, 123 | quant_x, 124 | quant_y, 125 | quant_x_range_min, 126 | quant_x_range_max, 127 | quant_y_range_min, 128 | quant_y_range_max, 129 | ) 130 | 131 | rel_err = np.abs((err_emp - err_analyt) / err_analyt) 132 | return err_analyt 133 | 134 | 135 | def compute_expected_quant_mse(distr, quant, quant_range_min, quant_range_max, num_samples): 136 | quant.set_quant_range(quant_range_min, quant_range_max) 137 | if isinstance(quant, FPQuantizer): 138 | grid = generate_all_float_values_scaled( 139 | quant.n_bits, quant.ebits, quant.default_bias, quant_range_max 140 | ) 141 | else: 142 | grid = quant.generate_grid().numpy() 143 | 144 | err_analyt = estimate_rounding_error_analyt(distr, grid) 145 | distr_sample = torch.tensor( 146 | distr.sample( 147 | num_samples, 148 | ) 149 | ) 150 | err_emp = estimate_rounding_error_empirical( 151 | distr_sample, quant, quant_range_min, quant_range_max 152 | ) 153 | 154 | rel_err = np.abs((err_emp - err_analyt) / err_analyt) 155 | if rel_err > 0.1: 156 | print( 157 | "Warning: the relative difference between the analytical and empirical error estimate is too high,\n" 158 | "please consider increasing the number of samples for the quantization range estimator." 159 | ) 160 | 161 | return err_analyt 162 | -------------------------------------------------------------------------------- /quantization/quantization_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from enum import auto 6 | 7 | from torch import nn 8 | from quantization.quantizers import QuantizerBase 9 | from quantization.quantizers.utils import QuantizerNotInitializedError 10 | from quantization.range_estimators import RangeEstimators, RangeEstimatorBase 11 | from utils import BaseEnumOptions 12 | 13 | from quantization.quantizers.uniform_quantizers import ( 14 | SymmetricUniformQuantizer, 15 | AsymmetricUniformQuantizer, 16 | ) 17 | from quantization.quantizers.fp8_quantizer import FPQuantizer 18 | 19 | from utils import ClassEnumOptions, MethodMap 20 | 21 | 22 | class QMethods(ClassEnumOptions): 23 | symmetric_uniform = MethodMap(SymmetricUniformQuantizer) 24 | asymmetric_uniform = MethodMap(AsymmetricUniformQuantizer) 25 | fp_quantizer = MethodMap(FPQuantizer) 26 | 27 | 28 | class QuantizationManager(nn.Module): 29 | """Implementation of Quantization and Quantization Range Estimation 30 | 31 | Parameters 32 | ---------- 33 | n_bits: int 34 | Number of bits for the quantization. 35 | qmethod: QMethods member (Enum) 36 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform, 37 | qmn_uniform etc. 38 | init: RangeEstimators member (Enum) 39 | Initialization method for the grid from 40 | per_channel: bool 41 | If true, will use a separate quantization grid for each kernel/channel. 42 | x_min: float or PyTorch Tensor 43 | The minimum value which needs to be represented. 44 | x_max: float or PyTorch Tensor 45 | The maximum value which needs to be represented. 46 | qparams: kwargs 47 | dictionary of quantization parameters to passed to the quantizer instantiation 48 | range_estim_params: kwargs 49 | dictionary of parameters to passed to the range estimator instantiation 50 | """ 51 | 52 | def __init__( 53 | self, 54 | qmethod: QuantizerBase = QMethods.symmetric_uniform.cls, 55 | init: RangeEstimatorBase = RangeEstimators.current_minmax.cls, 56 | per_channel=False, 57 | x_min=None, 58 | x_max=None, 59 | qparams=None, 60 | range_estim_params=None, 61 | ): 62 | super().__init__() 63 | self.state = Qstates.estimate_ranges 64 | self.qmethod = qmethod 65 | self.init = init 66 | self.per_channel = per_channel 67 | self.qparams = qparams if qparams else {} 68 | self.range_estim_params = range_estim_params if range_estim_params else {} 69 | self.range_estimator = None 70 | 71 | # define quantizer 72 | self.quantizer = self.qmethod(per_channel=self.per_channel, **qparams) 73 | self.quantizer.state = self.state 74 | 75 | # define range estimation method for quantizer initialisation 76 | if x_min is not None and x_max is not None: 77 | self.set_quant_range(x_min, x_max) 78 | self.fix_ranges() 79 | else: 80 | # set up the collector function to set the ranges 81 | self.range_estimator = self.init( 82 | per_channel=self.per_channel, quantizer=self.quantizer, **self.range_estim_params 83 | ) 84 | 85 | @property 86 | def n_bits(self): 87 | return self.quantizer.n_bits 88 | 89 | def estimate_ranges(self): 90 | self.state = Qstates.estimate_ranges 91 | self.quantizer.state = self.state 92 | 93 | def fix_ranges(self): 94 | if self.quantizer.is_initialized: 95 | self.state = Qstates.fix_ranges 96 | self.quantizer.state = self.state 97 | else: 98 | raise QuantizerNotInitializedError() 99 | 100 | def learn_ranges(self): 101 | self.quantizer.make_range_trainable() 102 | self.state = Qstates.learn_ranges 103 | self.quantizer.state = self.state 104 | 105 | def estimate_ranges_train(self): 106 | self.state = Qstates.estimate_ranges_train 107 | self.quantizer.state = self.state 108 | 109 | def reset_ranges(self): 110 | self.range_estimator.reset() 111 | self.quantizer.reset() 112 | self.estimate_ranges() 113 | 114 | def forward(self, x): 115 | if self.state == Qstates.estimate_ranges or ( 116 | self.state == Qstates.estimate_ranges_train and self.training 117 | ): 118 | # Note this can be per tensor or per channel 119 | cur_xmin, cur_xmax = self.range_estimator(x) 120 | self.set_quant_range(cur_xmin, cur_xmax) 121 | 122 | return self.quantizer(x) 123 | 124 | def set_quant_range(self, x_min, x_max): 125 | self.quantizer.set_quant_range(x_min, x_max) 126 | 127 | def extra_repr(self): 128 | return "state={}".format(self.state.name) 129 | 130 | 131 | class Qstates(BaseEnumOptions): 132 | estimate_ranges = auto() # ranges are updated in eval and train mode 133 | fix_ranges = auto() # quantization ranges are fixed for train and eval 134 | learn_ranges = auto() # quantization params are nn.Parameters 135 | estimate_ranges_train = auto() # quantization ranges are updated during train and fixed for 136 | # eval 137 | -------------------------------------------------------------------------------- /quantization/quantized_folded_bn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.nn.modules.conv import _ConvNd 8 | 9 | from quantization.hijacker import QuantizationHijacker 10 | 11 | 12 | class BNFusedHijacker(QuantizationHijacker): 13 | """Extension to the QuantizationHijacker that fuses batch normalization (BN) after a weight 14 | layer into a joined module. The parameters and the statistics of the BN layer remain in 15 | full-precision. 16 | """ 17 | 18 | def __init__(self, *args, **kwargs): 19 | kwargs.pop("bias", None) # Bias will be learned by BN params 20 | super().__init__(*args, **kwargs, bias=False) 21 | bn_dim = self.get_bn_dim() 22 | self.register_buffer("running_mean", torch.zeros(bn_dim)) 23 | self.register_buffer("running_var", torch.ones(bn_dim)) 24 | self.momentum = kwargs.pop("momentum", 0.1) 25 | self.gamma = nn.Parameter(torch.ones(bn_dim)) 26 | self.beta = nn.Parameter(torch.zeros(bn_dim)) 27 | self.epsilon = kwargs.get("eps", 1e-5) 28 | self.bias = None 29 | 30 | def forward(self, x): 31 | # Quantize input 32 | if self.quantize_input and self._quant_a: 33 | x = self.activation_quantizer(x) 34 | 35 | # Get quantized weight 36 | weight, bias = self.get_params() 37 | res = self.run_forward(x, weight, bias) 38 | 39 | res = F.batch_norm( 40 | res, 41 | self.running_mean, 42 | self.running_var, 43 | self.gamma, 44 | self.beta, 45 | self.training, 46 | self.momentum, 47 | self.epsilon, 48 | ) 49 | # Apply fused activation function 50 | if self.activation_function is not None: 51 | res = self.activation_function(res) 52 | 53 | # Quantize output 54 | if not self.quantize_input and self._quant_a: 55 | res = self.activation_quantizer(res) 56 | return res 57 | 58 | def get_bn_dim(self): 59 | if isinstance(self, nn.Linear): 60 | return self.out_features 61 | elif isinstance(self, _ConvNd): 62 | return self.out_channels 63 | else: 64 | msg = ( 65 | f"Unsupported type used: {self}. Must be a linear or (transpose)-convolutional " 66 | f"nn.Module" 67 | ) 68 | raise NotImplementedError(msg) 69 | -------------------------------------------------------------------------------- /quantization/quantizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from quantization.quantizers.base_quantizers import QuantizerBase 6 | from quantization.quantizers.fp8_quantizer import FPQuantizer 7 | from quantization.quantizers.uniform_quantizers import ( 8 | AsymmetricUniformQuantizer, 9 | SymmetricUniformQuantizer, 10 | ) 11 | -------------------------------------------------------------------------------- /quantization/quantizers/base_quantizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from torch import nn 6 | 7 | 8 | class QuantizerBase(nn.Module): 9 | def __init__(self, n_bits, per_channel=False, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.n_bits = n_bits 12 | self.per_channel = per_channel 13 | self.state = None 14 | self.x_min_fp32 = self.x_max_fp32 = None 15 | 16 | @property 17 | def is_initialized(self): 18 | raise NotImplementedError() 19 | 20 | @property 21 | def x_max(self): 22 | raise NotImplementedError() 23 | 24 | @property 25 | def symmetric(self): 26 | raise NotImplementedError() 27 | 28 | @property 29 | def x_min(self): 30 | raise NotImplementedError() 31 | 32 | def forward(self, x_float): 33 | raise NotImplementedError() 34 | 35 | def _adjust_params_per_channel(self, x): 36 | raise NotImplementedError() 37 | 38 | def set_quant_range(self, x_min, x_max): 39 | raise NotImplementedError() 40 | 41 | def extra_repr(self): 42 | return "n_bits={}, per_channel={}, is_initalized={}".format( 43 | self.n_bits, self.per_channel, self.is_initialized 44 | ) 45 | 46 | def reset(self): 47 | self._delta = None 48 | -------------------------------------------------------------------------------- /quantization/quantizers/fp8_quantizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import torch 5 | import torch.nn as nn 6 | from quantization.quantizers.base_quantizers import QuantizerBase 7 | import numpy as np 8 | from itertools import product 9 | from torch.autograd import Function 10 | from quantization.quantizers.rounding_utils import round_ste_func 11 | 12 | 13 | def generate_all_values_fp( 14 | num_total_bits: int = 8, num_exponent_bits: int = 4, bias: int = 8 15 | ) -> list: 16 | num_fraction_bits = num_total_bits - 1 - num_exponent_bits 17 | 18 | all_values = [] 19 | exp_lower = -bias 20 | for S in [-1.0, 1.0]: 21 | for E_str_iter in product(*[[0, 1]] * num_exponent_bits): 22 | for F_str_iter in product(*[[0, 1]] * num_fraction_bits): 23 | E_str = "".join(str(i) for i in E_str_iter) 24 | F_str = "".join(str(i) for i in F_str_iter) 25 | 26 | # encoded exponent 27 | E_enc = decode_binary_str(E_str) 28 | E_eff = E_enc - bias 29 | if E_eff == exp_lower: 30 | is_subnormal = 1 31 | else: 32 | is_subnormal = 0 33 | 34 | F_enc = decode_binary_str(F_str) * 2**-num_fraction_bits 35 | F_eff = F_enc + 1 - is_subnormal 36 | 37 | fp8_val = S * 2.0 ** (E_enc - bias + is_subnormal) * F_eff 38 | all_values.append(fp8_val) 39 | res = np.array(all_values) 40 | res = np.sort(res) 41 | return res 42 | 43 | 44 | def generate_all_float_values_scaled(num_total_bits, num_exp_bits, exp_bias, range_limit_fp): 45 | grid = generate_all_values_fp(num_total_bits, num_exp_bits, exp_bias) 46 | float_max_abs_val = np.max(np.abs(grid)) 47 | 48 | float_scale = float_max_abs_val / range_limit_fp 49 | floats_all = grid / float_scale 50 | return floats_all 51 | 52 | 53 | def decode_float8(S, E, F, bias=16): 54 | sign = -2 * int(S) + 1 55 | exponent = int(E, 2) if E else 0 56 | # Normal FP8 : exponent > 0 : result = 2^(exponent-bias) * 1.F 57 | # Subnormal FP8: exponent == 0: result = 2^(-bias+1) * 0.F 58 | # Lowest quantization bin: 2^(-bias+1) * {0.0 ... 1 + (2^mantissa-1)/2^mantissa} 59 | # All other bins : 2^(exponent-bias) * {1.0 ... 1 + (2^mantissa-1)/2^mantissa}; exponent > 0 60 | A = int(exponent != 0) 61 | fraction = A + sum([2 ** -(i + 1) * int(a) for i, a in enumerate(F)]) 62 | exponent += int(exponent == 0) 63 | return sign * fraction * 2.0 ** (exponent - bias) 64 | 65 | 66 | def i(x): 67 | return np.array([x]).astype(np.int32) 68 | 69 | 70 | def gen(n_bits, exponent_bits, bias): 71 | all_values = [] 72 | for s in product(*[[0, 1]] * 1): 73 | for e in product(*[[0, 1]] * exponent_bits): 74 | for m in product(*[[0, 1]] * (n_bits - 1 - exponent_bits)): 75 | s = str(s[0]) 76 | e = "".join(str(i) for i in e) 77 | m = "".join(str(i) for i in m) 78 | all_values.append(decode_float8(s, e, m, bias=bias)) 79 | return sorted(all_values) 80 | 81 | 82 | def get_max_value(num_exponent_bits: int = 4, bias: int = 8): 83 | num_fraction_bits = 7 - num_exponent_bits 84 | scale = 2**-num_fraction_bits 85 | max_frac = 1 - scale 86 | max_value = 2 ** (2**num_exponent_bits - 1 - bias) * (1 + max_frac) 87 | 88 | return max_value 89 | 90 | 91 | def quantize_to_fp8_ste_MM( 92 | x_float: torch.Tensor, 93 | n_bits: int, 94 | maxval: torch.Tensor, 95 | num_mantissa_bits: torch.Tensor, 96 | sign_bits: int, 97 | ) -> torch.Tensor: 98 | """ 99 | Simpler FP8 quantizer that exploits the fact that FP quantization is just INT quantization with 100 | scales that depend on the input. 101 | 102 | This allows to define FP8 quantization using STE rounding functions and thus learn the bias 103 | 104 | """ 105 | M = torch.clamp(round_ste_func(num_mantissa_bits), 1, n_bits - sign_bits) 106 | E = n_bits - sign_bits - M 107 | 108 | if maxval.shape[0] != 1 and len(maxval.shape) != len(x_float.shape): 109 | maxval = maxval.view([-1] + [1] * (len(x_float.shape) - 1)) 110 | bias = 2**E - torch.log2(maxval) + torch.log2(2 - 2 ** (-M)) - 1 111 | 112 | minval = -maxval if sign_bits == 1 else torch.zeros_like(maxval) 113 | xc = torch.min(torch.max(x_float, minval), maxval) 114 | 115 | """ 116 | 2 notes here: 117 | 1: Shifting by bias to ensure data is aligned to the scaled grid in case bias not in Z. 118 | Recall that implicitly bias := bias' - log2(alpha), where bias' in Z. If we assume 119 | alpha in (0.5, 1], then alpha contracts the grid, which is equivalent to translate the 120 | data 'to the right' relative to the grid, which is what the subtraction of log2(alpha) 121 | (which is negative) accomplishes. 122 | 2: Ideally this procedure doesn't affect gradients wrt the input (we want to use the STE). 123 | We can achieve this by detaching log2 of the (absolute) input. 124 | 125 | """ 126 | 127 | # log_scales = torch.max((torch.floor(torch.log2(torch.abs(xc)) + bias)).detach(), 1.) 128 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(xc)) + bias)).detach(), 1.0) 129 | 130 | scales = 2.0 ** (log_scales - M - bias) 131 | 132 | result = round_ste_func(xc / scales) * scales 133 | return result 134 | 135 | 136 | class FP8QuantizerFunc(Function): 137 | @staticmethod 138 | def forward(ctx, x_float, bias, num_exponent_bits): 139 | return quantize_to_fp8_ste_MM(x_float, bias, num_exponent_bits) 140 | 141 | @staticmethod 142 | def backward(ctx, grad_output): 143 | return grad_output, None, None 144 | 145 | 146 | def decode_binary_str(F_str): 147 | F = sum([2 ** -(i + 1) * int(a) for i, a in enumerate(F_str)]) * 2 ** len(F_str) 148 | return F 149 | 150 | 151 | class FPQuantizer(QuantizerBase): 152 | """ 153 | 8-bit Floating Point Quantizer 154 | """ 155 | 156 | def __init__( 157 | self, 158 | *args, 159 | scale_domain=None, 160 | mantissa_bits=4, 161 | maxval=3, 162 | set_maxval=False, 163 | learn_maxval=False, 164 | learn_mantissa_bits=False, 165 | mse_include_mantissa_bits=True, 166 | allow_unsigned=False, 167 | **kwargs, 168 | ): 169 | super().__init__(*args, **kwargs) 170 | 171 | self.mantissa_bits = mantissa_bits 172 | 173 | self.ebits = self.n_bits - self.mantissa_bits - 1 174 | self.default_bias = 2 ** (self.ebits - 1) 175 | 176 | # assume signed, correct when range setting turns out to be unsigned 177 | default_maxval = (2 - 2 ** (-self.mantissa_bits)) * 2 ** ( 178 | 2**self.ebits - 1 - self.default_bias 179 | ) 180 | 181 | self.maxval = maxval if maxval is not None else default_maxval 182 | 183 | self.maxval = torch.Tensor([self.maxval]) 184 | self.mantissa_bits = torch.Tensor([float(self.mantissa_bits)]) 185 | 186 | self.set_maxval = set_maxval 187 | self.learning_maxval = learn_maxval 188 | self.learning_mantissa_bits = learn_mantissa_bits 189 | self.mse_include_mantissa_bits = mse_include_mantissa_bits 190 | 191 | self.allow_unsigned = allow_unsigned 192 | self.sign_bits = 1 193 | 194 | def forward(self, x_float): 195 | if self.maxval.device != x_float.device: 196 | self.maxval = self.maxval.to(x_float.device) 197 | if self.mantissa_bits.device != x_float.device: 198 | self.mantissa_bits = self.mantissa_bits.to(x_float.device) 199 | 200 | res = quantize_to_fp8_ste_MM( 201 | x_float, self.n_bits, self.maxval, self.mantissa_bits, self.sign_bits 202 | ) 203 | 204 | ebits = self.n_bits - self.mantissa_bits - 1 205 | return res 206 | 207 | def is_initialized(self): 208 | return True 209 | 210 | def symmetric(self): 211 | return False 212 | 213 | def effective_bit_width(self): 214 | return None 215 | 216 | def _make_unsigned(self, x_min): 217 | if isinstance(x_min, torch.Tensor): 218 | return self.allow_unsigned and torch.all(x_min >= 0) 219 | else: 220 | return self.allow_unsigned and x_min >= 0 221 | 222 | def set_quant_range(self, x_min, x_max): 223 | 224 | if self._make_unsigned(x_min): 225 | self.sign_bits = 0 226 | 227 | if self.set_maxval: 228 | if not isinstance(x_max, torch.Tensor): 229 | x_max = torch.Tensor([x_max]).to(self.maxval.device) 230 | x_min = torch.Tensor([x_min]).to(self.maxval.device) 231 | if self.maxval.device != x_max.device: 232 | self.maxval = self.maxval.to(x_max.device) 233 | if self.mantissa_bits.device != x_max.device: 234 | self.mantissa_bits = self.mantissa_bits.to(x_max.device) 235 | 236 | mx = torch.abs(torch.max(torch.abs(x_min), x_max)) 237 | self.maxval = mx 238 | 239 | if not isinstance(self.maxval, torch.Tensor) or len(self.maxval.shape) == 0: 240 | self.maxval = torch.Tensor([self.maxval]) 241 | 242 | def make_range_trainable(self): 243 | if self.learning_maxval: 244 | self.learn_maxval() 245 | if self.learning_mantissa_bits: 246 | self.learn_mantissa_bits() 247 | 248 | def learn_maxval(self): 249 | self.learning_maxval = True 250 | self.maxval = torch.nn.Parameter(self.maxval) 251 | 252 | def learn_mantissa_bits(self): 253 | self.learning_mantissa_bits = True 254 | self.mantissa_bits = torch.nn.Parameter(self.mantissa_bits) 255 | 256 | def fix_ranges(self): 257 | if isinstance(self.maxval, nn.Parameter): 258 | self.parameter_to_fixed("maxval") 259 | if isinstance(self.mantissa_bits, nn.Parameter): 260 | self.parameter_to_fixed("mantissa_bits") 261 | 262 | def extra_repr(self): 263 | maxval = self.maxval 264 | 265 | M = torch.clamp(torch.round(self.mantissa_bits), 1, 7) 266 | E = 7 - M 267 | maxval = 2**E - torch.log2(self.maxval) + torch.log2(2 - 2 ** (-M)) - 1 268 | if maxval.shape[0] > 1: 269 | bstr = "[per_channel]" 270 | else: 271 | bstr = f"{maxval.item()}" 272 | return f"Exponent: {E.item()} bits; mode: ; bias: {bstr}" 273 | -------------------------------------------------------------------------------- /quantization/quantizers/rounding_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from torch import nn 5 | import torch 6 | from torch.autograd import Function 7 | 8 | # Functional 9 | from utils import MethodMap, ClassEnumOptions 10 | 11 | 12 | class RoundStraightThrough(Function): 13 | @staticmethod 14 | def forward(ctx, x): 15 | return torch.round(x) 16 | 17 | @staticmethod 18 | def backward(ctx, output_grad): 19 | return output_grad 20 | 21 | 22 | class StochasticRoundSTE(Function): 23 | @staticmethod 24 | def forward(ctx, x): 25 | # Sample noise between [0, 1) 26 | noise = torch.rand_like(x) 27 | return torch.floor(x + noise) 28 | 29 | @staticmethod 30 | def backward(ctx, output_grad): 31 | return output_grad 32 | 33 | 34 | class ScaleGradient(Function): 35 | @staticmethod 36 | def forward(ctx, x, scale): 37 | ctx.scale = scale 38 | return x 39 | 40 | @staticmethod 41 | def backward(ctx, output_grad): 42 | return output_grad * ctx.scale, None 43 | 44 | 45 | class EWGSFunctional(Function): 46 | """ 47 | x_in: float input 48 | scaling_factor: backward scaling factor 49 | x_out: discretized version of x_in within the range of [0,1] 50 | """ 51 | 52 | @staticmethod 53 | def forward(ctx, x_in, scaling_factor): 54 | x_int = torch.round(x_in) 55 | ctx._scaling_factor = scaling_factor 56 | ctx.save_for_backward(x_in - x_int) 57 | return x_int 58 | 59 | @staticmethod 60 | def backward(ctx, g): 61 | diff = ctx.saved_tensors[0] 62 | delta = ctx._scaling_factor 63 | scale = 1 + delta * torch.sign(g) * diff 64 | return g * scale, None, None 65 | 66 | 67 | class StackSigmoidFunctional(Function): 68 | @staticmethod 69 | def forward(ctx, x, alpha): 70 | # Apply round to nearest in the forward pass 71 | ctx.save_for_backward(x, alpha) 72 | return torch.round(x) 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | x, alpha = ctx.saved_tensors 77 | sig_min = torch.sigmoid(alpha / 2) 78 | sig_scale = 1 - 2 * sig_min 79 | x_base = torch.floor(x).detach() 80 | x_rest = x - x_base - 0.5 81 | stacked_sigmoid_grad = ( 82 | torch.sigmoid(x_rest * -alpha) 83 | * (1 - torch.sigmoid(x_rest * -alpha)) 84 | * -alpha 85 | / sig_scale 86 | ) 87 | return stacked_sigmoid_grad * grad_output, None 88 | 89 | 90 | # Parametrized modules 91 | class ParametrizedGradEstimatorBase(nn.Module): 92 | def __init__(self, *args, **kwargs): 93 | super().__init__() 94 | self._trainable = False 95 | 96 | def make_grad_params_trainable(self): 97 | self._trainable = True 98 | for name, buf in self.named_buffers(recurse=False): 99 | setattr(self, name, torch.nn.Parameter(buf)) 100 | 101 | def make_grad_params_tensor(self): 102 | self._trainable = False 103 | for name, param in self.named_parameters(recurse=False): 104 | cur_value = param.data 105 | delattr(self, name) 106 | self.register_buffer(name, cur_value) 107 | 108 | def forward(self, x): 109 | raise NotImplementedError() 110 | 111 | 112 | class StackedSigmoid(ParametrizedGradEstimatorBase): 113 | """ 114 | Stacked sigmoid estimator based on a simulated sigmoid forward pass 115 | """ 116 | 117 | def __init__(self, alpha=1.0): 118 | super().__init__() 119 | self.register_buffer("alpha", torch.tensor(alpha)) 120 | 121 | def forward(self, x): 122 | return stacked_sigmoid_func(x, self.alpha) 123 | 124 | def extra_repr(self): 125 | return f"alpha={self.alpha.item()}" 126 | 127 | 128 | class EWGSDiscretizer(ParametrizedGradEstimatorBase): 129 | def __init__(self, scaling_factor=0.2): 130 | super().__init__() 131 | self.register_buffer("scaling_factor", torch.tensor(scaling_factor)) 132 | 133 | def forward(self, x): 134 | return ewgs_func(x, self.scaling_factor) 135 | 136 | def extra_repr(self): 137 | return f"scaling_factor={self.scaling_factor.item()}" 138 | 139 | 140 | class StochasticRounding(nn.Module): 141 | def __init__(self): 142 | super().__init__() 143 | 144 | def forward(self, x): 145 | if self.training: 146 | return stochastic_round_ste_func(x) 147 | else: 148 | return round_ste_func(x) 149 | 150 | 151 | round_ste_func = RoundStraightThrough.apply 152 | stacked_sigmoid_func = StackSigmoidFunctional.apply 153 | scale_grad_func = ScaleGradient.apply 154 | stochastic_round_ste_func = StochasticRoundSTE.apply 155 | ewgs_func = EWGSFunctional.apply 156 | 157 | 158 | class GradientEstimator(ClassEnumOptions): 159 | ste = MethodMap(round_ste_func) 160 | stoch_round = MethodMap(StochasticRounding) 161 | ewgs = MethodMap(EWGSDiscretizer) 162 | stacked_sigmoid = MethodMap(StackedSigmoid) 163 | -------------------------------------------------------------------------------- /quantization/quantizers/uniform_quantizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import inspect 6 | import torch 7 | 8 | from quantization.quantizers.rounding_utils import scale_grad_func, round_ste_func 9 | from .utils import QuantizerNotInitializedError 10 | from .base_quantizers import QuantizerBase 11 | 12 | 13 | class AsymmetricUniformQuantizer(QuantizerBase): 14 | """ 15 | PyTorch Module that implements Asymmetric Uniform Quantization using STE. 16 | Quantizes its argument in the forward pass, passes the gradient 'straight 17 | through' on the backward pass, ignoring the quantization that occurred. 18 | 19 | Parameters 20 | ---------- 21 | n_bits: int 22 | Number of bits for quantization. 23 | scale_domain: str ('log', 'linear) with default='linear' 24 | Domain of scale factor 25 | per_channel: bool 26 | If True: allows for per-channel quantization 27 | """ 28 | 29 | def __init__( 30 | self, 31 | n_bits, 32 | scale_domain="linear", 33 | discretizer=round_ste_func, 34 | discretizer_args=tuple(), 35 | grad_scaling=False, 36 | eps=1e-8, 37 | **kwargs 38 | ): 39 | super().__init__(n_bits=n_bits, **kwargs) 40 | 41 | assert scale_domain in ("linear", "log") 42 | self.register_buffer("_delta", None) 43 | self.register_buffer("_zero_float", None) 44 | 45 | if inspect.isclass(discretizer): 46 | self.discretizer = discretizer(*discretizer_args) 47 | else: 48 | self.discretizer = discretizer 49 | 50 | self.scale_domain = scale_domain 51 | self.grad_scaling = grad_scaling 52 | self.eps = eps 53 | 54 | # A few useful properties 55 | @property 56 | def delta(self): 57 | if self._delta is not None: 58 | return self._delta 59 | else: 60 | raise QuantizerNotInitializedError() 61 | 62 | @property 63 | def zero_float(self): 64 | if self._zero_float is not None: 65 | return self._zero_float 66 | else: 67 | raise QuantizerNotInitializedError() 68 | 69 | @property 70 | def is_initialized(self): 71 | return self._delta is not None 72 | 73 | @property 74 | def symmetric(self): 75 | return False 76 | 77 | @property 78 | def int_min(self): 79 | # integer grid minimum 80 | return 0.0 81 | 82 | @property 83 | def int_max(self): 84 | # integer grid maximum 85 | return 2.0**self.n_bits - 1 86 | 87 | @property 88 | def scale(self): 89 | if self.scale_domain == "linear": 90 | return torch.clamp(self.delta, min=self.eps) 91 | elif self.scale_domain == "log": 92 | return torch.exp(self.delta) 93 | 94 | @property 95 | def zero_point(self): 96 | zero_point = self.discretizer(self.zero_float) 97 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max) 98 | return zero_point 99 | 100 | @property 101 | def x_max(self): 102 | return self.scale * (self.int_max - self.zero_point) 103 | 104 | @property 105 | def x_min(self): 106 | return self.scale * (self.int_min - self.zero_point) 107 | 108 | def to_integer_forward(self, x_float, *args, **kwargs): 109 | """ 110 | Qunatized input to its integer representation 111 | Parameters 112 | ---------- 113 | x_float: PyTorch Float Tensor 114 | Full-precision Tensor 115 | 116 | Returns 117 | ------- 118 | x_int: PyTorch Float Tensor of integers 119 | """ 120 | if self.grad_scaling: 121 | grad_scale = self.calculate_grad_scale(x_float) 122 | scale = scale_grad_func(self.scale, grad_scale) 123 | zero_point = ( 124 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale) 125 | ) 126 | else: 127 | scale = self.scale 128 | zero_point = self.zero_point 129 | 130 | x_int = self.discretizer(x_float / scale) + zero_point 131 | x_int = torch.clamp(x_int, self.int_min, self.int_max) 132 | 133 | return x_int 134 | 135 | def forward(self, x_float, *args, **kwargs): 136 | """ 137 | Quantizes (quantized to integer and the scales back to original domain) 138 | Parameters 139 | ---------- 140 | x_float: PyTorch Float Tensor 141 | Full-precision Tensor 142 | 143 | Returns 144 | ------- 145 | x_quant: PyTorch Float Tensor 146 | Quantized-Dequantized Tensor 147 | """ 148 | if self.per_channel: 149 | self._adjust_params_per_channel(x_float) 150 | 151 | if self.grad_scaling: 152 | grad_scale = self.calculate_grad_scale(x_float) 153 | scale = scale_grad_func(self.scale, grad_scale) 154 | zero_point = ( 155 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale) 156 | ) 157 | else: 158 | scale = self.scale 159 | zero_point = self.zero_point 160 | 161 | x_int = self.to_integer_forward(x_float, *args, **kwargs) 162 | x_quant = scale * (x_int - zero_point) 163 | 164 | return x_quant 165 | 166 | def calculate_grad_scale(self, quant_tensor): 167 | num_pos_levels = self.int_max # Qp in LSQ paper 168 | num_elements = quant_tensor.numel() # nfeatures or nweights in LSQ paper 169 | if self.per_channel: 170 | # In the per tensor case we do not sum the gradients over the output channel dimension 171 | num_elements /= quant_tensor.shape[0] 172 | 173 | return (num_pos_levels * num_elements) ** -0.5 # 1 / sqrt (Qn * nfeatures) 174 | 175 | def _adjust_params_per_channel(self, x): 176 | """ 177 | Adjusts the quantization parameter tensors (delta, zero_float) 178 | to the input tensor shape if they don't match 179 | Parameters 180 | ---------- 181 | x: input tensor 182 | """ 183 | if x.ndim != self.delta.ndim: 184 | new_shape = [-1] + [1] * (len(x.shape) - 1) 185 | self._delta = self.delta.view(new_shape) 186 | if self._zero_float is not None: 187 | self._zero_float = self._zero_float.view(new_shape) 188 | 189 | def _tensorize_min_max(self, x_min, x_max): 190 | """ 191 | Converts provided min max range into tensors 192 | Parameters 193 | ---------- 194 | x_min: float or PyTorch 1D tensor 195 | x_max: float or PyTorch 1D tensor 196 | 197 | Returns 198 | ------- 199 | x_min: PyTorch Tensor 0 or 1-D 200 | x_max: PyTorch Tensor 0 or 1-D 201 | """ 202 | # Ensure a torch tensor 203 | if not torch.is_tensor(x_min): 204 | x_min = torch.tensor(x_min).float() 205 | x_max = torch.tensor(x_max).float() 206 | 207 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel: 208 | print(x_min) 209 | print(self.per_channel) 210 | raise ValueError( 211 | "x_min and x_max must be a float or 1-D Tensor" 212 | " for per-tensor quantization (per_channel=False)" 213 | ) 214 | # Ensure we always use zero and avoid division by zero 215 | x_min = torch.min(x_min, torch.zeros_like(x_min)) 216 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps) 217 | 218 | return x_min, x_max 219 | 220 | def set_quant_range(self, x_min, x_max): 221 | """ 222 | Instantiates the quantization parameters based on the provided 223 | min and max range 224 | 225 | Parameters 226 | ---------- 227 | x_min: tensor or float 228 | Quantization range minimum limit 229 | x_max: tensor of float 230 | Quantization range minimum limit 231 | """ 232 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max 233 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 234 | self._delta = (x_max - x_min) / self.int_max 235 | self._zero_float = (-x_min / self.delta).detach() 236 | 237 | if self.scale_domain == "log": 238 | self._delta = torch.log(self.delta) 239 | 240 | self._delta = self._delta.detach() 241 | 242 | def make_range_trainable(self): 243 | # Converts trainable parameters to nn.Parameters 244 | if self.delta not in self.parameters(): 245 | self._delta = torch.nn.Parameter(self._delta) 246 | self._zero_float = torch.nn.Parameter(self._zero_float) 247 | 248 | def fix_ranges(self): 249 | # Removes trainable quantization params from nn.Parameters 250 | if self.delta in self.parameters(): 251 | _delta = self._delta.data 252 | _zero_float = self._zero_float.data 253 | del self._delta # delete the parameter 254 | del self._zero_float 255 | self.register_buffer("_delta", _delta) 256 | self.register_buffer("_zero_float", _zero_float) 257 | 258 | 259 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer): 260 | """ 261 | PyTorch Module that implements Symmetric Uniform Quantization using STE. 262 | Quantizes its argument in the forward pass, passes the gradient 'straight 263 | through' on the backward pass, ignoring the quantization that occurred. 264 | 265 | Parameters 266 | ---------- 267 | n_bits: int 268 | Number of bits for quantization. 269 | scale_domain: str ('log', 'linear) with default='linear' 270 | Domain of scale factor 271 | per_channel: bool 272 | If True: allows for per-channel quantization 273 | """ 274 | 275 | def __init__(self, *args, **kwargs): 276 | super().__init__(*args, **kwargs) 277 | self.register_buffer("_signed", None) 278 | 279 | @property 280 | def signed(self): 281 | if self._signed is not None: 282 | return self._signed.item() 283 | else: 284 | raise QuantizerNotInitializedError() 285 | 286 | @property 287 | def symmetric(self): 288 | return True 289 | 290 | @property 291 | def int_min(self): 292 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0 293 | 294 | @property 295 | def int_max(self): 296 | pos_n_bits = self.n_bits - self.signed 297 | return 2.0**pos_n_bits - 1 298 | 299 | @property 300 | def zero_point(self): 301 | return 0.0 302 | 303 | def set_quant_range(self, x_min, x_max): 304 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max 305 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 306 | self._signed = x_min.min() < 0 307 | 308 | x_absmax = torch.max(x_min.abs(), x_max) 309 | self._delta = x_absmax / self.int_max 310 | 311 | if self.scale_domain == "log": 312 | self._delta = torch.log(self._delta) 313 | 314 | self._delta = self._delta.detach() 315 | 316 | def make_range_trainable(self): 317 | # Converts trainable parameters to nn.Parameters 318 | if self.delta not in self.parameters(): 319 | self._delta = torch.nn.Parameter(self._delta) 320 | 321 | def fix_ranges(self): 322 | # Removes trainable quantization params from nn.Parameters 323 | if self.delta in self.parameters(): 324 | _delta = self._delta.data 325 | del self._delta # delete the parameter 326 | self.register_buffer("_delta", _delta) 327 | 328 | def generate_grid(self): 329 | x_int_rng = torch.arange(self.int_min, self.int_max + 1) 330 | grid = self.scale * (x_int_rng - self.zero_point) 331 | return grid 332 | -------------------------------------------------------------------------------- /quantization/quantizers/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | 6 | class QuantizerNotInitializedError(Exception): 7 | """Raised when a quantizer has not been initialized""" 8 | 9 | def __init__(self): 10 | super(QuantizerNotInitializedError, self).__init__( 11 | "Quantizer has not been initialized yet" 12 | ) 13 | -------------------------------------------------------------------------------- /quantization/range_estimators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | import copy 5 | from enum import auto 6 | 7 | import numpy as np 8 | import torch 9 | from scipy.optimize import minimize_scalar 10 | from torch import nn 11 | 12 | from utils import to_numpy, BaseEnumOptions, MethodMap, ClassEnumOptions 13 | 14 | 15 | class RangeEstimatorBase(nn.Module): 16 | def __init__(self, per_channel=False, quantizer=None, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | self.register_buffer("current_xmin", None) 19 | self.register_buffer("current_xmax", None) 20 | self.per_channel = per_channel 21 | self.quantizer = quantizer 22 | 23 | def forward(self, x): 24 | """ 25 | Accepts an input tensor, updates the current estimates of x_min and x_max 26 | and returns them. 27 | Parameters 28 | ---------- 29 | x: Input tensor 30 | 31 | Returns 32 | ------- 33 | self.current_xmin: tensor 34 | 35 | self.current_xmax: tensor 36 | 37 | """ 38 | raise NotImplementedError() 39 | 40 | def reset(self): 41 | """ 42 | Reset the range estimator. 43 | """ 44 | self.current_xmin = None 45 | self.current_xmax = None 46 | 47 | def __repr__(self): 48 | # We overwrite this from nn.Module as we do not want to have submodules such as 49 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module. 50 | lines = self.extra_repr().split("\n") 51 | extra_str = lines[0] if len(lines) == 1 else "\n " + "\n ".join(lines) + "\n" 52 | 53 | return self._get_name() + "(" + extra_str + ")" 54 | 55 | 56 | class CurrentMinMaxEstimator(RangeEstimatorBase): 57 | def __init__(self, percentile=None, *args, **kwargs): 58 | self.percentile = percentile 59 | super().__init__(*args, **kwargs) 60 | 61 | def forward(self, x): 62 | if self.per_channel: 63 | x = x.view(x.shape[0], -1) 64 | if self.percentile: 65 | axis = -1 if self.per_channel else None 66 | data_np = to_numpy(x) 67 | x_min, x_max = np.percentile( 68 | data_np, (self.percentile, 100 - self.percentile), axis=axis 69 | ) 70 | self.current_xmin = torch.tensor(x_min).to(x.device) 71 | self.current_xmax = torch.tensor(x_max).to(x.device) 72 | else: 73 | self.current_xmin = x.min(-1)[0].detach() if self.per_channel else x.min().detach() 74 | self.current_xmax = x.max(-1)[0].detach() if self.per_channel else x.max().detach() 75 | 76 | return self.current_xmin, self.current_xmax 77 | 78 | 79 | class AllMinMaxEstimator(RangeEstimatorBase): 80 | def __init__(self, *args, **kwargs): 81 | super().__init__(*args, **kwargs) 82 | 83 | def forward(self, x): 84 | if self.per_channel: 85 | # Along 1st dim 86 | x_flattened = x.view(x.shape[0], -1) 87 | x_min = x_flattened.min(-1)[0].detach() 88 | x_max = x_flattened.max(-1)[0].detach() 89 | else: 90 | x_min = torch.min(x).detach() 91 | x_max = torch.max(x).detach() 92 | 93 | if self.current_xmin is None: 94 | self.current_xmin = x_min 95 | self.current_xmax = x_max 96 | else: 97 | self.current_xmin = torch.min(self.current_xmin, x_min) 98 | self.current_xmax = torch.max(self.current_xmax, x_max) 99 | 100 | return self.current_xmin, self.current_xmax 101 | 102 | 103 | class RunningMinMaxEstimator(RangeEstimatorBase): 104 | def __init__(self, momentum=0.9, *args, **kwargs): 105 | self.momentum = momentum 106 | super().__init__(*args, **kwargs) 107 | 108 | def forward(self, x): 109 | if self.per_channel: 110 | # Along 1st dim 111 | x_flattened = x.view(x.shape[0], -1) 112 | x_min = x_flattened.min(-1)[0].detach() 113 | x_max = x_flattened.max(-1)[0].detach() 114 | else: 115 | x_min = torch.min(x).detach() 116 | x_max = torch.max(x).detach() 117 | 118 | if self.current_xmin is None: 119 | self.current_xmin = x_min 120 | self.current_xmax = x_max 121 | else: 122 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin 123 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax 124 | 125 | return self.current_xmin, self.current_xmax 126 | 127 | 128 | class OptMethod(BaseEnumOptions): 129 | grid = auto() 130 | golden_section = auto() 131 | 132 | 133 | class LineSearchEstimator(RangeEstimatorBase): 134 | def __init__( 135 | self, 136 | num_candidates=1000, 137 | opt_method=OptMethod.grid, 138 | range_margin=0.5, 139 | expand_range=10.0, 140 | *args, 141 | **kwargs, 142 | ): 143 | 144 | super().__init__(*args, **kwargs) 145 | assert opt_method in OptMethod 146 | self.opt_method = opt_method 147 | self.num_candidates = num_candidates 148 | self.expand_range = expand_range 149 | self.loss_array = None 150 | self.max_pos_thr = None 151 | self.max_neg_thr = None 152 | self.max_search_range = None 153 | self.one_sided_dist = None 154 | self.range_margin = range_margin 155 | if self.quantizer is None: 156 | raise NotImplementedError( 157 | "A Quantizer must be given as an argument to the MSE Range" "Estimator" 158 | ) 159 | self.max_int_skew = (2**self.quantizer.n_bits) // 4 # For asymmetric quantization 160 | 161 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False): 162 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr) 163 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1) 164 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel 165 | # grid search 166 | if per_channel_loss: 167 | return to_numpy(temp_sum) 168 | else: 169 | return to_numpy(torch.sum(temp_sum)) 170 | 171 | @property 172 | def step_size(self): 173 | if self.one_sided_dist is None: 174 | raise NoDataPassedError() 175 | 176 | return self.max_search_range / self.num_candidates 177 | 178 | @property 179 | def optimization_method(self): 180 | if self.one_sided_dist is None: 181 | raise NoDataPassedError() 182 | 183 | if self.opt_method == OptMethod.grid: 184 | # Grid search method 185 | if self.one_sided_dist or self.quantizer.symmetric: 186 | # 1-D grid search 187 | return self._perform_1D_search 188 | else: 189 | # 2-D grid_search 190 | return self._perform_2D_search 191 | elif self.opt_method == OptMethod.golden_section: 192 | # Golden section method 193 | if self.one_sided_dist or self.quantizer.symmetric: 194 | return self._golden_section_symmetric 195 | else: 196 | return self._golden_section_asymmetric 197 | else: 198 | raise NotImplementedError("Optimization Method not Implemented") 199 | 200 | def quantize(self, x_float, x_min=None, x_max=None): 201 | temp_q = copy.deepcopy(self.quantizer) 202 | # In the current implementation no optimization procedure requires temp quantizer for 203 | # loss_fx to be per-channel 204 | temp_q.per_channel = False 205 | if x_min or x_max: 206 | temp_q.set_quant_range(x_min, x_max) 207 | return temp_q(x_float) 208 | 209 | def _define_search_range(self, data): 210 | self.channel_groups = len(data) if self.per_channel else 1 211 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device) 212 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device) 213 | 214 | if self.one_sided_dist or self.quantizer.symmetric: 215 | # 1D search space 216 | self.loss_array = np.zeros( 217 | (self.channel_groups, self.num_candidates + 1) 218 | ) # 1D search space 219 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish 220 | # Defining the search range for clipping thresholds 221 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin 222 | self.max_neg_thr = -self.max_pos_thr * self.expand_range 223 | self.max_search_range = self.max_pos_thr * self.expand_range 224 | else: 225 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth 226 | # index represents whether the skew is positive (0) or negative (1)) 227 | self.loss_array = np.zeros( 228 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2] 229 | ) # 2D search space 230 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish 231 | # Define the search range for clipping thresholds in asymmetric case 232 | self.max_pos_thr = float(data.max()) + self.range_margin 233 | self.max_neg_thr = float(data.min()) - self.range_margin 234 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr)) 235 | 236 | def _perform_1D_search(self, data): 237 | """ 238 | Grid search through all candidate quantizers in 1D to find the best 239 | The loss is accumulated over all batches without any momentum 240 | :param data: input tensor 241 | """ 242 | for cand_index in range(1, self.num_candidates + 1): 243 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index 244 | pos_thr = self.step_size * cand_index 245 | 246 | self.loss_array[:, cand_index] += self.loss_fx( 247 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel 248 | ) 249 | 250 | min_cand = self.loss_array.argmin(axis=1) 251 | xmin = ( 252 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand 253 | ).astype(np.single) 254 | xmax = (self.step_size * min_cand).astype(np.single) 255 | self.current_xmax = torch.tensor(xmax).to(device=data.device) 256 | self.current_xmin = torch.tensor(xmin).to(device=data.device) 257 | 258 | def forward(self, data): 259 | if self.loss_array is None: 260 | # Initialize search range on first batch, and accumulate losses with subsequent calls 261 | 262 | # Decide whether input distribution is one-sided 263 | if self.one_sided_dist is None: 264 | self.one_sided_dist = bool((data.min() >= 0).item()) 265 | 266 | # Define search 267 | self._define_search_range(data) 268 | 269 | # Perform Search/Optimization for Quantization Ranges 270 | self.optimization_method(data) 271 | 272 | return self.current_xmin, self.current_xmax 273 | 274 | def reset(self): 275 | super().reset() 276 | self.loss_array = None 277 | 278 | def extra_repr(self): 279 | repr = "opt_method={}".format(self.opt_method.name) 280 | if self.opt_method == OptMethod.grid: 281 | repr += " ,num_candidates={}".format(self.num_candidates) 282 | return repr 283 | 284 | 285 | class FP_MSE_Estimator(RangeEstimatorBase): 286 | def __init__( 287 | self, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, *args, **kwargs 288 | ): 289 | super(FP_MSE_Estimator, self).__init__(*args, **kwargs) 290 | assert opt_method == OptMethod.grid 291 | 292 | self.num_candidates = num_candidates 293 | self.mses = self.search_grid = None 294 | 295 | def _define_search_range(self, x, mbit_list): 296 | if self.per_channel: 297 | x = x.view(x.shape[0], -1) 298 | else: 299 | x = x.view(1, -1) 300 | mxs = [torch.max(torch.abs(xc.min()), torch.abs(xc.max())) for xc in x] 301 | 302 | if self.search_grid is None: 303 | assert self.mses is None 304 | 305 | lsp = [torch.linspace(0.1 * mx.item(), 1.2 * mx.item(), 111) for mx in mxs] 306 | 307 | # 111 x n_channels 308 | search_grid = torch.stack(lsp).to(x.device).transpose(0, 1) 309 | 310 | # mbits x 111 x n_channels (or 1 in case not --per-channel) 311 | mses = torch.stack([torch.zeros_like(search_grid) for _ in range(len(mbit_list))]) 312 | 313 | self.mses = mses 314 | self.search_grid = search_grid 315 | 316 | return self.search_grid, self.mses 317 | 318 | def forward(self, x): 319 | mbit_list = [float(self.quantizer.mantissa_bits)] 320 | 321 | if self.quantizer.mse_include_mantissa_bits: 322 | # highest possible value is self.n_bits - self.sign_bits - 1 323 | mbit_list = [ 324 | float(x) for x in range(1, self.quantizer.n_bits - self.quantizer.sign_bits) 325 | ] 326 | 327 | search_grid, mses = self._define_search_range(x, mbit_list) 328 | 329 | assert mses.shape[1:] == search_grid.shape, f"{mses.shape}, {search_grid.shape}" 330 | 331 | # Need to do this here too to get correct search range 332 | sign_bits = int(torch.any(x < 0)) if self.quantizer.allow_unsigned else 1 333 | 334 | meandims = list(torch.arange(len(x.shape))) 335 | if self.per_channel: 336 | meandims = meandims[1:] 337 | for m, mbits in enumerate(mbit_list): 338 | mbits = torch.Tensor([mbits]).to(x.device) 339 | self.quantizer.mantissa_bits = mbits 340 | for i, maxval in enumerate(search_grid): 341 | x_min, x_max = sign_bits * -1.0 * maxval, maxval 342 | self.quantizer.set_quant_range(x_min, x_max) 343 | xfp = self.quantizer(x) 344 | 345 | # get MSE per channel (mean over all non-channel dims) 346 | mse = ((x - xfp) ** 2).mean(meandims) 347 | mses[m, i, :] += mse 348 | 349 | # Find best mbits per channel 350 | best_mbits_per_channel = mses.min(1)[0].argmin(0) 351 | 352 | # Get plurality vote on mbits 353 | best_mbit_idx = torch.mode(best_mbits_per_channel).values.item() 354 | best_mbits = float(mbit_list[best_mbit_idx]) 355 | 356 | # then, find best per-channel scale for best mbit 357 | # first, get the MSES for the best mbit, then argmin over linspace dim to get best index per channel 358 | mses = mses[best_mbit_idx].argmin(0) 359 | # then, for each channel, get the argmin MSE max value 360 | maxval = torch.tensor([search_grid[mses[i], i] for i in range(search_grid.shape[-1])]).to( 361 | x.device 362 | ) 363 | 364 | self.quantizer.mantissa_bits = torch.tensor(best_mbits).to( 365 | self.quantizer.mantissa_bits.device 366 | ) 367 | 368 | maxval = maxval.to(self.quantizer.maxval.device) 369 | return sign_bits * -1.0 * maxval, maxval 370 | 371 | 372 | def estimate_range_line_search(W, quant, num_candidates=None): 373 | if num_candidates is None: 374 | est_fp = LineSearchEstimator(quantizer=quant) 375 | else: 376 | est_fp = LineSearchEstimator(quantizer=quant, num_candidates=num_candidates) 377 | 378 | mse_range_min_fp, mse_range_max_fp = est_fp.forward(W) 379 | return (mse_range_min_fp, mse_range_max_fp) 380 | 381 | 382 | class NoDataPassedError(Exception): 383 | """Raised data has been passed into the Range Estimator""" 384 | 385 | def __init__(self): 386 | super().__init__("Data must be pass through the range estimator to be initialized") 387 | 388 | 389 | class RangeEstimators(ClassEnumOptions): 390 | current_minmax = MethodMap(CurrentMinMaxEstimator) 391 | allminmax = MethodMap(AllMinMaxEstimator) 392 | running_minmax = MethodMap(RunningMinMaxEstimator) 393 | MSE = MethodMap(FP_MSE_Estimator) 394 | -------------------------------------------------------------------------------- /quantization/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | 6 | import torch 7 | import torch.serialization 8 | 9 | from quantization.quantizers import QuantizerBase 10 | from quantization.quantizers.rounding_utils import ParametrizedGradEstimatorBase 11 | from quantization.range_estimators import RangeEstimators 12 | from utils import StopForwardException, get_layer_by_name 13 | 14 | 15 | def separate_quantized_model_params(quant_model): 16 | """ 17 | This method separates the parameters of the quantized model to 4 categories. 18 | Parameters 19 | ---------- 20 | quant_model: (QuantizedModel) 21 | 22 | Returns 23 | ------- 24 | quant_params: (list) 25 | Quantization parameters, e.g. delta and zero_float 26 | model_params: (list) 27 | The model parameters of the base model without any quantization operations 28 | grad_params: (list) 29 | Parameters found in the gradient estimators (ParametrizedGradEstimatorBase) 30 | ------- 31 | 32 | """ 33 | quant_params, grad_params = [], [] 34 | quant_params_names, grad_params_names = [], [] 35 | for mod_name, module in quant_model.named_modules(): 36 | if isinstance(module, QuantizerBase): 37 | for name, param in module.named_parameters(recurse=False): 38 | quant_params.append(param) 39 | quant_params_names.append(".".join((mod_name, name))) 40 | if isinstance(module, ParametrizedGradEstimatorBase): 41 | # gradient estimator params 42 | for name, param in module.named_parameters(recurse=False): 43 | grad_params.append(param) 44 | grad_params_names.append(".".join((mod_name, name))) 45 | 46 | def tensor_in_list(tensor, lst): 47 | return any([e is tensor for e in lst]) 48 | 49 | found_params = quant_params + grad_params 50 | 51 | model_params = [p for p in quant_model.parameters() if not tensor_in_list(p, found_params)] 52 | model_param_names = [ 53 | n for n, p in quant_model.named_parameters() if not tensor_in_list(p, found_params) 54 | ] 55 | 56 | print("Quantization parameters ({}):".format(len(quant_params_names))) 57 | print(quant_params_names) 58 | 59 | print("Gradient estimator parameters ({}):".format(len(grad_params_names))) 60 | print(grad_params_names) 61 | 62 | print("Other model parameters ({}):".format(len(model_param_names))) 63 | print(model_param_names) 64 | 65 | assert len(model_params + quant_params + grad_params) == len( 66 | list(quant_model.parameters()) 67 | ), "{}; {}; {} -- {}".format( 68 | len(model_params), len(quant_params), len(grad_params), len(list(quant_model.parameters())) 69 | ) 70 | 71 | return quant_params, model_params, grad_params 72 | 73 | 74 | def pass_data_for_range_estimation( 75 | loader, model, act_quant, weight_quant, max_num_batches=20, cross_entropy_layer=None, inp_idx=0 76 | ): 77 | print("\nEstimate quantization ranges on training data") 78 | model.set_quant_state(weight_quant, act_quant) 79 | # Put model in eval such that BN EMA does not get updated 80 | model.eval() 81 | 82 | if cross_entropy_layer is not None: 83 | layer_xent = get_layer_by_name(model, cross_entropy_layer) 84 | if layer_xent: 85 | print('Set cross entropy estimator for layer "{}"'.format(cross_entropy_layer)) 86 | act_quant_mgr = layer_xent.activation_quantizer 87 | act_quant_mgr.range_estimator = RangeEstimators.cross_entropy.cls( 88 | per_channel=act_quant_mgr.per_channel, 89 | quantizer=act_quant_mgr.quantizer, 90 | **act_quant_mgr.range_estim_params, 91 | ) 92 | else: 93 | raise ValueError("Cross-entropy layer not found") 94 | 95 | batches = [] 96 | device = next(model.parameters()).device 97 | 98 | with torch.no_grad(): 99 | for i, data in enumerate(loader): 100 | try: 101 | if isinstance(data, (tuple, list)): 102 | x = data[inp_idx].to(device=device) 103 | batches.append(x.data.cpu().numpy()) 104 | model(x) 105 | print(f"proccesed step={i}") 106 | else: 107 | x = {k: v.to(device=device) for k, v in data.items()} 108 | model(**x) 109 | print(f"proccesed step={i}") 110 | 111 | if i >= max_num_batches - 1 or not act_quant: 112 | break 113 | except StopForwardException: 114 | pass 115 | return batches 116 | 117 | 118 | def set_range_estimators(config, model): 119 | print("Make quantizers learnable") 120 | model.learn_ranges() 121 | 122 | if config.qat.grad_scaling: 123 | print("Activate gradient scaling") 124 | model.grad_scaling(True) 125 | 126 | # Ensure we have the desired quant state 127 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant) 128 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click==8.1.3 2 | pytorch-ignite~=0.4.9 3 | tensorboard==2.10.1 4 | scipy==1.9.3 5 | numpy==1.22.4 6 | timm~=0.4.12 7 | tqdm==4.49.0 8 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from .stopwatch import Stopwatch 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /utils/click_options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import click 6 | 7 | from models import QuantArchitectures 8 | from functools import wraps, partial 9 | from quantization.quantization_manager import QMethods 10 | from quantization.range_estimators import RangeEstimators, OptMethod 11 | from utils import split_dict, DotDict, ClickEnumOption, seed_all 12 | from utils.imagenet_dataloaders import ImageInterpolation 13 | 14 | click.option = partial(click.option, show_default=True) 15 | 16 | _HELP_MSG = ( 17 | "Enforce determinism also on the GPU by disabling CUDNN and setting " 18 | "`torch.set_deterministic(True)`. In many cases this comes at the cost of efficiency " 19 | "and performance." 20 | ) 21 | 22 | 23 | def base_options(func): 24 | @click.option( 25 | "--images-dir", type=click.Path(exists=True), help="Root directory of images", required=True 26 | ) 27 | @click.option("--max-epochs", default=90, type=int, help="Maximum number of training epochs.") 28 | @click.option( 29 | "--interpolation", 30 | type=ClickEnumOption(ImageInterpolation), 31 | default=ImageInterpolation.bilinear.name, 32 | help="Desired interpolation to use for resizing.", 33 | ) 34 | @click.option( 35 | "--save-checkpoint-dir", 36 | type=click.Path(exists=False), 37 | default=None, 38 | help="Directory where to save checkpoints (model, optimizer, lr_scheduler).", 39 | ) 40 | @click.option( 41 | "--tb-logging-dir", default=None, type=str, help="The logging directory " "for tensorboard" 42 | ) 43 | @click.option("--cuda/--no-cuda", is_flag=True, default=True, help="Use GPU") 44 | @click.option("--batch-size", default=128, type=int, help="Mini-batch size") 45 | @click.option("--num-workers", default=16, type=int, help="Number of workers for data loading") 46 | @click.option("--seed", default=None, type=int, help="Random number generator seed to set") 47 | @click.option("--deterministic/--nondeterministic", default=False, help=_HELP_MSG) 48 | # Architecture related options 49 | @click.option( 50 | "--architecture", 51 | type=ClickEnumOption(QuantArchitectures), 52 | required=True, 53 | help="Quantized architecture", 54 | ) 55 | @click.option( 56 | "--model-dir", 57 | type=click.Path(exists=True), 58 | default=None, 59 | help="Path for model directory. If the model does not exist it will downloaded " 60 | "from a URL", 61 | ) 62 | @click.option( 63 | "--pretrained/--no-pretrained", 64 | is_flag=True, 65 | default=True, 66 | help="Use pretrained model weights", 67 | ) 68 | @click.option( 69 | "--progress-bar/--no-progress-bar", is_flag=True, default=False, help="Show progress bar" 70 | ) 71 | @wraps(func) 72 | def func_wrapper(config, *args, **kwargs): 73 | config.base, remaining_kwargs = split_dict( 74 | kwargs, 75 | [ 76 | "images_dir", 77 | "max_epochs", 78 | "interpolation", 79 | "save_checkpoint_dir", 80 | "tb_logging_dir", 81 | "cuda", 82 | "batch_size", 83 | "num_workers", 84 | "seed", 85 | "model_dir", 86 | "architecture", 87 | "pretrained", 88 | "deterministic", 89 | "progress_bar", 90 | ], 91 | ) 92 | 93 | seed, deterministic = config.base.seed, config.base.deterministic 94 | 95 | if seed is None: 96 | if deterministic is True: 97 | raise ValueError("Enforcing determinism without providing a seed is not supported") 98 | else: 99 | seed_all(seed=seed, deterministic=deterministic) 100 | 101 | return func(config, *args, **remaining_kwargs) 102 | 103 | return func_wrapper 104 | 105 | 106 | class multi_optimizer_options: 107 | """ 108 | An instance of this class is a callable object to serve as a decorator; 109 | hence the lower case class name. 110 | 111 | Among the CLI options defined in the decorator, `--{prefix-}optimizer-type` 112 | requires special attention. Default value for that variable for 113 | {prefix-}optimizer is the value in use by the main optimizer. 114 | 115 | Examples: 116 | @multi_optimizer_options('quant') 117 | @pass_config 118 | def command(config): 119 | ... 120 | """ 121 | 122 | def __init__(self, prefix: str = ""): 123 | self.optimizer_name = prefix + "_optimizer" if prefix else "optimizer" 124 | self.prefix_option = prefix + "-" if prefix else "" 125 | self.prefix_attribute = prefix + "_" if prefix else "" 126 | 127 | def __call__(self, func): 128 | prefix_option = self.prefix_option 129 | prefix_attribute = self.prefix_attribute 130 | 131 | @click.option( 132 | f"--{prefix_option}optimizer", 133 | default="SGD", 134 | type=click.Choice(["SGD", "Adam"], case_sensitive=False), 135 | help=f"Class name of torch Optimizer to be used.", 136 | ) 137 | @click.option( 138 | f"--{prefix_option}learning-rate", 139 | default=None, 140 | type=float, 141 | help="Initial learning rate.", 142 | ) 143 | @click.option( 144 | f"--{prefix_option}momentum", default=0.9, type=float, help=f"Optimizer momentum." 145 | ) 146 | @click.option( 147 | f"--{prefix_option}weight-decay", 148 | default=None, 149 | type=float, 150 | help="Weight decay for the network.", 151 | ) 152 | @click.option( 153 | f"--{prefix_option}learning-rate-schedule", 154 | default=None, 155 | type=str, 156 | help="Learning rate scheduler, 'MultiStepLR:10:20:40' or " 157 | "'cosine:1e-4' for cosine decay", 158 | ) 159 | @wraps(func) 160 | def func_wrapper(config, *args, **kwargs): 161 | base_arg_names = [ 162 | "optimizer", 163 | "learning_rate", 164 | "momentum", 165 | "weight_decay", 166 | "learning_rate_schedule", 167 | ] 168 | 169 | optimizer_opt = DotDict() 170 | 171 | # Collect basic arguments 172 | for arg in base_arg_names: 173 | option_name = prefix_attribute + arg 174 | optimizer_opt[arg] = kwargs.pop(option_name) 175 | 176 | # config.{prefix_attribute}optimizer = optimizer_opt 177 | setattr(config, prefix_attribute + "optimizer", optimizer_opt) 178 | 179 | return func(config, *args, **kwargs) 180 | 181 | return func_wrapper 182 | 183 | 184 | def qat_options(func): 185 | @click.option( 186 | "--reestimate-bn-stats/--no-reestimate-bn-stats", 187 | is_flag=True, 188 | default=True, 189 | help="Reestimates the BN stats before every evaluation.", 190 | ) 191 | @click.option( 192 | "--grad-scaling/--no-grad-scaling", 193 | is_flag=True, 194 | default=False, 195 | help="Do gradient scaling as in LSQ paper.", 196 | ) 197 | @click.option( 198 | "--sep-quant-optimizer/--no-sep-quant-optimizer", 199 | is_flag=True, 200 | default=False, 201 | help="Use a separate optimizer for the quantizers.", 202 | ) 203 | @multi_optimizer_options("quant") 204 | @oscillations_dampen_options 205 | @oscillations_freeze_options 206 | @wraps(func) 207 | def func_wrapper(config, *args, **kwargs): 208 | config.qat, remainder_kwargs = split_dict( 209 | kwargs, ["reestimate_bn_stats", "grad_scaling", "sep_quant_optimizer"] 210 | ) 211 | return func(config, *args, **remainder_kwargs) 212 | 213 | return func_wrapper 214 | 215 | 216 | def oscillations_dampen_options(func): 217 | @click.option( 218 | "--oscillations-dampen-weight", 219 | default=None, 220 | type=float, 221 | help="If given, adds a oscillations dampening to the loss with given " "weighting.", 222 | ) 223 | @click.option( 224 | "--oscillations-dampen-aggregation", 225 | type=click.Choice(["sum", "mean", "kernel_mean"]), 226 | default="kernel_mean", 227 | help="Aggregation type for bin regularization loss.", 228 | ) 229 | @click.option( 230 | "--oscillations-dampen-weight-final", 231 | type=float, 232 | default=None, 233 | help="Dampening regularization final value for annealing schedule.", 234 | ) 235 | @click.option( 236 | "--oscillations-dampen-anneal-start", 237 | default=0.25, 238 | type=float, 239 | help="Start of annealing (relative to total number of iterations).", 240 | ) 241 | @wraps(func) 242 | def func_wrapper(config, *args, **kwargs): 243 | config.osc_damp, remainder_kwargs = split_dict( 244 | kwargs, 245 | [ 246 | "oscillations_dampen_weight", 247 | "oscillations_dampen_aggregation", 248 | "oscillations_dampen_weight_final", 249 | "oscillations_dampen_anneal_start", 250 | ], 251 | "oscillations_dampen", 252 | ) 253 | 254 | return func(config, *args, **remainder_kwargs) 255 | 256 | return func_wrapper 257 | 258 | 259 | def oscillations_freeze_options(func): 260 | @click.option( 261 | "--oscillations-freeze-threshold", 262 | default=0.0, 263 | type=float, 264 | help="If greater than 0, we will freeze oscillations which frequency (EMA) is " 265 | "higher than the given threshold. Frequency is defined as 1/period length.", 266 | ) 267 | @click.option( 268 | "--oscillations-freeze-ema-momentum", 269 | default=0.001, 270 | type=float, 271 | help="The momentum to calculate the EMA frequency of the oscillation. In case" 272 | "freezing is used, this should be at least 2-3 times lower than the " 273 | "freeze threshold.", 274 | ) 275 | @click.option( 276 | "--oscillations-freeze-use-ema/--no-oscillation-freeze-use-ema", 277 | is_flag=True, 278 | default=True, 279 | help="Uses an EMA of past x_int to find the correct freezing int value.", 280 | ) 281 | @click.option( 282 | "--oscillations-freeze-max-bits", 283 | default=4, 284 | type=int, 285 | help="Max bit-width for oscillation tracking and freezing. If layers weight is in" 286 | "higher bits we do not track or freeze oscillations.", 287 | ) 288 | @click.option( 289 | "--oscillations-freeze-threshold-final", 290 | type=float, 291 | default=None, 292 | help="Oscillation freezing final value for annealing schedule.", 293 | ) 294 | @click.option( 295 | "--oscillations-freeze-anneal-start", 296 | default=0.25, 297 | type=float, 298 | help="Start of annealing (relative to total number of iterations).", 299 | ) 300 | @wraps(func) 301 | def func_wrapper(config, *args, **kwargs): 302 | config.osc_freeze, remainder_kwargs = split_dict( 303 | kwargs, 304 | [ 305 | "oscillations_freeze_threshold", 306 | "oscillations_freeze_ema_momentum", 307 | "oscillations_freeze_use_ema", 308 | "oscillations_freeze_max_bits", 309 | "oscillations_freeze_threshold_final", 310 | "oscillations_freeze_anneal_start", 311 | ], 312 | "oscillations_freeze", 313 | ) 314 | 315 | return func(config, *args, **remainder_kwargs) 316 | 317 | return func_wrapper 318 | 319 | 320 | def quantization_options(func): 321 | # Weight quantization options 322 | @click.option( 323 | "--weight-quant/--no-weight-quant", 324 | is_flag=True, 325 | default=True, 326 | help="Run evaluation weight quantization or use FP32 weights", 327 | ) 328 | @click.option( 329 | "--qmethod", 330 | type=ClickEnumOption(QMethods), 331 | default=QMethods.symmetric_uniform.name, 332 | help="Quantization scheme to use.", 333 | ) 334 | @click.option( 335 | "--weight-quant-method", 336 | default=RangeEstimators.current_minmax.name, 337 | type=ClickEnumOption(RangeEstimators), 338 | help="Method to determine weight quantization clipping thresholds.", 339 | ) 340 | @click.option( 341 | "--weight-opt-method", 342 | default=OptMethod.grid.name, 343 | type=ClickEnumOption(OptMethod), 344 | help="Optimization procedure for activation quantization clipping thresholds", 345 | ) 346 | @click.option( 347 | "--num-candidates", 348 | type=int, 349 | default=None, 350 | help="Number of grid points for grid search in MSE range method.", 351 | ) 352 | @click.option("--n-bits", default=8, type=int, help="Default number of quantization bits.") 353 | @click.option( 354 | "--per-channel/--no-per-channel", 355 | is_flag=True, 356 | default=False, 357 | help="If given, quantize each channel separately.", 358 | ) 359 | # Activation quantization options 360 | @click.option( 361 | "--act-quant/--no-act-quant", 362 | is_flag=True, 363 | default=True, 364 | help="Run evaluation with activation quantization or use FP32 activations", 365 | ) 366 | @click.option( 367 | "--qmethod-act", 368 | type=ClickEnumOption(QMethods), 369 | default=None, 370 | help="Quantization scheme for activation to use. If not specified `--qmethod` " "is used.", 371 | ) 372 | @click.option( 373 | "--n-bits-act", default=None, type=int, help="Number of quantization bits for activations." 374 | ) 375 | @click.option( 376 | "--act-quant-method", 377 | default=RangeEstimators.running_minmax.name, 378 | type=ClickEnumOption(RangeEstimators), 379 | help="Method to determine activation quantization clipping thresholds", 380 | ) 381 | @click.option( 382 | "--act-opt-method", 383 | default=OptMethod.grid.name, 384 | type=ClickEnumOption(OptMethod), 385 | help="Optimization procedure for activation quantization clipping thresholds", 386 | ) 387 | @click.option( 388 | "--act-num-candidates", 389 | type=int, 390 | default=None, 391 | help="Number of grid points for grid search in MSE/SQNR/Cross-entropy", 392 | ) 393 | @click.option( 394 | "--act-momentum", 395 | type=float, 396 | default=None, 397 | help="Exponential averaging factor for running_minmax", 398 | ) 399 | @click.option( 400 | "--num-est-batches", 401 | type=int, 402 | default=1, 403 | help="Number of training batches to be used for activation range estimation", 404 | ) 405 | # Other options 406 | @click.option( 407 | "--quant-setup", 408 | default="all", 409 | type=click.Choice(["all", "LSQ", "FP_logits", "fc4", "fc4_dw8", "LSQ_paper"]), 410 | help="Method to quantize the network.", 411 | ) 412 | @wraps(func) 413 | def func_wrapper(config, *args, **kwargs): 414 | config.quant, remainder_kwargs = split_dict( 415 | kwargs, 416 | [ 417 | "qmethod", 418 | "qmethod_act", 419 | "weight_quant_method", 420 | "weight_opt_method", 421 | "num_candidates", 422 | "n_bits", 423 | "n_bits_act", 424 | "per_channel", 425 | "act_quant", 426 | "weight_quant", 427 | "quant_setup", 428 | "num_est_batches", 429 | "act_momentum", 430 | "act_num_candidates", 431 | "act_opt_method", 432 | "act_quant_method", 433 | ], 434 | ) 435 | 436 | config.quant.qmethod_act = config.quant.qmethod_act or config.quant.qmethod 437 | 438 | return func(config, *args, **remainder_kwargs) 439 | 440 | return func_wrapper 441 | 442 | 443 | def fp8_options(func): 444 | # Weight quantization options 445 | @click.option("--fp8-maxval", type=float, default=None) 446 | @click.option("--fp8-mantissa-bits", type=int, default=4) 447 | @click.option("--fp8-set-maxval/--no-fp8-set-maxval", is_flag=True, default=False) 448 | @click.option("--fp8-learn-maxval/--no-fp8-learn-maxval", is_flag=True, default=False) 449 | @click.option( 450 | "--fp8-learn-mantissa-bits/--no-fp8-learn-mantissa-bits", is_flag=True, default=False 451 | ) 452 | @click.option( 453 | "--fp8-mse-include-mantissa-bits/--no-fp8-mse-include-mantissa-bits", 454 | is_flag=True, 455 | default=True, 456 | ) 457 | @click.option("--fp8-allow-unsigned/--no-fp8-allow-unsigned", is_flag=True, default=False) 458 | @wraps(func) 459 | def func_wrapper(config, *args, **kwargs): 460 | config.fp8, remainder_kwargs = split_dict( 461 | kwargs, 462 | [ 463 | "fp8_maxval", 464 | "fp8_mantissa_bits", 465 | "fp8_set_maxval", 466 | "fp8_learn_maxval", 467 | "fp8_learn_mantissa_bits", 468 | "fp8_mse_include_mantissa_bits", 469 | "fp8_allow_unsigned", 470 | ], 471 | ) 472 | return func(config, *args, **remainder_kwargs) 473 | 474 | return func_wrapper 475 | 476 | 477 | def quant_params_dict(config): 478 | weight_range_options = {} 479 | if config.quant.weight_quant_method == RangeEstimators.MSE: 480 | weight_range_options = dict(opt_method=config.quant.weight_opt_method) 481 | if config.quant.num_candidates is not None: 482 | weight_range_options["num_candidates"] = config.quant.num_candidates 483 | 484 | act_range_options = {} 485 | if config.quant.act_quant_method == RangeEstimators.MSE: 486 | act_range_options = dict(opt_method=config.quant.act_opt_method) 487 | if config.quant.act_num_candidates is not None: 488 | act_range_options["num_candidates"] = config.quant.num_candidates 489 | 490 | qparams = { 491 | "method": config.quant.qmethod.cls, 492 | "n_bits": config.quant.n_bits, 493 | "n_bits_act": config.quant.n_bits_act, 494 | "act_method": config.quant.qmethod_act.cls, 495 | "per_channel_weights": config.quant.per_channel, 496 | "quant_setup": config.quant.quant_setup, 497 | "weight_range_method": config.quant.weight_quant_method.cls, 498 | "weight_range_options": weight_range_options, 499 | "act_range_method": config.quant.act_quant_method.cls, 500 | "act_range_options": act_range_options, 501 | "quantize_input": True if config.quant.quant_setup == "LSQ_paper" else False, 502 | } 503 | 504 | if config.quant.qmethod.name.startswith("fp_quantizer"): 505 | fp8_kwargs = { 506 | k.replace("fp8_", ""): v for k, v in config.fp8.items() if k.startswith("fp8") 507 | } 508 | qparams["fp8_kwargs"] = fp8_kwargs 509 | 510 | return qparams 511 | -------------------------------------------------------------------------------- /utils/distributions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import numpy as np 6 | import scipy.stats as stats 7 | import scipy.integrate 8 | from scipy import special 9 | 10 | 11 | class DistrBase: 12 | def __init__(self, params_dict, range_min, range_max, *args, **kwargs): 13 | self.params_dict = params_dict 14 | assert range_max >= range_min 15 | self.range_min = range_min 16 | self.range_max = range_max 17 | 18 | def pdf(self, x): 19 | raise NotImplementedError() 20 | 21 | def sample(self, shape): 22 | raise NotImplementedError() 23 | 24 | def eval_point_mass_range_min(self): 25 | raise NotImplementedError() 26 | 27 | def eval_point_mass_range_max(self): 28 | raise NotImplementedError() 29 | 30 | def integr_interv_p_sqr_r(self, a, b): 31 | raise NotImplementedError() 32 | 33 | def integr_interv_x_p_r_signed(self, a, b): 34 | raise NotImplementedError() 35 | 36 | def eval_p_sqr_r(self, x, grid): 37 | raise NotImplementedError() 38 | 39 | def eval_x_p_r_signed(self, x, grid): 40 | raise NotImplementedError() 41 | 42 | def eval_non_central_second_moment(self): 43 | raise NotImplementedError() 44 | 45 | def print(self): 46 | raise NotImplementedError() 47 | 48 | 49 | class ClippedGaussDistr(DistrBase): 50 | def __init__(self, *args, **kwargs): 51 | super().__init__(*args, **kwargs) 52 | mu = self.params_dict["mu"] 53 | sigma = self.params_dict["sigma"] 54 | self.point_mass_range_min = stats.norm.cdf(self.range_min, loc=mu, scale=sigma) 55 | self.point_mass_range_max = 1.0 - stats.norm.cdf(self.range_max, loc=mu, scale=sigma) 56 | 57 | def print(self): 58 | print( 59 | "Gaussian distr ", 60 | ", mu = ", 61 | self.params_dict["mu"], 62 | ", sigma = ", 63 | self.params_dict["sigma"], 64 | " clipped at [", 65 | self.range_min, 66 | ",", 67 | self.range_max, 68 | "]", 69 | ) 70 | 71 | def cdf(self, x_np): 72 | p = stats.norm.cdf(x_np, self.params_dict["mu"], self.params_dict["sigma"]) 73 | return p 74 | 75 | def pdf(self, x): 76 | x_np = x.cpu().numpy() 77 | p = stats.norm.pdf(x_np, self.params_dict["mu"], self.params_dict["sigma"]) 78 | return p 79 | 80 | def inverse_cdf(self, x): 81 | res = stats.norm.ppf(x, loc=self.params_dict["mu"], scale=self.params_dict["sigma"]) 82 | return res 83 | 84 | def sample(self, shape): 85 | r = np.random.normal( 86 | loc=self.params_dict["mu"], scale=self.params_dict["sigma"], size=shape 87 | ) 88 | r = np.clip(r, self.range_min, self.range_max) 89 | return r 90 | 91 | def integr_interv_p_sqr_r(self, a, b, u): 92 | assert b >= a 93 | mu = self.params_dict["mu"] 94 | sigma = self.params_dict["sigma"] 95 | root_half = np.sqrt(0.5) 96 | root_half_pi = np.sqrt(0.5 * np.pi) 97 | t1 = -sigma * ( 98 | np.exp((-0.5 * a**2 + 1.0 * a * mu - 0.5 * mu**2) / sigma**2) 99 | * sigma 100 | * (-1.0 * a - 1.0 * mu + 2.0 * u) 101 | + ( 102 | -root_half_pi * mu**2 103 | - root_half_pi * sigma**2 104 | + 2.0 * root_half_pi * mu * u 105 | - root_half_pi * u**2 106 | ) 107 | * special.erf((-root_half * a + root_half * mu) / sigma) 108 | ) 109 | t2 = sigma * ( 110 | np.exp((-0.5 * b**2 + 1.0 * b * mu - 0.5 * mu**2) / sigma**2) 111 | * sigma 112 | * (-1.0 * b - 1.0 * mu + 2.0 * u) 113 | + ( 114 | -root_half_pi * mu**2 115 | - root_half_pi * sigma**2 116 | + 2.0 * root_half_pi * mu * u 117 | - root_half_pi * u**2 118 | ) 119 | * special.erf((-root_half * b + root_half * mu) / sigma) 120 | ) 121 | const = 1 / sigma / np.sqrt(2 * np.pi) 122 | return (t1 + t2) * const 123 | 124 | def integr_interv_x_p_signed_r(self, a, b, x0): 125 | assert b >= a 126 | mu = self.params_dict["mu"] 127 | sigma = self.params_dict["sigma"] 128 | root_half = np.sqrt(0.5) 129 | root_half_pi = np.sqrt(0.5 * np.pi) 130 | 131 | res = ( 132 | x0 133 | * sigma 134 | * ( 135 | np.exp(-((0.5 * mu**2) / sigma**2)) 136 | * ( 137 | np.exp((a * (-0.5 * a + mu)) / sigma**2) 138 | - np.exp((b * (-0.5 * b + mu)) / sigma**2) 139 | ) 140 | * sigma 141 | - root_half_pi * mu * special.erf((root_half * a - root_half * mu) / sigma) 142 | + root_half_pi * mu * special.erf((root_half * b - root_half * mu) / sigma) 143 | ) 144 | + sigma 145 | * ( 146 | np.exp((-0.5 * a**2 + a * mu - 0.5 * mu**2) / sigma**2) 147 | * (-a * sigma - mu * sigma) 148 | + (-root_half_pi * mu**2 - root_half_pi * sigma**2) 149 | * special.erf((-root_half * a + root_half * mu) / sigma) 150 | ) 151 | - sigma 152 | * ( 153 | np.exp((-0.5 * b**2 + b * mu - 0.5 * mu**2) / sigma**2) 154 | * (-b * sigma - mu * sigma) 155 | + (-root_half_pi * mu**2 - root_half_pi * sigma**2) 156 | * special.erf((-root_half * b + root_half * mu) / sigma) 157 | ) 158 | ) 159 | 160 | const = 1 / sigma / np.sqrt(2 * np.pi) 161 | return res * const 162 | 163 | def integr_p_times_x(self, a, b): 164 | assert b >= a 165 | mu = self.params_dict["mu"] 166 | sigma = self.params_dict["sigma"] 167 | 168 | root_half = np.sqrt(0.5) 169 | root_half_pi = np.sqrt(0.5 * np.pi) 170 | 171 | res = sigma * ( 172 | np.exp(-((0.5 * mu**2) / sigma**2)) 173 | * ( 174 | np.exp((a * (-0.5 * a + mu)) / sigma**2) 175 | - np.exp((b * (-0.5 * b + mu)) / sigma**2) 176 | ) 177 | * sigma 178 | - root_half_pi * mu * special.erf(root_half * (a - mu) / sigma) 179 | + root_half_pi * mu * special.erf(root_half * (b - mu) / sigma) 180 | ) 181 | 182 | scale = 1 / (sigma * np.sqrt(2 * np.pi)) 183 | return res * scale 184 | 185 | def eval_non_central_second_moment(self): 186 | term_range_min = self.point_mass_range_min * self.range_min**2 187 | term_range_max = self.point_mass_range_max * self.range_max**2 188 | term_middle_intergral = self.integr_interv_p_sqr_r(self.range_min, self.range_max, 0.0) 189 | return term_range_min + term_middle_intergral + term_range_max 190 | 191 | 192 | class ClippedStudentTDistr(DistrBase): 193 | def __init__(self, *args, **kwargs): 194 | super().__init__(*args, **kwargs) 195 | nu = self.params_dict["nu"] 196 | self.point_mass_range_min = stats.t.cdf(self.range_min, nu) 197 | self.point_mass_range_max = 1.0 - stats.t.cdf(self.range_max, nu) 198 | 199 | def print(self): 200 | print( 201 | "Student's-t distr", 202 | ", nu = ", 203 | self.params_dict["nu"], 204 | " clipped at [", 205 | self.range_min, 206 | ",", 207 | self.range_max, 208 | "]", 209 | ) 210 | 211 | def pdf(self, x): 212 | x_np = x.cpu().numpy() 213 | p = stats.t.pdf(x_np, self.params_dict["nu"]) 214 | return p 215 | 216 | def cdf(self, x): 217 | p = stats.t.cdf(x, self.params_dict["nu"]) 218 | return p 219 | 220 | def inverse_cdf(self, x): 221 | res = stats.t.ppf(x, self.params_dict["nu"]) 222 | return res 223 | 224 | def sample(self, shape): 225 | r = np.random.standard_t(self.params_dict["nu"], size=shape) 226 | r = np.clip(r, self.range_min, self.range_max) 227 | return r 228 | 229 | def integr_interv_p_sqr_r(self, a, b, u): 230 | assert b >= a 231 | nu = self.params_dict["nu"] 232 | 233 | first_term = (2.0 * nu * (-1.0 + ((a**2 + nu) / nu) ** (1.0 / 2.0 - nu / 2.0)) * u) / ( 234 | 1.0 - nu 235 | ) 236 | second_term = -(2 * nu * (-1 + ((b**2 + nu) / nu) ** (1.0 / 2.0 - nu / 2)) * u) / ( 237 | 1.0 - nu 238 | ) 239 | third_term = ( 240 | -a 241 | * u**2 242 | * scipy.special.hyp2f1(1.0 / 2.0, (1.0 + nu) / 2.0, 3.0 / 2.0, -(a**2.0 / nu)) 243 | ) 244 | forth_term = ( 245 | b 246 | * u**2.0 247 | * scipy.special.hyp2f1(1.0 / 2.0, (1.0 + nu) / 2.0, 3.0 / 2.0, -(b**2 / nu)) 248 | ) 249 | fifth_term = ( 250 | -1.0 251 | / 3.0 252 | * a**3 253 | * scipy.special.hyp2f1(3.0 / 2.0, (1.0 + nu) / 2.0, 5.0 / 2.0, -(a**2 / nu)) 254 | ) 255 | sixth_term = ( 256 | 1.0 257 | / 3.0 258 | * b**3 259 | * scipy.special.hyp2f1(3.0 / 2.0, (1.0 + nu) / 2.0, 5.0 / 2.0, -(b**2 / nu)) 260 | ) 261 | res = first_term + second_term + third_term + forth_term + fifth_term + sixth_term 262 | 263 | const = ( 264 | scipy.special.gamma(0.5 * (nu + 1.0)) 265 | / np.sqrt(np.pi * nu) 266 | / scipy.special.gamma(0.5 * nu) 267 | ) 268 | return res * const 269 | 270 | def integr_p_times_x(self, a, b): 271 | assert b >= a 272 | nu = self.params_dict["nu"] 273 | res = 0.0 274 | 275 | const = ( 276 | scipy.special.gamma(0.5 * (nu + 1.0)) 277 | / np.sqrt(np.pi * nu) 278 | / scipy.special.gamma(0.5 * nu) 279 | ) 280 | return res * const 281 | 282 | def scale(self): 283 | nu = self.params_dict["nu"] 284 | res = ( 285 | scipy.special.gamma(0.5 * (nu + 1.0)) 286 | / np.sqrt(np.pi * nu) 287 | / scipy.special.gamma(0.5 * nu) 288 | ) 289 | return res 290 | 291 | def integr_cubic_root_p(self, a, b): 292 | assert b >= a 293 | nu = self.params_dict["nu"] 294 | 295 | common_mult = 1.0 / (nu - 2.0) * 3.0 * (a * b) ** (-nu / 3.0) * nu ** ((1.0 + nu) / 6.0) 296 | first_term = ( 297 | scipy.special.hyp2f1( 298 | 1 / 6.0 * (-2.0 + nu), (1.0 + nu) / 6, (4.0 + nu) / 6.0, -nu / a**2 299 | ) 300 | * a ** (2.0 / 3.0) 301 | * b ** (nu / 3.0) 302 | ) 303 | second_term = ( 304 | -scipy.special.hyp2f1( 305 | 1 / 6.0 * (-2.0 + nu), (1.0 + nu) / 6, (4.0 + nu) / 6.0, -nu / b**2 306 | ) 307 | * b ** (2.0 / 3.0) 308 | * a ** (nu / 3.0) 309 | ) 310 | 311 | return common_mult * (first_term + second_term) 312 | 313 | def integr_interv_u_t_times_p_no_constant(self, a, b, u): 314 | assert b >= a 315 | df = self.params_dict["nu"] 316 | res = ( 317 | df ** ((1.0 + df) / 2.0) 318 | * (-((a**2 + df) ** (1.0 / 2.0 - df / 2.0)) + (b**2 + df) ** (1.0 / 2.0 - df / 2.0)) 319 | * u 320 | ) / (1.0 - df) 321 | return res 322 | 323 | def integr_interv_x_p_signed_r(self, a, b, x0): 324 | assert b >= a 325 | nu = self.params_dict["nu"] 326 | 327 | const = ( 328 | scipy.special.gamma(0.5 * (nu + 1.0)) 329 | / np.sqrt(np.pi * nu) 330 | / scipy.special.gamma(0.5 * nu) 331 | ) 332 | r1 = self.integr_interv_u_t_times_p_no_constant(a, b, x0) * const 333 | r2 = self.integr_interv_p_sqr_r(a, b, 0.0) 334 | 335 | res = r1 - r2 336 | return res 337 | 338 | def eval_non_central_second_moment(self): 339 | term_range_min = self.point_mass_range_min * self.range_min**2 340 | term_range_max = self.point_mass_range_max * self.range_max**2 341 | term_middle_intergral = self.integr_interv_p_sqr_r(self.range_min, self.range_max, 0.0) 342 | return term_range_min + term_middle_intergral + term_range_max 343 | 344 | 345 | class UniformDistr(DistrBase): 346 | def __init__(self, *args, **kwargs): 347 | super().__init__(*args, **kwargs) 348 | self.p = 1 / (self.range_max - self.range_min) 349 | 350 | def print(self): 351 | print("Uniform distribution on [", self.range_min, ",", self.range_max, "]") 352 | 353 | def pdf(self, x): 354 | return self.p 355 | 356 | def cdf(self, x): 357 | return (x - self.range_min) * self.p 358 | 359 | def sample(self, shape): 360 | return np.random.uniform(self.range_min, self.range_max, shape) 361 | 362 | def integr_interv_p_sqr_r(self, a, b, u): 363 | assert b >= a 364 | res = -(a**3 / 3.0) + b**3 / 3.0 + a**2 * u - b**2 * u - a * u**2 + b * u**2 365 | return res * self.p 366 | 367 | def eval_non_central_second_moment(self): 368 | if not isinstance(self, UniformDistr): 369 | term_range_min = self.point_mass_range_min() * self.range_min**2 370 | term_range_max = self.point_mass_range_max() * self.range_max**2 371 | else: 372 | term_range_min = 0.0 373 | term_range_max = 0.0 374 | term_middle_intergral = self.integr_interv_p_sqr_r(self.range_min, self.range_max, 0.0) 375 | 376 | return term_range_min + term_middle_intergral + term_range_max 377 | 378 | def integr_p_times_x(self, a, b): 379 | return 0.5 * (b**2 - a**2) * self.p 380 | 381 | def integr_interv_x_p_signed_r(self, a, b, x0): 382 | assert b >= a 383 | res = 0.5 * a**2 - 0.5 * b**2 + (b - a) * x0 384 | return res * self.p 385 | -------------------------------------------------------------------------------- /utils/grid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import torch 6 | import numpy as np 7 | from utils.distributions import ClippedGaussDistr, ClippedStudentTDistr 8 | 9 | 10 | def rounding_error_abs_nearest(x_float, grid): 11 | n_grid = grid.size 12 | grid_row = np.array(grid).reshape(1, n_grid) 13 | m_vals = x_float.numel() 14 | x_float = x_float.cpu().detach().numpy().reshape(m_vals, 1) 15 | 16 | dist = np.abs(x_float - grid_row) 17 | min_dist = np.min(dist, axis=1) 18 | 19 | return min_dist 20 | 21 | 22 | def quant_scalar_nearest(x_float, grid): 23 | dist = np.abs(x_float - grid) 24 | idx = np.argmin(dist) 25 | q_x = grid[idx] 26 | return q_x 27 | 28 | 29 | def clip_grid_exclude_bounds(grid, min_val, max_val): 30 | idx_subset = torch.logical_and(grid > min_val, grid < max_val) 31 | return grid[idx_subset] 32 | 33 | 34 | def clip_grid_include_bounds(grid, min_val, max_val): 35 | idx_subset = torch.logical_and(grid >= min_val, grid <= max_val) 36 | return grid[idx_subset] 37 | 38 | 39 | def clip_grid_add_bounds(grid, min_val, max_val): 40 | grid_clipped = clip_grid_exclude_bounds(grid, min_val, max_val) 41 | bounds_np = np.array([min_val, max_val]) 42 | clipped_with_bounds = np.sort(np.concatenate((grid_clipped, bounds_np))) 43 | return clipped_with_bounds 44 | 45 | 46 | def integrate_pdf_grid_func_analyt(distr, grid, distr_attr_func_name): 47 | grid = np.sort(grid) 48 | interval_integr_func = getattr(distr, distr_attr_func_name) 49 | res = 0.0 50 | 51 | if distr.range_min < grid[0]: 52 | res += interval_integr_func(distr.range_min, grid[0], grid[0]) 53 | 54 | for i_interval in range(0, len(grid) - 1): 55 | grid_mid = 0.5 * (grid[i_interval] + grid[i_interval + 1]) 56 | 57 | first_half_a = max(grid[i_interval], distr.range_min) 58 | first_half_b = min(grid_mid, distr.range_max) 59 | 60 | second_half_a = max(grid_mid, distr.range_min) 61 | second_half_b = min(grid[i_interval + 1], distr.range_max) 62 | 63 | if first_half_a < first_half_b: 64 | res += interval_integr_func(first_half_a, first_half_b, grid[i_interval]) 65 | if second_half_a < second_half_b: 66 | res += interval_integr_func(second_half_a, second_half_b, grid[i_interval + 1]) 67 | 68 | if distr.range_max > grid[-1]: 69 | res += interval_integr_func(grid[-1], distr.range_max, grid[-1]) 70 | 71 | if ( 72 | isinstance(distr, ClippedGaussDistr) or isinstance(distr, ClippedStudentTDistr) 73 | ) and distr_attr_func_name == "integr_interv_x_p_signed_r": 74 | q_range_min = quant_scalar_nearest(torch.Tensor([distr.range_min]), grid) 75 | q_range_max = quant_scalar_nearest(torch.Tensor([distr.range_max]), grid) 76 | 77 | term_point_mass = ( 78 | distr.range_min * (q_range_min - distr.range_min) * distr.point_mass_range_min 79 | + distr.range_max * (q_range_max - distr.range_max) * distr.point_mass_range_max 80 | ) 81 | res += term_point_mass 82 | elif ( 83 | isinstance(distr, ClippedGaussDistr) or isinstance(distr, ClippedStudentTDistr) 84 | ) and distr_attr_func_name == "integr_interv_p_sqr_r": 85 | q_range_min = quant_scalar_nearest(torch.Tensor([distr.range_min]), grid) 86 | q_range_max = quant_scalar_nearest(torch.Tensor([distr.range_max]), grid) 87 | 88 | term_point_mass = (q_range_min - distr.range_min) ** 2 * distr.point_mass_range_min + ( 89 | q_range_max - distr.range_max 90 | ) ** 2 * distr.point_mass_range_max 91 | res += term_point_mass 92 | 93 | return res 94 | -------------------------------------------------------------------------------- /utils/imagenet_dataloaders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import os 6 | 7 | import torchvision 8 | import torch.utils.data as torch_data 9 | from torchvision import transforms 10 | from utils import BaseEnumOptions 11 | 12 | 13 | class ImageInterpolation(BaseEnumOptions): 14 | nearest = transforms.InterpolationMode.NEAREST 15 | box = transforms.InterpolationMode.BOX 16 | bilinear = transforms.InterpolationMode.BILINEAR 17 | hamming = transforms.InterpolationMode.HAMMING 18 | bicubic = transforms.InterpolationMode.BICUBIC 19 | lanczos = transforms.InterpolationMode.LANCZOS 20 | 21 | 22 | class ImageNetDataLoaders(object): 23 | """ 24 | Data loader provider for ImageNet images, providing a train and a validation loader. 25 | It assumes that the structure of the images is 26 | images_dir 27 | - train 28 | - label1 29 | - label2 30 | - ... 31 | - val 32 | - label1 33 | - label2 34 | - ... 35 | """ 36 | 37 | def __init__( 38 | self, 39 | images_dir: str, 40 | image_size: int, 41 | batch_size: int, 42 | num_workers: int, 43 | interpolation: transforms.InterpolationMode, 44 | ): 45 | """ 46 | Parameters 47 | ---------- 48 | images_dir: str 49 | Root image directory 50 | image_size: int 51 | Number of pixels the image will be re-sized to (square) 52 | batch_size: int 53 | Batch size of both the training and validation loaders 54 | num_workers 55 | Number of parallel workers loading the images 56 | interpolation: transforms.InterpolationMode 57 | Desired interpolation to use for resizing. 58 | """ 59 | 60 | self.images_dir = images_dir 61 | self.batch_size = batch_size 62 | self.num_workers = num_workers 63 | 64 | # For normalization, mean and std dev values are calculated per channel 65 | # and can be found on the web. 66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 67 | 68 | self.train_transforms = transforms.Compose( 69 | [ 70 | transforms.RandomResizedCrop(image_size, interpolation=interpolation.value), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | normalize, 74 | ] 75 | ) 76 | 77 | self.val_transforms = transforms.Compose( 78 | [ 79 | transforms.Resize(image_size + 24, interpolation=interpolation.value), 80 | transforms.CenterCrop(image_size), 81 | transforms.ToTensor(), 82 | normalize, 83 | ] 84 | ) 85 | 86 | self._train_loader = None 87 | self._val_loader = None 88 | 89 | @property 90 | def train_loader(self) -> torch_data.DataLoader: 91 | if not self._train_loader: 92 | root = os.path.join(self.images_dir, "train") 93 | train_set = torchvision.datasets.ImageFolder(root, transform=self.train_transforms) 94 | self._train_loader = torch_data.DataLoader( 95 | train_set, 96 | batch_size=self.batch_size, 97 | shuffle=True, 98 | num_workers=self.num_workers, 99 | pin_memory=True, 100 | ) 101 | return self._train_loader 102 | 103 | @property 104 | def val_loader(self) -> torch_data.DataLoader: 105 | if not self._val_loader: 106 | root = os.path.join(self.images_dir, "val") 107 | val_set = torchvision.datasets.ImageFolder(root, transform=self.val_transforms) 108 | self._val_loader = torch_data.DataLoader( 109 | val_set, 110 | batch_size=self.batch_size, 111 | shuffle=False, 112 | num_workers=self.num_workers, 113 | pin_memory=True, 114 | ) 115 | return self._val_loader 116 | -------------------------------------------------------------------------------- /utils/optimizer_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import torch 6 | 7 | 8 | def get_lr_scheduler(optimizer, lr_schedule, epochs): 9 | scheduler = None 10 | if lr_schedule: 11 | if lr_schedule.startswith("multistep"): 12 | epochs = [int(s) for s in lr_schedule.split(":")[1:]] 13 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epochs) 14 | elif lr_schedule.startswith("cosine"): 15 | eta_min = float(lr_schedule.split(":")[1]) 16 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 17 | optimizer, epochs, eta_min=eta_min 18 | ) 19 | return scheduler 20 | 21 | 22 | def optimizer_lr_factory(config_optim, params, epochs): 23 | if config_optim.optimizer.lower() == "sgd": 24 | optimizer = torch.optim.SGD( 25 | params, 26 | lr=config_optim.learning_rate, 27 | momentum=config_optim.momentum, 28 | weight_decay=config_optim.weight_decay, 29 | ) 30 | elif config_optim.optimizer.lower() == "adam": 31 | optimizer = torch.optim.Adam( 32 | params, lr=config_optim.learning_rate, weight_decay=config_optim.weight_decay 33 | ) 34 | else: 35 | raise ValueError() 36 | 37 | lr_scheduler = get_lr_scheduler(optimizer, config_optim.learning_rate_schedule, epochs) 38 | 39 | return optimizer, lr_scheduler 40 | -------------------------------------------------------------------------------- /utils/qat_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import copy 6 | import torch 7 | 8 | from quantization.quantized_folded_bn import BNFusedHijacker 9 | from utils.imagenet_dataloaders import ImageNetDataLoaders 10 | 11 | 12 | def get_dataloaders_and_model(config, load_type="fp32", **qparams): 13 | dataloaders = ImageNetDataLoaders( 14 | config.base.images_dir, 15 | 224, 16 | config.base.batch_size, 17 | config.base.num_workers, 18 | config.base.interpolation, 19 | ) 20 | 21 | model = config.base.architecture( 22 | pretrained=config.base.pretrained, 23 | load_type=load_type, 24 | model_dir=config.base.model_dir, 25 | **qparams, 26 | ) 27 | if config.base.cuda: 28 | model = model.cuda() 29 | 30 | return dataloaders, model 31 | 32 | 33 | class ReestimateBNStats: 34 | def __init__(self, model, data_loader, num_batches=50): 35 | super().__init__() 36 | self.model = model 37 | self.data_loader = data_loader 38 | self.num_batches = num_batches 39 | 40 | def __call__(self, engine): 41 | print("-- Reestimate current BN statistics --") 42 | reestimate_BN_stats(self.model, self.data_loader, self.num_batches) 43 | 44 | 45 | def reestimate_BN_stats(model, data_loader, num_batches=50, store_ema_stats=False): 46 | # We set BN momentum to 1 an use train mode 47 | # -> the running mean/var have the current batch statistics 48 | model.eval() 49 | org_momentum = {} 50 | for name, module in model.named_modules(): 51 | if isinstance(module, BNFusedHijacker): 52 | org_momentum[name] = module.momentum 53 | module.momentum = 1.0 54 | module.running_mean_sum = torch.zeros_like(module.running_mean) 55 | module.running_var_sum = torch.zeros_like(module.running_var) 56 | # Set all BNFusedHijacker modules to train mode for but not its children 57 | module.training = True 58 | 59 | if store_ema_stats: 60 | # Save the original EMA, make sure they are in buffers so they end in the state dict 61 | if not hasattr(module, "running_mean_ema"): 62 | module.register_buffer("running_mean_ema", copy.deepcopy(module.running_mean)) 63 | module.register_buffer("running_var_ema", copy.deepcopy(module.running_var)) 64 | else: 65 | module.running_mean_ema = copy.deepcopy(module.running_mean) 66 | module.running_var_ema = copy.deepcopy(module.running_var) 67 | 68 | # Run data for estimation 69 | device = next(model.parameters()).device 70 | batch_count = 0 71 | with torch.no_grad(): 72 | for x, y in data_loader: 73 | model(x.to(device)) 74 | # We save the running mean/var to a buffer 75 | for name, module in model.named_modules(): 76 | if isinstance(module, BNFusedHijacker): 77 | module.running_mean_sum += module.running_mean 78 | module.running_var_sum += module.running_var 79 | 80 | batch_count += 1 81 | if batch_count == num_batches: 82 | break 83 | # At the end we normalize the buffer and write it into the running mean/var 84 | for name, module in model.named_modules(): 85 | if isinstance(module, BNFusedHijacker): 86 | module.running_mean = module.running_mean_sum / batch_count 87 | module.running_var = module.running_var_sum / batch_count 88 | # We reset the momentum in case it would be used anywhere else 89 | module.momentum = org_momentum[name] 90 | model.eval() 91 | -------------------------------------------------------------------------------- /utils/stopwatch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import sys 6 | import time 7 | 8 | 9 | class Stopwatch: 10 | """ 11 | A simple cross-platform context-manager stopwatch. 12 | 13 | Examples 14 | -------- 15 | >>> import time 16 | >>> with Stopwatch(verbose=True) as st: 17 | ... time.sleep(0.101) #doctest: +ELLIPSIS 18 | Elapsed time: 0.10... sec 19 | """ 20 | 21 | def __init__(self, name=None, verbose=False): 22 | self._name = name 23 | self._verbose = verbose 24 | 25 | self._start_time_point = 0.0 26 | self._total_duration = 0.0 27 | self._is_running = False 28 | 29 | if sys.platform == "win32": 30 | # on Windows, the best timer is time.clock() 31 | self._timer_fn = time.clock 32 | else: 33 | # on most other platforms, the best timer is time.time() 34 | self._timer_fn = time.time 35 | 36 | def __enter__(self, verbose=False): 37 | return self.start() 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | self.stop() 41 | if self._verbose: 42 | self.print() 43 | 44 | def start(self): 45 | if not self._is_running: 46 | self._start_time_point = self._timer_fn() 47 | self._is_running = True 48 | return self 49 | 50 | def stop(self): 51 | if self._is_running: 52 | self._total_duration += self._timer_fn() - self._start_time_point 53 | self._is_running = False 54 | return self 55 | 56 | def reset(self): 57 | self._start_time_point = 0.0 58 | self._total_duration = 0.0 59 | self._is_running = False 60 | return self 61 | 62 | def _update_state(self): 63 | now = self._timer_fn() 64 | self._total_duration += now - self._start_time_point 65 | self._start_time_point = now 66 | 67 | def _format(self): 68 | prefix = f"[{self._name}]" if self._name is not None else "Elapsed time" 69 | info = f"{prefix}: {self._total_duration:.3f} sec" 70 | return info 71 | 72 | def format(self): 73 | if self._is_running: 74 | self._update_state() 75 | return self._format() 76 | 77 | def print(self): 78 | print(self.format()) 79 | 80 | def get_total_duration(self): 81 | if self._is_running: 82 | self._update_state() 83 | return self._total_duration 84 | -------------------------------------------------------------------------------- /utils/supervised_driver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | from ignite.contrib.handlers import TensorboardLogger 6 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator 7 | from ignite.handlers import Checkpoint, global_step_from_engine 8 | from torch.optim import Optimizer 9 | 10 | 11 | def create_trainer_engine( 12 | model, 13 | optimizer, 14 | criterion, 15 | metrics, 16 | data_loaders, 17 | lr_scheduler=None, 18 | save_checkpoint_dir=None, 19 | device="cuda", 20 | ): 21 | # Create trainer 22 | trainer = create_supervised_trainer( 23 | model=model, 24 | optimizer=optimizer, 25 | loss_fn=criterion, 26 | device=device, 27 | output_transform=custom_output_transform, 28 | ) 29 | 30 | for name, metric in metrics.items(): 31 | metric.attach(trainer, name) 32 | 33 | # Add lr_scheduler 34 | if lr_scheduler: 35 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step()) 36 | 37 | # Create evaluator 38 | evaluator = create_supervised_evaluator(model=model, metrics=metrics, device=device) 39 | 40 | # Save model checkpoint 41 | if save_checkpoint_dir: 42 | to_save = {"model": model, "optimizer": optimizer} 43 | if lr_scheduler: 44 | to_save["lr_scheduler"] = lr_scheduler 45 | checkpoint = Checkpoint( 46 | to_save, 47 | save_checkpoint_dir, 48 | n_saved=1, 49 | global_step_transform=global_step_from_engine(trainer), 50 | ) 51 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint) 52 | 53 | # Add hooks for logging metrics 54 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results, optimizer) 55 | 56 | trainer.add_event_handler( 57 | Events.EPOCH_COMPLETED, run_evaluation_for_training, evaluator, data_loaders.val_loader 58 | ) 59 | 60 | return trainer, evaluator 61 | 62 | 63 | def custom_output_transform(x, y, y_pred, loss): 64 | return y_pred, y 65 | 66 | 67 | def log_training_results(trainer, optimizer): 68 | learning_rate = optimizer.param_groups[0]["lr"] 69 | log_metrics(trainer.state.metrics, "Training", trainer.state.epoch, learning_rate) 70 | 71 | 72 | def run_evaluation_for_training(trainer, evaluator, val_loader): 73 | evaluator.run(val_loader) 74 | log_metrics(evaluator.state.metrics, "Evaluation", trainer.state.epoch) 75 | 76 | 77 | def log_metrics(metrics, stage: str = "", training_epoch=None, learning_rate=None): 78 | log_text = " {}".format(metrics) if metrics else "" 79 | if training_epoch is not None: 80 | log_text = "Epoch: {}".format(training_epoch) + log_text 81 | if learning_rate and learning_rate > 0.0: 82 | log_text += " Learning rate: {:.2E}".format(learning_rate) 83 | log_text = "Results - " + log_text 84 | if stage: 85 | log_text = "{} ".format(stage) + log_text 86 | print(log_text, flush=True) 87 | 88 | 89 | def setup_tensorboard_logger(trainer, evaluator, output_path, optimizers=None): 90 | logger = TensorboardLogger(logdir=output_path) 91 | 92 | # Attach the logger to log loss and accuracy for both training and validation 93 | for tag, cur_evaluator in [("train", trainer), ("validation", evaluator)]: 94 | logger.attach_output_handler( 95 | cur_evaluator, 96 | event_name=Events.EPOCH_COMPLETED, 97 | tag=tag, 98 | metric_names="all", 99 | global_step_transform=global_step_from_engine(trainer), 100 | ) 101 | 102 | # Log optimizer parameters 103 | if isinstance(optimizers, Optimizer): 104 | optimizers = {None: optimizers} 105 | 106 | for k, optimizer in optimizers.items(): 107 | logger.attach_opt_params_handler( 108 | trainer, Events.EPOCH_COMPLETED, optimizer, param_name="lr", tag=k 109 | ) 110 | 111 | return logger 112 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 3 | # All Rights Reserved. 4 | 5 | import collections 6 | import os 7 | import random 8 | from collections import namedtuple 9 | from enum import Flag, auto 10 | from functools import partial 11 | 12 | import click 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class DotDict(dict): 19 | """ 20 | A dictionary that allows attribute-style access. 21 | Examples 22 | -------- 23 | >>> config = DotDict(a=None) 24 | >>> config.a = 42 25 | >>> config.b = 'egg' 26 | >>> config # can be used as dict 27 | {'a': 42, 'b': 'egg'} 28 | """ 29 | 30 | def __setattr__(self, key, value): 31 | self.__setitem__(key, value) 32 | 33 | def __delattr__(self, key): 34 | self.__delitem__(key) 35 | 36 | def __getattr__(self, key): 37 | if key in self: 38 | return self.__getitem__(key) 39 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})") 40 | 41 | 42 | def relu(x): 43 | x = np.array(x) 44 | return x * (x > 0) 45 | 46 | 47 | def get_all_layer_names(model, subtypes=None): 48 | if subtypes is None: 49 | return [name for name, module in model.named_modules()][1:] 50 | return [name for name, module in model.named_modules() if isinstance(module, subtypes)] 51 | 52 | 53 | def get_layer_name_to_module_dict(model): 54 | return {name: module for name, module in model.named_modules() if name} 55 | 56 | 57 | def get_module_to_layer_name_dict(model): 58 | modules_to_names = collections.OrderedDict() 59 | for name, module in model.named_modules(): 60 | modules_to_names[module] = name 61 | return modules_to_names 62 | 63 | 64 | def get_layer_name(model, layer): 65 | for name, module in model.named_modules(): 66 | if module == layer: 67 | return name 68 | return None 69 | 70 | 71 | def get_layer_by_name(model, layer_name): 72 | for name, module in model.named_modules(): 73 | if name == layer_name: 74 | return module 75 | return None 76 | 77 | 78 | def create_conv_layer_list(cls, model: nn.Module) -> list: 79 | """ 80 | Function finds all prunable layers in the provided model 81 | 82 | Parameters 83 | ---------- 84 | cls: SVD class 85 | model : torch.nn.Module 86 | A pytorch model. 87 | 88 | Returns 89 | ------- 90 | conv_layer_list : list 91 | List of all prunable layers in the given model. 92 | 93 | """ 94 | conv_layer_list = [] 95 | 96 | def fill_list(mod): 97 | if isinstance(mod, tuple(cls.supported_layer_types)): 98 | conv_layer_list.append(mod) 99 | 100 | model.apply(fill_list) 101 | return conv_layer_list 102 | 103 | 104 | def create_linear_layer_list(cls, model: nn.Module) -> list: 105 | """ 106 | Function finds all prunable layers in the provided model 107 | 108 | Parameters 109 | ---------- 110 | model : torch.nn.Module 111 | A pytorch model. 112 | 113 | Returns 114 | ------- 115 | conv_layer_list : list 116 | List of all prunable layers in the given model. 117 | 118 | """ 119 | conv_layer_list = [] 120 | 121 | def fill_list(mod): 122 | if isinstance(mod, tuple(cls.supported_layer_types)): 123 | conv_layer_list.append(mod) 124 | 125 | model.apply(fill_list) 126 | return conv_layer_list 127 | 128 | 129 | def to_numpy(tensor): 130 | """ 131 | Helper function that turns the given tensor into a numpy array 132 | 133 | Parameters 134 | ---------- 135 | tensor : torch.Tensor 136 | 137 | Returns 138 | ------- 139 | tensor : float or np.array 140 | 141 | """ 142 | if isinstance(tensor, np.ndarray): 143 | return tensor 144 | if hasattr(tensor, "is_cuda"): 145 | if tensor.is_cuda: 146 | return tensor.cpu().detach().numpy() 147 | if hasattr(tensor, "detach"): 148 | return tensor.detach().numpy() 149 | if hasattr(tensor, "numpy"): 150 | return tensor.numpy() 151 | 152 | return np.array(tensor) 153 | 154 | 155 | def set_module_attr(model, layer_name, value): 156 | split = layer_name.split(".") 157 | 158 | this_module = model 159 | for mod_name in split[:-1]: 160 | if mod_name.isdigit(): 161 | this_module = this_module[int(mod_name)] 162 | else: 163 | this_module = getattr(this_module, mod_name) 164 | 165 | last_mod_name = split[-1] 166 | if last_mod_name.isdigit(): 167 | this_module[int(last_mod_name)] = value 168 | else: 169 | setattr(this_module, last_mod_name, value) 170 | 171 | 172 | def search_for_zero_planes(model: torch.nn.Module): 173 | """If list of modules to winnow is empty to start with, search through all modules to check 174 | if any 175 | planes have been zeroed out. Update self._list_of_modules_to_winnow with any findings. 176 | :param model: torch model to search through modules for zeroed parameters 177 | """ 178 | 179 | list_of_modules_to_winnow = [] 180 | for _, module in model.named_modules(): 181 | if isinstance(module, (torch.nn.Linear, torch.nn.modules.conv.Conv2d)): 182 | in_channels_to_winnow = _assess_weight_and_bias(module.weight, module.bias) 183 | if in_channels_to_winnow: 184 | list_of_modules_to_winnow.append((module, in_channels_to_winnow)) 185 | return list_of_modules_to_winnow 186 | 187 | 188 | def _assess_weight_and_bias(weight: torch.nn.Parameter, _bias: torch.nn.Parameter): 189 | """4-dim weights [CH-out, CH-in, H, W] and 1-dim bias [CH-out]""" 190 | if len(weight.shape) > 2: 191 | input_channels_to_ignore = (weight.sum((0, 2, 3)) == 0).nonzero().squeeze().tolist() 192 | else: 193 | input_channels_to_ignore = (weight.sum(0) == 0).nonzero().squeeze().tolist() 194 | 195 | if type(input_channels_to_ignore) != list: 196 | input_channels_to_ignore = [input_channels_to_ignore] 197 | 198 | return input_channels_to_ignore 199 | 200 | 201 | def seed_all(seed: int = 1029, deterministic: bool = False): 202 | """ 203 | This is our attempt to make experiments reproducible by seeding all known RNGs and setting 204 | appropriate torch directives. 205 | For a general discussion of reproducibility in Pytorch and CUDA and a documentation of the 206 | options we are using see, e.g., 207 | https://pytorch.org/docs/1.7.1/notes/randomness.html 208 | https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility 209 | 210 | As of today (July 2021), even after seeding and setting some directives, 211 | there remain unfortunate contradictions: 212 | 1. CUDNN 213 | - having CUDNN enabled leads to 214 | - non-determinism in Pytorch when using the GPU, cf. MORPH-10999. 215 | - having CUDNN disabled leads to 216 | - most regression tests in Qrunchy failing, cf. MORPH-11103 217 | - significantly increased execution time in some cases 218 | - performance degradation in some cases 219 | 2. torch.set_deterministic(d) 220 | - setting d = True leads to errors for Pytorch algorithms that do not (yet) have a deterministic 221 | counterpart, e.g., the layer `adaptive_avg_pool2d_backward_cuda` in vgg16__torchvision. 222 | 223 | Thus, we leave the choice of enforcing determinism by disabling CUDNN and non-deterministic 224 | algorithms to the user. To keep it simple, we only have one switch for both. 225 | This situation could be re-evaluated upon updates of Pytorch, CUDA, CUDNN. 226 | """ 227 | 228 | assert isinstance(seed, int), f"RNG seed must be an integer ({seed})" 229 | assert seed >= 0, f"RNG seed must be a positive integer ({seed})" 230 | 231 | # Builtin RNGs 232 | random.seed(seed) 233 | os.environ["PYTHONHASHSEED"] = str(seed) 234 | 235 | # Numpy RNG 236 | np.random.seed(seed) 237 | 238 | # CUDNN determinism (setting those has not lead to errors so far) 239 | torch.backends.cudnn.benchmark = False 240 | torch.backends.cudnn.deterministic = True 241 | 242 | # Torch RNGs 243 | torch.manual_seed(seed) 244 | torch.cuda.manual_seed(seed) 245 | torch.cuda.manual_seed_all(seed) 246 | 247 | # Problematic settings, see docstring. Precaution: We do not mutate unless asked to do so 248 | if deterministic is True: 249 | torch.backends.cudnn.enabled = False 250 | 251 | torch.set_deterministic(True) # Use torch.use_deterministic_algorithms(True) in torch 1.8.1 252 | # When using torch.set_deterministic(True), it is advised by Pytorch to set the 253 | # CUBLAS_WORKSPACE_CONFIG variable as follows, see 254 | # https://pytorch.org/docs/1.7.1/notes/randomness.html#avoiding-nondeterministic-algorithms 255 | # and the link to the CUDA homepage on that website. 256 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 257 | 258 | 259 | def assert_allclose(actual, desired, *args, **kwargs): 260 | """A more beautiful version of torch.all_close.""" 261 | np.testing.assert_allclose(to_numpy(actual), to_numpy(desired), *args, **kwargs) 262 | 263 | 264 | def count_params(module): 265 | return len(nn.utils.parameters_to_vector(module.parameters())) 266 | 267 | 268 | class StopForwardException(Exception): 269 | """Used to throw and catch an exception to stop traversing the graph.""" 270 | 271 | pass 272 | 273 | 274 | class StopForwardHook: 275 | def __call__(self, module, *args): 276 | raise StopForwardException 277 | 278 | 279 | def sigmoid(x): 280 | return 1.0 / (1.0 + np.exp(-x)) 281 | 282 | 283 | class CosineTempDecay: 284 | def __init__(self, t_max, temp_range=(20.0, 2.0), rel_decay_start=0): 285 | self.t_max = t_max 286 | self.start_temp, self.end_temp = temp_range 287 | self.decay_start = rel_decay_start * t_max 288 | 289 | def __call__(self, t): 290 | if t < self.decay_start: 291 | return self.start_temp 292 | 293 | rel_t = (t - self.decay_start) / (self.t_max - self.decay_start) 294 | return self.end_temp + 0.5 * (self.start_temp - self.end_temp) * (1 + np.cos(rel_t * np.pi)) 295 | 296 | 297 | class BaseEnumOptions(Flag): 298 | def __str__(self): 299 | return self.name 300 | 301 | @classmethod 302 | def list_names(cls): 303 | return [m.name for m in cls] 304 | 305 | 306 | class ClassEnumOptions(BaseEnumOptions): 307 | @property 308 | def cls(self): 309 | return self.value.cls 310 | 311 | def __call__(self, *args, **kwargs): 312 | return self.value.cls(*args, **kwargs) 313 | 314 | 315 | MethodMap = partial(namedtuple("MethodMap", ["value", "cls"]), auto()) 316 | 317 | 318 | def split_dict(src: dict, include=(), remove_prefix: str = ""): 319 | """ 320 | Splits dictionary into a DotDict and a remainder. 321 | The arguments to be placed in the first DotDict are those listed in `include`. 322 | Parameters 323 | ---------- 324 | src: dict 325 | The source dictionary. 326 | include: 327 | List of keys to be returned in the first DotDict. 328 | remove_suffix: 329 | remove prefix from key 330 | """ 331 | result = DotDict() 332 | 333 | for arg in include: 334 | if remove_prefix: 335 | key = arg.replace(f"{remove_prefix}_", "", 1) 336 | else: 337 | key = arg 338 | result[key] = src[arg] 339 | remainder = {key: val for key, val in src.items() if key not in include} 340 | return result, remainder 341 | 342 | 343 | class ClickEnumOption(click.Choice): 344 | """ 345 | Adjusted click.Choice type for BaseOption which is based on Enum 346 | """ 347 | 348 | def __init__(self, enum_options, case_sensitive=True): 349 | assert issubclass(enum_options, BaseEnumOptions) 350 | self.base_option = enum_options 351 | super().__init__(self.base_option.list_names(), case_sensitive) 352 | 353 | def convert(self, value, param, ctx): 354 | # Exact match 355 | if value in self.choices: 356 | return self.base_option[value] 357 | 358 | # Match through normalization and case sensitivity 359 | # first do token_normalize_func, then lowercase 360 | # preserve original `value` to produce an accurate message in 361 | # `self.fail` 362 | normed_value = value 363 | normed_choices = self.choices 364 | 365 | if ctx is not None and ctx.token_normalize_func is not None: 366 | normed_value = ctx.token_normalize_func(value) 367 | normed_choices = [ctx.token_normalize_func(choice) for choice in self.choices] 368 | 369 | if not self.case_sensitive: 370 | normed_value = normed_value.lower() 371 | normed_choices = [choice.lower() for choice in normed_choices] 372 | 373 | if normed_value in normed_choices: 374 | return self.base_option[normed_value] 375 | 376 | self.fail( 377 | "invalid choice: %s. (choose from %s)" % (value, ", ".join(self.choices)), param, ctx 378 | ) 379 | --------------------------------------------------------------------------------