├── .gitignore ├── LICENSE ├── README.md ├── doc └── ops.md ├── examples ├── ckpt │ ├── mnist_cnn.pth │ ├── mnist_cnn_bn.pth │ ├── mnist_cnn_bn_qat.pth │ └── mnist_cnn_qat.pth ├── models │ ├── mobilenetv2.py │ ├── model.py │ ├── resnet.py │ └── vgg.py ├── ptq │ ├── ptq.py │ └── tflite_ptq.py ├── qat │ └── qat.py ├── script │ ├── ptq_cifar10.sh │ ├── qat_cifar10.sh │ └── train_cifar10.sh ├── train.py └── utils.py ├── img ├── IMG_1294.png └── resnet_onnx.png ├── setup.py ├── test ├── calc_scale_zeropoint_test.py ├── qadd_test.py ├── qconcat_test.py ├── qdiv_test.py ├── qlayernorm_test.py ├── qmatmul_test.py ├── qmean_test.py ├── qmul_test.py ├── qsoftmax_test.py ├── qsqrt_test.py ├── qsub_test.py └── sqrt_interger_test.py └── torchquanter ├── __init__.py ├── nn ├── __init__.py ├── base.py ├── qadd.py ├── qavgpool2d.py ├── qconcat.py ├── qconv2d.py ├── qconvbnrelu.py ├── qdiv.py ├── qlayernorm.py ├── qlinear.py ├── qmatmul.py ├── qmaxpool2d.py ├── qmean.py ├── qmul.py ├── qnorm.py ├── qrelu.py ├── qsigmoid.py ├── qsoftmax.py ├── qsoftmax_w_policy.py ├── qsqrt.py └── qsub.py └── utils ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | ckpt 4 | .vscode 5 | *.egg-info 6 | 7 | train_demo.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchQuanter 2 | 3 | TorchQuanter is designed to quantize pytorch model. 4 | 5 | ## Install 6 | ``` 7 | pip install -e . 8 | ``` 9 | 10 | ## Quantization specification 11 | ```python 12 | signed 13 | True: int8 14 | False: uint8 15 | 16 | symmetric_weight: 17 | True: int8, zero_point=0 18 | 19 | qmode: 20 | per_tensor: 21 | * "module weights share the same scale and zero_point" 22 | 23 | per_channel: 24 | * "each output channel of module weights has a scale and a zero_point" 25 | * support op: Conv, Depthwise-conv 26 | ``` 27 | 28 | ## Support operations 29 | ``` 30 | Conv2d, Conv2d + BatchNorm2d + ReLU 31 | Linear, Linear + ReLU 32 | ReLU 33 | MaxPool2d 34 | ``` 35 | 36 | ## Support Export ONNX 37 | Export onnx demo can be found in [examples/ptq/ptq.py](examples/ptq/ptq.py) 38 | 39 | -------------------------------------------------------------------------------- /doc/ops.md: -------------------------------------------------------------------------------- 1 | # TorchQuanter operators 2 | 3 | ## LayerNorm 4 | 5 | The layernrom in Pytorch 6 | ```python 7 | import torch 8 | 9 | x = torch.rand(32, 64) 10 | mean_ = x.mean(dim=-1, keepdim=True) 11 | var_ = torch.sum((x - mean_)**2, dim=-1, keepdim=True) / x.shape[-1] 12 | 13 | output1 = (x - mean_) / torch.sqrt(var_) 14 | output2 = torch.nn.functional.layer_norm(x, (64,)) 15 | 16 | print(output1.shape) 17 | print(output2.shape) 18 | print(output1[0][0:10]) 19 | print(output2[0][0:10]) 20 | ``` 21 | 22 | Here is output: 23 | ``` 24 | torch.Size([32, 64]) 25 | torch.Size([32, 64]) 26 | tensor([ 0.3549, -0.2236, 0.0240, -0.1606, 0.7370, 0.9559, 1.2528, 0.6228, 27 | 0.3453, 0.8060]) 28 | tensor([ 0.3549, -0.2236, 0.0240, -0.1605, 0.7370, 0.9558, 1.2528, 0.6228, 29 | 0.3453, 0.8060]) 30 | ``` 31 | 32 | We can find the denominator of `var_` is not `n-1` but `n`, or it will not calculate the same result with `layer_norm`. 33 | 34 | The explanation may to use `torch.var(x, unbiased=False)` unbiased when calculating variance. Details can be seen in [link](https://stackoverflow.com/questions/66289517/layer-normalization-in-pytorch) 35 | 36 | ### Quantization inference 37 | 38 | 105 | 106 | 推理由 标准化 * W + bias 组成,由于标准化中存在除法会导致输出为小数, 107 | 因此会对其进行放大,即乘上$2^{8-1}$次方,那么数学上等价之后就要除以$2^{8-1}$, 108 | 那就可以把这个除以$2^{8-1}$给融合到output_scale里面 109 | 110 | 需要注意,由于标准化过程会消除`input scale`和`input zero_point`的影响,因此最后 111 | ```python 112 | self.M = self.qw.scale / (self.qo.scale * 2**(8 - 1)) 113 | ``` 114 | 115 | QLayerNorm量化推理 116 | ```python 117 | def quantize_inference(self, x, mode=None): 118 | x = x - self.qi.zero_point 119 | 120 | # Interger-only LayerNorm 121 | mean_ = x.mean(dim=-1, keepdim=True) # int16 122 | sum_ = torch.sum((x - mean_)**2, dim=-1, keepdim=True).clamp(*get_qmin_qmax(self.max_bits, signed=True)) # 裁剪到32bit范围内 123 | var_ = torch.floor(sum_ / x.shape[-1]) 124 | var_[var_ == 0.] = 1. # prevent overflow 125 | # std_ = sqrt_interger(var_) # 比较费时间,此处快速评估无需使用 126 | std_ = torch.sqrt(var_).floor() 127 | factor = torch.floor(2**(8 - 1) / std_) 128 | x = torch.floor(torch.clamp((x - mean_) * factor, *get_qmin_qmax(16, signed=True))) 129 | 130 | if self.layernorm_module.elementwise_affine: 131 | x = x * self.layernorm_module.weight.data + self.layernorm_module.bias.data 132 | x = x.clamp(*get_qmin_qmax(self.max_bits, signed=True)) 133 | 134 | if mode is None: 135 | x = self.M * x 136 | x.round_() 137 | elif mode == 'cmsis_nn': 138 | multiplier, shift = approximate_float(self.M) 139 | round_ = 1 << (shift - 1) 140 | x = (x * multiplier + round_) >> (31 - shift) 141 | else: 142 | raise Exception(f'Unknown mode {mode}') 143 | x = x + self.qo.zero_point 144 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 145 | return x 146 | ``` 147 | 148 | 完整的代码实现可参考: 149 | 1. [QLayerNorm](../torchquanter/nn/qlayernorm.py) 150 | 151 | --------------------------------- 152 | 153 | 154 | ## Softmax 155 | The input and output of Softmax is `int8` or `uint8`. 156 | But it will calculate with floating point in PC, and it will calculate with fixed point in Micro. 157 | 158 | input zero_point for softmax is useless in softmax. 159 | 160 | **output scale is fixed to** `1/256`**, zero_point is fixed to** `-128`. 161 | [link1](https://stackoverflow.com/questions/54052091/softmax-tensorflow-lite-not-behaving-properly/54584333#54584333) 162 | 163 | ### Softmax in CMSIS-NN 164 | Note that the `arm_softmax_s8` in CMSIS-NN needs parameters `mult` and `shift`, which is different from other layers. 165 | 166 | other layer: 167 | ```python 168 | approximate_float(input_scale) 169 | ``` 170 | 171 | softmax in CMSIS-NN `scale` generate: 172 | ```python 173 | softmax_input_integer_bits = 5 # 8bit定点数中整型占5bit,应该不用修改 174 | 175 | # 这里将softmax输入的scale重新生成为接口需要的 176 | input_scale = min(input_scale * (1 << (31 - softmax_input_integer_bits)), 177 | (1 << 31) - 1) 178 | # 使用函数得到 arm_softmax_s8 所需要的 mult 和 shift 179 | approximate_float(input_scale) 180 | ``` 181 | [reference link](https://github.com/ARM-software/CMSIS_5/blob/cf675280148688a50834e7b0496022360e5431cd/CMSIS/NN/Tests/UnitTest/generate_test_data.py#L781) 182 | 183 | 184 | example: 185 | ```python 186 | def test_softmax_s8(): 187 | # 数据来自官方测试用例 188 | input_data = torch.tensor([-80, -48, 16, 0, -96], dtype=torch.float32) 189 | gold_output = torch.tensor([-128, -125, 56, -60, -128], dtype=torch.float32) 190 | input_mult = 1077952576 191 | input_left_shift = 23 192 | diff_min = -248 # 暂时不知道干什么用的 193 | 194 | # softmax 不需要input_zero_point,数学上不影响结果 195 | x = input_data - input_data.max() 196 | 197 | # 这里应该是官方计算中从 int8 -> fixed point 的方法 198 | x = ((x * input_mult) >> (31 - input_left_shift)) / (1 << (31 - 5)) 199 | 200 | # 转成 fixed point后直接输入softmax函数中进行测试,结果正确 201 | out1 = F.softmax(x, dim=-1) 202 | out1 = out1 / (1 / 256.) - 128 # output scale和zero_point是定死的 203 | out1.round_() 204 | assert (out1 == gold_output).all(), print(out1) 205 | ``` -------------------------------------------------------------------------------- /examples/ckpt/mnist_cnn.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roxbili/TorchQuanter/7ed0a9ffb043d9c46231383ead83aaaef97e77a3/examples/ckpt/mnist_cnn.pth -------------------------------------------------------------------------------- /examples/ckpt/mnist_cnn_bn.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roxbili/TorchQuanter/7ed0a9ffb043d9c46231383ead83aaaef97e77a3/examples/ckpt/mnist_cnn_bn.pth -------------------------------------------------------------------------------- /examples/ckpt/mnist_cnn_bn_qat.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roxbili/TorchQuanter/7ed0a9ffb043d9c46231383ead83aaaef97e77a3/examples/ckpt/mnist_cnn_bn_qat.pth -------------------------------------------------------------------------------- /examples/ckpt/mnist_cnn_qat.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roxbili/TorchQuanter/7ed0a9ffb043d9c46231383ead83aaaef97e77a3/examples/ckpt/mnist_cnn_qat.pth -------------------------------------------------------------------------------- /examples/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchquanter.nn import QConvBNReLU, QAdd, QReLU, QMean, QLinear, QAdaptiveAvgPool2d 7 | 8 | 9 | class BaseBlock(nn.Module): 10 | alpha = 1 11 | def __init__(self, input_channel, output_channel, t = 6, downsample = False, **kwargs): 12 | """ 13 | t: expansion factor, t*input_channel is channel of expansion layer 14 | alpha: width multiplier, to get thinner models 15 | rho: resolution multiplier, to get reduced representation 16 | """ 17 | super(BaseBlock, self).__init__() 18 | self.stride = 2 if downsample else 1 19 | self.downsample = downsample 20 | self.shortcut = (not downsample) and (input_channel == output_channel) 21 | # apply alpha 22 | input_channel = int(self.alpha * input_channel) 23 | output_channel = int(self.alpha * output_channel) 24 | # for main path: 25 | c = t * input_channel 26 | # 1x1 point wise conv 27 | self.conv1 = nn.Conv2d(input_channel, c, kernel_size = 1, bias = False) 28 | self.bn1 = nn.BatchNorm2d(c) 29 | # 3x3 depth wise conv 30 | self.conv2 = nn.Conv2d(c, c, kernel_size = 3, stride = self.stride, padding = 1, groups = c, bias = False) 31 | self.bn2 = nn.BatchNorm2d(c) 32 | # 1x1 point wise conv 33 | self.conv3 = nn.Conv2d(c, output_channel, kernel_size = 1, bias = False) 34 | self.bn3 = nn.BatchNorm2d(output_channel) 35 | 36 | def forward(self, inputs): 37 | # main path 38 | x = F.relu(self.bn1(self.conv1(inputs)), inplace = True) 39 | x = F.relu(self.bn2(self.conv2(x)), inplace = True) 40 | x = self.bn3(self.conv3(x)) 41 | # shortcut path 42 | x = x + inputs if self.shortcut else x 43 | return x 44 | 45 | def quantize(self, num_bits=8, signed=True): 46 | self.qconv1 = QConvBNReLU(self.conv1, self.bn1, qi=False, qo=True, num_bits=num_bits, signed=signed) 47 | self.qconv2 = QConvBNReLU(self.conv2, self.bn2, qi=False, qo=True, num_bits=num_bits, signed=signed) 48 | self.qconv3 = QConvBNReLU(self.conv3, self.bn3, relu=False, qi=False, qo=True, num_bits=num_bits, signed=signed) 49 | self.qadd = QAdd(qi1=False, qi2=False, qo=True, num_bits=num_bits, signed=signed) 50 | 51 | def quantize_forward(self, inputs): 52 | x = inputs 53 | x = self.qconv1(x) 54 | x = self.qconv2(x) 55 | x = self.qconv3(x) 56 | if self.shortcut: 57 | x = self.qadd(x, inputs) 58 | return x 59 | 60 | def freeze(self, input_qi): 61 | self.qconv1.freeze(input_qi) 62 | self.qconv2.freeze(self.qconv1.qo) 63 | self.qconv3.freeze(self.qconv2.qo) 64 | if self.shortcut: 65 | self.qadd.freeze(self.qconv3.qo, input_qi) 66 | return self.qadd.qo 67 | return self.qconv3.qo 68 | 69 | def quantize_inference(self, qx_in, mode=None): 70 | qx = qx_in 71 | qx = self.qconv1.quantize_inference(qx, mode=mode) 72 | qx = self.qconv2.quantize_inference(qx, mode=mode) 73 | qx = self.qconv3.quantize_inference(qx, mode=mode) 74 | if self.shortcut: 75 | qx = self.qadd.quantize_inference(qx, qx_in) 76 | return qx 77 | 78 | 79 | class MobileNetV2(nn.Module): 80 | def __init__(self, output_size = 10, alpha = 1, **kwargs): 81 | super(MobileNetV2, self).__init__() 82 | self.output_size = output_size 83 | 84 | # first conv layer 85 | self.conv0 = nn.Conv2d(3, int(32*alpha), kernel_size = 3, stride = 1, padding = 1, bias = False) 86 | self.bn0 = nn.BatchNorm2d(int(32*alpha)) 87 | 88 | # build bottlenecks 89 | BaseBlock.alpha = alpha 90 | self.bottlenecks = nn.Sequential( 91 | BaseBlock(32, 16, t = 1, downsample = False), 92 | BaseBlock(16, 24, downsample = False), 93 | BaseBlock(24, 24), 94 | BaseBlock(24, 32, downsample = False), 95 | BaseBlock(32, 32), 96 | BaseBlock(32, 32), 97 | BaseBlock(32, 64, downsample = True), 98 | BaseBlock(64, 64), 99 | BaseBlock(64, 64), 100 | BaseBlock(64, 64), 101 | BaseBlock(64, 96, downsample = False), 102 | BaseBlock(96, 96), 103 | BaseBlock(96, 96), 104 | BaseBlock(96, 160, downsample = True), 105 | BaseBlock(160, 160), 106 | BaseBlock(160, 160), 107 | BaseBlock(160, 320, downsample = False)) 108 | 109 | # last conv layers and fc layer 110 | self.conv1 = nn.Conv2d(int(320*alpha), 1280, kernel_size = 1, bias = False) 111 | self.bn1 = nn.BatchNorm2d(1280) 112 | self.fc = nn.Linear(1280, output_size) 113 | 114 | # weights init 115 | self.weights_init() 116 | 117 | 118 | def weights_init(self): 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 122 | m.weight.data.normal_(0, math.sqrt(2. / n)) 123 | 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | 128 | 129 | def forward(self, inputs): 130 | # first conv layer 131 | x = F.relu(self.bn0(self.conv0(inputs)), inplace = True) 132 | # assert x.shape[1:] == torch.Size([32, 32, 32]) 133 | # bottlenecks 134 | x = self.bottlenecks(x) 135 | # assert x.shape[1:] == torch.Size([320, 8, 8]) 136 | # last conv layer 137 | x = F.relu(self.bn1(self.conv1(x)), inplace = True) 138 | # assert x.shape[1:] == torch.Size([1280,8,8]) 139 | # global pooling and fc (in place of conv 1x1 in paper) 140 | x = F.adaptive_avg_pool2d(x, 1) 141 | x = x.view(x.shape[0], -1) 142 | x = self.fc(x) 143 | return x 144 | 145 | def quantize(self, num_bits=8, signed=True, symmetric_feature=False): 146 | self.qconv0 = QConvBNReLU(self.conv0, self.bn0, qi=True, qo=True, 147 | num_bits=num_bits, signed=signed, symmetric_feature=symmetric_feature) 148 | for i in range(len(self.bottlenecks)): 149 | self.bottlenecks[i].quantize(symmetric_feature=symmetric_feature) 150 | self.qconv1 = QConvBNReLU(self.conv1, self.bn1, qi=False, qo=True, 151 | num_bits=num_bits, signed=signed, symmetric_feature=symmetric_feature) 152 | self.qavg = QAdaptiveAvgPool2d(1, qi=False, qo=True, 153 | num_bits=num_bits, signed=signed, symmetric_feature=symmetric_feature) 154 | # self.qavg = QMean(dim=[-1, -2], keepdim=True, qi=False, qo=True, num_bits=num_bits, signed=signed) 155 | self.qfc = QLinear(self.fc, qi=False, qo=True, relu=False, 156 | num_bits=num_bits, signed=signed, symmetric_feature=symmetric_feature) 157 | 158 | def quantize_forward(self, x): 159 | x = self.qconv0(x) 160 | for i in range(len(self.bottlenecks)): 161 | x = self.bottlenecks[i].quantize_forward(x) 162 | x = self.qconv1(x) 163 | x = self.qavg(x) 164 | x = x.view(x.shape[0], -1) 165 | x = self.qfc(x) 166 | return x 167 | 168 | def freeze(self): 169 | """ 170 | 统计完min、max后将网络彻底变成int8,例如将weight、bias变成int8 171 | """ 172 | self.qconv0.freeze() 173 | tmp_qo = self.qconv0.qo 174 | for i in range(len(self.bottlenecks)): 175 | tmp_qo = self.bottlenecks[i].freeze(tmp_qo) 176 | self.qconv1.freeze(tmp_qo) 177 | self.qavg.freeze(self.qconv1.qo) 178 | # self.qfc.freeze(self.qconv1.qo) 179 | self.qfc.freeze(self.qavg.qo) 180 | 181 | def quantize_inference(self, x, mode='cmsis_nn'): 182 | """ 183 | 真正的量化推理,使用int8 184 | """ 185 | qx = self.qconv0.qi.quantize_tensor(x) 186 | qx = self.qconv0.quantize_inference(qx, mode=mode) 187 | 188 | for i in range(len(self.bottlenecks)): 189 | qx = self.bottlenecks[i].quantize_inference(qx, mode=mode) 190 | qx = self.qconv1.quantize_inference(qx, mode=mode) 191 | qx = self.qavg.quantize_inference(qx, mode=mode) 192 | # out = self.qavg.qo.dequantize_tensor(qx) 193 | qx = qx.view(qx.shape[0], -1) 194 | qx = self.qfc.quantize_inference(qx, mode=mode) 195 | out = self.qfc.qo.dequantize_tensor(qx) 196 | return out 197 | 198 | 199 | def mobilenetv2_quant(pretrained=False, **kwargs): 200 | num_classes = kwargs['num_classes'] 201 | return MobileNetV2(num_classes) 202 | -------------------------------------------------------------------------------- /examples/models/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | 4 | 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 6 | 7 | Deep Residual Learning for Image Recognition 8 | https://arxiv.org/abs/1512.03385v1 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import LayerNorm 13 | from torch.nn import functional as F 14 | from torchquanter.nn import QConvBNReLU, QAdd, QReLU, QMean, QLinear, QAdaptiveAvgPool2d 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | """Basic Block for resnet 18 and resnet 34 19 | 20 | """ 21 | 22 | # BasicBlock and BottleNeck block 23 | # have different output size 24 | # we use class attribute expansion 25 | # to distinct 26 | expansion = 1 27 | 28 | def __init__(self, in_channels, out_channels, stride=1): 29 | super().__init__() 30 | 31 | # residual function 32 | self.residual_function = nn.Sequential( 33 | nn.Conv2d(in_channels, 34 | out_channels, 35 | kernel_size=3, 36 | stride=stride, 37 | padding=1, 38 | bias=True), nn.BatchNorm2d(out_channels), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(out_channels, 41 | out_channels * BasicBlock.expansion, 42 | kernel_size=3, 43 | padding=1, 44 | bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)) 46 | 47 | # shortcut 48 | self.shortcut = nn.Sequential() 49 | 50 | # the shortcut output dimension is not the same with residual function 51 | # use 1*1 convolution to match the dimension 52 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_channels, 55 | out_channels * BasicBlock.expansion, 56 | kernel_size=1, 57 | stride=stride, 58 | bias=False), 59 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)) 60 | 61 | def forward(self, x): 62 | return nn.ReLU(inplace=True)(self.residual_function(x) + 63 | self.shortcut(x)) 64 | 65 | def quantize(self, first_qi=True, num_bits=8, signed=True, symmetric_feature=False): 66 | self.qresidual_function = nn.Sequential( 67 | QConvBNReLU( 68 | self.residual_function[0], 69 | self.residual_function[1], 70 | qi=first_qi, qo=True, 71 | num_bits=num_bits, signed=signed, 72 | symmetric_feature=symmetric_feature 73 | ), 74 | QConvBNReLU( 75 | self.residual_function[3], 76 | self.residual_function[4], 77 | relu=False, qi=False, qo=True, 78 | num_bits=num_bits, signed=signed, 79 | symmetric_feature=symmetric_feature 80 | ) 81 | ) 82 | if len(self.shortcut) > 0: 83 | self.qshortcut = QConvBNReLU( 84 | self.shortcut[0], self.shortcut[1], relu=False, 85 | qi=False, qo=True, 86 | num_bits=num_bits, signed=signed, 87 | symmetric_feature=symmetric_feature 88 | ) 89 | self.qadd = QAdd(qi1=False, qi2=False, qo=True, num_bits=num_bits, 90 | signed=signed, symmetric_feature=symmetric_feature) 91 | self.qrelu = QReLU(qi=False, num_bits=num_bits, signed=signed, symmetric_feature=symmetric_feature) 92 | 93 | def quantize_forward(self, x): 94 | x1 = self.qresidual_function(x) 95 | if len(self.shortcut) > 0: 96 | x2 = self.qshortcut(x) 97 | else: 98 | x2 = x 99 | x = self.qadd(x1, x2) 100 | x = self.qrelu(x) 101 | return x 102 | 103 | def freeze(self, qi=None): 104 | qo_1 = self.qresidual_function[0].freeze(qi=qi) 105 | qo_1 = self.qresidual_function[1].freeze(qi=qo_1) 106 | if len(self.shortcut) > 0: 107 | qo_2 = self.qshortcut.freeze(qi=qi) 108 | qo = self.qadd.freeze(qi1=qo_1, qi2=qo_2) 109 | else: 110 | qo = self.qadd.freeze(qi1=qo_1, qi2=qi) 111 | qo = self.qrelu.freeze(qi=qo) 112 | return qo 113 | 114 | def quantize_inference(self, qx, mode='cmsis_nn'): 115 | qx1 = self.qresidual_function[0].quantize_inference(qx, mode=mode) 116 | qx1 = self.qresidual_function[1].quantize_inference(qx1, mode=mode) 117 | if len(self.shortcut) > 0: 118 | qx2 = self.qshortcut.quantize_inference(qx, mode=mode) 119 | else: 120 | qx2 = qx 121 | qx = self.qadd.quantize_inference(qx1, qx2, mode=mode) 122 | qx = self.qrelu.quantize_inference(qx, mode=mode) 123 | return qx 124 | 125 | 126 | class BottleNeck(nn.Module): 127 | """Residual block for resnet over 50 layers 128 | 129 | """ 130 | expansion = 4 131 | 132 | def __init__(self, in_channels, out_channels, stride=1): 133 | super().__init__() 134 | self.residual_function = nn.Sequential( 135 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 136 | nn.BatchNorm2d(out_channels), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(out_channels, 139 | out_channels, 140 | stride=stride, 141 | kernel_size=3, 142 | padding=1, 143 | bias=False), 144 | nn.BatchNorm2d(out_channels), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(out_channels, 147 | out_channels * BottleNeck.expansion, 148 | kernel_size=1, 149 | bias=False), 150 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 151 | ) 152 | 153 | self.shortcut = nn.Sequential() 154 | 155 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 156 | self.shortcut = nn.Sequential( 157 | nn.Conv2d(in_channels, 158 | out_channels * BottleNeck.expansion, 159 | stride=stride, 160 | kernel_size=1, 161 | bias=False), 162 | nn.BatchNorm2d(out_channels * BottleNeck.expansion)) 163 | 164 | def forward(self, x): 165 | return nn.ReLU(inplace=True)(self.residual_function(x) + 166 | self.shortcut(x)) 167 | 168 | 169 | class ResNet(nn.Module): 170 | def __init__(self, block, num_block, num_classes=100): 171 | super().__init__() 172 | 173 | self.in_channels = 64 174 | 175 | self.conv1 = nn.Sequential( 176 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 177 | nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 178 | # we use a different inputsize than the original paper 179 | # so conv2_x's stride is 1 180 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 181 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 182 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 183 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 184 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 185 | self.fc = nn.Linear(512 * block.expansion, num_classes) 186 | 187 | def _make_layer(self, block, out_channels, num_blocks, stride): 188 | """make resnet layers(by layer i didnt mean this 'layer' was the 189 | same as a neuron netowork layer, ex. conv layer), one layer may 190 | contain more than one residual block 191 | 192 | Args: 193 | block: block type, basic block or bottle neck block 194 | out_channels: output depth channel number of this layer 195 | num_blocks: how many blocks per layer 196 | stride: the stride of the first block of this layer 197 | 198 | Return: 199 | return a resnet layer 200 | """ 201 | 202 | # we have num_block blocks per layer, the first block 203 | # could be 1 or 2, other blocks would always be 1 204 | strides = [stride] + [1] * (num_blocks - 1) 205 | layers = [] 206 | for stride in strides: 207 | layers.append(block(self.in_channels, out_channels, stride)) 208 | self.in_channels = out_channels * block.expansion 209 | 210 | return nn.Sequential(*layers) 211 | 212 | def forward(self, x): 213 | output = self.conv1(x) 214 | output = self.conv2_x(output) 215 | output = self.conv3_x(output) 216 | output = self.conv4_x(output) 217 | output = self.conv5_x(output) 218 | output = self.avg_pool(output) 219 | output = output.view(output.size(0), -1) 220 | output = self.fc(output) 221 | 222 | return output 223 | 224 | def quantize(self, num_bits=8, signed=True): 225 | self.qconv1 = QConvBNReLU(self.conv1[0], self.conv1[1], qi=True, 226 | num_bits=num_bits, signed=signed) 227 | # conv2_x 228 | for block in self.conv2_x: 229 | block.quantize(first_qi=False, num_bits=num_bits, signed=signed) 230 | # conv3_x 231 | for block in self.conv3_x: 232 | block.quantize(first_qi=False, num_bits=num_bits, signed=signed) 233 | # conv4_x 234 | for block in self.conv4_x: 235 | block.quantize(first_qi=False, num_bits=num_bits, signed=signed) 236 | # conv5_x 237 | for block in self.conv5_x: 238 | block.quantize(first_qi=False, num_bits=num_bits, signed=signed) 239 | self.qavg_pool = QAdaptiveAvgPool2d((1, 1), qi=False, qo=True, 240 | num_bits=num_bits, signed=signed) 241 | self.qfc = QLinear(self.fc, relu=False, qi=False, qo=True, 242 | num_bits=num_bits, signed=signed) 243 | 244 | def quantize_forward(self, x): 245 | x = self.qconv1(x) 246 | # conv2_x 247 | for block in self.conv2_x: 248 | x = block.quantize_forward(x) 249 | # conv3_x 250 | for block in self.conv3_x: 251 | x = block.quantize_forward(x) 252 | # conv4_x 253 | for block in self.conv4_x: 254 | x = block.quantize_forward(x) 255 | # conv5_x 256 | for block in self.conv5_x: 257 | x = block.quantize_forward(x) 258 | x = self.qavg_pool(x) 259 | x = x.view(x.size(0), -1) 260 | x = self.qfc(x) 261 | return x 262 | 263 | def freeze(self): 264 | qo = self.qconv1.freeze() 265 | # conv2_x 266 | for block in self.conv2_x: 267 | qo = block.freeze(qi=qo) 268 | # conv3_x 269 | for block in self.conv3_x: 270 | qo = block.freeze(qi=qo) 271 | # conv4_x 272 | for block in self.conv4_x: 273 | qo = block.freeze(qi=qo) 274 | # conv5_x 275 | for block in self.conv5_x: 276 | qo = block.freeze(qi=qo) 277 | qo = self.qavg_pool.freeze(qi=qo) 278 | qo = self.qfc.freeze(qi=qo) 279 | return qo 280 | 281 | def quantize_inference(self, x, mode='cmsis_nn'): 282 | qx = self.qconv1.qi.quantize_tensor(x) 283 | qx = self.qconv1.quantize_inference(qx) 284 | # conv2_x 285 | for block in self.conv2_x: 286 | qx = block.quantize_inference(qx, mode=mode) 287 | # conv3_x 288 | for block in self.conv3_x: 289 | qx = block.quantize_inference(qx, mode=mode) 290 | # conv4_x 291 | for block in self.conv4_x: 292 | qx = block.quantize_inference(qx, mode=mode) 293 | # conv5_x 294 | for block in self.conv5_x: 295 | qx = block.quantize_inference(qx, mode=mode) 296 | qx = self.qavg_pool.quantize_inference(qx, mode=mode) 297 | qx = qx.view(qx.size(0), -1) 298 | qx = self.qfc.quantize_inference(qx, mode=mode) 299 | x = self.qfc.qo.dequantize_tensor(qx) 300 | return x 301 | 302 | 303 | def resnet18_quant(pretrained=False, **kwargs): 304 | """ return a ResNet 18 object 305 | """ 306 | num_classes = 10 307 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes) 308 | 309 | 310 | def resnet34_quant(pretrained=False, **kwargs): 311 | """ return a ResNet 34 object 312 | """ 313 | num_classes = 10 314 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes) 315 | 316 | 317 | def resnet50_quant(pretrained=False, **kwargs): 318 | """ return a ResNet 50 object 319 | """ 320 | num_classes = 10 321 | return ResNet(BottleNeck, [3, 4, 6, 3], num_classes) 322 | 323 | 324 | def resnet101_quant(pretrained=False, **kwargs): 325 | """ return a ResNet 101 object 326 | """ 327 | num_classes = 10 328 | return ResNet(BottleNeck, [3, 4, 23, 3], num_classes) 329 | 330 | 331 | def resnet152_quant(pretrained=False, **kwargs): 332 | """ return a ResNet 152 object 333 | """ 334 | num_classes = 10 335 | return ResNet(BottleNeck, [3, 8, 36, 3], num_classes) 336 | -------------------------------------------------------------------------------- /examples/models/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torchquanter.nn import * 9 | 10 | 11 | __all__ = [ 12 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 13 | 'vgg19_bn', 'vgg19', 14 | ] 15 | 16 | 17 | class Feature(nn.Module): 18 | def __init__(self, cfg, batch_norm=False): 19 | super(Feature, self).__init__() 20 | self.cfg = cfg 21 | self.batch_norm = batch_norm 22 | layers = [] 23 | in_channels = 3 24 | for v in self.cfg: 25 | if v == 'M': 26 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 27 | else: 28 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 29 | if batch_norm: 30 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 31 | else: 32 | layers += [conv2d, nn.ReLU(inplace=True)] 33 | in_channels = v 34 | self.layers = nn.Sequential(*layers) 35 | 36 | def forward(self, x): 37 | return self.layers(x) 38 | 39 | def quantize(self, first_qi=True, num_bits=8, signed=True): 40 | qlayers = [] 41 | i = 0 42 | for v in self.cfg: 43 | if v == 'M': 44 | qlayers += [QMaxPool2d( 45 | self.layers[i], qi=False, 46 | num_bits=num_bits, signed=signed)] 47 | i += 1 48 | elif self.batch_norm: 49 | qlayers += [QConvBNReLU( 50 | self.layers[i], self.layers[i + 1], 51 | relu=True, qi=first_qi, num_bits=num_bits, signed=True)] 52 | i += 3 53 | first_qi = False 54 | else: 55 | qlayers += [QConv2d( 56 | self.layers[i], relu=True, qi=first_qi, 57 | num_bits=num_bits, signed=True)] 58 | i += 2 59 | first_qi = False 60 | self.qlayers = nn.Sequential(*qlayers) 61 | 62 | def quantize_forward(self, x): 63 | return self.qlayers(x) 64 | 65 | def freeze(self, qi=None): 66 | for op in self.qlayers: 67 | qi = op.freeze(qi=qi) 68 | return qi # return last op qo 69 | 70 | def quantize_inference(self, qx, mode='cmsis_nn'): 71 | for op in self.qlayers: 72 | qx = op.quantize_inference(qx, mode=mode) 73 | return qx 74 | 75 | 76 | class VGG(nn.Module): 77 | ''' 78 | VGG model 79 | ''' 80 | def __init__(self, features: Feature): 81 | super(VGG, self).__init__() 82 | self.features = features 83 | self.classifier = nn.Sequential( 84 | nn.Dropout(), 85 | nn.Linear(512, 512), 86 | nn.ReLU(True), 87 | nn.Dropout(), 88 | nn.Linear(512, 512), 89 | nn.ReLU(True), 90 | nn.Linear(512, 10), 91 | ) 92 | # Initialize weights 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 96 | m.weight.data.normal_(0, math.sqrt(2. / n)) 97 | m.bias.data.zero_() 98 | 99 | def forward(self, x): 100 | x = self.features(x) 101 | x = x.view(x.size(0), -1) 102 | x = self.classifier(x) 103 | return x 104 | 105 | def quantize(self, num_bits=8, signed=True): 106 | self.features.quantize(first_qi=True, num_bits=num_bits, signed=signed) 107 | self.qclassifier = nn.Sequential( 108 | nn.Dropout(), 109 | QLinear(self.classifier[1], relu=True, qi=False, num_bits=num_bits, signed=signed), 110 | nn.Dropout(), 111 | QLinear(self.classifier[4], relu=True, qi=False, num_bits=num_bits, signed=signed), 112 | QLinear(self.classifier[6], relu=False, qi=False, num_bits=num_bits, signed=signed), 113 | ) 114 | 115 | def quantize_forward(self, x): 116 | x = self.features.quantize_forward(x) 117 | x = x.view(x.size(0), -1) 118 | x = self.qclassifier(x) 119 | return x 120 | 121 | def freeze(self): 122 | qo = self.features.freeze() 123 | for op in self.qclassifier: 124 | if not isinstance(op, nn.Dropout): 125 | qo = op.freeze(qi=qo) 126 | return qo 127 | 128 | def quantize_inference(self, x, mode='cmsis_nn'): 129 | qx = self.features.qlayers[0].qi.quantize_tensor(x) 130 | qx = self.features.quantize_inference(qx, mode=mode) 131 | qx = qx.view(qx.size(0), -1) 132 | for op in self.qclassifier: 133 | if not isinstance(op, nn.Dropout): 134 | qx = op.quantize_inference(qx, mode=mode) 135 | x = self.qclassifier[-1].qo.dequantize_tensor(qx) 136 | return x 137 | 138 | 139 | cfg = { 140 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 141 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 142 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 143 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 144 | 512, 512, 512, 512, 'M'], 145 | } 146 | 147 | 148 | def vgg11(pretrained=False, **kwargs): 149 | """VGG 11-layer model (configuration "A")""" 150 | return VGG(Feature(cfg['A'])) 151 | 152 | 153 | def vgg11_bn(pretrained=False, **kwargs): 154 | """VGG 11-layer model (configuration "A") with batch normalization""" 155 | return VGG(Feature(cfg['A'], batch_norm=True)) 156 | 157 | 158 | def vgg13(pretrained=False, **kwargs): 159 | """VGG 13-layer model (configuration "B")""" 160 | return VGG(Feature(cfg['B'])) 161 | 162 | 163 | def vgg13_bn(pretrained=False, **kwargs): 164 | """VGG 13-layer model (configuration "B") with batch normalization""" 165 | return VGG(Feature(cfg['B'], batch_norm=True)) 166 | 167 | 168 | def vgg16(pretrained=False, **kwargs): 169 | """VGG 16-layer model (configuration "D")""" 170 | return VGG(Feature(cfg['D'])) 171 | 172 | 173 | def vgg16_bn(pretrained=False, **kwargs): 174 | """VGG 16-layer model (configuration "D") with batch normalization""" 175 | return VGG(Feature(cfg['D'], batch_norm=True)) 176 | 177 | 178 | def vgg19(pretrained=False, **kwargs): 179 | """VGG 19-layer model (configuration "E")""" 180 | return VGG(Feature(cfg['E'])) 181 | 182 | 183 | def vgg19_bn(pretrained=False, **kwargs): 184 | """VGG 19-layer model (configuration 'E') with batch normalization""" 185 | return VGG(Feature(cfg['E'], batch_norm=True)) 186 | -------------------------------------------------------------------------------- /examples/ptq/ptq.py: -------------------------------------------------------------------------------- 1 | """后训练量化""" 2 | 3 | import os, sys 4 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | 11 | from models.model import ( 12 | Model, ModelBN, ModelLinear, ModelShortCut, ModelBNNoReLU, 13 | ModelLayerNorm, ModelAttention, ModelMV2, ModelMV2Naive, ModelDepthwise, 14 | ModelMV2ShortCut, ModelTransformerEncoder, ModelConvEncoder, 15 | TinyFormerSupernetDMTPOnePath 16 | ) 17 | from torchquanter.utils import random_seed 18 | from models.resnet import resnet18_quant 19 | from models.mobilenetv2 import mobilenetv2_quant 20 | from utils import get_loader, export_onnx 21 | 22 | 23 | def _args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--dataset', default='mnist', help='Type of dataset') 26 | parser.add_argument('--dataset-dir', metavar='DIR', default='/tmp', 27 | help='Path to dataset') 28 | parser.add_argument('--mean', type=float, nargs='+', default=[0.1307,], metavar='MEAN', 29 | help='Override mean pixel value of dataset') 30 | parser.add_argument('--std', type=float, nargs='+', default=[0.3081,], metavar='STD', 31 | help='Override std deviation of of dataset') 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def full_inference(model, test_loader): 37 | correct = 0 38 | for i, (data, target) in enumerate(test_loader, 1): 39 | data, target = data.to(device), target.to(device) 40 | output = model(data) 41 | pred = output.argmax(dim=1, keepdim=True) 42 | correct += pred.eq(target.view_as(pred)).sum().item() 43 | print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset))) 44 | 45 | def quantize(model: Model, loader): 46 | correct = 0 47 | for i, (data, target) in enumerate(loader, 1): 48 | data, target = data.to(device), target.to(device) 49 | output = model.quantize_forward(data) 50 | pred = output.argmax(dim=1, keepdim=True) 51 | correct += pred.eq(target.view_as(pred)).sum().item() 52 | print('quantization finish') 53 | print('Train set: quantize_forward Accuracy: {:.2f}%\n'.format(100. * correct / len(loader.dataset))) 54 | 55 | def quantize_inference(model, test_loader): 56 | correct = 0 57 | for i, (data, target) in enumerate(test_loader, 1): 58 | data, target = data.to(device), target.to(device) 59 | output = model.quantize_inference(data) 60 | pred = output.argmax(dim=1, keepdim=True) 61 | correct += pred.eq(target.view_as(pred)).sum().item() 62 | print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset))) 63 | 64 | if __name__ == "__main__": 65 | random_seed(seed=42) 66 | args = _args() 67 | 68 | # parameters 69 | batch_size = 64 70 | save_model_dir = 'examples/ckpt' 71 | 72 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 73 | 74 | # dataset 75 | train_loader, test_loader = get_loader(args, batch_size) 76 | 77 | # 加载模型 78 | # model = Model() 79 | # model = ModelBN() 80 | # model = ModelBNNoReLU() 81 | # model = ModelLinear() 82 | # model = ModelShortCut() 83 | # model = ModelLayerNorm() 84 | # model = ModelAttention() 85 | # model = ModelDepthwise() 86 | # model = ModelMV2Naive() 87 | # model = ModelMV2() 88 | # model = ModelMV2ShortCut() 89 | # model = ModelTransformerEncoder() 90 | # model = ModelConvEncoder() 91 | # model = TinyFormerSupernetDMTPOnePath( 92 | # num_classes=10, downsample_layers=1, mv2block_layers=1, 93 | # transformer_layers=1, channel=[8, 8, 8], last_channel=8, 94 | # transformer0_embedding_dim=[16], transformer0_dim_feedforward=[16], 95 | # transformer1_embedding_dim=[16], transformer1_dim_feedforward=[16], 96 | # choice=[1,0,0,0], first_channel=1 97 | # ) 98 | # model = resnet18_quant() 99 | model = mobilenetv2_quant() 100 | 101 | model = model.to(device) 102 | state_dict = torch.load(os.path.join(save_model_dir, f'{args.dataset}_{model._get_name()}.pth'), map_location=device) 103 | model.load_state_dict(state_dict) 104 | 105 | model.eval() 106 | full_inference(model, test_loader) # 测试模型全精度的精度 107 | 108 | # 量化 109 | num_bits = 8 110 | print('Quantization bit: %d' % num_bits) 111 | model.quantize(num_bits=num_bits, signed=True, symmetric_feature=True) 112 | model = model.to(device) 113 | model.eval() 114 | quantize(model, train_loader) 115 | model.freeze() 116 | 117 | # 量化推理 118 | quantize_inference(model, test_loader) 119 | 120 | # 保存参数 121 | save_path = os.path.join(save_model_dir, f'{args.dataset}_{model._get_name()}_ptq.pth') 122 | torch.save(model.state_dict(), save_path) 123 | 124 | # 加载参数 125 | state_dict = torch.load(save_path) 126 | model.load_state_dict(state_dict) 127 | 128 | # 量化推理 129 | quantize_inference(model, test_loader) 130 | 131 | try: 132 | export_onnx(args, model) 133 | except torch.onnx.CheckerError: 134 | pass 135 | except Exception as e: 136 | raise e 137 | -------------------------------------------------------------------------------- /examples/ptq/tflite_ptq.py: -------------------------------------------------------------------------------- 1 | """tflite后训练量化""" 2 | 3 | import os, sys 4 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import datasets, transforms 9 | 10 | from models.model import ( 11 | Model, ModelBN, ModelLinear, ModelShortCut, ModelBNNoReLU, 12 | ModelLayerNorm, ModelAttention, ModelMV2, ModelMV2Naive, ModelDepthwise, 13 | ModelMV2ShortCut, ModelTransformerEncoder, ModelConvEncoder, 14 | TinyFormerSupernetDMTPOnePath 15 | ) 16 | from torchquanter.utils import random_seed 17 | from converter import Torch2TFLiteConverter 18 | 19 | 20 | def _args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--dataset-dir', metavar='DIR', default='/tmp', 23 | help='path to dataset') 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def full_inference(model, test_loader): 29 | correct = 0 30 | for i, (data, target) in enumerate(test_loader, 1): 31 | data, target = data.to(device), target.to(device) 32 | output = model(data) 33 | pred = output.argmax(dim=1, keepdim=True) 34 | correct += pred.eq(target.view_as(pred)).sum().item() 35 | print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset))) 36 | 37 | def representative_dataset(): 38 | for data, label in test_loader.dataset: 39 | data = data.unsqueeze(0) 40 | yield [data] 41 | 42 | if __name__ == "__main__": 43 | random_seed(seed=42) 44 | args = _args() 45 | 46 | # parameters 47 | batch_size = 64 48 | save_model_dir = 'examples/ckpt' 49 | 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | 52 | # dataset 53 | test_loader = torch.utils.data.DataLoader( 54 | datasets.MNIST(args.dataset_dir, train=False, download=False, 55 | transform=transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.1307,), (0.3081,)) 58 | ])), 59 | batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True 60 | ) 61 | 62 | # 加载模型 63 | # model = Model() 64 | # model = ModelBN() 65 | model = ModelBNNoReLU() 66 | # model = ModelLinear() 67 | # model = ModelShortCut() 68 | # model = ModelLayerNorm() 69 | # model = ModelAttention() 70 | # model = ModelDepthwise() 71 | # model = ModelMV2Naive() 72 | # model = ModelMV2() 73 | # model = ModelMV2ShortCut() 74 | # model = ModelTransformerEncoder() 75 | # model = ModelConvEncoder() 76 | # model = TinyFormerSupernetDMTPOnePath( 77 | # num_classes=10, downsample_layers=1, mv2block_layers=1, 78 | # transformer_layers=1, channel=[8, 8, 8], last_channel=8, 79 | # transformer0_embedding_dim=[16], transformer0_dim_feedforward=[16], 80 | # transformer1_embedding_dim=[16], transformer1_dim_feedforward=[16], 81 | # choice=[1,0,0,0], first_channel=1 82 | # ) 83 | 84 | model = model.to(device) 85 | state_dict = torch.load(os.path.join(save_model_dir, f'mnist_{model._get_name()}.pth'), map_location=device) 86 | model.load_state_dict(state_dict) 87 | 88 | model.eval() 89 | full_inference(model, test_loader) # 测试模型全精度的精度 90 | 91 | # 量化 92 | tmp_path = '/tmp/model.pth' 93 | torch.save(model, tmp_path) 94 | converter = Torch2TFLiteConverter( 95 | torch_model_path=tmp_path, 96 | tf_model_path='/tmp/model_converter/tf_model', 97 | tflite_model_save_path=os.path.join(save_model_dir, f'mnist_{model._get_name()}.lite'), 98 | target_shape=(28,28,1), 99 | representative_dataset=representative_dataset, 100 | evaluate_loader=test_loader 101 | ) 102 | converter.convert() -------------------------------------------------------------------------------- /examples/qat/qat.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | 10 | from models.model import ( 11 | Model, ModelBN, ModelLinear, ModelShortCut, ModelBNNoReLU, 12 | ModelLayerNorm, ModelAttention, ModelMV2, ModelMV2Naive, ModelDepthwise, 13 | ModelMV2ShortCut, ModelTransformerEncoder, ModelConvEncoder, 14 | TinyFormerSupernetDMTPOnePath, ModelBNSymmetric 15 | ) 16 | from torchquanter.utils import random_seed 17 | from models.resnet import resnet18_quant 18 | from utils import get_loader 19 | 20 | 21 | def _args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', default='mnist', help='Type of dataset') 24 | parser.add_argument('--dataset-dir', metavar='DIR', default='/tmp', 25 | help='Path to dataset') 26 | parser.add_argument('--mean', type=float, nargs='+', default=[0.1307,], metavar='MEAN', 27 | help='Override mean pixel value of dataset') 28 | parser.add_argument('--std', type=float, nargs='+', default=[0.3081,], metavar='STD', 29 | help='Override std deviation of of dataset') 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def quantize_aware_training(model: Model, device, train_loader, optimizer, epoch): 35 | model.train() 36 | lossLayer = torch.nn.CrossEntropyLoss() 37 | for batch_idx, (data, target) in enumerate(train_loader, 1): 38 | data, target = data.to(device), target.to(device) 39 | optimizer.zero_grad() 40 | output = model.quantize_forward(data) 41 | assert not torch.isnan(output).any() 42 | loss = lossLayer(output, target) 43 | loss.backward() 44 | optimizer.step() 45 | 46 | if batch_idx % 50 == 0: 47 | print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format( 48 | epoch, batch_idx * len(data), len(train_loader.dataset), loss.item() 49 | )) 50 | 51 | 52 | def quantize_validate(model: Model, device, test_loader): 53 | model.eval() 54 | test_loss = 0 55 | correct = 0 56 | lossLayer = torch.nn.CrossEntropyLoss(reduction='sum') 57 | for data, target in test_loader: 58 | data, target = data.to(device), target.to(device) 59 | output = model.quantize_forward(data) 60 | test_loss += lossLayer(output, target).item() 61 | pred = output.argmax(dim=1, keepdim=True) 62 | correct += pred.eq(target.view_as(pred)).sum().item() 63 | 64 | test_loss /= len(test_loader.dataset) 65 | 66 | print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format( 67 | test_loss, 100. * correct / len(test_loader.dataset) 68 | )) 69 | 70 | 71 | def full_inference(model, test_loader): 72 | correct = 0 73 | for i, (data, target) in enumerate(test_loader, 1): 74 | data, target = data.to(device), target.to(device) 75 | output = model(data) 76 | pred = output.argmax(dim=1, keepdim=True) 77 | correct += pred.eq(target.view_as(pred)).sum().item() 78 | print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset))) 79 | 80 | 81 | def quantize_inference(model, test_loader): 82 | correct = 0 83 | for i, (data, target) in enumerate(test_loader, 1): 84 | data, target = data.to(device), target.to(device) 85 | output = model.quantize_inference(data) 86 | pred = output.argmax(dim=1, keepdim=True) 87 | correct += pred.eq(target.view_as(pred)).sum().item() 88 | print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset))) 89 | 90 | 91 | if __name__ == "__main__": 92 | random_seed(seed=42) 93 | args = _args() 94 | 95 | batch_size = 64 96 | epochs = 10 97 | lr = 0.0001 98 | momentum = 0.5 99 | save_model_dir = 'examples/ckpt' 100 | 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | 103 | # dataset 104 | train_loader, test_loader = get_loader(args, batch_size) 105 | 106 | # 加载训练好的全精度模型 107 | # model = Model() 108 | # model = ModelBN() 109 | # model = ModelBNNoReLU() 110 | # model = ModelLinear() 111 | # model = ModelShortCut() 112 | # model = ModelLayerNorm() 113 | # model = ModelAttention() 114 | # model = ModelDepthwise() 115 | # model = ModelMV2Naive() 116 | # model = ModelMV2() 117 | # model = ModelMV2ShortCut() 118 | # model = ModelTransformerEncoder() 119 | # model = ModelConvEncoder() 120 | # model = TinyFormerSupernetDMTPOnePath( 121 | # num_classes=10, downsample_layers=1, mv2block_layers=1, 122 | # transformer_layers=1, channel=[8, 8, 8], last_channel=8, 123 | # transformer0_embedding_dim=[16], transformer0_dim_feedforward=[16], 124 | # transformer1_embedding_dim=[16], transformer1_dim_feedforward=[16], 125 | # choice=[1,0,0,0], first_channel=1 126 | # ) # 对学习率特别敏感,学习旅需要设置非常小 127 | model = resnet18_quant() 128 | 129 | model = model.to(device) 130 | state_dict = torch.load(os.path.join(save_model_dir, f'{args.dataset}_{model._get_name()}.pth'), map_location=device) 131 | model.load_state_dict(state_dict) 132 | 133 | model.eval() 134 | # full_inference(model, test_loader) # 测试模型全精度的精度 135 | 136 | # init 137 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) 138 | num_bits = 8 139 | model.quantize(num_bits=num_bits, signed=True) 140 | print('Quantization bit: %d' % num_bits) 141 | model = model.to(device) 142 | 143 | # train 144 | model.train() 145 | for epoch in range(1, epochs + 1): 146 | quantize_aware_training(model, device, train_loader, optimizer, epoch) 147 | quantize_validate(model, device, test_loader) 148 | 149 | # save qat model 150 | model.eval() 151 | torch.save(model.state_dict(), os.path.join(save_model_dir, f'{args.dataset}_{model._get_name()}_qat.pth')) 152 | 153 | # fp32 -> int8/uint8 154 | model.freeze() 155 | 156 | # 量化推理 157 | quantize_inference(model, test_loader) 158 | -------------------------------------------------------------------------------- /examples/script/ptq_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 examples/ptq/ptq.py --dataset "cifar10" \ 4 | --dataset-dir "/home/LAB/leifd/dataset/cifar/cifar-10" \ 5 | --mean 0.4914 0.4822 0.4465 --std 0.2470 0.2435 0.2616 6 | -------------------------------------------------------------------------------- /examples/script/qat_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 examples/qat/qat.py --dataset "cifar10" \ 4 | --dataset-dir "/home/LAB/leifd/dataset/cifar/cifar-10" \ 5 | --mean 0.4914 0.4822 0.4465 --std 0.2470 0.2435 0.2616 6 | -------------------------------------------------------------------------------- /examples/script/train_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 examples/train.py --dataset "cifar10" \ 4 | --dataset-dir "/home/LAB/leifd/dataset/cifar/cifar-10" \ 5 | --mean 0.4914 0.4822 0.4465 --std 0.2470 0.2435 0.2616 6 | -------------------------------------------------------------------------------- /examples/train.py: -------------------------------------------------------------------------------- 1 | """训练全精度模型""" 2 | 3 | import os, sys 4 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | 11 | from models.model import ( 12 | Model, ModelBN, ModelLinear, ModelShortCut, ModelBNNoReLU, 13 | ModelLayerNorm, ModelAttention, ModelMV2, ModelMV2Naive, ModelDepthwise, 14 | ModelMV2ShortCut, ModelTransformerEncoder, ModelConvEncoder, 15 | TinyFormerSupernetDMTPOnePath 16 | ) 17 | from torchquanter.utils import random_seed 18 | from models.resnet import resnet18_quant 19 | from models.mobilenetv2 import mobilenetv2_quant 20 | from utils import get_loader 21 | 22 | 23 | def _args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--dataset', default='mnist', help='Type of dataset') 26 | parser.add_argument('--dataset-dir', metavar='DIR', default='/tmp', 27 | help='Path to dataset') 28 | parser.add_argument('--mean', type=float, nargs='+', default=[0.1307,], metavar='MEAN', 29 | help='Override mean pixel value of dataset') 30 | parser.add_argument('--std', type=float, nargs='+', default=[0.3081,], metavar='STD', 31 | help='Override std deviation of of dataset') 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def train_one_epoch(model, device, train_loader, optimizer, epoch): 37 | model.train() 38 | lossLayer = torch.nn.CrossEntropyLoss() 39 | for batch_idx, (data, target) in enumerate(train_loader): 40 | data, target = data.to(device), target.to(device) 41 | optimizer.zero_grad() 42 | output = model(data) 43 | loss = lossLayer(output, target) 44 | loss.backward() 45 | optimizer.step() 46 | 47 | if batch_idx % 50 == 0: 48 | print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format( 49 | epoch, batch_idx * len(data), len(train_loader.dataset), loss.item() 50 | )) 51 | 52 | def validate(model: nn.Module, device, test_loader): 53 | model.eval() 54 | test_loss = 0 55 | correct = 0 56 | lossLayer = torch.nn.CrossEntropyLoss(reduction='sum') 57 | for data, target in test_loader: 58 | data, target = data.to(device), target.to(device) 59 | output = model(data) 60 | test_loss += lossLayer(output, target).item() 61 | pred = output.argmax(dim=1, keepdim=True) 62 | correct += pred.eq(target.view_as(pred)).sum().item() 63 | 64 | test_loss /= len(test_loader.dataset) 65 | 66 | print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format( 67 | test_loss, 100. * correct / len(test_loader.dataset) 68 | )) 69 | 70 | 71 | if __name__ == "__main__": 72 | random_seed(seed=42) 73 | args = _args() 74 | 75 | # parameters 76 | batch_size = 64 77 | test_batch_size = 64 78 | seed = 1 79 | epochs = 10 80 | lr = 0.01 81 | momentum = 0.5 82 | save_model_dir = 'examples/ckpt' 83 | 84 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 85 | 86 | # dataset 87 | train_loader, test_loader = get_loader(args, batch_size) 88 | 89 | # choose model 90 | # model = Model() 91 | # model = ModelBN() 92 | # model = ModelBNNoReLU() 93 | # model = ModelLinear() 94 | # model = ModelShortCut() 95 | # model = ModelLayerNorm() 96 | # model = ModelAttention() 97 | # model = ModelDepthwise() 98 | # model = ModelMV2Naive() 99 | # model = ModelMV2() 100 | # model = ModelMV2ShortCut() 101 | # model = ModelTransformerEncoder() 102 | # model = ModelConvEncoder() 103 | # model = TinyFormerSupernetDMTPOnePath( 104 | # num_classes=10, downsample_layers=1, mv2block_layers=1, 105 | # transformer_layers=1, channel=[8, 8, 8], last_channel=8, 106 | # transformer0_embedding_dim=[16], transformer0_dim_feedforward=[16], 107 | # transformer1_embedding_dim=[16], transformer1_dim_feedforward=[16], 108 | # choice=[1,0,0,0], first_channel=1 109 | # ) 110 | model = resnet18_quant() 111 | model = mobilenetv2_quant() 112 | 113 | model = model.to(device) 114 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) 115 | 116 | for epoch in range(1, epochs + 1): 117 | train_one_epoch(model, device, train_loader, optimizer, epoch) 118 | validate(model, device, test_loader) 119 | 120 | if save_model_dir is not None: 121 | if not os.path.exists(save_model_dir): 122 | os.makedirs(save_model_dir) 123 | 124 | model_save_path = os.path.join(save_model_dir, f'{args.dataset}_{model._get_name()}.pth') 125 | torch.save(model.state_dict(), model_save_path) 126 | print(f'model is saved to {model_save_path}') 127 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | 5 | def get_loader(args, batch_size): 6 | if args.dataset == 'mnist': 7 | train_loader = torch.utils.data.DataLoader( 8 | datasets.MNIST(args.dataset_dir, train=True, download=True, 9 | transform=transforms.Compose([ 10 | transforms.ToTensor(), 11 | transforms.Normalize(args.mean, args.std) 12 | ])), 13 | batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True 14 | ) 15 | 16 | test_loader = torch.utils.data.DataLoader( 17 | datasets.MNIST(args.dataset_dir, train=False, download=True, 18 | transform=transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Normalize(args.mean, args.std) 21 | ])), 22 | batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True 23 | ) 24 | elif args.dataset == 'cifar10': 25 | train_loader = torch.utils.data.DataLoader( 26 | datasets.CIFAR10(args.dataset_dir, train=True, download=True, 27 | transform=transforms.Compose([ 28 | transforms.ToTensor(), 29 | transforms.Normalize(args.mean, args.std) 30 | ])), 31 | batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True 32 | ) 33 | 34 | test_loader = torch.utils.data.DataLoader( 35 | datasets.CIFAR10(args.dataset_dir, train=False, download=True, 36 | transform=transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize(args.mean, args.std) 39 | ])), 40 | batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True 41 | ) 42 | else: 43 | raise ValueError(f"Unsupported dataset type {args.dataset}") 44 | return train_loader, test_loader 45 | 46 | def export_onnx(args, model): 47 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 48 | if args.dataset == 'mnist': 49 | dummy_input = torch.rand(1, 1, 28, 28).to(device) 50 | elif args.dataset == 'cifar10': 51 | dummy_input = torch.rand(1, 3, 32, 32).to(device) 52 | else: 53 | raise ValueError(f"Unsupported dataset type {args.dataset}") 54 | 55 | forward_bk = model.forward 56 | model.forward = model.quantize_inference 57 | torch.onnx.export(model, dummy_input, 58 | 'test.onnx', opset_version=11) 59 | model.forward = forward_bk 60 | -------------------------------------------------------------------------------- /img/IMG_1294.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roxbili/TorchQuanter/7ed0a9ffb043d9c46231383ead83aaaef97e77a3/img/IMG_1294.png -------------------------------------------------------------------------------- /img/resnet_onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roxbili/TorchQuanter/7ed0a9ffb043d9c46231383ead83aaaef97e77a3/img/resnet_onnx.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", 'r') as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torchquanter", # 模块名称 8 | version="1.0", # 当前版本 9 | author="roxbili", # 作者 10 | description="quant for torch", # 模块简介 11 | long_description=long_description, # 模块详细介绍 12 | long_description_content_type="text/markdown", # 模块详细介绍格式 13 | url="https://github.com/Roxbili/TorchQuanter", # 模块github地址 14 | packages=setuptools.find_packages(), # 自动找到项目中导入的模块 15 | # 模块相关的元数据(更多的描述) 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: Independent", 20 | ], 21 | # 依赖模块 22 | install_requires=[ 23 | "torch" 24 | ], 25 | # python版本 26 | python_requires=">=3", 27 | ) -------------------------------------------------------------------------------- /test/calc_scale_zeropoint_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchquanter.utils import calcScaleZeroPoint, approximate_float 3 | 4 | if __name__ == '__main__': 5 | tensor = torch.randn(8, 3, 3, 3) 6 | 7 | # per_channel 8 | max_ = tensor.flatten(1).max(dim=-1)[0] 9 | min_ = tensor.flatten(1).min(dim=-1)[0] 10 | print(min_) 11 | print(max_) 12 | 13 | scale, zero_point = calcScaleZeroPoint(min_, max_, symmetric=False) 14 | multiplier, shift = approximate_float(scale) 15 | err = abs(scale - multiplier >> (31 - shift)) 16 | print(f'multiplier: {multiplier}, shift: {shift}, err: {err}') 17 | 18 | scale, zero_point = calcScaleZeroPoint(min_, max_, symmetric=True) 19 | multiplier, shift = approximate_float(scale) 20 | err = abs(scale - multiplier >> (31 - shift)) 21 | print(f'multiplier: {multiplier}, shift: {shift}, err: {err}') 22 | 23 | 24 | # per_tensor 25 | max_ = tensor.max() 26 | min_ = tensor.min() 27 | print(min_) 28 | print(max_) 29 | 30 | scale, zero_point = calcScaleZeroPoint(min_, max_, symmetric=False) 31 | multiplier, shift = approximate_float(scale) 32 | err = abs(scale - multiplier >> (31 - shift)) 33 | print(f'multiplier: {multiplier}, shift: {shift}, err: {err}') 34 | 35 | scale, zero_point = calcScaleZeroPoint(min_, max_, symmetric=True) 36 | multiplier, shift = approximate_float(scale) 37 | err = abs(scale - multiplier >> (31 - shift)) 38 | print(f'multiplier: {multiplier}, shift: {shift}, err: {err}') -------------------------------------------------------------------------------- /test/qadd_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torchquanter.nn import QAdd, QConv2d 7 | from models.model import ModelShortCut 8 | 9 | 10 | torch.manual_seed(0) 11 | 12 | class TestAdd(nn.Module): 13 | def __init__(self): 14 | super(TestAdd, self).__init__() 15 | self.conv = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) + x 19 | return x 20 | 21 | def quantize(self): 22 | self.qconv = QConv2d(self.conv, qi=True, qo=True, qmode='per_channel') 23 | self.qadd = QAdd(qi1=False, qi2=False, qo=True) 24 | 25 | def quantize_forward(self, x): 26 | x_ = self.qconv(x) 27 | x = self.qadd(x_, x) 28 | return x 29 | 30 | def freeze(self): 31 | self.qconv.freeze() 32 | self.qadd.freeze(self.qconv.qo, self.qconv.qi) 33 | 34 | def quantize_inference(self, x, mode='cmsis_nn'): 35 | qx = self.qconv.qi.quantize_tensor(x) 36 | qx_ = self.qconv.quantize_inference(qx) 37 | qx = self.qadd.quantize_inference(qx_, qx, mode=mode) 38 | out = self.qadd.qo.dequantize_tensor(qx) 39 | return out 40 | 41 | def test_qadd1(): 42 | data = torch.rand(1,1,5,5) 43 | 44 | model = TestAdd() 45 | model(data) 46 | out = model(data).flatten() 47 | 48 | model.eval() 49 | model.quantize() 50 | for _ in range(10): 51 | model.quantize_forward(data) 52 | model.freeze() 53 | 54 | qout_float = model.quantize_inference(data, mode=None).flatten() 55 | err = (out - qout_float).abs().mean() 56 | assert err < 0.1, f'err: {err}' 57 | 58 | def test_qadd2(): 59 | data = torch.rand(1,1,28,28) 60 | model = ModelShortCut() 61 | out = model(data).flatten() 62 | 63 | model.eval() 64 | model.quantize() 65 | for _ in range(10): 66 | model.quantize_forward(data) 67 | model.freeze() 68 | 69 | qout_float = model.quantize_inference(data).flatten() 70 | err = (out - qout_float).abs().mean() 71 | assert err < 0.1, f'err: {err}' 72 | 73 | 74 | if __name__ == '__main__': 75 | test_qadd1() 76 | # test_qadd2() -------------------------------------------------------------------------------- /test/qconcat_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torchquanter.nn import QConcat, QConv2d 7 | 8 | 9 | torch.manual_seed(0) 10 | 11 | class TestConcat(nn.Module): 12 | def __init__(self): 13 | super(TestConcat, self).__init__() 14 | self.conv = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1) 15 | 16 | def forward(self, x): 17 | out = self.conv(x) 18 | out = torch.cat((out, x), dim=1) 19 | return out 20 | 21 | def quantize(self): 22 | self.qconv = QConv2d(self.conv, qi=True, qo=True, qmode='per_channel') 23 | self.qcat = QConcat(dim=1, qi1=False, qi2=False, qo=True) 24 | 25 | def quantize_forward(self, x): 26 | x_ = self.qconv(x) 27 | x = self.qcat(x_, x) 28 | return x 29 | 30 | def freeze(self): 31 | self.qconv.freeze() 32 | self.qcat.freeze(self.qconv.qo, self.qconv.qi) 33 | 34 | def quantize_inference(self, x, mode='cmsis_nn'): 35 | qx = self.qconv.qi.quantize_tensor(x) 36 | qx_ = self.qconv.quantize_inference(qx) 37 | qx = self.qcat.quantize_inference(qx_, qx, mode=mode) 38 | out = self.qcat.qo.dequantize_tensor(qx) 39 | return out 40 | 41 | def test_qconcat(): 42 | data = torch.rand(1,1,5,5) 43 | 44 | model = TestConcat() 45 | model(data) 46 | out = model(data).flatten() 47 | 48 | model.eval() 49 | model.quantize() 50 | for _ in range(10): 51 | model.quantize_forward(data) 52 | model.freeze() 53 | 54 | qout_float = model.quantize_inference(data, mode=None).flatten() 55 | err = (out - qout_float).abs().mean() 56 | assert err < 0.1, f'err: {err}' 57 | 58 | 59 | if __name__ == '__main__': 60 | test_qconcat() 61 | -------------------------------------------------------------------------------- /test/qdiv_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchquanter.nn import QDiv 8 | 9 | torch.manual_seed(0) 10 | 11 | class TestDiv(nn.Module): 12 | def __init__(self): 13 | super(TestDiv, self).__init__() 14 | 15 | def forward(self, x1, x2): 16 | x = torch.div(x1, x2) 17 | return x 18 | 19 | def quantize(self): 20 | self.qdiv = QDiv() 21 | 22 | def quantize_forward(self, x1, x2): 23 | x = self.qdiv(x1, x2) 24 | return x 25 | 26 | def freeze(self): 27 | self.qdiv.freeze() 28 | 29 | def quantize_inference(self, x1, x2, mode='cmsis_nn'): 30 | qx1 = self.qdiv.qi1.quantize_tensor(x1) 31 | qx2 = self.qdiv.qi2.quantize_tensor(x2) 32 | qx = self.qdiv.quantize_inference(qx1, qx2, mode=mode) 33 | out = self.qdiv.qo.dequantize_tensor(qx) 34 | return out 35 | 36 | def test_mean(): 37 | data1 = torch.rand(1,10) 38 | data2 = torch.rand(1,10) 39 | 40 | model = TestDiv() 41 | model(data1, data2) 42 | out = model(data1, data2).flatten() 43 | 44 | model.eval() 45 | model.quantize() 46 | for _ in range(10): 47 | simulate_out = model.quantize_forward(data1, data2) 48 | model.freeze() 49 | 50 | qout_float = model.quantize_inference(data1, data2).flatten() 51 | err = (out - qout_float).abs().mean() 52 | assert err < 0.1, f'err: {err}' 53 | 54 | if __name__ == '__main__': 55 | test_mean() -------------------------------------------------------------------------------- /test/qlayernorm_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchquanter.nn import QLayerNorm, QLayerNormTFLite 8 | # from models.model import ModelShortCut 9 | 10 | torch.manual_seed(0) 11 | 12 | class TestQlayernorm(nn.Module): 13 | def __init__(self): 14 | super(TestQlayernorm, self).__init__() 15 | self.layernorm = nn.LayerNorm(10) 16 | 17 | def forward(self, x): 18 | x = self.layernorm(x) 19 | return x 20 | 21 | def quantize(self): 22 | self.qlayernorm = QLayerNorm(self.layernorm, qi=True, qo=True) 23 | # self.qlayernorm = QLayerNormTFLite(self.layernorm, qi=True, qo=True) 24 | 25 | def quantize_forward(self, x): 26 | x = self.qlayernorm(x) 27 | return x 28 | 29 | def freeze(self): 30 | self.qlayernorm.freeze() 31 | 32 | def quantize_inference(self, x, mode='cmsis_nn'): 33 | qx = self.qlayernorm.qi.quantize_tensor(x) 34 | qx = self.qlayernorm.quantize_inference(qx, mode=mode) 35 | out = self.qlayernorm.qo.dequantize_tensor(qx) 36 | return out 37 | 38 | def test_qlayernorm(): 39 | data = torch.rand(1,10) 40 | 41 | model = TestQlayernorm() 42 | model(data) 43 | out = model(data).flatten() 44 | 45 | model.eval() 46 | model.quantize() 47 | for _ in range(10): 48 | out = model.quantize_forward(data) 49 | model.freeze() 50 | 51 | qout_float = model.quantize_inference(data).flatten() 52 | err = (out - qout_float).abs().mean() 53 | assert err < 0.1, f'err: {err}' 54 | 55 | if __name__ == '__main__': 56 | test_qlayernorm() -------------------------------------------------------------------------------- /test/qmatmul_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torchquanter.nn import QLinear, QMatmul 7 | 8 | torch.manual_seed(0) 9 | 10 | class TestMul(nn.Module): 11 | def __init__(self): 12 | super(TestMul, self).__init__() 13 | self.fc = nn.Linear(2, 4) 14 | 15 | def forward(self, x): 16 | x = self.fc(x) 17 | x = torch.matmul(x, x.permute(1,0)) 18 | return x 19 | 20 | def quantize(self): 21 | self.qfc = QLinear(self.fc, qi=True, qo=True) 22 | self.qmatmul = QMatmul(qi1=False, qi2=False) 23 | 24 | def quantize_forward(self, x): 25 | x = self.qfc(x) 26 | x = self.qmatmul(x, x.permute(1,0)) 27 | return x 28 | 29 | def freeze(self): 30 | self.qfc.freeze() 31 | self.qmatmul.freeze(self.qfc.qo, self.qfc.qo) 32 | 33 | def quantize_inference(self, x, mode='cmsis_nn'): 34 | qx = self.qfc.qi.quantize_tensor(x) 35 | qx = self.qfc(qx) 36 | qx = self.qmatmul(qx, qx.permute(1,0)) 37 | out = self.qmatmul.qo.dequantize_tensor(qx) 38 | return out 39 | 40 | def test_qadd1(): 41 | data = torch.rand(10,2) 42 | 43 | model = TestMul() 44 | model(data) 45 | out = model(data).flatten() 46 | 47 | model.eval() 48 | model.quantize() 49 | for _ in range(10): 50 | model.quantize_forward(data) 51 | model.freeze() 52 | 53 | qout_float = model.quantize_inference(data, mode=None).flatten() 54 | err = (out - qout_float).abs().mean() 55 | assert err < 0.3, f'err: {err}' 56 | 57 | 58 | if __name__ == '__main__': 59 | test_qadd1() -------------------------------------------------------------------------------- /test/qmean_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchquanter.nn import QMean 8 | 9 | torch.manual_seed(0) 10 | 11 | class TestMean(nn.Module): 12 | def __init__(self): 13 | super(TestMean, self).__init__() 14 | 15 | def forward(self, x): 16 | x = torch.mean(x) 17 | return x 18 | 19 | def quantize(self): 20 | self.qmean = QMean(dim=-1, qi=True, qo=True) 21 | 22 | def quantize_forward(self, x): 23 | x = self.qmean(x) 24 | return x 25 | 26 | def freeze(self): 27 | self.qmean.freeze() 28 | 29 | def quantize_inference(self, x, mode='cmsis_nn'): 30 | qx = self.qmean.qi.quantize_tensor(x) 31 | qx = self.qmean.quantize_inference(qx) 32 | out = self.qmean.qo.dequantize_tensor(qx) 33 | return out 34 | 35 | def test_mean(): 36 | data = torch.rand(1,512) 37 | 38 | model = TestMean() 39 | model(data) 40 | out = model(data).flatten() 41 | 42 | model.eval() 43 | model.quantize() 44 | for _ in range(10): 45 | simulate_out = model.quantize_forward(data) 46 | model.freeze() 47 | 48 | qout_float = model.quantize_inference(data).flatten() 49 | err = (out - qout_float).abs().mean() 50 | assert err < 0.1, f'err: {err}' 51 | 52 | if __name__ == '__main__': 53 | test_mean() -------------------------------------------------------------------------------- /test/qmul_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torchquanter.nn import QLinear, QMul 7 | 8 | torch.manual_seed(0) 9 | 10 | class TestMul(nn.Module): 11 | def __init__(self): 12 | super(TestMul, self).__init__() 13 | self.fc = nn.Linear(2, 4) 14 | 15 | def forward(self, x): 16 | x = self.fc(x) 17 | x = torch.mul(x, x) 18 | return x 19 | 20 | def quantize(self): 21 | self.qfc = QLinear(self.fc, qi=True, qo=True) 22 | self.qmul = QMul(qi1=False, qi2=False) 23 | 24 | def quantize_forward(self, x): 25 | x = self.qfc(x) 26 | x = self.qmul(x, x) 27 | return x 28 | 29 | def freeze(self): 30 | self.qfc.freeze() 31 | self.qmul.freeze(self.qfc.qo, self.qfc.qo) 32 | 33 | def quantize_inference(self, x, mode='cmsis_nn'): 34 | qx = self.qfc.qi.quantize_tensor(x) 35 | qx = self.qfc(qx) 36 | qx = self.qmul(qx, qx) 37 | out = self.qmul.qo.dequantize_tensor(qx) 38 | return out 39 | 40 | def test_qadd1(): 41 | data = torch.rand(1,2) 42 | 43 | model = TestMul() 44 | model(data) 45 | out = model(data).flatten() 46 | 47 | model.eval() 48 | model.quantize() 49 | for _ in range(10): 50 | model.quantize_forward(data) 51 | model.freeze() 52 | 53 | qout_float = model.quantize_inference(data, mode=None).flatten() 54 | err = (out - qout_float).abs().mean() 55 | assert err < 0.3, f'err: {err}' 56 | 57 | 58 | if __name__ == '__main__': 59 | test_qadd1() -------------------------------------------------------------------------------- /test/qsoftmax_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torchquanter.nn import QSoftmax 9 | # from models.model import ModelShortCut 10 | 11 | torch.manual_seed(0) 12 | 13 | class TestSoftmax(nn.Module): 14 | def __init__(self): 15 | super(TestSoftmax, self).__init__() 16 | self.softmax = nn.Softmax(dim=-1) 17 | 18 | def forward(self, x): 19 | x = self.softmax(x) 20 | return x 21 | 22 | def quantize(self): 23 | self.qsoftmax = QSoftmax(dim=-1, qi=True, qo=True) 24 | 25 | def quantize_forward(self, x): 26 | x = self.qsoftmax(x) 27 | return x 28 | 29 | def freeze(self): 30 | self.qsoftmax.freeze() 31 | 32 | def quantize_inference(self, x, mode='cmsis_nn'): 33 | qx = self.qsoftmax.qi.quantize_tensor(x) 34 | qx = self.qsoftmax.quantize_inference(qx) 35 | out = self.qsoftmax.qo.dequantize_tensor(qx) 36 | return out 37 | 38 | def test_qsoftmax(): 39 | data = torch.rand(1,4) 40 | 41 | model = TestSoftmax() 42 | model(data) 43 | out = model(data).flatten() 44 | 45 | model.eval() 46 | model.quantize() 47 | for _ in range(10): 48 | simulate_out = model.quantize_forward(data) 49 | model.freeze() 50 | 51 | qout_float = model.quantize_inference(data).flatten() 52 | err = (out - qout_float).abs().mean() 53 | assert err < 0.1, f'err: {err}' 54 | 55 | def test_softmax_s8(): 56 | # 数据来自官方测试用例 57 | input_data = torch.tensor([-80, -48, 16, 0, -96], dtype=torch.float32) 58 | gold_output = torch.tensor([-128, -125, 56, -60, -128], dtype=torch.float32) 59 | input_mult = 1077952576 60 | input_left_shift = 23 61 | diff_min = -248 # 暂时不知道干什么用的 62 | 63 | # softmax 不需要input_zero_point,数学上不影响结果 64 | x = input_data - input_data.max() 65 | 66 | # 这里应该是官方计算中从 int8 -> fixed point 的方法 67 | x = ((x * input_mult) >> (31 - input_left_shift)) / (1 << (31 - 5)) 68 | 69 | # 转成 fixed point后直接输入softmax函数中进行测试,结果正确 70 | out1 = F.softmax(x, dim=-1) 71 | out1 = out1 / (1 / 256.) - 128 # output scale和zero_point是定死的 72 | out1.round_() 73 | assert (out1 == gold_output).all(), print(out1) 74 | 75 | 76 | if __name__ == '__main__': 77 | test_qsoftmax() 78 | test_softmax_s8() -------------------------------------------------------------------------------- /test/qsqrt_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchquanter.nn import QSqrt 8 | 9 | torch.manual_seed(0) 10 | 11 | class TestSqrt(nn.Module): 12 | def __init__(self): 13 | super(TestSqrt, self).__init__() 14 | 15 | def forward(self, x): 16 | x = torch.sqrt(x) 17 | return x 18 | 19 | def quantize(self): 20 | self.qsqrt = QSqrt() 21 | 22 | def quantize_forward(self, x): 23 | x = self.qsqrt(x) 24 | return x 25 | 26 | def freeze(self): 27 | self.qsqrt.freeze() 28 | 29 | def quantize_inference(self, x, mode='cmsis_nn'): 30 | qx = self.qsqrt.qi.quantize_tensor(x) 31 | qx = self.qsqrt.quantize_inference(qx, mode=mode) 32 | out = self.qsqrt.qo.dequantize_tensor(qx) 33 | return out 34 | 35 | def test_mean(): 36 | data = torch.rand(1,10) 37 | 38 | model = TestSqrt() 39 | model(data) 40 | out = model(data).flatten() 41 | 42 | model.eval() 43 | model.quantize() 44 | for _ in range(10): 45 | simulate_out = model.quantize_forward(data) 46 | model.freeze() 47 | 48 | qout_float = model.quantize_inference(data).flatten() 49 | err = (out - qout_float).abs().mean() 50 | assert err < 0.1, f'err: {err}' 51 | 52 | if __name__ == '__main__': 53 | test_mean() -------------------------------------------------------------------------------- /test/qsub_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.getcwd(), 'examples/')) 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchquanter.nn import QSub 8 | 9 | torch.manual_seed(0) 10 | 11 | class TestSub(nn.Module): 12 | def __init__(self): 13 | super(TestSub, self).__init__() 14 | 15 | def forward(self, x1, x2): 16 | x = torch.sub(x1, x2) 17 | return x 18 | 19 | def quantize(self): 20 | self.qsub = QSub() 21 | 22 | def quantize_forward(self, x1, x2): 23 | x = self.qsub(x1, x2) 24 | return x 25 | 26 | def freeze(self): 27 | self.qsub.freeze() 28 | 29 | def quantize_inference(self, x1, x2, mode='cmsis_nn'): 30 | qx1 = self.qsub.qi1.quantize_tensor(x1) 31 | qx2 = self.qsub.qi2.quantize_tensor(x2) 32 | qx = self.qsub.quantize_inference(qx1, qx2) 33 | out = self.qsub.qo.dequantize_tensor(qx) 34 | return out 35 | 36 | def test_mean(): 37 | data1 = torch.rand(1,4) 38 | data2 = torch.rand(1,4) 39 | 40 | model = TestSub() 41 | model(data1, data2) 42 | out = model(data1, data2).flatten() 43 | 44 | model.eval() 45 | model.quantize() 46 | for _ in range(10): 47 | simulate_out = model.quantize_forward(data1, data2) 48 | model.freeze() 49 | 50 | qout_float = model.quantize_inference(data1, data2).flatten() 51 | err = (out - qout_float).abs().mean() 52 | assert err < 0.1, f'err: {err}' 53 | 54 | if __name__ == '__main__': 55 | test_mean() -------------------------------------------------------------------------------- /test/sqrt_interger_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchquanter.utils import sqrt_interger 4 | 5 | def sqrt_interger_test(tensor: torch.Tensor): 6 | out1 = torch.sqrt(tensor).floor() 7 | out2 = sqrt_interger(tensor) 8 | 9 | if (out1 == out2).all() == False: 10 | print(f' tensor: {tensor}') 11 | print(f' gold: {out1}') 12 | print(f'sqrt_interger: {out2}') 13 | raise Exception 14 | 15 | def data_generator(tensor_size, low=0, high=2**32): 16 | """ 17 | low(inclusive), high(exclusive) 18 | """ 19 | return torch.randint(low=low, high=high, size=tensor_size).float() 20 | 21 | if __name__ == '__main__': 22 | for i in range(1000): 23 | tensor = data_generator((10,)) 24 | sqrt_interger_test(tensor) -------------------------------------------------------------------------------- /torchquanter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roxbili/TorchQuanter/7ed0a9ffb043d9c46231383ead83aaaef97e77a3/torchquanter/__init__.py -------------------------------------------------------------------------------- /torchquanter/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .qconv2d import QConv2d 2 | from .qrelu import QReLU 3 | from .qlinear import QLinear 4 | from .qmaxpool2d import QMaxPool2d 5 | from .qconvbnrelu import QConvBNReLU 6 | from .qadd import QAdd 7 | from .qlayernorm import QLayerNormTFLite, QLayerNorm, QLayerNormFP32 8 | from .qsoftmax import QSoftmax 9 | from .qmul import QMul 10 | from .qmatmul import QMatmul 11 | from .qnorm import QNorm 12 | from .qmean import QMean 13 | from .qsub import QSub 14 | from .qsqrt import QSqrt 15 | from .qdiv import QDiv 16 | from .qavgpool2d import QAvgPool2d, QAdaptiveAvgPool2d 17 | from .qsigmoid import QSigmoid 18 | from .qconcat import QConcat 19 | from .qsoftmax_w_policy import QSoftmax_W_Policy -------------------------------------------------------------------------------- /torchquanter/nn/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | from torchquanter.utils import ( 5 | calcScaleZeroPoint, quantize_tensor, dequantize_tensor, 6 | get_qmin_qmax, approximate_float, broadcast_dim_as) 7 | from torch.autograd import Function 8 | 9 | 10 | class QParam(nn.Module): 11 | """ 12 | Quantization parameters recorder 13 | """ 14 | def __init__(self, num_bits=8, signed=True, symmetric=False, momentum=0.9): 15 | """ 16 | Args 17 | ---------- 18 | signed: bool, True for int8, False for uint8 19 | symmetric: bool, True for symmetric quantization(zero_point=0) 20 | momentum: the value used for the running_mean and running_var computation. Default: 0.9 21 | """ 22 | super(QParam, self).__init__() 23 | self.num_bits = num_bits 24 | self.signed = signed 25 | self.symmetric = symmetric 26 | self.momentum = momentum 27 | self.qmin, self.qmax = get_qmin_qmax(num_bits, signed) 28 | 29 | scale = torch.tensor([], requires_grad=False) 30 | zero_point = torch.tensor([], requires_grad=False) 31 | 32 | # register for saving parameters when calling torch.save API 33 | self.register_buffer('scale', scale) 34 | self.register_buffer('zero_point', zero_point) 35 | 36 | # check validity of initial parameters 37 | self._init_check() 38 | 39 | def _init_check(self): 40 | assert not (self.signed == False and self.symmetric == True), \ 41 | 'Only support symmetirc quantization with signed quantization parameters.' 42 | 43 | def update(self, tensor): 44 | """ 45 | update the max and min from the tensor, 46 | calculate the scale and zero point 47 | """ 48 | raise NotImplementedError('update function is not implemented') 49 | 50 | def quantize_tensor(self, tensor): 51 | return quantize_tensor(tensor, self.scale, self.zero_point, num_bits=self.num_bits, signed=self.signed) 52 | 53 | def dequantize_tensor(self, q_x): 54 | return dequantize_tensor(q_x, self.scale, self.zero_point) 55 | 56 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 57 | """ 58 | load parameters from state_dict 59 | """ 60 | pass 61 | 62 | def __str__(self): 63 | pass 64 | 65 | 66 | class QParamIO(QParam): 67 | def __init__(self, num_bits=8, signed=True, symmetric=False, momentum=0.1): 68 | """ 69 | Args 70 | ---------- 71 | signed: bool, True for int8, False for uint8 72 | symmetric: bool, True for symmetric quantization(zero_point=0) 73 | momentum: the value used for the running_mean and running_var computation. Default: 0.1 74 | """ 75 | super(QParamIO, self).__init__(num_bits=num_bits, signed=signed, symmetric=symmetric, momentum=momentum) 76 | 77 | running_min = torch.tensor([], requires_grad=False) 78 | running_max = torch.tensor([], requires_grad=False) 79 | 80 | # register for saving parameters when calling torch.save API 81 | self.register_buffer('running_min', running_min) 82 | self.register_buffer('running_max', running_max) 83 | 84 | def update(self, tensor): 85 | """ 86 | update the max and min from the tensor, 87 | calculate the scale and zero point 88 | """ 89 | if self.running_max.nelement() == 0: 90 | self.running_max.data = tensor.max().data 91 | else: # exponential moving average update min and max 92 | self.running_max.data = (1.0 - self.momentum) * self.running_max.data + self.momentum * tensor.max().data 93 | 94 | if self.running_min.nelement() == 0: 95 | self.running_min.data = tensor.min().data 96 | else: 97 | self.running_min.data = (1.0 - self.momentum) * self.running_min.data + self.momentum * tensor.min().data 98 | self.scale, self.zero_point = calcScaleZeroPoint(self.running_min, self.running_max, 99 | self.num_bits, signed=self.signed, symmetric=self.symmetric) 100 | 101 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 102 | key_names = ['scale', 'zero_point', 'running_min', 'running_max'] 103 | for key in key_names: 104 | value = getattr(self, key) 105 | value.data = state_dict[prefix + key].data 106 | state_dict.pop(prefix + key) 107 | 108 | def __str__(self): 109 | # info = 'scale: %.10f ' % self.scale 110 | # info += 'zero_point: %d ' % self.zero_point 111 | # info += 'running_min: %.6f ' % self.running_min 112 | # info += 'running_max: %.6f ' % self.running_max 113 | info = f'scale: {self.scale} zero_point: {self.zero_point}\n' \ 114 | f'running_min: {self.running_min} running_max: {self.running_max}' 115 | return info 116 | 117 | 118 | class QParamW(QParam): 119 | def __init__(self, num_bits=8, signed=True, symmetric=True, momentum=0.1, qmode='per_channel'): 120 | """ 121 | Args 122 | ---------- 123 | signed: bool, True for int8, False for uint8 124 | symmetric: bool, True for symmetric quantization(zero_point=0) 125 | momentum: the value used for the running_mean and running_var computation. Default: 0.1 126 | qmode: str, per_tensor or per_channel. per_channel quantize along tensor axis 0. 127 | """ 128 | self.qmode = qmode 129 | super(QParamW, self).__init__(num_bits=num_bits, signed=signed, symmetric=symmetric, momentum=momentum) 130 | 131 | min = torch.tensor([], requires_grad=False) 132 | max = torch.tensor([], requires_grad=False) 133 | 134 | # register for saving parameters when calling torch.save API 135 | self.register_buffer('min', min) 136 | self.register_buffer('max', max) 137 | 138 | def _init_check(self): 139 | super(QParamW, self)._init_check() 140 | assert self.qmode in ['per_tensor', 'per_channel'], \ 141 | f"Only 'per_tensor' or 'per_channel' mode is supported" 142 | 143 | def update(self, tensor): 144 | """ 145 | update the max and min from the tensor, 146 | calculate the scale and zero point 147 | """ 148 | if self.qmode == 'per_tensor': 149 | tensor_max = tensor.max() 150 | tensor_min = tensor.min() 151 | elif self.qmode == 'per_channel': 152 | tensor_max = tensor.flatten(1).max(dim=-1)[0] # .view(-1, *[1 for _ in range(tensor.dim() - 1)]) 153 | tensor_min = tensor.flatten(1).min(dim=-1)[0] # .view(-1, *[1 for _ in range(tensor.dim() - 1)]) 154 | else: 155 | raise Exception("Only 'per_tensor' or 'per_channel' mode is supported") 156 | 157 | if self.max.nelement() == 0: 158 | self.max.data = tensor_max.data 159 | else: # exponential moving average update min and max 160 | self.max.data = (1.0 - self.momentum) * self.max.data + self.momentum * tensor_max.data 161 | 162 | if self.min.nelement() == 0: 163 | self.min.data = tensor_min.data 164 | else: # exponential moving average update min and max 165 | self.min.data = (1.0 - self.momentum) * self.min.data + self.momentum * tensor_min.data 166 | 167 | self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, 168 | self.num_bits, signed=self.signed, symmetric=self.symmetric) 169 | 170 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 171 | key_names = ['scale', 'zero_point', 'min', 'max'] 172 | for key in key_names: 173 | value = getattr(self, key) 174 | value.data = state_dict[prefix + key].data 175 | state_dict.pop(prefix + key) 176 | 177 | def __str__(self): 178 | # info = 'scale: %.10f ' % self.scale 179 | # info += 'zero_point: %d ' % self.zero_point 180 | # info += 'min: %.6f ' % self.min 181 | # info += 'max: %.6f ' % self.max 182 | info = f'scale: {self.scale} zero_point: {self.zero_point}\n' \ 183 | f'min: {self.min} max: {self.max}' 184 | return info 185 | 186 | 187 | class QModule(nn.Module): 188 | def __init__(self, qi=True, qo=True, num_bits=8, signed=True, symmetric=False): 189 | super(QModule, self).__init__() 190 | if qi: 191 | self.qi = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric) 192 | if qo: 193 | self.qo = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric) 194 | 195 | M = torch.tensor([], requires_grad=False) 196 | self.register_buffer('M', M) 197 | self.freeze_flag = False 198 | 199 | @abstractmethod 200 | def freeze(self): 201 | raise NotImplementedError('freeze should be implemented.') 202 | 203 | @abstractmethod 204 | def quantize_inference(self, x, mode=None): 205 | """ 206 | Args 207 | ---------- 208 | x: float 209 | mode: None or cmsis_nn. Inference mode, None means use float multiplying. default None. 210 | """ 211 | raise NotImplementedError('quantize_inference should be implemented.') 212 | 213 | 214 | class FakeQuantize(Function): 215 | 216 | @staticmethod 217 | def forward(ctx, x, qparam: QParam): 218 | x = qparam.quantize_tensor(x) 219 | x = qparam.dequantize_tensor(x) 220 | return x 221 | 222 | @staticmethod 223 | def backward(ctx, grad_output): 224 | return grad_output, None 225 | 226 | class QuantizeTensor(Function): 227 | 228 | @staticmethod 229 | def forward(ctx, x, qparam: QParam): 230 | qx = qparam.quantize_tensor(x) 231 | return qx 232 | 233 | def backward(ctx, grad_output): 234 | return grad_output, None 235 | 236 | class DequantizeTensor(Function): 237 | 238 | @staticmethod 239 | def forward(ctx, qx, qparam: QParam): 240 | x = qparam.dequantize_tensor(qx) 241 | return x 242 | 243 | def backward(ctx, grad_output): 244 | return grad_output, None 245 | 246 | class FloorSTE(Function): 247 | """ 248 | Straight-through Estimator(STE) for torch.floor() 249 | """ 250 | 251 | @staticmethod 252 | def forward(ctx, x): 253 | return torch.floor(x) 254 | 255 | @staticmethod 256 | def backward(ctx, grad_output): 257 | return grad_output.clone() 258 | 259 | class RoundSTE(Function): 260 | 261 | @staticmethod 262 | def forward(ctx, x): 263 | return torch.round(x) 264 | 265 | @staticmethod 266 | def backward(ctx, grad_output): 267 | return grad_output.clone() 268 | 269 | class ClampSTE(Function): 270 | 271 | @staticmethod 272 | def forward(ctx, x, qmin, qmax): 273 | return torch.clamp(x, qmin, qmax) 274 | 275 | @staticmethod 276 | def backward(ctx, grad_output): 277 | return grad_output, None, None 278 | -------------------------------------------------------------------------------- /torchquanter/nn/qadd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, QParamIO 6 | from torchquanter.utils import broadcast_dim_as, approximate_float 7 | 8 | class QAdd(QModule): 9 | 10 | def __init__(self, qi1=True, qi2=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 11 | super(QAdd, self).__init__(qi=False, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 12 | if qi1: 13 | self.qi1 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 14 | if qi2: 15 | self.qi2 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 16 | self.num_bits = num_bits 17 | self.signed = signed 18 | 19 | M1 = torch.tensor([], requires_grad=False) 20 | self.register_buffer('M1', M1) 21 | M2 = torch.tensor([], requires_grad=False) 22 | self.register_buffer('M2', M2) 23 | 24 | def freeze(self, qi1=None, qi2=None, qo=None): 25 | 26 | if hasattr(self, 'qi1') and qi1 is not None: 27 | raise ValueError('qi has been provided in init function.') 28 | if not hasattr(self, 'qi1') and qi1 is None: 29 | raise ValueError('qi is not existed, should be provided.') 30 | 31 | if hasattr(self, 'qi2') and qi2 is not None: 32 | raise ValueError('qi has been provided in init function.') 33 | if not hasattr(self, 'qi2') and qi2 is None: 34 | raise ValueError('qi is not existed, should be provided.') 35 | 36 | if hasattr(self, 'qo') and qo is not None: 37 | raise ValueError('qo has been provided in init function.') 38 | if not hasattr(self, 'qo') and qo is None: 39 | raise ValueError('qo is not existed, should be provided.') 40 | self.freeze_flag = True 41 | 42 | if qi1 is not None: 43 | self.qi1 = qi1 44 | if qi2 is not None: 45 | self.qi2 = qi2 46 | if qo is not None: 47 | self.qo = qo 48 | self.M1 = self.qi1.scale / self.qo.scale 49 | self.M2 = self.qi2.scale / self.qo.scale 50 | return self.qo 51 | 52 | def forward(self, x1, x2): 53 | if hasattr(self, 'qi1'): 54 | self.qi1.update(x1) 55 | x1 = FakeQuantize.apply(x1, self.qi1) 56 | if hasattr(self, 'qi2'): 57 | self.qi2.update(x2) 58 | x2 = FakeQuantize.apply(x2, self.qi2) 59 | if self.freeze_flag: 60 | raise Exception(f'{self._get_name()} has been frozen') 61 | 62 | out = x1 + x2 63 | 64 | if hasattr(self, 'qo'): 65 | self.qo.update(out) 66 | out = FakeQuantize.apply(out, self.qo) 67 | 68 | return out 69 | 70 | def quantize_inference(self, x1, x2, mode=None): 71 | x1 = x1 - self.qi1.zero_point 72 | x2 = x2 - self.qi2.zero_point 73 | if mode is None: 74 | x1 = self.M1 * x1 75 | x2 = self.M2 * x2 76 | out = x1 + x2 77 | out.round_() 78 | elif mode == 'cmsis_nn': 79 | out = ReScaleAdd.apply(x1, x2, self.M1, self.M2) 80 | else: 81 | raise Exception(f'Unknown mode {mode}') 82 | out = out + self.qo.zero_point 83 | out.clamp_(self.qo.qmin, self.qo.qmax).round_() 84 | return out 85 | 86 | 87 | class ReScaleAdd(torch.autograd.Function): 88 | @staticmethod 89 | def symbolic(g, x1, x2, M1, M2): 90 | return g.op("ReScale", x1, x2, M1, M2) 91 | 92 | @staticmethod 93 | def forward(ctx, x1, x2, M1, M2): 94 | multiplier1, shift1 = approximate_float(M1) 95 | round1 = 1 << (shift1 - 1) 96 | multiplier2, shift2 = approximate_float(M2) 97 | round2 = 1 << (shift2 - 1) 98 | 99 | x1 = (x1 * multiplier1 + round1) >> (31 - shift1) 100 | x2 = (x2 * multiplier2 + round2) >> (31 - shift2) 101 | out = x1 + x2 102 | return out 103 | -------------------------------------------------------------------------------- /torchquanter/nn/qavgpool2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .base import QModule, QParam, FakeQuantize, FloorSTE 5 | from torchquanter.utils import quantize_tensor, approximate_float 6 | 7 | class QAvgPool2d(QModule): 8 | 9 | def __init__(self, avgpool_module, qi=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 10 | super(QAvgPool2d, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 11 | self.avgpool_module = avgpool_module 12 | self.num_bits = num_bits 13 | self.signed = signed 14 | 15 | def freeze(self, qi=None, qo=None): 16 | 17 | if hasattr(self, 'qi') and qi is not None: 18 | raise ValueError('qi has been provided in init function.') 19 | if not hasattr(self, 'qi') and qi is None: 20 | raise ValueError('qi is not existed, should be provided.') 21 | 22 | if hasattr(self, 'qo') and qo is not None: 23 | raise ValueError('qo has been provided in init function.') 24 | if not hasattr(self, 'qo') and qo is None: 25 | raise ValueError('qo is not existed, should be provided.') 26 | 27 | if qi is not None: 28 | self.qi = qi 29 | if qo is not None: 30 | self.qo = qo 31 | self.M = self.qi.scale / self.qo.scale 32 | return self.qo 33 | 34 | def forward(self, x): 35 | if hasattr(self, 'qi'): 36 | self.qi.update(x) 37 | x = FakeQuantize.apply(x, self.qi) 38 | 39 | x = self.avgpool_module(x) 40 | 41 | if hasattr(self, 'qo'): 42 | self.qo.update(x) 43 | x = FakeQuantize.apply(x, self.qo) 44 | return x 45 | 46 | def quantize_inference(self, x, mode=None): 47 | x = x - self.qi.zero_point 48 | x = self.avgpool_module(x).floor() 49 | if mode is None: 50 | x = self.M * x 51 | x.round_() 52 | elif mode == 'cmsis_nn': 53 | x = ReScaleAvgPool.apply(x, self.M) 54 | else: 55 | raise Exception(f'Unknown mode {mode}') 56 | x = x + self.qo.zero_point 57 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 58 | return x 59 | 60 | 61 | class QAdaptiveAvgPool2d(QModule): 62 | 63 | def __init__(self, output_size, qi=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 64 | super(QAdaptiveAvgPool2d, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 65 | self.output_size = output_size 66 | self.num_bits = num_bits 67 | self.signed = signed 68 | 69 | def freeze(self, qi=None, qo=None): 70 | 71 | if hasattr(self, 'qi') and qi is not None: 72 | raise ValueError('qi has been provided in init function.') 73 | if not hasattr(self, 'qi') and qi is None: 74 | raise ValueError('qi is not existed, should be provided.') 75 | 76 | if hasattr(self, 'qo') and qo is not None: 77 | raise ValueError('qo has been provided in init function.') 78 | if not hasattr(self, 'qo') and qo is None: 79 | raise ValueError('qo is not existed, should be provided.') 80 | self.freeze_flag = True 81 | 82 | if qi is not None: 83 | self.qi = qi 84 | if qo is not None: 85 | self.qo = qo 86 | self.M = self.qi.scale / self.qo.scale 87 | return self.qo 88 | 89 | def forward(self, x): 90 | if hasattr(self, 'qi'): 91 | self.qi.update(x) 92 | x = FakeQuantize.apply(x, self.qi) 93 | if self.freeze_flag: 94 | raise Exception(f'{self._get_name()} has been frozen') 95 | 96 | x = F.adaptive_avg_pool2d(x, self.output_size) 97 | 98 | if hasattr(self, 'qo'): 99 | self.qo.update(x) 100 | x = FakeQuantize.apply(x, self.qo) 101 | return x 102 | 103 | def quantize_inference(self, x, mode=None): 104 | x = x - self.qi.zero_point 105 | x = F.adaptive_avg_pool2d(x, self.output_size).floor() 106 | if mode is None: 107 | x = self.M * x 108 | x.round_() 109 | elif mode == 'cmsis_nn': 110 | x = ReScaleAvgPool.apply(x, self.M) 111 | else: 112 | raise Exception(f'Unknown mode {mode}') 113 | x = x + self.qo.zero_point 114 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 115 | return x 116 | 117 | 118 | class ReScaleAvgPool(torch.autograd.Function): 119 | @staticmethod 120 | def symbolic(g, x, M): 121 | return g.op("ReScale", x, M) 122 | 123 | @staticmethod 124 | def forward(ctx, x, M): 125 | multiplier, shift = approximate_float(M) 126 | round_ = 1 << (shift - 1) 127 | x = (x * multiplier + round_) >> (31 - shift) 128 | return x 129 | -------------------------------------------------------------------------------- /torchquanter/nn/qconcat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, QParamIO 6 | from torchquanter.utils import broadcast_dim_as, approximate_float 7 | 8 | class QConcat(QModule): 9 | 10 | def __init__(self, dim, qi1=True, qi2=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 11 | super(QConcat, self).__init__(qi=False, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 12 | if qi1: 13 | self.qi1 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 14 | if qi2: 15 | self.qi2 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 16 | self.num_bits = num_bits 17 | self.signed = signed 18 | self.dim = dim 19 | 20 | M1 = torch.tensor([], requires_grad=False) 21 | self.register_buffer('M1', M1) 22 | M2 = torch.tensor([], requires_grad=False) 23 | self.register_buffer('M2', M2) 24 | 25 | def freeze(self, qi1=None, qi2=None, qo=None): 26 | 27 | if hasattr(self, 'qi1') and qi1 is not None: 28 | raise ValueError('qi has been provided in init function.') 29 | if not hasattr(self, 'qi1') and qi1 is None: 30 | raise ValueError('qi is not existed, should be provided.') 31 | 32 | if hasattr(self, 'qi2') and qi2 is not None: 33 | raise ValueError('qi has been provided in init function.') 34 | if not hasattr(self, 'qi2') and qi2 is None: 35 | raise ValueError('qi is not existed, should be provided.') 36 | 37 | if hasattr(self, 'qo') and qo is not None: 38 | raise ValueError('qo has been provided in init function.') 39 | if not hasattr(self, 'qo') and qo is None: 40 | raise ValueError('qo is not existed, should be provided.') 41 | self.freeze_flag = True 42 | 43 | if qi1 is not None: 44 | self.qi1 = qi1 45 | if qi2 is not None: 46 | self.qi2 = qi2 47 | if qo is not None: 48 | self.qo = qo 49 | self.M1 = self.qi1.scale / self.qo.scale 50 | self.M2 = self.qi2.scale / self.qo.scale 51 | return self.qo 52 | 53 | def forward(self, x1, x2): 54 | if hasattr(self, 'qi1'): 55 | self.qi1.update(x1) 56 | x1 = FakeQuantize.apply(x1, self.qi1) 57 | if hasattr(self, 'qi2'): 58 | self.qi2.update(x2) 59 | x2 = FakeQuantize.apply(x2, self.qi2) 60 | if self.freeze_flag: 61 | raise Exception(f'{self._get_name()} has been frozen') 62 | 63 | out = torch.cat((x1, x2), self.dim) 64 | 65 | if hasattr(self, 'qo'): 66 | self.qo.update(out) 67 | out = FakeQuantize.apply(out, self.qo) 68 | 69 | return out 70 | 71 | def quantize_inference(self, x1, x2, mode=None): 72 | x1 = x1 - self.qi1.zero_point 73 | x2 = x2 - self.qi2.zero_point 74 | if mode is None: 75 | x1 = self.M1 * x1 76 | x2 = self.M2 * x2 77 | out = torch.cat((x1, x2), dim=self.dim) 78 | out.round_() 79 | elif mode == 'cmsis_nn': 80 | out = ReScaleConcat.apply(x1, x2, self.M1, self.M2, self.dim) 81 | else: 82 | raise Exception(f'Unknown mode {mode}') 83 | out = out + self.qo.zero_point 84 | out.clamp_(self.qo.qmin, self.qo.qmax).round_() 85 | return out 86 | 87 | 88 | class ReScaleConcat(torch.autograd.Function): 89 | @staticmethod 90 | def symbolic(g, x1, x2, M1, M2, dim): 91 | return g.op("ReScale", x1, x2, M1, M2, dim) 92 | 93 | @staticmethod 94 | def forward(ctx, x1, x2, M1, M2, dim): 95 | multiplier1, shift1 = approximate_float(M1) 96 | round1 = 1 << (shift1 - 1) 97 | multiplier2, shift2 = approximate_float(M2) 98 | round2 = 1 << (shift2 - 1) 99 | 100 | x1 = (x1 * multiplier1 + round1) >> (31 - shift1) 101 | x2 = (x2 * multiplier2 + round2) >> (31 - shift2) 102 | out = torch.cat((x1, x2), dim=dim) 103 | return out 104 | -------------------------------------------------------------------------------- /torchquanter/nn/qconv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, QParamW, FakeQuantize 6 | from torchquanter.utils import quantize_tensor, broadcast_dim_as, approximate_float 7 | 8 | class QConv2d(QModule): 9 | 10 | def __init__(self, conv_module: nn.Conv2d, relu=False, qi=True, qo=True, num_bits=8, 11 | signed=True, symmetric_feature=False, symmetric_weight=True, qmode='per_channel'): 12 | super(QConv2d, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 13 | self.num_bits = num_bits 14 | self.signed = signed 15 | self.conv_module = conv_module 16 | self.relu = relu 17 | self.qw = QParamW(num_bits=num_bits, signed=signed, symmetric=symmetric_weight, qmode=qmode) 18 | 19 | def freeze(self, qi=None, qo=None): 20 | 21 | if hasattr(self, 'qi') and qi is not None: 22 | raise ValueError('qi has been provided in init function.') 23 | if not hasattr(self, 'qi') and qi is None: 24 | raise ValueError('qi is not existed, should be provided.') 25 | 26 | if hasattr(self, 'qo') and qo is not None: 27 | raise ValueError('qo has been provided in init function.') 28 | if not hasattr(self, 'qo') and qo is None: 29 | raise ValueError('qo is not existed, should be provided.') 30 | self.freeze_flag = True 31 | 32 | if qi is not None: 33 | self.qi = qi 34 | if qo is not None: 35 | self.qo = qo 36 | self.M = self.qw.scale * self.qi.scale / self.qo.scale 37 | 38 | self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data) 39 | self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point.view(-1,1,1,1) # 这样减法后可能无法保证范围在 8bit 内 40 | 41 | if self.conv_module.bias is not None: 42 | self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale, 43 | zero_point=0, num_bits=32, signed=True) 44 | return self.qo 45 | 46 | def forward(self, x): 47 | if hasattr(self, 'qi'): 48 | self.qi.update(x) 49 | x = FakeQuantize.apply(x, self.qi) 50 | if self.freeze_flag: 51 | raise Exception(f'{self._get_name()} has been frozen') 52 | 53 | self.qw.update(self.conv_module.weight.data) # 统计min、max并计算scale和zero_point 54 | 55 | # 不能使用 x = self.conv_module(x) 的方法,因为修改conv_module.weight会报错, 56 | # 修改conv_module.data的话无法正常回传梯度(未验证) 57 | x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw), self.conv_module.bias, 58 | stride=self.conv_module.stride, 59 | padding=self.conv_module.padding, dilation=self.conv_module.dilation, 60 | groups=self.conv_module.groups) 61 | if self.relu: 62 | x = F.relu(x) 63 | 64 | if hasattr(self, 'qo'): 65 | self.qo.update(x) 66 | x = FakeQuantize.apply(x, self.qo) 67 | 68 | return x 69 | 70 | def quantize_inference(self, x, mode=None): 71 | x = x - self.qi.zero_point 72 | x = self.conv_module(x) 73 | if mode is None: 74 | x = broadcast_dim_as(self.M, x, dim=1) * x 75 | x.round_() 76 | elif mode == 'cmsis_nn': 77 | x = ReScaleConv.apply(x, self.M) 78 | else: 79 | raise Exception(f'Unknown mode {mode}') 80 | x = x + self.qo.zero_point 81 | x.clamp_(0 if self.qo.symmetric and self.relu else self.qo.qmin, 82 | self.qo.qmax).round_() 83 | return x 84 | 85 | 86 | class ReScaleConv(torch.autograd.Function): 87 | @staticmethod 88 | def symbolic(g, x, M): 89 | return g.op("ReScale", x, M) 90 | 91 | @staticmethod 92 | def forward(ctx, x, M): 93 | multiplier, shift = approximate_float(M) 94 | round_ = 1 << (shift - 1) 95 | x = (x * broadcast_dim_as(multiplier, x, dim=1) + broadcast_dim_as(round_, x, dim=1)) \ 96 | >> (31 - broadcast_dim_as(shift, x, dim=1)) 97 | return x 98 | -------------------------------------------------------------------------------- /torchquanter/nn/qconvbnrelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from .base import QModule, QParamW, FakeQuantize 7 | from .qconv2d import ReScaleConv 8 | from torchquanter.utils import quantize_tensor, broadcast_dim_as, approximate_float 9 | 10 | class QConvBNReLU(QModule): 11 | 12 | def __init__(self, conv_module: nn.Conv2d, bn_module: nn.BatchNorm2d, relu=True, qi=True, qo=True, 13 | num_bits=8, signed=True, symmetric_feature=False, symmetric_weight=True, qmode='per_channel'): 14 | super(QConvBNReLU, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 15 | self.num_bits = num_bits 16 | self.signed = signed 17 | self.conv_module = conv_module 18 | self.bn_module = bn_module 19 | self.relu = relu 20 | self.qw = QParamW(num_bits=num_bits, signed=signed, symmetric=symmetric_weight, qmode=qmode) 21 | # self.qb = QParam(num_bits=32, signed=signed) 22 | 23 | def fold_bn(self, mean, std): 24 | if self.bn_module.affine: 25 | gamma_ = self.bn_module.weight / std 26 | weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1) 27 | if self.conv_module.bias is not None: 28 | bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias 29 | else: 30 | bias = self.bn_module.bias - gamma_ * mean 31 | else: 32 | gamma_ = 1 / std 33 | weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1) 34 | if self.conv_module.bias is not None: 35 | bias = gamma_ * self.conv_module.bias - gamma_ * mean 36 | else: 37 | bias = -gamma_ * mean 38 | 39 | return weight, bias 40 | 41 | def forward(self, x): 42 | 43 | if hasattr(self, 'qi'): 44 | self.qi.update(x) 45 | x = FakeQuantize.apply(x, self.qi) 46 | if self.freeze_flag: 47 | raise Exception(f'{self._get_name()} has been frozen') 48 | 49 | if self.training: 50 | y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias, 51 | stride=self.conv_module.stride, 52 | padding=self.conv_module.padding, 53 | dilation=self.conv_module.dilation, 54 | groups=self.conv_module.groups) 55 | y = y.permute(1, 0, 2, 3) # NCHW -> CNHW 56 | y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW 57 | mean = y.mean(1).detach() 58 | var = y.var(1).detach() 59 | self.bn_module.running_mean = \ 60 | self.bn_module.momentum * self.bn_module.running_mean + \ 61 | (1 - self.bn_module.momentum) * mean 62 | self.bn_module.running_var = \ 63 | self.bn_module.momentum * self.bn_module.running_var + \ 64 | (1 - self.bn_module.momentum) * var 65 | else: 66 | mean = Variable(self.bn_module.running_mean) 67 | var = Variable(self.bn_module.running_var) 68 | 69 | std = torch.sqrt(var + self.bn_module.eps) 70 | 71 | weight, bias = self.fold_bn(mean, std) 72 | 73 | self.qw.update(weight.data) 74 | 75 | x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias, 76 | stride=self.conv_module.stride, 77 | padding=self.conv_module.padding, dilation=self.conv_module.dilation, 78 | groups=self.conv_module.groups) 79 | 80 | if self.relu: 81 | x = F.relu(x) 82 | 83 | if hasattr(self, 'qo'): 84 | self.qo.update(x) 85 | x = FakeQuantize.apply(x, self.qo) 86 | 87 | return x 88 | 89 | def freeze(self, qi=None, qo=None): 90 | if hasattr(self, 'qi') and qi is not None: 91 | raise ValueError('qi has been provided in init function.') 92 | if not hasattr(self, 'qi') and qi is None: 93 | raise ValueError('qi is not existed, should be provided.') 94 | 95 | if hasattr(self, 'qo') and qo is not None: 96 | raise ValueError('qo has been provided in init function.') 97 | if not hasattr(self, 'qo') and qo is None: 98 | raise ValueError('qo is not existed, should be provided.') 99 | self.freeze_flag = True 100 | 101 | if qi is not None: 102 | self.qi = qi 103 | if qo is not None: 104 | self.qo = qo 105 | self.M = self.qw.scale * self.qi.scale / self.qo.scale 106 | 107 | std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps) 108 | 109 | weight, bias = self.fold_bn(self.bn_module.running_mean, std) 110 | self.conv_module.weight.data = self.qw.quantize_tensor(weight.data) 111 | self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point.view(-1,1,1,1) 112 | 113 | if self.conv_module.bias is not None: 114 | self.conv_module.bias.data = quantize_tensor(bias, scale=self.qi.scale * self.qw.scale, 115 | zero_point=0, num_bits=32, signed=True) 116 | else: 117 | self.conv_module.bias = nn.Parameter(quantize_tensor(bias, scale=self.qi.scale * self.qw.scale, 118 | zero_point=0, num_bits=32, signed=True), requires_grad=True) 119 | return self.qo 120 | 121 | def quantize_inference(self, x, mode=None): 122 | x = x - self.qi.zero_point 123 | x = self.conv_module(x) 124 | if mode is None: 125 | x = broadcast_dim_as(self.M, x, dim=1) * x 126 | x.round_() 127 | elif mode == 'cmsis_nn': 128 | x = ReScaleConv.apply(x, self.M) 129 | else: 130 | raise Exception(f'Unknown mode {mode}') 131 | x = x + self.qo.zero_point 132 | x.clamp_(0 if self.qo.symmetric and self.relu else self.qo.qmin, 133 | self.qo.qmax).round_() 134 | return x 135 | -------------------------------------------------------------------------------- /torchquanter/nn/qdiv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, QParamIO, FloorSTE, QuantizeTensor, DequantizeTensor 6 | from torchquanter.utils import broadcast_dim_as, approximate_float 7 | 8 | class QDiv(QModule): 9 | # High error 10 | pass 11 | 12 | ''' 13 | def __init__(self, mul_const=None, qi1=True, qi2=True, qo=True, num_bits=8, signed=True): 14 | """ 15 | Args 16 | ---------- 17 | mul_const: if not None, x1 / x2 * mul_const 18 | """ 19 | super(QDiv, self).__init__(qi=False, qo=qo, num_bits=num_bits, signed=signed) 20 | if qi1: 21 | self.qi1 = QParamIO(num_bits=num_bits, signed=signed, symmetric=False) 22 | if qi2: 23 | self.qi2 = QParamIO(num_bits=num_bits, signed=signed, symmetric=False) 24 | self.mul_const = mul_const if mul_const is not None else 1. 25 | self.num_bits = num_bits 26 | self.signed = signed 27 | self.first_time = True 28 | 29 | def freeze(self, qi1=None, qi2=None, qo=None): 30 | 31 | if hasattr(self, 'qi1') and qi1 is not None: 32 | raise ValueError('qi has been provided in init function.') 33 | if not hasattr(self, 'qi1') and qi1 is None: 34 | raise ValueError('qi is not existed, should be provided.') 35 | 36 | if hasattr(self, 'qi2') and qi2 is not None: 37 | raise ValueError('qi has been provided in init function.') 38 | if not hasattr(self, 'qi2') and qi2 is None: 39 | raise ValueError('qi is not existed, should be provided.') 40 | 41 | if hasattr(self, 'qo') and qo is not None: 42 | raise ValueError('qo has been provided in init function.') 43 | if not hasattr(self, 'qo') and qo is None: 44 | raise ValueError('qo is not existed, should be provided.') 45 | 46 | if qi1 is not None: 47 | self.qi1 = qi1 48 | if qi2 is not None: 49 | self.qi2 = qi2 50 | if qo is not None: 51 | self.qo = qo 52 | self.M = self.qi1.scale * self.mul_const / (self.qo.scale * self.qi2.scale) 53 | 54 | def forward(self, x1, x2, qi1=None, qi2=None): 55 | """ 56 | here need to get before_layer.qo as here.qi 57 | """ 58 | if not hasattr(self, 'qi1') and qi1 is None: 59 | raise ValueError('qi1 is not existed, should be provided.') 60 | if hasattr(self, 'qi1') and qi1 is not None: 61 | raise ValueError('qi1 has been provided in init function.') 62 | if not hasattr(self, 'qi2') and qi2 is None: 63 | raise ValueError('qi2 is not existed, should be provided.') 64 | if hasattr(self, 'qi2') and qi2 is not None: 65 | raise ValueError('qi2 has been provided in init function.') 66 | 67 | if hasattr(self, 'qi1'): 68 | qi1 = self.qi1 69 | qi1.update(x1) 70 | x1 = FakeQuantize.apply(x1, self.qi1) 71 | if hasattr(self, 'qi2'): 72 | qi2 = self.qi2 73 | qi2.update(x2) 74 | x1 = FakeQuantize.apply(x1, self.qi2) 75 | 76 | if self.first_time: 77 | out = torch.div(x1, x2) 78 | self.qo.update(out) 79 | self.first_time=False 80 | else: 81 | qx1 = QuantizeTensor.apply(x1, qi1) 82 | qx2 = QuantizeTensor.apply(x2, qi2) 83 | 84 | qx1 = qx1 - qi1.zero_point 85 | qx2 = qx2 - qi2.zero_point 86 | 87 | qx = FloorSTE.apply(qx1 / qx2) 88 | out = self.mul_const * qi1.scale * qx / (self.qo.scale * qi2.scale) 89 | 90 | if hasattr(self, 'qo'): 91 | self.qo.update(out) 92 | out = FakeQuantize.apply(out, self.qo) 93 | 94 | return out 95 | 96 | def quantize_inference(self, x1, x2, mode=None): 97 | x1 = x1 - self.qi1.zero_point 98 | x2 = x2 - self.qi2.zero_point 99 | if mode is None: 100 | out = torch.div(x1, x2).floor() 101 | out = out * self.M 102 | out.round_() 103 | elif mode == 'cmsis_nn': 104 | multiplier, shift = approximate_float(self.M) 105 | round_ = 1 << (shift - 1) 106 | out = torch.div(x1, x2).floor() 107 | out = (out * multiplier + round_) >> (31 - shift) 108 | else: 109 | raise Exception(f'Unknown mode {mode}') 110 | out = out + self.qo.zero_point 111 | out.clamp_(self.qo.qmin, self.qo.qmax).round_() 112 | return out 113 | ''' -------------------------------------------------------------------------------- /torchquanter/nn/qlayernorm.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .base import QModule, QParamW, FakeQuantize, FloorSTE, QuantizeTensor, DequantizeTensor, RoundSTE, ClampSTE 8 | from .qnorm import QNorm 9 | from .qmean import QMean 10 | from .qadd import QAdd 11 | from .qsub import QSub 12 | from .qmul import QMul 13 | from .qsqrt import QSqrt 14 | from .qdiv import QDiv 15 | from torchquanter.utils import quantize_tensor, broadcast_dim_as, approximate_float, sqrt_interger, get_qmin_qmax 16 | 17 | #### Deprecated! 18 | # class QLayerNorm(QModule): 19 | # """ 20 | # QNorm * weight + bias 21 | # """ 22 | 23 | # def __init__(self, layernorm_module: nn.LayerNorm, qi=True, qo=True, num_bits=8, max_bits=32, 24 | # signed=True, symmetric_weight=True): 25 | # qlayernorm_qo = qo if layernorm_module.elementwise_affine else False # affine为False则直接使用QNorm的qo即可 26 | # super(QLayerNorm, self).__init__(qi=qi, qo=qlayernorm_qo, num_bits=num_bits, signed=signed) 27 | # self.num_bits = num_bits 28 | # self.max_bits = max_bits 29 | # self.signed = signed 30 | # self.layernorm_module = layernorm_module 31 | # self.qnorm = QNorm(qi=False, qo=True, num_bits=num_bits, max_bits=32, signed=signed) 32 | 33 | # if self.layernorm_module.elementwise_affine: 34 | # self.qw = QParamW(num_bits=num_bits, signed=signed, symmetric=symmetric_weight, qmode='per_tensor') 35 | 36 | # def freeze(self, qi=None, qo=None): 37 | 38 | # if hasattr(self, 'qi') and qi is not None: 39 | # raise ValueError('qi has been provided in init function.') 40 | # if not hasattr(self, 'qi') and qi is None: 41 | # raise ValueError('qi is not existed, should be provided.') 42 | 43 | # if hasattr(self, 'qo') and qo is not None: 44 | # raise ValueError('qo has been provided in init function.') 45 | 46 | # if qi is not None: 47 | # self.qi = qi 48 | # if qo is not None: 49 | # self.qo = qo 50 | 51 | # self.qnorm.freeze(qi=self.qi) 52 | 53 | # if self.layernorm_module.elementwise_affine: 54 | # self.M = self.qnorm.qo.scale * self.qw.scale / self.qo.scale # 这里非常特殊,没有self.qi.scale,因为输入标准化后完全消除了qi.scale,导致之后无法提取qi.scale了 55 | 56 | # self.layernorm_module.weight.data = self.qw.quantize_tensor(self.layernorm_module.weight.data) 57 | # self.layernorm_module.weight.data = self.layernorm_module.weight.data - self.qw.zero_point # 这样减法后可能无法保证范围在 8bit 内 58 | 59 | # self.layernorm_module.bias.data = quantize_tensor(self.layernorm_module.bias.data, scale=self.qnorm.qo.scale * self.qw.scale, 60 | # zero_point=0, num_bits=32, signed=True) 61 | # else: 62 | # self.qo = self.qnorm.qo 63 | 64 | # def forward(self, x, qi=None): 65 | # """ 66 | # here need to get before_layer.qo as layernorm.qi 67 | # """ 68 | # if not hasattr(self, 'qi') and qi is None: 69 | # raise ValueError('qi is not existed, should be provided.') 70 | # if hasattr(self, 'qi') and qi is not None: # for test without before_layer.qo 71 | # raise ValueError('qi has been provided in init function.') 72 | # if hasattr(self, 'qi') and qi is None: # for test without before_layer.qo 73 | # qi = self.qi 74 | # qi.update(x) 75 | 76 | # x = self.qnorm(x, qi) 77 | 78 | # if self.layernorm_module.elementwise_affine: 79 | # self.qw.update(self.layernorm_module.weight.data) # 统计min、max并计算scale和zero_point 80 | # x = torch.mul(x, FakeQuantize.apply(self.layernorm_module.weight, self.qw)) + self.layernorm_module.bias 81 | 82 | # if hasattr(self, 'qo'): 83 | # self.qo.update(x) 84 | # x = FakeQuantize.apply(x, self.qo) 85 | 86 | # return x 87 | 88 | # def quantize_inference(self, x, mode=None): 89 | # x = self.qnorm.quantize_inference(x, mode=mode) 90 | 91 | # if self.layernorm_module.elementwise_affine: 92 | # x = x - self.qnorm.qo.zero_point 93 | 94 | # x = x * self.layernorm_module.weight.data + self.layernorm_module.bias.data 95 | 96 | # if mode is None: 97 | # x = self.M * x 98 | # x.round_() 99 | # elif mode == 'cmsis_nn': 100 | # multiplier, shift = approximate_float(self.M) 101 | # round_ = 1 << (shift - 1) 102 | # x = (x * multiplier + round_) >> (31 - shift) 103 | # else: 104 | # raise Exception(f'Unknown mode {mode}') 105 | # x = x + self.qo.zero_point 106 | # x.clamp_(self.qo.qmin, self.qo.qmax).round_() 107 | # return x 108 | 109 | 110 | class QLayerNorm(QModule): 111 | """ 112 | 整体量化 113 | """ 114 | 115 | def __init__(self, layernorm_module: nn.LayerNorm, qi=True, qo=True, num_bits=8, max_bits=32, 116 | signed=True, symmetric_feature=False, symmetric_weight=True): 117 | super(QLayerNorm, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 118 | self.num_bits = num_bits 119 | self.max_bits = max_bits 120 | self.layernorm_module = layernorm_module 121 | self.signed = signed 122 | self.scale = 2**(8 - 1) 123 | self.first_time = True 124 | 125 | if self.layernorm_module.elementwise_affine: 126 | self.qw = QParamW(num_bits=num_bits, signed=signed, symmetric=symmetric_weight, qmode='per_tensor') 127 | 128 | def freeze(self, qi=None, qo=None): 129 | 130 | if hasattr(self, 'qi') and qi is not None: 131 | raise ValueError('qi has been provided in init function.') 132 | if not hasattr(self, 'qi') and qi is None: 133 | raise ValueError('qi is not existed, should be provided.') 134 | 135 | if hasattr(self, 'qo') and qo is not None: 136 | raise ValueError('qo has been provided in init function.') 137 | if not hasattr(self, 'qo') and qo is None: 138 | raise ValueError('qo is not existed, should be provided.') 139 | self.freeze_flag = True 140 | 141 | if qi is not None: 142 | self.qi = qi 143 | if qo is not None: 144 | self.qo = qo 145 | 146 | if self.layernorm_module.elementwise_affine: 147 | self.M = self.qw.scale / (self.qo.scale * self.scale) # 这里非常特殊,没有self.qi.scale,因为输入标准化后完全消除了qi.scale,导致之后无法提取qi.scale了 148 | self.layernorm_module.weight.data = self.qw.quantize_tensor(self.layernorm_module.weight.data) 149 | self.layernorm_module.weight.data = self.layernorm_module.weight.data - self.qw.zero_point # 这样减法后可能无法保证范围在 8bit 内 150 | 151 | self.layernorm_module.bias.data = quantize_tensor(self.layernorm_module.bias.data, scale=self.qw.scale / self.scale, 152 | zero_point=0, num_bits=32, signed=True) 153 | else: 154 | self.M = 1 / (self.qo.scale * self.scale) # 这里非常特殊,没有self.qi.scale,因为输入标准化后完全消除了qi.scale,导致之后无法提取qi.scale了 155 | return self.qo 156 | 157 | def forward(self, x, qi=None): 158 | """ 159 | here need to get before_layer.qo as norm.qi 160 | """ 161 | if not hasattr(self, 'qi') and qi is None: 162 | raise ValueError('qi is not existed, should be provided.') 163 | if hasattr(self, 'qi') and qi is not None: # for test without before_layer.qo 164 | raise ValueError('qi has been provided in init function.') 165 | if hasattr(self, 'qi') and qi is None: # for test without before_layer.qo 166 | qi = self.qi 167 | qi.update(x) 168 | if self.freeze_flag: 169 | raise Exception(f'{self._get_name()} has been frozen') 170 | 171 | if self.qo.scale.numel() == 0: 172 | x = F.layer_norm(x, self.layernorm_module.normalized_shape, 173 | self.layernorm_module.weight, self.layernorm_module.bias, 174 | self.layernorm_module.eps) 175 | else: 176 | qx = QuantizeTensor.apply(x, qi) 177 | qx = qx - qi.zero_point 178 | 179 | # Interger-only Norm 180 | mean_ = qx.mean(dim=-1, keepdim=True).clamp(*get_qmin_qmax(16, signed=True)) 181 | sum_ = ClampSTE.apply(torch.sum((qx - mean_)**2, dim=-1, keepdim=True), 182 | *get_qmin_qmax(self.max_bits, signed=True)) # int32, 这里超出去直接裁剪可能不如偏移来得好,先这么做吧 183 | var_ = FloorSTE.apply(sum_ / qx.shape[-1]) 184 | var_[var_ == 0.] = 1. # prevent overflow 185 | std_ = FloorSTE.apply(torch.sqrt(var_)) 186 | factor = FloorSTE.apply(2**(8 - 1) / std_) 187 | qx = FloorSTE.apply(ClampSTE.apply((qx - mean_) * factor, *get_qmin_qmax(16, signed=True))) 188 | x = qx / 2**(8 - 1) # 不需要floor因为这个除法是整合到M中去的 189 | 190 | if self.layernorm_module.elementwise_affine: 191 | self.qw.update(self.layernorm_module.weight.data) # 统计min、max并计算scale和zero_point 192 | x = torch.mul(x, FakeQuantize.apply(self.layernorm_module.weight, self.qw)) + self.layernorm_module.bias 193 | x = x.clamp(*get_qmin_qmax(self.max_bits, signed=True)) 194 | 195 | self.qo.update(x) 196 | x = FakeQuantize.apply(x, self.qo) 197 | return x 198 | 199 | def quantize_inference(self, x, mode=None): 200 | x = x - self.qi.zero_point 201 | 202 | if self.layernorm_module.elementwise_affine: 203 | x = QLayerNormAffine.apply( 204 | x, 205 | self.layernorm_module.weight, 206 | self.layernorm_module.bias, 207 | self.layernorm_module.eps, 208 | list(range(-1, -len(self.layernorm_module.normalized_shape) - 1, -1)), 209 | self.scale, 210 | self.max_bits 211 | ) 212 | else: 213 | x = QLayerNormNoAffine.apply( 214 | x, 215 | self.layernorm_module.eps, 216 | list(range(-1, -len(self.layernorm_module.normalized_shape) - 1, -1)), 217 | self.scale, 218 | self.max_bits 219 | ) 220 | 221 | if mode is None: 222 | x = self.M * x 223 | x.round_() 224 | elif mode == 'cmsis_nn': 225 | x = ReScaleLayerNorm.apply(x, self.M) 226 | else: 227 | raise Exception(f'Unknown mode {mode}') 228 | x = x + self.qo.zero_point 229 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 230 | return x 231 | 232 | 233 | class ReScaleLayerNorm(torch.autograd.Function): 234 | @staticmethod 235 | def symbolic(g, x, M): 236 | return g.op("ReScale", x, M) 237 | 238 | @staticmethod 239 | def forward(ctx, x, M): 240 | multiplier, shift = approximate_float(M) 241 | round_ = 1 << (shift - 1) 242 | x = (x * multiplier + round_) >> (31 - shift) 243 | return x 244 | 245 | 246 | class QLayerNormAffine(torch.autograd.Function): 247 | @staticmethod 248 | def symbolic(g, x, weight, bias, eps, axis, scale, max_bits): 249 | return g.op( 250 | "QLayerNorm", 251 | x, 252 | weight, 253 | bias, 254 | epsilon_f=eps, 255 | axis_i=axis, 256 | scale_i=scale, 257 | max_bits_i=max_bits 258 | ) 259 | 260 | @staticmethod 261 | def forward(ctx, x, weight, bias, eps, axis, scale, max_bits): 262 | # Interger-only LayerNorm 263 | mean_ = x.mean(dim=-1, keepdim=True).clamp(*get_qmin_qmax(16, signed=True)) # int16 264 | mean_ = mean_.round() 265 | sum_ = torch.sum((x - mean_)**2, dim=-1, keepdim=True).clamp(*get_qmin_qmax(max_bits, signed=True)) # 裁剪到32bit范围内 266 | var_ = torch.floor(sum_ / x.shape[-1]) 267 | var_[var_ == 0.] = 1. # prevent overflow 268 | # std_ = sqrt_interger(var_) # 比较费时间,此处快速评估无需使用 269 | std_ = torch.sqrt(var_).floor() 270 | factor = torch.floor(scale / std_) 271 | x = torch.floor(torch.clamp((x - mean_) * factor, *get_qmin_qmax(16, signed=True))) 272 | 273 | x = x * weight.data + bias.data 274 | x = x.clamp(*get_qmin_qmax(max_bits, signed=True)) 275 | return x 276 | 277 | 278 | class QLayerNormNoAffine(torch.autograd.Function): 279 | @staticmethod 280 | def symbolic(g, x, eps, axis, scale, max_bits): 281 | return g.op( 282 | "QLayerNorm", 283 | x, 284 | epsilon_f=eps, 285 | axis_i=axis, 286 | scale_i=scale, 287 | max_bits_i=max_bits 288 | ) 289 | 290 | @staticmethod 291 | def forward(ctx, x, eps, axis, scale, max_bits): 292 | # Interger-only LayerNorm 293 | mean_ = x.mean(dim=-1, keepdim=True).clamp(*get_qmin_qmax(16, signed=True)) # int16 294 | mean_ = mean_.round() 295 | sum_ = torch.sum((x - mean_)**2, dim=-1, keepdim=True).clamp(*get_qmin_qmax(max_bits, signed=True)) # 裁剪到32bit范围内 296 | var_ = torch.floor(sum_ / x.shape[-1]) 297 | var_[var_ == 0.] = 1. # prevent overflow 298 | # std_ = sqrt_interger(var_) # 比较费时间,此处快速评估无需使用 299 | std_ = torch.sqrt(var_).floor() 300 | factor = torch.floor(scale / std_) 301 | x = torch.floor(torch.clamp((x - mean_) * factor, *get_qmin_qmax(max_bits, signed=True))) 302 | return x 303 | 304 | 305 | 306 | class QLayerNormFP32(QModule): 307 | """ 308 | 使用FP32进行计算 309 | """ 310 | 311 | def __init__(self, layernorm_module: nn.LayerNorm, qi=True, qo=True, num_bits=8, max_bits=32, 312 | signed=True, symmetric_weight=True): 313 | super(QLayerNormFP32, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed) 314 | self.num_bits = num_bits 315 | self.max_bits = max_bits 316 | self.layernorm_module = layernorm_module 317 | self.signed = signed 318 | 319 | def freeze(self, qi=None, qo=None): 320 | 321 | if hasattr(self, 'qi') and qi is not None: 322 | raise ValueError('qi has been provided in init function.') 323 | if not hasattr(self, 'qi') and qi is None: 324 | raise ValueError('qi is not existed, should be provided.') 325 | 326 | if hasattr(self, 'qo') and qo is not None: 327 | raise ValueError('qo has been provided in init function.') 328 | if not hasattr(self, 'qo') and qo is None: 329 | raise ValueError('qo is not existed, should be provided.') 330 | 331 | if qi is not None: 332 | self.qi = qi 333 | if qo is not None: 334 | self.qo = qo 335 | 336 | def forward(self, x, qi=None): 337 | """ 338 | qi=None是为了保证和之前一致,使得代码可以服用 339 | """ 340 | if hasattr(self, 'qi'): 341 | self.qi.update(x) 342 | x = FakeQuantize.apply(x, self.qi) 343 | 344 | x = self.layernorm_module(x) 345 | 346 | if hasattr(self, 'qo'): 347 | self.qo.update(x) 348 | x = FakeQuantize.apply(x, self.qo) 349 | return x 350 | 351 | def quantize_inference(self, x, mode=None): 352 | x = self.qi.dequantize_tensor(x) 353 | x = self.layernorm_module(x) 354 | x = self.qo.quantize_tensor(x) 355 | return x 356 | 357 | 358 | class QLayerNormTFLite(QModule): 359 | pass 360 | ''' 361 | # High error 362 | def __init__(self, layernorm_module: nn.LayerNorm, qi=True, qo=True, num_bits=8, max_bits=32, 363 | signed=True, symmetric_weight=True): 364 | qlayernorm_qo = qo if layernorm_module.elementwise_affine else False # affine为False则直接使用QMul的qo即可 365 | super(QLayerNormTFLite, self).__init__(qi=qi, qo=qlayernorm_qo, num_bits=num_bits, signed=signed) 366 | self.num_bits = num_bits 367 | self.max_bits = max_bits 368 | self.signed = signed 369 | self.layernorm_module = layernorm_module 370 | 371 | # numerator 372 | self.mean = QMean(dim=-1, keepdim=True, qi=False, qo=True, num_bits=num_bits, signed=signed) 373 | self.sub = QSub(qi1=False, qi2=False, qo=True, num_bits=num_bits, signed=signed) 374 | 375 | # denominator 376 | self.var_mul = QMul(qi1=False, qi2=False, qo=True, num_bits=num_bits, signed=signed) 377 | self.var_mean = QMean(dim=-1, keepdim=True, qi=False, qo=True, num_bits=num_bits, signed=signed) 378 | self.var_sqrt = QSqrt(qi=False, qo=True, num_bits=num_bits, signed=signed) 379 | 380 | # numerator / denominator 381 | self.div = QDiv(qi1=False, qi2=False, qo=True, num_bits=num_bits, signed=signed) 382 | 383 | if self.layernorm_module.elementwise_affine: 384 | self.qw = QParamW(num_bits=num_bits, signed=signed, symmetric=symmetric_weight, qmode='per_tensor') 385 | 386 | def freeze(self, qi=None, qo=None): 387 | 388 | if hasattr(self, 'qi') and qi is not None: 389 | raise ValueError('qi has been provided in init function.') 390 | if not hasattr(self, 'qi') and qi is None: 391 | raise ValueError('qi is not existed, should be provided.') 392 | 393 | if hasattr(self,'qo') and qo is not None: 394 | raise ValueError('qo has been provided in init function.') 395 | 396 | if qi is not None: 397 | self.qi = qi 398 | if qo is not None: 399 | self.qo = qo 400 | 401 | # numerator 402 | self.mean.freeze(qi=self.qi) 403 | self.sub.freeze(qi1=self.qi, qi2=self.mean.qo) 404 | 405 | # denominator 406 | self.var_mul.freeze(qi1=self.sub.qo, qi2=self.sub.qo) 407 | self.var_mean.freeze(qi=self.var_mul.qo) 408 | self.var_sqrt.freeze(qi=self.var_mean.qo) 409 | 410 | # numerator / denominator 411 | self.div.freeze(qi1=self.sub.qo, qi2=self.var_sqrt.qo) 412 | 413 | if self.layernorm_module.elementwise_affine: 414 | self.M = self.div.qo.scale * self.qw.scale / self.qo.scale 415 | 416 | self.layernorm_module.weight.data = self.qw.quantize_tensor(self.layernorm_module.weight.data) 417 | self.layernorm_module.weight.data = self.layernorm_module.weight.data - self.qw.zero_point # 这样减法后可能无法保证范围在 8bit 内 418 | 419 | self.layernorm_module.bias.data = quantize_tensor(self.layernorm_module.bias.data, scale=self.div.qo.scale * self.qw.scale, 420 | zero_point=0, num_bits=32, signed=True) 421 | else: 422 | self.qo = self.div.qo 423 | 424 | def forward(self, x): 425 | """ 426 | here need to get before_layer.qo as layernorm.qi 427 | """ 428 | if hasattr(self, 'qi'): 429 | self.qi.update(x) 430 | x = FakeQuantize.apply(x, self.qi) 431 | 432 | # numerator 433 | mean_ = self.mean(x) 434 | sub_ = self.sub(x, mean_) 435 | 436 | # denominator 437 | var_mul_ = self.var_mul(sub_, sub_) 438 | var_mean_ = self.var_mean(var_mul_) 439 | var_sqrt_ = self.var_sqrt(var_mean_, qi=self.var_mean.qo) 440 | 441 | # numerator / denominator 442 | div_ = self.div(sub_, var_sqrt_, qi1=self.sub.qo, qi2=self.var_sqrt.qo) 443 | x = div_ 444 | 445 | if self.layernorm_module.elementwise_affine: 446 | self.qw.update(self.layernorm_module.weight.data) # 统计min、max并计算scale和zero_point 447 | x = torch.mul(x, FakeQuantize.apply(self.layernorm_module.weight, self.qw)) + self.layernorm_module.bias 448 | 449 | if hasattr(self, 'qo'): 450 | self.qo.update(x) 451 | x = FakeQuantize.apply(x, self.qo) 452 | 453 | return x 454 | 455 | def quantize_inference(self, x, mode=None): 456 | x = self.qi.dequantize_tensor(x) # float32 -> int8 457 | 458 | # numerator 459 | qmean = self.mean.quantize_inference(x, mode=mode) 460 | qsub = self.sub.quantize_inference(x, qmean, mode=mode) 461 | 462 | # denominator 463 | qvar_mul = self.var_mul.quantize_inference(qsub, qsub, mode=mode) 464 | qvar_mean = self.var_mean.quantize_inference(qvar_mul, mode=mode) 465 | qvar_sqrt = self.var_sqrt.quantize_inference(qvar_mean, mode=mode) 466 | 467 | # numerator / denominator 468 | qdiv = self.div.quantize_inference(qsub, qvar_sqrt, mode=mode) 469 | x = qdiv 470 | 471 | if self.layernorm_module.elementwise_affine: 472 | x = x - self.div.qo.zero_point 473 | 474 | x = x * self.layernorm_module.weight.data + self.layernorm_module.bias.data 475 | 476 | if mode is None: 477 | x = self.M * x 478 | x.round_() 479 | elif mode == 'cmsis_nn': 480 | multiplier, shift = approximate_float(self.M) 481 | round_ = 1 << (shift - 1) 482 | x = (x * multiplier + round_) >> (31 - shift) 483 | else: 484 | raise Exception(f'Unknown mode {mode}') 485 | x = x + self.qo.zero_point 486 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 487 | return x 488 | ''' -------------------------------------------------------------------------------- /torchquanter/nn/qlinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, QParamW, FakeQuantize 6 | from torchquanter.utils import quantize_tensor, approximate_float 7 | 8 | class QLinear(QModule): 9 | 10 | def __init__(self, fc_module: nn.Linear, relu=False, qi=True, qo=True, num_bits=8, 11 | signed=True, symmetric_feature=False, symmetric_weight=True): 12 | super(QLinear, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 13 | self.num_bits = num_bits 14 | self.signed = signed 15 | self.fc_module = fc_module 16 | self.relu = relu 17 | self.qw = QParamW(num_bits=num_bits, signed=signed, symmetric=symmetric_weight, qmode='per_tensor') 18 | 19 | def freeze(self, qi=None, qo=None): 20 | 21 | if hasattr(self, 'qi') and qi is not None: 22 | raise ValueError('qi has been provided in init function.') 23 | if not hasattr(self, 'qi') and qi is None: 24 | raise ValueError('qi is not existed, should be provided.') 25 | 26 | if hasattr(self, 'qo') and qo is not None: 27 | raise ValueError('qo has been provided in init function.') 28 | if not hasattr(self, 'qo') and qo is None: 29 | raise ValueError('qo is not existed, should be provided.') 30 | self.freeze_flag = True 31 | 32 | if qi is not None: 33 | self.qi = qi 34 | if qo is not None: 35 | self.qo = qo 36 | self.M = self.qw.scale * self.qi.scale / self.qo.scale 37 | 38 | self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data) 39 | self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point.view(-1,1) 40 | 41 | if self.fc_module.bias is not None: 42 | self.fc_module.bias.data = quantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale, 43 | zero_point=0, num_bits=32, signed=True) 44 | return self.qo 45 | 46 | def forward(self, x): 47 | if hasattr(self, 'qi'): 48 | self.qi.update(x) 49 | x = FakeQuantize.apply(x, self.qi) 50 | if self.freeze_flag: 51 | raise Exception(f'{self._get_name()} has been frozen') 52 | 53 | self.qw.update(self.fc_module.weight.data) 54 | 55 | x = F.linear(x, FakeQuantize.apply(self.fc_module.weight, self.qw), self.fc_module.bias) 56 | if self.relu: 57 | x = F.relu(x) 58 | 59 | if hasattr(self, 'qo'): 60 | self.qo.update(x) 61 | x = FakeQuantize.apply(x, self.qo) 62 | 63 | return x 64 | 65 | def quantize_inference(self, x, mode=None): 66 | x = x - self.qi.zero_point 67 | x = self.fc_module(x) 68 | if mode is None: 69 | x = self.M * x 70 | x.round_() 71 | elif mode == 'cmsis_nn': 72 | x = ReScaleLinear.apply(x, self.M) 73 | else: 74 | raise Exception(f'Unknown mode {mode}') 75 | x = x + self.qo.zero_point 76 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 77 | x.clamp_(0 if self.qo.symmetric and self.relu else self.qo.qmin, 78 | self.qo.qmax).round_() 79 | return x 80 | 81 | 82 | class ReScaleLinear(torch.autograd.Function): 83 | @staticmethod 84 | def symbolic(g, x, M): 85 | return g.op("ReScale", x, M) 86 | 87 | @staticmethod 88 | def forward(ctx, x, M): 89 | multiplier, shift = approximate_float(M) 90 | round_ = 1 << (shift - 1) 91 | x = (x * multiplier + round_) >> (31 - shift) 92 | return x 93 | -------------------------------------------------------------------------------- /torchquanter/nn/qmatmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, QParamIO 6 | from torchquanter.utils import broadcast_dim_as, approximate_float 7 | 8 | class QMatmul(QModule): 9 | """ 10 | Dot produc function 11 | """ 12 | 13 | def __init__(self, qi1=True, qi2=True, qo=True, mul_const=None, num_bits=8, signed=True, symmetric_feature=False): 14 | """ 15 | Args 16 | ---------- 17 | const: if not None, torch.matmul(x1, x2) * const 18 | """ 19 | super(QMatmul, self).__init__(qi=False, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 20 | if qi1: 21 | self.qi1 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 22 | if qi2: 23 | self.qi2 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 24 | self.mul_const = mul_const if mul_const is not None else 1. 25 | self.num_bits = num_bits 26 | self.signed = signed 27 | 28 | def freeze(self, qi1=None, qi2=None, qo=None): 29 | 30 | if hasattr(self, 'qi1') and qi1 is not None: 31 | raise ValueError('qi has been provided in init function.') 32 | if not hasattr(self, 'qi1') and qi1 is None: 33 | raise ValueError('qi is not existed, should be provided.') 34 | 35 | if hasattr(self, 'qi2') and qi2 is not None: 36 | raise ValueError('qi has been provided in init function.') 37 | if not hasattr(self, 'qi2') and qi2 is None: 38 | raise ValueError('qi is not existed, should be provided.') 39 | 40 | if hasattr(self, 'qo') and qo is not None: 41 | raise ValueError('qo has been provided in init function.') 42 | if not hasattr(self, 'qo') and qo is None: 43 | raise ValueError('qo is not existed, should be provided.') 44 | self.freeze_flag = True 45 | 46 | if qi1 is not None: 47 | self.qi1 = qi1 48 | if qi2 is not None: 49 | self.qi2 = qi2 50 | if qo is not None: 51 | self.qo = qo 52 | self.M = self.qi1.scale * self.qi2.scale * self.mul_const / self.qo.scale 53 | return self.qo 54 | 55 | def forward(self, x1, x2): 56 | if hasattr(self, 'qi1'): 57 | self.qi1.update(x1) 58 | x1 = FakeQuantize.apply(x1, self.qi1) 59 | if hasattr(self, 'qi2'): 60 | self.qi2.update(x2) 61 | x2 = FakeQuantize.apply(x2, self.qi2) 62 | if self.freeze_flag: 63 | raise Exception(f'{self._get_name()} has been frozen') 64 | 65 | out = torch.matmul(x1, x2) * self.mul_const 66 | 67 | if hasattr(self, 'qo'): 68 | self.qo.update(out) 69 | out = FakeQuantize.apply(out, self.qo) 70 | 71 | return out 72 | 73 | def quantize_inference(self, x1, x2, mode=None): 74 | x1 = x1 - self.qi1.zero_point 75 | x2 = x2 - self.qi2.zero_point 76 | if mode is None: 77 | out = torch.matmul(x1, x2) 78 | out = out * self.M 79 | out.round_() 80 | elif mode == 'cmsis_nn': 81 | out = ReScaleMatMul.apply(x1, x2, self.M) 82 | else: 83 | raise Exception(f'Unknown mode {mode}') 84 | out = out + self.qo.zero_point 85 | out.clamp_(self.qo.qmin, self.qo.qmax).round_() 86 | return out 87 | 88 | 89 | class ReScaleMatMul(torch.autograd.Function): 90 | @staticmethod 91 | def symbolic(g, x1, x2, M): 92 | return g.op("ReScale", x1, x2, M) 93 | 94 | @staticmethod 95 | def forward(ctx, x1, x2, M): 96 | multiplier, shift = approximate_float(M) 97 | round_ = 1 << (shift - 1) 98 | out = torch.matmul(x1, x2) 99 | out = (out * multiplier + round_) >> (31 - shift) 100 | return out 101 | -------------------------------------------------------------------------------- /torchquanter/nn/qmaxpool2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, QParam, FakeQuantize 6 | from torchquanter.utils import quantize_tensor 7 | 8 | class QMaxPool2d(QModule): 9 | 10 | def __init__(self, maxpool2d_module: nn.MaxPool2d, qi=False, num_bits=8, signed=True, symmetric_feature=False): 11 | super().__init__(qi=qi, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 12 | self.maxpool2d_module = maxpool2d_module 13 | 14 | def freeze(self, qi=None): 15 | if hasattr(self, 'qi') and qi is not None: 16 | raise ValueError('qi has been provided in init function.') 17 | if not hasattr(self, 'qi') and qi is None: 18 | raise ValueError('qi is not existed, should be provided.') 19 | self.freeze_flag = True 20 | 21 | if qi is not None: 22 | self.qi = qi 23 | return self.qi 24 | 25 | def forward(self, x): 26 | if hasattr(self, 'qi'): 27 | self.qi.update(x) 28 | x = FakeQuantize.apply(x, self.qi) 29 | if self.freeze_flag: 30 | raise Exception(f'{self._get_name()} has been frozen') 31 | 32 | x = F.max_pool2d(x, self.maxpool2d_module.kernel_size, 33 | self.maxpool2d_module.stride, self.maxpool2d_module.padding) 34 | 35 | return x 36 | 37 | def quantize_inference(self, x, **kwargs): 38 | return F.max_pool2d(x, self.maxpool2d_module.kernel_size, 39 | self.maxpool2d_module.stride, self.maxpool2d_module.padding) -------------------------------------------------------------------------------- /torchquanter/nn/qmean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .base import QModule, QParam, FakeQuantize, FloorSTE 5 | from torchquanter.utils import quantize_tensor, approximate_float 6 | 7 | class QMean(QModule): 8 | 9 | def __init__(self, dim=-1, keepdim=False, qi=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 10 | super(QMean, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 11 | self.dim = dim 12 | self.keepdim = keepdim 13 | self.num_bits = num_bits 14 | self.signed = signed 15 | 16 | def freeze(self, qi=None, qo=None): 17 | 18 | if hasattr(self, 'qi') and qi is not None: 19 | raise ValueError('qi has been provided in init function.') 20 | if not hasattr(self, 'qi') and qi is None: 21 | raise ValueError('qi is not existed, should be provided.') 22 | 23 | if hasattr(self, 'qo') and qo is not None: 24 | raise ValueError('qo has been provided in init function.') 25 | if not hasattr(self, 'qo') and qo is None: 26 | raise ValueError('qo is not existed, should be provided.') 27 | self.freeze_flag = True 28 | 29 | if qi is not None: 30 | self.qi = qi 31 | if qo is not None: 32 | self.qo = qo 33 | self.M = self.qi.scale / self.qo.scale 34 | return self.qo 35 | 36 | def forward(self, x): 37 | if hasattr(self, 'qi'): 38 | self.qi.update(x) 39 | x = FakeQuantize.apply(x, self.qi) 40 | if self.freeze_flag: 41 | raise Exception(f'{self._get_name()} has been frozen') 42 | 43 | x = torch.mean(x, dim=self.dim, keepdim=self.keepdim) 44 | 45 | if hasattr(self, 'qo'): 46 | self.qo.update(x) 47 | x = FakeQuantize.apply(x, self.qo) 48 | 49 | return x 50 | 51 | def quantize_inference(self, x, mode=None): 52 | x = x - self.qi.zero_point 53 | x = torch.mean(x, dim=self.dim, keepdim=self.keepdim).floor() 54 | if mode is None: 55 | x = self.M * x 56 | x.round_() 57 | elif mode == 'cmsis_nn': 58 | ReScaleMean.apply(x, self.M) 59 | else: 60 | raise Exception(f'Unknown mode {mode}') 61 | x = x + self.qo.zero_point 62 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 63 | return x 64 | 65 | 66 | class ReScaleMean(torch.autograd.Function): 67 | @staticmethod 68 | def symbolic(g, x, M): 69 | return g.op("ReScale", x, M) 70 | 71 | @staticmethod 72 | def forward(ctx, x, M): 73 | multiplier, shift = approximate_float(M) 74 | round_ = 1 << (shift - 1) 75 | x = (x * multiplier + round_) >> (31 - shift) 76 | return x 77 | -------------------------------------------------------------------------------- /torchquanter/nn/qmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, QParamIO 6 | from torchquanter.utils import broadcast_dim_as, approximate_float 7 | 8 | class QMul(QModule): 9 | """ 10 | Dot produc function 11 | """ 12 | 13 | def __init__(self, mul_const=None, qi1=True, qi2=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 14 | """ 15 | Args 16 | ---------- 17 | mul_const: if not None, x1 * x2 * mul_const 18 | """ 19 | super(QMul, self).__init__(qi=False, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 20 | if qi1: 21 | self.qi1 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 22 | if qi2: 23 | self.qi2 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 24 | self.mul_const = mul_const if mul_const is not None else 1. 25 | self.num_bits = num_bits 26 | self.signed = signed 27 | 28 | def freeze(self, qi1=None, qi2=None, qo=None): 29 | 30 | if hasattr(self, 'qi1') and qi1 is not None: 31 | raise ValueError('qi has been provided in init function.') 32 | if not hasattr(self, 'qi1') and qi1 is None: 33 | raise ValueError('qi is not existed, should be provided.') 34 | 35 | if hasattr(self, 'qi2') and qi2 is not None: 36 | raise ValueError('qi has been provided in init function.') 37 | if not hasattr(self, 'qi2') and qi2 is None: 38 | raise ValueError('qi is not existed, should be provided.') 39 | 40 | if hasattr(self, 'qo') and qo is not None: 41 | raise ValueError('qo has been provided in init function.') 42 | if not hasattr(self, 'qo') and qo is None: 43 | raise ValueError('qo is not existed, should be provided.') 44 | self.freeze_flag = True 45 | 46 | if qi1 is not None: 47 | self.qi1 = qi1 48 | if qi2 is not None: 49 | self.qi2 = qi2 50 | if qo is not None: 51 | self.qo = qo 52 | self.M = self.qi1.scale * self.qi2.scale * self.mul_const / self.qo.scale 53 | return self.qo 54 | 55 | def forward(self, x1, x2): 56 | if hasattr(self, 'qi1'): 57 | self.qi1.update(x1) 58 | x1 = FakeQuantize.apply(x1, self.qi1) 59 | if hasattr(self, 'qi2'): 60 | self.qi2.update(x2) 61 | x2 = FakeQuantize.apply(x2, self.qi2) 62 | if self.freeze_flag: 63 | raise Exception(f'{self._get_name()} has been frozen') 64 | 65 | out = torch.mul(x1, x2) * self.mul_const 66 | 67 | if hasattr(self, 'qo'): 68 | self.qo.update(out) 69 | out = FakeQuantize.apply(out, self.qo) 70 | 71 | return out 72 | 73 | def quantize_inference(self, x1, x2, mode=None): 74 | x1 = x1 - self.qi1.zero_point 75 | x2 = x2 - self.qi2.zero_point 76 | if mode is None: 77 | out = torch.mul(x1, x2) 78 | out = out * self.M 79 | out.round_() 80 | elif mode == 'cmsis_nn': 81 | out = ReScaleMul.apply(x1, x2, self.M) 82 | else: 83 | raise Exception(f'Unknown mode {mode}') 84 | out = out + self.qo.zero_point 85 | out.clamp_(self.qo.qmin, self.qo.qmax).round_() 86 | return out 87 | 88 | 89 | class ReScaleMul(torch.autograd.Function): 90 | @staticmethod 91 | def symbolic(g, x1, x2, M): 92 | return g.op("ReScale", x1, x2, M) 93 | 94 | @staticmethod 95 | def forward(ctx, x1, x2, M): 96 | multiplier, shift = approximate_float(M) 97 | round_ = 1 << (shift - 1) 98 | out = torch.mul(x1, x2) 99 | out = (out * multiplier + round_) >> (31 - shift) 100 | return out 101 | -------------------------------------------------------------------------------- /torchquanter/nn/qnorm.py: -------------------------------------------------------------------------------- 1 | from pickletools import read_uint1 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .base import QModule, QParamW, FakeQuantize, FloorSTE, QuantizeTensor, DequantizeTensor, RoundSTE, ClampSTE 7 | from torchquanter.utils import quantize_tensor, broadcast_dim_as, approximate_float, sqrt_interger, get_qmin_qmax 8 | 9 | class QNorm(QModule): 10 | """ 11 | (x - mean) / var 12 | """ 13 | 14 | def __init__(self, eps=1e-5, qi=True, qo=True, num_bits=8, max_bits=32, signed=True, symmetric_feature=False): 15 | super(QNorm, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 16 | self.num_bits = num_bits 17 | self.max_bits = max_bits 18 | self.signed = signed 19 | self.first_time = True 20 | self.eps = eps 21 | 22 | def freeze(self, qi=None, qo=None): 23 | 24 | if hasattr(self, 'qi') and qi is not None: 25 | raise ValueError('qi has been provided in init function.') 26 | if not hasattr(self, 'qi') and qi is None: 27 | raise ValueError('qi is not existed, should be provided.') 28 | 29 | if hasattr(self, 'qo') and qo is not None: 30 | raise ValueError('qo has been provided in init function.') 31 | if not hasattr(self, 'qo') and qo is None: 32 | raise ValueError('qo is not existed, should be provided.') 33 | self.freeze_flag = True 34 | 35 | if qi is not None: 36 | self.qi = qi 37 | if qo is not None: 38 | self.qo = qo 39 | 40 | self.M = 1 / (self.qo.scale * 2**(self.max_bits - 2)) 41 | return self.qo 42 | 43 | def forward(self, x, qi=None): 44 | """ 45 | here need to get before_layer.qo as norm.qi 46 | """ 47 | if not hasattr(self, 'qi') and qi is None: 48 | raise ValueError('qi is not existed, should be provided.') 49 | if hasattr(self, 'qi') and qi is not None: # for test without before_layer.qo 50 | raise ValueError('qi has been provided in init function.') 51 | if hasattr(self, 'qi') and qi is None: # for test without before_layer.qo 52 | qi = self.qi 53 | qi.update(x) 54 | if self.freeze_flag: 55 | raise Exception(f'{self._get_name()} has been frozen') 56 | 57 | if self.qo.scale.numel() == 0: 58 | mean_ = torch.mean(x, dim=-1, keepdims=True) 59 | var_ = torch.var(x, dim=-1, keepdims=True) 60 | std_ = torch.sqrt(var_ + self.eps) 61 | x = (x - mean_) / std_ 62 | else: 63 | qx = QuantizeTensor.apply(x, qi) 64 | qx = qx - qi.zero_point 65 | 66 | # Interger-only Norm 67 | mean_ = qx.mean(dim=-1, keepdim=True) 68 | sum_ = ClampSTE.apply(torch.sum((qx - mean_)**2, dim=-1, keepdim=True), 69 | *get_qmin_qmax(self.max_bits, signed=True)) # int32, 这里超出去直接裁剪可能不如偏移来得好,先这么做吧 70 | var_ = FloorSTE.apply(sum_ / qx.shape[-1]) 71 | var_[var_ == 0.] = 1. # prevent overflow 72 | std_ = FloorSTE.apply(torch.sqrt(var_)) 73 | factor = FloorSTE.apply(2**(self.max_bits - 1) / std_) 74 | qx = FloorSTE.apply(ClampSTE.apply(((qx - mean_) * factor / 2), *get_qmin_qmax(self.max_bits, signed=True))) 75 | x = qx / 2**(self.max_bits - 2) # 不需要floor因为这个除法是整合到M中去的 76 | 77 | self.qo.update(x) 78 | x = FakeQuantize.apply(x, self.qo) 79 | return x 80 | 81 | def quantize_inference(self, x, mode=None): 82 | x = x - self.qi.zero_point 83 | 84 | # Interger-only LayerNorm 85 | mean_ = x.mean(dim=-1, keepdim=True) # int16 86 | sum_ = torch.sum((x - mean_)**2, dim=-1, keepdim=True).clamp(*get_qmin_qmax(self.max_bits, signed=True)) # 裁剪到32bit范围内 87 | var_ = torch.floor(sum_ / x.shape[-1]) 88 | var_[var_ == 0.] = 1. # prevent overflow 89 | # std_ = sqrt_interger(var_) # 比较费时间,此处快速评估无需使用 90 | std_ = torch.sqrt(var_).floor() 91 | factor = torch.floor(2**(self.max_bits - 1) / std_) 92 | x = torch.floor(torch.clamp((x - mean_) * factor / 2, *get_qmin_qmax(self.max_bits, signed=True))) 93 | 94 | if mode is None: 95 | x = self.M * x 96 | x.round_() 97 | elif mode == 'cmsis_nn': 98 | x = ReScaleNorm.apply(x, self.M) 99 | else: 100 | raise Exception(f'Unknown mode {mode}') 101 | x = x + self.qo.zero_point 102 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 103 | return x 104 | 105 | 106 | class ReScaleNorm(torch.autograd.Function): 107 | @staticmethod 108 | def symbolic(g, x, M): 109 | return g.op("ReScale", x, M) 110 | 111 | @staticmethod 112 | def forward(ctx, x, M): 113 | multiplier, shift = approximate_float(M) 114 | round_ = 1 << (shift - 1) 115 | x = (x * multiplier + round_) >> (31 - shift) 116 | return x 117 | -------------------------------------------------------------------------------- /torchquanter/nn/qrelu.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from .base import QModule, QParam, FakeQuantize 4 | from torchquanter.utils import quantize_tensor 5 | 6 | class QReLU(QModule): 7 | 8 | def __init__(self, qi=False, num_bits=8, signed=True, symmetric_feature=False): 9 | super(QReLU, self).__init__(qi=qi, qo=False, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 10 | 11 | def freeze(self, qi=None): 12 | 13 | if hasattr(self, 'qi') and qi is not None: 14 | raise ValueError('qi has been provided in init function.') 15 | if not hasattr(self, 'qi') and qi is None: 16 | raise ValueError('qi is not existed, should be provided.') 17 | self.freeze_flag = True 18 | 19 | if qi is not None: 20 | self.qi = qi 21 | return self.qi 22 | 23 | def forward(self, x): 24 | if hasattr(self, 'qi'): 25 | self.qi.update(x) 26 | x = FakeQuantize.apply(x, self.qi) 27 | if self.freeze_flag: 28 | raise Exception(f'{self._get_name()} has been frozen') 29 | 30 | x = F.relu(x) 31 | 32 | return x 33 | 34 | def quantize_inference(self, x, **kwargs): 35 | x = x.clone() 36 | x[x < self.qi.zero_point] = self.qi.zero_point 37 | return x -------------------------------------------------------------------------------- /torchquanter/nn/qsigmoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, FloorSTE, QuantizeTensor, DequantizeTensor, RoundSTE 6 | from torchquanter.utils import quantize_tensor, approximate_float 7 | 8 | class QSigmoid(QModule): 9 | 10 | def __init__(self, qi=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 11 | super(QSigmoid, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 12 | self.num_bits = num_bits 13 | self.signed = signed 14 | 15 | def freeze(self, qi=None, qo=None): 16 | 17 | if hasattr(self, 'qi') and qi is not None: 18 | raise ValueError('qi has been provided in init function.') 19 | if not hasattr(self, 'qi') and qi is None: 20 | raise ValueError('qi is not existed, should be provided.') 21 | 22 | if hasattr(self, 'qo') and qo is not None: 23 | raise ValueError('qo has been provided in init function.') 24 | if not hasattr(self, 'qo') and qo is None: 25 | raise ValueError('qo is not existed, should be provided.') 26 | self.freeze_flag = True 27 | 28 | if qi is not None: 29 | self.qi = qi 30 | if qo is not None: 31 | self.qo = qo 32 | return self.qo 33 | 34 | def forward(self, x): 35 | if hasattr(self, 'qi'): 36 | self.qi.update(x) 37 | x = FakeQuantize.apply(x, self.qi) 38 | if self.freeze_flag: 39 | raise Exception(f'{self._get_name()} has been frozen') 40 | 41 | x = F.sigmoid(x) 42 | 43 | if hasattr(self, 'qo'): 44 | self.qo.update(x) 45 | x = FakeQuantize.apply(x, self.qo) 46 | return x 47 | 48 | def quantize_inference(self, x, **kwargs): 49 | x = self.qi.dequantize_tensor(x) 50 | x = F.sigmoid(x) 51 | x = self.qo.quantize_tensor(x) 52 | return x 53 | 54 | -------------------------------------------------------------------------------- /torchquanter/nn/qsoftmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, FloorSTE, QuantizeTensor, DequantizeTensor, RoundSTE 6 | from torchquanter.utils import quantize_tensor, approximate_float 7 | 8 | class QSoftmax(QModule): 9 | 10 | def __init__(self, dim=-1, qi=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 11 | super(QSoftmax, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 12 | self.dim = dim 13 | self.num_bits = num_bits 14 | self.signed = signed 15 | self._init_qo(qo) 16 | 17 | def _init_qo(self, qo): 18 | if qo is True: 19 | self.qo.scale = torch.tensor(1 / 256., dtype=torch.float32, device=self.qo.scale.device) 20 | self.qo.zero_point = torch.tensor(-128., dtype=torch.float32, device=self.qo.scale.device) 21 | 22 | def freeze(self, qi=None, qo=None): 23 | 24 | if hasattr(self, 'qi') and qi is not None: 25 | raise ValueError('qi has been provided in init function.') 26 | if not hasattr(self, 'qi') and qi is None: 27 | raise ValueError('qi is not existed, should be provided.') 28 | 29 | if hasattr(self, 'qo') and qo is not None: 30 | raise ValueError('qo has been provided in init function.') 31 | if not hasattr(self, 'qo') and qo is None: 32 | raise ValueError('qo is not existed, should be provided.') 33 | self.freeze_flag = True 34 | 35 | if qi is not None: 36 | self.qi = qi 37 | if qo is not None: 38 | self.qo = qo 39 | return self.qo 40 | 41 | def forward(self, x): 42 | if hasattr(self, 'qi'): 43 | self.qi.update(x) 44 | x = FakeQuantize.apply(x, self.qi) 45 | if self.freeze_flag: 46 | raise Exception(f'{self._get_name()} has been frozen') 47 | 48 | x = F.softmax(x, dim=self.dim) 49 | 50 | # default qo.scale = 1/256, qo.zero_point=-128 51 | if hasattr(self, 'qo'): 52 | x = FakeQuantize.apply(x, self.qo) 53 | 54 | return x 55 | 56 | def quantize_inference(self, x, **kwargs): 57 | x = self.qi.dequantize_tensor(x) 58 | x = F.softmax(x, dim=self.dim) 59 | x = self.qo.quantize_tensor(x) 60 | return x 61 | 62 | 63 | 64 | ''' 65 | # i-bert softmax quant(error) 66 | # (https://github.com/huggingface/transformers/blob/main/src/transformers/models/ibert/quant_modules.py) 67 | 68 | class QSoftmax(QModule): 69 | 70 | def __init__(self, qi=True, qo=True, num_bits=8, signed=True): 71 | super(QSoftmax, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed) 72 | assert qo == True, "qo must be True in QSoftmax" 73 | self.num_bits = num_bits 74 | self.signed = signed 75 | self.max_bit = 32 76 | self.x0 = -0.6931 # -ln2 77 | self.const = 30 # dummy integer constant 78 | self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c 79 | self.coef[1] /= self.coef[0] 80 | self.coef[2] /= self.coef[0] 81 | 82 | def freeze(self, qi=None, qo=None): 83 | 84 | if hasattr(self, 'qi') and qi is not None: 85 | raise ValueError('qi has been provided in init function.') 86 | if not hasattr(self, 'qi') and qi is None: 87 | raise ValueError('qi is not existed, should be provided.') 88 | 89 | if hasattr(self, 'qo') and qo is not None: 90 | raise ValueError('qo has been provided in init function.') 91 | if not hasattr(self, 'qo') and qo is None: 92 | raise ValueError('qo is not existed, should be provided.') 93 | 94 | if qi is not None: 95 | self.qi = qi 96 | if qo is not None: 97 | self.qo = qo 98 | 99 | def int_polynomial(self, x_int, scaling_factor): 100 | with torch.no_grad(): 101 | b_int = torch.floor(self.coef[1] / scaling_factor) 102 | c_int = torch.floor(self.coef[2] / scaling_factor**2) 103 | z = (x_int + b_int) * x_int + c_int 104 | scaling_factor = self.coef[0] * scaling_factor**2 105 | return z, scaling_factor 106 | 107 | def int_exp(self, x_int, scaling_factor): 108 | with torch.no_grad(): 109 | x0_int = torch.floor(self.x0 / scaling_factor) 110 | x_int = torch.max(x_int, self.const * x0_int) 111 | 112 | q = FloorSTE.apply(x_int / x0_int) 113 | r = x_int - x0_int * q 114 | exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor) 115 | exp_int = torch.clamp(FloorSTE.apply(exp_int * 2 ** (self.const - q)), min=0) 116 | scaling_factor = exp_scaling_factor / 2**self.const 117 | return exp_int, scaling_factor 118 | 119 | def forward(self, x, qi=None): 120 | if not hasattr(self, 'qi') and qi is None: 121 | raise ValueError('qi is not existed, should be provided.') 122 | if hasattr(self, 'qi') and qi is None: # for test without before_layer.qo 123 | qi = self.qi 124 | qi.update(x) 125 | 126 | x_int = QuantizeTensor.apply(x, qi) 127 | 128 | x_int_max, _ = x_int.max(dim=-1, keepdim=True) 129 | x_int = x_int - x_int_max 130 | exp_int, exp_scaling_factor = self.int_exp(x_int, qi.scale) 131 | 132 | exp_int_sum = exp_int.sum(dim=-1, keepdim=True) 133 | factor = FloorSTE.apply(2**self.max_bit / exp_int_sum) 134 | exp_int = FloorSTE.apply(exp_int * factor / 2 ** (self.max_bit - self.num_bits)) 135 | scaling_factor = 1 / 2**self.num_bits 136 | 137 | # update qo 138 | self.qo.scale = torch.tensor(scaling_factor, dtype=qi.scale.dtype, device=qi.scale.device) 139 | self.qo.zero_point = torch.tensor(0., dtype=qi.scale.dtype, device=qi.zero_point.device) 140 | 141 | x = DequantizeTensor.apply(exp_int, self.qo) 142 | return x 143 | 144 | def quantize_inference(self, x, mode=None): 145 | pass 146 | # x = x - self.qi.zero_point 147 | # x = self.fc_module(x) 148 | # if mode is None: 149 | # x = self.M * x 150 | # x.round_() 151 | # elif mode == 'cmsis_nn': 152 | # multiplier, shift = approximate_float(self.M) 153 | # x = (x * multiplier) >> (31 - shift) 154 | # else: 155 | # raise Exception(f'Unknown mode {mode}') 156 | # x = x + self.qo.zero_point 157 | # x.clamp_(self.qo.qmin, self.qo.qmax).round_() 158 | # return x 159 | ''' -------------------------------------------------------------------------------- /torchquanter/nn/qsoftmax_w_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, FloorSTE, QuantizeTensor, DequantizeTensor, RoundSTE 6 | from torchquanter.utils import quantize_tensor, approximate_float 7 | 8 | class QSoftmax_W_Policy(QModule): 9 | 10 | def __init__(self, dim=-1, qi=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 11 | super(QSoftmax_W_Policy, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 12 | self.dim = dim 13 | self.num_bits = num_bits 14 | self.signed = signed 15 | self._init_qo(qo) 16 | 17 | def _init_qo(self, qo): 18 | if qo is True: 19 | self.qo.scale = torch.tensor(1 / 256., dtype=torch.float32, device=self.qo.scale.device) 20 | self.qo.zero_point = torch.tensor(-128., dtype=torch.float32, device=self.qo.scale.device) 21 | 22 | def freeze(self, qi=None, qo=None): 23 | 24 | if hasattr(self, 'qi') and qi is not None: 25 | raise ValueError('qi has been provided in init function.') 26 | if not hasattr(self, 'qi') and qi is None: 27 | raise ValueError('qi is not existed, should be provided.') 28 | 29 | if hasattr(self, 'qo') and qo is not None: 30 | raise ValueError('qo has been provided in init function.') 31 | if not hasattr(self, 'qo') and qo is None: 32 | raise ValueError('qo is not existed, should be provided.') 33 | self.freeze_flag = True 34 | 35 | if qi is not None: 36 | self.qi = qi 37 | if qo is not None: 38 | self.qo = qo 39 | return self.qo 40 | 41 | def forward(self, x, policy): 42 | if hasattr(self, 'qi'): 43 | self.qi.update(x) 44 | x = FakeQuantize.apply(x, self.qi) 45 | if self.freeze_flag: 46 | raise Exception(f'{self._get_name()} has been frozen') 47 | if policy is None: 48 | x = F.softmax(x, dim=self.dim) 49 | else: 50 | x = self.softmax_with_policy(x, policy) 51 | 52 | # default qo.scale = 1/256, qo.zero_point=-128 53 | if hasattr(self, 'qo'): 54 | x = FakeQuantize.apply(x, self.qo) 55 | 56 | return x 57 | 58 | def quantize_inference(self, x, policy, **kwargs): 59 | x = self.qi.dequantize_tensor(x) 60 | if policy is None: 61 | x = F.softmax(x, dim=self.dim) 62 | else: 63 | x = self.softmax_with_policy(x, policy) 64 | x = self.qo.quantize_tensor(x) 65 | return x 66 | 67 | def softmax_with_policy(self, attn, policy, eps=1e-6): 68 | B, N, _ = policy.size() 69 | B, H, N, N = attn.size() 70 | attn_policy = policy.reshape(B, 1, 1, N) # * policy.reshape(B, 1, N, 1) 71 | eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N) 72 | attn_policy = attn_policy + (1.0 - attn_policy) * eye 73 | max_att = torch.max(attn, dim=-1, keepdim=True)[0] 74 | attn = attn - max_att 75 | # attn = attn.exp_() * attn_policy 76 | # return attn / attn.sum(dim=-1, keepdim=True) 77 | 78 | # for stable training 79 | attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32) 80 | attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps) 81 | return attn.type_as(max_att) 82 | 83 | 84 | ''' 85 | # i-bert softmax quant(error) 86 | # (https://github.com/huggingface/transformers/blob/main/src/transformers/models/ibert/quant_modules.py) 87 | 88 | class QSoftmax(QModule): 89 | 90 | def __init__(self, qi=True, qo=True, num_bits=8, signed=True): 91 | super(QSoftmax, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed) 92 | assert qo == True, "qo must be True in QSoftmax" 93 | self.num_bits = num_bits 94 | self.signed = signed 95 | self.max_bit = 32 96 | self.x0 = -0.6931 # -ln2 97 | self.const = 30 # dummy integer constant 98 | self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c 99 | self.coef[1] /= self.coef[0] 100 | self.coef[2] /= self.coef[0] 101 | 102 | def freeze(self, qi=None, qo=None): 103 | 104 | if hasattr(self, 'qi') and qi is not None: 105 | raise ValueError('qi has been provided in init function.') 106 | if not hasattr(self, 'qi') and qi is None: 107 | raise ValueError('qi is not existed, should be provided.') 108 | 109 | if hasattr(self, 'qo') and qo is not None: 110 | raise ValueError('qo has been provided in init function.') 111 | if not hasattr(self, 'qo') and qo is None: 112 | raise ValueError('qo is not existed, should be provided.') 113 | 114 | if qi is not None: 115 | self.qi = qi 116 | if qo is not None: 117 | self.qo = qo 118 | 119 | def int_polynomial(self, x_int, scaling_factor): 120 | with torch.no_grad(): 121 | b_int = torch.floor(self.coef[1] / scaling_factor) 122 | c_int = torch.floor(self.coef[2] / scaling_factor**2) 123 | z = (x_int + b_int) * x_int + c_int 124 | scaling_factor = self.coef[0] * scaling_factor**2 125 | return z, scaling_factor 126 | 127 | def int_exp(self, x_int, scaling_factor): 128 | with torch.no_grad(): 129 | x0_int = torch.floor(self.x0 / scaling_factor) 130 | x_int = torch.max(x_int, self.const * x0_int) 131 | 132 | q = FloorSTE.apply(x_int / x0_int) 133 | r = x_int - x0_int * q 134 | exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor) 135 | exp_int = torch.clamp(FloorSTE.apply(exp_int * 2 ** (self.const - q)), min=0) 136 | scaling_factor = exp_scaling_factor / 2**self.const 137 | return exp_int, scaling_factor 138 | 139 | def forward(self, x, qi=None): 140 | if not hasattr(self, 'qi') and qi is None: 141 | raise ValueError('qi is not existed, should be provided.') 142 | if hasattr(self, 'qi') and qi is None: # for test without before_layer.qo 143 | qi = self.qi 144 | qi.update(x) 145 | 146 | x_int = QuantizeTensor.apply(x, qi) 147 | 148 | x_int_max, _ = x_int.max(dim=-1, keepdim=True) 149 | x_int = x_int - x_int_max 150 | exp_int, exp_scaling_factor = self.int_exp(x_int, qi.scale) 151 | 152 | exp_int_sum = exp_int.sum(dim=-1, keepdim=True) 153 | factor = FloorSTE.apply(2**self.max_bit / exp_int_sum) 154 | exp_int = FloorSTE.apply(exp_int * factor / 2 ** (self.max_bit - self.num_bits)) 155 | scaling_factor = 1 / 2**self.num_bits 156 | 157 | # update qo 158 | self.qo.scale = torch.tensor(scaling_factor, dtype=qi.scale.dtype, device=qi.scale.device) 159 | self.qo.zero_point = torch.tensor(0., dtype=qi.scale.dtype, device=qi.zero_point.device) 160 | 161 | x = DequantizeTensor.apply(exp_int, self.qo) 162 | return x 163 | 164 | def quantize_inference(self, x, mode=None): 165 | pass 166 | # x = x - self.qi.zero_point 167 | # x = self.fc_module(x) 168 | # if mode is None: 169 | # x = self.M * x 170 | # x.round_() 171 | # elif mode == 'cmsis_nn': 172 | # multiplier, shift = approximate_float(self.M) 173 | # x = (x * multiplier) >> (31 - shift) 174 | # else: 175 | # raise Exception(f'Unknown mode {mode}') 176 | # x = x + self.qo.zero_point 177 | # x.clamp_(self.qo.qmin, self.qo.qmax).round_() 178 | # return x 179 | ''' -------------------------------------------------------------------------------- /torchquanter/nn/qsqrt.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .base import DequantizeTensor, QModule, QParam, FakeQuantize, FloorSTE, QuantizeTensor, RoundSTE 6 | from torchquanter.utils import quantize_tensor, approximate_float, sqrt_interger 7 | 8 | class QSqrt(QModule): 9 | 10 | def __init__(self, qi=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 11 | assert qo == True 12 | super(QSqrt, self).__init__(qi=qi, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 13 | self.num_bits = num_bits 14 | self.signed = signed 15 | self.first_time = True 16 | 17 | def freeze(self, qi=None, qo=None): 18 | 19 | if hasattr(self, 'qi') and qi is not None: 20 | raise ValueError('qi has been provided in init function.') 21 | if not hasattr(self, 'qi') and qi is None: 22 | raise ValueError('qi is not existed, should be provided.') 23 | 24 | if hasattr(self, 'qo') and qo is not None: 25 | raise ValueError('qo has been provided in init function.') 26 | if not hasattr(self, 'qo') and qo is None: 27 | raise ValueError('qo is not existed, should be provided.') 28 | self.freeze_flag = True 29 | 30 | if qi is not None: 31 | self.qi = qi 32 | if qo is not None: 33 | self.qo = qo 34 | self.M = torch.sqrt(self.qi.scale) / self.qo.scale 35 | return self.qo 36 | 37 | def forward(self, x, qi=None): 38 | if not hasattr(self, 'qi') and qi is None: 39 | raise ValueError('qi is not existed, should be provided.') 40 | if hasattr(self, 'qi') and qi is not None: 41 | raise ValueError('qi has been provided in init function.') 42 | 43 | if hasattr(self, 'qi'): 44 | qi = self.qi 45 | qi.update(x) 46 | if self.freeze_flag: 47 | raise Exception(f'{self._get_name()} has been frozen') 48 | 49 | if self.first_time: 50 | out = torch.sqrt(x) 51 | self.qo.update(out) 52 | self.first_time=False 53 | else: 54 | qx = QuantizeTensor.apply(x, qi) 55 | 56 | qx = qx - qi.zero_point 57 | qx = FloorSTE.apply(torch.sqrt(qx)) 58 | qx = RoundSTE.apply(torch.sqrt(qi.scale) / self.qo.scale * qx) 59 | qx = qx + self.qo.zero_point 60 | 61 | x = DequantizeTensor.apply(qx, self.qo) 62 | 63 | if hasattr(self, 'qo'): 64 | self.qo.update(x) 65 | x = FakeQuantize.apply(x, self.qo) 66 | 67 | return x 68 | 69 | def quantize_inference(self, x, mode=None): 70 | x = x - self.qi.zero_point 71 | x = sqrt_interger(x) 72 | if mode is None: 73 | x = self.M * x 74 | x.round_() 75 | elif mode == 'cmsis_nn': 76 | x = ReScaleSqrt.apply(x, self.M) 77 | else: 78 | raise Exception(f'Unknown mode {mode}') 79 | x = x + self.qo.zero_point 80 | x.clamp_(self.qo.qmin, self.qo.qmax).round_() 81 | return x 82 | 83 | 84 | class ReScaleSqrt(torch.autograd.Function): 85 | @staticmethod 86 | def symbolic(g, x, M): 87 | return g.op("ReScale", x, M) 88 | 89 | @staticmethod 90 | def forward(ctx, x, M): 91 | multiplier, shift = approximate_float(M) 92 | round_ = 1 << (shift - 1) 93 | x = (x * multiplier + round_) >> (31 - shift) 94 | return x 95 | -------------------------------------------------------------------------------- /torchquanter/nn/qsub.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import QModule, FakeQuantize, QParamIO 6 | from torchquanter.utils import broadcast_dim_as, approximate_float 7 | 8 | class QSub(QModule): 9 | 10 | def __init__(self, qi1=True, qi2=True, qo=True, num_bits=8, signed=True, symmetric_feature=False): 11 | """ 12 | return 13 | ---------- 14 | output: x1 - x2 15 | """ 16 | super(QSub, self).__init__(qi=False, qo=qo, num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 17 | if qi1: 18 | self.qi1 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 19 | if qi2: 20 | self.qi2 = QParamIO(num_bits=num_bits, signed=signed, symmetric=symmetric_feature) 21 | self.num_bits = num_bits 22 | self.signed = signed 23 | 24 | M1 = torch.tensor([], requires_grad=False) 25 | self.register_buffer('M1', M1) 26 | M2 = torch.tensor([], requires_grad=False) 27 | self.register_buffer('M2', M2) 28 | 29 | def freeze(self, qi1=None, qi2=None, qo=None): 30 | 31 | if hasattr(self, 'qi1') and qi1 is not None: 32 | raise ValueError('qi has been provided in init function.') 33 | if not hasattr(self, 'qi1') and qi1 is None: 34 | raise ValueError('qi is not existed, should be provided.') 35 | 36 | if hasattr(self, 'qi2') and qi2 is not None: 37 | raise ValueError('qi has been provided in init function.') 38 | if not hasattr(self, 'qi2') and qi2 is None: 39 | raise ValueError('qi is not existed, should be provided.') 40 | 41 | if hasattr(self, 'qo') and qo is not None: 42 | raise ValueError('qo has been provided in init function.') 43 | if not hasattr(self, 'qo') and qo is None: 44 | raise ValueError('qo is not existed, should be provided.') 45 | self.freeze_flag = True 46 | 47 | if qi1 is not None: 48 | self.qi1 = qi1 49 | if qi2 is not None: 50 | self.qi2 = qi2 51 | if qo is not None: 52 | self.qo = qo 53 | self.M1 = self.qi1.scale / self.qo.scale 54 | self.M2 = self.qi2.scale / self.qo.scale 55 | return self.qo 56 | 57 | def forward(self, x1, x2): 58 | if hasattr(self, 'qi1'): 59 | self.qi1.update(x1) 60 | x1 = FakeQuantize.apply(x1, self.qi1) 61 | if hasattr(self, 'qi2'): 62 | self.qi2.update(x2) 63 | x2 = FakeQuantize.apply(x2, self.qi2) 64 | if self.freeze_flag: 65 | raise Exception(f'{self._get_name()} has been frozen') 66 | 67 | out = x1 - x2 68 | 69 | if hasattr(self, 'qo'): 70 | self.qo.update(out) 71 | out = FakeQuantize.apply(out, self.qo) 72 | 73 | return out 74 | 75 | def quantize_inference(self, x1, x2, mode=None): 76 | x1 = x1 - self.qi1.zero_point 77 | x2 = x2 - self.qi2.zero_point 78 | if mode is None: 79 | x1 = self.M1 * x1 80 | x2 = self.M2 * x2 81 | out = x1 - x2 82 | out.round_() 83 | elif mode == 'cmsis_nn': 84 | multiplier1, shift1 = approximate_float(self.M1) 85 | round1 = 1 << (shift1 - 1) 86 | multiplier2, shift2 = approximate_float(self.M2) 87 | round2 = 1 << (shift2 - 1) 88 | 89 | x1 = (x1 * multiplier1 + round1) >> (31 - shift1) 90 | x2 = (x2 * multiplier2 + round2) >> (31 - shift2) 91 | out = x1 - x2 92 | else: 93 | raise Exception(f'Unknown mode {mode}') 94 | out = out + self.qo.zero_point 95 | out.clamp_(self.qo.qmin, self.qo.qmax).round_() 96 | return out -------------------------------------------------------------------------------- /torchquanter/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /torchquanter/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | 11 | def get_qmin_qmax(num_bits, signed): 12 | if signed: 13 | qmin = - 2. ** (num_bits - 1) 14 | qmax = 2. ** (num_bits - 1) - 1 15 | else: 16 | qmin = 0. 17 | qmax = 2.**num_bits - 1. 18 | return qmin, qmax 19 | 20 | def calcScaleZeroPoint(min_val: torch.Tensor, max_val: torch.Tensor, num_bits=8, signed=True, symmetric=False, eps=1e-6): 21 | """ 22 | calculate scale and zero point for quantization 23 | """ 24 | qmin, qmax = get_qmin_qmax(num_bits, signed) 25 | 26 | if not symmetric: 27 | if max_val != min_val: 28 | scale = (max_val - min_val) / (qmax - qmin) # S=(rmax-rmin)/(qmax-qmin) 29 | else: 30 | scale = max_val / (qmax - qmin) 31 | zero_point = qmax - max_val / scale # Z=round(qmax-rmax/scale) 32 | 33 | zero_point = torch.where(zero_point < qmin, torch.tensor(qmin, dtype=zero_point.dtype, device=zero_point.device), zero_point) 34 | zero_point = torch.where(zero_point > qmax, torch.tensor(qmax, dtype=zero_point.dtype, device=zero_point.device), zero_point) 35 | else: 36 | scale = torch.where(max_val.abs().data > min_val.abs().data, max_val.abs().data, min_val.abs().data) / max(abs(qmax), abs(qmin)) 37 | scale[scale == 0] = eps 38 | zero_point = torch.zeros(max_val.shape, dtype=min_val.dtype, device=max_val.device) 39 | 40 | zero_point.round_() 41 | assert 0 not in scale 42 | return scale.to(min_val.device), zero_point.to(min_val.device) 43 | 44 | def quantize_tensor(x: torch.Tensor, scale, zero_point, num_bits=8, signed=True): 45 | """ 46 | use scale and zero_point to quantize tensor 47 | """ 48 | if not isinstance(scale, torch.Tensor): 49 | scale = torch.tensor(scale, dtype=x.dtype, device=x.device) 50 | if not isinstance(zero_point, torch.Tensor): 51 | zero_point = torch.tensor(zero_point, dtype=x.dtype, device=x.device) 52 | 53 | qmin, qmax = get_qmin_qmax(num_bits, signed) 54 | scale_ = broadcast_dim_as(scale, x) 55 | zero_point_ = broadcast_dim_as(zero_point, x) 56 | 57 | q_x = zero_point_ + x / scale_ 58 | q_x.clamp_(qmin, qmax).round_() # q=round(clip(r/S+Z)) 59 | 60 | return q_x.float() # 由于pytorch不支持int类型的运算,因此我们还是用float来表示整数 61 | 62 | def dequantize_tensor(q_x, scale, zero_point): 63 | scale_ = broadcast_dim_as(scale, q_x) 64 | zero_point_ = broadcast_dim_as(zero_point, q_x) 65 | return scale_ * (q_x - zero_point_) # r=S(q-Z) 66 | 67 | def broadcast_dim_as(tensor: torch.Tensor, x: torch.Tensor, dim=0): 68 | """ 69 | broadcast tensor to x's dimension. 70 | 71 | Args 72 | ---------- 73 | tensor: the tensor to be broadcasted 74 | x: target tensor dimensions 75 | dim: which dimension to keep 76 | 77 | e.g.: 78 | x.shape (32, 1, 3, 3) 79 | tensor.shape (32) 80 | 81 | after broadcast: tensor shape is (32, 1, 1, 1) 82 | """ 83 | assert tensor.dim() <= 1, 'tensor dimension must be 0 or 1' 84 | dims = [1 if i != dim else -1 for i in range(x.dim())] 85 | return tensor.view(dims) 86 | 87 | def approximate_float(M): 88 | """ 89 | approximate float with multiplier and shift. 90 | ``` 91 | float = multiplier / 2^(31 - shift) 92 | float = multiplier >> (31 - shift) 93 | ``` 94 | 95 | Args 96 | ---------- 97 | M: float number 98 | 99 | Return 100 | ---------- 101 | multiplier: torch.float32(real: int32) 102 | shift: torch.int32(real: int8) 103 | """ 104 | significand, shift = torch.frexp(M) 105 | significand_q31 = torch.round(significand * (1 << 31)) 106 | 107 | # to torch tensor 108 | return significand_q31, shift 109 | 110 | def sqrt_interger(tensor: torch.Tensor): 111 | """ 112 | Newton’s method to find root of a number, which is the element of tensor 113 | """ 114 | tensor.round_() # make sure the element of tensor is interger 115 | 116 | std_ = torch.zeros(tensor.flatten().shape, dtype=tensor.dtype, device=tensor.device) 117 | for i, n in enumerate(tensor.flatten()): 118 | x = n # Assuming the sqrt of n as n only 119 | if x == 0.: 120 | continue 121 | 122 | count = 0 123 | while (1) : 124 | count += 1 125 | root = ((x + (n / x).floor()) >> 1).floor() # Calculate more closed x 126 | if root >= x: # Check for closeness 127 | std_[i] = x 128 | break 129 | x = root # Update root 130 | 131 | std_ = std_.view_as(tensor) 132 | return std_ --------------------------------------------------------------------------------