├── .gitignore
├── LICENSE
├── README.md
├── compute_quant_error.py
├── image_net.py
├── models
├── __init__.py
├── mobilenet_v2.py
├── mobilenet_v2_quantized.py
└── resnet_quantized.py
├── quantization
├── __init__.py
├── autoquant_utils.py
├── base_quantized_classes.py
├── base_quantized_model.py
├── hijacker.py
├── quant_error_estimator.py
├── quantization_manager.py
├── quantized_folded_bn.py
├── quantizers
│ ├── __init__.py
│ ├── base_quantizers.py
│ ├── fp8_quantizer.py
│ ├── rounding_utils.py
│ ├── uniform_quantizers.py
│ └── utils.py
├── range_estimators.py
└── utils.py
├── requirements.txt
└── utils
├── __init__.py
├── click_options.py
├── distributions.py
├── grid.py
├── imagenet_dataloaders.py
├── optimizer_utils.py
├── qat_utils.py
├── stopwatch.py
├── supervised_driver.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | __pycache__/
4 | experiments/
5 | Makefile
6 | resources/
7 | docker/
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022 Qualcomm Technologies, Inc.
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer:
8 |
9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 |
11 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
12 |
13 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FP8 Quantization: The Power of the Exponent
2 | This repository contains the implementation and experiments for the paper presented in
3 |
4 | **Andrey Kuzmin\*1, Mart van Baalen\*1, Yuwei Ren1,
5 | Markus Nagel1, Jorn Peters1, Tijmen Blankevoort1 "FP8 Quantization: The Power of the Exponent", NeurIPS
6 | 2022.** [[ArXiv]](https://arxiv.org/abs/2208.09225)
7 |
8 | *Equal contribution
9 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.)
10 |
11 | You can use this code to recreate the results in the paper.
12 |
13 | ## Method and Results
14 |
15 | In this repository we share the code to reproduce analytical and experimental results on performance of FP8 format with different mantissa/exponent division versus INT8. The first part of the repository allows the user to reproduce
16 | analytical computations of SQNR for uniform, Gaussian, and Student's-t distibutions. Varying the mantissa/exponent bit-width division changes the trade-off between accurate representation of the data around mean of the distribution,
17 | and the ability to capture its tails. The more outliers are present in the data, the more exponent bits is useful to allocate for the best results. In the second part we provide the code to reproduce the post-training quantization (PTQ)
18 | results for MobileNetV2, and Resnet-18 pre-trained on ImageNet.
19 |
20 |
21 |
22 | ## How to install
23 | Make sure to have Python ≥3.8 (tested with Python 3.8.10) and
24 | ensure the latest version of `pip` (tested with 21.3.1):
25 | ```bash
26 | python3 -m venv env
27 | source env/bin/activate
28 | pip install --upgrade --no-deps pip
29 | ```
30 |
31 | Next, install PyTorch 1.11.0 with the appropriate CUDA version (tested with CUDA 10.0):
32 | ```bash
33 | pip install torch==1.11.0 torchvision==0.12.0
34 | ```
35 |
36 | Finally, install the remaining dependencies using pip:
37 | ```bash
38 | pip install -r requirements.txt
39 | ```
40 | ## Running experiments
41 | ### Analytical expected SQNR computations
42 | The main run file to compute the expected SQNR for different distributions using different formats is
43 | `compute_quant_error.py`. The script takes no input arguments and computes the SQNR for different distributions and formats:
44 | ```bash
45 | python compute_quant_error.py
46 | ```
47 | ### ImageNet experiments
48 | The main run file to reproduce the ImageNet experiments is `image_net.py`.
49 | It contains commands for validating models quantized with post-training quantization.
50 | You can see the full list of options for each command using `python image_net.py [COMMAND] --help`.
51 | ```bash
52 | Usage: image_net.py [OPTIONS] COMMAND [ARGS]...
53 |
54 | Options:
55 | --help Show this message and exit.
56 |
57 | Commands:
58 | validate-quantized
59 | ```
60 |
61 | To reproduce the experiments run:
62 | ```bash
63 | python image_net.py validate-quantized --images-dir
64 | --architecture --batch-size 64 --seed 10
65 | --model-dir # only needed for MobileNet-V2
66 | --n-bits 8 --cuda --load-type fp32 --quant-setup all --qmethod fp_quantizer --per-channel
67 | --fp8-mantissa-bits=5 --fp8-set-maxval --no-fp8-mse-include-mantissa-bits
68 | --weight-quant-method=current_minmax --act-quant-method=allminmax --num-est-batches=1
69 | ```
70 |
71 | where can be mobilenet_v2_quantized or resnet18_quantized.
72 | Please note that only MobileNet-V2 requires pre-trained weights that can be downloaded here (the tar file is used as it is without a need to untar):
73 | - [MobileNetV2](https://drive.google.com/open?id=1jlto6HRVD3ipNkAl1lNhDbkBp7HylaqR)
74 |
75 | ## Reference
76 | If you find our work useful, please cite
77 | ```
78 | @article{kuzmin2022fp8,
79 | title={FP8 Quantization: The Power of the Exponent},
80 | author={Kuzmin, Andrey and Van Baalen, Mart and Ren, Yuwei and Nagel, Markus and Peters, Jorn and Blankevoort, Tijmen},
81 | journal={arXiv preprint arXiv:2208.09225},
82 | year={2022}
83 | }
84 | ```
85 |
--------------------------------------------------------------------------------
/compute_quant_error.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from utils.distributions import ClippedGaussDistr, UniformDistr, ClippedStudentTDistr
8 | from quantization.quant_error_estimator import (
9 | compute_expected_quant_mse,
10 | compute_expected_dot_prod_mse,
11 | )
12 | from quantization.quantizers.fp8_quantizer import FPQuantizer
13 | from quantization.range_estimators import estimate_range_line_search
14 | from quantization.quantizers.uniform_quantizers import SymmetricUniformQuantizer
15 | from utils import seed_all
16 |
17 |
18 | def compute_quant_error(distr, n_bits=8, n_samples=5000000, seed=10):
19 | seed_all(seed)
20 | distr_sample = torch.tensor(distr.sample((n_samples,)))
21 | for exp_bits in [5, 4, 3, 2, 0]:
22 | mantissa_bits = n_bits - 1 - exp_bits
23 | if exp_bits > 0:
24 | quant = FPQuantizer(n_bits=8, mantissa_bits=mantissa_bits, set_maxval=True)
25 | elif exp_bits == 0:
26 | quant = SymmetricUniformQuantizer(n_bits=n_bits)
27 |
28 | (quant_range_min, quant_range_max) = estimate_range_line_search(distr_sample, quant)
29 | quant_expected_mse = compute_expected_quant_mse(
30 | distr, quant, quant_range_min, quant_range_max, n_samples
31 | )
32 | quant_sqnr = -10.0 * np.log10(quant_expected_mse)
33 |
34 | dot_prod_expected_mse = compute_expected_dot_prod_mse(
35 | distr,
36 | distr,
37 | quant,
38 | quant,
39 | quant_range_min,
40 | quant_range_max,
41 | quant_range_min,
42 | quant_range_max,
43 | )
44 |
45 | dot_prod_sqnr = -10.0 * np.log10(dot_prod_expected_mse)
46 |
47 | print(
48 | "FP8 {} E {} M Quantization: expected MSE {:.2e}".format(
49 | exp_bits, mantissa_bits, quant_expected_mse
50 | ),
51 | " SQNR ",
52 | "{:.2e}\n".format(quant_sqnr),
53 | "Dot product:".rjust(23),
54 | " expected MSE {:.2e}".format(dot_prod_expected_mse),
55 | " SQNR ",
56 | "{:.2e}".format(dot_prod_sqnr),
57 | )
58 |
59 |
60 | if __name__ == "__main__":
61 | distr_list = [
62 | UniformDistr(range_min=-1.0, range_max=1.0, params_dict={}),
63 | ClippedGaussDistr(params_dict={"mu": 0.0, "sigma": 1.0}, range_min=-10.0, range_max=10.0),
64 | ClippedStudentTDistr(params_dict={"nu": 8.0}, range_min=-100.0, range_max=100.0),
65 | ]
66 |
67 | for distr in distr_list:
68 | print("*" * 80)
69 | distr.print()
70 | compute_quant_error(distr)
71 |
--------------------------------------------------------------------------------
/image_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import logging
4 | import os
5 |
6 | import click
7 | from ignite.contrib.handlers import ProgressBar
8 | from ignite.engine import create_supervised_evaluator
9 | from ignite.metrics import Accuracy, TopKCategoricalAccuracy, Loss
10 | from torch.nn import CrossEntropyLoss
11 |
12 | from quantization.utils import pass_data_for_range_estimation
13 | from utils import DotDict
14 | from utils.click_options import (
15 | qat_options,
16 | quantization_options,
17 | fp8_options,
18 | quant_params_dict,
19 | base_options,
20 | )
21 | from utils.qat_utils import get_dataloaders_and_model, ReestimateBNStats
22 |
23 |
24 | class Config(DotDict):
25 | pass
26 |
27 |
28 | @click.group()
29 | def fp8_cmd_group():
30 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
31 |
32 |
33 | pass_config = click.make_pass_decorator(Config, ensure=True)
34 |
35 |
36 | @fp8_cmd_group.command()
37 | @pass_config
38 | @base_options
39 | @fp8_options
40 | @quantization_options
41 | @qat_options
42 | @click.option(
43 | "--load-type",
44 | type=click.Choice(["fp32", "quantized"]),
45 | default="quantized",
46 | help='Either "fp32", or "quantized". Specify weather to load a quantized or a FP ' "model.",
47 | )
48 | def validate_quantized(config, load_type):
49 | """
50 | function for running validation on pre-trained quantized models
51 | """
52 | print("Setting up network and data loaders")
53 | qparams = quant_params_dict(config)
54 |
55 | dataloaders, model = get_dataloaders_and_model(config=config, load_type=load_type, **qparams)
56 |
57 | if load_type == "fp32":
58 | # Estimate ranges using training data
59 | pass_data_for_range_estimation(
60 | loader=dataloaders.train_loader,
61 | model=model,
62 | act_quant=config.quant.act_quant,
63 | weight_quant=config.quant.weight_quant,
64 | max_num_batches=config.quant.num_est_batches,
65 | )
66 | # Ensure we have the desired quant state
67 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant)
68 |
69 | # Fix ranges
70 | model.fix_ranges()
71 |
72 | # Create evaluator
73 | loss_func = CrossEntropyLoss()
74 | metrics = {
75 | "top_1_accuracy": Accuracy(),
76 | "top_5_accuracy": TopKCategoricalAccuracy(),
77 | "loss": Loss(loss_func),
78 | }
79 |
80 | pbar = ProgressBar()
81 | evaluator = create_supervised_evaluator(
82 | model=model, metrics=metrics, device="cuda" if config.base.cuda else "cpu"
83 | )
84 | pbar.attach(evaluator)
85 | print("Model with the ranges estimated:\n{}".format(model))
86 |
87 | # BN Re-estimation
88 | if config.qat.reestimate_bn_stats:
89 | ReestimateBNStats(
90 | model, dataloaders.train_loader, num_batches=int(0.02 * len(dataloaders.train_loader))
91 | )(None)
92 |
93 | print("Start quantized validation")
94 | evaluator.run(dataloaders.val_loader)
95 | final_metrics = evaluator.state.metrics
96 | print(final_metrics)
97 |
98 |
99 | if __name__ == "__main__":
100 | fp8_cmd_group()
101 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from models.mobilenet_v2_quantized import mobilenetv2_quantized
6 | from models.resnet_quantized import resnet18_quantized, resnet50_quantized
7 | from utils import ClassEnumOptions, MethodMap
8 |
9 |
10 | class QuantArchitectures(ClassEnumOptions):
11 | mobilenet_v2_quantized = MethodMap(mobilenetv2_quantized)
12 | resnet18_quantized = MethodMap(resnet18_quantized)
13 | resnet50_quantized = MethodMap(resnet50_quantized)
14 |
--------------------------------------------------------------------------------
/models/mobilenet_v2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | # https://github.com/tonylins/pytorch-mobilenet-v2
6 |
7 | import math
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | __all__ = ["MobileNetV2"]
13 |
14 |
15 | def conv_bn(inp, oup, stride):
16 | return nn.Sequential(
17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)
18 | )
19 |
20 |
21 | def conv_1x1_bn(inp, oup):
22 | return nn.Sequential(
23 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)
24 | )
25 |
26 |
27 | class InvertedResidual(nn.Module):
28 | def __init__(self, inp, oup, stride, expand_ratio):
29 | super(InvertedResidual, self).__init__()
30 | self.stride = stride
31 | assert stride in [1, 2]
32 |
33 | hidden_dim = round(inp * expand_ratio)
34 | self.use_res_connect = self.stride == 1 and inp == oup
35 |
36 | if expand_ratio == 1:
37 | self.conv = nn.Sequential(
38 | # dw
39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
40 | nn.BatchNorm2d(hidden_dim),
41 | nn.ReLU6(inplace=True),
42 | # pw-linear
43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(oup),
45 | )
46 | else:
47 | self.conv = nn.Sequential(
48 | # pw
49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
50 | nn.BatchNorm2d(hidden_dim),
51 | nn.ReLU6(inplace=True),
52 | # dw
53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
54 | nn.BatchNorm2d(hidden_dim),
55 | nn.ReLU6(inplace=True),
56 | # pw-linear
57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
58 | nn.BatchNorm2d(oup),
59 | )
60 |
61 | def forward(self, x):
62 | if self.use_res_connect:
63 | return x + self.conv(x)
64 | else:
65 | return self.conv(x)
66 |
67 |
68 | class MobileNetV2(nn.Module):
69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.0, dropout=0.0):
70 | super().__init__()
71 | block = InvertedResidual
72 | input_channel = 32
73 | last_channel = 1280
74 | inverted_residual_setting = [
75 | # t, c, n, s
76 | [1, 16, 1, 1],
77 | [6, 24, 2, 2],
78 | [6, 32, 3, 2],
79 | [6, 64, 4, 2],
80 | [6, 96, 3, 1],
81 | [6, 160, 3, 2],
82 | [6, 320, 1, 1],
83 | ]
84 |
85 | # building first layer
86 | assert input_size % 32 == 0
87 | input_channel = int(input_channel * width_mult)
88 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
89 | features = [conv_bn(3, input_channel, 2)]
90 | # building inverted residual blocks
91 | for t, c, n, s in inverted_residual_setting:
92 | output_channel = int(c * width_mult)
93 | for i in range(n):
94 | if i == 0:
95 | features.append(block(input_channel, output_channel, s, expand_ratio=t))
96 | else:
97 | features.append(block(input_channel, output_channel, 1, expand_ratio=t))
98 | input_channel = output_channel
99 | # building last several layers
100 | features.append(conv_1x1_bn(input_channel, self.last_channel))
101 | features.append(nn.AvgPool2d(input_size // 32))
102 | # make it nn.Sequential
103 | self.features = nn.Sequential(*features)
104 |
105 | # building classifier
106 | self.classifier = nn.Sequential(
107 | nn.Dropout(dropout),
108 | nn.Linear(self.last_channel, n_class),
109 | )
110 |
111 | self._initialize_weights()
112 |
113 | def forward(self, x):
114 | x = self.features(x)
115 | x = F.adaptive_avg_pool2d(x, 1).squeeze() # type: ignore[arg-type] # accepted slang
116 | x = self.classifier(x)
117 | return x
118 |
119 | def _initialize_weights(self):
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
124 | if m.bias is not None:
125 | m.bias.data.zero_()
126 | elif isinstance(m, nn.BatchNorm2d):
127 | m.weight.data.fill_(1)
128 | m.bias.data.zero_()
129 | elif isinstance(m, nn.Linear):
130 | n = m.weight.size(1)
131 | m.weight.data.normal_(0, 0.01)
132 | m.bias.data.zero_()
133 |
--------------------------------------------------------------------------------
/models/mobilenet_v2_quantized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import os
6 | import re
7 | import torch
8 | from collections import OrderedDict
9 | from models.mobilenet_v2 import MobileNetV2, InvertedResidual
10 | from quantization.autoquant_utils import quantize_sequential, Flattener, quantize_model, BNQConv
11 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts
12 | from quantization.base_quantized_model import QuantizedModel
13 |
14 |
15 | class QuantizedInvertedResidual(QuantizedActivation):
16 | def __init__(self, inv_res_orig, **quant_params):
17 | super().__init__(**quant_params)
18 | self.use_res_connect = inv_res_orig.use_res_connect
19 | self.conv = quantize_sequential(inv_res_orig.conv, **quant_params)
20 |
21 | def forward(self, x):
22 | if self.use_res_connect:
23 | x = x + self.conv(x)
24 | return self.quantize_activations(x)
25 | else:
26 | return self.conv(x)
27 |
28 |
29 | class QuantizedMobileNetV2(QuantizedModel):
30 | def __init__(self, model_fp, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params):
31 | super().__init__(input_size)
32 | specials = {InvertedResidual: QuantizedInvertedResidual}
33 | # quantize and copy parts from original model
34 | quantize_input = quant_setup and quant_setup == "LSQ_paper"
35 | self.features = quantize_sequential(
36 | model_fp.features,
37 | tie_activation_quantizers=not quantize_input,
38 | specials=specials,
39 | **quant_params,
40 | )
41 |
42 | self.flattener = Flattener()
43 | self.classifier = quantize_model(model_fp.classifier, **quant_params)
44 |
45 | if quant_setup == "FP_logits":
46 | print("Do not quantize output of FC layer")
47 | self.classifier[1].activation_quantizer = FP32Acts()
48 | # self.classifier.activation_quantizer = FP32Acts() # no activation quantization of logits
49 | elif quant_setup == "fc4":
50 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
51 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4
52 | elif quant_setup == "fc4_dw8":
53 | print("\n\n### fc4_dw8 setup ###\n\n")
54 | # FC layer in 4 bits, depth-wise separable once in 8 bit
55 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
56 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4
57 | for name, module in self.named_modules():
58 | if isinstance(module, BNQConv) and module.groups == module.in_channels:
59 | module.weight_quantizer.quantizer.n_bits = 8
60 | print(f"Set layer {name} to 8 bits")
61 | elif quant_setup == "LSQ":
62 | print("Set quantization to LSQ (first+last layer in 8 bits)")
63 | # Weights of the first layer
64 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
65 | # The quantizer of the last conv_layer layer (input to avgpool with tied quantizers)
66 | self.features[-2][0].activation_quantizer.quantizer.n_bits = 8
67 | # Weights of the last layer
68 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8
69 | # no activation quantization of logits
70 | self.classifier[1].activation_quantizer = FP32Acts()
71 | elif quant_setup == "LSQ_paper":
72 | # Weights of the first layer
73 | self.features[0][0].activation_quantizer = FP32Acts()
74 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
75 | # Weights of the last layer
76 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8
77 | self.classifier[1].activation_quantizer.quantizer.n_bits = 8
78 | # Set all QuantizedActivations to FP32
79 | for layer in self.features.modules():
80 | if isinstance(layer, QuantizedActivation):
81 | layer.activation_quantizer = FP32Acts()
82 | elif quant_setup is not None and quant_setup != "all":
83 | raise ValueError(
84 | "Quantization setup '{}' not supported for MobilenetV2".format(quant_setup)
85 | )
86 |
87 | def forward(self, x):
88 | x = self.features(x)
89 | x = self.flattener(x)
90 | x = self.classifier(x)
91 |
92 | return x
93 |
94 |
95 | def mobilenetv2_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams):
96 | fp_model = MobileNetV2()
97 | if pretrained and load_type == "fp32":
98 | # Load model from pretrained FP32 weights
99 | assert os.path.exists(model_dir)
100 | print(f"Loading pretrained weights from {model_dir}")
101 | state_dict = torch.load(model_dir)
102 | fp_model.load_state_dict(state_dict)
103 | quant_model = QuantizedMobileNetV2(fp_model, **qparams)
104 | elif load_type == "quantized":
105 | # Load pretrained QuantizedModel
106 | print(f"Loading pretrained quantized model from {model_dir}")
107 | state_dict = torch.load(model_dir)
108 | quant_model = QuantizedMobileNetV2(fp_model, **qparams)
109 | quant_model.load_state_dict(state_dict, strict=False)
110 | else:
111 | raise ValueError("wrong load_type specified")
112 |
113 | return quant_model
114 |
--------------------------------------------------------------------------------
/models/resnet_quantized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import torch
5 | from torch import nn
6 | from torchvision.models.resnet import BasicBlock, Bottleneck
7 | from torchvision.models import resnet18, resnet50
8 |
9 | from quantization.autoquant_utils import quantize_model, Flattener, QuantizedActivationWrapper
10 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts
11 | from quantization.base_quantized_model import QuantizedModel
12 |
13 |
14 | class QuantizedBlock(QuantizedActivation):
15 | def __init__(self, block, **quant_params):
16 | super().__init__(**quant_params)
17 |
18 | if isinstance(block, Bottleneck):
19 | features = nn.Sequential(
20 | block.conv1,
21 | block.bn1,
22 | block.relu,
23 | block.conv2,
24 | block.bn2,
25 | block.relu,
26 | block.conv3,
27 | block.bn3,
28 | )
29 | elif isinstance(block, BasicBlock):
30 | features = nn.Sequential(block.conv1, block.bn1, block.relu, block.conv2, block.bn2)
31 |
32 | self.features = quantize_model(features, **quant_params)
33 | self.downsample = (
34 | quantize_model(block.downsample, **quant_params) if block.downsample else None
35 | )
36 |
37 | self.relu = block.relu
38 |
39 | def forward(self, x):
40 | residual = x if self.downsample is None else self.downsample(x)
41 | out = self.features(x)
42 |
43 | out += residual
44 | out = self.relu(out)
45 |
46 | return self.quantize_activations(out)
47 |
48 |
49 | class QuantizedResNet(QuantizedModel):
50 | def __init__(self, resnet, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params):
51 | super().__init__(input_size)
52 | specials = {BasicBlock: QuantizedBlock, Bottleneck: QuantizedBlock}
53 |
54 | if hasattr(resnet, "maxpool"):
55 | # ImageNet ResNet case
56 | features = nn.Sequential(
57 | resnet.conv1,
58 | resnet.bn1,
59 | resnet.relu,
60 | resnet.maxpool,
61 | resnet.layer1,
62 | resnet.layer2,
63 | resnet.layer3,
64 | resnet.layer4,
65 | )
66 | else:
67 | # Tiny ImageNet ResNet case
68 | features = nn.Sequential(
69 | resnet.conv1,
70 | resnet.bn1,
71 | resnet.relu,
72 | resnet.layer1,
73 | resnet.layer2,
74 | resnet.layer3,
75 | resnet.layer4,
76 | )
77 |
78 | self.features = quantize_model(features, specials=specials, **quant_params)
79 |
80 | if quant_setup and quant_setup == "LSQ_paper":
81 | # Keep avgpool intact as we quantize the input the last layer
82 | self.avgpool = resnet.avgpool
83 | else:
84 | self.avgpool = QuantizedActivationWrapper(
85 | resnet.avgpool,
86 | tie_activation_quantizers=True,
87 | input_quantizer=self.features[-1][-1].activation_quantizer,
88 | **quant_params,
89 | )
90 | self.flattener = Flattener()
91 | self.fc = quantize_model(resnet.fc, **quant_params)
92 |
93 | # Adapt to specific quantization setup
94 | if quant_setup == "LSQ":
95 | print("Set quantization to LSQ (first+last layer in 8 bits)")
96 | # Weights of the first layer
97 | self.features[0].weight_quantizer.quantizer.n_bits = 8
98 | # The quantizer of the residual (input to last layer)
99 | self.features[-1][-1].activation_quantizer.quantizer.n_bits = 8
100 | # Output of the last conv (input to last layer)
101 | self.features[-1][-1].features[-1].activation_quantizer.quantizer.n_bits = 8
102 | # Weights of the last layer
103 | self.fc.weight_quantizer.quantizer.n_bits = 8
104 | # no activation quantization of logits
105 | self.fc.activation_quantizer = FP32Acts()
106 | elif quant_setup == "LSQ_paper":
107 | # Weights of the first layer
108 | self.features[0].activation_quantizer = FP32Acts()
109 | self.features[0].weight_quantizer.quantizer.n_bits = 8
110 | # Weights of the last layer
111 | self.fc.activation_quantizer.quantizer.n_bits = 8
112 | self.fc.weight_quantizer.quantizer.n_bits = 8
113 | # Set all QuantizedActivations to FP32
114 | for layer in self.features.modules():
115 | if isinstance(layer, QuantizedActivation):
116 | layer.activation_quantizer = FP32Acts()
117 | elif quant_setup == "FP_logits":
118 | print("Do not quantize output of FC layer")
119 | self.fc.activation_quantizer = FP32Acts() # no activation quantization of logits
120 | elif quant_setup == "fc4":
121 | self.features[0].weight_quantizer.quantizer.n_bits = 8
122 | self.fc.weight_quantizer.quantizer.n_bits = 4
123 | elif quant_setup is not None and quant_setup != "all":
124 | raise ValueError("Quantization setup '{}' not supported for Resnet".format(quant_setup))
125 |
126 | def forward(self, x):
127 | x = self.features(x)
128 | x = self.avgpool(x)
129 | # x = x.view(x.size(0), -1)
130 | x = self.flattener(x)
131 | x = self.fc(x)
132 |
133 | return x
134 |
135 |
136 | def resnet18_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams):
137 | if load_type == "fp32":
138 | # Load model from pretrained FP32 weights
139 | fp_model = resnet18(pretrained=pretrained)
140 | quant_model = QuantizedResNet(fp_model, **qparams)
141 | elif load_type == "quantized":
142 | # Load pretrained QuantizedModel
143 | print(f"Loading pretrained quantized model from {model_dir}")
144 | state_dict = torch.load(model_dir)
145 | fp_model = resnet18()
146 | quant_model = QuantizedResNet(fp_model, **qparams)
147 | quant_model.load_state_dict(state_dict)
148 | else:
149 | raise ValueError("wrong load_type specified")
150 | return quant_model
151 |
152 |
153 | def resnet50_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams):
154 | if load_type == "fp32":
155 | # Load model from pretrained FP32 weights
156 | fp_model = resnet50(pretrained=pretrained)
157 | quant_model = QuantizedResNet(fp_model, **qparams)
158 | elif load_type == "quantized":
159 | # Load pretrained QuantizedModel
160 | print(f"Loading pretrained quantized model from {model_dir}")
161 | state_dict = torch.load(model_dir)
162 | fp_model = resnet50()
163 | quant_model = QuantizedResNet(fp_model, **qparams)
164 | quant_model.load_state_dict(state_dict)
165 | else:
166 | raise ValueError("wrong load_type specified")
167 | return quant_model
168 |
--------------------------------------------------------------------------------
/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from . import utils, autoquant_utils, quantized_folded_bn
6 |
--------------------------------------------------------------------------------
/quantization/autoquant_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import copy
6 | import warnings
7 |
8 | from torch import nn
9 | from torch.nn import functional as F
10 | from torch.nn.modules.conv import _ConvNd
11 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd
12 |
13 |
14 | from quantization.base_quantized_classes import QuantizedActivation, QuantizedModule
15 | from quantization.hijacker import QuantizationHijacker, activations_set
16 | from quantization.quantization_manager import QuantizationManager
17 | from quantization.quantized_folded_bn import BNFusedHijacker
18 |
19 |
20 | class QuantConv1d(QuantizationHijacker, nn.Conv1d):
21 | def run_forward(self, x, weight, bias, offsets=None):
22 | return F.conv1d(
23 | x.contiguous(),
24 | weight.contiguous(),
25 | bias=bias,
26 | stride=self.stride,
27 | padding=self.padding,
28 | dilation=self.dilation,
29 | groups=self.groups,
30 | )
31 |
32 |
33 | class QuantConv(QuantizationHijacker, nn.Conv2d):
34 | def run_forward(self, x, weight, bias, offsets=None):
35 | return F.conv2d(
36 | x.contiguous(),
37 | weight.contiguous(),
38 | bias=bias,
39 | stride=self.stride,
40 | padding=self.padding,
41 | dilation=self.dilation,
42 | groups=self.groups,
43 | )
44 |
45 |
46 | class QuantConvTransposeBase(QuantizationHijacker):
47 | def quantize_weights(self, weights):
48 | if self.per_channel_weights:
49 | # NOTE: ND tranpose conv weights are stored as (in_channels, out_channels, *)
50 | # instead of (out_channels, in_channels, *) for convs
51 | # and per-channel quantization should be applied to out channels
52 | # transposing before passing to quantizer is trick to avoid
53 | # changing logic in range estimators and quantizers
54 | weights = weights.transpose(1, 0).contiguous()
55 | weights = self.weight_quantizer(weights)
56 | if self.per_channel_weights:
57 | weights = weights.transpose(1, 0).contiguous()
58 | return weights
59 |
60 |
61 | class QuantConvTranspose1d(QuantConvTransposeBase, nn.ConvTranspose1d):
62 | def run_forward(self, x, weight, bias, offsets=None):
63 | return F.conv_transpose1d(
64 | x.contiguous(),
65 | weight.contiguous(),
66 | bias=bias,
67 | stride=self.stride,
68 | padding=self.padding,
69 | output_padding=self.output_padding,
70 | dilation=self.dilation,
71 | groups=self.groups,
72 | )
73 |
74 |
75 | class QuantConvTranspose(QuantConvTransposeBase, nn.ConvTranspose2d):
76 | def run_forward(self, x, weight, bias, offsets=None):
77 | return F.conv_transpose2d(
78 | x.contiguous(),
79 | weight.contiguous(),
80 | bias=bias,
81 | stride=self.stride,
82 | padding=self.padding,
83 | output_padding=self.output_padding,
84 | dilation=self.dilation,
85 | groups=self.groups,
86 | )
87 |
88 |
89 | class QuantLinear(QuantizationHijacker, nn.Linear):
90 | def run_forward(self, x, weight, bias, offsets=None):
91 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias)
92 |
93 |
94 | class BNQConv1d(BNFusedHijacker, nn.Conv1d):
95 | def run_forward(self, x, weight, bias, offsets=None):
96 | return F.conv1d(
97 | x.contiguous(),
98 | weight.contiguous(),
99 | bias=bias,
100 | stride=self.stride,
101 | padding=self.padding,
102 | dilation=self.dilation,
103 | groups=self.groups,
104 | )
105 |
106 |
107 | class BNQConv(BNFusedHijacker, nn.Conv2d):
108 | def run_forward(self, x, weight, bias, offsets=None):
109 | return F.conv2d(
110 | x.contiguous(),
111 | weight.contiguous(),
112 | bias=bias,
113 | stride=self.stride,
114 | padding=self.padding,
115 | dilation=self.dilation,
116 | groups=self.groups,
117 | )
118 |
119 |
120 | class BNQLinear(BNFusedHijacker, nn.Linear):
121 | def run_forward(self, x, weight, bias, offsets=None):
122 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias)
123 |
124 |
125 | class QuantizedActivationWrapper(QuantizedActivation):
126 | """
127 | Wraps over a layer and quantized the activation.
128 | It also allow for tying the input and output quantizer which is helpful
129 | for layers such Average Pooling
130 | """
131 |
132 | def __init__(
133 | self,
134 | layer,
135 | tie_activation_quantizers=False,
136 | input_quantizer: QuantizationManager = None,
137 | *args,
138 | **kwargs,
139 | ):
140 | super().__init__(*args, **kwargs)
141 | self.tie_activation_quantizers = tie_activation_quantizers
142 | if input_quantizer:
143 | assert isinstance(input_quantizer, QuantizationManager)
144 | self.activation_quantizer = input_quantizer
145 | self.layer = layer
146 |
147 | def quantize_activations_no_range_update(self, x):
148 | if self._quant_a:
149 | return self.activation_quantizer.quantizer(x)
150 | else:
151 | return x
152 |
153 | def forward(self, x):
154 | x = self.layer(x)
155 | if self.tie_activation_quantizers:
156 | # The input activation quantizer is used to quantize the activation
157 | # but without updating the quantization range
158 | return self.quantize_activations_no_range_update(x)
159 | else:
160 | return self.quantize_activations(x)
161 |
162 | def extra_repr(self):
163 | return f"tie_activation_quantizers={self.tie_activation_quantizers}"
164 |
165 |
166 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm):
167 | def run_forward(self, x, weight, bias, offsets=None):
168 | return F.layer_norm(
169 | input=x.contiguous(),
170 | normalized_shape=self.normalized_shape,
171 | weight=weight.contiguous(),
172 | bias=bias.contiguous(),
173 | eps=self.eps,
174 | )
175 |
176 |
177 | class Flattener(nn.Module):
178 | def forward(self, x):
179 | return x.view(x.shape[0], -1)
180 |
181 |
182 | # Non BN Quant Modules Map
183 | non_bn_module_map = {
184 | nn.Conv1d: QuantConv1d,
185 | nn.Conv2d: QuantConv,
186 | nn.ConvTranspose1d: QuantConvTranspose1d,
187 | nn.ConvTranspose2d: QuantConvTranspose,
188 | nn.Linear: QuantLinear,
189 | nn.LayerNorm: QuantLayerNorm,
190 | }
191 |
192 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd)
193 | # BN Quant Modules Map
194 | bn_module_map = {nn.Conv1d: BNQConv1d, nn.Conv2d: BNQConv, nn.Linear: BNQLinear}
195 |
196 | quant_conv_modules = (QuantConv1d, QuantConv, BNQConv1d, BNQConv)
197 |
198 |
199 | def next_bn(module, i):
200 | return len(module) > i + 1 and isinstance(module[i + 1], (nn.BatchNorm2d, nn.BatchNorm1d))
201 |
202 |
203 | def get_act(module, i):
204 | # Case 1: conv + act
205 | if len(module) - i > 1 and isinstance(module[i + 1], tuple(activations_set)):
206 | return module[i + 1], i + 1
207 |
208 | # Case 2: conv + bn + act
209 | if (
210 | len(module) - i > 2
211 | and next_bn(module, i)
212 | and isinstance(module[i + 2], tuple(activations_set))
213 | ):
214 | return module[i + 2], i + 2
215 |
216 | # Case 3: conv + bn + X -> return false
217 | # Case 4: conv + X -> return false
218 | return None, None
219 |
220 |
221 | def get_conv_args(module):
222 | args = dict(
223 | in_channels=module.in_channels,
224 | out_channels=module.out_channels,
225 | kernel_size=module.kernel_size,
226 | stride=module.stride,
227 | padding=module.padding,
228 | dilation=module.dilation,
229 | groups=module.groups,
230 | bias=module.bias is not None,
231 | )
232 | if isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d)):
233 | args["output_padding"] = module.output_padding
234 | return args
235 |
236 |
237 | def get_linear_args(module):
238 | args = dict(
239 | in_features=module.in_features,
240 | out_features=module.out_features,
241 | bias=module.bias is not None,
242 | )
243 | return args
244 |
245 |
246 | def get_layernorm_args(module):
247 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps)
248 | return args
249 |
250 |
251 | def get_module_args(mod, act):
252 | if isinstance(mod, _ConvNd):
253 | kwargs = get_conv_args(mod)
254 | elif isinstance(mod, nn.Linear):
255 | kwargs = get_linear_args(mod)
256 | elif isinstance(mod, nn.LayerNorm):
257 | kwargs = get_layernorm_args(mod)
258 | else:
259 | raise ValueError
260 |
261 | kwargs["activation"] = act
262 |
263 | return kwargs
264 |
265 |
266 | def fold_bn(module, i, **quant_params):
267 | bn = next_bn(module, i)
268 | act, act_idx = get_act(module, i)
269 | modmap = bn_module_map if bn else non_bn_module_map
270 | modtype = modmap[type(module[i])]
271 |
272 | kwargs = get_module_args(module[i], act)
273 | new_module = modtype(**kwargs, **quant_params)
274 | new_module.weight.data = module[i].weight.data.clone()
275 |
276 | if bn:
277 | new_module.gamma.data = module[i + 1].weight.data.clone()
278 | new_module.beta.data = module[i + 1].bias.data.clone()
279 | new_module.running_mean.data = module[i + 1].running_mean.data.clone()
280 | new_module.running_var.data = module[i + 1].running_var.data.clone()
281 | if module[i].bias is not None:
282 | new_module.running_mean.data -= module[i].bias.data
283 | print("Warning: bias in conv/linear before batch normalization.")
284 | new_module.epsilon = module[i + 1].eps
285 |
286 | elif module[i].bias is not None:
287 | new_module.bias.data = module[i].bias.data.clone()
288 |
289 | return new_module, i + int(bool(act)) + int(bn) + 1
290 |
291 |
292 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params):
293 | specials = specials or dict()
294 |
295 | i = 0
296 | quant_modules = []
297 | while i < len(model):
298 | if isinstance(model[i], QuantizedModule):
299 | quant_modules.append(model[i])
300 | elif type(model[i]) in non_bn_module_map:
301 | new_module, new_i = fold_bn(model, i, **quant_params)
302 | quant_modules.append(new_module)
303 | i = new_i
304 | continue
305 |
306 | elif type(model[i]) in specials:
307 | quant_modules.append(specials[type(model[i])](model[i], **quant_params))
308 |
309 | elif isinstance(model[i], non_param_modules):
310 | # Check for last quantizer
311 | input_quantizer = None
312 | if quant_modules and isinstance(quant_modules[-1], QuantizedModule):
313 | last_layer = quant_modules[-1]
314 | input_quantizer = quant_modules[-1].activation_quantizer
315 | elif (
316 | quant_modules
317 | and isinstance(quant_modules[-1], nn.Sequential)
318 | and isinstance(quant_modules[-1][-1], QuantizedModule)
319 | ):
320 | last_layer = quant_modules[-1][-1]
321 | input_quantizer = quant_modules[-1][-1].activation_quantizer
322 |
323 | if input_quantizer and tie_activation_quantizers:
324 | # If input quantizer is found the tie input/output act quantizers
325 | print(
326 | f"Tying input quantizer {i-1}^th layer of type {type(last_layer)} to the "
327 | f"quantized {type(model[i])} following it"
328 | )
329 | quant_modules.append(
330 | QuantizedActivationWrapper(
331 | model[i],
332 | tie_activation_quantizers=tie_activation_quantizers,
333 | input_quantizer=input_quantizer,
334 | **quant_params,
335 | )
336 | )
337 | else:
338 | # Input quantizer not found
339 | quant_modules.append(QuantizedActivationWrapper(model[i], **quant_params))
340 | if tie_activation_quantizers:
341 | warnings.warn("Input quantizer not found, so we do not tie quantizers")
342 | else:
343 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params))
344 | i += 1
345 | return nn.Sequential(*quant_modules)
346 |
347 |
348 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params):
349 | specials = specials or dict()
350 |
351 | if isinstance(model, nn.Sequential):
352 | quant_model = quantize_sequential(
353 | model, specials, tie_activation_quantizers, **quant_params
354 | )
355 |
356 | elif type(model) in specials:
357 | quant_model = specials[type(model)](model, **quant_params)
358 |
359 | elif isinstance(model, non_param_modules):
360 | quant_model = QuantizedActivationWrapper(model, **quant_params)
361 |
362 | elif type(model) in non_bn_module_map:
363 | # If we do isinstance() then we might run into issues with modules that inherit from
364 | # one of these classes, for whatever reason
365 | modtype = non_bn_module_map[type(model)]
366 | kwargs = get_module_args(model, None)
367 | quant_model = modtype(**kwargs, **quant_params)
368 |
369 | quant_model.weight.data = model.weight.data
370 | if getattr(model, "bias", None) is not None:
371 | quant_model.bias.data = model.bias.data
372 |
373 | else:
374 | # Unknown type, try to quantize all child modules
375 | quant_model = copy.deepcopy(model)
376 | for name, module in quant_model._modules.items():
377 | new_model = quantize_model(module, specials=specials, **quant_params)
378 | if new_model is not None:
379 | setattr(quant_model, name, new_model)
380 |
381 | return quant_model
382 |
--------------------------------------------------------------------------------
/quantization/base_quantized_classes.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from quantization.quantization_manager import QuantizationManager
9 | from quantization.quantizers import QuantizerBase, AsymmetricUniformQuantizer
10 | from quantization.range_estimators import (
11 | RangeEstimatorBase,
12 | CurrentMinMaxEstimator,
13 | RunningMinMaxEstimator,
14 | )
15 |
16 |
17 | def _set_layer_learn_ranges(layer):
18 | if isinstance(layer, QuantizationManager):
19 | if layer.quantizer.is_initialized:
20 | layer.learn_ranges()
21 |
22 |
23 | def _set_layer_fix_ranges(layer):
24 | if isinstance(layer, QuantizationManager):
25 | if layer.quantizer.is_initialized:
26 | layer.fix_ranges()
27 |
28 |
29 | def _set_layer_estimate_ranges(layer):
30 | if isinstance(layer, QuantizationManager):
31 | layer.estimate_ranges()
32 |
33 |
34 | def _set_layer_estimate_ranges_train(layer):
35 | if isinstance(layer, QuantizationManager):
36 | if layer.quantizer.is_initialized:
37 | layer.estimate_ranges_train()
38 |
39 |
40 | class QuantizedModule(nn.Module):
41 | """
42 | Parent class for a quantized module. It adds the basic functionality of switching the module
43 | between quantized and full precision mode. It also defines the cached parameters and handles
44 | the reset of the cache properly.
45 | """
46 |
47 | def __init__(
48 | self,
49 | *args,
50 | method: QuantizerBase = AsymmetricUniformQuantizer,
51 | act_method=None,
52 | weight_range_method: RangeEstimatorBase = CurrentMinMaxEstimator,
53 | act_range_method: RangeEstimatorBase = RunningMinMaxEstimator,
54 | n_bits=8,
55 | n_bits_act=None,
56 | per_channel_weights=False,
57 | percentile=None,
58 | weight_range_options=None,
59 | act_range_options=None,
60 | scale_domain="linear",
61 | act_quant_kwargs={},
62 | weight_quant_kwargs={},
63 | quantize_input=False,
64 | fp8_kwargs=None,
65 | **kwargs
66 | ):
67 | kwargs.pop("act_quant_dict", None)
68 |
69 | super().__init__(*args, **kwargs)
70 |
71 | self.method = method
72 | self.act_method = act_method or method
73 | self.n_bits = n_bits
74 | self.n_bits_act = n_bits_act or n_bits
75 | self.per_channel_weights = per_channel_weights
76 | self.percentile = percentile
77 | self.weight_range_method = weight_range_method
78 | self.weight_range_options = weight_range_options if weight_range_options else {}
79 | self.act_range_method = act_range_method
80 | self.act_range_options = act_range_options if act_range_options else {}
81 | self.scale_domain = scale_domain
82 | self.quantize_input = quantize_input
83 | self.fp8_kwargs = fp8_kwargs or {}
84 |
85 | self.quant_params = None
86 | self.register_buffer("_quant_w", torch.BoolTensor([False]))
87 | self.register_buffer("_quant_a", torch.BoolTensor([False]))
88 |
89 | self.act_qparams = dict(
90 | n_bits=self.n_bits_act,
91 | scale_domain=self.scale_domain,
92 | **act_quant_kwargs,
93 | **self.fp8_kwargs
94 | )
95 | self.weight_qparams = dict(
96 | n_bits=self.n_bits,
97 | scale_domain=self.scale_domain,
98 | **weight_quant_kwargs,
99 | **self.fp8_kwargs
100 | )
101 |
102 | def quantized_weights(self):
103 | self._quant_w = torch.BoolTensor([True])
104 |
105 | def full_precision_weights(self):
106 | self._quant_w = torch.BoolTensor([False])
107 |
108 | def quantized_acts(self):
109 | self._quant_a = torch.BoolTensor([True])
110 |
111 | def full_precision_acts(self):
112 | self._quant_a = torch.BoolTensor([False])
113 |
114 | def quantized(self):
115 | self.quantized_weights()
116 | self.quantized_acts()
117 |
118 | def full_precision(self):
119 | self.full_precision_weights()
120 | self.full_precision_acts()
121 |
122 | def get_quantizer_status(self):
123 | return dict(quant_a=self._quant_a.item(), quant_w=self._quant_w.item())
124 |
125 | def set_quantizer_status(self, quantizer_status):
126 | if quantizer_status["quant_a"]:
127 | self.quantized_acts()
128 | else:
129 | self.full_precision_acts()
130 |
131 | if quantizer_status["quant_w"]:
132 | self.quantized_weights()
133 | else:
134 | self.full_precision_weights()
135 |
136 | def learn_ranges(self):
137 | self.apply(_set_layer_learn_ranges)
138 |
139 | def fix_ranges(self):
140 | self.apply(_set_layer_fix_ranges)
141 |
142 | def estimate_ranges(self):
143 | self.apply(_set_layer_estimate_ranges)
144 |
145 | def estimate_ranges_train(self):
146 | self.apply(_set_layer_estimate_ranges_train)
147 |
148 | def extra_repr(self):
149 | quant_state = "weight_quant={}, act_quant={}".format(
150 | self._quant_w.item(), self._quant_a.item()
151 | )
152 | parent_repr = super().extra_repr()
153 | return "{},\n{}".format(parent_repr, quant_state) if parent_repr else quant_state
154 |
155 |
156 | class QuantizedActivation(QuantizedModule):
157 | def __init__(self, *args, **kwargs):
158 | super().__init__(*args, **kwargs)
159 | self.activation_quantizer = QuantizationManager(
160 | qmethod=self.act_method,
161 | qparams=self.act_qparams,
162 | init=self.act_range_method,
163 | range_estim_params=self.act_range_options,
164 | )
165 |
166 | def quantize_activations(self, x):
167 | if self._quant_a:
168 | return self.activation_quantizer(x)
169 | else:
170 | return x
171 |
172 | def forward(self, x):
173 | return self.quantize_activations(x)
174 |
175 |
176 | class FP32Acts(nn.Module):
177 | def forward(self, x):
178 | return x
179 |
180 | def reset_ranges(self):
181 | pass
182 |
--------------------------------------------------------------------------------
/quantization/base_quantized_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | from typing import Union, Dict
5 |
6 | import torch
7 | from torch import nn, Tensor
8 |
9 | from quantization.base_quantized_classes import (
10 | QuantizedModule,
11 | _set_layer_estimate_ranges,
12 | _set_layer_estimate_ranges_train,
13 | _set_layer_learn_ranges,
14 | _set_layer_fix_ranges,
15 | )
16 | from quantization.quantizers import QuantizerBase
17 |
18 |
19 | class QuantizedModel(nn.Module):
20 | """
21 | Parent class for a quantized model. This allows you to have convenience functions to put the
22 | whole model into quantization or full precision.
23 | """
24 |
25 | def __init__(self, input_size=(1, 3, 224, 224)):
26 | """
27 | Parameters
28 | ----------
29 | input_size: Tuple with the input dimension for the model (including batch dimension)
30 | """
31 | super().__init__()
32 | self.input_size = input_size
33 |
34 | def load_state_dict(
35 | self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True
36 | ):
37 | """
38 | This function overwrites the load_state_dict of nn.Module to ensure that quantization
39 | parameters are loaded correctly for quantized model.
40 |
41 | """
42 | quant_state_dict = {
43 | k: v for k, v in state_dict.items() if k.endswith("_quant_a") or k.endswith("_quant_w")
44 | }
45 |
46 | if quant_state_dict:
47 | super().load_state_dict(quant_state_dict, strict=False)
48 | else:
49 | raise ValueError(
50 | "The quantization states of activations or weights should be "
51 | "included in the state dict "
52 | )
53 | # Pass dummy data through quantized model to ensure all quantization parameters are
54 | # initialized with the correct dimensions (None tensors will lead to issues in state dict
55 | # loading)
56 | device = next(self.parameters()).device
57 | dummy_input = torch.rand(*self.input_size, device=device)
58 | with torch.no_grad():
59 | self.forward(dummy_input)
60 |
61 | # Load state dict
62 | super().load_state_dict(state_dict, strict)
63 |
64 | def quantized_weights(self):
65 | def _fn(layer):
66 | if isinstance(layer, QuantizedModule):
67 | layer.quantized_weights()
68 |
69 | self.apply(_fn)
70 |
71 | def full_precision_weights(self):
72 | def _fn(layer):
73 | if isinstance(layer, QuantizedModule):
74 | layer.full_precision_weights()
75 |
76 | self.apply(_fn)
77 |
78 | def quantized_acts(self):
79 | def _fn(layer):
80 | if isinstance(layer, QuantizedModule):
81 | layer.quantized_acts()
82 |
83 | self.apply(_fn)
84 |
85 | def full_precision_acts(self):
86 | def _fn(layer):
87 | if isinstance(layer, QuantizedModule):
88 | layer.full_precision_acts()
89 |
90 | self.apply(_fn)
91 |
92 | def quantized(self):
93 | def _fn(layer):
94 | if isinstance(layer, QuantizedModule):
95 | layer.quantized()
96 |
97 | self.apply(_fn)
98 |
99 | def full_precision(self):
100 | def _fn(layer):
101 | if isinstance(layer, QuantizedModule):
102 | layer.full_precision()
103 |
104 | self.apply(_fn)
105 |
106 | def estimate_ranges(self):
107 | self.apply(_set_layer_estimate_ranges)
108 |
109 | def estimate_ranges_train(self):
110 | self.apply(_set_layer_estimate_ranges_train)
111 |
112 | def set_quant_state(self, weight_quant, act_quant):
113 | if act_quant:
114 | self.quantized_acts()
115 | else:
116 | self.full_precision_acts()
117 |
118 | if weight_quant:
119 | self.quantized_weights()
120 | else:
121 | self.full_precision_weights()
122 |
123 | def grad_scaling(self, grad_scaling=True):
124 | def _fn(module):
125 | if isinstance(module, QuantizerBase):
126 | module.grad_scaling = grad_scaling
127 |
128 | self.apply(_fn)
129 | # Methods for switching quantizer quantization states
130 |
131 | def learn_ranges(self):
132 | self.apply(_set_layer_learn_ranges)
133 |
134 | def fix_ranges(self):
135 | self.apply(_set_layer_fix_ranges)
136 |
--------------------------------------------------------------------------------
/quantization/hijacker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import copy
6 |
7 | from timm.models.layers.activations import Swish, HardSwish, HardSigmoid
8 | from timm.models.layers.activations_me import SwishMe, HardSwishMe, HardSigmoidMe
9 | from torch import nn
10 |
11 | from quantization.base_quantized_classes import QuantizedModule
12 | from quantization.quantization_manager import QuantizationManager
13 | from quantization.range_estimators import RangeEstimators
14 |
15 | activations_set = [
16 | nn.ReLU,
17 | nn.ReLU6,
18 | nn.Hardtanh,
19 | nn.Sigmoid,
20 | nn.Tanh,
21 | nn.GELU,
22 | nn.PReLU,
23 | Swish,
24 | SwishMe,
25 | HardSwish,
26 | HardSwishMe,
27 | HardSigmoid,
28 | HardSigmoidMe,
29 | ]
30 |
31 |
32 | class QuantizationHijacker(QuantizedModule):
33 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and
34 | dequantization on the weights and output distributions.
35 |
36 | Usage:
37 | To make a quantized nn.Linear layer:
38 | class HijackedLinear(QuantizationHijacker, nn.Linear):
39 | pass
40 | """
41 |
42 | def __init__(self, *args, activation: nn.Module = None, **kwargs):
43 |
44 | super().__init__(*args, **kwargs)
45 | if activation:
46 | assert isinstance(activation, tuple(activations_set)), str(activation)
47 |
48 | self.activation_function = copy.deepcopy(activation) if activation else None
49 |
50 | self.activation_quantizer = QuantizationManager(
51 | qmethod=self.act_method,
52 | init=self.act_range_method,
53 | qparams=self.act_qparams,
54 | range_estim_params=self.act_range_options,
55 | )
56 |
57 | if self.weight_range_method == RangeEstimators.current_minmax:
58 | weight_init_params = dict(percentile=self.percentile)
59 | else:
60 | weight_init_params = self.weight_range_options
61 |
62 | self.weight_quantizer = QuantizationManager(
63 | qmethod=self.method,
64 | init=self.weight_range_method,
65 | per_channel=self.per_channel_weights,
66 | qparams=self.weight_qparams,
67 | range_estim_params=weight_init_params,
68 | )
69 |
70 | def forward(self, x, offsets=None):
71 | # Quantize input
72 | if self.quantize_input and self._quant_a:
73 | x = self.activation_quantizer(x)
74 |
75 | # Get quantized weight
76 | weight, bias = self.get_params()
77 | res = self.run_forward(x, weight, bias, offsets=offsets)
78 |
79 | # Apply fused activation function
80 | if self.activation_function is not None:
81 | res = self.activation_function(res)
82 |
83 | # Quantize output
84 | if not self.quantize_input and self._quant_a:
85 | res = self.activation_quantizer(res)
86 | return res
87 |
88 | def get_params(self):
89 |
90 | weight, bias = self.get_weight_bias()
91 |
92 | if self._quant_w:
93 | weight = self.quantize_weights(weight)
94 |
95 | return weight, bias
96 |
97 | def quantize_weights(self, weights):
98 | return self.weight_quantizer(weights)
99 |
100 | def get_weight_bias(self):
101 | bias = None
102 | if hasattr(self, "bias"):
103 | bias = self.bias
104 | return self.weight, bias
105 |
106 | def run_forward(self, x, weight, bias, offsets=None):
107 | # Performs the actual linear operation of the layer
108 | raise NotImplementedError()
109 |
110 | def extra_repr(self):
111 | activation = "input" if self.quantize_input else "output"
112 | return f"{super().extra_repr()}-{activation}"
113 |
--------------------------------------------------------------------------------
/quantization/quant_error_estimator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import torch
6 | import numpy as np
7 | from utils.grid import integrate_pdf_grid_func_analyt
8 | from quantization.quantizers.fp8_quantizer import FPQuantizer, generate_all_float_values_scaled
9 |
10 |
11 | def generate_integr_grid_piecewise(integr_discontin, num_intervals_smallest_bin):
12 | bin_widths = np.diff(integr_discontin)
13 | min_bin_width = np.min(bin_widths[bin_widths > 0.0])
14 | integr_min_step = min_bin_width / num_intervals_smallest_bin
15 | grid_list = []
16 | for i in range(len(integr_discontin) - 1):
17 | curr_interv_min = integr_discontin[i]
18 | curr_interv_max = integr_discontin[i + 1]
19 | curr_interv_width = curr_interv_max - curr_interv_min
20 |
21 | if curr_interv_width == 0.0:
22 | continue
23 | assert curr_interv_min < curr_interv_max
24 | curr_interv_n_subintervals = np.ceil(curr_interv_width / integr_min_step).astype("int")
25 | curr_interv_n_pts = curr_interv_n_subintervals + 1
26 | lspace = torch.linspace(curr_interv_min, curr_interv_max, curr_interv_n_pts)
27 | grid_list.append(lspace)
28 |
29 | grid_all = torch.cat(grid_list)
30 | grid_all_no_dup = torch.unique(grid_all)
31 |
32 | return grid_all_no_dup
33 |
34 |
35 | def estimate_rounding_error_analyt(distr, grid):
36 | err = integrate_pdf_grid_func_analyt(distr, grid, "integr_interv_p_sqr_r")
37 | return err
38 |
39 |
40 | def estimate_dot_prod_error_analyt(distr_x, grid_x, distr_y, grid_y):
41 | rounding_err_x = integrate_pdf_grid_func_analyt(distr_x, grid_x, "integr_interv_p_sqr_r")
42 | rounding_err_y = integrate_pdf_grid_func_analyt(distr_y, grid_y, "integr_interv_p_sqr_r")
43 | second_moment_x = distr_x.eval_non_central_second_moment()
44 | second_moment_y = distr_y.eval_non_central_second_moment()
45 | y_p_y_R_y_signed = integrate_pdf_grid_func_analyt(distr_y, grid_y, "integr_interv_x_p_signed_r")
46 | x_p_x_R_x_signed = integrate_pdf_grid_func_analyt(distr_x, grid_x, "integr_interv_x_p_signed_r")
47 |
48 | term_rounding_err_x_moment_y = rounding_err_x * second_moment_y
49 | term_rounding_err_y_moment_x = rounding_err_y * second_moment_x
50 | term_mixed_rounding_err = rounding_err_x * rounding_err_y
51 | term_mixed_R_signed = 2.0 * y_p_y_R_y_signed * x_p_x_R_x_signed
52 | term_rounding_err_x_R_y_signed = 2.0 * rounding_err_x * y_p_y_R_y_signed
53 | term_rounding_err_y_R_x_signed = 2.0 * rounding_err_y * x_p_x_R_x_signed
54 |
55 | total_sum = (
56 | term_rounding_err_x_moment_y
57 | + term_rounding_err_y_moment_x
58 | + term_mixed_R_signed
59 | + term_mixed_rounding_err
60 | + term_rounding_err_x_R_y_signed
61 | + term_rounding_err_y_R_x_signed
62 | )
63 |
64 | return total_sum
65 |
66 |
67 | def estimate_rounding_error_empirical(W, quantizer, range_min, range_max):
68 | quantizer.set_quant_range(range_min, range_max)
69 | W_int_quant = quantizer.forward(W)
70 |
71 | round_err_sqr_int_quant_emp = (W_int_quant - W) ** 2
72 | res = torch.mean(round_err_sqr_int_quant_emp.flatten()).item()
73 | return res
74 |
75 |
76 | def estimate_dot_prod_error_empirical(
77 | x, y, quantizer_x, quantizer_y, x_range_min, x_range_max, y_range_min, y_range_max
78 | ):
79 | quantizer_x.set_quant_range(x_range_min, x_range_max)
80 | quantizer_y.set_quant_range(y_range_min, y_range_max)
81 | x_quant = quantizer_x.forward(x)
82 | y_quant = quantizer_y.forward(y)
83 |
84 | scalar_prod_err_emp = (torch.mul(x, y) - torch.mul(x_quant, y_quant)) ** 2
85 | res = torch.mean(scalar_prod_err_emp.flatten()).item()
86 | return res
87 |
88 |
89 | def compute_expected_dot_prod_mse(
90 | distr_x,
91 | distr_y,
92 | quant_x,
93 | quant_y,
94 | quant_x_range_min,
95 | quant_x_range_max,
96 | quant_y_range_min,
97 | quant_y_range_max,
98 | num_samples=2000000,
99 | ):
100 |
101 | quant_x.set_quant_range(quant_x_range_min, quant_x_range_max)
102 | quant_y.set_quant_range(quant_y_range_min, quant_y_range_max)
103 | if isinstance(quant_x, FPQuantizer):
104 | grid_x = generate_all_float_values_scaled(
105 | quant_x.n_bits, quant_x.ebits, quant_x.default_bias, quant_x_range_max
106 | )
107 | else:
108 | grid_x = quant_x.generate_grid().numpy()
109 |
110 | if isinstance(quant_y, FPQuantizer):
111 | grid_y = generate_all_float_values_scaled(
112 | quant_y.n_bits, quant_y.ebits, quant_y.default_bias, quant_y_range_max
113 | )
114 | else:
115 | grid_y = quant_x.generate_grid().numpy()
116 |
117 | err_analyt = estimate_dot_prod_error_analyt(distr_x, grid_x, distr_y, grid_y)
118 | distr_x_sample = torch.tensor(distr_x.sample((num_samples,)))
119 | distr_y_sample = torch.tensor(distr_x.sample((num_samples,)))
120 | err_emp = estimate_dot_prod_error_empirical(
121 | distr_x_sample,
122 | distr_y_sample,
123 | quant_x,
124 | quant_y,
125 | quant_x_range_min,
126 | quant_x_range_max,
127 | quant_y_range_min,
128 | quant_y_range_max,
129 | )
130 |
131 | rel_err = np.abs((err_emp - err_analyt) / err_analyt)
132 | return err_analyt
133 |
134 |
135 | def compute_expected_quant_mse(distr, quant, quant_range_min, quant_range_max, num_samples):
136 | quant.set_quant_range(quant_range_min, quant_range_max)
137 | if isinstance(quant, FPQuantizer):
138 | grid = generate_all_float_values_scaled(
139 | quant.n_bits, quant.ebits, quant.default_bias, quant_range_max
140 | )
141 | else:
142 | grid = quant.generate_grid().numpy()
143 |
144 | err_analyt = estimate_rounding_error_analyt(distr, grid)
145 | distr_sample = torch.tensor(
146 | distr.sample(
147 | num_samples,
148 | )
149 | )
150 | err_emp = estimate_rounding_error_empirical(
151 | distr_sample, quant, quant_range_min, quant_range_max
152 | )
153 |
154 | rel_err = np.abs((err_emp - err_analyt) / err_analyt)
155 | if rel_err > 0.1:
156 | print(
157 | "Warning: the relative difference between the analytical and empirical error estimate is too high,\n"
158 | "please consider increasing the number of samples for the quantization range estimator."
159 | )
160 |
161 | return err_analyt
162 |
--------------------------------------------------------------------------------
/quantization/quantization_manager.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from enum import auto
6 |
7 | from torch import nn
8 | from quantization.quantizers import QuantizerBase
9 | from quantization.quantizers.utils import QuantizerNotInitializedError
10 | from quantization.range_estimators import RangeEstimators, RangeEstimatorBase
11 | from utils import BaseEnumOptions
12 |
13 | from quantization.quantizers.uniform_quantizers import (
14 | SymmetricUniformQuantizer,
15 | AsymmetricUniformQuantizer,
16 | )
17 | from quantization.quantizers.fp8_quantizer import FPQuantizer
18 |
19 | from utils import ClassEnumOptions, MethodMap
20 |
21 |
22 | class QMethods(ClassEnumOptions):
23 | symmetric_uniform = MethodMap(SymmetricUniformQuantizer)
24 | asymmetric_uniform = MethodMap(AsymmetricUniformQuantizer)
25 | fp_quantizer = MethodMap(FPQuantizer)
26 |
27 |
28 | class QuantizationManager(nn.Module):
29 | """Implementation of Quantization and Quantization Range Estimation
30 |
31 | Parameters
32 | ----------
33 | n_bits: int
34 | Number of bits for the quantization.
35 | qmethod: QMethods member (Enum)
36 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform,
37 | qmn_uniform etc.
38 | init: RangeEstimators member (Enum)
39 | Initialization method for the grid from
40 | per_channel: bool
41 | If true, will use a separate quantization grid for each kernel/channel.
42 | x_min: float or PyTorch Tensor
43 | The minimum value which needs to be represented.
44 | x_max: float or PyTorch Tensor
45 | The maximum value which needs to be represented.
46 | qparams: kwargs
47 | dictionary of quantization parameters to passed to the quantizer instantiation
48 | range_estim_params: kwargs
49 | dictionary of parameters to passed to the range estimator instantiation
50 | """
51 |
52 | def __init__(
53 | self,
54 | qmethod: QuantizerBase = QMethods.symmetric_uniform.cls,
55 | init: RangeEstimatorBase = RangeEstimators.current_minmax.cls,
56 | per_channel=False,
57 | x_min=None,
58 | x_max=None,
59 | qparams=None,
60 | range_estim_params=None,
61 | ):
62 | super().__init__()
63 | self.state = Qstates.estimate_ranges
64 | self.qmethod = qmethod
65 | self.init = init
66 | self.per_channel = per_channel
67 | self.qparams = qparams if qparams else {}
68 | self.range_estim_params = range_estim_params if range_estim_params else {}
69 | self.range_estimator = None
70 |
71 | # define quantizer
72 | self.quantizer = self.qmethod(per_channel=self.per_channel, **qparams)
73 | self.quantizer.state = self.state
74 |
75 | # define range estimation method for quantizer initialisation
76 | if x_min is not None and x_max is not None:
77 | self.set_quant_range(x_min, x_max)
78 | self.fix_ranges()
79 | else:
80 | # set up the collector function to set the ranges
81 | self.range_estimator = self.init(
82 | per_channel=self.per_channel, quantizer=self.quantizer, **self.range_estim_params
83 | )
84 |
85 | @property
86 | def n_bits(self):
87 | return self.quantizer.n_bits
88 |
89 | def estimate_ranges(self):
90 | self.state = Qstates.estimate_ranges
91 | self.quantizer.state = self.state
92 |
93 | def fix_ranges(self):
94 | if self.quantizer.is_initialized:
95 | self.state = Qstates.fix_ranges
96 | self.quantizer.state = self.state
97 | else:
98 | raise QuantizerNotInitializedError()
99 |
100 | def learn_ranges(self):
101 | self.quantizer.make_range_trainable()
102 | self.state = Qstates.learn_ranges
103 | self.quantizer.state = self.state
104 |
105 | def estimate_ranges_train(self):
106 | self.state = Qstates.estimate_ranges_train
107 | self.quantizer.state = self.state
108 |
109 | def reset_ranges(self):
110 | self.range_estimator.reset()
111 | self.quantizer.reset()
112 | self.estimate_ranges()
113 |
114 | def forward(self, x):
115 | if self.state == Qstates.estimate_ranges or (
116 | self.state == Qstates.estimate_ranges_train and self.training
117 | ):
118 | # Note this can be per tensor or per channel
119 | cur_xmin, cur_xmax = self.range_estimator(x)
120 | self.set_quant_range(cur_xmin, cur_xmax)
121 |
122 | return self.quantizer(x)
123 |
124 | def set_quant_range(self, x_min, x_max):
125 | self.quantizer.set_quant_range(x_min, x_max)
126 |
127 | def extra_repr(self):
128 | return "state={}".format(self.state.name)
129 |
130 |
131 | class Qstates(BaseEnumOptions):
132 | estimate_ranges = auto() # ranges are updated in eval and train mode
133 | fix_ranges = auto() # quantization ranges are fixed for train and eval
134 | learn_ranges = auto() # quantization params are nn.Parameters
135 | estimate_ranges_train = auto() # quantization ranges are updated during train and fixed for
136 | # eval
137 |
--------------------------------------------------------------------------------
/quantization/quantized_folded_bn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from torch.nn.modules.conv import _ConvNd
8 |
9 | from quantization.hijacker import QuantizationHijacker
10 |
11 |
12 | class BNFusedHijacker(QuantizationHijacker):
13 | """Extension to the QuantizationHijacker that fuses batch normalization (BN) after a weight
14 | layer into a joined module. The parameters and the statistics of the BN layer remain in
15 | full-precision.
16 | """
17 |
18 | def __init__(self, *args, **kwargs):
19 | kwargs.pop("bias", None) # Bias will be learned by BN params
20 | super().__init__(*args, **kwargs, bias=False)
21 | bn_dim = self.get_bn_dim()
22 | self.register_buffer("running_mean", torch.zeros(bn_dim))
23 | self.register_buffer("running_var", torch.ones(bn_dim))
24 | self.momentum = kwargs.pop("momentum", 0.1)
25 | self.gamma = nn.Parameter(torch.ones(bn_dim))
26 | self.beta = nn.Parameter(torch.zeros(bn_dim))
27 | self.epsilon = kwargs.get("eps", 1e-5)
28 | self.bias = None
29 |
30 | def forward(self, x):
31 | # Quantize input
32 | if self.quantize_input and self._quant_a:
33 | x = self.activation_quantizer(x)
34 |
35 | # Get quantized weight
36 | weight, bias = self.get_params()
37 | res = self.run_forward(x, weight, bias)
38 |
39 | res = F.batch_norm(
40 | res,
41 | self.running_mean,
42 | self.running_var,
43 | self.gamma,
44 | self.beta,
45 | self.training,
46 | self.momentum,
47 | self.epsilon,
48 | )
49 | # Apply fused activation function
50 | if self.activation_function is not None:
51 | res = self.activation_function(res)
52 |
53 | # Quantize output
54 | if not self.quantize_input and self._quant_a:
55 | res = self.activation_quantizer(res)
56 | return res
57 |
58 | def get_bn_dim(self):
59 | if isinstance(self, nn.Linear):
60 | return self.out_features
61 | elif isinstance(self, _ConvNd):
62 | return self.out_channels
63 | else:
64 | msg = (
65 | f"Unsupported type used: {self}. Must be a linear or (transpose)-convolutional "
66 | f"nn.Module"
67 | )
68 | raise NotImplementedError(msg)
69 |
--------------------------------------------------------------------------------
/quantization/quantizers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from quantization.quantizers.base_quantizers import QuantizerBase
6 | from quantization.quantizers.fp8_quantizer import FPQuantizer
7 | from quantization.quantizers.uniform_quantizers import (
8 | AsymmetricUniformQuantizer,
9 | SymmetricUniformQuantizer,
10 | )
11 |
--------------------------------------------------------------------------------
/quantization/quantizers/base_quantizers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from torch import nn
6 |
7 |
8 | class QuantizerBase(nn.Module):
9 | def __init__(self, n_bits, per_channel=False, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self.n_bits = n_bits
12 | self.per_channel = per_channel
13 | self.state = None
14 | self.x_min_fp32 = self.x_max_fp32 = None
15 |
16 | @property
17 | def is_initialized(self):
18 | raise NotImplementedError()
19 |
20 | @property
21 | def x_max(self):
22 | raise NotImplementedError()
23 |
24 | @property
25 | def symmetric(self):
26 | raise NotImplementedError()
27 |
28 | @property
29 | def x_min(self):
30 | raise NotImplementedError()
31 |
32 | def forward(self, x_float):
33 | raise NotImplementedError()
34 |
35 | def _adjust_params_per_channel(self, x):
36 | raise NotImplementedError()
37 |
38 | def set_quant_range(self, x_min, x_max):
39 | raise NotImplementedError()
40 |
41 | def extra_repr(self):
42 | return "n_bits={}, per_channel={}, is_initalized={}".format(
43 | self.n_bits, self.per_channel, self.is_initialized
44 | )
45 |
46 | def reset(self):
47 | self._delta = None
48 |
--------------------------------------------------------------------------------
/quantization/quantizers/fp8_quantizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import torch
5 | import torch.nn as nn
6 | from quantization.quantizers.base_quantizers import QuantizerBase
7 | import numpy as np
8 | from itertools import product
9 | from torch.autograd import Function
10 | from quantization.quantizers.rounding_utils import round_ste_func
11 |
12 |
13 | def generate_all_values_fp(
14 | num_total_bits: int = 8, num_exponent_bits: int = 4, bias: int = 8
15 | ) -> list:
16 | num_fraction_bits = num_total_bits - 1 - num_exponent_bits
17 |
18 | all_values = []
19 | exp_lower = -bias
20 | for S in [-1.0, 1.0]:
21 | for E_str_iter in product(*[[0, 1]] * num_exponent_bits):
22 | for F_str_iter in product(*[[0, 1]] * num_fraction_bits):
23 | E_str = "".join(str(i) for i in E_str_iter)
24 | F_str = "".join(str(i) for i in F_str_iter)
25 |
26 | # encoded exponent
27 | E_enc = decode_binary_str(E_str)
28 | E_eff = E_enc - bias
29 | if E_eff == exp_lower:
30 | is_subnormal = 1
31 | else:
32 | is_subnormal = 0
33 |
34 | F_enc = decode_binary_str(F_str) * 2**-num_fraction_bits
35 | F_eff = F_enc + 1 - is_subnormal
36 |
37 | fp8_val = S * 2.0 ** (E_enc - bias + is_subnormal) * F_eff
38 | all_values.append(fp8_val)
39 | res = np.array(all_values)
40 | res = np.sort(res)
41 | return res
42 |
43 |
44 | def generate_all_float_values_scaled(num_total_bits, num_exp_bits, exp_bias, range_limit_fp):
45 | grid = generate_all_values_fp(num_total_bits, num_exp_bits, exp_bias)
46 | float_max_abs_val = np.max(np.abs(grid))
47 |
48 | float_scale = float_max_abs_val / range_limit_fp
49 | floats_all = grid / float_scale
50 | return floats_all
51 |
52 |
53 | def decode_float8(S, E, F, bias=16):
54 | sign = -2 * int(S) + 1
55 | exponent = int(E, 2) if E else 0
56 | # Normal FP8 : exponent > 0 : result = 2^(exponent-bias) * 1.F
57 | # Subnormal FP8: exponent == 0: result = 2^(-bias+1) * 0.F
58 | # Lowest quantization bin: 2^(-bias+1) * {0.0 ... 1 + (2^mantissa-1)/2^mantissa}
59 | # All other bins : 2^(exponent-bias) * {1.0 ... 1 + (2^mantissa-1)/2^mantissa}; exponent > 0
60 | A = int(exponent != 0)
61 | fraction = A + sum([2 ** -(i + 1) * int(a) for i, a in enumerate(F)])
62 | exponent += int(exponent == 0)
63 | return sign * fraction * 2.0 ** (exponent - bias)
64 |
65 |
66 | def i(x):
67 | return np.array([x]).astype(np.int32)
68 |
69 |
70 | def gen(n_bits, exponent_bits, bias):
71 | all_values = []
72 | for s in product(*[[0, 1]] * 1):
73 | for e in product(*[[0, 1]] * exponent_bits):
74 | for m in product(*[[0, 1]] * (n_bits - 1 - exponent_bits)):
75 | s = str(s[0])
76 | e = "".join(str(i) for i in e)
77 | m = "".join(str(i) for i in m)
78 | all_values.append(decode_float8(s, e, m, bias=bias))
79 | return sorted(all_values)
80 |
81 |
82 | def get_max_value(num_exponent_bits: int = 4, bias: int = 8):
83 | num_fraction_bits = 7 - num_exponent_bits
84 | scale = 2**-num_fraction_bits
85 | max_frac = 1 - scale
86 | max_value = 2 ** (2**num_exponent_bits - 1 - bias) * (1 + max_frac)
87 |
88 | return max_value
89 |
90 |
91 | def quantize_to_fp8_ste_MM(
92 | x_float: torch.Tensor,
93 | n_bits: int,
94 | maxval: torch.Tensor,
95 | num_mantissa_bits: torch.Tensor,
96 | sign_bits: int,
97 | ) -> torch.Tensor:
98 | """
99 | Simpler FP8 quantizer that exploits the fact that FP quantization is just INT quantization with
100 | scales that depend on the input.
101 |
102 | This allows to define FP8 quantization using STE rounding functions and thus learn the bias
103 |
104 | """
105 | M = torch.clamp(round_ste_func(num_mantissa_bits), 1, n_bits - sign_bits)
106 | E = n_bits - sign_bits - M
107 |
108 | if maxval.shape[0] != 1 and len(maxval.shape) != len(x_float.shape):
109 | maxval = maxval.view([-1] + [1] * (len(x_float.shape) - 1))
110 | bias = 2**E - torch.log2(maxval) + torch.log2(2 - 2 ** (-M)) - 1
111 |
112 | minval = -maxval if sign_bits == 1 else torch.zeros_like(maxval)
113 | xc = torch.min(torch.max(x_float, minval), maxval)
114 |
115 | """
116 | 2 notes here:
117 | 1: Shifting by bias to ensure data is aligned to the scaled grid in case bias not in Z.
118 | Recall that implicitly bias := bias' - log2(alpha), where bias' in Z. If we assume
119 | alpha in (0.5, 1], then alpha contracts the grid, which is equivalent to translate the
120 | data 'to the right' relative to the grid, which is what the subtraction of log2(alpha)
121 | (which is negative) accomplishes.
122 | 2: Ideally this procedure doesn't affect gradients wrt the input (we want to use the STE).
123 | We can achieve this by detaching log2 of the (absolute) input.
124 |
125 | """
126 |
127 | # log_scales = torch.max((torch.floor(torch.log2(torch.abs(xc)) + bias)).detach(), 1.)
128 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(xc)) + bias)).detach(), 1.0)
129 |
130 | scales = 2.0 ** (log_scales - M - bias)
131 |
132 | result = round_ste_func(xc / scales) * scales
133 | return result
134 |
135 |
136 | class FP8QuantizerFunc(Function):
137 | @staticmethod
138 | def forward(ctx, x_float, bias, num_exponent_bits):
139 | return quantize_to_fp8_ste_MM(x_float, bias, num_exponent_bits)
140 |
141 | @staticmethod
142 | def backward(ctx, grad_output):
143 | return grad_output, None, None
144 |
145 |
146 | def decode_binary_str(F_str):
147 | F = sum([2 ** -(i + 1) * int(a) for i, a in enumerate(F_str)]) * 2 ** len(F_str)
148 | return F
149 |
150 |
151 | class FPQuantizer(QuantizerBase):
152 | """
153 | 8-bit Floating Point Quantizer
154 | """
155 |
156 | def __init__(
157 | self,
158 | *args,
159 | scale_domain=None,
160 | mantissa_bits=4,
161 | maxval=3,
162 | set_maxval=False,
163 | learn_maxval=False,
164 | learn_mantissa_bits=False,
165 | mse_include_mantissa_bits=True,
166 | allow_unsigned=False,
167 | **kwargs,
168 | ):
169 | super().__init__(*args, **kwargs)
170 |
171 | self.mantissa_bits = mantissa_bits
172 |
173 | self.ebits = self.n_bits - self.mantissa_bits - 1
174 | self.default_bias = 2 ** (self.ebits - 1)
175 |
176 | # assume signed, correct when range setting turns out to be unsigned
177 | default_maxval = (2 - 2 ** (-self.mantissa_bits)) * 2 ** (
178 | 2**self.ebits - 1 - self.default_bias
179 | )
180 |
181 | self.maxval = maxval if maxval is not None else default_maxval
182 |
183 | self.maxval = torch.Tensor([self.maxval])
184 | self.mantissa_bits = torch.Tensor([float(self.mantissa_bits)])
185 |
186 | self.set_maxval = set_maxval
187 | self.learning_maxval = learn_maxval
188 | self.learning_mantissa_bits = learn_mantissa_bits
189 | self.mse_include_mantissa_bits = mse_include_mantissa_bits
190 |
191 | self.allow_unsigned = allow_unsigned
192 | self.sign_bits = 1
193 |
194 | def forward(self, x_float):
195 | if self.maxval.device != x_float.device:
196 | self.maxval = self.maxval.to(x_float.device)
197 | if self.mantissa_bits.device != x_float.device:
198 | self.mantissa_bits = self.mantissa_bits.to(x_float.device)
199 |
200 | res = quantize_to_fp8_ste_MM(
201 | x_float, self.n_bits, self.maxval, self.mantissa_bits, self.sign_bits
202 | )
203 |
204 | ebits = self.n_bits - self.mantissa_bits - 1
205 | return res
206 |
207 | def is_initialized(self):
208 | return True
209 |
210 | def symmetric(self):
211 | return False
212 |
213 | def effective_bit_width(self):
214 | return None
215 |
216 | def _make_unsigned(self, x_min):
217 | if isinstance(x_min, torch.Tensor):
218 | return self.allow_unsigned and torch.all(x_min >= 0)
219 | else:
220 | return self.allow_unsigned and x_min >= 0
221 |
222 | def set_quant_range(self, x_min, x_max):
223 |
224 | if self._make_unsigned(x_min):
225 | self.sign_bits = 0
226 |
227 | if self.set_maxval:
228 | if not isinstance(x_max, torch.Tensor):
229 | x_max = torch.Tensor([x_max]).to(self.maxval.device)
230 | x_min = torch.Tensor([x_min]).to(self.maxval.device)
231 | if self.maxval.device != x_max.device:
232 | self.maxval = self.maxval.to(x_max.device)
233 | if self.mantissa_bits.device != x_max.device:
234 | self.mantissa_bits = self.mantissa_bits.to(x_max.device)
235 |
236 | mx = torch.abs(torch.max(torch.abs(x_min), x_max))
237 | self.maxval = mx
238 |
239 | if not isinstance(self.maxval, torch.Tensor) or len(self.maxval.shape) == 0:
240 | self.maxval = torch.Tensor([self.maxval])
241 |
242 | def make_range_trainable(self):
243 | if self.learning_maxval:
244 | self.learn_maxval()
245 | if self.learning_mantissa_bits:
246 | self.learn_mantissa_bits()
247 |
248 | def learn_maxval(self):
249 | self.learning_maxval = True
250 | self.maxval = torch.nn.Parameter(self.maxval)
251 |
252 | def learn_mantissa_bits(self):
253 | self.learning_mantissa_bits = True
254 | self.mantissa_bits = torch.nn.Parameter(self.mantissa_bits)
255 |
256 | def fix_ranges(self):
257 | if isinstance(self.maxval, nn.Parameter):
258 | self.parameter_to_fixed("maxval")
259 | if isinstance(self.mantissa_bits, nn.Parameter):
260 | self.parameter_to_fixed("mantissa_bits")
261 |
262 | def extra_repr(self):
263 | maxval = self.maxval
264 |
265 | M = torch.clamp(torch.round(self.mantissa_bits), 1, 7)
266 | E = 7 - M
267 | maxval = 2**E - torch.log2(self.maxval) + torch.log2(2 - 2 ** (-M)) - 1
268 | if maxval.shape[0] > 1:
269 | bstr = "[per_channel]"
270 | else:
271 | bstr = f"{maxval.item()}"
272 | return f"Exponent: {E.item()} bits; mode: ; bias: {bstr}"
273 |
--------------------------------------------------------------------------------
/quantization/quantizers/rounding_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from torch import nn
5 | import torch
6 | from torch.autograd import Function
7 |
8 | # Functional
9 | from utils import MethodMap, ClassEnumOptions
10 |
11 |
12 | class RoundStraightThrough(Function):
13 | @staticmethod
14 | def forward(ctx, x):
15 | return torch.round(x)
16 |
17 | @staticmethod
18 | def backward(ctx, output_grad):
19 | return output_grad
20 |
21 |
22 | class StochasticRoundSTE(Function):
23 | @staticmethod
24 | def forward(ctx, x):
25 | # Sample noise between [0, 1)
26 | noise = torch.rand_like(x)
27 | return torch.floor(x + noise)
28 |
29 | @staticmethod
30 | def backward(ctx, output_grad):
31 | return output_grad
32 |
33 |
34 | class ScaleGradient(Function):
35 | @staticmethod
36 | def forward(ctx, x, scale):
37 | ctx.scale = scale
38 | return x
39 |
40 | @staticmethod
41 | def backward(ctx, output_grad):
42 | return output_grad * ctx.scale, None
43 |
44 |
45 | class EWGSFunctional(Function):
46 | """
47 | x_in: float input
48 | scaling_factor: backward scaling factor
49 | x_out: discretized version of x_in within the range of [0,1]
50 | """
51 |
52 | @staticmethod
53 | def forward(ctx, x_in, scaling_factor):
54 | x_int = torch.round(x_in)
55 | ctx._scaling_factor = scaling_factor
56 | ctx.save_for_backward(x_in - x_int)
57 | return x_int
58 |
59 | @staticmethod
60 | def backward(ctx, g):
61 | diff = ctx.saved_tensors[0]
62 | delta = ctx._scaling_factor
63 | scale = 1 + delta * torch.sign(g) * diff
64 | return g * scale, None, None
65 |
66 |
67 | class StackSigmoidFunctional(Function):
68 | @staticmethod
69 | def forward(ctx, x, alpha):
70 | # Apply round to nearest in the forward pass
71 | ctx.save_for_backward(x, alpha)
72 | return torch.round(x)
73 |
74 | @staticmethod
75 | def backward(ctx, grad_output):
76 | x, alpha = ctx.saved_tensors
77 | sig_min = torch.sigmoid(alpha / 2)
78 | sig_scale = 1 - 2 * sig_min
79 | x_base = torch.floor(x).detach()
80 | x_rest = x - x_base - 0.5
81 | stacked_sigmoid_grad = (
82 | torch.sigmoid(x_rest * -alpha)
83 | * (1 - torch.sigmoid(x_rest * -alpha))
84 | * -alpha
85 | / sig_scale
86 | )
87 | return stacked_sigmoid_grad * grad_output, None
88 |
89 |
90 | # Parametrized modules
91 | class ParametrizedGradEstimatorBase(nn.Module):
92 | def __init__(self, *args, **kwargs):
93 | super().__init__()
94 | self._trainable = False
95 |
96 | def make_grad_params_trainable(self):
97 | self._trainable = True
98 | for name, buf in self.named_buffers(recurse=False):
99 | setattr(self, name, torch.nn.Parameter(buf))
100 |
101 | def make_grad_params_tensor(self):
102 | self._trainable = False
103 | for name, param in self.named_parameters(recurse=False):
104 | cur_value = param.data
105 | delattr(self, name)
106 | self.register_buffer(name, cur_value)
107 |
108 | def forward(self, x):
109 | raise NotImplementedError()
110 |
111 |
112 | class StackedSigmoid(ParametrizedGradEstimatorBase):
113 | """
114 | Stacked sigmoid estimator based on a simulated sigmoid forward pass
115 | """
116 |
117 | def __init__(self, alpha=1.0):
118 | super().__init__()
119 | self.register_buffer("alpha", torch.tensor(alpha))
120 |
121 | def forward(self, x):
122 | return stacked_sigmoid_func(x, self.alpha)
123 |
124 | def extra_repr(self):
125 | return f"alpha={self.alpha.item()}"
126 |
127 |
128 | class EWGSDiscretizer(ParametrizedGradEstimatorBase):
129 | def __init__(self, scaling_factor=0.2):
130 | super().__init__()
131 | self.register_buffer("scaling_factor", torch.tensor(scaling_factor))
132 |
133 | def forward(self, x):
134 | return ewgs_func(x, self.scaling_factor)
135 |
136 | def extra_repr(self):
137 | return f"scaling_factor={self.scaling_factor.item()}"
138 |
139 |
140 | class StochasticRounding(nn.Module):
141 | def __init__(self):
142 | super().__init__()
143 |
144 | def forward(self, x):
145 | if self.training:
146 | return stochastic_round_ste_func(x)
147 | else:
148 | return round_ste_func(x)
149 |
150 |
151 | round_ste_func = RoundStraightThrough.apply
152 | stacked_sigmoid_func = StackSigmoidFunctional.apply
153 | scale_grad_func = ScaleGradient.apply
154 | stochastic_round_ste_func = StochasticRoundSTE.apply
155 | ewgs_func = EWGSFunctional.apply
156 |
157 |
158 | class GradientEstimator(ClassEnumOptions):
159 | ste = MethodMap(round_ste_func)
160 | stoch_round = MethodMap(StochasticRounding)
161 | ewgs = MethodMap(EWGSDiscretizer)
162 | stacked_sigmoid = MethodMap(StackedSigmoid)
163 |
--------------------------------------------------------------------------------
/quantization/quantizers/uniform_quantizers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import inspect
6 | import torch
7 |
8 | from quantization.quantizers.rounding_utils import scale_grad_func, round_ste_func
9 | from .utils import QuantizerNotInitializedError
10 | from .base_quantizers import QuantizerBase
11 |
12 |
13 | class AsymmetricUniformQuantizer(QuantizerBase):
14 | """
15 | PyTorch Module that implements Asymmetric Uniform Quantization using STE.
16 | Quantizes its argument in the forward pass, passes the gradient 'straight
17 | through' on the backward pass, ignoring the quantization that occurred.
18 |
19 | Parameters
20 | ----------
21 | n_bits: int
22 | Number of bits for quantization.
23 | scale_domain: str ('log', 'linear) with default='linear'
24 | Domain of scale factor
25 | per_channel: bool
26 | If True: allows for per-channel quantization
27 | """
28 |
29 | def __init__(
30 | self,
31 | n_bits,
32 | scale_domain="linear",
33 | discretizer=round_ste_func,
34 | discretizer_args=tuple(),
35 | grad_scaling=False,
36 | eps=1e-8,
37 | **kwargs
38 | ):
39 | super().__init__(n_bits=n_bits, **kwargs)
40 |
41 | assert scale_domain in ("linear", "log")
42 | self.register_buffer("_delta", None)
43 | self.register_buffer("_zero_float", None)
44 |
45 | if inspect.isclass(discretizer):
46 | self.discretizer = discretizer(*discretizer_args)
47 | else:
48 | self.discretizer = discretizer
49 |
50 | self.scale_domain = scale_domain
51 | self.grad_scaling = grad_scaling
52 | self.eps = eps
53 |
54 | # A few useful properties
55 | @property
56 | def delta(self):
57 | if self._delta is not None:
58 | return self._delta
59 | else:
60 | raise QuantizerNotInitializedError()
61 |
62 | @property
63 | def zero_float(self):
64 | if self._zero_float is not None:
65 | return self._zero_float
66 | else:
67 | raise QuantizerNotInitializedError()
68 |
69 | @property
70 | def is_initialized(self):
71 | return self._delta is not None
72 |
73 | @property
74 | def symmetric(self):
75 | return False
76 |
77 | @property
78 | def int_min(self):
79 | # integer grid minimum
80 | return 0.0
81 |
82 | @property
83 | def int_max(self):
84 | # integer grid maximum
85 | return 2.0**self.n_bits - 1
86 |
87 | @property
88 | def scale(self):
89 | if self.scale_domain == "linear":
90 | return torch.clamp(self.delta, min=self.eps)
91 | elif self.scale_domain == "log":
92 | return torch.exp(self.delta)
93 |
94 | @property
95 | def zero_point(self):
96 | zero_point = self.discretizer(self.zero_float)
97 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max)
98 | return zero_point
99 |
100 | @property
101 | def x_max(self):
102 | return self.scale * (self.int_max - self.zero_point)
103 |
104 | @property
105 | def x_min(self):
106 | return self.scale * (self.int_min - self.zero_point)
107 |
108 | def to_integer_forward(self, x_float, *args, **kwargs):
109 | """
110 | Qunatized input to its integer representation
111 | Parameters
112 | ----------
113 | x_float: PyTorch Float Tensor
114 | Full-precision Tensor
115 |
116 | Returns
117 | -------
118 | x_int: PyTorch Float Tensor of integers
119 | """
120 | if self.grad_scaling:
121 | grad_scale = self.calculate_grad_scale(x_float)
122 | scale = scale_grad_func(self.scale, grad_scale)
123 | zero_point = (
124 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale)
125 | )
126 | else:
127 | scale = self.scale
128 | zero_point = self.zero_point
129 |
130 | x_int = self.discretizer(x_float / scale) + zero_point
131 | x_int = torch.clamp(x_int, self.int_min, self.int_max)
132 |
133 | return x_int
134 |
135 | def forward(self, x_float, *args, **kwargs):
136 | """
137 | Quantizes (quantized to integer and the scales back to original domain)
138 | Parameters
139 | ----------
140 | x_float: PyTorch Float Tensor
141 | Full-precision Tensor
142 |
143 | Returns
144 | -------
145 | x_quant: PyTorch Float Tensor
146 | Quantized-Dequantized Tensor
147 | """
148 | if self.per_channel:
149 | self._adjust_params_per_channel(x_float)
150 |
151 | if self.grad_scaling:
152 | grad_scale = self.calculate_grad_scale(x_float)
153 | scale = scale_grad_func(self.scale, grad_scale)
154 | zero_point = (
155 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale)
156 | )
157 | else:
158 | scale = self.scale
159 | zero_point = self.zero_point
160 |
161 | x_int = self.to_integer_forward(x_float, *args, **kwargs)
162 | x_quant = scale * (x_int - zero_point)
163 |
164 | return x_quant
165 |
166 | def calculate_grad_scale(self, quant_tensor):
167 | num_pos_levels = self.int_max # Qp in LSQ paper
168 | num_elements = quant_tensor.numel() # nfeatures or nweights in LSQ paper
169 | if self.per_channel:
170 | # In the per tensor case we do not sum the gradients over the output channel dimension
171 | num_elements /= quant_tensor.shape[0]
172 |
173 | return (num_pos_levels * num_elements) ** -0.5 # 1 / sqrt (Qn * nfeatures)
174 |
175 | def _adjust_params_per_channel(self, x):
176 | """
177 | Adjusts the quantization parameter tensors (delta, zero_float)
178 | to the input tensor shape if they don't match
179 | Parameters
180 | ----------
181 | x: input tensor
182 | """
183 | if x.ndim != self.delta.ndim:
184 | new_shape = [-1] + [1] * (len(x.shape) - 1)
185 | self._delta = self.delta.view(new_shape)
186 | if self._zero_float is not None:
187 | self._zero_float = self._zero_float.view(new_shape)
188 |
189 | def _tensorize_min_max(self, x_min, x_max):
190 | """
191 | Converts provided min max range into tensors
192 | Parameters
193 | ----------
194 | x_min: float or PyTorch 1D tensor
195 | x_max: float or PyTorch 1D tensor
196 |
197 | Returns
198 | -------
199 | x_min: PyTorch Tensor 0 or 1-D
200 | x_max: PyTorch Tensor 0 or 1-D
201 | """
202 | # Ensure a torch tensor
203 | if not torch.is_tensor(x_min):
204 | x_min = torch.tensor(x_min).float()
205 | x_max = torch.tensor(x_max).float()
206 |
207 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel:
208 | print(x_min)
209 | print(self.per_channel)
210 | raise ValueError(
211 | "x_min and x_max must be a float or 1-D Tensor"
212 | " for per-tensor quantization (per_channel=False)"
213 | )
214 | # Ensure we always use zero and avoid division by zero
215 | x_min = torch.min(x_min, torch.zeros_like(x_min))
216 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps)
217 |
218 | return x_min, x_max
219 |
220 | def set_quant_range(self, x_min, x_max):
221 | """
222 | Instantiates the quantization parameters based on the provided
223 | min and max range
224 |
225 | Parameters
226 | ----------
227 | x_min: tensor or float
228 | Quantization range minimum limit
229 | x_max: tensor of float
230 | Quantization range minimum limit
231 | """
232 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max
233 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
234 | self._delta = (x_max - x_min) / self.int_max
235 | self._zero_float = (-x_min / self.delta).detach()
236 |
237 | if self.scale_domain == "log":
238 | self._delta = torch.log(self.delta)
239 |
240 | self._delta = self._delta.detach()
241 |
242 | def make_range_trainable(self):
243 | # Converts trainable parameters to nn.Parameters
244 | if self.delta not in self.parameters():
245 | self._delta = torch.nn.Parameter(self._delta)
246 | self._zero_float = torch.nn.Parameter(self._zero_float)
247 |
248 | def fix_ranges(self):
249 | # Removes trainable quantization params from nn.Parameters
250 | if self.delta in self.parameters():
251 | _delta = self._delta.data
252 | _zero_float = self._zero_float.data
253 | del self._delta # delete the parameter
254 | del self._zero_float
255 | self.register_buffer("_delta", _delta)
256 | self.register_buffer("_zero_float", _zero_float)
257 |
258 |
259 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer):
260 | """
261 | PyTorch Module that implements Symmetric Uniform Quantization using STE.
262 | Quantizes its argument in the forward pass, passes the gradient 'straight
263 | through' on the backward pass, ignoring the quantization that occurred.
264 |
265 | Parameters
266 | ----------
267 | n_bits: int
268 | Number of bits for quantization.
269 | scale_domain: str ('log', 'linear) with default='linear'
270 | Domain of scale factor
271 | per_channel: bool
272 | If True: allows for per-channel quantization
273 | """
274 |
275 | def __init__(self, *args, **kwargs):
276 | super().__init__(*args, **kwargs)
277 | self.register_buffer("_signed", None)
278 |
279 | @property
280 | def signed(self):
281 | if self._signed is not None:
282 | return self._signed.item()
283 | else:
284 | raise QuantizerNotInitializedError()
285 |
286 | @property
287 | def symmetric(self):
288 | return True
289 |
290 | @property
291 | def int_min(self):
292 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0
293 |
294 | @property
295 | def int_max(self):
296 | pos_n_bits = self.n_bits - self.signed
297 | return 2.0**pos_n_bits - 1
298 |
299 | @property
300 | def zero_point(self):
301 | return 0.0
302 |
303 | def set_quant_range(self, x_min, x_max):
304 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max
305 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
306 | self._signed = x_min.min() < 0
307 |
308 | x_absmax = torch.max(x_min.abs(), x_max)
309 | self._delta = x_absmax / self.int_max
310 |
311 | if self.scale_domain == "log":
312 | self._delta = torch.log(self._delta)
313 |
314 | self._delta = self._delta.detach()
315 |
316 | def make_range_trainable(self):
317 | # Converts trainable parameters to nn.Parameters
318 | if self.delta not in self.parameters():
319 | self._delta = torch.nn.Parameter(self._delta)
320 |
321 | def fix_ranges(self):
322 | # Removes trainable quantization params from nn.Parameters
323 | if self.delta in self.parameters():
324 | _delta = self._delta.data
325 | del self._delta # delete the parameter
326 | self.register_buffer("_delta", _delta)
327 |
328 | def generate_grid(self):
329 | x_int_rng = torch.arange(self.int_min, self.int_max + 1)
330 | grid = self.scale * (x_int_rng - self.zero_point)
331 | return grid
332 |
--------------------------------------------------------------------------------
/quantization/quantizers/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 |
6 | class QuantizerNotInitializedError(Exception):
7 | """Raised when a quantizer has not been initialized"""
8 |
9 | def __init__(self):
10 | super(QuantizerNotInitializedError, self).__init__(
11 | "Quantizer has not been initialized yet"
12 | )
13 |
--------------------------------------------------------------------------------
/quantization/range_estimators.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import copy
5 | from enum import auto
6 |
7 | import numpy as np
8 | import torch
9 | from scipy.optimize import minimize_scalar
10 | from torch import nn
11 |
12 | from utils import to_numpy, BaseEnumOptions, MethodMap, ClassEnumOptions
13 |
14 |
15 | class RangeEstimatorBase(nn.Module):
16 | def __init__(self, per_channel=False, quantizer=None, *args, **kwargs):
17 | super().__init__(*args, **kwargs)
18 | self.register_buffer("current_xmin", None)
19 | self.register_buffer("current_xmax", None)
20 | self.per_channel = per_channel
21 | self.quantizer = quantizer
22 |
23 | def forward(self, x):
24 | """
25 | Accepts an input tensor, updates the current estimates of x_min and x_max
26 | and returns them.
27 | Parameters
28 | ----------
29 | x: Input tensor
30 |
31 | Returns
32 | -------
33 | self.current_xmin: tensor
34 |
35 | self.current_xmax: tensor
36 |
37 | """
38 | raise NotImplementedError()
39 |
40 | def reset(self):
41 | """
42 | Reset the range estimator.
43 | """
44 | self.current_xmin = None
45 | self.current_xmax = None
46 |
47 | def __repr__(self):
48 | # We overwrite this from nn.Module as we do not want to have submodules such as
49 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module.
50 | lines = self.extra_repr().split("\n")
51 | extra_str = lines[0] if len(lines) == 1 else "\n " + "\n ".join(lines) + "\n"
52 |
53 | return self._get_name() + "(" + extra_str + ")"
54 |
55 |
56 | class CurrentMinMaxEstimator(RangeEstimatorBase):
57 | def __init__(self, percentile=None, *args, **kwargs):
58 | self.percentile = percentile
59 | super().__init__(*args, **kwargs)
60 |
61 | def forward(self, x):
62 | if self.per_channel:
63 | x = x.view(x.shape[0], -1)
64 | if self.percentile:
65 | axis = -1 if self.per_channel else None
66 | data_np = to_numpy(x)
67 | x_min, x_max = np.percentile(
68 | data_np, (self.percentile, 100 - self.percentile), axis=axis
69 | )
70 | self.current_xmin = torch.tensor(x_min).to(x.device)
71 | self.current_xmax = torch.tensor(x_max).to(x.device)
72 | else:
73 | self.current_xmin = x.min(-1)[0].detach() if self.per_channel else x.min().detach()
74 | self.current_xmax = x.max(-1)[0].detach() if self.per_channel else x.max().detach()
75 |
76 | return self.current_xmin, self.current_xmax
77 |
78 |
79 | class AllMinMaxEstimator(RangeEstimatorBase):
80 | def __init__(self, *args, **kwargs):
81 | super().__init__(*args, **kwargs)
82 |
83 | def forward(self, x):
84 | if self.per_channel:
85 | # Along 1st dim
86 | x_flattened = x.view(x.shape[0], -1)
87 | x_min = x_flattened.min(-1)[0].detach()
88 | x_max = x_flattened.max(-1)[0].detach()
89 | else:
90 | x_min = torch.min(x).detach()
91 | x_max = torch.max(x).detach()
92 |
93 | if self.current_xmin is None:
94 | self.current_xmin = x_min
95 | self.current_xmax = x_max
96 | else:
97 | self.current_xmin = torch.min(self.current_xmin, x_min)
98 | self.current_xmax = torch.max(self.current_xmax, x_max)
99 |
100 | return self.current_xmin, self.current_xmax
101 |
102 |
103 | class RunningMinMaxEstimator(RangeEstimatorBase):
104 | def __init__(self, momentum=0.9, *args, **kwargs):
105 | self.momentum = momentum
106 | super().__init__(*args, **kwargs)
107 |
108 | def forward(self, x):
109 | if self.per_channel:
110 | # Along 1st dim
111 | x_flattened = x.view(x.shape[0], -1)
112 | x_min = x_flattened.min(-1)[0].detach()
113 | x_max = x_flattened.max(-1)[0].detach()
114 | else:
115 | x_min = torch.min(x).detach()
116 | x_max = torch.max(x).detach()
117 |
118 | if self.current_xmin is None:
119 | self.current_xmin = x_min
120 | self.current_xmax = x_max
121 | else:
122 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin
123 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax
124 |
125 | return self.current_xmin, self.current_xmax
126 |
127 |
128 | class OptMethod(BaseEnumOptions):
129 | grid = auto()
130 | golden_section = auto()
131 |
132 |
133 | class LineSearchEstimator(RangeEstimatorBase):
134 | def __init__(
135 | self,
136 | num_candidates=1000,
137 | opt_method=OptMethod.grid,
138 | range_margin=0.5,
139 | expand_range=10.0,
140 | *args,
141 | **kwargs,
142 | ):
143 |
144 | super().__init__(*args, **kwargs)
145 | assert opt_method in OptMethod
146 | self.opt_method = opt_method
147 | self.num_candidates = num_candidates
148 | self.expand_range = expand_range
149 | self.loss_array = None
150 | self.max_pos_thr = None
151 | self.max_neg_thr = None
152 | self.max_search_range = None
153 | self.one_sided_dist = None
154 | self.range_margin = range_margin
155 | if self.quantizer is None:
156 | raise NotImplementedError(
157 | "A Quantizer must be given as an argument to the MSE Range" "Estimator"
158 | )
159 | self.max_int_skew = (2**self.quantizer.n_bits) // 4 # For asymmetric quantization
160 |
161 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False):
162 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr)
163 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1)
164 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel
165 | # grid search
166 | if per_channel_loss:
167 | return to_numpy(temp_sum)
168 | else:
169 | return to_numpy(torch.sum(temp_sum))
170 |
171 | @property
172 | def step_size(self):
173 | if self.one_sided_dist is None:
174 | raise NoDataPassedError()
175 |
176 | return self.max_search_range / self.num_candidates
177 |
178 | @property
179 | def optimization_method(self):
180 | if self.one_sided_dist is None:
181 | raise NoDataPassedError()
182 |
183 | if self.opt_method == OptMethod.grid:
184 | # Grid search method
185 | if self.one_sided_dist or self.quantizer.symmetric:
186 | # 1-D grid search
187 | return self._perform_1D_search
188 | else:
189 | # 2-D grid_search
190 | return self._perform_2D_search
191 | elif self.opt_method == OptMethod.golden_section:
192 | # Golden section method
193 | if self.one_sided_dist or self.quantizer.symmetric:
194 | return self._golden_section_symmetric
195 | else:
196 | return self._golden_section_asymmetric
197 | else:
198 | raise NotImplementedError("Optimization Method not Implemented")
199 |
200 | def quantize(self, x_float, x_min=None, x_max=None):
201 | temp_q = copy.deepcopy(self.quantizer)
202 | # In the current implementation no optimization procedure requires temp quantizer for
203 | # loss_fx to be per-channel
204 | temp_q.per_channel = False
205 | if x_min or x_max:
206 | temp_q.set_quant_range(x_min, x_max)
207 | return temp_q(x_float)
208 |
209 | def _define_search_range(self, data):
210 | self.channel_groups = len(data) if self.per_channel else 1
211 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device)
212 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device)
213 |
214 | if self.one_sided_dist or self.quantizer.symmetric:
215 | # 1D search space
216 | self.loss_array = np.zeros(
217 | (self.channel_groups, self.num_candidates + 1)
218 | ) # 1D search space
219 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish
220 | # Defining the search range for clipping thresholds
221 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin
222 | self.max_neg_thr = -self.max_pos_thr * self.expand_range
223 | self.max_search_range = self.max_pos_thr * self.expand_range
224 | else:
225 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth
226 | # index represents whether the skew is positive (0) or negative (1))
227 | self.loss_array = np.zeros(
228 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2]
229 | ) # 2D search space
230 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish
231 | # Define the search range for clipping thresholds in asymmetric case
232 | self.max_pos_thr = float(data.max()) + self.range_margin
233 | self.max_neg_thr = float(data.min()) - self.range_margin
234 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr))
235 |
236 | def _perform_1D_search(self, data):
237 | """
238 | Grid search through all candidate quantizers in 1D to find the best
239 | The loss is accumulated over all batches without any momentum
240 | :param data: input tensor
241 | """
242 | for cand_index in range(1, self.num_candidates + 1):
243 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index
244 | pos_thr = self.step_size * cand_index
245 |
246 | self.loss_array[:, cand_index] += self.loss_fx(
247 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel
248 | )
249 |
250 | min_cand = self.loss_array.argmin(axis=1)
251 | xmin = (
252 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand
253 | ).astype(np.single)
254 | xmax = (self.step_size * min_cand).astype(np.single)
255 | self.current_xmax = torch.tensor(xmax).to(device=data.device)
256 | self.current_xmin = torch.tensor(xmin).to(device=data.device)
257 |
258 | def forward(self, data):
259 | if self.loss_array is None:
260 | # Initialize search range on first batch, and accumulate losses with subsequent calls
261 |
262 | # Decide whether input distribution is one-sided
263 | if self.one_sided_dist is None:
264 | self.one_sided_dist = bool((data.min() >= 0).item())
265 |
266 | # Define search
267 | self._define_search_range(data)
268 |
269 | # Perform Search/Optimization for Quantization Ranges
270 | self.optimization_method(data)
271 |
272 | return self.current_xmin, self.current_xmax
273 |
274 | def reset(self):
275 | super().reset()
276 | self.loss_array = None
277 |
278 | def extra_repr(self):
279 | repr = "opt_method={}".format(self.opt_method.name)
280 | if self.opt_method == OptMethod.grid:
281 | repr += " ,num_candidates={}".format(self.num_candidates)
282 | return repr
283 |
284 |
285 | class FP_MSE_Estimator(RangeEstimatorBase):
286 | def __init__(
287 | self, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, *args, **kwargs
288 | ):
289 | super(FP_MSE_Estimator, self).__init__(*args, **kwargs)
290 | assert opt_method == OptMethod.grid
291 |
292 | self.num_candidates = num_candidates
293 | self.mses = self.search_grid = None
294 |
295 | def _define_search_range(self, x, mbit_list):
296 | if self.per_channel:
297 | x = x.view(x.shape[0], -1)
298 | else:
299 | x = x.view(1, -1)
300 | mxs = [torch.max(torch.abs(xc.min()), torch.abs(xc.max())) for xc in x]
301 |
302 | if self.search_grid is None:
303 | assert self.mses is None
304 |
305 | lsp = [torch.linspace(0.1 * mx.item(), 1.2 * mx.item(), 111) for mx in mxs]
306 |
307 | # 111 x n_channels
308 | search_grid = torch.stack(lsp).to(x.device).transpose(0, 1)
309 |
310 | # mbits x 111 x n_channels (or 1 in case not --per-channel)
311 | mses = torch.stack([torch.zeros_like(search_grid) for _ in range(len(mbit_list))])
312 |
313 | self.mses = mses
314 | self.search_grid = search_grid
315 |
316 | return self.search_grid, self.mses
317 |
318 | def forward(self, x):
319 | mbit_list = [float(self.quantizer.mantissa_bits)]
320 |
321 | if self.quantizer.mse_include_mantissa_bits:
322 | # highest possible value is self.n_bits - self.sign_bits - 1
323 | mbit_list = [
324 | float(x) for x in range(1, self.quantizer.n_bits - self.quantizer.sign_bits)
325 | ]
326 |
327 | search_grid, mses = self._define_search_range(x, mbit_list)
328 |
329 | assert mses.shape[1:] == search_grid.shape, f"{mses.shape}, {search_grid.shape}"
330 |
331 | # Need to do this here too to get correct search range
332 | sign_bits = int(torch.any(x < 0)) if self.quantizer.allow_unsigned else 1
333 |
334 | meandims = list(torch.arange(len(x.shape)))
335 | if self.per_channel:
336 | meandims = meandims[1:]
337 | for m, mbits in enumerate(mbit_list):
338 | mbits = torch.Tensor([mbits]).to(x.device)
339 | self.quantizer.mantissa_bits = mbits
340 | for i, maxval in enumerate(search_grid):
341 | x_min, x_max = sign_bits * -1.0 * maxval, maxval
342 | self.quantizer.set_quant_range(x_min, x_max)
343 | xfp = self.quantizer(x)
344 |
345 | # get MSE per channel (mean over all non-channel dims)
346 | mse = ((x - xfp) ** 2).mean(meandims)
347 | mses[m, i, :] += mse
348 |
349 | # Find best mbits per channel
350 | best_mbits_per_channel = mses.min(1)[0].argmin(0)
351 |
352 | # Get plurality vote on mbits
353 | best_mbit_idx = torch.mode(best_mbits_per_channel).values.item()
354 | best_mbits = float(mbit_list[best_mbit_idx])
355 |
356 | # then, find best per-channel scale for best mbit
357 | # first, get the MSES for the best mbit, then argmin over linspace dim to get best index per channel
358 | mses = mses[best_mbit_idx].argmin(0)
359 | # then, for each channel, get the argmin MSE max value
360 | maxval = torch.tensor([search_grid[mses[i], i] for i in range(search_grid.shape[-1])]).to(
361 | x.device
362 | )
363 |
364 | self.quantizer.mantissa_bits = torch.tensor(best_mbits).to(
365 | self.quantizer.mantissa_bits.device
366 | )
367 |
368 | maxval = maxval.to(self.quantizer.maxval.device)
369 | return sign_bits * -1.0 * maxval, maxval
370 |
371 |
372 | def estimate_range_line_search(W, quant, num_candidates=None):
373 | if num_candidates is None:
374 | est_fp = LineSearchEstimator(quantizer=quant)
375 | else:
376 | est_fp = LineSearchEstimator(quantizer=quant, num_candidates=num_candidates)
377 |
378 | mse_range_min_fp, mse_range_max_fp = est_fp.forward(W)
379 | return (mse_range_min_fp, mse_range_max_fp)
380 |
381 |
382 | class NoDataPassedError(Exception):
383 | """Raised data has been passed into the Range Estimator"""
384 |
385 | def __init__(self):
386 | super().__init__("Data must be pass through the range estimator to be initialized")
387 |
388 |
389 | class RangeEstimators(ClassEnumOptions):
390 | current_minmax = MethodMap(CurrentMinMaxEstimator)
391 | allminmax = MethodMap(AllMinMaxEstimator)
392 | running_minmax = MethodMap(RunningMinMaxEstimator)
393 | MSE = MethodMap(FP_MSE_Estimator)
394 |
--------------------------------------------------------------------------------
/quantization/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 |
6 | import torch
7 | import torch.serialization
8 |
9 | from quantization.quantizers import QuantizerBase
10 | from quantization.quantizers.rounding_utils import ParametrizedGradEstimatorBase
11 | from quantization.range_estimators import RangeEstimators
12 | from utils import StopForwardException, get_layer_by_name
13 |
14 |
15 | def separate_quantized_model_params(quant_model):
16 | """
17 | This method separates the parameters of the quantized model to 4 categories.
18 | Parameters
19 | ----------
20 | quant_model: (QuantizedModel)
21 |
22 | Returns
23 | -------
24 | quant_params: (list)
25 | Quantization parameters, e.g. delta and zero_float
26 | model_params: (list)
27 | The model parameters of the base model without any quantization operations
28 | grad_params: (list)
29 | Parameters found in the gradient estimators (ParametrizedGradEstimatorBase)
30 | -------
31 |
32 | """
33 | quant_params, grad_params = [], []
34 | quant_params_names, grad_params_names = [], []
35 | for mod_name, module in quant_model.named_modules():
36 | if isinstance(module, QuantizerBase):
37 | for name, param in module.named_parameters(recurse=False):
38 | quant_params.append(param)
39 | quant_params_names.append(".".join((mod_name, name)))
40 | if isinstance(module, ParametrizedGradEstimatorBase):
41 | # gradient estimator params
42 | for name, param in module.named_parameters(recurse=False):
43 | grad_params.append(param)
44 | grad_params_names.append(".".join((mod_name, name)))
45 |
46 | def tensor_in_list(tensor, lst):
47 | return any([e is tensor for e in lst])
48 |
49 | found_params = quant_params + grad_params
50 |
51 | model_params = [p for p in quant_model.parameters() if not tensor_in_list(p, found_params)]
52 | model_param_names = [
53 | n for n, p in quant_model.named_parameters() if not tensor_in_list(p, found_params)
54 | ]
55 |
56 | print("Quantization parameters ({}):".format(len(quant_params_names)))
57 | print(quant_params_names)
58 |
59 | print("Gradient estimator parameters ({}):".format(len(grad_params_names)))
60 | print(grad_params_names)
61 |
62 | print("Other model parameters ({}):".format(len(model_param_names)))
63 | print(model_param_names)
64 |
65 | assert len(model_params + quant_params + grad_params) == len(
66 | list(quant_model.parameters())
67 | ), "{}; {}; {} -- {}".format(
68 | len(model_params), len(quant_params), len(grad_params), len(list(quant_model.parameters()))
69 | )
70 |
71 | return quant_params, model_params, grad_params
72 |
73 |
74 | def pass_data_for_range_estimation(
75 | loader, model, act_quant, weight_quant, max_num_batches=20, cross_entropy_layer=None, inp_idx=0
76 | ):
77 | print("\nEstimate quantization ranges on training data")
78 | model.set_quant_state(weight_quant, act_quant)
79 | # Put model in eval such that BN EMA does not get updated
80 | model.eval()
81 |
82 | if cross_entropy_layer is not None:
83 | layer_xent = get_layer_by_name(model, cross_entropy_layer)
84 | if layer_xent:
85 | print('Set cross entropy estimator for layer "{}"'.format(cross_entropy_layer))
86 | act_quant_mgr = layer_xent.activation_quantizer
87 | act_quant_mgr.range_estimator = RangeEstimators.cross_entropy.cls(
88 | per_channel=act_quant_mgr.per_channel,
89 | quantizer=act_quant_mgr.quantizer,
90 | **act_quant_mgr.range_estim_params,
91 | )
92 | else:
93 | raise ValueError("Cross-entropy layer not found")
94 |
95 | batches = []
96 | device = next(model.parameters()).device
97 |
98 | with torch.no_grad():
99 | for i, data in enumerate(loader):
100 | try:
101 | if isinstance(data, (tuple, list)):
102 | x = data[inp_idx].to(device=device)
103 | batches.append(x.data.cpu().numpy())
104 | model(x)
105 | print(f"proccesed step={i}")
106 | else:
107 | x = {k: v.to(device=device) for k, v in data.items()}
108 | model(**x)
109 | print(f"proccesed step={i}")
110 |
111 | if i >= max_num_batches - 1 or not act_quant:
112 | break
113 | except StopForwardException:
114 | pass
115 | return batches
116 |
117 |
118 | def set_range_estimators(config, model):
119 | print("Make quantizers learnable")
120 | model.learn_ranges()
121 |
122 | if config.qat.grad_scaling:
123 | print("Activate gradient scaling")
124 | model.grad_scaling(True)
125 |
126 | # Ensure we have the desired quant state
127 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant)
128 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | click==8.1.3
2 | pytorch-ignite~=0.4.9
3 | tensorboard==2.10.1
4 | scipy==1.9.3
5 | numpy==1.22.4
6 | timm~=0.4.12
7 | tqdm==4.49.0
8 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from .stopwatch import Stopwatch
6 | from .utils import *
7 |
--------------------------------------------------------------------------------
/utils/click_options.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import click
6 |
7 | from models import QuantArchitectures
8 | from functools import wraps, partial
9 | from quantization.quantization_manager import QMethods
10 | from quantization.range_estimators import RangeEstimators, OptMethod
11 | from utils import split_dict, DotDict, ClickEnumOption, seed_all
12 | from utils.imagenet_dataloaders import ImageInterpolation
13 |
14 | click.option = partial(click.option, show_default=True)
15 |
16 | _HELP_MSG = (
17 | "Enforce determinism also on the GPU by disabling CUDNN and setting "
18 | "`torch.set_deterministic(True)`. In many cases this comes at the cost of efficiency "
19 | "and performance."
20 | )
21 |
22 |
23 | def base_options(func):
24 | @click.option(
25 | "--images-dir", type=click.Path(exists=True), help="Root directory of images", required=True
26 | )
27 | @click.option("--max-epochs", default=90, type=int, help="Maximum number of training epochs.")
28 | @click.option(
29 | "--interpolation",
30 | type=ClickEnumOption(ImageInterpolation),
31 | default=ImageInterpolation.bilinear.name,
32 | help="Desired interpolation to use for resizing.",
33 | )
34 | @click.option(
35 | "--save-checkpoint-dir",
36 | type=click.Path(exists=False),
37 | default=None,
38 | help="Directory where to save checkpoints (model, optimizer, lr_scheduler).",
39 | )
40 | @click.option(
41 | "--tb-logging-dir", default=None, type=str, help="The logging directory " "for tensorboard"
42 | )
43 | @click.option("--cuda/--no-cuda", is_flag=True, default=True, help="Use GPU")
44 | @click.option("--batch-size", default=128, type=int, help="Mini-batch size")
45 | @click.option("--num-workers", default=16, type=int, help="Number of workers for data loading")
46 | @click.option("--seed", default=None, type=int, help="Random number generator seed to set")
47 | @click.option("--deterministic/--nondeterministic", default=False, help=_HELP_MSG)
48 | # Architecture related options
49 | @click.option(
50 | "--architecture",
51 | type=ClickEnumOption(QuantArchitectures),
52 | required=True,
53 | help="Quantized architecture",
54 | )
55 | @click.option(
56 | "--model-dir",
57 | type=click.Path(exists=True),
58 | default=None,
59 | help="Path for model directory. If the model does not exist it will downloaded "
60 | "from a URL",
61 | )
62 | @click.option(
63 | "--pretrained/--no-pretrained",
64 | is_flag=True,
65 | default=True,
66 | help="Use pretrained model weights",
67 | )
68 | @click.option(
69 | "--progress-bar/--no-progress-bar", is_flag=True, default=False, help="Show progress bar"
70 | )
71 | @wraps(func)
72 | def func_wrapper(config, *args, **kwargs):
73 | config.base, remaining_kwargs = split_dict(
74 | kwargs,
75 | [
76 | "images_dir",
77 | "max_epochs",
78 | "interpolation",
79 | "save_checkpoint_dir",
80 | "tb_logging_dir",
81 | "cuda",
82 | "batch_size",
83 | "num_workers",
84 | "seed",
85 | "model_dir",
86 | "architecture",
87 | "pretrained",
88 | "deterministic",
89 | "progress_bar",
90 | ],
91 | )
92 |
93 | seed, deterministic = config.base.seed, config.base.deterministic
94 |
95 | if seed is None:
96 | if deterministic is True:
97 | raise ValueError("Enforcing determinism without providing a seed is not supported")
98 | else:
99 | seed_all(seed=seed, deterministic=deterministic)
100 |
101 | return func(config, *args, **remaining_kwargs)
102 |
103 | return func_wrapper
104 |
105 |
106 | class multi_optimizer_options:
107 | """
108 | An instance of this class is a callable object to serve as a decorator;
109 | hence the lower case class name.
110 |
111 | Among the CLI options defined in the decorator, `--{prefix-}optimizer-type`
112 | requires special attention. Default value for that variable for
113 | {prefix-}optimizer is the value in use by the main optimizer.
114 |
115 | Examples:
116 | @multi_optimizer_options('quant')
117 | @pass_config
118 | def command(config):
119 | ...
120 | """
121 |
122 | def __init__(self, prefix: str = ""):
123 | self.optimizer_name = prefix + "_optimizer" if prefix else "optimizer"
124 | self.prefix_option = prefix + "-" if prefix else ""
125 | self.prefix_attribute = prefix + "_" if prefix else ""
126 |
127 | def __call__(self, func):
128 | prefix_option = self.prefix_option
129 | prefix_attribute = self.prefix_attribute
130 |
131 | @click.option(
132 | f"--{prefix_option}optimizer",
133 | default="SGD",
134 | type=click.Choice(["SGD", "Adam"], case_sensitive=False),
135 | help=f"Class name of torch Optimizer to be used.",
136 | )
137 | @click.option(
138 | f"--{prefix_option}learning-rate",
139 | default=None,
140 | type=float,
141 | help="Initial learning rate.",
142 | )
143 | @click.option(
144 | f"--{prefix_option}momentum", default=0.9, type=float, help=f"Optimizer momentum."
145 | )
146 | @click.option(
147 | f"--{prefix_option}weight-decay",
148 | default=None,
149 | type=float,
150 | help="Weight decay for the network.",
151 | )
152 | @click.option(
153 | f"--{prefix_option}learning-rate-schedule",
154 | default=None,
155 | type=str,
156 | help="Learning rate scheduler, 'MultiStepLR:10:20:40' or "
157 | "'cosine:1e-4' for cosine decay",
158 | )
159 | @wraps(func)
160 | def func_wrapper(config, *args, **kwargs):
161 | base_arg_names = [
162 | "optimizer",
163 | "learning_rate",
164 | "momentum",
165 | "weight_decay",
166 | "learning_rate_schedule",
167 | ]
168 |
169 | optimizer_opt = DotDict()
170 |
171 | # Collect basic arguments
172 | for arg in base_arg_names:
173 | option_name = prefix_attribute + arg
174 | optimizer_opt[arg] = kwargs.pop(option_name)
175 |
176 | # config.{prefix_attribute}optimizer = optimizer_opt
177 | setattr(config, prefix_attribute + "optimizer", optimizer_opt)
178 |
179 | return func(config, *args, **kwargs)
180 |
181 | return func_wrapper
182 |
183 |
184 | def qat_options(func):
185 | @click.option(
186 | "--reestimate-bn-stats/--no-reestimate-bn-stats",
187 | is_flag=True,
188 | default=True,
189 | help="Reestimates the BN stats before every evaluation.",
190 | )
191 | @click.option(
192 | "--grad-scaling/--no-grad-scaling",
193 | is_flag=True,
194 | default=False,
195 | help="Do gradient scaling as in LSQ paper.",
196 | )
197 | @click.option(
198 | "--sep-quant-optimizer/--no-sep-quant-optimizer",
199 | is_flag=True,
200 | default=False,
201 | help="Use a separate optimizer for the quantizers.",
202 | )
203 | @multi_optimizer_options("quant")
204 | @oscillations_dampen_options
205 | @oscillations_freeze_options
206 | @wraps(func)
207 | def func_wrapper(config, *args, **kwargs):
208 | config.qat, remainder_kwargs = split_dict(
209 | kwargs, ["reestimate_bn_stats", "grad_scaling", "sep_quant_optimizer"]
210 | )
211 | return func(config, *args, **remainder_kwargs)
212 |
213 | return func_wrapper
214 |
215 |
216 | def oscillations_dampen_options(func):
217 | @click.option(
218 | "--oscillations-dampen-weight",
219 | default=None,
220 | type=float,
221 | help="If given, adds a oscillations dampening to the loss with given " "weighting.",
222 | )
223 | @click.option(
224 | "--oscillations-dampen-aggregation",
225 | type=click.Choice(["sum", "mean", "kernel_mean"]),
226 | default="kernel_mean",
227 | help="Aggregation type for bin regularization loss.",
228 | )
229 | @click.option(
230 | "--oscillations-dampen-weight-final",
231 | type=float,
232 | default=None,
233 | help="Dampening regularization final value for annealing schedule.",
234 | )
235 | @click.option(
236 | "--oscillations-dampen-anneal-start",
237 | default=0.25,
238 | type=float,
239 | help="Start of annealing (relative to total number of iterations).",
240 | )
241 | @wraps(func)
242 | def func_wrapper(config, *args, **kwargs):
243 | config.osc_damp, remainder_kwargs = split_dict(
244 | kwargs,
245 | [
246 | "oscillations_dampen_weight",
247 | "oscillations_dampen_aggregation",
248 | "oscillations_dampen_weight_final",
249 | "oscillations_dampen_anneal_start",
250 | ],
251 | "oscillations_dampen",
252 | )
253 |
254 | return func(config, *args, **remainder_kwargs)
255 |
256 | return func_wrapper
257 |
258 |
259 | def oscillations_freeze_options(func):
260 | @click.option(
261 | "--oscillations-freeze-threshold",
262 | default=0.0,
263 | type=float,
264 | help="If greater than 0, we will freeze oscillations which frequency (EMA) is "
265 | "higher than the given threshold. Frequency is defined as 1/period length.",
266 | )
267 | @click.option(
268 | "--oscillations-freeze-ema-momentum",
269 | default=0.001,
270 | type=float,
271 | help="The momentum to calculate the EMA frequency of the oscillation. In case"
272 | "freezing is used, this should be at least 2-3 times lower than the "
273 | "freeze threshold.",
274 | )
275 | @click.option(
276 | "--oscillations-freeze-use-ema/--no-oscillation-freeze-use-ema",
277 | is_flag=True,
278 | default=True,
279 | help="Uses an EMA of past x_int to find the correct freezing int value.",
280 | )
281 | @click.option(
282 | "--oscillations-freeze-max-bits",
283 | default=4,
284 | type=int,
285 | help="Max bit-width for oscillation tracking and freezing. If layers weight is in"
286 | "higher bits we do not track or freeze oscillations.",
287 | )
288 | @click.option(
289 | "--oscillations-freeze-threshold-final",
290 | type=float,
291 | default=None,
292 | help="Oscillation freezing final value for annealing schedule.",
293 | )
294 | @click.option(
295 | "--oscillations-freeze-anneal-start",
296 | default=0.25,
297 | type=float,
298 | help="Start of annealing (relative to total number of iterations).",
299 | )
300 | @wraps(func)
301 | def func_wrapper(config, *args, **kwargs):
302 | config.osc_freeze, remainder_kwargs = split_dict(
303 | kwargs,
304 | [
305 | "oscillations_freeze_threshold",
306 | "oscillations_freeze_ema_momentum",
307 | "oscillations_freeze_use_ema",
308 | "oscillations_freeze_max_bits",
309 | "oscillations_freeze_threshold_final",
310 | "oscillations_freeze_anneal_start",
311 | ],
312 | "oscillations_freeze",
313 | )
314 |
315 | return func(config, *args, **remainder_kwargs)
316 |
317 | return func_wrapper
318 |
319 |
320 | def quantization_options(func):
321 | # Weight quantization options
322 | @click.option(
323 | "--weight-quant/--no-weight-quant",
324 | is_flag=True,
325 | default=True,
326 | help="Run evaluation weight quantization or use FP32 weights",
327 | )
328 | @click.option(
329 | "--qmethod",
330 | type=ClickEnumOption(QMethods),
331 | default=QMethods.symmetric_uniform.name,
332 | help="Quantization scheme to use.",
333 | )
334 | @click.option(
335 | "--weight-quant-method",
336 | default=RangeEstimators.current_minmax.name,
337 | type=ClickEnumOption(RangeEstimators),
338 | help="Method to determine weight quantization clipping thresholds.",
339 | )
340 | @click.option(
341 | "--weight-opt-method",
342 | default=OptMethod.grid.name,
343 | type=ClickEnumOption(OptMethod),
344 | help="Optimization procedure for activation quantization clipping thresholds",
345 | )
346 | @click.option(
347 | "--num-candidates",
348 | type=int,
349 | default=None,
350 | help="Number of grid points for grid search in MSE range method.",
351 | )
352 | @click.option("--n-bits", default=8, type=int, help="Default number of quantization bits.")
353 | @click.option(
354 | "--per-channel/--no-per-channel",
355 | is_flag=True,
356 | default=False,
357 | help="If given, quantize each channel separately.",
358 | )
359 | # Activation quantization options
360 | @click.option(
361 | "--act-quant/--no-act-quant",
362 | is_flag=True,
363 | default=True,
364 | help="Run evaluation with activation quantization or use FP32 activations",
365 | )
366 | @click.option(
367 | "--qmethod-act",
368 | type=ClickEnumOption(QMethods),
369 | default=None,
370 | help="Quantization scheme for activation to use. If not specified `--qmethod` " "is used.",
371 | )
372 | @click.option(
373 | "--n-bits-act", default=None, type=int, help="Number of quantization bits for activations."
374 | )
375 | @click.option(
376 | "--act-quant-method",
377 | default=RangeEstimators.running_minmax.name,
378 | type=ClickEnumOption(RangeEstimators),
379 | help="Method to determine activation quantization clipping thresholds",
380 | )
381 | @click.option(
382 | "--act-opt-method",
383 | default=OptMethod.grid.name,
384 | type=ClickEnumOption(OptMethod),
385 | help="Optimization procedure for activation quantization clipping thresholds",
386 | )
387 | @click.option(
388 | "--act-num-candidates",
389 | type=int,
390 | default=None,
391 | help="Number of grid points for grid search in MSE/SQNR/Cross-entropy",
392 | )
393 | @click.option(
394 | "--act-momentum",
395 | type=float,
396 | default=None,
397 | help="Exponential averaging factor for running_minmax",
398 | )
399 | @click.option(
400 | "--num-est-batches",
401 | type=int,
402 | default=1,
403 | help="Number of training batches to be used for activation range estimation",
404 | )
405 | # Other options
406 | @click.option(
407 | "--quant-setup",
408 | default="all",
409 | type=click.Choice(["all", "LSQ", "FP_logits", "fc4", "fc4_dw8", "LSQ_paper"]),
410 | help="Method to quantize the network.",
411 | )
412 | @wraps(func)
413 | def func_wrapper(config, *args, **kwargs):
414 | config.quant, remainder_kwargs = split_dict(
415 | kwargs,
416 | [
417 | "qmethod",
418 | "qmethod_act",
419 | "weight_quant_method",
420 | "weight_opt_method",
421 | "num_candidates",
422 | "n_bits",
423 | "n_bits_act",
424 | "per_channel",
425 | "act_quant",
426 | "weight_quant",
427 | "quant_setup",
428 | "num_est_batches",
429 | "act_momentum",
430 | "act_num_candidates",
431 | "act_opt_method",
432 | "act_quant_method",
433 | ],
434 | )
435 |
436 | config.quant.qmethod_act = config.quant.qmethod_act or config.quant.qmethod
437 |
438 | return func(config, *args, **remainder_kwargs)
439 |
440 | return func_wrapper
441 |
442 |
443 | def fp8_options(func):
444 | # Weight quantization options
445 | @click.option("--fp8-maxval", type=float, default=None)
446 | @click.option("--fp8-mantissa-bits", type=int, default=4)
447 | @click.option("--fp8-set-maxval/--no-fp8-set-maxval", is_flag=True, default=False)
448 | @click.option("--fp8-learn-maxval/--no-fp8-learn-maxval", is_flag=True, default=False)
449 | @click.option(
450 | "--fp8-learn-mantissa-bits/--no-fp8-learn-mantissa-bits", is_flag=True, default=False
451 | )
452 | @click.option(
453 | "--fp8-mse-include-mantissa-bits/--no-fp8-mse-include-mantissa-bits",
454 | is_flag=True,
455 | default=True,
456 | )
457 | @click.option("--fp8-allow-unsigned/--no-fp8-allow-unsigned", is_flag=True, default=False)
458 | @wraps(func)
459 | def func_wrapper(config, *args, **kwargs):
460 | config.fp8, remainder_kwargs = split_dict(
461 | kwargs,
462 | [
463 | "fp8_maxval",
464 | "fp8_mantissa_bits",
465 | "fp8_set_maxval",
466 | "fp8_learn_maxval",
467 | "fp8_learn_mantissa_bits",
468 | "fp8_mse_include_mantissa_bits",
469 | "fp8_allow_unsigned",
470 | ],
471 | )
472 | return func(config, *args, **remainder_kwargs)
473 |
474 | return func_wrapper
475 |
476 |
477 | def quant_params_dict(config):
478 | weight_range_options = {}
479 | if config.quant.weight_quant_method == RangeEstimators.MSE:
480 | weight_range_options = dict(opt_method=config.quant.weight_opt_method)
481 | if config.quant.num_candidates is not None:
482 | weight_range_options["num_candidates"] = config.quant.num_candidates
483 |
484 | act_range_options = {}
485 | if config.quant.act_quant_method == RangeEstimators.MSE:
486 | act_range_options = dict(opt_method=config.quant.act_opt_method)
487 | if config.quant.act_num_candidates is not None:
488 | act_range_options["num_candidates"] = config.quant.num_candidates
489 |
490 | qparams = {
491 | "method": config.quant.qmethod.cls,
492 | "n_bits": config.quant.n_bits,
493 | "n_bits_act": config.quant.n_bits_act,
494 | "act_method": config.quant.qmethod_act.cls,
495 | "per_channel_weights": config.quant.per_channel,
496 | "quant_setup": config.quant.quant_setup,
497 | "weight_range_method": config.quant.weight_quant_method.cls,
498 | "weight_range_options": weight_range_options,
499 | "act_range_method": config.quant.act_quant_method.cls,
500 | "act_range_options": act_range_options,
501 | "quantize_input": True if config.quant.quant_setup == "LSQ_paper" else False,
502 | }
503 |
504 | if config.quant.qmethod.name.startswith("fp_quantizer"):
505 | fp8_kwargs = {
506 | k.replace("fp8_", ""): v for k, v in config.fp8.items() if k.startswith("fp8")
507 | }
508 | qparams["fp8_kwargs"] = fp8_kwargs
509 |
510 | return qparams
511 |
--------------------------------------------------------------------------------
/utils/distributions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import numpy as np
6 | import scipy.stats as stats
7 | import scipy.integrate
8 | from scipy import special
9 |
10 |
11 | class DistrBase:
12 | def __init__(self, params_dict, range_min, range_max, *args, **kwargs):
13 | self.params_dict = params_dict
14 | assert range_max >= range_min
15 | self.range_min = range_min
16 | self.range_max = range_max
17 |
18 | def pdf(self, x):
19 | raise NotImplementedError()
20 |
21 | def sample(self, shape):
22 | raise NotImplementedError()
23 |
24 | def eval_point_mass_range_min(self):
25 | raise NotImplementedError()
26 |
27 | def eval_point_mass_range_max(self):
28 | raise NotImplementedError()
29 |
30 | def integr_interv_p_sqr_r(self, a, b):
31 | raise NotImplementedError()
32 |
33 | def integr_interv_x_p_r_signed(self, a, b):
34 | raise NotImplementedError()
35 |
36 | def eval_p_sqr_r(self, x, grid):
37 | raise NotImplementedError()
38 |
39 | def eval_x_p_r_signed(self, x, grid):
40 | raise NotImplementedError()
41 |
42 | def eval_non_central_second_moment(self):
43 | raise NotImplementedError()
44 |
45 | def print(self):
46 | raise NotImplementedError()
47 |
48 |
49 | class ClippedGaussDistr(DistrBase):
50 | def __init__(self, *args, **kwargs):
51 | super().__init__(*args, **kwargs)
52 | mu = self.params_dict["mu"]
53 | sigma = self.params_dict["sigma"]
54 | self.point_mass_range_min = stats.norm.cdf(self.range_min, loc=mu, scale=sigma)
55 | self.point_mass_range_max = 1.0 - stats.norm.cdf(self.range_max, loc=mu, scale=sigma)
56 |
57 | def print(self):
58 | print(
59 | "Gaussian distr ",
60 | ", mu = ",
61 | self.params_dict["mu"],
62 | ", sigma = ",
63 | self.params_dict["sigma"],
64 | " clipped at [",
65 | self.range_min,
66 | ",",
67 | self.range_max,
68 | "]",
69 | )
70 |
71 | def cdf(self, x_np):
72 | p = stats.norm.cdf(x_np, self.params_dict["mu"], self.params_dict["sigma"])
73 | return p
74 |
75 | def pdf(self, x):
76 | x_np = x.cpu().numpy()
77 | p = stats.norm.pdf(x_np, self.params_dict["mu"], self.params_dict["sigma"])
78 | return p
79 |
80 | def inverse_cdf(self, x):
81 | res = stats.norm.ppf(x, loc=self.params_dict["mu"], scale=self.params_dict["sigma"])
82 | return res
83 |
84 | def sample(self, shape):
85 | r = np.random.normal(
86 | loc=self.params_dict["mu"], scale=self.params_dict["sigma"], size=shape
87 | )
88 | r = np.clip(r, self.range_min, self.range_max)
89 | return r
90 |
91 | def integr_interv_p_sqr_r(self, a, b, u):
92 | assert b >= a
93 | mu = self.params_dict["mu"]
94 | sigma = self.params_dict["sigma"]
95 | root_half = np.sqrt(0.5)
96 | root_half_pi = np.sqrt(0.5 * np.pi)
97 | t1 = -sigma * (
98 | np.exp((-0.5 * a**2 + 1.0 * a * mu - 0.5 * mu**2) / sigma**2)
99 | * sigma
100 | * (-1.0 * a - 1.0 * mu + 2.0 * u)
101 | + (
102 | -root_half_pi * mu**2
103 | - root_half_pi * sigma**2
104 | + 2.0 * root_half_pi * mu * u
105 | - root_half_pi * u**2
106 | )
107 | * special.erf((-root_half * a + root_half * mu) / sigma)
108 | )
109 | t2 = sigma * (
110 | np.exp((-0.5 * b**2 + 1.0 * b * mu - 0.5 * mu**2) / sigma**2)
111 | * sigma
112 | * (-1.0 * b - 1.0 * mu + 2.0 * u)
113 | + (
114 | -root_half_pi * mu**2
115 | - root_half_pi * sigma**2
116 | + 2.0 * root_half_pi * mu * u
117 | - root_half_pi * u**2
118 | )
119 | * special.erf((-root_half * b + root_half * mu) / sigma)
120 | )
121 | const = 1 / sigma / np.sqrt(2 * np.pi)
122 | return (t1 + t2) * const
123 |
124 | def integr_interv_x_p_signed_r(self, a, b, x0):
125 | assert b >= a
126 | mu = self.params_dict["mu"]
127 | sigma = self.params_dict["sigma"]
128 | root_half = np.sqrt(0.5)
129 | root_half_pi = np.sqrt(0.5 * np.pi)
130 |
131 | res = (
132 | x0
133 | * sigma
134 | * (
135 | np.exp(-((0.5 * mu**2) / sigma**2))
136 | * (
137 | np.exp((a * (-0.5 * a + mu)) / sigma**2)
138 | - np.exp((b * (-0.5 * b + mu)) / sigma**2)
139 | )
140 | * sigma
141 | - root_half_pi * mu * special.erf((root_half * a - root_half * mu) / sigma)
142 | + root_half_pi * mu * special.erf((root_half * b - root_half * mu) / sigma)
143 | )
144 | + sigma
145 | * (
146 | np.exp((-0.5 * a**2 + a * mu - 0.5 * mu**2) / sigma**2)
147 | * (-a * sigma - mu * sigma)
148 | + (-root_half_pi * mu**2 - root_half_pi * sigma**2)
149 | * special.erf((-root_half * a + root_half * mu) / sigma)
150 | )
151 | - sigma
152 | * (
153 | np.exp((-0.5 * b**2 + b * mu - 0.5 * mu**2) / sigma**2)
154 | * (-b * sigma - mu * sigma)
155 | + (-root_half_pi * mu**2 - root_half_pi * sigma**2)
156 | * special.erf((-root_half * b + root_half * mu) / sigma)
157 | )
158 | )
159 |
160 | const = 1 / sigma / np.sqrt(2 * np.pi)
161 | return res * const
162 |
163 | def integr_p_times_x(self, a, b):
164 | assert b >= a
165 | mu = self.params_dict["mu"]
166 | sigma = self.params_dict["sigma"]
167 |
168 | root_half = np.sqrt(0.5)
169 | root_half_pi = np.sqrt(0.5 * np.pi)
170 |
171 | res = sigma * (
172 | np.exp(-((0.5 * mu**2) / sigma**2))
173 | * (
174 | np.exp((a * (-0.5 * a + mu)) / sigma**2)
175 | - np.exp((b * (-0.5 * b + mu)) / sigma**2)
176 | )
177 | * sigma
178 | - root_half_pi * mu * special.erf(root_half * (a - mu) / sigma)
179 | + root_half_pi * mu * special.erf(root_half * (b - mu) / sigma)
180 | )
181 |
182 | scale = 1 / (sigma * np.sqrt(2 * np.pi))
183 | return res * scale
184 |
185 | def eval_non_central_second_moment(self):
186 | term_range_min = self.point_mass_range_min * self.range_min**2
187 | term_range_max = self.point_mass_range_max * self.range_max**2
188 | term_middle_intergral = self.integr_interv_p_sqr_r(self.range_min, self.range_max, 0.0)
189 | return term_range_min + term_middle_intergral + term_range_max
190 |
191 |
192 | class ClippedStudentTDistr(DistrBase):
193 | def __init__(self, *args, **kwargs):
194 | super().__init__(*args, **kwargs)
195 | nu = self.params_dict["nu"]
196 | self.point_mass_range_min = stats.t.cdf(self.range_min, nu)
197 | self.point_mass_range_max = 1.0 - stats.t.cdf(self.range_max, nu)
198 |
199 | def print(self):
200 | print(
201 | "Student's-t distr",
202 | ", nu = ",
203 | self.params_dict["nu"],
204 | " clipped at [",
205 | self.range_min,
206 | ",",
207 | self.range_max,
208 | "]",
209 | )
210 |
211 | def pdf(self, x):
212 | x_np = x.cpu().numpy()
213 | p = stats.t.pdf(x_np, self.params_dict["nu"])
214 | return p
215 |
216 | def cdf(self, x):
217 | p = stats.t.cdf(x, self.params_dict["nu"])
218 | return p
219 |
220 | def inverse_cdf(self, x):
221 | res = stats.t.ppf(x, self.params_dict["nu"])
222 | return res
223 |
224 | def sample(self, shape):
225 | r = np.random.standard_t(self.params_dict["nu"], size=shape)
226 | r = np.clip(r, self.range_min, self.range_max)
227 | return r
228 |
229 | def integr_interv_p_sqr_r(self, a, b, u):
230 | assert b >= a
231 | nu = self.params_dict["nu"]
232 |
233 | first_term = (2.0 * nu * (-1.0 + ((a**2 + nu) / nu) ** (1.0 / 2.0 - nu / 2.0)) * u) / (
234 | 1.0 - nu
235 | )
236 | second_term = -(2 * nu * (-1 + ((b**2 + nu) / nu) ** (1.0 / 2.0 - nu / 2)) * u) / (
237 | 1.0 - nu
238 | )
239 | third_term = (
240 | -a
241 | * u**2
242 | * scipy.special.hyp2f1(1.0 / 2.0, (1.0 + nu) / 2.0, 3.0 / 2.0, -(a**2.0 / nu))
243 | )
244 | forth_term = (
245 | b
246 | * u**2.0
247 | * scipy.special.hyp2f1(1.0 / 2.0, (1.0 + nu) / 2.0, 3.0 / 2.0, -(b**2 / nu))
248 | )
249 | fifth_term = (
250 | -1.0
251 | / 3.0
252 | * a**3
253 | * scipy.special.hyp2f1(3.0 / 2.0, (1.0 + nu) / 2.0, 5.0 / 2.0, -(a**2 / nu))
254 | )
255 | sixth_term = (
256 | 1.0
257 | / 3.0
258 | * b**3
259 | * scipy.special.hyp2f1(3.0 / 2.0, (1.0 + nu) / 2.0, 5.0 / 2.0, -(b**2 / nu))
260 | )
261 | res = first_term + second_term + third_term + forth_term + fifth_term + sixth_term
262 |
263 | const = (
264 | scipy.special.gamma(0.5 * (nu + 1.0))
265 | / np.sqrt(np.pi * nu)
266 | / scipy.special.gamma(0.5 * nu)
267 | )
268 | return res * const
269 |
270 | def integr_p_times_x(self, a, b):
271 | assert b >= a
272 | nu = self.params_dict["nu"]
273 | res = 0.0
274 |
275 | const = (
276 | scipy.special.gamma(0.5 * (nu + 1.0))
277 | / np.sqrt(np.pi * nu)
278 | / scipy.special.gamma(0.5 * nu)
279 | )
280 | return res * const
281 |
282 | def scale(self):
283 | nu = self.params_dict["nu"]
284 | res = (
285 | scipy.special.gamma(0.5 * (nu + 1.0))
286 | / np.sqrt(np.pi * nu)
287 | / scipy.special.gamma(0.5 * nu)
288 | )
289 | return res
290 |
291 | def integr_cubic_root_p(self, a, b):
292 | assert b >= a
293 | nu = self.params_dict["nu"]
294 |
295 | common_mult = 1.0 / (nu - 2.0) * 3.0 * (a * b) ** (-nu / 3.0) * nu ** ((1.0 + nu) / 6.0)
296 | first_term = (
297 | scipy.special.hyp2f1(
298 | 1 / 6.0 * (-2.0 + nu), (1.0 + nu) / 6, (4.0 + nu) / 6.0, -nu / a**2
299 | )
300 | * a ** (2.0 / 3.0)
301 | * b ** (nu / 3.0)
302 | )
303 | second_term = (
304 | -scipy.special.hyp2f1(
305 | 1 / 6.0 * (-2.0 + nu), (1.0 + nu) / 6, (4.0 + nu) / 6.0, -nu / b**2
306 | )
307 | * b ** (2.0 / 3.0)
308 | * a ** (nu / 3.0)
309 | )
310 |
311 | return common_mult * (first_term + second_term)
312 |
313 | def integr_interv_u_t_times_p_no_constant(self, a, b, u):
314 | assert b >= a
315 | df = self.params_dict["nu"]
316 | res = (
317 | df ** ((1.0 + df) / 2.0)
318 | * (-((a**2 + df) ** (1.0 / 2.0 - df / 2.0)) + (b**2 + df) ** (1.0 / 2.0 - df / 2.0))
319 | * u
320 | ) / (1.0 - df)
321 | return res
322 |
323 | def integr_interv_x_p_signed_r(self, a, b, x0):
324 | assert b >= a
325 | nu = self.params_dict["nu"]
326 |
327 | const = (
328 | scipy.special.gamma(0.5 * (nu + 1.0))
329 | / np.sqrt(np.pi * nu)
330 | / scipy.special.gamma(0.5 * nu)
331 | )
332 | r1 = self.integr_interv_u_t_times_p_no_constant(a, b, x0) * const
333 | r2 = self.integr_interv_p_sqr_r(a, b, 0.0)
334 |
335 | res = r1 - r2
336 | return res
337 |
338 | def eval_non_central_second_moment(self):
339 | term_range_min = self.point_mass_range_min * self.range_min**2
340 | term_range_max = self.point_mass_range_max * self.range_max**2
341 | term_middle_intergral = self.integr_interv_p_sqr_r(self.range_min, self.range_max, 0.0)
342 | return term_range_min + term_middle_intergral + term_range_max
343 |
344 |
345 | class UniformDistr(DistrBase):
346 | def __init__(self, *args, **kwargs):
347 | super().__init__(*args, **kwargs)
348 | self.p = 1 / (self.range_max - self.range_min)
349 |
350 | def print(self):
351 | print("Uniform distribution on [", self.range_min, ",", self.range_max, "]")
352 |
353 | def pdf(self, x):
354 | return self.p
355 |
356 | def cdf(self, x):
357 | return (x - self.range_min) * self.p
358 |
359 | def sample(self, shape):
360 | return np.random.uniform(self.range_min, self.range_max, shape)
361 |
362 | def integr_interv_p_sqr_r(self, a, b, u):
363 | assert b >= a
364 | res = -(a**3 / 3.0) + b**3 / 3.0 + a**2 * u - b**2 * u - a * u**2 + b * u**2
365 | return res * self.p
366 |
367 | def eval_non_central_second_moment(self):
368 | if not isinstance(self, UniformDistr):
369 | term_range_min = self.point_mass_range_min() * self.range_min**2
370 | term_range_max = self.point_mass_range_max() * self.range_max**2
371 | else:
372 | term_range_min = 0.0
373 | term_range_max = 0.0
374 | term_middle_intergral = self.integr_interv_p_sqr_r(self.range_min, self.range_max, 0.0)
375 |
376 | return term_range_min + term_middle_intergral + term_range_max
377 |
378 | def integr_p_times_x(self, a, b):
379 | return 0.5 * (b**2 - a**2) * self.p
380 |
381 | def integr_interv_x_p_signed_r(self, a, b, x0):
382 | assert b >= a
383 | res = 0.5 * a**2 - 0.5 * b**2 + (b - a) * x0
384 | return res * self.p
385 |
--------------------------------------------------------------------------------
/utils/grid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import torch
6 | import numpy as np
7 | from utils.distributions import ClippedGaussDistr, ClippedStudentTDistr
8 |
9 |
10 | def rounding_error_abs_nearest(x_float, grid):
11 | n_grid = grid.size
12 | grid_row = np.array(grid).reshape(1, n_grid)
13 | m_vals = x_float.numel()
14 | x_float = x_float.cpu().detach().numpy().reshape(m_vals, 1)
15 |
16 | dist = np.abs(x_float - grid_row)
17 | min_dist = np.min(dist, axis=1)
18 |
19 | return min_dist
20 |
21 |
22 | def quant_scalar_nearest(x_float, grid):
23 | dist = np.abs(x_float - grid)
24 | idx = np.argmin(dist)
25 | q_x = grid[idx]
26 | return q_x
27 |
28 |
29 | def clip_grid_exclude_bounds(grid, min_val, max_val):
30 | idx_subset = torch.logical_and(grid > min_val, grid < max_val)
31 | return grid[idx_subset]
32 |
33 |
34 | def clip_grid_include_bounds(grid, min_val, max_val):
35 | idx_subset = torch.logical_and(grid >= min_val, grid <= max_val)
36 | return grid[idx_subset]
37 |
38 |
39 | def clip_grid_add_bounds(grid, min_val, max_val):
40 | grid_clipped = clip_grid_exclude_bounds(grid, min_val, max_val)
41 | bounds_np = np.array([min_val, max_val])
42 | clipped_with_bounds = np.sort(np.concatenate((grid_clipped, bounds_np)))
43 | return clipped_with_bounds
44 |
45 |
46 | def integrate_pdf_grid_func_analyt(distr, grid, distr_attr_func_name):
47 | grid = np.sort(grid)
48 | interval_integr_func = getattr(distr, distr_attr_func_name)
49 | res = 0.0
50 |
51 | if distr.range_min < grid[0]:
52 | res += interval_integr_func(distr.range_min, grid[0], grid[0])
53 |
54 | for i_interval in range(0, len(grid) - 1):
55 | grid_mid = 0.5 * (grid[i_interval] + grid[i_interval + 1])
56 |
57 | first_half_a = max(grid[i_interval], distr.range_min)
58 | first_half_b = min(grid_mid, distr.range_max)
59 |
60 | second_half_a = max(grid_mid, distr.range_min)
61 | second_half_b = min(grid[i_interval + 1], distr.range_max)
62 |
63 | if first_half_a < first_half_b:
64 | res += interval_integr_func(first_half_a, first_half_b, grid[i_interval])
65 | if second_half_a < second_half_b:
66 | res += interval_integr_func(second_half_a, second_half_b, grid[i_interval + 1])
67 |
68 | if distr.range_max > grid[-1]:
69 | res += interval_integr_func(grid[-1], distr.range_max, grid[-1])
70 |
71 | if (
72 | isinstance(distr, ClippedGaussDistr) or isinstance(distr, ClippedStudentTDistr)
73 | ) and distr_attr_func_name == "integr_interv_x_p_signed_r":
74 | q_range_min = quant_scalar_nearest(torch.Tensor([distr.range_min]), grid)
75 | q_range_max = quant_scalar_nearest(torch.Tensor([distr.range_max]), grid)
76 |
77 | term_point_mass = (
78 | distr.range_min * (q_range_min - distr.range_min) * distr.point_mass_range_min
79 | + distr.range_max * (q_range_max - distr.range_max) * distr.point_mass_range_max
80 | )
81 | res += term_point_mass
82 | elif (
83 | isinstance(distr, ClippedGaussDistr) or isinstance(distr, ClippedStudentTDistr)
84 | ) and distr_attr_func_name == "integr_interv_p_sqr_r":
85 | q_range_min = quant_scalar_nearest(torch.Tensor([distr.range_min]), grid)
86 | q_range_max = quant_scalar_nearest(torch.Tensor([distr.range_max]), grid)
87 |
88 | term_point_mass = (q_range_min - distr.range_min) ** 2 * distr.point_mass_range_min + (
89 | q_range_max - distr.range_max
90 | ) ** 2 * distr.point_mass_range_max
91 | res += term_point_mass
92 |
93 | return res
94 |
--------------------------------------------------------------------------------
/utils/imagenet_dataloaders.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import os
6 |
7 | import torchvision
8 | import torch.utils.data as torch_data
9 | from torchvision import transforms
10 | from utils import BaseEnumOptions
11 |
12 |
13 | class ImageInterpolation(BaseEnumOptions):
14 | nearest = transforms.InterpolationMode.NEAREST
15 | box = transforms.InterpolationMode.BOX
16 | bilinear = transforms.InterpolationMode.BILINEAR
17 | hamming = transforms.InterpolationMode.HAMMING
18 | bicubic = transforms.InterpolationMode.BICUBIC
19 | lanczos = transforms.InterpolationMode.LANCZOS
20 |
21 |
22 | class ImageNetDataLoaders(object):
23 | """
24 | Data loader provider for ImageNet images, providing a train and a validation loader.
25 | It assumes that the structure of the images is
26 | images_dir
27 | - train
28 | - label1
29 | - label2
30 | - ...
31 | - val
32 | - label1
33 | - label2
34 | - ...
35 | """
36 |
37 | def __init__(
38 | self,
39 | images_dir: str,
40 | image_size: int,
41 | batch_size: int,
42 | num_workers: int,
43 | interpolation: transforms.InterpolationMode,
44 | ):
45 | """
46 | Parameters
47 | ----------
48 | images_dir: str
49 | Root image directory
50 | image_size: int
51 | Number of pixels the image will be re-sized to (square)
52 | batch_size: int
53 | Batch size of both the training and validation loaders
54 | num_workers
55 | Number of parallel workers loading the images
56 | interpolation: transforms.InterpolationMode
57 | Desired interpolation to use for resizing.
58 | """
59 |
60 | self.images_dir = images_dir
61 | self.batch_size = batch_size
62 | self.num_workers = num_workers
63 |
64 | # For normalization, mean and std dev values are calculated per channel
65 | # and can be found on the web.
66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
67 |
68 | self.train_transforms = transforms.Compose(
69 | [
70 | transforms.RandomResizedCrop(image_size, interpolation=interpolation.value),
71 | transforms.RandomHorizontalFlip(),
72 | transforms.ToTensor(),
73 | normalize,
74 | ]
75 | )
76 |
77 | self.val_transforms = transforms.Compose(
78 | [
79 | transforms.Resize(image_size + 24, interpolation=interpolation.value),
80 | transforms.CenterCrop(image_size),
81 | transforms.ToTensor(),
82 | normalize,
83 | ]
84 | )
85 |
86 | self._train_loader = None
87 | self._val_loader = None
88 |
89 | @property
90 | def train_loader(self) -> torch_data.DataLoader:
91 | if not self._train_loader:
92 | root = os.path.join(self.images_dir, "train")
93 | train_set = torchvision.datasets.ImageFolder(root, transform=self.train_transforms)
94 | self._train_loader = torch_data.DataLoader(
95 | train_set,
96 | batch_size=self.batch_size,
97 | shuffle=True,
98 | num_workers=self.num_workers,
99 | pin_memory=True,
100 | )
101 | return self._train_loader
102 |
103 | @property
104 | def val_loader(self) -> torch_data.DataLoader:
105 | if not self._val_loader:
106 | root = os.path.join(self.images_dir, "val")
107 | val_set = torchvision.datasets.ImageFolder(root, transform=self.val_transforms)
108 | self._val_loader = torch_data.DataLoader(
109 | val_set,
110 | batch_size=self.batch_size,
111 | shuffle=False,
112 | num_workers=self.num_workers,
113 | pin_memory=True,
114 | )
115 | return self._val_loader
116 |
--------------------------------------------------------------------------------
/utils/optimizer_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import torch
6 |
7 |
8 | def get_lr_scheduler(optimizer, lr_schedule, epochs):
9 | scheduler = None
10 | if lr_schedule:
11 | if lr_schedule.startswith("multistep"):
12 | epochs = [int(s) for s in lr_schedule.split(":")[1:]]
13 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epochs)
14 | elif lr_schedule.startswith("cosine"):
15 | eta_min = float(lr_schedule.split(":")[1])
16 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
17 | optimizer, epochs, eta_min=eta_min
18 | )
19 | return scheduler
20 |
21 |
22 | def optimizer_lr_factory(config_optim, params, epochs):
23 | if config_optim.optimizer.lower() == "sgd":
24 | optimizer = torch.optim.SGD(
25 | params,
26 | lr=config_optim.learning_rate,
27 | momentum=config_optim.momentum,
28 | weight_decay=config_optim.weight_decay,
29 | )
30 | elif config_optim.optimizer.lower() == "adam":
31 | optimizer = torch.optim.Adam(
32 | params, lr=config_optim.learning_rate, weight_decay=config_optim.weight_decay
33 | )
34 | else:
35 | raise ValueError()
36 |
37 | lr_scheduler = get_lr_scheduler(optimizer, config_optim.learning_rate_schedule, epochs)
38 |
39 | return optimizer, lr_scheduler
40 |
--------------------------------------------------------------------------------
/utils/qat_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import copy
6 | import torch
7 |
8 | from quantization.quantized_folded_bn import BNFusedHijacker
9 | from utils.imagenet_dataloaders import ImageNetDataLoaders
10 |
11 |
12 | def get_dataloaders_and_model(config, load_type="fp32", **qparams):
13 | dataloaders = ImageNetDataLoaders(
14 | config.base.images_dir,
15 | 224,
16 | config.base.batch_size,
17 | config.base.num_workers,
18 | config.base.interpolation,
19 | )
20 |
21 | model = config.base.architecture(
22 | pretrained=config.base.pretrained,
23 | load_type=load_type,
24 | model_dir=config.base.model_dir,
25 | **qparams,
26 | )
27 | if config.base.cuda:
28 | model = model.cuda()
29 |
30 | return dataloaders, model
31 |
32 |
33 | class ReestimateBNStats:
34 | def __init__(self, model, data_loader, num_batches=50):
35 | super().__init__()
36 | self.model = model
37 | self.data_loader = data_loader
38 | self.num_batches = num_batches
39 |
40 | def __call__(self, engine):
41 | print("-- Reestimate current BN statistics --")
42 | reestimate_BN_stats(self.model, self.data_loader, self.num_batches)
43 |
44 |
45 | def reestimate_BN_stats(model, data_loader, num_batches=50, store_ema_stats=False):
46 | # We set BN momentum to 1 an use train mode
47 | # -> the running mean/var have the current batch statistics
48 | model.eval()
49 | org_momentum = {}
50 | for name, module in model.named_modules():
51 | if isinstance(module, BNFusedHijacker):
52 | org_momentum[name] = module.momentum
53 | module.momentum = 1.0
54 | module.running_mean_sum = torch.zeros_like(module.running_mean)
55 | module.running_var_sum = torch.zeros_like(module.running_var)
56 | # Set all BNFusedHijacker modules to train mode for but not its children
57 | module.training = True
58 |
59 | if store_ema_stats:
60 | # Save the original EMA, make sure they are in buffers so they end in the state dict
61 | if not hasattr(module, "running_mean_ema"):
62 | module.register_buffer("running_mean_ema", copy.deepcopy(module.running_mean))
63 | module.register_buffer("running_var_ema", copy.deepcopy(module.running_var))
64 | else:
65 | module.running_mean_ema = copy.deepcopy(module.running_mean)
66 | module.running_var_ema = copy.deepcopy(module.running_var)
67 |
68 | # Run data for estimation
69 | device = next(model.parameters()).device
70 | batch_count = 0
71 | with torch.no_grad():
72 | for x, y in data_loader:
73 | model(x.to(device))
74 | # We save the running mean/var to a buffer
75 | for name, module in model.named_modules():
76 | if isinstance(module, BNFusedHijacker):
77 | module.running_mean_sum += module.running_mean
78 | module.running_var_sum += module.running_var
79 |
80 | batch_count += 1
81 | if batch_count == num_batches:
82 | break
83 | # At the end we normalize the buffer and write it into the running mean/var
84 | for name, module in model.named_modules():
85 | if isinstance(module, BNFusedHijacker):
86 | module.running_mean = module.running_mean_sum / batch_count
87 | module.running_var = module.running_var_sum / batch_count
88 | # We reset the momentum in case it would be used anywhere else
89 | module.momentum = org_momentum[name]
90 | model.eval()
91 |
--------------------------------------------------------------------------------
/utils/stopwatch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import sys
6 | import time
7 |
8 |
9 | class Stopwatch:
10 | """
11 | A simple cross-platform context-manager stopwatch.
12 |
13 | Examples
14 | --------
15 | >>> import time
16 | >>> with Stopwatch(verbose=True) as st:
17 | ... time.sleep(0.101) #doctest: +ELLIPSIS
18 | Elapsed time: 0.10... sec
19 | """
20 |
21 | def __init__(self, name=None, verbose=False):
22 | self._name = name
23 | self._verbose = verbose
24 |
25 | self._start_time_point = 0.0
26 | self._total_duration = 0.0
27 | self._is_running = False
28 |
29 | if sys.platform == "win32":
30 | # on Windows, the best timer is time.clock()
31 | self._timer_fn = time.clock
32 | else:
33 | # on most other platforms, the best timer is time.time()
34 | self._timer_fn = time.time
35 |
36 | def __enter__(self, verbose=False):
37 | return self.start()
38 |
39 | def __exit__(self, exc_type, exc_val, exc_tb):
40 | self.stop()
41 | if self._verbose:
42 | self.print()
43 |
44 | def start(self):
45 | if not self._is_running:
46 | self._start_time_point = self._timer_fn()
47 | self._is_running = True
48 | return self
49 |
50 | def stop(self):
51 | if self._is_running:
52 | self._total_duration += self._timer_fn() - self._start_time_point
53 | self._is_running = False
54 | return self
55 |
56 | def reset(self):
57 | self._start_time_point = 0.0
58 | self._total_duration = 0.0
59 | self._is_running = False
60 | return self
61 |
62 | def _update_state(self):
63 | now = self._timer_fn()
64 | self._total_duration += now - self._start_time_point
65 | self._start_time_point = now
66 |
67 | def _format(self):
68 | prefix = f"[{self._name}]" if self._name is not None else "Elapsed time"
69 | info = f"{prefix}: {self._total_duration:.3f} sec"
70 | return info
71 |
72 | def format(self):
73 | if self._is_running:
74 | self._update_state()
75 | return self._format()
76 |
77 | def print(self):
78 | print(self.format())
79 |
80 | def get_total_duration(self):
81 | if self._is_running:
82 | self._update_state()
83 | return self._total_duration
84 |
--------------------------------------------------------------------------------
/utils/supervised_driver.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from ignite.contrib.handlers import TensorboardLogger
6 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
7 | from ignite.handlers import Checkpoint, global_step_from_engine
8 | from torch.optim import Optimizer
9 |
10 |
11 | def create_trainer_engine(
12 | model,
13 | optimizer,
14 | criterion,
15 | metrics,
16 | data_loaders,
17 | lr_scheduler=None,
18 | save_checkpoint_dir=None,
19 | device="cuda",
20 | ):
21 | # Create trainer
22 | trainer = create_supervised_trainer(
23 | model=model,
24 | optimizer=optimizer,
25 | loss_fn=criterion,
26 | device=device,
27 | output_transform=custom_output_transform,
28 | )
29 |
30 | for name, metric in metrics.items():
31 | metric.attach(trainer, name)
32 |
33 | # Add lr_scheduler
34 | if lr_scheduler:
35 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step())
36 |
37 | # Create evaluator
38 | evaluator = create_supervised_evaluator(model=model, metrics=metrics, device=device)
39 |
40 | # Save model checkpoint
41 | if save_checkpoint_dir:
42 | to_save = {"model": model, "optimizer": optimizer}
43 | if lr_scheduler:
44 | to_save["lr_scheduler"] = lr_scheduler
45 | checkpoint = Checkpoint(
46 | to_save,
47 | save_checkpoint_dir,
48 | n_saved=1,
49 | global_step_transform=global_step_from_engine(trainer),
50 | )
51 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint)
52 |
53 | # Add hooks for logging metrics
54 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results, optimizer)
55 |
56 | trainer.add_event_handler(
57 | Events.EPOCH_COMPLETED, run_evaluation_for_training, evaluator, data_loaders.val_loader
58 | )
59 |
60 | return trainer, evaluator
61 |
62 |
63 | def custom_output_transform(x, y, y_pred, loss):
64 | return y_pred, y
65 |
66 |
67 | def log_training_results(trainer, optimizer):
68 | learning_rate = optimizer.param_groups[0]["lr"]
69 | log_metrics(trainer.state.metrics, "Training", trainer.state.epoch, learning_rate)
70 |
71 |
72 | def run_evaluation_for_training(trainer, evaluator, val_loader):
73 | evaluator.run(val_loader)
74 | log_metrics(evaluator.state.metrics, "Evaluation", trainer.state.epoch)
75 |
76 |
77 | def log_metrics(metrics, stage: str = "", training_epoch=None, learning_rate=None):
78 | log_text = " {}".format(metrics) if metrics else ""
79 | if training_epoch is not None:
80 | log_text = "Epoch: {}".format(training_epoch) + log_text
81 | if learning_rate and learning_rate > 0.0:
82 | log_text += " Learning rate: {:.2E}".format(learning_rate)
83 | log_text = "Results - " + log_text
84 | if stage:
85 | log_text = "{} ".format(stage) + log_text
86 | print(log_text, flush=True)
87 |
88 |
89 | def setup_tensorboard_logger(trainer, evaluator, output_path, optimizers=None):
90 | logger = TensorboardLogger(logdir=output_path)
91 |
92 | # Attach the logger to log loss and accuracy for both training and validation
93 | for tag, cur_evaluator in [("train", trainer), ("validation", evaluator)]:
94 | logger.attach_output_handler(
95 | cur_evaluator,
96 | event_name=Events.EPOCH_COMPLETED,
97 | tag=tag,
98 | metric_names="all",
99 | global_step_transform=global_step_from_engine(trainer),
100 | )
101 |
102 | # Log optimizer parameters
103 | if isinstance(optimizers, Optimizer):
104 | optimizers = {None: optimizers}
105 |
106 | for k, optimizer in optimizers.items():
107 | logger.attach_opt_params_handler(
108 | trainer, Events.EPOCH_COMPLETED, optimizer, param_name="lr", tag=k
109 | )
110 |
111 | return logger
112 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import collections
6 | import os
7 | import random
8 | from collections import namedtuple
9 | from enum import Flag, auto
10 | from functools import partial
11 |
12 | import click
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 |
17 |
18 | class DotDict(dict):
19 | """
20 | A dictionary that allows attribute-style access.
21 | Examples
22 | --------
23 | >>> config = DotDict(a=None)
24 | >>> config.a = 42
25 | >>> config.b = 'egg'
26 | >>> config # can be used as dict
27 | {'a': 42, 'b': 'egg'}
28 | """
29 |
30 | def __setattr__(self, key, value):
31 | self.__setitem__(key, value)
32 |
33 | def __delattr__(self, key):
34 | self.__delitem__(key)
35 |
36 | def __getattr__(self, key):
37 | if key in self:
38 | return self.__getitem__(key)
39 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})")
40 |
41 |
42 | def relu(x):
43 | x = np.array(x)
44 | return x * (x > 0)
45 |
46 |
47 | def get_all_layer_names(model, subtypes=None):
48 | if subtypes is None:
49 | return [name for name, module in model.named_modules()][1:]
50 | return [name for name, module in model.named_modules() if isinstance(module, subtypes)]
51 |
52 |
53 | def get_layer_name_to_module_dict(model):
54 | return {name: module for name, module in model.named_modules() if name}
55 |
56 |
57 | def get_module_to_layer_name_dict(model):
58 | modules_to_names = collections.OrderedDict()
59 | for name, module in model.named_modules():
60 | modules_to_names[module] = name
61 | return modules_to_names
62 |
63 |
64 | def get_layer_name(model, layer):
65 | for name, module in model.named_modules():
66 | if module == layer:
67 | return name
68 | return None
69 |
70 |
71 | def get_layer_by_name(model, layer_name):
72 | for name, module in model.named_modules():
73 | if name == layer_name:
74 | return module
75 | return None
76 |
77 |
78 | def create_conv_layer_list(cls, model: nn.Module) -> list:
79 | """
80 | Function finds all prunable layers in the provided model
81 |
82 | Parameters
83 | ----------
84 | cls: SVD class
85 | model : torch.nn.Module
86 | A pytorch model.
87 |
88 | Returns
89 | -------
90 | conv_layer_list : list
91 | List of all prunable layers in the given model.
92 |
93 | """
94 | conv_layer_list = []
95 |
96 | def fill_list(mod):
97 | if isinstance(mod, tuple(cls.supported_layer_types)):
98 | conv_layer_list.append(mod)
99 |
100 | model.apply(fill_list)
101 | return conv_layer_list
102 |
103 |
104 | def create_linear_layer_list(cls, model: nn.Module) -> list:
105 | """
106 | Function finds all prunable layers in the provided model
107 |
108 | Parameters
109 | ----------
110 | model : torch.nn.Module
111 | A pytorch model.
112 |
113 | Returns
114 | -------
115 | conv_layer_list : list
116 | List of all prunable layers in the given model.
117 |
118 | """
119 | conv_layer_list = []
120 |
121 | def fill_list(mod):
122 | if isinstance(mod, tuple(cls.supported_layer_types)):
123 | conv_layer_list.append(mod)
124 |
125 | model.apply(fill_list)
126 | return conv_layer_list
127 |
128 |
129 | def to_numpy(tensor):
130 | """
131 | Helper function that turns the given tensor into a numpy array
132 |
133 | Parameters
134 | ----------
135 | tensor : torch.Tensor
136 |
137 | Returns
138 | -------
139 | tensor : float or np.array
140 |
141 | """
142 | if isinstance(tensor, np.ndarray):
143 | return tensor
144 | if hasattr(tensor, "is_cuda"):
145 | if tensor.is_cuda:
146 | return tensor.cpu().detach().numpy()
147 | if hasattr(tensor, "detach"):
148 | return tensor.detach().numpy()
149 | if hasattr(tensor, "numpy"):
150 | return tensor.numpy()
151 |
152 | return np.array(tensor)
153 |
154 |
155 | def set_module_attr(model, layer_name, value):
156 | split = layer_name.split(".")
157 |
158 | this_module = model
159 | for mod_name in split[:-1]:
160 | if mod_name.isdigit():
161 | this_module = this_module[int(mod_name)]
162 | else:
163 | this_module = getattr(this_module, mod_name)
164 |
165 | last_mod_name = split[-1]
166 | if last_mod_name.isdigit():
167 | this_module[int(last_mod_name)] = value
168 | else:
169 | setattr(this_module, last_mod_name, value)
170 |
171 |
172 | def search_for_zero_planes(model: torch.nn.Module):
173 | """If list of modules to winnow is empty to start with, search through all modules to check
174 | if any
175 | planes have been zeroed out. Update self._list_of_modules_to_winnow with any findings.
176 | :param model: torch model to search through modules for zeroed parameters
177 | """
178 |
179 | list_of_modules_to_winnow = []
180 | for _, module in model.named_modules():
181 | if isinstance(module, (torch.nn.Linear, torch.nn.modules.conv.Conv2d)):
182 | in_channels_to_winnow = _assess_weight_and_bias(module.weight, module.bias)
183 | if in_channels_to_winnow:
184 | list_of_modules_to_winnow.append((module, in_channels_to_winnow))
185 | return list_of_modules_to_winnow
186 |
187 |
188 | def _assess_weight_and_bias(weight: torch.nn.Parameter, _bias: torch.nn.Parameter):
189 | """4-dim weights [CH-out, CH-in, H, W] and 1-dim bias [CH-out]"""
190 | if len(weight.shape) > 2:
191 | input_channels_to_ignore = (weight.sum((0, 2, 3)) == 0).nonzero().squeeze().tolist()
192 | else:
193 | input_channels_to_ignore = (weight.sum(0) == 0).nonzero().squeeze().tolist()
194 |
195 | if type(input_channels_to_ignore) != list:
196 | input_channels_to_ignore = [input_channels_to_ignore]
197 |
198 | return input_channels_to_ignore
199 |
200 |
201 | def seed_all(seed: int = 1029, deterministic: bool = False):
202 | """
203 | This is our attempt to make experiments reproducible by seeding all known RNGs and setting
204 | appropriate torch directives.
205 | For a general discussion of reproducibility in Pytorch and CUDA and a documentation of the
206 | options we are using see, e.g.,
207 | https://pytorch.org/docs/1.7.1/notes/randomness.html
208 | https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
209 |
210 | As of today (July 2021), even after seeding and setting some directives,
211 | there remain unfortunate contradictions:
212 | 1. CUDNN
213 | - having CUDNN enabled leads to
214 | - non-determinism in Pytorch when using the GPU, cf. MORPH-10999.
215 | - having CUDNN disabled leads to
216 | - most regression tests in Qrunchy failing, cf. MORPH-11103
217 | - significantly increased execution time in some cases
218 | - performance degradation in some cases
219 | 2. torch.set_deterministic(d)
220 | - setting d = True leads to errors for Pytorch algorithms that do not (yet) have a deterministic
221 | counterpart, e.g., the layer `adaptive_avg_pool2d_backward_cuda` in vgg16__torchvision.
222 |
223 | Thus, we leave the choice of enforcing determinism by disabling CUDNN and non-deterministic
224 | algorithms to the user. To keep it simple, we only have one switch for both.
225 | This situation could be re-evaluated upon updates of Pytorch, CUDA, CUDNN.
226 | """
227 |
228 | assert isinstance(seed, int), f"RNG seed must be an integer ({seed})"
229 | assert seed >= 0, f"RNG seed must be a positive integer ({seed})"
230 |
231 | # Builtin RNGs
232 | random.seed(seed)
233 | os.environ["PYTHONHASHSEED"] = str(seed)
234 |
235 | # Numpy RNG
236 | np.random.seed(seed)
237 |
238 | # CUDNN determinism (setting those has not lead to errors so far)
239 | torch.backends.cudnn.benchmark = False
240 | torch.backends.cudnn.deterministic = True
241 |
242 | # Torch RNGs
243 | torch.manual_seed(seed)
244 | torch.cuda.manual_seed(seed)
245 | torch.cuda.manual_seed_all(seed)
246 |
247 | # Problematic settings, see docstring. Precaution: We do not mutate unless asked to do so
248 | if deterministic is True:
249 | torch.backends.cudnn.enabled = False
250 |
251 | torch.set_deterministic(True) # Use torch.use_deterministic_algorithms(True) in torch 1.8.1
252 | # When using torch.set_deterministic(True), it is advised by Pytorch to set the
253 | # CUBLAS_WORKSPACE_CONFIG variable as follows, see
254 | # https://pytorch.org/docs/1.7.1/notes/randomness.html#avoiding-nondeterministic-algorithms
255 | # and the link to the CUDA homepage on that website.
256 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
257 |
258 |
259 | def assert_allclose(actual, desired, *args, **kwargs):
260 | """A more beautiful version of torch.all_close."""
261 | np.testing.assert_allclose(to_numpy(actual), to_numpy(desired), *args, **kwargs)
262 |
263 |
264 | def count_params(module):
265 | return len(nn.utils.parameters_to_vector(module.parameters()))
266 |
267 |
268 | class StopForwardException(Exception):
269 | """Used to throw and catch an exception to stop traversing the graph."""
270 |
271 | pass
272 |
273 |
274 | class StopForwardHook:
275 | def __call__(self, module, *args):
276 | raise StopForwardException
277 |
278 |
279 | def sigmoid(x):
280 | return 1.0 / (1.0 + np.exp(-x))
281 |
282 |
283 | class CosineTempDecay:
284 | def __init__(self, t_max, temp_range=(20.0, 2.0), rel_decay_start=0):
285 | self.t_max = t_max
286 | self.start_temp, self.end_temp = temp_range
287 | self.decay_start = rel_decay_start * t_max
288 |
289 | def __call__(self, t):
290 | if t < self.decay_start:
291 | return self.start_temp
292 |
293 | rel_t = (t - self.decay_start) / (self.t_max - self.decay_start)
294 | return self.end_temp + 0.5 * (self.start_temp - self.end_temp) * (1 + np.cos(rel_t * np.pi))
295 |
296 |
297 | class BaseEnumOptions(Flag):
298 | def __str__(self):
299 | return self.name
300 |
301 | @classmethod
302 | def list_names(cls):
303 | return [m.name for m in cls]
304 |
305 |
306 | class ClassEnumOptions(BaseEnumOptions):
307 | @property
308 | def cls(self):
309 | return self.value.cls
310 |
311 | def __call__(self, *args, **kwargs):
312 | return self.value.cls(*args, **kwargs)
313 |
314 |
315 | MethodMap = partial(namedtuple("MethodMap", ["value", "cls"]), auto())
316 |
317 |
318 | def split_dict(src: dict, include=(), remove_prefix: str = ""):
319 | """
320 | Splits dictionary into a DotDict and a remainder.
321 | The arguments to be placed in the first DotDict are those listed in `include`.
322 | Parameters
323 | ----------
324 | src: dict
325 | The source dictionary.
326 | include:
327 | List of keys to be returned in the first DotDict.
328 | remove_suffix:
329 | remove prefix from key
330 | """
331 | result = DotDict()
332 |
333 | for arg in include:
334 | if remove_prefix:
335 | key = arg.replace(f"{remove_prefix}_", "", 1)
336 | else:
337 | key = arg
338 | result[key] = src[arg]
339 | remainder = {key: val for key, val in src.items() if key not in include}
340 | return result, remainder
341 |
342 |
343 | class ClickEnumOption(click.Choice):
344 | """
345 | Adjusted click.Choice type for BaseOption which is based on Enum
346 | """
347 |
348 | def __init__(self, enum_options, case_sensitive=True):
349 | assert issubclass(enum_options, BaseEnumOptions)
350 | self.base_option = enum_options
351 | super().__init__(self.base_option.list_names(), case_sensitive)
352 |
353 | def convert(self, value, param, ctx):
354 | # Exact match
355 | if value in self.choices:
356 | return self.base_option[value]
357 |
358 | # Match through normalization and case sensitivity
359 | # first do token_normalize_func, then lowercase
360 | # preserve original `value` to produce an accurate message in
361 | # `self.fail`
362 | normed_value = value
363 | normed_choices = self.choices
364 |
365 | if ctx is not None and ctx.token_normalize_func is not None:
366 | normed_value = ctx.token_normalize_func(value)
367 | normed_choices = [ctx.token_normalize_func(choice) for choice in self.choices]
368 |
369 | if not self.case_sensitive:
370 | normed_value = normed_value.lower()
371 | normed_choices = [choice.lower() for choice in normed_choices]
372 |
373 | if normed_value in normed_choices:
374 | return self.base_option[normed_value]
375 |
376 | self.fail(
377 | "invalid choice: %s. (choose from %s)" % (value, ", ".join(self.choices)), param, ctx
378 | )
379 |
--------------------------------------------------------------------------------