├── .gitignore
├── LICENSE
├── README.md
├── data
└── imagenet.py
├── hubconf.py
├── main_imagenet.py
├── models
├── __init__.py
├── mnasnet.py
├── mobilenetv2.py
├── regnet.py
└── resnet.py
├── quant
├── __init__.py
├── adaptive_rounding.py
├── block_recon.py
├── data_utils.py
├── fold_bn.py
├── layer_recon.py
├── quant_block.py
├── quant_layer.py
├── quant_model.py
├── set_act_quantize_params.py
└── set_weight_quantize_params.py
└── run_script.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.egg
3 | *.egg-info
4 | *.log
5 | *.pyc
6 | *.npy
7 | *.pth
8 | .DS_Store
9 | __pycache__
10 | model_zoo
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
PD-Quant
3 |
PD-Quant: Post-Training Quantization Based on Prediction Difference Metric
4 |
5 |
8 |
9 |
10 | ## Usage
11 | ### 1. Download pre-trained FP model.
12 | The pre-trained FP models in our experiment comes from [BRECQ](https://github.com/yhhhli/BRECQ), they can be downloaded in [link](https://github.com/yhhhli/BRECQ/releases/tag/v1.0).
13 | And modify the path of the pre-trained model in ```hubconf.py```.
14 |
15 | ### 2. Installation.
16 | ```
17 | python >= 3.7.13
18 | numpy >= 1.21.6
19 | torch >= 1.11.0
20 | torchvision >= 0.12.0
21 | ```
22 |
23 | ### 3. Run experiments
24 | You can run ```run_script.py``` for different models including ResNet18, ResNet50, RegNet600, RegNet3200, MobilenetV2, and MNasNet.
25 | It will experiment on 4 bit settings including W2A2, W4A2, W2A4, and W4A4.
26 |
27 | Take ResNet18 as an example:
28 | ```
29 | python run_script.py resnet18
30 | ```
31 |
32 | ## Results
33 |
34 | | Methods | Bits (W/A) | Res18 |Res50 | MNV2 | Reg600M | Reg3.2G | MNasx2 |
35 | | ------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
36 | | Full Prec. | 32/32 | 71.01 | 76.63 | 72.62 | 73.52 | 78.46 | 76.52 |
37 | |PD-Quant| 4/4 | 69.30 | 75.09 | 68.33 | 71.04 | 76.57 | 73.30 |
38 | |PD-Quant| 2/4 | 65.07 | 70.92 | 55.27 | 64.00 | 72.43 | 63.33|
39 | |PD-Quant| 4/2 | 58.65 | 64.18 | 20.40 | 51.29 | 62.76 | 38.89 |
40 | |PD-Quant| 2/2 | 53.08 | 56.98 | 14.17 | 40.92 | 55.13 | 28.03|
41 |
42 | ## Reference
43 | ```
44 | @article{liu2022pd,
45 | title={PD-Quant: Post-Training Quantization based on Prediction Difference Metric},
46 | author={Liu, Jiawei and Niu, Lin and Yuan, Zhihang and Yang, Dawei and Wang, Xinggang and Liu, Wenyu},
47 | journal={arXiv preprint arXiv:2212.07048},
48 | year={2022}
49 | }
50 | ```
51 |
52 | ## Thanks
53 | Our code is based on [QDROP](https://github.com/wimh966/QDrop) by @wimh966.
54 |
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 |
6 |
7 | def build_imagenet_data(data_path: str = '', input_size: int = 224, batch_size: int = 64, workers: int = 4):
8 | print('==> Using Pytorch Dataset')
9 |
10 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
11 | std=[0.229, 0.224, 0.225])
12 | traindir = os.path.join(data_path, 'train')
13 | valdir = os.path.join(data_path, 'val')
14 | train_dataset = datasets.ImageFolder(
15 | traindir,
16 | transforms.Compose([
17 | transforms.RandomResizedCrop(input_size),
18 | transforms.RandomHorizontalFlip(),
19 | transforms.ToTensor(),
20 | normalize,
21 | ]))
22 |
23 | train_loader = torch.utils.data.DataLoader(
24 | train_dataset, batch_size=batch_size, shuffle=True,
25 | num_workers=workers, pin_memory=True)
26 | val_loader = torch.utils.data.DataLoader(
27 | datasets.ImageFolder(valdir, transforms.Compose([
28 | transforms.Resize(256),
29 | transforms.CenterCrop(input_size),
30 | transforms.ToTensor(),
31 | normalize,
32 | ])),
33 | batch_size=batch_size, shuffle=False,
34 | num_workers=workers, pin_memory=True)
35 | return train_loader, val_loader
36 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from models.resnet import resnet18 as _resnet18
3 | from models.resnet import resnet50 as _resnet50
4 | from models.mobilenetv2 import mobilenetv2 as _mobilenetv2
5 | from models.mnasnet import mnasnet as _mnasnet
6 | from models.regnet import regnetx_600m as _regnetx_600m
7 | from models.regnet import regnetx_3200m as _regnetx_3200m
8 | import torch
9 | dependencies = ['torch']
10 | model_path = {
11 | 'resnet18': '/home/tmp/resnet18_imagenet.pth.tar',
12 | 'resnet50': '/home/tmp/resnet50_imagenet.pth.tar',
13 | 'mbv2': '/home/tmp/mobilenetv2.pth.tar',
14 | 'reg600m': '/home/tmp/regnet_600m.pth.tar',
15 | 'reg3200m': '/home/tmp/regnet_3200m.pth.tar',
16 | 'mnasnet': '/home/tmp/mnasnet.pth.tar',
17 | }
18 |
19 |
20 | def resnet18(pretrained=False, **kwargs):
21 | # Call the model, load pretrained weights
22 | model = _resnet18(**kwargs)
23 | if pretrained:
24 | checkpoint = torch.load(model_path['resnet18'], map_location='cpu')
25 | model.load_state_dict(checkpoint)
26 | return model
27 |
28 |
29 | def resnet50(pretrained=False, **kwargs):
30 | # Call the model, load pretrained weights
31 | model = _resnet50(**kwargs)
32 | if pretrained:
33 | checkpoint = torch.load(model_path['resnet50'], map_location='cpu')
34 | model.load_state_dict(checkpoint)
35 | return model
36 |
37 |
38 | def mobilenetv2(pretrained=False, **kwargs):
39 | # Call the model, load pretrained weights
40 | model = _mobilenetv2(**kwargs)
41 | if pretrained:
42 | checkpoint = torch.load(model_path['mbv2'], map_location='cpu')
43 | model.load_state_dict(checkpoint['model'])
44 | return model
45 |
46 |
47 | def regnetx_600m(pretrained=False, **kwargs):
48 | # Call the model, load pretrained weights
49 | model = _regnetx_600m(**kwargs)
50 | if pretrained:
51 | checkpoint = torch.load(model_path['reg600m'], map_location='cpu')
52 | model.load_state_dict(checkpoint)
53 | return model
54 |
55 |
56 | def regnetx_3200m(pretrained=False, **kwargs):
57 | # Call the model, load pretrained weights
58 | model = _regnetx_3200m(**kwargs)
59 | if pretrained:
60 | checkpoint = torch.load(model_path['reg3200m'], map_location='cpu')
61 | model.load_state_dict(checkpoint)
62 | return model
63 |
64 |
65 | def mnasnet(pretrained=False, **kwargs):
66 | # Call the model, load pretrained weights
67 | model = _mnasnet(**kwargs)
68 | if pretrained:
69 | checkpoint = torch.load(model_path['mnasnet'], map_location='cpu')
70 | model.load_state_dict(checkpoint)
71 | return model
72 |
--------------------------------------------------------------------------------
/main_imagenet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import argparse
5 | import os
6 | import random
7 | import time
8 | import hubconf # noqa: F401
9 | import copy
10 | from quant import (
11 | block_reconstruction,
12 | layer_reconstruction,
13 | BaseQuantBlock,
14 | QuantModule,
15 | QuantModel,
16 | set_weight_quantize_params,
17 | )
18 | from data.imagenet import build_imagenet_data
19 |
20 |
21 | def seed_all(seed=1029):
22 | random.seed(seed)
23 | os.environ['PYTHONHASHSEED'] = str(seed)
24 | np.random.seed(seed)
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed(seed)
27 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
28 | torch.backends.cudnn.benchmark = False
29 | torch.backends.cudnn.deterministic = True
30 |
31 |
32 | class AverageMeter(object):
33 | """Computes and stores the average and current value"""
34 | def __init__(self, name, fmt=':f'):
35 | self.name = name
36 | self.fmt = fmt
37 | self.reset()
38 |
39 | def reset(self):
40 | self.val = 0
41 | self.avg = 0
42 | self.sum = 0
43 | self.count = 0
44 |
45 | def update(self, val, n=1):
46 | self.val = val
47 | self.sum += val * n
48 | self.count += n
49 | self.avg = self.sum / self.count
50 |
51 | def __str__(self):
52 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
53 | return fmtstr.format(**self.__dict__)
54 |
55 |
56 | class ProgressMeter(object):
57 | def __init__(self, num_batches, meters, prefix=""):
58 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
59 | self.meters = meters
60 | self.prefix = prefix
61 |
62 | def display(self, batch):
63 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
64 | entries += [str(meter) for meter in self.meters]
65 | print('\t'.join(entries))
66 |
67 | def _get_batch_fmtstr(self, num_batches):
68 | num_digits = len(str(num_batches // 1))
69 | fmt = '{:' + str(num_digits) + 'd}'
70 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
71 |
72 |
73 | def accuracy(output, target, topk=(1,)):
74 | """Computes the accuracy over the k top predictions for the specified values of k"""
75 | with torch.no_grad():
76 | maxk = max(topk)
77 | batch_size = target.size(0)
78 |
79 | _, pred = output.topk(maxk, 1, True, True)
80 | pred = pred.t()
81 | correct = pred.eq(target.view(1, -1).expand_as(pred))
82 |
83 | res = []
84 | for k in topk:
85 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
86 | res.append(correct_k.mul_(100.0 / batch_size))
87 | return res
88 |
89 | @torch.no_grad()
90 | def validate_model(val_loader, model, device=None, print_freq=100):
91 | if device is None:
92 | device = next(model.parameters()).device
93 | else:
94 | model.to(device)
95 | batch_time = AverageMeter('Time', ':6.3f')
96 | top1 = AverageMeter('Acc@1', ':6.2f')
97 | top5 = AverageMeter('Acc@5', ':6.2f')
98 | progress = ProgressMeter(
99 | len(val_loader),
100 | [batch_time, top1, top5],
101 | prefix='Test: ')
102 |
103 | # switch to evaluate mode
104 | model.eval()
105 |
106 | end = time.time()
107 | for i, (images, target) in enumerate(val_loader):
108 | images = images.to(device)
109 | target = target.to(device)
110 |
111 | # compute output
112 | output = model(images)
113 |
114 | # measure accuracy and record loss
115 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
116 | top1.update(acc1[0], images.size(0))
117 | top5.update(acc5[0], images.size(0))
118 |
119 | # measure elapsed time
120 | batch_time.update(time.time() - end)
121 | end = time.time()
122 |
123 | if i % print_freq == 0:
124 | progress.display(i)
125 |
126 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))
127 |
128 | return top1.avg
129 |
130 | def get_train_samples(train_loader, num_samples):
131 | train_data, target = [], []
132 | for batch in train_loader:
133 | train_data.append(batch[0])
134 | target.append(batch[1])
135 | if len(train_data) * batch[0].size(0) >= num_samples:
136 | break
137 | return torch.cat(train_data, dim=0)[:num_samples], torch.cat(target, dim=0)[:num_samples]
138 |
139 |
140 | if __name__ == '__main__':
141 |
142 | parser = argparse.ArgumentParser(description='running parameters',
143 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
144 | # general parameters for data and model
145 | parser.add_argument('--seed', default=1005, type=int, help='random seed for results reproduction')
146 | parser.add_argument('--arch', default='resnet18', type=str, help='model name',
147 | choices=['resnet18', 'resnet50', 'mobilenetv2', 'regnetx_600m', 'regnetx_3200m', 'mnasnet'])
148 | parser.add_argument('--batch_size', default=64, type=int, help='mini-batch size for data loader')
149 | parser.add_argument('--workers', default=4, type=int, help='number of workers for data loader')
150 | parser.add_argument('--data_path', default='/datasets-to-imagenet', type=str, help='path to ImageNet data')
151 |
152 | # quantization parameters
153 | parser.add_argument('--n_bits_w', default=4, type=int, help='bitwidth for weight quantization')
154 | parser.add_argument('--channel_wise', default=True, help='apply channel_wise quantization for weights')
155 | parser.add_argument('--n_bits_a', default=4, type=int, help='bitwidth for activation quantization')
156 | parser.add_argument('--disable_8bit_head_stem', action='store_true')
157 |
158 | # weight calibration parameters
159 | parser.add_argument('--num_samples', default=1024, type=int, help='size of the calibration dataset')
160 | parser.add_argument('--iters_w', default=20000, type=int, help='number of iteration for adaround')
161 | parser.add_argument('--weight', default=0.01, type=float, help='weight of rounding cost vs the reconstruction loss.')
162 | parser.add_argument('--keep_cpu', action='store_true', help='keep the calibration data on cpu')
163 |
164 | parser.add_argument('--b_start', default=20, type=int, help='temperature at the beginning of calibration')
165 | parser.add_argument('--b_end', default=2, type=int, help='temperature at the end of calibration')
166 | parser.add_argument('--warmup', default=0.2, type=float, help='in the warmup period no regularization is applied')
167 |
168 | # activation calibration parameters
169 | parser.add_argument('--lr', default=4e-5, type=float, help='learning rate for LSQ')
170 |
171 | parser.add_argument('--init_wmode', default='mse', type=str, choices=['minmax', 'mse', 'minmax_scale'],
172 | help='init opt mode for weight')
173 | parser.add_argument('--init_amode', default='mse', type=str, choices=['minmax', 'mse', 'minmax_scale'],
174 | help='init opt mode for activation')
175 |
176 | parser.add_argument('--prob', default=0.5, type=float)
177 | parser.add_argument('--input_prob', default=0.5, type=float)
178 | parser.add_argument('--lamb_r', default=0.1, type=float, help='hyper-parameter for regularization')
179 | parser.add_argument('--T', default=4.0, type=float, help='temperature coefficient for KL divergence')
180 | parser.add_argument('--bn_lr', default=1e-3, type=float, help='learning rate for DC')
181 | parser.add_argument('--lamb_c', default=0.02, type=float, help='hyper-parameter for DC')
182 | args = parser.parse_args()
183 |
184 | seed_all(args.seed)
185 | # build imagenet data loader
186 | train_loader, test_loader = build_imagenet_data(batch_size=args.batch_size, workers=args.workers,
187 | data_path=args.data_path)
188 | # load model
189 | cnn = eval('hubconf.{}(pretrained=True)'.format(args.arch))
190 | cnn.cuda()
191 | cnn.eval()
192 | fp_model = copy.deepcopy(cnn)
193 | fp_model.cuda()
194 | fp_model.eval()
195 |
196 | # build quantization parameters
197 | wq_params = {'n_bits': args.n_bits_w, 'channel_wise': args.channel_wise, 'scale_method': args.init_wmode}
198 | aq_params = {'n_bits': args.n_bits_a, 'channel_wise': False, 'scale_method': args.init_amode,
199 | 'leaf_param': True, 'prob': args.prob}
200 |
201 | fp_model = QuantModel(model=fp_model, weight_quant_params=wq_params, act_quant_params=aq_params, is_fusing=False)
202 | fp_model.cuda()
203 | fp_model.eval()
204 | fp_model.set_quant_state(False, False)
205 | qnn = QuantModel(model=cnn, weight_quant_params=wq_params, act_quant_params=aq_params)
206 | qnn.cuda()
207 | qnn.eval()
208 | if not args.disable_8bit_head_stem:
209 | print('Setting the first and the last layer to 8-bit')
210 | qnn.set_first_last_layer_to_8bit()
211 |
212 | qnn.disable_network_output_quantization()
213 | print('the quantized model is below!')
214 | print(qnn)
215 | cali_data, cali_target = get_train_samples(train_loader, num_samples=args.num_samples)
216 | device = next(qnn.parameters()).device
217 |
218 | # Kwargs for weight rounding calibration
219 | kwargs = dict(cali_data=cali_data, iters=args.iters_w, weight=args.weight,
220 | b_range=(args.b_start, args.b_end), warmup=args.warmup, opt_mode='mse',
221 | lr=args.lr, input_prob=args.input_prob, keep_gpu=not args.keep_cpu,
222 | lamb_r=args.lamb_r, T=args.T, bn_lr=args.bn_lr, lamb_c=args.lamb_c)
223 |
224 |
225 | '''init weight quantizer'''
226 | set_weight_quantize_params(qnn)
227 |
228 | def set_weight_act_quantize_params(module, fp_module):
229 | if isinstance(module, QuantModule):
230 | layer_reconstruction(qnn, fp_model, module, fp_module, **kwargs)
231 | elif isinstance(module, BaseQuantBlock):
232 | block_reconstruction(qnn, fp_model, module, fp_module, **kwargs)
233 | else:
234 | raise NotImplementedError
235 | def recon_model(model: nn.Module, fp_model: nn.Module):
236 | """
237 | Block reconstruction. For the first and last layers, we can only apply layer reconstruction.
238 | """
239 | for (name, module), (_, fp_module) in zip(model.named_children(), fp_model.named_children()):
240 | if isinstance(module, QuantModule):
241 | print('Reconstruction for layer {}'.format(name))
242 | set_weight_act_quantize_params(module, fp_module)
243 | elif isinstance(module, BaseQuantBlock):
244 | print('Reconstruction for block {}'.format(name))
245 | set_weight_act_quantize_params(module, fp_module)
246 | else:
247 | recon_model(module, fp_module)
248 | # Start calibration
249 | recon_model(qnn, fp_model)
250 |
251 | qnn.set_quant_state(weight_quant=True, act_quant=True)
252 | print('Full quantization (W{}A{}) accuracy: {}'.format(args.n_bits_w, args.n_bits_a,
253 | validate_model(test_loader, qnn)))
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/PD-Quant/e12038edd0c39d09772279dba90c43ce9cb5170d/models/__init__.py
--------------------------------------------------------------------------------
/models/mnasnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __all__ = ['mnasnet']
5 |
6 |
7 | class _InvertedResidual(nn.Module):
8 |
9 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor):
10 | super(_InvertedResidual, self).__init__()
11 | assert stride in [1, 2]
12 | assert kernel_size in [3, 5]
13 | mid_ch = in_ch * expansion_factor
14 | self.apply_residual = (in_ch == out_ch and stride == 1)
15 | self.layers = nn.Sequential(
16 | # Pointwise
17 | nn.Conv2d(in_ch, mid_ch, 1, bias=False),
18 | BN(mid_ch),
19 | nn.ReLU(inplace=True),
20 | # Depthwise
21 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
22 | stride=stride, groups=mid_ch, bias=False),
23 | BN(mid_ch),
24 | nn.ReLU(inplace=True),
25 | # Linear pointwise. Note that there's no activation.
26 | nn.Conv2d(mid_ch, out_ch, 1, bias=False),
27 | BN(out_ch))
28 |
29 | def forward(self, input):
30 | if self.apply_residual:
31 | return self.layers(input) + input
32 | else:
33 | return self.layers(input)
34 |
35 |
36 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats):
37 | """ Creates a stack of inverted residuals. """
38 | assert repeats >= 1
39 | # First one has no skip, because feature map size changes.
40 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor)
41 | remaining = []
42 | for _ in range(1, repeats):
43 | remaining.append(
44 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor))
45 | return nn.Sequential(first, *remaining)
46 |
47 |
48 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
49 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default
50 | bias, will round up, unless the number is no more than 10% greater than the
51 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
52 | assert 0.0 < round_up_bias < 1.0
53 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
54 | return new_val if new_val >= round_up_bias * val else new_val + divisor
55 |
56 |
57 | def _get_depths(scale):
58 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up
59 | rather than down. """
60 | depths = [32, 16, 24, 40, 80, 96, 192, 320]
61 | return [_round_to_multiple_of(depth * scale, 8) for depth in depths]
62 |
63 |
64 | class MNASNet(torch.nn.Module):
65 | # Version 2 adds depth scaling in the initial stages of the network.
66 | _version = 2
67 |
68 | def __init__(self, scale=2.0, num_classes=1000, dropout=0.0):
69 | super(MNASNet, self).__init__()
70 |
71 | global BN
72 | BN = nn.BatchNorm2d
73 |
74 | assert scale > 0.0
75 | self.scale = scale
76 | self.num_classes = num_classes
77 | depths = _get_depths(scale)
78 | layers = [
79 | # First layer: regular conv.
80 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
81 | BN(depths[0]),
82 | nn.ReLU(inplace=True),
83 | # Depthwise separable, no skip.
84 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
85 | groups=depths[0], bias=False),
86 | BN(depths[0]),
87 | nn.ReLU(inplace=True),
88 | nn.Conv2d(depths[0], depths[1], 1,
89 | padding=0, stride=1, bias=False),
90 | BN(depths[1]),
91 | # MNASNet blocks: stacks of inverted residuals.
92 | _stack(depths[1], depths[2], 3, 2, 3, 3),
93 | _stack(depths[2], depths[3], 5, 2, 3, 3),
94 | _stack(depths[3], depths[4], 5, 2, 6, 3),
95 | _stack(depths[4], depths[5], 3, 1, 6, 2),
96 | _stack(depths[5], depths[6], 5, 2, 6, 4),
97 | _stack(depths[6], depths[7], 3, 1, 6, 1),
98 | # Final mapping to classifier input.
99 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
100 | BN(1280),
101 | nn.ReLU(inplace=True),
102 | ]
103 | self.layers = nn.Sequential(*layers)
104 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
105 | nn.Linear(1280, num_classes))
106 | self._initialize_weights()
107 |
108 | def forward(self, x):
109 | x = self.layers(x)
110 | # Equivalent to global avgpool and removing H and W dimensions.
111 | x = x.mean([2, 3])
112 | x = self.classifier(x)
113 | return x
114 |
115 | def _initialize_weights(self):
116 | for m in self.modules():
117 | if isinstance(m, nn.Conv2d):
118 | nn.init.kaiming_normal_(m.weight, mode="fan_out",
119 | nonlinearity="relu")
120 | if m.bias is not None:
121 | nn.init.zeros_(m.bias)
122 | elif isinstance(m, nn.BatchNorm2d):
123 | nn.init.ones_(m.weight)
124 | nn.init.zeros_(m.bias)
125 | elif isinstance(m, nn.Linear):
126 | nn.init.kaiming_uniform_(m.weight, mode="fan_out",
127 | nonlinearity="sigmoid")
128 | nn.init.zeros_(m.bias)
129 |
130 |
131 | def mnasnet(**kwargs):
132 | model = MNASNet(**kwargs)
133 | return model
134 |
135 |
--------------------------------------------------------------------------------
/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 |
5 |
6 | def conv_bn(inp, oup, stride):
7 | return nn.Sequential(
8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
9 | nn.BatchNorm2d(oup),
10 | nn.ReLU6(inplace=True)
11 | )
12 |
13 |
14 | def conv_1x1_bn(inp, oup):
15 | return nn.Sequential(
16 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
17 | nn.BatchNorm2d(oup),
18 | nn.ReLU6(inplace=True)
19 | )
20 |
21 |
22 | class InvertedResidual(nn.Module):
23 | def __init__(self, inp, oup, stride, expand_ratio):
24 | super(InvertedResidual, self).__init__()
25 | self.stride = stride
26 | assert stride in [1, 2]
27 |
28 | hidden_dim = round(inp * expand_ratio)
29 | self.use_res_connect = self.stride == 1 and inp == oup
30 | self.expand_ratio = expand_ratio
31 | if expand_ratio == 1:
32 | self.conv = nn.Sequential(
33 | # dw
34 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
35 | nn.BatchNorm2d(hidden_dim),
36 | nn.ReLU6(inplace=True),
37 | # pw-linear
38 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
39 | nn.BatchNorm2d(oup),
40 | )
41 | else:
42 | self.conv = nn.Sequential(
43 | # pw
44 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
45 | nn.BatchNorm2d(hidden_dim),
46 | nn.ReLU6(inplace=True),
47 | # dw
48 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
49 | nn.BatchNorm2d(hidden_dim),
50 | nn.ReLU6(inplace=True),
51 | # pw-linear
52 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
53 | nn.BatchNorm2d(oup),
54 | )
55 |
56 | def forward(self, x):
57 | if self.use_res_connect:
58 | return x + self.conv(x)
59 | else:
60 | return self.conv(x)
61 |
62 |
63 | class MobileNetV2(nn.Module):
64 | def __init__(self, n_class=1000, input_size=224, width_mult=1., dropout=0.0):
65 | super(MobileNetV2, self).__init__()
66 | block = InvertedResidual
67 | input_channel = 32
68 | last_channel = 1280
69 | interverted_residual_setting = [
70 | # t, c, n, s
71 | [1, 16, 1, 1],
72 | [6, 24, 2, 2],
73 | [6, 32, 3, 2],
74 | [6, 64, 4, 2],
75 | [6, 96, 3, 1],
76 | [6, 160, 3, 2],
77 | [6, 320, 1, 1],
78 | ]
79 |
80 | # building first layer
81 | assert input_size % 32 == 0
82 | input_channel = int(input_channel * width_mult)
83 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
84 | self.features = [conv_bn(3, input_channel, 2)]
85 | # building inverted residual blocks
86 | for t, c, n, s in interverted_residual_setting:
87 | output_channel = int(c * width_mult)
88 | for i in range(n):
89 | if i == 0:
90 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
91 | else:
92 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
93 | input_channel = output_channel
94 | # building last several layers
95 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
96 | # self.features.append(nn.AvgPool2d(input_size // 32))
97 | # make it nn.Sequential
98 | self.features = nn.Sequential(*self.features)
99 |
100 | # building classifier
101 | self.classifier = nn.Sequential(
102 | nn.Dropout(dropout),
103 | nn.Linear(self.last_channel, n_class),
104 | )
105 |
106 | self._initialize_weights()
107 |
108 | def forward(self, x):
109 | x = self.features(x)
110 | x = x.mean([2, 3])
111 | x = self.classifier(x)
112 | return x
113 |
114 | def _initialize_weights(self):
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv2d):
117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118 | m.weight.data.normal_(0, math.sqrt(2. / n))
119 | if m.bias is not None:
120 | m.bias.data.zero_()
121 | elif isinstance(m, nn.BatchNorm2d):
122 | m.weight.data.fill_(1)
123 | m.bias.data.zero_()
124 | elif isinstance(m, nn.Linear):
125 | n = m.weight.size(1)
126 | m.weight.data.normal_(0, 0.01)
127 | m.bias.data.zero_()
128 |
129 |
130 | def mobilenetv2(**kwargs):
131 | """
132 | Constructs a MobileNetV2 model.
133 | """
134 | model = MobileNetV2(**kwargs)
135 | return model
--------------------------------------------------------------------------------
/models/regnet.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/env python
2 |
3 | import numpy as np
4 | import torch.nn as nn
5 | import math
6 |
7 | regnetX_200M_config = {'WA': 36.44, 'W0': 24, 'WM': 2.49, 'DEPTH': 13, 'GROUP_W': 8, 'SE_ON': False}
8 | regnetX_400M_config = {'WA': 24.48, 'W0': 24, 'WM': 2.54, 'DEPTH': 22, 'GROUP_W': 16, 'SE_ON': False}
9 | regnetX_600M_config = {'WA': 36.97, 'W0': 48, 'WM': 2.24, 'DEPTH': 16, 'GROUP_W': 24, 'SE_ON': False}
10 | regnetX_800M_config = {'WA': 35.73, 'W0': 56, 'WM': 2.28, 'DEPTH': 16, 'GROUP_W': 16, 'SE_ON': False}
11 | regnetX_1600M_config = {'WA': 34.01, 'W0': 80, 'WM': 2.25, 'DEPTH': 18, 'GROUP_W': 24, 'SE_ON': False}
12 | regnetX_3200M_config = {'WA': 26.31, 'W0': 88, 'WM': 2.25, 'DEPTH': 25, 'GROUP_W': 48, 'SE_ON': False}
13 | regnetX_4000M_config = {'WA': 38.65, 'W0': 96, 'WM': 2.43, 'DEPTH': 23, 'GROUP_W': 40, 'SE_ON': False}
14 | regnetX_6400M_config = {'WA': 60.83, 'W0': 184, 'WM': 2.07, 'DEPTH': 17, 'GROUP_W': 56, 'SE_ON': False}
15 | regnetY_200M_config = {'WA': 36.44, 'W0': 24, 'WM': 2.49, 'DEPTH': 13, 'GROUP_W': 8, 'SE_ON': True}
16 | regnetY_400M_config = {'WA': 27.89, 'W0': 48, 'WM': 2.09, 'DEPTH': 16, 'GROUP_W': 8, 'SE_ON': True}
17 | regnetY_600M_config = {'WA': 32.54, 'W0': 48, 'WM': 2.32, 'DEPTH': 15, 'GROUP_W': 16, 'SE_ON': True}
18 | regnetY_800M_config = {'WA': 38.84, 'W0': 56, 'WM': 2.4, 'DEPTH': 14, 'GROUP_W': 16, 'SE_ON': True}
19 | regnetY_1600M_config = {'WA': 20.71, 'W0': 48, 'WM': 2.65, 'DEPTH': 27, 'GROUP_W': 24, 'SE_ON': True}
20 | regnetY_3200M_config = {'WA': 42.63, 'W0': 80, 'WM': 2.66, 'DEPTH': 21, 'GROUP_W': 24, 'SE_ON': True}
21 | regnetY_4000M_config = {'WA': 31.41, 'W0': 96, 'WM': 2.24, 'DEPTH': 22, 'GROUP_W': 64, 'SE_ON': True}
22 | regnetY_6400M_config = {'WA': 33.22, 'W0': 112, 'WM': 2.27, 'DEPTH': 25, 'GROUP_W': 72, 'SE_ON': True}
23 |
24 |
25 | BN = nn.BatchNorm2d
26 |
27 | __all__ = ['regnetx_200m', 'regnetx_400m', 'regnetx_600m', 'regnetx_800m',
28 | 'regnetx_1600m', 'regnetx_3200m', 'regnetx_4000m', 'regnetx_6400m',
29 | 'regnety_200m', 'regnety_400m', 'regnety_600m', 'regnety_800m',
30 | 'regnety_1600m', 'regnety_3200m', 'regnety_4000m', 'regnety_6400m']
31 |
32 |
33 | class SimpleStemIN(nn.Module):
34 | """Simple stem for ImageNet."""
35 |
36 | def __init__(self, in_w, out_w):
37 | super(SimpleStemIN, self).__init__()
38 | self._construct(in_w, out_w)
39 |
40 | def _construct(self, in_w, out_w):
41 | # 3x3, BN, ReLU
42 | self.conv = nn.Conv2d(
43 | in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False
44 | )
45 | self.bn = BN(out_w)
46 | self.relu = nn.ReLU(True)
47 |
48 | def forward(self, x):
49 | for layer in self.children():
50 | x = layer(x)
51 | return x
52 |
53 |
54 | class SE(nn.Module):
55 | """Squeeze-and-Excitation (SE) block"""
56 |
57 | def __init__(self, w_in, w_se):
58 | super(SE, self).__init__()
59 | self._construct(w_in, w_se)
60 |
61 | def _construct(self, w_in, w_se):
62 | # AvgPool
63 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
64 | # FC, Activation, FC, Sigmoid
65 | self.f_ex = nn.Sequential(
66 | nn.Conv2d(w_in, w_se, kernel_size=1, bias=True),
67 | nn.ReLU(inplace=True),
68 | nn.Conv2d(w_se, w_in, kernel_size=1, bias=True),
69 | nn.Sigmoid(),
70 | )
71 |
72 | def forward(self, x):
73 | return x * self.f_ex(self.avg_pool(x))
74 |
75 |
76 | class BottleneckTransform(nn.Module):
77 | """Bottlenect transformation: 1x1, 3x3, 1x1"""
78 |
79 | def __init__(self, w_in, w_out, stride, bm, gw, se_r):
80 | super(BottleneckTransform, self).__init__()
81 | self._construct(w_in, w_out, stride, bm, gw, se_r)
82 |
83 | def _construct(self, w_in, w_out, stride, bm, gw, se_r):
84 | # Compute the bottleneck width
85 | w_b = int(round(w_out * bm))
86 | # Compute the number of groups
87 | num_gs = w_b // gw
88 | # 1x1, BN, ReLU
89 | self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False)
90 | self.a_bn = BN(w_b)
91 | self.a_relu = nn.ReLU(True)
92 | # 3x3, BN, ReLU
93 | self.b = nn.Conv2d(
94 | w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False
95 | )
96 | self.b_bn = BN(w_b)
97 | self.b_relu = nn.ReLU(True)
98 | # Squeeze-and-Excitation (SE)
99 | if se_r:
100 | w_se = int(round(w_in * se_r))
101 | self.se = SE(w_b, w_se)
102 | # 1x1, BN
103 | self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False)
104 | self.c_bn = BN(w_out)
105 | self.c_bn.final_bn = True
106 |
107 | def forward(self, x):
108 | for layer in self.children():
109 | x = layer(x)
110 | return x
111 |
112 |
113 | class ResBottleneckBlock(nn.Module):
114 | """Residual bottleneck block: x + F(x), F = bottleneck transform"""
115 |
116 | def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
117 | super(ResBottleneckBlock, self).__init__()
118 | self._construct(w_in, w_out, stride, bm, gw, se_r)
119 |
120 | def _add_skip_proj(self, w_in, w_out, stride):
121 | self.proj = nn.Conv2d(
122 | w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
123 | )
124 | self.bn = BN(w_out)
125 |
126 | def _construct(self, w_in, w_out, stride, bm, gw, se_r):
127 | # Use skip connection with projection if shape changes
128 | self.proj_block = (w_in != w_out) or (stride != 1)
129 | if self.proj_block:
130 | self._add_skip_proj(w_in, w_out, stride)
131 | self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
132 | self.relu = nn.ReLU(True)
133 |
134 | def forward(self, x):
135 | if self.proj_block:
136 | x = self.bn(self.proj(x)) + self.f(x)
137 | else:
138 | x = x + self.f(x)
139 | x = self.relu(x)
140 | return x
141 |
142 |
143 | class AnyHead(nn.Module):
144 | """AnyNet head."""
145 |
146 | def __init__(self, w_in, nc):
147 | super(AnyHead, self).__init__()
148 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
149 | self.fc = nn.Linear(w_in, nc, bias=True)
150 |
151 | def forward(self, x):
152 | x = self.avg_pool(x)
153 | x = x.view(x.size(0), -1)
154 | x = self.fc(x)
155 | return x
156 |
157 |
158 | class AnyStage(nn.Module):
159 | """AnyNet stage (sequence of blocks w/ the same output shape)."""
160 |
161 | def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
162 | super(AnyStage, self).__init__()
163 | self._construct(w_in, w_out, stride, d, block_fun, bm, gw, se_r)
164 |
165 | def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
166 | # Construct the blocks
167 | for i in range(d):
168 | # Stride and w_in apply to the first block of the stage
169 | b_stride = stride if i == 0 else 1
170 | b_w_in = w_in if i == 0 else w_out
171 | # Construct the block
172 | self.add_module(
173 | "b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bm, gw, se_r)
174 | )
175 |
176 | def forward(self, x):
177 | for block in self.children():
178 | x = block(x)
179 | return x
180 |
181 |
182 | class AnyNet(nn.Module):
183 | """AnyNet model."""
184 |
185 | def __init__(self, **kwargs):
186 | super(AnyNet, self).__init__()
187 | if kwargs:
188 | self._construct(
189 | stem_w=kwargs["stem_w"],
190 | ds=kwargs["ds"],
191 | ws=kwargs["ws"],
192 | ss=kwargs["ss"],
193 | bms=kwargs["bms"],
194 | gws=kwargs["gws"],
195 | se_r=kwargs["se_r"],
196 | nc=kwargs["nc"],
197 | )
198 | for m in self.modules():
199 | if isinstance(m, nn.Conv2d):
200 | # Note that there is no bias due to BN
201 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
202 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
203 | elif isinstance(m, nn.BatchNorm2d):
204 | m.weight.data.fill_(1)
205 | m.bias.data.zero_()
206 | elif isinstance(m, nn.Linear):
207 | n = m.weight.size(1)
208 | m.weight.data.normal_(0, 1.0 / float(n))
209 | m.bias.data.zero_()
210 |
211 | def _construct(self, stem_w, ds, ws, ss, bms, gws, se_r, nc):
212 | # self.logger.info("Constructing AnyNet: ds={}, ws={}".format(ds, ws))
213 | # Generate dummy bot muls and gs for models that do not use them
214 | bms = bms if bms else [1.0 for _d in ds]
215 | gws = gws if gws else [1 for _d in ds]
216 | # Group params by stage
217 | stage_params = list(zip(ds, ws, ss, bms, gws))
218 | # Construct the stem
219 | self.stem = SimpleStemIN(3, stem_w)
220 | # Construct the stages
221 | block_fun = ResBottleneckBlock
222 | prev_w = stem_w
223 | for i, (d, w, s, bm, gw) in enumerate(stage_params):
224 | self.add_module(
225 | "s{}".format(i + 1), AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r)
226 | )
227 | prev_w = w
228 | # Construct the head
229 | self.head = AnyHead(w_in=prev_w, nc=nc)
230 |
231 | def forward(self, x):
232 | for module in self.children():
233 | x = module(x)
234 | return x
235 |
236 |
237 | def quantize_float(f, q):
238 | """Converts a float to closest non-zero int divisible by q."""
239 | return int(round(f / q) * q)
240 |
241 |
242 | def adjust_ws_gs_comp(ws, bms, gs):
243 | """Adjusts the compatibility of widths and groups."""
244 | ws_bot = [int(w * b) for w, b in zip(ws, bms)]
245 | gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
246 | ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
247 | ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
248 | return ws, gs
249 |
250 |
251 | def get_stages_from_blocks(ws, rs):
252 | """Gets ws/ds of network at each stage from per block values."""
253 | ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
254 | ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
255 | s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
256 | s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
257 | return s_ws, s_ds
258 |
259 |
260 | def generate_regnet(w_a, w_0, w_m, d, q=8):
261 | """Generates per block ws from RegNet parameters.
262 |
263 | args:
264 | w_a(float): slope
265 | w_0(int): initial width
266 | w_m(float): an additional parameter that controls quantization
267 | d(int): number of depth
268 | q(int): the coefficient of division
269 |
270 | procedure:
271 | 1. generate a linear parameterization for block widths. Eql(2)
272 | 2. compute corresponding stage for each block $log_{w_m}^{w_j/w_0}$. Eql(3)
273 | 3. compute per-block width via $w_0*w_m^(s_j)$ and qunatize them that can be divided by q. Eql(4)
274 |
275 | return:
276 | ws(list of quantized float): quantized width list for blocks in different stages
277 | num_stages(int): total number of stages
278 | max_stage(float): the maximal index of stage
279 | ws_cont(list of float): original width list for blocks in different stages
280 | """
281 | assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
282 | ws_cont = np.arange(d) * w_a + w_0
283 | ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
284 | ws = w_0 * np.power(w_m, ks)
285 | ws = np.round(np.divide(ws, q)) * q
286 | num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
287 | ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
288 | return ws, num_stages, max_stage, ws_cont
289 |
290 |
291 | class RegNet(AnyNet):
292 | """RegNet model class, based on
293 | `"Designing Network Design Spaces" `_
294 | """
295 |
296 | def __init__(self, cfg, bn=None):
297 | # Generate RegNet ws per block
298 | b_ws, num_s, _, _ = generate_regnet(
299 | cfg['WA'], cfg['W0'], cfg['WM'], cfg['DEPTH']
300 | )
301 | # Convert to per stage format
302 | ws, ds = get_stages_from_blocks(b_ws, b_ws)
303 | # Generate group widths and bot muls
304 | gws = [cfg['GROUP_W'] for _ in range(num_s)]
305 | bms = [1 for _ in range(num_s)]
306 | # Adjust the compatibility of ws and gws
307 | ws, gws = adjust_ws_gs_comp(ws, bms, gws)
308 | # Use the same stride for each stage, stride set to 2
309 | ss = [2 for _ in range(num_s)]
310 | # Use SE for RegNetY
311 | se_r = 0.25 if cfg['SE_ON'] else None
312 | # Construct the model
313 | STEM_W = 32
314 |
315 | global BN
316 |
317 | kwargs = {
318 | "stem_w": STEM_W,
319 | "ss": ss,
320 | "ds": ds,
321 | "ws": ws,
322 | "bms": bms,
323 | "gws": gws,
324 | "se_r": se_r,
325 | "nc": 1000,
326 | }
327 | super(RegNet, self).__init__(**kwargs)
328 |
329 |
330 | def regnetx_200m(**kwargs):
331 | """
332 | Constructs a RegNet-X model under 200M FLOPs.
333 | """
334 | model = RegNet(regnetX_200M_config, **kwargs)
335 | return model
336 |
337 |
338 | def regnetx_400m(**kwargs):
339 | """
340 | Constructs a RegNet-X model under 400M FLOPs.
341 | """
342 | model = RegNet(regnetX_400M_config, **kwargs)
343 | return model
344 |
345 |
346 | def regnetx_600m(**kwargs):
347 | """
348 | Constructs a RegNet-X model under 600M FLOPs.
349 | """
350 | model = RegNet(regnetX_600M_config, **kwargs)
351 | return model
352 |
353 |
354 | def regnetx_800m(**kwargs):
355 | """
356 | Constructs a RegNet-X model under 800M FLOPs.
357 | """
358 | model = RegNet(regnetX_800M_config, **kwargs)
359 | return model
360 |
361 |
362 | def regnetx_1600m(**kwargs):
363 | """
364 | Constructs a RegNet-X model under 1600M FLOPs.
365 | """
366 | model = RegNet(regnetX_1600M_config, **kwargs)
367 | return model
368 |
369 |
370 | def regnetx_3200m(**kwargs):
371 | """
372 | Constructs a RegNet-X model under 3200M FLOPs.
373 | """
374 | model = RegNet(regnetX_3200M_config, **kwargs)
375 | return model
376 |
377 |
378 | def regnetx_4000m(**kwargs):
379 | """
380 | Constructs a RegNet-X model under 4000M FLOPs.
381 | """
382 | model = RegNet(regnetX_4000M_config, **kwargs)
383 | return model
384 |
385 |
386 | def regnetx_6400m(**kwargs):
387 | """
388 | Constructs a RegNet-X model under 6400M FLOPs.
389 | """
390 | model = RegNet(regnetX_6400M_config, **kwargs)
391 | return model
392 |
393 |
394 | def regnety_200m(**kwargs):
395 | """
396 | Constructs a RegNet-Y model under 200M FLOPs.
397 | """
398 | model = RegNet(regnetY_200M_config, **kwargs)
399 | return model
400 |
401 |
402 | def regnety_400m(**kwargs):
403 | """
404 | Constructs a RegNet-Y model under 400M FLOPs.
405 | """
406 | model = RegNet(regnetY_400M_config, **kwargs)
407 | return model
408 |
409 |
410 | def regnety_600m(**kwargs):
411 | """
412 | Constructs a RegNet-Y model under 600M FLOPs.
413 | """
414 | model = RegNet(regnetY_600M_config, **kwargs)
415 | return model
416 |
417 |
418 | def regnety_800m(**kwargs):
419 | """
420 | Constructs a RegNet-Y model under 800M FLOPs.
421 | """
422 | model = RegNet(regnetY_800M_config, **kwargs)
423 | return model
424 |
425 |
426 | def regnety_1600m(**kwargs):
427 | """
428 | Constructs a RegNet-Y model under 1600M FLOPs.
429 | """
430 | model = RegNet(regnetY_1600M_config, **kwargs)
431 | return model
432 |
433 |
434 | def regnety_3200m(**kwargs):
435 | """
436 | Constructs a RegNet-Y model under 3200M FLOPs.
437 | """
438 | model = RegNet(regnetY_3200M_config, **kwargs)
439 | return model
440 |
441 |
442 | def regnety_4000m(**kwargs):
443 | """
444 | Constructs a RegNet-Y model under 4000M FLOPs.
445 | """
446 | model = RegNet(regnetY_4000M_config, **kwargs)
447 | return model
448 |
449 |
450 | def regnety_6400m(**kwargs):
451 | """
452 | Constructs a RegNet-Y model under 6400M FLOPs.
453 | """
454 | model = RegNet(regnetY_6400M_config, **kwargs)
455 | return model
456 |
457 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/env python
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
9 |
10 |
11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
12 | """3x3 convolution with padding"""
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=dilation, groups=groups, bias=False, dilation=dilation)
15 |
16 |
17 | def conv1x1(in_planes, out_planes, stride=1):
18 | """1x1 convolution"""
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | expansion = 1
24 | __constants__ = ['downsample']
25 |
26 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
27 | base_width=64, dilation=1, norm_layer=None):
28 | super(BasicBlock, self).__init__()
29 | if norm_layer is None:
30 | norm_layer = BN
31 | if groups != 1 or base_width != 64:
32 | raise ValueError(
33 | 'BasicBlock only supports groups=1 and base_width=64')
34 | if dilation > 1:
35 | raise NotImplementedError(
36 | "Dilation > 1 not supported in BasicBlock")
37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
38 | self.conv1 = conv3x3(inplanes, planes, stride)
39 | self.bn1 = norm_layer(planes)
40 | self.relu1 = nn.ReLU(inplace=True)
41 | self.conv2 = conv3x3(planes, planes)
42 | self.bn2 = norm_layer(planes)
43 | self.downsample = downsample
44 | self.relu2 = nn.ReLU(inplace=True)
45 | self.stride = stride
46 |
47 | def forward(self, x):
48 | identity = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu1(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 |
57 | if self.downsample is not None:
58 | identity = self.downsample(x)
59 |
60 | out += identity
61 | out = self.relu2(out)
62 |
63 | return out
64 |
65 |
66 | class Bottleneck(nn.Module):
67 | expansion = 4
68 | __constants__ = ['downsample']
69 |
70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
71 | base_width=64, dilation=1, norm_layer=None):
72 | super(Bottleneck, self).__init__()
73 | if norm_layer is None:
74 | norm_layer = BN
75 | width = int(planes * (base_width / 64.)) * groups
76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
77 | self.conv1 = conv1x1(inplanes, width)
78 | self.bn1 = norm_layer(width)
79 | self.relu1 = nn.ReLU(inplace=True)
80 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
81 | self.bn2 = norm_layer(width)
82 | self.relu2 = nn.ReLU(inplace=True)
83 | self.conv3 = conv1x1(width, planes * self.expansion)
84 | self.bn3 = norm_layer(planes * self.expansion)
85 | self.relu3 = nn.ReLU(inplace=True)
86 | self.downsample = downsample
87 | self.stride = stride
88 |
89 | def forward(self, x):
90 | identity = x
91 |
92 | out = self.conv1(x)
93 | out = self.bn1(out)
94 | out = self.relu1(out)
95 |
96 | out = self.conv2(out)
97 | out = self.bn2(out)
98 | out = self.relu2(out)
99 |
100 | out = self.conv3(out)
101 | out = self.bn3(out)
102 |
103 | if self.downsample is not None:
104 | identity = self.downsample(x)
105 |
106 | out += identity
107 | out = self.relu3(out)
108 |
109 | return out
110 |
111 |
112 | class ResNet(nn.Module):
113 |
114 | def __init__(self,
115 | block,
116 | layers,
117 | num_classes=1000,
118 | zero_init_residual=False,
119 | groups=1,
120 | width_per_group=64,
121 | replace_stride_with_dilation=None,
122 | deep_stem=False,
123 | avg_down=False):
124 |
125 | super(ResNet, self).__init__()
126 |
127 | global BN
128 |
129 | BN = torch.nn.BatchNorm2d
130 | norm_layer = BN
131 |
132 | self._norm_layer = norm_layer
133 |
134 | self.inplanes = 64
135 | self.dilation = 1
136 | self.deep_stem = deep_stem
137 | self.avg_down = avg_down
138 |
139 | if replace_stride_with_dilation is None:
140 | # each element in the tuple indicates if we should replace
141 | # the 2x2 stride with a dilated convolution instead
142 | replace_stride_with_dilation = [False, False, False]
143 | if len(replace_stride_with_dilation) != 3:
144 | raise ValueError("replace_stride_with_dilation should be None "
145 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
146 | self.groups = groups
147 | self.base_width = width_per_group
148 |
149 | if self.deep_stem:
150 | self.conv1 = nn.Sequential(
151 | nn.Conv2d(3, 32, kernel_size=3, stride=2,
152 | padding=1, bias=False),
153 | norm_layer(32),
154 | nn.ReLU(inplace=True),
155 | nn.Conv2d(32, 32, kernel_size=3, stride=1,
156 | padding=1, bias=False),
157 | norm_layer(32),
158 | nn.ReLU(inplace=True),
159 | nn.Conv2d(32, 64, kernel_size=3, stride=1,
160 | padding=1, bias=False),
161 | )
162 | else:
163 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
164 | stride=2, padding=3, bias=False)
165 |
166 | self.bn1 = norm_layer(self.inplanes)
167 | self.relu = nn.ReLU(inplace=True)
168 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
169 | self.layer1 = self._make_layer(block, 64, layers[0])
170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
171 | dilate=replace_stride_with_dilation[0])
172 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
173 | dilate=replace_stride_with_dilation[1])
174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
175 | dilate=replace_stride_with_dilation[2])
176 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
177 | self.fc = nn.Linear(512 * block.expansion, num_classes)
178 |
179 | for m in self.modules():
180 | if isinstance(m, nn.Conv2d):
181 | nn.init.kaiming_normal_(
182 | m.weight, mode='fan_out', nonlinearity='relu')
183 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
184 | nn.init.constant_(m.weight, 1)
185 | nn.init.constant_(m.bias, 0)
186 |
187 | # Zero-initialize the last BN in each residual branch,
188 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
189 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
190 | if zero_init_residual:
191 | for m in self.modules():
192 | if isinstance(m, Bottleneck):
193 | nn.init.constant_(m.bn3.weight, 0)
194 | elif isinstance(m, BasicBlock):
195 | nn.init.constant_(m.bn2.weight, 0)
196 |
197 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
198 | norm_layer = self._norm_layer
199 | downsample = None
200 | previous_dilation = self.dilation
201 | if dilate:
202 | self.dilation *= stride
203 | stride = 1
204 | if stride != 1 or self.inplanes != planes * block.expansion:
205 | if self.avg_down:
206 | downsample = nn.Sequential(
207 | nn.AvgPool2d(stride, stride=stride,
208 | ceil_mode=True, count_include_pad=False),
209 | conv1x1(self.inplanes, planes * block.expansion),
210 | norm_layer(planes * block.expansion),
211 | )
212 | else:
213 | downsample = nn.Sequential(
214 | conv1x1(self.inplanes, planes * block.expansion, stride),
215 | norm_layer(planes * block.expansion),
216 | )
217 |
218 | layers = []
219 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
220 | self.base_width, previous_dilation, norm_layer))
221 | self.inplanes = planes * block.expansion
222 | for _ in range(1, blocks):
223 | layers.append(block(self.inplanes, planes, groups=self.groups,
224 | base_width=self.base_width, dilation=self.dilation,
225 | norm_layer=norm_layer))
226 |
227 | return nn.Sequential(*layers)
228 |
229 | def _forward_impl(self, x):
230 | # See note [TorchScript super()]
231 | x = self.conv1(x)
232 | x = self.bn1(x)
233 | x = self.relu(x)
234 | x = self.maxpool(x)
235 | x = self.layer1(x)
236 | x = self.layer2(x)
237 | x = self.layer3(x)
238 | x = self.layer4(x)
239 |
240 | x = self.avgpool(x)
241 | x = torch.flatten(x, 1)
242 | x = self.fc(x)
243 |
244 | return x
245 |
246 | def forward(self, x):
247 | return self._forward_impl(x)
248 |
249 |
250 | def resnet18(**kwargs):
251 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
252 | return model
253 |
254 |
255 | def resnet34(**kwargs):
256 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
257 | return model
258 |
259 |
260 | def resnet50(**kwargs):
261 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
262 | return model
263 |
264 |
265 | def resnet101(**kwargs):
266 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
267 | return model
268 |
269 |
270 | def resnet152(**kwargs):
271 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
272 | return model
273 |
274 |
275 | def resnext50_32x4d(**kwargs):
276 | kwargs['groups'] = 32
277 | kwargs['width_per_group'] = 4
278 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
279 | return model
280 |
281 |
282 | def resnext101_32x8d(**kwargs):
283 | kwargs['groups'] = 32
284 | kwargs['width_per_group'] = 8
285 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
286 | return model
287 |
288 |
289 | def wide_resnet50_2(**kwargs):
290 | kwargs['width_per_group'] = 64 * 2
291 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
292 | return model
293 |
294 |
295 | def wide_resnet101_2(**kwargs):
296 | kwargs['width_per_group'] = 64 * 2
297 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
298 | return model
299 |
--------------------------------------------------------------------------------
/quant/__init__.py:
--------------------------------------------------------------------------------
1 | from .block_recon import block_reconstruction
2 | from .layer_recon import layer_reconstruction
3 | from .quant_block import BaseQuantBlock
4 | from .quant_layer import QuantModule
5 | from .quant_model import QuantModel
6 | from .set_weight_quantize_params import set_weight_quantize_params, get_init, save_quantized_weight
7 | from .set_act_quantize_params import set_act_quantize_params
8 |
--------------------------------------------------------------------------------
/quant/adaptive_rounding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .quant_layer import UniformAffineQuantizer, round_ste
4 |
5 |
6 | class AdaRoundQuantizer(nn.Module):
7 | """
8 | Adaptive Rounding Quantizer, used to optimize the rounding policy
9 | by reconstructing the intermediate output.
10 | Based on
11 | Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568
12 |
13 | :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer
14 | :param round_mode: controls the forward pass in this quantizer
15 | :param weight_tensor: initialize alpha
16 | """
17 |
18 | def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'):
19 | super(AdaRoundQuantizer, self).__init__()
20 | # copying all attributes from UniformAffineQuantizer
21 | self.n_bits = uaq.n_bits
22 | self.sym = uaq.sym
23 | self.delta = uaq.delta
24 | self.zero_point = uaq.zero_point
25 | self.n_levels = uaq.n_levels
26 |
27 | self.round_mode = round_mode
28 | self.alpha = None
29 | self.soft_targets = False
30 |
31 | # params for sigmoid function
32 | self.gamma, self.zeta = -0.1, 1.1
33 | self.beta = 2/3
34 | self.init_alpha(x=weight_tensor.clone())
35 |
36 | def forward(self, x):
37 | if self.round_mode == 'nearest':
38 | x_int = torch.round(x / self.delta)
39 | elif self.round_mode == 'nearest_ste':
40 | x_int = round_ste(x / self.delta)
41 | elif self.round_mode == 'stochastic':
42 | x_floor = torch.floor(x / self.delta)
43 | rest = (x / self.delta) - x_floor # rest of rounding
44 | x_int = x_floor + torch.bernoulli(rest)
45 | print('Draw stochastic sample')
46 | elif self.round_mode == 'learned_hard_sigmoid':
47 | x_floor = torch.floor(x / self.delta)
48 | if self.soft_targets:
49 | x_int = x_floor + self.get_soft_targets()
50 | else:
51 | x_int = x_floor + (self.alpha >= 0).float()
52 | else:
53 | raise ValueError('Wrong rounding mode')
54 |
55 | x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1)
56 | x_float_q = (x_quant - self.zero_point) * self.delta
57 |
58 | return x_float_q
59 |
60 | def get_soft_targets(self):
61 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1)
62 |
63 | def init_alpha(self, x: torch.Tensor):
64 | x_floor = torch.floor(x / self.delta)
65 | if self.round_mode == 'learned_hard_sigmoid':
66 | print('Init alpha to be FP32')
67 | rest = (x / self.delta) - x_floor # rest of rounding [0, 1)
68 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
69 | self.alpha = nn.Parameter(alpha)
70 | else:
71 | raise NotImplementedError
72 |
73 | @torch.jit.export
74 | def extra_repr(self):
75 | return 'bit={}'.format(self.n_bits)
76 |
--------------------------------------------------------------------------------
/quant/block_recon.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .quant_layer import QuantModule, lp_loss
4 | from .quant_model import QuantModel
5 | from .quant_block import BaseQuantBlock, specials_unquantized
6 | from .adaptive_rounding import AdaRoundQuantizer
7 | from .set_weight_quantize_params import get_init, get_dc_fp_init
8 | from .set_act_quantize_params import set_act_quantize_params
9 |
10 | include = False
11 | def find_unquantized_module(model: torch.nn.Module, module_list: list = [], name_list: list = []):
12 | """Store subsequent unquantized modules in a list"""
13 | global include
14 | for name, module in model.named_children():
15 | if isinstance(module, (QuantModule, BaseQuantBlock)):
16 | if not module.trained:
17 | include = True
18 | module.set_quant_state(False,False)
19 | name_list.append(name)
20 | module_list.append(module)
21 | elif include and type(module) in specials_unquantized:
22 | name_list.append(name)
23 | module_list.append(module)
24 | else:
25 | find_unquantized_module(module, module_list, name_list)
26 | return module_list[1:], name_list[1:]
27 |
28 | def block_reconstruction(model: QuantModel, fp_model: QuantModel, block: BaseQuantBlock, fp_block: BaseQuantBlock,
29 | cali_data: torch.Tensor, batch_size: int = 32, iters: int = 20000, weight: float = 0.01,
30 | opt_mode: str = 'mse', b_range: tuple = (20, 2),
31 | warmup: float = 0.0, p: float = 2.0, lr: float = 4e-5,
32 | input_prob: float = 1.0, keep_gpu: bool = True,
33 | lamb_r: float = 0.2, T: float = 7.0, bn_lr: float = 1e-3, lamb_c=0.02):
34 | """
35 | Reconstruction to optimize the output from each block.
36 |
37 | :param model: QuantModel
38 | :param block: BaseQuantBlock that needs to be optimized
39 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound
40 | :param batch_size: mini-batch size for reconstruction
41 | :param iters: optimization iterations for reconstruction,
42 | :param weight: the weight of rounding regularization term
43 | :param opt_mode: optimization mode
44 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output
45 | :param include_act_func: optimize the output after activation function
46 | :param b_range: temperature range
47 | :param warmup: proportion of iterations that no scheduling for temperature
48 | :param lr: learning rate for act delta learning
49 | :param p: L_p norm minimization
50 | :param lamb_r: hyper-parameter for regularization
51 | :param T: temperature coefficient for KL divergence
52 | :param bn_lr: learning rate for DC
53 | :param lamb_c: hyper-parameter for DC
54 | """
55 |
56 | '''get input and set scale'''
57 | cached_inps = get_init(model, block, cali_data, batch_size=batch_size,
58 | input_prob=True, keep_gpu=keep_gpu)
59 | cached_outs, cached_output, cur_syms = get_dc_fp_init(fp_model, fp_block, cali_data, batch_size=batch_size,
60 | input_prob=True, keep_gpu=keep_gpu, bn_lr=bn_lr, lamb=lamb_c)
61 | set_act_quantize_params(block, cali_data=cached_inps[:min(256, cached_inps.size(0))])
62 |
63 | '''set state'''
64 | cur_weight, cur_act = True, True
65 |
66 | global include
67 | module_list, name_list, include = [], [], False
68 | module_list, name_list = find_unquantized_module(model, module_list, name_list)
69 | block.set_quant_state(cur_weight, cur_act)
70 | for para in model.parameters():
71 | para.requires_grad = False
72 |
73 | '''set quantizer'''
74 | round_mode = 'learned_hard_sigmoid'
75 | # Replace weight quantizer to AdaRoundQuantizer
76 | w_para, a_para = [], []
77 | w_opt, a_opt = None, None
78 | scheduler, a_scheduler = None, None
79 |
80 | for module in block.modules():
81 | '''weight'''
82 | if isinstance(module, QuantModule):
83 | module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode,
84 | weight_tensor=module.org_weight.data)
85 | module.weight_quantizer.soft_targets = True
86 | w_para += [module.weight_quantizer.alpha]
87 | '''activation'''
88 | if isinstance(module, (QuantModule, BaseQuantBlock)):
89 | if module.act_quantizer.delta is not None:
90 | module.act_quantizer.delta = torch.nn.Parameter(torch.tensor(module.act_quantizer.delta))
91 | a_para += [module.act_quantizer.delta]
92 | '''set up drop'''
93 | module.act_quantizer.is_training = True
94 |
95 | if len(w_para) != 0:
96 | w_opt = torch.optim.Adam(w_para, lr=3e-3)
97 | if len(a_para) != 0:
98 | a_opt = torch.optim.Adam(a_para, lr=lr)
99 | a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=iters, eta_min=0.)
100 |
101 | loss_mode = 'relaxation'
102 | rec_loss = opt_mode
103 | loss_func = LossFunction(block, round_loss=loss_mode, weight=weight, max_count=iters, rec_loss=rec_loss,
104 | b_range=b_range, decay_start=0, warmup=warmup, p=p, lam=lamb_r, T=T)
105 | device = 'cuda'
106 | sz = cached_inps.size(0)
107 | for i in range(iters):
108 | idx = torch.randint(0, sz, (batch_size,))
109 | cur_inp = cached_inps[idx].to(device)
110 | cur_sym = cur_syms[idx].to(device)
111 | output_fp = cached_output[idx].to(device)
112 | cur_out = cached_outs[idx].to(device)
113 | if input_prob < 1.0:
114 | drop_inp = torch.where(torch.rand_like(cur_inp) < input_prob, cur_inp, cur_sym)
115 |
116 | cur_inp = torch.cat((drop_inp, cur_inp))
117 |
118 | if w_opt:
119 | w_opt.zero_grad()
120 | if a_opt:
121 | a_opt.zero_grad()
122 |
123 | out_all = block(cur_inp)
124 |
125 | '''forward for prediction difference'''
126 | out_drop = out_all[:batch_size]
127 | out_quant = out_all[batch_size:]
128 | output = out_quant
129 | for num, module in enumerate(module_list):
130 | # for ResNet and RegNet
131 | if name_list[num] == 'fc':
132 | output = torch.flatten(output, 1)
133 | # for MobileNet and MNasNet
134 | if isinstance(module, torch.nn.Dropout):
135 | output = output.mean([2, 3])
136 | output = module(output)
137 | err = loss_func(out_drop, cur_out, output, output_fp)
138 |
139 | err.backward(retain_graph=True)
140 | if w_opt:
141 | w_opt.step()
142 | if a_opt:
143 | a_opt.step()
144 | if scheduler:
145 | scheduler.step()
146 | if a_scheduler:
147 | a_scheduler.step()
148 | torch.cuda.empty_cache()
149 |
150 | for module in block.modules():
151 | if isinstance(module, QuantModule):
152 | '''weight '''
153 | module.weight_quantizer.soft_targets = False
154 | '''activation'''
155 | if isinstance(module, (QuantModule, BaseQuantBlock)):
156 | module.act_quantizer.is_training = False
157 | module.trained = True
158 | for module in fp_block.modules():
159 | if isinstance(module, (QuantModule, BaseQuantBlock)):
160 | module.trained = True
161 |
162 | class LossFunction:
163 | def __init__(self,
164 | block: BaseQuantBlock,
165 | round_loss: str = 'relaxation',
166 | weight: float = 1.,
167 | rec_loss: str = 'mse',
168 | max_count: int = 2000,
169 | b_range: tuple = (10, 2),
170 | decay_start: float = 0.0,
171 | warmup: float = 0.0,
172 | p: float = 2.,
173 | lam: float = 1.0,
174 | T: float = 7.0):
175 |
176 | self.block = block
177 | self.round_loss = round_loss
178 | self.weight = weight
179 | self.rec_loss = rec_loss
180 | self.loss_start = max_count * warmup
181 | self.p = p
182 | self.lam = lam
183 | self.T = T
184 |
185 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start,
186 | start_b=b_range[0], end_b=b_range[1])
187 | self.count = 0
188 | self.pd_loss = torch.nn.KLDivLoss(reduction='batchmean')
189 |
190 | def __call__(self, pred, tgt, output, output_fp):
191 | """
192 | Compute the total loss for adaptive rounding:
193 | rec_loss is the quadratic output reconstruction loss, round_loss is
194 | a regularization term to optimize the rounding policy, pd_loss is the
195 | prediction difference loss.
196 |
197 | :param pred: output from quantized model
198 | :param tgt: output from FP model
199 | :param output: prediction from quantized model
200 | :param output_fp: prediction from FP model
201 | :return: total loss function
202 | """
203 | self.count += 1
204 | if self.rec_loss == 'mse':
205 | rec_loss = lp_loss(pred, tgt, p=self.p)
206 | else:
207 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))
208 |
209 | pd_loss = self.pd_loss(F.log_softmax(output / self.T, dim=1), F.softmax(output_fp / self.T, dim=1)) / self.lam
210 |
211 | b = self.temp_decay(self.count)
212 | if self.count < self.loss_start or self.round_loss == 'none':
213 | b = round_loss = 0
214 | elif self.round_loss == 'relaxation':
215 | round_loss = 0
216 | for name, module in self.block.named_modules():
217 | if isinstance(module, QuantModule):
218 | round_vals = module.weight_quantizer.get_soft_targets()
219 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
220 | else:
221 | raise NotImplementedError
222 |
223 | total_loss = rec_loss + round_loss + pd_loss
224 | if self.count % 500 == 0:
225 | print('Total loss:\t{:.3f} (rec:{:.3f}, pd:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
226 | float(total_loss), float(rec_loss), float(pd_loss), float(round_loss), b, self.count))
227 | return total_loss
228 |
229 |
230 | class LinearTempDecay:
231 | def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2):
232 | self.t_max = t_max
233 | self.start_decay = rel_start_decay * t_max
234 | self.start_b = start_b
235 | self.end_b = end_b
236 |
237 | def __call__(self, t):
238 | """
239 | Cosine annealing scheduler for temperature b.
240 | :param t: the current time step
241 | :return: scheduled temperature
242 | """
243 | if t < self.start_decay:
244 | return self.start_b
245 | else:
246 | rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
247 | # return self.end_b + 0.5 * (self.start_b - self.end_b) * (1 + np.cos(rel_t * np.pi))
248 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))
249 |
--------------------------------------------------------------------------------
/quant/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | from .quant_layer import QuantModule, Union, lp_loss
6 | from .quant_model import QuantModel
7 | from .quant_block import BaseQuantBlock
8 | from tqdm import trange
9 |
10 |
11 | def save_dc_fp_data(model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], cali_data: torch.Tensor,
12 | batch_size: int = 32, keep_gpu: bool = True,
13 | input_prob: bool = False, lamb=50, bn_lr=1e-3):
14 | """Activation after correction"""
15 | device = next(model.parameters()).device
16 | get_inp_out = GetDcFpLayerInpOut(model, layer, device=device, input_prob=input_prob, lamb=lamb, bn_lr=bn_lr)
17 | cached_batches = []
18 |
19 | print("Start correcting {} batches of data!".format(int(cali_data.size(0) / batch_size)))
20 | for i in trange(int(cali_data.size(0) / batch_size)):
21 | if input_prob:
22 | cur_out, out_fp, cur_sym = get_inp_out(cali_data[i * batch_size:(i + 1) * batch_size])
23 | cached_batches.append((cur_out.cpu(), out_fp.cpu(), cur_sym.cpu()))
24 | else:
25 | cur_out, out_fp = get_inp_out(cali_data[i * batch_size:(i + 1) * batch_size])
26 | cached_batches.append((cur_out.cpu(), out_fp.cpu()))
27 | cached_outs = torch.cat([x[0] for x in cached_batches])
28 | cached_outputs = torch.cat([x[1] for x in cached_batches])
29 | if input_prob:
30 | cached_sym = torch.cat([x[2] for x in cached_batches])
31 | torch.cuda.empty_cache()
32 | if keep_gpu:
33 | cached_outs = cached_outs.to(device)
34 | cached_outputs = cached_outputs.to(device)
35 | if input_prob:
36 | cached_sym = cached_sym.to(device)
37 | if input_prob:
38 | cached_outs.requires_grad = False
39 | cached_sym.requires_grad = False
40 | return cached_outs, cached_outputs, cached_sym
41 | return cached_outs, cached_outputs
42 |
43 |
44 | def save_inp_oup_data(model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], cali_data: torch.Tensor,
45 | batch_size: int = 32, keep_gpu: bool = True,
46 | input_prob: bool = False):
47 | """
48 | Save input data and output data of a particular layer/block over calibration dataset.
49 |
50 | :param model: QuantModel
51 | :param layer: QuantModule or QuantBlock
52 | :param cali_data: calibration data set
53 | :param weight_quant: use weight_quant quantization
54 | :param act_quant: use act_quant quantization
55 | :param batch_size: mini-batch size for calibration
56 | :param keep_gpu: put saved data on GPU for faster optimization
57 | :return: input and output data
58 | """
59 | device = next(model.parameters()).device
60 | get_inp_out = GetLayerInpOut(model, layer, device=device, input_prob=input_prob)
61 | cached_batches = []
62 |
63 | for i in range(int(cali_data.size(0) / batch_size)):
64 | cur_inp = get_inp_out(cali_data[i * batch_size:(i + 1) * batch_size])
65 | cached_batches.append(cur_inp.cpu())
66 | cached_inps = torch.cat([x for x in cached_batches])
67 | torch.cuda.empty_cache()
68 | if keep_gpu:
69 | cached_inps = cached_inps.to(device)
70 |
71 | return cached_inps
72 |
73 |
74 | class StopForwardException(Exception):
75 | """
76 | Used to throw and catch an exception to stop traversing the graph
77 | """
78 | pass
79 |
80 |
81 | class DataSaverHook:
82 | """
83 | Forward hook that stores the input and output of a block
84 | """
85 |
86 | def __init__(self, store_input=False, store_output=False, stop_forward=False):
87 | self.store_input = store_input
88 | self.store_output = store_output
89 | self.stop_forward = stop_forward
90 |
91 | self.input_store = None
92 | self.output_store = None
93 |
94 | def __call__(self, module, input_batch, output_batch):
95 | if self.store_input:
96 | self.input_store = input_batch
97 | if self.store_output:
98 | self.output_store = output_batch
99 | if self.stop_forward:
100 | raise StopForwardException
101 |
102 |
103 | class input_hook(object):
104 | """
105 | Forward_hook used to get the output of the intermediate layer.
106 | """
107 | def __init__(self, stop_forward=False):
108 | super(input_hook, self).__init__()
109 | self.inputs = None
110 |
111 | def hook(self, module, input, output):
112 | self.inputs = input
113 |
114 | def clear(self):
115 | self.inputs = None
116 |
117 | class GetLayerInpOut:
118 | def __init__(self, model: QuantModel, layer: Union[QuantModule, BaseQuantBlock],
119 | device: torch.device, input_prob: bool = False):
120 | self.model = model
121 | self.layer = layer
122 | self.device = device
123 | self.data_saver = DataSaverHook(store_input=True, store_output=False, stop_forward=True)
124 | self.input_prob = input_prob
125 |
126 | def __call__(self, model_input):
127 |
128 | handle = self.layer.register_forward_hook(self.data_saver)
129 | with torch.no_grad():
130 | self.model.set_quant_state(weight_quant=True, act_quant=True)
131 | try:
132 | _ = self.model(model_input.to(self.device))
133 | except StopForwardException:
134 | pass
135 |
136 | handle.remove()
137 |
138 | return self.data_saver.input_store[0].detach()
139 |
140 | class GetDcFpLayerInpOut:
141 | def __init__(self, model: QuantModel, layer: Union[QuantModule, BaseQuantBlock],
142 | device: torch.device, input_prob: bool = False, lamb=50, bn_lr=1e-3):
143 | self.model = model
144 | self.layer = layer
145 | self.device = device
146 | self.data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=False)
147 | self.input_prob = input_prob
148 | self.bn_stats = []
149 | self.eps = 1e-6
150 | self.lamb=lamb
151 | self.bn_lr=bn_lr
152 | for n, m in self.layer.named_modules():
153 | if isinstance(m, nn.BatchNorm2d):
154 | # get the statistics in the BatchNorm layers
155 | self.bn_stats.append(
156 | (m.running_mean.detach().clone().flatten().cuda(),
157 | torch.sqrt(m.running_var +
158 | self.eps).detach().clone().flatten().cuda()))
159 |
160 | def own_loss(self, A, B):
161 | return (A - B).norm()**2 / B.size(0)
162 |
163 | def relative_loss(self, A, B):
164 | return (A-B).abs().mean()/A.abs().mean()
165 |
166 | def __call__(self, model_input):
167 | self.model.set_quant_state(False, False)
168 | handle = self.layer.register_forward_hook(self.data_saver)
169 | hooks = []
170 | hook_handles = []
171 | for name, module in self.layer.named_modules():
172 | if isinstance(module, nn.BatchNorm2d):
173 | hook = input_hook()
174 | hooks.append(hook)
175 | hook_handles.append(module.register_forward_hook(hook.hook))
176 | assert len(hooks) == len(self.bn_stats)
177 |
178 | with torch.no_grad():
179 | try:
180 | output_fp = self.model(model_input.to(self.device))
181 | except StopForwardException:
182 | pass
183 | if self.input_prob:
184 | input_sym = self.data_saver.input_store[0].detach()
185 |
186 | handle.remove()
187 | para_input = input_sym.data.clone()
188 | para_input = para_input.to(self.device)
189 | para_input.requires_grad = True
190 | optimizer = optim.Adam([para_input], lr=self.bn_lr)
191 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
192 | min_lr=1e-5,
193 | verbose=False,
194 | patience=100)
195 | iters=500
196 | for iter in range(iters):
197 | self.layer.zero_grad()
198 | optimizer.zero_grad()
199 | for hook in hooks:
200 | hook.clear()
201 | _ = self.layer(para_input)
202 | mean_loss = 0
203 | std_loss = 0
204 | for num, (bn_stat, hook) in enumerate(zip(self.bn_stats, hooks)):
205 | tmp_input = hook.inputs[0]
206 | bn_mean, bn_std = bn_stat[0], bn_stat[1]
207 | tmp_mean = torch.mean(tmp_input.view(tmp_input.size(0),
208 | tmp_input.size(1), -1),
209 | dim=2)
210 | tmp_std = torch.sqrt(
211 | torch.var(tmp_input.view(tmp_input.size(0),
212 | tmp_input.size(1), -1),
213 | dim=2) + self.eps)
214 | mean_loss += self.own_loss(bn_mean, tmp_mean)
215 | std_loss += self.own_loss(bn_std, tmp_std)
216 | constraint_loss = lp_loss(para_input, input_sym) / self.lamb
217 | total_loss = mean_loss + std_loss + constraint_loss
218 | total_loss.backward()
219 | optimizer.step()
220 | scheduler.step(total_loss.item())
221 | # if (iter+1) % 500 == 0:
222 | # print('Total loss:\t{:.3f} (mse:{:.3f}, mean:{:.3f}, std:{:.3f})\tcount={}'.format(
223 | # float(total_loss), float(constraint_loss), float(mean_loss), float(std_loss), iter))
224 |
225 | with torch.no_grad():
226 | out_fp = self.layer(para_input)
227 |
228 | if self.input_prob:
229 | return out_fp.detach(), output_fp.detach(), para_input.detach()
230 | return out_fp.detach(), output_fp.detach()
--------------------------------------------------------------------------------
/quant/fold_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 |
5 |
6 | class StraightThrough(nn.Module):
7 | def __int__(self):
8 | super().__init__()
9 |
10 | def forward(self, input):
11 | return input
12 |
13 |
14 | def _fold_bn(conv_module, bn_module):
15 | w = conv_module.weight.data
16 | y_mean = bn_module.running_mean
17 | y_var = bn_module.running_var
18 | safe_std = torch.sqrt(y_var + bn_module.eps)
19 | w_view = (conv_module.out_channels, 1, 1, 1)
20 | if bn_module.affine:
21 | weight = w * (bn_module.weight / safe_std).view(w_view)
22 | beta = bn_module.bias - bn_module.weight * y_mean / safe_std
23 | if conv_module.bias is not None:
24 | bias = bn_module.weight * conv_module.bias / safe_std + beta
25 | else:
26 | bias = beta
27 | else:
28 | weight = w / safe_std.view(w_view)
29 | beta = -y_mean / safe_std
30 | if conv_module.bias is not None:
31 | bias = conv_module.bias / safe_std + beta
32 | else:
33 | bias = beta
34 | return weight, bias
35 |
36 |
37 | def fold_bn_into_conv(conv_module, bn_module):
38 | w, b = _fold_bn(conv_module, bn_module)
39 | if conv_module.bias is None:
40 | conv_module.bias = nn.Parameter(b)
41 | else:
42 | conv_module.bias.data = b
43 | conv_module.weight.data = w
44 | # set bn running stats
45 | bn_module.running_mean = bn_module.bias.data
46 | bn_module.running_var = bn_module.weight.data ** 2
47 |
48 |
49 | def reset_bn(module: nn.BatchNorm2d):
50 | if module.track_running_stats:
51 | module.running_mean.zero_()
52 | module.running_var.fill_(1-module.eps)
53 | # we do not reset numer of tracked batches here
54 | # self.num_batches_tracked.zero_()
55 | if module.affine:
56 | init.ones_(module.weight)
57 | init.zeros_(module.bias)
58 |
59 |
60 | def is_bn(m):
61 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)
62 |
63 |
64 | def is_absorbing(m):
65 | return (isinstance(m, nn.Conv2d)) or isinstance(m, nn.Linear)
66 |
67 |
68 | def search_fold_and_remove_bn(model):
69 | model.eval()
70 | prev = None
71 | for n, m in model.named_children():
72 | if is_bn(m) and is_absorbing(prev):
73 | fold_bn_into_conv(prev, m)
74 | # set the bn module to straight through
75 | setattr(model, n, StraightThrough())
76 | elif is_absorbing(m):
77 | prev = m
78 | else:
79 | prev = search_fold_and_remove_bn(m)
80 | return prev
81 |
82 |
83 | def search_fold_and_reset_bn(model):
84 | model.eval()
85 | prev = None
86 | for n, m in model.named_children():
87 | if is_bn(m) and is_absorbing(prev):
88 | fold_bn_into_conv(prev, m)
89 | # reset_bn(m)
90 | else:
91 | search_fold_and_reset_bn(m)
92 | prev = m
93 |
94 |
--------------------------------------------------------------------------------
/quant/layer_recon.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .quant_layer import QuantModule, lp_loss
4 | from .quant_model import QuantModel
5 | from .block_recon import LinearTempDecay
6 | from .adaptive_rounding import AdaRoundQuantizer
7 | from .set_weight_quantize_params import get_init, get_dc_fp_init
8 | from .set_act_quantize_params import set_act_quantize_params
9 | from .quant_block import BaseQuantBlock, specials_unquantized
10 |
11 | include = False
12 | def find_unquantized_module(model: torch.nn.Module, module_list: list = [], name_list: list = []):
13 | """Store subsequent unquantized modules in a list"""
14 | global include
15 | for name, module in model.named_children():
16 | if isinstance(module, (QuantModule, BaseQuantBlock)):
17 | if not module.trained:
18 | include = True
19 | module.set_quant_state(False,False)
20 | name_list.append(name)
21 | module_list.append(module)
22 | elif include and type(module) in specials_unquantized:
23 | name_list.append(name)
24 | module_list.append(module)
25 | else:
26 | find_unquantized_module(module, module_list, name_list)
27 | return module_list[1:], name_list[1:]
28 |
29 | def layer_reconstruction(model: QuantModel, fp_model: QuantModel, layer: QuantModule, fp_layer: QuantModule,
30 | cali_data: torch.Tensor,batch_size: int = 32, iters: int = 20000, weight: float = 0.001,
31 | opt_mode: str = 'mse', b_range: tuple = (20, 2),
32 | warmup: float = 0.0, p: float = 2.0, lr: float = 4e-5, input_prob: float = 1.0,
33 | keep_gpu: bool = True, lamb_r: float = 0.2, T: float = 7.0, bn_lr: float = 1e-3, lamb_c=0.02):
34 | """
35 | Reconstruction to optimize the output from each layer.
36 |
37 | :param model: QuantModel
38 | :param layer: QuantModule that needs to be optimized
39 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound
40 | :param batch_size: mini-batch size for reconstruction
41 | :param iters: optimization iterations for reconstruction,
42 | :param weight: the weight of rounding regularization term
43 | :param opt_mode: optimization mode
44 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output
45 | :param include_act_func: optimize the output after activation function
46 | :param b_range: temperature range
47 | :param warmup: proportion of iterations that no scheduling for temperature
48 | :param lr: learning rate for act delta learning
49 | :param p: L_p norm minimization
50 | :param lamb_r: hyper-parameter for regularization
51 | :param T: temperature coefficient for KL divergence
52 | :param bn_lr: learning rate for DC
53 | :param lamb_c: hyper-parameter for DC
54 | """
55 |
56 | '''get input and set scale'''
57 | cached_inps = get_init(model, layer, cali_data, batch_size=batch_size,
58 | input_prob=True, keep_gpu=keep_gpu)
59 | cached_outs, cached_output, cur_syms = get_dc_fp_init(fp_model, fp_layer, cali_data, batch_size=batch_size,
60 | input_prob=True, keep_gpu=keep_gpu, bn_lr=bn_lr, lamb=lamb_c)
61 | set_act_quantize_params(layer, cali_data=cached_inps[:min(256, cached_inps.size(0))])
62 |
63 | '''set state'''
64 | cur_weight, cur_act = True, True
65 |
66 | global include
67 | module_list, name_list, include = [], [], False
68 | module_list, name_list = find_unquantized_module(model, module_list, name_list)
69 | layer.set_quant_state(cur_weight, cur_act)
70 | for para in model.parameters():
71 | para.requires_grad = False
72 |
73 | '''set quantizer'''
74 | round_mode = 'learned_hard_sigmoid'
75 | # Replace weight quantizer to AdaRoundQuantizer
76 | w_para, a_para = [], []
77 | w_opt, a_opt = None, None
78 | scheduler, a_scheduler = None, None
79 |
80 | '''weight'''
81 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode,
82 | weight_tensor=layer.org_weight.data)
83 | layer.weight_quantizer.soft_targets = True
84 | w_para += [layer.weight_quantizer.alpha]
85 |
86 | '''activation'''
87 | if layer.act_quantizer.delta is not None:
88 | layer.act_quantizer.delta = torch.nn.Parameter(torch.tensor(layer.act_quantizer.delta))
89 | a_para += [layer.act_quantizer.delta]
90 | '''set up drop'''
91 | layer.act_quantizer.is_training = True
92 |
93 | if len(w_para) != 0:
94 | w_opt = torch.optim.Adam(w_para, lr=3e-3)
95 | if len(a_para) != 0:
96 | a_opt = torch.optim.Adam(a_para, lr=lr)
97 | a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=iters, eta_min=0.)
98 |
99 | loss_mode = 'relaxation'
100 | rec_loss = opt_mode
101 | loss_func = LossFunction(layer, round_loss=loss_mode, weight=weight,
102 | max_count=iters, rec_loss=rec_loss, b_range=b_range,
103 | decay_start=0, warmup=warmup, p=p, lam=lamb_r, T=T)
104 | device = 'cuda'
105 | sz = cached_inps.size(0)
106 | for i in range(iters):
107 | idx = torch.randint(0, sz, (batch_size,))
108 | cur_inp = cached_inps[idx].to(device)
109 | cur_sym = cur_syms[idx].to(device)
110 | output_fp = cached_output[idx].to(device)
111 | cur_out = cached_outs[idx].to(device)
112 | if input_prob < 1.0:
113 | drop_inp = torch.where(torch.rand_like(cur_inp) < input_prob, cur_inp, cur_sym)
114 |
115 | cur_inp = torch.cat((drop_inp, cur_inp))
116 |
117 | if w_opt:
118 | w_opt.zero_grad()
119 | if a_opt:
120 | a_opt.zero_grad()
121 | out_all = layer(cur_inp)
122 |
123 | '''forward for prediction difference'''
124 | out_drop = out_all[:batch_size]
125 | out_quant = out_all[batch_size:]
126 | output = out_quant
127 | for num, module in enumerate(module_list):
128 | # for ResNet and RegNet
129 | if name_list[num] == 'fc':
130 | output = torch.flatten(output, 1)
131 | # for MobileNet and MNasNet
132 | if isinstance(module, torch.nn.Dropout):
133 | output = output.mean([2, 3])
134 | output = module(output)
135 | err = loss_func(out_drop, cur_out, output, output_fp)
136 |
137 | err.backward(retain_graph=True)
138 | if w_opt:
139 | w_opt.step()
140 | if a_opt:
141 | a_opt.step()
142 | if scheduler:
143 | scheduler.step()
144 | if a_scheduler:
145 | a_scheduler.step()
146 | torch.cuda.empty_cache()
147 |
148 | layer.weight_quantizer.soft_targets = False
149 | layer.act_quantizer.is_training = False
150 | layer.trained = True
151 |
152 |
153 | class LossFunction:
154 | def __init__(self,
155 | layer: QuantModule,
156 | round_loss: str = 'relaxation',
157 | weight: float = 1.,
158 | rec_loss: str = 'mse',
159 | max_count: int = 2000,
160 | b_range: tuple = (10, 2),
161 | decay_start: float = 0.0,
162 | warmup: float = 0.0,
163 | p: float = 2.,
164 | lam: float = 1.0,
165 | T: float = 7.0):
166 |
167 | self.layer = layer
168 | self.round_loss = round_loss
169 | self.weight = weight
170 | self.rec_loss = rec_loss
171 | self.loss_start = max_count * warmup
172 | self.p = p
173 | self.lam = lam
174 | self.T = T
175 |
176 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start,
177 | start_b=b_range[0], end_b=b_range[1])
178 | self.count = 0
179 | self.pd_loss = torch.nn.KLDivLoss(reduction='batchmean')
180 |
181 | def __call__(self, pred, tgt, output, output_fp):
182 | """
183 | Compute the total loss for adaptive rounding:
184 | rec_loss is the quadratic output reconstruction loss, round_loss is
185 | a regularization term to optimize the rounding policy, pd_loss is the
186 | prediction difference loss.
187 |
188 | :param pred: output from quantized model
189 | :param tgt: output from FP model
190 | :param output: prediction from quantized model
191 | :param output_fp: prediction from FP model
192 | :return: total loss function
193 | """
194 | self.count += 1
195 | if self.rec_loss == 'mse':
196 | rec_loss = lp_loss(pred, tgt, p=self.p)
197 | else:
198 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))
199 |
200 | pd_loss = self.pd_loss(F.log_softmax(output / self.T, dim=1), F.softmax(output_fp / self.T, dim=1)) / self.lam
201 |
202 | b = self.temp_decay(self.count)
203 | if self.count < self.loss_start or self.round_loss == 'none':
204 | b = round_loss = 0
205 | elif self.round_loss == 'relaxation':
206 | round_loss = 0
207 | round_vals = self.layer.weight_quantizer.get_soft_targets()
208 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
209 | else:
210 | raise NotImplementedError
211 | total_loss = rec_loss + round_loss + pd_loss
212 | if self.count % 500 == 0:
213 | print('Total loss:\t{:.3f} (rec:{:.3f}, pd:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
214 | float(total_loss), float(rec_loss), float(pd_loss), float(round_loss), b, self.count))
215 | return total_loss
216 |
--------------------------------------------------------------------------------
/quant/quant_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .quant_layer import QuantModule, UniformAffineQuantizer
3 | from models.resnet import BasicBlock, Bottleneck
4 | from models.regnet import ResBottleneckBlock
5 | from models.mobilenetv2 import InvertedResidual
6 | from models.mnasnet import _InvertedResidual
7 |
8 |
9 | class BaseQuantBlock(nn.Module):
10 | """
11 | Base implementation of block structures for all networks.
12 | Due to the branch architecture, we have to perform activation function
13 | and quantization after the elemental-wise add operation, therefore, we
14 | put this part in this class.
15 | """
16 | def __init__(self):
17 | super().__init__()
18 | self.use_weight_quant = False
19 | self.use_act_quant = False
20 | self.ignore_reconstruction = False
21 | self.trained = False
22 |
23 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
24 | # setting weight quantization here does not affect actual forward pass
25 | self.use_weight_quant = weight_quant
26 | self.use_act_quant = act_quant
27 | for m in self.modules():
28 | if isinstance(m, QuantModule):
29 | m.set_quant_state(weight_quant, act_quant)
30 |
31 |
32 | class QuantBasicBlock(BaseQuantBlock):
33 | """
34 | Implementation of Quantized BasicBlock used in ResNet-18 and ResNet-34.
35 | """
36 | def __init__(self, basic_block: BasicBlock, weight_quant_params: dict = {}, act_quant_params: dict = {}):
37 | super().__init__()
38 | self.conv1 = QuantModule(basic_block.conv1, weight_quant_params, act_quant_params)
39 | self.conv1.norm_function = basic_block.bn1
40 | self.conv1.activation_function = basic_block.relu1
41 | self.conv2 = QuantModule(basic_block.conv2, weight_quant_params, act_quant_params, disable_act_quant=True)
42 | self.conv2.norm_function = basic_block.bn2
43 |
44 | if basic_block.downsample is None:
45 | self.downsample = None
46 | else:
47 | self.downsample = QuantModule(basic_block.downsample[0], weight_quant_params, act_quant_params,
48 | disable_act_quant=True)
49 | self.downsample.norm_function = basic_block.downsample[1]
50 | self.activation_function = basic_block.relu2
51 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params)
52 |
53 | def forward(self, x):
54 | residual = x if self.downsample is None else self.downsample(x)
55 | out = self.conv1(x)
56 | out = self.conv2(out)
57 | out += residual
58 | out = self.activation_function(out)
59 | if self.use_act_quant:
60 | out = self.act_quantizer(out)
61 | return out
62 |
63 |
64 | class QuantBottleneck(BaseQuantBlock):
65 | """
66 | Implementation of Quantized Bottleneck Block used in ResNet-50, -101 and -152.
67 | """
68 |
69 | def __init__(self, bottleneck: Bottleneck, weight_quant_params: dict = {}, act_quant_params: dict = {}):
70 | super().__init__()
71 | self.conv1 = QuantModule(bottleneck.conv1, weight_quant_params, act_quant_params)
72 | self.conv1.norm_function = bottleneck.bn1
73 | self.conv1.activation_function = bottleneck.relu1
74 | self.conv2 = QuantModule(bottleneck.conv2, weight_quant_params, act_quant_params)
75 | self.conv2.norm_function = bottleneck.bn2
76 | self.conv2.activation_function = bottleneck.relu2
77 | self.conv3 = QuantModule(bottleneck.conv3, weight_quant_params, act_quant_params, disable_act_quant=True)
78 | self.conv3.norm_function = bottleneck.bn3
79 |
80 | if bottleneck.downsample is None:
81 | self.downsample = None
82 | else:
83 | self.downsample = QuantModule(bottleneck.downsample[0], weight_quant_params, act_quant_params,
84 | disable_act_quant=True)
85 | self.downsample.norm_function = bottleneck.downsample[1]
86 | # modify the activation function to ReLU
87 | self.activation_function = bottleneck.relu3
88 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params)
89 |
90 | def forward(self, x):
91 | residual = x if self.downsample is None else self.downsample(x)
92 | out = self.conv1(x)
93 | out = self.conv2(out)
94 | out = self.conv3(out)
95 | out += residual
96 | out = self.activation_function(out)
97 | if self.use_act_quant:
98 | out = self.act_quantizer(out)
99 | return out
100 |
101 |
102 | class QuantResBottleneckBlock(BaseQuantBlock):
103 | """
104 | Implementation of Quantized Bottleneck Blockused in RegNetX (no SE module).
105 | """
106 |
107 | def __init__(self, bottleneck: ResBottleneckBlock, weight_quant_params: dict = {}, act_quant_params: dict = {}):
108 | super().__init__()
109 | self.conv1 = QuantModule(bottleneck.f.a, weight_quant_params, act_quant_params)
110 | self.conv1.norm_function = bottleneck.f.a_bn
111 | self.conv1.activation_function = bottleneck.f.a_relu
112 | self.conv2 = QuantModule(bottleneck.f.b, weight_quant_params, act_quant_params)
113 | self.conv2.norm_function = bottleneck.f.b_bn
114 | self.conv2.activation_function = bottleneck.f.b_relu
115 | self.conv3 = QuantModule(bottleneck.f.c, weight_quant_params, act_quant_params, disable_act_quant=True)
116 | self.conv3.norm_function = bottleneck.f.c_bn
117 |
118 | if bottleneck.proj_block:
119 | self.downsample = QuantModule(bottleneck.proj, weight_quant_params, act_quant_params,
120 | disable_act_quant=True)
121 | self.downsample.norm_function = bottleneck.bn
122 | else:
123 | self.downsample = None
124 | # copying all attributes in original block
125 | self.proj_block = bottleneck.proj_block
126 |
127 | self.activation_function = bottleneck.relu
128 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params)
129 |
130 | def forward(self, x):
131 | residual = x if not self.proj_block else self.downsample(x)
132 | out = self.conv1(x)
133 | out = self.conv2(out)
134 | out = self.conv3(out)
135 | out += residual
136 | out = self.activation_function(out)
137 | if self.use_act_quant:
138 | out = self.act_quantizer(out)
139 | return out
140 |
141 |
142 | class QuantInvertedResidual(BaseQuantBlock):
143 | """
144 | Implementation of Quantized Inverted Residual Block used in MobileNetV2.
145 | Inverted Residual does not have activation function.
146 | """
147 |
148 | def __init__(self, inv_res: InvertedResidual, weight_quant_params: dict = {}, act_quant_params: dict = {}):
149 | super().__init__()
150 |
151 | self.use_res_connect = inv_res.use_res_connect
152 | self.expand_ratio = inv_res.expand_ratio
153 | if self.expand_ratio == 1:
154 | self.conv = nn.Sequential(
155 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params),
156 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params, disable_act_quant=True),
157 | )
158 | self.conv[0].norm_function = inv_res.conv[1]
159 | self.conv[0].activation_function = nn.ReLU6()
160 | self.conv[1].norm_function = inv_res.conv[4]
161 | else:
162 | self.conv = nn.Sequential(
163 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params),
164 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params),
165 | QuantModule(inv_res.conv[6], weight_quant_params, act_quant_params, disable_act_quant=True),
166 | )
167 | self.conv[0].norm_function = inv_res.conv[1]
168 | self.conv[0].activation_function = nn.ReLU6()
169 | self.conv[1].norm_function = inv_res.conv[4]
170 | self.conv[1].activation_function = nn.ReLU6()
171 | self.conv[2].norm_function = inv_res.conv[7]
172 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params)
173 |
174 | def forward(self, x):
175 | if self.use_res_connect:
176 | out = x + self.conv(x)
177 | else:
178 | out = self.conv(x)
179 | if self.use_act_quant:
180 | out = self.act_quantizer(out)
181 | return out
182 |
183 |
184 | class _QuantInvertedResidual(BaseQuantBlock):
185 | def __init__(self, _inv_res: _InvertedResidual, weight_quant_params: dict = {}, act_quant_params: dict = {}):
186 | super().__init__()
187 |
188 | self.apply_residual = _inv_res.apply_residual
189 | self.conv = nn.Sequential(
190 | QuantModule(_inv_res.layers[0], weight_quant_params, act_quant_params),
191 | QuantModule(_inv_res.layers[3], weight_quant_params, act_quant_params),
192 | QuantModule(_inv_res.layers[6], weight_quant_params, act_quant_params, disable_act_quant=True),
193 | )
194 | self.conv[0].activation_function = nn.ReLU()
195 | self.conv[0].norm_function = _inv_res.layers[1]
196 | self.conv[1].activation_function = nn.ReLU()
197 | self.conv[1].norm_function = _inv_res.layers[4]
198 | self.conv[2].norm_function = _inv_res.layers[7]
199 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params)
200 |
201 | def forward(self, x):
202 | if self.apply_residual:
203 | out = x + self.conv(x)
204 | else:
205 | out = self.conv(x)
206 | if self.use_act_quant:
207 | out = self.act_quantizer(out)
208 | return out
209 |
210 |
211 | specials = {
212 | BasicBlock: QuantBasicBlock,
213 | Bottleneck: QuantBottleneck,
214 | ResBottleneckBlock: QuantResBottleneckBlock,
215 | InvertedResidual: QuantInvertedResidual,
216 | _InvertedResidual: _QuantInvertedResidual,
217 | }
218 |
219 | specials_unquantized = [nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.Dropout]
220 |
--------------------------------------------------------------------------------
/quant/quant_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import Union
5 |
6 |
7 | class StraightThrough(nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def forward(self, input):
12 | return input
13 |
14 |
15 | def round_ste(x: torch.Tensor):
16 | """
17 | Implement Straight-Through Estimator for rounding operation.
18 | """
19 | return (x.round() - x).detach() + x
20 |
21 |
22 | def lp_loss(pred, tgt, p=2.0, reduction='none'):
23 | """
24 | loss function measured in L_p Norm
25 | """
26 | if reduction == 'none':
27 | return (pred - tgt).abs().pow(p).sum(1).mean()
28 | else:
29 | return (pred - tgt).abs().pow(p).mean()
30 |
31 |
32 | class UniformAffineQuantizer(nn.Module):
33 | """
34 | PyTorch Function that can be used for asymmetric quantization (also called uniform affine
35 | quantization). Quantizes its argument in the forward pass, passes the gradient 'straight
36 | through' on the backward pass, ignoring the quantization that occurred.
37 | Based on https://arxiv.org/abs/1806.08342.
38 |
39 | :param n_bits: number of bit for quantization
40 | :param symmetric: if True, the zero_point should always be 0
41 | :param channel_wise: if True, compute scale and zero_point in each channel
42 | :param scale_method: determines the quantization scale and zero point
43 | :param prob: for qdrop;
44 | """
45 |
46 | def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False,
47 | scale_method: str = 'minmax',
48 | leaf_param: bool = False, prob: float = 1.0):
49 | super(UniformAffineQuantizer, self).__init__()
50 | self.sym = symmetric
51 | if self.sym:
52 | raise NotImplementedError
53 | assert 2 <= n_bits <= 8, 'bitwidth not supported'
54 | self.n_bits = n_bits
55 | self.n_levels = 2 ** self.n_bits
56 | self.delta = 1.0
57 | self.zero_point = 0.0
58 | self.inited = True
59 |
60 | '''if leaf_param, use EMA to set scale'''
61 | self.leaf_param = leaf_param
62 | self.channel_wise = channel_wise
63 | self.eps = torch.tensor(1e-8, dtype=torch.float32)
64 |
65 | '''mse params'''
66 | self.scale_method = 'mse'
67 | self.one_side_dist = None
68 | self.num = 100
69 |
70 | '''for activation quantization'''
71 | self.running_min = None
72 | self.running_max = None
73 |
74 | '''do like dropout'''
75 | self.prob = prob
76 | self.is_training = False
77 |
78 | def set_inited(self, inited: bool = True): # inited manually
79 | self.inited = inited
80 |
81 | def update_quantize_range(self, x_min, x_max):
82 | if self.running_min is None:
83 | self.running_min = x_min
84 | self.running_max = x_max
85 | self.running_min = 0.1 * x_min + 0.9 * self.running_min
86 | self.running_max = 0.1 * x_max + 0.9 * self.running_max
87 | return self.running_min, self.running_max
88 |
89 | def forward(self, x: torch.Tensor):
90 | if self.inited is False:
91 | if self.leaf_param:
92 | self.delta, self.zero_point = self.init_quantization_scale(x.clone().detach(), self.channel_wise)
93 | else:
94 | self.delta, self.zero_point = self.init_quantization_scale(x.clone().detach(), self.channel_wise)
95 |
96 | # start quantization
97 | x_int = round_ste(x / self.delta) + self.zero_point
98 | x_quant = torch.clamp(x_int, 0, self.n_levels - 1)
99 | x_dequant = (x_quant - self.zero_point) * self.delta
100 | if self.is_training and self.prob < 1.0:
101 | x_ans = torch.where(torch.rand_like(x) < self.prob, x_dequant, x)
102 | else:
103 | x_ans = x_dequant
104 | return x_ans
105 |
106 | def lp_loss(self, pred, tgt, p=2.0):
107 | x = (pred - tgt).abs().pow(p)
108 | if not self.channel_wise:
109 | return x.mean()
110 | else:
111 | y = torch.flatten(x, 1)
112 | return y.mean(1)
113 |
114 | def calculate_qparams(self, min_val, max_val):
115 | # one_dim or one element
116 | quant_min, quant_max = 0, self.n_levels - 1
117 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
118 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
119 |
120 | scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
121 | scale = torch.max(scale, self.eps)
122 | zero_point = quant_min - torch.round(min_val_neg / scale)
123 | zero_point = torch.clamp(zero_point, quant_min, quant_max)
124 | return scale, zero_point
125 |
126 | def quantize(self, x: torch.Tensor, x_max, x_min):
127 | delta, zero_point = self.calculate_qparams(x_min, x_max)
128 | if self.channel_wise:
129 | new_shape = [1] * len(x.shape)
130 | new_shape[0] = x.shape[0]
131 | delta = delta.reshape(new_shape)
132 | zero_point = zero_point.reshape(new_shape)
133 | x_int = torch.round(x / delta)
134 | x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1)
135 | x_float_q = (x_quant - zero_point) * delta
136 | return x_float_q
137 |
138 | def perform_2D_search(self, x):
139 | if self.channel_wise:
140 | y = torch.flatten(x, 1)
141 | x_min, x_max = torch._aminmax(y, 1)
142 | # may also have the one side distribution in some channels
143 | x_max = torch.max(x_max, torch.zeros_like(x_max))
144 | x_min = torch.min(x_min, torch.zeros_like(x_min))
145 | else:
146 | x_min, x_max = torch._aminmax(x)
147 | xrange = x_max - x_min
148 | best_score = torch.zeros_like(x_min) + (1e+10)
149 | best_min = x_min.clone()
150 | best_max = x_max.clone()
151 | # enumerate xrange
152 | for i in range(1, self.num + 1):
153 | tmp_min = torch.zeros_like(x_min)
154 | tmp_max = xrange / self.num * i
155 | tmp_delta = (tmp_max - tmp_min) / (2 ** self.n_bits - 1)
156 | # enumerate zp
157 | for zp in range(0, self.n_levels):
158 | new_min = tmp_min - zp * tmp_delta
159 | new_max = tmp_max - zp * tmp_delta
160 | x_q = self.quantize(x, new_max, new_min)
161 | score = self.lp_loss(x, x_q, 2.4)
162 | best_min = torch.where(score < best_score, new_min, best_min)
163 | best_max = torch.where(score < best_score, new_max, best_max)
164 | best_score = torch.min(best_score, score)
165 | return best_min, best_max
166 |
167 | def perform_1D_search(self, x):
168 | if self.channel_wise:
169 | y = torch.flatten(x, 1)
170 | x_min, x_max = torch._aminmax(y, 1)
171 | else:
172 | x_min, x_max = torch._aminmax(x)
173 | xrange = torch.max(x_min.abs(), x_max)
174 | best_score = torch.zeros_like(x_min) + (1e+10)
175 | best_min = x_min.clone()
176 | best_max = x_max.clone()
177 | # enumerate xrange
178 | for i in range(1, self.num + 1):
179 | thres = xrange / self.num * i
180 | new_min = torch.zeros_like(x_min) if self.one_side_dist == 'pos' else -thres
181 | new_max = torch.zeros_like(x_max) if self.one_side_dist == 'neg' else thres
182 | x_q = self.quantize(x, new_max, new_min)
183 | score = self.lp_loss(x, x_q, 2.4)
184 | best_min = torch.where(score < best_score, new_min, best_min)
185 | best_max = torch.where(score < best_score, new_max, best_max)
186 | best_score = torch.min(score, best_score)
187 | return best_min, best_max
188 |
189 | def get_x_min_x_max(self, x):
190 | if self.scale_method != 'mse':
191 | raise NotImplementedError
192 | if self.one_side_dist is None:
193 | self.one_side_dist = 'pos' if x.min() >= 0.0 else 'neg' if x.max() <= 0.0 else 'no'
194 | if self.one_side_dist != 'no' or self.sym: # one-side distribution or symmetric value for 1-d search
195 | best_min, best_max = self.perform_1D_search(x)
196 | else: # 2-d search
197 | best_min, best_max = self.perform_2D_search(x)
198 | if self.leaf_param:
199 | return self.update_quantize_range(best_min, best_max)
200 | return best_min, best_max
201 |
202 | def init_quantization_scale_channel(self, x: torch.Tensor):
203 | x_min, x_max = self.get_x_min_x_max(x)
204 | return self.calculate_qparams(x_min, x_max)
205 |
206 | def init_quantization_scale(self, x_clone: torch.Tensor, channel_wise: bool = False):
207 | if channel_wise:
208 | # determine the scale and zero point channel-by-channel
209 | delta, zero_point = self.init_quantization_scale_channel(x_clone)
210 | new_shape = [1] * len(x_clone.shape)
211 | new_shape[0] = x_clone.shape[0]
212 | delta = delta.reshape(new_shape)
213 | zero_point = zero_point.reshape(new_shape)
214 | else:
215 | delta, zero_point = self.init_quantization_scale_channel(x_clone)
216 | return delta, zero_point
217 |
218 | def bitwidth_refactor(self, refactored_bit: int):
219 | assert 2 <= refactored_bit <= 8, 'bitwidth not supported'
220 | self.n_bits = refactored_bit
221 | self.n_levels = 2 ** self.n_bits
222 |
223 | @torch.jit.export
224 | def extra_repr(self):
225 | return 'bit={}, is_training={}, inited={}'.format(
226 | self.n_bits, self.is_training, self.inited
227 | )
228 |
229 |
230 | class QuantModule(nn.Module):
231 | """
232 | Quantized Module that can perform quantized convolution or normal convolution.
233 | To activate quantization, please use set_quant_state function.
234 | """
235 |
236 | def __init__(self, org_module: Union[nn.Conv2d, nn.Linear], weight_quant_params: dict = {},
237 | act_quant_params: dict = {}, disable_act_quant=False):
238 | super(QuantModule, self).__init__()
239 | if isinstance(org_module, nn.Conv2d):
240 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
241 | dilation=org_module.dilation, groups=org_module.groups)
242 | self.fwd_func = F.conv2d
243 | else:
244 | self.fwd_kwargs = dict()
245 | self.fwd_func = F.linear
246 | self.weight = org_module.weight
247 | self.org_weight = org_module.weight.data.clone()
248 | if org_module.bias is not None:
249 | self.bias = org_module.bias
250 | self.org_bias = org_module.bias.data.clone()
251 | else:
252 | self.bias = None
253 | self.org_bias = None
254 | # de-activate the quantized forward default
255 | self.use_weight_quant = False
256 | self.use_act_quant = False
257 | # initialize quantizer
258 | self.weight_quantizer = UniformAffineQuantizer(**weight_quant_params)
259 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params)
260 |
261 | self.norm_function = StraightThrough()
262 | self.activation_function = StraightThrough()
263 | self.ignore_reconstruction = False
264 | self.disable_act_quant = disable_act_quant
265 | self.trained = False
266 |
267 | def forward(self, input: torch.Tensor):
268 | if self.use_weight_quant:
269 | weight = self.weight_quantizer(self.weight)
270 | bias = self.bias
271 | else:
272 | weight = self.org_weight
273 | bias = self.org_bias
274 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)
275 | # disable act quantization is designed for convolution before elemental-wise operation,
276 | # in that case, we apply activation function and quantization after ele-wise op.
277 | out = self.norm_function(out)
278 | out = self.activation_function(out)
279 | if self.disable_act_quant:
280 | return out
281 | if self.use_act_quant:
282 | out = self.act_quantizer(out)
283 | return out
284 |
285 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
286 | self.use_weight_quant = weight_quant
287 | self.use_act_quant = act_quant
288 |
289 | @torch.jit.export
290 | def extra_repr(self):
291 | return 'wbit={}, abit={}, disable_act_quant={}'.format(
292 | self.weight_quantizer.n_bits, self.act_quantizer.n_bits, self.disable_act_quant
293 | )
294 |
--------------------------------------------------------------------------------
/quant/quant_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .quant_block import specials, BaseQuantBlock
3 | from .quant_layer import QuantModule, StraightThrough, UniformAffineQuantizer
4 | from .fold_bn import search_fold_and_remove_bn
5 |
6 |
7 | class QuantModel(nn.Module):
8 |
9 | def __init__(self, model: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}, is_fusing=True):
10 | super().__init__()
11 | if is_fusing:
12 | search_fold_and_remove_bn(model)
13 | self.model = model
14 | self.quant_module_refactor(self.model, weight_quant_params, act_quant_params)
15 | else:
16 | self.model = model
17 | self.quant_module_refactor_wo_fuse(self.model, weight_quant_params, act_quant_params)
18 |
19 | def quant_module_refactor(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}):
20 | """
21 | Recursively replace the normal conv2d and Linear layer to QuantModule
22 | :param module: nn.Module with nn.Conv2d or nn.Linear in its children
23 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer
24 | :param act_quant_params: quantization parameters like n_bits for activation quantizer
25 | """
26 | prev_quantmodule = None
27 | for name, child_module in module.named_children():
28 | if type(child_module) in specials:
29 | setattr(module, name, specials[type(child_module)](child_module, weight_quant_params, act_quant_params))
30 | elif isinstance(child_module, (nn.Conv2d, nn.Linear)):
31 | setattr(module, name, QuantModule(child_module, weight_quant_params, act_quant_params))
32 | prev_quantmodule = getattr(module, name)
33 |
34 | elif isinstance(child_module, (nn.ReLU, nn.ReLU6)):
35 | if prev_quantmodule is not None:
36 | prev_quantmodule.activation_function = child_module
37 | setattr(module, name, StraightThrough())
38 | else:
39 | continue
40 |
41 | elif isinstance(child_module, StraightThrough):
42 | continue
43 |
44 | else:
45 | self.quant_module_refactor(child_module, weight_quant_params, act_quant_params)
46 |
47 | def quant_module_refactor_wo_fuse(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}):
48 | """
49 | Recursively replace the normal conv2d and Linear layer to QuantModule
50 | :param module: nn.Module with nn.Conv2d or nn.Linear in its children
51 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer
52 | :param act_quant_params: quantization parameters like n_bits for activation quantizer
53 | """
54 | prev_quantmodule = None
55 | for name, child_module in module.named_children():
56 | if type(child_module) in specials:
57 | setattr(module, name, specials[type(child_module)](child_module, weight_quant_params, act_quant_params))
58 | elif isinstance(child_module, (nn.Conv2d, nn.Linear)):
59 | setattr(module, name, QuantModule(child_module, weight_quant_params, act_quant_params))
60 | prev_quantmodule = getattr(module, name)
61 |
62 | elif isinstance(child_module, nn.BatchNorm2d):
63 | if prev_quantmodule is not None:
64 | prev_quantmodule.norm_function = child_module
65 | setattr(module, name, StraightThrough())
66 | else:
67 | continue
68 |
69 | elif isinstance(child_module, (nn.ReLU, nn.ReLU6)):
70 | if prev_quantmodule is not None:
71 | prev_quantmodule.activation_function = child_module
72 | setattr(module, name, StraightThrough())
73 | else:
74 | continue
75 |
76 | elif isinstance(child_module, StraightThrough):
77 | continue
78 |
79 | else:
80 | self.quant_module_refactor_wo_fuse(child_module, weight_quant_params, act_quant_params)
81 |
82 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
83 | for m in self.model.modules():
84 | if isinstance(m, (QuantModule, BaseQuantBlock)):
85 | m.set_quant_state(weight_quant, act_quant)
86 |
87 | def forward(self, input):
88 | return self.model(input)
89 |
90 | def set_first_last_layer_to_8bit(self):
91 | w_list, a_list = [], []
92 | for module in self.model.modules():
93 | if isinstance(module, UniformAffineQuantizer):
94 | if module.leaf_param:
95 | a_list.append(module)
96 | else:
97 | w_list.append(module)
98 | w_list[0].bitwidth_refactor(8)
99 | w_list[-1].bitwidth_refactor(8)
100 | 'the image input has been in 0~255, set the last layer\'s input to 8-bit'
101 | a_list[-2].bitwidth_refactor(8)
102 | # a_list[0].bitwidth_refactor(8)
103 |
104 | def disable_network_output_quantization(self):
105 | module_list = []
106 | for m in self.model.modules():
107 | if isinstance(m, QuantModule):
108 | module_list += [m]
109 | module_list[-1].disable_act_quant = True
110 |
--------------------------------------------------------------------------------
/quant/set_act_quantize_params.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .quant_layer import QuantModule
3 | from .quant_block import BaseQuantBlock
4 | from .quant_model import QuantModel
5 | from typing import Union
6 |
7 | def set_act_quantize_params(module: Union[QuantModel, QuantModule, BaseQuantBlock],
8 | cali_data, batch_size: int = 256):
9 | module.set_quant_state(True, True)
10 |
11 | for t in module.modules():
12 | if isinstance(t, (QuantModule, BaseQuantBlock)):
13 | t.act_quantizer.set_inited(False)
14 |
15 | '''set or init step size and zero point in the activation quantizer'''
16 | batch_size = min(batch_size, cali_data.size(0))
17 | with torch.no_grad():
18 | for i in range(int(cali_data.size(0) / batch_size)):
19 | module(cali_data[i * batch_size:(i + 1) * batch_size].cuda())
20 | torch.cuda.empty_cache()
21 |
22 | for t in module.modules():
23 | if isinstance(t, (QuantModule, BaseQuantBlock)):
24 | t.act_quantizer.set_inited(True)
25 |
--------------------------------------------------------------------------------
/quant/set_weight_quantize_params.py:
--------------------------------------------------------------------------------
1 | from .quant_layer import QuantModule
2 | from .data_utils import save_inp_oup_data, save_dc_fp_data
3 |
4 |
5 | def get_init(model, block, cali_data, batch_size, input_prob: bool = False, keep_gpu: bool=True):
6 | cached_inps = save_inp_oup_data(model, block, cali_data, batch_size, input_prob=input_prob, keep_gpu=keep_gpu)
7 | return cached_inps
8 |
9 | def get_dc_fp_init(model, block, cali_data, batch_size, input_prob: bool = False, keep_gpu: bool=True, lamb=50, bn_lr=1e-3):
10 | cached_outs, cached_outputs, cached_sym = save_dc_fp_data(model, block, cali_data, batch_size, input_prob=input_prob, keep_gpu=keep_gpu, lamb=lamb, bn_lr=bn_lr)
11 | return cached_outs, cached_outputs, cached_sym
12 |
13 | def set_weight_quantize_params(model):
14 | for module in model.modules():
15 | if isinstance(module, QuantModule):
16 | module.weight_quantizer.set_inited(False)
17 | '''caculate the step size and zero point for weight quantizer'''
18 | module.weight_quantizer(module.weight)
19 | module.weight_quantizer.set_inited(True)
20 |
21 | def save_quantized_weight(model):
22 | for module in model.modules():
23 | if isinstance(module, QuantModule):
24 | module.weight.data = module.weight_quantizer(module.weight)
25 |
--------------------------------------------------------------------------------
/run_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import time
4 |
5 | if __name__ == "__main__":
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("exp_name", type=str, choices=['resnet18', 'resnet50', 'mobilenetv2', 'regnetx_600m', 'regnetx_3200m', 'mnasnet'])
8 | args = parser.parse_args()
9 | w_bits = [2, 4, 2, 4]
10 | a_bits = [2, 2, 4, 4]
11 |
12 | if args.exp_name == "resnet18":
13 | for i in range(4):
14 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch resnet18 --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.02")
15 | time.sleep(0.5)
16 |
17 | if args.exp_name == "resnet50":
18 | for i in range(4):
19 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch resnet50 --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.02")
20 | time.sleep(0.5)
21 |
22 | if args.exp_name == "regnetx_600m":
23 | for i in range(4):
24 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch regnetx_600m --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.01")
25 | time.sleep(0.5)
26 |
27 | if args.exp_name == "regnetx_3200m":
28 | for i in range(4):
29 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch regnetx_3200m --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.01 --T 4.0 --lamb_c 0.01")
30 | time.sleep(0.5)
31 |
32 | if args.exp_name == "mobilenetv2":
33 | for i in range(4):
34 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch mobilenetv2 --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.1 --T 1.0 --lamb_c 0.005")
35 | time.sleep(0.5)
36 |
37 | if args.exp_name == "mnasnet":
38 | for i in range(4):
39 | os.system(f"python main_imagenet.py --data_path /datasets/imagenet --arch mnasnet --n_bits_w {w_bits[i]} --n_bits_a {a_bits[i]} --weight 0.2 --T 1.0 --lamb_c 0.001")
40 | time.sleep(0.5)
41 |
42 |
--------------------------------------------------------------------------------