├── .gitignore ├── LICENSE ├── README.md ├── data └── imagenet.py ├── hubconf.py ├── main_imagenet.py ├── models ├── __init__.py ├── mnasnet.py ├── mobilenetv2.py ├── regnet.py └── resnet.py ├── quant ├── __init__.py ├── adaptive_rounding.py ├── block_recon.py ├── data_utils.py ├── fold_bn.py ├── layer_recon.py ├── quant_block.py ├── quant_layer.py ├── quant_model.py ├── set_act_quantize_params.py └── set_weight_quantize_params.py └── run_script.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.egg 3 | *.egg-info 4 | *.log 5 | *.pyc 6 | *.npy 7 | *.pth 8 | .DS_Store 9 | __pycache__ 10 | model_zoo -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

PD-Quant

3 |

PD-Quant: Post-Training Quantization Based on Prediction Difference Metric

4 | 5 |
6 | [arXiv] 7 |
8 |
9 | 10 | ## Usage 11 | ### 1. Download pre-trained FP model. 12 | The pre-trained FP models in our experiment comes from [BRECQ](https://github.com/yhhhli/BRECQ), they can be downloaded in [link](https://github.com/yhhhli/BRECQ/releases/tag/v1.0). 13 | And modify the path of the pre-trained model in ```hubconf.py```. 14 | 15 | ### 2. Installation. 16 | ``` 17 | python >= 3.7.13 18 | numpy >= 1.21.6 19 | torch >= 1.11.0 20 | torchvision >= 0.12.0 21 | ``` 22 | 23 | ### 3. Run experiments 24 | You can run ```run_script.py``` for different models including ResNet18, ResNet50, RegNet600, RegNet3200, MobilenetV2, and MNasNet. 25 | It will experiment on 4 bit settings including W2A2, W4A2, W2A4, and W4A4. 26 | 27 | Take ResNet18 as an example: 28 | ``` 29 | python run_script.py resnet18 30 | ``` 31 | 32 | ## Results 33 | 34 | | Methods | Bits (W/A) | Res18 |Res50 | MNV2 | Reg600M | Reg3.2G | MNasx2 | 35 | | ------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 36 | | Full Prec. | 32/32 | 71.01 | 76.63 | 72.62 | 73.52 | 78.46 | 76.52 | 37 | |PD-Quant| 4/4 | 69.30 | 75.09 | 68.33 | 71.04 | 76.57 | 73.30 | 38 | |PD-Quant| 2/4 | 65.07 | 70.92 | 55.27 | 64.00 | 72.43 | 63.33| 39 | |PD-Quant| 4/2 | 58.65 | 64.18 | 20.40 | 51.29 | 62.76 | 38.89 | 40 | |PD-Quant| 2/2 | 53.08 | 56.98 | 14.17 | 40.92 | 55.13 | 28.03| 41 | 42 | ## Reference 43 | ``` 44 | @article{liu2022pd, 45 | title={PD-Quant: Post-Training Quantization based on Prediction Difference Metric}, 46 | author={Liu, Jiawei and Niu, Lin and Yuan, Zhihang and Yang, Dawei and Wang, Xinggang and Liu, Wenyu}, 47 | journal={arXiv preprint arXiv:2212.07048}, 48 | year={2022} 49 | } 50 | ``` 51 | 52 | ## Thanks 53 | Our code is based on [QDROP](https://github.com/wimh966/QDrop) by @wimh966. 54 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | 6 | 7 | def build_imagenet_data(data_path: str = '', input_size: int = 224, batch_size: int = 64, workers: int = 4): 8 | print('==> Using Pytorch Dataset') 9 | 10 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 11 | std=[0.229, 0.224, 0.225]) 12 | traindir = os.path.join(data_path, 'train') 13 | valdir = os.path.join(data_path, 'val') 14 | train_dataset = datasets.ImageFolder( 15 | traindir, 16 | transforms.Compose([ 17 | transforms.RandomResizedCrop(input_size), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | normalize, 21 | ])) 22 | 23 | train_loader = torch.utils.data.DataLoader( 24 | train_dataset, batch_size=batch_size, shuffle=True, 25 | num_workers=workers, pin_memory=True) 26 | val_loader = torch.utils.data.DataLoader( 27 | datasets.ImageFolder(valdir, transforms.Compose([ 28 | transforms.Resize(256), 29 | transforms.CenterCrop(input_size), 30 | transforms.ToTensor(), 31 | normalize, 32 | ])), 33 | batch_size=batch_size, shuffle=False, 34 | num_workers=workers, pin_memory=True) 35 | return train_loader, val_loader 36 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from models.resnet import resnet18 as _resnet18 3 | from models.resnet import resnet50 as _resnet50 4 | from models.mobilenetv2 import mobilenetv2 as _mobilenetv2 5 | from models.mnasnet import mnasnet as _mnasnet 6 | from models.regnet import regnetx_600m as _regnetx_600m 7 | from models.regnet import regnetx_3200m as _regnetx_3200m 8 | import torch 9 | dependencies = ['torch'] 10 | model_path = { 11 | 'resnet18': '/home/tmp/resnet18_imagenet.pth.tar', 12 | 'resnet50': '/home/tmp/resnet50_imagenet.pth.tar', 13 | 'mbv2': '/home/tmp/mobilenetv2.pth.tar', 14 | 'reg600m': '/home/tmp/regnet_600m.pth.tar', 15 | 'reg3200m': '/home/tmp/regnet_3200m.pth.tar', 16 | 'mnasnet': '/home/tmp/mnasnet.pth.tar', 17 | } 18 | 19 | 20 | def resnet18(pretrained=False, **kwargs): 21 | # Call the model, load pretrained weights 22 | model = _resnet18(**kwargs) 23 | if pretrained: 24 | checkpoint = torch.load(model_path['resnet18'], map_location='cpu') 25 | model.load_state_dict(checkpoint) 26 | return model 27 | 28 | 29 | def resnet50(pretrained=False, **kwargs): 30 | # Call the model, load pretrained weights 31 | model = _resnet50(**kwargs) 32 | if pretrained: 33 | checkpoint = torch.load(model_path['resnet50'], map_location='cpu') 34 | model.load_state_dict(checkpoint) 35 | return model 36 | 37 | 38 | def mobilenetv2(pretrained=False, **kwargs): 39 | # Call the model, load pretrained weights 40 | model = _mobilenetv2(**kwargs) 41 | if pretrained: 42 | checkpoint = torch.load(model_path['mbv2'], map_location='cpu') 43 | model.load_state_dict(checkpoint['model']) 44 | return model 45 | 46 | 47 | def regnetx_600m(pretrained=False, **kwargs): 48 | # Call the model, load pretrained weights 49 | model = _regnetx_600m(**kwargs) 50 | if pretrained: 51 | checkpoint = torch.load(model_path['reg600m'], map_location='cpu') 52 | model.load_state_dict(checkpoint) 53 | return model 54 | 55 | 56 | def regnetx_3200m(pretrained=False, **kwargs): 57 | # Call the model, load pretrained weights 58 | model = _regnetx_3200m(**kwargs) 59 | if pretrained: 60 | checkpoint = torch.load(model_path['reg3200m'], map_location='cpu') 61 | model.load_state_dict(checkpoint) 62 | return model 63 | 64 | 65 | def mnasnet(pretrained=False, **kwargs): 66 | # Call the model, load pretrained weights 67 | model = _mnasnet(**kwargs) 68 | if pretrained: 69 | checkpoint = torch.load(model_path['mnasnet'], map_location='cpu') 70 | model.load_state_dict(checkpoint) 71 | return model 72 | -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | import os 6 | import random 7 | import time 8 | import hubconf # noqa: F401 9 | import copy 10 | from quant import ( 11 | block_reconstruction, 12 | layer_reconstruction, 13 | BaseQuantBlock, 14 | QuantModule, 15 | QuantModel, 16 | set_weight_quantize_params, 17 | ) 18 | from data.imagenet import build_imagenet_data 19 | 20 | 21 | def seed_all(seed=1029): 22 | random.seed(seed) 23 | os.environ['PYTHONHASHSEED'] = str(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.cudnn.deterministic = True 30 | 31 | 32 | class AverageMeter(object): 33 | """Computes and stores the average and current value""" 34 | def __init__(self, name, fmt=':f'): 35 | self.name = name 36 | self.fmt = fmt 37 | self.reset() 38 | 39 | def reset(self): 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | 45 | def update(self, val, n=1): 46 | self.val = val 47 | self.sum += val * n 48 | self.count += n 49 | self.avg = self.sum / self.count 50 | 51 | def __str__(self): 52 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 53 | return fmtstr.format(**self.__dict__) 54 | 55 | 56 | class ProgressMeter(object): 57 | def __init__(self, num_batches, meters, prefix=""): 58 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 59 | self.meters = meters 60 | self.prefix = prefix 61 | 62 | def display(self, batch): 63 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 64 | entries += [str(meter) for meter in self.meters] 65 | print('\t'.join(entries)) 66 | 67 | def _get_batch_fmtstr(self, num_batches): 68 | num_digits = len(str(num_batches // 1)) 69 | fmt = '{:' + str(num_digits) + 'd}' 70 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 71 | 72 | 73 | def accuracy(output, target, topk=(1,)): 74 | """Computes the accuracy over the k top predictions for the specified values of k""" 75 | with torch.no_grad(): 76 | maxk = max(topk) 77 | batch_size = target.size(0) 78 | 79 | _, pred = output.topk(maxk, 1, True, True) 80 | pred = pred.t() 81 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 82 | 83 | res = [] 84 | for k in topk: 85 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 86 | res.append(correct_k.mul_(100.0 / batch_size)) 87 | return res 88 | 89 | @torch.no_grad() 90 | def validate_model(val_loader, model, device=None, print_freq=100): 91 | if device is None: 92 | device = next(model.parameters()).device 93 | else: 94 | model.to(device) 95 | batch_time = AverageMeter('Time', ':6.3f') 96 | top1 = AverageMeter('Acc@1', ':6.2f') 97 | top5 = AverageMeter('Acc@5', ':6.2f') 98 | progress = ProgressMeter( 99 | len(val_loader), 100 | [batch_time, top1, top5], 101 | prefix='Test: ') 102 | 103 | # switch to evaluate mode 104 | model.eval() 105 | 106 | end = time.time() 107 | for i, (images, target) in enumerate(val_loader): 108 | images = images.to(device) 109 | target = target.to(device) 110 | 111 | # compute output 112 | output = model(images) 113 | 114 | # measure accuracy and record loss 115 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 116 | top1.update(acc1[0], images.size(0)) 117 | top5.update(acc5[0], images.size(0)) 118 | 119 | # measure elapsed time 120 | batch_time.update(time.time() - end) 121 | end = time.time() 122 | 123 | if i % print_freq == 0: 124 | progress.display(i) 125 | 126 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) 127 | 128 | return top1.avg 129 | 130 | def get_train_samples(train_loader, num_samples): 131 | train_data, target = [], [] 132 | for batch in train_loader: 133 | train_data.append(batch[0]) 134 | target.append(batch[1]) 135 | if len(train_data) * batch[0].size(0) >= num_samples: 136 | break 137 | return torch.cat(train_data, dim=0)[:num_samples], torch.cat(target, dim=0)[:num_samples] 138 | 139 | 140 | if __name__ == '__main__': 141 | 142 | parser = argparse.ArgumentParser(description='running parameters', 143 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 144 | # general parameters for data and model 145 | parser.add_argument('--seed', default=1005, type=int, help='random seed for results reproduction') 146 | parser.add_argument('--arch', default='resnet18', type=str, help='model name', 147 | choices=['resnet18', 'resnet50', 'mobilenetv2', 'regnetx_600m', 'regnetx_3200m', 'mnasnet']) 148 | parser.add_argument('--batch_size', default=64, type=int, help='mini-batch size for data loader') 149 | parser.add_argument('--workers', default=4, type=int, help='number of workers for data loader') 150 | parser.add_argument('--data_path', default='/datasets-to-imagenet', type=str, help='path to ImageNet data') 151 | 152 | # quantization parameters 153 | parser.add_argument('--n_bits_w', default=4, type=int, help='bitwidth for weight quantization') 154 | parser.add_argument('--channel_wise', default=True, help='apply channel_wise quantization for weights') 155 | parser.add_argument('--n_bits_a', default=4, type=int, help='bitwidth for activation quantization') 156 | parser.add_argument('--disable_8bit_head_stem', action='store_true') 157 | 158 | # weight calibration parameters 159 | parser.add_argument('--num_samples', default=1024, type=int, help='size of the calibration dataset') 160 | parser.add_argument('--iters_w', default=20000, type=int, help='number of iteration for adaround') 161 | parser.add_argument('--weight', default=0.01, type=float, help='weight of rounding cost vs the reconstruction loss.') 162 | parser.add_argument('--keep_cpu', action='store_true', help='keep the calibration data on cpu') 163 | 164 | parser.add_argument('--b_start', default=20, type=int, help='temperature at the beginning of calibration') 165 | parser.add_argument('--b_end', default=2, type=int, help='temperature at the end of calibration') 166 | parser.add_argument('--warmup', default=0.2, type=float, help='in the warmup period no regularization is applied') 167 | 168 | # activation calibration parameters 169 | parser.add_argument('--lr', default=4e-5, type=float, help='learning rate for LSQ') 170 | 171 | parser.add_argument('--init_wmode', default='mse', type=str, choices=['minmax', 'mse', 'minmax_scale'], 172 | help='init opt mode for weight') 173 | parser.add_argument('--init_amode', default='mse', type=str, choices=['minmax', 'mse', 'minmax_scale'], 174 | help='init opt mode for activation') 175 | 176 | parser.add_argument('--prob', default=0.5, type=float) 177 | parser.add_argument('--input_prob', default=0.5, type=float) 178 | parser.add_argument('--lamb_r', default=0.1, type=float, help='hyper-parameter for regularization') 179 | parser.add_argument('--T', default=4.0, type=float, help='temperature coefficient for KL divergence') 180 | parser.add_argument('--bn_lr', default=1e-3, type=float, help='learning rate for DC') 181 | parser.add_argument('--lamb_c', default=0.02, type=float, help='hyper-parameter for DC') 182 | args = parser.parse_args() 183 | 184 | seed_all(args.seed) 185 | # build imagenet data loader 186 | train_loader, test_loader = build_imagenet_data(batch_size=args.batch_size, workers=args.workers, 187 | data_path=args.data_path) 188 | # load model 189 | cnn = eval('hubconf.{}(pretrained=True)'.format(args.arch)) 190 | cnn.cuda() 191 | cnn.eval() 192 | fp_model = copy.deepcopy(cnn) 193 | fp_model.cuda() 194 | fp_model.eval() 195 | 196 | # build quantization parameters 197 | wq_params = {'n_bits': args.n_bits_w, 'channel_wise': args.channel_wise, 'scale_method': args.init_wmode} 198 | aq_params = {'n_bits': args.n_bits_a, 'channel_wise': False, 'scale_method': args.init_amode, 199 | 'leaf_param': True, 'prob': args.prob} 200 | 201 | fp_model = QuantModel(model=fp_model, weight_quant_params=wq_params, act_quant_params=aq_params, is_fusing=False) 202 | fp_model.cuda() 203 | fp_model.eval() 204 | fp_model.set_quant_state(False, False) 205 | qnn = QuantModel(model=cnn, weight_quant_params=wq_params, act_quant_params=aq_params) 206 | qnn.cuda() 207 | qnn.eval() 208 | if not args.disable_8bit_head_stem: 209 | print('Setting the first and the last layer to 8-bit') 210 | qnn.set_first_last_layer_to_8bit() 211 | 212 | qnn.disable_network_output_quantization() 213 | print('the quantized model is below!') 214 | print(qnn) 215 | cali_data, cali_target = get_train_samples(train_loader, num_samples=args.num_samples) 216 | device = next(qnn.parameters()).device 217 | 218 | # Kwargs for weight rounding calibration 219 | kwargs = dict(cali_data=cali_data, iters=args.iters_w, weight=args.weight, 220 | b_range=(args.b_start, args.b_end), warmup=args.warmup, opt_mode='mse', 221 | lr=args.lr, input_prob=args.input_prob, keep_gpu=not args.keep_cpu, 222 | lamb_r=args.lamb_r, T=args.T, bn_lr=args.bn_lr, lamb_c=args.lamb_c) 223 | 224 | 225 | '''init weight quantizer''' 226 | set_weight_quantize_params(qnn) 227 | 228 | def set_weight_act_quantize_params(module, fp_module): 229 | if isinstance(module, QuantModule): 230 | layer_reconstruction(qnn, fp_model, module, fp_module, **kwargs) 231 | elif isinstance(module, BaseQuantBlock): 232 | block_reconstruction(qnn, fp_model, module, fp_module, **kwargs) 233 | else: 234 | raise NotImplementedError 235 | def recon_model(model: nn.Module, fp_model: nn.Module): 236 | """ 237 | Block reconstruction. For the first and last layers, we can only apply layer reconstruction. 238 | """ 239 | for (name, module), (_, fp_module) in zip(model.named_children(), fp_model.named_children()): 240 | if isinstance(module, QuantModule): 241 | print('Reconstruction for layer {}'.format(name)) 242 | set_weight_act_quantize_params(module, fp_module) 243 | elif isinstance(module, BaseQuantBlock): 244 | print('Reconstruction for block {}'.format(name)) 245 | set_weight_act_quantize_params(module, fp_module) 246 | else: 247 | recon_model(module, fp_module) 248 | # Start calibration 249 | recon_model(qnn, fp_model) 250 | 251 | qnn.set_quant_state(weight_quant=True, act_quant=True) 252 | print('Full quantization (W{}A{}) accuracy: {}'.format(args.n_bits_w, args.n_bits_a, 253 | validate_model(test_loader, qnn))) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/PD-Quant/e12038edd0c39d09772279dba90c43ce9cb5170d/models/__init__.py -------------------------------------------------------------------------------- /models/mnasnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['mnasnet'] 5 | 6 | 7 | class _InvertedResidual(nn.Module): 8 | 9 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor): 10 | super(_InvertedResidual, self).__init__() 11 | assert stride in [1, 2] 12 | assert kernel_size in [3, 5] 13 | mid_ch = in_ch * expansion_factor 14 | self.apply_residual = (in_ch == out_ch and stride == 1) 15 | self.layers = nn.Sequential( 16 | # Pointwise 17 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 18 | BN(mid_ch), 19 | nn.ReLU(inplace=True), 20 | # Depthwise 21 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 22 | stride=stride, groups=mid_ch, bias=False), 23 | BN(mid_ch), 24 | nn.ReLU(inplace=True), 25 | # Linear pointwise. Note that there's no activation. 26 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 27 | BN(out_ch)) 28 | 29 | def forward(self, input): 30 | if self.apply_residual: 31 | return self.layers(input) + input 32 | else: 33 | return self.layers(input) 34 | 35 | 36 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats): 37 | """ Creates a stack of inverted residuals. """ 38 | assert repeats >= 1 39 | # First one has no skip, because feature map size changes. 40 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor) 41 | remaining = [] 42 | for _ in range(1, repeats): 43 | remaining.append( 44 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor)) 45 | return nn.Sequential(first, *remaining) 46 | 47 | 48 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 49 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 50 | bias, will round up, unless the number is no more than 10% greater than the 51 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 52 | assert 0.0 < round_up_bias < 1.0 53 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 54 | return new_val if new_val >= round_up_bias * val else new_val + divisor 55 | 56 | 57 | def _get_depths(scale): 58 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 59 | rather than down. """ 60 | depths = [32, 16, 24, 40, 80, 96, 192, 320] 61 | return [_round_to_multiple_of(depth * scale, 8) for depth in depths] 62 | 63 | 64 | class MNASNet(torch.nn.Module): 65 | # Version 2 adds depth scaling in the initial stages of the network. 66 | _version = 2 67 | 68 | def __init__(self, scale=2.0, num_classes=1000, dropout=0.0): 69 | super(MNASNet, self).__init__() 70 | 71 | global BN 72 | BN = nn.BatchNorm2d 73 | 74 | assert scale > 0.0 75 | self.scale = scale 76 | self.num_classes = num_classes 77 | depths = _get_depths(scale) 78 | layers = [ 79 | # First layer: regular conv. 80 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), 81 | BN(depths[0]), 82 | nn.ReLU(inplace=True), 83 | # Depthwise separable, no skip. 84 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, 85 | groups=depths[0], bias=False), 86 | BN(depths[0]), 87 | nn.ReLU(inplace=True), 88 | nn.Conv2d(depths[0], depths[1], 1, 89 | padding=0, stride=1, bias=False), 90 | BN(depths[1]), 91 | # MNASNet blocks: stacks of inverted residuals. 92 | _stack(depths[1], depths[2], 3, 2, 3, 3), 93 | _stack(depths[2], depths[3], 5, 2, 3, 3), 94 | _stack(depths[3], depths[4], 5, 2, 6, 3), 95 | _stack(depths[4], depths[5], 3, 1, 6, 2), 96 | _stack(depths[5], depths[6], 5, 2, 6, 4), 97 | _stack(depths[6], depths[7], 3, 1, 6, 1), 98 | # Final mapping to classifier input. 99 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), 100 | BN(1280), 101 | nn.ReLU(inplace=True), 102 | ] 103 | self.layers = nn.Sequential(*layers) 104 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 105 | nn.Linear(1280, num_classes)) 106 | self._initialize_weights() 107 | 108 | def forward(self, x): 109 | x = self.layers(x) 110 | # Equivalent to global avgpool and removing H and W dimensions. 111 | x = x.mean([2, 3]) 112 | x = self.classifier(x) 113 | return x 114 | 115 | def _initialize_weights(self): 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 119 | nonlinearity="relu") 120 | if m.bias is not None: 121 | nn.init.zeros_(m.bias) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | nn.init.ones_(m.weight) 124 | nn.init.zeros_(m.bias) 125 | elif isinstance(m, nn.Linear): 126 | nn.init.kaiming_uniform_(m.weight, mode="fan_out", 127 | nonlinearity="sigmoid") 128 | nn.init.zeros_(m.bias) 129 | 130 | 131 | def mnasnet(**kwargs): 132 | model = MNASNet(**kwargs) 133 | return model 134 | 135 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | 5 | 6 | def conv_bn(inp, oup, stride): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 9 | nn.BatchNorm2d(oup), 10 | nn.ReLU6(inplace=True) 11 | ) 12 | 13 | 14 | def conv_1x1_bn(inp, oup): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | class InvertedResidual(nn.Module): 23 | def __init__(self, inp, oup, stride, expand_ratio): 24 | super(InvertedResidual, self).__init__() 25 | self.stride = stride 26 | assert stride in [1, 2] 27 | 28 | hidden_dim = round(inp * expand_ratio) 29 | self.use_res_connect = self.stride == 1 and inp == oup 30 | self.expand_ratio = expand_ratio 31 | if expand_ratio == 1: 32 | self.conv = nn.Sequential( 33 | # dw 34 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 35 | nn.BatchNorm2d(hidden_dim), 36 | nn.ReLU6(inplace=True), 37 | # pw-linear 38 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(oup), 40 | ) 41 | else: 42 | self.conv = nn.Sequential( 43 | # pw 44 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 45 | nn.BatchNorm2d(hidden_dim), 46 | nn.ReLU6(inplace=True), 47 | # dw 48 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 49 | nn.BatchNorm2d(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # pw-linear 52 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 53 | nn.BatchNorm2d(oup), 54 | ) 55 | 56 | def forward(self, x): 57 | if self.use_res_connect: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class MobileNetV2(nn.Module): 64 | def __init__(self, n_class=1000, input_size=224, width_mult=1., dropout=0.0): 65 | super(MobileNetV2, self).__init__() 66 | block = InvertedResidual 67 | input_channel = 32 68 | last_channel = 1280 69 | interverted_residual_setting = [ 70 | # t, c, n, s 71 | [1, 16, 1, 1], 72 | [6, 24, 2, 2], 73 | [6, 32, 3, 2], 74 | [6, 64, 4, 2], 75 | [6, 96, 3, 1], 76 | [6, 160, 3, 2], 77 | [6, 320, 1, 1], 78 | ] 79 | 80 | # building first layer 81 | assert input_size % 32 == 0 82 | input_channel = int(input_channel * width_mult) 83 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 84 | self.features = [conv_bn(3, input_channel, 2)] 85 | # building inverted residual blocks 86 | for t, c, n, s in interverted_residual_setting: 87 | output_channel = int(c * width_mult) 88 | for i in range(n): 89 | if i == 0: 90 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 91 | else: 92 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 93 | input_channel = output_channel 94 | # building last several layers 95 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 96 | # self.features.append(nn.AvgPool2d(input_size // 32)) 97 | # make it nn.Sequential 98 | self.features = nn.Sequential(*self.features) 99 | 100 | # building classifier 101 | self.classifier = nn.Sequential( 102 | nn.Dropout(dropout), 103 | nn.Linear(self.last_channel, n_class), 104 | ) 105 | 106 | self._initialize_weights() 107 | 108 | def forward(self, x): 109 | x = self.features(x) 110 | x = x.mean([2, 3]) 111 | x = self.classifier(x) 112 | return x 113 | 114 | def _initialize_weights(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | if m.bias is not None: 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.Linear): 125 | n = m.weight.size(1) 126 | m.weight.data.normal_(0, 0.01) 127 | m.bias.data.zero_() 128 | 129 | 130 | def mobilenetv2(**kwargs): 131 | """ 132 | Constructs a MobileNetV2 model. 133 | """ 134 | model = MobileNetV2(**kwargs) 135 | return model -------------------------------------------------------------------------------- /models/regnet.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | import math 6 | 7 | regnetX_200M_config = {'WA': 36.44, 'W0': 24, 'WM': 2.49, 'DEPTH': 13, 'GROUP_W': 8, 'SE_ON': False} 8 | regnetX_400M_config = {'WA': 24.48, 'W0': 24, 'WM': 2.54, 'DEPTH': 22, 'GROUP_W': 16, 'SE_ON': False} 9 | regnetX_600M_config = {'WA': 36.97, 'W0': 48, 'WM': 2.24, 'DEPTH': 16, 'GROUP_W': 24, 'SE_ON': False} 10 | regnetX_800M_config = {'WA': 35.73, 'W0': 56, 'WM': 2.28, 'DEPTH': 16, 'GROUP_W': 16, 'SE_ON': False} 11 | regnetX_1600M_config = {'WA': 34.01, 'W0': 80, 'WM': 2.25, 'DEPTH': 18, 'GROUP_W': 24, 'SE_ON': False} 12 | regnetX_3200M_config = {'WA': 26.31, 'W0': 88, 'WM': 2.25, 'DEPTH': 25, 'GROUP_W': 48, 'SE_ON': False} 13 | regnetX_4000M_config = {'WA': 38.65, 'W0': 96, 'WM': 2.43, 'DEPTH': 23, 'GROUP_W': 40, 'SE_ON': False} 14 | regnetX_6400M_config = {'WA': 60.83, 'W0': 184, 'WM': 2.07, 'DEPTH': 17, 'GROUP_W': 56, 'SE_ON': False} 15 | regnetY_200M_config = {'WA': 36.44, 'W0': 24, 'WM': 2.49, 'DEPTH': 13, 'GROUP_W': 8, 'SE_ON': True} 16 | regnetY_400M_config = {'WA': 27.89, 'W0': 48, 'WM': 2.09, 'DEPTH': 16, 'GROUP_W': 8, 'SE_ON': True} 17 | regnetY_600M_config = {'WA': 32.54, 'W0': 48, 'WM': 2.32, 'DEPTH': 15, 'GROUP_W': 16, 'SE_ON': True} 18 | regnetY_800M_config = {'WA': 38.84, 'W0': 56, 'WM': 2.4, 'DEPTH': 14, 'GROUP_W': 16, 'SE_ON': True} 19 | regnetY_1600M_config = {'WA': 20.71, 'W0': 48, 'WM': 2.65, 'DEPTH': 27, 'GROUP_W': 24, 'SE_ON': True} 20 | regnetY_3200M_config = {'WA': 42.63, 'W0': 80, 'WM': 2.66, 'DEPTH': 21, 'GROUP_W': 24, 'SE_ON': True} 21 | regnetY_4000M_config = {'WA': 31.41, 'W0': 96, 'WM': 2.24, 'DEPTH': 22, 'GROUP_W': 64, 'SE_ON': True} 22 | regnetY_6400M_config = {'WA': 33.22, 'W0': 112, 'WM': 2.27, 'DEPTH': 25, 'GROUP_W': 72, 'SE_ON': True} 23 | 24 | 25 | BN = nn.BatchNorm2d 26 | 27 | __all__ = ['regnetx_200m', 'regnetx_400m', 'regnetx_600m', 'regnetx_800m', 28 | 'regnetx_1600m', 'regnetx_3200m', 'regnetx_4000m', 'regnetx_6400m', 29 | 'regnety_200m', 'regnety_400m', 'regnety_600m', 'regnety_800m', 30 | 'regnety_1600m', 'regnety_3200m', 'regnety_4000m', 'regnety_6400m'] 31 | 32 | 33 | class SimpleStemIN(nn.Module): 34 | """Simple stem for ImageNet.""" 35 | 36 | def __init__(self, in_w, out_w): 37 | super(SimpleStemIN, self).__init__() 38 | self._construct(in_w, out_w) 39 | 40 | def _construct(self, in_w, out_w): 41 | # 3x3, BN, ReLU 42 | self.conv = nn.Conv2d( 43 | in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False 44 | ) 45 | self.bn = BN(out_w) 46 | self.relu = nn.ReLU(True) 47 | 48 | def forward(self, x): 49 | for layer in self.children(): 50 | x = layer(x) 51 | return x 52 | 53 | 54 | class SE(nn.Module): 55 | """Squeeze-and-Excitation (SE) block""" 56 | 57 | def __init__(self, w_in, w_se): 58 | super(SE, self).__init__() 59 | self._construct(w_in, w_se) 60 | 61 | def _construct(self, w_in, w_se): 62 | # AvgPool 63 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 64 | # FC, Activation, FC, Sigmoid 65 | self.f_ex = nn.Sequential( 66 | nn.Conv2d(w_in, w_se, kernel_size=1, bias=True), 67 | nn.ReLU(inplace=True), 68 | nn.Conv2d(w_se, w_in, kernel_size=1, bias=True), 69 | nn.Sigmoid(), 70 | ) 71 | 72 | def forward(self, x): 73 | return x * self.f_ex(self.avg_pool(x)) 74 | 75 | 76 | class BottleneckTransform(nn.Module): 77 | """Bottlenect transformation: 1x1, 3x3, 1x1""" 78 | 79 | def __init__(self, w_in, w_out, stride, bm, gw, se_r): 80 | super(BottleneckTransform, self).__init__() 81 | self._construct(w_in, w_out, stride, bm, gw, se_r) 82 | 83 | def _construct(self, w_in, w_out, stride, bm, gw, se_r): 84 | # Compute the bottleneck width 85 | w_b = int(round(w_out * bm)) 86 | # Compute the number of groups 87 | num_gs = w_b // gw 88 | # 1x1, BN, ReLU 89 | self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False) 90 | self.a_bn = BN(w_b) 91 | self.a_relu = nn.ReLU(True) 92 | # 3x3, BN, ReLU 93 | self.b = nn.Conv2d( 94 | w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False 95 | ) 96 | self.b_bn = BN(w_b) 97 | self.b_relu = nn.ReLU(True) 98 | # Squeeze-and-Excitation (SE) 99 | if se_r: 100 | w_se = int(round(w_in * se_r)) 101 | self.se = SE(w_b, w_se) 102 | # 1x1, BN 103 | self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False) 104 | self.c_bn = BN(w_out) 105 | self.c_bn.final_bn = True 106 | 107 | def forward(self, x): 108 | for layer in self.children(): 109 | x = layer(x) 110 | return x 111 | 112 | 113 | class ResBottleneckBlock(nn.Module): 114 | """Residual bottleneck block: x + F(x), F = bottleneck transform""" 115 | 116 | def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): 117 | super(ResBottleneckBlock, self).__init__() 118 | self._construct(w_in, w_out, stride, bm, gw, se_r) 119 | 120 | def _add_skip_proj(self, w_in, w_out, stride): 121 | self.proj = nn.Conv2d( 122 | w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False 123 | ) 124 | self.bn = BN(w_out) 125 | 126 | def _construct(self, w_in, w_out, stride, bm, gw, se_r): 127 | # Use skip connection with projection if shape changes 128 | self.proj_block = (w_in != w_out) or (stride != 1) 129 | if self.proj_block: 130 | self._add_skip_proj(w_in, w_out, stride) 131 | self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r) 132 | self.relu = nn.ReLU(True) 133 | 134 | def forward(self, x): 135 | if self.proj_block: 136 | x = self.bn(self.proj(x)) + self.f(x) 137 | else: 138 | x = x + self.f(x) 139 | x = self.relu(x) 140 | return x 141 | 142 | 143 | class AnyHead(nn.Module): 144 | """AnyNet head.""" 145 | 146 | def __init__(self, w_in, nc): 147 | super(AnyHead, self).__init__() 148 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 149 | self.fc = nn.Linear(w_in, nc, bias=True) 150 | 151 | def forward(self, x): 152 | x = self.avg_pool(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.fc(x) 155 | return x 156 | 157 | 158 | class AnyStage(nn.Module): 159 | """AnyNet stage (sequence of blocks w/ the same output shape).""" 160 | 161 | def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): 162 | super(AnyStage, self).__init__() 163 | self._construct(w_in, w_out, stride, d, block_fun, bm, gw, se_r) 164 | 165 | def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): 166 | # Construct the blocks 167 | for i in range(d): 168 | # Stride and w_in apply to the first block of the stage 169 | b_stride = stride if i == 0 else 1 170 | b_w_in = w_in if i == 0 else w_out 171 | # Construct the block 172 | self.add_module( 173 | "b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bm, gw, se_r) 174 | ) 175 | 176 | def forward(self, x): 177 | for block in self.children(): 178 | x = block(x) 179 | return x 180 | 181 | 182 | class AnyNet(nn.Module): 183 | """AnyNet model.""" 184 | 185 | def __init__(self, **kwargs): 186 | super(AnyNet, self).__init__() 187 | if kwargs: 188 | self._construct( 189 | stem_w=kwargs["stem_w"], 190 | ds=kwargs["ds"], 191 | ws=kwargs["ws"], 192 | ss=kwargs["ss"], 193 | bms=kwargs["bms"], 194 | gws=kwargs["gws"], 195 | se_r=kwargs["se_r"], 196 | nc=kwargs["nc"], 197 | ) 198 | for m in self.modules(): 199 | if isinstance(m, nn.Conv2d): 200 | # Note that there is no bias due to BN 201 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 202 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 203 | elif isinstance(m, nn.BatchNorm2d): 204 | m.weight.data.fill_(1) 205 | m.bias.data.zero_() 206 | elif isinstance(m, nn.Linear): 207 | n = m.weight.size(1) 208 | m.weight.data.normal_(0, 1.0 / float(n)) 209 | m.bias.data.zero_() 210 | 211 | def _construct(self, stem_w, ds, ws, ss, bms, gws, se_r, nc): 212 | # self.logger.info("Constructing AnyNet: ds={}, ws={}".format(ds, ws)) 213 | # Generate dummy bot muls and gs for models that do not use them 214 | bms = bms if bms else [1.0 for _d in ds] 215 | gws = gws if gws else [1 for _d in ds] 216 | # Group params by stage 217 | stage_params = list(zip(ds, ws, ss, bms, gws)) 218 | # Construct the stem 219 | self.stem = SimpleStemIN(3, stem_w) 220 | # Construct the stages 221 | block_fun = ResBottleneckBlock 222 | prev_w = stem_w 223 | for i, (d, w, s, bm, gw) in enumerate(stage_params): 224 | self.add_module( 225 | "s{}".format(i + 1), AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r) 226 | ) 227 | prev_w = w 228 | # Construct the head 229 | self.head = AnyHead(w_in=prev_w, nc=nc) 230 | 231 | def forward(self, x): 232 | for module in self.children(): 233 | x = module(x) 234 | return x 235 | 236 | 237 | def quantize_float(f, q): 238 | """Converts a float to closest non-zero int divisible by q.""" 239 | return int(round(f / q) * q) 240 | 241 | 242 | def adjust_ws_gs_comp(ws, bms, gs): 243 | """Adjusts the compatibility of widths and groups.""" 244 | ws_bot = [int(w * b) for w, b in zip(ws, bms)] 245 | gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)] 246 | ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)] 247 | ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)] 248 | return ws, gs 249 | 250 | 251 | def get_stages_from_blocks(ws, rs): 252 | """Gets ws/ds of network at each stage from per block values.""" 253 | ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs) 254 | ts = [w != wp or r != rp for w, wp, r, rp in ts_temp] 255 | s_ws = [w for w, t in zip(ws, ts[:-1]) if t] 256 | s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist() 257 | return s_ws, s_ds 258 | 259 | 260 | def generate_regnet(w_a, w_0, w_m, d, q=8): 261 | """Generates per block ws from RegNet parameters. 262 | 263 | args: 264 | w_a(float): slope 265 | w_0(int): initial width 266 | w_m(float): an additional parameter that controls quantization 267 | d(int): number of depth 268 | q(int): the coefficient of division 269 | 270 | procedure: 271 | 1. generate a linear parameterization for block widths. Eql(2) 272 | 2. compute corresponding stage for each block $log_{w_m}^{w_j/w_0}$. Eql(3) 273 | 3. compute per-block width via $w_0*w_m^(s_j)$ and qunatize them that can be divided by q. Eql(4) 274 | 275 | return: 276 | ws(list of quantized float): quantized width list for blocks in different stages 277 | num_stages(int): total number of stages 278 | max_stage(float): the maximal index of stage 279 | ws_cont(list of float): original width list for blocks in different stages 280 | """ 281 | assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 282 | ws_cont = np.arange(d) * w_a + w_0 283 | ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) 284 | ws = w_0 * np.power(w_m, ks) 285 | ws = np.round(np.divide(ws, q)) * q 286 | num_stages, max_stage = len(np.unique(ws)), ks.max() + 1 287 | ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist() 288 | return ws, num_stages, max_stage, ws_cont 289 | 290 | 291 | class RegNet(AnyNet): 292 | """RegNet model class, based on 293 | `"Designing Network Design Spaces" `_ 294 | """ 295 | 296 | def __init__(self, cfg, bn=None): 297 | # Generate RegNet ws per block 298 | b_ws, num_s, _, _ = generate_regnet( 299 | cfg['WA'], cfg['W0'], cfg['WM'], cfg['DEPTH'] 300 | ) 301 | # Convert to per stage format 302 | ws, ds = get_stages_from_blocks(b_ws, b_ws) 303 | # Generate group widths and bot muls 304 | gws = [cfg['GROUP_W'] for _ in range(num_s)] 305 | bms = [1 for _ in range(num_s)] 306 | # Adjust the compatibility of ws and gws 307 | ws, gws = adjust_ws_gs_comp(ws, bms, gws) 308 | # Use the same stride for each stage, stride set to 2 309 | ss = [2 for _ in range(num_s)] 310 | # Use SE for RegNetY 311 | se_r = 0.25 if cfg['SE_ON'] else None 312 | # Construct the model 313 | STEM_W = 32 314 | 315 | global BN 316 | 317 | kwargs = { 318 | "stem_w": STEM_W, 319 | "ss": ss, 320 | "ds": ds, 321 | "ws": ws, 322 | "bms": bms, 323 | "gws": gws, 324 | "se_r": se_r, 325 | "nc": 1000, 326 | } 327 | super(RegNet, self).__init__(**kwargs) 328 | 329 | 330 | def regnetx_200m(**kwargs): 331 | """ 332 | Constructs a RegNet-X model under 200M FLOPs. 333 | """ 334 | model = RegNet(regnetX_200M_config, **kwargs) 335 | return model 336 | 337 | 338 | def regnetx_400m(**kwargs): 339 | """ 340 | Constructs a RegNet-X model under 400M FLOPs. 341 | """ 342 | model = RegNet(regnetX_400M_config, **kwargs) 343 | return model 344 | 345 | 346 | def regnetx_600m(**kwargs): 347 | """ 348 | Constructs a RegNet-X model under 600M FLOPs. 349 | """ 350 | model = RegNet(regnetX_600M_config, **kwargs) 351 | return model 352 | 353 | 354 | def regnetx_800m(**kwargs): 355 | """ 356 | Constructs a RegNet-X model under 800M FLOPs. 357 | """ 358 | model = RegNet(regnetX_800M_config, **kwargs) 359 | return model 360 | 361 | 362 | def regnetx_1600m(**kwargs): 363 | """ 364 | Constructs a RegNet-X model under 1600M FLOPs. 365 | """ 366 | model = RegNet(regnetX_1600M_config, **kwargs) 367 | return model 368 | 369 | 370 | def regnetx_3200m(**kwargs): 371 | """ 372 | Constructs a RegNet-X model under 3200M FLOPs. 373 | """ 374 | model = RegNet(regnetX_3200M_config, **kwargs) 375 | return model 376 | 377 | 378 | def regnetx_4000m(**kwargs): 379 | """ 380 | Constructs a RegNet-X model under 4000M FLOPs. 381 | """ 382 | model = RegNet(regnetX_4000M_config, **kwargs) 383 | return model 384 | 385 | 386 | def regnetx_6400m(**kwargs): 387 | """ 388 | Constructs a RegNet-X model under 6400M FLOPs. 389 | """ 390 | model = RegNet(regnetX_6400M_config, **kwargs) 391 | return model 392 | 393 | 394 | def regnety_200m(**kwargs): 395 | """ 396 | Constructs a RegNet-Y model under 200M FLOPs. 397 | """ 398 | model = RegNet(regnetY_200M_config, **kwargs) 399 | return model 400 | 401 | 402 | def regnety_400m(**kwargs): 403 | """ 404 | Constructs a RegNet-Y model under 400M FLOPs. 405 | """ 406 | model = RegNet(regnetY_400M_config, **kwargs) 407 | return model 408 | 409 | 410 | def regnety_600m(**kwargs): 411 | """ 412 | Constructs a RegNet-Y model under 600M FLOPs. 413 | """ 414 | model = RegNet(regnetY_600M_config, **kwargs) 415 | return model 416 | 417 | 418 | def regnety_800m(**kwargs): 419 | """ 420 | Constructs a RegNet-Y model under 800M FLOPs. 421 | """ 422 | model = RegNet(regnetY_800M_config, **kwargs) 423 | return model 424 | 425 | 426 | def regnety_1600m(**kwargs): 427 | """ 428 | Constructs a RegNet-Y model under 1600M FLOPs. 429 | """ 430 | model = RegNet(regnetY_1600M_config, **kwargs) 431 | return model 432 | 433 | 434 | def regnety_3200m(**kwargs): 435 | """ 436 | Constructs a RegNet-Y model under 3200M FLOPs. 437 | """ 438 | model = RegNet(regnetY_3200M_config, **kwargs) 439 | return model 440 | 441 | 442 | def regnety_4000m(**kwargs): 443 | """ 444 | Constructs a RegNet-Y model under 4000M FLOPs. 445 | """ 446 | model = RegNet(regnetY_4000M_config, **kwargs) 447 | return model 448 | 449 | 450 | def regnety_6400m(**kwargs): 451 | """ 452 | Constructs a RegNet-Y model under 6400M FLOPs. 453 | """ 454 | model = RegNet(regnetY_6400M_config, **kwargs) 455 | return model 456 | 457 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=dilation, groups=groups, bias=False, dilation=dilation) 15 | 16 | 17 | def conv1x1(in_planes, out_planes, stride=1): 18 | """1x1 convolution""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | __constants__ = ['downsample'] 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 27 | base_width=64, dilation=1, norm_layer=None): 28 | super(BasicBlock, self).__init__() 29 | if norm_layer is None: 30 | norm_layer = BN 31 | if groups != 1 or base_width != 64: 32 | raise ValueError( 33 | 'BasicBlock only supports groups=1 and base_width=64') 34 | if dilation > 1: 35 | raise NotImplementedError( 36 | "Dilation > 1 not supported in BasicBlock") 37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = norm_layer(planes) 40 | self.relu1 = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = norm_layer(planes) 43 | self.downsample = downsample 44 | self.relu2 = nn.ReLU(inplace=True) 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu1(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | identity = self.downsample(x) 59 | 60 | out += identity 61 | out = self.relu2(out) 62 | 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | __constants__ = ['downsample'] 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, dilation=1, norm_layer=None): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = BN 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.relu1 = nn.ReLU(inplace=True) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.relu2 = nn.ReLU(inplace=True) 83 | self.conv3 = conv1x1(width, planes * self.expansion) 84 | self.bn3 = norm_layer(planes * self.expansion) 85 | self.relu3 = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu1(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu2(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu3(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, 115 | block, 116 | layers, 117 | num_classes=1000, 118 | zero_init_residual=False, 119 | groups=1, 120 | width_per_group=64, 121 | replace_stride_with_dilation=None, 122 | deep_stem=False, 123 | avg_down=False): 124 | 125 | super(ResNet, self).__init__() 126 | 127 | global BN 128 | 129 | BN = torch.nn.BatchNorm2d 130 | norm_layer = BN 131 | 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | self.deep_stem = deep_stem 137 | self.avg_down = avg_down 138 | 139 | if replace_stride_with_dilation is None: 140 | # each element in the tuple indicates if we should replace 141 | # the 2x2 stride with a dilated convolution instead 142 | replace_stride_with_dilation = [False, False, False] 143 | if len(replace_stride_with_dilation) != 3: 144 | raise ValueError("replace_stride_with_dilation should be None " 145 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 146 | self.groups = groups 147 | self.base_width = width_per_group 148 | 149 | if self.deep_stem: 150 | self.conv1 = nn.Sequential( 151 | nn.Conv2d(3, 32, kernel_size=3, stride=2, 152 | padding=1, bias=False), 153 | norm_layer(32), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(32, 32, kernel_size=3, stride=1, 156 | padding=1, bias=False), 157 | norm_layer(32), 158 | nn.ReLU(inplace=True), 159 | nn.Conv2d(32, 64, kernel_size=3, stride=1, 160 | padding=1, bias=False), 161 | ) 162 | else: 163 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 164 | stride=2, padding=3, bias=False) 165 | 166 | self.bn1 = norm_layer(self.inplanes) 167 | self.relu = nn.ReLU(inplace=True) 168 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 169 | self.layer1 = self._make_layer(block, 64, layers[0]) 170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 171 | dilate=replace_stride_with_dilation[0]) 172 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 173 | dilate=replace_stride_with_dilation[1]) 174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 175 | dilate=replace_stride_with_dilation[2]) 176 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 177 | self.fc = nn.Linear(512 * block.expansion, num_classes) 178 | 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | nn.init.kaiming_normal_( 182 | m.weight, mode='fan_out', nonlinearity='relu') 183 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 184 | nn.init.constant_(m.weight, 1) 185 | nn.init.constant_(m.bias, 0) 186 | 187 | # Zero-initialize the last BN in each residual branch, 188 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 189 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 190 | if zero_init_residual: 191 | for m in self.modules(): 192 | if isinstance(m, Bottleneck): 193 | nn.init.constant_(m.bn3.weight, 0) 194 | elif isinstance(m, BasicBlock): 195 | nn.init.constant_(m.bn2.weight, 0) 196 | 197 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 198 | norm_layer = self._norm_layer 199 | downsample = None 200 | previous_dilation = self.dilation 201 | if dilate: 202 | self.dilation *= stride 203 | stride = 1 204 | if stride != 1 or self.inplanes != planes * block.expansion: 205 | if self.avg_down: 206 | downsample = nn.Sequential( 207 | nn.AvgPool2d(stride, stride=stride, 208 | ceil_mode=True, count_include_pad=False), 209 | conv1x1(self.inplanes, planes * block.expansion), 210 | norm_layer(planes * block.expansion), 211 | ) 212 | else: 213 | downsample = nn.Sequential( 214 | conv1x1(self.inplanes, planes * block.expansion, stride), 215 | norm_layer(planes * block.expansion), 216 | ) 217 | 218 | layers = [] 219 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 220 | self.base_width, previous_dilation, norm_layer)) 221 | self.inplanes = planes * block.expansion 222 | for _ in range(1, blocks): 223 | layers.append(block(self.inplanes, planes, groups=self.groups, 224 | base_width=self.base_width, dilation=self.dilation, 225 | norm_layer=norm_layer)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def _forward_impl(self, x): 230 | # See note [TorchScript super()] 231 | x = self.conv1(x) 232 | x = self.bn1(x) 233 | x = self.relu(x) 234 | x = self.maxpool(x) 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | 240 | x = self.avgpool(x) 241 | x = torch.flatten(x, 1) 242 | x = self.fc(x) 243 | 244 | return x 245 | 246 | def forward(self, x): 247 | return self._forward_impl(x) 248 | 249 | 250 | def resnet18(**kwargs): 251 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 252 | return model 253 | 254 | 255 | def resnet34(**kwargs): 256 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 257 | return model 258 | 259 | 260 | def resnet50(**kwargs): 261 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 262 | return model 263 | 264 | 265 | def resnet101(**kwargs): 266 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 267 | return model 268 | 269 | 270 | def resnet152(**kwargs): 271 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 272 | return model 273 | 274 | 275 | def resnext50_32x4d(**kwargs): 276 | kwargs['groups'] = 32 277 | kwargs['width_per_group'] = 4 278 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 279 | return model 280 | 281 | 282 | def resnext101_32x8d(**kwargs): 283 | kwargs['groups'] = 32 284 | kwargs['width_per_group'] = 8 285 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 286 | return model 287 | 288 | 289 | def wide_resnet50_2(**kwargs): 290 | kwargs['width_per_group'] = 64 * 2 291 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 292 | return model 293 | 294 | 295 | def wide_resnet101_2(**kwargs): 296 | kwargs['width_per_group'] = 64 * 2 297 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 298 | return model 299 | -------------------------------------------------------------------------------- /quant/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_recon import block_reconstruction 2 | from .layer_recon import layer_reconstruction 3 | from .quant_block import BaseQuantBlock 4 | from .quant_layer import QuantModule 5 | from .quant_model import QuantModel 6 | from .set_weight_quantize_params import set_weight_quantize_params, get_init, save_quantized_weight 7 | from .set_act_quantize_params import set_act_quantize_params 8 | -------------------------------------------------------------------------------- /quant/adaptive_rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .quant_layer import UniformAffineQuantizer, round_ste 4 | 5 | 6 | class AdaRoundQuantizer(nn.Module): 7 | """ 8 | Adaptive Rounding Quantizer, used to optimize the rounding policy 9 | by reconstructing the intermediate output. 10 | Based on 11 | Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568 12 | 13 | :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer 14 | :param round_mode: controls the forward pass in this quantizer 15 | :param weight_tensor: initialize alpha 16 | """ 17 | 18 | def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'): 19 | super(AdaRoundQuantizer, self).__init__() 20 | # copying all attributes from UniformAffineQuantizer 21 | self.n_bits = uaq.n_bits 22 | self.sym = uaq.sym 23 | self.delta = uaq.delta 24 | self.zero_point = uaq.zero_point 25 | self.n_levels = uaq.n_levels 26 | 27 | self.round_mode = round_mode 28 | self.alpha = None 29 | self.soft_targets = False 30 | 31 | # params for sigmoid function 32 | self.gamma, self.zeta = -0.1, 1.1 33 | self.beta = 2/3 34 | self.init_alpha(x=weight_tensor.clone()) 35 | 36 | def forward(self, x): 37 | if self.round_mode == 'nearest': 38 | x_int = torch.round(x / self.delta) 39 | elif self.round_mode == 'nearest_ste': 40 | x_int = round_ste(x / self.delta) 41 | elif self.round_mode == 'stochastic': 42 | x_floor = torch.floor(x / self.delta) 43 | rest = (x / self.delta) - x_floor # rest of rounding 44 | x_int = x_floor + torch.bernoulli(rest) 45 | print('Draw stochastic sample') 46 | elif self.round_mode == 'learned_hard_sigmoid': 47 | x_floor = torch.floor(x / self.delta) 48 | if self.soft_targets: 49 | x_int = x_floor + self.get_soft_targets() 50 | else: 51 | x_int = x_floor + (self.alpha >= 0).float() 52 | else: 53 | raise ValueError('Wrong rounding mode') 54 | 55 | x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1) 56 | x_float_q = (x_quant - self.zero_point) * self.delta 57 | 58 | return x_float_q 59 | 60 | def get_soft_targets(self): 61 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1) 62 | 63 | def init_alpha(self, x: torch.Tensor): 64 | x_floor = torch.floor(x / self.delta) 65 | if self.round_mode == 'learned_hard_sigmoid': 66 | print('Init alpha to be FP32') 67 | rest = (x / self.delta) - x_floor # rest of rounding [0, 1) 68 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest 69 | self.alpha = nn.Parameter(alpha) 70 | else: 71 | raise NotImplementedError 72 | 73 | @torch.jit.export 74 | def extra_repr(self): 75 | return 'bit={}'.format(self.n_bits) 76 | -------------------------------------------------------------------------------- /quant/block_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .quant_layer import QuantModule, lp_loss 4 | from .quant_model import QuantModel 5 | from .quant_block import BaseQuantBlock, specials_unquantized 6 | from .adaptive_rounding import AdaRoundQuantizer 7 | from .set_weight_quantize_params import get_init, get_dc_fp_init 8 | from .set_act_quantize_params import set_act_quantize_params 9 | 10 | include = False 11 | def find_unquantized_module(model: torch.nn.Module, module_list: list = [], name_list: list = []): 12 | """Store subsequent unquantized modules in a list""" 13 | global include 14 | for name, module in model.named_children(): 15 | if isinstance(module, (QuantModule, BaseQuantBlock)): 16 | if not module.trained: 17 | include = True 18 | module.set_quant_state(False,False) 19 | name_list.append(name) 20 | module_list.append(module) 21 | elif include and type(module) in specials_unquantized: 22 | name_list.append(name) 23 | module_list.append(module) 24 | else: 25 | find_unquantized_module(module, module_list, name_list) 26 | return module_list[1:], name_list[1:] 27 | 28 | def block_reconstruction(model: QuantModel, fp_model: QuantModel, block: BaseQuantBlock, fp_block: BaseQuantBlock, 29 | cali_data: torch.Tensor, batch_size: int = 32, iters: int = 20000, weight: float = 0.01, 30 | opt_mode: str = 'mse', b_range: tuple = (20, 2), 31 | warmup: float = 0.0, p: float = 2.0, lr: float = 4e-5, 32 | input_prob: float = 1.0, keep_gpu: bool = True, 33 | lamb_r: float = 0.2, T: float = 7.0, bn_lr: float = 1e-3, lamb_c=0.02): 34 | """ 35 | Reconstruction to optimize the output from each block. 36 | 37 | :param model: QuantModel 38 | :param block: BaseQuantBlock that needs to be optimized 39 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound 40 | :param batch_size: mini-batch size for reconstruction 41 | :param iters: optimization iterations for reconstruction, 42 | :param weight: the weight of rounding regularization term 43 | :param opt_mode: optimization mode 44 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output 45 | :param include_act_func: optimize the output after activation function 46 | :param b_range: temperature range 47 | :param warmup: proportion of iterations that no scheduling for temperature 48 | :param lr: learning rate for act delta learning 49 | :param p: L_p norm minimization 50 | :param lamb_r: hyper-parameter for regularization 51 | :param T: temperature coefficient for KL divergence 52 | :param bn_lr: learning rate for DC 53 | :param lamb_c: hyper-parameter for DC 54 | """ 55 | 56 | '''get input and set scale''' 57 | cached_inps = get_init(model, block, cali_data, batch_size=batch_size, 58 | input_prob=True, keep_gpu=keep_gpu) 59 | cached_outs, cached_output, cur_syms = get_dc_fp_init(fp_model, fp_block, cali_data, batch_size=batch_size, 60 | input_prob=True, keep_gpu=keep_gpu, bn_lr=bn_lr, lamb=lamb_c) 61 | set_act_quantize_params(block, cali_data=cached_inps[:min(256, cached_inps.size(0))]) 62 | 63 | '''set state''' 64 | cur_weight, cur_act = True, True 65 | 66 | global include 67 | module_list, name_list, include = [], [], False 68 | module_list, name_list = find_unquantized_module(model, module_list, name_list) 69 | block.set_quant_state(cur_weight, cur_act) 70 | for para in model.parameters(): 71 | para.requires_grad = False 72 | 73 | '''set quantizer''' 74 | round_mode = 'learned_hard_sigmoid' 75 | # Replace weight quantizer to AdaRoundQuantizer 76 | w_para, a_para = [], [] 77 | w_opt, a_opt = None, None 78 | scheduler, a_scheduler = None, None 79 | 80 | for module in block.modules(): 81 | '''weight''' 82 | if isinstance(module, QuantModule): 83 | module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode, 84 | weight_tensor=module.org_weight.data) 85 | module.weight_quantizer.soft_targets = True 86 | w_para += [module.weight_quantizer.alpha] 87 | '''activation''' 88 | if isinstance(module, (QuantModule, BaseQuantBlock)): 89 | if module.act_quantizer.delta is not None: 90 | module.act_quantizer.delta = torch.nn.Parameter(torch.tensor(module.act_quantizer.delta)) 91 | a_para += [module.act_quantizer.delta] 92 | '''set up drop''' 93 | module.act_quantizer.is_training = True 94 | 95 | if len(w_para) != 0: 96 | w_opt = torch.optim.Adam(w_para, lr=3e-3) 97 | if len(a_para) != 0: 98 | a_opt = torch.optim.Adam(a_para, lr=lr) 99 | a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=iters, eta_min=0.) 100 | 101 | loss_mode = 'relaxation' 102 | rec_loss = opt_mode 103 | loss_func = LossFunction(block, round_loss=loss_mode, weight=weight, max_count=iters, rec_loss=rec_loss, 104 | b_range=b_range, decay_start=0, warmup=warmup, p=p, lam=lamb_r, T=T) 105 | device = 'cuda' 106 | sz = cached_inps.size(0) 107 | for i in range(iters): 108 | idx = torch.randint(0, sz, (batch_size,)) 109 | cur_inp = cached_inps[idx].to(device) 110 | cur_sym = cur_syms[idx].to(device) 111 | output_fp = cached_output[idx].to(device) 112 | cur_out = cached_outs[idx].to(device) 113 | if input_prob < 1.0: 114 | drop_inp = torch.where(torch.rand_like(cur_inp) < input_prob, cur_inp, cur_sym) 115 | 116 | cur_inp = torch.cat((drop_inp, cur_inp)) 117 | 118 | if w_opt: 119 | w_opt.zero_grad() 120 | if a_opt: 121 | a_opt.zero_grad() 122 | 123 | out_all = block(cur_inp) 124 | 125 | '''forward for prediction difference''' 126 | out_drop = out_all[:batch_size] 127 | out_quant = out_all[batch_size:] 128 | output = out_quant 129 | for num, module in enumerate(module_list): 130 | # for ResNet and RegNet 131 | if name_list[num] == 'fc': 132 | output = torch.flatten(output, 1) 133 | # for MobileNet and MNasNet 134 | if isinstance(module, torch.nn.Dropout): 135 | output = output.mean([2, 3]) 136 | output = module(output) 137 | err = loss_func(out_drop, cur_out, output, output_fp) 138 | 139 | err.backward(retain_graph=True) 140 | if w_opt: 141 | w_opt.step() 142 | if a_opt: 143 | a_opt.step() 144 | if scheduler: 145 | scheduler.step() 146 | if a_scheduler: 147 | a_scheduler.step() 148 | torch.cuda.empty_cache() 149 | 150 | for module in block.modules(): 151 | if isinstance(module, QuantModule): 152 | '''weight ''' 153 | module.weight_quantizer.soft_targets = False 154 | '''activation''' 155 | if isinstance(module, (QuantModule, BaseQuantBlock)): 156 | module.act_quantizer.is_training = False 157 | module.trained = True 158 | for module in fp_block.modules(): 159 | if isinstance(module, (QuantModule, BaseQuantBlock)): 160 | module.trained = True 161 | 162 | class LossFunction: 163 | def __init__(self, 164 | block: BaseQuantBlock, 165 | round_loss: str = 'relaxation', 166 | weight: float = 1., 167 | rec_loss: str = 'mse', 168 | max_count: int = 2000, 169 | b_range: tuple = (10, 2), 170 | decay_start: float = 0.0, 171 | warmup: float = 0.0, 172 | p: float = 2., 173 | lam: float = 1.0, 174 | T: float = 7.0): 175 | 176 | self.block = block 177 | self.round_loss = round_loss 178 | self.weight = weight 179 | self.rec_loss = rec_loss 180 | self.loss_start = max_count * warmup 181 | self.p = p 182 | self.lam = lam 183 | self.T = T 184 | 185 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start, 186 | start_b=b_range[0], end_b=b_range[1]) 187 | self.count = 0 188 | self.pd_loss = torch.nn.KLDivLoss(reduction='batchmean') 189 | 190 | def __call__(self, pred, tgt, output, output_fp): 191 | """ 192 | Compute the total loss for adaptive rounding: 193 | rec_loss is the quadratic output reconstruction loss, round_loss is 194 | a regularization term to optimize the rounding policy, pd_loss is the 195 | prediction difference loss. 196 | 197 | :param pred: output from quantized model 198 | :param tgt: output from FP model 199 | :param output: prediction from quantized model 200 | :param output_fp: prediction from FP model 201 | :return: total loss function 202 | """ 203 | self.count += 1 204 | if self.rec_loss == 'mse': 205 | rec_loss = lp_loss(pred, tgt, p=self.p) 206 | else: 207 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) 208 | 209 | pd_loss = self.pd_loss(F.log_softmax(output / self.T, dim=1), F.softmax(output_fp / self.T, dim=1)) / self.lam 210 | 211 | b = self.temp_decay(self.count) 212 | if self.count < self.loss_start or self.round_loss == 'none': 213 | b = round_loss = 0 214 | elif self.round_loss == 'relaxation': 215 | round_loss = 0 216 | for name, module in self.block.named_modules(): 217 | if isinstance(module, QuantModule): 218 | round_vals = module.weight_quantizer.get_soft_targets() 219 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() 220 | else: 221 | raise NotImplementedError 222 | 223 | total_loss = rec_loss + round_loss + pd_loss 224 | if self.count % 500 == 0: 225 | print('Total loss:\t{:.3f} (rec:{:.3f}, pd:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format( 226 | float(total_loss), float(rec_loss), float(pd_loss), float(round_loss), b, self.count)) 227 | return total_loss 228 | 229 | 230 | class LinearTempDecay: 231 | def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): 232 | self.t_max = t_max 233 | self.start_decay = rel_start_decay * t_max 234 | self.start_b = start_b 235 | self.end_b = end_b 236 | 237 | def __call__(self, t): 238 | """ 239 | Cosine annealing scheduler for temperature b. 240 | :param t: the current time step 241 | :return: scheduled temperature 242 | """ 243 | if t < self.start_decay: 244 | return self.start_b 245 | else: 246 | rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) 247 | # return self.end_b + 0.5 * (self.start_b - self.end_b) * (1 + np.cos(rel_t * np.pi)) 248 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) 249 | -------------------------------------------------------------------------------- /quant/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | from .quant_layer import QuantModule, Union, lp_loss 6 | from .quant_model import QuantModel 7 | from .quant_block import BaseQuantBlock 8 | from tqdm import trange 9 | 10 | 11 | def save_dc_fp_data(model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], cali_data: torch.Tensor, 12 | batch_size: int = 32, keep_gpu: bool = True, 13 | input_prob: bool = False, lamb=50, bn_lr=1e-3): 14 | """Activation after correction""" 15 | device = next(model.parameters()).device 16 | get_inp_out = GetDcFpLayerInpOut(model, layer, device=device, input_prob=input_prob, lamb=lamb, bn_lr=bn_lr) 17 | cached_batches = [] 18 | 19 | print("Start correcting {} batches of data!".format(int(cali_data.size(0) / batch_size))) 20 | for i in trange(int(cali_data.size(0) / batch_size)): 21 | if input_prob: 22 | cur_out, out_fp, cur_sym = get_inp_out(cali_data[i * batch_size:(i + 1) * batch_size]) 23 | cached_batches.append((cur_out.cpu(), out_fp.cpu(), cur_sym.cpu())) 24 | else: 25 | cur_out, out_fp = get_inp_out(cali_data[i * batch_size:(i + 1) * batch_size]) 26 | cached_batches.append((cur_out.cpu(), out_fp.cpu())) 27 | cached_outs = torch.cat([x[0] for x in cached_batches]) 28 | cached_outputs = torch.cat([x[1] for x in cached_batches]) 29 | if input_prob: 30 | cached_sym = torch.cat([x[2] for x in cached_batches]) 31 | torch.cuda.empty_cache() 32 | if keep_gpu: 33 | cached_outs = cached_outs.to(device) 34 | cached_outputs = cached_outputs.to(device) 35 | if input_prob: 36 | cached_sym = cached_sym.to(device) 37 | if input_prob: 38 | cached_outs.requires_grad = False 39 | cached_sym.requires_grad = False 40 | return cached_outs, cached_outputs, cached_sym 41 | return cached_outs, cached_outputs 42 | 43 | 44 | def save_inp_oup_data(model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], cali_data: torch.Tensor, 45 | batch_size: int = 32, keep_gpu: bool = True, 46 | input_prob: bool = False): 47 | """ 48 | Save input data and output data of a particular layer/block over calibration dataset. 49 | 50 | :param model: QuantModel 51 | :param layer: QuantModule or QuantBlock 52 | :param cali_data: calibration data set 53 | :param weight_quant: use weight_quant quantization 54 | :param act_quant: use act_quant quantization 55 | :param batch_size: mini-batch size for calibration 56 | :param keep_gpu: put saved data on GPU for faster optimization 57 | :return: input and output data 58 | """ 59 | device = next(model.parameters()).device 60 | get_inp_out = GetLayerInpOut(model, layer, device=device, input_prob=input_prob) 61 | cached_batches = [] 62 | 63 | for i in range(int(cali_data.size(0) / batch_size)): 64 | cur_inp = get_inp_out(cali_data[i * batch_size:(i + 1) * batch_size]) 65 | cached_batches.append(cur_inp.cpu()) 66 | cached_inps = torch.cat([x for x in cached_batches]) 67 | torch.cuda.empty_cache() 68 | if keep_gpu: 69 | cached_inps = cached_inps.to(device) 70 | 71 | return cached_inps 72 | 73 | 74 | class StopForwardException(Exception): 75 | """ 76 | Used to throw and catch an exception to stop traversing the graph 77 | """ 78 | pass 79 | 80 | 81 | class DataSaverHook: 82 | """ 83 | Forward hook that stores the input and output of a block 84 | """ 85 | 86 | def __init__(self, store_input=False, store_output=False, stop_forward=False): 87 | self.store_input = store_input 88 | self.store_output = store_output 89 | self.stop_forward = stop_forward 90 | 91 | self.input_store = None 92 | self.output_store = None 93 | 94 | def __call__(self, module, input_batch, output_batch): 95 | if self.store_input: 96 | self.input_store = input_batch 97 | if self.store_output: 98 | self.output_store = output_batch 99 | if self.stop_forward: 100 | raise StopForwardException 101 | 102 | 103 | class input_hook(object): 104 | """ 105 | Forward_hook used to get the output of the intermediate layer. 106 | """ 107 | def __init__(self, stop_forward=False): 108 | super(input_hook, self).__init__() 109 | self.inputs = None 110 | 111 | def hook(self, module, input, output): 112 | self.inputs = input 113 | 114 | def clear(self): 115 | self.inputs = None 116 | 117 | class GetLayerInpOut: 118 | def __init__(self, model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], 119 | device: torch.device, input_prob: bool = False): 120 | self.model = model 121 | self.layer = layer 122 | self.device = device 123 | self.data_saver = DataSaverHook(store_input=True, store_output=False, stop_forward=True) 124 | self.input_prob = input_prob 125 | 126 | def __call__(self, model_input): 127 | 128 | handle = self.layer.register_forward_hook(self.data_saver) 129 | with torch.no_grad(): 130 | self.model.set_quant_state(weight_quant=True, act_quant=True) 131 | try: 132 | _ = self.model(model_input.to(self.device)) 133 | except StopForwardException: 134 | pass 135 | 136 | handle.remove() 137 | 138 | return self.data_saver.input_store[0].detach() 139 | 140 | class GetDcFpLayerInpOut: 141 | def __init__(self, model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], 142 | device: torch.device, input_prob: bool = False, lamb=50, bn_lr=1e-3): 143 | self.model = model 144 | self.layer = layer 145 | self.device = device 146 | self.data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=False) 147 | self.input_prob = input_prob 148 | self.bn_stats = [] 149 | self.eps = 1e-6 150 | self.lamb=lamb 151 | self.bn_lr=bn_lr 152 | for n, m in self.layer.named_modules(): 153 | if isinstance(m, nn.BatchNorm2d): 154 | # get the statistics in the BatchNorm layers 155 | self.bn_stats.append( 156 | (m.running_mean.detach().clone().flatten().cuda(), 157 | torch.sqrt(m.running_var + 158 | self.eps).detach().clone().flatten().cuda())) 159 | 160 | def own_loss(self, A, B): 161 | return (A - B).norm()**2 / B.size(0) 162 | 163 | def relative_loss(self, A, B): 164 | return (A-B).abs().mean()/A.abs().mean() 165 | 166 | def __call__(self, model_input): 167 | self.model.set_quant_state(False, False) 168 | handle = self.layer.register_forward_hook(self.data_saver) 169 | hooks = [] 170 | hook_handles = [] 171 | for name, module in self.layer.named_modules(): 172 | if isinstance(module, nn.BatchNorm2d): 173 | hook = input_hook() 174 | hooks.append(hook) 175 | hook_handles.append(module.register_forward_hook(hook.hook)) 176 | assert len(hooks) == len(self.bn_stats) 177 | 178 | with torch.no_grad(): 179 | try: 180 | output_fp = self.model(model_input.to(self.device)) 181 | except StopForwardException: 182 | pass 183 | if self.input_prob: 184 | input_sym = self.data_saver.input_store[0].detach() 185 | 186 | handle.remove() 187 | para_input = input_sym.data.clone() 188 | para_input = para_input.to(self.device) 189 | para_input.requires_grad = True 190 | optimizer = optim.Adam([para_input], lr=self.bn_lr) 191 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 192 | min_lr=1e-5, 193 | verbose=False, 194 | patience=100) 195 | iters=500 196 | for iter in range(iters): 197 | self.layer.zero_grad() 198 | optimizer.zero_grad() 199 | for hook in hooks: 200 | hook.clear() 201 | _ = self.layer(para_input) 202 | mean_loss = 0 203 | std_loss = 0 204 | for num, (bn_stat, hook) in enumerate(zip(self.bn_stats, hooks)): 205 | tmp_input = hook.inputs[0] 206 | bn_mean, bn_std = bn_stat[0], bn_stat[1] 207 | tmp_mean = torch.mean(tmp_input.view(tmp_input.size(0), 208 | tmp_input.size(1), -1), 209 | dim=2) 210 | tmp_std = torch.sqrt( 211 | torch.var(tmp_input.view(tmp_input.size(0), 212 | tmp_input.size(1), -1), 213 | dim=2) + self.eps) 214 | mean_loss += self.own_loss(bn_mean, tmp_mean) 215 | std_loss += self.own_loss(bn_std, tmp_std) 216 | constraint_loss = lp_loss(para_input, input_sym) / self.lamb 217 | total_loss = mean_loss + std_loss + constraint_loss 218 | total_loss.backward() 219 | optimizer.step() 220 | scheduler.step(total_loss.item()) 221 | # if (iter+1) % 500 == 0: 222 | # print('Total loss:\t{:.3f} (mse:{:.3f}, mean:{:.3f}, std:{:.3f})\tcount={}'.format( 223 | # float(total_loss), float(constraint_loss), float(mean_loss), float(std_loss), iter)) 224 | 225 | with torch.no_grad(): 226 | out_fp = self.layer(para_input) 227 | 228 | if self.input_prob: 229 | return out_fp.detach(), output_fp.detach(), para_input.detach() 230 | return out_fp.detach(), output_fp.detach() -------------------------------------------------------------------------------- /quant/fold_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | class StraightThrough(nn.Module): 7 | def __int__(self): 8 | super().__init__() 9 | 10 | def forward(self, input): 11 | return input 12 | 13 | 14 | def _fold_bn(conv_module, bn_module): 15 | w = conv_module.weight.data 16 | y_mean = bn_module.running_mean 17 | y_var = bn_module.running_var 18 | safe_std = torch.sqrt(y_var + bn_module.eps) 19 | w_view = (conv_module.out_channels, 1, 1, 1) 20 | if bn_module.affine: 21 | weight = w * (bn_module.weight / safe_std).view(w_view) 22 | beta = bn_module.bias - bn_module.weight * y_mean / safe_std 23 | if conv_module.bias is not None: 24 | bias = bn_module.weight * conv_module.bias / safe_std + beta 25 | else: 26 | bias = beta 27 | else: 28 | weight = w / safe_std.view(w_view) 29 | beta = -y_mean / safe_std 30 | if conv_module.bias is not None: 31 | bias = conv_module.bias / safe_std + beta 32 | else: 33 | bias = beta 34 | return weight, bias 35 | 36 | 37 | def fold_bn_into_conv(conv_module, bn_module): 38 | w, b = _fold_bn(conv_module, bn_module) 39 | if conv_module.bias is None: 40 | conv_module.bias = nn.Parameter(b) 41 | else: 42 | conv_module.bias.data = b 43 | conv_module.weight.data = w 44 | # set bn running stats 45 | bn_module.running_mean = bn_module.bias.data 46 | bn_module.running_var = bn_module.weight.data ** 2 47 | 48 | 49 | def reset_bn(module: nn.BatchNorm2d): 50 | if module.track_running_stats: 51 | module.running_mean.zero_() 52 | module.running_var.fill_(1-module.eps) 53 | # we do not reset numer of tracked batches here 54 | # self.num_batches_tracked.zero_() 55 | if module.affine: 56 | init.ones_(module.weight) 57 | init.zeros_(module.bias) 58 | 59 | 60 | def is_bn(m): 61 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 62 | 63 | 64 | def is_absorbing(m): 65 | return (isinstance(m, nn.Conv2d)) or isinstance(m, nn.Linear) 66 | 67 | 68 | def search_fold_and_remove_bn(model): 69 | model.eval() 70 | prev = None 71 | for n, m in model.named_children(): 72 | if is_bn(m) and is_absorbing(prev): 73 | fold_bn_into_conv(prev, m) 74 | # set the bn module to straight through 75 | setattr(model, n, StraightThrough()) 76 | elif is_absorbing(m): 77 | prev = m 78 | else: 79 | prev = search_fold_and_remove_bn(m) 80 | return prev 81 | 82 | 83 | def search_fold_and_reset_bn(model): 84 | model.eval() 85 | prev = None 86 | for n, m in model.named_children(): 87 | if is_bn(m) and is_absorbing(prev): 88 | fold_bn_into_conv(prev, m) 89 | # reset_bn(m) 90 | else: 91 | search_fold_and_reset_bn(m) 92 | prev = m 93 | 94 | -------------------------------------------------------------------------------- /quant/layer_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .quant_layer import QuantModule, lp_loss 4 | from .quant_model import QuantModel 5 | from .block_recon import LinearTempDecay 6 | from .adaptive_rounding import AdaRoundQuantizer 7 | from .set_weight_quantize_params import get_init, get_dc_fp_init 8 | from .set_act_quantize_params import set_act_quantize_params 9 | from .quant_block import BaseQuantBlock, specials_unquantized 10 | 11 | include = False 12 | def find_unquantized_module(model: torch.nn.Module, module_list: list = [], name_list: list = []): 13 | """Store subsequent unquantized modules in a list""" 14 | global include 15 | for name, module in model.named_children(): 16 | if isinstance(module, (QuantModule, BaseQuantBlock)): 17 | if not module.trained: 18 | include = True 19 | module.set_quant_state(False,False) 20 | name_list.append(name) 21 | module_list.append(module) 22 | elif include and type(module) in specials_unquantized: 23 | name_list.append(name) 24 | module_list.append(module) 25 | else: 26 | find_unquantized_module(module, module_list, name_list) 27 | return module_list[1:], name_list[1:] 28 | 29 | def layer_reconstruction(model: QuantModel, fp_model: QuantModel, layer: QuantModule, fp_layer: QuantModule, 30 | cali_data: torch.Tensor,batch_size: int = 32, iters: int = 20000, weight: float = 0.001, 31 | opt_mode: str = 'mse', b_range: tuple = (20, 2), 32 | warmup: float = 0.0, p: float = 2.0, lr: float = 4e-5, input_prob: float = 1.0, 33 | keep_gpu: bool = True, lamb_r: float = 0.2, T: float = 7.0, bn_lr: float = 1e-3, lamb_c=0.02): 34 | """ 35 | Reconstruction to optimize the output from each layer. 36 | 37 | :param model: QuantModel 38 | :param layer: QuantModule that needs to be optimized 39 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound 40 | :param batch_size: mini-batch size for reconstruction 41 | :param iters: optimization iterations for reconstruction, 42 | :param weight: the weight of rounding regularization term 43 | :param opt_mode: optimization mode 44 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output 45 | :param include_act_func: optimize the output after activation function 46 | :param b_range: temperature range 47 | :param warmup: proportion of iterations that no scheduling for temperature 48 | :param lr: learning rate for act delta learning 49 | :param p: L_p norm minimization 50 | :param lamb_r: hyper-parameter for regularization 51 | :param T: temperature coefficient for KL divergence 52 | :param bn_lr: learning rate for DC 53 | :param lamb_c: hyper-parameter for DC 54 | """ 55 | 56 | '''get input and set scale''' 57 | cached_inps = get_init(model, layer, cali_data, batch_size=batch_size, 58 | input_prob=True, keep_gpu=keep_gpu) 59 | cached_outs, cached_output, cur_syms = get_dc_fp_init(fp_model, fp_layer, cali_data, batch_size=batch_size, 60 | input_prob=True, keep_gpu=keep_gpu, bn_lr=bn_lr, lamb=lamb_c) 61 | set_act_quantize_params(layer, cali_data=cached_inps[:min(256, cached_inps.size(0))]) 62 | 63 | '''set state''' 64 | cur_weight, cur_act = True, True 65 | 66 | global include 67 | module_list, name_list, include = [], [], False 68 | module_list, name_list = find_unquantized_module(model, module_list, name_list) 69 | layer.set_quant_state(cur_weight, cur_act) 70 | for para in model.parameters(): 71 | para.requires_grad = False 72 | 73 | '''set quantizer''' 74 | round_mode = 'learned_hard_sigmoid' 75 | # Replace weight quantizer to AdaRoundQuantizer 76 | w_para, a_para = [], [] 77 | w_opt, a_opt = None, None 78 | scheduler, a_scheduler = None, None 79 | 80 | '''weight''' 81 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode, 82 | weight_tensor=layer.org_weight.data) 83 | layer.weight_quantizer.soft_targets = True 84 | w_para += [layer.weight_quantizer.alpha] 85 | 86 | '''activation''' 87 | if layer.act_quantizer.delta is not None: 88 | layer.act_quantizer.delta = torch.nn.Parameter(torch.tensor(layer.act_quantizer.delta)) 89 | a_para += [layer.act_quantizer.delta] 90 | '''set up drop''' 91 | layer.act_quantizer.is_training = True 92 | 93 | if len(w_para) != 0: 94 | w_opt = torch.optim.Adam(w_para, lr=3e-3) 95 | if len(a_para) != 0: 96 | a_opt = torch.optim.Adam(a_para, lr=lr) 97 | a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=iters, eta_min=0.) 98 | 99 | loss_mode = 'relaxation' 100 | rec_loss = opt_mode 101 | loss_func = LossFunction(layer, round_loss=loss_mode, weight=weight, 102 | max_count=iters, rec_loss=rec_loss, b_range=b_range, 103 | decay_start=0, warmup=warmup, p=p, lam=lamb_r, T=T) 104 | device = 'cuda' 105 | sz = cached_inps.size(0) 106 | for i in range(iters): 107 | idx = torch.randint(0, sz, (batch_size,)) 108 | cur_inp = cached_inps[idx].to(device) 109 | cur_sym = cur_syms[idx].to(device) 110 | output_fp = cached_output[idx].to(device) 111 | cur_out = cached_outs[idx].to(device) 112 | if input_prob < 1.0: 113 | drop_inp = torch.where(torch.rand_like(cur_inp) < input_prob, cur_inp, cur_sym) 114 | 115 | cur_inp = torch.cat((drop_inp, cur_inp)) 116 | 117 | if w_opt: 118 | w_opt.zero_grad() 119 | if a_opt: 120 | a_opt.zero_grad() 121 | out_all = layer(cur_inp) 122 | 123 | '''forward for prediction difference''' 124 | out_drop = out_all[:batch_size] 125 | out_quant = out_all[batch_size:] 126 | output = out_quant 127 | for num, module in enumerate(module_list): 128 | # for ResNet and RegNet 129 | if name_list[num] == 'fc': 130 | output = torch.flatten(output, 1) 131 | # for MobileNet and MNasNet 132 | if isinstance(module, torch.nn.Dropout): 133 | output = output.mean([2, 3]) 134 | output = module(output) 135 | err = loss_func(out_drop, cur_out, output, output_fp) 136 | 137 | err.backward(retain_graph=True) 138 | if w_opt: 139 | w_opt.step() 140 | if a_opt: 141 | a_opt.step() 142 | if scheduler: 143 | scheduler.step() 144 | if a_scheduler: 145 | a_scheduler.step() 146 | torch.cuda.empty_cache() 147 | 148 | layer.weight_quantizer.soft_targets = False 149 | layer.act_quantizer.is_training = False 150 | layer.trained = True 151 | 152 | 153 | class LossFunction: 154 | def __init__(self, 155 | layer: QuantModule, 156 | round_loss: str = 'relaxation', 157 | weight: float = 1., 158 | rec_loss: str = 'mse', 159 | max_count: int = 2000, 160 | b_range: tuple = (10, 2), 161 | decay_start: float = 0.0, 162 | warmup: float = 0.0, 163 | p: float = 2., 164 | lam: float = 1.0, 165 | T: float = 7.0): 166 | 167 | self.layer = layer 168 | self.round_loss = round_loss 169 | self.weight = weight 170 | self.rec_loss = rec_loss 171 | self.loss_start = max_count * warmup 172 | self.p = p 173 | self.lam = lam 174 | self.T = T 175 | 176 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start, 177 | start_b=b_range[0], end_b=b_range[1]) 178 | self.count = 0 179 | self.pd_loss = torch.nn.KLDivLoss(reduction='batchmean') 180 | 181 | def __call__(self, pred, tgt, output, output_fp): 182 | """ 183 | Compute the total loss for adaptive rounding: 184 | rec_loss is the quadratic output reconstruction loss, round_loss is 185 | a regularization term to optimize the rounding policy, pd_loss is the 186 | prediction difference loss. 187 | 188 | :param pred: output from quantized model 189 | :param tgt: output from FP model 190 | :param output: prediction from quantized model 191 | :param output_fp: prediction from FP model 192 | :return: total loss function 193 | """ 194 | self.count += 1 195 | if self.rec_loss == 'mse': 196 | rec_loss = lp_loss(pred, tgt, p=self.p) 197 | else: 198 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) 199 | 200 | pd_loss = self.pd_loss(F.log_softmax(output / self.T, dim=1), F.softmax(output_fp / self.T, dim=1)) / self.lam 201 | 202 | b = self.temp_decay(self.count) 203 | if self.count < self.loss_start or self.round_loss == 'none': 204 | b = round_loss = 0 205 | elif self.round_loss == 'relaxation': 206 | round_loss = 0 207 | round_vals = self.layer.weight_quantizer.get_soft_targets() 208 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() 209 | else: 210 | raise NotImplementedError 211 | total_loss = rec_loss + round_loss + pd_loss 212 | if self.count % 500 == 0: 213 | print('Total loss:\t{:.3f} (rec:{:.3f}, pd:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format( 214 | float(total_loss), float(rec_loss), float(pd_loss), float(round_loss), b, self.count)) 215 | return total_loss 216 | -------------------------------------------------------------------------------- /quant/quant_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .quant_layer import QuantModule, UniformAffineQuantizer 3 | from models.resnet import BasicBlock, Bottleneck 4 | from models.regnet import ResBottleneckBlock 5 | from models.mobilenetv2 import InvertedResidual 6 | from models.mnasnet import _InvertedResidual 7 | 8 | 9 | class BaseQuantBlock(nn.Module): 10 | """ 11 | Base implementation of block structures for all networks. 12 | Due to the branch architecture, we have to perform activation function 13 | and quantization after the elemental-wise add operation, therefore, we 14 | put this part in this class. 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | self.use_weight_quant = False 19 | self.use_act_quant = False 20 | self.ignore_reconstruction = False 21 | self.trained = False 22 | 23 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 24 | # setting weight quantization here does not affect actual forward pass 25 | self.use_weight_quant = weight_quant 26 | self.use_act_quant = act_quant 27 | for m in self.modules(): 28 | if isinstance(m, QuantModule): 29 | m.set_quant_state(weight_quant, act_quant) 30 | 31 | 32 | class QuantBasicBlock(BaseQuantBlock): 33 | """ 34 | Implementation of Quantized BasicBlock used in ResNet-18 and ResNet-34. 35 | """ 36 | def __init__(self, basic_block: BasicBlock, weight_quant_params: dict = {}, act_quant_params: dict = {}): 37 | super().__init__() 38 | self.conv1 = QuantModule(basic_block.conv1, weight_quant_params, act_quant_params) 39 | self.conv1.norm_function = basic_block.bn1 40 | self.conv1.activation_function = basic_block.relu1 41 | self.conv2 = QuantModule(basic_block.conv2, weight_quant_params, act_quant_params, disable_act_quant=True) 42 | self.conv2.norm_function = basic_block.bn2 43 | 44 | if basic_block.downsample is None: 45 | self.downsample = None 46 | else: 47 | self.downsample = QuantModule(basic_block.downsample[0], weight_quant_params, act_quant_params, 48 | disable_act_quant=True) 49 | self.downsample.norm_function = basic_block.downsample[1] 50 | self.activation_function = basic_block.relu2 51 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 52 | 53 | def forward(self, x): 54 | residual = x if self.downsample is None else self.downsample(x) 55 | out = self.conv1(x) 56 | out = self.conv2(out) 57 | out += residual 58 | out = self.activation_function(out) 59 | if self.use_act_quant: 60 | out = self.act_quantizer(out) 61 | return out 62 | 63 | 64 | class QuantBottleneck(BaseQuantBlock): 65 | """ 66 | Implementation of Quantized Bottleneck Block used in ResNet-50, -101 and -152. 67 | """ 68 | 69 | def __init__(self, bottleneck: Bottleneck, weight_quant_params: dict = {}, act_quant_params: dict = {}): 70 | super().__init__() 71 | self.conv1 = QuantModule(bottleneck.conv1, weight_quant_params, act_quant_params) 72 | self.conv1.norm_function = bottleneck.bn1 73 | self.conv1.activation_function = bottleneck.relu1 74 | self.conv2 = QuantModule(bottleneck.conv2, weight_quant_params, act_quant_params) 75 | self.conv2.norm_function = bottleneck.bn2 76 | self.conv2.activation_function = bottleneck.relu2 77 | self.conv3 = QuantModule(bottleneck.conv3, weight_quant_params, act_quant_params, disable_act_quant=True) 78 | self.conv3.norm_function = bottleneck.bn3 79 | 80 | if bottleneck.downsample is None: 81 | self.downsample = None 82 | else: 83 | self.downsample = QuantModule(bottleneck.downsample[0], weight_quant_params, act_quant_params, 84 | disable_act_quant=True) 85 | self.downsample.norm_function = bottleneck.downsample[1] 86 | # modify the activation function to ReLU 87 | self.activation_function = bottleneck.relu3 88 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 89 | 90 | def forward(self, x): 91 | residual = x if self.downsample is None else self.downsample(x) 92 | out = self.conv1(x) 93 | out = self.conv2(out) 94 | out = self.conv3(out) 95 | out += residual 96 | out = self.activation_function(out) 97 | if self.use_act_quant: 98 | out = self.act_quantizer(out) 99 | return out 100 | 101 | 102 | class QuantResBottleneckBlock(BaseQuantBlock): 103 | """ 104 | Implementation of Quantized Bottleneck Blockused in RegNetX (no SE module). 105 | """ 106 | 107 | def __init__(self, bottleneck: ResBottleneckBlock, weight_quant_params: dict = {}, act_quant_params: dict = {}): 108 | super().__init__() 109 | self.conv1 = QuantModule(bottleneck.f.a, weight_quant_params, act_quant_params) 110 | self.conv1.norm_function = bottleneck.f.a_bn 111 | self.conv1.activation_function = bottleneck.f.a_relu 112 | self.conv2 = QuantModule(bottleneck.f.b, weight_quant_params, act_quant_params) 113 | self.conv2.norm_function = bottleneck.f.b_bn 114 | self.conv2.activation_function = bottleneck.f.b_relu 115 | self.conv3 = QuantModule(bottleneck.f.c, weight_quant_params, act_quant_params, disable_act_quant=True) 116 | self.conv3.norm_function = bottleneck.f.c_bn 117 | 118 | if bottleneck.proj_block: 119 | self.downsample = QuantModule(bottleneck.proj, weight_quant_params, act_quant_params, 120 | disable_act_quant=True) 121 | self.downsample.norm_function = bottleneck.bn 122 | else: 123 | self.downsample = None 124 | # copying all attributes in original block 125 | self.proj_block = bottleneck.proj_block 126 | 127 | self.activation_function = bottleneck.relu 128 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 129 | 130 | def forward(self, x): 131 | residual = x if not self.proj_block else self.downsample(x) 132 | out = self.conv1(x) 133 | out = self.conv2(out) 134 | out = self.conv3(out) 135 | out += residual 136 | out = self.activation_function(out) 137 | if self.use_act_quant: 138 | out = self.act_quantizer(out) 139 | return out 140 | 141 | 142 | class QuantInvertedResidual(BaseQuantBlock): 143 | """ 144 | Implementation of Quantized Inverted Residual Block used in MobileNetV2. 145 | Inverted Residual does not have activation function. 146 | """ 147 | 148 | def __init__(self, inv_res: InvertedResidual, weight_quant_params: dict = {}, act_quant_params: dict = {}): 149 | super().__init__() 150 | 151 | self.use_res_connect = inv_res.use_res_connect 152 | self.expand_ratio = inv_res.expand_ratio 153 | if self.expand_ratio == 1: 154 | self.conv = nn.Sequential( 155 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), 156 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params, disable_act_quant=True), 157 | ) 158 | self.conv[0].norm_function = inv_res.conv[1] 159 | self.conv[0].activation_function = nn.ReLU6() 160 | self.conv[1].norm_function = inv_res.conv[4] 161 | else: 162 | self.conv = nn.Sequential( 163 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), 164 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params), 165 | QuantModule(inv_res.conv[6], weight_quant_params, act_quant_params, disable_act_quant=True), 166 | ) 167 | self.conv[0].norm_function = inv_res.conv[1] 168 | self.conv[0].activation_function = nn.ReLU6() 169 | self.conv[1].norm_function = inv_res.conv[4] 170 | self.conv[1].activation_function = nn.ReLU6() 171 | self.conv[2].norm_function = inv_res.conv[7] 172 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 173 | 174 | def forward(self, x): 175 | if self.use_res_connect: 176 | out = x + self.conv(x) 177 | else: 178 | out = self.conv(x) 179 | if self.use_act_quant: 180 | out = self.act_quantizer(out) 181 | return out 182 | 183 | 184 | class _QuantInvertedResidual(BaseQuantBlock): 185 | def __init__(self, _inv_res: _InvertedResidual, weight_quant_params: dict = {}, act_quant_params: dict = {}): 186 | super().__init__() 187 | 188 | self.apply_residual = _inv_res.apply_residual 189 | self.conv = nn.Sequential( 190 | QuantModule(_inv_res.layers[0], weight_quant_params, act_quant_params), 191 | QuantModule(_inv_res.layers[3], weight_quant_params, act_quant_params), 192 | QuantModule(_inv_res.layers[6], weight_quant_params, act_quant_params, disable_act_quant=True), 193 | ) 194 | self.conv[0].activation_function = nn.ReLU() 195 | self.conv[0].norm_function = _inv_res.layers[1] 196 | self.conv[1].activation_function = nn.ReLU() 197 | self.conv[1].norm_function = _inv_res.layers[4] 198 | self.conv[2].norm_function = _inv_res.layers[7] 199 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 200 | 201 | def forward(self, x): 202 | if self.apply_residual: 203 | out = x + self.conv(x) 204 | else: 205 | out = self.conv(x) 206 | if self.use_act_quant: 207 | out = self.act_quantizer(out) 208 | return out 209 | 210 | 211 | specials = { 212 | BasicBlock: QuantBasicBlock, 213 | Bottleneck: QuantBottleneck, 214 | ResBottleneckBlock: QuantResBottleneckBlock, 215 | InvertedResidual: QuantInvertedResidual, 216 | _InvertedResidual: _QuantInvertedResidual, 217 | } 218 | 219 | specials_unquantized = [nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.Dropout] 220 | -------------------------------------------------------------------------------- /quant/quant_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Union 5 | 6 | 7 | class StraightThrough(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, input): 12 | return input 13 | 14 | 15 | def round_ste(x: torch.Tensor): 16 | """ 17 | Implement Straight-Through Estimator for rounding operation. 18 | """ 19 | return (x.round() - x).detach() + x 20 | 21 | 22 | def lp_loss(pred, tgt, p=2.0, reduction='none'): 23 | """ 24 | loss function measured in L_p Norm 25 | """ 26 | if reduction == 'none': 27 | return (pred - tgt).abs().pow(p).sum(1).mean() 28 | else: 29 | return (pred - tgt).abs().pow(p).mean() 30 | 31 | 32 | class UniformAffineQuantizer(nn.Module): 33 | """ 34 | PyTorch Function that can be used for asymmetric quantization (also called uniform affine 35 | quantization). Quantizes its argument in the forward pass, passes the gradient 'straight 36 | through' on the backward pass, ignoring the quantization that occurred. 37 | Based on https://arxiv.org/abs/1806.08342. 38 | 39 | :param n_bits: number of bit for quantization 40 | :param symmetric: if True, the zero_point should always be 0 41 | :param channel_wise: if True, compute scale and zero_point in each channel 42 | :param scale_method: determines the quantization scale and zero point 43 | :param prob: for qdrop; 44 | """ 45 | 46 | def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False, 47 | scale_method: str = 'minmax', 48 | leaf_param: bool = False, prob: float = 1.0): 49 | super(UniformAffineQuantizer, self).__init__() 50 | self.sym = symmetric 51 | if self.sym: 52 | raise NotImplementedError 53 | assert 2 <= n_bits <= 8, 'bitwidth not supported' 54 | self.n_bits = n_bits 55 | self.n_levels = 2 ** self.n_bits 56 | self.delta = 1.0 57 | self.zero_point = 0.0 58 | self.inited = True 59 | 60 | '''if leaf_param, use EMA to set scale''' 61 | self.leaf_param = leaf_param 62 | self.channel_wise = channel_wise 63 | self.eps = torch.tensor(1e-8, dtype=torch.float32) 64 | 65 | '''mse params''' 66 | self.scale_method = 'mse' 67 | self.one_side_dist = None 68 | self.num = 100 69 | 70 | '''for activation quantization''' 71 | self.running_min = None 72 | self.running_max = None 73 | 74 | '''do like dropout''' 75 | self.prob = prob 76 | self.is_training = False 77 | 78 | def set_inited(self, inited: bool = True): # inited manually 79 | self.inited = inited 80 | 81 | def update_quantize_range(self, x_min, x_max): 82 | if self.running_min is None: 83 | self.running_min = x_min 84 | self.running_max = x_max 85 | self.running_min = 0.1 * x_min + 0.9 * self.running_min 86 | self.running_max = 0.1 * x_max + 0.9 * self.running_max 87 | return self.running_min, self.running_max 88 | 89 | def forward(self, x: torch.Tensor): 90 | if self.inited is False: 91 | if self.leaf_param: 92 | self.delta, self.zero_point = self.init_quantization_scale(x.clone().detach(), self.channel_wise) 93 | else: 94 | self.delta, self.zero_point = self.init_quantization_scale(x.clone().detach(), self.channel_wise) 95 | 96 | # start quantization 97 | x_int = round_ste(x / self.delta) + self.zero_point 98 | x_quant = torch.clamp(x_int, 0, self.n_levels - 1) 99 | x_dequant = (x_quant - self.zero_point) * self.delta 100 | if self.is_training and self.prob < 1.0: 101 | x_ans = torch.where(torch.rand_like(x) < self.prob, x_dequant, x) 102 | else: 103 | x_ans = x_dequant 104 | return x_ans 105 | 106 | def lp_loss(self, pred, tgt, p=2.0): 107 | x = (pred - tgt).abs().pow(p) 108 | if not self.channel_wise: 109 | return x.mean() 110 | else: 111 | y = torch.flatten(x, 1) 112 | return y.mean(1) 113 | 114 | def calculate_qparams(self, min_val, max_val): 115 | # one_dim or one element 116 | quant_min, quant_max = 0, self.n_levels - 1 117 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 118 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 119 | 120 | scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) 121 | scale = torch.max(scale, self.eps) 122 | zero_point = quant_min - torch.round(min_val_neg / scale) 123 | zero_point = torch.clamp(zero_point, quant_min, quant_max) 124 | return scale, zero_point 125 | 126 | def quantize(self, x: torch.Tensor, x_max, x_min): 127 | delta, zero_point = self.calculate_qparams(x_min, x_max) 128 | if self.channel_wise: 129 | new_shape = [1] * len(x.shape) 130 | new_shape[0] = x.shape[0] 131 | delta = delta.reshape(new_shape) 132 | zero_point = zero_point.reshape(new_shape) 133 | x_int = torch.round(x / delta) 134 | x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1) 135 | x_float_q = (x_quant - zero_point) * delta 136 | return x_float_q 137 | 138 | def perform_2D_search(self, x): 139 | if self.channel_wise: 140 | y = torch.flatten(x, 1) 141 | x_min, x_max = torch._aminmax(y, 1) 142 | # may also have the one side distribution in some channels 143 | x_max = torch.max(x_max, torch.zeros_like(x_max)) 144 | x_min = torch.min(x_min, torch.zeros_like(x_min)) 145 | else: 146 | x_min, x_max = torch._aminmax(x) 147 | xrange = x_max - x_min 148 | best_score = torch.zeros_like(x_min) + (1e+10) 149 | best_min = x_min.clone() 150 | best_max = x_max.clone() 151 | # enumerate xrange 152 | for i in range(1, self.num + 1): 153 | tmp_min = torch.zeros_like(x_min) 154 | tmp_max = xrange / self.num * i 155 | tmp_delta = (tmp_max - tmp_min) / (2 ** self.n_bits - 1) 156 | # enumerate zp 157 | for zp in range(0, self.n_levels): 158 | new_min = tmp_min - zp * tmp_delta 159 | new_max = tmp_max - zp * tmp_delta 160 | x_q = self.quantize(x, new_max, new_min) 161 | score = self.lp_loss(x, x_q, 2.4) 162 | best_min = torch.where(score < best_score, new_min, best_min) 163 | best_max = torch.where(score < best_score, new_max, best_max) 164 | best_score = torch.min(best_score, score) 165 | return best_min, best_max 166 | 167 | def perform_1D_search(self, x): 168 | if self.channel_wise: 169 | y = torch.flatten(x, 1) 170 | x_min, x_max = torch._aminmax(y, 1) 171 | else: 172 | x_min, x_max = torch._aminmax(x) 173 | xrange = torch.max(x_min.abs(), x_max) 174 | best_score = torch.zeros_like(x_min) + (1e+10) 175 | best_min = x_min.clone() 176 | best_max = x_max.clone() 177 | # enumerate xrange 178 | for i in range(1, self.num + 1): 179 | thres = xrange / self.num * i 180 | new_min = torch.zeros_like(x_min) if self.one_side_dist == 'pos' else -thres 181 | new_max = torch.zeros_like(x_max) if self.one_side_dist == 'neg' else thres 182 | x_q = self.quantize(x, new_max, new_min) 183 | score = self.lp_loss(x, x_q, 2.4) 184 | best_min = torch.where(score < best_score, new_min, best_min) 185 | best_max = torch.where(score < best_score, new_max, best_max) 186 | best_score = torch.min(score, best_score) 187 | return best_min, best_max 188 | 189 | def get_x_min_x_max(self, x): 190 | if self.scale_method != 'mse': 191 | raise NotImplementedError 192 | if self.one_side_dist is None: 193 | self.one_side_dist = 'pos' if x.min() >= 0.0 else 'neg' if x.max() <= 0.0 else 'no' 194 | if self.one_side_dist != 'no' or self.sym: # one-side distribution or symmetric value for 1-d search 195 | best_min, best_max = self.perform_1D_search(x) 196 | else: # 2-d search 197 | best_min, best_max = self.perform_2D_search(x) 198 | if self.leaf_param: 199 | return self.update_quantize_range(best_min, best_max) 200 | return best_min, best_max 201 | 202 | def init_quantization_scale_channel(self, x: torch.Tensor): 203 | x_min, x_max = self.get_x_min_x_max(x) 204 | return self.calculate_qparams(x_min, x_max) 205 | 206 | def init_quantization_scale(self, x_clone: torch.Tensor, channel_wise: bool = False): 207 | if channel_wise: 208 | # determine the scale and zero point channel-by-channel 209 | delta, zero_point = self.init_quantization_scale_channel(x_clone) 210 | new_shape = [1] * len(x_clone.shape) 211 | new_shape[0] = x_clone.shape[0] 212 | delta = delta.reshape(new_shape) 213 | zero_point = zero_point.reshape(new_shape) 214 | else: 215 | delta, zero_point = self.init_quantization_scale_channel(x_clone) 216 | return delta, zero_point 217 | 218 | def bitwidth_refactor(self, refactored_bit: int): 219 | assert 2 <= refactored_bit <= 8, 'bitwidth not supported' 220 | self.n_bits = refactored_bit 221 | self.n_levels = 2 ** self.n_bits 222 | 223 | @torch.jit.export 224 | def extra_repr(self): 225 | return 'bit={}, is_training={}, inited={}'.format( 226 | self.n_bits, self.is_training, self.inited 227 | ) 228 | 229 | 230 | class QuantModule(nn.Module): 231 | """ 232 | Quantized Module that can perform quantized convolution or normal convolution. 233 | To activate quantization, please use set_quant_state function. 234 | """ 235 | 236 | def __init__(self, org_module: Union[nn.Conv2d, nn.Linear], weight_quant_params: dict = {}, 237 | act_quant_params: dict = {}, disable_act_quant=False): 238 | super(QuantModule, self).__init__() 239 | if isinstance(org_module, nn.Conv2d): 240 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding, 241 | dilation=org_module.dilation, groups=org_module.groups) 242 | self.fwd_func = F.conv2d 243 | else: 244 | self.fwd_kwargs = dict() 245 | self.fwd_func = F.linear 246 | self.weight = org_module.weight 247 | self.org_weight = org_module.weight.data.clone() 248 | if org_module.bias is not None: 249 | self.bias = org_module.bias 250 | self.org_bias = org_module.bias.data.clone() 251 | else: 252 | self.bias = None 253 | self.org_bias = None 254 | # de-activate the quantized forward default 255 | self.use_weight_quant = False 256 | self.use_act_quant = False 257 | # initialize quantizer 258 | self.weight_quantizer = UniformAffineQuantizer(**weight_quant_params) 259 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 260 | 261 | self.norm_function = StraightThrough() 262 | self.activation_function = StraightThrough() 263 | self.ignore_reconstruction = False 264 | self.disable_act_quant = disable_act_quant 265 | self.trained = False 266 | 267 | def forward(self, input: torch.Tensor): 268 | if self.use_weight_quant: 269 | weight = self.weight_quantizer(self.weight) 270 | bias = self.bias 271 | else: 272 | weight = self.org_weight 273 | bias = self.org_bias 274 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) 275 | # disable act quantization is designed for convolution before elemental-wise operation, 276 | # in that case, we apply activation function and quantization after ele-wise op. 277 | out = self.norm_function(out) 278 | out = self.activation_function(out) 279 | if self.disable_act_quant: 280 | return out 281 | if self.use_act_quant: 282 | out = self.act_quantizer(out) 283 | return out 284 | 285 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 286 | self.use_weight_quant = weight_quant 287 | self.use_act_quant = act_quant 288 | 289 | @torch.jit.export 290 | def extra_repr(self): 291 | return 'wbit={}, abit={}, disable_act_quant={}'.format( 292 | self.weight_quantizer.n_bits, self.act_quantizer.n_bits, self.disable_act_quant 293 | ) 294 | -------------------------------------------------------------------------------- /quant/quant_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .quant_block import specials, BaseQuantBlock 3 | from .quant_layer import QuantModule, StraightThrough, UniformAffineQuantizer 4 | from .fold_bn import search_fold_and_remove_bn 5 | 6 | 7 | class QuantModel(nn.Module): 8 | 9 | def __init__(self, model: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}, is_fusing=True): 10 | super().__init__() 11 | if is_fusing: 12 | search_fold_and_remove_bn(model) 13 | self.model = model 14 | self.quant_module_refactor(self.model, weight_quant_params, act_quant_params) 15 | else: 16 | self.model = model 17 | self.quant_module_refactor_wo_fuse(self.model, weight_quant_params, act_quant_params) 18 | 19 | def quant_module_refactor(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}): 20 | """ 21 | Recursively replace the normal conv2d and Linear layer to QuantModule 22 | :param module: nn.Module with nn.Conv2d or nn.Linear in its children 23 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer 24 | :param act_quant_params: quantization parameters like n_bits for activation quantizer 25 | """ 26 | prev_quantmodule = None 27 | for name, child_module in module.named_children(): 28 | if type(child_module) in specials: 29 | setattr(module, name, specials[type(child_module)](child_module, weight_quant_params, act_quant_params)) 30 | elif isinstance(child_module, (nn.Conv2d, nn.Linear)): 31 | setattr(module, name, QuantModule(child_module, weight_quant_params, act_quant_params)) 32 | prev_quantmodule = getattr(module, name) 33 | 34 | elif isinstance(child_module, (nn.ReLU, nn.ReLU6)): 35 | if prev_quantmodule is not None: 36 | prev_quantmodule.activation_function = child_module 37 | setattr(module, name, StraightThrough()) 38 | else: 39 | continue 40 | 41 | elif isinstance(child_module, StraightThrough): 42 | continue 43 | 44 | else: 45 | self.quant_module_refactor(child_module, weight_quant_params, act_quant_params) 46 | 47 | def quant_module_refactor_wo_fuse(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}): 48 | """ 49 | Recursively replace the normal conv2d and Linear layer to QuantModule 50 | :param module: nn.Module with nn.Conv2d or nn.Linear in its children 51 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer 52 | :param act_quant_params: quantization parameters like n_bits for activation quantizer 53 | """ 54 | prev_quantmodule = None 55 | for name, child_module in module.named_children(): 56 | if type(child_module) in specials: 57 | setattr(module, name, specials[type(child_module)](child_module, weight_quant_params, act_quant_params)) 58 | elif isinstance(child_module, (nn.Conv2d, nn.Linear)): 59 | setattr(module, name, QuantModule(child_module, weight_quant_params, act_quant_params)) 60 | prev_quantmodule = getattr(module, name) 61 | 62 | elif isinstance(child_module, nn.BatchNorm2d): 63 | if prev_quantmodule is not None: 64 | prev_quantmodule.norm_function = child_module 65 | setattr(module, name, StraightThrough()) 66 | else: 67 | continue 68 | 69 | elif isinstance(child_module, (nn.ReLU, nn.ReLU6)): 70 | if prev_quantmodule is not None: 71 | prev_quantmodule.activation_function = child_module 72 | setattr(module, name, StraightThrough()) 73 | else: 74 | continue 75 | 76 | elif isinstance(child_module, StraightThrough): 77 | continue 78 | 79 | else: 80 | self.quant_module_refactor_wo_fuse(child_module, weight_quant_params, act_quant_params) 81 | 82 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 83 | for m in self.model.modules(): 84 | if isinstance(m, (QuantModule, BaseQuantBlock)): 85 | m.set_quant_state(weight_quant, act_quant) 86 | 87 | def forward(self, input): 88 | return self.model(input) 89 | 90 | def set_first_last_layer_to_8bit(self): 91 | w_list, a_list = [], [] 92 | for module in self.model.modules(): 93 | if isinstance(module, UniformAffineQuantizer): 94 | if module.leaf_param: 95 | a_list.append(module) 96 | else: 97 | w_list.append(module) 98 | w_list[0].bitwidth_refactor(8) 99 | w_list[-1].bitwidth_refactor(8) 100 | 'the image input has been in 0~255, set the last layer\'s input to 8-bit' 101 | a_list[-2].bitwidth_refactor(8) 102 | # a_list[0].bitwidth_refactor(8) 103 | 104 | def disable_network_output_quantization(self): 105 | module_list = [] 106 | for m in self.model.modules(): 107 | if isinstance(m, QuantModule): 108 | module_list += [m] 109 | module_list[-1].disable_act_quant = True 110 | -------------------------------------------------------------------------------- /quant/set_act_quantize_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .quant_layer import QuantModule 3 | from .quant_block import BaseQuantBlock 4 | from .quant_model import QuantModel 5 | from typing import Union 6 | 7 | def set_act_quantize_params(module: Union[QuantModel, QuantModule, BaseQuantBlock], 8 | cali_data, batch_size: int = 256): 9 | module.set_quant_state(True, True) 10 | 11 | for t in module.modules(): 12 | if isinstance(t, (QuantModule, BaseQuantBlock)): 13 | t.act_quantizer.set_inited(False) 14 | 15 | '''set or init step size and zero point in the activation quantizer''' 16 | batch_size = min(batch_size, cali_data.size(0)) 17 | with torch.no_grad(): 18 | for i in range(int(cali_data.size(0) / batch_size)): 19 | module(cali_data[i * batch_size:(i + 1) * batch_size].cuda()) 20 | torch.cuda.empty_cache() 21 | 22 | for t in module.modules(): 23 | if isinstance(t, (QuantModule, BaseQuantBlock)): 24 | t.act_quantizer.set_inited(True) 25 | -------------------------------------------------------------------------------- /quant/set_weight_quantize_params.py: -------------------------------------------------------------------------------- 1 | from .quant_layer import QuantModule 2 | from .data_utils import save_inp_oup_data, save_dc_fp_data 3 | 4 | 5 | def get_init(model, block, cali_data, batch_size, input_prob: bool = False, keep_gpu: bool=True): 6 | cached_inps = save_inp_oup_data(model, block, cali_data, batch_size, input_prob=input_prob, keep_gpu=keep_gpu) 7 | return cached_inps 8 | 9 | def get_dc_fp_init(model, block, cali_data, batch_size, input_prob: bool = False, keep_gpu: bool=True, lamb=50, bn_lr=1e-3): 10 | cached_outs, cached_outputs, cached_sym = save_dc_fp_data(model, block, cali_data, batch_size, input_prob=input_prob, keep_gpu=keep_gpu, lamb=lamb, bn_lr=bn_lr) 11 | return cached_outs, cached_outputs, cached_sym 12 | 13 | def set_weight_quantize_params(model): 14 | for module in model.modules(): 15 | if isinstance(module, QuantModule): 16 | module.weight_quantizer.set_inited(False) 17 | '''caculate the step size and zero point for weight quantizer''' 18 | module.weight_quantizer(module.weight) 19 | module.weight_quantizer.set_inited(True) 20 | 21 | def save_quantized_weight(model): 22 | for module in model.modules(): 23 | if isinstance(module, QuantModule): 24 | module.weight.data = module.weight_quantizer(module.weight) 25 | -------------------------------------------------------------------------------- /run_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("exp_name", type=str, choices=['resnet18', 'resnet50', 'mobilenetv2', 'regnetx_600m', 'regnetx_3200m', 'mnasnet']) 8 | args = parser.parse_args() 9 | w_bits = [2, 4, 2, 4] 10 | a_bits = [2, 2, 4, 4] 11 | 12 | if args.exp_name == "resnet18": 13 | for i in range(4): 14 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch resnet18 --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.02") 15 | time.sleep(0.5) 16 | 17 | if args.exp_name == "resnet50": 18 | for i in range(4): 19 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch resnet50 --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.02") 20 | time.sleep(0.5) 21 | 22 | if args.exp_name == "regnetx_600m": 23 | for i in range(4): 24 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch regnetx_600m --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.01") 25 | time.sleep(0.5) 26 | 27 | if args.exp_name == "regnetx_3200m": 28 | for i in range(4): 29 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch regnetx_3200m --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.01") 30 | time.sleep(0.5) 31 | 32 | if args.exp_name == "mobilenetv2": 33 | for i in range(4): 34 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch mobilenetv2 --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.1 --T 1.0 --lamb_c 0.005") 35 | time.sleep(0.5) 36 | 37 | if args.exp_name == "mnasnet": 38 | for i in range(4): 39 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch mnasnet --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.2 --T 1.0 --lamb_c 0.001") 40 | time.sleep(0.5) 41 | 42 | --------------------------------------------------------------------------------