├── .gitignore ├── README.md ├── configs ├── BasePTQ.py └── PTQ4ViT.py ├── example ├── get_int.py ├── test_ablation.py ├── test_all.py └── test_vit.py ├── quant_layers ├── conv.py ├── linear.py └── matmul.py └── utils ├── datasets.py ├── integer.py ├── models.py ├── net_wrap.py └── quant_calib.py /.gitignore: -------------------------------------------------------------------------------- 1 | tmp 2 | *.pyc 3 | __pycache__ 4 | *.pth 5 | .vscode 6 | checkpoints 7 | *.log 8 | *.csv 9 | *.png 10 | *.jpg 11 | output 12 | *.weights 13 | *.tmp.* 14 | data 15 | ckt 16 | *.out 17 | *.zip 18 | *.json 19 | test.ipynb 20 | int_weights -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PTQ4ViT 2 | Post-Training Quantization Framework for Vision Transformers. 3 | We use the twin uniform quantization method to reduce the quantization error on these activation values. 4 | And we use a Hessian guided metric to evaluate different scaling factors, which improves the accuracy of calibration with a small cost. 5 | The quantized vision transformers (ViT, DeiT, and Swin) achieve near-lossless prediction accuracy (less than 0.5\% drop at 8-bit quantization) on the ImageNet classification task. Please read the [paper](https://arxiv.org/abs/2111.12293) for details. 6 | 7 | ## Updates 8 | 9 | *19/07/2022* 10 | Add discussion on Base PTQ, and provide more ablation study results. 11 | 12 | ### Number of Calibration Images 13 | 14 | | Model | W8A8 #ims=32 | W6A6 #ims=32 | W8A8 #ims=128 | W6A6 #ims=128 | 15 | |:------------:|:------------:|:------------:|:-------------:|:-------------:| 16 | | ViT-S/224/32 | 75.58 | 71.91 | 75.54 | 72.29 | 17 | | ViT-S/224 | 81.00 | 78.63 | 80.99 | 78.44 | 18 | | ViT-B/224 | 84.25 | 81.65 | 84.27 | 81.84 | 19 | | ViT-B/384 | 85.83 | 83.35 | 85.81 | 83.84 | 20 | | DeiT-S/224 | 79.47 | 76.28 | 79.41 | 76.51 | 21 | | DeiT-B/224 | 81.48 | 80.25 | 81.54 | 80.30 | 22 | | DeiT-B/384 | 82.97 | 81.55 | 83.01 | 81.67 | 23 | | Swin-T/224 | 81.25 | 80.47 | 81.27 | 80.30 | 24 | | Swin-S/224 | 83.11 | 82.38 | 83.15 | 82.38 | 25 | | Swin-B/224 | 85.15 | 84.01 | 85.17 | 84.15 | 26 | | Swin-B/384 | 86.39 | 85.39 | 86.36 | 85.45 | 27 | 28 | | Model | Time #ims=32 | Time #ims=128 | 29 | |:------------:|:------------:|:-------------:| 30 | | ViT-S/224/32 | 2 min | 5 min | 31 | | ViT-S/224 | 3 min | 7 min | 32 | | ViT-B/224 | 4 min | 13 min | 33 | | ViT-B/384 | 12 min | 43 min | 34 | | DeiT-S/224 | 3 min | 7 min | 35 | | DeiT-B/224 | 4 min | 16 min | 36 | | DeiT-B/384 | 14 min | 52 min | 37 | | Swin-T/224 | 3 min | 9 min | 38 | | Swin-S/224 | 8 min | 17 min | 39 | | Swin-B/224 | 10 min | 23 min | 40 | | Swin-B/384 | 25 min | 69 min | 41 | 42 | One of the targets of PTQ4ViT is to quickly quantize a vision transformer. 43 | We have proposed to pre-compute the output and gradient of each layer and compute the influence of scaling factor candidates in batches to reduce the quantization time. 44 | As demonstrated in the second table, PTQ4ViT can quantize most vision transformers in several minutes using 32 calibration images. 45 | Using 128 calibration images significantly increases the quantization time. 46 | We observe the Top-1 accuracy varies slightly in the first table, demonstrating PTQ4ViT is not very sensitive to the number of calibration images. 47 | 48 | ### Base PTQ 49 | Base PTQ is a simple quantization strategy and serves as a benchmark for our experiments. 50 | Like PTQ4ViT, we quantize all weights and inputs for fully-connect layers (including the first projection layer and the last prediction layer), as well as all input matrices of matrix multiplication operations. 51 | For fully-connected layers, we use layerwise scaling factors $\Delta_W$ for weight quantization and $\Delta_X$ for input quantization; while for matrix multiplication operations, we use $\Delta_A$ and $\Delta_B$ for A's quantization and B's quantization respectively. 52 | 53 | To get the best scaling factors, we apply a linear grid search on the search space. 54 | The same as EasyQuantand Liu et al., we take hyper-parameters $\alpha=0.5$, $\beta = 1.2$, one search round and use cosine distance as the metric. 55 | Note that in PTQ4ViT, we change the hyper-parameters to $\alpha=0$, $\beta = 1.2$ and three search rounds, which slightly improves the performance. 56 | 57 | It should be noticed that Base PTQ adopts a parallel quantization paradigm, which makes it essentially different from sequential quantization paradigms such as EasyQuant. 58 | In sequential quantization, the input data of the current quantizing layer is generated with all previous layers quantizing weights and activations. 59 | While in parallel quantization, the input data of the current quantizing layer is simply the raw output of the previous layer. 60 | 61 | In practice, we found sequential quantization on vision transformers suffers from significant accuracy degradation on small calibration datasets. 62 | While parallel quantization shows robustness on small calibration datasets. 63 | Therefore, we choose parallel quantization for both Base PTQ and PTQ4ViT. 64 | 65 | ### More Ablation Study 66 | 67 | We supply more ablation studies for the hyper-parameters. 68 | It is enough to set the number of quantization intervals $\ge$ 20 (accuracy change $< 0.3\%$). 69 | It is enough to set the upper bound of m $\ge$ 15 (no accuracy change). 70 | The best settings of alpha and beta vary from different layers. 71 | It is appropriate to set $\alpha=0$ and $\beta=1/2^{k-1}$, which has little impact on search efficiency. 72 | We observe that search rounds has little impact on the prediction accuracy (accuracy change $<$ 0.05\% when search rounds $>1$). 73 | 74 | We randomly take 32 calibration images to quantize different models 20 times and we observe the fluctuation is not significant. 75 | The mean/std of accuracies are: ViT-S/32 $75.55\%/0.055\%$ , ViT-S $80.96\%/0.046\%$, ViT-B $84.12\%/0.068\%$, DeiT-S $79.45\%/0.094\%$ , and Swin-S $83.11\%/0.035\%$. 76 | 77 | 78 | *15/01/2022* 79 | Add saved quantized models with PTQ4ViT. 80 | | model | link | 81 | |:------------:|:--------:| 82 | | ViT-S/224/32 | [Google](https://drive.google.com/file/d/195JJJKULvaukte6PA9U08oezjd176CTs/view?usp=sharing) | 83 | | ViT-S/224 | [Google](https://drive.google.com/file/d/14uEDgRmDBYoKoZtpO9IWMfG8Uvkt_OuL/view?usp=sharing) | 84 | | ViT-B/224 | [Google](https://drive.google.com/file/d/1ou6s9Vd-_iyQ7sj7VYET-pRvJA6WMMLA/view?usp=sharing) | 85 | | ViT-B/384 | [Google](https://drive.google.com/file/d/1tuU8or8SfQomtoWam7WFTnUxtuw3n7fs/view?usp=sharing) | 86 | | DeiT-S/224 | [Google](https://drive.google.com/file/d/1673fX-SuiRlHhm7k0Yyyx_3ynwtvUPyf/view?usp=sharing) | 87 | | DeiT-B/224 | [Google](https://drive.google.com/file/d/1WRAtmPF0kDR9iTLc9gv_63aEkOCZ_zOI/view?usp=sharing) | 88 | | DeiT-B/384 | [Google](https://drive.google.com/file/d/1mPPlM2ioe4zts_rdKdjZTCUj8KcbquyA/view?usp=sharing) | 89 | | Swin-T/224 | [Google](https://drive.google.com/file/d/1bSahHgtL3yFaHPlG-SDtu__YY0zJ8lxr/view?usp=sharing) | 90 | | Swin-S/224 | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing) | 91 | | Swin-B/224 | [Google](https://drive.google.com/file/d/19UUUQYJGs5SQaDe27PjY3x1QTBU5hwXm/view?usp=sharing) | 92 | | Swin-B/384 | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing) | 93 | 94 | *10/12/2021* 95 | Add `utils/integer.py`, you can now: 96 | 1. convert calibrated fp32 model into int8 97 | 2. register pre-forward hook in the model, and fetch activation in int8. (We use uint8 to store results 98 | of twin quantization, please refer to the paper to see the bits' layout). 99 | 100 | ## Install 101 | 102 | ### Requirement 103 | - python>=3.5 104 | - pytorch>=1.5 105 | - matplotlib 106 | - pandas 107 | - timm 108 | 109 | ### Datasets 110 | To run example testing, you should put your ImageNet2012 dataset in path `/datasets/imagenet`. 111 | 112 | We use `ViTImageNetLoaderGenerator` in `utils/datasets.py` to initialize our DataLoader. 113 | If your Imagenet datasets are stored elsewhere, you'll need to manually pass its root as an argument when instantiating a `ViTImageNetLoaderGenerator`. 114 | 115 | ## Usage 116 | 117 | ### 1. Run example quantization 118 | To test on all models with BasePTQ/PTQ4ViT, run 119 | ```bash 120 | python example/test_all.py 121 | ``` 122 | 123 | To run ablation testing, run 124 | ```bash 125 | python example/test_ablation.py 126 | ``` 127 | 128 | You can run the testing scripts with multiple GPUs. For example, calling 129 | ```bash 130 | python example/test_all.py --multigpu --n_gpu 6 131 | ``` 132 | will use 6 gpus to run the test. 133 | 134 | ### 2. Download quantized model checkpoints 135 | (Coming soon) 136 | 137 | ## Results 138 | ### Results of BasePTQ 139 | 140 | | model | original | w8a8 | w6a6 | 141 | |:------------:|:--------:|:------:|:-------:| 142 | | ViT-S/224/32 | 75.99 | 73.61 | 60.144 | 143 | | ViT-S/224 | 81.39 | 80.468 | 70.244 | 144 | | ViT-B/224 | 84.54 | 83.896 | 75.668 | 145 | | ViT-B/384 | 86.00 | 85.352 | 46.886 | 146 | | DeiT-S/224 | 79.80 | 77.654 | 72.268 | 147 | | DeiT-B/224 | 81.80 | 80.946 | 78.786 | 148 | | DeiT-B/384 | 83.11 | 82.33 | 68.442 | 149 | | Swin-T/224 | 81.39 | 80.962 | 78.456 | 150 | | Swin-S/224 | 83.23 | 82.758 | 81.742 | 151 | | Swin-B/224 | 85.27 | 84.792 | 83.354 | 152 | | Swin-B/384 | 86.44 | 86.168 | 85.226 | 153 | 154 | Results of PTQ4ViT 155 | 156 | | model | original | w8a8 | w6a6 | 157 | |:------------:|:--------:|:------:|:-------:| 158 | | ViT-S/224/32 | 75.99 | 75.582 | 71.908 | 159 | | ViT-S/224 | 81.39 | 81.002 | 78.63 | 160 | | ViT-B/224 | 84.54 | 84.25 | 81.65 | 161 | | ViT-B/384 | 86.00 | 85.828 | 83.348 | 162 | | DeiT-S/224 | 79.80 | 79.474 | 76.282 | 163 | | DeiT-B/224 | 81.80 | 81.482 | 80.25 | 164 | | DeiT-B/384 | 83.11 | 82.974 | 81.55 | 165 | | Swin-T/224 | 81.39 | 81.246 | 80.47 | 166 | | Swin-S/224 | 83.23 | 83.106 | 82.38 | 167 | | Swin-B/224 | 85.27 | 85.146 | 84.012 | 168 | | Swin-B/384 | 86.44 | 86.394 | 85.388 | 169 | 170 | ### Results of Ablation 171 | - ViT-S/224 (original top-1 accuracy 81.39%) 172 | 173 | | Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 | 174 | |:--------------:|:------------:|:---------:|:------:|:-------:| 175 | | | | | 80.47 | 70.24 | 176 | | ✓ | | | 80.93 | 77.20 | 177 | | ✓ | ✓ | | 81.11 | 78.57 | 178 | | ✓ | | ✓ | 80.84 | 76.93 | 179 | | | ✓ | ✓ | 79.25 | 74.07 | 180 | | ✓ | ✓ | ✓ | 81.00 | 78.63 | 181 | 182 | - ViT-B/224 (original top-1 accuracy 84.54%) 183 | 184 | | Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 | 185 | |:--------------:|:------------:|:---------:|:------:|:-------:| 186 | | | | | 83.90 | 75.67 | 187 | | ✓ | | | 83.97 | 79.90 | 188 | | ✓ | ✓ | | 84.07 | 80.76 | 189 | | ✓ | | ✓ | 84.10 | 80.82 | 190 | | | ✓ | ✓ | 83.40 | 78.86 | 191 | | ✓ | ✓ | ✓ | 84.25 | 81.65 | 192 | 193 | - ViT-B/384 (original top-1 accuracy 86.00%) 194 | 195 | | Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 | 196 | |:--------------:|:------------:|:---------:|:------:|:-------:| 197 | | | | | 85.35 | 46.89 | 198 | | ✓ | | | 85.42 | 79.99 | 199 | | ✓ | ✓ | | 85.67 | 82.01 | 200 | | ✓ | | ✓ | 85.60 | 82.21 | 201 | | | ✓ | ✓ | 84.35 | 80.86 | 202 | | ✓ | ✓ | ✓ | 85.89 | 83.19 | 203 | 204 | ## Citation 205 | ``` 206 | @article{PTQ4ViT_arixv2022, 207 | title={PTQ4ViT: Post-Training Quantization Framework for Vision Transformers}, 208 | author={Zhihang Yuan, Chenhao Xue, Yiqi Chen, Qiang Wu, Guangyu Sun}, 209 | journal={arXiv preprint arXiv:2111.12293}, 210 | year={2022}, 211 | } 212 | ``` 213 | -------------------------------------------------------------------------------- /configs/BasePTQ.py: -------------------------------------------------------------------------------- 1 | from quant_layers.conv import PTQSLQuantConv2d, BatchingEasyQuantConv2d 2 | from quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear 3 | from quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul 4 | 5 | bit = 8 6 | conv_fc_name_list = ["qconv", "qlinear_qkv", "qlinear_proj", "qlinear_MLP_1", "qlinear_MLP_2", "qlinear_classifier", "qlinear_reduction"] 7 | matmul_name_list = [ "qmatmul_qk", "qmatmul_scorev"] 8 | w_bit = {name: bit for name in conv_fc_name_list} 9 | a_bit = {name: bit for name in conv_fc_name_list} 10 | A_bit = {name: bit for name in matmul_name_list} 11 | B_bit = {name: bit for name in matmul_name_list} 12 | 13 | ptqsl_conv2d_kwargs = { 14 | "metric": "cosine", 15 | "eq_alpha": 0.5, 16 | "eq_beta": 1.2, 17 | "eq_n": 100, 18 | 'search_round': 1, 19 | "n_V": 1, 20 | "n_H": 1, 21 | } 22 | ptqsl_linear_kwargs = { 23 | "metric": "cosine", 24 | "eq_alpha": 0.5, 25 | "eq_beta": 1.2, 26 | "eq_n": 100, 27 | 'search_round': 1, 28 | "n_V": 1, 29 | "n_H": 1, 30 | "n_a": 1, 31 | } 32 | ptqsl_matmul_kwargs = { 33 | "metric": "cosine", 34 | "eq_alpha": 0.5, 35 | "eq_beta": 1.2, 36 | "eq_n": 100, 37 | 'search_round': 1, 38 | "n_G_A": 1, 39 | "n_V_A": 1, 40 | "n_H_A": 1, 41 | "n_G_B": 1, 42 | "n_V_B": 1, 43 | "n_H_B": 1, 44 | } 45 | 46 | 47 | def get_module(module_type, *args, **kwargs): 48 | if module_type == "qconv": 49 | kwargs.update(ptqsl_conv2d_kwargs) 50 | module=BatchingEasyQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization 51 | # module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization 52 | elif "qlinear" in module_type: 53 | kwargs.update(ptqsl_linear_kwargs) 54 | if module_type == "qlinear_qkv": 55 | kwargs["n_V"] *= 3 # q, k, v 56 | module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) 57 | else: 58 | module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) 59 | elif "qmatmul" in module_type: 60 | kwargs.update(ptqsl_matmul_kwargs) 61 | module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) 62 | return module -------------------------------------------------------------------------------- /configs/PTQ4ViT.py: -------------------------------------------------------------------------------- 1 | from quant_layers.conv import PTQSLQuantConv2d, ChannelwiseBatchingQuantConv2d 2 | from quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear 3 | from quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul 4 | 5 | no_softmax = False 6 | no_postgelu = False 7 | 8 | bit = 8 9 | conv_fc_name_list = ["qconv", "qlinear_qkv", "qlinear_proj", "qlinear_MLP_1", "qlinear_MLP_2", "qlinear_classifier", "qlinear_reduction"] 10 | matmul_name_list = [ "qmatmul_qk", "qmatmul_scorev"] 11 | w_bit = {name: bit for name in conv_fc_name_list} 12 | a_bit = {name: bit for name in conv_fc_name_list} 13 | A_bit = {name: bit for name in matmul_name_list} 14 | B_bit = {name: bit for name in matmul_name_list} 15 | 16 | ptqsl_conv2d_kwargs = { 17 | "metric": "hessian", 18 | "eq_alpha": 0.01, 19 | "eq_beta": 1.2, 20 | "eq_n": 100, 21 | 'search_round': 3, 22 | "n_V": 1, 23 | "n_H": 1, 24 | } 25 | ptqsl_linear_kwargs = { 26 | "metric": "hessian", 27 | "eq_alpha": 0.01, 28 | "eq_beta": 1.2, 29 | "eq_n": 100, 30 | 'search_round': 3, 31 | "n_V": 1, 32 | "n_H": 1, 33 | "n_a": 1, 34 | "bias_correction":True # Conventionally I'll not add an actual bias correction in linear 35 | } 36 | ptqsl_matmul_kwargs = { 37 | "metric": "hessian", 38 | "eq_alpha": 0.01, 39 | "eq_beta": 1.2, 40 | "eq_n": 100, 41 | 'search_round': 3, 42 | "n_G_A": 1, 43 | "n_V_A": 1, 44 | "n_H_A": 1, 45 | "n_G_B": 1, 46 | "n_V_B": 1, 47 | "n_H_B": 1, 48 | } 49 | 50 | 51 | def get_module(module_type, *args, **kwargs): 52 | if module_type == "qconv": 53 | kwargs.update(ptqsl_conv2d_kwargs) 54 | module=ChannelwiseBatchingQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization 55 | # module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization 56 | elif "qlinear" in module_type: 57 | kwargs.update(ptqsl_linear_kwargs) 58 | if module_type == "qlinear_qkv": 59 | kwargs["n_V"] *= 3 # q, k, v 60 | module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) 61 | elif module_type == "qlinear_MLP_2": 62 | if no_postgelu: 63 | module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) 64 | else: 65 | module=PostGeluPTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) 66 | elif module_type == "qlinear_classifier": 67 | kwargs["n_V"] = 1 68 | module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) 69 | else: 70 | module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) 71 | elif "qmatmul" in module_type: 72 | kwargs.update(ptqsl_matmul_kwargs) 73 | if module_type == "qmatmul_qk": 74 | module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) 75 | elif module_type == "qmatmul_scorev": 76 | if no_softmax: 77 | module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) 78 | else: 79 | module=SoSPTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) 80 | return module -------------------------------------------------------------------------------- /example/get_int.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | sys.path.insert(0,'.') 4 | from example.test_vit import * 5 | import utils.net_wrap as net_wrap 6 | import utils.datasets as datasets 7 | import utils.integer as integer 8 | from utils.quant_calib import HessianQuantCalibrator 9 | 10 | from itertools import product 11 | 12 | def get_int_weights(name, config_name): 13 | quant_cfg = init_config(config_name) 14 | 15 | net = get_net(name) 16 | 17 | wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg) 18 | 19 | g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net}) 20 | test_loader=g.test_loader() 21 | calib_loader=g.calib_loader(num=32) 22 | 23 | quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 24 | quant_calibrator.batching_quant_calib() 25 | 26 | int_weights = integer.get_model_int_weight(wrapped_modules) 27 | torch.save(int_weights, f"./int_weights/{name}.pth") 28 | 29 | 30 | if __name__ == "__main__": 31 | args = parse_args() 32 | 33 | names = [ 34 | # "vit_tiny_patch16_224", 35 | # "vit_small_patch32_224", 36 | # "vit_small_patch16_224", 37 | # "vit_base_patch16_224", 38 | "vit_base_patch16_384", 39 | 40 | # "deit_tiny_patch16_224", 41 | # "deit_small_patch16_224", 42 | # "deit_base_patch16_224", 43 | # "deit_base_patch16_384", 44 | 45 | # "swin_tiny_patch4_window7_224", 46 | # "swin_small_patch4_window7_224", 47 | # "swin_base_patch4_window7_224", 48 | # "swin_base_patch4_window12_384", 49 | ] 50 | config_names = ["PTQ4ViT", "BasePTQ"] 51 | 52 | cfg_list = [] 53 | for name, config in product(names, config_names): 54 | cfg_list.append({"name":name, "config_name":config}) 55 | 56 | if args.multiprocess: 57 | multiprocess(get_int_weights, cfg_list, n_gpu=args.n_gpu) 58 | else: 59 | for cfg in cfg_list: 60 | get_int_weights(**cfg) -------------------------------------------------------------------------------- /example/test_ablation.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import module 2 | from test_vit import * 3 | from quant_layers.conv import MinMaxQuantConv2d 4 | from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear 5 | from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul 6 | import matplotlib.pyplot as plt 7 | from utils.net_wrap import wrap_certain_modules_in_net 8 | from tqdm import tqdm 9 | import torch.nn.functional as F 10 | import pickle as pkl 11 | from itertools import product 12 | import types 13 | from utils.quant_calib import HessianQuantCalibrator, QuantCalibrator 14 | from utils.models import get_net 15 | import time 16 | 17 | def test_all_ablation(name, cfg_modifier=lambda x: x, calib_size=32): 18 | quant_cfg = init_config("PTQ4ViT") 19 | quant_cfg = cfg_modifier(quant_cfg) 20 | 21 | net = get_net(name) 22 | 23 | wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg) 24 | 25 | g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net}) 26 | test_loader=g.test_loader() 27 | calib_loader=g.calib_loader(num=calib_size) 28 | 29 | quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 30 | quant_calibrator.batching_quant_calib() 31 | 32 | acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs["metric"]) 33 | 34 | print(f"model: {name} \n") 35 | print(f"calibration size: {calib_size} \n") 36 | print(f"bit settings: {quant_cfg.bit} \n") 37 | print(f"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \n") 38 | print(f"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \n") 39 | print(f"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \n") 40 | print(f"accuracy: {acc} \n\n") 41 | 42 | class cfg_modifier(): 43 | def __init__(self, **kwargs): 44 | for name, value in kwargs.items(): 45 | setattr(self,name,value) 46 | 47 | def __call__(self, cfg): 48 | # bit setting 49 | cfg.bit = self.bit_setting 50 | cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list} 51 | cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list} 52 | cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} 53 | cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} 54 | 55 | # conv2d configs 56 | cfg.ptqsl_conv2d_kwargs["n_V"] = self.linear_ptq_setting[0] 57 | cfg.ptqsl_conv2d_kwargs["n_H"] = self.linear_ptq_setting[1] 58 | cfg.ptqsl_conv2d_kwargs["metric"] = self.metric 59 | cfg.ptqsl_conv2d_kwargs["search_round"] = self.search_round 60 | cfg.ptqsl_conv2d_kwargs["parallel_eq_n"] = 1 # maximum 7 , reserve 4Gb for gradient 61 | cfg.ptqsl_conv2d_kwargs["init_layerwise"] = False 62 | 63 | # linear configs 64 | cfg.ptqsl_linear_kwargs["n_V"] = self.linear_ptq_setting[0] 65 | cfg.ptqsl_linear_kwargs["n_H"] = self.linear_ptq_setting[1] 66 | cfg.ptqsl_linear_kwargs["n_a"] = self.linear_ptq_setting[2] 67 | cfg.ptqsl_linear_kwargs["metric"] = self.metric 68 | cfg.ptqsl_linear_kwargs["search_round"] = self.search_round 69 | cfg.ptqsl_linear_kwargs["parallel_eq_n"] = 1 # maximum 7, reserve 4Gb for gradient 70 | cfg.ptqsl_linear_kwargs["init_layerwise"] = False 71 | 72 | # matmul configs 73 | cfg.ptqsl_matmul_kwargs["metric"] = self.metric 74 | cfg.ptqsl_matmul_kwargs["search_round"] = self.search_round 75 | cfg.ptqsl_matmul_kwargs["parallel_eq_n"] = 1 # maximum 3! 76 | cfg.ptqsl_matmul_kwargs["init_layerwise"] = False 77 | 78 | # ablation 79 | cfg.no_softmax = self.no_softmax 80 | cfg.no_postgelu = self.no_postgelu 81 | 82 | return cfg 83 | 84 | if __name__=='__main__': 85 | args = parse_args() 86 | 87 | names = [ 88 | "vit_small_patch16_224", 89 | "vit_base_patch16_224", 90 | "vit_base_patch16_384", 91 | ] 92 | metrics = ["hessian", "cosine"] 93 | linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a 94 | search_rounds = [3] 95 | calib_sizes = [32] 96 | bit_settings = [(8,8), (6,6)] # weight, activation 97 | no_softmaxs = [True, False] 98 | no_postgelus = [True, False] 99 | 100 | cfg_list = [] 101 | for name, metric, linear_ptq_setting, search_round, calib_size, bit_setting, no_softmax, no_postgelu in product(names, metrics, linear_ptq_settings, search_rounds, calib_sizes, bit_settings, no_softmaxs, no_postgelus): 102 | cfg_list.append({ 103 | "name": name, 104 | "cfg_modifier":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, search_round=search_round, bit_setting=bit_setting, no_softmax=no_softmax, no_postgelu=no_postgelu), 105 | "calib_size":calib_size, 106 | }) 107 | 108 | if args.multiprocess: 109 | multiprocess(test_all_ablation, cfg_list, n_gpu=args.n_gpu) 110 | else: 111 | for cfg in cfg_list: 112 | test_all_ablation(**cfg) -------------------------------------------------------------------------------- /example/test_all.py: -------------------------------------------------------------------------------- 1 | from timm.models.layers import config 2 | from torch.nn.modules import module 3 | from test_vit import * 4 | from quant_layers.conv import MinMaxQuantConv2d 5 | from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear 6 | from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul 7 | import matplotlib.pyplot as plt 8 | from utils.net_wrap import wrap_certain_modules_in_net 9 | from tqdm import tqdm 10 | import torch.nn.functional as F 11 | import pickle as pkl 12 | from itertools import product 13 | import types 14 | from utils.quant_calib import HessianQuantCalibrator, QuantCalibrator 15 | from utils.models import get_net 16 | import time 17 | 18 | def test_all(name, cfg_modifier=lambda x: x, calib_size=32, config_name="PTQ4ViT"): 19 | quant_cfg = init_config(config_name) 20 | quant_cfg = cfg_modifier(quant_cfg) 21 | 22 | net = get_net(name) 23 | 24 | wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg) 25 | 26 | g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net}) 27 | test_loader=g.test_loader() 28 | calib_loader=g.calib_loader(num=calib_size) 29 | 30 | # add timing 31 | calib_start_time = time.time() 32 | quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 33 | quant_calibrator.batching_quant_calib() 34 | calib_end_time = time.time() 35 | 36 | acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs["metric"]) 37 | 38 | print(f"model: {name} \n") 39 | print(f"calibration size: {calib_size} \n") 40 | print(f"bit settings: {quant_cfg.bit} \n") 41 | print(f"config: {config_name} \n") 42 | print(f"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \n") 43 | print(f"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \n") 44 | print(f"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \n") 45 | print(f"calibration time: {(calib_end_time-calib_start_time)/60}min \n") 46 | print(f"accuracy: {acc} \n\n") 47 | 48 | class cfg_modifier(): 49 | def __init__(self, **kwargs): 50 | for name, value in kwargs.items(): 51 | setattr(self,name,value) 52 | 53 | def __call__(self, cfg): 54 | # bit setting 55 | cfg.bit = self.bit_setting 56 | cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list} 57 | cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list} 58 | cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} 59 | cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} 60 | 61 | # conv2d configs 62 | cfg.ptqsl_conv2d_kwargs["n_V"] = self.linear_ptq_setting[0] 63 | cfg.ptqsl_conv2d_kwargs["n_H"] = self.linear_ptq_setting[1] 64 | cfg.ptqsl_conv2d_kwargs["metric"] = self.metric 65 | cfg.ptqsl_conv2d_kwargs["init_layerwise"] = False 66 | 67 | # linear configs 68 | cfg.ptqsl_linear_kwargs["n_V"] = self.linear_ptq_setting[0] 69 | cfg.ptqsl_linear_kwargs["n_H"] = self.linear_ptq_setting[1] 70 | cfg.ptqsl_linear_kwargs["n_a"] = self.linear_ptq_setting[2] 71 | cfg.ptqsl_linear_kwargs["metric"] = self.metric 72 | cfg.ptqsl_linear_kwargs["init_layerwise"] = False 73 | 74 | # matmul configs 75 | cfg.ptqsl_matmul_kwargs["metric"] = self.metric 76 | cfg.ptqsl_matmul_kwargs["init_layerwise"] = False 77 | 78 | return cfg 79 | 80 | if __name__=='__main__': 81 | args = parse_args() 82 | 83 | names = [ 84 | "vit_tiny_patch16_224", 85 | "vit_small_patch32_224", 86 | "vit_small_patch16_224", 87 | "vit_base_patch16_224", 88 | "vit_base_patch16_384", 89 | 90 | "deit_tiny_patch16_224", 91 | "deit_small_patch16_224", 92 | "deit_base_patch16_224", 93 | "deit_base_patch16_384", 94 | 95 | "swin_tiny_patch4_window7_224", 96 | "swin_small_patch4_window7_224", 97 | "swin_base_patch4_window7_224", 98 | "swin_base_patch4_window12_384", 99 | ] 100 | metrics = ["hessian"] 101 | linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a 102 | calib_sizes = [32,128] 103 | bit_settings = [(8,8), (6,6)] # weight, activation 104 | config_names = ["PTQ4ViT", "BasePTQ"] 105 | 106 | cfg_list = [] 107 | for name, metric, linear_ptq_setting, calib_size, bit_setting, config_name in product(names, metrics, linear_ptq_settings, calib_sizes, bit_settings, config_names): 108 | cfg_list.append({ 109 | "name": name, 110 | "cfg_modifier":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, bit_setting=bit_setting), 111 | "calib_size":calib_size, 112 | "config_name": config_name 113 | }) 114 | 115 | if args.multiprocess: 116 | multiprocess(test_all, cfg_list, n_gpu=args.n_gpu) 117 | else: 118 | for cfg in cfg_list: 119 | test_all(**cfg) -------------------------------------------------------------------------------- /example/test_vit.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | sys.path.insert(0,'.') 4 | import torch 5 | import torch.nn as nn 6 | from tqdm import tqdm 7 | import argparse 8 | from importlib import reload,import_module 9 | import multiprocessing 10 | import os 11 | import time 12 | from itertools import product 13 | 14 | import utils.datasets as datasets 15 | import utils.net_wrap as net_wrap 16 | from utils.quant_calib import QuantCalibrator, HessianQuantCalibrator 17 | from utils.models import get_net 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--n_gpu", type=int, default=6) 22 | parser.add_argument("--multiprocess", action='store_true') 23 | args = parser.parse_args() 24 | return args 25 | 26 | def test_classification(net,test_loader,max_iteration=None, description=None): 27 | pos=0 28 | tot=0 29 | i = 0 30 | max_iteration = len(test_loader) if max_iteration is None else max_iteration 31 | with torch.no_grad(): 32 | q=tqdm(test_loader, desc=description) 33 | for inp,target in q: 34 | i+=1 35 | inp=inp.cuda() 36 | target=target.cuda() 37 | out=net(inp) 38 | pos_num=torch.sum(out.argmax(1)==target).item() 39 | pos+=pos_num 40 | tot+=inp.size(0) 41 | q.set_postfix({"acc":pos/tot}) 42 | if i >= max_iteration: 43 | break 44 | print(pos/tot) 45 | return pos/tot 46 | 47 | def process(pid, experiment_process, args_queue, n_gpu): 48 | """ 49 | worker process. 50 | """ 51 | gpu_id=pid%n_gpu 52 | os.environ['CUDA_VISIBLE_DEVICES']=f'{gpu_id}' 53 | 54 | tot_run=0 55 | while args_queue.qsize(): 56 | test_args=args_queue.get() 57 | print(f"Run {test_args} on pid={pid} gpu_id={gpu_id}") 58 | experiment_process(**test_args) 59 | time.sleep(0.5) 60 | tot_run+=1 61 | # run_experiment(**args) 62 | print(f"{pid} tot_run {tot_run}") 63 | 64 | 65 | def multiprocess(experiment_process, cfg_list=None, n_gpu=6): 66 | """ 67 | run experiment processes on "n_gpu" cards via "n_gpu" worker process. 68 | "cfg_list" arranges kwargs for each test point, and worker process will fetch kwargs and carry out an experiment. 69 | """ 70 | args_queue = multiprocessing.Queue() 71 | for cfg in cfg_list: 72 | args_queue.put(cfg) 73 | 74 | ps=[] 75 | for pid in range(n_gpu): 76 | p=multiprocessing.Process(target=process,args=(pid,experiment_process,args_queue,n_gpu)) 77 | p.start() 78 | ps.append(p) 79 | for p in ps: 80 | p.join() 81 | 82 | def init_config(config_name): 83 | """initialize the config. Use reload to make sure it's fresh one!""" 84 | _,_,files = next(os.walk("./configs")) 85 | if config_name+".py" in files: 86 | quant_cfg = import_module(f"configs.{config_name}") 87 | else: 88 | raise NotImplementedError(f"Invalid config name {config_name}") 89 | reload(quant_cfg) 90 | return quant_cfg 91 | 92 | 93 | def experiment_basic(net='vit_base_patch16_384', config="PTQ4ViT"): 94 | """ 95 | A basic testbench. 96 | """ 97 | quant_cfg = init_config(config) 98 | net = get_net(net) 99 | wrapped_modules = net_wrap.wrap_modules_in_net(net,quant_cfg) 100 | 101 | g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16,kwargs={"model":net}) 102 | test_loader=g.test_loader() 103 | calib_loader=g.calib_loader(num=32) 104 | 105 | quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 106 | quant_calibrator.batching_quant_calib() 107 | 108 | test_classification(net,test_loader) 109 | 110 | if __name__=='__main__': 111 | args = parse_args() 112 | cfg_list = [] 113 | 114 | nets = ['vit_tiny_patch16_224', "deit_base_patch16_384"] 115 | configs= ['PTQ4ViT'] 116 | 117 | cfg_list = [{ 118 | "net":net, 119 | "config":config, 120 | } 121 | for net, config in product(nets, configs) 122 | ] 123 | 124 | if args.multiprocess: 125 | multiprocess(experiment_basic, cfg_list, n_gpu=args.n_gpu) 126 | else: 127 | for cfg in cfg_list: 128 | experiment_basic(**cfg) 129 | 130 | 131 | -------------------------------------------------------------------------------- /quant_layers/conv.py: -------------------------------------------------------------------------------- 1 | from numpy import not_equal 2 | from torch import tensor 3 | from quant_layers.linear import MinMaxQuantLinear 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from itertools import product 8 | 9 | class MinMaxQuantConv2d(nn.Conv2d): 10 | """ 11 | MinMax quantize weight and output 12 | """ 13 | def __init__(self,in_channels: int, 14 | out_channels: int, 15 | kernel_size, 16 | stride = 1, 17 | padding = 0, 18 | dilation = 1, 19 | groups: int = 1, 20 | bias: bool = True, 21 | padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None): 22 | super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode) 23 | self.n_calibration_steps=2 24 | self.mode=mode 25 | self.w_bit=w_bit 26 | self.a_bit=a_bit 27 | self.bias_bit=bias_bit 28 | assert bias_bit is None,"No support bias bit now" 29 | self.w_interval=None 30 | self.a_interval=None 31 | self.bias_interval=None 32 | self.raw_input=None 33 | self.raw_out=None 34 | self.metric=None 35 | self.next_nodes=[] 36 | self.w_qmax=2**(self.w_bit-1) 37 | self.a_qmax=2**(self.a_bit-1) 38 | # self.bias_qmax=2**(self.bias_bit-1) 39 | 40 | def forward(self, x): 41 | if self.mode=='raw': 42 | out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 43 | elif self.mode=="quant_forward": 44 | out=self.quant_forward(x) 45 | elif self.mode=="calibration_step1": 46 | out=self.calibration_step1(x) 47 | elif self.mode=="calibration_step2": 48 | out=self.calibration_step2(x) 49 | else: 50 | raise NotImplementedError 51 | return out 52 | 53 | def quant_weight_bias(self): 54 | w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1) 55 | w_sim=w.mul_(self.w_interval) 56 | if self.bias is not None: 57 | return w_sim,self.bias 58 | # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1) 59 | # bias_sim=bias*self.bias_interval 60 | # return w_sim,bias_sim 61 | else: 62 | return w_sim,None 63 | 64 | def quant_input(self,x): 65 | x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1) 66 | x_sim.mul_(self.a_interval) 67 | return x_sim 68 | 69 | def quant_forward(self,x): 70 | assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" 71 | w_sim,bias_sim=self.quant_weight_bias() 72 | x_sim=self.quant_input(x) 73 | out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) 74 | return out 75 | 76 | def calibration_step1(self,x): 77 | # step1: collection the FP32 values 78 | out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 79 | self.raw_input=x.cpu().detach() 80 | self.raw_out=out.cpu().detach() 81 | return out 82 | 83 | def calibration_step2(self,x): 84 | # step2: search for the best S^w and S^a of each layer 85 | self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach() 86 | self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach() 87 | self.calibrated=True 88 | out=self.quant_forward(x) 89 | return out 90 | 91 | class QuantileQuantConv2d(MinMaxQuantConv2d): 92 | """ 93 | Quantile quantize weight and output 94 | """ 95 | def __init__(self, 96 | in_channels: int, 97 | out_channels: int, 98 | kernel_size, 99 | stride = 1, 100 | padding = 0, 101 | dilation = 1, 102 | groups: int = 1, 103 | bias: bool = True, 104 | padding_mode: str = 'zeros', 105 | mode='raw',w_bit=8,a_bit=8,bias_bit=None, 106 | w_quantile=0.9999,a_quantile=0.9999): 107 | super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,mode,w_bit,a_bit,bias_bit) 108 | self.w_quantile = w_quantile 109 | self.a_quantile = a_quantile 110 | 111 | def _quantile(self, tensor, quantile): 112 | if tensor.numel() >= 16777216: 113 | n = tensor.numel()//16777216 114 | return torch.quantile(tensor.view(-1)[:16777216*n].view(n,16777216),quantile,1).mean() 115 | else: 116 | return torch.quantile(tensor,quantile) 117 | 118 | def calibration_step2(self,x): 119 | # step2: search for the best S^w and S^o of each layer 120 | self.w_interval=(self._quantile(self.weight.data.abs(),self.w_quantile)/(self.w_qmax-0.5)).detach() 121 | self.a_interval=(self._quantile(x.abs(),self.a_quantile)/(self.a_qmax-0.5)).detach() 122 | self.calibrated=True 123 | out=self.quant_forward(x) 124 | return out 125 | 126 | class PTQSLQuantConv2d(MinMaxQuantConv2d): 127 | """ 128 | PTQSL on Conv2d 129 | weight: (oc,ic,kw,kh) -> (oc,ic*kw*kh) -> divide into sub-matrixs and quantize 130 | input: (B,ic,W,H), keep this shape 131 | 132 | Only support SL quantization on weights. 133 | """ 134 | def __init__(self, in_channels: int, 135 | out_channels: int, 136 | kernel_size, 137 | stride = 1, 138 | padding = 0, 139 | dilation = 1, 140 | groups: int = 1, 141 | bias: bool = True, 142 | padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None, 143 | metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, 144 | n_V=1, n_H=1, init_layerwise=False): 145 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit) 146 | self.metric = metric 147 | self.search_round = search_round 148 | self.eq_alpha = eq_alpha 149 | self.eq_beta = eq_beta 150 | self.eq_n = eq_n 151 | self.parallel_eq_n = parallel_eq_n 152 | self.n_H = n_H 153 | self.n_V = n_V 154 | self.init_layerwise = init_layerwise 155 | self.raw_grad = None 156 | 157 | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1): 158 | """ 159 | tensor_raw: *, features 160 | tensor_sim: *, features 161 | similarity: * 162 | It's your job to calculate mean on * dims! 163 | """ 164 | if metric == "cosine": 165 | similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) 166 | else: 167 | if metric == "L1_norm": 168 | similarity = -torch.abs(tensor_raw - tensor_sim) 169 | elif metric == "L2_norm": 170 | similarity = -(tensor_raw - tensor_sim) ** 2 171 | elif metric == "linear_weighted_L2_norm": 172 | similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 173 | elif metric == "square_weighted_L2_norm": 174 | similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 175 | elif metric == "hessian": 176 | raw_grad = self.raw_grad.reshape_as(tensor_raw) 177 | similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 178 | else: 179 | raise NotImplementedError(f"metric {metric} not implemented!") 180 | similarity = torch.mean(similarity, dim=dim) 181 | return similarity 182 | 183 | def quant_weight_bias(self): 184 | # self.weight_interval shape: n_V, 1, n_H, 1 185 | oc,ic,kw,kh=self.weight.data.shape 186 | w_sim = self.weight.view(self.n_V, oc//self.n_V, self.n_H, (ic*kw*kh)//self.n_H) 187 | w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval) 188 | w_sim = w_sim.view(oc,ic,kw,kh) 189 | return w_sim, self.bias 190 | 191 | def _search_best_w_interval(self, x, weight_interval_candidates): 192 | """ 193 | Modularization of searching best weight intervals 194 | """ 195 | tmp_w_interval = self.w_interval.unsqueeze(0) 196 | for v,h in product(range(self.n_V), range(self.n_H)): 197 | similarities = [] 198 | for p_st in range(0, self.eq_n, self.parallel_eq_n): 199 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 200 | cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1) 201 | cur_w_interval[:,v:v+1,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,v:v+1,:,h:h+1,:] 202 | # quantize weight and bias 203 | oc,ic,kw,kh=self.weight.data.shape 204 | w_sim = self.weight.view(self.n_V,oc//self.n_V,self.n_H,-1).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols 205 | w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols 206 | w_sim = w_sim.view(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh 207 | bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None 208 | # quantize input 209 | x_sim = self.quant_input(x) 210 | # calculate similarity and store them 211 | out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: B,parallel_eq_n*oc,fw,fh 212 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: B,parallel_eq_n,oc,fw,fh 213 | similarity = self._get_similarity(self.raw_out, out_sim, self.metric, dim=2) # shape: B,parallel_eq_n,fw,fh 214 | similarity = torch.mean(similarity, [0,2,3]) # shape: parallel_eq_n 215 | similarities.append(similarity) 216 | # store best weight interval of h into tmp_w_interval 217 | similarities = torch.cat(similarities, dim=0) # shape: eq_n 218 | best_index = similarities.argmax(dim=0).reshape(-1,1,1,1,1) 219 | tmp_w_interval[:,v:v+1,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,v:v+1,:,h:h+1,:],dim=0,index=best_index) 220 | self.w_interval = tmp_w_interval.squeeze(dim=0) 221 | 222 | def _search_best_a_interval(self, x, input_interval_candidates): 223 | similarities = [] 224 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 225 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 226 | cur_a_interval = input_interval_candidates[p_st:p_ed] 227 | # quantize weight and bias 228 | w_sim, bias_sim = self.quant_weight_bias() 229 | # quantize input 230 | B,ic,iw,ih = x.shape 231 | x_sim=x.unsqueeze(0) # shape: 1,B,ic,iw,ih 232 | x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,B,ic,iw,ih 233 | x_sim=x_sim.view(-1,ic,iw,ih) 234 | # calculate similarity and store them 235 | out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*B,oc,fw,fh 236 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,B,oc,fw,fh 237 | similarity = self._get_similarity(self.raw_out.transpose(0,1), out_sim, self.metric, dim=2) # shape: parallel_eq_n,B,fw,fh 238 | similarity = torch.mean(similarity, dim=[1,2,3]) # shape: parallel_eq_n 239 | similarities.append(similarity) 240 | # store best input interval and store in tmp_a_interval 241 | similarities = torch.cat(similarities, dim=0) # shape: eq_n 242 | a_best_index = similarities.argmax(dim=0).view(1,1,1,1,1) 243 | self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze() 244 | 245 | 246 | def _initialize_intervals(self, x): 247 | self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach() 248 | if self.init_layerwise: 249 | self.w_interval = ((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) 250 | else: 251 | self.w_interval = (self.weight.view(self.n_V,self.out_channels//self.n_V,self.n_H,-1).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) 252 | 253 | def calibration_step2(self, x): 254 | # initialize intervals with minmax intervals 255 | self._initialize_intervals(x) 256 | 257 | # put raw outs on GPU 258 | self.raw_out = self.raw_out.to(x.device).unsqueeze(1) # shape: B,1,oc,W,H 259 | 260 | # put raw grad on GPU 261 | self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None 262 | 263 | # prepare weight intervals and similarities 264 | weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 265 | input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: nq_n,1,1,1,1 266 | for e in range(self.search_round): 267 | # search for best weight interval 268 | self._search_best_w_interval(x, weight_interval_candidates) 269 | # search for best input interval 270 | self._search_best_a_interval(x, input_interval_candidates) 271 | 272 | self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None 273 | 274 | self.calibrated = True 275 | out=self.quant_forward(x) 276 | del self.raw_input, self.raw_out, self.raw_grad 277 | return out 278 | 279 | class BatchingEasyQuantConv2d(PTQSLQuantConv2d): 280 | """An agile implementation of Layerwise Easyquant""" 281 | def __init__(self, in_channels: int, 282 | out_channels: int, 283 | kernel_size, 284 | stride = 1, 285 | padding = 0, 286 | dilation = 1, 287 | groups: int = 1, 288 | bias: bool = True, 289 | padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None, 290 | metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, 291 | n_V=1, n_H=1, init_layerwise=False): 292 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, 293 | mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_V=n_V, n_H=n_H, init_layerwise=init_layerwise) 294 | self.n_V = 1 295 | self.n_H = 1 296 | 297 | def _initialize_calib_parameters(self): 298 | """ 299 | set parameters for feeding calibration data 300 | """ 301 | self.calib_size = int(self.raw_input.shape[0]) 302 | self.calib_batch_size = int(self.raw_input.shape[0]) 303 | while True: 304 | numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU 305 | self.parallel_eq_n = int((15*1024*1024*1024/4)//numel) 306 | if self.parallel_eq_n <= 1: 307 | self.calib_need_batching = True 308 | self.calib_batch_size //= 2 309 | else: 310 | break 311 | 312 | def _initialize_intervals(self): 313 | self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach() 314 | tmp_a_intervals = [] 315 | for b_st in range(0,self.calib_size,self.calib_batch_size): 316 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 317 | x_ = self.raw_input[b_st:b_ed].cuda() 318 | a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1) 319 | tmp_a_intervals.append(a_interval_) 320 | self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False) 321 | 322 | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None): 323 | """ 324 | tensor_raw: *, features 325 | tensor_sim: *, features 326 | similarity: * 327 | It's your job to calculate mean on * dims! 328 | """ 329 | if metric == "cosine": 330 | similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) 331 | elif metric == "pearson": 332 | # calculate similarity w.r.t complete feature map, but maintain dimension requirement 333 | b, parallel_eq_n = tensor_sim.shape[0], tensor_sim.shape[1] 334 | similarity = F.cosine_similarity(tensor_raw.view(b,1,-1), tensor_sim.view(b,parallel_eq_n,-1), dim=dim).view(b,parallel_eq_n,1,1) 335 | else: 336 | if metric == "L1_norm": 337 | similarity = -torch.abs(tensor_raw - tensor_sim) 338 | elif metric == "L2_norm": 339 | similarity = -(tensor_raw - tensor_sim) ** 2 340 | elif metric == "linear_weighted_L2_norm": 341 | similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 342 | elif metric == "square_weighted_L2_norm": 343 | similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 344 | elif metric == "hessian": 345 | assert raw_grad != None, f"No raw grad!" 346 | raw_grad = raw_grad.reshape_as(tensor_raw) 347 | similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 348 | else: 349 | raise NotImplementedError(f"metric {metric} not implemented!") 350 | similarity = torch.mean(similarity, dim=dim) 351 | return similarity 352 | 353 | def quant_weight_bias(self): 354 | w_sim = self.weight 355 | w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval) 356 | return w_sim, self.bias 357 | 358 | def quant_forward(self, x): 359 | assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" 360 | w_sim,bias_sim=self.quant_weight_bias() 361 | x_sim=self.quant_input(x) if self.a_bit < 32 else x 362 | out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) 363 | return out 364 | 365 | def _search_best_w_interval(self, weight_interval_candidates): 366 | batch_similarities = [] 367 | for b_st in range(0,self.calib_size,self.calib_batch_size): 368 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 369 | x = self.raw_input[b_st:b_ed].cuda() 370 | raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh 371 | raw_grad = self.raw_grad[b_st:b_ed].cuda() 372 | similarities = [] 373 | for p_st in range(0, self.eq_n, self.parallel_eq_n): 374 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 375 | cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1 376 | # quantize weight and bias 377 | oc,ic,kw,kh = self.weight.data.shape 378 | w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh 379 | w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh 380 | w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh 381 | bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None 382 | # quantize input 383 | x_sim = self.quant_input(x) 384 | # calculate similarity and store them 385 | out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh 386 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh 387 | similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: b,parallel_eq_n,fw,fh 388 | similarity = torch.mean(similarity, [2,3]) # shape: b,parallel_eq_n 389 | similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n 390 | similarities.append(similarity) 391 | # store best weight interval of h into tmp_w_interval 392 | similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n 393 | batch_similarities.append(similarities) 394 | batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n 395 | best_index = batch_similarities.argmax(dim=0).reshape(1,1,1,1,1) # shape: 1,1,1,1,1 396 | self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0) 397 | 398 | def _search_best_a_interval(self, input_interval_candidates): 399 | batch_similarities = [] 400 | for b_st in range(0,self.calib_size,self.calib_batch_size): 401 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 402 | x = self.raw_input[b_st:b_ed].cuda() 403 | raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(0) # shape: 1,b,oc,fw,fh 404 | raw_grad = self.raw_grad[b_st:b_ed].cuda() 405 | similarities = [] 406 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 407 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 408 | cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1 409 | # quantize weight and bias 410 | w_sim, bias_sim = self.quant_weight_bias() 411 | # quantize input 412 | B,ic,iw,ih = x.shape 413 | x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih 414 | x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih 415 | x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih 416 | # calculate similarity and store them 417 | out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh 418 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh 419 | similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: parallel_eq_n,b,fw,fh 420 | similarity = torch.mean(similarity, dim=[3,4]) # shape: parallel_eq_n,b 421 | similarity = torch.sum(similarity, dim=1, keepdim=True) # shape: parallel_eq_n,1 422 | similarities.append(similarity) 423 | similarities = torch.cat(similarities, dim=0) # shape: eq_n, 1 424 | batch_similarities.append(similarities) 425 | batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n 426 | a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1) 427 | self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze() 428 | 429 | def calibration_step2(self): 430 | self._initialize_calib_parameters() 431 | self._initialize_intervals() 432 | weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval # shape: eq_n,1,1,1,1 433 | input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1 434 | for e in range(self.search_round): 435 | # search for best weight interval 436 | self._search_best_w_interval(weight_interval_candidates) 437 | # search for best input interval 438 | if self.a_bit < 32: 439 | self._search_best_a_interval(input_interval_candidates) 440 | self.calibrated = True 441 | del self.raw_input, self.raw_out, self.raw_grad 442 | 443 | 444 | class ChannelwiseBatchingQuantConv2d(PTQSLQuantConv2d): 445 | """ 446 | Only implemented acceleration with batching_calibration_step2 447 | 448 | setting a_bit to >= 32 will use minmax quantization, which means turning off activation quantization 449 | """ 450 | def __init__(self, in_channels: int, 451 | out_channels: int, 452 | kernel_size, 453 | stride = 1, 454 | padding = 0, 455 | dilation = 1, 456 | groups: int = 1, 457 | bias: bool = True, 458 | padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None, 459 | metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, 460 | n_V=1, n_H=1, init_layerwise=False): 461 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, 462 | mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, 463 | n_V=n_V, n_H=n_H, init_layerwise=init_layerwise) 464 | self.n_V = self.out_channels 465 | self.n_H = 1 466 | 467 | def _initialize_calib_parameters(self): 468 | """ 469 | set parameters for feeding calibration data 470 | """ 471 | self.calib_size = int(self.raw_input.shape[0]) 472 | self.calib_batch_size = int(self.raw_input.shape[0]) 473 | while True: 474 | numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU 475 | self.parallel_eq_n = int((15*1024*1024*1024/4)//numel) 476 | if self.parallel_eq_n <= 1: 477 | self.calib_need_batching = True 478 | self.calib_batch_size //= 2 479 | else: 480 | break 481 | 482 | def _initialize_intervals(self): 483 | # weight intervals: shape oc,1,1,1 484 | if self.init_layerwise: 485 | self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.out_channels,1,1,1) 486 | else: 487 | self.w_interval=((self.weight.abs().amax([1,2,3],keepdim=True))/(self.w_qmax-0.5)) 488 | 489 | # activation intervals: shape 1 490 | tmp_a_intervals = [] 491 | for b_st in range(0,self.calib_size,self.calib_batch_size): 492 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 493 | x_ = self.raw_input[b_st:b_ed].cuda() 494 | a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1) 495 | tmp_a_intervals.append(a_interval_) 496 | self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False) 497 | 498 | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None): 499 | """ 500 | tensor_raw: *, features 501 | tensor_sim: *, features 502 | similarity: *, features 503 | """ 504 | if metric == "cosine": 505 | # support cosine on patch dim, which is sub-optimal 506 | # not supporting search best a interval 507 | b, parallel_eq_n, oc = tensor_sim.shape[0], tensor_sim.shape[1], tensor_sim.shape[2] 508 | similarity = F.cosine_similarity(tensor_raw.view(b,1,oc,-1), tensor_sim.view(b,parallel_eq_n,oc,-1), dim=-1).view(b,parallel_eq_n,oc,1,1) 509 | else: 510 | if metric == "L1_norm": 511 | similarity = -torch.abs(tensor_raw - tensor_sim) 512 | elif metric == "L2_norm": 513 | similarity = -(tensor_raw - tensor_sim) ** 2 514 | elif metric == "linear_weighted_L2_norm": 515 | similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 516 | elif metric == "square_weighted_L2_norm": 517 | similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 518 | elif metric == "hessian": 519 | assert raw_grad != None, f"raw_grad is None in _get_similarity!" 520 | raw_grad = raw_grad.reshape_as(tensor_raw) 521 | similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 522 | else: 523 | raise NotImplementedError(f"metric {metric} not implemented!") 524 | return similarity 525 | 526 | def _search_best_w_interval(self, weight_interval_candidates): 527 | batch_similarities = [] 528 | for b_st in range(0,self.calib_size,self.calib_batch_size): 529 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 530 | x = self.raw_input[b_st:b_ed].cuda() 531 | raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh 532 | raw_grad = self.raw_grad[b_st:b_ed].cuda() 533 | similarities = [] 534 | for p_st in range(0, self.eq_n, self.parallel_eq_n): 535 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 536 | cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,oc,1,1,1 537 | # quantize weight and bias 538 | oc,ic,kw,kh = self.weight.data.shape 539 | w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh 540 | w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh 541 | w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh 542 | bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None 543 | # quantize input 544 | x_sim = self.quant_input(x) if self.a_bit < 32 else x 545 | # calculate similarity and store them 546 | out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh 547 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh 548 | similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad) # shape: b,parallel_eq_n,oc,fw,fh 549 | similarity = torch.mean(similarity, [3,4]) # shape: b,parallel_eq_n,oc 550 | similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n, oc 551 | similarities.append(similarity) 552 | # store best weight interval of h into tmp_w_interval 553 | similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n,oc 554 | batch_similarities.append(similarities) 555 | batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n,oc 556 | best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,oc,1,1,1 557 | self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0) 558 | 559 | def _search_best_a_interval(self, input_interval_candidates): 560 | batch_similarities = [] 561 | for b_st in range(0,self.calib_size,self.calib_batch_size): 562 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 563 | x = self.raw_input[b_st:b_ed].cuda() 564 | raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh 565 | raw_grad = self.raw_grad[b_st:b_ed].cuda() 566 | similarities = [] 567 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 568 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 569 | cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1 570 | # quantize weight and bias 571 | w_sim, bias_sim = self.quant_weight_bias() 572 | # quantize input 573 | B,ic,iw,ih = x.shape 574 | x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih 575 | x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih 576 | x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih 577 | # calculate similarity and store them 578 | out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh 579 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh 580 | out_sim = out_sim.transpose_(0, 1) # shape: b,parallel_eq_n,oc,fw,fh 581 | similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: b,parallel_eq_n,oc,fw,fh 582 | similarity = torch.mean(similarity, dim=[2,3,4]) # shape: b,parallel_eq_n 583 | similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1,parallel_eq_n 584 | similarities.append(similarity) 585 | similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n 586 | batch_similarities.append(similarities) 587 | batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n 588 | a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1) 589 | self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze() 590 | 591 | def calibration_step2(self): 592 | self._initialize_calib_parameters() 593 | self._initialize_intervals() 594 | weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,oc,1,1,1 595 | input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1 596 | for e in range(self.search_round): 597 | # search for best weight interval 598 | self._search_best_w_interval(weight_interval_candidates) 599 | # search for best input interval 600 | if self.a_bit < 32: 601 | self._search_best_a_interval(input_interval_candidates) 602 | self.calibrated = True 603 | del self.raw_input, self.raw_out, self.raw_grad 604 | 605 | def quant_weight_bias(self): 606 | w_sim = (self.weight/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval) 607 | return w_sim, self.bias 608 | 609 | def quant_forward(self, x): 610 | assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" 611 | w_sim,bias_sim=self.quant_weight_bias() 612 | x_sim=self.quant_input(x) if self.a_bit < 32 else x 613 | out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) 614 | return out -------------------------------------------------------------------------------- /quant_layers/linear.py: -------------------------------------------------------------------------------- 1 | from quant_layers.matmul import PTQSLBatchingQuantMatMul 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class MinMaxQuantLinear(nn.Linear): 7 | def __init__(self, 8 | in_features: int, 9 | out_features: int, 10 | bias: bool = True, 11 | mode = "raw", 12 | w_bit = 8, 13 | a_bit = 8, 14 | bias_bit = None, 15 | bias_correction=False): 16 | super().__init__(in_features,out_features,bias) 17 | self.n_calibration_step=2 18 | self.mode = mode 19 | self.w_bit = w_bit 20 | self.a_bit = a_bit 21 | self.bias_bit=bias_bit 22 | assert bias_bit is None,"No support bias bit now" 23 | self.w_interval=None 24 | self.a_interval=None 25 | self.raw_input=None 26 | self.raw_out=None 27 | self.metric=None 28 | self.next_nodes=[] 29 | self.w_qmax=2**(self.w_bit-1) 30 | self.a_qmax=2**(self.a_bit-1) 31 | self.bias_correction = bias_correction 32 | 33 | def forward(self, x): 34 | if self.mode=='raw': 35 | out=F.linear(x, self.weight, self.bias) 36 | elif self.mode=="quant_forward": 37 | out=self.quant_forward(x) 38 | elif self.mode=="calibration_step1": 39 | out=self.calibration_step1(x) 40 | elif self.mode=="calibration_step2": 41 | out=self.calibration_step2(x) 42 | else: 43 | raise NotImplementedError 44 | return out 45 | 46 | def quant_weight_bias(self): 47 | w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1) 48 | w_sim=w.mul_(self.w_interval) 49 | if self.bias is not None: 50 | return w_sim,self.bias 51 | # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1) 52 | # bias_sim=bias*self.bias_interval 53 | # return w_sim,bias_sim 54 | else: 55 | return w_sim,None 56 | 57 | def quant_input(self, x): 58 | x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1) 59 | x_sim.mul_(self.a_interval) 60 | return x_sim 61 | 62 | def quant_forward(self,x): 63 | assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" 64 | w_sim,bias_sim=self.quant_weight_bias() 65 | x_sim=self.quant_input(x) 66 | out=F.linear(x_sim, w_sim, bias_sim) 67 | return out 68 | 69 | def _bias_correction_quant_forward(self, x): 70 | if self.bias_correction and self.bias != None: 71 | w_sim = self.quant_weight_bias()[0] 72 | x_sim = self.quant_input(x) 73 | eps = F.linear(x_sim, w_sim-self.weight.data, None) 74 | eps = torch.mean(eps, dim=(list(range(len(eps.shape)-1))), keepdim=False) 75 | self.bias -= eps 76 | self.bias_correction = False 77 | return self.quant_forward(x) 78 | 79 | def calibration_step1(self,x): 80 | # step1: collection the FP32 values 81 | out=F.linear(x, self.weight, self.bias) 82 | self.raw_input=x.cpu().detach() 83 | self.raw_out=out.cpu().detach() 84 | return out 85 | 86 | def calibration_step2(self,x): 87 | # step2: search for the best S^w and S^o of each layer 88 | self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach() 89 | self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach() 90 | self.calibrated=True 91 | out=self._bias_correction_quant_forward(x) 92 | return out 93 | 94 | class PTQSLQuantLinear(MinMaxQuantLinear): 95 | """ 96 | PTQSL on linear modules. 97 | """ 98 | def __init__(self, 99 | in_features: int, 100 | out_features: int, 101 | bias: bool = True, 102 | mode = "raw", 103 | w_bit = 8, 104 | a_bit = 8, 105 | bias_bit = None, 106 | bias_correction = False, 107 | metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): 108 | super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction) 109 | self.metric = metric 110 | self.search_round = search_round 111 | self.eq_alpha = eq_alpha 112 | self.eq_beta = eq_beta 113 | self.eq_n = eq_n 114 | self.n_H = n_H 115 | self.n_V = n_V 116 | self.n_a = n_a 117 | self.crb_rows = out_features // n_V 118 | self.crb_cols = in_features // n_H # ignore remnent != 0 situations 119 | self.crb_acts = in_features // n_a 120 | self.parallel_eq_n = parallel_eq_n 121 | self.init_layerwise = init_layerwise 122 | self.raw_grad = None 123 | 124 | def _get_similarity(self, tensor_raw, tensor_sim, metric=None): 125 | """ 126 | tensor_raw: *, features 127 | tensor_sim: *, features 128 | similarity: * 129 | It's your job to calculate mean on * dims! 130 | """ 131 | if metric == "cosine": 132 | similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1) 133 | elif metric == "pearson": 134 | similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=-1,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=-1,keepdim=True), dim=-1) 135 | else: 136 | if metric == "L1_norm": 137 | similarity = -torch.abs(tensor_raw - tensor_sim) 138 | elif metric == "L2_norm": 139 | similarity = -(tensor_raw - tensor_sim) ** 2 140 | elif metric == "linear_weighted_L2_norm": 141 | similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 142 | elif metric == "square_weighted_L2_norm": 143 | similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 144 | elif metric == "hessian": 145 | raw_grad = self.raw_grad.reshape_as(tensor_raw) 146 | similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 147 | else: 148 | raise NotImplementedError(f"metric {metric} not implemented!") 149 | similarity = torch.mean(similarity, dim=-1) 150 | return similarity 151 | 152 | def quant_weight_bias(self): 153 | # self.w_interval shape: n_V, 1, n_H, 1 154 | w=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols)/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1) 155 | w_sim=w.mul_(self.w_interval).view(self.out_features,self.in_features) 156 | if self.bias is not None: 157 | return w_sim,self.bias 158 | # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1) 159 | # bias_sim=bias*self.bias_interval 160 | # return w_sim,bias_sim 161 | else: 162 | return w_sim,None 163 | 164 | def quant_input(self, x): 165 | # self.a_interval shape: n_a,1 166 | x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2) 167 | x_sim=(x_sim.div_(self.a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1) 168 | x_sim = x_sim.mul_(self.a_interval).reshape_as(x) 169 | return x_sim 170 | 171 | def _search_best_w_interval(self, x, weight_interval_candidates, raw_out_expanded_chunked): 172 | """ 173 | Modularization of searching best weight intervals 174 | """ 175 | tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1 176 | for h in range(self.n_H): 177 | similarities = [] 178 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 179 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 180 | cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1) 181 | cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:] 182 | # quantize weight and bias 183 | w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols 184 | w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols 185 | w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic 186 | bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None 187 | # quantize input 188 | x_sim = self.quant_input(x) 189 | # calculate similarity and store them 190 | out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n*oc 191 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,oc 192 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,n_V,crb_rows 193 | similarity = self._get_similarity(raw_out_expanded_chunked, out_sim, self.metric) # shape: B,*,parallel_eq_n,n_V 194 | similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-2))) # shape: parallel_eq_n, n_V 195 | similarities.append(similarity) 196 | # store best weight interval of h into tmp_w_interval 197 | similarities = torch.cat(similarities, dim=0) # shape: eq_n, n_V 198 | h_best_index = similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1 199 | tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index) 200 | self.w_interval = tmp_w_interval.squeeze(dim=0) 201 | 202 | def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded): 203 | tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1 204 | for a in range(self.n_a): 205 | similarities = [] 206 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 207 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 208 | cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n 209 | cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] 210 | # quantize weight and bias 211 | w_sim, bias_sim = self.quant_weight_bias() 212 | # quantize input 213 | x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) 214 | x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n 215 | x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic 216 | # calculate similarity and store them 217 | out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc 218 | similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n 219 | similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n 220 | similarities.append(similarity) 221 | # store best input interval and store in tmp_a_interval 222 | similarities = torch.cat(similarities, dim=0) # shape: eq_n 223 | a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) 224 | tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) 225 | self.a_interval = tmp_a_interval.squeeze(-1) 226 | 227 | def _initialize_intervals(self, x): 228 | if self.init_layerwise: 229 | self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) 230 | self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1) 231 | else: 232 | self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) 233 | self.a_interval=((x.view(*x.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1) 234 | 235 | def calibration_step2(self,x): 236 | # initialize intervals with minmax intervals 237 | self._initialize_intervals(x) 238 | 239 | # put raw outs on GPU 240 | raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2) # shape: B,*,1,oc 241 | raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows 242 | 243 | # put raw grad on GPU 244 | self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None 245 | 246 | # prepare weight intervals and similarities 247 | weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 248 | input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n 249 | for e in range(self.search_round): 250 | # search for best weight interval 251 | self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked) 252 | # search for best input interval 253 | self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded) 254 | 255 | self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None 256 | 257 | self.calibrated = True 258 | out=self._bias_correction_quant_forward(x) 259 | del self.raw_input, self.raw_out, self.raw_grad 260 | return out 261 | 262 | class PostGeluPTQSLQuantLinear(PTQSLQuantLinear): 263 | def __init__(self, 264 | in_features: int, 265 | out_features: int, 266 | bias: bool = True, 267 | mode = "raw", 268 | w_bit = 8, 269 | a_bit = 8, 270 | bias_bit = None, 271 | bias_correction = False, 272 | metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): 273 | super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, 274 | metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise) 275 | 276 | def quant_input(self, x): 277 | """ 278 | self.a_interval = [a_interval_pos, a_interval_neg] 279 | """ 280 | # self.a_interval[0] shape: n_a,1 281 | # self.a_interval[1] shape: 1 282 | x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2) 283 | x_pos=(x_/(self.a_interval[0])).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval[0]) 284 | x_neg=(x_/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0).mul_(self.a_interval[1]) 285 | return (x_pos + x_neg).reshape_as(x) 286 | 287 | def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded): 288 | tmp_a_interval = self.a_interval[0].unsqueeze(-1) # shape: n_a,1,1 289 | for a in range(self.n_a): 290 | similarities = [] 291 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 292 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 293 | cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n 294 | cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] 295 | # quantize weight and bias 296 | w_sim, bias_sim = self.quant_weight_bias() 297 | # quantize input 298 | x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) 299 | x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n 300 | x_neg=(x_sim/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0)*(self.a_interval[1]) # shape: B,*,n_a,crb_acts,1 301 | x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic 302 | # calculate similarity and store them 303 | out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc 304 | similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n 305 | similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n 306 | similarities.append(similarity) 307 | # store best input interval and store in tmp_a_interval 308 | similarities = torch.cat(similarities, dim=0) # shape: eq_n 309 | a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) 310 | tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) 311 | self.a_interval[0] = tmp_a_interval.squeeze(-1) 312 | 313 | def _initialize_intervals(self, x): 314 | if self.init_layerwise: 315 | self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) 316 | self.a_interval=[(x.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)] 317 | else: 318 | self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) 319 | self.a_interval=[((x.view(*x.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)] 320 | self.a_interval.append(0.16997124254703522/self.a_qmax) 321 | 322 | def calibration_step2(self,x): 323 | # initialize intervals with minmax intervals 324 | self._initialize_intervals(x) 325 | 326 | # put raw outs on GPU 327 | raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2) # shape: B,*,1,oc 328 | raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows 329 | 330 | # put raw grad on GPU 331 | self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None 332 | 333 | # prepare weight intervals and similarities 334 | weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 335 | input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval[0].unsqueeze(-1) # shape: n_a,1,eq_n 336 | for e in range(self.search_round): 337 | # search for best weight interval 338 | self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked) 339 | # search for best input interval 340 | self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded) 341 | 342 | self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None 343 | 344 | self.calibrated = True 345 | out=self._bias_correction_quant_forward(x) 346 | del self.raw_input, self.raw_out, self.raw_grad 347 | return out 348 | 349 | class PTQSLBatchingQuantLinear(PTQSLQuantLinear): 350 | def __init__(self, 351 | in_features: int, 352 | out_features: int, 353 | bias: bool = True, 354 | mode = "raw", 355 | w_bit = 8, 356 | a_bit = 8, 357 | bias_bit = None, 358 | bias_correction = False, 359 | metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): 360 | super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise) 361 | self.calib_size = None 362 | self.calib_batch_size = None 363 | self.calib_need_batching = False 364 | 365 | def _initialize_calib_parameters(self): 366 | """ 367 | set parameters for feeding calibration data 368 | """ 369 | self.calib_size = int(self.raw_input.shape[0]) 370 | self.calib_batch_size = int(self.raw_input.shape[0]) 371 | while True: 372 | numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU 373 | self.parallel_eq_n = int((3*1024*1024*1024/4)//numel) 374 | if self.parallel_eq_n <= 1: 375 | self.calib_need_batching = True 376 | self.calib_batch_size //= 2 377 | else: 378 | break 379 | 380 | def _initialize_intervals(self): 381 | # weight intervals 382 | if self.init_layerwise: 383 | self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) 384 | else: 385 | self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) 386 | 387 | # activation intervals 388 | tmp_a_intervals = [] 389 | for b_st in range(0,self.calib_size,self.calib_batch_size): 390 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 391 | x_ = self.raw_input[b_st:b_ed].cuda() 392 | if self.init_layerwise: 393 | a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1) 394 | else: 395 | a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1) 396 | tmp_a_intervals.append(a_interval_) 397 | self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True) 398 | 399 | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None): 400 | """ 401 | tensor_raw: *, features 402 | tensor_sim: *, features 403 | similarity: * 404 | It's your job to calculate mean on * dims! 405 | """ 406 | if metric == "cosine": 407 | similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1) 408 | else: 409 | if metric == "L1_norm": 410 | similarity = -torch.abs(tensor_raw - tensor_sim) 411 | elif metric == "L2_norm": 412 | similarity = -(tensor_raw - tensor_sim) ** 2 413 | elif metric == "linear_weighted_L2_norm": 414 | similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 415 | elif metric == "square_weighted_L2_norm": 416 | similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 417 | elif metric == "hessian": 418 | assert raw_grad != None, f"raw_grad is None in _get_similarity!" 419 | raw_grad = raw_grad.reshape_as(tensor_raw) 420 | similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 421 | else: 422 | raise NotImplementedError(f"metric {metric} not implemented!") 423 | similarity = torch.mean(similarity, dim=-1) 424 | return similarity 425 | 426 | def _get_pearson_w(self, tensor_raw, tensor_sim): 427 | """ 428 | Quick implementation of similarity-aware linear quantization 429 | tensor_sim: b,*,parallel_eq_n,n_V,crb_rows 430 | tensor_raw: b,*,1,n_V,crb_rows 431 | """ 432 | b, parallel_eq_n, n_V = tensor_sim.shape[0],tensor_sim.shape[-3],tensor_sim.shape[-2] 433 | tensor_sim = tensor_sim.transpose(-1,-3).contiguous_().view(b,-1,n_V,parallel_eq_n) 434 | tensor_raw = tensor_raw.transpose(-1,-3).view(b,-1,n_V,1) 435 | tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True) 436 | tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True) 437 | similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,n_V,parallel_eq_n 438 | similarity = similarity.permute(0,2,1).contiguous_() 439 | return similarity 440 | 441 | def _get_pearson_a(self, tensor_raw, tensor_sim): 442 | """ 443 | Quick implementation of similarity-aware linear quantization 444 | tensor_sim: b,*,parallel_eq_n,oc 445 | tensor_raw: b,*,1,oc 446 | """ 447 | b, parallel_eq_n = tensor_sim.shape[0],tensor_sim.shape[-2] 448 | tensor_sim = tensor_sim.transpose(-1,-2).contiguous_().view(b,-1,parallel_eq_n) 449 | tensor_raw = tensor_raw.transpose(-1,-2).view(b,-1,1) 450 | tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True) 451 | tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True) 452 | similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,parallel_eq_n 453 | return similarity 454 | 455 | def _search_best_w_interval(self, weight_interval_candidates): 456 | tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1 457 | for h in range(self.n_H): 458 | batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax) 459 | for b_st in range(0, self.calib_size, self.calib_batch_size): 460 | b_ed = min(self.calib_size, b_st + self.calib_batch_size) 461 | x = self.raw_input[b_st:b_ed].cuda() 462 | raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc 463 | raw_out_expanded = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,1,n_V,crb_rows 464 | raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later 465 | similarities = [] 466 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 467 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 468 | cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1) 469 | cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:] 470 | # quantize weight and bias 471 | w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols 472 | w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols 473 | w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic 474 | bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None 475 | # quantize input 476 | x_sim = self.quant_input(x) 477 | # calculate similarity and store them 478 | out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n*oc 479 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,oc 480 | out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,n_V,crb_rows 481 | if self.metric != "pearson": 482 | similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n,n_V 483 | if len(similarity.shape) > 3: 484 | similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-2))) # shape: b, parallel_eq_n, n_V 485 | else: 486 | similarity = self._get_pearson_w(raw_out_expanded, out_sim) 487 | similarity = similarity.sum(dim=0, keepdim=True) # shape: 1, parallel_eq_n, n_V 488 | similarities.append(similarity) 489 | # store best weight interval of h into tmp_w_interval 490 | similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n, n_V 491 | batch_similarities.append(similarities) 492 | batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n, n_V 493 | h_best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1 494 | tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index) 495 | self.w_interval = tmp_w_interval.squeeze(dim=0) 496 | 497 | def _search_best_a_interval(self, input_interval_candidates): 498 | tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1 499 | for a in range(self.n_a): 500 | batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax) 501 | for b_st in range(0, self.calib_size, self.calib_batch_size): 502 | b_ed = min(self.calib_size, b_st + self.calib_batch_size) 503 | x = self.raw_input[b_st:b_ed].cuda() 504 | raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc 505 | raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later 506 | similarities = [] 507 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 508 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 509 | cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n 510 | cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] 511 | # quantize weight and bias 512 | w_sim, bias_sim = self.quant_weight_bias() 513 | # quantize input 514 | x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) 515 | x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n 516 | x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic 517 | # calculate similarity and store them 518 | out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc 519 | if self.metric != "pearson": 520 | similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n 521 | if len(similarity.shape) > 2: 522 | similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n 523 | else: 524 | similarity = self._get_pearson_a(raw_out_expanded, out_sim) 525 | similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n 526 | similarities.append(similarity) 527 | # store best input interval and store in tmp_a_interval 528 | similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n 529 | batch_similarities.append(similarities) 530 | batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n 531 | a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) 532 | tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) 533 | self.a_interval = tmp_a_interval.squeeze(-1) 534 | 535 | 536 | def calibration_step2(self): 537 | """ 538 | Only use cached raw inputs/outs/grads 539 | """ 540 | self._initialize_calib_parameters() 541 | self._initialize_intervals() 542 | 543 | # prepare weight intervals and similarities 544 | weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 545 | input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n 546 | for e in range(self.search_round): 547 | # search for best weight interval 548 | self._search_best_w_interval(weight_interval_candidates) 549 | # search for best input interval 550 | self._search_best_a_interval(input_interval_candidates) 551 | 552 | self.calibrated = True 553 | # self._bias_correction_quant_forward(self.raw_input.cuda()) # debugging 554 | del self.raw_input, self.raw_out, self.raw_grad 555 | return None 556 | 557 | class PostGeluPTQSLBatchingQuantLinear(PTQSLBatchingQuantLinear): 558 | """ 559 | An Agile implementation of PostGeluPTQSLBatchingQuantLinear 560 | use a_interval for positive activation quantization and a_neg_interval for negative activation quantization 561 | """ 562 | def __init__(self, 563 | in_features: int, 564 | out_features: int, 565 | bias: bool = True, 566 | mode = "raw", 567 | w_bit = 8, 568 | a_bit = 8, 569 | bias_bit = None, 570 | bias_correction = False, 571 | metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): 572 | super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, 573 | metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise) 574 | self.a_neg_interval = 0.16997124254703522/self.a_qmax 575 | 576 | def _initialize_intervals(self): 577 | # weight intervals 578 | if self.init_layerwise: 579 | self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) 580 | else: 581 | self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) 582 | 583 | # activation intervals (for positive parts) 584 | if self.init_layerwise: 585 | tmp_a_intervals = [] 586 | for b_st in range(0,self.calib_size,self.calib_batch_size): 587 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 588 | x_ = self.raw_input[b_st:b_ed].cuda() 589 | a_interval_=(x_.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1) 590 | tmp_a_intervals.append(a_interval_) 591 | self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True) 592 | else: 593 | tmp_a_intervals = [] 594 | for b_st in range(0,self.calib_size,self.calib_batch_size): 595 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 596 | x_ = self.raw_input[b_st:b_ed].cuda() 597 | a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1) 598 | tmp_a_intervals.append(a_interval_) 599 | self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True) 600 | 601 | def quant_input(self, x): 602 | # self.a_interval shape: n_a,1 603 | # self.a_neg_interval shape: 1 604 | x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2) 605 | x_pos=(x_/(self.a_interval)).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval) 606 | x_neg=(x_/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0).mul_(self.a_neg_interval) 607 | return (x_pos + x_neg).reshape_as(x) 608 | 609 | def _search_best_a_interval(self, input_interval_candidates): 610 | tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1 611 | for a in range(self.n_a): 612 | batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax) 613 | for b_st in range(0, self.calib_size, self.calib_batch_size): 614 | b_ed = min(self.calib_size, b_st + self.calib_batch_size) 615 | x = self.raw_input[b_st:b_ed].cuda() 616 | raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc 617 | raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later 618 | similarities = [] 619 | for p_st in range(0,self.eq_n,self.parallel_eq_n): 620 | p_ed = min(self.eq_n, p_st+self.parallel_eq_n) 621 | cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n 622 | cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] 623 | # quantize weight and bias 624 | w_sim, bias_sim = self.quant_weight_bias() 625 | # quantize input 626 | x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) 627 | x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n 628 | x_neg=(x_sim/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0)*(self.a_neg_interval) # shape: b,*,n_a,crb_acts,1 629 | x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic 630 | # calculate similarity and store them 631 | out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc 632 | similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n 633 | similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n 634 | similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n 635 | similarities.append(similarity) 636 | # store best input interval and store in tmp_a_interval 637 | similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n 638 | batch_similarities.append(similarities) 639 | batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n 640 | a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) 641 | tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) 642 | self.a_interval = tmp_a_interval.squeeze(-1) -------------------------------------------------------------------------------- /quant_layers/matmul.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | from itertools import product 7 | 8 | class MinMaxQuantMatMul(nn.Module): 9 | """Matrix Multiplication base class""" 10 | def __init__(self, A_bit=8, B_bit=8, mode="raw"): 11 | super().__init__() 12 | self.A_bit=A_bit 13 | self.B_bit=B_bit 14 | self.A_interval=None 15 | self.B_interval=None 16 | self.A_qmax=2**(self.A_bit-1) 17 | self.B_qmax=2**(self.B_bit-1) 18 | self.mode=mode 19 | self.raw_input = None 20 | self.raw_out = None 21 | 22 | def forward(self, A,B): 23 | if self.mode=='raw': 24 | out=A @ B 25 | elif self.mode=="quant_forward": 26 | out=self.quant_forward(A,B) 27 | elif self.mode=="calibration_step1": 28 | out=self.calibration_step1(A,B) 29 | elif self.mode=="calibration_step2": 30 | out=self.calibration_step2(A,B) 31 | else: 32 | raise NotImplementedError 33 | return out 34 | 35 | def quant_input(self,x,interval,qmax): 36 | x_sim=(x/interval).round_().clamp_(-qmax,qmax-1) 37 | x_sim.mul_(interval) 38 | return x_sim 39 | 40 | def quant_forward(self,A,B): 41 | assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" 42 | A_sim=self.quant_input(A,self.A_interval,self.A_qmax) 43 | B_sim=self.quant_input(B,self.B_interval,self.B_qmax) 44 | out=A_sim@B_sim 45 | return out 46 | 47 | def calibration_step1(self,A,B): 48 | # step1: collection the FP32 values 49 | self.raw_input=A.cpu().detach(), B.cpu().detach() 50 | out=A@B 51 | self.raw_out=out.cpu().detach() 52 | return out 53 | 54 | def calibration_step2(self,A,B): 55 | # step2: search for the best S^w and S^o of each layer 56 | self.A_interval=(A.data.abs().max()/(self.A_qmax-0.5)).detach() 57 | self.B_interval=(B.data.abs().max()/(self.B_qmax-0.5)).detach() 58 | self.calibrated=True 59 | out=self.quant_forward(A,B) 60 | return out 61 | 62 | class PTQSLQuantMatMul(MinMaxQuantMatMul): 63 | """ 64 | Chunk matrix into blockes and quantize. 65 | Chunking follows naive padding strategy. 66 | Alternately search for best intervals of each individual blocks for A and B. 67 | 68 | two different scenarios: 69 | - Q @ K: 70 | - A's shape: B,H,S,W 71 | - B's shape: B,H,W,S 72 | - scores @ V: 73 | - A's shape: B,H,S,S 74 | - B's shape: B,H,S,W 75 | - interval shape: 1,n_G,1,n_V,1,n_H,1 76 | """ 77 | def __init__(self, A_bit=8, B_bit=8, mode="raw", 78 | metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, 79 | n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False): 80 | super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode) 81 | self.metric = metric 82 | self.search_round = search_round 83 | self.eq_alpha = eq_alpha 84 | self.eq_beta = eq_beta 85 | self.eq_n = eq_n 86 | self.parallel_eq_n = parallel_eq_n 87 | self.n_G_A = n_G_A 88 | self.n_V_A = n_V_A 89 | self.n_H_A = n_H_A 90 | self.n_G_B = n_G_B 91 | self.n_V_B = n_V_B 92 | self.n_H_B = n_H_B 93 | # init these parameters in self.calibration_step1 94 | self.crb_groups_A = None 95 | self.crb_groups_B = None 96 | self.crb_rows_A = None 97 | self.crb_cols_A = None 98 | self.crb_rows_B = None 99 | self.crb_cols_B = None 100 | self.pad_groups_A = None 101 | self.pad_groups_B = None 102 | self.pad_rows_A = None 103 | self.pad_rows_B = None 104 | self.pad_cols_A = None 105 | self.pad_cols_B = None 106 | self.raw_grad = None 107 | self.init_layerwise = init_layerwise 108 | 109 | def _get_padding_parameters(self, A, B): 110 | self.crb_groups_A = (A.shape[1]+self.n_G_A-1) // self.n_G_A 111 | self.crb_groups_B = (B.shape[1]+self.n_G_B-1) // self.n_G_B 112 | self.crb_rows_A = (A.shape[2]+self.n_V_A-1) // self.n_V_A 113 | self.crb_cols_A = (A.shape[3]+self.n_H_A-1) // self.n_H_A 114 | self.crb_rows_B = (B.shape[2]+self.n_V_B-1) // self.n_V_B 115 | self.crb_cols_B = (B.shape[3]+self.n_H_B-1) // self.n_H_B 116 | 117 | self.pad_groups_A = self.crb_groups_A*self.n_G_A - A.shape[1] 118 | self.pad_rows_A = self.crb_rows_A*self.n_V_A - A.shape[2] 119 | self.pad_cols_A = self.crb_cols_A*self.n_H_A - A.shape[3] 120 | self.pad_groups_B = self.crb_groups_B*self.n_G_B - B.shape[1] 121 | self.pad_rows_B = self.crb_rows_B*self.n_V_B - B.shape[2] 122 | self.pad_cols_B = self.crb_cols_B*self.n_H_B - B.shape[3] 123 | 124 | def quant_input_A(self, x): 125 | x = F.pad(x, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]) 126 | x = x.view(-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) 127 | x = (x/self.A_interval).round_().clamp(-self.A_qmax,self.A_qmax-1).mul_(self.A_interval) 128 | x = x.view(-1,self.n_G_A*self.crb_groups_A,self.n_V_A*self.crb_rows_A,self.n_H_A*self.crb_cols_A) 129 | x = x[:,:x.shape[1]-self.pad_groups_A,:x.shape[2]-self.pad_rows_A,:x.shape[3]-self.pad_cols_A] 130 | return x 131 | 132 | def quant_input_B(self, x): 133 | x = F.pad(x, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]) 134 | x = x.view(-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) 135 | x = (x/self.B_interval).round_().clamp(-self.B_qmax,self.B_qmax-1).mul_(self.B_interval) 136 | x = x.view(-1,self.n_G_B*self.crb_groups_B,self.n_V_B*self.crb_rows_B,self.n_H_B*self.crb_cols_B) 137 | x = x[:,:x.shape[1]-self.pad_groups_B,:x.shape[2]-self.pad_rows_B,:x.shape[3]-self.pad_cols_B] 138 | return x 139 | 140 | def quant_forward(self, A, B): 141 | assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" 142 | A_sim=self.quant_input_A(A) 143 | B_sim=self.quant_input_B(B) 144 | out=A_sim@B_sim 145 | return out 146 | 147 | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1): 148 | """ 149 | tensor_raw: *, features, * 150 | tensor_sim: *, features, * 151 | similarity: * 152 | It's your job to calculate mean on non-feature * dims! 153 | 154 | Similarity without inherent feature structure is more welcome to parallelism. 155 | """ 156 | if metric == "cosine": 157 | similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled 158 | elif metric == "pearson": 159 | similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw), tensor_sim-torch.mean(tensor_sim), dim=dim) 160 | else: 161 | if metric == "L1_norm": 162 | similarity = -torch.abs(tensor_raw - tensor_sim) 163 | elif metric == "L2_norm": 164 | similarity = -(tensor_raw - tensor_sim) ** 2 165 | elif metric == "linear_weighted_L2_norm": 166 | similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 167 | elif metric == "square_weighted_L2_norm": 168 | similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 169 | elif metric == "hessian": 170 | raw_grad = self.raw_grad.reshape_as(tensor_raw) 171 | similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 172 | else: 173 | raise NotImplementedError(f"metric {metric} not implemented!") 174 | similarity = torch.mean(similarity, dim=dim) 175 | return similarity 176 | 177 | def _search_best_A_interval(self, A, B, A_interval_candidates): 178 | """ 179 | Modularization of searching best interval 180 | """ 181 | # recalculate A_pad 182 | A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) 183 | 184 | tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 185 | # out-of-loop optimization 186 | B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3 187 | for v, h in product(range(self.n_V_A), range(self.n_H_A)): 188 | similarities = [] 189 | for p_st in range(0, self.eq_n, self.parallel_eq_n): 190 | p_ed = min(self.eq_n,p_st+self.parallel_eq_n) 191 | # quantize A 192 | cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) 193 | cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] 194 | A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval) 195 | A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding) 196 | A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,B,H,dim1,dim2 197 | # quantize B, this quantization is optimized out of loop 198 | # calculate similarity and store them 199 | out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3 200 | similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1 201 | similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on) 202 | similarities.append(similarity) 203 | # calculate best similarity for this block 204 | similarities = torch.cat(similarities, 0) # shape: eq_n,H 205 | similarities = F.pad(similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A 206 | best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) 207 | tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) 208 | self.A_interval = tmp_A_interval.squeeze(0) 209 | 210 | def _search_best_B_interval(self, A, B, B_interval_candidates): 211 | """ 212 | Modularization of searching best interval 213 | """ 214 | # recalculate B_pad 215 | B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) 216 | 217 | tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 218 | # out-of-loop optimization 219 | A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2 220 | for v, h in product(range(self.n_V_B), range(self.n_H_B)): 221 | similarities = [] 222 | for p_st in range(0, self.eq_n, self.parallel_eq_n): 223 | p_ed = min(self.eq_n,p_st+self.parallel_eq_n) 224 | # quantize A, this quantization is optimized out of loop 225 | # quantize B 226 | cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) 227 | cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] 228 | B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval) 229 | B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,B,H*,dim2*,dim3* (* stand for padding) 230 | B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,B,H,dim2,dim3 231 | # calculate similarity and store them 232 | out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3 233 | similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1 234 | similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on) 235 | similarities.append(similarity) 236 | # calculate best similarity for this block 237 | similarities = torch.cat(similarities, 0) # shape: eq_n,H 238 | similarities = F.pad(similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B 239 | best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) 240 | tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) 241 | self.B_interval = tmp_B_interval.squeeze(0) 242 | 243 | def _initialize_intervals(self, A, B): 244 | # pad A and B for future quantization 245 | self._get_padding_parameters(A, B) # put it here because hessian does not use calibration step 1 246 | A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) # shape: 1,B,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols 247 | B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) 248 | 249 | # initialize intervals with minmax intervals 250 | if self.init_layerwise: 251 | self.A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1) 252 | self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1) 253 | else: 254 | self.A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 255 | self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 256 | 257 | def calibration_step2(self, A, B): 258 | # put raw outs/grads on GPU 259 | self.raw_out = self.raw_out.unsqueeze(0).to(A.device) 260 | self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None 261 | 262 | self._initialize_intervals(A, B) 263 | 264 | # prepare weight intervals and similarities 265 | A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0) 266 | B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) 267 | 268 | for e in range(self.search_round): 269 | # search for best A interval 270 | self._search_best_A_interval(A, B, A_interval_candidates) 271 | # search for best B interval 272 | self._search_best_B_interval(A, B, B_interval_candidates) 273 | 274 | # put raw data back to cpu 275 | self.raw_out = self.raw_out.squeeze(0).to("cpu") 276 | self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None 277 | 278 | # finish calibration and output the result 279 | self.calibrated = True 280 | del self.raw_input, self.raw_out, self.raw_grad 281 | out=self.quant_forward(A,B) 282 | return out 283 | 284 | class SoSPTQSLQuantMatMul(PTQSLQuantMatMul): 285 | """ 286 | Sublayerwise PTQ on matmul modules with Split-of-Softmax (SoS) on score matrix. 287 | 288 | Data after softmaxing has highly biased distribution, making it difficult to quantize with uniform quantization. 289 | An elegant tradeoff between great majority of unimportant values and few crucial values is impossible under low bit quantization. 290 | Therefore, we propose to split complete interval of (0, 1) into several smaller intervals and perform uniform quantization on each. 291 | We could manually assgin or search for the best split point. 292 | Currently, we only consider single split point scenarios, since this proves to be effective enough. 293 | 294 | The algorithm no longer requires PTQSL on score matrix, and will ignore relevant parameters. 295 | 296 | with proper hardware implementation, we don't need to use a sign bit anymore. 297 | """ 298 | def __init__(self, A_bit=8, B_bit=8, mode="raw", 299 | metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, 300 | n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False, 301 | split=None): 302 | super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, 303 | metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, 304 | n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise) 305 | self.n_G_A = 1 306 | self.n_V_A = 1 307 | self.n_H_A = 1 308 | self.A_qmax = 2**(self.A_bit-1) # well, still need it 309 | self.split = split 310 | if split != None: 311 | self.A_interval = self.split/(self.A_qmax-1) 312 | 313 | def quant_input_A(self, x): 314 | x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) 315 | x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval 316 | return x_high + x_low 317 | 318 | def _search_best_A_interval(self, A, B, split_candidates): 319 | """ 320 | search for best split point 321 | """ 322 | # out-of-loop optimization 323 | A_ = A.unsqueeze(0) 324 | # B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3 325 | B_sim = B.unsqueeze(0) 326 | 327 | similarities = [] 328 | for i in range(len(split_candidates)): 329 | # quantize A 330 | cur_A_interval = split_candidates[i]/(self.A_qmax-1) 331 | A_high = (A_.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) 332 | A_low =( A_.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval 333 | A_sim = A_high + A_low # shape: 1,B,H,S,S 334 | # quantize B, this quantization is optimized out of loop 335 | # calculate similarity and store them (dim1=dim2=S, dim3=W) 336 | out_sim = A_sim @ B_sim # shape: 1,B,H,dim1,dim3 337 | similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1 338 | similarity = similarity.mean([1,2,3]) # shape: 1 339 | similarities.append(similarity) 340 | # calculate best similarity for this block 341 | similarities = torch.cat(similarities, 0) # shape: eq_n 342 | best_index = torch.argmax(similarities, dim=0, keepdim=False) 343 | self.split = split_candidates[best_index] 344 | self.A_interval = self.split/(self.A_qmax-1) 345 | # debugging 346 | # print(f"best split: {self.split}") 347 | 348 | def _initialize_intervals(self, A, B): 349 | # pad A and B for future quantization 350 | self._get_padding_parameters(A, B) 351 | B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) 352 | 353 | # initialize intervals with minmax intervals 354 | self.split = 0.01 355 | self.A_interval = self.split/(self.A_qmax-1) 356 | if self.init_layerwise: 357 | self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1) 358 | else: 359 | self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 360 | 361 | def calibration_step2(self, A, B): 362 | # put raw outs/grads on GPU 363 | self.raw_out = self.raw_out.unsqueeze(0).to(A.device) 364 | self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None 365 | 366 | self._initialize_intervals(A, B) 367 | 368 | # prepare weight intervals and similarities 369 | A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda() 370 | # split_eq_alpha, split_eq_beta, split_eq_n = 0.002, 0.03, 50 371 | # A_split_candidates = torch.tensor([split_eq_alpha + (split_eq_beta- split_eq_alpha)*i/split_eq_n for i in range(split_eq_n + 1)]).cuda() 372 | B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) 373 | 374 | for e in range(self.search_round): 375 | # search for best A interval 376 | self._search_best_A_interval(A, B, A_split_candidates) 377 | # search for best B interval 378 | self._search_best_B_interval(A, B, B_interval_candidates) 379 | 380 | # put raw data back to cpu 381 | self.raw_out = self.raw_out.squeeze(0).to("cpu") 382 | self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None 383 | 384 | # finish calibration and output the result 385 | self.calibrated = True 386 | del self.raw_input, self.raw_out, self.raw_grad 387 | out=self.quant_forward(A,B) 388 | return out 389 | 390 | class PTQSLBatchingQuantMatMul(PTQSLQuantMatMul): 391 | def __init__(self, A_bit=8, B_bit=8, mode="raw", 392 | metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, 393 | n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False): 394 | super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise) 395 | 396 | def _initialize_calib_parameters(self): 397 | """ 398 | set parameters for feeding calibration data 399 | """ 400 | self.calib_size = int(self.raw_input[0].shape[0]) 401 | self.calib_batch_size = int(self.raw_input[0].shape[0]) 402 | while True: 403 | numel = ((self.raw_input[0].numel()+self.raw_input[1].numel()+2*self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU 404 | self.parallel_eq_n = int((3*1024*1024*1024/4)//numel) 405 | if self.parallel_eq_n <= 1: 406 | self.calib_need_batching = True 407 | self.calib_batch_size //= 2 408 | else: 409 | break 410 | 411 | def _get_padding_parameters(self, A, B): 412 | """ 413 | We adopt a head-wise quantization here 414 | """ 415 | self.n_G_A = A.shape[1] 416 | self.n_G_B = B.shape[1] 417 | super()._get_padding_parameters(A,B) 418 | 419 | def _initialize_intervals(self): 420 | # pad A and B for future quantization 421 | self._get_padding_parameters(self.raw_input[0], self.raw_input[1]) # put it here because hessian does not use calibration step 1 422 | 423 | # initialize intervals with minmax intervals 424 | tmp_A_intervals = [] 425 | tmp_B_intervals = [] 426 | for b_st in range(0,self.calib_size,self.calib_batch_size): 427 | b_ed = min(self.calib_size, b_st+self.calib_batch_size) 428 | A, B = self.raw_input[0][b_st:b_ed].cuda(), self.raw_input[1][b_st:b_ed].cuda() 429 | if self.init_layerwise: 430 | A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1) 431 | B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1) 432 | else: 433 | A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) 434 | B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) 435 | A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 436 | B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 437 | tmp_A_intervals.append(A_interval) 438 | tmp_B_intervals.append(B_interval) 439 | self.A_interval = torch.cat(tmp_A_intervals, dim=0).amax(0, keepdim=True) 440 | self.B_interval = torch.cat(tmp_B_intervals, dim=0).amax(0, keepdim=True) 441 | 442 | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None): 443 | """ 444 | tensor_raw: *, features, * 445 | tensor_sim: *, features, * 446 | similarity: * 447 | It's your job to calculate mean on non-feature * dims! 448 | 449 | Similarity without inherent feature structure is more welcome to parallelism. 450 | """ 451 | if metric == "cosine": 452 | similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled 453 | elif metric == "pearson": 454 | similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=dim,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=dim,keepdim=True), dim=dim) # should only support dim=-1 and cannot be paralleled 455 | # a quick implementation of pearson similarity 456 | # tensor_raw: 1,B,H,dim1,dim3 457 | # tensor_sim: parallel_eq_n,B,H,dim1,dim3 458 | # parallel_eq_n,B,H,dim1,dim3 = tensor_sim.shape 459 | # tensor_sim = tensor_sim.view(parallel_eq_n,B,-1) 460 | # tensor_raw = tensor_raw.view(1,B,-1) 461 | # tensor_sim_mean = tensor_sim.mean(dim=[1,2],keepdim=True) 462 | # tensor_raw_mean = tensor_raw.mean(dim=[1,2],keepdim=True) 463 | # similarity = F.cosine_similarity(tensor_raw-tensor_raw_mean,tensor_sim-tensor_sim_mean,dim=-1) # shape: parallel_eq_n,B 464 | # similarity = similarity.reshape(parallel_eq_n,B,1,1) # restore two dims 465 | else: 466 | if metric == "L1_norm": 467 | similarity = -torch.abs(tensor_raw - tensor_sim) 468 | elif metric == "L2_norm": 469 | similarity = -(tensor_raw - tensor_sim) ** 2 470 | elif metric == "linear_weighted_L2_norm": 471 | similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 472 | elif metric == "square_weighted_L2_norm": 473 | similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 474 | elif metric == "hessian": 475 | assert raw_grad != None, f"No raw_grad in PTQSLBatchingQuantMatMul!" 476 | raw_grad = raw_grad.reshape_as(tensor_raw) 477 | similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 478 | else: 479 | raise NotImplementedError(f"metric {metric} not implemented!") 480 | similarity = torch.mean(similarity, dim=dim) 481 | return similarity 482 | 483 | def _search_best_A_interval(self, A_interval_candidates): 484 | """ 485 | Modularization of searching best interval 486 | """ 487 | tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 488 | # out-of-loop optimization 489 | for v, h in product(range(self.n_V_A), range(self.n_H_A)): 490 | batch_similarities = [] # similarities, need to concatenate and calculate sum 491 | for b_st in range(0, self.calib_size, self.calib_batch_size): 492 | b_ed = min(self.calib_size, b_st + self.calib_batch_size) 493 | A = self.raw_input[0][b_st:b_ed].cuda() 494 | A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) 495 | B = self.raw_input[1][b_st:b_ed].cuda() 496 | B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,b,H,dim2,dim3 497 | raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda() 498 | raw_grad = self.raw_grad[b_st:b_ed].cuda() 499 | similarities = [] 500 | for p_st in range(0, self.eq_n, self.parallel_eq_n): 501 | p_ed = min(self.eq_n,p_st+self.parallel_eq_n) 502 | # quantize A 503 | cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) 504 | cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] 505 | A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval) 506 | A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding) 507 | A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,b,H,dim1,dim2 508 | # quantize B, this quantization is optimized out of loop 509 | # calculate similarity and store them 510 | out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3 511 | similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1 512 | similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on) 513 | similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H 514 | similarities.append(similarity) 515 | # calculate best similarity for this block 516 | similarities = torch.cat(similarities, 0) # shape: eq_n,1,H 517 | batch_similarities.append(similarities) 518 | batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H 519 | batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A 520 | best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) 521 | tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) 522 | self.A_interval = tmp_A_interval.squeeze(0) 523 | 524 | def _search_best_B_interval(self, B_interval_candidates): 525 | """ 526 | Modularization of searching best interval 527 | """ 528 | tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 529 | # out-of-loop optimization 530 | for v, h in product(range(self.n_V_B), range(self.n_H_B)): 531 | batch_similarities = [] # similarities, need to concatenate and calculate sum 532 | for b_st in range(0, self.calib_size, self.calib_batch_size): 533 | b_ed = min(self.calib_size, b_st + self.calib_batch_size) 534 | A = self.raw_input[0][b_st:b_ed].cuda() 535 | A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2 536 | B = self.raw_input[1][b_st:b_ed].cuda() 537 | B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) 538 | raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda() 539 | raw_grad = self.raw_grad[b_st:b_ed].cuda() 540 | similarities = [] 541 | for p_st in range(0, self.eq_n, self.parallel_eq_n): 542 | p_ed = min(self.eq_n,p_st+self.parallel_eq_n) 543 | # quantize A, this quantization is optimized out of loop 544 | # quantize B 545 | cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) 546 | cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] 547 | B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval) 548 | B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,b,H*,dim2*,dim3* (* stand for padding) 549 | B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,b,H,dim2,dim3 550 | # calculate similarity and store them 551 | out_sim = A_sim @ B_sim # shape: parallel_eq_n,b,H,dim1,dim3 552 | similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1 553 | similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on) 554 | similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H 555 | similarities.append(similarity) 556 | # calculate best similarity for this block 557 | similarities = torch.cat(similarities, 0) # shape: eq_n,1,H 558 | batch_similarities.append(similarities) 559 | batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H 560 | batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B 561 | best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) 562 | tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) 563 | self.B_interval = tmp_B_interval.squeeze(0) 564 | 565 | def calibration_step2(self): 566 | self._initialize_calib_parameters() 567 | self._initialize_intervals() 568 | A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0) 569 | B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) 570 | for e in range(self.search_round): 571 | # search for best A interval 572 | self._search_best_A_interval(A_interval_candidates) 573 | # search for best B interval 574 | self._search_best_B_interval(B_interval_candidates) 575 | self.calibrated = True 576 | del self.raw_input, self.raw_out, self.raw_grad 577 | 578 | class SoSPTQSLBatchingQuantMatMul(PTQSLBatchingQuantMatMul): 579 | def __init__(self, A_bit=8, B_bit=8, mode="raw", 580 | metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, 581 | n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False, 582 | split=None): 583 | super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, 584 | metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, 585 | n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise) 586 | self.n_G_A = 1 587 | self.n_V_A = 1 588 | self.n_H_A = 1 589 | # with proper hardware implementation, we don't need to use a sign bit anymore 590 | self.A_qmax = 2**(self.A_bit-1) 591 | self.split = split 592 | if split != None: 593 | self.A_interval = self.split/(self.A_qmax-1) 594 | 595 | def quant_input_A(self, x): 596 | x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) 597 | x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval 598 | return x_high + x_low 599 | 600 | def _search_best_A_interval(self, split_candidates): 601 | batch_similarities = [] 602 | for b_st in range(0, self.calib_size, self.calib_batch_size): 603 | b_ed = min(self.calib_size, b_st + self.calib_batch_size) 604 | A = self.raw_input[0][b_st:b_ed].unsqueeze(0).cuda() 605 | B = self.raw_input[1][b_st:b_ed].unsqueeze(0).cuda() 606 | B_sim = B 607 | raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda() 608 | raw_grad = self.raw_grad[b_st:b_ed].cuda() 609 | similarities = [] 610 | for i in range(len(split_candidates)): 611 | # quantize A 612 | cur_A_interval = split_candidates[i]/(self.A_qmax-1) 613 | A_high = (A.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) 614 | A_low =( A.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval 615 | A_sim = A_high + A_low # shape: 1,b,H,S,S 616 | # quantize B, this quantization is optimized out of loop 617 | # calculate similarity and store them (dim1=dim2=S, dim3=W) 618 | out_sim = A_sim @ B_sim # shape: 1,b,H,dim1,dim3 619 | similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1 620 | similarity = similarity.mean([2,3]) # shape: parallel_eq_n, b 621 | similarity = similarity.sum(dim=1,keepdim=True) # parallel_eq_n, 1 622 | similarities.append(similarity) 623 | # calculate best similarity for this block 624 | similarities = torch.cat(similarities, 0) # shape: eq_n, 1 625 | batch_similarities.append(similarities) 626 | batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n 627 | best_index = torch.argmax(batch_similarities, dim=0, keepdim=False) 628 | self.split = split_candidates[best_index] 629 | self.A_interval = self.split/(self.A_qmax-1) 630 | # debugging 631 | # print(f"best split: {self.split}") 632 | 633 | def calibration_step2(self): 634 | self._initialize_calib_parameters() 635 | self._initialize_intervals() 636 | A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda() 637 | B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) 638 | for e in range(self.search_round): 639 | # search for best A interval 640 | self._search_best_A_interval(A_split_candidates) 641 | # search for best B interval 642 | self._search_best_B_interval(B_interval_candidates) 643 | self.calibrated = True 644 | del self.raw_input, self.raw_out, self.raw_grad -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reuse version v4 3 | Author: Hahn Yuan 4 | """ 5 | import PIL 6 | import torch 7 | import argparse 8 | import numpy as np 9 | import os 10 | import copy 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | from torchvision.datasets import ImageFolder,DatasetFolder 14 | import torch.utils.data 15 | import re 16 | import warnings 17 | from PIL import Image 18 | from PIL import ImageFile 19 | import random 20 | import torch.nn.functional as F 21 | from torch.utils.data import Dataset 22 | 23 | def calculate_n_correct(outputs,targets): 24 | _, predicted = outputs.max(1) 25 | n_correct= predicted.eq(targets).sum().item() 26 | return n_correct 27 | 28 | class SetSplittor(): 29 | def __init__(self,fraction=0.2): 30 | self.fraction=fraction 31 | 32 | def split(self,dataset): 33 | pass 34 | 35 | class LoaderGenerator(): 36 | """ 37 | """ 38 | def __init__(self,root,dataset_name,train_batch_size=1,test_batch_size=1,num_workers=0,kwargs={}): 39 | self.root=root 40 | self.dataset_name=str.lower(dataset_name) 41 | self.train_batch_size=train_batch_size 42 | self.test_batch_size=test_batch_size 43 | self.num_workers=num_workers 44 | self.kwargs=kwargs 45 | self.items=[] 46 | self._train_set=None 47 | self._test_set=None 48 | self._calib_set=None 49 | self.train_transform=None 50 | self.test_transform=None 51 | self.train_loader_kwargs = { 52 | 'num_workers': self.num_workers , 53 | 'pin_memory': kwargs.get('pin_memory',True), 54 | 'drop_last':kwargs.get('drop_last',False) 55 | } 56 | self.test_loader_kwargs=self.train_loader_kwargs.copy() 57 | self.load() 58 | 59 | @property 60 | def train_set(self): 61 | pass 62 | 63 | @property 64 | def test_set(self): 65 | pass 66 | 67 | def load(self): 68 | pass 69 | 70 | def train_loader(self): 71 | assert self.train_set is not None 72 | return torch.utils.data.DataLoader(self.train_set, batch_size=self.train_batch_size, shuffle=True, **self.train_loader_kwargs) 73 | 74 | def test_loader(self,shuffle=False,batch_size=None): 75 | assert self.test_set is not None 76 | if batch_size is None: 77 | batch_size=self.test_batch_size 78 | return torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs) 79 | 80 | def val_loader(self): 81 | assert self.val_set is not None 82 | return torch.utils.data.DataLoader(self.val_set, batch_size=self.test_batch_size, shuffle=False, **self.test_loader_kwargs) 83 | 84 | def trainval_loader(self): 85 | assert self.trainval_set is not None 86 | return torch.utils.data.DataLoader(self.trainval_set, batch_size=self.train_batch_size, shuffle=True, **self.train_loader_kwargs) 87 | 88 | def calib_loader(self,num=1024,seed=3): 89 | if self._calib_set is None: 90 | np.random.seed(seed) 91 | inds=np.random.permutation(len(self.train_set))[:num] 92 | self._calib_set=torch.utils.data.Subset(copy.deepcopy(self.train_set),inds) 93 | self._calib_set.dataset.transform=self.test_transform 94 | return torch.utils.data.DataLoader(self._calib_set, batch_size=num, shuffle=False, **self.train_loader_kwargs) 95 | 96 | class CIFARLoaderGenerator(LoaderGenerator): 97 | def load(self): 98 | if self.dataset_name=='cifar100': 99 | self.dataset_fn=datasets.CIFAR100 100 | normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], 101 | std=[0.2673, 0.2564, 0.2762]) 102 | elif self.dataset_name=='cifar10': 103 | self.dataset_fn=datasets.CIFAR10 104 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 105 | std=[0.2470, 0.2435, 0.2616]) 106 | else: 107 | raise NotImplementedError 108 | self.train_transform = transforms.Compose([ 109 | transforms.RandomCrop(32,padding=4), 110 | transforms.RandomHorizontalFlip(), 111 | transforms.ToTensor(), 112 | normalize, 113 | ]) 114 | self.test_transform = transforms.Compose([ 115 | transforms.ToTensor(), 116 | normalize, 117 | ]) 118 | @property 119 | def train_set(self): 120 | if self._train_set is None: 121 | self._train_set=self.dataset_fn(self.root, train=True, download=True, transform=self.train_transform) 122 | return self._train_set 123 | 124 | @property 125 | def test_set(self): 126 | if self._test_set is None: 127 | self._test_set=self.dataset_fn(self.root, train=False, transform=self.test_transform) 128 | return self._test_set 129 | 130 | class COCOLoaderGenerator(LoaderGenerator): 131 | def load(self): 132 | # download from https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh 133 | self.train_set = DetectionListDataset(os.path.join(self.root,'trainvalno5k.txt'),transform=augmentation_detection_tansforms) 134 | self.test_set = DetectionListDataset(os.path.join(self.root,'5k.txt'),transform=detection_tansforms,multiscale=False) 135 | self.train_loader_kwargs={"collate_fn":self.train_set.collate_fn} 136 | self.test_loader_kwargs={"collate_fn":self.test_set.collate_fn} 137 | 138 | class DetectionListDataset(Dataset): 139 | def __init__(self, list_path, img_size=416, multiscale=True, transform=None): 140 | with open(list_path, "r") as file: 141 | self.img_files = [path for path in file.readlines()] 142 | self.label_files = [ 143 | path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt") 144 | for path in self.img_files 145 | ] 146 | self.img_size = img_size 147 | self.max_objects = 100 148 | self.multiscale = multiscale 149 | self.min_size = self.img_size - 3 * 32 150 | self.max_size = self.img_size + 3 * 32 151 | self.batch_count = 0 152 | self.transform = transform 153 | 154 | def __getitem__(self, index): 155 | try: 156 | img_path = self.img_files[index % len(self.img_files)].rstrip() 157 | img = np.array(Image.open(img_path).convert('RGB'), dtype=np.uint8) 158 | except Exception as e: 159 | print(f"Could not read image '{img_path}'.") 160 | return 161 | try: 162 | label_path = self.label_files[index % len(self.img_files)].rstrip() 163 | # Ignore warning if file is empty 164 | with warnings.catch_warnings(): 165 | warnings.simplefilter("ignore") 166 | boxes = np.loadtxt(label_path).reshape(-1, 5) 167 | except Exception as e: 168 | print(f"Could not read label '{label_path}'.") 169 | return 170 | if self.transform: 171 | try: 172 | img, bb_targets = self.transform((img, boxes)) 173 | except: 174 | print(f"Could not apply transform.") 175 | return 176 | return img_path, img, bb_targets 177 | 178 | def collate_fn(self, batch): 179 | self.batch_count += 1 180 | # Drop invalid images 181 | batch = [data for data in batch if data is not None] 182 | 183 | paths, imgs, bb_targets = list(zip(*batch)) 184 | # Selects new image size every tenth batch 185 | if self.multiscale and self.batch_count % 10 == 0: 186 | self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32)) 187 | # Resize images to input shape 188 | imgs = torch.stack([F.interpolate(img.unsqueeze(0), size=self.img_size, mode="nearest").squeeze(0) for img in imgs]) 189 | # Add sample index to targets 190 | for i, boxes in enumerate(bb_targets): 191 | boxes[:, 0] = i 192 | bb_targets = torch.cat(bb_targets, 0) 193 | return paths, imgs, bb_targets 194 | 195 | def __len__(self): 196 | return len(self.img_files) 197 | 198 | # def faster_im_loader(path): 199 | # with open(path,'rb') as f: 200 | # bgr_array = TurboJPEG().decode(f.read()) 201 | # rgb_array=np.concatenate([bgr_array[:,:,2:3],bgr_array[:,:,1:2],bgr_array[:,:,0:1]],-1) 202 | # return torch.Tensor(rgb_array)/255 203 | 204 | class ImageNetLoaderGenerator(LoaderGenerator): 205 | def load(self): 206 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 207 | std=[0.229, 0.224, 0.225]) 208 | self.train_transform = transforms.Compose([ 209 | transforms.Resize(256), 210 | transforms.RandomResizedCrop(224), 211 | transforms.RandomHorizontalFlip(), 212 | transforms.ToTensor(), 213 | normalize, 214 | ]) 215 | 216 | self.test_transform = transforms.Compose([ 217 | transforms.Resize(256), 218 | transforms.CenterCrop(224), 219 | transforms.ToTensor(), 220 | normalize, 221 | ]) 222 | 223 | @property 224 | def train_set(self): 225 | if self._train_set is None: 226 | self._train_set=ImageFolder(os.path.join(self.root,'train'), self.train_transform) 227 | return self._train_set 228 | 229 | @property 230 | def test_set(self): 231 | if self._test_set is None: 232 | self._test_set=ImageFolder(os.path.join(self.root,'val'), self.test_transform) 233 | return self._test_set 234 | 235 | class CacheDataset(Dataset): 236 | def __init__(self,datas,targets) -> None: 237 | super().__init__() 238 | self.datas=datas 239 | self.targets=targets 240 | 241 | def __getitem__(self,idx): 242 | return self.datas[idx],self.targets[idx] 243 | 244 | def __len__(self): 245 | return len(self.datas) 246 | 247 | class FasterImageNetLoaderGenerator(ImageNetLoaderGenerator): 248 | def test_loader(self,shuffle=False,batch_size=None): 249 | cache='/dev/shm/imagenet.pkl' 250 | assert self.test_set is not None 251 | if batch_size is None: 252 | batch_size=self.test_batch_size 253 | if os.path.exists(cache): 254 | print("Loading the dataset from shared memory") 255 | datas,targets=torch.load(cache) 256 | else: 257 | print("Preprocessing the dataset and save it to shared memory") 258 | loader=torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs) 259 | datas=[] 260 | targets=[] 261 | for data,target in loader: 262 | datas.append(data) 263 | targets.append(target) 264 | datas=torch.cat(datas,0) 265 | targets=torch.cat(targets,0) 266 | torch.save([datas,targets],cache) 267 | dataset=CacheDataset(datas,targets) 268 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs) 269 | 270 | class DebugLoaderGenerator(LoaderGenerator): 271 | 272 | def load(self): 273 | version=re.findall("\d+",self.dataset_name)[0] 274 | class DebugSet(torch.utils.data.Dataset): 275 | def __getitem__(self,idx): 276 | if version=='0': 277 | return torch.ones([1,4,4]),0 278 | if version=='1': 279 | return torch.ones([1,8,8]),0 280 | if version=='2': 281 | return torch.ones([1,1,1]),0 282 | if version=='3': 283 | return torch.ones([1,3,3]),0 284 | else: 285 | raise NotImplementedError(f"version {version} of Debug dataset is not supported") 286 | def __len__(self): return 1 287 | self.train_set=DebugSet() 288 | self.test_set=DebugSet() 289 | 290 | def get_dataset(args:argparse.Namespace): 291 | """ Preparing Datasets, args: 292 | dataset (required): MNIST, cifar10/100, ImageNet, coco 293 | dataset_root: str, default='./datasets' 294 | num_workers: int 295 | batch_size: int 296 | test_batch_size: int 297 | val_fraction: float, default=0 298 | 299 | """ 300 | dataset_name=str.lower(args.dataset) 301 | dataset_root=getattr(args,'dataset_root','./datasets') 302 | num_workers=args.num_workers if hasattr(args,'num_workers') else 4 303 | batch_size=args.batch_size if hasattr(args,'batch_size') else 64 304 | test_batch_size=args.test_batch_size if hasattr(args,'test_batch_size') else batch_size 305 | val_fraction=args.val_fraction if hasattr(args,"val_fraction") else 0 306 | if "cifar" in dataset_name: 307 | # Data loading code 308 | g=CIFARLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) 309 | elif "coco" in dataset_name: 310 | g=COCOLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) 311 | elif "debug" in dataset_name: 312 | g=DebugLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) 313 | elif args.dataset=='ImageNet': 314 | g=ImageNetLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) 315 | else: 316 | raise NotImplementedError 317 | return g.train_loader(),g.test_loader() 318 | 319 | 320 | import timm 321 | from timm.models.vision_transformer import VisionTransformer 322 | from timm.data import resolve_data_config 323 | from timm.data.transforms_factory import create_transform 324 | 325 | class ViTImageNetLoaderGenerator(ImageNetLoaderGenerator): 326 | """ 327 | DataLoader for Vision Transformer. 328 | To comply with timm's framework, we use the model's corresponding transform. 329 | """ 330 | def __init__(self, root, dataset_name, train_batch_size, test_batch_size, num_workers, kwargs={}): 331 | kwargs.update({"pin_memory":False}) 332 | super().__init__(root, dataset_name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, num_workers=num_workers, kwargs=kwargs) 333 | 334 | def load(self): 335 | model = self.kwargs.get("model", None) 336 | assert model != None, f"No model in ViTImageNetLoaderGenerator!" 337 | 338 | config = resolve_data_config({}, model=model) 339 | self.train_transform = create_transform(**config, is_training=True) 340 | self.test_transform = create_transform(**config) 341 | 342 | -------------------------------------------------------------------------------- /utils/integer.py: -------------------------------------------------------------------------------- 1 | from numpy import dtype 2 | from quant_layers.matmul import MinMaxQuantMatMul, PTQSLBatchingQuantMatMul, PTQSLQuantMatMul, SoSPTQSLBatchingQuantMatMul, SoSPTQSLQuantMatMul 3 | from quant_layers.linear import MinMaxQuantLinear, PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear, PostGeluPTQSLQuantLinear 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def quantize_int_weight(module): 9 | """ 10 | get weight of type 'uint8' of a quantized module. 11 | Bias are not quantized and you can use raw bias. 12 | """ 13 | assert hasattr(module, 'weight'), f"module {module} does not have weight" 14 | assert module.w_bit == 8, f"module {module}'s weight is quantized with {module.w_bit} bits" 15 | 16 | w_int = (module.weight/module.w_interval).round_().clamp_(-module.w_qmax, module.w_qmax-1) 17 | w_int = w_int.cpu().detach().to(torch.int8) 18 | return w_int 19 | 20 | def dequantize_int_weight(module, w_int): 21 | """ 22 | Make sure it's the same module that generates w_int 23 | """ 24 | w_sim = module.w_interval.cpu() * w_int.float() 25 | return w_sim 26 | 27 | def quantize_matmul_input(input, interval, qmax, n_G, n_V, n_H, crb_groups, crb_rows, crb_cols): 28 | """ 29 | quantize input matrix of matmul operation, with respect to sublayerwise padding settings 30 | """ 31 | pad_groups = crb_groups*n_G - input.shape[1] 32 | pad_rows = crb_rows*n_V - input.shape[2] 33 | pad_cols = crb_cols*n_H - input.shape[3] 34 | 35 | x = F.pad(input, [0,pad_cols,0,pad_rows,0,pad_groups]) 36 | x = x.view(-1,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols) 37 | x = (x/interval).round_().clamp(-qmax,qmax-1) 38 | x = x.view(-1,n_G*crb_groups,n_V*crb_rows,n_H*crb_cols) 39 | x = x[:,:x.shape[1]-pad_groups,:x.shape[2]-pad_rows,:x.shape[3]-pad_cols] 40 | 41 | return x 42 | 43 | 44 | def quantize_int_activation(module, input): 45 | """ 46 | Quantize current inputs into uint8 and store them as an attribute of the module. 47 | 48 | The function is a pre-forward hook that need to be manually added to the calibrated model. 49 | You need to manipulate the cached data before feeding another batch of pictures. 50 | Currently only support int8. (For twin quantization, we use uint8) 51 | 52 | For twin quantization: 53 | - For softmax, the MSB being 1 means using large interval, while MSB being 0 means using small interval. 54 | - For post-GELU, the MSB serves as sign bit. We use 1 for positive values and 0 for negative values. 55 | """ 56 | if isinstance(module, PostGeluPTQSLQuantLinear) or isinstance(module, PostGeluPTQSLBatchingQuantLinear): 57 | assert module.a_bit == 8, f"module {module}'s activation is quantized with {module.a_bit} bits" 58 | 59 | x = input[0] 60 | 61 | int_input_pos = (x/module.a_interval).round_().clamp_(0, module.a_qmax-1) 62 | int_input_pos = int_input_pos.detach().to(torch.uint8) + 128 63 | 64 | int_input_neg = (x/module.a_neg_interval).round_().clamp_(-module.a_qmax+1, 0).abs() 65 | int_input_neg = int_input_neg.detach().to(torch.uint8) 66 | 67 | int_input = (int_input_pos + int_input_neg).cpu() 68 | module.int_input = [int_input] 69 | 70 | elif isinstance(module, MinMaxQuantLinear): 71 | assert module.a_bit == 8, f"module {module}'s activation is quantized with {module.a_bit} bits" 72 | 73 | x = input[0] 74 | int_input = (x/module.a_interval).round_().clamp_(-module.a_qmax, module.a_qmax-1) 75 | int_input = int_input.cpu().detach().to(torch.int8) 76 | 77 | module.int_input = [int_input] 78 | 79 | elif isinstance(module, SoSPTQSLQuantMatMul) or isinstance(module, SoSPTQSLBatchingQuantMatMul): 80 | assert module.A_bit == 8, f"module {module}'s matrix A is quantized with {module.A_bit} bits" 81 | assert module.B_bit == 8, f"module {module}'s matrix B is quantized with {module.B_bit} bits" 82 | 83 | A, B = input[0], input[1] 84 | 85 | A_high = (A.clamp(module.split, 1)*(module.A_qmax-1)).round_().clamp_(0,module.A_qmax-1) 86 | A_high = A_high.detach().to(torch.uint8) + 128 87 | 88 | A_low = (A.clamp(0, module.split)/module.A_interval).round_().clamp_(0,module.A_qmax-1) 89 | A_low = A_low.detach().to(torch.uint8) 90 | 91 | A_int = (A_high + A_low).cpu() 92 | 93 | B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B) 94 | B_int = B_int.cpu().detach().to(torch.int8) 95 | 96 | module.int_input = [A_int, B_int] 97 | 98 | elif isinstance(module, PTQSLQuantMatMul) or isinstance(module, PTQSLBatchingQuantMatMul): 99 | assert module.A_bit == 8, f"module {module}'s matrix A is quantized with {module.A_bit} bits" 100 | assert module.B_bit == 8, f"module {module}'s matrix B is quantized with {module.B_bit} bits" 101 | 102 | A, B = input[0], input[1] 103 | 104 | A_int = quantize_matmul_input(A,module.A_interval,module.A_qmax,module.n_G_A,module.n_V_A,module.n_H_A,module.crb_groups_A,module.crb_rows_A,module.crb_cols_A) 105 | A_int = A_int.cpu().detach().to(torch.int8) 106 | 107 | B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B) 108 | B_int = B_int.cpu().detach().to(torch.int8) 109 | 110 | module.int_input = [A_int, B_int] 111 | 112 | 113 | def get_model_int_weight(wrapped_modules): 114 | """ 115 | Get quantized weights (in int8) of a model. 116 | 117 | Return: 118 | A dict, with modules' names as keys, and int weights as values. 119 | """ 120 | 121 | int_weights = {} 122 | 123 | for name, m in wrapped_modules.items(): 124 | try: 125 | int_weights[name] = quantize_int_weight(m) 126 | except: 127 | pass 128 | 129 | return int_weights 130 | -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | from types import MethodType 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import timm 6 | from timm.models import vision_transformer 7 | from timm.models.vision_transformer import Attention 8 | from timm.models.swin_transformer import WindowAttention 9 | 10 | def attention_forward(self, x): 11 | B, N, C = x.shape 12 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 13 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 14 | 15 | # attn = (q @ k.transpose(-2, -1)) * self.scale 16 | attn = self.matmul1(q, k.transpose(-2, -1)) * self.scale 17 | attn = attn.softmax(dim=-1) 18 | attn = self.attn_drop(attn) 19 | del q, k 20 | 21 | # x = (attn @ v).transpose(1, 2).reshape(B, N, C) 22 | x = self.matmul2(attn, v).transpose(1, 2).reshape(B, N, C) 23 | del attn, v 24 | x = self.proj(x) 25 | x = self.proj_drop(x) 26 | return x 27 | 28 | def window_attention_forward(self, x, mask = None): 29 | B_, N, C = x.shape 30 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 31 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 32 | 33 | q = q * self.scale 34 | # attn = (q @ k.transpose(-2, -1)) 35 | attn = self.matmul1(q, k.transpose(-2,-1)) 36 | 37 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 38 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 39 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 40 | attn = attn + relative_position_bias.unsqueeze(0) 41 | 42 | if mask is not None: 43 | nW = mask.shape[0] 44 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 45 | attn = attn.view(-1, self.num_heads, N, N) 46 | attn = self.softmax(attn) 47 | else: 48 | attn = self.softmax(attn) 49 | 50 | attn = self.attn_drop(attn) 51 | 52 | # x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 53 | x = self.matmul2(attn, v).transpose(1, 2).reshape(B_, N, C) 54 | x = self.proj(x) 55 | x = self.proj_drop(x) 56 | return x 57 | 58 | class MatMul(nn.Module): 59 | def forward(self, A, B): 60 | return A @ B 61 | 62 | def get_net(name): 63 | """ 64 | Get a vision transformer model. 65 | This will replace matrix multiplication operations with matmul modules in the model. 66 | 67 | Currently support almost all models in timm.models.transformers, including: 68 | - vit_tiny/small/base/large_patch16/patch32_224/384, 69 | - deit_tiny/small/base(_distilled)_patch16_224, 70 | - deit_base(_distilled)_patch16_384, 71 | - swin_tiny/small/base/large_patch4_window7_224, 72 | - swin_base/large_patch4_window12_384 73 | 74 | These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator 75 | for calibration and testing. 76 | """ 77 | net = timm.create_model(name, pretrained=True) 78 | 79 | for name, module in net.named_modules(): 80 | if isinstance(module, Attention): 81 | setattr(module, "matmul1", MatMul()) 82 | setattr(module, "matmul2", MatMul()) 83 | module.forward = MethodType(attention_forward, module) 84 | if isinstance(module, WindowAttention): 85 | setattr(module, "matmul1", MatMul()) 86 | setattr(module, "matmul2", MatMul()) 87 | module.forward = MethodType(window_attention_forward, module) 88 | 89 | net.cuda() 90 | net.eval() 91 | return net 92 | -------------------------------------------------------------------------------- /utils/net_wrap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.models import MatMul 5 | import re 6 | 7 | 8 | def _fold_bn(conv_module, bn_module): 9 | w = conv_module.weight.data 10 | y_mean = bn_module.running_mean 11 | y_var = bn_module.running_var 12 | safe_std = torch.sqrt(y_var + bn_module.eps) 13 | w_view = (conv_module.out_channels, 1, 1, 1) 14 | if bn_module.affine: 15 | weight = w * (bn_module.weight / safe_std).view(w_view) 16 | beta = bn_module.bias - bn_module.weight * y_mean / safe_std 17 | if conv_module.bias is not None: 18 | bias = bn_module.weight * conv_module.bias / safe_std + beta 19 | else: 20 | bias = beta 21 | else: 22 | weight = w / safe_std.view(w_view) 23 | beta = -y_mean / safe_std 24 | if conv_module.bias is not None: 25 | bias = conv_module.bias / safe_std + beta 26 | else: 27 | bias = beta 28 | return weight, bias 29 | 30 | def fold_bn_into_conv(conv_module, bn_module): 31 | w, b = _fold_bn(conv_module, bn_module) 32 | if conv_module.bias is None: 33 | conv_module.bias = nn.Parameter(b.data) 34 | else: 35 | conv_module.bias.data = b.data 36 | conv_module.weight.data = w.data 37 | 38 | 39 | def wrap_modules_in_net(net,cfg): 40 | wrapped_modules={} 41 | module_dict={} 42 | module_types = {"qkv":"qlinear_qkv", "proj":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':"qlinear_MLP_2", 'head':'qlinear_classifier','matmul1':"qmatmul_qk", 'matmul2':"qmatmul_scorev", "reduction": "qlinear_reduction"} 43 | 44 | it=[(name,m) for name,m in net.named_modules()] 45 | for name,m in it: 46 | module_dict[name]=m 47 | idx=name.rfind('.') 48 | if idx==-1: 49 | idx=0 50 | father_name=name[:idx] 51 | if father_name in module_dict: 52 | father_module=module_dict[father_name] 53 | else: 54 | raise RuntimeError(f"father module {father_name} not found") 55 | if isinstance(m,nn.Conv2d): 56 | # Embedding Layer 57 | idx = idx+1 if idx != 0 else idx 58 | new_m=cfg.get_module("qconv",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode) 59 | new_m.weight.data=m.weight.data 60 | new_m.bias=m.bias 61 | replace_m=new_m 62 | wrapped_modules[name] = new_m 63 | setattr(father_module,name[idx:],replace_m) 64 | elif isinstance(m,nn.Linear): 65 | # Linear Layer 66 | idx = idx+1 if idx != 0 else idx 67 | new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features) 68 | new_m.weight.data=m.weight.data 69 | new_m.bias=m.bias 70 | replace_m=new_m 71 | wrapped_modules[name] = new_m 72 | setattr(father_module,name[idx:],replace_m) 73 | elif isinstance(m,MatMul): 74 | # Matmul Layer 75 | idx = idx+1 if idx != 0 else idx 76 | new_m = cfg.get_module(module_types[name[idx:]]) 77 | replace_m=new_m 78 | wrapped_modules[name] = new_m 79 | setattr(father_module,name[idx:],replace_m) 80 | print("Completed net wrap.") 81 | return wrapped_modules 82 | 83 | def wrap_certain_modules_in_net(net,cfg,layers,modules_to_wrap,wrap_embedding=False): 84 | """ 85 | wrap specific module inside transformer block of specific layer 86 | layers: list of integers, indicating layers to wrap 87 | modules_to_wrap: list of modules to wrap 88 | """ 89 | wrapped_modules={} 90 | module_dict={} 91 | module_types = {"qkv":"qlinear_qkv", "proj":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':"qlinear_MLP_2", 'head':'qlinear_classifier','matmul1':"qmatmul_qk", 'matmul2':"qmatmul_scorev"} 92 | 93 | it=[(name,m) for name,m in net.named_modules()] 94 | for name,m in it: 95 | module_dict[name]=m 96 | idx=name.rfind('.') 97 | if idx==-1: 98 | idx=0 99 | father_name=name[:idx] 100 | if father_name in module_dict: 101 | father_module=module_dict[father_name] 102 | else: 103 | raise RuntimeError(f"father module {father_name} not found") 104 | layer = re.search('\d+', name) 105 | if layer is not None: # inside a transformer block 106 | layer = int(name[layer.span()[0]:layer.span()[1]]) 107 | if layer not in layers: continue 108 | if isinstance(m,nn.Conv2d): 109 | # Embedding Layer 110 | idx = idx+1 if idx != 0 else idx 111 | if not wrap_embedding: 112 | continue # timm patch_embed use proj as well... 113 | # if name[idx:] not in modules_to_wrap: continue 114 | new_m=cfg.get_module("qconv",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode) 115 | new_m.weight.data=m.weight.data 116 | new_m.bias=m.bias 117 | replace_m=new_m 118 | wrapped_modules[name] = new_m 119 | setattr(father_module,name[idx:],replace_m) 120 | elif isinstance(m,nn.Linear): 121 | # Linear Layer 122 | idx = idx+1 if idx != 0 else idx 123 | if name[idx:] not in modules_to_wrap: continue 124 | new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features) 125 | new_m.weight.data=m.weight.data 126 | new_m.bias=m.bias 127 | replace_m=new_m 128 | wrapped_modules[name] = new_m 129 | setattr(father_module,name[idx:],replace_m) 130 | elif isinstance(m,MatMul): 131 | # Matmul Layer 132 | idx = idx+1 if idx != 0 else idx 133 | if name[idx:] not in modules_to_wrap: continue 134 | new_m = cfg.get_module(module_types[name[idx:]]) 135 | replace_m=new_m 136 | wrapped_modules[name] = new_m 137 | setattr(father_module,name[idx:],replace_m) 138 | print("Completed net wrap.") 139 | return wrapped_modules 140 | -------------------------------------------------------------------------------- /utils/quant_calib.py: -------------------------------------------------------------------------------- 1 | from numpy import isin 2 | import torch 3 | from quant_layers.conv import MinMaxQuantConv2d 4 | from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear 5 | from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | class QuantCalibrator(): 10 | """ 11 | Modularization of quant calib. 12 | 13 | Notice: 14 | all quant modules has method "calibration_step1" that should only store raw inputs and outputs 15 | all quant modules has method "calibration_step2" that should only quantize its intervals 16 | and we assume we could feed in all calibration data in one batch, without backward propagations 17 | 18 | sequential calibration is memory-friendly, while parallel calibration may consume 19 | hundreds of GB of memory. 20 | """ 21 | def __init__(self, net, wrapped_modules, calib_loader, sequential=True): 22 | self.net = net 23 | self.wrapped_modules = wrapped_modules 24 | self.calib_loader = calib_loader 25 | self.sequential = sequential 26 | self.calibrated = False 27 | 28 | def sequential_quant_calib(self): 29 | """ 30 | A quick implementation of calibration. 31 | Assume calibration dataset could be fed at once. 32 | """ 33 | # run calibration 34 | n_calibration_steps=2 35 | for step in range(n_calibration_steps): 36 | print(f"Start calibration step={step+1}") 37 | for name,module in self.wrapped_modules.items(): 38 | # corner cases for calibrated modules 39 | if hasattr(module, "calibrated"): 40 | if step == 1: 41 | module.mode = "raw" 42 | elif step == 2: 43 | module.mode = "quant_forward" 44 | else: 45 | module.mode=f'calibration_step{step+1}' 46 | with torch.no_grad(): 47 | for inp,target in self.calib_loader: 48 | inp=inp.cuda() 49 | self.net(inp) 50 | 51 | # finish calibration 52 | for name,module in self.wrapped_modules.items(): 53 | module.mode='quant_forward' 54 | torch.cuda.empty_cache() # memory footprint cleanup 55 | print("sequential calibration finished") 56 | 57 | def parallel_quant_calib(self): 58 | """ 59 | A quick implementation of parallel quant calib 60 | Assume calibration dataset could be fed at once, and memory could hold all raw inputs/outs 61 | """ 62 | # calibration step1: collect raw data 63 | print(f"Start calibration step=1") 64 | for name,module in self.wrapped_modules.items(): 65 | # corner cases for calibrated modules 66 | if hasattr(module, "calibrated"): 67 | module.mode = "raw" 68 | else: 69 | module.mode=f'calibration_step1' 70 | with torch.no_grad(): 71 | for inp,target in self.calib_loader: 72 | inp=inp.cuda() 73 | self.net(inp) 74 | # calibration step2: each module run calibration with collected raw data 75 | for name,module in self.wrapped_modules.items(): 76 | if hasattr(module, "calibrated"): 77 | continue 78 | else: 79 | module.mode=f"calibration_step2" 80 | with torch.no_grad(): 81 | if isinstance(module, MinMaxQuantLinear): 82 | module.forward(module.raw_input.cuda()) 83 | elif isinstance(module, MinMaxQuantConv2d): 84 | module.forward(module.raw_input.cuda()) 85 | elif isinstance(module, MinMaxQuantMatMul): 86 | module.forward(module.raw_input[0].cuda(), module.raw_input[1].cuda()) 87 | torch.cuda.empty_cache() 88 | 89 | # finish calibration 90 | for name,module in self.wrapped_modules.items(): 91 | module.mode='quant_forward' 92 | torch.cuda.empty_cache() # memory footprint cleanup 93 | print("calibration finished") 94 | 95 | def quant_calib(self): 96 | calib_layers=[] 97 | for name,module in self.wrapped_modules.items(): 98 | calib_layers.append(name) 99 | print(f"prepare parallel calibration for {calib_layers}") 100 | if self.sequential: 101 | self.sequential_quant_calib() 102 | else: 103 | self.parallel_quant_calib() 104 | self.calibrated = True 105 | 106 | def batching_quant_calib(self): 107 | calib_layers=[] 108 | for name,module in self.wrapped_modules.items(): 109 | calib_layers.append(name) 110 | print(f"prepare parallel calibration for {calib_layers}") 111 | 112 | print("start calibration") 113 | 114 | # assume wrapped modules are in order (true for dict in python>=3.5) 115 | q = tqdm(self.wrapped_modules.items(), desc="Brecq") 116 | for name, module in q: 117 | q.set_postfix_str(name) 118 | 119 | # add fp and bp hooks to current modules, which bypass calibration step 1 120 | # precedent modules are using quant forward 121 | hooks = [] 122 | if isinstance(module, MinMaxQuantLinear): 123 | hooks.append(module.register_forward_hook(linear_forward_hook)) 124 | if isinstance(module, MinMaxQuantConv2d): 125 | hooks.append(module.register_forward_hook(conv2d_forward_hook)) 126 | if isinstance(module, MinMaxQuantMatMul): 127 | hooks.append(module.register_forward_hook(matmul_forward_hook)) 128 | 129 | # feed in calibration data, and store the data 130 | for inp, target in self.calib_loader: 131 | for batch_st in range(0,self.calib_loader.batch_size,self.batch_size): 132 | self.net.zero_grad() 133 | inp_ = inp[batch_st:batch_st+self.batch_size].cuda() 134 | self.net(inp_) 135 | del inp, target 136 | torch.cuda.empty_cache() 137 | 138 | # replace cached raw_inputs, raw_outs 139 | if isinstance(module, MinMaxQuantLinear): 140 | module.raw_input = torch.cat(module.raw_input, dim=0) 141 | module.raw_out = torch.cat(module.raw_out, dim=0) 142 | if isinstance(module, MinMaxQuantConv2d): 143 | module.raw_input = torch.cat(module.raw_input, dim=0) 144 | module.raw_out = torch.cat(module.raw_out, dim=0) 145 | if isinstance(module, MinMaxQuantMatMul): 146 | module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input] 147 | module.raw_out = torch.cat(module.raw_out, dim=0) 148 | for hook in hooks: 149 | hook.remove() 150 | 151 | # run calibration step2 152 | with torch.no_grad(): 153 | if isinstance(module, MinMaxQuantLinear): 154 | module.calibration_step2() 155 | if isinstance(module, MinMaxQuantConv2d): 156 | module.calibration_step2() 157 | if isinstance(module, MinMaxQuantMatMul): 158 | module.calibration_step2() 159 | torch.cuda.empty_cache() 160 | 161 | # finishing up current module calibration 162 | if self.sequential: 163 | module.mode = "quant_forward" 164 | else: 165 | module.mode = "raw" 166 | 167 | # finish calibration 168 | for name, module in self.wrapped_modules.items(): 169 | module.mode = "quant_forward" 170 | 171 | print("calibration finished") 172 | 173 | def grad_hook(module, grad_input, grad_output): 174 | if module.raw_grad is None: 175 | module.raw_grad = [] 176 | module.raw_grad.append(grad_output[0].cpu().detach()) # that's a tuple! 177 | 178 | def linear_forward_hook(module, input, output): 179 | if module.raw_input is None: 180 | module.raw_input = [] 181 | if module.raw_out is None: 182 | module.raw_out = [] 183 | module.raw_input.append(input[0].cpu().detach()) 184 | module.raw_out.append(output.cpu().detach()) 185 | 186 | def conv2d_forward_hook(module, input, output): 187 | if module.raw_input is None: 188 | module.raw_input = [] 189 | if module.raw_out is None: 190 | module.raw_out = [] 191 | module.raw_input.append(input[0].cpu().detach()) 192 | module.raw_out.append(output.cpu().detach()) 193 | 194 | def matmul_forward_hook(module, input, output): 195 | if module.raw_input is None: 196 | module.raw_input = [[],[]] 197 | if module.raw_out is None: 198 | module.raw_out = [] 199 | module.raw_input[0].append(input[0].cpu().detach()) 200 | module.raw_input[1].append(input[1].cpu().detach()) 201 | module.raw_out.append(output.cpu().detach()) 202 | 203 | class HessianQuantCalibrator(QuantCalibrator): 204 | """ 205 | Modularization of hessian_quant_calib 206 | 207 | Hessian metric needs gradients of layer outputs to weigh the loss, 208 | which calls for back propagation in calibration, both sequentially 209 | and parallelly. Despite the complexity of bp, hessian quant calibrator 210 | is compatible with other non-gradient quantization metrics. 211 | """ 212 | def __init__(self, net, wrapped_modules, calib_loader, sequential=False, batch_size=1): 213 | super().__init__(net, wrapped_modules, calib_loader, sequential=sequential) 214 | self.batch_size = batch_size 215 | 216 | def quant_calib(self): 217 | """ 218 | An implementation of original hessian calibration. 219 | """ 220 | 221 | calib_layers=[] 222 | for name,module in self.wrapped_modules.items(): 223 | calib_layers.append(name) 224 | print(f"prepare parallel calibration for {calib_layers}") 225 | 226 | print("start hessian calibration") 227 | 228 | # get raw_pred as target distribution 229 | with torch.no_grad(): 230 | for inp, _ in self.calib_loader: 231 | raw_pred = self.net(inp.cuda()) 232 | raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach() 233 | torch.cuda.empty_cache() 234 | 235 | # assume wrapped modules are in order (true for dict in python>=3.5) 236 | q = tqdm(self.wrapped_modules.items(), desc="Brecq") 237 | for name, module in q: 238 | q.set_postfix_str(name) 239 | 240 | # add fp and bp hooks to current modules, which bypass calibration step 1 241 | # precedent modules are using quant forward 242 | hooks = [] 243 | if isinstance(module, MinMaxQuantLinear): 244 | hooks.append(module.register_forward_hook(linear_forward_hook)) 245 | if isinstance(module, MinMaxQuantConv2d): 246 | hooks.append(module.register_forward_hook(conv2d_forward_hook)) 247 | if isinstance(module, MinMaxQuantMatMul): 248 | hooks.append(module.register_forward_hook(matmul_forward_hook)) 249 | if hasattr(module, "metric") and module.metric == "hessian": 250 | hooks.append(module.register_backward_hook(grad_hook)) 251 | 252 | # feed in calibration data, and store the data 253 | for inp, target in self.calib_loader: 254 | for batch_st in range(0,self.calib_loader.batch_size,self.batch_size): 255 | self.net.zero_grad() 256 | inp_ = inp[batch_st:batch_st+self.batch_size].cuda() 257 | pred = self.net(inp_) 258 | loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction="batchmean") 259 | loss.backward() 260 | del inp, target, pred, loss 261 | torch.cuda.empty_cache() 262 | 263 | # replace cached raw_inputs, raw_outs 264 | if isinstance(module, MinMaxQuantLinear): 265 | module.raw_input = torch.cat(module.raw_input, dim=0) 266 | module.raw_out = torch.cat(module.raw_out, dim=0) 267 | if isinstance(module, MinMaxQuantConv2d): 268 | module.raw_input = torch.cat(module.raw_input, dim=0) 269 | module.raw_out = torch.cat(module.raw_out, dim=0) 270 | if isinstance(module, MinMaxQuantMatMul): 271 | module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input] 272 | module.raw_out = torch.cat(module.raw_out, dim=0) 273 | if hasattr(module, "metric") and module.metric == "hessian": 274 | module.raw_grad = torch.cat(module.raw_grad, dim=0) 275 | for hook in hooks: 276 | hook.remove() 277 | 278 | # run calibration step2 279 | with torch.no_grad(): 280 | if isinstance(module, MinMaxQuantLinear): 281 | module.calibration_step2(module.raw_input.cuda()) 282 | if isinstance(module, MinMaxQuantConv2d): 283 | module.calibration_step2(module.raw_input.cuda()) 284 | if isinstance(module, MinMaxQuantMatMul): 285 | module.calibration_step2(module.raw_input[0].cuda(), module.raw_input[1].cuda()) 286 | torch.cuda.empty_cache() 287 | 288 | # finishing up current module calibration 289 | if self.sequential: 290 | module.mode = "quant_forward" 291 | else: 292 | module.mode = "raw" 293 | 294 | # finish calibration 295 | for name, module in self.wrapped_modules.items(): 296 | module.mode = "quant_forward" 297 | 298 | print("hessian calibration finished") 299 | 300 | def batching_quant_calib(self): 301 | calib_layers=[] 302 | for name,module in self.wrapped_modules.items(): 303 | calib_layers.append(name) 304 | print(f"prepare parallel calibration for {calib_layers}") 305 | 306 | print("start hessian calibration") 307 | 308 | # get raw_pred as target distribution 309 | with torch.no_grad(): 310 | for inp, _ in self.calib_loader: 311 | raw_pred = self.net(inp.cuda()) 312 | raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach() 313 | torch.cuda.empty_cache() 314 | 315 | # assume wrapped modules are in order (true for dict in python>=3.5) 316 | q = tqdm(self.wrapped_modules.items(), desc="Hessian") 317 | for name, module in q: 318 | q.set_postfix_str(name) 319 | 320 | # add fp and bp hooks to current modules, which bypass calibration step 1 321 | # precedent modules are using quant forward 322 | hooks = [] 323 | if isinstance(module, MinMaxQuantLinear): 324 | hooks.append(module.register_forward_hook(linear_forward_hook)) 325 | if isinstance(module, MinMaxQuantConv2d): 326 | hooks.append(module.register_forward_hook(conv2d_forward_hook)) 327 | if isinstance(module, MinMaxQuantMatMul): 328 | hooks.append(module.register_forward_hook(matmul_forward_hook)) 329 | if hasattr(module, "metric"): 330 | hooks.append(module.register_backward_hook(grad_hook)) 331 | 332 | # feed in calibration data, and store the data 333 | for inp, target in self.calib_loader: 334 | for batch_st in range(0,self.calib_loader.batch_size,self.batch_size): 335 | self.net.zero_grad() 336 | inp_ = inp[batch_st:batch_st+self.batch_size].cuda() 337 | pred = self.net(inp_) 338 | loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction="batchmean") 339 | loss.backward() 340 | del inp, target, pred, loss 341 | torch.cuda.empty_cache() 342 | 343 | # replace cached raw_inputs, raw_outs 344 | if isinstance(module, MinMaxQuantLinear): 345 | module.raw_input = torch.cat(module.raw_input, dim=0) 346 | module.raw_out = torch.cat(module.raw_out, dim=0) 347 | if isinstance(module, MinMaxQuantConv2d): 348 | module.raw_input = torch.cat(module.raw_input, dim=0) 349 | module.raw_out = torch.cat(module.raw_out, dim=0) 350 | if isinstance(module, MinMaxQuantMatMul): 351 | module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input] 352 | module.raw_out = torch.cat(module.raw_out, dim=0) 353 | if hasattr(module, "metric"): 354 | module.raw_grad = torch.cat(module.raw_grad, dim=0) 355 | for hook in hooks: 356 | hook.remove() 357 | 358 | # run calibration step2 359 | with torch.no_grad(): 360 | if isinstance(module, MinMaxQuantLinear): 361 | module.calibration_step2() 362 | if isinstance(module, MinMaxQuantConv2d): 363 | module.calibration_step2() 364 | if isinstance(module, MinMaxQuantMatMul): 365 | module.calibration_step2() 366 | torch.cuda.empty_cache() 367 | 368 | # finishing up current module calibration 369 | if self.sequential: 370 | module.mode = "quant_forward" 371 | else: 372 | module.mode = "raw" 373 | 374 | # finish calibration 375 | for name, module in self.wrapped_modules.items(): 376 | module.mode = "quant_forward" 377 | 378 | print("hessian calibration finished") --------------------------------------------------------------------------------