├── LICENSE
├── README.md
├── bash
├── set_env.sh
├── train_efficientnet.sh
└── train_mobilenetv2.sh
├── display
└── toy_regression.gif
├── main.py
├── models
├── __init__.py
├── efficientnet_lite_quantized.py
├── mobilenet_v2.py
├── mobilenet_v2_quantized.py
└── resnet_quantized.py
├── quantization
├── __init__.py
├── autoquant_utils.py
├── base_quantized_classes.py
├── base_quantized_model.py
├── hijacker.py
├── quantization_manager.py
├── quantized_folded_bn.py
├── quantizers
│ ├── __init__.py
│ ├── base_quantizers.py
│ ├── rounding_utils.py
│ ├── uniform_quantizers.py
│ └── utils.py
├── range_estimators.py
└── utils.py
├── requirements.txt
└── utils
├── __init__.py
├── click_options.py
├── imagenet_dataloaders.py
├── optimizer_utils.py
├── oscillation_tracking_utils.py
├── qat_utils.py
├── stopwatch.py
├── supervised_driver.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022 Qualcomm Technologies, Inc.
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer:
8 |
9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 |
11 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
12 |
13 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Overcoming Oscillations in Quantization-Aware Training
2 | This repository containes the implementation and experiments for the paper presented in
3 |
4 | **Markus Nagel\*1, Marios Fournarakis\*1, Yelysei Bondarenko1,
5 | Tijmen Blankevoort1 "Overcoming Oscillations in Quantization-Aware Training", ICML
6 | 2022.** [[ArXiv]](https://arxiv.org/abs/2203.11086)
7 |
8 | *Equal contribution
9 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.)
10 |
11 | You can use this code to recreate the results in the paper.
12 | ## Reference
13 | If you find our work useful, please cite
14 | ```
15 | @InProceedings{pmlr-v162-nagel22a,
16 | title = {Overcoming Oscillations in Quantization-Aware Training},
17 | author = {Nagel, Markus and Fournarakis, Marios and Bondarenko, Yelysei and Blankevoort, Tijmen},
18 | booktitle = {Proceedings of the 39th International Conference on Machine Learning},
19 | pages = {16318--16330},
20 | year = {2022},
21 | editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
22 | volume = {162},
23 | series = {Proceedings of Machine Learning Research},
24 | month = {17--23 Jul},
25 | publisher = {PMLR},
26 | pdf = {https://proceedings.mlr.press/v162/nagel22a/nagel22a.pdf},
27 | url = {https://proceedings.mlr.press/v162/nagel22a.html}
28 | }
29 | ```
30 |
31 | ## Method and Results
32 |
33 | When training neural networks with simulated quantization, we observe that quantized weights can,
34 | rather unexpectedly, oscillate between two grid-points. This is an inherent issue problem caused
35 | by the straight-through-estimator (STE). In our paper, we delve deeper in this little understood
36 | phenomenon and show that oscillations harm accuracy by corrupting the EMA statistics of the
37 | batch-normalization layers and by preventing convergence to local mimima.
38 |
39 |
40 |
41 |
42 |
43 | We propose two novel methods to tackle oscillations at their source: **oscillations dampening**
44 | and **iterative state freezing** We demonstrate that our algorithms achieve state-of-the-art
45 | accuracy for low-bit (3 & 4 bits) weight and activation quantization of efficient architectures,
46 | such as MobileNetV2, MobileNetV3, and EfficentNet-lite on ImageNet.
47 |
48 |
49 | ## How to install
50 | Make sure to have Python ≥3.6 (tested with Python 3.6.8) and
51 | ensure the latest version of `pip` (**tested** with 21.3.1):
52 | ```bash
53 | source env/bin/activate
54 | pip install --upgrade --no-deps pip
55 | ```
56 |
57 | Next, install PyTorch 1.9.1 with the appropriate CUDA version (tested with CUDA 10.0, CuDNN 7.6.3):
58 | ```bash
59 | pip install torch==1.9.1 torchvision==0.10.1
60 | ```
61 |
62 | Finally, install the remaining dependencies using pip:
63 | ```bash
64 | pip install -r requirements.txt
65 | ```
66 |
67 | ## Running experiments
68 | The main run file to reproduce all experiments is `main.py`.
69 | It contains commands for quantization-aware training (QAT) and validating quantized models.
70 | You can see the full list of options for each command using `python main.py [COMMAND] --help`.
71 | ```bash
72 | Usage: main.py [OPTIONS] COMMAND [ARGS]...
73 |
74 | Options:
75 | --help Show this message and exit.
76 |
77 | Commands:
78 | train-quantized
79 | ```
80 |
81 | ## Quantization-Aware Training (QAT)
82 | All models are fine-tuned starting from pre-trained FP32 weights. Pretrained weights may be found here
83 |
84 | - [MobileNetV2](https://drive.google.com/open?id=1jlto6HRVD3ipNkAl1lNhDbkBp7HylaqR)
85 | - EfficientNet-Lite: pretrained weights from [repository](https://github.com/rwightman/pytorch-image-models/) (downloaded at runtime)
86 |
87 | ## MobileNetV2
88 |
89 | To train with **oscillations dampening** run:
90 | ```bash
91 | python main.py train-quantized --arhcitecture mobilenet_v2_quantized
92 | --images-dir path/to/raw_imagenet --act-quant-method MSE --weight-quant-method MSE
93 | --optimizer SGD --weight-decay 2.5e-05 --sep-quant-optimizer
94 | --quant-optimizer Adam --quant-learning-rate 1e-5 --quant-weight-decay 0.0
95 | --model-dir /path/to/mobilenet_v2.pth.tar --learning-rate-schedule cosine:0
96 | # Dampening loss configurations
97 | --oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1
98 | # 4-bit best learning rate
99 | --n-bits 4 --learning-rate 0.0033
100 | # 3-bits best learning rate
101 | --n-bits 3 --learning-rate 0.01
102 | ```
103 |
104 | To train with **iterative weight freezing** run:
105 | ```bash
106 | python main.py train-quantized --arhcitecture mobilenet_v2_quantized
107 | --images-dir path/to/raw_imagenet --act-quant-method MSE --weight-quant-method MSE
108 | --optimizer SGD --sep-quant-optimizer
109 | --quant-optimizer Adam --quant-learning-rate 1e-5 --quant-weight-decay 0.0
110 | --model-dir /path/to/mobilenet_v2.pth.tar --learning-rate-schedule cosine:0
111 | # Iterative weight freezing configuration
112 | --oscillations-freeze-threshold 0.1
113 | # 4-bit best configuration
114 | --n-bits 4 --learning-rate 0.0033 --weight-decay 5e-05 --oscillations-freeze-threshold-final 0.01
115 | # 3-bit best configuration
116 | --n-bits 3 --learning-rate 0.01 --weight-decay 2.5e-05 --oscillations-freeze-threshold-final 0.011
117 | ```
118 |
119 | For end user's convenience, bash scripts are provided under `/bash/` for reproducing our experiments.
120 | ```bash
121 | ./bash/train_mobilenetv2.sh --IMAGES_DIR path_to_raw_imagenet --MODEL_DIR path_to_pretrained_weights # QAT training of MobileNetV2 with defaults (method 'freeze' and 3 bits)
122 | ./bash/train_efficientnet.sh --IMAGES_DIR path_to_raw_imagenet --METHOD damp --N_BITS 4
123 | ```
124 |
--------------------------------------------------------------------------------
/bash/set_env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | # Setting up the environment
6 | source env/bin/activate
7 | export LC_ALL=C.UTF-8
8 | export LANG=C.UTF-8
9 | export PYTHONPATH=${PYTHONPATH}:$(realpath "$PWD")
10 |
--------------------------------------------------------------------------------
/bash/train_efficientnet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | #########################################################################################################
6 |
7 | # Bash script for running QAT EfficientNet-Lite training configuration.
8 | # IMAGES_DIR: path to local raw imagenet dataset and
9 | # MODEL_DIR: path to model's pretrained weights
10 | # should be at least specified.
11 | #
12 | # Example of using this script:
13 | # $ ./bash/train_efficientnet.sh --IMAGES_DIR path_to_imagenet_raw --MODEL_DIR path_to_weights
14 | #
15 | # For getting usage info:
16 | # $ ./bash/train_efficientnet.sh
17 | #
18 | # Other configurable params:
19 | # N_BITS: 3(default), 4 are currently supported
20 | # METHOD: freeze (default; iterative weight freezing), damp (oscillations dampening)
21 | #
22 | # The script may be extended with further input parameters (please refer to "/utils/click_options.py")
23 |
24 | #########################################################################################################
25 |
26 | source bash/set_env.sh
27 |
28 | MODEL='efficientnet'
29 | N_BITS=3
30 | METHOD='freeze'
31 |
32 | for ARG in "$@"
33 | do
34 | key=$(echo $ARG | cut -f1 -d=)
35 | value=$(echo $ARG | cut -f2 -d=)
36 |
37 | if [[ $key == *"--"* ]]; then
38 | v="${key/--/}"
39 | declare $v="${value}"
40 | fi
41 | done
42 |
43 | if [[ -z $IMAGES_DIR ]] || [[ -z $MODEL_DIR ]]; then
44 | echo "Usage: $(basename "$0")
45 | --IMAGES_DIR=[path to imagenet_raw]
46 | --MODEL_DIR=[path to model's pretrained weights]
47 | --N_BITS=[3(default), 4]
48 | --METHOD=[freeze(default), damp]"
49 | exit 1
50 | fi
51 |
52 | if [ $N_BITS -ne 3 ] && [ $N_BITS -ne 4 ]; then
53 | echo "Only 3,4 bits configuration currently supported"
54 | exit 1
55 | fi
56 |
57 | if [ "$METHOD" != 'freeze' ] && [ "$METHOD" != 'damp' ]; then
58 | echo "Only methods 'damp' and 'freeze' are currently supported."
59 | exit 1
60 | fi
61 |
62 | CMD_ARGS='--architecture efficientnet_lite0_quantized
63 | --act-quant-method MSE
64 | --weight-quant-method MSE
65 | --optimizer SGD
66 | --max-epochs 50
67 | --learning-rate-schedule cosine:0
68 | --sep-quant-optimizer
69 | --quant-optimizer Adam
70 | --quant-learning-rate 1e-5
71 | --quant-weight-decay 0.0'
72 |
73 | # QAT methods
74 | if [ $METHOD == 'freeze' ]; then
75 | CMD_QAT='--oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1'
76 | if [ $N_BITS == 3 ]; then
77 | CMD_BITS='--n-bits 3 --learning-rate 0.01 --weight-decay 5e-05 --oscillations-freeze-threshold-final 0.005'
78 | else
79 | CMD_BITS='--n-bits 4 --learning-rate 0.0033 --weight-decay 1e-04 --oscillations-freeze-threshold-final 0.015'
80 | fi
81 | else
82 | CMD_QAT='--oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1'
83 | if [ $N_BITS == 3 ]; then
84 | CMD_BITS='--n-bits 3 --learning-rate 0.01 --weight-decay 5e-5'
85 | else
86 | CMD_BITS='--n-bits 4 --learning-rate 0.0033 --weight-decay 1e-4'
87 | fi
88 | fi
89 |
90 | CMD_ARGS="$CMD_ARGS $CMD_QAT $CMD_BITS"
91 |
92 | python main.py train-quantized \
93 | --images-dir $IMAGES_DIR \
94 | $CMD_ARGS
95 |
--------------------------------------------------------------------------------
/bash/train_mobilenetv2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | #########################################################################################################
6 |
7 | # Bash script for running QAT MobileNetV2 training configuration.
8 | # IMAGES_DIR: path to local raw imagenet dataset and
9 | # MODEL_DIR: path to model's pretrained weights
10 | # should be at least specified.
11 | #
12 | # Example of using this script:
13 | # $ ./bash/train_mobilenetv2.sh --IMAGES_DIR path_to_imagenet_raw --MODEL_DIR path_to_weights
14 | #
15 | # For getting usage info:
16 | # $ ./bash/train_mobilenetv2.sh
17 | #
18 | # Other configurable params:
19 | # N_BITS: 3(default), 4 are currently supported
20 | # METHOD: freeze (default; iterative weight freezing), damp (oscillations dampening)
21 | #
22 | # The script may be extended with further input parameters (please refer to "/utils/click_options.py")
23 |
24 | #########################################################################################################
25 |
26 | source bash/set_env.sh
27 |
28 | MODEL='mobilenetV2'
29 | N_BITS=3
30 | METHOD='freeze'
31 |
32 | for ARG in "$@"
33 | do
34 | key=$(echo $ARG | cut -f1 -d=)
35 | value=$(echo $ARG | cut -f2 -d=)
36 |
37 | if [[ $key == *"--"* ]]; then
38 | v="${key/--/}"
39 | declare $v="${value}"
40 | fi
41 | done
42 |
43 | if [[ -z $IMAGES_DIR ]] || [[ -z $MODEL_DIR ]]; then
44 | echo "Usage: $(basename "$0")
45 | --IMAGES_DIR=[path to imagenet_raw]
46 | --MODEL_DIR=[path to model's pretrained weights]
47 | --N_BITS=[3(default), 4]
48 | --METHOD=[freeze(default), damp]"
49 | exit 1
50 | fi
51 |
52 | if [ $N_BITS -ne 3 ] && [ $N_BITS -ne 4 ]; then
53 | echo 'Only 3,4 bits configuration currently supported'
54 | exit 1
55 | fi
56 |
57 | if [ "$METHOD" != 'freeze' ] && [ "$METHOD" != 'damp' ]; then
58 | echo "Only methods 'damp' and 'freeze' are currently supported."
59 | exit 1
60 | fi
61 |
62 | CMD_ARGS='--architecture mobilenet_v2_quantized
63 | --act-quant-method MSE
64 | --weight-quant-method MSE
65 | --optimizer SGD
66 | --weight-decay 2.5e-05
67 | --sep-quant-optimizer
68 | --quant-optimizer Adam
69 | --quant-learning-rate 1e-5
70 | --quant-weight-decay 0.0
71 | --learning-rate-schedule cosine:0'
72 |
73 | # QAT methods
74 | if [ $METHOD == "freeze" ]; then
75 | CMD_QAT='--oscillations-freeze-threshold 0.1'
76 | if [ $N_BITS == 3 ]; then
77 | CMD_BITS='--n-bits 3 --learning-rate 0.01 --weight-decay 2.5e-05 --oscillations-freeze-threshold-final 0.011'
78 | else
79 | CMD_BITS='--n-bits 4 --learning-rate 0.0033 --weight-decay 5e-05 --oscillations-freeze-threshold-final 0.01'
80 | fi
81 | else
82 | CMD_QAT='--oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1 --weight-decay 2.5e-05'
83 | if [ $N_BITS == 3 ]; then
84 | CMD_BITS='--n-bits 3 --learning-rate 0.01'
85 | else
86 | CMD_BITS='--n-bits 4 --learning-rate 0.0033'
87 | fi
88 | fi
89 |
90 | CMD_ARGS="$CMD_ARGS $CMD_QAT $CMD_BITS"
91 |
92 | python main.py train-quantized \
93 | --images-dir $IMAGES_DIR \
94 | --model-dir $MODEL_DIR \
95 | $CMD_ARGS
--------------------------------------------------------------------------------
/display/toy_regression.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qualcomm-AI-research/oscillations-qat/9064d8540c1705242f08b864f06661247012ee4d/display/toy_regression.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import logging
4 | import os
5 |
6 | import click
7 | from ignite.contrib.handlers import ProgressBar
8 | from ignite.engine import Events, create_supervised_evaluator
9 | from ignite.metrics import Accuracy, TopKCategoricalAccuracy, Loss
10 | from torch.nn import CrossEntropyLoss
11 |
12 | from quantization.utils import (
13 | pass_data_for_range_estimation,
14 | separate_quantized_model_params,
15 | set_range_estimators,
16 | )
17 | from utils import DotDict, CosineTempDecay
18 | from utils.click_options import (
19 | qat_options,
20 | quantization_options,
21 | quant_params_dict,
22 | base_options,
23 | multi_optimizer_options,
24 | )
25 | from utils.optimizer_utils import optimizer_lr_factory
26 | from utils.oscillation_tracking_utils import add_oscillation_trackers
27 | from utils.qat_utils import (
28 | get_dataloaders_and_model,
29 | MethodPropagator,
30 | DampeningLoss,
31 | CompositeLoss,
32 | UpdateDampeningLossWeighting,
33 | UpdateFreezingThreshold,
34 | ReestimateBNStats,
35 | )
36 | from utils.supervised_driver import create_trainer_engine, setup_tensorboard_logger, log_metrics
37 |
38 |
39 | # setup stuff
40 | class Config(DotDict):
41 | pass
42 |
43 |
44 | @click.group()
45 | def oscillations():
46 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
47 |
48 |
49 | pass_config = click.make_pass_decorator(Config, ensure=True)
50 |
51 |
52 | @oscillations.command()
53 | @pass_config
54 | @base_options
55 | @multi_optimizer_options()
56 | @quantization_options
57 | @qat_options
58 | def train_quantized(config):
59 | """
60 | Main QAT function
61 | """
62 |
63 | print("Setting up network and data loaders")
64 | qparams = quant_params_dict(config)
65 |
66 | dataloaders, model = get_dataloaders_and_model(config, **qparams)
67 |
68 | # Estimate ranges using training data
69 | pass_data_for_range_estimation(
70 | loader=dataloaders.train_loader,
71 | model=model,
72 | act_quant=config.quant.act_quant,
73 | weight_quant=config.quant.weight_quant,
74 | max_num_batches=config.quant.num_est_batches,
75 | )
76 |
77 | # Put quantizers in desirable state
78 | set_range_estimators(config, model)
79 |
80 | print("Loaded model:\n{}".format(model))
81 |
82 | # Get all models parameters in subcategories
83 | quantizer_params, model_params, grad_params = separate_quantized_model_params(model)
84 | model_optimizer, quant_optimizer = None, None
85 | if config.qat.sep_quant_optimizer:
86 | # Separate optimizer for model and quantization parameters
87 | model_optimizer, model_lr_scheduler = optimizer_lr_factory(
88 | config.optimizer, model_params, config.base.max_epochs
89 | )
90 | quant_optimizer, quant_lr_scheduler = optimizer_lr_factory(
91 | config.quant_optimizer, quantizer_params, config.base.max_epochs
92 | )
93 |
94 | optimizer = MethodPropagator([model_optimizer, quant_optimizer])
95 | lr_schedulers = [s for s in [model_lr_scheduler, quant_lr_scheduler] if s is not None]
96 | lr_scheduler = MethodPropagator(lr_schedulers) if len(lr_schedulers) else None
97 | else:
98 | optimizer, lr_scheduler = optimizer_lr_factory(
99 | config.optimizer, quantizer_params + model_params, config.base.max_epochs
100 | )
101 |
102 | print("Optimizer:\n{}".format(optimizer))
103 | print(f"LR scheduler\n{lr_scheduler}")
104 |
105 | # Define metrics for ingite engine
106 | metrics = {"top_1_accuracy": Accuracy(), "top_5_accuracy": TopKCategoricalAccuracy()}
107 |
108 | # Set-up losses
109 | task_loss_fn = CrossEntropyLoss()
110 | dampening_loss = None
111 | if config.osc_damp.weight is not None:
112 | # Add dampening loss to task loss
113 | dampening_loss = DampeningLoss(model, config.osc_damp.weight, config.osc_damp.aggregation)
114 | loss_dict = {"task_loss": task_loss_fn, "dampening_loss": dampening_loss}
115 | loss_func = CompositeLoss(loss_dict)
116 | loss_metrics = {
117 | "task_loss": Loss(task_loss_fn),
118 | "dampening_loss": Loss(dampening_loss),
119 | "loss": Loss(loss_func),
120 | }
121 | else:
122 | loss_func = task_loss_fn
123 | loss_metrics = {"loss": Loss(loss_func)}
124 |
125 | metrics.update(loss_metrics)
126 |
127 | # Set up ignite trainer and evaluator
128 | trainer, evaluator = create_trainer_engine(
129 | model=model,
130 | optimizer=optimizer,
131 | criterion=loss_func,
132 | data_loaders=dataloaders,
133 | metrics=metrics,
134 | lr_scheduler=lr_scheduler,
135 | save_checkpoint_dir=config.base.save_checkpoint_dir,
136 | device="cuda" if config.base.cuda else "cpu",
137 | )
138 |
139 | if config.base.progress_bar:
140 | pbar = ProgressBar()
141 | pbar.attach(trainer)
142 | pbar.attach(evaluator)
143 |
144 | # Create TensorboardLogger
145 | if config.base.tb_logging_dir:
146 | if config.qat.sep_quant_optimizer:
147 | optimizers_dict = {"model": model_optimizer, "quant_params": quant_optimizer}
148 | else:
149 | optimizers_dict = optimizer
150 | tb_logger = setup_tensorboard_logger(
151 | trainer, evaluator, config.base.tb_logging_dir, optimizers_dict
152 | )
153 |
154 | if config.osc_damp.weight_final:
155 | # Apply cosine annealing of dampening loss
156 | total_iterations = len(dataloaders.train_loader) * config.base.max_epochs
157 | annealing_schedule = CosineTempDecay(
158 | t_max=total_iterations,
159 | temp_range=(config.osc_damp.weight, config.osc_damp.weight_final),
160 | rel_decay_start=config.osc_damp.anneal_start,
161 | )
162 | print(f"Weight gradient parameter cosine annealing schedule:\n{annealing_schedule}")
163 | trainer.add_event_handler(
164 | Events.ITERATION_STARTED,
165 | UpdateDampeningLossWeighting(dampening_loss, annealing_schedule),
166 | )
167 |
168 | # Evaluate model
169 | print("Running evaluation before training")
170 | evaluator.run(dataloaders.val_loader)
171 | log_metrics(evaluator.state.metrics, "Evaluation", trainer.state.epoch)
172 |
173 | # BN Re-estimation
174 | if config.qat.reestimate_bn_stats:
175 | evaluator.add_event_handler(
176 | Events.EPOCH_STARTED, ReestimateBNStats(model, dataloaders.train_loader)
177 | )
178 |
179 | # Add oscillation trackers to the model and set up oscillation freezing
180 | if config.osc_freeze.threshold:
181 | oscillation_tracker_dict = add_oscillation_trackers(
182 | model,
183 | max_bits=config.osc_freeze.max_bits,
184 | momentum=config.osc_freeze.ema_momentum,
185 | freeze_threshold=config.osc_freeze.threshold,
186 | use_ema_x_int=config.osc_freeze.use_ema,
187 | )
188 |
189 | if config.osc_freeze.threshold_final:
190 | # Apply cosine annealing schedule to the freezing threshdold
191 | total_iterations = len(dataloaders.train_loader) * config.base.max_epochs
192 | annealing_schedule = CosineTempDecay(
193 | t_max=total_iterations,
194 | temp_range=(config.osc_freeze.threshold, config.osc_freeze.threshold_final),
195 | rel_decay_start=config.osc_freeze.anneal_start,
196 | )
197 | print(f"Oscillation freezing annealing schedule:\n{annealing_schedule}")
198 | trainer.add_event_handler(
199 | Events.ITERATION_STARTED,
200 | UpdateFreezingThreshold(oscillation_tracker_dict, annealing_schedule),
201 | )
202 |
203 | print("Starting training")
204 |
205 | trainer.run(dataloaders.train_loader, max_epochs=config.base.max_epochs)
206 |
207 | print("Finished training")
208 |
209 |
210 | @oscillations.command()
211 | @pass_config
212 | @base_options
213 | @quantization_options
214 | @click.option(
215 | "--load-type",
216 | type=click.Choice(["fp32", "quantized"]),
217 | default="quantized",
218 | help='Either "fp32", or "quantized". Specify weather to load a quantized or a FP ' "model.",
219 | )
220 | def validate_quantized(config, load_type):
221 | """
222 | function for running validation on pre-trained quantized models
223 | """
224 | print("Setting up network and data loaders")
225 | qparams = quant_params_dict(config)
226 |
227 | dataloaders, model = get_dataloaders_and_model(config=config, load_type=load_type, **qparams)
228 |
229 | if load_type == "fp32":
230 | # Estimate ranges using training data
231 | pass_data_for_range_estimation(
232 | loader=dataloaders.train_loader,
233 | model=model,
234 | act_quant=config.quant.act_quant,
235 | weight_quant=config.quant.weight_quant,
236 | max_num_batches=config.quant.num_est_batches,
237 | )
238 | # Ensure we have the desired quant state
239 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant)
240 |
241 | # Fix ranges
242 | model.fix_ranges()
243 | print("Loaded model:\n{}".format(model))
244 |
245 | # Create evaluator
246 | loss_func = CrossEntropyLoss()
247 | metrics = {
248 | "top_1_accuracy": Accuracy(),
249 | "top_5_accuracy": TopKCategoricalAccuracy(),
250 | "loss": Loss(loss_func),
251 | }
252 |
253 | pbar = ProgressBar()
254 | evaluator = create_supervised_evaluator(
255 | model=model, metrics=metrics, device="cuda" if config.base.cuda else "cpu"
256 | )
257 | pbar.attach(evaluator)
258 | print("Start quantized validation")
259 | evaluator.run(dataloaders.val_loader)
260 | final_metrics = evaluator.state.metrics
261 | print(final_metrics)
262 |
263 |
264 | if __name__ == "__main__":
265 | oscillations()
266 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from models.efficientnet_lite_quantized import efficientnet_lite0_quantized
6 | from models.mobilenet_v2_quantized import mobilenetv2_quantized
7 | from models.resnet_quantized import resnet18_quantized, resnet50_quantized
8 | from utils import ClassEnumOptions, MethodMap
9 |
10 |
11 | class QuantArchitectures(ClassEnumOptions):
12 | mobilenet_v2_quantized = MethodMap(mobilenetv2_quantized)
13 | resnet18_quantized = MethodMap(resnet18_quantized)
14 | resnet50_quantized = MethodMap(resnet50_quantized)
15 | efficientnet_lite0_quantized = MethodMap(efficientnet_lite0_quantized)
16 |
17 |
--------------------------------------------------------------------------------
/models/efficientnet_lite_quantized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import torch
5 | from timm.models import create_model
6 | from timm.models.efficientnet_blocks import DepthwiseSeparableConv, InvertedResidual
7 | from torch import nn
8 |
9 | from quantization.autoquant_utils import quantize_sequential, quantize_model
10 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts
11 | from quantization.base_quantized_model import QuantizedModel
12 |
13 |
14 | class QuantizedInvertedResidual(QuantizedActivation):
15 | def __init__(self, inv_res_orig, **quant_params):
16 | super().__init__(**quant_params)
17 |
18 | assert inv_res_orig.drop_path_rate == 0.0
19 | assert isinstance(inv_res_orig.se, nn.Identity)
20 |
21 | self.has_residual = inv_res_orig.has_residual
22 |
23 | conv_pw = nn.Sequential(inv_res_orig.conv_pw, inv_res_orig.bn1, inv_res_orig.act1)
24 | self.conv_pw = quantize_sequential(conv_pw, **quant_params)[0]
25 |
26 | conv_dw = nn.Sequential(inv_res_orig.conv_dw, inv_res_orig.bn2, inv_res_orig.act2)
27 | self.conv_dw = quantize_sequential(conv_dw, **quant_params) # [0]
28 |
29 | conv_pwl = nn.Sequential(inv_res_orig.conv_pwl, inv_res_orig.bn3)
30 | self.conv_pwl = quantize_sequential(conv_pwl, **quant_params)[0]
31 |
32 | def forward(self, x):
33 | residual = x
34 | # Point-wise expansion
35 | x = self.conv_pw(x)
36 | # Depth-wise convolution
37 | x = self.conv_dw(x)
38 | # Point-wise linear projection
39 | x = self.conv_pwl(x)
40 |
41 | if self.has_residual:
42 | x += residual
43 | x = self.quantize_activations(x)
44 | return x
45 |
46 |
47 | class QuantizedDepthwiseSeparableConv(QuantizedActivation):
48 | def __init__(self, dws_orig, **quant_params):
49 | super().__init__(**quant_params)
50 |
51 | assert dws_orig.drop_path_rate == 0.0
52 | assert isinstance(dws_orig.se, nn.Identity)
53 |
54 | self.has_residual = dws_orig.has_residual
55 |
56 | conv_dw = nn.Sequential(dws_orig.conv_dw, dws_orig.bn1, dws_orig.act1)
57 | self.conv_dw = quantize_sequential(conv_dw, **quant_params)[0]
58 |
59 | conv_pw = nn.Sequential(dws_orig.conv_pw, dws_orig.bn2, dws_orig.act2)
60 | self.conv_pw = quantize_sequential(conv_pw, **quant_params)[0]
61 |
62 | def forward(self, x):
63 | residual = x
64 | # Depth-wise convolution
65 | x = self.conv_dw(x)
66 | # Point-wise projection
67 | x = self.conv_pw(x)
68 | if self.has_residual:
69 | x += residual
70 | x = self.quantize_activations(x)
71 | return x
72 |
73 |
74 | class QuantizedEfficientNetLite(QuantizedModel):
75 | def __init__(self, base_model, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params):
76 | super().__init__(input_size)
77 |
78 | specials = {
79 | InvertedResidual: QuantizedInvertedResidual,
80 | DepthwiseSeparableConv: QuantizedDepthwiseSeparableConv,
81 | }
82 |
83 | conv_stem = nn.Sequential(base_model.conv_stem, base_model.bn1, base_model.act1)
84 | self.conv_stem = quantize_model(conv_stem, specials=specials, **quant_params)[0]
85 |
86 | self.blocks = quantize_model(base_model.blocks, specials=specials, **quant_params)
87 |
88 | conv_head = nn.Sequential(base_model.conv_head, base_model.bn2, base_model.act2)
89 | self.conv_head = quantize_model(conv_head, specials=specials, **quant_params)[0]
90 |
91 | self.global_pool = base_model.global_pool
92 |
93 | base_model.classifier.__class__ = nn.Linear # Small hack to work with autoquant
94 | self.classifier = quantize_model(base_model.classifier, **quant_params)
95 |
96 | if quant_setup == "FP_logits":
97 | print("Do not quantize output of FC layer")
98 | self.classifier.activation_quantizer = FP32Acts() # no activation quantization of
99 | # logits
100 | elif quant_setup == "LSQ":
101 | print("Set quantization to LSQ (first+last layer in 8 bits)")
102 | # Weights of the first layer
103 | self.conv_stem.weight_quantizer.quantizer.n_bits = 8
104 | # The quantizer of the last conv_layer layer (input to global)
105 | self.conv_head.activation_quantizer.quantizer.n_bits = 8
106 | # Weights of the last layer
107 | self.classifier.weight_quantizer.quantizer.n_bits = 8
108 | # no activation quantization of logits
109 | self.classifier.activation_quantizer = FP32Acts()
110 | elif quant_setup == "LSQ_paper":
111 | # Weights of the first layer
112 | self.conv_stem.activation_quantizer = FP32Acts()
113 | self.conv_stem.weight_quantizer.quantizer.n_bits = 8
114 | # Weights of the last layer
115 | self.classifier.activation_quantizer.quantizer.n_bits = 8
116 | self.classifier.weight_quantizer.quantizer.n_bits = 8
117 | # Set all QuantizedActivations to FP32
118 | for layer in self.blocks.modules():
119 | if isinstance(layer, QuantizedActivation):
120 | layer.activation_quantizer = FP32Acts()
121 | elif quant_setup is not None and quant_setup != "all":
122 | raise ValueError(
123 | "Quantization setup '{}' not supported for EfficientNet lite".format(quant_setup)
124 | )
125 |
126 | def forward(self, x):
127 | # features
128 | x = self.conv_stem(x)
129 | x = self.blocks(x)
130 | x = self.conv_head(x)
131 |
132 | x = self.global_pool(x)
133 | x = x.flatten(1)
134 | return self.classifier(x)
135 |
136 |
137 | def efficientnet_lite0_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams):
138 | if load_type == "fp32":
139 | # Load model from pretrained FP32 weights
140 | fp_model = create_model("efficientnet_lite0", pretrained=pretrained)
141 | quant_model = QuantizedEfficientNetLite(fp_model, **qparams)
142 | elif load_type == "quantized":
143 | # Load pretrained QuantizedModel
144 | print(f"Loading pretrained quantized model from {model_dir}")
145 | state_dict = torch.load(model_dir)
146 | fp_model = create_model("efficientnet_lite0")
147 | quant_model = QuantizedEfficientNetLite(fp_model, **qparams)
148 | quant_model.load_state_dict(state_dict)
149 | else:
150 | raise ValueError("wrong load_type specified")
151 | return quant_model
152 |
--------------------------------------------------------------------------------
/models/mobilenet_v2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | # https://github.com/tonylins/pytorch-mobilenet-v2
6 |
7 | import math
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | __all__ = ["MobileNetV2"]
13 |
14 |
15 | def conv_bn(inp, oup, stride):
16 | return nn.Sequential(
17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)
18 | )
19 |
20 |
21 | def conv_1x1_bn(inp, oup):
22 | return nn.Sequential(
23 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)
24 | )
25 |
26 |
27 | class InvertedResidual(nn.Module):
28 | def __init__(self, inp, oup, stride, expand_ratio):
29 | super(InvertedResidual, self).__init__()
30 | self.stride = stride
31 | assert stride in [1, 2]
32 |
33 | hidden_dim = round(inp * expand_ratio)
34 | self.use_res_connect = self.stride == 1 and inp == oup
35 |
36 | if expand_ratio == 1:
37 | self.conv = nn.Sequential(
38 | # dw
39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
40 | nn.BatchNorm2d(hidden_dim),
41 | nn.ReLU6(inplace=True),
42 | # pw-linear
43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(oup),
45 | )
46 | else:
47 | self.conv = nn.Sequential(
48 | # pw
49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
50 | nn.BatchNorm2d(hidden_dim),
51 | nn.ReLU6(inplace=True),
52 | # dw
53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
54 | nn.BatchNorm2d(hidden_dim),
55 | nn.ReLU6(inplace=True),
56 | # pw-linear
57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
58 | nn.BatchNorm2d(oup),
59 | )
60 |
61 | def forward(self, x):
62 | if self.use_res_connect:
63 | return x + self.conv(x)
64 | else:
65 | return self.conv(x)
66 |
67 |
68 | class MobileNetV2(nn.Module):
69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.0, dropout=0.0):
70 | super().__init__()
71 | block = InvertedResidual
72 | input_channel = 32
73 | last_channel = 1280
74 | inverted_residual_setting = [
75 | # t, c, n, s
76 | [1, 16, 1, 1],
77 | [6, 24, 2, 2],
78 | [6, 32, 3, 2],
79 | [6, 64, 4, 2],
80 | [6, 96, 3, 1],
81 | [6, 160, 3, 2],
82 | [6, 320, 1, 1],
83 | ]
84 |
85 | # building first layer
86 | assert input_size % 32 == 0
87 | input_channel = int(input_channel * width_mult)
88 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
89 | features = [conv_bn(3, input_channel, 2)]
90 | # building inverted residual blocks
91 | for t, c, n, s in inverted_residual_setting:
92 | output_channel = int(c * width_mult)
93 | for i in range(n):
94 | if i == 0:
95 | features.append(block(input_channel, output_channel, s, expand_ratio=t))
96 | else:
97 | features.append(block(input_channel, output_channel, 1, expand_ratio=t))
98 | input_channel = output_channel
99 | # building last several layers
100 | features.append(conv_1x1_bn(input_channel, self.last_channel))
101 | features.append(nn.AvgPool2d(input_size // 32))
102 | # make it nn.Sequential
103 | self.features = nn.Sequential(*features)
104 |
105 | # building classifier
106 | self.classifier = nn.Sequential(
107 | nn.Dropout(dropout),
108 | nn.Linear(self.last_channel, n_class),
109 | )
110 |
111 | self._initialize_weights()
112 |
113 | def forward(self, x):
114 | x = self.features(x)
115 | x = F.adaptive_avg_pool2d(x, 1).squeeze() # type: ignore[arg-type] # accepted slang
116 | x = self.classifier(x)
117 | return x
118 |
119 | def _initialize_weights(self):
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
124 | if m.bias is not None:
125 | m.bias.data.zero_()
126 | elif isinstance(m, nn.BatchNorm2d):
127 | m.weight.data.fill_(1)
128 | m.bias.data.zero_()
129 | elif isinstance(m, nn.Linear):
130 | n = m.weight.size(1)
131 | m.weight.data.normal_(0, 0.01)
132 | m.bias.data.zero_()
133 |
--------------------------------------------------------------------------------
/models/mobilenet_v2_quantized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import os
6 | import re
7 | import torch
8 | from collections import OrderedDict
9 | from models.mobilenet_v2 import MobileNetV2, InvertedResidual
10 | from quantization.autoquant_utils import quantize_sequential, Flattener, quantize_model, BNQConv
11 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts
12 | from quantization.base_quantized_model import QuantizedModel
13 |
14 |
15 | class QuantizedInvertedResidual(QuantizedActivation):
16 | def __init__(self, inv_res_orig, **quant_params):
17 | super().__init__(**quant_params)
18 | self.use_res_connect = inv_res_orig.use_res_connect
19 | self.conv = quantize_sequential(inv_res_orig.conv, **quant_params)
20 |
21 | def forward(self, x):
22 | if self.use_res_connect:
23 | x = x + self.conv(x)
24 | return self.quantize_activations(x)
25 | else:
26 | return self.conv(x)
27 |
28 |
29 | class QuantizedMobileNetV2(QuantizedModel):
30 | def __init__(self, model_fp, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params):
31 | super().__init__(input_size)
32 | specials = {InvertedResidual: QuantizedInvertedResidual}
33 | # quantize and copy parts from original model
34 | quantize_input = quant_setup and quant_setup == "LSQ_paper"
35 | self.features = quantize_sequential(
36 | model_fp.features,
37 | tie_activation_quantizers=not quantize_input,
38 | specials=specials,
39 | **quant_params,
40 | )
41 |
42 | self.flattener = Flattener()
43 | self.classifier = quantize_model(model_fp.classifier, **quant_params)
44 |
45 | if quant_setup == "FP_logits":
46 | print("Do not quantize output of FC layer")
47 | self.classifier[1].activation_quantizer = FP32Acts()
48 | # self.classifier.activation_quantizer = FP32Acts() # no activation quantization of logits
49 | elif quant_setup == "fc4":
50 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
51 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4
52 | elif quant_setup == "fc4_dw8":
53 | print("\n\n### fc4_dw8 setup ###\n\n")
54 | # FC layer in 4 bits, depth-wise separable once in 8 bit
55 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
56 | self.classifier[1].weight_quantizer.quantizer.n_bits = 4
57 | for name, module in self.named_modules():
58 | if isinstance(module, BNQConv) and module.groups == module.in_channels:
59 | module.weight_quantizer.quantizer.n_bits = 8
60 | print(f"Set layer {name} to 8 bits")
61 | elif quant_setup == "LSQ":
62 | print("Set quantization to LSQ (first+last layer in 8 bits)")
63 | # Weights of the first layer
64 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
65 | # The quantizer of the last conv_layer layer (input to avgpool with tied quantizers)
66 | self.features[-2][0].activation_quantizer.quantizer.n_bits = 8
67 | # Weights of the last layer
68 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8
69 | # no activation quantization of logits
70 | self.classifier[1].activation_quantizer = FP32Acts()
71 | elif quant_setup == "LSQ_paper":
72 | # Weights of the first layer
73 | self.features[0][0].activation_quantizer = FP32Acts()
74 | self.features[0][0].weight_quantizer.quantizer.n_bits = 8
75 | # Weights of the last layer
76 | self.classifier[1].weight_quantizer.quantizer.n_bits = 8
77 | self.classifier[1].activation_quantizer.quantizer.n_bits = 8
78 | # Set all QuantizedActivations to FP32
79 | for layer in self.features.modules():
80 | if isinstance(layer, QuantizedActivation):
81 | layer.activation_quantizer = FP32Acts()
82 | elif quant_setup is not None and quant_setup != "all":
83 | raise ValueError(
84 | "Quantization setup '{}' not supported for MobilenetV2".format(quant_setup)
85 | )
86 |
87 | def forward(self, x):
88 | x = self.features(x)
89 | x = self.flattener(x)
90 | x = self.classifier(x)
91 |
92 | return x
93 |
94 |
95 | def mobilenetv2_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams):
96 | fp_model = MobileNetV2()
97 | if pretrained and load_type == "fp32":
98 | # Load model from pretrained FP32 weights
99 | assert os.path.exists(model_dir)
100 | print(f"Loading pretrained weights from {model_dir}")
101 | state_dict = torch.load(model_dir)
102 | fp_model.load_state_dict(state_dict)
103 | quant_model = QuantizedMobileNetV2(fp_model, **qparams)
104 | elif load_type == "quantized":
105 | # Load pretrained QuantizedModel
106 | print(f"Loading pretrained quantized model from {model_dir}")
107 | state_dict = torch.load(model_dir)
108 | quant_model = QuantizedMobileNetV2(fp_model, **qparams)
109 | quant_model.load_state_dict(state_dict, strict=False)
110 | else:
111 | raise ValueError("wrong load_type specified")
112 |
113 | return quant_model
114 |
--------------------------------------------------------------------------------
/models/resnet_quantized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import torch
5 | from torch import nn
6 | from torchvision.models.resnet import BasicBlock, Bottleneck
7 | from torchvision.models import resnet18, resnet50
8 |
9 | from quantization.autoquant_utils import quantize_model, Flattener, QuantizedActivationWrapper
10 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts
11 | from quantization.base_quantized_model import QuantizedModel
12 |
13 |
14 | class QuantizedBlock(QuantizedActivation):
15 | def __init__(self, block, **quant_params):
16 | super().__init__(**quant_params)
17 |
18 | if isinstance(block, Bottleneck):
19 | features = nn.Sequential(
20 | block.conv1,
21 | block.bn1,
22 | block.relu,
23 | block.conv2,
24 | block.bn2,
25 | block.relu,
26 | block.conv3,
27 | block.bn3,
28 | )
29 | elif isinstance(block, BasicBlock):
30 | features = nn.Sequential(block.conv1, block.bn1, block.relu, block.conv2, block.bn2)
31 |
32 | self.features = quantize_model(features, **quant_params)
33 | self.downsample = (
34 | quantize_model(block.downsample, **quant_params) if block.downsample else None
35 | )
36 |
37 | self.relu = block.relu
38 |
39 | def forward(self, x):
40 | residual = x if self.downsample is None else self.downsample(x)
41 | out = self.features(x)
42 |
43 | out += residual
44 | out = self.relu(out)
45 |
46 | return self.quantize_activations(out)
47 |
48 |
49 | class QuantizedResNet(QuantizedModel):
50 | def __init__(self, resnet, input_size=(1, 3, 224, 224), quant_setup=None, **quant_params):
51 | super().__init__(input_size)
52 | specials = {BasicBlock: QuantizedBlock, Bottleneck: QuantizedBlock}
53 |
54 | if hasattr(resnet, "maxpool"):
55 | # ImageNet ResNet case
56 | features = nn.Sequential(
57 | resnet.conv1,
58 | resnet.bn1,
59 | resnet.relu,
60 | resnet.maxpool,
61 | resnet.layer1,
62 | resnet.layer2,
63 | resnet.layer3,
64 | resnet.layer4,
65 | )
66 | else:
67 | # Tiny ImageNet ResNet case
68 | features = nn.Sequential(
69 | resnet.conv1,
70 | resnet.bn1,
71 | resnet.relu,
72 | resnet.layer1,
73 | resnet.layer2,
74 | resnet.layer3,
75 | resnet.layer4,
76 | )
77 |
78 | self.features = quantize_model(features, specials=specials, **quant_params)
79 |
80 | if quant_setup and quant_setup == "LSQ_paper":
81 | # Keep avgpool intact as we quantize the input the last layer
82 | self.avgpool = resnet.avgpool
83 | else:
84 | self.avgpool = QuantizedActivationWrapper(
85 | resnet.avgpool,
86 | tie_activation_quantizers=True,
87 | input_quantizer=self.features[-1][-1].activation_quantizer,
88 | **quant_params,
89 | )
90 | self.flattener = Flattener()
91 | self.fc = quantize_model(resnet.fc, **quant_params)
92 |
93 | # Adapt to specific quantization setup
94 | if quant_setup == "LSQ":
95 | print("Set quantization to LSQ (first+last layer in 8 bits)")
96 | # Weights of the first layer
97 | self.features[0].weight_quantizer.quantizer.n_bits = 8
98 | # The quantizer of the residual (input to last layer)
99 | self.features[-1][-1].activation_quantizer.quantizer.n_bits = 8
100 | # Output of the last conv (input to last layer)
101 | self.features[-1][-1].features[-1].activation_quantizer.quantizer.n_bits = 8
102 | # Weights of the last layer
103 | self.fc.weight_quantizer.quantizer.n_bits = 8
104 | # no activation quantization of logits
105 | self.fc.activation_quantizer = FP32Acts()
106 | elif quant_setup == "LSQ_paper":
107 | # Weights of the first layer
108 | self.features[0].activation_quantizer = FP32Acts()
109 | self.features[0].weight_quantizer.quantizer.n_bits = 8
110 | # Weights of the last layer
111 | self.fc.activation_quantizer.quantizer.n_bits = 8
112 | self.fc.weight_quantizer.quantizer.n_bits = 8
113 | # Set all QuantizedActivations to FP32
114 | for layer in self.features.modules():
115 | if isinstance(layer, QuantizedActivation):
116 | layer.activation_quantizer = FP32Acts()
117 | elif quant_setup == "FP_logits":
118 | print("Do not quantize output of FC layer")
119 | self.fc.activation_quantizer = FP32Acts() # no activation quantization of logits
120 | elif quant_setup == "fc4":
121 | self.features[0].weight_quantizer.quantizer.n_bits = 8
122 | self.fc.weight_quantizer.quantizer.n_bits = 4
123 | elif quant_setup is not None and quant_setup != "all":
124 | raise ValueError("Quantization setup '{}' not supported for Resnet".format(quant_setup))
125 |
126 | def forward(self, x):
127 | x = self.features(x)
128 | x = self.avgpool(x)
129 | # x = x.view(x.size(0), -1)
130 | x = self.flattener(x)
131 | x = self.fc(x)
132 |
133 | return x
134 |
135 |
136 | def resnet18_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams):
137 | if load_type == "fp32":
138 | # Load model from pretrained FP32 weights
139 | fp_model = resnet18(pretrained=pretrained)
140 | quant_model = QuantizedResNet(fp_model, **qparams)
141 | elif load_type == "quantized":
142 | # Load pretrained QuantizedModel
143 | print(f"Loading pretrained quantized model from {model_dir}")
144 | state_dict = torch.load(model_dir)
145 | fp_model = resnet18()
146 | quant_model = QuantizedResNet(fp_model, **qparams)
147 | quant_model.load_state_dict(state_dict)
148 | else:
149 | raise ValueError("wrong load_type specified")
150 | return quant_model
151 |
152 |
153 | def resnet50_quantized(pretrained=True, model_dir=None, load_type="fp32", **qparams):
154 | if load_type == "fp32":
155 | # Load model from pretrained FP32 weights
156 | fp_model = resnet50(pretrained=pretrained)
157 | quant_model = QuantizedResNet(fp_model, **qparams)
158 | elif load_type == "quantized":
159 | # Load pretrained QuantizedModel
160 | print(f"Loading pretrained quantized model from {model_dir}")
161 | state_dict = torch.load(model_dir)
162 | fp_model = resnet50()
163 | quant_model = QuantizedResNet(fp_model, **qparams)
164 | quant_model.load_state_dict(state_dict)
165 | else:
166 | raise ValueError("wrong load_type specified")
167 | return quant_model
168 |
--------------------------------------------------------------------------------
/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from . import utils, autoquant_utils, quantized_folded_bn
6 |
--------------------------------------------------------------------------------
/quantization/autoquant_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import copy
6 | import warnings
7 |
8 | from torch import nn
9 | from torch.nn import functional as F
10 | from torch.nn.modules.conv import _ConvNd
11 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd
12 |
13 |
14 | from quantization.base_quantized_classes import QuantizedActivation, QuantizedModule
15 | from quantization.hijacker import QuantizationHijacker, activations_set
16 | from quantization.quantization_manager import QuantizationManager
17 | from quantization.quantized_folded_bn import BNFusedHijacker
18 |
19 |
20 | class QuantConv1d(QuantizationHijacker, nn.Conv1d):
21 | def run_forward(self, x, weight, bias, offsets=None):
22 | return F.conv1d(
23 | x.contiguous(),
24 | weight.contiguous(),
25 | bias=bias,
26 | stride=self.stride,
27 | padding=self.padding,
28 | dilation=self.dilation,
29 | groups=self.groups,
30 | )
31 |
32 |
33 | class QuantConv(QuantizationHijacker, nn.Conv2d):
34 | def run_forward(self, x, weight, bias, offsets=None):
35 | return F.conv2d(
36 | x.contiguous(),
37 | weight.contiguous(),
38 | bias=bias,
39 | stride=self.stride,
40 | padding=self.padding,
41 | dilation=self.dilation,
42 | groups=self.groups,
43 | )
44 |
45 |
46 | class QuantConvTransposeBase(QuantizationHijacker):
47 | def quantize_weights(self, weights):
48 | if self.per_channel_weights:
49 | # NOTE: ND tranpose conv weights are stored as (in_channels, out_channels, *)
50 | # instead of (out_channels, in_channels, *) for convs
51 | # and per-channel quantization should be applied to out channels
52 | # transposing before passing to quantizer is trick to avoid
53 | # changing logic in range estimators and quantizers
54 | weights = weights.transpose(1, 0).contiguous()
55 | weights = self.weight_quantizer(weights)
56 | if self.per_channel_weights:
57 | weights = weights.transpose(1, 0).contiguous()
58 | return weights
59 |
60 |
61 | class QuantConvTranspose1d(QuantConvTransposeBase, nn.ConvTranspose1d):
62 | def run_forward(self, x, weight, bias, offsets=None):
63 | return F.conv_transpose1d(
64 | x.contiguous(),
65 | weight.contiguous(),
66 | bias=bias,
67 | stride=self.stride,
68 | padding=self.padding,
69 | output_padding=self.output_padding,
70 | dilation=self.dilation,
71 | groups=self.groups,
72 | )
73 |
74 |
75 | class QuantConvTranspose(QuantConvTransposeBase, nn.ConvTranspose2d):
76 | def run_forward(self, x, weight, bias, offsets=None):
77 | return F.conv_transpose2d(
78 | x.contiguous(),
79 | weight.contiguous(),
80 | bias=bias,
81 | stride=self.stride,
82 | padding=self.padding,
83 | output_padding=self.output_padding,
84 | dilation=self.dilation,
85 | groups=self.groups,
86 | )
87 |
88 |
89 | class QuantLinear(QuantizationHijacker, nn.Linear):
90 | def run_forward(self, x, weight, bias, offsets=None):
91 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias)
92 |
93 |
94 | class BNQConv1d(BNFusedHijacker, nn.Conv1d):
95 | def run_forward(self, x, weight, bias, offsets=None):
96 | return F.conv1d(
97 | x.contiguous(),
98 | weight.contiguous(),
99 | bias=bias,
100 | stride=self.stride,
101 | padding=self.padding,
102 | dilation=self.dilation,
103 | groups=self.groups,
104 | )
105 |
106 |
107 | class BNQConv(BNFusedHijacker, nn.Conv2d):
108 | def run_forward(self, x, weight, bias, offsets=None):
109 | return F.conv2d(
110 | x.contiguous(),
111 | weight.contiguous(),
112 | bias=bias,
113 | stride=self.stride,
114 | padding=self.padding,
115 | dilation=self.dilation,
116 | groups=self.groups,
117 | )
118 |
119 |
120 | class BNQLinear(BNFusedHijacker, nn.Linear):
121 | def run_forward(self, x, weight, bias, offsets=None):
122 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias)
123 |
124 |
125 | class QuantizedActivationWrapper(QuantizedActivation):
126 | """
127 | Wraps over a layer and quantized the activation.
128 | It also allow for tying the input and output quantizer which is helpful
129 | for layers such Average Pooling
130 | """
131 |
132 | def __init__(
133 | self,
134 | layer,
135 | tie_activation_quantizers=False,
136 | input_quantizer: QuantizationManager = None,
137 | *args,
138 | **kwargs,
139 | ):
140 | super().__init__(*args, **kwargs)
141 | self.tie_activation_quantizers = tie_activation_quantizers
142 | if input_quantizer:
143 | assert isinstance(input_quantizer, QuantizationManager)
144 | self.activation_quantizer = input_quantizer
145 | self.layer = layer
146 |
147 | def quantize_activations_no_range_update(self, x):
148 | if self._quant_a:
149 | return self.activation_quantizer.quantizer(x)
150 | else:
151 | return x
152 |
153 | def forward(self, x):
154 | x = self.layer(x)
155 | if self.tie_activation_quantizers:
156 | # The input activation quantizer is used to quantize the activation
157 | # but without updating the quantization range
158 | return self.quantize_activations_no_range_update(x)
159 | else:
160 | return self.quantize_activations(x)
161 |
162 | def extra_repr(self):
163 | return f"tie_activation_quantizers={self.tie_activation_quantizers}"
164 |
165 |
166 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm):
167 | def run_forward(self, x, weight, bias, offsets=None):
168 | return F.layer_norm(
169 | input=x.contiguous(),
170 | normalized_shape=self.normalized_shape,
171 | weight=weight.contiguous(),
172 | bias=bias.contiguous(),
173 | eps=self.eps,
174 | )
175 |
176 |
177 | class Flattener(nn.Module):
178 | def forward(self, x):
179 | return x.view(x.shape[0], -1)
180 |
181 |
182 | # Non BN Quant Modules Map
183 | non_bn_module_map = {
184 | nn.Conv1d: QuantConv1d,
185 | nn.Conv2d: QuantConv,
186 | nn.ConvTranspose1d: QuantConvTranspose1d,
187 | nn.ConvTranspose2d: QuantConvTranspose,
188 | nn.Linear: QuantLinear,
189 | nn.LayerNorm: QuantLayerNorm,
190 | }
191 |
192 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd)
193 | # BN Quant Modules Map
194 | bn_module_map = {nn.Conv1d: BNQConv1d, nn.Conv2d: BNQConv, nn.Linear: BNQLinear}
195 |
196 | quant_conv_modules = (QuantConv1d, QuantConv, BNQConv1d, BNQConv)
197 |
198 |
199 | def next_bn(module, i):
200 | return len(module) > i + 1 and isinstance(module[i + 1], (nn.BatchNorm2d, nn.BatchNorm1d))
201 |
202 |
203 | def get_act(module, i):
204 | # Case 1: conv + act
205 | if len(module) - i > 1 and isinstance(module[i + 1], tuple(activations_set)):
206 | return module[i + 1], i + 1
207 |
208 | # Case 2: conv + bn + act
209 | if (
210 | len(module) - i > 2
211 | and next_bn(module, i)
212 | and isinstance(module[i + 2], tuple(activations_set))
213 | ):
214 | return module[i + 2], i + 2
215 |
216 | # Case 3: conv + bn + X -> return false
217 | # Case 4: conv + X -> return false
218 | return None, None
219 |
220 |
221 | def get_conv_args(module):
222 | args = dict(
223 | in_channels=module.in_channels,
224 | out_channels=module.out_channels,
225 | kernel_size=module.kernel_size,
226 | stride=module.stride,
227 | padding=module.padding,
228 | dilation=module.dilation,
229 | groups=module.groups,
230 | bias=module.bias is not None,
231 | )
232 | if isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d)):
233 | args["output_padding"] = module.output_padding
234 | return args
235 |
236 |
237 | def get_linear_args(module):
238 | args = dict(
239 | in_features=module.in_features,
240 | out_features=module.out_features,
241 | bias=module.bias is not None,
242 | )
243 | return args
244 |
245 |
246 | def get_layernorm_args(module):
247 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps)
248 | return args
249 |
250 |
251 | def get_module_args(mod, act):
252 | if isinstance(mod, _ConvNd):
253 | kwargs = get_conv_args(mod)
254 | elif isinstance(mod, nn.Linear):
255 | kwargs = get_linear_args(mod)
256 | elif isinstance(mod, nn.LayerNorm):
257 | kwargs = get_layernorm_args(mod)
258 | else:
259 | raise ValueError
260 |
261 | kwargs["activation"] = act
262 |
263 | return kwargs
264 |
265 |
266 | def fold_bn(module, i, **quant_params):
267 | bn = next_bn(module, i)
268 | act, act_idx = get_act(module, i)
269 | modmap = bn_module_map if bn else non_bn_module_map
270 | modtype = modmap[type(module[i])]
271 |
272 | kwargs = get_module_args(module[i], act)
273 | new_module = modtype(**kwargs, **quant_params)
274 | new_module.weight.data = module[i].weight.data.clone()
275 |
276 | if bn:
277 | new_module.gamma.data = module[i + 1].weight.data.clone()
278 | new_module.beta.data = module[i + 1].bias.data.clone()
279 | new_module.running_mean.data = module[i + 1].running_mean.data.clone()
280 | new_module.running_var.data = module[i + 1].running_var.data.clone()
281 | if module[i].bias is not None:
282 | new_module.running_mean.data -= module[i].bias.data
283 | print("Warning: bias in conv/linear before batch normalization.")
284 | new_module.epsilon = module[i + 1].eps
285 |
286 | elif module[i].bias is not None:
287 | new_module.bias.data = module[i].bias.data.clone()
288 |
289 | return new_module, i + int(bool(act)) + int(bn) + 1
290 |
291 |
292 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params):
293 | specials = specials or dict()
294 |
295 | i = 0
296 | quant_modules = []
297 | while i < len(model):
298 | if isinstance(model[i], QuantizedModule):
299 | quant_modules.append(model[i])
300 | elif type(model[i]) in non_bn_module_map:
301 | new_module, new_i = fold_bn(model, i, **quant_params)
302 | quant_modules.append(new_module)
303 | i = new_i
304 | continue
305 |
306 | elif type(model[i]) in specials:
307 | quant_modules.append(specials[type(model[i])](model[i], **quant_params))
308 |
309 | elif isinstance(model[i], non_param_modules):
310 | # Check for last quantizer
311 | input_quantizer = None
312 | if quant_modules and isinstance(quant_modules[-1], QuantizedModule):
313 | last_layer = quant_modules[-1]
314 | input_quantizer = quant_modules[-1].activation_quantizer
315 | elif (
316 | quant_modules
317 | and isinstance(quant_modules[-1], nn.Sequential)
318 | and isinstance(quant_modules[-1][-1], QuantizedModule)
319 | ):
320 | last_layer = quant_modules[-1][-1]
321 | input_quantizer = quant_modules[-1][-1].activation_quantizer
322 |
323 | if input_quantizer and tie_activation_quantizers:
324 | # If input quantizer is found the tie input/output act quantizers
325 | print(
326 | f"Tying input quantizer {i-1}^th layer of type {type(last_layer)} to the "
327 | f"quantized {type(model[i])} following it"
328 | )
329 | quant_modules.append(
330 | QuantizedActivationWrapper(
331 | model[i],
332 | tie_activation_quantizers=tie_activation_quantizers,
333 | input_quantizer=input_quantizer,
334 | **quant_params,
335 | )
336 | )
337 | else:
338 | # Input quantizer not found
339 | quant_modules.append(QuantizedActivationWrapper(model[i], **quant_params))
340 | if tie_activation_quantizers:
341 | warnings.warn("Input quantizer not found, so we do not tie quantizers")
342 | else:
343 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params))
344 | i += 1
345 | return nn.Sequential(*quant_modules)
346 |
347 |
348 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params):
349 | specials = specials or dict()
350 |
351 | if isinstance(model, nn.Sequential):
352 | quant_model = quantize_sequential(
353 | model, specials, tie_activation_quantizers, **quant_params
354 | )
355 |
356 | elif type(model) in specials:
357 | quant_model = specials[type(model)](model, **quant_params)
358 |
359 | elif isinstance(model, non_param_modules):
360 | quant_model = QuantizedActivationWrapper(model, **quant_params)
361 |
362 | elif type(model) in non_bn_module_map:
363 | # If we do isinstance() then we might run into issues with modules that inherit from
364 | # one of these classes, for whatever reason
365 | modtype = non_bn_module_map[type(model)]
366 | kwargs = get_module_args(model, None)
367 | quant_model = modtype(**kwargs, **quant_params)
368 |
369 | quant_model.weight.data = model.weight.data
370 | if getattr(model, "bias", None) is not None:
371 | quant_model.bias.data = model.bias.data
372 |
373 | else:
374 | # Unknown type, try to quantize all child modules
375 | quant_model = copy.deepcopy(model)
376 | for name, module in quant_model._modules.items():
377 | new_model = quantize_model(module, specials=specials, **quant_params)
378 | if new_model is not None:
379 | setattr(quant_model, name, new_model)
380 |
381 | return quant_model
382 |
--------------------------------------------------------------------------------
/quantization/base_quantized_classes.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from quantization.quantization_manager import QuantizationManager
9 | from quantization.quantizers import QuantizerBase, AsymmetricUniformQuantizer
10 | from quantization.quantizers.rounding_utils import round_ste_func
11 | from quantization.range_estimators import (
12 | RangeEstimatorBase,
13 | CurrentMinMaxEstimator,
14 | RunningMinMaxEstimator,
15 | )
16 |
17 |
18 | def _set_layer_learn_ranges(layer):
19 | if isinstance(layer, QuantizationManager):
20 | if layer.quantizer.is_initialized:
21 | layer.learn_ranges()
22 |
23 |
24 | def _set_layer_fix_ranges(layer):
25 | if isinstance(layer, QuantizationManager):
26 | if layer.quantizer.is_initialized:
27 | layer.fix_ranges()
28 |
29 |
30 | def _set_layer_estimate_ranges(layer):
31 | if isinstance(layer, QuantizationManager):
32 | layer.estimate_ranges()
33 |
34 |
35 | def _set_layer_estimate_ranges_train(layer):
36 | if isinstance(layer, QuantizationManager):
37 | if layer.quantizer.is_initialized:
38 | layer.estimate_ranges_train()
39 |
40 |
41 | class QuantizedModule(nn.Module):
42 | """
43 | Parent class for a quantized module. It adds the basic functionality of switching the module
44 | between quantized and full precision mode. It also defines the cached parameters and handles
45 | the reset of the cache properly.
46 | """
47 |
48 | def __init__(
49 | self,
50 | *args,
51 | method: QuantizerBase = AsymmetricUniformQuantizer,
52 | act_method=None,
53 | weight_range_method: RangeEstimatorBase = CurrentMinMaxEstimator,
54 | act_range_method: RangeEstimatorBase = RunningMinMaxEstimator,
55 | n_bits=8,
56 | n_bits_act=None,
57 | per_channel_weights=False,
58 | percentile=None,
59 | weight_range_options=None,
60 | act_range_options=None,
61 | scale_domain="linear",
62 | act_quant_kwargs={},
63 | weight_quant_kwargs={},
64 | weight_discretizer=round_ste_func,
65 | act_discretizer=round_ste_func,
66 | act_discretizer_args=tuple(),
67 | weight_discretizer_args=tuple(),
68 | quantize_input=False,
69 | **kwargs
70 | ):
71 | kwargs.pop("act_quant_dict", None)
72 |
73 | super().__init__(*args, **kwargs)
74 |
75 | self.method = method
76 | self.act_method = act_method or method
77 | self.n_bits = n_bits
78 | self.n_bits_act = n_bits_act or n_bits
79 | self.per_channel_weights = per_channel_weights
80 | self.percentile = percentile
81 | self.weight_range_method = weight_range_method
82 | self.weight_range_options = weight_range_options if weight_range_options else {}
83 | self.act_range_method = act_range_method
84 | self.act_range_options = act_range_options if act_range_options else {}
85 | self.scale_domain = scale_domain
86 | self.quantize_input = quantize_input
87 |
88 | self.quant_params = None
89 | self.register_buffer("_quant_w", torch.BoolTensor([False]))
90 | self.register_buffer("_quant_a", torch.BoolTensor([False]))
91 |
92 | self.act_qparams = dict(
93 | n_bits=self.n_bits_act,
94 | scale_domain=self.scale_domain,
95 | discretizer=act_discretizer,
96 | discretizer_args=act_discretizer_args,
97 | **act_quant_kwargs
98 | )
99 | self.weight_qparams = dict(
100 | n_bits=self.n_bits,
101 | scale_domain=self.scale_domain,
102 | discretizer=weight_discretizer,
103 | discretizer_args=weight_discretizer_args,
104 | **weight_quant_kwargs
105 | )
106 |
107 | def quantized_weights(self):
108 | self._quant_w = torch.BoolTensor([True])
109 |
110 | def full_precision_weights(self):
111 | self._quant_w = torch.BoolTensor([False])
112 |
113 | def quantized_acts(self):
114 | self._quant_a = torch.BoolTensor([True])
115 |
116 | def full_precision_acts(self):
117 | self._quant_a = torch.BoolTensor([False])
118 |
119 | def quantized(self):
120 | self.quantized_weights()
121 | self.quantized_acts()
122 |
123 | def full_precision(self):
124 | self.full_precision_weights()
125 | self.full_precision_acts()
126 |
127 | def get_quantizer_status(self):
128 | return dict(quant_a=self._quant_a.item(), quant_w=self._quant_w.item())
129 |
130 | def set_quantizer_status(self, quantizer_status):
131 | if quantizer_status["quant_a"]:
132 | self.quantized_acts()
133 | else:
134 | self.full_precision_acts()
135 |
136 | if quantizer_status["quant_w"]:
137 | self.quantized_weights()
138 | else:
139 | self.full_precision_weights()
140 |
141 | def learn_ranges(self):
142 | self.apply(_set_layer_learn_ranges)
143 |
144 | def fix_ranges(self):
145 | self.apply(_set_layer_fix_ranges)
146 |
147 | def estimate_ranges(self):
148 | self.apply(_set_layer_estimate_ranges)
149 |
150 | def estimate_ranges_train(self):
151 | self.apply(_set_layer_estimate_ranges_train)
152 |
153 | def extra_repr(self):
154 | quant_state = "weight_quant={}, act_quant={}".format(
155 | self._quant_w.item(), self._quant_a.item()
156 | )
157 | parent_repr = super().extra_repr()
158 | return "{},\n{}".format(parent_repr, quant_state) if parent_repr else quant_state
159 |
160 |
161 | class QuantizedActivation(QuantizedModule):
162 | def __init__(self, *args, **kwargs):
163 | super().__init__(*args, **kwargs)
164 | self.activation_quantizer = QuantizationManager(
165 | qmethod=self.act_method,
166 | qparams=self.act_qparams,
167 | init=self.act_range_method,
168 | range_estim_params=self.act_range_options,
169 | )
170 |
171 | def quantize_activations(self, x):
172 | if self._quant_a:
173 | return self.activation_quantizer(x)
174 | else:
175 | return x
176 |
177 | def forward(self, x):
178 | return self.quantize_activations(x)
179 |
180 |
181 | class FP32Acts(nn.Module):
182 | def forward(self, x):
183 | return x
184 |
185 | def reset_ranges(self):
186 | pass
187 |
--------------------------------------------------------------------------------
/quantization/base_quantized_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | from typing import Union, Dict
5 |
6 | import torch
7 | from torch import nn, Tensor
8 |
9 | from quantization.base_quantized_classes import (
10 | QuantizedModule,
11 | _set_layer_estimate_ranges,
12 | _set_layer_estimate_ranges_train,
13 | _set_layer_learn_ranges,
14 | _set_layer_fix_ranges,
15 | )
16 | from quantization.quantizers import QuantizerBase
17 |
18 |
19 | class QuantizedModel(nn.Module):
20 | """
21 | Parent class for a quantized model. This allows you to have convenience functions to put the
22 | whole model into quantization or full precision.
23 | """
24 |
25 | def __init__(self, input_size=(1, 3, 224, 224)):
26 | """
27 | Parameters
28 | ----------
29 | input_size: Tuple with the input dimension for the model (including batch dimension)
30 | """
31 | super().__init__()
32 | self.input_size = input_size
33 |
34 | def load_state_dict(
35 | self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True
36 | ):
37 | """
38 | This function overwrites the load_state_dict of nn.Module to ensure that quantization
39 | parameters are loaded correctly for quantized model.
40 |
41 | """
42 | quant_state_dict = {
43 | k: v for k, v in state_dict.items() if k.endswith("_quant_a") or k.endswith("_quant_w")
44 | }
45 |
46 | if quant_state_dict:
47 | super().load_state_dict(quant_state_dict, strict=False)
48 | else:
49 | raise ValueError(
50 | "The quantization states of activations or weights should be "
51 | "included in the state dict "
52 | )
53 | # Pass dummy data through quantized model to ensure all quantization parameters are
54 | # initialized with the correct dimensions (None tensors will lead to issues in state dict
55 | # loading)
56 | device = next(self.parameters()).device
57 | dummy_input = torch.rand(*self.input_size, device=device)
58 | with torch.no_grad():
59 | self.forward(dummy_input)
60 |
61 | # Load state dict
62 | super().load_state_dict(state_dict, strict)
63 |
64 | def quantized_weights(self):
65 | def _fn(layer):
66 | if isinstance(layer, QuantizedModule):
67 | layer.quantized_weights()
68 |
69 | self.apply(_fn)
70 |
71 | def full_precision_weights(self):
72 | def _fn(layer):
73 | if isinstance(layer, QuantizedModule):
74 | layer.full_precision_weights()
75 |
76 | self.apply(_fn)
77 |
78 | def quantized_acts(self):
79 | def _fn(layer):
80 | if isinstance(layer, QuantizedModule):
81 | layer.quantized_acts()
82 |
83 | self.apply(_fn)
84 |
85 | def full_precision_acts(self):
86 | def _fn(layer):
87 | if isinstance(layer, QuantizedModule):
88 | layer.full_precision_acts()
89 |
90 | self.apply(_fn)
91 |
92 | def quantized(self):
93 | def _fn(layer):
94 | if isinstance(layer, QuantizedModule):
95 | layer.quantized()
96 |
97 | self.apply(_fn)
98 |
99 | def full_precision(self):
100 | def _fn(layer):
101 | if isinstance(layer, QuantizedModule):
102 | layer.full_precision()
103 |
104 | self.apply(_fn)
105 |
106 | def estimate_ranges(self):
107 | self.apply(_set_layer_estimate_ranges)
108 |
109 | def estimate_ranges_train(self):
110 | self.apply(_set_layer_estimate_ranges_train)
111 |
112 | def set_quant_state(self, weight_quant, act_quant):
113 | if act_quant:
114 | self.quantized_acts()
115 | else:
116 | self.full_precision_acts()
117 |
118 | if weight_quant:
119 | self.quantized_weights()
120 | else:
121 | self.full_precision_weights()
122 |
123 | def grad_scaling(self, grad_scaling=True):
124 | def _fn(module):
125 | if isinstance(module, QuantizerBase):
126 | module.grad_scaling = grad_scaling
127 |
128 | self.apply(_fn)
129 | # Methods for switching quantizer quantization states
130 |
131 | def learn_ranges(self):
132 | self.apply(_set_layer_learn_ranges)
133 |
134 | def fix_ranges(self):
135 | self.apply(_set_layer_fix_ranges)
136 |
--------------------------------------------------------------------------------
/quantization/hijacker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import copy
6 |
7 | from timm.models.layers.activations import Swish, HardSwish, HardSigmoid
8 | from timm.models.layers.activations_me import SwishMe, HardSwishMe, HardSigmoidMe
9 | from torch import nn
10 |
11 | from quantization.base_quantized_classes import QuantizedModule
12 | from quantization.quantization_manager import QuantizationManager
13 | from quantization.range_estimators import RangeEstimators
14 |
15 | activations_set = [
16 | nn.ReLU,
17 | nn.ReLU6,
18 | nn.Hardtanh,
19 | nn.Sigmoid,
20 | nn.Tanh,
21 | nn.GELU,
22 | nn.PReLU,
23 | Swish,
24 | SwishMe,
25 | HardSwish,
26 | HardSwishMe,
27 | HardSigmoid,
28 | HardSigmoidMe,
29 | ]
30 |
31 |
32 | class QuantizationHijacker(QuantizedModule):
33 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and
34 | dequantization on the weights and output distributions.
35 |
36 | Usage:
37 | To make a quantized nn.Linear layer:
38 | class HijackedLinear(QuantizationHijacker, nn.Linear):
39 | pass
40 | """
41 |
42 | def __init__(self, *args, activation: nn.Module = None, **kwargs):
43 |
44 | super().__init__(*args, **kwargs)
45 | if activation:
46 | assert isinstance(activation, tuple(activations_set)), str(activation)
47 |
48 | self.activation_function = copy.deepcopy(activation) if activation else None
49 |
50 | self.activation_quantizer = QuantizationManager(
51 | qmethod=self.act_method,
52 | init=self.act_range_method,
53 | qparams=self.act_qparams,
54 | range_estim_params=self.act_range_options,
55 | )
56 |
57 | if self.weight_range_method == RangeEstimators.current_minmax:
58 | weight_init_params = dict(percentile=self.percentile)
59 | else:
60 | weight_init_params = self.weight_range_options
61 |
62 | self.weight_quantizer = QuantizationManager(
63 | qmethod=self.method,
64 | init=self.weight_range_method,
65 | per_channel=self.per_channel_weights,
66 | qparams=self.weight_qparams,
67 | range_estim_params=weight_init_params,
68 | )
69 |
70 | def forward(self, x, offsets=None):
71 | # Quantize input
72 | if self.quantize_input and self._quant_a:
73 | x = self.activation_quantizer(x)
74 |
75 | # Get quantized weight
76 | weight, bias = self.get_params()
77 | res = self.run_forward(x, weight, bias, offsets=offsets)
78 |
79 | # Apply fused activation function
80 | if self.activation_function is not None:
81 | res = self.activation_function(res)
82 |
83 | # Quantize output
84 | if not self.quantize_input and self._quant_a:
85 | res = self.activation_quantizer(res)
86 | return res
87 |
88 | def get_params(self):
89 |
90 | weight, bias = self.get_weight_bias()
91 |
92 | if self._quant_w:
93 | weight = self.quantize_weights(weight)
94 |
95 | return weight, bias
96 |
97 | def quantize_weights(self, weights):
98 | return self.weight_quantizer(weights)
99 |
100 | def get_weight_bias(self):
101 | bias = None
102 | if hasattr(self, "bias"):
103 | bias = self.bias
104 | return self.weight, bias
105 |
106 | def run_forward(self, x, weight, bias, offsets=None):
107 | # Performs the actual linear operation of the layer
108 | raise NotImplementedError()
109 |
110 | def extra_repr(self):
111 | activation = "input" if self.quantize_input else "output"
112 | return f"{super().extra_repr()}-{activation}"
113 |
--------------------------------------------------------------------------------
/quantization/quantization_manager.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from enum import auto
6 |
7 | from torch import nn
8 | from quantization.quantizers import QMethods, QuantizerBase
9 | from quantization.quantizers.utils import QuantizerNotInitializedError
10 | from quantization.range_estimators import RangeEstimators, RangeEstimatorBase
11 | from utils import BaseEnumOptions
12 |
13 |
14 | class QuantizationManager(nn.Module):
15 | """Implementation of Quantization and Quantization Range Estimation
16 |
17 | Parameters
18 | ----------
19 | n_bits: int
20 | Number of bits for the quantization.
21 | qmethod: QMethods member (Enum)
22 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform,
23 | qmn_uniform etc.
24 | init: RangeEstimators member (Enum)
25 | Initialization method for the grid from
26 | per_channel: bool
27 | If true, will use a separate quantization grid for each kernel/channel.
28 | x_min: float or PyTorch Tensor
29 | The minimum value which needs to be represented.
30 | x_max: float or PyTorch Tensor
31 | The maximum value which needs to be represented.
32 | qparams: kwargs
33 | dictionary of quantization parameters to passed to the quantizer instantiation
34 | range_estim_params: kwargs
35 | dictionary of parameters to passed to the range estimator instantiation
36 | """
37 |
38 | def __init__(
39 | self,
40 | qmethod: QuantizerBase = QMethods.symmetric_uniform.cls,
41 | init: RangeEstimatorBase = RangeEstimators.current_minmax.cls,
42 | per_channel=False,
43 | x_min=None,
44 | x_max=None,
45 | qparams=None,
46 | range_estim_params=None,
47 | ):
48 | super().__init__()
49 | self.state = Qstates.estimate_ranges
50 | self.qmethod = qmethod
51 | self.init = init
52 | self.per_channel = per_channel
53 | self.qparams = qparams if qparams else {}
54 | self.range_estim_params = range_estim_params if range_estim_params else {}
55 | self.range_estimator = None
56 |
57 | # define quantizer
58 | self.quantizer = self.qmethod(per_channel=self.per_channel, **qparams)
59 | self.quantizer.state = self.state
60 |
61 | # define range estimation method for quantizer initialisation
62 | if x_min is not None and x_max is not None:
63 | self.set_quant_range(x_min, x_max)
64 | self.fix_ranges()
65 | else:
66 | # set up the collector function to set the ranges
67 | self.range_estimator = self.init(
68 | per_channel=self.per_channel, quantizer=self.quantizer, **self.range_estim_params
69 | )
70 |
71 | @property
72 | def n_bits(self):
73 | return self.quantizer.n_bits
74 |
75 | def estimate_ranges(self):
76 | self.state = Qstates.estimate_ranges
77 | self.quantizer.state = self.state
78 |
79 | def fix_ranges(self):
80 | if self.quantizer.is_initialized:
81 | self.state = Qstates.fix_ranges
82 | self.quantizer.state = self.state
83 | else:
84 | raise QuantizerNotInitializedError()
85 |
86 | def learn_ranges(self):
87 | self.quantizer.make_range_trainable()
88 | self.state = Qstates.learn_ranges
89 | self.quantizer.state = self.state
90 |
91 | def estimate_ranges_train(self):
92 | self.state = Qstates.estimate_ranges_train
93 | self.quantizer.state = self.state
94 |
95 | def reset_ranges(self):
96 | self.range_estimator.reset()
97 | self.quantizer.reset()
98 | self.estimate_ranges()
99 |
100 | def forward(self, x):
101 | if self.state == Qstates.estimate_ranges or (
102 | self.state == Qstates.estimate_ranges_train and self.training
103 | ):
104 | # Note this can be per tensor or per channel
105 | cur_xmin, cur_xmax = self.range_estimator(x)
106 | self.set_quant_range(cur_xmin, cur_xmax)
107 |
108 | return self.quantizer(x)
109 |
110 | def set_quant_range(self, x_min, x_max):
111 | self.quantizer.set_quant_range(x_min, x_max)
112 |
113 | def extra_repr(self):
114 | return "state={}".format(self.state.name)
115 |
116 |
117 | class Qstates(BaseEnumOptions):
118 | estimate_ranges = auto() # ranges are updated in eval and train mode
119 | fix_ranges = auto() # quantization ranges are fixed for train and eval
120 | learn_ranges = auto() # quantization params are nn.Parameters
121 | estimate_ranges_train = auto() # quantization ranges are updated during train and fixed for
122 | # eval
123 |
--------------------------------------------------------------------------------
/quantization/quantized_folded_bn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from torch.nn.modules.conv import _ConvNd
8 |
9 | from quantization.hijacker import QuantizationHijacker
10 |
11 |
12 | class BNFusedHijacker(QuantizationHijacker):
13 | """Extension to the QuantizationHijacker that fuses batch normalization (BN) after a weight
14 | layer into a joined module. The parameters and the statistics of the BN layer remain in
15 | full-precision.
16 | """
17 |
18 | def __init__(self, *args, **kwargs):
19 | kwargs.pop("bias", None) # Bias will be learned by BN params
20 | super().__init__(*args, **kwargs, bias=False)
21 | bn_dim = self.get_bn_dim()
22 | self.register_buffer("running_mean", torch.zeros(bn_dim))
23 | self.register_buffer("running_var", torch.ones(bn_dim))
24 | self.momentum = kwargs.pop("momentum", 0.1)
25 | self.gamma = nn.Parameter(torch.ones(bn_dim))
26 | self.beta = nn.Parameter(torch.zeros(bn_dim))
27 | self.epsilon = kwargs.get("eps", 1e-5)
28 | self.bias = None
29 |
30 | def forward(self, x):
31 | # Quantize input
32 | if self.quantize_input and self._quant_a:
33 | x = self.activation_quantizer(x)
34 |
35 | # Get quantized weight
36 | weight, bias = self.get_params()
37 | res = self.run_forward(x, weight, bias)
38 |
39 | res = F.batch_norm(
40 | res,
41 | self.running_mean,
42 | self.running_var,
43 | self.gamma,
44 | self.beta,
45 | self.training,
46 | self.momentum,
47 | self.epsilon,
48 | )
49 | # Apply fused activation function
50 | if self.activation_function is not None:
51 | res = self.activation_function(res)
52 |
53 | # Quantize output
54 | if not self.quantize_input and self._quant_a:
55 | res = self.activation_quantizer(res)
56 | return res
57 |
58 | def get_bn_dim(self):
59 | if isinstance(self, nn.Linear):
60 | return self.out_features
61 | elif isinstance(self, _ConvNd):
62 | return self.out_channels
63 | else:
64 | msg = (
65 | f"Unsupported type used: {self}. Must be a linear or (transpose)-convolutional "
66 | f"nn.Module"
67 | )
68 | raise NotImplementedError(msg)
69 |
--------------------------------------------------------------------------------
/quantization/quantizers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from quantization.quantizers.base_quantizers import QuantizerBase
6 | from quantization.quantizers.uniform_quantizers import (
7 | SymmetricUniformQuantizer,
8 | AsymmetricUniformQuantizer,
9 | )
10 | from utils import ClassEnumOptions, MethodMap
11 |
12 |
13 | class QMethods(ClassEnumOptions):
14 | symmetric_uniform = MethodMap(SymmetricUniformQuantizer)
15 | asymmetric_uniform = MethodMap(AsymmetricUniformQuantizer)
16 |
--------------------------------------------------------------------------------
/quantization/quantizers/base_quantizers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from torch import nn
6 |
7 |
8 | class QuantizerBase(nn.Module):
9 | def __init__(self, n_bits, per_channel=False, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self.n_bits = n_bits
12 | self.per_channel = per_channel
13 | self.state = None
14 | self.x_min_fp32 = self.x_max_fp32 = None
15 |
16 | @property
17 | def is_initialized(self):
18 | raise NotImplementedError()
19 |
20 | @property
21 | def x_max(self):
22 | raise NotImplementedError()
23 |
24 | @property
25 | def symmetric(self):
26 | raise NotImplementedError()
27 |
28 | @property
29 | def x_min(self):
30 | raise NotImplementedError()
31 |
32 | def forward(self, x_float):
33 | raise NotImplementedError()
34 |
35 | def _adjust_params_per_channel(self, x):
36 | raise NotImplementedError()
37 |
38 | def set_quant_range(self, x_min, x_max):
39 | raise NotImplementedError()
40 |
41 | def extra_repr(self):
42 | return "n_bits={}, per_channel={}, is_initalized={}".format(
43 | self.n_bits, self.per_channel, self.is_initialized
44 | )
45 |
46 | def reset(self):
47 | self._delta = None
48 |
--------------------------------------------------------------------------------
/quantization/quantizers/rounding_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from torch import nn
5 | import torch
6 | from torch.autograd import Function
7 |
8 | # Functional
9 | from utils import MethodMap, ClassEnumOptions
10 |
11 |
12 | class RoundStraightThrough(Function):
13 | @staticmethod
14 | def forward(ctx, x):
15 | return torch.round(x)
16 |
17 | @staticmethod
18 | def backward(ctx, output_grad):
19 | return output_grad
20 |
21 |
22 | class StochasticRoundSTE(Function):
23 | @staticmethod
24 | def forward(ctx, x):
25 | # Sample noise between [0, 1)
26 | noise = torch.rand_like(x)
27 | return torch.floor(x + noise)
28 |
29 | @staticmethod
30 | def backward(ctx, output_grad):
31 | return output_grad
32 |
33 |
34 | class ScaleGradient(Function):
35 | @staticmethod
36 | def forward(ctx, x, scale):
37 | ctx.scale = scale
38 | return x
39 |
40 | @staticmethod
41 | def backward(ctx, output_grad):
42 | return output_grad * ctx.scale, None
43 |
44 |
45 | class EWGSFunctional(Function):
46 | """
47 | x_in: float input
48 | scaling_factor: backward scaling factor
49 | x_out: discretized version of x_in within the range of [0,1]
50 | """
51 |
52 | @staticmethod
53 | def forward(ctx, x_in, scaling_factor):
54 | x_int = torch.round(x_in)
55 | ctx._scaling_factor = scaling_factor
56 | ctx.save_for_backward(x_in - x_int)
57 | return x_int
58 |
59 | @staticmethod
60 | def backward(ctx, g):
61 | diff = ctx.saved_tensors[0]
62 | delta = ctx._scaling_factor
63 | scale = 1 + delta * torch.sign(g) * diff
64 | return g * scale, None, None
65 |
66 |
67 | class StackSigmoidFunctional(Function):
68 | @staticmethod
69 | def forward(ctx, x, alpha):
70 | # Apply round to nearest in the forward pass
71 | ctx.save_for_backward(x, alpha)
72 | return torch.round(x)
73 |
74 | @staticmethod
75 | def backward(ctx, grad_output):
76 | x, alpha = ctx.saved_tensors
77 | sig_min = torch.sigmoid(alpha / 2)
78 | sig_scale = 1 - 2 * sig_min
79 | x_base = torch.floor(x).detach()
80 | x_rest = x - x_base - 0.5
81 | stacked_sigmoid_grad = (
82 | torch.sigmoid(x_rest * -alpha)
83 | * (1 - torch.sigmoid(x_rest * -alpha))
84 | * -alpha
85 | / sig_scale
86 | )
87 | return stacked_sigmoid_grad * grad_output, None
88 |
89 |
90 | # Parametrized modules
91 | class ParametrizedGradEstimatorBase(nn.Module):
92 | def __init__(self, *args, **kwargs):
93 | super().__init__()
94 | self._trainable = False
95 |
96 | def make_grad_params_trainable(self):
97 | self._trainable = True
98 | for name, buf in self.named_buffers(recurse=False):
99 | setattr(self, name, torch.nn.Parameter(buf))
100 |
101 | def make_grad_params_tensor(self):
102 | self._trainable = False
103 | for name, param in self.named_parameters(recurse=False):
104 | cur_value = param.data
105 | delattr(self, name)
106 | self.register_buffer(name, cur_value)
107 |
108 | def forward(self, x):
109 | raise NotImplementedError()
110 |
111 |
112 | class StackedSigmoid(ParametrizedGradEstimatorBase):
113 | """
114 | Stacked sigmoid estimator based on a simulated sigmoid forward pass
115 | """
116 |
117 | def __init__(self, alpha=1.0):
118 | super().__init__()
119 | self.register_buffer("alpha", torch.tensor(alpha))
120 |
121 | def forward(self, x):
122 | return stacked_sigmoid_func(x, self.alpha)
123 |
124 | def extra_repr(self):
125 | return f"alpha={self.alpha.item()}"
126 |
127 |
128 | class EWGSDiscretizer(ParametrizedGradEstimatorBase):
129 | def __init__(self, scaling_factor=0.2):
130 | super().__init__()
131 | self.register_buffer("scaling_factor", torch.tensor(scaling_factor))
132 |
133 | def forward(self, x):
134 | return ewgs_func(x, self.scaling_factor)
135 |
136 | def extra_repr(self):
137 | return f"scaling_factor={self.scaling_factor.item()}"
138 |
139 |
140 | class StochasticRounding(nn.Module):
141 | def __init__(self):
142 | super().__init__()
143 |
144 | def forward(self, x):
145 | if self.training:
146 | return stochastic_round_ste_func(x)
147 | else:
148 | return round_ste_func(x)
149 |
150 |
151 | round_ste_func = RoundStraightThrough.apply
152 | stacked_sigmoid_func = StackSigmoidFunctional.apply
153 | scale_grad_func = ScaleGradient.apply
154 | stochastic_round_ste_func = StochasticRoundSTE.apply
155 | ewgs_func = EWGSFunctional.apply
156 |
157 |
158 | class GradientEstimator(ClassEnumOptions):
159 | ste = MethodMap(round_ste_func)
160 | stoch_round = MethodMap(StochasticRounding)
161 | ewgs = MethodMap(EWGSDiscretizer)
162 | stacked_sigmoid = MethodMap(StackedSigmoid)
163 |
--------------------------------------------------------------------------------
/quantization/quantizers/uniform_quantizers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import inspect
6 | import torch
7 |
8 | from quantization.quantizers.rounding_utils import scale_grad_func, round_ste_func
9 | from .utils import QuantizerNotInitializedError
10 | from .base_quantizers import QuantizerBase
11 |
12 |
13 | class AsymmetricUniformQuantizer(QuantizerBase):
14 | """
15 | PyTorch Module that implements Asymmetric Uniform Quantization using STE.
16 | Quantizes its argument in the forward pass, passes the gradient 'straight
17 | through' on the backward pass, ignoring the quantization that occurred.
18 |
19 | Parameters
20 | ----------
21 | n_bits: int
22 | Number of bits for quantization.
23 | scale_domain: str ('log', 'linear) with default='linear'
24 | Domain of scale factor
25 | per_channel: bool
26 | If True: allows for per-channel quantization
27 | """
28 |
29 | def __init__(
30 | self,
31 | n_bits,
32 | scale_domain="linear",
33 | discretizer=round_ste_func,
34 | discretizer_args=tuple(),
35 | grad_scaling=False,
36 | eps=1e-8,
37 | **kwargs
38 | ):
39 | super().__init__(n_bits=n_bits, **kwargs)
40 |
41 | assert scale_domain in ("linear", "log")
42 | self.register_buffer("_delta", None)
43 | self.register_buffer("_zero_float", None)
44 |
45 | if inspect.isclass(discretizer):
46 | self.discretizer = discretizer(*discretizer_args)
47 | else:
48 | self.discretizer = discretizer
49 |
50 | self.scale_domain = scale_domain
51 | self.grad_scaling = grad_scaling
52 | self.eps = eps
53 |
54 | # A few useful properties
55 | @property
56 | def delta(self):
57 | if self._delta is not None:
58 | return self._delta
59 | else:
60 | raise QuantizerNotInitializedError()
61 |
62 | @property
63 | def zero_float(self):
64 | if self._zero_float is not None:
65 | return self._zero_float
66 | else:
67 | raise QuantizerNotInitializedError()
68 |
69 | @property
70 | def is_initialized(self):
71 | return self._delta is not None
72 |
73 | @property
74 | def symmetric(self):
75 | return False
76 |
77 | @property
78 | def int_min(self):
79 | # integer grid minimum
80 | return 0.0
81 |
82 | @property
83 | def int_max(self):
84 | # integer grid maximum
85 | return 2.0**self.n_bits - 1
86 |
87 | @property
88 | def scale(self):
89 | if self.scale_domain == "linear":
90 | return torch.clamp(self.delta, min=self.eps)
91 | elif self.scale_domain == "log":
92 | return torch.exp(self.delta)
93 |
94 | @property
95 | def zero_point(self):
96 | zero_point = self.discretizer(self.zero_float)
97 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max)
98 | return zero_point
99 |
100 | @property
101 | def x_max(self):
102 | return self.scale * (self.int_max - self.zero_point)
103 |
104 | @property
105 | def x_min(self):
106 | return self.scale * (self.int_min - self.zero_point)
107 |
108 | def to_integer_forward(self, x_float, *args, **kwargs):
109 | """
110 | Qunatized input to its integer representation
111 | Parameters
112 | ----------
113 | x_float: PyTorch Float Tensor
114 | Full-precision Tensor
115 |
116 | Returns
117 | -------
118 | x_int: PyTorch Float Tensor of integers
119 | """
120 | if self.grad_scaling:
121 | grad_scale = self.calculate_grad_scale(x_float)
122 | scale = scale_grad_func(self.scale, grad_scale)
123 | zero_point = (
124 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale)
125 | )
126 | else:
127 | scale = self.scale
128 | zero_point = self.zero_point
129 |
130 | x_int = self.discretizer(x_float / scale) + zero_point
131 | x_int = torch.clamp(x_int, self.int_min, self.int_max)
132 |
133 | return x_int
134 |
135 | def forward(self, x_float, *args, **kwargs):
136 | """
137 | Quantizes (quantized to integer and the scales back to original domain)
138 | Parameters
139 | ----------
140 | x_float: PyTorch Float Tensor
141 | Full-precision Tensor
142 |
143 | Returns
144 | -------
145 | x_quant: PyTorch Float Tensor
146 | Quantized-Dequantized Tensor
147 | """
148 | if self.per_channel:
149 | self._adjust_params_per_channel(x_float)
150 |
151 | if self.grad_scaling:
152 | grad_scale = self.calculate_grad_scale(x_float)
153 | scale = scale_grad_func(self.scale, grad_scale)
154 | zero_point = (
155 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale)
156 | )
157 | else:
158 | scale = self.scale
159 | zero_point = self.zero_point
160 |
161 | x_int = self.to_integer_forward(x_float, *args, **kwargs)
162 | x_quant = scale * (x_int - zero_point)
163 |
164 | return x_quant
165 |
166 | def calculate_grad_scale(self, quant_tensor):
167 | num_pos_levels = self.int_max # Qp in LSQ paper
168 | num_elements = quant_tensor.numel() # nfeatures or nweights in LSQ paper
169 | if self.per_channel:
170 | # In the per tensor case we do not sum the gradients over the output channel dimension
171 | num_elements /= quant_tensor.shape[0]
172 |
173 | return (num_pos_levels * num_elements) ** -0.5 # 1 / sqrt (Qn * nfeatures)
174 |
175 | def _adjust_params_per_channel(self, x):
176 | """
177 | Adjusts the quantization parameter tensors (delta, zero_float)
178 | to the input tensor shape if they don't match
179 | Parameters
180 | ----------
181 | x: input tensor
182 | """
183 | if x.ndim != self.delta.ndim:
184 | new_shape = [-1] + [1] * (len(x.shape) - 1)
185 | self._delta = self.delta.view(new_shape)
186 | if self._zero_float is not None:
187 | self._zero_float = self._zero_float.view(new_shape)
188 |
189 | def _tensorize_min_max(self, x_min, x_max):
190 | """
191 | Converts provided min max range into tensors
192 | Parameters
193 | ----------
194 | x_min: float or PyTorch 1D tensor
195 | x_max: float or PyTorch 1D tensor
196 |
197 | Returns
198 | -------
199 | x_min: PyTorch Tensor 0 or 1-D
200 | x_max: PyTorch Tensor 0 or 1-D
201 | """
202 | # Ensure a torch tensor
203 | if not torch.is_tensor(x_min):
204 | x_min = torch.tensor(x_min).float()
205 | x_max = torch.tensor(x_max).float()
206 |
207 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel:
208 | print(x_min)
209 | print(self.per_channel)
210 | raise ValueError(
211 | "x_min and x_max must be a float or 1-D Tensor"
212 | " for per-tensor quantization (per_channel=False)"
213 | )
214 | # Ensure we always use zero and avoid division by zero
215 | x_min = torch.min(x_min, torch.zeros_like(x_min))
216 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps)
217 |
218 | return x_min, x_max
219 |
220 | def set_quant_range(self, x_min, x_max):
221 | """
222 | Instantiates the quantization parameters based on the provided
223 | min and max range
224 |
225 | Parameters
226 | ----------
227 | x_min: tensor or float
228 | Quantization range minimum limit
229 | x_max: tensor of float
230 | Quantization range minimum limit
231 | """
232 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max
233 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
234 | self._delta = (x_max - x_min) / self.int_max
235 | self._zero_float = (-x_min / self.delta).detach()
236 |
237 | if self.scale_domain == "log":
238 | self._delta = torch.log(self.delta)
239 |
240 | self._delta = self._delta.detach()
241 |
242 | def make_range_trainable(self):
243 | # Converts trainable parameters to nn.Parameters
244 | if self.delta not in self.parameters():
245 | self._delta = torch.nn.Parameter(self._delta)
246 | self._zero_float = torch.nn.Parameter(self._zero_float)
247 |
248 | def fix_ranges(self):
249 | # Removes trainable quantization params from nn.Parameters
250 | if self.delta in self.parameters():
251 | _delta = self._delta.data
252 | _zero_float = self._zero_float.data
253 | del self._delta # delete the parameter
254 | del self._zero_float
255 | self.register_buffer("_delta", _delta)
256 | self.register_buffer("_zero_float", _zero_float)
257 |
258 |
259 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer):
260 | """
261 | PyTorch Module that implements Symmetric Uniform Quantization using STE.
262 | Quantizes its argument in the forward pass, passes the gradient 'straight
263 | through' on the backward pass, ignoring the quantization that occurred.
264 |
265 | Parameters
266 | ----------
267 | n_bits: int
268 | Number of bits for quantization.
269 | scale_domain: str ('log', 'linear) with default='linear'
270 | Domain of scale factor
271 | per_channel: bool
272 | If True: allows for per-channel quantization
273 | """
274 |
275 | def __init__(self, *args, **kwargs):
276 | super().__init__(*args, **kwargs)
277 | self.register_buffer("_signed", None)
278 |
279 | @property
280 | def signed(self):
281 | if self._signed is not None:
282 | return self._signed.item()
283 | else:
284 | raise QuantizerNotInitializedError()
285 |
286 | @property
287 | def symmetric(self):
288 | return True
289 |
290 | @property
291 | def int_min(self):
292 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0
293 |
294 | @property
295 | def int_max(self):
296 | pos_n_bits = self.n_bits - self.signed
297 | return 2.0**pos_n_bits - 1
298 |
299 | @property
300 | def zero_point(self):
301 | return 0.0
302 |
303 | def set_quant_range(self, x_min, x_max):
304 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max
305 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
306 | self._signed = x_min.min() < 0
307 |
308 | x_absmax = torch.max(x_min.abs(), x_max)
309 | self._delta = x_absmax / self.int_max
310 |
311 | if self.scale_domain == "log":
312 | self._delta = torch.log(self._delta)
313 |
314 | self._delta = self._delta.detach()
315 |
316 | def make_range_trainable(self):
317 | # Converts trainable parameters to nn.Parameters
318 | if self.delta not in self.parameters():
319 | self._delta = torch.nn.Parameter(self._delta)
320 |
321 | def fix_ranges(self):
322 | # Removes trainable quantization params from nn.Parameters
323 | if self.delta in self.parameters():
324 | _delta = self._delta.data
325 | del self._delta # delete the parameter
326 | self.register_buffer("_delta", _delta)
327 |
--------------------------------------------------------------------------------
/quantization/quantizers/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 |
6 | class QuantizerNotInitializedError(Exception):
7 | """Raised when a quantizer has not been initialized"""
8 |
9 | def __init__(self):
10 | super(QuantizerNotInitializedError, self).__init__(
11 | "Quantizer has not been initialized yet"
12 | )
13 |
--------------------------------------------------------------------------------
/quantization/range_estimators.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import copy
5 | from enum import auto
6 |
7 | import numpy as np
8 | import torch
9 | from scipy.optimize import minimize_scalar
10 | from torch import nn
11 |
12 | from utils import to_numpy, BaseEnumOptions, MethodMap, ClassEnumOptions
13 |
14 |
15 | class RangeEstimatorBase(nn.Module):
16 | def __init__(self, per_channel=False, quantizer=None, *args, **kwargs):
17 | super().__init__(*args, **kwargs)
18 | self.register_buffer("current_xmin", None)
19 | self.register_buffer("current_xmax", None)
20 | self.per_channel = per_channel
21 | self.quantizer = quantizer
22 |
23 | def forward(self, x):
24 | """
25 | Accepts an input tensor, updates the current estimates of x_min and x_max
26 | and returns them.
27 | Parameters
28 | ----------
29 | x: Input tensor
30 |
31 | Returns
32 | -------
33 | self.current_xmin: tensor
34 |
35 | self.current_xmax: tensor
36 |
37 | """
38 | raise NotImplementedError()
39 |
40 | def reset(self):
41 | """
42 | Reset the range estimator.
43 | """
44 | self.current_xmin = None
45 | self.current_xmax = None
46 |
47 | def __repr__(self):
48 | # We overwrite this from nn.Module as we do not want to have submodules such as
49 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module.
50 | lines = self.extra_repr().split("\n")
51 | extra_str = lines[0] if len(lines) == 1 else "\n " + "\n ".join(lines) + "\n"
52 |
53 | return self._get_name() + "(" + extra_str + ")"
54 |
55 |
56 | class CurrentMinMaxEstimator(RangeEstimatorBase):
57 | def __init__(self, percentile=None, *args, **kwargs):
58 | self.percentile = percentile
59 | super().__init__(*args, **kwargs)
60 |
61 | def forward(self, x):
62 | if self.per_channel:
63 | x = x.view(x.shape[0], -1)
64 | if self.percentile:
65 | axis = -1 if self.per_channel else None
66 | data_np = to_numpy(x)
67 | x_min, x_max = np.percentile(
68 | data_np, (self.percentile, 100 - self.percentile), axis=axis
69 | )
70 | self.current_xmin = torch.tensor(x_min).to(x.device)
71 | self.current_xmax = torch.tensor(x_max).to(x.device)
72 | else:
73 | self.current_xmin = x.min(-1)[0].detach() if self.per_channel else x.min().detach()
74 | self.current_xmax = x.max(-1)[0].detach() if self.per_channel else x.max().detach()
75 |
76 | return self.current_xmin, self.current_xmax
77 |
78 |
79 | class RunningMinMaxEstimator(RangeEstimatorBase):
80 | def __init__(self, momentum=0.9, *args, **kwargs):
81 | self.momentum = momentum
82 | super().__init__(*args, **kwargs)
83 |
84 | def forward(self, x):
85 | if self.per_channel:
86 | # Along 1st dim
87 | x_flattened = x.view(x.shape[0], -1)
88 | x_min = x_flattened.min(-1)[0].detach()
89 | x_max = x_flattened.max(-1)[0].detach()
90 | else:
91 | x_min = torch.min(x).detach()
92 | x_max = torch.max(x).detach()
93 |
94 | if self.current_xmin is None:
95 | self.current_xmin = x_min
96 | self.current_xmax = x_max
97 | else:
98 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin
99 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax
100 |
101 | return self.current_xmin, self.current_xmax
102 |
103 |
104 | class OptMethod(BaseEnumOptions):
105 | grid = auto()
106 | golden_section = auto()
107 |
108 |
109 | class MSE_Estimator(RangeEstimatorBase):
110 | def __init__(
111 | self, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, *args, **kwargs
112 | ):
113 |
114 | super().__init__(*args, **kwargs)
115 | assert opt_method in OptMethod
116 | self.opt_method = opt_method
117 | self.num_candidates = num_candidates
118 | self.loss_array = None
119 | self.max_pos_thr = None
120 | self.max_neg_thr = None
121 | self.max_search_range = None
122 | self.one_sided_dist = None
123 | self.range_margin = range_margin
124 | if self.quantizer is None:
125 | raise NotImplementedError(
126 | "A Quantizer must be given as an argument to the MSE Range Estimator"
127 | )
128 | self.max_int_skew = (2**self.quantizer.n_bits) // 4 # For asymmetric quantization
129 |
130 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False):
131 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr)
132 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1)
133 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel
134 | # grid search
135 | if per_channel_loss:
136 | return to_numpy(temp_sum)
137 | else:
138 | return to_numpy(torch.sum(temp_sum))
139 |
140 | @property
141 | def step_size(self):
142 | if self.one_sided_dist is None:
143 | raise NoDataPassedError()
144 |
145 | return self.max_search_range / self.num_candidates
146 |
147 | @property
148 | def optimization_method(self):
149 | if self.one_sided_dist is None:
150 | raise NoDataPassedError()
151 |
152 | if self.opt_method == OptMethod.grid:
153 | # Grid search method
154 | if self.one_sided_dist or self.quantizer.symmetric:
155 | # 1-D grid search
156 | return self._perform_1D_search
157 | else:
158 | # 2-D grid_search
159 | return self._perform_2D_search
160 | elif self.opt_method == OptMethod.golden_section:
161 | # Golden section method
162 | if self.one_sided_dist or self.quantizer.symmetric:
163 | return self._golden_section_symmetric
164 | else:
165 | return self._golden_section_asymmetric
166 | else:
167 | raise NotImplementedError("Optimization Method not Implemented")
168 |
169 | def quantize(self, x_float, x_min=None, x_max=None):
170 | temp_q = copy.deepcopy(self.quantizer)
171 | # In the current implementation no optimization procedure requires temp quantizer for
172 | # loss_fx to be per-channel
173 | temp_q.per_channel = False
174 | if x_min or x_max:
175 | temp_q.set_quant_range(x_min, x_max)
176 | return temp_q(x_float)
177 |
178 | def golden_sym_loss(self, range, data):
179 | """
180 | Loss function passed to the golden section optimizer from scipy in case of symmetric
181 | quantization
182 | """
183 | neg_thr = 0 if self.one_sided_dist else -range
184 | pos_thr = range
185 | return self.loss_fx(data, neg_thr, pos_thr)
186 |
187 | def golden_asym_shift_loss(self, shift, range, data):
188 | """
189 | Inner Loss function (shift) passed to the golden section optimizer from scipy
190 | in case of asymmetric quantization
191 | """
192 | pos_thr = range + shift
193 | neg_thr = -range + shift
194 | return self.loss_fx(data, neg_thr, pos_thr)
195 |
196 | def golden_asym_range_loss(self, range, data):
197 | """
198 | Outer Loss function (range) passed to the golden section optimizer from scipy in case of
199 | asymmetric quantization
200 | """
201 | temp_delta = 2 * range / (2**self.quantizer.n_bits - 1)
202 | max_shift = temp_delta * self.max_int_skew
203 | result = minimize_scalar(
204 | self.golden_asym_shift_loss,
205 | args=(range, data),
206 | bounds=(-max_shift, max_shift),
207 | method="Bounded",
208 | )
209 | return result.fun
210 |
211 | def _define_search_range(self, data):
212 | self.channel_groups = len(data) if self.per_channel else 1
213 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device)
214 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device)
215 |
216 | if self.one_sided_dist or self.quantizer.symmetric:
217 | # 1D search space
218 | self.loss_array = np.zeros(
219 | (self.channel_groups, self.num_candidates + 1)
220 | ) # 1D search space
221 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish
222 | # Defining the search range for clipping thresholds
223 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin
224 | self.max_neg_thr = -self.max_pos_thr
225 | self.max_search_range = self.max_pos_thr
226 | else:
227 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth
228 | # index represents whether the skew is positive (0) or negative (1))
229 | self.loss_array = np.zeros(
230 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2]
231 | ) # 2D search space
232 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish
233 | # Define the search range for clipping thresholds in asymmetric case
234 | self.max_pos_thr = float(data.max()) + self.range_margin
235 | self.max_neg_thr = float(data.min()) - self.range_margin
236 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr))
237 |
238 | def _perform_1D_search(self, data):
239 | """
240 | Grid search through all candidate quantizers in 1D to find the best
241 | The loss is accumulated over all batches without any momentum
242 | :param data: input tensor
243 | """
244 | for cand_index in range(1, self.num_candidates + 1):
245 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index
246 | pos_thr = self.step_size * cand_index
247 |
248 | self.loss_array[:, cand_index] += self.loss_fx(
249 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel
250 | )
251 | # find the best clipping thresholds
252 | min_cand = self.loss_array.argmin(axis=1)
253 | xmin = (
254 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand
255 | ).astype(np.single)
256 | xmax = (self.step_size * min_cand).astype(np.single)
257 | self.current_xmax = torch.tensor(xmax).to(device=data.device)
258 | self.current_xmin = torch.tensor(xmin).to(device=data.device)
259 |
260 | def _perform_2D_search(self, data):
261 | """
262 | Grid search through all candidate quantizers in 1D to find the best
263 | The loss is accumulated over all batches without any momentum
264 | Parameters
265 | ----------
266 | data: PyTorch Tensor
267 | Returns
268 | -------
269 |
270 | """
271 | for cand_index in range(1, self.num_candidates + 1):
272 | # defining the symmetric quantization range
273 | temp_start = -self.step_size * cand_index
274 | temp_finish = self.step_size * cand_index
275 | temp_delta = float(temp_finish - temp_start) / (2**self.quantizer.n_bits - 1)
276 | for shift in range(self.max_int_skew):
277 | for reverse in range(2):
278 | # introducing asymmetry in the quantization range
279 | skew = ((-1) ** reverse) * shift * temp_delta
280 | neg_thr = max(temp_start + skew, self.max_neg_thr)
281 | pos_thr = min(temp_finish + skew, self.max_pos_thr)
282 |
283 | self.loss_array[:, cand_index, shift, reverse] += self.loss_fx(
284 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel
285 | )
286 |
287 | for channel_index in range(self.channel_groups):
288 | min_cand, min_shift, min_reverse = np.unravel_index(
289 | np.argmin(self.loss_array[channel_index], axis=None),
290 | self.loss_array[channel_index].shape,
291 | )
292 | min_interval_start = -self.step_size * min_cand
293 | min_interval_finish = self.step_size * min_cand
294 | min_delta = float(min_interval_finish - min_interval_start) / (
295 | 2**self.quantizer.n_bits - 1
296 | )
297 | min_skew = ((-1) ** min_reverse) * min_shift * min_delta
298 | xmin = max(min_interval_start + min_skew, self.max_neg_thr)
299 | xmax = min(min_interval_finish + min_skew, self.max_pos_thr)
300 |
301 | self.current_xmin[channel_index] = torch.tensor(xmin).to(device=data.device)
302 | self.current_xmax[channel_index] = torch.tensor(xmax).to(device=data.device)
303 |
304 | def _golden_section_symmetric(self, data):
305 | for channel_index in range(self.channel_groups):
306 | if channel_index == 0 and not self.per_channel:
307 | data_segment = data
308 | else:
309 | data_segment = data[channel_index]
310 |
311 | self.result = minimize_scalar(
312 | self.golden_sym_loss,
313 | args=data_segment,
314 | bounds=(0.01 * self.max_search_range, self.max_search_range),
315 | method="Bounded",
316 | )
317 | self.current_xmax[channel_index] = torch.tensor(self.result.x).to(device=data.device)
318 | self.current_xmin[channel_index] = (
319 | torch.tensor(0.0).to(device=data.device)
320 | if self.one_sided_dist
321 | else -self.current_xmax[channel_index]
322 | )
323 |
324 | def _golden_section_asymmetric(self, data):
325 | for channel_index in range(self.channel_groups):
326 | if channel_index == 0 and not self.per_channel:
327 | data_segment = data
328 | else:
329 | data_segment = data[channel_index]
330 |
331 | self.result = minimize_scalar(
332 | self.golden_asym_range_loss,
333 | args=data_segment,
334 | bounds=(0.01 * self.max_search_range, self.max_search_range),
335 | method="Bounded",
336 | )
337 | self.final_range = self.result.x
338 | temp_delta = 2 * self.final_range / (2**self.quantizer.n_bits - 1)
339 | max_shift = temp_delta * self.max_int_skew
340 | self.subresult = minimize_scalar(
341 | self.golden_asym_shift_loss,
342 | args=(self.final_range, data_segment),
343 | bounds=(-max_shift, max_shift),
344 | method="Bounded",
345 | )
346 | self.final_shift = self.subresult.x
347 | self.current_xmax[channel_index] = torch.tensor(self.final_range + self.final_shift).to(
348 | device=data.device
349 | )
350 | self.current_xmin[channel_index] = torch.tensor(
351 | -self.final_range + self.final_shift
352 | ).to(device=data.device)
353 |
354 | def forward(self, data):
355 | if self.loss_array is None:
356 | # Initialize search range on first batch, and accumulate losses with subsequent calls
357 |
358 | # Decide whether input distribution is one-sided
359 | if self.one_sided_dist is None:
360 | self.one_sided_dist = bool((data.min() >= 0).item())
361 |
362 | # Define search
363 | self._define_search_range(data)
364 |
365 | # Perform Search/Optimization for Quantization Ranges
366 | self.optimization_method(data)
367 |
368 | return self.current_xmin, self.current_xmax
369 |
370 | def reset(self):
371 | super().reset()
372 | self.loss_array = None
373 |
374 | def extra_repr(self):
375 | repr = "opt_method={}".format(self.opt_method.name)
376 | if self.opt_method == OptMethod.grid:
377 | repr += " ,num_candidates={}".format(self.num_candidates)
378 | return repr
379 |
380 |
381 | class NoDataPassedError(Exception):
382 | """Raised data has been passed into the Range Estimator"""
383 |
384 | def __init__(self):
385 | super().__init__("Data must be pass through the range estimator to be initialized")
386 |
387 |
388 | class RangeEstimators(ClassEnumOptions):
389 | current_minmax = MethodMap(CurrentMinMaxEstimator)
390 | running_minmax = MethodMap(RunningMinMaxEstimator)
391 | MSE = MethodMap(MSE_Estimator)
392 |
--------------------------------------------------------------------------------
/quantization/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 |
6 | import torch
7 | import torch.serialization
8 |
9 | from quantization.quantizers import QuantizerBase
10 | from quantization.quantizers.rounding_utils import ParametrizedGradEstimatorBase
11 | from quantization.range_estimators import RangeEstimators
12 | from utils import StopForwardException, get_layer_by_name
13 |
14 |
15 | def separate_quantized_model_params(quant_model):
16 | """
17 | This method separates the parameters of the quantized model to 4 categories.
18 | Parameters
19 | ----------
20 | quant_model: (QuantizedModel)
21 |
22 | Returns
23 | -------
24 | quant_params: (list)
25 | Quantization parameters, e.g. delta and zero_float
26 | model_params: (list)
27 | The model parameters of the base model without any quantization operations
28 | grad_params: (list)
29 | Parameters found in the gradient estimators (ParametrizedGradEstimatorBase)
30 | -------
31 |
32 | """
33 | quant_params, grad_params = [], []
34 | quant_params_names, grad_params_names = [], []
35 | for mod_name, module in quant_model.named_modules():
36 | if isinstance(module, QuantizerBase):
37 | for name, param in module.named_parameters(recurse=False):
38 | quant_params.append(param)
39 | quant_params_names.append(".".join((mod_name, name)))
40 | if isinstance(module, ParametrizedGradEstimatorBase):
41 | # gradient estimator params
42 | for name, param in module.named_parameters(recurse=False):
43 | grad_params.append(param)
44 | grad_params_names.append(".".join((mod_name, name)))
45 |
46 | def tensor_in_list(tensor, lst):
47 | return any([e is tensor for e in lst])
48 |
49 | found_params = quant_params + grad_params
50 |
51 | model_params = [p for p in quant_model.parameters() if not tensor_in_list(p, found_params)]
52 | model_param_names = [
53 | n for n, p in quant_model.named_parameters() if not tensor_in_list(p, found_params)
54 | ]
55 |
56 | print("Quantization parameters ({}):".format(len(quant_params_names)))
57 | print(quant_params_names)
58 |
59 | print("Gradient estimator parameters ({}):".format(len(grad_params_names)))
60 | print(grad_params_names)
61 |
62 | print("Other model parameters ({}):".format(len(model_param_names)))
63 | print(model_param_names)
64 |
65 | assert len(model_params + quant_params + grad_params) == len(
66 | list(quant_model.parameters())
67 | ), "{}; {}; {} -- {}".format(
68 | len(model_params), len(quant_params), len(grad_params), len(list(quant_model.parameters()))
69 | )
70 |
71 | return quant_params, model_params, grad_params
72 |
73 |
74 | def pass_data_for_range_estimation(
75 | loader, model, act_quant, weight_quant, max_num_batches=20, cross_entropy_layer=None, inp_idx=0
76 | ):
77 | print("\nEstimate quantization ranges on training data")
78 | model.set_quant_state(weight_quant, act_quant)
79 | # Put model in eval such that BN EMA does not get updated
80 | model.eval()
81 |
82 | if cross_entropy_layer is not None:
83 | layer_xent = get_layer_by_name(model, cross_entropy_layer)
84 | if layer_xent:
85 | print('Set cross entropy estimator for layer "{}"'.format(cross_entropy_layer))
86 | act_quant_mgr = layer_xent.activation_quantizer
87 | act_quant_mgr.range_estimator = RangeEstimators.cross_entropy.cls(
88 | per_channel=act_quant_mgr.per_channel,
89 | quantizer=act_quant_mgr.quantizer,
90 | **act_quant_mgr.range_estim_params,
91 | )
92 | else:
93 | raise ValueError("Cross-entropy layer not found")
94 |
95 | batches = []
96 | device = next(model.parameters()).device
97 |
98 | with torch.no_grad():
99 | for i, data in enumerate(loader):
100 | try:
101 | if isinstance(data, (tuple, list)):
102 | x = data[inp_idx].to(device=device)
103 | batches.append(x.data.cpu().numpy())
104 | model(x)
105 | print(f"proccesed step={i}")
106 | else:
107 | x = {k: v.to(device=device) for k, v in data.items()}
108 | model(**x)
109 | print(f"proccesed step={i}")
110 |
111 | if i >= max_num_batches - 1 or not act_quant:
112 | break
113 | except StopForwardException:
114 | pass
115 | return batches
116 |
117 |
118 | def set_range_estimators(config, model):
119 | print("Make quantizers learnable")
120 | model.learn_ranges()
121 |
122 | if config.qat.grad_scaling:
123 | print("Activate gradient scaling")
124 | model.grad_scaling(True)
125 |
126 | # Ensure we have the desired quant state
127 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant)
128 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | click>=7.0
2 | pytorch-ignite~=0.4.9
3 | tensorboard>=2.5
4 | scipy==1.3.1
5 | numpy==1.19.5
6 | pillow==6.2.1
7 | timm~=0.4.12
8 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from .stopwatch import Stopwatch
6 | from .utils import *
7 |
--------------------------------------------------------------------------------
/utils/click_options.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import click
6 |
7 | from models import QuantArchitectures
8 | from functools import wraps, partial
9 | from quantization.quantizers import QMethods
10 | from quantization.range_estimators import RangeEstimators, OptMethod
11 | from utils import split_dict, DotDict, ClickEnumOption, seed_all
12 | from utils.imagenet_dataloaders import ImageInterpolation
13 |
14 | click.option = partial(click.option, show_default=True)
15 |
16 | _HELP_MSG = (
17 | "Enforce determinism also on the GPU by disabling CUDNN and setting "
18 | "`torch.set_deterministic(True)`. In many cases this comes at the cost of efficiency "
19 | "and performance."
20 | )
21 |
22 |
23 | def base_options(func):
24 | @click.option(
25 | "--images-dir", type=click.Path(exists=True), help="Root directory of images", required=True
26 | )
27 | @click.option("--max-epochs", default=90, type=int, help="Maximum number of training epochs.")
28 | @click.option(
29 | "--interpolation",
30 | type=ClickEnumOption(ImageInterpolation),
31 | default=ImageInterpolation.bilinear.name,
32 | help="Desired interpolation to use for resizing.",
33 | )
34 | @click.option(
35 | "--save-checkpoint-dir",
36 | type=click.Path(exists=False),
37 | default=None,
38 | help="Directory where to save checkpoints (model, optimizer, lr_scheduler).",
39 | )
40 | @click.option(
41 | "--tb-logging-dir", default=None, type=str, help="The logging directory " "for tensorboard"
42 | )
43 | @click.option("--cuda/--no-cuda", is_flag=True, default=True, help="Use GPU")
44 | @click.option("--batch-size", default=128, type=int, help="Mini-batch size")
45 | @click.option("--num-workers", default=16, type=int, help="Number of workers for data loading")
46 | @click.option("--seed", default=None, type=int, help="Random number generator seed to set")
47 | @click.option("--deterministic/--nondeterministic", default=False, help=_HELP_MSG)
48 | # Architecture related options
49 | @click.option(
50 | "--architecture",
51 | type=ClickEnumOption(QuantArchitectures),
52 | required=True,
53 | help="Quantized architecture",
54 | )
55 | @click.option(
56 | "--model-dir",
57 | type=click.Path(exists=True),
58 | default=None,
59 | help="Path for model directory. If the model does not exist it will downloaded "
60 | "from a URL",
61 | )
62 | @click.option(
63 | "--pretrained/--no-pretrained",
64 | is_flag=True,
65 | default=True,
66 | help="Use pretrained model weights",
67 | )
68 | @click.option(
69 | "--progress-bar/--no-progress-bar", is_flag=True, default=False, help="Show progress bar"
70 | )
71 | @wraps(func)
72 | def func_wrapper(config, *args, **kwargs):
73 | config.base, remaining_kwargs = split_dict(
74 | kwargs,
75 | [
76 | "images_dir",
77 | "max_epochs",
78 | "interpolation",
79 | "save_checkpoint_dir",
80 | "tb_logging_dir",
81 | "cuda",
82 | "batch_size",
83 | "num_workers",
84 | "seed",
85 | "model_dir",
86 | "architecture",
87 | "pretrained",
88 | "deterministic",
89 | "progress_bar",
90 | ],
91 | )
92 |
93 | seed, deterministic = config.base.seed, config.base.deterministic
94 |
95 | if seed is None:
96 | if deterministic is True:
97 | raise ValueError("Enforcing determinism without providing a seed is not supported")
98 | else:
99 | seed_all(seed=seed, deterministic=deterministic)
100 |
101 | return func(config, *args, **remaining_kwargs)
102 |
103 | return func_wrapper
104 |
105 |
106 | class multi_optimizer_options:
107 | """
108 | An instance of this class is a callable object to serve as a decorator;
109 | hence the lower case class name.
110 |
111 | Among the CLI options defined in the decorator, `--{prefix-}optimizer-type`
112 | requires special attention. Default value for that variable for
113 | {prefix-}optimizer is the value in use by the main optimizer.
114 |
115 | Examples:
116 | @multi_optimizer_options('quant')
117 | @pass_config
118 | def command(config):
119 | ...
120 | """
121 |
122 | def __init__(self, prefix: str = ""):
123 | self.optimizer_name = prefix + "_optimizer" if prefix else "optimizer"
124 | self.prefix_option = prefix + "-" if prefix else ""
125 | self.prefix_attribute = prefix + "_" if prefix else ""
126 |
127 | def __call__(self, func):
128 | prefix_option = self.prefix_option
129 | prefix_attribute = self.prefix_attribute
130 |
131 | @click.option(
132 | f"--{prefix_option}optimizer",
133 | default="SGD",
134 | type=click.Choice(["SGD", "Adam"], case_sensitive=False),
135 | help=f"Class name of torch Optimizer to be used.",
136 | )
137 | @click.option(
138 | f"--{prefix_option}learning-rate",
139 | default=None,
140 | type=float,
141 | help="Initial learning rate.",
142 | )
143 | @click.option(
144 | f"--{prefix_option}momentum", default=0.9, type=float, help=f"Optimizer momentum."
145 | )
146 | @click.option(
147 | f"--{prefix_option}weight-decay",
148 | default=None,
149 | type=float,
150 | help="Weight decay for the network.",
151 | )
152 | @click.option(
153 | f"--{prefix_option}learning-rate-schedule",
154 | default=None,
155 | type=str,
156 | help="Learning rate scheduler, 'MultiStepLR:10:20:40' or "
157 | "'cosine:1e-4' for cosine decay",
158 | )
159 | @wraps(func)
160 | def func_wrapper(config, *args, **kwargs):
161 | base_arg_names = [
162 | "optimizer",
163 | "learning_rate",
164 | "momentum",
165 | "weight_decay",
166 | "learning_rate_schedule",
167 | ]
168 |
169 | optimizer_opt = DotDict()
170 |
171 | # Collect basic arguments
172 | for arg in base_arg_names:
173 | option_name = prefix_attribute + arg
174 | optimizer_opt[arg] = kwargs.pop(option_name)
175 |
176 | # config.{prefix_attribute}optimizer = optimizer_opt
177 | setattr(config, prefix_attribute + "optimizer", optimizer_opt)
178 |
179 | return func(config, *args, **kwargs)
180 |
181 | return func_wrapper
182 |
183 |
184 | def qat_options(func):
185 | @click.option(
186 | "--reestimate-bn-stats/--no-reestimate-bn-stats",
187 | is_flag=True,
188 | default=True,
189 | help="Reestimates the BN stats before every evaluation.",
190 | )
191 | @click.option(
192 | "--grad-scaling/--no-grad-scaling",
193 | is_flag=True,
194 | default=False,
195 | help="Do gradient scaling as in LSQ paper.",
196 | )
197 | @click.option(
198 | "--sep-quant-optimizer/--no-sep-quant-optimizer",
199 | is_flag=True,
200 | default=False,
201 | help="Use a separate optimizer for the quantizers.",
202 | )
203 | @multi_optimizer_options("quant")
204 | @oscillations_dampen_options
205 | @oscillations_freeze_options
206 | @wraps(func)
207 | def func_wrapper(config, *args, **kwargs):
208 | config.qat, remainder_kwargs = split_dict(
209 | kwargs, ["reestimate_bn_stats", "grad_scaling", "sep_quant_optimizer"]
210 | )
211 | return func(config, *args, **remainder_kwargs)
212 |
213 | return func_wrapper
214 |
215 |
216 | def oscillations_dampen_options(func):
217 | @click.option(
218 | "--oscillations-dampen-weight",
219 | default=None,
220 | type=float,
221 | help="If given, adds a oscillations dampening to the loss with given " "weighting.",
222 | )
223 | @click.option(
224 | "--oscillations-dampen-aggregation",
225 | type=click.Choice(["sum", "mean", "kernel_mean"]),
226 | default="kernel_mean",
227 | help="Aggregation type for bin regularization loss.",
228 | )
229 | @click.option(
230 | "--oscillations-dampen-weight-final",
231 | type=float,
232 | default=None,
233 | help="Dampening regularization final value for annealing schedule.",
234 | )
235 | @click.option(
236 | "--oscillations-dampen-anneal-start",
237 | default=0.25,
238 | type=float,
239 | help="Start of annealing (relative to total number of iterations).",
240 | )
241 | @wraps(func)
242 | def func_wrapper(config, *args, **kwargs):
243 | config.osc_damp, remainder_kwargs = split_dict(
244 | kwargs,
245 | [
246 | "oscillations_dampen_weight",
247 | "oscillations_dampen_aggregation",
248 | "oscillations_dampen_weight_final",
249 | "oscillations_dampen_anneal_start",
250 | ],
251 | "oscillations_dampen",
252 | )
253 |
254 | return func(config, *args, **remainder_kwargs)
255 |
256 | return func_wrapper
257 |
258 |
259 | def oscillations_freeze_options(func):
260 | @click.option(
261 | "--oscillations-freeze-threshold",
262 | default=0.0,
263 | type=float,
264 | help="If greater than 0, we will freeze oscillations which frequency (EMA) is "
265 | "higher than the given threshold. Frequency is defined as 1/period length.",
266 | )
267 | @click.option(
268 | "--oscillations-freeze-ema-momentum",
269 | default=0.001,
270 | type=float,
271 | help="The momentum to calculate the EMA frequency of the oscillation. In case"
272 | "freezing is used, this should be at least 2-3 times lower than the "
273 | "freeze threshold.",
274 | )
275 | @click.option(
276 | "--oscillations-freeze-use-ema/--no-oscillation-freeze-use-ema",
277 | is_flag=True,
278 | default=True,
279 | help="Uses an EMA of past x_int to find the correct freezing int value.",
280 | )
281 | @click.option(
282 | "--oscillations-freeze-max-bits",
283 | default=4,
284 | type=int,
285 | help="Max bit-width for oscillation tracking and freezing. If layers weight is in"
286 | "higher bits we do not track or freeze oscillations.",
287 | )
288 | @click.option(
289 | "--oscillations-freeze-threshold-final",
290 | type=float,
291 | default=None,
292 | help="Oscillation freezing final value for annealing schedule.",
293 | )
294 | @click.option(
295 | "--oscillations-freeze-anneal-start",
296 | default=0.25,
297 | type=float,
298 | help="Start of annealing (relative to total number of iterations).",
299 | )
300 | @wraps(func)
301 | def func_wrapper(config, *args, **kwargs):
302 | config.osc_freeze, remainder_kwargs = split_dict(
303 | kwargs,
304 | [
305 | "oscillations_freeze_threshold",
306 | "oscillations_freeze_ema_momentum",
307 | "oscillations_freeze_use_ema",
308 | "oscillations_freeze_max_bits",
309 | "oscillations_freeze_threshold_final",
310 | "oscillations_freeze_anneal_start",
311 | ],
312 | "oscillations_freeze",
313 | )
314 |
315 | return func(config, *args, **remainder_kwargs)
316 |
317 | return func_wrapper
318 |
319 |
320 | def quantization_options(func):
321 | # Weight quantization options
322 | @click.option(
323 | "--weight-quant/--no-weight-quant",
324 | is_flag=True,
325 | default=True,
326 | help="Run evaluation weight quantization or use FP32 weights",
327 | )
328 | @click.option(
329 | "--qmethod",
330 | type=ClickEnumOption(QMethods),
331 | default=QMethods.symmetric_uniform.name,
332 | help="Quantization scheme to use.",
333 | )
334 | @click.option(
335 | "--weight-quant-method",
336 | default=RangeEstimators.current_minmax.name,
337 | type=ClickEnumOption(RangeEstimators),
338 | help="Method to determine weight quantization clipping thresholds.",
339 | )
340 | @click.option(
341 | "--weight-opt-method",
342 | default=OptMethod.grid.name,
343 | type=ClickEnumOption(OptMethod),
344 | help="Optimization procedure for activation quantization clipping thresholds",
345 | )
346 | @click.option(
347 | "--num-candidates",
348 | type=int,
349 | default=None,
350 | help="Number of grid points for grid search in MSE range method.",
351 | )
352 | @click.option("--n-bits", default=8, type=int, help="Default number of quantization bits.")
353 | @click.option(
354 | "--per-channel/--no-per-channel",
355 | is_flag=True,
356 | default=False,
357 | help="If given, quantize each channel separately.",
358 | )
359 | # Activation quantization options
360 | @click.option(
361 | "--act-quant/--no-act-quant",
362 | is_flag=True,
363 | default=True,
364 | help="Run evaluation with activation quantization or use FP32 activations",
365 | )
366 | @click.option(
367 | "--qmethod-act",
368 | type=ClickEnumOption(QMethods),
369 | default=None,
370 | help="Quantization scheme for activation to use. If not specified `--qmethod` " "is used.",
371 | )
372 | @click.option(
373 | "--n-bits-act", default=None, type=int, help="Number of quantization bits for activations."
374 | )
375 | @click.option(
376 | "--act-quant-method",
377 | default=RangeEstimators.running_minmax.name,
378 | type=ClickEnumOption(RangeEstimators),
379 | help="Method to determine activation quantization clipping thresholds",
380 | )
381 | @click.option(
382 | "--act-opt-method",
383 | default=OptMethod.grid.name,
384 | type=ClickEnumOption(OptMethod),
385 | help="Optimization procedure for activation quantization clipping thresholds",
386 | )
387 | @click.option(
388 | "--act-num-candidates",
389 | type=int,
390 | default=None,
391 | help="Number of grid points for grid search in MSE/SQNR/Cross-entropy",
392 | )
393 | @click.option(
394 | "--act-momentum",
395 | type=float,
396 | default=None,
397 | help="Exponential averaging factor for running_minmax",
398 | )
399 | @click.option(
400 | "--num-est-batches",
401 | type=int,
402 | default=1,
403 | help="Number of training batches to be used for activation range estimation",
404 | )
405 | # Other options
406 | @click.option(
407 | "--quant-setup",
408 | default="LSQ_paper",
409 | type=click.Choice(["all", "LSQ", "FP_logits", "fc4", "fc4_dw8", "LSQ_paper"]),
410 | help="Method to quantize the network.",
411 | )
412 | @wraps(func)
413 | def func_wrapper(config, *args, **kwargs):
414 | config.quant, remainder_kwargs = split_dict(
415 | kwargs,
416 | [
417 | "qmethod",
418 | "qmethod_act",
419 | "weight_quant_method",
420 | "weight_opt_method",
421 | "num_candidates",
422 | "n_bits",
423 | "n_bits_act",
424 | "per_channel",
425 | "act_quant",
426 | "weight_quant",
427 | "quant_setup",
428 | "num_est_batches",
429 | "act_momentum",
430 | "act_num_candidates",
431 | "act_opt_method",
432 | "act_quant_method",
433 | ],
434 | )
435 |
436 | config.quant.qmethod_act = config.quant.qmethod_act or config.quant.qmethod
437 |
438 | return func(config, *args, **remainder_kwargs)
439 |
440 | return func_wrapper
441 |
442 |
443 | def quant_params_dict(config):
444 | weight_range_options = {}
445 | if config.quant.weight_quant_method == RangeEstimators.MSE:
446 | weight_range_options = dict(opt_method=config.quant.weight_opt_method)
447 | if config.quant.num_candidates is not None:
448 | weight_range_options["num_candidates"] = config.quant.num_candidates
449 |
450 | act_range_options = {}
451 | if config.quant.act_quant_method == RangeEstimators.MSE:
452 | act_range_options = dict(opt_method=config.quant.act_opt_method)
453 | if config.quant.act_num_candidates is not None:
454 | act_range_options["num_candidates"] = config.quant.num_candidates
455 |
456 | qparams = {
457 | "method": config.quant.qmethod.cls,
458 | "n_bits": config.quant.n_bits,
459 | "n_bits_act": config.quant.n_bits_act,
460 | "act_method": config.quant.qmethod_act.cls,
461 | "per_channel_weights": config.quant.per_channel,
462 | "quant_setup": config.quant.quant_setup,
463 | "weight_range_method": config.quant.weight_quant_method.cls,
464 | "weight_range_options": weight_range_options,
465 | "act_range_method": config.quant.act_quant_method.cls,
466 | "act_range_options": act_range_options,
467 | "quantize_input": True if config.quant.quant_setup == "LSQ_paper" else False,
468 | }
469 |
470 | return qparams
471 |
--------------------------------------------------------------------------------
/utils/imagenet_dataloaders.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import os
6 |
7 | import torchvision
8 | import torch.utils.data as torch_data
9 | from torchvision import transforms
10 | from utils import BaseEnumOptions
11 |
12 |
13 | class ImageInterpolation(BaseEnumOptions):
14 | nearest = transforms.InterpolationMode.NEAREST
15 | box = transforms.InterpolationMode.BOX
16 | bilinear = transforms.InterpolationMode.BILINEAR
17 | hamming = transforms.InterpolationMode.HAMMING
18 | bicubic = transforms.InterpolationMode.BICUBIC
19 | lanczos = transforms.InterpolationMode.LANCZOS
20 |
21 |
22 | class ImageNetDataLoaders(object):
23 | """
24 | Data loader provider for ImageNet images, providing a train and a validation loader.
25 | It assumes that the structure of the images is
26 | images_dir
27 | - train
28 | - label1
29 | - label2
30 | - ...
31 | - val
32 | - label1
33 | - label2
34 | - ...
35 | """
36 |
37 | def __init__(
38 | self,
39 | images_dir: str,
40 | image_size: int,
41 | batch_size: int,
42 | num_workers: int,
43 | interpolation: transforms.InterpolationMode,
44 | ):
45 | """
46 | Parameters
47 | ----------
48 | images_dir: str
49 | Root image directory
50 | image_size: int
51 | Number of pixels the image will be re-sized to (square)
52 | batch_size: int
53 | Batch size of both the training and validation loaders
54 | num_workers
55 | Number of parallel workers loading the images
56 | interpolation: transforms.InterpolationMode
57 | Desired interpolation to use for resizing.
58 | """
59 |
60 | self.images_dir = images_dir
61 | self.batch_size = batch_size
62 | self.num_workers = num_workers
63 |
64 | # For normalization, mean and std dev values are calculated per channel
65 | # and can be found on the web.
66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
67 |
68 | self.train_transforms = transforms.Compose(
69 | [
70 | transforms.RandomResizedCrop(image_size, interpolation=interpolation.value),
71 | transforms.RandomHorizontalFlip(),
72 | transforms.ToTensor(),
73 | normalize,
74 | ]
75 | )
76 |
77 | self.val_transforms = transforms.Compose(
78 | [
79 | transforms.Resize(image_size + 24, interpolation=interpolation.value),
80 | transforms.CenterCrop(image_size),
81 | transforms.ToTensor(),
82 | normalize,
83 | ]
84 | )
85 |
86 | self._train_loader = None
87 | self._val_loader = None
88 |
89 | @property
90 | def train_loader(self) -> torch_data.DataLoader:
91 | if not self._train_loader:
92 | root = os.path.join(self.images_dir, "train")
93 | train_set = torchvision.datasets.ImageFolder(root, transform=self.train_transforms)
94 | self._train_loader = torch_data.DataLoader(
95 | train_set,
96 | batch_size=self.batch_size,
97 | shuffle=True,
98 | num_workers=self.num_workers,
99 | pin_memory=True,
100 | )
101 | return self._train_loader
102 |
103 | @property
104 | def val_loader(self) -> torch_data.DataLoader:
105 | if not self._val_loader:
106 | root = os.path.join(self.images_dir, "val")
107 | val_set = torchvision.datasets.ImageFolder(root, transform=self.val_transforms)
108 | self._val_loader = torch_data.DataLoader(
109 | val_set,
110 | batch_size=self.batch_size,
111 | shuffle=False,
112 | num_workers=self.num_workers,
113 | pin_memory=True,
114 | )
115 | return self._val_loader
116 |
--------------------------------------------------------------------------------
/utils/optimizer_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import torch
6 |
7 |
8 | def get_lr_scheduler(optimizer, lr_schedule, epochs):
9 | scheduler = None
10 | if lr_schedule:
11 | if lr_schedule.startswith("multistep"):
12 | epochs = [int(s) for s in lr_schedule.split(":")[1:]]
13 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epochs)
14 | elif lr_schedule.startswith("cosine"):
15 | eta_min = float(lr_schedule.split(":")[1])
16 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
17 | optimizer, epochs, eta_min=eta_min
18 | )
19 | return scheduler
20 |
21 |
22 | def optimizer_lr_factory(config_optim, params, epochs):
23 | if config_optim.optimizer.lower() == "sgd":
24 | optimizer = torch.optim.SGD(
25 | params,
26 | lr=config_optim.learning_rate,
27 | momentum=config_optim.momentum,
28 | weight_decay=config_optim.weight_decay,
29 | )
30 | elif config_optim.optimizer.lower() == "adam":
31 | optimizer = torch.optim.Adam(
32 | params, lr=config_optim.learning_rate, weight_decay=config_optim.weight_decay
33 | )
34 | else:
35 | raise ValueError()
36 |
37 | lr_scheduler = get_lr_scheduler(optimizer, config_optim.learning_rate_schedule, epochs)
38 |
39 | return optimizer, lr_scheduler
40 |
--------------------------------------------------------------------------------
/utils/oscillation_tracking_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 | import torch
5 |
6 | from quantization.hijacker import QuantizationHijacker
7 |
8 |
9 | def add_oscillation_trackers(model, max_bits=4, *args, **kwarks):
10 | tracker_dict = {}
11 | # Add oscillation trackers to all weight quantizers
12 | for name, module in model.named_modules():
13 | if isinstance(module, QuantizationHijacker):
14 | q = module.weight_quantizer.quantizer
15 | if q.n_bits > max_bits:
16 | print(
17 | f"Skip tracking/freezing for {name}, too high bit {q.n_bits} (max {max_bits})"
18 | )
19 | continue
20 | int_fwd_wrapper = TrackOscillation(int_fwd=q.to_integer_forward, *args, **kwarks)
21 | q.to_integer_forward = int_fwd_wrapper
22 | tracker_dict[name + ".weight_quantizer"] = int_fwd_wrapper
23 | return tracker_dict
24 |
25 |
26 | class TrackOscillation:
27 | """
28 | This is a wrapper of the int_forward function of a quantizer.
29 | It tracks the oscillations in integer domain.
30 | """
31 |
32 | def __init__(self, int_fwd, momentum=0.01, freeze_threshold=0, use_ema_x_int=True):
33 | self.int_fwd = int_fwd
34 | self.momentum = momentum
35 |
36 | self.prev_x_int = None
37 | self.prev_switch_dir = None
38 |
39 | # Statistics to log
40 | self.ema_oscillation = None
41 | self.oscillated_sum = None
42 | self.total_oscillation = None
43 | self.iters_since_reset = 0
44 |
45 | # Extra variables for weight freezing
46 | self.freeze_threshold = freeze_threshold # This should be at least 2-3x the momentum value.
47 | self.use_ema_x_int = use_ema_x_int
48 | self.frozen = None
49 | self.frozen_x_int = None
50 | self.ema_x_int = None
51 |
52 | def __call__(self, x_float, skip_tracking=False, *args, **kwargs):
53 | x_int = self.int_fwd(x_float, *args, **kwargs)
54 |
55 | # Apply weight freezing
56 | if self.frozen is not None:
57 | x_int = ~self.frozen * x_int + self.frozen * self.frozen_x_int
58 |
59 | if skip_tracking:
60 | return x_int
61 |
62 | with torch.no_grad():
63 | # Check if everything is correctly initialized, otherwise do so
64 | self.check_init(x_int)
65 |
66 | # detect difference in x_int NB we round to avoid int inaccuracies
67 | delta_x_int = torch.round(self.prev_x_int - x_int).detach() # should be {-1, 0, 1}
68 | switch_dir = torch.sign(delta_x_int) # This is {-1, 0, 1} as sign(0) is mapped to 0
69 | # binary mask for switching
70 | switched = delta_x_int != 0
71 |
72 | oscillated = (self.prev_switch_dir * switch_dir) == -1
73 | self.ema_oscillation = (
74 | self.momentum * oscillated + (1 - self.momentum) * self.ema_oscillation
75 | )
76 |
77 | # Update prev_switch_dir for the switch variables
78 | self.prev_switch_dir[switched] = switch_dir[switched]
79 | self.prev_x_int = x_int
80 | self.oscillated_sum = oscillated.sum()
81 | self.total_oscillation += oscillated
82 | self.iters_since_reset += 1
83 |
84 | # Freeze some weights
85 | if self.freeze_threshold > 0:
86 | freeze_weights = self.ema_oscillation > self.freeze_threshold
87 | self.frozen[freeze_weights] = True # Set them to frozen
88 | if self.use_ema_x_int:
89 | self.frozen_x_int[freeze_weights] = torch.round(self.ema_x_int[freeze_weights])
90 | # Update x_int EMA which can be used for freezing
91 | self.ema_x_int = self.momentum * x_int + (1 - self.momentum) * self.ema_x_int
92 | else:
93 | self.frozen_x_int[freeze_weights] = x_int[freeze_weights]
94 |
95 | return x_int
96 |
97 | def check_init(self, x_int):
98 | if self.prev_x_int is None:
99 | # Init prev switch dir to 0
100 | self.prev_switch_dir = torch.zeros_like(x_int)
101 | self.prev_x_int = x_int.detach() # Not sure if needed, don't think so
102 | self.ema_oscillation = torch.zeros_like(x_int)
103 | self.oscillated_sum = 0
104 | self.total_oscillation = torch.zeros_like(x_int)
105 | print("Init tracking", x_int.shape)
106 | else:
107 | assert (
108 | self.prev_x_int.shape == x_int.shape
109 | ), "Tracking shape does not match current tensor shape."
110 |
111 | # For weight freezing
112 | if self.frozen is None and self.freeze_threshold > 0:
113 | self.frozen = torch.zeros_like(x_int, dtype=torch.bool)
114 | self.frozen_x_int = torch.zeros_like(x_int)
115 | if self.use_ema_x_int:
116 | self.ema_x_int = x_int.detach().clone()
117 | print("Init freezing", x_int.shape)
118 |
--------------------------------------------------------------------------------
/utils/qat_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import copy
6 |
7 | import torch
8 |
9 | from quantization.hijacker import QuantizationHijacker
10 | from quantization.quantized_folded_bn import BNFusedHijacker
11 | from utils.imagenet_dataloaders import ImageNetDataLoaders
12 |
13 |
14 | class MethodPropagator:
15 | """convenience class to allow multiple optimizers or LR schedulers to be used as if it
16 | were one optimizer/scheduler."""
17 |
18 | def __init__(self, propagatables):
19 | self.propagatables = propagatables
20 |
21 | def __getattr__(self, item):
22 | if callable(getattr(self.propagatables[0], item)):
23 |
24 | def propagate_call(*args, **kwargs):
25 | for prop in self.propagatables:
26 | getattr(prop, item)(*args, **kwargs)
27 |
28 | return propagate_call
29 | else:
30 | return getattr(self.propagatables[0], item)
31 |
32 | def __str__(self):
33 | result = ""
34 | for prop in self.propagatables:
35 | result += str(prop) + "\n"
36 | return result
37 |
38 | def __iter__(self):
39 | for i in self.propagatables:
40 | yield i
41 |
42 | def __contains__(self, item):
43 | return item in self.propagatables
44 |
45 |
46 | def get_dataloaders_and_model(config, load_type="fp32", **qparams):
47 | dataloaders = ImageNetDataLoaders(
48 | config.base.images_dir,
49 | 224,
50 | config.base.batch_size,
51 | config.base.num_workers,
52 | config.base.interpolation,
53 | )
54 |
55 | model = config.base.architecture(
56 | pretrained=config.base.pretrained,
57 | load_type=load_type,
58 | model_dir=config.base.model_dir,
59 | **qparams,
60 | )
61 | if config.base.cuda:
62 | model = model.cuda()
63 |
64 | return dataloaders, model
65 |
66 |
67 | class CompositeLoss:
68 | def __init__(self, loss_dict):
69 | """
70 | Composite loss of N separate loss functions. All functions are summed up.
71 |
72 | Note, each loss function gets as argument (prediction, target), even though if it might not
73 | need it. Other data independent instances need to be provided directly to the loss function
74 | (e.g. the model/weights in case of a regularization term.
75 |
76 | """
77 | self.loss_dict = loss_dict
78 |
79 | def __call__(self, prediction, target, *args, **kwargs):
80 | total_loss = 0
81 | for loss_func in self.loss_dict.values():
82 | total_loss += loss_func(prediction, target, *args, **kwargs)
83 | return total_loss
84 |
85 |
86 | class UpdateFreezingThreshold:
87 | def __init__(self, tracker_dict, decay_schedule):
88 | self.tracker_dict = tracker_dict
89 | self.decay_schedule = decay_schedule
90 |
91 | def __call__(self, engine):
92 | if engine.state.iteration < self.decay_schedule.decay_start:
93 | # Put it always to 0 for real warm-start
94 | new_threshold = 0
95 | else:
96 | new_threshold = self.decay_schedule(engine.state.iteration)
97 |
98 | # Update trackers with new threshold
99 | for name, tracker in self.tracker_dict.items():
100 | tracker.freeze_threshold = new_threshold
101 | # print('Set new freezing threshold', new_threshold)
102 |
103 |
104 | class UpdateDampeningLossWeighting:
105 | def __init__(self, bin_reg_loss, decay_schedule):
106 | self.dampen_loss = bin_reg_loss
107 | self.decay_schedule = decay_schedule
108 |
109 | def __call__(self, engine):
110 | new_weighting = self.decay_schedule(engine.state.iteration)
111 | self.dampen_loss.weighting = new_weighting
112 | # print('Set new bin reg weighting', new_weighting)
113 |
114 |
115 | class DampeningLoss:
116 | def __init__(self, model, weighting=1.0, aggregation="sum"):
117 | """
118 | Calculates the dampening loss for all weights in a given quantized model. It is
119 | expected that all quantized weights are in a Hijacker module.
120 |
121 | """
122 | self.model = model
123 | self.weighting = weighting
124 | self.aggregation = aggregation
125 |
126 | def __call__(self, *args, **kwargs):
127 | total_bin_loss = 0
128 | for name, module in self.model.named_modules():
129 | if isinstance(module, QuantizationHijacker):
130 | # FP32 weight tensor, potential folded but before quantization
131 | weight, _ = module.get_weight_bias()
132 | # The matching weight quantizer (not manager, direct quantizer class)
133 | quantizer = module.weight_quantizer.quantizer
134 | total_bin_loss += dampening_loss(weight, quantizer, self.aggregation)
135 | return total_bin_loss * self.weighting
136 |
137 |
138 | def dampening_loss(w_fp, quantizer, aggregation="sum"):
139 | # L &= (s*w_{int} - w)^2
140 | # We also need to add clipping for both cases, we can do so by using the forward
141 | w_q = quantizer(w_fp, skip_tracking=True).detach() # this is also clipped and our target
142 | # clamp w in FP32 domain to not change range learning (min(max) is needed for per-channel)
143 | w_fp_clip = torch.min(torch.max(w_fp, quantizer.x_min), quantizer.x_max)
144 | loss = (w_q - w_fp_clip) ** 2
145 | if aggregation == "sum":
146 | return loss.sum()
147 | elif aggregation == "mean":
148 | return loss.mean()
149 | elif aggregation == "kernel_mean":
150 | return loss.sum(0).mean()
151 | else:
152 | raise ValueError(f"Aggregation method '{aggregation}' not implemented.")
153 |
154 |
155 | class ReestimateBNStats:
156 | def __init__(self, model, data_loader, num_batches=50):
157 | super().__init__()
158 | self.model = model
159 | self.data_loader = data_loader
160 | self.num_batches = num_batches
161 |
162 | def __call__(self, engine):
163 | print("-- Reestimate current BN statistics --")
164 | reestimate_BN_stats(self.model, self.data_loader, self.num_batches)
165 |
166 |
167 | def reestimate_BN_stats(model, data_loader, num_batches=50, store_ema_stats=False):
168 | # We set BN momentum to 1 an use train mode
169 | # -> the running mean/var have the current batch statistics
170 | model.eval()
171 | org_momentum = {}
172 | for name, module in model.named_modules():
173 | if isinstance(module, BNFusedHijacker):
174 | org_momentum[name] = module.momentum
175 | module.momentum = 1.0
176 | module.running_mean_sum = torch.zeros_like(module.running_mean)
177 | module.running_var_sum = torch.zeros_like(module.running_var)
178 | # Set all BNFusedHijacker modules to train mode for but not its children
179 | module.training = True
180 |
181 | if store_ema_stats:
182 | # Save the original EMA, make sure they are in buffers so they end in the state dict
183 | if not hasattr(module, "running_mean_ema"):
184 | module.register_buffer("running_mean_ema", copy.deepcopy(module.running_mean))
185 | module.register_buffer("running_var_ema", copy.deepcopy(module.running_var))
186 | else:
187 | module.running_mean_ema = copy.deepcopy(module.running_mean)
188 | module.running_var_ema = copy.deepcopy(module.running_var)
189 |
190 | # Run data for estimation
191 | device = next(model.parameters()).device
192 | batch_count = 0
193 | with torch.no_grad():
194 | for x, y in data_loader:
195 | model(x.to(device))
196 | # We save the running mean/var to a buffer
197 | for name, module in model.named_modules():
198 | if isinstance(module, BNFusedHijacker):
199 | module.running_mean_sum += module.running_mean
200 | module.running_var_sum += module.running_var
201 |
202 | batch_count += 1
203 | if batch_count == num_batches:
204 | break
205 | # At the end we normalize the buffer and write it into the running mean/var
206 | for name, module in model.named_modules():
207 | if isinstance(module, BNFusedHijacker):
208 | module.running_mean = module.running_mean_sum / batch_count
209 | module.running_var = module.running_var_sum / batch_count
210 | # We reset the momentum in case it would be used anywhere else
211 | module.momentum = org_momentum[name]
212 | model.eval()
213 |
--------------------------------------------------------------------------------
/utils/stopwatch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import sys
6 | import time
7 |
8 |
9 | class Stopwatch:
10 | """
11 | A simple cross-platform context-manager stopwatch.
12 |
13 | Examples
14 | --------
15 | >>> import time
16 | >>> with Stopwatch(verbose=True) as st:
17 | ... time.sleep(0.101) #doctest: +ELLIPSIS
18 | Elapsed time: 0.10... sec
19 | """
20 |
21 | def __init__(self, name=None, verbose=False):
22 | self._name = name
23 | self._verbose = verbose
24 |
25 | self._start_time_point = 0.0
26 | self._total_duration = 0.0
27 | self._is_running = False
28 |
29 | if sys.platform == "win32":
30 | # on Windows, the best timer is time.clock()
31 | self._timer_fn = time.clock
32 | else:
33 | # on most other platforms, the best timer is time.time()
34 | self._timer_fn = time.time
35 |
36 | def __enter__(self, verbose=False):
37 | return self.start()
38 |
39 | def __exit__(self, exc_type, exc_val, exc_tb):
40 | self.stop()
41 | if self._verbose:
42 | self.print()
43 |
44 | def start(self):
45 | if not self._is_running:
46 | self._start_time_point = self._timer_fn()
47 | self._is_running = True
48 | return self
49 |
50 | def stop(self):
51 | if self._is_running:
52 | self._total_duration += self._timer_fn() - self._start_time_point
53 | self._is_running = False
54 | return self
55 |
56 | def reset(self):
57 | self._start_time_point = 0.0
58 | self._total_duration = 0.0
59 | self._is_running = False
60 | return self
61 |
62 | def _update_state(self):
63 | now = self._timer_fn()
64 | self._total_duration += now - self._start_time_point
65 | self._start_time_point = now
66 |
67 | def _format(self):
68 | prefix = f"[{self._name}]" if self._name is not None else "Elapsed time"
69 | info = f"{prefix}: {self._total_duration:.3f} sec"
70 | return info
71 |
72 | def format(self):
73 | if self._is_running:
74 | self._update_state()
75 | return self._format()
76 |
77 | def print(self):
78 | print(self.format())
79 |
80 | def get_total_duration(self):
81 | if self._is_running:
82 | self._update_state()
83 | return self._total_duration
84 |
--------------------------------------------------------------------------------
/utils/supervised_driver.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | from ignite.contrib.handlers import TensorboardLogger
6 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
7 | from ignite.handlers import Checkpoint, global_step_from_engine
8 | from torch.optim import Optimizer
9 |
10 |
11 | def create_trainer_engine(
12 | model,
13 | optimizer,
14 | criterion,
15 | metrics,
16 | data_loaders,
17 | lr_scheduler=None,
18 | save_checkpoint_dir=None,
19 | device="cuda",
20 | ):
21 | # Create trainer
22 | trainer = create_supervised_trainer(
23 | model=model,
24 | optimizer=optimizer,
25 | loss_fn=criterion,
26 | device=device,
27 | output_transform=custom_output_transform,
28 | )
29 |
30 | for name, metric in metrics.items():
31 | metric.attach(trainer, name)
32 |
33 | # Add lr_scheduler
34 | if lr_scheduler:
35 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step())
36 |
37 | # Create evaluator
38 | evaluator = create_supervised_evaluator(model=model, metrics=metrics, device=device)
39 |
40 | # Save model checkpoint
41 | if save_checkpoint_dir:
42 | to_save = {"model": model, "optimizer": optimizer}
43 | if lr_scheduler:
44 | to_save["lr_scheduler"] = lr_scheduler
45 | checkpoint = Checkpoint(
46 | to_save,
47 | save_checkpoint_dir,
48 | n_saved=1,
49 | global_step_transform=global_step_from_engine(trainer),
50 | )
51 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint)
52 |
53 | # Add hooks for logging metrics
54 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results, optimizer)
55 |
56 | trainer.add_event_handler(
57 | Events.EPOCH_COMPLETED, run_evaluation_for_training, evaluator, data_loaders.val_loader
58 | )
59 |
60 | return trainer, evaluator
61 |
62 |
63 | def custom_output_transform(x, y, y_pred, loss):
64 | return y_pred, y
65 |
66 |
67 | def log_training_results(trainer, optimizer):
68 | learning_rate = optimizer.param_groups[0]["lr"]
69 | log_metrics(trainer.state.metrics, "Training", trainer.state.epoch, learning_rate)
70 |
71 |
72 | def run_evaluation_for_training(trainer, evaluator, val_loader):
73 | evaluator.run(val_loader)
74 | log_metrics(evaluator.state.metrics, "Evaluation", trainer.state.epoch)
75 |
76 |
77 | def log_metrics(metrics, stage: str = "", training_epoch=None, learning_rate=None):
78 | log_text = " {}".format(metrics) if metrics else ""
79 | if training_epoch is not None:
80 | log_text = "Epoch: {}".format(training_epoch) + log_text
81 | if learning_rate and learning_rate > 0.0:
82 | log_text += " Learning rate: {:.2E}".format(learning_rate)
83 | log_text = "Results - " + log_text
84 | if stage:
85 | log_text = "{} ".format(stage) + log_text
86 | print(log_text, flush=True)
87 |
88 |
89 | def setup_tensorboard_logger(trainer, evaluator, output_path, optimizers=None):
90 | logger = TensorboardLogger(logdir=output_path)
91 |
92 | # Attach the logger to log loss and accuracy for both training and validation
93 | for tag, cur_evaluator in [("train", trainer), ("validation", evaluator)]:
94 | logger.attach_output_handler(
95 | cur_evaluator,
96 | event_name=Events.EPOCH_COMPLETED,
97 | tag=tag,
98 | metric_names="all",
99 | global_step_transform=global_step_from_engine(trainer),
100 | )
101 |
102 | # Log optimizer parameters
103 | if isinstance(optimizers, Optimizer):
104 | optimizers = {None: optimizers}
105 |
106 | for k, optimizer in optimizers.items():
107 | logger.attach_opt_params_handler(
108 | trainer, Events.EPOCH_COMPLETED, optimizer, param_name="lr", tag=k
109 | )
110 |
111 | return logger
112 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
3 | # All Rights Reserved.
4 |
5 | import collections
6 | import os
7 | import random
8 | from collections import namedtuple
9 | from enum import Flag, auto
10 | from functools import partial
11 |
12 | import click
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 |
17 |
18 | class DotDict(dict):
19 | """
20 | A dictionary that allows attribute-style access.
21 | Examples
22 | --------
23 | >>> config = DotDict(a=None)
24 | >>> config.a = 42
25 | >>> config.b = 'egg'
26 | >>> config # can be used as dict
27 | {'a': 42, 'b': 'egg'}
28 | """
29 |
30 | def __setattr__(self, key, value):
31 | self.__setitem__(key, value)
32 |
33 | def __delattr__(self, key):
34 | self.__delitem__(key)
35 |
36 | def __getattr__(self, key):
37 | if key in self:
38 | return self.__getitem__(key)
39 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})")
40 |
41 |
42 | def relu(x):
43 | x = np.array(x)
44 | return x * (x > 0)
45 |
46 |
47 | def get_all_layer_names(model, subtypes=None):
48 | if subtypes is None:
49 | return [name for name, module in model.named_modules()][1:]
50 | return [name for name, module in model.named_modules() if isinstance(module, subtypes)]
51 |
52 |
53 | def get_layer_name_to_module_dict(model):
54 | return {name: module for name, module in model.named_modules() if name}
55 |
56 |
57 | def get_module_to_layer_name_dict(model):
58 | modules_to_names = collections.OrderedDict()
59 | for name, module in model.named_modules():
60 | modules_to_names[module] = name
61 | return modules_to_names
62 |
63 |
64 | def get_layer_name(model, layer):
65 | for name, module in model.named_modules():
66 | if module == layer:
67 | return name
68 | return None
69 |
70 |
71 | def get_layer_by_name(model, layer_name):
72 | for name, module in model.named_modules():
73 | if name == layer_name:
74 | return module
75 | return None
76 |
77 |
78 | def create_conv_layer_list(cls, model: nn.Module) -> list:
79 | """
80 | Function finds all prunable layers in the provided model
81 |
82 | Parameters
83 | ----------
84 | cls: SVD class
85 | model : torch.nn.Module
86 | A pytorch model.
87 |
88 | Returns
89 | -------
90 | conv_layer_list : list
91 | List of all prunable layers in the given model.
92 |
93 | """
94 | conv_layer_list = []
95 |
96 | def fill_list(mod):
97 | if isinstance(mod, tuple(cls.supported_layer_types)):
98 | conv_layer_list.append(mod)
99 |
100 | model.apply(fill_list)
101 | return conv_layer_list
102 |
103 |
104 | def create_linear_layer_list(cls, model: nn.Module) -> list:
105 | """
106 | Function finds all prunable layers in the provided model
107 |
108 | Parameters
109 | ----------
110 | model : torch.nn.Module
111 | A pytorch model.
112 |
113 | Returns
114 | -------
115 | conv_layer_list : list
116 | List of all prunable layers in the given model.
117 |
118 | """
119 | conv_layer_list = []
120 |
121 | def fill_list(mod):
122 | if isinstance(mod, tuple(cls.supported_layer_types)):
123 | conv_layer_list.append(mod)
124 |
125 | model.apply(fill_list)
126 | return conv_layer_list
127 |
128 |
129 | def to_numpy(tensor):
130 | """
131 | Helper function that turns the given tensor into a numpy array
132 |
133 | Parameters
134 | ----------
135 | tensor : torch.Tensor
136 |
137 | Returns
138 | -------
139 | tensor : float or np.array
140 |
141 | """
142 | if isinstance(tensor, np.ndarray):
143 | return tensor
144 | if hasattr(tensor, "is_cuda"):
145 | if tensor.is_cuda:
146 | return tensor.cpu().detach().numpy()
147 | if hasattr(tensor, "detach"):
148 | return tensor.detach().numpy()
149 | if hasattr(tensor, "numpy"):
150 | return tensor.numpy()
151 |
152 | return np.array(tensor)
153 |
154 |
155 | def set_module_attr(model, layer_name, value):
156 | split = layer_name.split(".")
157 |
158 | this_module = model
159 | for mod_name in split[:-1]:
160 | if mod_name.isdigit():
161 | this_module = this_module[int(mod_name)]
162 | else:
163 | this_module = getattr(this_module, mod_name)
164 |
165 | last_mod_name = split[-1]
166 | if last_mod_name.isdigit():
167 | this_module[int(last_mod_name)] = value
168 | else:
169 | setattr(this_module, last_mod_name, value)
170 |
171 |
172 | def search_for_zero_planes(model: torch.nn.Module):
173 | """If list of modules to winnow is empty to start with, search through all modules to check
174 | if any
175 | planes have been zeroed out. Update self._list_of_modules_to_winnow with any findings.
176 | :param model: torch model to search through modules for zeroed parameters
177 | """
178 |
179 | list_of_modules_to_winnow = []
180 | for _, module in model.named_modules():
181 | if isinstance(module, (torch.nn.Linear, torch.nn.modules.conv.Conv2d)):
182 | in_channels_to_winnow = _assess_weight_and_bias(module.weight, module.bias)
183 | if in_channels_to_winnow:
184 | list_of_modules_to_winnow.append((module, in_channels_to_winnow))
185 | return list_of_modules_to_winnow
186 |
187 |
188 | def _assess_weight_and_bias(weight: torch.nn.Parameter, _bias: torch.nn.Parameter):
189 | """4-dim weights [CH-out, CH-in, H, W] and 1-dim bias [CH-out]"""
190 | if len(weight.shape) > 2:
191 | input_channels_to_ignore = (weight.sum((0, 2, 3)) == 0).nonzero().squeeze().tolist()
192 | else:
193 | input_channels_to_ignore = (weight.sum(0) == 0).nonzero().squeeze().tolist()
194 |
195 | if type(input_channels_to_ignore) != list:
196 | input_channels_to_ignore = [input_channels_to_ignore]
197 |
198 | return input_channels_to_ignore
199 |
200 |
201 | def seed_all(seed: int = 1029, deterministic: bool = False):
202 | """
203 | This is our attempt to make experiments reproducible by seeding all known RNGs and setting
204 | appropriate torch directives.
205 | For a general discussion of reproducibility in Pytorch and CUDA and a documentation of the
206 | options we are using see, e.g.,
207 | https://pytorch.org/docs/1.7.1/notes/randomness.html
208 | https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
209 |
210 | As of today (July 2021), even after seeding and setting some directives,
211 | there remain unfortunate contradictions:
212 | 1. CUDNN
213 | - having CUDNN enabled leads to
214 | - non-determinism in Pytorch when using the GPU, cf. MORPH-10999.
215 | - having CUDNN disabled leads to
216 | - most regression tests in Qrunchy failing, cf. MORPH-11103
217 | - significantly increased execution time in some cases
218 | - performance degradation in some cases
219 | 2. torch.set_deterministic(d)
220 | - setting d = True leads to errors for Pytorch algorithms that do not (yet) have a deterministic
221 | counterpart, e.g., the layer `adaptive_avg_pool2d_backward_cuda` in vgg16__torchvision.
222 |
223 | Thus, we leave the choice of enforcing determinism by disabling CUDNN and non-deterministic
224 | algorithms to the user. To keep it simple, we only have one switch for both.
225 | This situation could be re-evaluated upon updates of Pytorch, CUDA, CUDNN.
226 | """
227 |
228 | assert isinstance(seed, int), f"RNG seed must be an integer ({seed})"
229 | assert seed >= 0, f"RNG seed must be a positive integer ({seed})"
230 |
231 | # Builtin RNGs
232 | random.seed(seed)
233 | os.environ["PYTHONHASHSEED"] = str(seed)
234 |
235 | # Numpy RNG
236 | np.random.seed(seed)
237 |
238 | # CUDNN determinism (setting those has not lead to errors so far)
239 | torch.backends.cudnn.benchmark = False
240 | torch.backends.cudnn.deterministic = True
241 |
242 | # Torch RNGs
243 | torch.manual_seed(seed)
244 | torch.cuda.manual_seed(seed)
245 | torch.cuda.manual_seed_all(seed)
246 |
247 | # Problematic settings, see docstring. Precaution: We do not mutate unless asked to do so
248 | if deterministic is True:
249 | torch.backends.cudnn.enabled = False
250 |
251 | torch.set_deterministic(True) # Use torch.use_deterministic_algorithms(True) in torch 1.8.1
252 | # When using torch.set_deterministic(True), it is advised by Pytorch to set the
253 | # CUBLAS_WORKSPACE_CONFIG variable as follows, see
254 | # https://pytorch.org/docs/1.7.1/notes/randomness.html#avoiding-nondeterministic-algorithms
255 | # and the link to the CUDA homepage on that website.
256 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
257 |
258 |
259 | def assert_allclose(actual, desired, *args, **kwargs):
260 | """A more beautiful version of torch.all_close."""
261 | np.testing.assert_allclose(to_numpy(actual), to_numpy(desired), *args, **kwargs)
262 |
263 |
264 | def count_params(module):
265 | return len(nn.utils.parameters_to_vector(module.parameters()))
266 |
267 |
268 | class StopForwardException(Exception):
269 | """Used to throw and catch an exception to stop traversing the graph."""
270 |
271 | pass
272 |
273 |
274 | class StopForwardHook:
275 | def __call__(self, module, *args):
276 | raise StopForwardException
277 |
278 |
279 | def sigmoid(x):
280 | return 1.0 / (1.0 + np.exp(-x))
281 |
282 |
283 | class CosineTempDecay:
284 | def __init__(self, t_max, temp_range=(20.0, 2.0), rel_decay_start=0):
285 | self.t_max = t_max
286 | self.start_temp, self.end_temp = temp_range
287 | self.decay_start = rel_decay_start * t_max
288 |
289 | def __call__(self, t):
290 | if t < self.decay_start:
291 | return self.start_temp
292 |
293 | rel_t = (t - self.decay_start) / (self.t_max - self.decay_start)
294 | return self.end_temp + 0.5 * (self.start_temp - self.end_temp) * (1 + np.cos(rel_t * np.pi))
295 |
296 |
297 | class BaseEnumOptions(Flag):
298 | def __str__(self):
299 | return self.name
300 |
301 | @classmethod
302 | def list_names(cls):
303 | return [m.name for m in cls]
304 |
305 |
306 | class ClassEnumOptions(BaseEnumOptions):
307 | @property
308 | def cls(self):
309 | return self.value.cls
310 |
311 | def __call__(self, *args, **kwargs):
312 | return self.value.cls(*args, **kwargs)
313 |
314 |
315 | MethodMap = partial(namedtuple("MethodMap", ["value", "cls"]), auto())
316 |
317 |
318 | def split_dict(src: dict, include=(), remove_prefix: str = ""):
319 | """
320 | Splits dictionary into a DotDict and a remainder.
321 | The arguments to be placed in the first DotDict are those listed in `include`.
322 | Parameters
323 | ----------
324 | src: dict
325 | The source dictionary.
326 | include:
327 | List of keys to be returned in the first DotDict.
328 | remove_suffix:
329 | remove prefix from key
330 | """
331 | result = DotDict()
332 |
333 | for arg in include:
334 | if remove_prefix:
335 | key = arg.replace(f"{remove_prefix}_", "", 1)
336 | else:
337 | key = arg
338 | result[key] = src[arg]
339 | remainder = {key: val for key, val in src.items() if key not in include}
340 | return result, remainder
341 |
342 |
343 | class ClickEnumOption(click.Choice):
344 | """
345 | Adjusted click.Choice type for BaseOption which is based on Enum
346 | """
347 |
348 | def __init__(self, enum_options, case_sensitive=True):
349 | assert issubclass(enum_options, BaseEnumOptions)
350 | self.base_option = enum_options
351 | super().__init__(self.base_option.list_names(), case_sensitive)
352 |
353 | def convert(self, value, param, ctx):
354 | # Exact match
355 | if value in self.choices:
356 | return self.base_option[value]
357 |
358 | # Match through normalization and case sensitivity
359 | # first do token_normalize_func, then lowercase
360 | # preserve original `value` to produce an accurate message in
361 | # `self.fail`
362 | normed_value = value
363 | normed_choices = self.choices
364 |
365 | if ctx is not None and ctx.token_normalize_func is not None:
366 | normed_value = ctx.token_normalize_func(value)
367 | normed_choices = [ctx.token_normalize_func(choice) for choice in self.choices]
368 |
369 | if not self.case_sensitive:
370 | normed_value = normed_value.lower()
371 | normed_choices = [choice.lower() for choice in normed_choices]
372 |
373 | if normed_value in normed_choices:
374 | return self.base_option[normed_value]
375 |
376 | self.fail(
377 | "invalid choice: %s. (choose from %s)" % (value, ", ".join(self.choices)), param, ctx
378 | )
379 |
--------------------------------------------------------------------------------