├── LICENSE
├── README.md
├── main.py
├── models
├── __init__.py
├── quantized_bert.py
├── quantized_mobilebert.py
└── quantized_roberta.py
├── quantization
├── __init__.py
├── adaround
│ ├── __init__.py
│ ├── adaround.py
│ ├── config.py
│ ├── quantizer.py
│ └── utils.py
├── autoquant_utils.py
├── base_quantized_classes.py
├── base_quantized_model.py
├── hijacker.py
├── quantization_manager.py
├── quantizers.py
├── range_estimators.py
└── utils.py
├── requirements.txt
└── utils
├── __init__.py
├── adaround_utils.py
├── glue_tasks.py
├── hf_models.py
├── per_embd_quant_utils.py
├── qat_utils.py
├── quant_click_options.py
├── tb_utils.py
├── transformer_click_options.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021 Qualcomm Technologies, Inc.
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification, are permitted
6 | (subject to the limitations in the disclaimer below) provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this list of conditions
9 | and the following disclaimer:
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions
12 | and the following disclaimer in the documentation and/or other materials provided with the
13 | istribution.
14 |
15 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to
16 | endorse or promote products derived from this software without specific prior written permission.
17 |
18 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS
19 | SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
20 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
26 | THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer Quantization
2 | This repository contains the implementation and experiments for the paper presented in
3 |
4 | **Yelysei Bondarenko1, Markus Nagel1, Tijmen Blankevoort1,
5 | "Understanding and Overcoming the Challenges of Efficient Transformer Quantization", EMNLP 2021.** [[ACL Anthology]](https://aclanthology.org/2021.emnlp-main.627/) [[ArXiv]](https://arxiv.org/abs/2109.12948)
6 |
7 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.)
8 |
9 |
10 | ## Reference
11 | If you find our work useful, please cite
12 | ```
13 | @inproceedings{bondarenko-etal-2021-understanding,
14 | title = "Understanding and Overcoming the Challenges of Efficient Transformer Quantization",
15 | author = "Bondarenko, Yelysei and
16 | Nagel, Markus and
17 | Blankevoort, Tijmen",
18 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
19 | month = nov,
20 | year = "2021",
21 | address = "Online and Punta Cana, Dominican Republic",
22 | publisher = "Association for Computational Linguistics",
23 | url = "https://aclanthology.org/2021.emnlp-main.627",
24 | pages = "7947--7969",
25 | abstract = "Transformer-based architectures have become the de-facto standard models for a wide range of Natural Language Processing tasks. However, their memory footprint and high latency are prohibitive for efficient deployment and inference on resource-limited devices. In this work, we explore quantization for transformers. We show that transformers have unique quantization challenges {--} namely, high dynamic activation ranges that are difficult to represent with a low bit fixed-point format. We establish that these activations contain structured outliers in the residual connections that encourage specific attention patterns, such as attending to the special separator token. To combat these challenges, we present three solutions based on post-training quantization and quantization-aware training, each with a different set of compromises for accuracy, model size, and ease of use. In particular, we introduce a novel quantization scheme {--} per-embedding-group quantization. We demonstrate the effectiveness of our methods on the GLUE benchmark using BERT, establishing state-of-the-art results for post-training quantization. Finally, we show that transformer weights and embeddings can be quantized to ultra-low bit-widths, leading to significant memory savings with a minimum accuracy loss. Our source code is available at \url{https://github.com/qualcomm-ai-research/transformer-quantization}.",
26 | }
27 | ```
28 |
29 | ## How to install
30 | First, ensure locale variables are set as follows:
31 | ```bash
32 | export LC_ALL=C.UTF-8
33 | export LANG=C.UTF-8
34 | ```
35 |
36 | Second, make sure to have Python ≥3.6 (tested with Python 3.6.8) and
37 | ensure the latest version of `pip` (tested with 21.2.4):
38 | ```bash
39 | pip install --upgrade --no-deps pip
40 | ```
41 |
42 | Next, install PyTorch 1.4.0 with the appropriate CUDA version (tested with CUDA 10.0, CuDNN 7.6.3):
43 | ```bash
44 | pip install torch==1.4.0 torchvision==0.5.0 -f https://download.pytorch.org/whl/torch_stable.html
45 | ```
46 |
47 | Finally, install the remaining dependencies using pip:
48 | ```bash
49 | pip install -r requirements.txt
50 | ```
51 |
52 | To run the code, the project root directory needs to be added to your pythonpath:
53 | ```bash
54 | export PYTHONPATH="${PYTHONPATH}:/path/to/this/dir"
55 | ```
56 |
57 | ## Running experiments
58 | The main run file to reproduce all experiments is `main.py`.
59 | It contains 4 commands to train and validate FP32 and quantized model:
60 | ```bash
61 | Usage: main.py [OPTIONS] COMMAND [ARGS]...
62 |
63 | Options:
64 | --help Show this message and exit.
65 |
66 | Commands:
67 | train-baseline
68 | train-quantized
69 | validate-baseline
70 | validate-quantized
71 | ```
72 | You can see the full list of options for each command using `python main.py [COMMAND] --help`.
73 |
74 | ### A. FP32 fine-tuning
75 | To start with, you need to get the fune-tuned model(s) for the GLUE task of interest.
76 | Example run command for fine-tuning:
77 | ```bash
78 | python main.py train-baseline --cuda --save-model --model-name bert_base_uncased --task rte \
79 | --learning-rate 3e-05 --batch-size 8 --eval-batch-size 8 --num-epochs 3 --max-seq-length 128 \
80 | --seed 1000 --output-dir /path/to/output/dir/
81 | ```
82 | You can also do it directly using HuggingFace library [[examples]](https://github.com/huggingface/transformers/tree/master/examples/pytorch/text-classification).
83 | In all experiments we used seeds 1000 - 1004 and reported the median score.
84 | The sample output directory looks as follows:
85 | ```bash
86 | /path/to/output/dir
87 | ├── config.out
88 | ├── eval_results_rte.txt
89 | ├── final_score.txt
90 | ├── out
91 | │ ├── config.json # Huggingface model config
92 | │ ├── pytorch_model.bin # PyTorch model checkpoint
93 | │ ├── special_tokens_map.json
94 | │ ├── tokenizer_config.json # Huggingface tokenizer config
95 | │ ├── training_args.bin
96 | │ └── vocab.txt # Vocabulary
97 | └── tb_logs # TensorBoard logs
98 | ├── 1632747625.1250594
99 | │ └── events.out.tfevents.*
100 | └── events.out.tfevents.*
101 | ```
102 |
103 | For validation (both full-precision and quantized), it is assumed that these output directories with the fine-tuned
104 | checkpoints are aranged as follows (you can also use a subset of GLUE tasks):
105 | ```bash
106 | /path/to/saved_models/
107 | ├── rte/rte_model_dir
108 | │ ├── out
109 | │ │ ├── config.json # Huggingface model config
110 | │ │ ├── pytorch_model.bin # PyTorch model checkpoint
111 | │ │ ├── tokenizer_config.json # Huggingface tokenizer config
112 | │ │ ├── vocab.txt # Vocabulary
113 | │ │ ├── (...)
114 | ├── cola/cola_model_dir
115 | │ ├── out
116 | │ │ ├── (...)
117 | ├── mnli/mnli_model_dir
118 | │ ├── out
119 | │ │ ├── (...)
120 | ├── mrpc/mrpc_model_dir
121 | │ ├── out
122 | │ │ ├── (...)
123 | ├── qnli/qnli_model_dir
124 | │ ├── out
125 | │ │ ├── (...)
126 | ├── qqp/qqp_model_dir
127 | │ ├── out
128 | │ │ ├── (...)
129 | ├── sst2/sst2_model_dir
130 | │ ├── out
131 | │ │ ├── (...)
132 | └── stsb/stsb_model_dir
133 | ├── out
134 | │ ├── (...)
135 | ```
136 | Note, that you have to create this file structure manually.
137 |
138 | The model can then be validated as follows:
139 | ```bash
140 | python main.py validate-baseline --eval-batch-size 32 --seed 1000 --model-name bert_base_uncased \
141 | --model-path /path/to/saved_models/ --task rte
142 | ```
143 | You can also validate multiple or all checkpoints by specifying
144 | `--task --task [...]` or `--task all`, respectively.
145 |
146 | ### B. Post-training quantization (PTQ)
147 |
148 | #### 1) Standard (naïve) W8A8 per-tensor PTQ / base run command for all PTQ experiments
149 | ```bash
150 | python main.py validate-quantized --act-quant --weight-quant --no-pad-to-max-length \
151 | --est-ranges-no-pad --eval-batch-size 16 --seed 1000 --model-path /path/to/saved_models/ \
152 | --task rte --n-bits 8 --n-bits-act 8 --qmethod symmetric_uniform \
153 | --qmethod-act asymmetric_uniform --weight-quant-method MSE --weight-opt-method golden_section \
154 | --act-quant-method current_minmax --est-ranges-batch-size 1 --num-est-batches 1 \
155 | --quant-setup all
156 | ```
157 | Note that the range estimation settings are slightly different for each task.
158 |
159 | #### 2) Mixed precision W8A{8,16} PTQ
160 | Specify `--quant-dict "{'y': 16, 'h': 16, 'x': 16}"`:
161 | * `'x': 16` will set FFN's input to 16-bit
162 | * `'h': 16` will set FFN's output to 16-bit
163 | * `'y': 16` will set FFN's residual sum to 16-bit
164 |
165 | For STS-B regression task, you will need to specify `--quant-dict "{'y': 16, 'h': 16, 'x': 16, 'P': 16, 'C': 16}"`
166 | and `--quant-setup MSE_logits`, which will also quantize pooler and the final classifier to 16-bit and use MSE estimator for the output.
167 |
168 | #### 3) Per-embedding and per-embedding-group (PEG) activation quantization
169 | * `--per-embd` -- Per-embedding quantization for all activations
170 | * `--per-groups [N_GROUPS]` -- PEG quantization for all activations, no permutation
171 | * `--per-groups [N_GROUPS] --per-groups-permute` -- PEG quantization for all activations, apply range-based permutation (separate for each quantizer)
172 | * `--quant-dict "{'y': 'ng6', 'h': 'ng6', 'x': 'ng6'}"` -- PEG quantization using 6 groups for FFN's input, output and residual sum, no permutation
173 | * `--quant-dict "{'y': 'ngp6', 'h': 'ngp6', 'x': 'ngp6'}" --per-groups-permute-shared-h` -- PEG quantization using 6 groups for FFN's input, output and residual sum, apply range-based permutation (shared between tensors in the same layer)
174 |
175 | #### 4) W4A32 PTQ with AdaRound
176 | ```bash
177 | python main.py validate-quantized --weight-quant --no-act-quant --no-pad-to-max-length \
178 | --est-ranges-no-pad --eval-batch-size 16 --seed 1000 --model-path /path/to/saved_models/ \
179 | --task rte --qmethod symmetric_uniform --qmethod-act asymmetric_uniform --n-bits 4 \
180 | --weight-quant-method MSE --weight-opt-method grid --num-candidates 100 --quant-setup all \
181 | --adaround all --adaround-num-samples 1024 --adaround-init range_estimator \
182 | --adaround-mode learned_hard_sigmoid --adaround-asym --adaround-iters 10000 \
183 | --adaround-act-quant no_act_quant
184 | ```
185 |
186 | ### C. Quantization-aware training (QAT)
187 | Base run command for QAT experiments (using W4A8 for example):
188 | ```bash
189 | python main.py train-quantized --cuda --do-eval --logging-first-step --weight-quant --act-quant \
190 | --pad-to-max-length --learn-ranges --tqdm --batch-size 8 --seed 1000 \
191 | --model-name bert_base_uncased --learning-rate 5e-05 --num-epochs 6 --warmup-steps 186 \
192 | --weight-decay 0.0 --attn-dropout 0.0 --hidden-dropout 0.0 --max-seq-length 128 --n-bits 4 \
193 | --n-bits-act 8 --qmethod symmetric_uniform --qmethod-act asymmetric_uniform \
194 | --weight-quant-method MSE --weight-opt-method golden_section --act-quant-method current_minmax \
195 | --est-ranges-batch-size 16 --num-est-batches 1 --quant-setup all \
196 | --model-path /path/to/saved_models/rte/out --task rte --output-dir /path/to/qat_output/dir
197 | ```
198 | Note that the settings are slightly different for each task (see Appendix).
199 |
200 | To run mixed-precision QAT with 2-bit embeddings and 4-bit weights, add `--quant-dict "{'Et': 2}"`.
201 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from .quantized_bert import QuantizedBertForSequenceClassification
5 | from .quantized_mobilebert import QuantizedMobileBertForSequenceClassification
6 | from .quantized_roberta import QuantizedRobertaForSequenceClassification
7 |
--------------------------------------------------------------------------------
/models/quantized_bert.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn import CrossEntropyLoss, MSELoss
10 | from transformers.models.bert.modeling_bert import (
11 | BertLayer,
12 | BertSelfAttention,
13 | BertSelfOutput,
14 | BaseModelOutputWithPoolingAndCrossAttentions,
15 | )
16 | from transformers.modeling_outputs import SequenceClassifierOutput
17 | from transformers.modeling_utils import ModuleUtilsMixin, apply_chunking_to_forward
18 |
19 | from quantization.autoquant_utils import quantize_model
20 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts
21 | from quantization.base_quantized_model import QuantizedModel
22 | from quantization.range_estimators import RangeEstimators, OptMethod
23 | from utils import _tb_advance_global_step, _tb_advance_token_counters, _tb_hist
24 |
25 |
26 | class QuantizedBertEmbeddings(QuantizedModel):
27 | def __init__(self, org_model, **quant_params):
28 | self.quant_dict = quant_params['quant_dict']
29 |
30 | super().__init__()
31 |
32 | quant_params_ = quant_params.copy()
33 | if 'Et' in self.quant_dict:
34 | quant_params_['weight_range_method'] = RangeEstimators.MSE
35 | quant_params_['weight_range_options'] = dict(opt_method=OptMethod.golden_section)
36 | self.word_embeddings = quantize_model(org_model.word_embeddings, **quant_params_)
37 |
38 | self.position_embeddings = quantize_model(org_model.position_embeddings, **quant_params)
39 | self.token_type_embeddings = quantize_model(org_model.token_type_embeddings, **quant_params)
40 |
41 | self.dropout = org_model.dropout
42 |
43 | position_ids = org_model.position_ids
44 | if position_ids is not None:
45 | self.register_buffer('position_ids', position_ids)
46 | else:
47 | self.position_ids = position_ids
48 |
49 | self.position_embedding_type = getattr(org_model, 'position_embedding_type', 'absolute')
50 |
51 | # Activation quantizers
52 | self.sum_input_token_type_embd_act_quantizer = QuantizedActivation(**quant_params)
53 | self.sum_pos_embd_act_quantizer = QuantizedActivation(**quant_params)
54 |
55 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be
56 | # able to load any TensorFlow checkpoint file
57 | self.LayerNorm = quantize_model(org_model.LayerNorm, **quant_params)
58 |
59 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
60 | if input_ids is not None:
61 | input_shape = input_ids.size()
62 | else:
63 | input_shape = inputs_embeds.size()[:-1]
64 |
65 | seq_length = input_shape[1]
66 |
67 | if position_ids is None:
68 | position_ids = self.position_ids[:, :seq_length]
69 |
70 | if token_type_ids is None:
71 | token_type_ids = torch.zeros(
72 | input_shape, dtype=torch.long, device=self.position_ids.device
73 | )
74 | if inputs_embeds is None:
75 | inputs_embeds = self.word_embeddings(input_ids)
76 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
77 |
78 | embeddings = inputs_embeds + token_type_embeddings
79 | embeddings = self.sum_input_token_type_embd_act_quantizer(embeddings)
80 |
81 | if self.position_embedding_type == 'absolute':
82 | position_embeddings = self.position_embeddings(position_ids)
83 | embeddings += position_embeddings
84 | embeddings = self.sum_pos_embd_act_quantizer(embeddings)
85 |
86 | embeddings = self.LayerNorm(embeddings)
87 | embeddings = self.dropout(embeddings)
88 | return embeddings
89 |
90 |
91 | class QuantizedBertSelfAttention(QuantizedModel):
92 | def __init__(self, org_model, **quant_params):
93 | self.quant_dict = quant_params['quant_dict']
94 |
95 | super().__init__()
96 |
97 | # copy attributes
98 | self.num_attention_heads = org_model.num_attention_heads
99 | self.attention_head_size = org_model.attention_head_size
100 | self.all_head_size = org_model.all_head_size
101 |
102 | self.position_embedding_type = getattr(org_model, 'position_embedding_type', None)
103 | if self.position_embedding_type in ('relative_key', 'relative_key_query'):
104 | raise NotImplementedError('current branch of computation is not yet supported')
105 |
106 | self.max_position_embeddings = org_model.max_position_embeddings
107 | self.distance_embedding = org_model.distance_embedding
108 |
109 | # quantized modules
110 | self.query = quantize_model(org_model.query, **quant_params)
111 | self.key = quantize_model(org_model.key, **quant_params)
112 | self.value = quantize_model(org_model.value, **quant_params)
113 | self.dropout = org_model.dropout
114 |
115 | # Activation quantizers
116 | self.attn_scores_act_quantizer = QuantizedActivation(**quant_params)
117 | self.attn_probs_act_quantizer = QuantizedActivation(**quant_params)
118 | self.context_act_quantizer = QuantizedActivation(**quant_params)
119 |
120 | def transpose_for_scores(self, x):
121 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
122 | x = x.view(*new_x_shape)
123 | return x.permute(0, 2, 1, 3)
124 |
125 | def forward(
126 | self,
127 | hidden_states,
128 | attention_mask=None,
129 | head_mask=None,
130 | encoder_hidden_states=None,
131 | encoder_attention_mask=None,
132 | past_key_value=None,
133 | output_attentions=False,
134 | ):
135 | mixed_query_layer = self.query(hidden_states)
136 |
137 | # If this is instantiated as a cross-attention module, the keys
138 | # and values come from an encoder; the attention mask needs to be
139 | # such that the encoder's padding tokens are not attended to.
140 | if encoder_hidden_states is not None:
141 | mixed_key_layer = self.key(encoder_hidden_states)
142 | mixed_value_layer = self.value(encoder_hidden_states)
143 | attention_mask = encoder_attention_mask
144 | else:
145 | mixed_key_layer = self.key(hidden_states)
146 | mixed_value_layer = self.value(hidden_states)
147 |
148 | query_layer = self.transpose_for_scores(mixed_query_layer)
149 | key_layer = self.transpose_for_scores(mixed_key_layer)
150 | value_layer = self.transpose_for_scores(mixed_value_layer)
151 |
152 | # Take the dot product between "query" and "key" to get the raw attention scores.
153 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
154 | attention_scores = self.attn_scores_act_quantizer(attention_scores)
155 |
156 | if self.position_embedding_type in ('relative_key', 'relative_key_query'):
157 | raise NotImplementedError('current branch of computation is not yet supported')
158 |
159 | seq_length = hidden_states.size()[1]
160 | position_ids_l = torch.arange(
161 | seq_length, dtype=torch.long, device=hidden_states.device
162 | ).view(-1, 1)
163 | position_ids_r = torch.arange(
164 | seq_length, dtype=torch.long, device=hidden_states.device
165 | ).view(1, -1)
166 | distance = position_ids_l - position_ids_r
167 | positional_embedding = self.distance_embedding(
168 | distance + self.max_position_embeddings - 1
169 | )
170 | # fp16 compatibility:
171 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype)
172 |
173 | if self.position_embedding_type == "relative_key":
174 | relative_position_scores = torch.einsum(
175 | "bhld,lrd->bhlr", query_layer, positional_embedding
176 | )
177 | attention_scores = attention_scores + relative_position_scores
178 | elif self.position_embedding_type == "relative_key_query":
179 | relative_position_scores_query = torch.einsum(
180 | "bhld,lrd->bhlr", query_layer, positional_embedding
181 | )
182 | relative_position_scores_key = torch.einsum(
183 | "bhrd,lrd->bhlr", key_layer, positional_embedding
184 | )
185 | attention_scores = (
186 | attention_scores + relative_position_scores_query + relative_position_scores_key
187 | )
188 |
189 | # NOTE: factor 1/d^0.5 can be absorbed into the previous act. quant. delta
190 | attention_scores /= math.sqrt(self.attention_head_size)
191 |
192 | if attention_mask is not None:
193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() fn)
194 | attention_scores += attention_mask
195 |
196 | # Normalize the attention scores to probabilities.
197 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
198 | attention_probs = self.attn_probs_act_quantizer(attention_probs)
199 |
200 | # This is actually dropping out entire tokens to attend to, which might
201 | # seem a bit unusual, but is taken from the original Transformer paper.
202 | attention_probs = self.dropout(attention_probs)
203 |
204 | # Mask heads if we want to
205 | if head_mask is not None:
206 | attention_probs = attention_probs * head_mask
207 |
208 | context_layer = torch.matmul(attention_probs, value_layer)
209 |
210 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
211 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
212 | context_layer = context_layer.view(*new_context_layer_shape)
213 | context_layer = self.context_act_quantizer(context_layer)
214 |
215 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
216 |
217 | _tb_advance_global_step(self)
218 | return outputs
219 |
220 |
221 | class QuantizedBertSelfOutput(QuantizedModel):
222 | def __init__(self, org_model, **quant_params):
223 | self.quant_dict = quant_params['quant_dict']
224 |
225 | # Exact same structure as for BertOutput.
226 | # Kept in order to be able to disable activation quantizer.
227 | super().__init__()
228 |
229 | self.dense = quantize_model(org_model.dense, **quant_params)
230 | self.dropout = org_model.dropout
231 |
232 | # Activation quantizer
233 | self.res_act_quantizer = QuantizedActivation(**quant_params)
234 |
235 | # LN
236 | self.LayerNorm = quantize_model(org_model.LayerNorm, **quant_params)
237 |
238 | def forward(self, hidden_states, input_tensor):
239 | hidden_states = self.dense(hidden_states)
240 | hidden_states = self.dropout(hidden_states)
241 |
242 | hidden_states = hidden_states + input_tensor
243 | hidden_states = self.res_act_quantizer(hidden_states)
244 |
245 | hidden_states = self.LayerNorm(hidden_states)
246 |
247 | _tb_advance_global_step(self)
248 | return hidden_states
249 |
250 |
251 | class QuantizedBertOutput(QuantizedModel):
252 | def __init__(self, org_model, **quant_params):
253 | self.quant_dict = quant_params['quant_dict']
254 |
255 | super().__init__()
256 |
257 | self.dense = quantize_model(org_model.dense, **quant_params)
258 | self.dropout = org_model.dropout
259 | self.res_act_quantizer = QuantizedActivation(**quant_params)
260 |
261 | # LN
262 | self.LayerNorm = quantize_model(org_model.LayerNorm, **quant_params)
263 |
264 | def forward(self, hidden_states, input_tensor):
265 | hidden_states = self.dense(hidden_states)
266 | hidden_states = self.dropout(hidden_states)
267 |
268 | _tb_advance_token_counters(self, input_tensor)
269 | _tb_hist(self, input_tensor, 'res_output_x')
270 | _tb_hist(self, hidden_states, 'res_output_h')
271 |
272 | hidden_states = hidden_states + input_tensor
273 |
274 | _tb_hist(self, hidden_states, 'res_output_x_h')
275 |
276 | hidden_states = self.res_act_quantizer(hidden_states)
277 | hidden_states = self.LayerNorm(hidden_states)
278 |
279 | _tb_advance_global_step(self)
280 | return hidden_states
281 |
282 |
283 | def quantize_intermediate(org_module, **quant_params):
284 | m_dense = org_module.dense
285 | m_act = org_module.intermediate_act_fn
286 | if not isinstance(m_act, nn.Module):
287 | if m_act == F.gelu:
288 | m_act = nn.GELU()
289 | else:
290 | raise NotImplementedError()
291 | return quantize_model(nn.Sequential(m_dense, m_act), **quant_params)
292 |
293 |
294 | class QuantizedBertLayer(QuantizedModel):
295 | def __init__(self, org_model, **quant_params):
296 | self.quant_dict = quant_params['quant_dict']
297 |
298 | super().__init__()
299 |
300 | # copy attributes
301 | self.chunk_size_feed_forward = org_model.chunk_size_feed_forward
302 | self.seq_len_dim = org_model.seq_len_dim
303 | self.is_decoder = org_model.is_decoder
304 | self.add_cross_attention = org_model.add_cross_attention
305 |
306 | # quantized components
307 | attention_specials = {
308 | BertSelfAttention: QuantizedBertSelfAttention,
309 | BertSelfOutput: QuantizedBertSelfOutput,
310 | }
311 | self.attention = quantize_model(
312 | org_model.attention, specials=attention_specials, **quant_params
313 | )
314 | if self.add_cross_attention:
315 | self.crossattention = quantize_model(
316 | org_model.crossattention, specials=attention_specials, **quant_params
317 | )
318 | self.intermediate = quantize_intermediate(org_model.intermediate, **quant_params)
319 | self.output = QuantizedBertOutput(org_model.output, **quant_params)
320 |
321 | def forward(
322 | self,
323 | hidden_states,
324 | attention_mask=None,
325 | head_mask=None,
326 | encoder_hidden_states=None,
327 | encoder_attention_mask=None,
328 | past_key_value=None,
329 | output_attentions=False,
330 | ):
331 | attn_args = (hidden_states, attention_mask, head_mask)
332 | attn_kw = dict(output_attentions=output_attentions)
333 |
334 | self_attention_outputs = self.attention(*attn_args, **attn_kw)
335 |
336 | attention_output = self_attention_outputs[0]
337 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
338 |
339 | if self.is_decoder and encoder_hidden_states is not None:
340 | raise NotImplementedError('current branch of computation is not yet supported')
341 |
342 | assert hasattr(self, "crossattention"), (
343 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
344 | f"cross-attention layers by setting `config.add_cross_attention=True`"
345 | )
346 | cross_attention_outputs = self.crossattention(
347 | attention_output,
348 | attention_mask,
349 | head_mask,
350 | encoder_hidden_states,
351 | encoder_attention_mask,
352 | output_attentions,
353 | )
354 | attention_output = cross_attention_outputs[0]
355 | # add cross attentions if we output attention weights:
356 | outputs = outputs + cross_attention_outputs[1:]
357 |
358 | assert self.chunk_size_feed_forward == 0 # below call is a no-op in that case
359 | layer_output = apply_chunking_to_forward(
360 | self.feed_forward_chunk,
361 | self.chunk_size_feed_forward,
362 | self.seq_len_dim,
363 | attention_output,
364 | )
365 | outputs = (layer_output,) + outputs
366 | return outputs
367 |
368 | def feed_forward_chunk(self, attention_output):
369 | intermediate_output = self.intermediate(attention_output)
370 | layer_output = self.output(intermediate_output, attention_output)
371 | return layer_output
372 |
373 |
374 | class QuantizedBertPooler(QuantizedModel):
375 | def __init__(self, org_model, **quant_params):
376 | super().__init__()
377 |
378 | self.dense_act = quantize_model(
379 | nn.Sequential(org_model.dense, org_model.activation), **quant_params
380 | )
381 |
382 | def forward(self, hidden_states):
383 | # We "pool" the model by simply taking the hidden state corresponding
384 | # to the first token.
385 | first_token_tensor = hidden_states[:, 0]
386 | pooled_output = self.dense_act(first_token_tensor)
387 |
388 | _tb_advance_global_step(self)
389 | return pooled_output
390 |
391 |
392 | class QuantizedBertModel(QuantizedModel, ModuleUtilsMixin):
393 | def __init__(self, org_model, **quant_params):
394 | super().__init__()
395 |
396 | self.config = org_model.config
397 |
398 | self.embeddings = QuantizedBertEmbeddings(org_model.embeddings, **quant_params)
399 | self.encoder = quantize_model(
400 | org_model.encoder, specials={BertLayer: QuantizedBertLayer}, **quant_params
401 | )
402 | self.pooler = (
403 | QuantizedBertPooler(org_model.pooler, **quant_params)
404 | if org_model.pooler is not None
405 | else None
406 | )
407 |
408 | def get_input_embeddings(self):
409 | return self.embeddings.word_embeddings
410 |
411 | def set_input_embeddings(self, value):
412 | self.embeddings.word_embeddings = value
413 |
414 | def _prune_heads(self, heads_to_prune):
415 | """
416 | Prunes heads of the model.
417 |
418 | Parameters
419 | ----------
420 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
421 | See base class PreTrainedModel.
422 | """
423 | for layer, heads in heads_to_prune.items():
424 | self.encoder.layer[layer].attention.prune_heads(heads)
425 |
426 | def forward(
427 | self,
428 | input_ids=None,
429 | attention_mask=None,
430 | token_type_ids=None,
431 | position_ids=None,
432 | head_mask=None,
433 | inputs_embeds=None,
434 | encoder_hidden_states=None,
435 | encoder_attention_mask=None,
436 | output_attentions=None,
437 | output_hidden_states=None,
438 | return_dict=None,
439 | ):
440 | output_attentions = (
441 | output_attentions if output_attentions is not None else self.config.output_attentions
442 | )
443 | output_hidden_states = (
444 | output_hidden_states
445 | if output_hidden_states is not None
446 | else self.config.output_hidden_states
447 | )
448 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
449 |
450 | if input_ids is not None and inputs_embeds is not None:
451 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
452 | elif input_ids is not None:
453 | input_shape = input_ids.size()
454 | elif inputs_embeds is not None:
455 | input_shape = inputs_embeds.size()[:-1]
456 | else:
457 | raise ValueError("You have to specify either input_ids or inputs_embeds")
458 |
459 | device = input_ids.device if input_ids is not None else inputs_embeds.device
460 |
461 | if attention_mask is None:
462 | attention_mask = torch.ones(input_shape, device=device)
463 | if token_type_ids is None:
464 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
465 |
466 | # We can provide a self-attention mask of dimensions
467 | # [batch_size, from_seq_length, to_seq_length]
468 | # ourselves in which case we just need to make it broadcastable to all heads.
469 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
470 | attention_mask, input_shape, device
471 | )
472 |
473 | # If a 2D or 3D attention mask is provided for the cross-attention
474 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
475 | if self.config.is_decoder and encoder_hidden_states is not None:
476 | raise NotImplementedError('current branch of computation is not yet supported')
477 |
478 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
479 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
480 | if encoder_attention_mask is None:
481 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
482 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
483 | else:
484 | encoder_extended_attention_mask = None
485 |
486 | # Prepare head mask if needed
487 | # 1.0 in head_mask indicate we keep the head
488 | # attention_probs has shape bsz x n_heads x N x N
489 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
490 | # and head_mask is converted to shape
491 | # [num_hidden_layers x batch x num_heads x seq_length x seq_length]
492 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
493 |
494 | embedding_output = self.embeddings(
495 | input_ids=input_ids,
496 | position_ids=position_ids,
497 | token_type_ids=token_type_ids,
498 | inputs_embeds=inputs_embeds,
499 | )
500 | encoder_outputs = self.encoder(
501 | embedding_output,
502 | attention_mask=extended_attention_mask,
503 | head_mask=head_mask,
504 | encoder_hidden_states=encoder_hidden_states,
505 | encoder_attention_mask=encoder_extended_attention_mask,
506 | output_attentions=output_attentions,
507 | output_hidden_states=output_hidden_states,
508 | return_dict=return_dict,
509 | )
510 | sequence_output = encoder_outputs[0]
511 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
512 |
513 | if not return_dict:
514 | return (sequence_output, pooled_output) + encoder_outputs[1:]
515 |
516 | return BaseModelOutputWithPoolingAndCrossAttentions(
517 | last_hidden_state=sequence_output,
518 | pooler_output=pooled_output,
519 | hidden_states=encoder_outputs.hidden_states,
520 | attentions=encoder_outputs.attentions,
521 | cross_attentions=encoder_outputs.cross_attentions,
522 | )
523 |
524 |
525 | class QuantizedBertForSequenceClassification(QuantizedModel):
526 | def __init__(self, org_model, quant_setup=None, **quant_params):
527 | super().__init__()
528 |
529 | self.num_labels = org_model.num_labels
530 | self.config = org_model.config
531 |
532 | if hasattr(org_model, 'bert'):
533 | self.bert = QuantizedBertModel(org_model=org_model.bert, **quant_params)
534 | if hasattr(org_model, 'dropout'):
535 | self.dropout = org_model.dropout
536 |
537 | quant_params_ = quant_params.copy()
538 |
539 | if quant_setup == 'MSE_logits':
540 | quant_params_['act_range_method'] = RangeEstimators.MSE
541 | quant_params_['act_range_options'] = dict(opt_method=OptMethod.golden_section)
542 | self.classifier = quantize_model(org_model.classifier, **quant_params_)
543 |
544 | elif quant_setup == 'FP_logits':
545 | print('Do not quantize output of FC layer')
546 |
547 | self.classifier = quantize_model(org_model.classifier, **quant_params_)
548 | # no activation quantization of logits:
549 | self.classifier.activation_quantizer = FP32Acts()
550 |
551 | elif quant_setup == 'all':
552 | self.classifier = quantize_model(org_model.classifier, **quant_params_)
553 |
554 | else:
555 | raise ValueError("Quantization setup '{}' not supported.".format(quant_setup))
556 |
557 | def forward(
558 | self,
559 | input_ids=None,
560 | attention_mask=None,
561 | token_type_ids=None,
562 | position_ids=None,
563 | head_mask=None,
564 | inputs_embeds=None,
565 | labels=None,
566 | output_attentions=None,
567 | output_hidden_states=None,
568 | return_dict=None,
569 | ):
570 | if isinstance(input_ids, tuple):
571 | if len(input_ids) == 2:
572 | input_ids, attention_mask = input_ids
573 | elif len(input_ids) == 3:
574 | input_ids, attention_mask, token_type_ids = input_ids
575 | elif len(input_ids) == 4:
576 | input_ids, attention_mask, token_type_ids, labels = input_ids
577 | else:
578 | raise ValueError('cannot interpret input tuple, use dict instead')
579 |
580 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
581 |
582 | outputs = self.bert(
583 | input_ids,
584 | attention_mask=attention_mask,
585 | token_type_ids=token_type_ids,
586 | position_ids=position_ids,
587 | head_mask=head_mask,
588 | inputs_embeds=inputs_embeds,
589 | output_attentions=output_attentions,
590 | output_hidden_states=output_hidden_states,
591 | return_dict=return_dict,
592 | )
593 |
594 | pooled_output = outputs[1]
595 | pooled_output = self.dropout(pooled_output)
596 |
597 | logits = self.classifier(pooled_output)
598 |
599 | if self.num_labels == 1:
600 | logits = torch.clamp(logits, 0.0, 5.0)
601 |
602 | loss = None
603 | if labels is not None:
604 | if self.num_labels == 1:
605 | # We are doing regression
606 | loss_fct = MSELoss()
607 | loss = loss_fct(logits.view(-1), labels.view(-1))
608 | else:
609 | loss_fct = CrossEntropyLoss()
610 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
611 |
612 | if not return_dict:
613 | output = (logits,) + outputs[2:]
614 | return ((loss,) + output) if loss is not None else output
615 |
616 | _tb_advance_global_step(self)
617 | return SequenceClassifierOutput(
618 | loss=loss,
619 | logits=logits,
620 | hidden_states=outputs.hidden_states,
621 | attentions=outputs.attentions,
622 | )
623 |
--------------------------------------------------------------------------------
/models/quantized_mobilebert.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn import CrossEntropyLoss, MSELoss
10 |
11 | from transformers.models.mobilebert.modeling_mobilebert import (
12 | BaseModelOutputWithPooling,
13 | BottleneckLayer,
14 | FFNLayer,
15 | MobileBertLayer,
16 | MobileBertSelfAttention,
17 | MobileBertSelfOutput,
18 | NoNorm,
19 | )
20 | from transformers.modeling_outputs import SequenceClassifierOutput
21 | from transformers.modeling_utils import ModuleUtilsMixin
22 |
23 | from quantization.autoquant_utils import quantize_model, quantize_module_list
24 | from quantization.base_quantized_classes import QuantizedActivation, FP32Acts
25 | from quantization.base_quantized_model import QuantizedModel
26 | from quantization.hijacker import QuantizationHijacker
27 | from quantization.range_estimators import RangeEstimators, OptMethod
28 | from utils import DotDict, _tb_advance_global_step, _tb_advance_token_counters, _tb_hist
29 |
30 |
31 | DEFAULT_QUANT_DICT = {
32 | # Embeddings
33 | 'sum_input_pos_embd': True,
34 | 'sum_token_type_embd': True,
35 |
36 | # Attention
37 | 'attn_scores': True,
38 | 'attn_probs': True,
39 | 'attn_probs_n_bits_act': None,
40 | 'attn_probs_act_range_method': None,
41 | 'attn_probs_act_range_options': None,
42 | 'attn_output': True,
43 |
44 | # Residual connections
45 | 'res_self_output': True,
46 | 'res_output': True,
47 | 'res_output_bottleneck': True,
48 | 'res_ffn_output': True,
49 | }
50 |
51 |
52 | def _make_quant_dict(partial_dict):
53 | quant_dict = DEFAULT_QUANT_DICT.copy()
54 | quant_dict.update(partial_dict)
55 | return DotDict(quant_dict)
56 |
57 |
58 | class QuantNoNorm(QuantizationHijacker):
59 | def __init__(self, org_model, *args, activation=None, **kwargs):
60 | super().__init__(*args, activation=activation, **kwargs)
61 | self.weight = org_model.weight
62 | self.bias = org_model.bias
63 |
64 | def forward(self, x, offsets=None):
65 | weight, bias = self.weight, self.bias
66 | if self._quant_w:
67 | weight = self.weight_quantizer(weight)
68 | bias = self.weight_quantizer(bias)
69 |
70 | res = x * weight + bias
71 | res = self.quantize_activations(res)
72 | return res
73 |
74 |
75 | class QuantizedMobileBertEmbeddings(QuantizedModel):
76 | def __init__(self, org_model, **quant_params):
77 | super().__init__()
78 |
79 | # copy attributes
80 | self.trigram_input = org_model.trigram_input
81 | self.embedding_size = org_model.embedding_size
82 | self.hidden_size = org_model.hidden_size
83 |
84 | # quantized modules
85 | self.word_embeddings = quantize_model(org_model.word_embeddings, **quant_params)
86 | self.position_embeddings = quantize_model(org_model.position_embeddings, **quant_params)
87 | self.token_type_embeddings = quantize_model(org_model.token_type_embeddings, **quant_params)
88 |
89 | self.embedding_transformation = quantize_model(
90 | org_model.embedding_transformation, **quant_params
91 | )
92 |
93 | assert isinstance(org_model.LayerNorm, NoNorm)
94 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params)
95 |
96 | self.dropout = org_model.dropout
97 |
98 | position_ids = org_model.position_ids
99 | if position_ids is not None:
100 | self.register_buffer('position_ids', position_ids)
101 | else:
102 | self.position_ids = position_ids
103 |
104 | # activation quantizers
105 | self.quant_dict = _make_quant_dict(quant_params['quant_dict'])
106 | self.sum_input_pos_embd_act_quantizer = (
107 | QuantizedActivation(**quant_params)
108 | if self.quant_dict.sum_input_pos_embd
109 | else FP32Acts()
110 | )
111 | self.sum_token_type_embd_act_quantizer = (
112 | QuantizedActivation(**quant_params)
113 | if self.quant_dict.sum_token_type_embd
114 | else FP32Acts()
115 | )
116 |
117 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
118 | if input_ids is not None:
119 | input_shape = input_ids.size()
120 | else:
121 | input_shape = inputs_embeds.size()[:-1]
122 |
123 | seq_length = input_shape[1]
124 |
125 | if position_ids is None:
126 | position_ids = self.position_ids[:, :seq_length]
127 |
128 | if token_type_ids is None:
129 | token_type_ids = torch.zeros(
130 | input_shape, dtype=torch.long, device=self.position_ids.device
131 | )
132 | if inputs_embeds is None:
133 | inputs_embeds = self.word_embeddings(input_ids) # (B, T, 128)
134 |
135 | if self.trigram_input:
136 | # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited
137 | # Devices (https://arxiv.org/abs/2004.02984)
138 | #
139 | # The embedding table in BERT models accounts for a substantial proportion of model size. To compress
140 | # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT.
141 | # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512
142 | # dimensional output.
143 | inputs_embeds = torch.cat(
144 | [
145 | F.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
146 | inputs_embeds,
147 | F.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
148 | ],
149 | dim=2,
150 | ) # (B, T, 384)
151 |
152 | if self.trigram_input or self.embedding_size != self.hidden_size:
153 | inputs_embeds = self.embedding_transformation(inputs_embeds) # (B, T, 512)
154 |
155 | # Add positional embeddings and token type embeddings, then layer # normalize and
156 | # perform dropout.
157 | position_embeddings = self.position_embeddings(position_ids)
158 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
159 |
160 | embeddings = self.sum_input_pos_embd_act_quantizer(inputs_embeds + position_embeddings)
161 | embeddings = self.sum_token_type_embd_act_quantizer(embeddings + token_type_embeddings)
162 | embeddings = self.LayerNorm(embeddings)
163 | embeddings = self.dropout(embeddings)
164 | return embeddings
165 |
166 |
167 | class QuantizedMobileBertSelfAttention(QuantizedModel):
168 | def __init__(self, org_model, **quant_params):
169 | super().__init__()
170 |
171 | # copy attributes
172 | self.num_attention_heads = org_model.num_attention_heads
173 | self.attention_head_size = org_model.attention_head_size
174 | self.all_head_size = org_model.all_head_size
175 |
176 | # quantized modules
177 | self.query = quantize_model(org_model.query, **quant_params)
178 | self.key = quantize_model(org_model.key, **quant_params)
179 | self.value = quantize_model(org_model.value, **quant_params)
180 | self.dropout = org_model.dropout
181 |
182 | # activation quantizers
183 | self.quant_dict = _make_quant_dict(quant_params['quant_dict'])
184 | self.attn_scores_act_quantizer = (
185 | QuantizedActivation(**quant_params) if self.quant_dict.attn_scores else FP32Acts()
186 | )
187 |
188 | quant_params_ = quant_params.copy()
189 | if self.quant_dict.attn_probs_n_bits_act is not None:
190 | quant_params_['n_bits_act'] = self.quant_dict.attn_probs_n_bits_act
191 | if self.quant_dict.attn_probs_act_range_method is not None:
192 | quant_params_['act_range_method'] = RangeEstimators[
193 | self.quant_dict.attn_probs_act_range_method
194 | ]
195 | if self.quant_dict.attn_probs_act_range_options is not None:
196 | act_range_options = self.quant_dict.attn_probs_act_range_options
197 | if 'opt_method' in act_range_options and not isinstance(act_range_options['opt_method'],
198 | OptMethod):
199 | act_range_options['opt_method'] = OptMethod[act_range_options['opt_method']]
200 | quant_params_['act_range_options'] = act_range_options
201 | self.attn_probs_act_quantizer = (
202 | QuantizedActivation(**quant_params_) if self.quant_dict.attn_probs else FP32Acts()
203 | )
204 |
205 | self.attn_output_act_quantizer = (
206 | QuantizedActivation(**quant_params) if self.quant_dict.attn_output else FP32Acts()
207 | )
208 |
209 | def transpose_for_scores(self, x):
210 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
211 | x = x.view(*new_x_shape)
212 | return x.permute(0, 2, 1, 3)
213 |
214 | def forward(
215 | self,
216 | query_tensor,
217 | key_tensor,
218 | value_tensor,
219 | attention_mask=None,
220 | head_mask=None,
221 | output_attentions=None,
222 | ):
223 | mixed_query_layer = self.query(query_tensor)
224 | mixed_key_layer = self.key(key_tensor)
225 | mixed_value_layer = self.value(value_tensor)
226 |
227 | query_layer = self.transpose_for_scores(mixed_query_layer)
228 | key_layer = self.transpose_for_scores(mixed_key_layer)
229 | value_layer = self.transpose_for_scores(mixed_value_layer)
230 |
231 | # Take the dot product between "query" and "key" to get the raw attention scores.
232 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
233 | attention_scores = self.attn_scores_act_quantizer(attention_scores)
234 |
235 | # NOTE: factor 1/d^0.5 can be absorbed into the previous act. quant. delta
236 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
237 |
238 | if attention_mask is not None:
239 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
240 | attention_scores = attention_scores + attention_mask
241 |
242 | # Normalize the attention scores to probabilities.
243 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
244 | attention_probs = self.attn_probs_act_quantizer(attention_probs)
245 |
246 | # This is actually dropping out entire tokens to attend to, which might
247 | # seem a bit unusual, but is taken from the original Transformer paper.
248 | attention_probs = self.dropout(attention_probs)
249 |
250 | # Mask heads if we want to
251 | if head_mask is not None:
252 | attention_probs = attention_probs * head_mask
253 |
254 | context_layer = torch.matmul(attention_probs, value_layer)
255 | context_layer = self.attn_output_act_quantizer(context_layer)
256 |
257 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
258 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
259 | context_layer = context_layer.view(*new_context_layer_shape)
260 |
261 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
262 | return outputs
263 |
264 |
265 | class QuantizedMobileBertSelfOutput(QuantizedModel):
266 | def __init__(self, org_model, **quant_params):
267 | super().__init__()
268 |
269 | # copy attributes
270 | self.use_bottleneck = org_model.use_bottleneck
271 |
272 | # quantized modules
273 | self.dense = quantize_model(org_model.dense, **quant_params)
274 |
275 | assert isinstance(org_model.LayerNorm, NoNorm)
276 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params)
277 |
278 | if not self.use_bottleneck:
279 | self.dropout = org_model.dropout
280 |
281 | # activation quantizers
282 | self.quant_dict = _make_quant_dict(quant_params['quant_dict'])
283 | self.res_act_quantizer = (
284 | QuantizedActivation(**quant_params) if self.quant_dict.res_self_output else FP32Acts()
285 | )
286 |
287 | def forward(self, hidden_states, residual_tensor):
288 | layer_outputs = self.dense(hidden_states)
289 | if not self.use_bottleneck:
290 | layer_outputs = self.dropout(layer_outputs)
291 |
292 | _tb_advance_token_counters(self, layer_outputs)
293 | _tb_hist(self, layer_outputs, 'res_self_output_h')
294 | _tb_hist(self, residual_tensor, 'res_self_output_x')
295 |
296 | layer_outputs = layer_outputs + residual_tensor
297 |
298 | _tb_hist(self, residual_tensor, 'res_self_output_x_h')
299 |
300 | layer_outputs = self.res_act_quantizer(layer_outputs)
301 | layer_outputs = self.LayerNorm(layer_outputs)
302 |
303 | _tb_advance_global_step(self)
304 | return layer_outputs
305 |
306 |
307 | def quantize_intermediate(org_module, **quant_params):
308 | m_dense = org_module.dense
309 | m_act = org_module.intermediate_act_fn
310 | if not isinstance(m_act, nn.Module):
311 | if m_act == F.gelu:
312 | m_act = nn.GELU()
313 | elif m_act == F.relu:
314 | m_act = nn.ReLU()
315 | else:
316 | raise NotImplementedError()
317 | return quantize_model(nn.Sequential(m_dense, m_act), **quant_params)
318 |
319 |
320 | class QuantizedOutputBottleneck(QuantizedModel):
321 | def __init__(self, org_model, **quant_params):
322 | super().__init__()
323 |
324 | self.dense = quantize_model(org_model.dense, **quant_params)
325 | assert isinstance(org_model.LayerNorm, NoNorm)
326 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params)
327 | self.dropout = org_model.dropout
328 |
329 | # activation quantizers
330 | self.quant_dict = _make_quant_dict(quant_params['quant_dict'])
331 | self.res_act_quantizer = (
332 | QuantizedActivation(**quant_params)
333 | if self.quant_dict.res_output_bottleneck
334 | else FP32Acts()
335 | )
336 |
337 | def forward(self, hidden_states, residual_tensor):
338 | layer_outputs = self.dense(hidden_states)
339 | layer_outputs = self.dropout(layer_outputs)
340 |
341 | _tb_advance_token_counters(self, layer_outputs)
342 | _tb_hist(self, layer_outputs, 'res_layer_h')
343 | _tb_hist(self, residual_tensor, 'res_layer_x')
344 |
345 | layer_outputs = layer_outputs + residual_tensor
346 |
347 | _tb_hist(self, layer_outputs, 'res_layer_x_h')
348 |
349 | layer_outputs = self.res_act_quantizer(layer_outputs)
350 | layer_outputs = self.LayerNorm(layer_outputs)
351 |
352 | _tb_advance_global_step(self)
353 | return layer_outputs
354 |
355 |
356 | class QuantizedMobileBertOutput(QuantizedModel):
357 | def __init__(self, org_model, **quant_params):
358 | super().__init__()
359 |
360 | # copy attributes
361 | self.use_bottleneck = org_model.use_bottleneck
362 |
363 | # quantized modules
364 | self.dense = quantize_model(org_model.dense, **quant_params)
365 | assert isinstance(org_model.LayerNorm, NoNorm)
366 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params)
367 |
368 | if not self.use_bottleneck:
369 | self.dropout = org_model.dropout
370 | else:
371 | self.bottleneck = QuantizedOutputBottleneck(
372 | org_model=org_model.bottleneck, **quant_params
373 | )
374 |
375 | # activation quantizers
376 | self.quant_dict = _make_quant_dict(quant_params['quant_dict'])
377 | self.res_act_quantizer = (
378 | QuantizedActivation(**quant_params) if self.quant_dict.res_output else FP32Acts()
379 | )
380 |
381 | def forward(self, intermediate_states, residual_tensor_1, residual_tensor_2):
382 | layer_output = self.dense(intermediate_states)
383 | if not self.use_bottleneck:
384 | layer_output = self.dropout(layer_output)
385 | layer_output = layer_output + residual_tensor_1
386 | layer_output = self.res_act_quantizer(layer_output)
387 | layer_output = self.LayerNorm(layer_output)
388 | else:
389 | _tb_advance_token_counters(self, layer_output)
390 | _tb_hist(self, layer_output, 'res_interm_h')
391 | _tb_hist(self, residual_tensor_1, 'res_interm_x')
392 |
393 | layer_output = layer_output + residual_tensor_1
394 |
395 | _tb_hist(self, layer_output, 'res_interm_x_h')
396 |
397 | layer_output = self.res_act_quantizer(layer_output)
398 | layer_output = self.LayerNorm(layer_output)
399 | layer_output = self.bottleneck(layer_output, residual_tensor_2)
400 |
401 | _tb_advance_global_step(self)
402 | return layer_output
403 |
404 |
405 | class QuantizedBottleneckLayer(QuantizedModel):
406 | def __init__(self, org_model, **quant_params):
407 | super().__init__()
408 |
409 | self.dense = quantize_model(org_model.dense, **quant_params)
410 | assert isinstance(org_model.LayerNorm, NoNorm)
411 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params)
412 |
413 | def forward(self, hidden_states):
414 | layer_input = self.dense(hidden_states)
415 | layer_input = self.LayerNorm(layer_input)
416 | return layer_input
417 |
418 |
419 | class QuantizedFFNOutput(QuantizedModel):
420 | def __init__(self, org_model, **quant_params):
421 | super().__init__()
422 |
423 | self.dense = quantize_model(org_model.dense, **quant_params)
424 | assert isinstance(org_model.LayerNorm, NoNorm)
425 | self.LayerNorm = QuantNoNorm(org_model.LayerNorm, **quant_params)
426 |
427 | # activation quantizers
428 | self.quant_dict = _make_quant_dict(quant_params['quant_dict'])
429 | self.res_act_quantizer = (
430 | QuantizedActivation(**quant_params) if self.quant_dict.res_ffn_output else FP32Acts()
431 | )
432 |
433 | def forward(self, hidden_states, residual_tensor):
434 | layer_outputs = self.dense(hidden_states)
435 |
436 | _tb_advance_token_counters(self, layer_outputs)
437 | num_ffn = self.ffn_idx + 1
438 | _tb_hist(self, layer_outputs, f'res_ffn{num_ffn}_h')
439 | _tb_hist(self, residual_tensor, f'res_ffn{num_ffn}_x')
440 |
441 | layer_outputs = layer_outputs + residual_tensor
442 |
443 | _tb_hist(self, layer_outputs, f'res_ffn{num_ffn}_x_h')
444 |
445 | layer_outputs = self.res_act_quantizer(layer_outputs)
446 | layer_outputs = self.LayerNorm(layer_outputs)
447 |
448 | _tb_advance_global_step(self)
449 | return layer_outputs
450 |
451 |
452 | class QuantizedFFNLayer(QuantizedModel):
453 | def __init__(self, org_model, **quant_params):
454 | super().__init__()
455 |
456 | self.intermediate = quantize_intermediate(org_model.intermediate, **quant_params)
457 | self.output = QuantizedFFNOutput(org_model.output, **quant_params)
458 |
459 | def forward(self, hidden_states):
460 | intermediate_output = self.intermediate(hidden_states)
461 | layer_outputs = self.output(intermediate_output, hidden_states)
462 | return layer_outputs
463 |
464 |
465 | class QuantizedMobileBertLayer(QuantizedModel):
466 | def __init__(self, org_model, **quant_params):
467 | super().__init__()
468 |
469 | # copy
470 | self.use_bottleneck = org_model.use_bottleneck
471 | self.num_feedforward_networks = org_model.num_feedforward_networks
472 |
473 | # quantized modules
474 | attention_specials = {
475 | MobileBertSelfAttention: QuantizedMobileBertSelfAttention,
476 | MobileBertSelfOutput: QuantizedMobileBertSelfOutput,
477 | }
478 | self.attention = quantize_model(
479 | org_model.attention, specials=attention_specials, **quant_params
480 | )
481 | self.intermediate = quantize_intermediate(org_model.intermediate, **quant_params)
482 | self.output = QuantizedMobileBertOutput(org_model.output, **quant_params)
483 |
484 | if self.use_bottleneck:
485 | self.bottleneck = quantize_model(
486 | org_model.bottleneck,
487 | specials={BottleneckLayer: QuantizedBottleneckLayer},
488 | **quant_params,
489 | )
490 | if getattr(org_model, 'ffn', None) is not None:
491 | self.ffn = quantize_module_list(
492 | org_model.ffn, specials={FFNLayer: QuantizedFFNLayer}, **quant_params
493 | )
494 |
495 | def forward(
496 | self,
497 | hidden_states,
498 | attention_mask=None,
499 | head_mask=None,
500 | output_attentions=None,
501 | ):
502 | if self.use_bottleneck:
503 | query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
504 | else:
505 | query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
506 |
507 | self_attention_outputs = self.attention(
508 | query_tensor,
509 | key_tensor,
510 | value_tensor,
511 | layer_input,
512 | attention_mask,
513 | head_mask,
514 | output_attentions=output_attentions,
515 | )
516 | attention_output = self_attention_outputs[0]
517 | s = (attention_output,)
518 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
519 |
520 | if self.num_feedforward_networks != 1:
521 | for i, ffn_module in enumerate(self.ffn):
522 | # attach index for TB vis
523 | for m in ffn_module.modules():
524 | m.ffn_idx = i
525 |
526 | attention_output = ffn_module(attention_output)
527 | s += (attention_output,)
528 |
529 | intermediate_output = self.intermediate(attention_output)
530 | layer_output = self.output(intermediate_output, attention_output, hidden_states)
531 | outputs = (
532 | (layer_output,)
533 | + outputs
534 | + (
535 | torch.tensor(1000),
536 | query_tensor,
537 | key_tensor,
538 | value_tensor,
539 | layer_input,
540 | attention_output,
541 | intermediate_output,
542 | )
543 | + s
544 | )
545 | return outputs
546 |
547 |
548 | class QuantizedMobileBertPooler(QuantizedModel):
549 | def __init__(self, org_model, **quant_params):
550 | super().__init__()
551 |
552 | self.do_activate = org_model.do_activate
553 | if self.do_activate:
554 | self.dense_act = quantize_model(
555 | nn.Sequential(org_model.dense, nn.Tanh()), **quant_params
556 | )
557 |
558 | def forward(self, hidden_states):
559 | # We "pool" the model by simply taking the hidden state corresponding
560 | # to the first token.
561 | first_token_tensor = hidden_states[:, 0]
562 | if not self.do_activate:
563 | return first_token_tensor
564 | else:
565 | pooled_output = self.dense_act(first_token_tensor)
566 | return pooled_output
567 |
568 |
569 | class QuantizedMobileBertModel(QuantizedModel, ModuleUtilsMixin):
570 | def __init__(self, org_model, **quant_params):
571 | super().__init__()
572 |
573 | self.config = org_model.config
574 |
575 | self.embeddings = QuantizedMobileBertEmbeddings(org_model.embeddings, **quant_params)
576 | self.encoder = quantize_model(
577 | org_model.encoder, specials={MobileBertLayer: QuantizedMobileBertLayer}, **quant_params
578 | )
579 | self.pooler = (
580 | QuantizedMobileBertPooler(org_model.pooler, **quant_params)
581 | if org_model.pooler is not None
582 | else None
583 | )
584 |
585 | def get_input_embeddings(self):
586 | return self.embeddings.word_embeddings
587 |
588 | def set_input_embeddings(self, value):
589 | self.embeddings.word_embeddings = value
590 |
591 | def _prune_heads(self, heads_to_prune):
592 | """
593 | Prunes heads of the model.
594 |
595 | Parameters
596 | ----------
597 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
598 | See base class PreTrainedModel.
599 | """
600 | for layer, heads in heads_to_prune.items():
601 | self.encoder.layer[layer].attention.prune_heads(heads)
602 |
603 | def forward(
604 | self,
605 | input_ids=None,
606 | attention_mask=None,
607 | token_type_ids=None,
608 | position_ids=None,
609 | head_mask=None,
610 | inputs_embeds=None,
611 | output_hidden_states=None,
612 | output_attentions=None,
613 | return_dict=None,
614 | ):
615 | output_attentions = (
616 | output_attentions if output_attentions is not None else self.config.output_attentions
617 | )
618 | output_hidden_states = (
619 | output_hidden_states
620 | if output_hidden_states is not None
621 | else self.config.output_hidden_states
622 | )
623 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
624 |
625 | if input_ids is not None and inputs_embeds is not None:
626 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
627 | elif input_ids is not None:
628 | input_shape = input_ids.size()
629 | elif inputs_embeds is not None:
630 | input_shape = inputs_embeds.size()[:-1]
631 | else:
632 | raise ValueError("You have to specify either input_ids or inputs_embeds")
633 |
634 | device = input_ids.device if input_ids is not None else inputs_embeds.device
635 |
636 | if attention_mask is None:
637 | attention_mask = torch.ones(input_shape, device=device)
638 | if token_type_ids is None:
639 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
640 |
641 | # We can provide a self-attention mask of dimensions
642 | # [batch_size, from_seq_length, to_seq_length]
643 | # ourselves in which case we just need to make it broadcastable to all heads.
644 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
645 | attention_mask, input_shape, self.device
646 | )
647 |
648 | # Prepare head mask if needed
649 | # 1.0 in head_mask indicate we keep the head
650 | # attention_probs has shape bsz x n_heads x N x N
651 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
652 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
653 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
654 |
655 | embedding_output = self.embeddings(
656 | input_ids=input_ids,
657 | position_ids=position_ids,
658 | token_type_ids=token_type_ids,
659 | inputs_embeds=inputs_embeds,
660 | )
661 | encoder_outputs = self.encoder(
662 | embedding_output,
663 | attention_mask=extended_attention_mask,
664 | head_mask=head_mask,
665 | output_attentions=output_attentions,
666 | output_hidden_states=output_hidden_states,
667 | return_dict=return_dict,
668 | )
669 |
670 | sequence_output = encoder_outputs[0]
671 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
672 |
673 | if not return_dict:
674 | return (sequence_output, pooled_output) + encoder_outputs[1:]
675 |
676 | return BaseModelOutputWithPooling(
677 | last_hidden_state=sequence_output,
678 | pooler_output=pooled_output,
679 | hidden_states=encoder_outputs.hidden_states,
680 | attentions=encoder_outputs.attentions,
681 | )
682 |
683 |
684 | class QuantizedMobileBertForSequenceClassification(QuantizedModel):
685 | def __init__(self, org_model, quant_setup=None, **quant_params):
686 | super().__init__()
687 |
688 | self.num_labels = org_model.num_labels
689 | self.config = org_model.config
690 |
691 | self.mobilebert = QuantizedMobileBertModel(org_model=org_model.mobilebert, **quant_params)
692 | self.dropout = org_model.dropout
693 | self.classifier = quantize_model(org_model.classifier, **quant_params)
694 |
695 | if quant_setup == 'FP_logits':
696 | print('Do not quantize output of FC layer')
697 | # no activation quantization of logits:
698 | self.classifier.activation_quantizer = FP32Acts()
699 | elif quant_setup is not None and quant_setup != 'all':
700 | raise ValueError("Quantization setup '{}' not supported.".format(quant_setup))
701 |
702 | def forward(
703 | self,
704 | input_ids=None,
705 | attention_mask=None,
706 | token_type_ids=None,
707 | position_ids=None,
708 | head_mask=None,
709 | inputs_embeds=None,
710 | labels=None,
711 | output_attentions=None,
712 | output_hidden_states=None,
713 | return_dict=None,
714 | ):
715 | r"""
716 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
717 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
718 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
719 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
720 | """
721 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
722 |
723 | outputs = self.mobilebert(
724 | input_ids,
725 | attention_mask=attention_mask,
726 | token_type_ids=token_type_ids,
727 | position_ids=position_ids,
728 | head_mask=head_mask,
729 | inputs_embeds=inputs_embeds,
730 | output_attentions=output_attentions,
731 | output_hidden_states=output_hidden_states,
732 | return_dict=return_dict,
733 | )
734 | pooled_output = outputs[1]
735 | pooled_output = self.dropout(pooled_output)
736 |
737 | # NB: optionally can keep final logits un-quantized, if only used for prediction
738 | # (can be enabled via --quant-setup FP_logits)
739 | logits = self.classifier(pooled_output)
740 |
741 | loss = None
742 | if labels is not None:
743 | if self.num_labels == 1:
744 | # We are doing regression
745 | loss_fct = MSELoss()
746 | loss = loss_fct(logits.view(-1), labels.view(-1))
747 | else:
748 | loss_fct = CrossEntropyLoss()
749 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
750 |
751 | if not return_dict:
752 | output = (logits,) + outputs[2:]
753 | return ((loss,) + output) if loss is not None else output
754 |
755 | return SequenceClassifierOutput(
756 | loss=loss,
757 | logits=logits,
758 | hidden_states=outputs.hidden_states,
759 | attentions=outputs.attentions,
760 | )
761 |
--------------------------------------------------------------------------------
/models/quantized_roberta.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import torch
5 | from torch.nn import CrossEntropyLoss, MSELoss
6 | from transformers.modeling_outputs import SequenceClassifierOutput
7 | from transformers.models.roberta.modeling_roberta import (
8 | RobertaSelfAttention,
9 | RobertaSelfOutput,
10 | RobertaLayer,
11 | )
12 |
13 | from models.quantized_bert import (
14 | QuantizedBertEmbeddings,
15 | QuantizedBertSelfAttention,
16 | QuantizedBertSelfOutput,
17 | QuantizedBertOutput,
18 | QuantizedBertLayer,
19 | QuantizedBertPooler,
20 | QuantizedBertModel,
21 | QuantizedBertForSequenceClassification,
22 | )
23 | from quantization.autoquant_utils import quantize_model
24 |
25 |
26 | def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
27 | """
28 | Replace non-padding symbols with their position numbers. Position numbers begin at
29 | padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
30 | `utils.make_positions`.
31 |
32 | Args:
33 | x: torch.Tensor x:
34 |
35 | Returns: torch.Tensor
36 | """
37 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX
38 | # export and XLA.
39 | mask = input_ids.ne(padding_idx).int()
40 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
41 | return incremental_indices.long() + padding_idx
42 |
43 |
44 | class QuantizedRobertaEmbeddings(QuantizedBertEmbeddings):
45 | def __init__(self, org_model, **quant_params):
46 | super().__init__(org_model, **quant_params)
47 |
48 | self.padding_idx = org_model.padding_idx
49 |
50 | def create_position_ids_from_inputs_embeds(self, inputs_embeds):
51 | """We are provided embeddings directly. We cannot infer which are padded so just generate
52 | sequential position ids.
53 | """
54 | input_shape = inputs_embeds.size()[:-1]
55 | sequence_length = input_shape[1]
56 |
57 | position_ids = torch.arange(
58 | self.padding_idx + 1,
59 | sequence_length + self.padding_idx + 1,
60 | dtype=torch.long,
61 | device=inputs_embeds.device,
62 | )
63 | return position_ids.unsqueeze(0).expand(input_shape)
64 |
65 | def forward(
66 | self,
67 | input_ids=None,
68 | token_type_ids=None,
69 | position_ids=None,
70 | inputs_embeds=None,
71 | past_key_values_length=0,
72 | ):
73 | if position_ids is None:
74 | if input_ids is not None:
75 | # Create the position ids from the input token ids. Any padded tokens remain padded.
76 | position_ids = create_position_ids_from_input_ids(
77 | input_ids, self.padding_idx, past_key_values_length
78 | ).to(input_ids.device)
79 | else:
80 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
81 |
82 | if input_ids is not None:
83 | input_shape = input_ids.size()
84 | else:
85 | input_shape = inputs_embeds.size()[:-1]
86 |
87 | if token_type_ids is None:
88 | token_type_ids = torch.zeros(
89 | input_shape, dtype=torch.long, device=self.position_ids.device
90 | )
91 |
92 | if inputs_embeds is None:
93 | inputs_embeds = self.word_embeddings(input_ids)
94 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
95 |
96 | embeddings = inputs_embeds + token_type_embeddings
97 | embeddings = self.sum_input_token_type_embd_act_quantizer(embeddings)
98 |
99 | if self.position_embedding_type == "absolute":
100 | position_embeddings = self.position_embeddings(position_ids)
101 | embeddings += position_embeddings
102 |
103 | embeddings = self.sum_pos_embd_act_quantizer(embeddings)
104 |
105 | embeddings = self.LayerNorm(embeddings)
106 | embeddings = self.dropout(embeddings)
107 | return embeddings
108 |
109 |
110 | class QuantizedRobertaSelfAttention(QuantizedBertSelfAttention):
111 | pass
112 |
113 |
114 | class QuantizedRobertaSelfOutput(QuantizedBertSelfOutput):
115 | pass
116 |
117 |
118 | class QuantizedRobertaOutput(QuantizedBertOutput):
119 | pass
120 |
121 |
122 | class QuantizedRobertaLayer(QuantizedBertLayer):
123 | def __init__(self, org_model, **quant_params):
124 | super().__init__(org_model, **quant_params)
125 |
126 | # update quantized components
127 | attention_specials = {
128 | RobertaSelfAttention: QuantizedRobertaSelfAttention,
129 | RobertaSelfOutput: QuantizedRobertaSelfOutput,
130 | }
131 | self.attention = quantize_model(
132 | org_model.attention, specials=attention_specials, **quant_params
133 | )
134 | if self.add_cross_attention:
135 | self.crossattention = quantize_model(
136 | org_model.crossattention, specials=attention_specials, **quant_params
137 | )
138 | self.output = QuantizedRobertaOutput(org_model.output, **quant_params)
139 |
140 |
141 | class QuantizedRobertaPooler(QuantizedBertPooler):
142 | pass
143 |
144 |
145 | class QuantizedRobertaModel(QuantizedBertModel):
146 | def __init__(self, org_model, **quant_params):
147 | super().__init__(org_model, **quant_params)
148 |
149 | # update quantized components
150 | self.embeddings = QuantizedRobertaEmbeddings(org_model.embeddings, **quant_params)
151 | self.encoder = quantize_model(
152 | org_model.encoder, specials={RobertaLayer: QuantizedRobertaLayer}, **quant_params
153 | )
154 | self.pooler = (
155 | QuantizedRobertaPooler(org_model.pooler, **quant_params)
156 | if org_model.pooler is not None
157 | else None
158 | )
159 |
160 |
161 | class QuantizedRobertaForSequenceClassification(QuantizedBertForSequenceClassification):
162 | def __init__(self, org_model, quant_setup=None, **quant_params):
163 | super().__init__(org_model, quant_setup=quant_setup, **quant_params)
164 |
165 | # update quantization components
166 | self.roberta = QuantizedRobertaModel(org_model=org_model.roberta, **quant_params)
167 | self.classifier = quantize_model(org_model.classifier, **quant_params)
168 |
169 | def forward(
170 | self,
171 | input_ids=None,
172 | attention_mask=None,
173 | token_type_ids=None,
174 | position_ids=None,
175 | head_mask=None,
176 | inputs_embeds=None,
177 | labels=None,
178 | output_attentions=None,
179 | output_hidden_states=None,
180 | return_dict=None,
181 | ):
182 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
183 |
184 | outputs = self.roberta(
185 | input_ids,
186 | attention_mask=attention_mask,
187 | token_type_ids=token_type_ids,
188 | position_ids=position_ids,
189 | head_mask=head_mask,
190 | inputs_embeds=inputs_embeds,
191 | output_attentions=output_attentions,
192 | output_hidden_states=output_hidden_states,
193 | return_dict=return_dict,
194 | )
195 | sequence_output = outputs[0]
196 |
197 | # NOTE: optionally can keep final logits un-quantized, if only used for prediction
198 | # (can be enabled via --quant-setup FP_logits)
199 | logits = self.classifier(sequence_output)
200 |
201 | loss = None
202 | if labels is not None:
203 | if self.num_labels == 1:
204 | # We are doing regression
205 | loss_fct = MSELoss()
206 | loss = loss_fct(logits.view(-1), labels.view(-1))
207 | else:
208 | loss_fct = CrossEntropyLoss()
209 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
210 |
211 | if not return_dict:
212 | output = (logits,) + outputs[2:]
213 | return ((loss,) + output) if loss is not None else output
214 |
215 | return SequenceClassifierOutput(
216 | loss=loss,
217 | logits=logits,
218 | hidden_states=outputs.hidden_states,
219 | attentions=outputs.attentions,
220 | )
221 |
--------------------------------------------------------------------------------
/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
--------------------------------------------------------------------------------
/quantization/adaround/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from quantization.adaround.adaround import apply_adaround_to_layer
5 | from quantization.adaround.utils import (
6 | AdaRoundInitMode,
7 | AdaRoundMode,
8 | AdaRoundActQuantMode,
9 | AdaRoundLossType,
10 | AdaRoundTempDecayType,
11 | )
12 |
--------------------------------------------------------------------------------
/quantization/adaround/adaround.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import logging
5 | from math import ceil
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 |
11 | from quantization.adaround.quantizer import ADAROUND_QUANTIZER_MAP
12 | from quantization.adaround.utils import (
13 | MODE_TO_LOSS_TYPE,
14 | AdaRoundInitMode,
15 | CombinedLoss,
16 | GetLayerInpOut,
17 | LayerOutputMSE,
18 | )
19 | from utils.utils import DotDict
20 |
21 |
22 | # setup logger
23 | logger = logging.getLogger('AdaRound')
24 | logger.setLevel(logging.INFO)
25 |
26 |
27 | def apply_adaround_to_layer(model, layer, data_tensor, batch_size, act_quant, adaround_config,
28 | keep_gpu=True):
29 | """Apply AdaRound to a `layer` in the `model`."""
30 |
31 | # disable caching of quantized params
32 | layer.caching = False
33 |
34 | # grid initialization
35 | if adaround_config.init == AdaRoundInitMode.range_estimator:
36 | pass # already initialized
37 | elif adaround_config.init == AdaRoundInitMode.mse:
38 | apply_mse_init(layer)
39 | elif adaround_config.init == AdaRoundInitMode.mse_out:
40 | apply_mse_out_init(model, layer, data_tensor, batch_size)
41 | elif adaround_config.init == AdaRoundInitMode.mse_out_asym:
42 | apply_mse_out_init(model, layer, data_tensor, batch_size, asym=True)
43 | else:
44 | raise ValueError(f'Unknown initialization for AdaRound: {adaround_config.init}')
45 |
46 | # activation function
47 | if not adaround_config.include_act_func:
48 | org_act_func = layer.activation_function
49 | layer.activation_function = None
50 |
51 | # replace quantizer with AdaRound quantizer
52 | org_w_quantizer = layer.weight_quantizer.quantizer
53 | org_w_quant_cls = org_w_quantizer.__class__
54 |
55 | if not org_w_quant_cls in ADAROUND_QUANTIZER_MAP:
56 | raise NotImplementedError(f'AdaRound is not supported for "{org_w_quant_cls}"')
57 |
58 | new_w_quant_cls = ADAROUND_QUANTIZER_MAP[org_w_quant_cls]
59 | w_quantizer = new_w_quant_cls(
60 | n_bits=org_w_quantizer.n_bits,
61 | scale_domain=org_w_quantizer.scale_domain,
62 | per_channel=org_w_quantizer.per_channel,
63 | eps=org_w_quantizer.eps,
64 | )
65 | w_quantizer.register_buffer('_delta', org_w_quantizer._delta)
66 | w_quantizer.register_buffer('_zero_float', org_w_quantizer._zero_float)
67 | if hasattr(org_w_quantizer, '_signed'):
68 | w_quantizer.register_buffer('_signed', org_w_quantizer._signed)
69 | layer.weight_quantizer.quantizer = w_quantizer
70 |
71 | # set AdaRound attributes
72 | w_quantizer.round_mode = adaround_config.round_mode
73 | w_quantizer.temperature = adaround_config.annealing[0]
74 |
75 | # single test (and init alpha)
76 | get_inp_out = GetLayerInpOut(model, layer, asym=adaround_config.asym, act_quant=act_quant)
77 | inp, out = get_inp_out(data_tensor[:batch_size])
78 | loss_soft_before, loss_hard_before = _compute_and_display_local_losses(
79 | w_quantizer, layer, inp, out, infix='before optimization'
80 | )
81 | w_quantizer.soft_targets = True
82 |
83 | # define loss
84 | loss_type = MODE_TO_LOSS_TYPE[w_quantizer.round_mode]
85 | loss_fn = CombinedLoss(
86 | quantizer=w_quantizer,
87 | loss_type=loss_type,
88 | weight=adaround_config.weight,
89 | max_count=adaround_config.iters,
90 | b_range=adaround_config.annealing,
91 | warmup=adaround_config.warmup,
92 | decay_type=adaround_config.decay_type,
93 | decay_shape=adaround_config.decay_shape,
94 | decay_start=adaround_config.decay_start,
95 | )
96 |
97 | # define optimizer
98 | opt_params = [w_quantizer.alpha]
99 | optimizer = torch.optim.Adam(opt_params, lr=adaround_config.lr)
100 |
101 | # main loop
102 | optimize_local_loss(
103 | layer,
104 | get_inp_out,
105 | data_tensor,
106 | optimizer,
107 | loss_fn,
108 | batch_size,
109 | adaround_config.iters,
110 | keep_gpu=keep_gpu,
111 | )
112 |
113 | # check afterwards
114 | logger.info(f'Local loss before optimization (hard quant): {loss_hard_before:.7f}')
115 | loss_soft_after, loss_hard_after = _compute_and_display_local_losses(
116 | w_quantizer, layer, inp, out, infix='after optimization'
117 | )
118 |
119 | # set to hard decision up/down
120 | w_quantizer.soft_targets = False
121 |
122 | # restore original activation function
123 | if not adaround_config.include_act_func:
124 | layer.activation_function = org_act_func
125 |
126 | # restore caching of quantized params
127 | layer.caching = True
128 |
129 | # prepare output
130 | out = DotDict(
131 | loss_soft_before=loss_soft_before,
132 | loss_hard_before=loss_hard_before,
133 | loss_soft_after=loss_soft_after,
134 | loss_hard_after=loss_hard_after,
135 | )
136 | return out
137 |
138 |
139 | def _compute_and_display_local_losses(quantizer, layer, inp, out, infix=''):
140 | org_soft_targets = quantizer.soft_targets
141 |
142 | quantizer.soft_targets = True
143 | out_soft_quant = layer(inp)
144 | quantizer.soft_targets = False
145 | out_hard_quant = layer(inp)
146 |
147 | soft_quant_loss = F.mse_loss(out_soft_quant, out)
148 | hard_quant_loss = F.mse_loss(out_hard_quant, out)
149 |
150 | if infix:
151 | infix = infix.strip() + ' '
152 |
153 | logger.info(f'Local loss {infix}(soft quant): {soft_quant_loss:.7f}')
154 | logger.info(f'Local loss {infix}(hard quant): {hard_quant_loss:.7f}')
155 |
156 | quantizer.soft_targets = org_soft_targets
157 | return float(soft_quant_loss), float(hard_quant_loss)
158 |
159 |
160 | def apply_mse_init(layer):
161 | w = layer.weight
162 | q = layer.weight_quantizer.quantizer
163 |
164 | with torch.no_grad():
165 | w_absmax = torch.max(w.max(), torch.abs(w.min()))
166 | best_score = np.inf
167 | best_max = w_absmax
168 | for i in range(80):
169 | s = w_absmax * (1.0 - 0.01 * i)
170 | q.set_quant_range(-s, s)
171 | score = F.mse_loss(w, q(w)).item()
172 |
173 | if score < best_score:
174 | best_score = score
175 | best_max = s
176 |
177 | logger.info(f'Finished: set max={best_max:.3f} (mse={best_score:.7f})')
178 | q.set_quant_range(-best_max, best_max)
179 |
180 |
181 | def apply_mse_out_init(model, layer, data_tensor, batch_size, asym=False):
182 | w = layer.weight
183 | q = layer.weight_quantizer.quantizer
184 |
185 | get_inp_out = GetLayerInpOut(model, layer, asym=asym)
186 | loss_fn = LayerOutputMSE(layer, get_inp_out, data_tensor, batch_size)
187 |
188 | with torch.no_grad():
189 | w_absmax = torch.max(w.max(), torch.abs(w.min()))
190 | best_score = np.inf
191 | best_max = w_absmax
192 | for i in range(80):
193 | s = w_absmax * (1.0 - 0.01 * i)
194 | q.set_quant_range(-s, s)
195 | score = loss_fn()
196 |
197 | if score < best_score:
198 | best_score = score
199 | best_max = s
200 | logger.info(f'Finished: set max={best_max:.3f} (mse={best_score:.7f})')
201 | q.set_quant_range(-best_max, best_max)
202 |
203 |
204 | def optimize_local_loss(layer, get_inp_out, data_tensor, optimizer, loss_fn, batch_size, iters,
205 | use_cached_data=True, keep_gpu=True):
206 | """AdaRound optimization loop."""
207 | if use_cached_data:
208 | logger.info('Caching data for local loss optimization')
209 |
210 | cached_batches = []
211 | if keep_gpu:
212 | torch.cuda.empty_cache()
213 | with torch.no_grad():
214 | for i in range(ceil(data_tensor.size(0) / batch_size)):
215 | cur_inp, cur_out = get_inp_out(data_tensor[i * batch_size:(i + 1) * batch_size])
216 | cached_batches.append((cur_inp.cpu(), cur_out.cpu()))
217 |
218 | cached_inps = torch.cat([x[0] for x in cached_batches])
219 | cached_outs = torch.cat([x[1] for x in cached_batches])
220 | device = cur_inp.device
221 |
222 | del cached_batches
223 | if keep_gpu: # put all cached data on GPU for faster optimization
224 | torch.cuda.empty_cache()
225 | try:
226 | cached_inps = cached_inps.to(device)
227 | cached_outs = cached_outs.to(device)
228 | except RuntimeError as e:
229 | logger.warning(
230 | f"WARNING: could not cache training data on GPU, keep on CPU ({e})"
231 | )
232 | cached_inps = cached_inps.cpu()
233 | cached_outs = cached_outs.cpu()
234 |
235 | for i in range(iters):
236 | idx = torch.randperm(cached_inps.size(0))[:batch_size]
237 | if use_cached_data:
238 | cur_inp = cached_inps[idx].to(device)
239 | cur_out = cached_outs[idx].to(device)
240 | else:
241 | cur_inp, cur_out = get_inp_out(data_tensor[idx])
242 |
243 | optimizer.zero_grad()
244 |
245 | try:
246 | out_quant = layer(cur_inp)
247 | loss = loss_fn(out_quant, cur_out)
248 | loss.backward()
249 | except RuntimeError as e:
250 | if use_cached_data and 'cuda' in str(cached_inps.device):
251 | logger.warning(
252 | f"WARNING: not enough CUDA memory for forward pass, "
253 | f"move cached data to CPU ({e})"
254 | )
255 | cached_inps = cached_inps.cpu()
256 | cached_outs = cached_outs.cpu()
257 | else:
258 | raise e
259 |
260 | optimizer.step()
261 |
--------------------------------------------------------------------------------
/quantization/adaround/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from quantization.adaround.utils import (
5 | AdaRoundActQuantMode,
6 | AdaRoundInitMode,
7 | AdaRoundMode,
8 | AdaRoundTempDecayType,
9 | )
10 | from utils.utils import DotDict
11 |
12 |
13 | class AdaRoundConfig(DotDict):
14 | pass
15 |
16 |
17 | DEFAULT_ADAROUND_CONFIG = AdaRoundConfig(
18 | # Base options
19 | layers=('all',),
20 | num_samples=1024,
21 | init=AdaRoundInitMode.range_estimator,
22 |
23 | # Method and continuous relaxation options
24 | round_mode=AdaRoundMode.learned_hard_sigmoid,
25 | asym=True,
26 | include_act_func=True,
27 | lr=1e-3,
28 | iters=1000,
29 | weight=0.01,
30 | annealing=(20, 2),
31 | decay_type=AdaRoundTempDecayType.cosine,
32 | decay_shape=1.0,
33 | decay_start=0.0,
34 | warmup=0.2,
35 |
36 | # Activation quantization
37 | act_quant_mode=AdaRoundActQuantMode.post_adaround,
38 | )
39 |
--------------------------------------------------------------------------------
/quantization/adaround/quantizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import logging
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from quantization.quantizers import (
10 | QuantizerBase,
11 | AsymmetricUniformQuantizer,
12 | SymmetricUniformQuantizer,
13 | )
14 | from quantization.adaround.utils import AdaRoundMode
15 |
16 |
17 | # setup logger
18 | logger = logging.getLogger('AdaRound')
19 | logger.setLevel(logging.INFO)
20 |
21 |
22 | def logit(p, eps=1e-16):
23 | p = torch.clamp(p, eps, 1 - eps)
24 | return -torch.log(1 / p - 1)
25 |
26 |
27 | def hard_sigmoid(x, zeta=1.1, gamma=-0.1):
28 | p = torch.sigmoid(x)
29 | return torch.clamp(p * (zeta - gamma) + gamma, 0.0, 1.0)
30 |
31 |
32 | def hard_logit(p, zeta=1.1, gamma=-0.1):
33 | # NOTE: argument of log is between 1/11 and 11 (for default values of zeta and gamma)
34 | return -torch.log((zeta - p) / (p - gamma))
35 |
36 |
37 | class AdaRoundQuantizer(QuantizerBase):
38 | def __init__(self, *args, **kwargs):
39 | super().__init__(*args, **kwargs)
40 |
41 | self.alpha = None
42 | self.round_mode = AdaRoundMode.nearest
43 | self.soft_targets = False
44 | self.temperature = None # for sigmoid temperature annealing
45 |
46 | def to_integer_forward(self, x_float):
47 | if self.round_mode == AdaRoundMode.nearest:
48 | return super().to_integer_forward(x_float)
49 |
50 | if self.round_mode not in AdaRoundMode.RELAXATION:
51 | raise ValueError(f'Unknown rounding mode: {self.round_mode}')
52 |
53 | # cont. relaxation
54 | x = x_float / self.scale
55 | x_floor = torch.floor(x)
56 |
57 | # initialize alpha, if needed
58 | if self.alpha is None:
59 | logger.info('Init alpha to be FP32')
60 |
61 | rest = x - x_floor # rest of rounding [0, 1)
62 | if self.round_mode == AdaRoundMode.learned_sigmoid:
63 | alpha = logit(rest) # => sigmoid(alpha) = rest
64 | elif self.round_mode == AdaRoundMode.learned_hard_sigmoid:
65 | alpha = hard_logit(rest) # => hard_sigmoid(alpha) = rest
66 | elif self.round_mode == AdaRoundMode.sigmoid_temp_decay:
67 | alpha = self.temperature * logit(rest) # => sigmoid(alpha/temperature) = rest
68 | else:
69 | raise ValueError(f'Unknown rounding mode: {self.round_mode}')
70 |
71 | self.alpha = nn.Parameter(alpha, requires_grad=True)
72 |
73 | # compute final x_int
74 | x_int = x_floor + (self.get_rest() if self.soft_targets else (self.alpha >= 0).float())
75 |
76 | if not self.symmetric:
77 | x_int += self.zero_point
78 |
79 | x_int = torch.clamp(x_int, self.int_min, self.int_max)
80 | return x_int
81 |
82 | def get_rest(self):
83 | if self.round_mode == AdaRoundMode.learned_sigmoid:
84 | return torch.sigmoid(self.alpha)
85 | elif self.round_mode == AdaRoundMode.learned_hard_sigmoid:
86 | return hard_sigmoid(self.alpha)
87 | elif self.round_mode == AdaRoundMode.sigmoid_temp_decay:
88 | return torch.sigmoid(self.alpha / self.temperature)
89 | else:
90 | raise ValueError(f'Unknown rounding mode: {self.round_mode}')
91 |
92 | def extra_repr(self):
93 | return ', '.join([
94 | f'n_bits={self.n_bits}',
95 | f'per_channel={self.per_channel}',
96 | f'is_initialized={self.is_initialized}',
97 | f'round_mode={self.round_mode}',
98 | f'soft_targets={self.soft_targets}',
99 | f'temperature={self.temperature}',
100 | ])
101 |
102 |
103 | class AdaRoundSymmetricUniformQuantizer(AdaRoundQuantizer, SymmetricUniformQuantizer):
104 | pass
105 |
106 |
107 | class AdaRoundAsymmetricUniformQuantizer(AdaRoundQuantizer, AsymmetricUniformQuantizer):
108 | pass
109 |
110 |
111 | ADAROUND_QUANTIZER_MAP = {
112 | SymmetricUniformQuantizer: AdaRoundSymmetricUniformQuantizer,
113 | AsymmetricUniformQuantizer: AdaRoundAsymmetricUniformQuantizer,
114 | }
115 |
--------------------------------------------------------------------------------
/quantization/adaround/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import logging
5 | from enum import Flag, auto
6 | from math import ceil
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from utils.utils import StopForwardException
13 |
14 |
15 | # setup logger
16 | logger = logging.getLogger('AdaRound')
17 | logger.setLevel(logging.INFO)
18 |
19 |
20 | def sigmoid(x):
21 | return (1.0 + np.exp(-x)) ** -1.0
22 |
23 |
24 | class BaseOption(Flag):
25 | def __str__(self):
26 | return self.name
27 |
28 | @property
29 | def cls(self):
30 | return self.value.cls
31 |
32 | @classmethod
33 | def list_names(cls):
34 | return [m.name for m in cls]
35 |
36 |
37 | class AdaRoundActQuantMode(BaseOption):
38 | # Activation quantization is disabled
39 | no_act_quant = auto()
40 |
41 | # AdaRound with FP32 acts, activations are quantized afterwards if applicable (default):
42 | post_adaround = auto()
43 |
44 |
45 | class AdaRoundInitMode(BaseOption):
46 | """Weight quantization grid initialization."""
47 |
48 | range_estimator = auto()
49 | mse = auto() # old implementation
50 | mse_out = auto()
51 | mse_out_asym = auto()
52 |
53 |
54 | class AdaRoundLossType(BaseOption):
55 | """Regularization terms."""
56 |
57 | relaxation = auto()
58 | temp_decay = auto()
59 |
60 |
61 | class AdaRoundMode(BaseOption):
62 | nearest = auto() # (default)
63 |
64 | # original AdaRound relaxation methods
65 | learned_sigmoid = auto()
66 | learned_hard_sigmoid = auto()
67 | sigmoid_temp_decay = auto()
68 |
69 | RELAXATION = learned_sigmoid | learned_hard_sigmoid | sigmoid_temp_decay
70 |
71 | @classmethod
72 | def list_names(cls):
73 | exclude = (AdaRoundMode.nearest, AdaRoundMode.RELAXATION)
74 | return [m.name for m in cls if not m in exclude]
75 |
76 |
77 | MODE_TO_LOSS_TYPE = {
78 | AdaRoundMode.learned_hard_sigmoid: AdaRoundLossType.relaxation,
79 | AdaRoundMode.learned_sigmoid: AdaRoundLossType.relaxation,
80 | AdaRoundMode.sigmoid_temp_decay: AdaRoundLossType.temp_decay,
81 | }
82 |
83 |
84 | class AdaRoundTempDecayType(BaseOption):
85 | linear = auto()
86 | cosine = auto()
87 | sigmoid = auto() # https://arxiv.org/abs/1811.09332
88 | power = auto()
89 | exp = auto()
90 | log = auto()
91 |
92 |
93 | class TempDecay:
94 | def __init__(self, t_max, b_range=(20.0, 2.0), rel_decay_start=0.0,
95 | decay_type=AdaRoundTempDecayType.linear, decay_shape=1.0):
96 | self.t_max = t_max
97 | self.start_b, self.end_b = b_range
98 | self.decay_type = decay_type
99 | self.decay_shape = decay_shape
100 | self.decay_start = rel_decay_start * t_max
101 |
102 | def __call__(self, t):
103 | if t < self.decay_start:
104 | return self.start_b
105 |
106 | rel_t = (t - self.decay_start) / (self.t_max - self.decay_start)
107 | if self.decay_type == AdaRoundTempDecayType.linear:
108 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))
109 | elif self.decay_type == AdaRoundTempDecayType.cosine:
110 | return self.end_b + 0.5 * (self.start_b - self.end_b) * (1 + np.cos(rel_t * np.pi))
111 | elif self.decay_type == AdaRoundTempDecayType.sigmoid:
112 | d = self.decay_shape
113 | offset = sigmoid(-d / 2)
114 | rel_progress = (sigmoid(d * (rel_t - 0.5)) - offset) / (1 - 2 * offset)
115 | return self.start_b + (self.end_b - self.start_b) * rel_progress
116 | elif self.decay_type == AdaRoundTempDecayType.power:
117 | return self.end_b + (self.start_b - self.end_b) * (1 - rel_t ** self.decay_shape)
118 | elif self.decay_type == AdaRoundTempDecayType.exp:
119 | r = self.decay_shape
120 | rel_progress = (1.0 - np.exp(-r * rel_t)) / (1.0 - np.exp(-r))
121 | return self.start_b + (self.end_b - self.start_b) * rel_progress
122 | elif self.decay_type == AdaRoundTempDecayType.log:
123 | r = self.decay_shape
124 | C = np.exp(self.end_b / r)
125 | c = np.exp(self.start_b / r)
126 | return r * np.log((C - c) * rel_t + c)
127 | else:
128 | raise ValueError(f'Unknown temp decay type {self.decay_type}')
129 |
130 |
131 | class CombinedLoss:
132 | def __init__(self, quantizer, loss_type=AdaRoundLossType.relaxation, weight=0.01,
133 | max_count=1000, b_range=(20, 2), warmup=0.0, decay_start=0.0, **temp_decay_kw):
134 | self.quantizer = quantizer
135 | self.loss_type = loss_type
136 | self.weight = weight
137 |
138 | self.loss_start = max_count * warmup
139 | self.temp_decay = TempDecay(
140 | max_count,
141 | b_range=b_range,
142 | rel_decay_start=warmup + (1.0 - warmup) * decay_start,
143 | **temp_decay_kw,
144 | )
145 | self.iter = 0
146 |
147 | def __call__(self, pred, tgt, *args, **kwargs):
148 | self.iter += 1
149 |
150 | rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean()
151 |
152 | if self.iter < self.loss_start:
153 | b = self.temp_decay(self.iter)
154 | round_loss = 0
155 | elif self.loss_type == AdaRoundLossType.temp_decay:
156 | b = self.temp_decay(self.iter)
157 | self.quantizer.temperature = b
158 | round_loss = 0
159 | elif self.loss_type == AdaRoundLossType.relaxation: # 1 - |(h-0.5)*2|**b
160 | b = self.temp_decay(self.iter)
161 | round_vals = self.quantizer.get_rest().view(-1)
162 | round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum()
163 | else:
164 | raise ValueError(f'Unknown loss type {self.loss_type}')
165 |
166 | total_loss = rec_loss + round_loss
167 | if self.iter == 1 or self.iter % 100 == 0:
168 | logger.info(
169 | f'Total loss:\t{total_loss:.4f} (rec:{rec_loss:.4f}, '
170 | f'round:{round_loss:.3f})\tb={b:.2f}\titer={self.iter}'
171 | )
172 | return total_loss
173 |
174 |
175 | class StopForwardHook:
176 | def __call__(self, module, *args):
177 | raise StopForwardException
178 |
179 |
180 | class DataSaverHook:
181 | def __init__(self, store_input=False, store_output=False, stop_forward=False):
182 | self.store_input = store_input
183 | self.store_output = store_output
184 | self.stop_forward = stop_forward
185 |
186 | self.input_store = None
187 | self.output_store = None
188 |
189 | def __call__(self, module, input_batch, output_batch):
190 | if self.store_input:
191 | self.input_store = input_batch
192 | if self.store_output:
193 | self.output_store = output_batch
194 | if self.stop_forward:
195 | raise StopForwardException
196 |
197 |
198 | class GetLayerInpOut:
199 | def __init__(self, model, layer, asym=False, act_quant=False, store_output=True):
200 | self.model = model
201 | self.layer = layer
202 | self.asym = asym
203 | self.device = layer.weight.device
204 | self.act_quant = act_quant
205 | self.store_output = store_output
206 | self.data_saver = DataSaverHook(
207 | store_input=True, store_output=self.store_output, stop_forward=True
208 | )
209 |
210 | def __call__(self, model_input):
211 | self.model.full_precision()
212 | handle = self.layer.register_forward_hook(self.data_saver)
213 |
214 | with torch.no_grad():
215 | try:
216 | _ = self.model(model_input.to(self.device))
217 | except StopForwardException:
218 | pass
219 |
220 | if self.asym: # recalculate input with network quantized
221 | self.data_saver.store_output = False
222 | self.model.set_quant_state(weight_quant=True, act_quant=self.act_quant)
223 | try:
224 | _ = self.model(model_input.to(self.device))
225 | except StopForwardException:
226 | pass
227 | self.data_saver.store_output = True
228 |
229 | handle.remove()
230 |
231 | self.model.full_precision()
232 | self.layer.quantized_weights()
233 | return self.data_saver.input_store[0].detach(), self.data_saver.output_store.detach()
234 |
235 |
236 | class LayerOutputMSE:
237 | def __init__(self, layer, get_inp_out, data_tensor, batch_size, name='mse_out'):
238 | cur_inp, cur_out = get_inp_out(data_tensor)
239 | self.input = cur_inp
240 | self.exp_out = cur_out
241 | self.layer = layer
242 | self.batch_size = batch_size
243 | self.name = name
244 |
245 | def __call__(self):
246 | loss = 0.0
247 | x = self.input
248 | for i in range(ceil(x.size(0) / self.batch_size)):
249 | cur_out = self.layer(x[i * self.batch_size:(i + 1) * self.batch_size])
250 | exp_out = self.exp_out[i * self.batch_size:(i + 1) * self.batch_size]
251 | loss += F.mse_loss(cur_out, exp_out).item()
252 | return loss
253 |
--------------------------------------------------------------------------------
/quantization/autoquant_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import copy
5 | import warnings
6 |
7 | from torch.nn import functional as F
8 | from torch import nn
9 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd
10 |
11 | from quantization.base_quantized_classes import FP32Acts, QuantizedActivation, QuantizedModule
12 | from quantization.hijacker import QuantizationHijacker, activations_list
13 | from quantization.quantization_manager import QuantizationManager
14 |
15 |
16 | class QuantLinear(QuantizationHijacker, nn.Linear):
17 | def __init__(self, *args, **kwargs):
18 | super().__init__(*args, **kwargs)
19 |
20 | def run_forward(self, x, weight, bias, offsets=None):
21 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias)
22 |
23 |
24 | class QuantizedActivationWrapper(QuantizedActivation):
25 | """
26 | Wraps over a layer and quantized the activation.
27 | It also allow for tying the input and output quantizer which is helpful
28 | for layers such Average Pooling.
29 | """
30 | def __init__(self, layer, tie_activation_quantizers=False,
31 | input_quantizer: QuantizationManager = None, *args, **kwargs):
32 | super().__init__(*args, **kwargs)
33 | self.tie_activation_quantizers = tie_activation_quantizers
34 | if input_quantizer:
35 | assert isinstance(input_quantizer, QuantizationManager)
36 | self.activation_quantizer = input_quantizer
37 | self.layer = layer
38 |
39 | def quantize_activations_no_range_update(self, x):
40 | if self._quant_a:
41 | return self.activation_quantizer.quantizer(x)
42 | else:
43 | return x
44 |
45 | def forward(self, x):
46 | x = self.layer(x)
47 | if self.tie_activation_quantizers:
48 | # The input activation quantizer is used to quantize the activation
49 | # but without updating the quantization range
50 | return self.quantize_activations_no_range_update(x)
51 | else:
52 | return self.quantize_activations(x)
53 |
54 |
55 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm):
56 | def __init__(self, *args, activation=None, **kwargs):
57 | super().__init__(*args, activation=activation, **kwargs)
58 |
59 | def run_forward(self, x, weight, bias, offsets=None):
60 | return F.layer_norm(
61 | input=x.contiguous(),
62 | normalized_shape=self.normalized_shape,
63 | weight=weight.contiguous(),
64 | bias=bias.contiguous(),
65 | eps=self.eps,
66 | )
67 |
68 |
69 | class QuantEmbedding(QuantizationHijacker, nn.Embedding):
70 | def __init__(self, *args, activation=None, **kwargs):
71 | super().__init__(*args, activation=activation, **kwargs)
72 | # NB: Embedding should not quantize activations, as it is simply a lookup table,
73 | # which is already quantized.
74 | self.activation_quantizer = FP32Acts()
75 |
76 | def run_forward(self, x, weight, bias, offsets=None):
77 | return F.embedding(
78 | input=x.contiguous(),
79 | weight=weight.contiguous(),
80 | padding_idx=self.padding_idx,
81 | max_norm=self.max_norm,
82 | norm_type=self.norm_type,
83 | scale_grad_by_freq=self.scale_grad_by_freq,
84 | sparse=self.sparse,
85 | )
86 |
87 |
88 | module_map = {
89 | nn.Linear: QuantLinear,
90 | nn.LayerNorm: QuantLayerNorm,
91 | nn.Embedding: QuantEmbedding
92 | }
93 |
94 |
95 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd)
96 |
97 |
98 | def get_act(module, i):
99 | result, act_idx = None, None
100 | for i in range(i + 1, len(module)):
101 | if isinstance(module[i], tuple(activations_list)):
102 | result = module[i]
103 | act_idx = i
104 | break
105 | return result, act_idx
106 |
107 |
108 | def get_linear_args(module):
109 | args = dict(
110 | in_features=module.in_features,
111 | out_features=module.out_features,
112 | bias=module.bias is not None,
113 | )
114 | return args
115 |
116 |
117 | def get_layernorm_args(module):
118 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps)
119 | return args
120 |
121 |
122 | def get_embedding_args(module):
123 | args = dict(
124 | num_embeddings=module.num_embeddings,
125 | embedding_dim=module.embedding_dim,
126 | padding_idx=module.padding_idx,
127 | max_norm=module.max_norm,
128 | norm_type=module.norm_type,
129 | scale_grad_by_freq=module.scale_grad_by_freq,
130 | sparse=module.sparse,
131 | )
132 | return args
133 |
134 |
135 | def get_module_args(mod, act):
136 | if isinstance(mod, nn.Linear):
137 | kwargs = get_linear_args(mod)
138 | elif isinstance(mod, nn.LayerNorm):
139 | kwargs = get_layernorm_args(mod)
140 | elif isinstance(mod, nn.Embedding):
141 | kwargs = get_embedding_args(mod)
142 | else:
143 | raise ValueError
144 |
145 | kwargs['activation'] = act
146 | return kwargs
147 |
148 |
149 | def quant_module(module, i, **quant_params):
150 | act, act_idx = get_act(module, i)
151 | modtype = module_map[type(module[i])]
152 |
153 | kwargs = get_module_args(module[i], act)
154 | new_module = modtype(**kwargs, **quant_params)
155 | new_module.weight.data = module[i].weight.data.clone()
156 |
157 | if module[i].bias is not None:
158 | new_module.bias.data = module[i].bias.data.clone()
159 |
160 | return new_module, i + int(bool(act)) + 1
161 |
162 |
163 | def quantize_sequence(model, specials=None, tie_activation_quantizers=False, **quant_params):
164 | specials = specials or dict()
165 |
166 | i = 0
167 | quant_modules = []
168 | while i < len(model):
169 | if isinstance(model[i], QuantizedModule):
170 | quant_modules.append(model[i])
171 | elif type(model[i]) in module_map:
172 | new_module, new_i = quant_module(model, i, **quant_params)
173 | quant_modules.append(new_module)
174 | i = new_i
175 | continue
176 |
177 | elif type(model[i]) in specials:
178 | quant_modules.append(specials[type(model[i])](model[i], **quant_params))
179 |
180 | elif isinstance(model[i], non_param_modules):
181 | # check for last quantizer
182 | input_quantizer = None
183 | if (
184 | quant_modules
185 | and isinstance(quant_modules[-1], QuantizedModule)
186 | and tie_activation_quantizers
187 | ):
188 | input_quantizer = quant_modules[-1].activation_quantizer
189 | warnings.warn(
190 | f'Tying input quantizer {i}^th layer of type '
191 | f'{type(quant_modules[-1])} to the quantized {type(model[i])} '
192 | f'following it'
193 | )
194 | quant_modules.append(
195 | QuantizedActivationWrapper(
196 | model[i],
197 | tie_activation_quantizers=tie_activation_quantizers,
198 | input_quantizer=input_quantizer,
199 | **quant_params,
200 | )
201 | )
202 |
203 | else:
204 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params))
205 | i += 1
206 | return quant_modules
207 |
208 |
209 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params):
210 | quant_modules = quantize_sequence(model, specials, tie_activation_quantizers, **quant_params)
211 | return nn.Sequential(*quant_modules)
212 |
213 |
214 | def quantize_module_list(model, specials=None, tie_activation_quantizers=False, **quant_params):
215 | quant_modules = quantize_sequence(model, specials, tie_activation_quantizers, **quant_params)
216 | return nn.ModuleList(quant_modules)
217 |
218 |
219 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params):
220 | specials = specials or dict()
221 |
222 | if isinstance(model, nn.Sequential):
223 | quant_model = quantize_sequential(
224 | model, specials, tie_activation_quantizers, **quant_params
225 | )
226 |
227 | elif type(model) in specials:
228 | quant_model = specials[type(model)](model, **quant_params)
229 |
230 | elif isinstance(model, non_param_modules):
231 | quant_model = QuantizedActivationWrapper(model, **quant_params)
232 |
233 | elif type(model) in module_map:
234 | # if we do isinstance() then we might run into issues with modules that inherit from
235 | # one of these classes, for whatever reason
236 | modtype = module_map[type(model)]
237 | kwargs = get_module_args(model, None)
238 | quant_model = modtype(**kwargs, **quant_params)
239 |
240 | quant_model.weight.data = model.weight.data
241 | if getattr(model, 'bias', None) is not None:
242 | quant_model.bias.data = model.bias.data
243 |
244 | else:
245 | # unknown type, try to quantize all child modules
246 | quant_model = copy.deepcopy(model)
247 | for name, module in quant_model._modules.items():
248 | new_model = quantize_model(module, specials=specials, **quant_params)
249 | if new_model is not None:
250 | setattr(quant_model, name, new_model)
251 |
252 | return quant_model
253 |
--------------------------------------------------------------------------------
/quantization/base_quantized_classes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from torch import nn
5 |
6 | from quantization.quantization_manager import QuantizationManager
7 | from quantization.quantizers import QMethods
8 | from quantization.range_estimators import RangeEstimators
9 |
10 |
11 | def _set_layer_learn_ranges(layer):
12 | if isinstance(layer, QuantizationManager):
13 | if layer.quantizer.is_initialized:
14 | layer.learn_ranges()
15 |
16 |
17 | def _set_layer_fix_ranges(layer):
18 | if isinstance(layer, QuantizationManager):
19 | if layer.quantizer.is_initialized:
20 | layer.fix_ranges()
21 |
22 |
23 | def _set_layer_estimate_ranges(layer):
24 | if isinstance(layer, QuantizationManager):
25 | if layer.quantizer.is_initialized:
26 | layer.estimate_ranges()
27 |
28 |
29 | def _set_layer_estimate_ranges_train(layer):
30 | if isinstance(layer, QuantizationManager):
31 | if layer.quantizer.is_initialized:
32 | layer.estimate_ranges_train()
33 |
34 |
35 | class QuantizedModule(nn.Module):
36 | """
37 | Parent class for a quantized module. It adds the basic functionality of switching the module
38 | between quantized and full precision mode. It also defines the cached parameters and handles
39 | the reset of the cache properly.
40 | """
41 | def __init__(self, *args, method=QMethods.asymmetric_uniform, act_method=None, n_bits=8,
42 | n_bits_act=None, per_channel_weights=False, per_channel_acts=False,
43 | percentile=None, weight_range_method=RangeEstimators.current_minmax,
44 | weight_range_options=None, act_range_method=RangeEstimators.running_minmax,
45 | act_range_options=None, scale_domain='linear', **kwargs):
46 | kwargs.pop('quant_dict', None)
47 | super().__init__(*args, **kwargs)
48 |
49 | self.method = method
50 | self.act_method = act_method or method
51 | self.n_bits = n_bits
52 | self.n_bits_act = n_bits_act or n_bits
53 | self.per_channel_weights = per_channel_weights
54 | self.per_channel_acts = per_channel_acts
55 | self.percentile = percentile
56 | self.weight_range_method = weight_range_method
57 | self.weight_range_options = weight_range_options if weight_range_options else {}
58 | self.act_range_method = act_range_method
59 | self.act_range_options = act_range_options if act_range_options else {}
60 | self.scale_domain = scale_domain
61 |
62 | self.cached_params = None
63 | self._caching = True
64 |
65 | self.quant_params = None
66 | self._quant_w = False
67 | self._quant_a = False
68 |
69 | @property
70 | def caching(self):
71 | return self._caching
72 |
73 | @caching.setter
74 | def caching(self, value: bool):
75 | self._caching = value
76 | if not value:
77 | self.cached_params = None
78 |
79 | def quantized_weights(self):
80 | self.cached_params = None
81 | self._quant_w = True
82 |
83 | def full_precision_weights(self):
84 | self.cached_params = None
85 | self._quant_w = False
86 |
87 | def quantized_acts(self):
88 | self._quant_a = True
89 |
90 | def full_precision_acts(self):
91 | self._quant_a = False
92 |
93 | def quantized(self):
94 | self.quantized_weights()
95 | self.quantized_acts()
96 |
97 | def full_precision(self):
98 | self.full_precision_weights()
99 | self.full_precision_acts()
100 |
101 | def learn_ranges(self):
102 | self.apply(_set_layer_learn_ranges)
103 |
104 | def fix_ranges(self):
105 | self.apply(_set_layer_fix_ranges)
106 |
107 | def estimate_ranges(self):
108 | self.apply(_set_layer_estimate_ranges)
109 |
110 | def estimate_ranges_train(self):
111 | self.apply(_set_layer_estimate_ranges_train)
112 |
113 | def train(self, mode=True):
114 | super().train(mode)
115 | if mode:
116 | self.cached_params = None
117 | return self
118 |
119 | def _apply(self, *args, **kwargs):
120 | self.cached_params = None
121 | return super(QuantizedModule, self)._apply(*args, **kwargs)
122 |
123 | def extra_repr(self):
124 | quant_state = 'weight_quant={}, act_quant={}'.format(self._quant_w, self._quant_a)
125 | parent_repr = super().extra_repr()
126 | return '{},\n{}'.format(parent_repr, quant_state) if parent_repr else quant_state
127 |
128 |
129 | class QuantizedActivation(QuantizedModule):
130 | def __init__(self, *args, **kwargs):
131 | super().__init__(*args, **kwargs)
132 | act_qparams = dict(n_bits=self.n_bits_act, scale_domain=self.scale_domain)
133 | self.activation_quantizer = QuantizationManager(
134 | qmethod=self.act_method,
135 | qparams=act_qparams,
136 | init=self.act_range_method,
137 | init_params=self.act_range_options,
138 | )
139 |
140 | def quantize_activations(self, x):
141 | if self._quant_a:
142 | return self.activation_quantizer(x)
143 | else:
144 | return x
145 |
146 | def forward(self, x):
147 | return self.quantize_activations(x)
148 |
149 |
150 | class FP32Acts(nn.Module):
151 | def forward(self, x):
152 | return x
153 |
154 | def reset_ranges(self):
155 | pass
156 |
--------------------------------------------------------------------------------
/quantization/base_quantized_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from torch import nn
5 |
6 | from quantization.base_quantized_classes import (
7 | QuantizedModule,
8 | _set_layer_learn_ranges,
9 | _set_layer_fix_ranges,
10 | _set_layer_estimate_ranges,
11 | _set_layer_estimate_ranges_train,
12 | )
13 |
14 |
15 | class QuantizedModel(nn.Module):
16 | """
17 | Parent class for a quantized model. This allows you to have convenience functions to put the
18 | whole model into quantization or full precision or to freeze BN. Otherwise it does not add any
19 | further functionality, so it is not a necessity that a quantized model uses this class.
20 | """
21 | def quantized_weights(self):
22 | def _fn(layer):
23 | if isinstance(layer, QuantizedModule):
24 | layer.quantized_weights()
25 |
26 | self.apply(_fn)
27 |
28 | def full_precision_weights(self):
29 | def _fn(layer):
30 | if isinstance(layer, QuantizedModule):
31 | layer.full_precision_weights()
32 |
33 | self.apply(_fn)
34 |
35 | def quantized_acts(self):
36 | def _fn(layer):
37 | if isinstance(layer, QuantizedModule):
38 | layer.quantized_acts()
39 |
40 | self.apply(_fn)
41 |
42 | def full_precision_acts(self):
43 | def _fn(layer):
44 | if isinstance(layer, QuantizedModule):
45 | layer.full_precision_acts()
46 |
47 | self.apply(_fn)
48 |
49 | def quantized(self):
50 | def _fn(layer):
51 | if isinstance(layer, QuantizedModule):
52 | layer.quantized()
53 |
54 | self.apply(_fn)
55 |
56 | def full_precision(self):
57 | def _fn(layer):
58 | if isinstance(layer, QuantizedModule):
59 | layer.full_precision()
60 |
61 | self.apply(_fn)
62 |
63 | # Methods for switching quantizer quantization states
64 | def learn_ranges(self):
65 | self.apply(_set_layer_learn_ranges)
66 |
67 | def fix_ranges(self):
68 | self.apply(_set_layer_fix_ranges)
69 |
70 | def fix_act_ranges(self):
71 | def _fn(module):
72 | if isinstance(module, QuantizedModule) and hasattr(module, 'activation_quantizer'):
73 | _set_layer_fix_ranges(module.activation_quantizer)
74 |
75 | self.apply(_fn)
76 |
77 | def fix_weight_ranges(self):
78 | def _fn(module):
79 | if isinstance(module, QuantizedModule) and hasattr(module, 'weight_quantizer'):
80 | _set_layer_fix_ranges(module.weight_quantizer)
81 |
82 | self.apply(_fn)
83 |
84 | def estimate_ranges(self):
85 | self.apply(_set_layer_estimate_ranges)
86 |
87 | def estimate_act_ranges(self):
88 | def _fn(module):
89 | if isinstance(module, QuantizedModule) and hasattr(module, 'activation_quantizer'):
90 | _set_layer_estimate_ranges(module.activation_quantizer)
91 |
92 | self.apply(_fn)
93 |
94 | def estimate_ranges_train(self):
95 | self.apply(_set_layer_estimate_ranges_train)
96 |
97 | def reset_act_ranges(self):
98 | def _fn(module):
99 | if isinstance(module, QuantizedModule) and hasattr(module, 'activation_quantizer'):
100 | module.activation_quantizer.reset_ranges()
101 |
102 | self.apply(_fn)
103 |
104 | def set_quant_state(self, weight_quant, act_quant):
105 | if act_quant:
106 | self.quantized_acts()
107 | else:
108 | self.full_precision_acts()
109 |
110 | if weight_quant:
111 | self.quantized_weights()
112 | else:
113 | self.full_precision_weights()
114 |
--------------------------------------------------------------------------------
/quantization/hijacker.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import copy
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from quantization.base_quantized_classes import QuantizedModule
10 | from quantization.quantization_manager import QuantizationManager
11 | from quantization.range_estimators import RangeEstimators
12 | from quantization.utils import to_numpy
13 |
14 |
15 | activations_list = [nn.ReLU, nn.ReLU6, nn.Hardtanh, nn.Sigmoid, nn.Tanh, nn.PReLU, nn.GELU]
16 |
17 |
18 | class QuantizationHijacker(QuantizedModule):
19 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and
20 | dequantization on the weights and output distributions.
21 |
22 | Usage:
23 | To make a quantized nn.Linear layer:
24 | ```
25 | >>> class QuantLinear(QuantizationHijacker, nn.Linear):
26 | ... pass
27 | ```
28 |
29 | It is vital that QSchemeForwardHijacker is the first parent class, and that the second parent
30 | class derives from nn.Module, otherwise it will not be reached by a super(., .) call.
31 |
32 | NB: this implementation (for now) assumes that there will always be some training involved,
33 | e.g. to estimate the activation ranges.
34 | """
35 | def __init__(self, *args, activation: nn.Module = None, **kwargs):
36 | super().__init__(*args, **kwargs)
37 | if activation:
38 | assert isinstance(activation, tuple(activations_list))
39 | self.activation_function = copy.deepcopy(activation) if activation else None
40 |
41 | weight_qparams = dict(n_bits=self.n_bits, scale_domain=self.scale_domain)
42 | act_qparams = dict(n_bits=self.n_bits_act, scale_domain=self.scale_domain)
43 |
44 | self.activation_quantizer = QuantizationManager(
45 | qmethod=self.act_method,
46 | init=self.act_range_method,
47 | per_channel=self.per_channel_acts,
48 | qparams=act_qparams,
49 | init_params=self.act_range_options,
50 | )
51 |
52 | if self.weight_range_method == RangeEstimators.current_minmax:
53 | weight_init_params = dict(percentile=self.percentile)
54 | else:
55 | weight_init_params = self.weight_range_options
56 | self.weight_quantizer = QuantizationManager(
57 | qmethod=self.method,
58 | init=self.weight_range_method,
59 | per_channel=self.per_channel_weights,
60 | qparams=weight_qparams,
61 | init_params=weight_init_params,
62 | )
63 | self.activation_save_target = None
64 | self.activation_save_name = None
65 |
66 | def forward(self, x, offsets=None):
67 | weight, bias = self.get_params()
68 | res = self.run_forward(x, weight, bias, offsets=offsets)
69 | res = self.quantize_activations(res)
70 | return res
71 |
72 | def get_params(self):
73 | if not self.training and self.cached_params:
74 | return self.cached_params
75 |
76 | weight, bias = self.get_weight_bias()
77 |
78 | if self._quant_w:
79 | weight = self.weight_quantizer(weight)
80 |
81 | if self._caching and not self.training and self.cached_params is None:
82 | self.cached_params = (
83 | torch.Tensor(to_numpy(weight)).to(weight.device),
84 | torch.Tensor(to_numpy(bias)).to(bias.device) if bias is not None else None,
85 | )
86 | return weight, bias
87 |
88 | def get_weight_bias(self):
89 | bias = None
90 | if hasattr(self, "bias"):
91 | bias = self.bias
92 | return self.weight, bias
93 |
94 | def run_forward(self, x, weight, bias, offsets=None):
95 | # Performs the actual (e.g., linear) operation of the layer
96 | raise NotImplementedError()
97 |
98 | def quantize_activations(self, activations):
99 | """Quantize a single activation tensor or all activations from a layer. I'm assuming that
100 | we should quantize all outputs for a layer with the same quantization scheme.
101 | """
102 | if self.activation_function is not None:
103 | activations = self.activation_function(activations)
104 |
105 | if self.activation_save_target is not None:
106 | self.activation_save_target[self.activation_save_name] = activations.data.cpu().numpy()
107 |
108 | if self._quant_a:
109 | activations = self.activation_quantizer(activations)
110 |
111 | if self.activation_save_target is not None:
112 | self.activation_save_target[self.activation_save_name + '_Q'] = (
113 | activations.data.cpu().numpy()
114 | )
115 |
116 | return activations
117 |
--------------------------------------------------------------------------------
/quantization/quantization_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from enum import Enum
5 |
6 | from torch import nn
7 |
8 | from quantization.quantizers import QMethods, QuantizerNotInitializedError
9 | from quantization.range_estimators import RangeEstimators
10 |
11 |
12 | class Qstates(Enum):
13 | estimate_ranges = 0 # ranges are updated in eval and train mode
14 | fix_ranges = 1 # quantization ranges are fixed for train and eval
15 | learn_ranges = 2 # quantization params are nn.Parameters
16 | estimate_ranges_train = 3 # quantization ranges are updated during train and fixed for eval
17 |
18 |
19 | class QuantizationManager(nn.Module):
20 | """Implementation of Quantization and Quantization Range Estimation
21 |
22 | Parameters
23 | ----------
24 | n_bits: int
25 | Number of bits for the quantization.
26 | qmethod: QMethods member (Enum)
27 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform,
28 | qmn_uniform etc.
29 | init: RangeEstimators member (Enum)
30 | Initialization method for the grid from
31 | per_channel: bool
32 | If true, will use a separate quantization grid for each kernel/channle.
33 | x_min: float or PyTorch Tensor
34 | The minimum value which needs to be represented.
35 | x_max: float or PyTorch Tensor
36 | The maximum value which needs to be represented.
37 | """
38 | def __init__(self, qmethod=QMethods.symmetric_uniform, init=RangeEstimators.current_minmax,
39 | per_channel=False, axis=None, n_groups=None, x_min=None, x_max=None, qparams=None,
40 | init_params=None):
41 | super().__init__()
42 | self.state = Qstates.estimate_ranges
43 | self.qmethod = qmethod
44 | self.init = init
45 | self.per_channel = per_channel
46 | self.axis = axis
47 | self.n_groups = n_groups
48 | self.qparams = qparams if qparams else {}
49 | self.init_params = init_params if init_params else {}
50 | self.range_estimator = None
51 |
52 | # define quantizer
53 | self.quantizer = self.qmethod.cls(per_channel=per_channel, axis=axis, **qparams)
54 |
55 | # define range estimation method for quantizer initialisation
56 | if x_min is not None and x_max is not None:
57 | self.set_quant_range(x_min, x_max)
58 | self.state = Qstates.fix_ranges
59 | else:
60 | # set up the collector function to set the ranges
61 | self.range_estimator = self.init.cls(
62 | per_channel=self.per_channel,
63 | quantizer=self.quantizer,
64 | axis=self.axis,
65 | n_groups=self.n_groups,
66 | **self.init_params
67 | )
68 |
69 | @property
70 | def n_bits(self):
71 | return self.quantizer.n_bits
72 |
73 | def estimate_ranges(self):
74 | self.state = Qstates.estimate_ranges
75 |
76 | def fix_ranges(self):
77 | if self.quantizer.is_initialized:
78 | self.state = Qstates.fix_ranges
79 | else:
80 | raise QuantizerNotInitializedError()
81 |
82 | def learn_ranges(self):
83 | self.quantizer.make_range_trainable()
84 | self.state = Qstates.learn_ranges
85 |
86 | def estimate_ranges_train(self):
87 | self.state = Qstates.estimate_ranges_train
88 |
89 | def reset_ranges(self):
90 | self.range_estimator.reset()
91 | self.quantizer.reset()
92 | self.estimate_ranges()
93 |
94 | def forward(self, x):
95 | if self.range_estimator.per_group_range_estimation:
96 | self.range_estimator(x)
97 | return x
98 |
99 | if self.state == Qstates.estimate_ranges or (
100 | self.state == Qstates.estimate_ranges_train and self.training
101 | ):
102 | # Note this can be per tensor or per channel
103 | cur_xmin, cur_xmax = self.range_estimator(x)
104 | self.set_quant_range(cur_xmin, cur_xmax)
105 |
106 | return self.quantizer(x)
107 |
108 | def set_quant_range(self, x_min, x_max):
109 | self.quantizer.set_quant_range(x_min, x_max)
110 |
111 | def extra_repr(self):
112 | return 'state={}'.format(self.state.name)
113 |
--------------------------------------------------------------------------------
/quantization/quantizers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from collections import namedtuple
5 | from enum import Enum
6 |
7 | import torch
8 | from torch.autograd import Function
9 | from torch import nn
10 |
11 |
12 | class RoundStraightThrough(Function):
13 | @staticmethod
14 | def forward(ctx, x):
15 | return torch.round(x)
16 |
17 | @staticmethod
18 | def backward(ctx, output_grad):
19 | return output_grad
20 |
21 |
22 | class FloorStraightThrough(Function):
23 | @staticmethod
24 | def forward(ctx, x):
25 | return torch.floor(x)
26 |
27 | @staticmethod
28 | def backward(ctx, output_grad):
29 | return output_grad
30 |
31 |
32 | round_ste_func = RoundStraightThrough.apply
33 | floor_ste_func = FloorStraightThrough.apply
34 |
35 |
36 | class QuantizerBase(nn.Module):
37 | def __init__(self, n_bits, per_channel=False, axis=None, *args, **kwargs):
38 | super().__init__(*args, **kwargs)
39 | self.n_bits = n_bits
40 | self.per_channel = per_channel
41 | self.axis = axis
42 |
43 | @property
44 | def is_initialized(self):
45 | raise NotImplementedError()
46 |
47 | @property
48 | def x_max(self):
49 | raise NotImplementedError()
50 |
51 | @property
52 | def symmetric(self):
53 | raise NotImplementedError()
54 |
55 | @property
56 | def x_min(self):
57 | raise NotImplementedError()
58 |
59 | def forward(self, x_float):
60 | raise NotImplementedError()
61 |
62 | def _adjust_params_per_axis(self, x):
63 | raise NotImplementedError()
64 |
65 | def _adjust_params_per_channel(self, x):
66 | raise NotImplementedError()
67 |
68 | def set_quant_range(self, x_min, x_max):
69 | raise NotImplementedError()
70 |
71 | def extra_repr(self):
72 | return (
73 | f'n_bits={self.n_bits}, per_channel={self.per_channel}, axis={self.axis}, '
74 | f'is_initalized={self.is_initialized}'
75 | )
76 |
77 | def reset(self):
78 | self._delta = None
79 |
80 |
81 | class AsymmetricUniformQuantizer(QuantizerBase):
82 | """
83 | PyTorch Module that implements Asymmetric Uniform Quantization using STE.
84 | Quantizes its argument in the forward pass, passes the gradient 'straight
85 | through' on the backward pass, ignoring the quantization that occurred.
86 |
87 | Parameters
88 | ----------
89 | n_bits: int
90 | Number of bits for quantization.
91 | scale_domain: str ('log', 'linear) with default='linear'
92 | Domain of scale factor
93 | per_channel: bool
94 | If True: allows for per-channel quantization
95 | """
96 | def __init__(self, n_bits, scale_domain='linear', per_channel=False, axis=None, eps=1e-8):
97 |
98 | super().__init__(n_bits, per_channel)
99 |
100 | assert scale_domain in ('linear', 'log')
101 | self.register_buffer('_delta', None)
102 | self.register_buffer('_zero_float', None)
103 | self.n_bits = n_bits
104 | self.scale_domain = scale_domain
105 | self.per_channel = per_channel
106 | self.axis = axis
107 | self.eps = eps
108 |
109 | # A few useful properties
110 | @property
111 | def delta(self):
112 | if self._delta is not None:
113 | return self._delta
114 | else:
115 | raise QuantizerNotInitializedError()
116 |
117 | @property
118 | def zero_float(self):
119 | if self._zero_float is not None:
120 | return self._zero_float
121 | else:
122 | raise QuantizerNotInitializedError()
123 |
124 | @property
125 | def is_initialized(self):
126 | return self._delta is not None
127 |
128 | @property
129 | def symmetric(self):
130 | return False
131 |
132 | @property
133 | def int_min(self):
134 | # integer grid minimum
135 | return 0.0
136 |
137 | @property
138 | def int_max(self):
139 | # integer grid maximum
140 | return 2.0 ** self.n_bits - 1
141 |
142 | @property
143 | def scale(self):
144 | if self.scale_domain == 'linear':
145 | return torch.clamp(self.delta, min=self.eps)
146 | elif self.scale_domain == 'log':
147 | return torch.exp(self.delta)
148 |
149 | @property
150 | def zero_point(self):
151 | zero_point = round_ste_func(self.zero_float)
152 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max)
153 | return zero_point
154 |
155 | @property
156 | def x_max(self):
157 | return self.scale * (self.int_max - self.zero_point)
158 |
159 | @property
160 | def x_min(self):
161 | return self.scale * (self.int_min - self.zero_point)
162 |
163 | def _clamp(self, x_int):
164 | with torch.no_grad():
165 | clampled_left = (x_int > self.int_max).float().sum()
166 | clampled_right = (x_int < self.int_min).float().sum()
167 | self._clamped = (clampled_left + clampled_right) / x_int.numel()
168 | x_clamped = torch.clamp(x_int, self.int_min, self.int_max)
169 |
170 | return x_clamped
171 |
172 | def to_integer_forward(self, x_float):
173 | """
174 | Qunatized input to its integer represantion
175 | Parameters
176 | ----------
177 | x_float: PyTorch Float Tensor
178 | Full-precision Tensor
179 |
180 | Returns
181 | -------
182 | x_int: PyTorch Float Tensor of integers
183 | """
184 | x_int = round_ste_func(x_float / self.scale) + self.zero_point
185 | x_int = torch.clamp(x_int, self.int_min, self.int_max)
186 |
187 | return x_int
188 |
189 | def forward(self, x_float):
190 | """
191 | Quantizes (quantized to integer and the scales back to original domain)
192 | Parameters
193 | ----------
194 | x_float: PyTorch Float Tensor
195 | Full-precision Tensor
196 |
197 | Returns
198 | -------
199 | x_quant: PyTorch Float Tensor
200 | Quantized-Dequantized Tensor
201 | """
202 | if self.axis is not None:
203 | self._adjust_params_per_axis(x_float)
204 |
205 | if self.per_channel:
206 | self._adjust_params_per_channel(x_float)
207 |
208 | x_int = self.to_integer_forward(x_float)
209 | x_quant = self.scale * (x_int - self.zero_point)
210 |
211 | return x_quant
212 |
213 | def _adjust_params_per_axis(self, x_float):
214 | r = len(x_float.size())
215 | new_shape = [1] * self.axis + [-1] + [1] * (r - self.axis - 1)
216 | self._delta = self._delta.view(new_shape)
217 | self._zero_float = self._zero_float.view(new_shape)
218 |
219 | def _adjust_params_per_channel(self, x):
220 | """
221 | Adjusts the quantization parameter tensors (delta, zero_float)
222 | to the input tensor shape if they don't match
223 |
224 | Parameters
225 | ----------
226 | x: input tensor
227 | """
228 | if x.ndim != self.delta.ndim:
229 | new_shape = [-1] + [1] * (len(x.shape) - 1)
230 | self._delta = self.delta.view(new_shape)
231 | if self._zero_float is not None:
232 | self._zero_float = self._zero_float.view(new_shape)
233 |
234 | def _tensorize_min_max(self, x_min, x_max):
235 | """
236 | Converts provided min max range into tensors
237 | Parameters
238 | ----------
239 | x_min: float or PyTorch 1D tensor
240 | x_max: float or PyTorch 1D tensor
241 |
242 | Returns
243 | -------
244 | x_min: PyTorch Tensor 0 or 1-D
245 | x_max: PyTorch Tensor 0 or 1-D
246 | """
247 | # Ensure a torch tensor
248 | if not torch.is_tensor(x_min):
249 | x_min = torch.tensor(x_min).float()
250 | x_max = torch.tensor(x_max).float()
251 |
252 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel and self.axis is None:
253 | raise ValueError(
254 | 'x_min and x_max must be a float or 1-D Tensor'
255 | ' for per-tensor quantization (per_channel=False)'
256 | )
257 | # Ensure we always use zero and avoid division by zero
258 | x_min = torch.min(x_min, torch.zeros_like(x_min))
259 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps)
260 |
261 | return x_min, x_max
262 |
263 | def set_quant_range(self, x_min, x_max):
264 | """
265 | Instantiates the quantization parameters based on the provided
266 | min and max range
267 |
268 | Parameters
269 | ----------
270 | x_min: tensor or float
271 | Quantization range minimum limit
272 | x_max: tensor of float
273 | Quantization range minimum limit
274 | """
275 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
276 | self._delta = (x_max - x_min) / self.int_max
277 | self._zero_float = (-x_min / self.delta).detach()
278 |
279 | if self.scale_domain == 'log':
280 | self._delta = torch.log(self.delta)
281 |
282 | self._delta = self._delta.detach()
283 |
284 | def make_range_trainable(self):
285 | # Converts trainable parameters to nn.Parameters
286 | if self.delta not in self.parameters():
287 | self._delta = torch.nn.Parameter(self._delta)
288 | self._zero_float = torch.nn.Parameter(self._zero_float)
289 |
290 |
291 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer):
292 | """
293 | PyTorch Module that implements Symmetric Uniform Quantization using STE.
294 | Quantizes its argument in the forward pass, passes the gradient 'straight
295 | through' on the backward pass, ignoring the quantization that occurred.
296 |
297 | Parameters
298 | ----------
299 | n_bits: int
300 | Number of bits for quantization.
301 | scale_domain: str ('log', 'linear) with default='linear'
302 | Domain of scale factor
303 | per_channel: bool
304 | If True: allows for per-channel quantization
305 | """
306 | def __init__(self, *args, **kwargs):
307 | super().__init__(*args, **kwargs)
308 | self.register_buffer('_signed', None)
309 |
310 | @property
311 | def signed(self):
312 | if self._signed is not None:
313 | return self._signed.item()
314 | else:
315 | raise QuantizerNotInitializedError()
316 |
317 | @property
318 | def symmetric(self):
319 | return True
320 |
321 | @property
322 | def int_min(self):
323 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0
324 |
325 | @property
326 | def int_max(self):
327 | pos_n_bits = self.n_bits - self.signed
328 | return 2.0 ** pos_n_bits - 1
329 |
330 | @property
331 | def zero_point(self):
332 | return 0.0
333 |
334 | def set_quant_range(self, x_min, x_max):
335 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
336 | self._signed = x_min.min() < 0
337 |
338 | x_absmax = torch.max(x_min.abs(), x_max)
339 | self._delta = x_absmax / self.int_max
340 |
341 | if self.scale_domain == 'log':
342 | self._delta = torch.log(self._delta)
343 |
344 | self._delta = self._delta.detach()
345 |
346 | def make_range_trainable(self):
347 | # Converts trainable parameters to nn.Parameters
348 | if self.delta not in self.parameters():
349 | self._delta = torch.nn.Parameter(self._delta)
350 |
351 |
352 | QMethodMap = namedtuple('QMethodMap', ['value', 'cls'])
353 |
354 |
355 | class QMethods(Enum):
356 | symmetric_uniform = QMethodMap(0, SymmetricUniformQuantizer)
357 | asymmetric_uniform = QMethodMap(1, AsymmetricUniformQuantizer)
358 |
359 | @property
360 | def cls(self):
361 | return self.value.cls
362 |
363 | @classmethod
364 | def list(cls):
365 | return [m.name for m in cls]
366 |
367 |
368 | class QuantizerNotInitializedError(Exception):
369 | """Raised when a quantizer has not initialized"""
370 |
371 | def __init__(self):
372 | super(QuantizerNotInitializedError, self).__init__('Quantizer has not been initialized yet')
373 |
--------------------------------------------------------------------------------
/quantization/range_estimators.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import copy
5 | from collections import namedtuple
6 | from enum import Enum
7 |
8 | import numpy as np
9 | import torch
10 | from scipy.optimize import minimize_scalar
11 | from torch import nn
12 | from torch.nn import functional as F
13 |
14 | from quantization.utils import to_numpy
15 |
16 |
17 | class RangeEstimatorBase(nn.Module):
18 | def __init__(self, per_channel=False, quantizer=None, axis=None, n_groups=None, *args,
19 | **kwargs):
20 | super().__init__(*args, **kwargs)
21 | self.register_buffer('current_xmin', None)
22 | self.register_buffer('current_xmax', None)
23 | self.per_channel = per_channel
24 | self.quantizer = quantizer
25 | self.axis = axis
26 | self.n_groups = n_groups
27 |
28 | self.per_group_range_estimation = False
29 | self.ranges = None
30 |
31 | def forward(self, x):
32 | """
33 | Accepts an input tensor, updates the current estimates of x_min and x_max
34 | and retruns them.
35 | Parameters
36 | ----------
37 | x: Input tensor
38 |
39 | Returns
40 | -------
41 | self.current_xmin: tensor
42 | self.current_xmax: tensor
43 | """
44 | raise NotImplementedError()
45 |
46 | def reset(self):
47 | """
48 | Reset the range estimator.
49 | """
50 | self.current_xmin = None
51 | self.current_xmax = None
52 |
53 | def __repr__(self):
54 | # We overwrite this from nn.Module as we do not want to have submodules such as
55 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module.
56 | lines = self.extra_repr().split('\n')
57 | extra_str = lines[0] if len(lines) == 1 else '\n ' + '\n '.join(lines) + '\n'
58 |
59 | return self._get_name() + '(' + extra_str + ')'
60 |
61 |
62 | class CurrentMinMaxEstimator(RangeEstimatorBase):
63 | def __init__(self, percentile=None, *args, **kwargs):
64 | self.percentile = percentile
65 | super().__init__(*args, **kwargs)
66 |
67 | def forward(self, x):
68 | if self.per_group_range_estimation:
69 | assert self.axis != 0
70 | x = x.transpose(0, self.axis).contiguous()
71 | x = x.view(x.size(0), -1)
72 |
73 | ranges = x.max(-1)[0].detach() - x.min(-1)[0].detach()
74 |
75 | if self.ranges is None:
76 | self.ranges = ranges
77 | else:
78 | momentum = 0.1
79 | self.ranges = momentum * ranges + (1 - momentum) * ranges
80 | return
81 |
82 | if self.axis is not None:
83 | if self.axis != 0:
84 | x = x.transpose(0, self.axis).contiguous()
85 | x = x.view(x.size(0), -1)
86 |
87 | if self.n_groups is not None:
88 | ng = self.n_groups
89 | assert ng > 0 and x.size(0) % ng == 0
90 | gs = x.size(0) // ng
91 |
92 | # permute
93 | if self.ranges is not None:
94 | i = torch.argsort(self.ranges)
95 | I = torch.eye(len(i), device=self.ranges.device)
96 | P = I[i]
97 | x = P.mm(x)
98 |
99 | x = x.view(ng, -1)
100 | m = x.min(-1)[0].detach()
101 | M = x.max(-1)[0].detach()
102 |
103 | m = m.repeat_interleave(gs)
104 | M = M.repeat_interleave(gs)
105 |
106 | # permute back
107 | if self.ranges is not None:
108 | m = P.T.mv(m)
109 | M = P.T.mv(M)
110 |
111 | self.current_xmin = m
112 | self.current_xmax = M
113 |
114 | else:
115 | self.current_xmin = x.min(-1)[0].detach()
116 | self.current_xmax = x.max(-1)[0].detach()
117 |
118 | elif self.per_channel:
119 | # Along 1st dim
120 | x_flattened = x.view(x.shape[0], -1)
121 | if self.percentile:
122 | data_np = to_numpy(x_flattened)
123 | x_min, x_max = np.percentile(
124 | data_np, (self.percentile, 100 - self.percentile), axis=-1
125 | )
126 | self.current_xmin = torch.Tensor(x_min)
127 | self.current_xmax = torch.Tensor(x_max)
128 | else:
129 | self.current_xmin = x_flattened.min(-1)[0].detach()
130 | self.current_xmax = x_flattened.max(-1)[0].detach()
131 |
132 | else:
133 | if self.percentile:
134 | device = x.device
135 | data_np = to_numpy(x)
136 | x_min, x_max = np.percentile(data_np, (self.percentile, 100))
137 | x_min = np.atleast_1d(x_min)
138 | x_max = np.atleast_1d(x_max)
139 | self.current_xmin = torch.Tensor(x_min).to(device).detach()
140 | self.current_xmax = torch.Tensor(x_max).to(device).detach()
141 | else:
142 | self.current_xmin = torch.min(x).detach()
143 | self.current_xmax = torch.max(x).detach()
144 |
145 | return self.current_xmin, self.current_xmax
146 |
147 |
148 | class AllMinMaxEstimator(RangeEstimatorBase):
149 | def __init__(self, *args, **kwargs):
150 | super().__init__(*args, **kwargs)
151 |
152 | def forward(self, x):
153 | if self.per_channel:
154 | # Along 1st dim
155 | x_flattened = x.view(x.shape[0], -1)
156 | x_min = x_flattened.min(-1)[0].detach()
157 | x_max = x_flattened.max(-1)[0].detach()
158 | else:
159 | x_min = torch.min(x).detach()
160 | x_max = torch.max(x).detach()
161 |
162 | if self.current_xmin is None:
163 | self.current_xmin = x_min
164 | self.current_xmax = x_max
165 | else:
166 | self.current_xmin = torch.min(self.current_xmin, x_min)
167 | self.current_xmax = torch.max(self.current_xmax, x_max)
168 |
169 | return self.current_xmin, self.current_xmax
170 |
171 |
172 | class RunningMinMaxEstimator(RangeEstimatorBase):
173 | def __init__(self, momentum=0.9, *args, **kwargs):
174 | self.momentum = momentum
175 | super().__init__(*args, **kwargs)
176 |
177 | def forward(self, x):
178 | if self.axis is not None:
179 | if self.axis != 0:
180 | x = x.transpose(0, self.axis).contiguous()
181 | x = x.view(x.size(0), -1)
182 |
183 | if self.n_groups is not None:
184 | ng = self.n_groups
185 | assert ng > 0 and x.size(0) % ng == 0
186 | gs = x.size(0) // ng
187 |
188 | x = x.view(ng, -1)
189 | m = x.min(-1)[0].detach()
190 | M = x.max(-1)[0].detach()
191 |
192 | x_min = m.repeat_interleave(gs)
193 | x_max = M.repeat_interleave(gs)
194 |
195 | else:
196 | x_min = x.min(-1)[0].detach()
197 | x_max = x.max(-1)[0].detach()
198 |
199 | elif self.per_channel:
200 | # Along 1st dim
201 | x_flattened = x.view(x.shape[0], -1)
202 | x_min = x_flattened.min(-1)[0].detach()
203 | x_max = x_flattened.max(-1)[0].detach()
204 |
205 | else:
206 | x_min = torch.min(x).detach()
207 | x_max = torch.max(x).detach()
208 |
209 | if self.current_xmin is None:
210 | self.current_xmin = x_min
211 | self.current_xmax = x_max
212 | else:
213 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin
214 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax
215 |
216 | return self.current_xmin, self.current_xmax
217 |
218 |
219 | class OptMethod(Enum):
220 | grid = 1
221 | golden_section = 2
222 |
223 | @classmethod
224 | def list(cls):
225 | return [m.name for m in cls]
226 |
227 |
228 | class MSE_Estimator(RangeEstimatorBase):
229 | def __init__(self, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, *args,
230 | **kwargs):
231 | super().__init__(*args, **kwargs)
232 | assert opt_method in OptMethod
233 |
234 | self.opt_method = opt_method
235 | self.num_candidates = num_candidates
236 | self.loss_array = None
237 | self.max_pos_thr = None
238 | self.max_neg_thr = None
239 | self.max_search_range = None
240 | self.one_sided_dist = None
241 | self.range_margin = range_margin
242 | if self.quantizer is None:
243 | raise NotImplementedError(
244 | 'A Quantizer must be given as an argument to the MSE Range' 'Estimator'
245 | )
246 | self.max_int_skew = (2 ** self.quantizer.n_bits) // 4 # for asymmetric quantization
247 |
248 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False):
249 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr)
250 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1)
251 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel
252 | # grid search
253 | if per_channel_loss:
254 | return to_numpy(temp_sum)
255 | else:
256 | return to_numpy(torch.sum(temp_sum))
257 |
258 | @property
259 | def step_size(self):
260 | if self.one_sided_dist is None:
261 | raise NoDataPassedError()
262 |
263 | return self.max_search_range / self.num_candidates
264 |
265 | @property
266 | def optimization_method(self):
267 | if self.one_sided_dist is None:
268 | raise NoDataPassedError()
269 |
270 | if self.opt_method == OptMethod.grid:
271 | # Grid search method
272 | if self.one_sided_dist or self.quantizer.symmetric:
273 | # 1-D grid search
274 | return self._perform_1D_search
275 | else:
276 | # 2-D grid_search
277 | return self._perform_2D_search
278 | elif self.opt_method == OptMethod.golden_section:
279 | # Golden section method
280 | if self.one_sided_dist or self.quantizer.symmetric:
281 | return self._golden_section_symmetric
282 | else:
283 | return self._golden_section_asymmetric
284 | else:
285 | raise NotImplementedError('Optimization Method not Implemented')
286 |
287 | def quantize(self, x_float, x_min=None, x_max=None):
288 | temp_q = copy.deepcopy(self.quantizer)
289 | # In the current implementation no optimization procedure requires temp quantizer for
290 | # loss_fx to be per-channel
291 | temp_q.per_channel = False
292 | if x_min or x_max:
293 | temp_q.set_quant_range(x_min, x_max)
294 | return temp_q(x_float)
295 |
296 | def golden_sym_loss(self, range, data):
297 | """
298 | Loss function passed to the golden section optimizer from scipy in case of symmetric
299 | quantization
300 | """
301 | neg_thr = 0 if self.one_sided_dist else -range
302 | pos_thr = range
303 | return self.loss_fx(data, neg_thr, pos_thr)
304 |
305 | def golden_asym_shift_loss(self, shift, range, data):
306 | """
307 | Inner Loss function (shift) passed to the golden section optimizer from scipy
308 | in case of asymmetric quantization
309 | """
310 | pos_thr = range + shift
311 | neg_thr = -range + shift
312 | return self.loss_fx(data, neg_thr, pos_thr)
313 |
314 | def golden_asym_range_loss(self, range, data):
315 | """
316 | Outer Loss function (range) passed to the golden section optimizer from scipy in case of
317 | asymmetric quantization
318 | """
319 | temp_delta = 2 * range / (2 ** self.quantizer.n_bits - 1)
320 | max_shift = temp_delta * self.max_int_skew
321 | result = minimize_scalar(
322 | self.golden_asym_shift_loss,
323 | args=(range, data),
324 | bounds=(-max_shift, max_shift),
325 | method='Bounded',
326 | )
327 | return result.fun
328 |
329 | def _define_search_range(self, data):
330 | self.channel_groups = len(data) if self.per_channel else 1
331 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device)
332 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device)
333 |
334 | if self.one_sided_dist or self.quantizer.symmetric:
335 | # 1D search space
336 | self.loss_array = np.zeros(
337 | (self.channel_groups, self.num_candidates + 1)
338 | ) # 1D search space
339 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish
340 | # Defining the search range for clipping thresholds
341 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin
342 | self.max_neg_thr = -self.max_pos_thr
343 | self.max_search_range = self.max_pos_thr
344 | else:
345 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth
346 | # index represents whether the skew is positive (0) or negative (1))
347 | self.loss_array = np.zeros(
348 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2]
349 | ) # 2D search space
350 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish
351 | # Define the search range for clipping thresholds in asymmetric case
352 | self.max_pos_thr = float(data.max()) + self.range_margin
353 | self.max_neg_thr = float(data.min()) - self.range_margin
354 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr))
355 |
356 | def _perform_1D_search(self, data):
357 | """
358 | Grid search through all candidate quantizers in 1D to find the best
359 | The loss is accumulated over all batches without any momentum
360 | :param data: input tensor
361 | """
362 | for cand_index in range(1, self.num_candidates + 1):
363 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index
364 | pos_thr = self.step_size * cand_index
365 |
366 | self.loss_array[:, cand_index] += self.loss_fx(
367 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel
368 | )
369 | # find the best clipping thresholds
370 | min_cand = self.loss_array.argmin(axis=1)
371 | xmin = (
372 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand
373 | ).astype(np.single)
374 | xmax = (self.step_size * min_cand).astype(np.single)
375 | self.current_xmax = torch.tensor(xmax).to(device=data.device)
376 | self.current_xmin = torch.tensor(xmin).to(device=data.device)
377 |
378 | def _perform_2D_search(self, data):
379 | """
380 | Grid search through all candidate quantizers in 1D to find the best
381 | The loss is accumulated over all batches without any momentum
382 | Parameters
383 | ----------
384 | data: PyTorch Tensor
385 | Returns
386 | -------
387 |
388 | """
389 | for cand_index in range(1, self.num_candidates + 1):
390 | # defining the symmetric quantization range
391 | temp_start = -self.step_size * cand_index
392 | temp_finish = self.step_size * cand_index
393 | temp_delta = float(temp_finish - temp_start) / (2 ** self.quantizer.n_bits - 1)
394 | for shift in range(self.max_int_skew):
395 | for reverse in range(2):
396 | # introducing asymmetry in the quantization range
397 | skew = ((-1) ** reverse) * shift * temp_delta
398 | neg_thr = max(temp_start + skew, self.max_neg_thr)
399 | pos_thr = min(temp_finish + skew, self.max_pos_thr)
400 |
401 | self.loss_array[:, cand_index, shift, reverse] += self.loss_fx(
402 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel
403 | )
404 |
405 | for channel_index in range(self.channel_groups):
406 | min_cand, min_shift, min_reverse = np.unravel_index(
407 | np.argmin(self.loss_array[channel_index], axis=None),
408 | self.loss_array[channel_index].shape,
409 | )
410 | min_interval_start = -self.step_size * min_cand
411 | min_interval_finish = self.step_size * min_cand
412 | min_delta = float(min_interval_finish - min_interval_start) / (
413 | 2 ** self.quantizer.n_bits - 1
414 | )
415 | min_skew = ((-1) ** min_reverse) * min_shift * min_delta
416 | xmin = max(min_interval_start + min_skew, self.max_neg_thr)
417 | xmax = min(min_interval_finish + min_skew, self.max_pos_thr)
418 |
419 | self.current_xmin[channel_index] = torch.tensor(xmin).to(device=data.device)
420 | self.current_xmax[channel_index] = torch.tensor(xmax).to(device=data.device)
421 |
422 | def _golden_section_symmetric(self, data):
423 | for channel_index in range(self.channel_groups):
424 | if channel_index == 0 and not self.per_channel:
425 | data_segment = data
426 | else:
427 | data_segment = data[channel_index]
428 |
429 | self.result = minimize_scalar(
430 | self.golden_sym_loss,
431 | args=data_segment,
432 | bounds=(0.01 * self.max_search_range, self.max_search_range),
433 | method='Bounded',
434 | )
435 | self.current_xmax[channel_index] = torch.tensor(self.result.x).to(device=data.device)
436 | self.current_xmin[channel_index] = (
437 | torch.tensor(0.0).to(device=data.device)
438 | if self.one_sided_dist
439 | else -self.current_xmax[channel_index]
440 | )
441 |
442 | def _golden_section_asymmetric(self, data):
443 | for channel_index in range(self.channel_groups):
444 | if channel_index == 0 and not self.per_channel:
445 | data_segment = data
446 | else:
447 | data_segment = data[channel_index]
448 |
449 | self.result = minimize_scalar(
450 | self.golden_asym_range_loss,
451 | args=data_segment,
452 | bounds=(0.01 * self.max_search_range, self.max_search_range),
453 | method='Bounded',
454 | )
455 | self.final_range = self.result.x
456 | temp_delta = 2 * self.final_range / (2 ** self.quantizer.n_bits - 1)
457 | max_shift = temp_delta * self.max_int_skew
458 | self.subresult = minimize_scalar(
459 | self.golden_asym_shift_loss,
460 | args=(self.final_range, data_segment),
461 | bounds=(-max_shift, max_shift),
462 | method='Bounded',
463 | )
464 | self.final_shift = self.subresult.x
465 | self.current_xmax[channel_index] = torch.tensor(self.final_range + self.final_shift).to(
466 | device=data.device
467 | )
468 | self.current_xmin[channel_index] = torch.tensor(
469 | -self.final_range + self.final_shift
470 | ).to(device=data.device)
471 |
472 | def forward(self, data):
473 | if self.loss_array is None:
474 | # Initialize search range on first batch, and accumulate losses with subsequent calls
475 |
476 | # Decide whether input distribution is one-sided
477 | if self.one_sided_dist is None:
478 | self.one_sided_dist = bool((data.min() >= 0).item())
479 |
480 | # Define search
481 | self._define_search_range(data)
482 |
483 | # Perform Search/Optimization for Quantization Ranges
484 | self.optimization_method(data)
485 |
486 | return self.current_xmin, self.current_xmax
487 |
488 | def reset(self):
489 | super().reset()
490 | self.loss_array = None
491 |
492 |
493 | class CrossEntropyEstimator(MSE_Estimator):
494 | def __init__(self, *args, **kwargs):
495 | super().__init__(*args, **kwargs)
496 |
497 | # per_channel_loss argument is here only to be consistent in definition with other loss fxs
498 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False):
499 | quantized_data = self.quantize(data, neg_thr, pos_thr)
500 | log_quantized_probs = F.log_softmax(quantized_data, dim=1)
501 | unquantized_probs = F.softmax(data, dim=1)
502 | return to_numpy(torch.sum(-unquantized_probs * log_quantized_probs))
503 |
504 |
505 | class NoDataPassedError(Exception):
506 | """Raised data has been passed inot the Range Estimator."""
507 |
508 | def __init__(self):
509 | super().__init__('Data must be pass through the range estimator to be initialized')
510 |
511 |
512 | RangeEstimatorMap = namedtuple('RangeEstimatorMap', ['value', 'cls'])
513 |
514 |
515 | class RangeEstimators(Enum):
516 | current_minmax = RangeEstimatorMap(0, CurrentMinMaxEstimator)
517 | allminmax = RangeEstimatorMap(1, AllMinMaxEstimator)
518 | running_minmax = RangeEstimatorMap(2, RunningMinMaxEstimator)
519 | MSE = RangeEstimatorMap(3, MSE_Estimator)
520 | cross_entropy = RangeEstimatorMap(4, CrossEntropyEstimator)
521 |
522 | @property
523 | def cls(self):
524 | return self.value.cls
525 |
526 | @classmethod
527 | def list(cls):
528 | return [m.name for m in cls]
529 |
--------------------------------------------------------------------------------
/quantization/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def to_numpy(tensor):
9 | """
10 | Helper function that turns the given tensor into a numpy array.
11 |
12 | Parameters
13 | ----------
14 | tensor : torch.Tensor
15 |
16 | Returns
17 | -------
18 | tensor : float or np.array
19 | """
20 | if isinstance(tensor, np.ndarray):
21 | return tensor
22 | if hasattr(tensor, 'is_cuda'):
23 | if tensor.is_cuda:
24 | return tensor.cpu().detach().numpy()
25 | if hasattr(tensor, 'detach'):
26 | return tensor.detach().numpy()
27 | if hasattr(tensor, 'numpy'):
28 | return tensor.numpy()
29 |
30 | return np.array(tensor)
31 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | click~=7.0
2 | datasets~=1.1.3
3 | numpy~=1.17.0
4 | scikit-learn~=0.19.1
5 | scipy~=1.3.1
6 | tensorboardX~=1.7
7 | torchsummary
8 | tqdm
9 | transformers~=4.1.0
10 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from utils.adaround_utils import apply_adaround_to_model
5 | from utils.glue_tasks import GLUE_Task, TASK_TO_FINAL_METRIC, load_task_data, make_compute_metric_fn
6 | from utils.hf_models import HF_Models, load_model_and_tokenizer
7 | from utils.per_embd_quant_utils import (
8 | hijack_act_quant,
9 | hijack_weight_quant,
10 | hijack_act_quant_modules,
11 | set_act_quant_axis_and_groups,
12 | )
13 | from utils.qat_utils import prepare_model_for_quantization
14 | from utils.quant_click_options import (
15 | quantization_options,
16 | activation_quantization_options,
17 | qat_options,
18 | adaround_options,
19 | make_qparams,
20 | split_dict,
21 | )
22 | from utils.tb_utils import _tb_advance_global_step, _tb_advance_token_counters, _tb_hist
23 | from utils.transformer_click_options import (
24 | glue_options,
25 | transformer_base_options,
26 | transformer_data_options,
27 | transformer_model_options,
28 | transformer_training_options,
29 | transformer_progress_options,
30 | transformer_quant_options,
31 | )
32 | from utils.utils import (
33 | seed_all,
34 | count_params,
35 | count_embedding_params,
36 | pass_data_for_range_estimation,
37 | DotDict,
38 | Stopwatch,
39 | StopForwardException,
40 | )
41 |
--------------------------------------------------------------------------------
/utils/adaround_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import logging
5 |
6 | import torch
7 |
8 | from quantization.adaround import apply_adaround_to_layer
9 | from quantization.adaround.utils import AdaRoundActQuantMode
10 | from quantization.base_quantized_classes import QuantizedModule
11 | from utils.utils import pass_data_for_range_estimation, Stopwatch
12 |
13 |
14 | # setup logger
15 | logger = logging.getLogger('AdaRound')
16 | logger.setLevel(logging.INFO)
17 |
18 |
19 | def get_train_samples(data_loader, num_samples, return_labels=False, inp_idx=0, lbl_idx=1):
20 | X, y = [], []
21 | for data in data_loader:
22 | X.append(data[inp_idx])
23 | if return_labels:
24 | y.append(data[lbl_idx])
25 | if len(X) * data[inp_idx].size(0) >= num_samples:
26 | break
27 |
28 | X = torch.cat(X, dim=0)[:num_samples]
29 | if return_labels:
30 | y = torch.cat(y, dim=0)[:num_samples]
31 | return X, y
32 | return X
33 |
34 |
35 | def apply_adaround_to_model(config, model, data_loader, range_est_data_loader, batch_size,
36 | driver=None, get_samples_fn=get_train_samples, inp_idx=0):
37 | """
38 | Apply AdaRound to `model`.
39 |
40 | Parameters
41 | ----------
42 | config : DotDict
43 | DotDict with quantization parameters
44 | model : QuantizedModel
45 | model to apply AdaRound on
46 | data_loader : torch.utils.data.DataLoader
47 | Training or other data used for AdaRound optimization
48 | driver : SupervisedDriver
49 | Used for validation. This is only used fore reporting accuracy, not for optimization
50 | inp_idx : int, str
51 | batch index in the input data from the dataloader
52 | """
53 | train_data = get_samples_fn(data_loader, num_samples=config.adaround.num_samples)
54 |
55 | device = next(model.parameters()).device
56 | train_data = train_data.to(device)
57 |
58 | # check and prepare list of layers to optimize for
59 | all_layer_names = []
60 | for name, module in model.named_modules():
61 | if isinstance(module, QuantizedModule) and hasattr(module, 'weight'):
62 | all_layer_names.append(name)
63 |
64 | if 'all' in config.adaround.layers:
65 | adaround_layer_names = all_layer_names
66 | else:
67 | adaround_layer_names = []
68 | for name in config.adaround.layers:
69 | if name in all_layer_names:
70 | adaround_layer_names.append(name)
71 | else:
72 | logger.warning(f'skipping unknown layer {name}')
73 |
74 | if not len(adaround_layer_names):
75 | logger.warning('No layers to apply AdaRound for, exiting...')
76 | return
77 |
78 | # deal with activation quantization
79 | if config.adaround.act_quant_mode in (
80 | AdaRoundActQuantMode.no_act_quant,
81 | AdaRoundActQuantMode.post_adaround,
82 | ):
83 | config.quant.act_quant = False
84 | model.reset_act_ranges()
85 | model.full_precision_acts()
86 | else:
87 | raise NotImplementedError(f"act mode '{config.adaround.act_quant_mode}' is not implemented")
88 |
89 | # main loop
90 | s_all = Stopwatch()
91 | for name, module in model.named_modules():
92 | if not name in adaround_layer_names:
93 | continue
94 |
95 | logger.info(f'Started AdaRound for layer {name}')
96 |
97 | model.full_precision()
98 | module.quantized_weights()
99 |
100 | s_all.start()
101 | with Stopwatch() as s_layer:
102 | apply_adaround_to_layer(
103 | model,
104 | module,
105 | train_data,
106 | batch_size=batch_size,
107 | act_quant=config.quant.act_quant,
108 | adaround_config=config.adaround,
109 | )
110 | logger.info(f'Done AdaRound for layer {name}. {s_layer.format()}\n')
111 | s_all.stop()
112 |
113 | s_all.stop()
114 | logger.info(f'Done optimizing all layers. {s_all.format()}')
115 |
116 | if config.adaround.act_quant_mode == AdaRoundActQuantMode.post_adaround:
117 | if driver is not None:
118 | # validate before activation quantization
119 | model.quantized_weights()
120 | state = driver.validate()
121 | acc_quant = state.metrics['top_1_accuracy']
122 | logger.info(f'FINAL res (without acts quant):\t{acc_quant * 100:.2f}%')
123 |
124 | # activate activation quantization and estimate ranges
125 | config.quant.act_quant = True
126 | model.estimate_act_ranges()
127 | pass_data_for_range_estimation(
128 | loader=range_est_data_loader,
129 | model=model,
130 | act_quant=True,
131 | weight_quant=True,
132 | max_num_batches=config.act_quant.num_batches,
133 | cross_entropy_layer=config.act_quant.cross_entropy_layer,
134 | inp_idx=inp_idx,
135 | )
136 | model.fix_act_ranges()
137 |
138 | # set state
139 | model.set_quant_state(weight_quant=True, act_quant=config.quant.act_quant)
140 |
--------------------------------------------------------------------------------
/utils/glue_tasks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import logging
5 | from enum import Flag, auto
6 | from functools import reduce
7 | from operator import or_
8 |
9 | import numpy as np
10 | from datasets import load_dataset, load_metric
11 | from transformers import EvalPrediction
12 |
13 | from utils.utils import DotDict
14 |
15 |
16 | # setup logger
17 | logger = logging.getLogger('GLUE')
18 | logger.setLevel(logging.INFO)
19 |
20 |
21 | class GLUE_Task(Flag):
22 | cola = auto()
23 | sst2 = auto()
24 | mrpc = auto()
25 | stsb = auto()
26 | qqp = auto()
27 | mnli = auto()
28 | qnli = auto()
29 | rte = auto()
30 | wnli = auto()
31 | all = cola | sst2 | mrpc | stsb | qqp | mnli | qnli | rte | wnli
32 |
33 | def __contains__(self, item):
34 | return (self.value & item.value) == item.value
35 |
36 | @classmethod
37 | def from_str(cls, *names):
38 | """Construct flag from strings."""
39 | assert len(names)
40 | return reduce(or_, map(cls.__getattr__, names))
41 |
42 | @classmethod
43 | def list_names(cls):
44 | """List all flags, including `all`."""
45 | return [m.name for m in cls]
46 |
47 | def iter(self):
48 | """List all member flags which are set (excluding `all`)."""
49 | for x in self.__class__.__members__.values():
50 | if x in self and x != self.__class__.all:
51 | yield x
52 |
53 | def iter_names(self):
54 | """List all member flag names which are set (excluding `all`)."""
55 | for x in self.iter():
56 | yield x.name
57 |
58 |
59 | TASK_TO_SENTENCE_KEYS = {
60 | GLUE_Task.cola: ('sentence', None),
61 | GLUE_Task.sst2: ('sentence', None),
62 | GLUE_Task.mrpc: ('sentence1', 'sentence2'),
63 | GLUE_Task.stsb: ('sentence1', 'sentence2'),
64 | GLUE_Task.qqp: ('question1', 'question2'),
65 | GLUE_Task.mnli: ('premise', 'hypothesis'),
66 | GLUE_Task.qnli: ('question', 'sentence'),
67 | GLUE_Task.rte: ('sentence1', 'sentence2'),
68 | GLUE_Task.wnli: ('sentence1', 'sentence2'),
69 | }
70 |
71 |
72 | TASK_TO_FINAL_METRIC = {
73 | GLUE_Task.cola: 'matthews_correlation',
74 | GLUE_Task.sst2: 'accuracy',
75 | GLUE_Task.mrpc: 'combined_score',
76 | GLUE_Task.stsb: 'combined_score',
77 | GLUE_Task.qqp: 'combined_score',
78 | GLUE_Task.mnli: 'accuracy',
79 | GLUE_Task.qnli: 'accuracy',
80 | GLUE_Task.rte: 'accuracy',
81 | GLUE_Task.wnli: 'accuracy',
82 | }
83 |
84 |
85 | TASK_N = {
86 | GLUE_Task.mnli: 392702,
87 | GLUE_Task.qqp: 363846,
88 | GLUE_Task.qnli: 104743,
89 | GLUE_Task.sst2: 67349,
90 | GLUE_Task.cola: 8551,
91 | GLUE_Task.stsb: 5749,
92 | GLUE_Task.mrpc: 3665,
93 | GLUE_Task.rte: 2490,
94 | GLUE_Task.wnli: 635,
95 | }
96 |
97 |
98 | def load_task_data(task: GLUE_Task, data_dir: str):
99 | out = DotDict()
100 |
101 | # download and load data
102 | logger.info(f'Getting {task.name} dataset ...\n')
103 | out.datasets = load_dataset('glue', task.name, cache_dir=data_dir)
104 |
105 | # determine number of labels
106 | logger.info('Determine labels ...\n')
107 | if task == GLUE_Task.stsb: # regression
108 | out.num_labels = 1
109 | logger.info(f'{task.name}: 1 label -- ')
110 | else:
111 | label_list = out.datasets["train"].features["label"].names
112 | out.num_labels = n_labels = len(label_list)
113 | logger.info(f'{task.name}: {n_labels} labels -- {label_list}')
114 |
115 | # store sentence keys
116 | out.sentence1_key, out.sentence2_key = TASK_TO_SENTENCE_KEYS[task]
117 | return out
118 |
119 |
120 | def make_compute_metric_fn(task: GLUE_Task):
121 | metric = load_metric('glue', task.name)
122 | logger.info('Metric:')
123 | logger.info(metric)
124 |
125 | def fn(p: EvalPrediction):
126 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
127 | preds = np.squeeze(preds) if task == GLUE_Task.stsb else np.argmax(preds, axis=1)
128 | result = metric.compute(predictions=preds, references=p.label_ids)
129 | if len(result) > 1:
130 | result['combined_score'] = np.mean(list(result.values())).item()
131 | return result
132 |
133 | return fn
134 |
--------------------------------------------------------------------------------
/utils/hf_models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import logging
5 | from enum import Enum
6 |
7 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
8 |
9 | from utils.utils import count_embedding_params, count_params, DotDict
10 |
11 |
12 | logger = logging.getLogger('GLUE')
13 | logger.setLevel(logging.ERROR)
14 |
15 |
16 | class HF_Models(Enum):
17 | # vanilla BERT
18 | bert_base_uncased = 'bert-base-uncased'
19 | bert_large_uncased = 'bert-large-uncased'
20 | bert_base_cased = 'bert-base-cased'
21 |
22 | # RoBERTa
23 | roberta_base = 'roberta-base'
24 |
25 | # Distilled vanilla models
26 | distilbert_base_uncased = 'distilbert-base-uncased'
27 | distilroberta_base = 'distilroberta-base'
28 |
29 | # Models optimized for runtime on mobile devices
30 | mobilebert_uncased = 'google/mobilebert-uncased'
31 | squeezebert_uncased = 'squeezebert/squeezebert-uncased'
32 |
33 | # ALBERT: very small set of models (optimized for memory)
34 | albert_base_v2 = 'albert-base-v2'
35 | albert_large_v2 = 'albert-large-v2'
36 |
37 | @classmethod
38 | def list_names(cls):
39 | return [m.name for m in cls]
40 |
41 |
42 | MODEL_TO_BACKBONE_ATTR = { # model..
43 | HF_Models.bert_base_uncased: 'bert',
44 | HF_Models.bert_large_uncased: 'bert',
45 | HF_Models.bert_base_cased: 'bert',
46 | HF_Models.distilroberta_base: 'roberta',
47 | HF_Models.roberta_base: 'roberta',
48 | HF_Models.mobilebert_uncased: 'mobilebert',
49 | }
50 |
51 |
52 | def load_model_and_tokenizer(model_name, model_path, use_fast_tokenizer, cache_dir, attn_dropout,
53 | hidden_dropout, num_labels, **kw):
54 | del kw # unused
55 |
56 | out = DotDict()
57 |
58 | # Config
59 | if model_path is not None:
60 | model_name_or_path = model_path
61 | else:
62 | model_name_or_path = HF_Models[model_name].value # use HF identifier
63 |
64 | config = AutoConfig.from_pretrained(
65 | model_name_or_path,
66 | num_labels=num_labels,
67 | cache_dir=cache_dir,
68 | )
69 |
70 | # set dropout rates
71 | if attn_dropout is not None:
72 | logger.info(f'Setting attn dropout to {attn_dropout}')
73 | if hasattr(config, 'attention_probs_dropout_prob'):
74 | setattr(config, 'attention_probs_dropout_prob', attn_dropout)
75 |
76 | if hidden_dropout is not None:
77 | logger.info(f'Setting hidden dropout to {hidden_dropout}')
78 | if hasattr(config, 'hidden_dropout_prob'):
79 | setattr(config, 'hidden_dropout_prob', attn_dropout)
80 |
81 | logger.info('HuggingFace model config:')
82 | logger.info(config)
83 | out.config = config
84 | out.model_name_or_path = model_name_or_path
85 |
86 | # Tokenizer
87 | tokenizer = AutoTokenizer.from_pretrained(
88 | model_name_or_path,
89 | use_fast=use_fast_tokenizer,
90 | cache_dir=cache_dir,
91 | )
92 | logger.info('Tokenizer:')
93 | logger.info(tokenizer)
94 | out.tokenizer = tokenizer
95 |
96 | # Model
97 | model = AutoModelForSequenceClassification.from_pretrained(
98 | model_name_or_path,
99 | from_tf=False,
100 | config=config,
101 | cache_dir=cache_dir,
102 | )
103 | logger.info('Model:')
104 | logger.info(model)
105 | out.model = model
106 |
107 | # Parameter counts
108 | total_params = count_params(model)
109 | embedding_params = count_embedding_params(model)
110 | non_embedding_params = total_params - embedding_params
111 | logger.info(f'Parameters (embedding): {embedding_params}')
112 | logger.info(f'Parameters (non-embedding): {non_embedding_params}')
113 | logger.info(f'Parameters (total): {total_params}')
114 | out.total_params = total_params
115 | out.embedding_params = embedding_params
116 | out.non_embedding_params = non_embedding_params
117 |
118 | # Additional attributes
119 | out.model_enum = HF_Models[model_name]
120 | out.backbone_attr = MODEL_TO_BACKBONE_ATTR.get(out.model_enum, None)
121 | return out
122 |
--------------------------------------------------------------------------------
/utils/per_embd_quant_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from quantization.base_quantized_classes import FP32Acts
5 |
6 |
7 | def _hijack_act_quant(module, value):
8 | if value is None:
9 | return
10 |
11 | if isinstance(value, int):
12 | module.activation_quantizer.quantizer.n_bits = value
13 | elif value == 'fp32':
14 | module.activation_quantizer = FP32Acts()
15 | elif value == 'per_embd':
16 | set_act_quant_axis_and_groups(module, axis=2, n_groups=None)
17 | elif value.startswith('ngp'):
18 | set_act_quant_axis_and_groups(module, axis=2, n_groups=int(value[3:]), permute=True)
19 | elif value.startswith('ng'):
20 | set_act_quant_axis_and_groups(module, axis=2, n_groups=int(value[2:]), permute=False)
21 | else:
22 | raise NotImplementedError(f'Unknown value "{value}" in quant_dict')
23 |
24 |
25 | def _hijack_weight_quant(module, value):
26 | if value is None:
27 | return
28 |
29 | if isinstance(value, int):
30 | module.weight_quantizer.quantizer.n_bits = value
31 | elif value == 'fp32':
32 | module.weight_quantizer = FP32Acts()
33 | else:
34 | raise NotImplementedError(f'Unknown value "{value}" in quant_dict')
35 |
36 |
37 | def hijack_act_quant(quant_dict, name, m):
38 | value = quant_dict.get(name, None)
39 | _hijack_act_quant(m, value)
40 |
41 |
42 | def hijack_weight_quant(quant_dict, name, m):
43 | value = quant_dict.get(name, None)
44 | _hijack_weight_quant(m, value)
45 |
46 |
47 | def hijack_act_quant_modules(quant_dict, name, m):
48 | value = quant_dict.get(name, None)
49 | for m_ in m.modules():
50 | if hasattr(m_, 'activation_quantizer'):
51 | _hijack_act_quant(m_, value)
52 |
53 |
54 | def set_act_quant_axis_and_groups(module, axis, n_groups, permute=False):
55 | if hasattr(module, 'activation_quantizer'):
56 | module = module.activation_quantizer
57 |
58 | module.axis = axis
59 | module.quantizer.axis = axis
60 | module.range_estimator.axis = axis
61 |
62 | module.n_groups = n_groups
63 | module.range_estimator.n_groups = n_groups
64 |
65 | if permute:
66 | module.range_estimator.per_group_range_estimation = True
67 |
68 | return module
69 |
--------------------------------------------------------------------------------
/utils/qat_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import logging
5 |
6 | from utils.utils import pass_data_for_range_estimation
7 |
8 |
9 | # setup logger
10 | logger = logging.getLogger('QAT')
11 | logger.setLevel('INFO')
12 |
13 |
14 | def prepare_model_for_quantization(config, model, loader):
15 |
16 | # estimate ranges using training data
17 | pass_data_for_range_estimation(
18 | loader=loader,
19 | model=model,
20 | act_quant=config.quant.act_quant,
21 | weight_quant=config.quant.weight_quant,
22 | max_num_batches=config.act_quant.num_batches,
23 | cross_entropy_layer=config.act_quant.cross_entropy_layer,
24 | )
25 |
26 | # put quantizers in desirable state
27 | if config.qat.learn_ranges:
28 | logger.info('Make quantizers learnable')
29 | model.learn_ranges()
30 | else:
31 | logger.info(
32 | f'Fix quantizer ranges to fixW={config.qat.fix_weight_ranges} and '
33 | f'fixA={config.qat.fix_act_ranges}'
34 | )
35 |
36 | # freeze quantization ranges if applicable
37 | model.estimate_ranges_train() # we use updating ranges in training as default
38 | if config.qat.fix_weight_ranges:
39 | model.fix_weight_ranges()
40 | if config.qat.fix_act_ranges:
41 | model.fix_act_ranges()
42 |
43 | # ensure we have the desired quant state
44 | model.set_quant_state(config.quant.weight_quant, config.quant.act_quant)
45 | return model
46 |
--------------------------------------------------------------------------------
/utils/quant_click_options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from functools import wraps
5 |
6 | import click
7 |
8 | from quantization.adaround.config import AdaRoundConfig, DEFAULT_ADAROUND_CONFIG as C
9 | from quantization.adaround.utils import (
10 | AdaRoundActQuantMode,
11 | AdaRoundInitMode,
12 | AdaRoundMode,
13 | AdaRoundTempDecayType,
14 | )
15 | from quantization.quantizers import QMethods
16 | from quantization.range_estimators import RangeEstimators, OptMethod
17 | from utils.utils import DotDict
18 |
19 |
20 | class StrTuple(click.ParamType):
21 | name = 'str_sequence'
22 |
23 | def convert(self, value, param, ctx):
24 | values = value.split(',')
25 | values = tuple(map(lambda s: s.strip(), values))
26 | return values
27 |
28 |
29 | def split_dict(src: dict, include=()):
30 | """
31 | Splits dictionary into a DotDict and a remainder.
32 | The arguments to be placed in the first DotDict are those listed in `include`.
33 |
34 | Parameters
35 | ----------
36 | src: dict
37 | The source dictionary.
38 | include:
39 | List of keys to be returned in the first DotDict.
40 | """
41 | result = DotDict()
42 |
43 | for arg in include:
44 | result[arg] = src[arg]
45 | remainder = {key: val for key, val in src.items() if key not in include}
46 | return result, remainder
47 |
48 |
49 | def quantization_options(func):
50 | @click.option(
51 | '--qmethod',
52 | type=click.Choice(QMethods.list()),
53 | required=True,
54 | help='Quantization scheme to use.',
55 | )
56 | @click.option(
57 | '--qmethod-act',
58 | type=click.Choice(QMethods.list()),
59 | default=None,
60 | help='Quantization scheme for activation to use. If not specified `--qmethod` is used.',
61 | )
62 | @click.option(
63 | '--weight-quant-method',
64 | default=RangeEstimators.current_minmax.name,
65 | type=click.Choice(RangeEstimators.list()),
66 | help='Method to determine weight quantization clipping thresholds.',
67 | )
68 | @click.option(
69 | '--weight-opt-method',
70 | default=OptMethod.grid.name,
71 | type=click.Choice(OptMethod.list()),
72 | help='Optimization procedure for activation quantization clipping thresholds',
73 | )
74 | @click.option(
75 | '--num-candidates',
76 | type=int,
77 | default=None,
78 | help='Number of grid points for grid search in MSE range method.',
79 | )
80 | @click.option('--n-bits', default=8, type=int, help='Default number of quantization bits.')
81 | @click.option(
82 | '--n-bits-act', default=None, type=int, help='Number of quantization bits for activations.'
83 | )
84 | @click.option('--per-channel', is_flag=True, help='If given, quantize each channel separately.')
85 | @click.option(
86 | '--percentile',
87 | type=float,
88 | default=None,
89 | help='Percentile clipping parameter (weights and activations)',
90 | )
91 | @click.option(
92 | '--act-quant/--no-act-quant',
93 | is_flag=True,
94 | default=True,
95 | help='Run evaluation with activation quantization or use FP32 activations',
96 | )
97 | @click.option(
98 | '--weight-quant/--no-weight-quant',
99 | is_flag=True,
100 | default=True,
101 | help='Run evaluation weight quantization or use FP32 weights',
102 | )
103 | @click.option(
104 | '--quant-setup',
105 | default='all',
106 | type=click.Choice(['all', 'FP_logits', 'MSE_logits']),
107 | help='Method to quantize the network.',
108 | )
109 | @wraps(func)
110 | def func_wrapper(config, *args, **kwargs):
111 | config.quant, remainder_kwargs = split_dict(kwargs, [
112 | 'qmethod',
113 | 'qmethod_act',
114 | 'weight_quant_method',
115 | 'weight_opt_method',
116 | 'num_candidates',
117 | 'n_bits',
118 | 'n_bits_act',
119 | 'per_channel',
120 | 'percentile',
121 | 'act_quant',
122 | 'weight_quant',
123 | 'quant_setup',
124 | ])
125 |
126 | config.quant.qmethod_act = config.quant.qmethod_act or config.quant.qmethod
127 |
128 | return func(config, *args, **remainder_kwargs)
129 |
130 | return func_wrapper
131 |
132 |
133 | def activation_quantization_options(func):
134 | @click.option(
135 | '--act-quant-method',
136 | default=RangeEstimators.running_minmax.name,
137 | type=click.Choice(RangeEstimators.list()),
138 | help='Method to determine activation quantization clipping thresholds',
139 | )
140 | @click.option(
141 | '--act-opt-method',
142 | default=OptMethod.grid.name,
143 | type=click.Choice(OptMethod.list()),
144 | help='Optimization procedure for activation quantization clipping thresholds',
145 | )
146 | @click.option(
147 | '--act-num-candidates',
148 | type=int,
149 | default=None,
150 | help='Number of grid points for grid search in MSE/Cross-entropy',
151 | )
152 | @click.option(
153 | '--act-momentum',
154 | type=float,
155 | default=None,
156 | help='Exponential averaging factor for running_minmax',
157 | )
158 | @click.option(
159 | '--cross-entropy-layer',
160 | default=None,
161 | type=str,
162 | help='Cross-entropy for activation range setting (often valuable for last layer)',
163 | )
164 | @click.option(
165 | '--num-est-batches',
166 | type=int,
167 | default=1,
168 | help='Number of training batches to be used for activation range estimation',
169 | )
170 | @wraps(func)
171 | def func_wrapper(config, act_quant_method, act_opt_method, act_num_candidates, act_momentum,
172 | cross_entropy_layer, num_est_batches, *args, **kwargs):
173 | config.act_quant = DotDict()
174 | config.act_quant.quant_method = act_quant_method
175 | config.act_quant.cross_entropy_layer = cross_entropy_layer
176 | config.act_quant.num_batches = num_est_batches
177 |
178 | config.act_quant.options = {}
179 |
180 | if act_num_candidates is not None:
181 | if act_quant_method != 'MSE':
182 | raise ValueError('Wrong option num_candidates passed')
183 | else:
184 | config.act_quant.options['num_candidates'] = act_num_candidates
185 |
186 | if act_momentum is not None:
187 | if act_quant_method != 'running_minmax':
188 | raise ValueError('Wrong option momentum passed')
189 | else:
190 | config.act_quant.options['momentum'] = act_momentum
191 |
192 | if act_opt_method != 'grid':
193 | config.act_quant.options['opt_method'] = OptMethod[act_opt_method]
194 | return func(config, *args, **kwargs)
195 |
196 | return func_wrapper
197 |
198 |
199 | def qat_options(func):
200 | @click.option(
201 | '--learn-ranges',
202 | is_flag=True,
203 | default=False,
204 | help='Learn quantization ranges, in that case fix ranges will be ignored.',
205 | )
206 | @click.option(
207 | '--fix-act-ranges/--no-fix-act-ranges',
208 | is_flag=True,
209 | default=False,
210 | help='Fix all activation quantization ranges for stable training',
211 | )
212 | @click.option(
213 | '--fix-weight-ranges/--no-fix-weight-ranges',
214 | is_flag=True,
215 | default=False,
216 | help='Fix all weight quantization ranges for stable training',
217 | )
218 | @wraps(func)
219 | def func_wrapper(config, *args, **kwargs):
220 | config.qat, remainder_kwargs = split_dict(
221 | kwargs, ['learn_ranges', 'fix_act_ranges', 'fix_weight_ranges']
222 | )
223 |
224 | return func(config, *args, **remainder_kwargs)
225 |
226 | return func_wrapper
227 |
228 |
229 | def adaround_options(func):
230 | # Base options
231 | @click.option(
232 | '--adaround',
233 | default=None,
234 | type=StrTuple(),
235 | help="Apply AdaRound: for full model ('all'), or any number of layers, "
236 | "specified by comma-separated names.",
237 | )
238 | @click.option(
239 | '--adaround-num-samples',
240 | default=C.num_samples,
241 | type=int,
242 | help='Number of samples to use for learning the rounding.',
243 | )
244 | @click.option(
245 | '--adaround-init',
246 | default=C.init.name,
247 | type=click.Choice(AdaRoundInitMode.list_names(), case_sensitive=False),
248 | help='Method to initialize the quantization grid for weights.',
249 | )
250 |
251 | # Method and continuous relaxation options
252 | @click.option(
253 | '--adaround-mode',
254 | default=C.round_mode.name,
255 | type=click.Choice(AdaRoundMode.list_names(), case_sensitive=False),
256 | help='Method to learn the rounding.',
257 | )
258 | @click.option(
259 | '--adaround-asym/--no-adaround-asym',
260 | is_flag=True,
261 | default=C.asym,
262 | help='Whether to use asymmetric reconstruction for AdaRound.',
263 | )
264 | @click.option(
265 | '--adaround-include-act-func/--adaround-no-act-func',
266 | is_flag=True,
267 | default=C.include_act_func,
268 | help='Include activation function into AdaRound.',
269 | )
270 | @click.option(
271 | '--adaround-lr',
272 | default=C.lr,
273 | type=float,
274 | help='Learning rate for continuous relaxation in AdaRound.',
275 | )
276 | @click.option(
277 | '--adaround-iters',
278 | default=C.iters,
279 | type=int,
280 | help='Number of iterations to train each layer.',
281 | )
282 | @click.option(
283 | '--adaround-weight',
284 | default=C.weight,
285 | type=float,
286 | help='Weight of rounding cost vs the reconstruction loss.',
287 | )
288 | @click.option(
289 | '--adaround-annealing',
290 | default=C.annealing,
291 | nargs=2,
292 | type=float,
293 | help='Annealing of regularization function temperature (tuple: start, end).',
294 | )
295 | @click.option(
296 | '--adaround-decay-type',
297 | default=C.decay_type.name,
298 | type=click.Choice(AdaRoundTempDecayType.list_names(), case_sensitive=False),
299 | help='Type of temperature annealing schedule.',
300 | )
301 | @click.option(
302 | '--adaround-decay-shape',
303 | default=C.decay_shape,
304 | type=float,
305 | help="Positive "
306 | "scalar value that controls the shape of decay schedules 'sigmoid', 'power', "
307 | "'exp', 'log'. Sensible values to try: sigmoid{10}, power{4,6,8}, exp{4,6,8}, "
308 | "log{1,2,3}.",
309 | )
310 | @click.option(
311 | '--adaround-decay-start',
312 | default=C.decay_start,
313 | type=float,
314 | help='Start of annealing (relative to --ltr-iters).',
315 | )
316 | @click.option(
317 | '--adaround-warmup',
318 | default=C.warmup,
319 | type=float,
320 | help='In the warmup period no regularization is applied (relative to --ltr-iters).',
321 | )
322 |
323 | # Activation quantization
324 | @click.option(
325 | '--adaround-act-quant',
326 | default=C.act_quant_mode.name,
327 | type=click.Choice(AdaRoundActQuantMode.list_names(), case_sensitive=False),
328 | help='Method to deal with activation quantization during AdaRound.',
329 | )
330 | @wraps(func)
331 | def func_wrapper(config, *args, **kwargs):
332 | config.adaround = AdaRoundConfig(**C)
333 |
334 | config.adaround.layers = kwargs.pop('adaround')
335 | config.adaround.num_samples = kwargs.pop('adaround_num_samples')
336 | config.adaround.init = AdaRoundInitMode[kwargs.pop('adaround_init')]
337 |
338 | config.adaround.round_mode = AdaRoundMode[kwargs.pop('adaround_mode')]
339 | config.adaround.asym = kwargs.pop('adaround_asym')
340 | config.adaround.include_act_func = kwargs.pop('adaround_include_act_func')
341 | config.adaround.lr = kwargs.pop('adaround_lr')
342 | config.adaround.iters = kwargs.pop('adaround_iters')
343 | config.adaround.weight = kwargs.pop('adaround_weight')
344 | config.adaround.annealing = kwargs.pop('adaround_annealing')
345 | config.adaround.decay_type = AdaRoundTempDecayType[kwargs.pop('adaround_decay_type')]
346 | config.adaround.decay_shape = kwargs.pop('adaround_decay_shape')
347 | config.adaround.decay_start = kwargs.pop('adaround_decay_start')
348 | config.adaround.warmup = kwargs.pop('adaround_warmup')
349 |
350 | config.adaround.act_quant_mode = AdaRoundActQuantMode[kwargs.pop('adaround_act_quant')]
351 | return func(config, *args, **kwargs)
352 |
353 | return func_wrapper
354 |
355 |
356 | def make_qparams(config):
357 | weight_range_options = {}
358 | if config.quant.weight_quant_method in ['MSE', 'cross_entropy']:
359 | weight_range_options = dict(opt_method=OptMethod[config.quant.weight_opt_method])
360 | if config.quant.num_candidates is not None:
361 | weight_range_options['num_candidates'] = config.quant.num_candidates
362 |
363 | act_range_options = config.act_quant.options
364 | if config.quant.percentile is not None:
365 | act_range_options['percentile'] = config.quant.percentile
366 |
367 | params = {
368 | 'method': QMethods[config.quant.qmethod],
369 | 'act_method': QMethods[config.quant.qmethod_act],
370 | 'n_bits': config.quant.n_bits,
371 | 'n_bits_act': config.quant.n_bits_act,
372 | 'per_channel_weights': config.quant.per_channel,
373 | 'percentile': config.quant.percentile,
374 | 'quant_setup': config.quant.quant_setup,
375 | 'weight_range_method': RangeEstimators[config.quant.weight_quant_method],
376 | 'weight_range_options': weight_range_options,
377 | 'act_range_method': RangeEstimators[config.act_quant.quant_method],
378 | 'act_range_options': config.act_quant.options,
379 | }
380 | return params
381 |
--------------------------------------------------------------------------------
/utils/tb_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 |
5 | def _tb_advance_global_step(module):
6 | if hasattr(module, 'global_step'):
7 | module.global_step += 1
8 | return module
9 |
10 |
11 | def _tb_advance_token_counters(module, tensor, verbose=False):
12 | token_count = getattr(module, 'tb_token_count', None)
13 | if token_count is not None:
14 | T = tensor.size(1)
15 | if token_count.last != T:
16 | if token_count.last != 0:
17 | token_count.total += token_count.last
18 | token_count.sample_idx += 1
19 | token_count.last = T
20 |
21 | if verbose:
22 | print(f'>>> T={T}\tlast_T={token_count.last}\tcumsum_T={token_count.total}')
23 | return module
24 |
25 |
26 | def _tb_hist(module, tensor, name, verbose=False):
27 | hist_kw = dict(bins='auto')
28 |
29 | tb_writer = getattr(module, 'tb_writer', None)
30 | if tb_writer is not None:
31 | if module.layer_idx == module.num_layers - 1:
32 | tensor = tensor[:, 0]
33 |
34 | # per-tensor
35 | layer_s = str(1 + module.layer_idx).zfill(2)
36 | full_name = f'{layer_s}/layer/{name}'
37 | global_step = module.global_step
38 | if verbose:
39 | stats = f'min={tensor.min():.1f}, max={tensor.max():.1f}'
40 | info = (
41 | f'TB logging {full_name}\t{tuple(tensor.size())}\t({stats})\t'
42 | f'[global_step={global_step}] ...'
43 | )
44 | print(info)
45 | tb_writer.add_histogram(full_name, tensor, global_step=global_step, **hist_kw)
46 |
47 | # per-token
48 | sample_idx_s = str(module.tb_token_count.sample_idx + 1).zfill(2)
49 | T = tensor.size(1)
50 | full_name = f'{layer_s}/token/{sample_idx_s}/{name}'
51 | for i in range(T):
52 | tb_writer.add_histogram(full_name, tensor[0, i], global_step=i, **hist_kw)
53 |
--------------------------------------------------------------------------------
/utils/transformer_click_options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | from functools import wraps
5 | from pathlib import Path
6 |
7 | import click
8 | import torch
9 | from transformers.trainer_utils import EvaluationStrategy
10 |
11 | from utils.hf_models import HF_Models
12 | from utils.glue_tasks import GLUE_Task
13 | from utils.quant_click_options import split_dict
14 | from utils.utils import seed_all
15 |
16 |
17 | def transformer_base_options(func):
18 | @click.option(
19 | '--cuda/--no-cuda', is_flag=True, default=torch.cuda.is_available(), help='Use GPU'
20 | )
21 | @click.option('--seed', type=int, default=1000, help='Random number generator seed to set.')
22 | @click.option(
23 | '--num-workers',
24 | type=int,
25 | default=0,
26 | help='Number of PyTorch data loader subprocesses. 0 means that the data will be '
27 | 'loaded in the main process.',
28 | )
29 | @click.option(
30 | '--output-dir',
31 | default=None,
32 | type=click.Path(file_okay=False, writable=True, resolve_path=True),
33 | help='The output directory where the model predictions and checkpoints will be written. '
34 | 'The model and the tokenizer will be saved in the `out` sub-folder.',
35 | )
36 | @click.option(
37 | '--overwrite-output',
38 | is_flag=True,
39 | default=False,
40 | help='Overwrite the content of the output directory and log file. Use this to '
41 | 'continue training if output directory contains model checkpoint.',
42 | )
43 | @wraps(func)
44 | def func_wrapper(config, *args, **kwargs):
45 | attrs = ['cuda', 'seed', 'num_workers', 'output_dir', 'overwrite_output']
46 | config.base, other_kw = split_dict(kwargs, attrs)
47 | seed = config.base.seed
48 | if seed is not None:
49 | seed_all(seed) # must be set before initializing the model
50 | return func(config, *args, **other_kw)
51 |
52 | return func_wrapper
53 |
54 |
55 | def transformer_data_options(func):
56 | @click.option(
57 | '--max-seq-length',
58 | type=int,
59 | default=128,
60 | help='The maximum total input sequence length after tokenization. Sequences '
61 | 'longer than this will be truncated, sequences shorter will be padded.',
62 | )
63 | @click.option(
64 | '--pad-to-max-length/--no-pad-to-max-length',
65 | is_flag=True,
66 | default=True,
67 | help='Whether to pad all samples to `max_seq_length`. If False, will pad the '
68 | 'samples dynamically when batching to the maximum length in the batch.',
69 | )
70 | @click.option('--line-by-line', is_flag=True, default=False)
71 | @click.option(
72 | '--overwrite-cache',
73 | is_flag=True,
74 | default=False,
75 | help='Overwrite the cached preprocessed datasets or not.',
76 | )
77 | @click.option('--num-train-samples', type=int, default=None)
78 | @click.option(
79 | '--num-val-samples',
80 | type=int,
81 | default=None,
82 | help='Number of samples to use for validation. If not specified, '
83 | 'validate on the entire set(s).',
84 | )
85 | @wraps(func)
86 | def func_wrapper(config, *args, **kwargs):
87 | attrs = [
88 | 'max_seq_length',
89 | 'pad_to_max_length',
90 | 'line_by_line',
91 | 'overwrite_cache',
92 | 'num_train_samples',
93 | 'num_val_samples',
94 | ]
95 |
96 | config.data, other_kw = split_dict(kwargs, attrs)
97 | return func(config, *args, **other_kw)
98 |
99 | return func_wrapper
100 |
101 |
102 | def glue_options(func):
103 | @click.option(
104 | '--task',
105 | type=click.Choice(GLUE_Task.list_names(), case_sensitive=False),
106 | default=(GLUE_Task.mrpc.name,),
107 | multiple=True,
108 | help='The name of the task to train on.',
109 | )
110 | @click.option(
111 | '--data-dir',
112 | default=str(Path.home() / '.glue_data'),
113 | type=click.Path(file_okay=False, writable=True, resolve_path=True),
114 | help='Directory where both raw and preprocessed GLUE datasets are stored.',
115 | )
116 | @wraps(func)
117 | def func_wrapper(config, *args, **kwargs):
118 | attrs = ['task', 'data_dir']
119 |
120 | config.glue, other_kw = split_dict(kwargs, attrs)
121 | return func(config, *args, **other_kw)
122 |
123 | return func_wrapper
124 |
125 |
126 | def transformer_model_options(func):
127 | # GLUE
128 | @click.option(
129 | '--model-name',
130 | type=click.Choice(HF_Models.list_names(), case_sensitive=False),
131 | default=HF_Models.bert_base_uncased.name,
132 | help='Model identifier from huggingface.co/models.',
133 | )
134 | @click.option(
135 | '--model-path',
136 | default=None,
137 | type=click.Path(exists=True, file_okay=False, resolve_path=True),
138 | help='For training (both FP32 and quantized), it is a path to a pretrained model, together '
139 | 'with a tokenizer (can be used to resume training). For validation, it is a path that '
140 | 'should contain fine-tuned checkpoints for all the requested tasks (each in a separate '
141 | 'sub-folder named as a corresponding task).',
142 | )
143 | @click.option(
144 | '--quant-model-path',
145 | default=None,
146 | type=click.Path(exists=True, file_okay=False, resolve_path=True),
147 | help='State dict of quantized model.',
148 | )
149 | @click.option(
150 | '--use-fast-tokenizer',
151 | is_flag=True,
152 | default=True,
153 | help='Whether to use one of the fast tokenizer (backed by the HuggingFace '
154 | 'tokenizers library) or not.',
155 | )
156 | @click.option(
157 | '--cache-dir',
158 | default=str(Path.home() / '.hf_cache'),
159 | type=click.Path(file_okay=False, writable=True, resolve_path=True),
160 | help='Where to store downloaded pretrained HuggingFace models (together with '
161 | 'respective config and a tokenizer).',
162 | )
163 | @click.option(
164 | '--attn-dropout', default=None, type=float, help='Dropout rate to set for attention probs.'
165 | )
166 | @click.option(
167 | '--hidden-dropout', default=None, type=float, help='Dropout rate to set for hidden states.'
168 | )
169 | @wraps(func)
170 | def func_wrapper(config, *args, **kwargs):
171 | attrs = [
172 | 'model_name',
173 | 'model_path',
174 | 'quant_model_path',
175 | 'use_fast_tokenizer',
176 | 'cache_dir',
177 | 'attn_dropout',
178 | 'hidden_dropout',
179 | ]
180 |
181 | config.model, other_kw = split_dict(kwargs, attrs)
182 | return func(config, *args, **other_kw)
183 |
184 | return func_wrapper
185 |
186 |
187 | def transformer_training_options(func):
188 | # standard settings
189 | @click.option(
190 | '--do-eval/--no-eval',
191 | is_flag=True,
192 | default=True,
193 | help='Whether to run eval on the dev set after training.',
194 | )
195 | @click.option('--batch-size', type=int, default=8, help='Batch size for training.')
196 | @click.option(
197 | '--eval-batch-size',
198 | type=int,
199 | default=None,
200 | help='Batch size for evaluation. Defaults to the batch size for training.',
201 | )
202 | @click.option(
203 | '--learning-rate', type=float, default=5e-5, help='The initial learning rate for Adam.'
204 | )
205 | @click.option(
206 | '--lr-scheduler-type',
207 | default='cosine',
208 | type=click.Choice(['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant',
209 | 'constant_with_warmup']),
210 | help='The scheduler type to use.',
211 | )
212 | @click.option('--weight-decay', type=float, default=0.0, help='Weight decay for AdamW.')
213 | @click.option(
214 | '--max-grad-norm',
215 | type=float,
216 | default=None,
217 | help='Max gradient norm. If set to 0, no clipping will be applied.',
218 | )
219 | @click.option(
220 | '--num-epochs', type=int, default=3, help='Total number of training epochs to perform.'
221 | )
222 | @click.option(
223 | '--max-steps',
224 | type=int,
225 | default=0,
226 | help='If > 0, set total number of training steps to perform. Overrides `num_epochs`.',
227 | )
228 | @click.option('--warmup-steps', type=int, default=0, help='Linear warmup over `warmup_steps`.')
229 |
230 | # hw optimizations
231 | @click.option(
232 | '--gradient-accumulation-steps',
233 | type=int,
234 | default=1,
235 | help="Number of updates steps to accumulate before performing a backward/update pass.",
236 | )
237 | @click.option('--amp', is_flag=True, default=False, help='Whether to use Apex AMP.')
238 | @click.option(
239 | '--amp-opt-level',
240 | type=click.Choice(('O0', 'O1', 'O2', 'O3')),
241 | default='O2',
242 | help='Apex AMP optimization level.',
243 | )
244 |
245 | # custom regularization
246 | @click.option('--ffn-weight-decay', type=float, default=0, help='Weight decay for FFN weights.')
247 | @click.option('--gamma', type=float, default=0, help='Activation regularization strength.')
248 | @click.option('--margin', type=float, default=0, help='Activation regularization margin.')
249 |
250 | # custom functionality
251 | @click.option(
252 | '--save-attn',
253 | is_flag=True,
254 | default=False,
255 | help='Save attention probabilities from the training set and skip training.',
256 | )
257 | @wraps(func)
258 | def func_wrapper(config, *args, **kwargs):
259 | attrs = [
260 | 'do_eval',
261 | 'batch_size',
262 | 'eval_batch_size',
263 | 'learning_rate',
264 | 'lr_scheduler_type',
265 | 'weight_decay',
266 | 'max_grad_norm',
267 | 'num_epochs',
268 | 'max_steps',
269 | 'warmup_steps',
270 | 'gradient_accumulation_steps',
271 | 'amp',
272 | 'amp_opt_level',
273 | 'ffn_weight_decay',
274 | 'gamma',
275 | 'margin',
276 | 'save_attn',
277 | ]
278 |
279 | config.training, other_kw = split_dict(kwargs, attrs)
280 | if config.training.eval_batch_size is None:
281 | config.training.eval_batch_size = config.training.batch_size
282 |
283 | return func(config, *args, **other_kw)
284 |
285 | return func_wrapper
286 |
287 |
288 | def transformer_progress_options(func):
289 | @click.option('--tqdm/--no-tqdm', default=True)
290 | @click.option(
291 | '--eval-during-training',
292 | is_flag=True,
293 | default=False,
294 | help='Run evaluation during training at each logging step.',
295 | )
296 | @click.option(
297 | '--eval-strategy',
298 | default=EvaluationStrategy.NO.value,
299 | type=click.Choice([m.value for m in EvaluationStrategy], case_sensitive=False),
300 | help='Evaluation frequency level.',
301 | )
302 | @click.option(
303 | '--eval-steps', type=int, default=None, help='Run an evaluation every `eval_steps` steps.'
304 | )
305 | @click.option(
306 | '--tb-logging-dir',
307 | default=None,
308 | type=click.Path(exists=False, writable=True, resolve_path=True),
309 | help='Tensorboard log dir.',
310 | )
311 | @click.option(
312 | '--tb',
313 | is_flag=True,
314 | default=False,
315 | help='Whether to create and log (additional) stuff to the TensorBoard writer',
316 | )
317 | @click.option(
318 | '--tb-graph',
319 | is_flag=True,
320 | default=False,
321 | help='Whether to log computational graph into the TensorBoard writer',
322 | )
323 | @click.option(
324 | '--logging-first-step',
325 | is_flag=True,
326 | default=False,
327 | help='Log and eval the first global_step.',
328 | )
329 | @click.option(
330 | '--logging-steps', type=int, default=500, help='Log every `logging_steps` updates steps.'
331 | )
332 | @click.option(
333 | '--save-steps',
334 | type=int,
335 | default=0,
336 | help='Save checkpoint every `save_steps` updates steps.',
337 | )
338 | @click.option(
339 | '--save-total-limit',
340 | type=int,
341 | default=None,
342 | help='Limit the total amount of checkpoints. Deletes the older checkpoints in '
343 | 'the `output_dir`. Default is unlimited checkpoints.',
344 | )
345 | @click.option(
346 | '--save-model',
347 | is_flag=True,
348 | default=False,
349 | help='Whether to save model and tokenizer after the training.',
350 | )
351 | @click.option(
352 | '--run-name',
353 | type=str,
354 | default=None,
355 | help='An optional descriptor for the run. Notably used for wandb logging.',
356 | )
357 | @click.option(
358 | '--load-best-model-at-end',
359 | is_flag=True,
360 | default=False,
361 | help='Whether or not to load the best model found during training at the end of '
362 | 'training.',
363 | )
364 | @click.option(
365 | '--metric-for-best-model',
366 | type=str,
367 | default=None,
368 | help='The metric to use to compare two different models.',
369 | )
370 | @click.option(
371 | '--greater-is-better',
372 | type=bool,
373 | default=None,
374 | help='Whether the `metric_for_best_model` should be maximized or not.',
375 | )
376 | @wraps(func)
377 | def func_wrapper(config, *args, **kwargs):
378 | attrs = [
379 | 'tqdm',
380 | 'eval_during_training',
381 | 'eval_strategy',
382 | 'eval_steps',
383 | 'tb_logging_dir',
384 | 'tb',
385 | 'tb_graph',
386 | 'logging_first_step',
387 | 'logging_steps',
388 | 'save_steps',
389 | 'save_total_limit',
390 | 'save_model',
391 | 'run_name',
392 | 'load_best_model_at_end',
393 | 'metric_for_best_model',
394 | 'greater_is_better',
395 | ]
396 |
397 | config.progress, other_kw = split_dict(kwargs, attrs)
398 | return func(config, *args, **other_kw)
399 |
400 | return func_wrapper
401 |
402 |
403 | def transformer_quant_options(func):
404 | @click.option(
405 | '--est-ranges-pad/--est-ranges-no-pad',
406 | is_flag=True,
407 | default=None,
408 | help='Specify whether to pad to max sequence length during range estimation.'
409 | 'If None, inherit the value of --pad-to-max-length.',
410 | )
411 | @click.option(
412 | '--est-ranges-batch-size',
413 | type=int,
414 | default=None,
415 | help='Batch size for range estimation. Defaults to the batch size for training.',
416 | )
417 | @click.option('--quant-dict', type=str, default=None)
418 | @click.option('--double', is_flag=True)
419 | @click.option('--dynamic', is_flag=True)
420 | @click.option('--per-token', is_flag=True)
421 | @click.option('--per-embd', is_flag=True)
422 | @click.option('--per-groups', type=int, default=None)
423 | @click.option('--per-groups-permute', is_flag=True)
424 | @click.option('--per-groups-permute-shared-h', is_flag=True)
425 | @wraps(func)
426 | def func_wrapper(config, est_ranges_pad, est_ranges_batch_size, quant_dict, double, dynamic,
427 | per_token, per_embd, per_groups, per_groups_permute,
428 | per_groups_permute_shared_h, *a, **kw):
429 | config.quant.est_ranges_pad = est_ranges_pad
430 | config.quant.est_ranges_batch_size = (
431 | est_ranges_batch_size if est_ranges_batch_size is not None
432 | else config.training.batch_size
433 | )
434 |
435 | if quant_dict is not None:
436 | quant_dict = eval(quant_dict)
437 | config.quant.quant_dict = quant_dict
438 | config.double = double
439 |
440 | config.quant.dynamic = dynamic
441 | config.quant.per_token = per_token
442 | if config.quant.per_token:
443 | config.quant.dynamic = True
444 |
445 | config.quant.per_embd = per_embd
446 | config.quant.per_groups = per_groups
447 | config.quant.per_groups_permute = per_groups_permute
448 | config.quant.per_groups_permute_shared_h = per_groups_permute_shared_h
449 |
450 | return func(config, *a, **kw)
451 |
452 | return func_wrapper
453 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
4 | import os
5 | import sys
6 | import time
7 | import random
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 |
13 | from quantization.range_estimators import RangeEstimators
14 |
15 |
16 | def seed_all(seed=1029):
17 | random.seed(seed)
18 | os.environ['PYTHONHASHSEED'] = str(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU
23 | torch.backends.cudnn.benchmark = False
24 | torch.backends.cudnn.deterministic = True
25 |
26 |
27 | def count_params(module):
28 | return len(nn.utils.parameters_to_vector(module.parameters()))
29 |
30 |
31 | def count_embedding_params(model):
32 | return sum(count_params(m) for m in model.modules() if isinstance(m, nn.Embedding))
33 |
34 |
35 | def get_layer_by_name(model, layer_name):
36 | for name, module in model.named_modules():
37 | if name == layer_name:
38 | return module
39 | return None
40 |
41 |
42 | class StopForwardException(Exception):
43 | """Used to throw and catch an exception to stop traversing the graph."""
44 | pass
45 |
46 |
47 | def pass_data_for_range_estimation(
48 | loader, model, act_quant, weight_quant, max_num_batches=20, cross_entropy_layer=None, inp_idx=0
49 | ):
50 | model.set_quant_state(weight_quant, act_quant)
51 | model.eval()
52 |
53 | if cross_entropy_layer is not None:
54 | layer_xent = get_layer_by_name(model, cross_entropy_layer)
55 | if layer_xent:
56 | print(f'Set cross entropy estimator for layer "{cross_entropy_layer}"')
57 | act_quant_mgr = layer_xent.activation_quantizer
58 | act_quant_mgr.range_estimator = RangeEstimators.cross_entropy.cls(
59 | per_channel=act_quant_mgr.per_channel,
60 | quantizer=act_quant_mgr.quantizer,
61 | **act_quant_mgr.init_params,
62 | )
63 | else:
64 | raise ValueError('Cross-entropy layer not found')
65 |
66 | device = next(model.parameters()).device
67 | for i, data in enumerate(loader):
68 | try:
69 | if isinstance(data, (tuple, list)):
70 | x = data[inp_idx].to(device=device)
71 | model(x)
72 | else:
73 | x = {k: v.to(device=device) for k, v in data.items()}
74 | model(**x)
75 | except StopForwardException:
76 | pass
77 |
78 | if i >= max_num_batches - 1 or not act_quant:
79 | break
80 |
81 |
82 | class DotDict(dict):
83 | """
84 | A dictionary that allows attribute-style access.
85 |
86 | Examples
87 | --------
88 | >>> config = DotDict(a=None)
89 | >>> config.a = 42
90 | >>> config.b = 'egg'
91 | >>> config # can be used as dict
92 | {'a': 42, 'b': 'egg'}
93 | """
94 | def __setattr__(self, key, value):
95 | self.__setitem__(key, value)
96 |
97 | def __delattr__(self, key):
98 | self.__delitem__(key)
99 |
100 | def __getattr__(self, key):
101 | if key in self:
102 | return self.__getitem__(key)
103 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})")
104 |
105 |
106 | class Stopwatch:
107 | """
108 | A simple cross-platform context-manager stopwatch.
109 |
110 | Examples
111 | --------
112 | >>> import time
113 | >>> with Stopwatch(verbose=True) as st:
114 | ... time.sleep(0.101) #doctest: +ELLIPSIS
115 | Elapsed time: 0.10... sec
116 | """
117 | def __init__(self, name=None, verbose=False):
118 | self._name = name
119 | self._verbose = verbose
120 |
121 | self._start_time_point = 0.0
122 | self._total_duration = 0.0
123 | self._is_running = False
124 |
125 | if sys.platform == 'win32':
126 | # on Windows, the best timer is time.clock()
127 | self._timer_fn = time.clock
128 | else:
129 | # on most other platforms, the best timer is time.time()
130 | self._timer_fn = time.time
131 |
132 | def __enter__(self, verbose=False):
133 | return self.start()
134 |
135 | def __exit__(self, exc_type, exc_val, exc_tb):
136 | self.stop()
137 | if self._verbose:
138 | self.print()
139 |
140 | def start(self):
141 | if not self._is_running:
142 | self._start_time_point = self._timer_fn()
143 | self._is_running = True
144 | return self
145 |
146 | def stop(self):
147 | if self._is_running:
148 | self._total_duration += self._timer_fn() - self._start_time_point
149 | self._is_running = False
150 | return self
151 |
152 | def reset(self):
153 | self._start_time_point = 0.0
154 | self._total_duration = 0.0
155 | self._is_running = False
156 | return self
157 |
158 | def _update_state(self):
159 | now = self._timer_fn()
160 | self._total_duration += now - self._start_time_point
161 | self._start_time_point = now
162 |
163 | def _format(self):
164 | prefix = f'[{self._name}]' if self._name is not None else 'Elapsed time'
165 | info = f'{prefix}: {self._total_duration:.3f} sec'
166 | return info
167 |
168 | def format(self):
169 | if self._is_running:
170 | self._update_state()
171 | return self._format()
172 |
173 | def print(self):
174 | print(self.format())
175 |
176 | def get_total_duration(self):
177 | if self._is_running:
178 | self._update_state()
179 | return self._total_duration
180 |
--------------------------------------------------------------------------------