├── LICENSE ├── README.md ├── main.py ├── models ├── __init__.py ├── quantized_bert.py ├── quantized_mobilebert.py └── quantized_roberta.py ├── quantization ├── __init__.py ├── adaround │ ├── __init__.py │ ├── adaround.py │ ├── config.py │ ├── quantizer.py │ └── utils.py ├── autoquant_utils.py ├── base_quantized_classes.py ├── base_quantized_model.py ├── hijacker.py ├── quantization_manager.py ├── quantizers.py ├── range_estimators.py └── utils.py ├── requirements.txt └── utils ├── __init__.py ├── adaround_utils.py ├── glue_tasks.py ├── hf_models.py ├── per_embd_quant_utils.py ├── qat_utils.py ├── quant_click_options.py ├── tb_utils.py ├── transformer_click_options.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted 6 | (subject to the limitations in the disclaimer below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this list of conditions 9 | and the following disclaimer: 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions 12 | and the following disclaimer in the documentation and/or other materials provided with the 13 | istribution. 14 | 15 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to 16 | endorse or promote products derived from this software without specific prior written permission. 17 | 18 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS 19 | SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED 20 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF 26 | THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Quantization 2 | This repository contains the implementation and experiments for the paper presented in 3 | 4 | **Yelysei Bondarenko1, Markus Nagel1, Tijmen Blankevoort1, 5 | "Understanding and Overcoming the Challenges of Efficient Transformer Quantization", EMNLP 2021.** [[ACL Anthology]](https://aclanthology.org/2021.emnlp-main.627/) [[ArXiv]](https://arxiv.org/abs/2109.12948) 6 | 7 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.) 8 | 9 | 10 | ## Reference 11 | If you find our work useful, please cite 12 | ``` 13 | @inproceedings{bondarenko-etal-2021-understanding, 14 | title = "Understanding and Overcoming the Challenges of Efficient Transformer Quantization", 15 | author = "Bondarenko, Yelysei and 16 | Nagel, Markus and 17 | Blankevoort, Tijmen", 18 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 19 | month = nov, 20 | year = "2021", 21 | address = "Online and Punta Cana, Dominican Republic", 22 | publisher = "Association for Computational Linguistics", 23 | url = "https://aclanthology.org/2021.emnlp-main.627", 24 | pages = "7947--7969", 25 | abstract = "Transformer-based architectures have become the de-facto standard models for a wide range of Natural Language Processing tasks. However, their memory footprint and high latency are prohibitive for efficient deployment and inference on resource-limited devices. In this work, we explore quantization for transformers. We show that transformers have unique quantization challenges {--} namely, high dynamic activation ranges that are difficult to represent with a low bit fixed-point format. We establish that these activations contain structured outliers in the residual connections that encourage specific attention patterns, such as attending to the special separator token. To combat these challenges, we present three solutions based on post-training quantization and quantization-aware training, each with a different set of compromises for accuracy, model size, and ease of use. In particular, we introduce a novel quantization scheme {--} per-embedding-group quantization. We demonstrate the effectiveness of our methods on the GLUE benchmark using BERT, establishing state-of-the-art results for post-training quantization. Finally, we show that transformer weights and embeddings can be quantized to ultra-low bit-widths, leading to significant memory savings with a minimum accuracy loss. Our source code is available at \url{https://github.com/qualcomm-ai-research/transformer-quantization}.", 26 | } 27 | ``` 28 | 29 | ## How to install 30 | First, ensure locale variables are set as follows: 31 | ```bash 32 | export LC_ALL=C.UTF-8 33 | export LANG=C.UTF-8 34 | ``` 35 | 36 | Second, make sure to have Python ≥3.6 (tested with Python 3.6.8) and 37 | ensure the latest version of `pip` (tested with 21.2.4): 38 | ```bash 39 | pip install --upgrade --no-deps pip 40 | ``` 41 | 42 | Next, install PyTorch 1.4.0 with the appropriate CUDA version (tested with CUDA 10.0, CuDNN 7.6.3): 43 | ```bash 44 | pip install torch==1.4.0 torchvision==0.5.0 -f https://download.pytorch.org/whl/torch_stable.html 45 | ``` 46 | 47 | Finally, install the remaining dependencies using pip: 48 | ```bash 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | To run the code, the project root directory needs to be added to your pythonpath: 53 | ```bash 54 | export PYTHONPATH="${PYTHONPATH}:/path/to/this/dir" 55 | ``` 56 | 57 | ## Running experiments 58 | The main run file to reproduce all experiments is `main.py`. 59 | It contains 4 commands to train and validate FP32 and quantized model: 60 | ```bash 61 | Usage: main.py [OPTIONS] COMMAND [ARGS]... 62 | 63 | Options: 64 | --help Show this message and exit. 65 | 66 | Commands: 67 | train-baseline 68 | train-quantized 69 | validate-baseline 70 | validate-quantized 71 | ``` 72 | You can see the full list of options for each command using `python main.py [COMMAND] --help`. 73 | 74 | ### A. FP32 fine-tuning 75 | To start with, you need to get the fune-tuned model(s) for the GLUE task of interest. 76 | Example run command for fine-tuning: 77 | ```bash 78 | python main.py train-baseline --cuda --save-model --model-name bert_base_uncased --task rte \ 79 | --learning-rate 3e-05 --batch-size 8 --eval-batch-size 8 --num-epochs 3 --max-seq-length 128 \ 80 | --seed 1000 --output-dir /path/to/output/dir/ 81 | ``` 82 | You can also do it directly using HuggingFace library [[examples]](https://github.com/huggingface/transformers/tree/master/examples/pytorch/text-classification). 83 | In all experiments we used seeds 1000 - 1004 and reported the median score. 84 | The sample output directory looks as follows: 85 | ```bash 86 | /path/to/output/dir 87 | ├── config.out 88 | ├── eval_results_rte.txt 89 | ├── final_score.txt 90 | ├── out 91 | │   ├── config.json # Huggingface model config 92 | │   ├── pytorch_model.bin # PyTorch model checkpoint 93 | │   ├── special_tokens_map.json 94 | │   ├── tokenizer_config.json # Huggingface tokenizer config 95 | │   ├── training_args.bin 96 | │   └── vocab.txt # Vocabulary 97 | └── tb_logs # TensorBoard logs 98 | ├── 1632747625.1250594 99 | │   └── events.out.tfevents.* 100 | └── events.out.tfevents.* 101 | ``` 102 | 103 | For validation (both full-precision and quantized), it is assumed that these output directories with the fine-tuned 104 | checkpoints are aranged as follows (you can also use a subset of GLUE tasks): 105 | ```bash 106 | /path/to/saved_models/ 107 | ├── rte/rte_model_dir 108 | │   ├── out 109 | │   │   ├── config.json # Huggingface model config 110 | │   │   ├── pytorch_model.bin # PyTorch model checkpoint 111 | │   │   ├── tokenizer_config.json # Huggingface tokenizer config 112 | │   │   ├── vocab.txt # Vocabulary 113 | │   │   ├── (...) 114 | ├── cola/cola_model_dir 115 | │   ├── out 116 | │   │   ├── (...) 117 | ├── mnli/mnli_model_dir 118 | │   ├── out 119 | │   │   ├── (...) 120 | ├── mrpc/mrpc_model_dir 121 | │   ├── out 122 | │   │   ├── (...) 123 | ├── qnli/qnli_model_dir 124 | │   ├── out 125 | │   │   ├── (...) 126 | ├── qqp/qqp_model_dir 127 | │   ├── out 128 | │   │   ├── (...) 129 | ├── sst2/sst2_model_dir 130 | │   ├── out 131 | │   │   ├── (...) 132 | └── stsb/stsb_model_dir 133 | ├── out 134 | │   ├── (...) 135 | ``` 136 | Note, that you have to create this file structure manually. 137 | 138 | The model can then be validated as follows: 139 | ```bash 140 | python main.py validate-baseline --eval-batch-size 32 --seed 1000 --model-name bert_base_uncased \ 141 | --model-path /path/to/saved_models/ --task rte 142 | ``` 143 | You can also validate multiple or all checkpoints by specifying 144 | `--task --task [...]` or `--task all`, respectively. 145 | 146 | ### B. Post-training quantization (PTQ) 147 | 148 | #### 1) Standard (naïve) W8A8 per-tensor PTQ / base run command for all PTQ experiments 149 | ```bash 150 | python main.py validate-quantized --act-quant --weight-quant --no-pad-to-max-length \ 151 | --est-ranges-no-pad --eval-batch-size 16 --seed 1000 --model-path /path/to/saved_models/ \ 152 | --task rte --n-bits 8 --n-bits-act 8 --qmethod symmetric_uniform \ 153 | --qmethod-act asymmetric_uniform --weight-quant-method MSE --weight-opt-method golden_section \ 154 | --act-quant-method current_minmax --est-ranges-batch-size 1 --num-est-batches 1 \ 155 | --quant-setup all 156 | ``` 157 | Note that the range estimation settings are slightly different for each task. 158 | 159 | #### 2) Mixed precision W8A{8,16} PTQ 160 | Specify `--quant-dict "{'y': 16, 'h': 16, 'x': 16}"`: 161 | * `'x': 16` will set FFN's input to 16-bit 162 | * `'h': 16` will set FFN's output to 16-bit 163 | * `'y': 16` will set FFN's residual sum to 16-bit 164 | 165 | For STS-B regression task, you will need to specify `--quant-dict "{'y': 16, 'h': 16, 'x': 16, 'P': 16, 'C': 16}"` 166 | and `--quant-setup MSE_logits`, which will also quantize pooler and the final classifier to 16-bit and use MSE estimator for the output. 167 | 168 | #### 3) Per-embedding and per-embedding-group (PEG) activation quantization 169 | * `--per-embd` -- Per-embedding quantization for all activations 170 | * `--per-groups [N_GROUPS]` -- PEG quantization for all activations, no permutation 171 | * `--per-groups [N_GROUPS] --per-groups-permute` -- PEG quantization for all activations, apply range-based permutation (separate for each quantizer) 172 | * `--quant-dict "{'y': 'ng6', 'h': 'ng6', 'x': 'ng6'}"` -- PEG quantization using 6 groups for FFN's input, output and residual sum, no permutation 173 | * `--quant-dict "{'y': 'ngp6', 'h': 'ngp6', 'x': 'ngp6'}" --per-groups-permute-shared-h` -- PEG quantization using 6 groups for FFN's input, output and residual sum, apply range-based permutation (shared between tensors in the same layer) 174 | 175 | #### 4) W4A32 PTQ with AdaRound 176 | ```bash 177 | python main.py validate-quantized --weight-quant --no-act-quant --no-pad-to-max-length \ 178 | --est-ranges-no-pad --eval-batch-size 16 --seed 1000 --model-path /path/to/saved_models/ \ 179 | --task rte --qmethod symmetric_uniform --qmethod-act asymmetric_uniform --n-bits 4 \ 180 | --weight-quant-method MSE --weight-opt-method grid --num-candidates 100 --quant-setup all \ 181 | --adaround all --adaround-num-samples 1024 --adaround-init range_estimator \ 182 | --adaround-mode learned_hard_sigmoid --adaround-asym --adaround-iters 10000 \ 183 | --adaround-act-quant no_act_quant 184 | ``` 185 | 186 | ### C. Quantization-aware training (QAT) 187 | Base run command for QAT experiments (using W4A8 for example): 188 | ```bash 189 | python main.py train-quantized --cuda --do-eval --logging-first-step --weight-quant --act-quant \ 190 | --pad-to-max-length --learn-ranges --tqdm --batch-size 8 --seed 1000 \ 191 | --model-name bert_base_uncased --learning-rate 5e-05 --num-epochs 6 --warmup-steps 186 \ 192 | --weight-decay 0.0 --attn-dropout 0.0 --hidden-dropout 0.0 --max-seq-length 128 --n-bits 4 \ 193 | --n-bits-act 8 --qmethod symmetric_uniform --qmethod-act asymmetric_uniform \ 194 | --weight-quant-method MSE --weight-opt-method golden_section --act-quant-method current_minmax \ 195 | --est-ranges-batch-size 16 --num-est-batches 1 --quant-setup all \ 196 | --model-path /path/to/saved_models/rte/out --task rte --output-dir /path/to/qat_output/dir 197 | ``` 198 | Note that the settings are slightly different for each task (see Appendix). 199 | 200 | To run mixed-precision QAT with 2-bit embeddings and 4-bit weights, add `--quant-dict "{'Et': 2}"`. 201 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .quantized_bert import QuantizedBertForSequenceClassification 5 | from .quantized_mobilebert import QuantizedMobileBertForSequenceClassification 6 | from .quantized_roberta import QuantizedRobertaForSequenceClassification 7 | -------------------------------------------------------------------------------- /models/quantized_bert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import CrossEntropyLoss, MSELoss 10 | from transformers.models.bert.modeling_bert import ( 11 | BertLayer, 12 | BertSelfAttention, 13 | BertSelfOutput, 14 | BaseModelOutputWithPoolingAndCrossAttentions, 15 | ) 16 | from transformers.modeling_outputs import SequenceClassifierOutput 17 | from transformers.modeling_utils import ModuleUtilsMixin, apply_chunking_to_forward 18 | 19 | from quantization.autoquant_utils import quantize_model 20 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts 21 | from quantization.base_quantized_model import QuantizedModel 22 | from quantization.range_estimators import RangeEstimators, OptMethod 23 | from utils import _tb_advance_global_step, _tb_advance_token_counters, _tb_hist 24 | 25 | 26 | class QuantizedBertEmbeddings(QuantizedModel): 27 | def __init__(self, org_model, **quant_params): 28 | self.quant_dict = quant_params['quant_dict'] 29 | 30 | super().__init__() 31 | 32 | quant_params_ = quant_params.copy() 33 | if 'Et' in self.quant_dict: 34 | quant_params_['weight_range_method'] = RangeEstimators.MSE 35 | quant_params_['weight_range_options'] = dict(opt_method=OptMethod.golden_section) 36 | self.word_embeddings = quantize_model(org_model.word_embeddings, **quant_params_) 37 | 38 | self.position_embeddings = quantize_model(org_model.position_embeddings, **quant_params) 39 | self.token_type_embeddings = quantize_model(org_model.token_type_embeddings, **quant_params) 40 | 41 | self.dropout = org_model.dropout 42 | 43 | position_ids = org_model.position_ids 44 | if position_ids is not None: 45 | self.register_buffer('position_ids', position_ids) 46 | else: 47 | self.position_ids = position_ids 48 | 49 | self.position_embedding_type = getattr(org_model, 'position_embedding_type', 'absolute') 50 | 51 | # Activation quantizers 52 | self.sum_input_token_type_embd_act_quantizer = QuantizedActivation(**quant_params) 53 | self.sum_pos_embd_act_quantizer = QuantizedActivation(**quant_params) 54 | 55 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be 56 | # able to load any TensorFlow checkpoint file 57 | self.LayerNorm = quantize_model(org_model.LayerNorm, **quant_params) 58 | 59 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): 60 | if input_ids is not None: 61 | input_shape = input_ids.size() 62 | else: 63 | input_shape = inputs_embeds.size()[:-1] 64 | 65 | seq_length = input_shape[1] 66 | 67 | if position_ids is None: 68 | position_ids = self.position_ids[:, :seq_length] 69 | 70 | if token_type_ids is None: 71 | token_type_ids = torch.zeros( 72 | input_shape, dtype=torch.long, device=self.position_ids.device 73 | ) 74 | if inputs_embeds is None: 75 | inputs_embeds = self.word_embeddings(input_ids) 76 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 77 | 78 | embeddings = inputs_embeds + token_type_embeddings 79 | embeddings = self.sum_input_token_type_embd_act_quantizer(embeddings) 80 | 81 | if self.position_embedding_type == 'absolute': 82 | position_embeddings = self.position_embeddings(position_ids) 83 | embeddings += position_embeddings 84 | embeddings = self.sum_pos_embd_act_quantizer(embeddings) 85 | 86 | embeddings = self.LayerNorm(embeddings) 87 | embeddings = self.dropout(embeddings) 88 | return embeddings 89 | 90 | 91 | class QuantizedBertSelfAttention(QuantizedModel): 92 | def __init__(self, org_model, **quant_params): 93 | self.quant_dict = quant_params['quant_dict'] 94 | 95 | super().__init__() 96 | 97 | # copy attributes 98 | self.num_attention_heads = org_model.num_attention_heads 99 | self.attention_head_size = org_model.attention_head_size 100 | self.all_head_size = org_model.all_head_size 101 | 102 | self.position_embedding_type = getattr(org_model, 'position_embedding_type', None) 103 | if self.position_embedding_type in ('relative_key', 'relative_key_query'): 104 | raise NotImplementedError('current branch of computation is not yet supported') 105 | 106 | self.max_position_embeddings = org_model.max_position_embeddings 107 | self.distance_embedding = org_model.distance_embedding 108 | 109 | # quantized modules 110 | self.query = quantize_model(org_model.query, **quant_params) 111 | self.key = quantize_model(org_model.key, **quant_params) 112 | self.value = quantize_model(org_model.value, **quant_params) 113 | self.dropout = org_model.dropout 114 | 115 | # Activation quantizers 116 | self.attn_scores_act_quantizer = QuantizedActivation(**quant_params) 117 | self.attn_probs_act_quantizer = QuantizedActivation(**quant_params) 118 | self.context_act_quantizer = QuantizedActivation(**quant_params) 119 | 120 | def transpose_for_scores(self, x): 121 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 122 | x = x.view(*new_x_shape) 123 | return x.permute(0, 2, 1, 3) 124 | 125 | def forward( 126 | self, 127 | hidden_states, 128 | attention_mask=None, 129 | head_mask=None, 130 | encoder_hidden_states=None, 131 | encoder_attention_mask=None, 132 | past_key_value=None, 133 | output_attentions=False, 134 | ): 135 | mixed_query_layer = self.query(hidden_states) 136 | 137 | # If this is instantiated as a cross-attention module, the keys 138 | # and values come from an encoder; the attention mask needs to be 139 | # such that the encoder's padding tokens are not attended to. 140 | if encoder_hidden_states is not None: 141 | mixed_key_layer = self.key(encoder_hidden_states) 142 | mixed_value_layer = self.value(encoder_hidden_states) 143 | attention_mask = encoder_attention_mask 144 | else: 145 | mixed_key_layer = self.key(hidden_states) 146 | mixed_value_layer = self.value(hidden_states) 147 | 148 | query_layer = self.transpose_for_scores(mixed_query_layer) 149 | key_layer = self.transpose_for_scores(mixed_key_layer) 150 | value_layer = self.transpose_for_scores(mixed_value_layer) 151 | 152 | # Take the dot product between "query" and "key" to get the raw attention scores. 153 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 154 | attention_scores = self.attn_scores_act_quantizer(attention_scores) 155 | 156 | if self.position_embedding_type in ('relative_key', 'relative_key_query'): 157 | raise NotImplementedError('current branch of computation is not yet supported') 158 | 159 | seq_length = hidden_states.size()[1] 160 | position_ids_l = torch.arange( 161 | seq_length, dtype=torch.long, device=hidden_states.device 162 | ).view(-1, 1) 163 | position_ids_r = torch.arange( 164 | seq_length, dtype=torch.long, device=hidden_states.device 165 | ).view(1, -1) 166 | distance = position_ids_l - position_ids_r 167 | positional_embedding = self.distance_embedding( 168 | distance + self.max_position_embeddings - 1 169 | ) 170 | # fp16 compatibility: 171 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) 172 | 173 | if self.position_embedding_type == "relative_key": 174 | relative_position_scores = torch.einsum( 175 | "bhld,lrd->bhlr", query_layer, positional_embedding 176 | ) 177 | attention_scores = attention_scores + relative_position_scores 178 | elif self.position_embedding_type == "relative_key_query": 179 | relative_position_scores_query = torch.einsum( 180 | "bhld,lrd->bhlr", query_layer, positional_embedding 181 | ) 182 | relative_position_scores_key = torch.einsum( 183 | "bhrd,lrd->bhlr", key_layer, positional_embedding 184 | ) 185 | attention_scores = ( 186 | attention_scores + relative_position_scores_query + relative_position_scores_key 187 | ) 188 | 189 | # NOTE: factor 1/d^0.5 can be absorbed into the previous act. quant. delta 190 | attention_scores /= math.sqrt(self.attention_head_size) 191 | 192 | if attention_mask is not None: 193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() fn) 194 | attention_scores += attention_mask 195 | 196 | # Normalize the attention scores to probabilities. 197 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 198 | attention_probs = self.attn_probs_act_quantizer(attention_probs) 199 | 200 | # This is actually dropping out entire tokens to attend to, which might 201 | # seem a bit unusual, but is taken from the original Transformer paper. 202 | attention_probs = self.dropout(attention_probs) 203 | 204 | # Mask heads if we want to 205 | if head_mask is not None: 206 | attention_probs = attention_probs * head_mask 207 | 208 | context_layer = torch.matmul(attention_probs, value_layer) 209 | 210 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 211 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 212 | context_layer = context_layer.view(*new_context_layer_shape) 213 | context_layer = self.context_act_quantizer(context_layer) 214 | 215 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 216 | 217 | _tb_advance_global_step(self) 218 | return outputs 219 | 220 | 221 | class QuantizedBertSelfOutput(QuantizedModel): 222 | def __init__(self, org_model, **quant_params): 223 | self.quant_dict = quant_params['quant_dict'] 224 | 225 | # Exact same structure as for BertOutput. 226 | # Kept in order to be able to disable activation quantizer. 227 | super().__init__() 228 | 229 | self.dense = quantize_model(org_model.dense, **quant_params) 230 | self.dropout = org_model.dropout 231 | 232 | # Activation quantizer 233 | self.res_act_quantizer = QuantizedActivation(**quant_params) 234 | 235 | # LN 236 | self.LayerNorm = quantize_model(org_model.LayerNorm, **quant_params) 237 | 238 | def forward(self, hidden_states, input_tensor): 239 | hidden_states = self.dense(hidden_states) 240 | hidden_states = self.dropout(hidden_states) 241 | 242 | hidden_states = hidden_states + input_tensor 243 | hidden_states = self.res_act_quantizer(hidden_states) 244 | 245 | hidden_states = self.LayerNorm(hidden_states) 246 | 247 | _tb_advance_global_step(self) 248 | return hidden_states 249 | 250 | 251 | class QuantizedBertOutput(QuantizedModel): 252 | def __init__(self, org_model, **quant_params): 253 | self.quant_dict = quant_params['quant_dict'] 254 | 255 | super().__init__() 256 | 257 | self.dense = quantize_model(org_model.dense, **quant_params) 258 | self.dropout = org_model.dropout 259 | self.res_act_quantizer = QuantizedActivation(**quant_params) 260 | 261 | # LN 262 | self.LayerNorm = quantize_model(org_model.LayerNorm, **quant_params) 263 | 264 | def forward(self, hidden_states, input_tensor): 265 | hidden_states = self.dense(hidden_states) 266 | hidden_states = self.dropout(hidden_states) 267 | 268 | _tb_advance_token_counters(self, input_tensor) 269 | _tb_hist(self, input_tensor, 'res_output_x') 270 | _tb_hist(self, hidden_states, 'res_output_h') 271 | 272 | hidden_states = hidden_states + input_tensor 273 | 274 | _tb_hist(self, hidden_states, 'res_output_x_h') 275 | 276 | hidden_states = self.res_act_quantizer(hidden_states) 277 | hidden_states = self.LayerNorm(hidden_states) 278 | 279 | _tb_advance_global_step(self) 280 | return hidden_states 281 | 282 | 283 | def quantize_intermediate(org_module, **quant_params): 284 | m_dense = org_module.dense 285 | m_act = org_module.intermediate_act_fn 286 | if not isinstance(m_act, nn.Module): 287 | if m_act == F.gelu: 288 | m_act = nn.GELU() 289 | else: 290 | raise NotImplementedError() 291 | return quantize_model(nn.Sequential(m_dense, m_act), **quant_params) 292 | 293 | 294 | class QuantizedBertLayer(QuantizedModel): 295 | def __init__(self, org_model, **quant_params): 296 | self.quant_dict = quant_params['quant_dict'] 297 | 298 | super().__init__() 299 | 300 | # copy attributes 301 | self.chunk_size_feed_forward = org_model.chunk_size_feed_forward 302 | self.seq_len_dim = org_model.seq_len_dim 303 | self.is_decoder = org_model.is_decoder 304 | self.add_cross_attention = org_model.add_cross_attention 305 | 306 | # quantized components 307 | attention_specials = { 308 | BertSelfAttention: QuantizedBertSelfAttention, 309 | BertSelfOutput: QuantizedBertSelfOutput, 310 | } 311 | self.attention = quantize_model( 312 | org_model.attention, specials=attention_specials, **quant_params 313 | ) 314 | if self.add_cross_attention: 315 | self.crossattention = quantize_model( 316 | org_model.crossattention, specials=attention_specials, **quant_params 317 | ) 318 | self.intermediate = quantize_intermediate(org_model.intermediate, **quant_params) 319 | self.output = QuantizedBertOutput(org_model.output, **quant_params) 320 | 321 | def forward( 322 | self, 323 | hidden_states, 324 | attention_mask=None, 325 | head_mask=None, 326 | encoder_hidden_states=None, 327 | encoder_attention_mask=None, 328 | past_key_value=None, 329 | output_attentions=False, 330 | ): 331 | attn_args = (hidden_states, attention_mask, head_mask) 332 | attn_kw = dict(output_attentions=output_attentions) 333 | 334 | self_attention_outputs = self.attention(*attn_args, **attn_kw) 335 | 336 | attention_output = self_attention_outputs[0] 337 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 338 | 339 | if self.is_decoder and encoder_hidden_states is not None: 340 | raise NotImplementedError('current branch of computation is not yet supported') 341 | 342 | assert hasattr(self, "crossattention"), ( 343 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 344 | f"cross-attention layers by setting `config.add_cross_attention=True`" 345 | ) 346 | cross_attention_outputs = self.crossattention( 347 | attention_output, 348 | attention_mask, 349 | head_mask, 350 | encoder_hidden_states, 351 | encoder_attention_mask, 352 | output_attentions, 353 | ) 354 | attention_output = cross_attention_outputs[0] 355 | # add cross attentions if we output attention weights: 356 | outputs = outputs + cross_attention_outputs[1:] 357 | 358 | assert self.chunk_size_feed_forward == 0 # below call is a no-op in that case 359 | layer_output = apply_chunking_to_forward( 360 | self.feed_forward_chunk, 361 | self.chunk_size_feed_forward, 362 | self.seq_len_dim, 363 | attention_output, 364 | ) 365 | outputs = (layer_output,) + outputs 366 | return outputs 367 | 368 | def feed_forward_chunk(self, attention_output): 369 | intermediate_output = self.intermediate(attention_output) 370 | layer_output = self.output(intermediate_output, attention_output) 371 | return layer_output 372 | 373 | 374 | class QuantizedBertPooler(QuantizedModel): 375 | def __init__(self, org_model, **quant_params): 376 | super().__init__() 377 | 378 | self.dense_act = quantize_model( 379 | nn.Sequential(org_model.dense, org_model.activation), **quant_params 380 | ) 381 | 382 | def forward(self, hidden_states): 383 | # We "pool" the model by simply taking the hidden state corresponding 384 | # to the first token. 385 | first_token_tensor = hidden_states[:, 0] 386 | pooled_output = self.dense_act(first_token_tensor) 387 | 388 | _tb_advance_global_step(self) 389 | return pooled_output 390 | 391 | 392 | class QuantizedBertModel(QuantizedModel, ModuleUtilsMixin): 393 | def __init__(self, org_model, **quant_params): 394 | super().__init__() 395 | 396 | self.config = org_model.config 397 | 398 | self.embeddings = QuantizedBertEmbeddings(org_model.embeddings, **quant_params) 399 | self.encoder = quantize_model( 400 | org_model.encoder, specials={BertLayer: QuantizedBertLayer}, **quant_params 401 | ) 402 | self.pooler = ( 403 | QuantizedBertPooler(org_model.pooler, **quant_params) 404 | if org_model.pooler is not None 405 | else None 406 | ) 407 | 408 | def get_input_embeddings(self): 409 | return self.embeddings.word_embeddings 410 | 411 | def set_input_embeddings(self, value): 412 | self.embeddings.word_embeddings = value 413 | 414 | def _prune_heads(self, heads_to_prune): 415 | """ 416 | Prunes heads of the model. 417 | 418 | Parameters 419 | ---------- 420 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 421 | See base class PreTrainedModel. 422 | """ 423 | for layer, heads in heads_to_prune.items(): 424 | self.encoder.layer[layer].attention.prune_heads(heads) 425 | 426 | def forward( 427 | self, 428 | input_ids=None, 429 | attention_mask=None, 430 | token_type_ids=None, 431 | position_ids=None, 432 | head_mask=None, 433 | inputs_embeds=None, 434 | encoder_hidden_states=None, 435 | encoder_attention_mask=None, 436 | output_attentions=None, 437 | output_hidden_states=None, 438 | return_dict=None, 439 | ): 440 | output_attentions = ( 441 | output_attentions if output_attentions is not None else self.config.output_attentions 442 | ) 443 | output_hidden_states = ( 444 | output_hidden_states 445 | if output_hidden_states is not None 446 | else self.config.output_hidden_states 447 | ) 448 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 449 | 450 | if input_ids is not None and inputs_embeds is not None: 451 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 452 | elif input_ids is not None: 453 | input_shape = input_ids.size() 454 | elif inputs_embeds is not None: 455 | input_shape = inputs_embeds.size()[:-1] 456 | else: 457 | raise ValueError("You have to specify either input_ids or inputs_embeds") 458 | 459 | device = input_ids.device if input_ids is not None else inputs_embeds.device 460 | 461 | if attention_mask is None: 462 | attention_mask = torch.ones(input_shape, device=device) 463 | if token_type_ids is None: 464 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 465 | 466 | # We can provide a self-attention mask of dimensions 467 | # [batch_size, from_seq_length, to_seq_length] 468 | # ourselves in which case we just need to make it broadcastable to all heads. 469 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( 470 | attention_mask, input_shape, device 471 | ) 472 | 473 | # If a 2D or 3D attention mask is provided for the cross-attention 474 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 475 | if self.config.is_decoder and encoder_hidden_states is not None: 476 | raise NotImplementedError('current branch of computation is not yet supported') 477 | 478 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 479 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 480 | if encoder_attention_mask is None: 481 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 482 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 483 | else: 484 | encoder_extended_attention_mask = None 485 | 486 | # Prepare head mask if needed 487 | # 1.0 in head_mask indicate we keep the head 488 | # attention_probs has shape bsz x n_heads x N x N 489 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 490 | # and head_mask is converted to shape 491 | # [num_hidden_layers x batch x num_heads x seq_length x seq_length] 492 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 493 | 494 | embedding_output = self.embeddings( 495 | input_ids=input_ids, 496 | position_ids=position_ids, 497 | token_type_ids=token_type_ids, 498 | inputs_embeds=inputs_embeds, 499 | ) 500 | encoder_outputs = self.encoder( 501 | embedding_output, 502 | attention_mask=extended_attention_mask, 503 | head_mask=head_mask, 504 | encoder_hidden_states=encoder_hidden_states, 505 | encoder_attention_mask=encoder_extended_attention_mask, 506 | output_attentions=output_attentions, 507 | output_hidden_states=output_hidden_states, 508 | return_dict=return_dict, 509 | ) 510 | sequence_output = encoder_outputs[0] 511 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 512 | 513 | if not return_dict: 514 | return (sequence_output, pooled_output) + encoder_outputs[1:] 515 | 516 | return BaseModelOutputWithPoolingAndCrossAttentions( 517 | last_hidden_state=sequence_output, 518 | pooler_output=pooled_output, 519 | hidden_states=encoder_outputs.hidden_states, 520 | attentions=encoder_outputs.attentions, 521 | cross_attentions=encoder_outputs.cross_attentions, 522 | ) 523 | 524 | 525 | class QuantizedBertForSequenceClassification(QuantizedModel): 526 | def __init__(self, org_model, quant_setup=None, **quant_params): 527 | super().__init__() 528 | 529 | self.num_labels = org_model.num_labels 530 | self.config = org_model.config 531 | 532 | if hasattr(org_model, 'bert'): 533 | self.bert = QuantizedBertModel(org_model=org_model.bert, **quant_params) 534 | if hasattr(org_model, 'dropout'): 535 | self.dropout = org_model.dropout 536 | 537 | quant_params_ = quant_params.copy() 538 | 539 | if quant_setup == 'MSE_logits': 540 | quant_params_['act_range_method'] = RangeEstimators.MSE 541 | quant_params_['act_range_options'] = dict(opt_method=OptMethod.golden_section) 542 | self.classifier = quantize_model(org_model.classifier, **quant_params_) 543 | 544 | elif quant_setup == 'FP_logits': 545 | print('Do not quantize output of FC layer') 546 | 547 | self.classifier = quantize_model(org_model.classifier, **quant_params_) 548 | # no activation quantization of logits: 549 | self.classifier.activation_quantizer = FP32Acts() 550 | 551 | elif quant_setup == 'all': 552 | self.classifier = quantize_model(org_model.classifier, **quant_params_) 553 | 554 | else: 555 | raise ValueError("Quantization setup '{}' not supported.".format(quant_setup)) 556 | 557 | def forward( 558 | self, 559 | input_ids=None, 560 | attention_mask=None, 561 | token_type_ids=None, 562 | position_ids=None, 563 | head_mask=None, 564 | inputs_embeds=None, 565 | labels=None, 566 | output_attentions=None, 567 | output_hidden_states=None, 568 | return_dict=None, 569 | ): 570 | if isinstance(input_ids, tuple): 571 | if len(input_ids) == 2: 572 | input_ids, attention_mask = input_ids 573 | elif len(input_ids) == 3: 574 | input_ids, attention_mask, token_type_ids = input_ids 575 | elif len(input_ids) == 4: 576 | input_ids, attention_mask, token_type_ids, labels = input_ids 577 | else: 578 | raise ValueError('cannot interpret input tuple, use dict instead') 579 | 580 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 581 | 582 | outputs = self.bert( 583 | input_ids, 584 | attention_mask=attention_mask, 585 | token_type_ids=token_type_ids, 586 | position_ids=position_ids, 587 | head_mask=head_mask, 588 | inputs_embeds=inputs_embeds, 589 | output_attentions=output_attentions, 590 | output_hidden_states=output_hidden_states, 591 | return_dict=return_dict, 592 | ) 593 | 594 | pooled_output = outputs[1] 595 | pooled_output = self.dropout(pooled_output) 596 | 597 | logits = self.classifier(pooled_output) 598 | 599 | if self.num_labels == 1: 600 | logits = torch.clamp(logits, 0.0, 5.0) 601 | 602 | loss = None 603 | if labels is not None: 604 | if self.num_labels == 1: 605 | # We are doing regression 606 | loss_fct = MSELoss() 607 | loss = loss_fct(logits.view(-1), labels.view(-1)) 608 | else: 609 | loss_fct = CrossEntropyLoss() 610 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 611 | 612 | if not return_dict: 613 | output = (logits,) + outputs[2:] 614 | return ((loss,) + output) if loss is not None else output 615 | 616 | _tb_advance_global_step(self) 617 | return SequenceClassifierOutput( 618 | loss=loss, 619 | logits=logits, 620 | hidden_states=outputs.hidden_states, 621 | attentions=outputs.attentions, 622 | ) 623 | -------------------------------------------------------------------------------- /models/quantized_mobilebert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import CrossEntropyLoss, MSELoss 10 | 11 | from transformers.models.mobilebert.modeling_mobilebert import ( 12 | BaseModelOutputWithPooling, 13 | BottleneckLayer, 14 | FFNLayer, 15 | MobileBertLayer, 16 | MobileBertSelfAttention, 17 | MobileBertSelfOutput, 18 | NoNorm, 19 | ) 20 | from transformers.modeling_outputs import SequenceClassifierOutput 21 | from transformers.modeling_utils import ModuleUtilsMixin 22 | 23 | from quantization.autoquant_utils import quantize_model, quantize_module_list 24 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts 25 | from quantization.base_quantized_model import QuantizedModel 26 | from quantization.hijacker import QuantizationHijacker 27 | from quantization.range_estimators import RangeEstimators, OptMethod 28 | from utils import DotDict, _tb_advance_global_step, _tb_advance_token_counters, _tb_hist 29 | 30 | 31 | DEFAULT_QUANT_DICT = { 32 | # Embeddings 33 | 'sum_input_pos_embd': True, 34 | 'sum_token_type_embd': True, 35 | 36 | # Attention 37 | 'attn_scores': True, 38 | 'attn_probs': True, 39 | 'attn_probs_n_bits_act': None, 40 | 'attn_probs_act_range_method': None, 41 | 'attn_probs_act_range_options': None, 42 | 'attn_output': True, 43 | 44 | # Residual connections 45 | 'res_self_output': True, 46 | 'res_output': True, 47 | 'res_output_bottleneck': True, 48 | 'res_ffn_output': True, 49 | } 50 | 51 | 52 | def _make_quant_dict(partial_dict): 53 | quant_dict = DEFAULT_QUANT_DICT.copy() 54 | quant_dict.update(partial_dict) 55 | return DotDict(quant_dict) 56 | 57 | 58 | class QuantNoNorm(QuantizationHijacker): 59 | def __init__(self, org_model, *args, activation=None, **kwargs): 60 | super().__init__(*args, activation=activation, **kwargs) 61 | self.weight = org_model.weight 62 | self.bias = org_model.bias 63 | 64 | def forward(self, x, offsets=None): 65 | weight, bias = self.weight, self.bias 66 | if self._quant_w: 67 | weight = self.weight_quantizer(weight) 68 | bias = self.weight_quantizer(bias) 69 | 70 | res = x * weight + bias 71 | res = self.quantize_activations(res) 72 | return res 73 | 74 | 75 | class QuantizedMobileBertEmbeddings(QuantizedModel): 76 | def __init__(self, org_model, **quant_params): 77 | super().__init__() 78 | 79 | # copy attributes 80 | self.trigram_input = org_model.trigram_input 81 | self.embedding_size = org_model.embedding_size 82 | self.hidden_size = org_model.hidden_size 83 | 84 | # quantized modules 85 | self.word_embeddings = quantize_model(org_model.word_embeddings, **quant_params) 86 | self.position_embeddings = quantize_model(org_model.position_embeddings, **quant_params) 87 | self.token_type_embeddings = quantize_model(org_model.token_type_embeddings, **quant_params) 88 | 89 | self.embedding_transformation = quantize_model( 90 | org_model.embedding_transformation, **quant_params 91 | ) 92 | 93 | assert isinstance(org_model.LayerNorm, NoNorm) 94 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params) 95 | 96 | self.dropout = org_model.dropout 97 | 98 | position_ids = org_model.position_ids 99 | if position_ids is not None: 100 | self.register_buffer('position_ids', position_ids) 101 | else: 102 | self.position_ids = position_ids 103 | 104 | # activation quantizers 105 | self.quant_dict = _make_quant_dict(quant_params['quant_dict']) 106 | self.sum_input_pos_embd_act_quantizer = ( 107 | QuantizedActivation(**quant_params) 108 | if self.quant_dict.sum_input_pos_embd 109 | else FP32Acts() 110 | ) 111 | self.sum_token_type_embd_act_quantizer = ( 112 | QuantizedActivation(**quant_params) 113 | if self.quant_dict.sum_token_type_embd 114 | else FP32Acts() 115 | ) 116 | 117 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): 118 | if input_ids is not None: 119 | input_shape = input_ids.size() 120 | else: 121 | input_shape = inputs_embeds.size()[:-1] 122 | 123 | seq_length = input_shape[1] 124 | 125 | if position_ids is None: 126 | position_ids = self.position_ids[:, :seq_length] 127 | 128 | if token_type_ids is None: 129 | token_type_ids = torch.zeros( 130 | input_shape, dtype=torch.long, device=self.position_ids.device 131 | ) 132 | if inputs_embeds is None: 133 | inputs_embeds = self.word_embeddings(input_ids) # (B, T, 128) 134 | 135 | if self.trigram_input: 136 | # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited 137 | # Devices (https://arxiv.org/abs/2004.02984) 138 | # 139 | # The embedding table in BERT models accounts for a substantial proportion of model size. To compress 140 | # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT. 141 | # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512 142 | # dimensional output. 143 | inputs_embeds = torch.cat( 144 | [ 145 | F.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0), 146 | inputs_embeds, 147 | F.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0), 148 | ], 149 | dim=2, 150 | ) # (B, T, 384) 151 | 152 | if self.trigram_input or self.embedding_size != self.hidden_size: 153 | inputs_embeds = self.embedding_transformation(inputs_embeds) # (B, T, 512) 154 | 155 | # Add positional embeddings and token type embeddings, then layer # normalize and 156 | # perform dropout. 157 | position_embeddings = self.position_embeddings(position_ids) 158 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 159 | 160 | embeddings = self.sum_input_pos_embd_act_quantizer(inputs_embeds + position_embeddings) 161 | embeddings = self.sum_token_type_embd_act_quantizer(embeddings + token_type_embeddings) 162 | embeddings = self.LayerNorm(embeddings) 163 | embeddings = self.dropout(embeddings) 164 | return embeddings 165 | 166 | 167 | class QuantizedMobileBertSelfAttention(QuantizedModel): 168 | def __init__(self, org_model, **quant_params): 169 | super().__init__() 170 | 171 | # copy attributes 172 | self.num_attention_heads = org_model.num_attention_heads 173 | self.attention_head_size = org_model.attention_head_size 174 | self.all_head_size = org_model.all_head_size 175 | 176 | # quantized modules 177 | self.query = quantize_model(org_model.query, **quant_params) 178 | self.key = quantize_model(org_model.key, **quant_params) 179 | self.value = quantize_model(org_model.value, **quant_params) 180 | self.dropout = org_model.dropout 181 | 182 | # activation quantizers 183 | self.quant_dict = _make_quant_dict(quant_params['quant_dict']) 184 | self.attn_scores_act_quantizer = ( 185 | QuantizedActivation(**quant_params) if self.quant_dict.attn_scores else FP32Acts() 186 | ) 187 | 188 | quant_params_ = quant_params.copy() 189 | if self.quant_dict.attn_probs_n_bits_act is not None: 190 | quant_params_['n_bits_act'] = self.quant_dict.attn_probs_n_bits_act 191 | if self.quant_dict.attn_probs_act_range_method is not None: 192 | quant_params_['act_range_method'] = RangeEstimators[ 193 | self.quant_dict.attn_probs_act_range_method 194 | ] 195 | if self.quant_dict.attn_probs_act_range_options is not None: 196 | act_range_options = self.quant_dict.attn_probs_act_range_options 197 | if 'opt_method' in act_range_options and not isinstance(act_range_options['opt_method'], 198 | OptMethod): 199 | act_range_options['opt_method'] = OptMethod[act_range_options['opt_method']] 200 | quant_params_['act_range_options'] = act_range_options 201 | self.attn_probs_act_quantizer = ( 202 | QuantizedActivation(**quant_params_) if self.quant_dict.attn_probs else FP32Acts() 203 | ) 204 | 205 | self.attn_output_act_quantizer = ( 206 | QuantizedActivation(**quant_params) if self.quant_dict.attn_output else FP32Acts() 207 | ) 208 | 209 | def transpose_for_scores(self, x): 210 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 211 | x = x.view(*new_x_shape) 212 | return x.permute(0, 2, 1, 3) 213 | 214 | def forward( 215 | self, 216 | query_tensor, 217 | key_tensor, 218 | value_tensor, 219 | attention_mask=None, 220 | head_mask=None, 221 | output_attentions=None, 222 | ): 223 | mixed_query_layer = self.query(query_tensor) 224 | mixed_key_layer = self.key(key_tensor) 225 | mixed_value_layer = self.value(value_tensor) 226 | 227 | query_layer = self.transpose_for_scores(mixed_query_layer) 228 | key_layer = self.transpose_for_scores(mixed_key_layer) 229 | value_layer = self.transpose_for_scores(mixed_value_layer) 230 | 231 | # Take the dot product between "query" and "key" to get the raw attention scores. 232 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 233 | attention_scores = self.attn_scores_act_quantizer(attention_scores) 234 | 235 | # NOTE: factor 1/d^0.5 can be absorbed into the previous act. quant. delta 236 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 237 | 238 | if attention_mask is not None: 239 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 240 | attention_scores = attention_scores + attention_mask 241 | 242 | # Normalize the attention scores to probabilities. 243 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 244 | attention_probs = self.attn_probs_act_quantizer(attention_probs) 245 | 246 | # This is actually dropping out entire tokens to attend to, which might 247 | # seem a bit unusual, but is taken from the original Transformer paper. 248 | attention_probs = self.dropout(attention_probs) 249 | 250 | # Mask heads if we want to 251 | if head_mask is not None: 252 | attention_probs = attention_probs * head_mask 253 | 254 | context_layer = torch.matmul(attention_probs, value_layer) 255 | context_layer = self.attn_output_act_quantizer(context_layer) 256 | 257 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 258 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 259 | context_layer = context_layer.view(*new_context_layer_shape) 260 | 261 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 262 | return outputs 263 | 264 | 265 | class QuantizedMobileBertSelfOutput(QuantizedModel): 266 | def __init__(self, org_model, **quant_params): 267 | super().__init__() 268 | 269 | # copy attributes 270 | self.use_bottleneck = org_model.use_bottleneck 271 | 272 | # quantized modules 273 | self.dense = quantize_model(org_model.dense, **quant_params) 274 | 275 | assert isinstance(org_model.LayerNorm, NoNorm) 276 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params) 277 | 278 | if not self.use_bottleneck: 279 | self.dropout = org_model.dropout 280 | 281 | # activation quantizers 282 | self.quant_dict = _make_quant_dict(quant_params['quant_dict']) 283 | self.res_act_quantizer = ( 284 | QuantizedActivation(**quant_params) if self.quant_dict.res_self_output else FP32Acts() 285 | ) 286 | 287 | def forward(self, hidden_states, residual_tensor): 288 | layer_outputs = self.dense(hidden_states) 289 | if not self.use_bottleneck: 290 | layer_outputs = self.dropout(layer_outputs) 291 | 292 | _tb_advance_token_counters(self, layer_outputs) 293 | _tb_hist(self, layer_outputs, 'res_self_output_h') 294 | _tb_hist(self, residual_tensor, 'res_self_output_x') 295 | 296 | layer_outputs = layer_outputs + residual_tensor 297 | 298 | _tb_hist(self, residual_tensor, 'res_self_output_x_h') 299 | 300 | layer_outputs = self.res_act_quantizer(layer_outputs) 301 | layer_outputs = self.LayerNorm(layer_outputs) 302 | 303 | _tb_advance_global_step(self) 304 | return layer_outputs 305 | 306 | 307 | def quantize_intermediate(org_module, **quant_params): 308 | m_dense = org_module.dense 309 | m_act = org_module.intermediate_act_fn 310 | if not isinstance(m_act, nn.Module): 311 | if m_act == F.gelu: 312 | m_act = nn.GELU() 313 | elif m_act == F.relu: 314 | m_act = nn.ReLU() 315 | else: 316 | raise NotImplementedError() 317 | return quantize_model(nn.Sequential(m_dense, m_act), **quant_params) 318 | 319 | 320 | class QuantizedOutputBottleneck(QuantizedModel): 321 | def __init__(self, org_model, **quant_params): 322 | super().__init__() 323 | 324 | self.dense = quantize_model(org_model.dense, **quant_params) 325 | assert isinstance(org_model.LayerNorm, NoNorm) 326 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params) 327 | self.dropout = org_model.dropout 328 | 329 | # activation quantizers 330 | self.quant_dict = _make_quant_dict(quant_params['quant_dict']) 331 | self.res_act_quantizer = ( 332 | QuantizedActivation(**quant_params) 333 | if self.quant_dict.res_output_bottleneck 334 | else FP32Acts() 335 | ) 336 | 337 | def forward(self, hidden_states, residual_tensor): 338 | layer_outputs = self.dense(hidden_states) 339 | layer_outputs = self.dropout(layer_outputs) 340 | 341 | _tb_advance_token_counters(self, layer_outputs) 342 | _tb_hist(self, layer_outputs, 'res_layer_h') 343 | _tb_hist(self, residual_tensor, 'res_layer_x') 344 | 345 | layer_outputs = layer_outputs + residual_tensor 346 | 347 | _tb_hist(self, layer_outputs, 'res_layer_x_h') 348 | 349 | layer_outputs = self.res_act_quantizer(layer_outputs) 350 | layer_outputs = self.LayerNorm(layer_outputs) 351 | 352 | _tb_advance_global_step(self) 353 | return layer_outputs 354 | 355 | 356 | class QuantizedMobileBertOutput(QuantizedModel): 357 | def __init__(self, org_model, **quant_params): 358 | super().__init__() 359 | 360 | # copy attributes 361 | self.use_bottleneck = org_model.use_bottleneck 362 | 363 | # quantized modules 364 | self.dense = quantize_model(org_model.dense, **quant_params) 365 | assert isinstance(org_model.LayerNorm, NoNorm) 366 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params) 367 | 368 | if not self.use_bottleneck: 369 | self.dropout = org_model.dropout 370 | else: 371 | self.bottleneck = QuantizedOutputBottleneck( 372 | org_model=org_model.bottleneck, **quant_params 373 | ) 374 | 375 | # activation quantizers 376 | self.quant_dict = _make_quant_dict(quant_params['quant_dict']) 377 | self.res_act_quantizer = ( 378 | QuantizedActivation(**quant_params) if self.quant_dict.res_output else FP32Acts() 379 | ) 380 | 381 | def forward(self, intermediate_states, residual_tensor_1, residual_tensor_2): 382 | layer_output = self.dense(intermediate_states) 383 | if not self.use_bottleneck: 384 | layer_output = self.dropout(layer_output) 385 | layer_output = layer_output + residual_tensor_1 386 | layer_output = self.res_act_quantizer(layer_output) 387 | layer_output = self.LayerNorm(layer_output) 388 | else: 389 | _tb_advance_token_counters(self, layer_output) 390 | _tb_hist(self, layer_output, 'res_interm_h') 391 | _tb_hist(self, residual_tensor_1, 'res_interm_x') 392 | 393 | layer_output = layer_output + residual_tensor_1 394 | 395 | _tb_hist(self, layer_output, 'res_interm_x_h') 396 | 397 | layer_output = self.res_act_quantizer(layer_output) 398 | layer_output = self.LayerNorm(layer_output) 399 | layer_output = self.bottleneck(layer_output, residual_tensor_2) 400 | 401 | _tb_advance_global_step(self) 402 | return layer_output 403 | 404 | 405 | class QuantizedBottleneckLayer(QuantizedModel): 406 | def __init__(self, org_model, **quant_params): 407 | super().__init__() 408 | 409 | self.dense = quantize_model(org_model.dense, **quant_params) 410 | assert isinstance(org_model.LayerNorm, NoNorm) 411 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params) 412 | 413 | def forward(self, hidden_states): 414 | layer_input = self.dense(hidden_states) 415 | layer_input = self.LayerNorm(layer_input) 416 | return layer_input 417 | 418 | 419 | class QuantizedFFNOutput(QuantizedModel): 420 | def __init__(self, org_model, **quant_params): 421 | super().__init__() 422 | 423 | self.dense = quantize_model(org_model.dense, **quant_params) 424 | assert isinstance(org_model.LayerNorm, NoNorm) 425 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params) 426 | 427 | # activation quantizers 428 | self.quant_dict = _make_quant_dict(quant_params['quant_dict']) 429 | self.res_act_quantizer = ( 430 | QuantizedActivation(**quant_params) if self.quant_dict.res_ffn_output else FP32Acts() 431 | ) 432 | 433 | def forward(self, hidden_states, residual_tensor): 434 | layer_outputs = self.dense(hidden_states) 435 | 436 | _tb_advance_token_counters(self, layer_outputs) 437 | num_ffn = self.ffn_idx + 1 438 | _tb_hist(self, layer_outputs, f'res_ffn{num_ffn}_h') 439 | _tb_hist(self, residual_tensor, f'res_ffn{num_ffn}_x') 440 | 441 | layer_outputs = layer_outputs + residual_tensor 442 | 443 | _tb_hist(self, layer_outputs, f'res_ffn{num_ffn}_x_h') 444 | 445 | layer_outputs = self.res_act_quantizer(layer_outputs) 446 | layer_outputs = self.LayerNorm(layer_outputs) 447 | 448 | _tb_advance_global_step(self) 449 | return layer_outputs 450 | 451 | 452 | class QuantizedFFNLayer(QuantizedModel): 453 | def __init__(self, org_model, **quant_params): 454 | super().__init__() 455 | 456 | self.intermediate = quantize_intermediate(org_model.intermediate, **quant_params) 457 | self.output = QuantizedFFNOutput(org_model.output, **quant_params) 458 | 459 | def forward(self, hidden_states): 460 | intermediate_output = self.intermediate(hidden_states) 461 | layer_outputs = self.output(intermediate_output, hidden_states) 462 | return layer_outputs 463 | 464 | 465 | class QuantizedMobileBertLayer(QuantizedModel): 466 | def __init__(self, org_model, **quant_params): 467 | super().__init__() 468 | 469 | # copy 470 | self.use_bottleneck = org_model.use_bottleneck 471 | self.num_feedforward_networks = org_model.num_feedforward_networks 472 | 473 | # quantized modules 474 | attention_specials = { 475 | MobileBertSelfAttention: QuantizedMobileBertSelfAttention, 476 | MobileBertSelfOutput: QuantizedMobileBertSelfOutput, 477 | } 478 | self.attention = quantize_model( 479 | org_model.attention, specials=attention_specials, **quant_params 480 | ) 481 | self.intermediate = quantize_intermediate(org_model.intermediate, **quant_params) 482 | self.output = QuantizedMobileBertOutput(org_model.output, **quant_params) 483 | 484 | if self.use_bottleneck: 485 | self.bottleneck = quantize_model( 486 | org_model.bottleneck, 487 | specials={BottleneckLayer: QuantizedBottleneckLayer}, 488 | **quant_params, 489 | ) 490 | if getattr(org_model, 'ffn', None) is not None: 491 | self.ffn = quantize_module_list( 492 | org_model.ffn, specials={FFNLayer: QuantizedFFNLayer}, **quant_params 493 | ) 494 | 495 | def forward( 496 | self, 497 | hidden_states, 498 | attention_mask=None, 499 | head_mask=None, 500 | output_attentions=None, 501 | ): 502 | if self.use_bottleneck: 503 | query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states) 504 | else: 505 | query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4 506 | 507 | self_attention_outputs = self.attention( 508 | query_tensor, 509 | key_tensor, 510 | value_tensor, 511 | layer_input, 512 | attention_mask, 513 | head_mask, 514 | output_attentions=output_attentions, 515 | ) 516 | attention_output = self_attention_outputs[0] 517 | s = (attention_output,) 518 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 519 | 520 | if self.num_feedforward_networks != 1: 521 | for i, ffn_module in enumerate(self.ffn): 522 | # attach index for TB vis 523 | for m in ffn_module.modules(): 524 | m.ffn_idx = i 525 | 526 | attention_output = ffn_module(attention_output) 527 | s += (attention_output,) 528 | 529 | intermediate_output = self.intermediate(attention_output) 530 | layer_output = self.output(intermediate_output, attention_output, hidden_states) 531 | outputs = ( 532 | (layer_output,) 533 | + outputs 534 | + ( 535 | torch.tensor(1000), 536 | query_tensor, 537 | key_tensor, 538 | value_tensor, 539 | layer_input, 540 | attention_output, 541 | intermediate_output, 542 | ) 543 | + s 544 | ) 545 | return outputs 546 | 547 | 548 | class QuantizedMobileBertPooler(QuantizedModel): 549 | def __init__(self, org_model, **quant_params): 550 | super().__init__() 551 | 552 | self.do_activate = org_model.do_activate 553 | if self.do_activate: 554 | self.dense_act = quantize_model( 555 | nn.Sequential(org_model.dense, nn.Tanh()), **quant_params 556 | ) 557 | 558 | def forward(self, hidden_states): 559 | # We "pool" the model by simply taking the hidden state corresponding 560 | # to the first token. 561 | first_token_tensor = hidden_states[:, 0] 562 | if not self.do_activate: 563 | return first_token_tensor 564 | else: 565 | pooled_output = self.dense_act(first_token_tensor) 566 | return pooled_output 567 | 568 | 569 | class QuantizedMobileBertModel(QuantizedModel, ModuleUtilsMixin): 570 | def __init__(self, org_model, **quant_params): 571 | super().__init__() 572 | 573 | self.config = org_model.config 574 | 575 | self.embeddings = QuantizedMobileBertEmbeddings(org_model.embeddings, **quant_params) 576 | self.encoder = quantize_model( 577 | org_model.encoder, specials={MobileBertLayer: QuantizedMobileBertLayer}, **quant_params 578 | ) 579 | self.pooler = ( 580 | QuantizedMobileBertPooler(org_model.pooler, **quant_params) 581 | if org_model.pooler is not None 582 | else None 583 | ) 584 | 585 | def get_input_embeddings(self): 586 | return self.embeddings.word_embeddings 587 | 588 | def set_input_embeddings(self, value): 589 | self.embeddings.word_embeddings = value 590 | 591 | def _prune_heads(self, heads_to_prune): 592 | """ 593 | Prunes heads of the model. 594 | 595 | Parameters 596 | ---------- 597 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 598 | See base class PreTrainedModel. 599 | """ 600 | for layer, heads in heads_to_prune.items(): 601 | self.encoder.layer[layer].attention.prune_heads(heads) 602 | 603 | def forward( 604 | self, 605 | input_ids=None, 606 | attention_mask=None, 607 | token_type_ids=None, 608 | position_ids=None, 609 | head_mask=None, 610 | inputs_embeds=None, 611 | output_hidden_states=None, 612 | output_attentions=None, 613 | return_dict=None, 614 | ): 615 | output_attentions = ( 616 | output_attentions if output_attentions is not None else self.config.output_attentions 617 | ) 618 | output_hidden_states = ( 619 | output_hidden_states 620 | if output_hidden_states is not None 621 | else self.config.output_hidden_states 622 | ) 623 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 624 | 625 | if input_ids is not None and inputs_embeds is not None: 626 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 627 | elif input_ids is not None: 628 | input_shape = input_ids.size() 629 | elif inputs_embeds is not None: 630 | input_shape = inputs_embeds.size()[:-1] 631 | else: 632 | raise ValueError("You have to specify either input_ids or inputs_embeds") 633 | 634 | device = input_ids.device if input_ids is not None else inputs_embeds.device 635 | 636 | if attention_mask is None: 637 | attention_mask = torch.ones(input_shape, device=device) 638 | if token_type_ids is None: 639 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 640 | 641 | # We can provide a self-attention mask of dimensions 642 | # [batch_size, from_seq_length, to_seq_length] 643 | # ourselves in which case we just need to make it broadcastable to all heads. 644 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( 645 | attention_mask, input_shape, self.device 646 | ) 647 | 648 | # Prepare head mask if needed 649 | # 1.0 in head_mask indicate we keep the head 650 | # attention_probs has shape bsz x n_heads x N x N 651 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 652 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 653 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 654 | 655 | embedding_output = self.embeddings( 656 | input_ids=input_ids, 657 | position_ids=position_ids, 658 | token_type_ids=token_type_ids, 659 | inputs_embeds=inputs_embeds, 660 | ) 661 | encoder_outputs = self.encoder( 662 | embedding_output, 663 | attention_mask=extended_attention_mask, 664 | head_mask=head_mask, 665 | output_attentions=output_attentions, 666 | output_hidden_states=output_hidden_states, 667 | return_dict=return_dict, 668 | ) 669 | 670 | sequence_output = encoder_outputs[0] 671 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 672 | 673 | if not return_dict: 674 | return (sequence_output, pooled_output) + encoder_outputs[1:] 675 | 676 | return BaseModelOutputWithPooling( 677 | last_hidden_state=sequence_output, 678 | pooler_output=pooled_output, 679 | hidden_states=encoder_outputs.hidden_states, 680 | attentions=encoder_outputs.attentions, 681 | ) 682 | 683 | 684 | class QuantizedMobileBertForSequenceClassification(QuantizedModel): 685 | def __init__(self, org_model, quant_setup=None, **quant_params): 686 | super().__init__() 687 | 688 | self.num_labels = org_model.num_labels 689 | self.config = org_model.config 690 | 691 | self.mobilebert = QuantizedMobileBertModel(org_model=org_model.mobilebert, **quant_params) 692 | self.dropout = org_model.dropout 693 | self.classifier = quantize_model(org_model.classifier, **quant_params) 694 | 695 | if quant_setup == 'FP_logits': 696 | print('Do not quantize output of FC layer') 697 | # no activation quantization of logits: 698 | self.classifier.activation_quantizer = FP32Acts() 699 | elif quant_setup is not None and quant_setup != 'all': 700 | raise ValueError("Quantization setup '{}' not supported.".format(quant_setup)) 701 | 702 | def forward( 703 | self, 704 | input_ids=None, 705 | attention_mask=None, 706 | token_type_ids=None, 707 | position_ids=None, 708 | head_mask=None, 709 | inputs_embeds=None, 710 | labels=None, 711 | output_attentions=None, 712 | output_hidden_states=None, 713 | return_dict=None, 714 | ): 715 | r""" 716 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 717 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 718 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 719 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 720 | """ 721 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 722 | 723 | outputs = self.mobilebert( 724 | input_ids, 725 | attention_mask=attention_mask, 726 | token_type_ids=token_type_ids, 727 | position_ids=position_ids, 728 | head_mask=head_mask, 729 | inputs_embeds=inputs_embeds, 730 | output_attentions=output_attentions, 731 | output_hidden_states=output_hidden_states, 732 | return_dict=return_dict, 733 | ) 734 | pooled_output = outputs[1] 735 | pooled_output = self.dropout(pooled_output) 736 | 737 | # NB: optionally can keep final logits un-quantized, if only used for prediction 738 | # (can be enabled via --quant-setup FP_logits) 739 | logits = self.classifier(pooled_output) 740 | 741 | loss = None 742 | if labels is not None: 743 | if self.num_labels == 1: 744 | # We are doing regression 745 | loss_fct = MSELoss() 746 | loss = loss_fct(logits.view(-1), labels.view(-1)) 747 | else: 748 | loss_fct = CrossEntropyLoss() 749 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 750 | 751 | if not return_dict: 752 | output = (logits,) + outputs[2:] 753 | return ((loss,) + output) if loss is not None else output 754 | 755 | return SequenceClassifierOutput( 756 | loss=loss, 757 | logits=logits, 758 | hidden_states=outputs.hidden_states, 759 | attentions=outputs.attentions, 760 | ) 761 | -------------------------------------------------------------------------------- /models/quantized_roberta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import torch 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | from transformers.modeling_outputs import SequenceClassifierOutput 7 | from transformers.models.roberta.modeling_roberta import ( 8 | RobertaSelfAttention, 9 | RobertaSelfOutput, 10 | RobertaLayer, 11 | ) 12 | 13 | from models.quantized_bert import ( 14 | QuantizedBertEmbeddings, 15 | QuantizedBertSelfAttention, 16 | QuantizedBertSelfOutput, 17 | QuantizedBertOutput, 18 | QuantizedBertLayer, 19 | QuantizedBertPooler, 20 | QuantizedBertModel, 21 | QuantizedBertForSequenceClassification, 22 | ) 23 | from quantization.autoquant_utils import quantize_model 24 | 25 | 26 | def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): 27 | """ 28 | Replace non-padding symbols with their position numbers. Position numbers begin at 29 | padding_idx+1. Padding symbols are ignored. This is modified from fairseq's 30 | `utils.make_positions`. 31 | 32 | Args: 33 | x: torch.Tensor x: 34 | 35 | Returns: torch.Tensor 36 | """ 37 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX 38 | # export and XLA. 39 | mask = input_ids.ne(padding_idx).int() 40 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 41 | return incremental_indices.long() + padding_idx 42 | 43 | 44 | class QuantizedRobertaEmbeddings(QuantizedBertEmbeddings): 45 | def __init__(self, org_model, **quant_params): 46 | super().__init__(org_model, **quant_params) 47 | 48 | self.padding_idx = org_model.padding_idx 49 | 50 | def create_position_ids_from_inputs_embeds(self, inputs_embeds): 51 | """We are provided embeddings directly. We cannot infer which are padded so just generate 52 | sequential position ids. 53 | """ 54 | input_shape = inputs_embeds.size()[:-1] 55 | sequence_length = input_shape[1] 56 | 57 | position_ids = torch.arange( 58 | self.padding_idx + 1, 59 | sequence_length + self.padding_idx + 1, 60 | dtype=torch.long, 61 | device=inputs_embeds.device, 62 | ) 63 | return position_ids.unsqueeze(0).expand(input_shape) 64 | 65 | def forward( 66 | self, 67 | input_ids=None, 68 | token_type_ids=None, 69 | position_ids=None, 70 | inputs_embeds=None, 71 | past_key_values_length=0, 72 | ): 73 | if position_ids is None: 74 | if input_ids is not None: 75 | # Create the position ids from the input token ids. Any padded tokens remain padded. 76 | position_ids = create_position_ids_from_input_ids( 77 | input_ids, self.padding_idx, past_key_values_length 78 | ).to(input_ids.device) 79 | else: 80 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) 81 | 82 | if input_ids is not None: 83 | input_shape = input_ids.size() 84 | else: 85 | input_shape = inputs_embeds.size()[:-1] 86 | 87 | if token_type_ids is None: 88 | token_type_ids = torch.zeros( 89 | input_shape, dtype=torch.long, device=self.position_ids.device 90 | ) 91 | 92 | if inputs_embeds is None: 93 | inputs_embeds = self.word_embeddings(input_ids) 94 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 95 | 96 | embeddings = inputs_embeds + token_type_embeddings 97 | embeddings = self.sum_input_token_type_embd_act_quantizer(embeddings) 98 | 99 | if self.position_embedding_type == "absolute": 100 | position_embeddings = self.position_embeddings(position_ids) 101 | embeddings += position_embeddings 102 | 103 | embeddings = self.sum_pos_embd_act_quantizer(embeddings) 104 | 105 | embeddings = self.LayerNorm(embeddings) 106 | embeddings = self.dropout(embeddings) 107 | return embeddings 108 | 109 | 110 | class QuantizedRobertaSelfAttention(QuantizedBertSelfAttention): 111 | pass 112 | 113 | 114 | class QuantizedRobertaSelfOutput(QuantizedBertSelfOutput): 115 | pass 116 | 117 | 118 | class QuantizedRobertaOutput(QuantizedBertOutput): 119 | pass 120 | 121 | 122 | class QuantizedRobertaLayer(QuantizedBertLayer): 123 | def __init__(self, org_model, **quant_params): 124 | super().__init__(org_model, **quant_params) 125 | 126 | # update quantized components 127 | attention_specials = { 128 | RobertaSelfAttention: QuantizedRobertaSelfAttention, 129 | RobertaSelfOutput: QuantizedRobertaSelfOutput, 130 | } 131 | self.attention = quantize_model( 132 | org_model.attention, specials=attention_specials, **quant_params 133 | ) 134 | if self.add_cross_attention: 135 | self.crossattention = quantize_model( 136 | org_model.crossattention, specials=attention_specials, **quant_params 137 | ) 138 | self.output = QuantizedRobertaOutput(org_model.output, **quant_params) 139 | 140 | 141 | class QuantizedRobertaPooler(QuantizedBertPooler): 142 | pass 143 | 144 | 145 | class QuantizedRobertaModel(QuantizedBertModel): 146 | def __init__(self, org_model, **quant_params): 147 | super().__init__(org_model, **quant_params) 148 | 149 | # update quantized components 150 | self.embeddings = QuantizedRobertaEmbeddings(org_model.embeddings, **quant_params) 151 | self.encoder = quantize_model( 152 | org_model.encoder, specials={RobertaLayer: QuantizedRobertaLayer}, **quant_params 153 | ) 154 | self.pooler = ( 155 | QuantizedRobertaPooler(org_model.pooler, **quant_params) 156 | if org_model.pooler is not None 157 | else None 158 | ) 159 | 160 | 161 | class QuantizedRobertaForSequenceClassification(QuantizedBertForSequenceClassification): 162 | def __init__(self, org_model, quant_setup=None, **quant_params): 163 | super().__init__(org_model, quant_setup=quant_setup, **quant_params) 164 | 165 | # update quantization components 166 | self.roberta = QuantizedRobertaModel(org_model=org_model.roberta, **quant_params) 167 | self.classifier = quantize_model(org_model.classifier, **quant_params) 168 | 169 | def forward( 170 | self, 171 | input_ids=None, 172 | attention_mask=None, 173 | token_type_ids=None, 174 | position_ids=None, 175 | head_mask=None, 176 | inputs_embeds=None, 177 | labels=None, 178 | output_attentions=None, 179 | output_hidden_states=None, 180 | return_dict=None, 181 | ): 182 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 183 | 184 | outputs = self.roberta( 185 | input_ids, 186 | attention_mask=attention_mask, 187 | token_type_ids=token_type_ids, 188 | position_ids=position_ids, 189 | head_mask=head_mask, 190 | inputs_embeds=inputs_embeds, 191 | output_attentions=output_attentions, 192 | output_hidden_states=output_hidden_states, 193 | return_dict=return_dict, 194 | ) 195 | sequence_output = outputs[0] 196 | 197 | # NOTE: optionally can keep final logits un-quantized, if only used for prediction 198 | # (can be enabled via --quant-setup FP_logits) 199 | logits = self.classifier(sequence_output) 200 | 201 | loss = None 202 | if labels is not None: 203 | if self.num_labels == 1: 204 | # We are doing regression 205 | loss_fct = MSELoss() 206 | loss = loss_fct(logits.view(-1), labels.view(-1)) 207 | else: 208 | loss_fct = CrossEntropyLoss() 209 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 210 | 211 | if not return_dict: 212 | output = (logits,) + outputs[2:] 213 | return ((loss,) + output) if loss is not None else output 214 | 215 | return SequenceClassifierOutput( 216 | loss=loss, 217 | logits=logits, 218 | hidden_states=outputs.hidden_states, 219 | attentions=outputs.attentions, 220 | ) 221 | -------------------------------------------------------------------------------- /quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | -------------------------------------------------------------------------------- /quantization/adaround/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from quantization.adaround.adaround import apply_adaround_to_layer 5 | from quantization.adaround.utils import ( 6 | AdaRoundInitMode, 7 | AdaRoundMode, 8 | AdaRoundActQuantMode, 9 | AdaRoundLossType, 10 | AdaRoundTempDecayType, 11 | ) 12 | -------------------------------------------------------------------------------- /quantization/adaround/adaround.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from math import ceil 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from quantization.adaround.quantizer import ADAROUND_QUANTIZER_MAP 12 | from quantization.adaround.utils import ( 13 | MODE_TO_LOSS_TYPE, 14 | AdaRoundInitMode, 15 | CombinedLoss, 16 | GetLayerInpOut, 17 | LayerOutputMSE, 18 | ) 19 | from utils.utils import DotDict 20 | 21 | 22 | # setup logger 23 | logger = logging.getLogger('AdaRound') 24 | logger.setLevel(logging.INFO) 25 | 26 | 27 | def apply_adaround_to_layer(model, layer, data_tensor, batch_size, act_quant, adaround_config, 28 | keep_gpu=True): 29 | """Apply AdaRound to a `layer` in the `model`.""" 30 | 31 | # disable caching of quantized params 32 | layer.caching = False 33 | 34 | # grid initialization 35 | if adaround_config.init == AdaRoundInitMode.range_estimator: 36 | pass # already initialized 37 | elif adaround_config.init == AdaRoundInitMode.mse: 38 | apply_mse_init(layer) 39 | elif adaround_config.init == AdaRoundInitMode.mse_out: 40 | apply_mse_out_init(model, layer, data_tensor, batch_size) 41 | elif adaround_config.init == AdaRoundInitMode.mse_out_asym: 42 | apply_mse_out_init(model, layer, data_tensor, batch_size, asym=True) 43 | else: 44 | raise ValueError(f'Unknown initialization for AdaRound: {adaround_config.init}') 45 | 46 | # activation function 47 | if not adaround_config.include_act_func: 48 | org_act_func = layer.activation_function 49 | layer.activation_function = None 50 | 51 | # replace quantizer with AdaRound quantizer 52 | org_w_quantizer = layer.weight_quantizer.quantizer 53 | org_w_quant_cls = org_w_quantizer.__class__ 54 | 55 | if not org_w_quant_cls in ADAROUND_QUANTIZER_MAP: 56 | raise NotImplementedError(f'AdaRound is not supported for "{org_w_quant_cls}"') 57 | 58 | new_w_quant_cls = ADAROUND_QUANTIZER_MAP[org_w_quant_cls] 59 | w_quantizer = new_w_quant_cls( 60 | n_bits=org_w_quantizer.n_bits, 61 | scale_domain=org_w_quantizer.scale_domain, 62 | per_channel=org_w_quantizer.per_channel, 63 | eps=org_w_quantizer.eps, 64 | ) 65 | w_quantizer.register_buffer('_delta', org_w_quantizer._delta) 66 | w_quantizer.register_buffer('_zero_float', org_w_quantizer._zero_float) 67 | if hasattr(org_w_quantizer, '_signed'): 68 | w_quantizer.register_buffer('_signed', org_w_quantizer._signed) 69 | layer.weight_quantizer.quantizer = w_quantizer 70 | 71 | # set AdaRound attributes 72 | w_quantizer.round_mode = adaround_config.round_mode 73 | w_quantizer.temperature = adaround_config.annealing[0] 74 | 75 | # single test (and init alpha) 76 | get_inp_out = GetLayerInpOut(model, layer, asym=adaround_config.asym, act_quant=act_quant) 77 | inp, out = get_inp_out(data_tensor[:batch_size]) 78 | loss_soft_before, loss_hard_before = _compute_and_display_local_losses( 79 | w_quantizer, layer, inp, out, infix='before optimization' 80 | ) 81 | w_quantizer.soft_targets = True 82 | 83 | # define loss 84 | loss_type = MODE_TO_LOSS_TYPE[w_quantizer.round_mode] 85 | loss_fn = CombinedLoss( 86 | quantizer=w_quantizer, 87 | loss_type=loss_type, 88 | weight=adaround_config.weight, 89 | max_count=adaround_config.iters, 90 | b_range=adaround_config.annealing, 91 | warmup=adaround_config.warmup, 92 | decay_type=adaround_config.decay_type, 93 | decay_shape=adaround_config.decay_shape, 94 | decay_start=adaround_config.decay_start, 95 | ) 96 | 97 | # define optimizer 98 | opt_params = [w_quantizer.alpha] 99 | optimizer = torch.optim.Adam(opt_params, lr=adaround_config.lr) 100 | 101 | # main loop 102 | optimize_local_loss( 103 | layer, 104 | get_inp_out, 105 | data_tensor, 106 | optimizer, 107 | loss_fn, 108 | batch_size, 109 | adaround_config.iters, 110 | keep_gpu=keep_gpu, 111 | ) 112 | 113 | # check afterwards 114 | logger.info(f'Local loss before optimization (hard quant): {loss_hard_before:.7f}') 115 | loss_soft_after, loss_hard_after = _compute_and_display_local_losses( 116 | w_quantizer, layer, inp, out, infix='after optimization' 117 | ) 118 | 119 | # set to hard decision up/down 120 | w_quantizer.soft_targets = False 121 | 122 | # restore original activation function 123 | if not adaround_config.include_act_func: 124 | layer.activation_function = org_act_func 125 | 126 | # restore caching of quantized params 127 | layer.caching = True 128 | 129 | # prepare output 130 | out = DotDict( 131 | loss_soft_before=loss_soft_before, 132 | loss_hard_before=loss_hard_before, 133 | loss_soft_after=loss_soft_after, 134 | loss_hard_after=loss_hard_after, 135 | ) 136 | return out 137 | 138 | 139 | def _compute_and_display_local_losses(quantizer, layer, inp, out, infix=''): 140 | org_soft_targets = quantizer.soft_targets 141 | 142 | quantizer.soft_targets = True 143 | out_soft_quant = layer(inp) 144 | quantizer.soft_targets = False 145 | out_hard_quant = layer(inp) 146 | 147 | soft_quant_loss = F.mse_loss(out_soft_quant, out) 148 | hard_quant_loss = F.mse_loss(out_hard_quant, out) 149 | 150 | if infix: 151 | infix = infix.strip() + ' ' 152 | 153 | logger.info(f'Local loss {infix}(soft quant): {soft_quant_loss:.7f}') 154 | logger.info(f'Local loss {infix}(hard quant): {hard_quant_loss:.7f}') 155 | 156 | quantizer.soft_targets = org_soft_targets 157 | return float(soft_quant_loss), float(hard_quant_loss) 158 | 159 | 160 | def apply_mse_init(layer): 161 | w = layer.weight 162 | q = layer.weight_quantizer.quantizer 163 | 164 | with torch.no_grad(): 165 | w_absmax = torch.max(w.max(), torch.abs(w.min())) 166 | best_score = np.inf 167 | best_max = w_absmax 168 | for i in range(80): 169 | s = w_absmax * (1.0 - 0.01 * i) 170 | q.set_quant_range(-s, s) 171 | score = F.mse_loss(w, q(w)).item() 172 | 173 | if score < best_score: 174 | best_score = score 175 | best_max = s 176 | 177 | logger.info(f'Finished: set max={best_max:.3f} (mse={best_score:.7f})') 178 | q.set_quant_range(-best_max, best_max) 179 | 180 | 181 | def apply_mse_out_init(model, layer, data_tensor, batch_size, asym=False): 182 | w = layer.weight 183 | q = layer.weight_quantizer.quantizer 184 | 185 | get_inp_out = GetLayerInpOut(model, layer, asym=asym) 186 | loss_fn = LayerOutputMSE(layer, get_inp_out, data_tensor, batch_size) 187 | 188 | with torch.no_grad(): 189 | w_absmax = torch.max(w.max(), torch.abs(w.min())) 190 | best_score = np.inf 191 | best_max = w_absmax 192 | for i in range(80): 193 | s = w_absmax * (1.0 - 0.01 * i) 194 | q.set_quant_range(-s, s) 195 | score = loss_fn() 196 | 197 | if score < best_score: 198 | best_score = score 199 | best_max = s 200 | logger.info(f'Finished: set max={best_max:.3f} (mse={best_score:.7f})') 201 | q.set_quant_range(-best_max, best_max) 202 | 203 | 204 | def optimize_local_loss(layer, get_inp_out, data_tensor, optimizer, loss_fn, batch_size, iters, 205 | use_cached_data=True, keep_gpu=True): 206 | """AdaRound optimization loop.""" 207 | if use_cached_data: 208 | logger.info('Caching data for local loss optimization') 209 | 210 | cached_batches = [] 211 | if keep_gpu: 212 | torch.cuda.empty_cache() 213 | with torch.no_grad(): 214 | for i in range(ceil(data_tensor.size(0) / batch_size)): 215 | cur_inp, cur_out = get_inp_out(data_tensor[i * batch_size:(i + 1) * batch_size]) 216 | cached_batches.append((cur_inp.cpu(), cur_out.cpu())) 217 | 218 | cached_inps = torch.cat([x[0] for x in cached_batches]) 219 | cached_outs = torch.cat([x[1] for x in cached_batches]) 220 | device = cur_inp.device 221 | 222 | del cached_batches 223 | if keep_gpu: # put all cached data on GPU for faster optimization 224 | torch.cuda.empty_cache() 225 | try: 226 | cached_inps = cached_inps.to(device) 227 | cached_outs = cached_outs.to(device) 228 | except RuntimeError as e: 229 | logger.warning( 230 | f"WARNING: could not cache training data on GPU, keep on CPU ({e})" 231 | ) 232 | cached_inps = cached_inps.cpu() 233 | cached_outs = cached_outs.cpu() 234 | 235 | for i in range(iters): 236 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 237 | if use_cached_data: 238 | cur_inp = cached_inps[idx].to(device) 239 | cur_out = cached_outs[idx].to(device) 240 | else: 241 | cur_inp, cur_out = get_inp_out(data_tensor[idx]) 242 | 243 | optimizer.zero_grad() 244 | 245 | try: 246 | out_quant = layer(cur_inp) 247 | loss = loss_fn(out_quant, cur_out) 248 | loss.backward() 249 | except RuntimeError as e: 250 | if use_cached_data and 'cuda' in str(cached_inps.device): 251 | logger.warning( 252 | f"WARNING: not enough CUDA memory for forward pass, " 253 | f"move cached data to CPU ({e})" 254 | ) 255 | cached_inps = cached_inps.cpu() 256 | cached_outs = cached_outs.cpu() 257 | else: 258 | raise e 259 | 260 | optimizer.step() 261 | -------------------------------------------------------------------------------- /quantization/adaround/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from quantization.adaround.utils import ( 5 | AdaRoundActQuantMode, 6 | AdaRoundInitMode, 7 | AdaRoundMode, 8 | AdaRoundTempDecayType, 9 | ) 10 | from utils.utils import DotDict 11 | 12 | 13 | class AdaRoundConfig(DotDict): 14 | pass 15 | 16 | 17 | DEFAULT_ADAROUND_CONFIG = AdaRoundConfig( 18 | # Base options 19 | layers=('all',), 20 | num_samples=1024, 21 | init=AdaRoundInitMode.range_estimator, 22 | 23 | # Method and continuous relaxation options 24 | round_mode=AdaRoundMode.learned_hard_sigmoid, 25 | asym=True, 26 | include_act_func=True, 27 | lr=1e-3, 28 | iters=1000, 29 | weight=0.01, 30 | annealing=(20, 2), 31 | decay_type=AdaRoundTempDecayType.cosine, 32 | decay_shape=1.0, 33 | decay_start=0.0, 34 | warmup=0.2, 35 | 36 | # Activation quantization 37 | act_quant_mode=AdaRoundActQuantMode.post_adaround, 38 | ) 39 | -------------------------------------------------------------------------------- /quantization/adaround/quantizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from quantization.quantizers import ( 10 | QuantizerBase, 11 | AsymmetricUniformQuantizer, 12 | SymmetricUniformQuantizer, 13 | ) 14 | from quantization.adaround.utils import AdaRoundMode 15 | 16 | 17 | # setup logger 18 | logger = logging.getLogger('AdaRound') 19 | logger.setLevel(logging.INFO) 20 | 21 | 22 | def logit(p, eps=1e-16): 23 | p = torch.clamp(p, eps, 1 - eps) 24 | return -torch.log(1 / p - 1) 25 | 26 | 27 | def hard_sigmoid(x, zeta=1.1, gamma=-0.1): 28 | p = torch.sigmoid(x) 29 | return torch.clamp(p * (zeta - gamma) + gamma, 0.0, 1.0) 30 | 31 | 32 | def hard_logit(p, zeta=1.1, gamma=-0.1): 33 | # NOTE: argument of log is between 1/11 and 11 (for default values of zeta and gamma) 34 | return -torch.log((zeta - p) / (p - gamma)) 35 | 36 | 37 | class AdaRoundQuantizer(QuantizerBase): 38 | def __init__(self, *args, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | 41 | self.alpha = None 42 | self.round_mode = AdaRoundMode.nearest 43 | self.soft_targets = False 44 | self.temperature = None # for sigmoid temperature annealing 45 | 46 | def to_integer_forward(self, x_float): 47 | if self.round_mode == AdaRoundMode.nearest: 48 | return super().to_integer_forward(x_float) 49 | 50 | if self.round_mode not in AdaRoundMode.RELAXATION: 51 | raise ValueError(f'Unknown rounding mode: {self.round_mode}') 52 | 53 | # cont. relaxation 54 | x = x_float / self.scale 55 | x_floor = torch.floor(x) 56 | 57 | # initialize alpha, if needed 58 | if self.alpha is None: 59 | logger.info('Init alpha to be FP32') 60 | 61 | rest = x - x_floor # rest of rounding [0, 1) 62 | if self.round_mode == AdaRoundMode.learned_sigmoid: 63 | alpha = logit(rest) # => sigmoid(alpha) = rest 64 | elif self.round_mode == AdaRoundMode.learned_hard_sigmoid: 65 | alpha = hard_logit(rest) # => hard_sigmoid(alpha) = rest 66 | elif self.round_mode == AdaRoundMode.sigmoid_temp_decay: 67 | alpha = self.temperature * logit(rest) # => sigmoid(alpha/temperature) = rest 68 | else: 69 | raise ValueError(f'Unknown rounding mode: {self.round_mode}') 70 | 71 | self.alpha = nn.Parameter(alpha, requires_grad=True) 72 | 73 | # compute final x_int 74 | x_int = x_floor + (self.get_rest() if self.soft_targets else (self.alpha >= 0).float()) 75 | 76 | if not self.symmetric: 77 | x_int += self.zero_point 78 | 79 | x_int = torch.clamp(x_int, self.int_min, self.int_max) 80 | return x_int 81 | 82 | def get_rest(self): 83 | if self.round_mode == AdaRoundMode.learned_sigmoid: 84 | return torch.sigmoid(self.alpha) 85 | elif self.round_mode == AdaRoundMode.learned_hard_sigmoid: 86 | return hard_sigmoid(self.alpha) 87 | elif self.round_mode == AdaRoundMode.sigmoid_temp_decay: 88 | return torch.sigmoid(self.alpha / self.temperature) 89 | else: 90 | raise ValueError(f'Unknown rounding mode: {self.round_mode}') 91 | 92 | def extra_repr(self): 93 | return ', '.join([ 94 | f'n_bits={self.n_bits}', 95 | f'per_channel={self.per_channel}', 96 | f'is_initialized={self.is_initialized}', 97 | f'round_mode={self.round_mode}', 98 | f'soft_targets={self.soft_targets}', 99 | f'temperature={self.temperature}', 100 | ]) 101 | 102 | 103 | class AdaRoundSymmetricUniformQuantizer(AdaRoundQuantizer, SymmetricUniformQuantizer): 104 | pass 105 | 106 | 107 | class AdaRoundAsymmetricUniformQuantizer(AdaRoundQuantizer, AsymmetricUniformQuantizer): 108 | pass 109 | 110 | 111 | ADAROUND_QUANTIZER_MAP = { 112 | SymmetricUniformQuantizer: AdaRoundSymmetricUniformQuantizer, 113 | AsymmetricUniformQuantizer: AdaRoundAsymmetricUniformQuantizer, 114 | } 115 | -------------------------------------------------------------------------------- /quantization/adaround/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from enum import Flag, auto 6 | from math import ceil 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from utils.utils import StopForwardException 13 | 14 | 15 | # setup logger 16 | logger = logging.getLogger('AdaRound') 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | def sigmoid(x): 21 | return (1.0 + np.exp(-x)) ** -1.0 22 | 23 | 24 | class BaseOption(Flag): 25 | def __str__(self): 26 | return self.name 27 | 28 | @property 29 | def cls(self): 30 | return self.value.cls 31 | 32 | @classmethod 33 | def list_names(cls): 34 | return [m.name for m in cls] 35 | 36 | 37 | class AdaRoundActQuantMode(BaseOption): 38 | # Activation quantization is disabled 39 | no_act_quant = auto() 40 | 41 | # AdaRound with FP32 acts, activations are quantized afterwards if applicable (default): 42 | post_adaround = auto() 43 | 44 | 45 | class AdaRoundInitMode(BaseOption): 46 | """Weight quantization grid initialization.""" 47 | 48 | range_estimator = auto() 49 | mse = auto() # old implementation 50 | mse_out = auto() 51 | mse_out_asym = auto() 52 | 53 | 54 | class AdaRoundLossType(BaseOption): 55 | """Regularization terms.""" 56 | 57 | relaxation = auto() 58 | temp_decay = auto() 59 | 60 | 61 | class AdaRoundMode(BaseOption): 62 | nearest = auto() # (default) 63 | 64 | # original AdaRound relaxation methods 65 | learned_sigmoid = auto() 66 | learned_hard_sigmoid = auto() 67 | sigmoid_temp_decay = auto() 68 | 69 | RELAXATION = learned_sigmoid | learned_hard_sigmoid | sigmoid_temp_decay 70 | 71 | @classmethod 72 | def list_names(cls): 73 | exclude = (AdaRoundMode.nearest, AdaRoundMode.RELAXATION) 74 | return [m.name for m in cls if not m in exclude] 75 | 76 | 77 | MODE_TO_LOSS_TYPE = { 78 | AdaRoundMode.learned_hard_sigmoid: AdaRoundLossType.relaxation, 79 | AdaRoundMode.learned_sigmoid: AdaRoundLossType.relaxation, 80 | AdaRoundMode.sigmoid_temp_decay: AdaRoundLossType.temp_decay, 81 | } 82 | 83 | 84 | class AdaRoundTempDecayType(BaseOption): 85 | linear = auto() 86 | cosine = auto() 87 | sigmoid = auto() # https://arxiv.org/abs/1811.09332 88 | power = auto() 89 | exp = auto() 90 | log = auto() 91 | 92 | 93 | class TempDecay: 94 | def __init__(self, t_max, b_range=(20.0, 2.0), rel_decay_start=0.0, 95 | decay_type=AdaRoundTempDecayType.linear, decay_shape=1.0): 96 | self.t_max = t_max 97 | self.start_b, self.end_b = b_range 98 | self.decay_type = decay_type 99 | self.decay_shape = decay_shape 100 | self.decay_start = rel_decay_start * t_max 101 | 102 | def __call__(self, t): 103 | if t < self.decay_start: 104 | return self.start_b 105 | 106 | rel_t = (t - self.decay_start) / (self.t_max - self.decay_start) 107 | if self.decay_type == AdaRoundTempDecayType.linear: 108 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) 109 | elif self.decay_type == AdaRoundTempDecayType.cosine: 110 | return self.end_b + 0.5 * (self.start_b - self.end_b) * (1 + np.cos(rel_t * np.pi)) 111 | elif self.decay_type == AdaRoundTempDecayType.sigmoid: 112 | d = self.decay_shape 113 | offset = sigmoid(-d / 2) 114 | rel_progress = (sigmoid(d * (rel_t - 0.5)) - offset) / (1 - 2 * offset) 115 | return self.start_b + (self.end_b - self.start_b) * rel_progress 116 | elif self.decay_type == AdaRoundTempDecayType.power: 117 | return self.end_b + (self.start_b - self.end_b) * (1 - rel_t ** self.decay_shape) 118 | elif self.decay_type == AdaRoundTempDecayType.exp: 119 | r = self.decay_shape 120 | rel_progress = (1.0 - np.exp(-r * rel_t)) / (1.0 - np.exp(-r)) 121 | return self.start_b + (self.end_b - self.start_b) * rel_progress 122 | elif self.decay_type == AdaRoundTempDecayType.log: 123 | r = self.decay_shape 124 | C = np.exp(self.end_b / r) 125 | c = np.exp(self.start_b / r) 126 | return r * np.log((C - c) * rel_t + c) 127 | else: 128 | raise ValueError(f'Unknown temp decay type {self.decay_type}') 129 | 130 | 131 | class CombinedLoss: 132 | def __init__(self, quantizer, loss_type=AdaRoundLossType.relaxation, weight=0.01, 133 | max_count=1000, b_range=(20, 2), warmup=0.0, decay_start=0.0, **temp_decay_kw): 134 | self.quantizer = quantizer 135 | self.loss_type = loss_type 136 | self.weight = weight 137 | 138 | self.loss_start = max_count * warmup 139 | self.temp_decay = TempDecay( 140 | max_count, 141 | b_range=b_range, 142 | rel_decay_start=warmup + (1.0 - warmup) * decay_start, 143 | **temp_decay_kw, 144 | ) 145 | self.iter = 0 146 | 147 | def __call__(self, pred, tgt, *args, **kwargs): 148 | self.iter += 1 149 | 150 | rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() 151 | 152 | if self.iter < self.loss_start: 153 | b = self.temp_decay(self.iter) 154 | round_loss = 0 155 | elif self.loss_type == AdaRoundLossType.temp_decay: 156 | b = self.temp_decay(self.iter) 157 | self.quantizer.temperature = b 158 | round_loss = 0 159 | elif self.loss_type == AdaRoundLossType.relaxation: # 1 - |(h-0.5)*2|**b 160 | b = self.temp_decay(self.iter) 161 | round_vals = self.quantizer.get_rest().view(-1) 162 | round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum() 163 | else: 164 | raise ValueError(f'Unknown loss type {self.loss_type}') 165 | 166 | total_loss = rec_loss + round_loss 167 | if self.iter == 1 or self.iter % 100 == 0: 168 | logger.info( 169 | f'Total loss:\t{total_loss:.4f} (rec:{rec_loss:.4f}, ' 170 | f'round:{round_loss:.3f})\tb={b:.2f}\titer={self.iter}' 171 | ) 172 | return total_loss 173 | 174 | 175 | class StopForwardHook: 176 | def __call__(self, module, *args): 177 | raise StopForwardException 178 | 179 | 180 | class DataSaverHook: 181 | def __init__(self, store_input=False, store_output=False, stop_forward=False): 182 | self.store_input = store_input 183 | self.store_output = store_output 184 | self.stop_forward = stop_forward 185 | 186 | self.input_store = None 187 | self.output_store = None 188 | 189 | def __call__(self, module, input_batch, output_batch): 190 | if self.store_input: 191 | self.input_store = input_batch 192 | if self.store_output: 193 | self.output_store = output_batch 194 | if self.stop_forward: 195 | raise StopForwardException 196 | 197 | 198 | class GetLayerInpOut: 199 | def __init__(self, model, layer, asym=False, act_quant=False, store_output=True): 200 | self.model = model 201 | self.layer = layer 202 | self.asym = asym 203 | self.device = layer.weight.device 204 | self.act_quant = act_quant 205 | self.store_output = store_output 206 | self.data_saver = DataSaverHook( 207 | store_input=True, store_output=self.store_output, stop_forward=True 208 | ) 209 | 210 | def __call__(self, model_input): 211 | self.model.full_precision() 212 | handle = self.layer.register_forward_hook(self.data_saver) 213 | 214 | with torch.no_grad(): 215 | try: 216 | _ = self.model(model_input.to(self.device)) 217 | except StopForwardException: 218 | pass 219 | 220 | if self.asym: # recalculate input with network quantized 221 | self.data_saver.store_output = False 222 | self.model.set_quant_state(weight_quant=True, act_quant=self.act_quant) 223 | try: 224 | _ = self.model(model_input.to(self.device)) 225 | except StopForwardException: 226 | pass 227 | self.data_saver.store_output = True 228 | 229 | handle.remove() 230 | 231 | self.model.full_precision() 232 | self.layer.quantized_weights() 233 | return self.data_saver.input_store[0].detach(), self.data_saver.output_store.detach() 234 | 235 | 236 | class LayerOutputMSE: 237 | def __init__(self, layer, get_inp_out, data_tensor, batch_size, name='mse_out'): 238 | cur_inp, cur_out = get_inp_out(data_tensor) 239 | self.input = cur_inp 240 | self.exp_out = cur_out 241 | self.layer = layer 242 | self.batch_size = batch_size 243 | self.name = name 244 | 245 | def __call__(self): 246 | loss = 0.0 247 | x = self.input 248 | for i in range(ceil(x.size(0) / self.batch_size)): 249 | cur_out = self.layer(x[i * self.batch_size:(i + 1) * self.batch_size]) 250 | exp_out = self.exp_out[i * self.batch_size:(i + 1) * self.batch_size] 251 | loss += F.mse_loss(cur_out, exp_out).item() 252 | return loss 253 | -------------------------------------------------------------------------------- /quantization/autoquant_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import copy 5 | import warnings 6 | 7 | from torch.nn import functional as F 8 | from torch import nn 9 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd 10 | 11 | from quantization.base_quantized_classes import FP32Acts, QuantizedActivation, QuantizedModule 12 | from quantization.hijacker import QuantizationHijacker, activations_list 13 | from quantization.quantization_manager import QuantizationManager 14 | 15 | 16 | class QuantLinear(QuantizationHijacker, nn.Linear): 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | 20 | def run_forward(self, x, weight, bias, offsets=None): 21 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias) 22 | 23 | 24 | class QuantizedActivationWrapper(QuantizedActivation): 25 | """ 26 | Wraps over a layer and quantized the activation. 27 | It also allow for tying the input and output quantizer which is helpful 28 | for layers such Average Pooling. 29 | """ 30 | def __init__(self, layer, tie_activation_quantizers=False, 31 | input_quantizer: QuantizationManager = None, *args, **kwargs): 32 | super().__init__(*args, **kwargs) 33 | self.tie_activation_quantizers = tie_activation_quantizers 34 | if input_quantizer: 35 | assert isinstance(input_quantizer, QuantizationManager) 36 | self.activation_quantizer = input_quantizer 37 | self.layer = layer 38 | 39 | def quantize_activations_no_range_update(self, x): 40 | if self._quant_a: 41 | return self.activation_quantizer.quantizer(x) 42 | else: 43 | return x 44 | 45 | def forward(self, x): 46 | x = self.layer(x) 47 | if self.tie_activation_quantizers: 48 | # The input activation quantizer is used to quantize the activation 49 | # but without updating the quantization range 50 | return self.quantize_activations_no_range_update(x) 51 | else: 52 | return self.quantize_activations(x) 53 | 54 | 55 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm): 56 | def __init__(self, *args, activation=None, **kwargs): 57 | super().__init__(*args, activation=activation, **kwargs) 58 | 59 | def run_forward(self, x, weight, bias, offsets=None): 60 | return F.layer_norm( 61 | input=x.contiguous(), 62 | normalized_shape=self.normalized_shape, 63 | weight=weight.contiguous(), 64 | bias=bias.contiguous(), 65 | eps=self.eps, 66 | ) 67 | 68 | 69 | class QuantEmbedding(QuantizationHijacker, nn.Embedding): 70 | def __init__(self, *args, activation=None, **kwargs): 71 | super().__init__(*args, activation=activation, **kwargs) 72 | # NB: Embedding should not quantize activations, as it is simply a lookup table, 73 | # which is already quantized. 74 | self.activation_quantizer = FP32Acts() 75 | 76 | def run_forward(self, x, weight, bias, offsets=None): 77 | return F.embedding( 78 | input=x.contiguous(), 79 | weight=weight.contiguous(), 80 | padding_idx=self.padding_idx, 81 | max_norm=self.max_norm, 82 | norm_type=self.norm_type, 83 | scale_grad_by_freq=self.scale_grad_by_freq, 84 | sparse=self.sparse, 85 | ) 86 | 87 | 88 | module_map = { 89 | nn.Linear: QuantLinear, 90 | nn.LayerNorm: QuantLayerNorm, 91 | nn.Embedding: QuantEmbedding 92 | } 93 | 94 | 95 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd) 96 | 97 | 98 | def get_act(module, i): 99 | result, act_idx = None, None 100 | for i in range(i + 1, len(module)): 101 | if isinstance(module[i], tuple(activations_list)): 102 | result = module[i] 103 | act_idx = i 104 | break 105 | return result, act_idx 106 | 107 | 108 | def get_linear_args(module): 109 | args = dict( 110 | in_features=module.in_features, 111 | out_features=module.out_features, 112 | bias=module.bias is not None, 113 | ) 114 | return args 115 | 116 | 117 | def get_layernorm_args(module): 118 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps) 119 | return args 120 | 121 | 122 | def get_embedding_args(module): 123 | args = dict( 124 | num_embeddings=module.num_embeddings, 125 | embedding_dim=module.embedding_dim, 126 | padding_idx=module.padding_idx, 127 | max_norm=module.max_norm, 128 | norm_type=module.norm_type, 129 | scale_grad_by_freq=module.scale_grad_by_freq, 130 | sparse=module.sparse, 131 | ) 132 | return args 133 | 134 | 135 | def get_module_args(mod, act): 136 | if isinstance(mod, nn.Linear): 137 | kwargs = get_linear_args(mod) 138 | elif isinstance(mod, nn.LayerNorm): 139 | kwargs = get_layernorm_args(mod) 140 | elif isinstance(mod, nn.Embedding): 141 | kwargs = get_embedding_args(mod) 142 | else: 143 | raise ValueError 144 | 145 | kwargs['activation'] = act 146 | return kwargs 147 | 148 | 149 | def quant_module(module, i, **quant_params): 150 | act, act_idx = get_act(module, i) 151 | modtype = module_map[type(module[i])] 152 | 153 | kwargs = get_module_args(module[i], act) 154 | new_module = modtype(**kwargs, **quant_params) 155 | new_module.weight.data = module[i].weight.data.clone() 156 | 157 | if module[i].bias is not None: 158 | new_module.bias.data = module[i].bias.data.clone() 159 | 160 | return new_module, i + int(bool(act)) + 1 161 | 162 | 163 | def quantize_sequence(model, specials=None, tie_activation_quantizers=False, **quant_params): 164 | specials = specials or dict() 165 | 166 | i = 0 167 | quant_modules = [] 168 | while i < len(model): 169 | if isinstance(model[i], QuantizedModule): 170 | quant_modules.append(model[i]) 171 | elif type(model[i]) in module_map: 172 | new_module, new_i = quant_module(model, i, **quant_params) 173 | quant_modules.append(new_module) 174 | i = new_i 175 | continue 176 | 177 | elif type(model[i]) in specials: 178 | quant_modules.append(specials[type(model[i])](model[i], **quant_params)) 179 | 180 | elif isinstance(model[i], non_param_modules): 181 | # check for last quantizer 182 | input_quantizer = None 183 | if ( 184 | quant_modules 185 | and isinstance(quant_modules[-1], QuantizedModule) 186 | and tie_activation_quantizers 187 | ): 188 | input_quantizer = quant_modules[-1].activation_quantizer 189 | warnings.warn( 190 | f'Tying input quantizer {i}^th layer of type ' 191 | f'{type(quant_modules[-1])} to the quantized {type(model[i])} ' 192 | f'following it' 193 | ) 194 | quant_modules.append( 195 | QuantizedActivationWrapper( 196 | model[i], 197 | tie_activation_quantizers=tie_activation_quantizers, 198 | input_quantizer=input_quantizer, 199 | **quant_params, 200 | ) 201 | ) 202 | 203 | else: 204 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params)) 205 | i += 1 206 | return quant_modules 207 | 208 | 209 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params): 210 | quant_modules = quantize_sequence(model, specials, tie_activation_quantizers, **quant_params) 211 | return nn.Sequential(*quant_modules) 212 | 213 | 214 | def quantize_module_list(model, specials=None, tie_activation_quantizers=False, **quant_params): 215 | quant_modules = quantize_sequence(model, specials, tie_activation_quantizers, **quant_params) 216 | return nn.ModuleList(quant_modules) 217 | 218 | 219 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params): 220 | specials = specials or dict() 221 | 222 | if isinstance(model, nn.Sequential): 223 | quant_model = quantize_sequential( 224 | model, specials, tie_activation_quantizers, **quant_params 225 | ) 226 | 227 | elif type(model) in specials: 228 | quant_model = specials[type(model)](model, **quant_params) 229 | 230 | elif isinstance(model, non_param_modules): 231 | quant_model = QuantizedActivationWrapper(model, **quant_params) 232 | 233 | elif type(model) in module_map: 234 | # if we do isinstance() then we might run into issues with modules that inherit from 235 | # one of these classes, for whatever reason 236 | modtype = module_map[type(model)] 237 | kwargs = get_module_args(model, None) 238 | quant_model = modtype(**kwargs, **quant_params) 239 | 240 | quant_model.weight.data = model.weight.data 241 | if getattr(model, 'bias', None) is not None: 242 | quant_model.bias.data = model.bias.data 243 | 244 | else: 245 | # unknown type, try to quantize all child modules 246 | quant_model = copy.deepcopy(model) 247 | for name, module in quant_model._modules.items(): 248 | new_model = quantize_model(module, specials=specials, **quant_params) 249 | if new_model is not None: 250 | setattr(quant_model, name, new_model) 251 | 252 | return quant_model 253 | -------------------------------------------------------------------------------- /quantization/base_quantized_classes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from torch import nn 5 | 6 | from quantization.quantization_manager import QuantizationManager 7 | from quantization.quantizers import QMethods 8 | from quantization.range_estimators import RangeEstimators 9 | 10 | 11 | def _set_layer_learn_ranges(layer): 12 | if isinstance(layer, QuantizationManager): 13 | if layer.quantizer.is_initialized: 14 | layer.learn_ranges() 15 | 16 | 17 | def _set_layer_fix_ranges(layer): 18 | if isinstance(layer, QuantizationManager): 19 | if layer.quantizer.is_initialized: 20 | layer.fix_ranges() 21 | 22 | 23 | def _set_layer_estimate_ranges(layer): 24 | if isinstance(layer, QuantizationManager): 25 | if layer.quantizer.is_initialized: 26 | layer.estimate_ranges() 27 | 28 | 29 | def _set_layer_estimate_ranges_train(layer): 30 | if isinstance(layer, QuantizationManager): 31 | if layer.quantizer.is_initialized: 32 | layer.estimate_ranges_train() 33 | 34 | 35 | class QuantizedModule(nn.Module): 36 | """ 37 | Parent class for a quantized module. It adds the basic functionality of switching the module 38 | between quantized and full precision mode. It also defines the cached parameters and handles 39 | the reset of the cache properly. 40 | """ 41 | def __init__(self, *args, method=QMethods.asymmetric_uniform, act_method=None, n_bits=8, 42 | n_bits_act=None, per_channel_weights=False, per_channel_acts=False, 43 | percentile=None, weight_range_method=RangeEstimators.current_minmax, 44 | weight_range_options=None, act_range_method=RangeEstimators.running_minmax, 45 | act_range_options=None, scale_domain='linear', **kwargs): 46 | kwargs.pop('quant_dict', None) 47 | super().__init__(*args, **kwargs) 48 | 49 | self.method = method 50 | self.act_method = act_method or method 51 | self.n_bits = n_bits 52 | self.n_bits_act = n_bits_act or n_bits 53 | self.per_channel_weights = per_channel_weights 54 | self.per_channel_acts = per_channel_acts 55 | self.percentile = percentile 56 | self.weight_range_method = weight_range_method 57 | self.weight_range_options = weight_range_options if weight_range_options else {} 58 | self.act_range_method = act_range_method 59 | self.act_range_options = act_range_options if act_range_options else {} 60 | self.scale_domain = scale_domain 61 | 62 | self.cached_params = None 63 | self._caching = True 64 | 65 | self.quant_params = None 66 | self._quant_w = False 67 | self._quant_a = False 68 | 69 | @property 70 | def caching(self): 71 | return self._caching 72 | 73 | @caching.setter 74 | def caching(self, value: bool): 75 | self._caching = value 76 | if not value: 77 | self.cached_params = None 78 | 79 | def quantized_weights(self): 80 | self.cached_params = None 81 | self._quant_w = True 82 | 83 | def full_precision_weights(self): 84 | self.cached_params = None 85 | self._quant_w = False 86 | 87 | def quantized_acts(self): 88 | self._quant_a = True 89 | 90 | def full_precision_acts(self): 91 | self._quant_a = False 92 | 93 | def quantized(self): 94 | self.quantized_weights() 95 | self.quantized_acts() 96 | 97 | def full_precision(self): 98 | self.full_precision_weights() 99 | self.full_precision_acts() 100 | 101 | def learn_ranges(self): 102 | self.apply(_set_layer_learn_ranges) 103 | 104 | def fix_ranges(self): 105 | self.apply(_set_layer_fix_ranges) 106 | 107 | def estimate_ranges(self): 108 | self.apply(_set_layer_estimate_ranges) 109 | 110 | def estimate_ranges_train(self): 111 | self.apply(_set_layer_estimate_ranges_train) 112 | 113 | def train(self, mode=True): 114 | super().train(mode) 115 | if mode: 116 | self.cached_params = None 117 | return self 118 | 119 | def _apply(self, *args, **kwargs): 120 | self.cached_params = None 121 | return super(QuantizedModule, self)._apply(*args, **kwargs) 122 | 123 | def extra_repr(self): 124 | quant_state = 'weight_quant={}, act_quant={}'.format(self._quant_w, self._quant_a) 125 | parent_repr = super().extra_repr() 126 | return '{},\n{}'.format(parent_repr, quant_state) if parent_repr else quant_state 127 | 128 | 129 | class QuantizedActivation(QuantizedModule): 130 | def __init__(self, *args, **kwargs): 131 | super().__init__(*args, **kwargs) 132 | act_qparams = dict(n_bits=self.n_bits_act, scale_domain=self.scale_domain) 133 | self.activation_quantizer = QuantizationManager( 134 | qmethod=self.act_method, 135 | qparams=act_qparams, 136 | init=self.act_range_method, 137 | init_params=self.act_range_options, 138 | ) 139 | 140 | def quantize_activations(self, x): 141 | if self._quant_a: 142 | return self.activation_quantizer(x) 143 | else: 144 | return x 145 | 146 | def forward(self, x): 147 | return self.quantize_activations(x) 148 | 149 | 150 | class FP32Acts(nn.Module): 151 | def forward(self, x): 152 | return x 153 | 154 | def reset_ranges(self): 155 | pass 156 | -------------------------------------------------------------------------------- /quantization/base_quantized_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from torch import nn 5 | 6 | from quantization.base_quantized_classes import ( 7 | QuantizedModule, 8 | _set_layer_learn_ranges, 9 | _set_layer_fix_ranges, 10 | _set_layer_estimate_ranges, 11 | _set_layer_estimate_ranges_train, 12 | ) 13 | 14 | 15 | class QuantizedModel(nn.Module): 16 | """ 17 | Parent class for a quantized model. This allows you to have convenience functions to put the 18 | whole model into quantization or full precision or to freeze BN. Otherwise it does not add any 19 | further functionality, so it is not a necessity that a quantized model uses this class. 20 | """ 21 | def quantized_weights(self): 22 | def _fn(layer): 23 | if isinstance(layer, QuantizedModule): 24 | layer.quantized_weights() 25 | 26 | self.apply(_fn) 27 | 28 | def full_precision_weights(self): 29 | def _fn(layer): 30 | if isinstance(layer, QuantizedModule): 31 | layer.full_precision_weights() 32 | 33 | self.apply(_fn) 34 | 35 | def quantized_acts(self): 36 | def _fn(layer): 37 | if isinstance(layer, QuantizedModule): 38 | layer.quantized_acts() 39 | 40 | self.apply(_fn) 41 | 42 | def full_precision_acts(self): 43 | def _fn(layer): 44 | if isinstance(layer, QuantizedModule): 45 | layer.full_precision_acts() 46 | 47 | self.apply(_fn) 48 | 49 | def quantized(self): 50 | def _fn(layer): 51 | if isinstance(layer, QuantizedModule): 52 | layer.quantized() 53 | 54 | self.apply(_fn) 55 | 56 | def full_precision(self): 57 | def _fn(layer): 58 | if isinstance(layer, QuantizedModule): 59 | layer.full_precision() 60 | 61 | self.apply(_fn) 62 | 63 | # Methods for switching quantizer quantization states 64 | def learn_ranges(self): 65 | self.apply(_set_layer_learn_ranges) 66 | 67 | def fix_ranges(self): 68 | self.apply(_set_layer_fix_ranges) 69 | 70 | def fix_act_ranges(self): 71 | def _fn(module): 72 | if isinstance(module, QuantizedModule) and hasattr(module, 'activation_quantizer'): 73 | _set_layer_fix_ranges(module.activation_quantizer) 74 | 75 | self.apply(_fn) 76 | 77 | def fix_weight_ranges(self): 78 | def _fn(module): 79 | if isinstance(module, QuantizedModule) and hasattr(module, 'weight_quantizer'): 80 | _set_layer_fix_ranges(module.weight_quantizer) 81 | 82 | self.apply(_fn) 83 | 84 | def estimate_ranges(self): 85 | self.apply(_set_layer_estimate_ranges) 86 | 87 | def estimate_act_ranges(self): 88 | def _fn(module): 89 | if isinstance(module, QuantizedModule) and hasattr(module, 'activation_quantizer'): 90 | _set_layer_estimate_ranges(module.activation_quantizer) 91 | 92 | self.apply(_fn) 93 | 94 | def estimate_ranges_train(self): 95 | self.apply(_set_layer_estimate_ranges_train) 96 | 97 | def reset_act_ranges(self): 98 | def _fn(module): 99 | if isinstance(module, QuantizedModule) and hasattr(module, 'activation_quantizer'): 100 | module.activation_quantizer.reset_ranges() 101 | 102 | self.apply(_fn) 103 | 104 | def set_quant_state(self, weight_quant, act_quant): 105 | if act_quant: 106 | self.quantized_acts() 107 | else: 108 | self.full_precision_acts() 109 | 110 | if weight_quant: 111 | self.quantized_weights() 112 | else: 113 | self.full_precision_weights() 114 | -------------------------------------------------------------------------------- /quantization/hijacker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import copy 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from quantization.base_quantized_classes import QuantizedModule 10 | from quantization.quantization_manager import QuantizationManager 11 | from quantization.range_estimators import RangeEstimators 12 | from quantization.utils import to_numpy 13 | 14 | 15 | activations_list = [nn.ReLU, nn.ReLU6, nn.Hardtanh, nn.Sigmoid, nn.Tanh, nn.PReLU, nn.GELU] 16 | 17 | 18 | class QuantizationHijacker(QuantizedModule): 19 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and 20 | dequantization on the weights and output distributions. 21 | 22 | Usage: 23 | To make a quantized nn.Linear layer: 24 | ``` 25 | >>> class QuantLinear(QuantizationHijacker, nn.Linear): 26 | ... pass 27 | ``` 28 | 29 | It is vital that QSchemeForwardHijacker is the first parent class, and that the second parent 30 | class derives from nn.Module, otherwise it will not be reached by a super(., .) call. 31 | 32 | NB: this implementation (for now) assumes that there will always be some training involved, 33 | e.g. to estimate the activation ranges. 34 | """ 35 | def __init__(self, *args, activation: nn.Module = None, **kwargs): 36 | super().__init__(*args, **kwargs) 37 | if activation: 38 | assert isinstance(activation, tuple(activations_list)) 39 | self.activation_function = copy.deepcopy(activation) if activation else None 40 | 41 | weight_qparams = dict(n_bits=self.n_bits, scale_domain=self.scale_domain) 42 | act_qparams = dict(n_bits=self.n_bits_act, scale_domain=self.scale_domain) 43 | 44 | self.activation_quantizer = QuantizationManager( 45 | qmethod=self.act_method, 46 | init=self.act_range_method, 47 | per_channel=self.per_channel_acts, 48 | qparams=act_qparams, 49 | init_params=self.act_range_options, 50 | ) 51 | 52 | if self.weight_range_method == RangeEstimators.current_minmax: 53 | weight_init_params = dict(percentile=self.percentile) 54 | else: 55 | weight_init_params = self.weight_range_options 56 | self.weight_quantizer = QuantizationManager( 57 | qmethod=self.method, 58 | init=self.weight_range_method, 59 | per_channel=self.per_channel_weights, 60 | qparams=weight_qparams, 61 | init_params=weight_init_params, 62 | ) 63 | self.activation_save_target = None 64 | self.activation_save_name = None 65 | 66 | def forward(self, x, offsets=None): 67 | weight, bias = self.get_params() 68 | res = self.run_forward(x, weight, bias, offsets=offsets) 69 | res = self.quantize_activations(res) 70 | return res 71 | 72 | def get_params(self): 73 | if not self.training and self.cached_params: 74 | return self.cached_params 75 | 76 | weight, bias = self.get_weight_bias() 77 | 78 | if self._quant_w: 79 | weight = self.weight_quantizer(weight) 80 | 81 | if self._caching and not self.training and self.cached_params is None: 82 | self.cached_params = ( 83 | torch.Tensor(to_numpy(weight)).to(weight.device), 84 | torch.Tensor(to_numpy(bias)).to(bias.device) if bias is not None else None, 85 | ) 86 | return weight, bias 87 | 88 | def get_weight_bias(self): 89 | bias = None 90 | if hasattr(self, "bias"): 91 | bias = self.bias 92 | return self.weight, bias 93 | 94 | def run_forward(self, x, weight, bias, offsets=None): 95 | # Performs the actual (e.g., linear) operation of the layer 96 | raise NotImplementedError() 97 | 98 | def quantize_activations(self, activations): 99 | """Quantize a single activation tensor or all activations from a layer. I'm assuming that 100 | we should quantize all outputs for a layer with the same quantization scheme. 101 | """ 102 | if self.activation_function is not None: 103 | activations = self.activation_function(activations) 104 | 105 | if self.activation_save_target is not None: 106 | self.activation_save_target[self.activation_save_name] = activations.data.cpu().numpy() 107 | 108 | if self._quant_a: 109 | activations = self.activation_quantizer(activations) 110 | 111 | if self.activation_save_target is not None: 112 | self.activation_save_target[self.activation_save_name + '_Q'] = ( 113 | activations.data.cpu().numpy() 114 | ) 115 | 116 | return activations 117 | -------------------------------------------------------------------------------- /quantization/quantization_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from enum import Enum 5 | 6 | from torch import nn 7 | 8 | from quantization.quantizers import QMethods, QuantizerNotInitializedError 9 | from quantization.range_estimators import RangeEstimators 10 | 11 | 12 | class Qstates(Enum): 13 | estimate_ranges = 0 # ranges are updated in eval and train mode 14 | fix_ranges = 1 # quantization ranges are fixed for train and eval 15 | learn_ranges = 2 # quantization params are nn.Parameters 16 | estimate_ranges_train = 3 # quantization ranges are updated during train and fixed for eval 17 | 18 | 19 | class QuantizationManager(nn.Module): 20 | """Implementation of Quantization and Quantization Range Estimation 21 | 22 | Parameters 23 | ---------- 24 | n_bits: int 25 | Number of bits for the quantization. 26 | qmethod: QMethods member (Enum) 27 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform, 28 | qmn_uniform etc. 29 | init: RangeEstimators member (Enum) 30 | Initialization method for the grid from 31 | per_channel: bool 32 | If true, will use a separate quantization grid for each kernel/channle. 33 | x_min: float or PyTorch Tensor 34 | The minimum value which needs to be represented. 35 | x_max: float or PyTorch Tensor 36 | The maximum value which needs to be represented. 37 | """ 38 | def __init__(self, qmethod=QMethods.symmetric_uniform, init=RangeEstimators.current_minmax, 39 | per_channel=False, axis=None, n_groups=None, x_min=None, x_max=None, qparams=None, 40 | init_params=None): 41 | super().__init__() 42 | self.state = Qstates.estimate_ranges 43 | self.qmethod = qmethod 44 | self.init = init 45 | self.per_channel = per_channel 46 | self.axis = axis 47 | self.n_groups = n_groups 48 | self.qparams = qparams if qparams else {} 49 | self.init_params = init_params if init_params else {} 50 | self.range_estimator = None 51 | 52 | # define quantizer 53 | self.quantizer = self.qmethod.cls(per_channel=per_channel, axis=axis, **qparams) 54 | 55 | # define range estimation method for quantizer initialisation 56 | if x_min is not None and x_max is not None: 57 | self.set_quant_range(x_min, x_max) 58 | self.state = Qstates.fix_ranges 59 | else: 60 | # set up the collector function to set the ranges 61 | self.range_estimator = self.init.cls( 62 | per_channel=self.per_channel, 63 | quantizer=self.quantizer, 64 | axis=self.axis, 65 | n_groups=self.n_groups, 66 | **self.init_params 67 | ) 68 | 69 | @property 70 | def n_bits(self): 71 | return self.quantizer.n_bits 72 | 73 | def estimate_ranges(self): 74 | self.state = Qstates.estimate_ranges 75 | 76 | def fix_ranges(self): 77 | if self.quantizer.is_initialized: 78 | self.state = Qstates.fix_ranges 79 | else: 80 | raise QuantizerNotInitializedError() 81 | 82 | def learn_ranges(self): 83 | self.quantizer.make_range_trainable() 84 | self.state = Qstates.learn_ranges 85 | 86 | def estimate_ranges_train(self): 87 | self.state = Qstates.estimate_ranges_train 88 | 89 | def reset_ranges(self): 90 | self.range_estimator.reset() 91 | self.quantizer.reset() 92 | self.estimate_ranges() 93 | 94 | def forward(self, x): 95 | if self.range_estimator.per_group_range_estimation: 96 | self.range_estimator(x) 97 | return x 98 | 99 | if self.state == Qstates.estimate_ranges or ( 100 | self.state == Qstates.estimate_ranges_train and self.training 101 | ): 102 | # Note this can be per tensor or per channel 103 | cur_xmin, cur_xmax = self.range_estimator(x) 104 | self.set_quant_range(cur_xmin, cur_xmax) 105 | 106 | return self.quantizer(x) 107 | 108 | def set_quant_range(self, x_min, x_max): 109 | self.quantizer.set_quant_range(x_min, x_max) 110 | 111 | def extra_repr(self): 112 | return 'state={}'.format(self.state.name) 113 | -------------------------------------------------------------------------------- /quantization/quantizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from collections import namedtuple 5 | from enum import Enum 6 | 7 | import torch 8 | from torch.autograd import Function 9 | from torch import nn 10 | 11 | 12 | class RoundStraightThrough(Function): 13 | @staticmethod 14 | def forward(ctx, x): 15 | return torch.round(x) 16 | 17 | @staticmethod 18 | def backward(ctx, output_grad): 19 | return output_grad 20 | 21 | 22 | class FloorStraightThrough(Function): 23 | @staticmethod 24 | def forward(ctx, x): 25 | return torch.floor(x) 26 | 27 | @staticmethod 28 | def backward(ctx, output_grad): 29 | return output_grad 30 | 31 | 32 | round_ste_func = RoundStraightThrough.apply 33 | floor_ste_func = FloorStraightThrough.apply 34 | 35 | 36 | class QuantizerBase(nn.Module): 37 | def __init__(self, n_bits, per_channel=False, axis=None, *args, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | self.n_bits = n_bits 40 | self.per_channel = per_channel 41 | self.axis = axis 42 | 43 | @property 44 | def is_initialized(self): 45 | raise NotImplementedError() 46 | 47 | @property 48 | def x_max(self): 49 | raise NotImplementedError() 50 | 51 | @property 52 | def symmetric(self): 53 | raise NotImplementedError() 54 | 55 | @property 56 | def x_min(self): 57 | raise NotImplementedError() 58 | 59 | def forward(self, x_float): 60 | raise NotImplementedError() 61 | 62 | def _adjust_params_per_axis(self, x): 63 | raise NotImplementedError() 64 | 65 | def _adjust_params_per_channel(self, x): 66 | raise NotImplementedError() 67 | 68 | def set_quant_range(self, x_min, x_max): 69 | raise NotImplementedError() 70 | 71 | def extra_repr(self): 72 | return ( 73 | f'n_bits={self.n_bits}, per_channel={self.per_channel}, axis={self.axis}, ' 74 | f'is_initalized={self.is_initialized}' 75 | ) 76 | 77 | def reset(self): 78 | self._delta = None 79 | 80 | 81 | class AsymmetricUniformQuantizer(QuantizerBase): 82 | """ 83 | PyTorch Module that implements Asymmetric Uniform Quantization using STE. 84 | Quantizes its argument in the forward pass, passes the gradient 'straight 85 | through' on the backward pass, ignoring the quantization that occurred. 86 | 87 | Parameters 88 | ---------- 89 | n_bits: int 90 | Number of bits for quantization. 91 | scale_domain: str ('log', 'linear) with default='linear' 92 | Domain of scale factor 93 | per_channel: bool 94 | If True: allows for per-channel quantization 95 | """ 96 | def __init__(self, n_bits, scale_domain='linear', per_channel=False, axis=None, eps=1e-8): 97 | 98 | super().__init__(n_bits, per_channel) 99 | 100 | assert scale_domain in ('linear', 'log') 101 | self.register_buffer('_delta', None) 102 | self.register_buffer('_zero_float', None) 103 | self.n_bits = n_bits 104 | self.scale_domain = scale_domain 105 | self.per_channel = per_channel 106 | self.axis = axis 107 | self.eps = eps 108 | 109 | # A few useful properties 110 | @property 111 | def delta(self): 112 | if self._delta is not None: 113 | return self._delta 114 | else: 115 | raise QuantizerNotInitializedError() 116 | 117 | @property 118 | def zero_float(self): 119 | if self._zero_float is not None: 120 | return self._zero_float 121 | else: 122 | raise QuantizerNotInitializedError() 123 | 124 | @property 125 | def is_initialized(self): 126 | return self._delta is not None 127 | 128 | @property 129 | def symmetric(self): 130 | return False 131 | 132 | @property 133 | def int_min(self): 134 | # integer grid minimum 135 | return 0.0 136 | 137 | @property 138 | def int_max(self): 139 | # integer grid maximum 140 | return 2.0 ** self.n_bits - 1 141 | 142 | @property 143 | def scale(self): 144 | if self.scale_domain == 'linear': 145 | return torch.clamp(self.delta, min=self.eps) 146 | elif self.scale_domain == 'log': 147 | return torch.exp(self.delta) 148 | 149 | @property 150 | def zero_point(self): 151 | zero_point = round_ste_func(self.zero_float) 152 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max) 153 | return zero_point 154 | 155 | @property 156 | def x_max(self): 157 | return self.scale * (self.int_max - self.zero_point) 158 | 159 | @property 160 | def x_min(self): 161 | return self.scale * (self.int_min - self.zero_point) 162 | 163 | def _clamp(self, x_int): 164 | with torch.no_grad(): 165 | clampled_left = (x_int > self.int_max).float().sum() 166 | clampled_right = (x_int < self.int_min).float().sum() 167 | self._clamped = (clampled_left + clampled_right) / x_int.numel() 168 | x_clamped = torch.clamp(x_int, self.int_min, self.int_max) 169 | 170 | return x_clamped 171 | 172 | def to_integer_forward(self, x_float): 173 | """ 174 | Qunatized input to its integer represantion 175 | Parameters 176 | ---------- 177 | x_float: PyTorch Float Tensor 178 | Full-precision Tensor 179 | 180 | Returns 181 | ------- 182 | x_int: PyTorch Float Tensor of integers 183 | """ 184 | x_int = round_ste_func(x_float / self.scale) + self.zero_point 185 | x_int = torch.clamp(x_int, self.int_min, self.int_max) 186 | 187 | return x_int 188 | 189 | def forward(self, x_float): 190 | """ 191 | Quantizes (quantized to integer and the scales back to original domain) 192 | Parameters 193 | ---------- 194 | x_float: PyTorch Float Tensor 195 | Full-precision Tensor 196 | 197 | Returns 198 | ------- 199 | x_quant: PyTorch Float Tensor 200 | Quantized-Dequantized Tensor 201 | """ 202 | if self.axis is not None: 203 | self._adjust_params_per_axis(x_float) 204 | 205 | if self.per_channel: 206 | self._adjust_params_per_channel(x_float) 207 | 208 | x_int = self.to_integer_forward(x_float) 209 | x_quant = self.scale * (x_int - self.zero_point) 210 | 211 | return x_quant 212 | 213 | def _adjust_params_per_axis(self, x_float): 214 | r = len(x_float.size()) 215 | new_shape = [1] * self.axis + [-1] + [1] * (r - self.axis - 1) 216 | self._delta = self._delta.view(new_shape) 217 | self._zero_float = self._zero_float.view(new_shape) 218 | 219 | def _adjust_params_per_channel(self, x): 220 | """ 221 | Adjusts the quantization parameter tensors (delta, zero_float) 222 | to the input tensor shape if they don't match 223 | 224 | Parameters 225 | ---------- 226 | x: input tensor 227 | """ 228 | if x.ndim != self.delta.ndim: 229 | new_shape = [-1] + [1] * (len(x.shape) - 1) 230 | self._delta = self.delta.view(new_shape) 231 | if self._zero_float is not None: 232 | self._zero_float = self._zero_float.view(new_shape) 233 | 234 | def _tensorize_min_max(self, x_min, x_max): 235 | """ 236 | Converts provided min max range into tensors 237 | Parameters 238 | ---------- 239 | x_min: float or PyTorch 1D tensor 240 | x_max: float or PyTorch 1D tensor 241 | 242 | Returns 243 | ------- 244 | x_min: PyTorch Tensor 0 or 1-D 245 | x_max: PyTorch Tensor 0 or 1-D 246 | """ 247 | # Ensure a torch tensor 248 | if not torch.is_tensor(x_min): 249 | x_min = torch.tensor(x_min).float() 250 | x_max = torch.tensor(x_max).float() 251 | 252 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel and self.axis is None: 253 | raise ValueError( 254 | 'x_min and x_max must be a float or 1-D Tensor' 255 | ' for per-tensor quantization (per_channel=False)' 256 | ) 257 | # Ensure we always use zero and avoid division by zero 258 | x_min = torch.min(x_min, torch.zeros_like(x_min)) 259 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps) 260 | 261 | return x_min, x_max 262 | 263 | def set_quant_range(self, x_min, x_max): 264 | """ 265 | Instantiates the quantization parameters based on the provided 266 | min and max range 267 | 268 | Parameters 269 | ---------- 270 | x_min: tensor or float 271 | Quantization range minimum limit 272 | x_max: tensor of float 273 | Quantization range minimum limit 274 | """ 275 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 276 | self._delta = (x_max - x_min) / self.int_max 277 | self._zero_float = (-x_min / self.delta).detach() 278 | 279 | if self.scale_domain == 'log': 280 | self._delta = torch.log(self.delta) 281 | 282 | self._delta = self._delta.detach() 283 | 284 | def make_range_trainable(self): 285 | # Converts trainable parameters to nn.Parameters 286 | if self.delta not in self.parameters(): 287 | self._delta = torch.nn.Parameter(self._delta) 288 | self._zero_float = torch.nn.Parameter(self._zero_float) 289 | 290 | 291 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer): 292 | """ 293 | PyTorch Module that implements Symmetric Uniform Quantization using STE. 294 | Quantizes its argument in the forward pass, passes the gradient 'straight 295 | through' on the backward pass, ignoring the quantization that occurred. 296 | 297 | Parameters 298 | ---------- 299 | n_bits: int 300 | Number of bits for quantization. 301 | scale_domain: str ('log', 'linear) with default='linear' 302 | Domain of scale factor 303 | per_channel: bool 304 | If True: allows for per-channel quantization 305 | """ 306 | def __init__(self, *args, **kwargs): 307 | super().__init__(*args, **kwargs) 308 | self.register_buffer('_signed', None) 309 | 310 | @property 311 | def signed(self): 312 | if self._signed is not None: 313 | return self._signed.item() 314 | else: 315 | raise QuantizerNotInitializedError() 316 | 317 | @property 318 | def symmetric(self): 319 | return True 320 | 321 | @property 322 | def int_min(self): 323 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0 324 | 325 | @property 326 | def int_max(self): 327 | pos_n_bits = self.n_bits - self.signed 328 | return 2.0 ** pos_n_bits - 1 329 | 330 | @property 331 | def zero_point(self): 332 | return 0.0 333 | 334 | def set_quant_range(self, x_min, x_max): 335 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 336 | self._signed = x_min.min() < 0 337 | 338 | x_absmax = torch.max(x_min.abs(), x_max) 339 | self._delta = x_absmax / self.int_max 340 | 341 | if self.scale_domain == 'log': 342 | self._delta = torch.log(self._delta) 343 | 344 | self._delta = self._delta.detach() 345 | 346 | def make_range_trainable(self): 347 | # Converts trainable parameters to nn.Parameters 348 | if self.delta not in self.parameters(): 349 | self._delta = torch.nn.Parameter(self._delta) 350 | 351 | 352 | QMethodMap = namedtuple('QMethodMap', ['value', 'cls']) 353 | 354 | 355 | class QMethods(Enum): 356 | symmetric_uniform = QMethodMap(0, SymmetricUniformQuantizer) 357 | asymmetric_uniform = QMethodMap(1, AsymmetricUniformQuantizer) 358 | 359 | @property 360 | def cls(self): 361 | return self.value.cls 362 | 363 | @classmethod 364 | def list(cls): 365 | return [m.name for m in cls] 366 | 367 | 368 | class QuantizerNotInitializedError(Exception): 369 | """Raised when a quantizer has not initialized""" 370 | 371 | def __init__(self): 372 | super(QuantizerNotInitializedError, self).__init__('Quantizer has not been initialized yet') 373 | -------------------------------------------------------------------------------- /quantization/range_estimators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import copy 5 | from collections import namedtuple 6 | from enum import Enum 7 | 8 | import numpy as np 9 | import torch 10 | from scipy.optimize import minimize_scalar 11 | from torch import nn 12 | from torch.nn import functional as F 13 | 14 | from quantization.utils import to_numpy 15 | 16 | 17 | class RangeEstimatorBase(nn.Module): 18 | def __init__(self, per_channel=False, quantizer=None, axis=None, n_groups=None, *args, 19 | **kwargs): 20 | super().__init__(*args, **kwargs) 21 | self.register_buffer('current_xmin', None) 22 | self.register_buffer('current_xmax', None) 23 | self.per_channel = per_channel 24 | self.quantizer = quantizer 25 | self.axis = axis 26 | self.n_groups = n_groups 27 | 28 | self.per_group_range_estimation = False 29 | self.ranges = None 30 | 31 | def forward(self, x): 32 | """ 33 | Accepts an input tensor, updates the current estimates of x_min and x_max 34 | and retruns them. 35 | Parameters 36 | ---------- 37 | x: Input tensor 38 | 39 | Returns 40 | ------- 41 | self.current_xmin: tensor 42 | self.current_xmax: tensor 43 | """ 44 | raise NotImplementedError() 45 | 46 | def reset(self): 47 | """ 48 | Reset the range estimator. 49 | """ 50 | self.current_xmin = None 51 | self.current_xmax = None 52 | 53 | def __repr__(self): 54 | # We overwrite this from nn.Module as we do not want to have submodules such as 55 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module. 56 | lines = self.extra_repr().split('\n') 57 | extra_str = lines[0] if len(lines) == 1 else '\n ' + '\n '.join(lines) + '\n' 58 | 59 | return self._get_name() + '(' + extra_str + ')' 60 | 61 | 62 | class CurrentMinMaxEstimator(RangeEstimatorBase): 63 | def __init__(self, percentile=None, *args, **kwargs): 64 | self.percentile = percentile 65 | super().__init__(*args, **kwargs) 66 | 67 | def forward(self, x): 68 | if self.per_group_range_estimation: 69 | assert self.axis != 0 70 | x = x.transpose(0, self.axis).contiguous() 71 | x = x.view(x.size(0), -1) 72 | 73 | ranges = x.max(-1)[0].detach() - x.min(-1)[0].detach() 74 | 75 | if self.ranges is None: 76 | self.ranges = ranges 77 | else: 78 | momentum = 0.1 79 | self.ranges = momentum * ranges + (1 - momentum) * ranges 80 | return 81 | 82 | if self.axis is not None: 83 | if self.axis != 0: 84 | x = x.transpose(0, self.axis).contiguous() 85 | x = x.view(x.size(0), -1) 86 | 87 | if self.n_groups is not None: 88 | ng = self.n_groups 89 | assert ng > 0 and x.size(0) % ng == 0 90 | gs = x.size(0) // ng 91 | 92 | # permute 93 | if self.ranges is not None: 94 | i = torch.argsort(self.ranges) 95 | I = torch.eye(len(i), device=self.ranges.device) 96 | P = I[i] 97 | x = P.mm(x) 98 | 99 | x = x.view(ng, -1) 100 | m = x.min(-1)[0].detach() 101 | M = x.max(-1)[0].detach() 102 | 103 | m = m.repeat_interleave(gs) 104 | M = M.repeat_interleave(gs) 105 | 106 | # permute back 107 | if self.ranges is not None: 108 | m = P.T.mv(m) 109 | M = P.T.mv(M) 110 | 111 | self.current_xmin = m 112 | self.current_xmax = M 113 | 114 | else: 115 | self.current_xmin = x.min(-1)[0].detach() 116 | self.current_xmax = x.max(-1)[0].detach() 117 | 118 | elif self.per_channel: 119 | # Along 1st dim 120 | x_flattened = x.view(x.shape[0], -1) 121 | if self.percentile: 122 | data_np = to_numpy(x_flattened) 123 | x_min, x_max = np.percentile( 124 | data_np, (self.percentile, 100 - self.percentile), axis=-1 125 | ) 126 | self.current_xmin = torch.Tensor(x_min) 127 | self.current_xmax = torch.Tensor(x_max) 128 | else: 129 | self.current_xmin = x_flattened.min(-1)[0].detach() 130 | self.current_xmax = x_flattened.max(-1)[0].detach() 131 | 132 | else: 133 | if self.percentile: 134 | device = x.device 135 | data_np = to_numpy(x) 136 | x_min, x_max = np.percentile(data_np, (self.percentile, 100)) 137 | x_min = np.atleast_1d(x_min) 138 | x_max = np.atleast_1d(x_max) 139 | self.current_xmin = torch.Tensor(x_min).to(device).detach() 140 | self.current_xmax = torch.Tensor(x_max).to(device).detach() 141 | else: 142 | self.current_xmin = torch.min(x).detach() 143 | self.current_xmax = torch.max(x).detach() 144 | 145 | return self.current_xmin, self.current_xmax 146 | 147 | 148 | class AllMinMaxEstimator(RangeEstimatorBase): 149 | def __init__(self, *args, **kwargs): 150 | super().__init__(*args, **kwargs) 151 | 152 | def forward(self, x): 153 | if self.per_channel: 154 | # Along 1st dim 155 | x_flattened = x.view(x.shape[0], -1) 156 | x_min = x_flattened.min(-1)[0].detach() 157 | x_max = x_flattened.max(-1)[0].detach() 158 | else: 159 | x_min = torch.min(x).detach() 160 | x_max = torch.max(x).detach() 161 | 162 | if self.current_xmin is None: 163 | self.current_xmin = x_min 164 | self.current_xmax = x_max 165 | else: 166 | self.current_xmin = torch.min(self.current_xmin, x_min) 167 | self.current_xmax = torch.max(self.current_xmax, x_max) 168 | 169 | return self.current_xmin, self.current_xmax 170 | 171 | 172 | class RunningMinMaxEstimator(RangeEstimatorBase): 173 | def __init__(self, momentum=0.9, *args, **kwargs): 174 | self.momentum = momentum 175 | super().__init__(*args, **kwargs) 176 | 177 | def forward(self, x): 178 | if self.axis is not None: 179 | if self.axis != 0: 180 | x = x.transpose(0, self.axis).contiguous() 181 | x = x.view(x.size(0), -1) 182 | 183 | if self.n_groups is not None: 184 | ng = self.n_groups 185 | assert ng > 0 and x.size(0) % ng == 0 186 | gs = x.size(0) // ng 187 | 188 | x = x.view(ng, -1) 189 | m = x.min(-1)[0].detach() 190 | M = x.max(-1)[0].detach() 191 | 192 | x_min = m.repeat_interleave(gs) 193 | x_max = M.repeat_interleave(gs) 194 | 195 | else: 196 | x_min = x.min(-1)[0].detach() 197 | x_max = x.max(-1)[0].detach() 198 | 199 | elif self.per_channel: 200 | # Along 1st dim 201 | x_flattened = x.view(x.shape[0], -1) 202 | x_min = x_flattened.min(-1)[0].detach() 203 | x_max = x_flattened.max(-1)[0].detach() 204 | 205 | else: 206 | x_min = torch.min(x).detach() 207 | x_max = torch.max(x).detach() 208 | 209 | if self.current_xmin is None: 210 | self.current_xmin = x_min 211 | self.current_xmax = x_max 212 | else: 213 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin 214 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax 215 | 216 | return self.current_xmin, self.current_xmax 217 | 218 | 219 | class OptMethod(Enum): 220 | grid = 1 221 | golden_section = 2 222 | 223 | @classmethod 224 | def list(cls): 225 | return [m.name for m in cls] 226 | 227 | 228 | class MSE_Estimator(RangeEstimatorBase): 229 | def __init__(self, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, *args, 230 | **kwargs): 231 | super().__init__(*args, **kwargs) 232 | assert opt_method in OptMethod 233 | 234 | self.opt_method = opt_method 235 | self.num_candidates = num_candidates 236 | self.loss_array = None 237 | self.max_pos_thr = None 238 | self.max_neg_thr = None 239 | self.max_search_range = None 240 | self.one_sided_dist = None 241 | self.range_margin = range_margin 242 | if self.quantizer is None: 243 | raise NotImplementedError( 244 | 'A Quantizer must be given as an argument to the MSE Range' 'Estimator' 245 | ) 246 | self.max_int_skew = (2 ** self.quantizer.n_bits) // 4 # for asymmetric quantization 247 | 248 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False): 249 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr) 250 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1) 251 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel 252 | # grid search 253 | if per_channel_loss: 254 | return to_numpy(temp_sum) 255 | else: 256 | return to_numpy(torch.sum(temp_sum)) 257 | 258 | @property 259 | def step_size(self): 260 | if self.one_sided_dist is None: 261 | raise NoDataPassedError() 262 | 263 | return self.max_search_range / self.num_candidates 264 | 265 | @property 266 | def optimization_method(self): 267 | if self.one_sided_dist is None: 268 | raise NoDataPassedError() 269 | 270 | if self.opt_method == OptMethod.grid: 271 | # Grid search method 272 | if self.one_sided_dist or self.quantizer.symmetric: 273 | # 1-D grid search 274 | return self._perform_1D_search 275 | else: 276 | # 2-D grid_search 277 | return self._perform_2D_search 278 | elif self.opt_method == OptMethod.golden_section: 279 | # Golden section method 280 | if self.one_sided_dist or self.quantizer.symmetric: 281 | return self._golden_section_symmetric 282 | else: 283 | return self._golden_section_asymmetric 284 | else: 285 | raise NotImplementedError('Optimization Method not Implemented') 286 | 287 | def quantize(self, x_float, x_min=None, x_max=None): 288 | temp_q = copy.deepcopy(self.quantizer) 289 | # In the current implementation no optimization procedure requires temp quantizer for 290 | # loss_fx to be per-channel 291 | temp_q.per_channel = False 292 | if x_min or x_max: 293 | temp_q.set_quant_range(x_min, x_max) 294 | return temp_q(x_float) 295 | 296 | def golden_sym_loss(self, range, data): 297 | """ 298 | Loss function passed to the golden section optimizer from scipy in case of symmetric 299 | quantization 300 | """ 301 | neg_thr = 0 if self.one_sided_dist else -range 302 | pos_thr = range 303 | return self.loss_fx(data, neg_thr, pos_thr) 304 | 305 | def golden_asym_shift_loss(self, shift, range, data): 306 | """ 307 | Inner Loss function (shift) passed to the golden section optimizer from scipy 308 | in case of asymmetric quantization 309 | """ 310 | pos_thr = range + shift 311 | neg_thr = -range + shift 312 | return self.loss_fx(data, neg_thr, pos_thr) 313 | 314 | def golden_asym_range_loss(self, range, data): 315 | """ 316 | Outer Loss function (range) passed to the golden section optimizer from scipy in case of 317 | asymmetric quantization 318 | """ 319 | temp_delta = 2 * range / (2 ** self.quantizer.n_bits - 1) 320 | max_shift = temp_delta * self.max_int_skew 321 | result = minimize_scalar( 322 | self.golden_asym_shift_loss, 323 | args=(range, data), 324 | bounds=(-max_shift, max_shift), 325 | method='Bounded', 326 | ) 327 | return result.fun 328 | 329 | def _define_search_range(self, data): 330 | self.channel_groups = len(data) if self.per_channel else 1 331 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device) 332 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device) 333 | 334 | if self.one_sided_dist or self.quantizer.symmetric: 335 | # 1D search space 336 | self.loss_array = np.zeros( 337 | (self.channel_groups, self.num_candidates + 1) 338 | ) # 1D search space 339 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish 340 | # Defining the search range for clipping thresholds 341 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin 342 | self.max_neg_thr = -self.max_pos_thr 343 | self.max_search_range = self.max_pos_thr 344 | else: 345 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth 346 | # index represents whether the skew is positive (0) or negative (1)) 347 | self.loss_array = np.zeros( 348 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2] 349 | ) # 2D search space 350 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish 351 | # Define the search range for clipping thresholds in asymmetric case 352 | self.max_pos_thr = float(data.max()) + self.range_margin 353 | self.max_neg_thr = float(data.min()) - self.range_margin 354 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr)) 355 | 356 | def _perform_1D_search(self, data): 357 | """ 358 | Grid search through all candidate quantizers in 1D to find the best 359 | The loss is accumulated over all batches without any momentum 360 | :param data: input tensor 361 | """ 362 | for cand_index in range(1, self.num_candidates + 1): 363 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index 364 | pos_thr = self.step_size * cand_index 365 | 366 | self.loss_array[:, cand_index] += self.loss_fx( 367 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel 368 | ) 369 | # find the best clipping thresholds 370 | min_cand = self.loss_array.argmin(axis=1) 371 | xmin = ( 372 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand 373 | ).astype(np.single) 374 | xmax = (self.step_size * min_cand).astype(np.single) 375 | self.current_xmax = torch.tensor(xmax).to(device=data.device) 376 | self.current_xmin = torch.tensor(xmin).to(device=data.device) 377 | 378 | def _perform_2D_search(self, data): 379 | """ 380 | Grid search through all candidate quantizers in 1D to find the best 381 | The loss is accumulated over all batches without any momentum 382 | Parameters 383 | ---------- 384 | data: PyTorch Tensor 385 | Returns 386 | ------- 387 | 388 | """ 389 | for cand_index in range(1, self.num_candidates + 1): 390 | # defining the symmetric quantization range 391 | temp_start = -self.step_size * cand_index 392 | temp_finish = self.step_size * cand_index 393 | temp_delta = float(temp_finish - temp_start) / (2 ** self.quantizer.n_bits - 1) 394 | for shift in range(self.max_int_skew): 395 | for reverse in range(2): 396 | # introducing asymmetry in the quantization range 397 | skew = ((-1) ** reverse) * shift * temp_delta 398 | neg_thr = max(temp_start + skew, self.max_neg_thr) 399 | pos_thr = min(temp_finish + skew, self.max_pos_thr) 400 | 401 | self.loss_array[:, cand_index, shift, reverse] += self.loss_fx( 402 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel 403 | ) 404 | 405 | for channel_index in range(self.channel_groups): 406 | min_cand, min_shift, min_reverse = np.unravel_index( 407 | np.argmin(self.loss_array[channel_index], axis=None), 408 | self.loss_array[channel_index].shape, 409 | ) 410 | min_interval_start = -self.step_size * min_cand 411 | min_interval_finish = self.step_size * min_cand 412 | min_delta = float(min_interval_finish - min_interval_start) / ( 413 | 2 ** self.quantizer.n_bits - 1 414 | ) 415 | min_skew = ((-1) ** min_reverse) * min_shift * min_delta 416 | xmin = max(min_interval_start + min_skew, self.max_neg_thr) 417 | xmax = min(min_interval_finish + min_skew, self.max_pos_thr) 418 | 419 | self.current_xmin[channel_index] = torch.tensor(xmin).to(device=data.device) 420 | self.current_xmax[channel_index] = torch.tensor(xmax).to(device=data.device) 421 | 422 | def _golden_section_symmetric(self, data): 423 | for channel_index in range(self.channel_groups): 424 | if channel_index == 0 and not self.per_channel: 425 | data_segment = data 426 | else: 427 | data_segment = data[channel_index] 428 | 429 | self.result = minimize_scalar( 430 | self.golden_sym_loss, 431 | args=data_segment, 432 | bounds=(0.01 * self.max_search_range, self.max_search_range), 433 | method='Bounded', 434 | ) 435 | self.current_xmax[channel_index] = torch.tensor(self.result.x).to(device=data.device) 436 | self.current_xmin[channel_index] = ( 437 | torch.tensor(0.0).to(device=data.device) 438 | if self.one_sided_dist 439 | else -self.current_xmax[channel_index] 440 | ) 441 | 442 | def _golden_section_asymmetric(self, data): 443 | for channel_index in range(self.channel_groups): 444 | if channel_index == 0 and not self.per_channel: 445 | data_segment = data 446 | else: 447 | data_segment = data[channel_index] 448 | 449 | self.result = minimize_scalar( 450 | self.golden_asym_range_loss, 451 | args=data_segment, 452 | bounds=(0.01 * self.max_search_range, self.max_search_range), 453 | method='Bounded', 454 | ) 455 | self.final_range = self.result.x 456 | temp_delta = 2 * self.final_range / (2 ** self.quantizer.n_bits - 1) 457 | max_shift = temp_delta * self.max_int_skew 458 | self.subresult = minimize_scalar( 459 | self.golden_asym_shift_loss, 460 | args=(self.final_range, data_segment), 461 | bounds=(-max_shift, max_shift), 462 | method='Bounded', 463 | ) 464 | self.final_shift = self.subresult.x 465 | self.current_xmax[channel_index] = torch.tensor(self.final_range + self.final_shift).to( 466 | device=data.device 467 | ) 468 | self.current_xmin[channel_index] = torch.tensor( 469 | -self.final_range + self.final_shift 470 | ).to(device=data.device) 471 | 472 | def forward(self, data): 473 | if self.loss_array is None: 474 | # Initialize search range on first batch, and accumulate losses with subsequent calls 475 | 476 | # Decide whether input distribution is one-sided 477 | if self.one_sided_dist is None: 478 | self.one_sided_dist = bool((data.min() >= 0).item()) 479 | 480 | # Define search 481 | self._define_search_range(data) 482 | 483 | # Perform Search/Optimization for Quantization Ranges 484 | self.optimization_method(data) 485 | 486 | return self.current_xmin, self.current_xmax 487 | 488 | def reset(self): 489 | super().reset() 490 | self.loss_array = None 491 | 492 | 493 | class CrossEntropyEstimator(MSE_Estimator): 494 | def __init__(self, *args, **kwargs): 495 | super().__init__(*args, **kwargs) 496 | 497 | # per_channel_loss argument is here only to be consistent in definition with other loss fxs 498 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False): 499 | quantized_data = self.quantize(data, neg_thr, pos_thr) 500 | log_quantized_probs = F.log_softmax(quantized_data, dim=1) 501 | unquantized_probs = F.softmax(data, dim=1) 502 | return to_numpy(torch.sum(-unquantized_probs * log_quantized_probs)) 503 | 504 | 505 | class NoDataPassedError(Exception): 506 | """Raised data has been passed inot the Range Estimator.""" 507 | 508 | def __init__(self): 509 | super().__init__('Data must be pass through the range estimator to be initialized') 510 | 511 | 512 | RangeEstimatorMap = namedtuple('RangeEstimatorMap', ['value', 'cls']) 513 | 514 | 515 | class RangeEstimators(Enum): 516 | current_minmax = RangeEstimatorMap(0, CurrentMinMaxEstimator) 517 | allminmax = RangeEstimatorMap(1, AllMinMaxEstimator) 518 | running_minmax = RangeEstimatorMap(2, RunningMinMaxEstimator) 519 | MSE = RangeEstimatorMap(3, MSE_Estimator) 520 | cross_entropy = RangeEstimatorMap(4, CrossEntropyEstimator) 521 | 522 | @property 523 | def cls(self): 524 | return self.value.cls 525 | 526 | @classmethod 527 | def list(cls): 528 | return [m.name for m in cls] 529 | -------------------------------------------------------------------------------- /quantization/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def to_numpy(tensor): 9 | """ 10 | Helper function that turns the given tensor into a numpy array. 11 | 12 | Parameters 13 | ---------- 14 | tensor : torch.Tensor 15 | 16 | Returns 17 | ------- 18 | tensor : float or np.array 19 | """ 20 | if isinstance(tensor, np.ndarray): 21 | return tensor 22 | if hasattr(tensor, 'is_cuda'): 23 | if tensor.is_cuda: 24 | return tensor.cpu().detach().numpy() 25 | if hasattr(tensor, 'detach'): 26 | return tensor.detach().numpy() 27 | if hasattr(tensor, 'numpy'): 28 | return tensor.numpy() 29 | 30 | return np.array(tensor) 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click~=7.0 2 | datasets~=1.1.3 3 | numpy~=1.17.0 4 | scikit-learn~=0.19.1 5 | scipy~=1.3.1 6 | tensorboardX~=1.7 7 | torchsummary 8 | tqdm 9 | transformers~=4.1.0 10 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from utils.adaround_utils import apply_adaround_to_model 5 | from utils.glue_tasks import GLUE_Task, TASK_TO_FINAL_METRIC, load_task_data, make_compute_metric_fn 6 | from utils.hf_models import HF_Models, load_model_and_tokenizer 7 | from utils.per_embd_quant_utils import ( 8 | hijack_act_quant, 9 | hijack_weight_quant, 10 | hijack_act_quant_modules, 11 | set_act_quant_axis_and_groups, 12 | ) 13 | from utils.qat_utils import prepare_model_for_quantization 14 | from utils.quant_click_options import ( 15 | quantization_options, 16 | activation_quantization_options, 17 | qat_options, 18 | adaround_options, 19 | make_qparams, 20 | split_dict, 21 | ) 22 | from utils.tb_utils import _tb_advance_global_step, _tb_advance_token_counters, _tb_hist 23 | from utils.transformer_click_options import ( 24 | glue_options, 25 | transformer_base_options, 26 | transformer_data_options, 27 | transformer_model_options, 28 | transformer_training_options, 29 | transformer_progress_options, 30 | transformer_quant_options, 31 | ) 32 | from utils.utils import ( 33 | seed_all, 34 | count_params, 35 | count_embedding_params, 36 | pass_data_for_range_estimation, 37 | DotDict, 38 | Stopwatch, 39 | StopForwardException, 40 | ) 41 | -------------------------------------------------------------------------------- /utils/adaround_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | 6 | import torch 7 | 8 | from quantization.adaround import apply_adaround_to_layer 9 | from quantization.adaround.utils import AdaRoundActQuantMode 10 | from quantization.base_quantized_classes import QuantizedModule 11 | from utils.utils import pass_data_for_range_estimation, Stopwatch 12 | 13 | 14 | # setup logger 15 | logger = logging.getLogger('AdaRound') 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | def get_train_samples(data_loader, num_samples, return_labels=False, inp_idx=0, lbl_idx=1): 20 | X, y = [], [] 21 | for data in data_loader: 22 | X.append(data[inp_idx]) 23 | if return_labels: 24 | y.append(data[lbl_idx]) 25 | if len(X) * data[inp_idx].size(0) >= num_samples: 26 | break 27 | 28 | X = torch.cat(X, dim=0)[:num_samples] 29 | if return_labels: 30 | y = torch.cat(y, dim=0)[:num_samples] 31 | return X, y 32 | return X 33 | 34 | 35 | def apply_adaround_to_model(config, model, data_loader, range_est_data_loader, batch_size, 36 | driver=None, get_samples_fn=get_train_samples, inp_idx=0): 37 | """ 38 | Apply AdaRound to `model`. 39 | 40 | Parameters 41 | ---------- 42 | config : DotDict 43 | DotDict with quantization parameters 44 | model : QuantizedModel 45 | model to apply AdaRound on 46 | data_loader : torch.utils.data.DataLoader 47 | Training or other data used for AdaRound optimization 48 | driver : SupervisedDriver 49 | Used for validation. This is only used fore reporting accuracy, not for optimization 50 | inp_idx : int, str 51 | batch index in the input data from the dataloader 52 | """ 53 | train_data = get_samples_fn(data_loader, num_samples=config.adaround.num_samples) 54 | 55 | device = next(model.parameters()).device 56 | train_data = train_data.to(device) 57 | 58 | # check and prepare list of layers to optimize for 59 | all_layer_names = [] 60 | for name, module in model.named_modules(): 61 | if isinstance(module, QuantizedModule) and hasattr(module, 'weight'): 62 | all_layer_names.append(name) 63 | 64 | if 'all' in config.adaround.layers: 65 | adaround_layer_names = all_layer_names 66 | else: 67 | adaround_layer_names = [] 68 | for name in config.adaround.layers: 69 | if name in all_layer_names: 70 | adaround_layer_names.append(name) 71 | else: 72 | logger.warning(f'skipping unknown layer {name}') 73 | 74 | if not len(adaround_layer_names): 75 | logger.warning('No layers to apply AdaRound for, exiting...') 76 | return 77 | 78 | # deal with activation quantization 79 | if config.adaround.act_quant_mode in ( 80 | AdaRoundActQuantMode.no_act_quant, 81 | AdaRoundActQuantMode.post_adaround, 82 | ): 83 | config.quant.act_quant = False 84 | model.reset_act_ranges() 85 | model.full_precision_acts() 86 | else: 87 | raise NotImplementedError(f"act mode '{config.adaround.act_quant_mode}' is not implemented") 88 | 89 | # main loop 90 | s_all = Stopwatch() 91 | for name, module in model.named_modules(): 92 | if not name in adaround_layer_names: 93 | continue 94 | 95 | logger.info(f'Started AdaRound for layer {name}') 96 | 97 | model.full_precision() 98 | module.quantized_weights() 99 | 100 | s_all.start() 101 | with Stopwatch() as s_layer: 102 | apply_adaround_to_layer( 103 | model, 104 | module, 105 | train_data, 106 | batch_size=batch_size, 107 | act_quant=config.quant.act_quant, 108 | adaround_config=config.adaround, 109 | ) 110 | logger.info(f'Done AdaRound for layer {name}. {s_layer.format()}\n') 111 | s_all.stop() 112 | 113 | s_all.stop() 114 | logger.info(f'Done optimizing all layers. {s_all.format()}') 115 | 116 | if config.adaround.act_quant_mode == AdaRoundActQuantMode.post_adaround: 117 | if driver is not None: 118 | # validate before activation quantization 119 | model.quantized_weights() 120 | state = driver.validate() 121 | acc_quant = state.metrics['top_1_accuracy'] 122 | logger.info(f'FINAL res (without acts quant):\t{acc_quant * 100:.2f}%') 123 | 124 | # activate activation quantization and estimate ranges 125 | config.quant.act_quant = True 126 | model.estimate_act_ranges() 127 | pass_data_for_range_estimation( 128 | loader=range_est_data_loader, 129 | model=model, 130 | act_quant=True, 131 | weight_quant=True, 132 | max_num_batches=config.act_quant.num_batches, 133 | cross_entropy_layer=config.act_quant.cross_entropy_layer, 134 | inp_idx=inp_idx, 135 | ) 136 | model.fix_act_ranges() 137 | 138 | # set state 139 | model.set_quant_state(weight_quant=True, act_quant=config.quant.act_quant) 140 | -------------------------------------------------------------------------------- /utils/glue_tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from enum import Flag, auto 6 | from functools import reduce 7 | from operator import or_ 8 | 9 | import numpy as np 10 | from datasets import load_dataset, load_metric 11 | from transformers import EvalPrediction 12 | 13 | from utils.utils import DotDict 14 | 15 | 16 | # setup logger 17 | logger = logging.getLogger('GLUE') 18 | logger.setLevel(logging.INFO) 19 | 20 | 21 | class GLUE_Task(Flag): 22 | cola = auto() 23 | sst2 = auto() 24 | mrpc = auto() 25 | stsb = auto() 26 | qqp = auto() 27 | mnli = auto() 28 | qnli = auto() 29 | rte = auto() 30 | wnli = auto() 31 | all = cola | sst2 | mrpc | stsb | qqp | mnli | qnli | rte | wnli 32 | 33 | def __contains__(self, item): 34 | return (self.value & item.value) == item.value 35 | 36 | @classmethod 37 | def from_str(cls, *names): 38 | """Construct flag from strings.""" 39 | assert len(names) 40 | return reduce(or_, map(cls.__getattr__, names)) 41 | 42 | @classmethod 43 | def list_names(cls): 44 | """List all flags, including `all`.""" 45 | return [m.name for m in cls] 46 | 47 | def iter(self): 48 | """List all member flags which are set (excluding `all`).""" 49 | for x in self.__class__.__members__.values(): 50 | if x in self and x != self.__class__.all: 51 | yield x 52 | 53 | def iter_names(self): 54 | """List all member flag names which are set (excluding `all`).""" 55 | for x in self.iter(): 56 | yield x.name 57 | 58 | 59 | TASK_TO_SENTENCE_KEYS = { 60 | GLUE_Task.cola: ('sentence', None), 61 | GLUE_Task.sst2: ('sentence', None), 62 | GLUE_Task.mrpc: ('sentence1', 'sentence2'), 63 | GLUE_Task.stsb: ('sentence1', 'sentence2'), 64 | GLUE_Task.qqp: ('question1', 'question2'), 65 | GLUE_Task.mnli: ('premise', 'hypothesis'), 66 | GLUE_Task.qnli: ('question', 'sentence'), 67 | GLUE_Task.rte: ('sentence1', 'sentence2'), 68 | GLUE_Task.wnli: ('sentence1', 'sentence2'), 69 | } 70 | 71 | 72 | TASK_TO_FINAL_METRIC = { 73 | GLUE_Task.cola: 'matthews_correlation', 74 | GLUE_Task.sst2: 'accuracy', 75 | GLUE_Task.mrpc: 'combined_score', 76 | GLUE_Task.stsb: 'combined_score', 77 | GLUE_Task.qqp: 'combined_score', 78 | GLUE_Task.mnli: 'accuracy', 79 | GLUE_Task.qnli: 'accuracy', 80 | GLUE_Task.rte: 'accuracy', 81 | GLUE_Task.wnli: 'accuracy', 82 | } 83 | 84 | 85 | TASK_N = { 86 | GLUE_Task.mnli: 392702, 87 | GLUE_Task.qqp: 363846, 88 | GLUE_Task.qnli: 104743, 89 | GLUE_Task.sst2: 67349, 90 | GLUE_Task.cola: 8551, 91 | GLUE_Task.stsb: 5749, 92 | GLUE_Task.mrpc: 3665, 93 | GLUE_Task.rte: 2490, 94 | GLUE_Task.wnli: 635, 95 | } 96 | 97 | 98 | def load_task_data(task: GLUE_Task, data_dir: str): 99 | out = DotDict() 100 | 101 | # download and load data 102 | logger.info(f'Getting {task.name} dataset ...\n') 103 | out.datasets = load_dataset('glue', task.name, cache_dir=data_dir) 104 | 105 | # determine number of labels 106 | logger.info('Determine labels ...\n') 107 | if task == GLUE_Task.stsb: # regression 108 | out.num_labels = 1 109 | logger.info(f'{task.name}: 1 label -- ') 110 | else: 111 | label_list = out.datasets["train"].features["label"].names 112 | out.num_labels = n_labels = len(label_list) 113 | logger.info(f'{task.name}: {n_labels} labels -- {label_list}') 114 | 115 | # store sentence keys 116 | out.sentence1_key, out.sentence2_key = TASK_TO_SENTENCE_KEYS[task] 117 | return out 118 | 119 | 120 | def make_compute_metric_fn(task: GLUE_Task): 121 | metric = load_metric('glue', task.name) 122 | logger.info('Metric:') 123 | logger.info(metric) 124 | 125 | def fn(p: EvalPrediction): 126 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 127 | preds = np.squeeze(preds) if task == GLUE_Task.stsb else np.argmax(preds, axis=1) 128 | result = metric.compute(predictions=preds, references=p.label_ids) 129 | if len(result) > 1: 130 | result['combined_score'] = np.mean(list(result.values())).item() 131 | return result 132 | 133 | return fn 134 | -------------------------------------------------------------------------------- /utils/hf_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from enum import Enum 6 | 7 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer 8 | 9 | from utils.utils import count_embedding_params, count_params, DotDict 10 | 11 | 12 | logger = logging.getLogger('GLUE') 13 | logger.setLevel(logging.ERROR) 14 | 15 | 16 | class HF_Models(Enum): 17 | # vanilla BERT 18 | bert_base_uncased = 'bert-base-uncased' 19 | bert_large_uncased = 'bert-large-uncased' 20 | bert_base_cased = 'bert-base-cased' 21 | 22 | # RoBERTa 23 | roberta_base = 'roberta-base' 24 | 25 | # Distilled vanilla models 26 | distilbert_base_uncased = 'distilbert-base-uncased' 27 | distilroberta_base = 'distilroberta-base' 28 | 29 | # Models optimized for runtime on mobile devices 30 | mobilebert_uncased = 'google/mobilebert-uncased' 31 | squeezebert_uncased = 'squeezebert/squeezebert-uncased' 32 | 33 | # ALBERT: very small set of models (optimized for memory) 34 | albert_base_v2 = 'albert-base-v2' 35 | albert_large_v2 = 'albert-large-v2' 36 | 37 | @classmethod 38 | def list_names(cls): 39 | return [m.name for m in cls] 40 | 41 | 42 | MODEL_TO_BACKBONE_ATTR = { # model.. 43 | HF_Models.bert_base_uncased: 'bert', 44 | HF_Models.bert_large_uncased: 'bert', 45 | HF_Models.bert_base_cased: 'bert', 46 | HF_Models.distilroberta_base: 'roberta', 47 | HF_Models.roberta_base: 'roberta', 48 | HF_Models.mobilebert_uncased: 'mobilebert', 49 | } 50 | 51 | 52 | def load_model_and_tokenizer(model_name, model_path, use_fast_tokenizer, cache_dir, attn_dropout, 53 | hidden_dropout, num_labels, **kw): 54 | del kw # unused 55 | 56 | out = DotDict() 57 | 58 | # Config 59 | if model_path is not None: 60 | model_name_or_path = model_path 61 | else: 62 | model_name_or_path = HF_Models[model_name].value # use HF identifier 63 | 64 | config = AutoConfig.from_pretrained( 65 | model_name_or_path, 66 | num_labels=num_labels, 67 | cache_dir=cache_dir, 68 | ) 69 | 70 | # set dropout rates 71 | if attn_dropout is not None: 72 | logger.info(f'Setting attn dropout to {attn_dropout}') 73 | if hasattr(config, 'attention_probs_dropout_prob'): 74 | setattr(config, 'attention_probs_dropout_prob', attn_dropout) 75 | 76 | if hidden_dropout is not None: 77 | logger.info(f'Setting hidden dropout to {hidden_dropout}') 78 | if hasattr(config, 'hidden_dropout_prob'): 79 | setattr(config, 'hidden_dropout_prob', attn_dropout) 80 | 81 | logger.info('HuggingFace model config:') 82 | logger.info(config) 83 | out.config = config 84 | out.model_name_or_path = model_name_or_path 85 | 86 | # Tokenizer 87 | tokenizer = AutoTokenizer.from_pretrained( 88 | model_name_or_path, 89 | use_fast=use_fast_tokenizer, 90 | cache_dir=cache_dir, 91 | ) 92 | logger.info('Tokenizer:') 93 | logger.info(tokenizer) 94 | out.tokenizer = tokenizer 95 | 96 | # Model 97 | model = AutoModelForSequenceClassification.from_pretrained( 98 | model_name_or_path, 99 | from_tf=False, 100 | config=config, 101 | cache_dir=cache_dir, 102 | ) 103 | logger.info('Model:') 104 | logger.info(model) 105 | out.model = model 106 | 107 | # Parameter counts 108 | total_params = count_params(model) 109 | embedding_params = count_embedding_params(model) 110 | non_embedding_params = total_params - embedding_params 111 | logger.info(f'Parameters (embedding): {embedding_params}') 112 | logger.info(f'Parameters (non-embedding): {non_embedding_params}') 113 | logger.info(f'Parameters (total): {total_params}') 114 | out.total_params = total_params 115 | out.embedding_params = embedding_params 116 | out.non_embedding_params = non_embedding_params 117 | 118 | # Additional attributes 119 | out.model_enum = HF_Models[model_name] 120 | out.backbone_attr = MODEL_TO_BACKBONE_ATTR.get(out.model_enum, None) 121 | return out 122 | -------------------------------------------------------------------------------- /utils/per_embd_quant_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from quantization.base_quantized_classes import FP32Acts 5 | 6 | 7 | def _hijack_act_quant(module, value): 8 | if value is None: 9 | return 10 | 11 | if isinstance(value, int): 12 | module.activation_quantizer.quantizer.n_bits = value 13 | elif value == 'fp32': 14 | module.activation_quantizer = FP32Acts() 15 | elif value == 'per_embd': 16 | set_act_quant_axis_and_groups(module, axis=2, n_groups=None) 17 | elif value.startswith('ngp'): 18 | set_act_quant_axis_and_groups(module, axis=2, n_groups=int(value[3:]), permute=True) 19 | elif value.startswith('ng'): 20 | set_act_quant_axis_and_groups(module, axis=2, n_groups=int(value[2:]), permute=False) 21 | else: 22 | raise NotImplementedError(f'Unknown value "{value}" in quant_dict') 23 | 24 | 25 | def _hijack_weight_quant(module, value): 26 | if value is None: 27 | return 28 | 29 | if isinstance(value, int): 30 | module.weight_quantizer.quantizer.n_bits = value 31 | elif value == 'fp32': 32 | module.weight_quantizer = FP32Acts() 33 | else: 34 | raise NotImplementedError(f'Unknown value "{value}" in quant_dict') 35 | 36 | 37 | def hijack_act_quant(quant_dict, name, m): 38 | value = quant_dict.get(name, None) 39 | _hijack_act_quant(m, value) 40 | 41 | 42 | def hijack_weight_quant(quant_dict, name, m): 43 | value = quant_dict.get(name, None) 44 | _hijack_weight_quant(m, value) 45 | 46 | 47 | def hijack_act_quant_modules(quant_dict, name, m): 48 | value = quant_dict.get(name, None) 49 | for m_ in m.modules(): 50 | if hasattr(m_, 'activation_quantizer'): 51 | _hijack_act_quant(m_, value) 52 | 53 | 54 | def set_act_quant_axis_and_groups(module, axis, n_groups, permute=False): 55 | if hasattr(module, 'activation_quantizer'): 56 | module = module.activation_quantizer 57 | 58 | module.axis = axis 59 | module.quantizer.axis = axis 60 | module.range_estimator.axis = axis 61 | 62 | module.n_groups = n_groups 63 | module.range_estimator.n_groups = n_groups 64 | 65 | if permute: 66 | module.range_estimator.per_group_range_estimation = True 67 | 68 | return module 69 | -------------------------------------------------------------------------------- /utils/qat_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | 6 | from utils.utils import pass_data_for_range_estimation 7 | 8 | 9 | # setup logger 10 | logger = logging.getLogger('QAT') 11 | logger.setLevel('INFO') 12 | 13 | 14 | def prepare_model_for_quantization(config, model, loader): 15 | 16 | # estimate ranges using training data 17 | pass_data_for_range_estimation( 18 | loader=loader, 19 | model=model, 20 | act_quant=config.quant.act_quant, 21 | weight_quant=config.quant.weight_quant, 22 | max_num_batches=config.act_quant.num_batches, 23 | cross_entropy_layer=config.act_quant.cross_entropy_layer, 24 | ) 25 | 26 | # put quantizers in desirable state 27 | if config.qat.learn_ranges: 28 | logger.info('Make quantizers learnable') 29 | model.learn_ranges() 30 | else: 31 | logger.info( 32 | f'Fix quantizer ranges to fixW={config.qat.fix_weight_ranges} and ' 33 | f'fixA={config.qat.fix_act_ranges}' 34 | ) 35 | 36 | # freeze quantization ranges if applicable 37 | model.estimate_ranges_train() # we use updating ranges in training as default 38 | if config.qat.fix_weight_ranges: 39 | model.fix_weight_ranges() 40 | if config.qat.fix_act_ranges: 41 | model.fix_act_ranges() 42 | 43 | # ensure we have the desired quant state 44 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant) 45 | return model 46 | -------------------------------------------------------------------------------- /utils/quant_click_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from functools import wraps 5 | 6 | import click 7 | 8 | from quantization.adaround.config import AdaRoundConfig, DEFAULT_ADAROUND_CONFIG as C 9 | from quantization.adaround.utils import ( 10 | AdaRoundActQuantMode, 11 | AdaRoundInitMode, 12 | AdaRoundMode, 13 | AdaRoundTempDecayType, 14 | ) 15 | from quantization.quantizers import QMethods 16 | from quantization.range_estimators import RangeEstimators, OptMethod 17 | from utils.utils import DotDict 18 | 19 | 20 | class StrTuple(click.ParamType): 21 | name = 'str_sequence' 22 | 23 | def convert(self, value, param, ctx): 24 | values = value.split(',') 25 | values = tuple(map(lambda s: s.strip(), values)) 26 | return values 27 | 28 | 29 | def split_dict(src: dict, include=()): 30 | """ 31 | Splits dictionary into a DotDict and a remainder. 32 | The arguments to be placed in the first DotDict are those listed in `include`. 33 | 34 | Parameters 35 | ---------- 36 | src: dict 37 | The source dictionary. 38 | include: 39 | List of keys to be returned in the first DotDict. 40 | """ 41 | result = DotDict() 42 | 43 | for arg in include: 44 | result[arg] = src[arg] 45 | remainder = {key: val for key, val in src.items() if key not in include} 46 | return result, remainder 47 | 48 | 49 | def quantization_options(func): 50 | @click.option( 51 | '--qmethod', 52 | type=click.Choice(QMethods.list()), 53 | required=True, 54 | help='Quantization scheme to use.', 55 | ) 56 | @click.option( 57 | '--qmethod-act', 58 | type=click.Choice(QMethods.list()), 59 | default=None, 60 | help='Quantization scheme for activation to use. If not specified `--qmethod` is used.', 61 | ) 62 | @click.option( 63 | '--weight-quant-method', 64 | default=RangeEstimators.current_minmax.name, 65 | type=click.Choice(RangeEstimators.list()), 66 | help='Method to determine weight quantization clipping thresholds.', 67 | ) 68 | @click.option( 69 | '--weight-opt-method', 70 | default=OptMethod.grid.name, 71 | type=click.Choice(OptMethod.list()), 72 | help='Optimization procedure for activation quantization clipping thresholds', 73 | ) 74 | @click.option( 75 | '--num-candidates', 76 | type=int, 77 | default=None, 78 | help='Number of grid points for grid search in MSE range method.', 79 | ) 80 | @click.option('--n-bits', default=8, type=int, help='Default number of quantization bits.') 81 | @click.option( 82 | '--n-bits-act', default=None, type=int, help='Number of quantization bits for activations.' 83 | ) 84 | @click.option('--per-channel', is_flag=True, help='If given, quantize each channel separately.') 85 | @click.option( 86 | '--percentile', 87 | type=float, 88 | default=None, 89 | help='Percentile clipping parameter (weights and activations)', 90 | ) 91 | @click.option( 92 | '--act-quant/--no-act-quant', 93 | is_flag=True, 94 | default=True, 95 | help='Run evaluation with activation quantization or use FP32 activations', 96 | ) 97 | @click.option( 98 | '--weight-quant/--no-weight-quant', 99 | is_flag=True, 100 | default=True, 101 | help='Run evaluation weight quantization or use FP32 weights', 102 | ) 103 | @click.option( 104 | '--quant-setup', 105 | default='all', 106 | type=click.Choice(['all', 'FP_logits', 'MSE_logits']), 107 | help='Method to quantize the network.', 108 | ) 109 | @wraps(func) 110 | def func_wrapper(config, *args, **kwargs): 111 | config.quant, remainder_kwargs = split_dict(kwargs, [ 112 | 'qmethod', 113 | 'qmethod_act', 114 | 'weight_quant_method', 115 | 'weight_opt_method', 116 | 'num_candidates', 117 | 'n_bits', 118 | 'n_bits_act', 119 | 'per_channel', 120 | 'percentile', 121 | 'act_quant', 122 | 'weight_quant', 123 | 'quant_setup', 124 | ]) 125 | 126 | config.quant.qmethod_act = config.quant.qmethod_act or config.quant.qmethod 127 | 128 | return func(config, *args, **remainder_kwargs) 129 | 130 | return func_wrapper 131 | 132 | 133 | def activation_quantization_options(func): 134 | @click.option( 135 | '--act-quant-method', 136 | default=RangeEstimators.running_minmax.name, 137 | type=click.Choice(RangeEstimators.list()), 138 | help='Method to determine activation quantization clipping thresholds', 139 | ) 140 | @click.option( 141 | '--act-opt-method', 142 | default=OptMethod.grid.name, 143 | type=click.Choice(OptMethod.list()), 144 | help='Optimization procedure for activation quantization clipping thresholds', 145 | ) 146 | @click.option( 147 | '--act-num-candidates', 148 | type=int, 149 | default=None, 150 | help='Number of grid points for grid search in MSE/Cross-entropy', 151 | ) 152 | @click.option( 153 | '--act-momentum', 154 | type=float, 155 | default=None, 156 | help='Exponential averaging factor for running_minmax', 157 | ) 158 | @click.option( 159 | '--cross-entropy-layer', 160 | default=None, 161 | type=str, 162 | help='Cross-entropy for activation range setting (often valuable for last layer)', 163 | ) 164 | @click.option( 165 | '--num-est-batches', 166 | type=int, 167 | default=1, 168 | help='Number of training batches to be used for activation range estimation', 169 | ) 170 | @wraps(func) 171 | def func_wrapper(config, act_quant_method, act_opt_method, act_num_candidates, act_momentum, 172 | cross_entropy_layer, num_est_batches, *args, **kwargs): 173 | config.act_quant = DotDict() 174 | config.act_quant.quant_method = act_quant_method 175 | config.act_quant.cross_entropy_layer = cross_entropy_layer 176 | config.act_quant.num_batches = num_est_batches 177 | 178 | config.act_quant.options = {} 179 | 180 | if act_num_candidates is not None: 181 | if act_quant_method != 'MSE': 182 | raise ValueError('Wrong option num_candidates passed') 183 | else: 184 | config.act_quant.options['num_candidates'] = act_num_candidates 185 | 186 | if act_momentum is not None: 187 | if act_quant_method != 'running_minmax': 188 | raise ValueError('Wrong option momentum passed') 189 | else: 190 | config.act_quant.options['momentum'] = act_momentum 191 | 192 | if act_opt_method != 'grid': 193 | config.act_quant.options['opt_method'] = OptMethod[act_opt_method] 194 | return func(config, *args, **kwargs) 195 | 196 | return func_wrapper 197 | 198 | 199 | def qat_options(func): 200 | @click.option( 201 | '--learn-ranges', 202 | is_flag=True, 203 | default=False, 204 | help='Learn quantization ranges, in that case fix ranges will be ignored.', 205 | ) 206 | @click.option( 207 | '--fix-act-ranges/--no-fix-act-ranges', 208 | is_flag=True, 209 | default=False, 210 | help='Fix all activation quantization ranges for stable training', 211 | ) 212 | @click.option( 213 | '--fix-weight-ranges/--no-fix-weight-ranges', 214 | is_flag=True, 215 | default=False, 216 | help='Fix all weight quantization ranges for stable training', 217 | ) 218 | @wraps(func) 219 | def func_wrapper(config, *args, **kwargs): 220 | config.qat, remainder_kwargs = split_dict( 221 | kwargs, ['learn_ranges', 'fix_act_ranges', 'fix_weight_ranges'] 222 | ) 223 | 224 | return func(config, *args, **remainder_kwargs) 225 | 226 | return func_wrapper 227 | 228 | 229 | def adaround_options(func): 230 | # Base options 231 | @click.option( 232 | '--adaround', 233 | default=None, 234 | type=StrTuple(), 235 | help="Apply AdaRound: for full model ('all'), or any number of layers, " 236 | "specified by comma-separated names.", 237 | ) 238 | @click.option( 239 | '--adaround-num-samples', 240 | default=C.num_samples, 241 | type=int, 242 | help='Number of samples to use for learning the rounding.', 243 | ) 244 | @click.option( 245 | '--adaround-init', 246 | default=C.init.name, 247 | type=click.Choice(AdaRoundInitMode.list_names(), case_sensitive=False), 248 | help='Method to initialize the quantization grid for weights.', 249 | ) 250 | 251 | # Method and continuous relaxation options 252 | @click.option( 253 | '--adaround-mode', 254 | default=C.round_mode.name, 255 | type=click.Choice(AdaRoundMode.list_names(), case_sensitive=False), 256 | help='Method to learn the rounding.', 257 | ) 258 | @click.option( 259 | '--adaround-asym/--no-adaround-asym', 260 | is_flag=True, 261 | default=C.asym, 262 | help='Whether to use asymmetric reconstruction for AdaRound.', 263 | ) 264 | @click.option( 265 | '--adaround-include-act-func/--adaround-no-act-func', 266 | is_flag=True, 267 | default=C.include_act_func, 268 | help='Include activation function into AdaRound.', 269 | ) 270 | @click.option( 271 | '--adaround-lr', 272 | default=C.lr, 273 | type=float, 274 | help='Learning rate for continuous relaxation in AdaRound.', 275 | ) 276 | @click.option( 277 | '--adaround-iters', 278 | default=C.iters, 279 | type=int, 280 | help='Number of iterations to train each layer.', 281 | ) 282 | @click.option( 283 | '--adaround-weight', 284 | default=C.weight, 285 | type=float, 286 | help='Weight of rounding cost vs the reconstruction loss.', 287 | ) 288 | @click.option( 289 | '--adaround-annealing', 290 | default=C.annealing, 291 | nargs=2, 292 | type=float, 293 | help='Annealing of regularization function temperature (tuple: start, end).', 294 | ) 295 | @click.option( 296 | '--adaround-decay-type', 297 | default=C.decay_type.name, 298 | type=click.Choice(AdaRoundTempDecayType.list_names(), case_sensitive=False), 299 | help='Type of temperature annealing schedule.', 300 | ) 301 | @click.option( 302 | '--adaround-decay-shape', 303 | default=C.decay_shape, 304 | type=float, 305 | help="Positive " 306 | "scalar value that controls the shape of decay schedules 'sigmoid', 'power', " 307 | "'exp', 'log'. Sensible values to try: sigmoid{10}, power{4,6,8}, exp{4,6,8}, " 308 | "log{1,2,3}.", 309 | ) 310 | @click.option( 311 | '--adaround-decay-start', 312 | default=C.decay_start, 313 | type=float, 314 | help='Start of annealing (relative to --ltr-iters).', 315 | ) 316 | @click.option( 317 | '--adaround-warmup', 318 | default=C.warmup, 319 | type=float, 320 | help='In the warmup period no regularization is applied (relative to --ltr-iters).', 321 | ) 322 | 323 | # Activation quantization 324 | @click.option( 325 | '--adaround-act-quant', 326 | default=C.act_quant_mode.name, 327 | type=click.Choice(AdaRoundActQuantMode.list_names(), case_sensitive=False), 328 | help='Method to deal with activation quantization during AdaRound.', 329 | ) 330 | @wraps(func) 331 | def func_wrapper(config, *args, **kwargs): 332 | config.adaround = AdaRoundConfig(**C) 333 | 334 | config.adaround.layers = kwargs.pop('adaround') 335 | config.adaround.num_samples = kwargs.pop('adaround_num_samples') 336 | config.adaround.init = AdaRoundInitMode[kwargs.pop('adaround_init')] 337 | 338 | config.adaround.round_mode = AdaRoundMode[kwargs.pop('adaround_mode')] 339 | config.adaround.asym = kwargs.pop('adaround_asym') 340 | config.adaround.include_act_func = kwargs.pop('adaround_include_act_func') 341 | config.adaround.lr = kwargs.pop('adaround_lr') 342 | config.adaround.iters = kwargs.pop('adaround_iters') 343 | config.adaround.weight = kwargs.pop('adaround_weight') 344 | config.adaround.annealing = kwargs.pop('adaround_annealing') 345 | config.adaround.decay_type = AdaRoundTempDecayType[kwargs.pop('adaround_decay_type')] 346 | config.adaround.decay_shape = kwargs.pop('adaround_decay_shape') 347 | config.adaround.decay_start = kwargs.pop('adaround_decay_start') 348 | config.adaround.warmup = kwargs.pop('adaround_warmup') 349 | 350 | config.adaround.act_quant_mode = AdaRoundActQuantMode[kwargs.pop('adaround_act_quant')] 351 | return func(config, *args, **kwargs) 352 | 353 | return func_wrapper 354 | 355 | 356 | def make_qparams(config): 357 | weight_range_options = {} 358 | if config.quant.weight_quant_method in ['MSE', 'cross_entropy']: 359 | weight_range_options = dict(opt_method=OptMethod[config.quant.weight_opt_method]) 360 | if config.quant.num_candidates is not None: 361 | weight_range_options['num_candidates'] = config.quant.num_candidates 362 | 363 | act_range_options = config.act_quant.options 364 | if config.quant.percentile is not None: 365 | act_range_options['percentile'] = config.quant.percentile 366 | 367 | params = { 368 | 'method': QMethods[config.quant.qmethod], 369 | 'act_method': QMethods[config.quant.qmethod_act], 370 | 'n_bits': config.quant.n_bits, 371 | 'n_bits_act': config.quant.n_bits_act, 372 | 'per_channel_weights': config.quant.per_channel, 373 | 'percentile': config.quant.percentile, 374 | 'quant_setup': config.quant.quant_setup, 375 | 'weight_range_method': RangeEstimators[config.quant.weight_quant_method], 376 | 'weight_range_options': weight_range_options, 377 | 'act_range_method': RangeEstimators[config.act_quant.quant_method], 378 | 'act_range_options': config.act_quant.options, 379 | } 380 | return params 381 | -------------------------------------------------------------------------------- /utils/tb_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | 5 | def _tb_advance_global_step(module): 6 | if hasattr(module, 'global_step'): 7 | module.global_step += 1 8 | return module 9 | 10 | 11 | def _tb_advance_token_counters(module, tensor, verbose=False): 12 | token_count = getattr(module, 'tb_token_count', None) 13 | if token_count is not None: 14 | T = tensor.size(1) 15 | if token_count.last != T: 16 | if token_count.last != 0: 17 | token_count.total += token_count.last 18 | token_count.sample_idx += 1 19 | token_count.last = T 20 | 21 | if verbose: 22 | print(f'>>> T={T}\tlast_T={token_count.last}\tcumsum_T={token_count.total}') 23 | return module 24 | 25 | 26 | def _tb_hist(module, tensor, name, verbose=False): 27 | hist_kw = dict(bins='auto') 28 | 29 | tb_writer = getattr(module, 'tb_writer', None) 30 | if tb_writer is not None: 31 | if module.layer_idx == module.num_layers - 1: 32 | tensor = tensor[:, 0] 33 | 34 | # per-tensor 35 | layer_s = str(1 + module.layer_idx).zfill(2) 36 | full_name = f'{layer_s}/layer/{name}' 37 | global_step = module.global_step 38 | if verbose: 39 | stats = f'min={tensor.min():.1f}, max={tensor.max():.1f}' 40 | info = ( 41 | f'TB logging {full_name}\t{tuple(tensor.size())}\t({stats})\t' 42 | f'[global_step={global_step}] ...' 43 | ) 44 | print(info) 45 | tb_writer.add_histogram(full_name, tensor, global_step=global_step, **hist_kw) 46 | 47 | # per-token 48 | sample_idx_s = str(module.tb_token_count.sample_idx + 1).zfill(2) 49 | T = tensor.size(1) 50 | full_name = f'{layer_s}/token/{sample_idx_s}/{name}' 51 | for i in range(T): 52 | tb_writer.add_histogram(full_name, tensor[0, i], global_step=i, **hist_kw) 53 | -------------------------------------------------------------------------------- /utils/transformer_click_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from functools import wraps 5 | from pathlib import Path 6 | 7 | import click 8 | import torch 9 | from transformers.trainer_utils import EvaluationStrategy 10 | 11 | from utils.hf_models import HF_Models 12 | from utils.glue_tasks import GLUE_Task 13 | from utils.quant_click_options import split_dict 14 | from utils.utils import seed_all 15 | 16 | 17 | def transformer_base_options(func): 18 | @click.option( 19 | '--cuda/--no-cuda', is_flag=True, default=torch.cuda.is_available(), help='Use GPU' 20 | ) 21 | @click.option('--seed', type=int, default=1000, help='Random number generator seed to set.') 22 | @click.option( 23 | '--num-workers', 24 | type=int, 25 | default=0, 26 | help='Number of PyTorch data loader subprocesses. 0 means that the data will be ' 27 | 'loaded in the main process.', 28 | ) 29 | @click.option( 30 | '--output-dir', 31 | default=None, 32 | type=click.Path(file_okay=False, writable=True, resolve_path=True), 33 | help='The output directory where the model predictions and checkpoints will be written. ' 34 | 'The model and the tokenizer will be saved in the `out` sub-folder.', 35 | ) 36 | @click.option( 37 | '--overwrite-output', 38 | is_flag=True, 39 | default=False, 40 | help='Overwrite the content of the output directory and log file. Use this to ' 41 | 'continue training if output directory contains model checkpoint.', 42 | ) 43 | @wraps(func) 44 | def func_wrapper(config, *args, **kwargs): 45 | attrs = ['cuda', 'seed', 'num_workers', 'output_dir', 'overwrite_output'] 46 | config.base, other_kw = split_dict(kwargs, attrs) 47 | seed = config.base.seed 48 | if seed is not None: 49 | seed_all(seed) # must be set before initializing the model 50 | return func(config, *args, **other_kw) 51 | 52 | return func_wrapper 53 | 54 | 55 | def transformer_data_options(func): 56 | @click.option( 57 | '--max-seq-length', 58 | type=int, 59 | default=128, 60 | help='The maximum total input sequence length after tokenization. Sequences ' 61 | 'longer than this will be truncated, sequences shorter will be padded.', 62 | ) 63 | @click.option( 64 | '--pad-to-max-length/--no-pad-to-max-length', 65 | is_flag=True, 66 | default=True, 67 | help='Whether to pad all samples to `max_seq_length`. If False, will pad the ' 68 | 'samples dynamically when batching to the maximum length in the batch.', 69 | ) 70 | @click.option('--line-by-line', is_flag=True, default=False) 71 | @click.option( 72 | '--overwrite-cache', 73 | is_flag=True, 74 | default=False, 75 | help='Overwrite the cached preprocessed datasets or not.', 76 | ) 77 | @click.option('--num-train-samples', type=int, default=None) 78 | @click.option( 79 | '--num-val-samples', 80 | type=int, 81 | default=None, 82 | help='Number of samples to use for validation. If not specified, ' 83 | 'validate on the entire set(s).', 84 | ) 85 | @wraps(func) 86 | def func_wrapper(config, *args, **kwargs): 87 | attrs = [ 88 | 'max_seq_length', 89 | 'pad_to_max_length', 90 | 'line_by_line', 91 | 'overwrite_cache', 92 | 'num_train_samples', 93 | 'num_val_samples', 94 | ] 95 | 96 | config.data, other_kw = split_dict(kwargs, attrs) 97 | return func(config, *args, **other_kw) 98 | 99 | return func_wrapper 100 | 101 | 102 | def glue_options(func): 103 | @click.option( 104 | '--task', 105 | type=click.Choice(GLUE_Task.list_names(), case_sensitive=False), 106 | default=(GLUE_Task.mrpc.name,), 107 | multiple=True, 108 | help='The name of the task to train on.', 109 | ) 110 | @click.option( 111 | '--data-dir', 112 | default=str(Path.home() / '.glue_data'), 113 | type=click.Path(file_okay=False, writable=True, resolve_path=True), 114 | help='Directory where both raw and preprocessed GLUE datasets are stored.', 115 | ) 116 | @wraps(func) 117 | def func_wrapper(config, *args, **kwargs): 118 | attrs = ['task', 'data_dir'] 119 | 120 | config.glue, other_kw = split_dict(kwargs, attrs) 121 | return func(config, *args, **other_kw) 122 | 123 | return func_wrapper 124 | 125 | 126 | def transformer_model_options(func): 127 | # GLUE 128 | @click.option( 129 | '--model-name', 130 | type=click.Choice(HF_Models.list_names(), case_sensitive=False), 131 | default=HF_Models.bert_base_uncased.name, 132 | help='Model identifier from huggingface.co/models.', 133 | ) 134 | @click.option( 135 | '--model-path', 136 | default=None, 137 | type=click.Path(exists=True, file_okay=False, resolve_path=True), 138 | help='For training (both FP32 and quantized), it is a path to a pretrained model, together ' 139 | 'with a tokenizer (can be used to resume training). For validation, it is a path that ' 140 | 'should contain fine-tuned checkpoints for all the requested tasks (each in a separate ' 141 | 'sub-folder named as a corresponding task).', 142 | ) 143 | @click.option( 144 | '--quant-model-path', 145 | default=None, 146 | type=click.Path(exists=True, file_okay=False, resolve_path=True), 147 | help='State dict of quantized model.', 148 | ) 149 | @click.option( 150 | '--use-fast-tokenizer', 151 | is_flag=True, 152 | default=True, 153 | help='Whether to use one of the fast tokenizer (backed by the HuggingFace ' 154 | 'tokenizers library) or not.', 155 | ) 156 | @click.option( 157 | '--cache-dir', 158 | default=str(Path.home() / '.hf_cache'), 159 | type=click.Path(file_okay=False, writable=True, resolve_path=True), 160 | help='Where to store downloaded pretrained HuggingFace models (together with ' 161 | 'respective config and a tokenizer).', 162 | ) 163 | @click.option( 164 | '--attn-dropout', default=None, type=float, help='Dropout rate to set for attention probs.' 165 | ) 166 | @click.option( 167 | '--hidden-dropout', default=None, type=float, help='Dropout rate to set for hidden states.' 168 | ) 169 | @wraps(func) 170 | def func_wrapper(config, *args, **kwargs): 171 | attrs = [ 172 | 'model_name', 173 | 'model_path', 174 | 'quant_model_path', 175 | 'use_fast_tokenizer', 176 | 'cache_dir', 177 | 'attn_dropout', 178 | 'hidden_dropout', 179 | ] 180 | 181 | config.model, other_kw = split_dict(kwargs, attrs) 182 | return func(config, *args, **other_kw) 183 | 184 | return func_wrapper 185 | 186 | 187 | def transformer_training_options(func): 188 | # standard settings 189 | @click.option( 190 | '--do-eval/--no-eval', 191 | is_flag=True, 192 | default=True, 193 | help='Whether to run eval on the dev set after training.', 194 | ) 195 | @click.option('--batch-size', type=int, default=8, help='Batch size for training.') 196 | @click.option( 197 | '--eval-batch-size', 198 | type=int, 199 | default=None, 200 | help='Batch size for evaluation. Defaults to the batch size for training.', 201 | ) 202 | @click.option( 203 | '--learning-rate', type=float, default=5e-5, help='The initial learning rate for Adam.' 204 | ) 205 | @click.option( 206 | '--lr-scheduler-type', 207 | default='cosine', 208 | type=click.Choice(['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 209 | 'constant_with_warmup']), 210 | help='The scheduler type to use.', 211 | ) 212 | @click.option('--weight-decay', type=float, default=0.0, help='Weight decay for AdamW.') 213 | @click.option( 214 | '--max-grad-norm', 215 | type=float, 216 | default=None, 217 | help='Max gradient norm. If set to 0, no clipping will be applied.', 218 | ) 219 | @click.option( 220 | '--num-epochs', type=int, default=3, help='Total number of training epochs to perform.' 221 | ) 222 | @click.option( 223 | '--max-steps', 224 | type=int, 225 | default=0, 226 | help='If > 0, set total number of training steps to perform. Overrides `num_epochs`.', 227 | ) 228 | @click.option('--warmup-steps', type=int, default=0, help='Linear warmup over `warmup_steps`.') 229 | 230 | # hw optimizations 231 | @click.option( 232 | '--gradient-accumulation-steps', 233 | type=int, 234 | default=1, 235 | help="Number of updates steps to accumulate before performing a backward/update pass.", 236 | ) 237 | @click.option('--amp', is_flag=True, default=False, help='Whether to use Apex AMP.') 238 | @click.option( 239 | '--amp-opt-level', 240 | type=click.Choice(('O0', 'O1', 'O2', 'O3')), 241 | default='O2', 242 | help='Apex AMP optimization level.', 243 | ) 244 | 245 | # custom regularization 246 | @click.option('--ffn-weight-decay', type=float, default=0, help='Weight decay for FFN weights.') 247 | @click.option('--gamma', type=float, default=0, help='Activation regularization strength.') 248 | @click.option('--margin', type=float, default=0, help='Activation regularization margin.') 249 | 250 | # custom functionality 251 | @click.option( 252 | '--save-attn', 253 | is_flag=True, 254 | default=False, 255 | help='Save attention probabilities from the training set and skip training.', 256 | ) 257 | @wraps(func) 258 | def func_wrapper(config, *args, **kwargs): 259 | attrs = [ 260 | 'do_eval', 261 | 'batch_size', 262 | 'eval_batch_size', 263 | 'learning_rate', 264 | 'lr_scheduler_type', 265 | 'weight_decay', 266 | 'max_grad_norm', 267 | 'num_epochs', 268 | 'max_steps', 269 | 'warmup_steps', 270 | 'gradient_accumulation_steps', 271 | 'amp', 272 | 'amp_opt_level', 273 | 'ffn_weight_decay', 274 | 'gamma', 275 | 'margin', 276 | 'save_attn', 277 | ] 278 | 279 | config.training, other_kw = split_dict(kwargs, attrs) 280 | if config.training.eval_batch_size is None: 281 | config.training.eval_batch_size = config.training.batch_size 282 | 283 | return func(config, *args, **other_kw) 284 | 285 | return func_wrapper 286 | 287 | 288 | def transformer_progress_options(func): 289 | @click.option('--tqdm/--no-tqdm', default=True) 290 | @click.option( 291 | '--eval-during-training', 292 | is_flag=True, 293 | default=False, 294 | help='Run evaluation during training at each logging step.', 295 | ) 296 | @click.option( 297 | '--eval-strategy', 298 | default=EvaluationStrategy.NO.value, 299 | type=click.Choice([m.value for m in EvaluationStrategy], case_sensitive=False), 300 | help='Evaluation frequency level.', 301 | ) 302 | @click.option( 303 | '--eval-steps', type=int, default=None, help='Run an evaluation every `eval_steps` steps.' 304 | ) 305 | @click.option( 306 | '--tb-logging-dir', 307 | default=None, 308 | type=click.Path(exists=False, writable=True, resolve_path=True), 309 | help='Tensorboard log dir.', 310 | ) 311 | @click.option( 312 | '--tb', 313 | is_flag=True, 314 | default=False, 315 | help='Whether to create and log (additional) stuff to the TensorBoard writer', 316 | ) 317 | @click.option( 318 | '--tb-graph', 319 | is_flag=True, 320 | default=False, 321 | help='Whether to log computational graph into the TensorBoard writer', 322 | ) 323 | @click.option( 324 | '--logging-first-step', 325 | is_flag=True, 326 | default=False, 327 | help='Log and eval the first global_step.', 328 | ) 329 | @click.option( 330 | '--logging-steps', type=int, default=500, help='Log every `logging_steps` updates steps.' 331 | ) 332 | @click.option( 333 | '--save-steps', 334 | type=int, 335 | default=0, 336 | help='Save checkpoint every `save_steps` updates steps.', 337 | ) 338 | @click.option( 339 | '--save-total-limit', 340 | type=int, 341 | default=None, 342 | help='Limit the total amount of checkpoints. Deletes the older checkpoints in ' 343 | 'the `output_dir`. Default is unlimited checkpoints.', 344 | ) 345 | @click.option( 346 | '--save-model', 347 | is_flag=True, 348 | default=False, 349 | help='Whether to save model and tokenizer after the training.', 350 | ) 351 | @click.option( 352 | '--run-name', 353 | type=str, 354 | default=None, 355 | help='An optional descriptor for the run. Notably used for wandb logging.', 356 | ) 357 | @click.option( 358 | '--load-best-model-at-end', 359 | is_flag=True, 360 | default=False, 361 | help='Whether or not to load the best model found during training at the end of ' 362 | 'training.', 363 | ) 364 | @click.option( 365 | '--metric-for-best-model', 366 | type=str, 367 | default=None, 368 | help='The metric to use to compare two different models.', 369 | ) 370 | @click.option( 371 | '--greater-is-better', 372 | type=bool, 373 | default=None, 374 | help='Whether the `metric_for_best_model` should be maximized or not.', 375 | ) 376 | @wraps(func) 377 | def func_wrapper(config, *args, **kwargs): 378 | attrs = [ 379 | 'tqdm', 380 | 'eval_during_training', 381 | 'eval_strategy', 382 | 'eval_steps', 383 | 'tb_logging_dir', 384 | 'tb', 385 | 'tb_graph', 386 | 'logging_first_step', 387 | 'logging_steps', 388 | 'save_steps', 389 | 'save_total_limit', 390 | 'save_model', 391 | 'run_name', 392 | 'load_best_model_at_end', 393 | 'metric_for_best_model', 394 | 'greater_is_better', 395 | ] 396 | 397 | config.progress, other_kw = split_dict(kwargs, attrs) 398 | return func(config, *args, **other_kw) 399 | 400 | return func_wrapper 401 | 402 | 403 | def transformer_quant_options(func): 404 | @click.option( 405 | '--est-ranges-pad/--est-ranges-no-pad', 406 | is_flag=True, 407 | default=None, 408 | help='Specify whether to pad to max sequence length during range estimation.' 409 | 'If None, inherit the value of --pad-to-max-length.', 410 | ) 411 | @click.option( 412 | '--est-ranges-batch-size', 413 | type=int, 414 | default=None, 415 | help='Batch size for range estimation. Defaults to the batch size for training.', 416 | ) 417 | @click.option('--quant-dict', type=str, default=None) 418 | @click.option('--double', is_flag=True) 419 | @click.option('--dynamic', is_flag=True) 420 | @click.option('--per-token', is_flag=True) 421 | @click.option('--per-embd', is_flag=True) 422 | @click.option('--per-groups', type=int, default=None) 423 | @click.option('--per-groups-permute', is_flag=True) 424 | @click.option('--per-groups-permute-shared-h', is_flag=True) 425 | @wraps(func) 426 | def func_wrapper(config, est_ranges_pad, est_ranges_batch_size, quant_dict, double, dynamic, 427 | per_token, per_embd, per_groups, per_groups_permute, 428 | per_groups_permute_shared_h, *a, **kw): 429 | config.quant.est_ranges_pad = est_ranges_pad 430 | config.quant.est_ranges_batch_size = ( 431 | est_ranges_batch_size if est_ranges_batch_size is not None 432 | else config.training.batch_size 433 | ) 434 | 435 | if quant_dict is not None: 436 | quant_dict = eval(quant_dict) 437 | config.quant.quant_dict = quant_dict 438 | config.double = double 439 | 440 | config.quant.dynamic = dynamic 441 | config.quant.per_token = per_token 442 | if config.quant.per_token: 443 | config.quant.dynamic = True 444 | 445 | config.quant.per_embd = per_embd 446 | config.quant.per_groups = per_groups 447 | config.quant.per_groups_permute = per_groups_permute 448 | config.quant.per_groups_permute_shared_h = per_groups_permute_shared_h 449 | 450 | return func(config, *a, **kw) 451 | 452 | return func_wrapper 453 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import os 5 | import sys 6 | import time 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | from quantization.range_estimators import RangeEstimators 14 | 15 | 16 | def seed_all(seed=1029): 17 | random.seed(seed) 18 | os.environ['PYTHONHASHSEED'] = str(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU 23 | torch.backends.cudnn.benchmark = False 24 | torch.backends.cudnn.deterministic = True 25 | 26 | 27 | def count_params(module): 28 | return len(nn.utils.parameters_to_vector(module.parameters())) 29 | 30 | 31 | def count_embedding_params(model): 32 | return sum(count_params(m) for m in model.modules() if isinstance(m, nn.Embedding)) 33 | 34 | 35 | def get_layer_by_name(model, layer_name): 36 | for name, module in model.named_modules(): 37 | if name == layer_name: 38 | return module 39 | return None 40 | 41 | 42 | class StopForwardException(Exception): 43 | """Used to throw and catch an exception to stop traversing the graph.""" 44 | pass 45 | 46 | 47 | def pass_data_for_range_estimation( 48 | loader, model, act_quant, weight_quant, max_num_batches=20, cross_entropy_layer=None, inp_idx=0 49 | ): 50 | model.set_quant_state(weight_quant, act_quant) 51 | model.eval() 52 | 53 | if cross_entropy_layer is not None: 54 | layer_xent = get_layer_by_name(model, cross_entropy_layer) 55 | if layer_xent: 56 | print(f'Set cross entropy estimator for layer "{cross_entropy_layer}"') 57 | act_quant_mgr = layer_xent.activation_quantizer 58 | act_quant_mgr.range_estimator = RangeEstimators.cross_entropy.cls( 59 | per_channel=act_quant_mgr.per_channel, 60 | quantizer=act_quant_mgr.quantizer, 61 | **act_quant_mgr.init_params, 62 | ) 63 | else: 64 | raise ValueError('Cross-entropy layer not found') 65 | 66 | device = next(model.parameters()).device 67 | for i, data in enumerate(loader): 68 | try: 69 | if isinstance(data, (tuple, list)): 70 | x = data[inp_idx].to(device=device) 71 | model(x) 72 | else: 73 | x = {k: v.to(device=device) for k, v in data.items()} 74 | model(**x) 75 | except StopForwardException: 76 | pass 77 | 78 | if i >= max_num_batches - 1 or not act_quant: 79 | break 80 | 81 | 82 | class DotDict(dict): 83 | """ 84 | A dictionary that allows attribute-style access. 85 | 86 | Examples 87 | -------- 88 | >>> config = DotDict(a=None) 89 | >>> config.a = 42 90 | >>> config.b = 'egg' 91 | >>> config # can be used as dict 92 | {'a': 42, 'b': 'egg'} 93 | """ 94 | def __setattr__(self, key, value): 95 | self.__setitem__(key, value) 96 | 97 | def __delattr__(self, key): 98 | self.__delitem__(key) 99 | 100 | def __getattr__(self, key): 101 | if key in self: 102 | return self.__getitem__(key) 103 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})") 104 | 105 | 106 | class Stopwatch: 107 | """ 108 | A simple cross-platform context-manager stopwatch. 109 | 110 | Examples 111 | -------- 112 | >>> import time 113 | >>> with Stopwatch(verbose=True) as st: 114 | ... time.sleep(0.101) #doctest: +ELLIPSIS 115 | Elapsed time: 0.10... sec 116 | """ 117 | def __init__(self, name=None, verbose=False): 118 | self._name = name 119 | self._verbose = verbose 120 | 121 | self._start_time_point = 0.0 122 | self._total_duration = 0.0 123 | self._is_running = False 124 | 125 | if sys.platform == 'win32': 126 | # on Windows, the best timer is time.clock() 127 | self._timer_fn = time.clock 128 | else: 129 | # on most other platforms, the best timer is time.time() 130 | self._timer_fn = time.time 131 | 132 | def __enter__(self, verbose=False): 133 | return self.start() 134 | 135 | def __exit__(self, exc_type, exc_val, exc_tb): 136 | self.stop() 137 | if self._verbose: 138 | self.print() 139 | 140 | def start(self): 141 | if not self._is_running: 142 | self._start_time_point = self._timer_fn() 143 | self._is_running = True 144 | return self 145 | 146 | def stop(self): 147 | if self._is_running: 148 | self._total_duration += self._timer_fn() - self._start_time_point 149 | self._is_running = False 150 | return self 151 | 152 | def reset(self): 153 | self._start_time_point = 0.0 154 | self._total_duration = 0.0 155 | self._is_running = False 156 | return self 157 | 158 | def _update_state(self): 159 | now = self._timer_fn() 160 | self._total_duration += now - self._start_time_point 161 | self._start_time_point = now 162 | 163 | def _format(self): 164 | prefix = f'[{self._name}]' if self._name is not None else 'Elapsed time' 165 | info = f'{prefix}: {self._total_duration:.3f} sec' 166 | return info 167 | 168 | def format(self): 169 | if self._is_running: 170 | self._update_state() 171 | return self._format() 172 | 173 | def print(self): 174 | print(self.format()) 175 | 176 | def get_total_duration(self): 177 | if self._is_running: 178 | self._update_state() 179 | return self._total_duration 180 | --------------------------------------------------------------------------------