├── .gitignore ├── PTQ ├── data │ └── imagenet.py ├── hubconf.py ├── main_imagenet.py ├── models │ ├── __init__.py │ ├── mobilenetv2.py │ └── resnet.py ├── quant │ ├── __init__.py │ ├── adaptive_rounding.py │ ├── block_recon.py │ ├── data_utils.py │ ├── fold_bn.py │ ├── layer_recon.py │ ├── quant_block.py │ ├── quant_layer.py │ └── quant_model.py └── run_scripts │ ├── train_mobilenetv2.sh │ └── train_resnet18.sh ├── QAT ├── bit_config.py ├── quant_train.py ├── run_scripts │ ├── train_resnet18.sh │ ├── train_resnet18_a8_cdp.sh │ ├── train_resnet18_cdp.sh │ ├── train_resnet50.sh │ └── train_resnet50_cdp.sh └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── models │ ├── q_inceptionv3.py │ ├── q_mobilenetv2.py │ └── q_resnet.py │ └── quantization_utils │ ├── quant_modules.py │ └── quant_utils.py ├── README.md ├── mixed_bit ├── ORM.py ├── cdp.py ├── data_providers │ ├── __init__.py │ └── imagenet.py ├── feature_extract.py ├── feature_extract_cdp.py ├── models │ ├── __init__.py │ ├── base_models.py │ ├── mobilenet_imagenet.py │ └── resnet_imagenet.py ├── run_manager.py ├── run_scripts │ ├── PTQ │ │ ├── quant_mobilenetv2.sh │ │ ├── quant_mobilenetv2_2.sh │ │ └── quant_resnet18.sh │ └── QAT │ │ ├── quant_resnet18.sh │ │ └── quant_resnet50.sh └── utils │ ├── __init__.py │ ├── get_data_iter.py │ └── pytorch_utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ## Project specific 2 | mixed_bit/Exp_base/ 3 | QAT/saved_quant_model/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | -------------------------------------------------------------------------------- /PTQ/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import torchvision.datasets as datasets 6 | 7 | 8 | def build_imagenet_data(data_path: str = '', input_size: int = 224, batch_size: int = 64, workers: int = 4): 9 | print('==> Using Pytorch Dataset') 10 | 11 | traindir = os.path.join(data_path, 'train') 12 | valdir = os.path.join(data_path, 'val') 13 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 14 | std=[0.229, 0.224, 0.225]) 15 | 16 | torchvision.set_image_backend('accimage') 17 | train_dataset = datasets.ImageFolder( 18 | traindir, 19 | transforms.Compose([ 20 | transforms.RandomResizedCrop(input_size), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | normalize, 24 | ])) 25 | 26 | train_loader = torch.utils.data.DataLoader( 27 | train_dataset, batch_size=batch_size, shuffle=True, 28 | num_workers=workers, pin_memory=True) 29 | val_loader = torch.utils.data.DataLoader( 30 | datasets.ImageFolder(valdir, transforms.Compose([ 31 | transforms.Resize(256), 32 | transforms.CenterCrop(input_size), 33 | transforms.ToTensor(), 34 | normalize, 35 | ])), 36 | batch_size=batch_size, shuffle=False, 37 | num_workers=workers, pin_memory=True) 38 | return train_loader, val_loader 39 | -------------------------------------------------------------------------------- /PTQ/hubconf.py: -------------------------------------------------------------------------------- 1 | from models.resnet import resnet18 as _resnet18 2 | from models.mobilenetv2 import mobilenetv2 as _mobilenetv2 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | dependencies = ['torch'] 7 | 8 | def resnet18(pretrained=False, **kwargs): 9 | # Call the model, load pretrained weights 10 | model = _resnet18(**kwargs) 11 | if pretrained: 12 | load_url = 'https://github.com/yhhhli/BRECQ/releases/download/v1.0/resnet18_imagenet.pth.tar' 13 | checkpoint = load_state_dict_from_url(url=load_url, map_location='cpu', progress=True) 14 | model.load_state_dict(checkpoint) 15 | return model 16 | 17 | def mobilenetv2(pretrained=False, **kwargs): 18 | # Call the model, load pretrained weights 19 | model = _mobilenetv2(**kwargs) 20 | if pretrained: 21 | load_url = 'https://github.com/yhhhli/BRECQ/releases/download/v1.0/mobilenetv2.pth.tar' 22 | checkpoint = load_state_dict_from_url(url=load_url, map_location='cpu', progress=True) 23 | model.load_state_dict(checkpoint['model']) 24 | return model 25 | 26 | 27 | -------------------------------------------------------------------------------- /PTQ/main_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import os 5 | import random 6 | import numpy as np 7 | import time 8 | import hubconf 9 | from quant import * 10 | from data.imagenet import build_imagenet_data 11 | 12 | 13 | def seed_all(seed=1029): 14 | random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 20 | torch.backends.cudnn.benchmark = False 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | def __init__(self, name, fmt=':f'): 27 | self.name = name 28 | self.fmt = fmt 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def update(self, val, n=1): 38 | self.val = val 39 | self.sum += val * n 40 | self.count += n 41 | self.avg = self.sum / self.count 42 | 43 | def __str__(self): 44 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 45 | return fmtstr.format(**self.__dict__) 46 | 47 | 48 | class ProgressMeter(object): 49 | def __init__(self, num_batches, meters, prefix=""): 50 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 51 | self.meters = meters 52 | self.prefix = prefix 53 | 54 | def display(self, batch): 55 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 56 | entries += [str(meter) for meter in self.meters] 57 | print('\t'.join(entries)) 58 | 59 | def _get_batch_fmtstr(self, num_batches): 60 | num_digits = len(str(num_batches // 1)) 61 | fmt = '{:' + str(num_digits) + 'd}' 62 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 63 | 64 | 65 | def accuracy(output, target, topk=(1,)): 66 | """Computes the accuracy over the k top predictions for the specified values of k""" 67 | with torch.no_grad(): 68 | maxk = max(topk) 69 | batch_size = target.size(0) 70 | 71 | _, pred = output.topk(maxk, 1, True, True) 72 | pred = pred.t() 73 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 74 | 75 | res = [] 76 | for k in topk: 77 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 78 | res.append(correct_k.mul_(100.0 / batch_size)) 79 | return res 80 | 81 | 82 | @torch.no_grad() 83 | def validate_model(val_loader, model, device=None, print_freq=100): 84 | if device is None: 85 | device = next(model.parameters()).device 86 | else: 87 | model.to(device) 88 | batch_time = AverageMeter('Time', ':6.3f') 89 | top1 = AverageMeter('Acc@1', ':6.2f') 90 | top5 = AverageMeter('Acc@5', ':6.2f') 91 | progress = ProgressMeter( 92 | len(val_loader), 93 | [batch_time, top1, top5], 94 | prefix='Test: ') 95 | 96 | # switch to evaluate mode 97 | model.eval() 98 | 99 | end = time.time() 100 | for i, (images, target) in enumerate(val_loader): 101 | images = images.to(device) 102 | target = target.to(device) 103 | 104 | # compute output 105 | output = model(images) 106 | 107 | # measure accuracy and record loss 108 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 109 | top1.update(acc1[0], images.size(0)) 110 | top5.update(acc5[0], images.size(0)) 111 | 112 | # measure elapsed time 113 | batch_time.update(time.time() - end) 114 | end = time.time() 115 | 116 | if i % print_freq == 0: 117 | progress.display(i) 118 | 119 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) 120 | 121 | return top1.avg 122 | 123 | 124 | def get_train_samples(train_loader, num_samples): 125 | train_data = [] 126 | for batch in train_loader: 127 | train_data.append(batch[0]) 128 | if len(train_data) * batch[0].size(0) >= num_samples: 129 | break 130 | return torch.cat(train_data, dim=0)[:num_samples] 131 | 132 | 133 | if __name__ == '__main__': 134 | 135 | parser = argparse.ArgumentParser(description='running parameters', 136 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 137 | 138 | # general parameters for data and model 139 | parser.add_argument('--seed', default=3, type=int, help='random seed for results reproduction') #1005 140 | parser.add_argument('--arch', default='resnet18', type=str, help='dataset name', 141 | choices=['resnet18', 'mobilenetv2']) 142 | parser.add_argument('--batch_size', default=64, type=int, help='mini-batch size for data loader') 143 | parser.add_argument('--workers', default=4, type=int, help='number of workers for data loader') 144 | parser.add_argument('--data_path', default='', type=str, help='path to ImageNet data', required=True) 145 | parser.add_argument('--gpu', help='gpu available', default='0') 146 | 147 | # quantization parameters 148 | parser.add_argument('--n_bits_w', default=4, type=int, help='bitwidth for weight quantization') 149 | parser.add_argument('--channel_wise', action='store_true', help='apply channel_wise quantization for weights') 150 | parser.add_argument('--n_bits_a', default=4, type=int, help='bitwidth for activation quantization') 151 | parser.add_argument('--act_quant', action='store_true', help='apply activation quantization') 152 | parser.add_argument('--disable_8bit_head_stem', action='store_true') 153 | parser.add_argument('--test_before_calibration', action='store_true') 154 | parser.add_argument('--bit_cfg', type=str, default="None") #混合精度 155 | 156 | # weight calibration parameters 157 | parser.add_argument('--num_samples', default=1024, type=int, help='size of the calibration dataset') 158 | parser.add_argument('--iters_w', default=20000, type=int, help='number of iteration for adaround') 159 | parser.add_argument('--weight', default=0.01, type=float, help='weight of rounding cost vs the reconstruction loss.') 160 | parser.add_argument('--sym', action='store_true', help='symmetric reconstruction, not recommended') 161 | parser.add_argument('--b_start', default=20, type=int, help='temperature at the beginning of calibration') 162 | parser.add_argument('--b_end', default=2, type=int, help='temperature at the end of calibration') 163 | parser.add_argument('--warmup', default=0.2, type=float, help='in the warmup period no regularization is applied') 164 | parser.add_argument('--step', default=20, type=int, help='record snn output per step') 165 | parser.add_argument('--use_bias', action='store_true', help='fix weight bias and variance after quantization') #新增 166 | parser.add_argument('--vcorr', action='store_true', help='use variance correction') #新增 167 | parser.add_argument('--bcorr', action='store_true', help='use bias correction') #新增 168 | 169 | # activation calibration parameters 170 | parser.add_argument('--iters_a', default=5000, type=int, help='number of iteration for LSQ') 171 | parser.add_argument('--lr', default=4e-4, type=float, help='learning rate for LSQ') 172 | parser.add_argument('--p', default=2.4, type=float, help='L_p norm minimization for LSQ') 173 | 174 | args = parser.parse_args() 175 | 176 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 177 | seed_all(args.seed) 178 | # torch.backends.cudnn.benchmark = True 179 | # build imagenet data loader 180 | train_loader, test_loader = build_imagenet_data(batch_size=args.batch_size, workers=args.workers, 181 | data_path=args.data_path) 182 | 183 | 184 | # load model 185 | cnn = eval('hubconf.{}(pretrained=True)'.format(args.arch)) 186 | cnn.cuda() 187 | cnn.eval() 188 | # print('Quantized accuracy before brecq: {}'.format(validate_model(test_loader, cnn))) 189 | # build quantization parameters 190 | wq_params = {'n_bits': args.n_bits_w, 'channel_wise': args.channel_wise, 'scale_method': 'mse'} 191 | aq_params = {'n_bits': args.n_bits_a, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.act_quant} 192 | qnn = QuantModel(model=cnn, weight_quant_params=wq_params, act_quant_params=aq_params) 193 | qnn.cuda() 194 | qnn.eval() 195 | if not args.disable_8bit_head_stem: 196 | print('Setting the first and the last layer to 8-bit') 197 | qnn.set_first_last_layer_to_8bit() 198 | 199 | if args.bit_cfg != "None": #新增 200 | print('Setting each layer to different bit') 201 | qnn.set_mixed_precision(eval(args.bit_cfg)) 202 | 203 | cali_data = get_train_samples(train_loader, num_samples=args.num_samples) 204 | device = next(qnn.parameters()).device 205 | 206 | 207 | # Initialize weight quantization parameters 208 | qnn.set_quant_state(True, False) 209 | _ = qnn(cali_data[:64].to(device)) 210 | 211 | if args.test_before_calibration: 212 | # qnn.set_bias_state(args.use_bias, args.vcorr, args.bcorr) 213 | print('Quantized accuracy before brecq: {}'.format(validate_model(test_loader, qnn))) 214 | # qnn.set_bias_state(False, False, False) 215 | 216 | 217 | # Kwargs for weight rounding calibration 218 | kwargs = dict(cali_data=cali_data, iters=args.iters_w, weight=args.weight, asym=True, 219 | b_range=(args.b_start, args.b_end), warmup=args.warmup, act_quant=False, opt_mode='mse') 220 | 221 | 222 | def recon_model(model: nn.Module): 223 | """ 224 | Block reconstruction. For the first and last layers, we can only apply layer reconstruction. 225 | """ 226 | for name, module in model.named_children(): 227 | if isinstance(module, QuantModule): 228 | if module.ignore_reconstruction is True: 229 | print('Ignore reconstruction of layer {}'.format(name)) 230 | continue 231 | else: 232 | print('Reconstruction for layer {}'.format(name)) 233 | layer_reconstruction(qnn, module, **kwargs) 234 | elif isinstance(module, BaseQuantBlock): 235 | if module.ignore_reconstruction is True: 236 | print('Ignore reconstruction of block {}'.format(name)) 237 | continue 238 | else: 239 | print('Reconstruction for block {}'.format(name)) 240 | block_reconstruction(qnn, module, **kwargs) 241 | else: 242 | recon_model(module) 243 | 244 | # Start calibration 245 | recon_model(qnn) 246 | qnn.set_quant_state(weight_quant=True, act_quant=False) 247 | qnn.set_bias_state(args.use_bias, args.vcorr, args.bcorr) #新增 248 | print('Weight quantization accuracy: {}'.format(validate_model(test_loader, qnn))) 249 | qnn.set_bias_state(False, False, False) #新增 250 | 251 | if args.act_quant: 252 | # Initialize activation quantization parameters 253 | qnn.set_quant_state(True, True) 254 | with torch.no_grad(): 255 | _ = qnn(cali_data[:64].to(device)) 256 | # Disable output quantization because network output 257 | # does not get involved in further computation 258 | qnn.disable_network_output_quantization() 259 | # Kwargs for activation rounding calibration 260 | kwargs = dict(cali_data=cali_data, iters=args.iters_a, act_quant=True, opt_mode='mse', lr=args.lr, p=args.p) 261 | recon_model(qnn) 262 | qnn.set_quant_state(weight_quant=True, act_quant=True) 263 | qnn.set_bias_state(args.use_bias, args.vcorr, args.bcorr) #新增 264 | print('Full quantization (W{}A{}) accuracy: {}'.format(args.n_bits_w, args.n_bits_a, 265 | validate_model(test_loader, qnn))) 266 | qnn.set_bias_state(False, False, False) #新增 -------------------------------------------------------------------------------- /PTQ/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmkai/CSMPQ/9150057ecec5935daaaa99e8ea08df97edcb59fa/PTQ/models/__init__.py -------------------------------------------------------------------------------- /PTQ/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | 3 | import torch.nn as nn 4 | import math 5 | import torch 6 | 7 | 8 | def conv_bn(inp, oup, stride): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | nn.BatchNorm2d(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def conv_1x1_bn(inp, oup): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU6(inplace=True) 21 | ) 22 | 23 | 24 | class InvertedResidual(nn.Module): 25 | def __init__(self, inp, oup, stride, expand_ratio): 26 | super(InvertedResidual, self).__init__() 27 | self.stride = stride 28 | assert stride in [1, 2] 29 | 30 | hidden_dim = round(inp * expand_ratio) 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | self.expand_ratio = expand_ratio 33 | if expand_ratio == 1: 34 | self.conv = nn.Sequential( 35 | # dw 36 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 37 | nn.BatchNorm2d(hidden_dim), 38 | nn.ReLU6(inplace=True), 39 | # pw-linear 40 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 41 | nn.BatchNorm2d(oup), 42 | ) 43 | else: 44 | self.conv = nn.Sequential( 45 | # pw 46 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(hidden_dim), 48 | nn.ReLU6(inplace=True), 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 51 | nn.BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 55 | nn.BatchNorm2d(oup), 56 | ) 57 | 58 | def forward(self, x): 59 | if self.use_res_connect: 60 | return x + self.conv(x) 61 | else: 62 | return self.conv(x) 63 | 64 | 65 | class MobileNetV2(nn.Module): 66 | def __init__(self, n_class=1000, input_size=224, width_mult=1., dropout=0.0): 67 | super(MobileNetV2, self).__init__() 68 | block = InvertedResidual 69 | input_channel = 32 70 | last_channel = 1280 71 | interverted_residual_setting = [ 72 | # t, c, n, s 73 | [1, 16, 1, 1], 74 | [6, 24, 2, 2], 75 | [6, 32, 3, 2], 76 | [6, 64, 4, 2], 77 | [6, 96, 3, 1], 78 | [6, 160, 3, 2], 79 | [6, 320, 1, 1], 80 | ] 81 | 82 | # building first layer 83 | assert input_size % 32 == 0 84 | input_channel = int(input_channel * width_mult) 85 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 86 | self.features = [conv_bn(3, input_channel, 2)] 87 | # building inverted residual blocks 88 | for t, c, n, s in interverted_residual_setting: 89 | output_channel = int(c * width_mult) 90 | for i in range(n): 91 | if i == 0: 92 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 93 | else: 94 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 95 | input_channel = output_channel 96 | # building last several layers 97 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 98 | # self.features.append(nn.AvgPool2d(input_size // 32)) 99 | # make it nn.Sequential 100 | self.features = nn.Sequential(*self.features) 101 | 102 | # building classifier 103 | self.classifier = nn.Sequential( 104 | nn.Dropout(dropout), 105 | nn.Linear(self.last_channel, n_class), 106 | ) 107 | 108 | self._initialize_weights() 109 | 110 | def forward(self, x): 111 | x = self.features(x) 112 | x = x.mean([2, 3]) 113 | x = self.classifier(x) 114 | return x 115 | 116 | def _initialize_weights(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | if m.bias is not None: 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | n = m.weight.size(1) 128 | m.weight.data.normal_(0, 0.01) 129 | m.bias.data.zero_() 130 | 131 | 132 | def mobilenetv2(**kwargs): 133 | """ 134 | Constructs a MobileNetV2 model. 135 | """ 136 | model = MobileNetV2(**kwargs) 137 | return model -------------------------------------------------------------------------------- /PTQ/models/resnet.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=dilation, groups=groups, bias=False, dilation=dilation) 15 | 16 | 17 | def conv1x1(in_planes, out_planes, stride=1): 18 | """1x1 convolution""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | __constants__ = ['downsample'] 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 27 | base_width=64, dilation=1, norm_layer=None): 28 | super(BasicBlock, self).__init__() 29 | if norm_layer is None: 30 | norm_layer = BN 31 | if groups != 1 or base_width != 64: 32 | raise ValueError( 33 | 'BasicBlock only supports groups=1 and base_width=64') 34 | if dilation > 1: 35 | raise NotImplementedError( 36 | "Dilation > 1 not supported in BasicBlock") 37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = norm_layer(planes) 40 | self.relu1 = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = norm_layer(planes) 43 | self.downsample = downsample 44 | self.relu2 = nn.ReLU(inplace=True) 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu1(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | identity = self.downsample(x) 59 | 60 | out += identity 61 | out = self.relu2(out) 62 | 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | __constants__ = ['downsample'] 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, dilation=1, norm_layer=None): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = BN 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.relu1 = nn.ReLU(inplace=True) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.relu2 = nn.ReLU(inplace=True) 83 | self.conv3 = conv1x1(width, planes * self.expansion) 84 | self.bn3 = norm_layer(planes * self.expansion) 85 | self.relu3 = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu1(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu2(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu3(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, 115 | block, 116 | layers, 117 | num_classes=1000, 118 | zero_init_residual=False, 119 | groups=1, 120 | width_per_group=64, 121 | replace_stride_with_dilation=None, 122 | deep_stem=False, 123 | avg_down=False): 124 | 125 | super(ResNet, self).__init__() 126 | 127 | global BN 128 | 129 | BN = torch.nn.BatchNorm2d 130 | norm_layer = BN 131 | 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | self.deep_stem = deep_stem 137 | self.avg_down = avg_down 138 | 139 | if replace_stride_with_dilation is None: 140 | # each element in the tuple indicates if we should replace 141 | # the 2x2 stride with a dilated convolution instead 142 | replace_stride_with_dilation = [False, False, False] 143 | if len(replace_stride_with_dilation) != 3: 144 | raise ValueError("replace_stride_with_dilation should be None " 145 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 146 | self.groups = groups 147 | self.base_width = width_per_group 148 | 149 | if self.deep_stem: 150 | self.conv1 = nn.Sequential( 151 | nn.Conv2d(3, 32, kernel_size=3, stride=2, 152 | padding=1, bias=False), 153 | norm_layer(32), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(32, 32, kernel_size=3, stride=1, 156 | padding=1, bias=False), 157 | norm_layer(32), 158 | nn.ReLU(inplace=True), 159 | nn.Conv2d(32, 64, kernel_size=3, stride=1, 160 | padding=1, bias=False), 161 | ) 162 | else: 163 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 164 | stride=2, padding=3, bias=False) 165 | 166 | self.bn1 = norm_layer(self.inplanes) 167 | self.relu = nn.ReLU(inplace=True) 168 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 169 | self.layer1 = self._make_layer(block, 64, layers[0]) 170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 171 | dilate=replace_stride_with_dilation[0]) 172 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 173 | dilate=replace_stride_with_dilation[1]) 174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 175 | dilate=replace_stride_with_dilation[2]) 176 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 177 | self.fc = nn.Linear(512 * block.expansion, num_classes) 178 | 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | nn.init.kaiming_normal_( 182 | m.weight, mode='fan_out', nonlinearity='relu') 183 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 184 | nn.init.constant_(m.weight, 1) 185 | nn.init.constant_(m.bias, 0) 186 | 187 | # Zero-initialize the last BN in each residual branch, 188 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 189 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 190 | if zero_init_residual: 191 | for m in self.modules(): 192 | if isinstance(m, Bottleneck): 193 | nn.init.constant_(m.bn3.weight, 0) 194 | elif isinstance(m, BasicBlock): 195 | nn.init.constant_(m.bn2.weight, 0) 196 | 197 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 198 | norm_layer = self._norm_layer 199 | downsample = None 200 | previous_dilation = self.dilation 201 | if dilate: 202 | self.dilation *= stride 203 | stride = 1 204 | if stride != 1 or self.inplanes != planes * block.expansion: 205 | if self.avg_down: 206 | downsample = nn.Sequential( 207 | nn.AvgPool2d(stride, stride=stride, 208 | ceil_mode=True, count_include_pad=False), 209 | conv1x1(self.inplanes, planes * block.expansion), 210 | norm_layer(planes * block.expansion), 211 | ) 212 | else: 213 | downsample = nn.Sequential( 214 | conv1x1(self.inplanes, planes * block.expansion, stride), 215 | norm_layer(planes * block.expansion), 216 | ) 217 | 218 | layers = [] 219 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 220 | self.base_width, previous_dilation, norm_layer)) 221 | self.inplanes = planes * block.expansion 222 | for _ in range(1, blocks): 223 | layers.append(block(self.inplanes, planes, groups=self.groups, 224 | base_width=self.base_width, dilation=self.dilation, 225 | norm_layer=norm_layer)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def _forward_impl(self, x): 230 | # See note [TorchScript super()] 231 | x = self.conv1(x) 232 | x = self.bn1(x) 233 | x = self.relu(x) 234 | x = self.maxpool(x) 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | 240 | x = self.avgpool(x) 241 | x = torch.flatten(x, 1) 242 | x = self.fc(x) 243 | 244 | return x 245 | 246 | def forward(self, x): 247 | return self._forward_impl(x) 248 | 249 | 250 | def resnet18(**kwargs): 251 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 252 | return model 253 | 254 | 255 | def resnet34(**kwargs): 256 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 257 | return model 258 | 259 | 260 | def resnet50(**kwargs): 261 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 262 | return model 263 | 264 | 265 | def resnet101(**kwargs): 266 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 267 | return model 268 | 269 | 270 | def resnet152(**kwargs): 271 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 272 | return model 273 | 274 | 275 | def resnext50_32x4d(**kwargs): 276 | kwargs['groups'] = 32 277 | kwargs['width_per_group'] = 4 278 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 279 | return model 280 | 281 | 282 | def resnext101_32x8d(**kwargs): 283 | kwargs['groups'] = 32 284 | kwargs['width_per_group'] = 8 285 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 286 | return model 287 | 288 | 289 | def wide_resnet50_2(**kwargs): 290 | kwargs['width_per_group'] = 64 * 2 291 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 292 | return model 293 | 294 | 295 | def wide_resnet101_2(**kwargs): 296 | kwargs['width_per_group'] = 64 * 2 297 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 298 | return model 299 | -------------------------------------------------------------------------------- /PTQ/quant/__init__.py: -------------------------------------------------------------------------------- 1 | from quant.block_recon import block_reconstruction 2 | from quant.layer_recon import layer_reconstruction 3 | from quant.quant_block import BaseQuantBlock 4 | from quant.quant_layer import QuantModule 5 | from quant.quant_model import QuantModel 6 | -------------------------------------------------------------------------------- /PTQ/quant/adaptive_rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from quant.quant_layer import UniformAffineQuantizer, round_ste 4 | 5 | 6 | class AdaRoundQuantizer(nn.Module): 7 | """ 8 | Adaptive Rounding Quantizer, used to optimize the rounding policy 9 | by reconstructing the intermediate output. 10 | Based on 11 | Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568 12 | 13 | :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer 14 | :param round_mode: controls the forward pass in this quantizer 15 | :param weight_tensor: initialize alpha 16 | """ 17 | 18 | def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'): 19 | super(AdaRoundQuantizer, self).__init__() 20 | # copying all attributes from UniformAffineQuantizer 21 | self.n_bits = uaq.n_bits 22 | self.sym = uaq.sym 23 | self.delta = uaq.delta 24 | self.zero_point = uaq.zero_point 25 | self.n_levels = uaq.n_levels 26 | 27 | self.round_mode = round_mode 28 | self.alpha = None 29 | self.soft_targets = False 30 | 31 | # params for sigmoid function 32 | self.gamma, self.zeta = -0.1, 1.1 33 | self.beta = 2/3 34 | self.init_alpha(x=weight_tensor.clone()) 35 | 36 | def forward(self, x): 37 | if self.round_mode == 'nearest': 38 | x_int = torch.round(x / self.delta) 39 | elif self.round_mode == 'nearest_ste': 40 | x_int = round_ste(x / self.delta) 41 | elif self.round_mode == 'stochastic': 42 | x_floor = torch.floor(x / self.delta) 43 | rest = (x / self.delta) - x_floor # rest of rounding 44 | x_int = x_floor + torch.bernoulli(rest) 45 | print('Draw stochastic sample') 46 | elif self.round_mode == 'learned_hard_sigmoid': 47 | x_floor = torch.floor(x / self.delta) 48 | if self.soft_targets: 49 | x_int = x_floor + self.get_soft_targets() 50 | else: 51 | x_int = x_floor + (self.alpha >= 0).float() 52 | else: 53 | raise ValueError('Wrong rounding mode') 54 | 55 | x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1) 56 | x_float_q = (x_quant - self.zero_point) * self.delta 57 | 58 | return x_float_q 59 | 60 | def get_soft_targets(self): 61 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1) 62 | 63 | def init_alpha(self, x: torch.Tensor): 64 | x_floor = torch.floor(x / self.delta) 65 | if self.round_mode == 'learned_hard_sigmoid': 66 | print('Init alpha to be FP32') 67 | rest = (x / self.delta) - x_floor # rest of rounding [0, 1) 68 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest 69 | self.alpha = nn.Parameter(alpha) 70 | else: 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /PTQ/quant/block_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from quant.quant_layer import QuantModule, StraightThrough, lp_loss 3 | from quant.quant_model import QuantModel 4 | from quant.quant_block import BaseQuantBlock 5 | from quant.adaptive_rounding import AdaRoundQuantizer 6 | from quant.data_utils import save_grad_data, save_inp_oup_data 7 | 8 | 9 | def block_reconstruction(model: QuantModel, block: BaseQuantBlock, cali_data: torch.Tensor, 10 | batch_size: int = 32, iters: int = 20000, weight: float = 0.01, opt_mode: str = 'mse', 11 | asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2), 12 | warmup: float = 0.0, act_quant: bool = False, lr: float = 4e-5, p: float = 2.0): 13 | """ 14 | Block reconstruction to optimize the output from each block. 15 | 16 | :param model: QuantModel 17 | :param block: BaseQuantBlock that needs to be optimized 18 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound 19 | :param batch_size: mini-batch size for reconstruction 20 | :param iters: optimization iterations for reconstruction, 21 | :param weight: the weight of rounding regularization term 22 | :param opt_mode: optimization mode 23 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output 24 | :param include_act_func: optimize the output after activation function 25 | :param b_range: temperature range 26 | :param warmup: proportion of iterations that no scheduling for temperature 27 | :param act_quant: use activation quantization or not. 28 | :param lr: learning rate for act delta learning 29 | :param p: L_p norm minimization 30 | """ 31 | model.set_quant_state(False, False) 32 | block.set_quant_state(True, act_quant) 33 | round_mode = 'learned_hard_sigmoid' 34 | 35 | if not include_act_func: 36 | org_act_func = block.activation_function 37 | block.activation_function = StraightThrough() 38 | 39 | if not act_quant: 40 | # Replace weight quantizer to AdaRoundQuantizer 41 | for name, module in block.named_modules(): 42 | if isinstance(module, QuantModule): 43 | module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode, 44 | weight_tensor=module.org_weight.data) 45 | module.weight_quantizer.soft_targets = True 46 | 47 | # Set up optimizer 48 | opt_params = [] 49 | for name, module in block.named_modules(): 50 | if isinstance(module, QuantModule): 51 | opt_params += [module.weight_quantizer.alpha] 52 | optimizer = torch.optim.Adam(opt_params) 53 | scheduler = None 54 | else: 55 | # Use UniformAffineQuantizer to learn delta 56 | if hasattr(block.act_quantizer, 'delta'): 57 | opt_params = [block.act_quantizer.delta] 58 | else: 59 | opt_params = [] 60 | for name, module in block.named_modules(): 61 | if isinstance(module, QuantModule): 62 | if module.act_quantizer.delta is not None: 63 | opt_params += [module.act_quantizer.delta] 64 | optimizer = torch.optim.Adam(opt_params, lr=lr) 65 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.) 66 | 67 | loss_mode = 'none' if act_quant else 'relaxation' 68 | rec_loss = opt_mode 69 | 70 | loss_func = LossFunction(block, round_loss=loss_mode, weight=weight, max_count=iters, rec_loss=rec_loss, 71 | b_range=b_range, decay_start=0, warmup=warmup, p=p) 72 | 73 | # Save data before optimizing the rounding 74 | cached_inps, cached_outs = save_inp_oup_data(model, block, cali_data, asym, act_quant, batch_size) 75 | if opt_mode != 'mse': 76 | cached_grads = save_grad_data(model, block, cali_data, act_quant, batch_size=batch_size) 77 | else: 78 | cached_grads = None 79 | device = 'cuda' 80 | for i in range(iters): 81 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 82 | cur_inp = cached_inps[idx].to(device) 83 | cur_out = cached_outs[idx].to(device) 84 | cur_grad = cached_grads[idx].to(device) if opt_mode != 'mse' else None 85 | 86 | optimizer.zero_grad() 87 | out_quant = block(cur_inp) 88 | 89 | err = loss_func(out_quant, cur_out, cur_grad) 90 | 91 | err.backward(retain_graph=True) 92 | optimizer.step() 93 | if scheduler: 94 | scheduler.step() 95 | 96 | torch.cuda.empty_cache() 97 | 98 | # Finish optimization, use hard rounding. 99 | for name, module in block.named_modules(): 100 | if isinstance(module, QuantModule): 101 | module.weight_quantizer.soft_targets = False 102 | 103 | # Reset original activation function 104 | if not include_act_func: 105 | block.activation_function = org_act_func 106 | 107 | 108 | class LossFunction: 109 | def __init__(self, 110 | block: BaseQuantBlock, 111 | round_loss: str = 'relaxation', 112 | weight: float = 1., 113 | rec_loss: str = 'mse', 114 | max_count: int = 2000, 115 | b_range: tuple = (10, 2), 116 | decay_start: float = 0.0, 117 | warmup: float = 0.0, 118 | p: float = 2.): 119 | 120 | self.block = block 121 | self.round_loss = round_loss 122 | self.weight = weight 123 | self.rec_loss = rec_loss 124 | self.loss_start = max_count * warmup 125 | self.p = p 126 | 127 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start, 128 | start_b=b_range[0], end_b=b_range[1]) 129 | self.count = 0 130 | 131 | def __call__(self, pred, tgt, grad=None): 132 | """ 133 | Compute the total loss for adaptive rounding: 134 | rec_loss is the quadratic output reconstruction loss, round_loss is 135 | a regularization term to optimize the rounding policy 136 | 137 | :param pred: output from quantized model 138 | :param tgt: output from FP model 139 | :param grad: gradients to compute fisher information 140 | :return: total loss function 141 | """ 142 | self.count += 1 143 | if self.rec_loss == 'mse': 144 | rec_loss = lp_loss(pred, tgt, p=self.p) 145 | elif self.rec_loss == 'fisher_diag': 146 | rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() 147 | elif self.rec_loss == 'fisher_full': 148 | a = (pred - tgt).abs() 149 | grad = grad.abs() 150 | batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) 151 | rec_loss = (batch_dotprod * a * grad).mean() / 100 152 | else: 153 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) 154 | 155 | b = self.temp_decay(self.count) 156 | if self.count < self.loss_start or self.round_loss == 'none': 157 | b = round_loss = 0 158 | elif self.round_loss == 'relaxation': 159 | round_loss = 0 160 | for name, module in self.block.named_modules(): 161 | if isinstance(module, QuantModule): 162 | round_vals = module.weight_quantizer.get_soft_targets() 163 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() 164 | else: 165 | raise NotImplementedError 166 | 167 | total_loss = rec_loss + round_loss 168 | if self.count % 500 == 0: 169 | print('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format( 170 | float(total_loss), float(rec_loss), float(round_loss), b, self.count)) 171 | return total_loss 172 | 173 | 174 | class LinearTempDecay: 175 | def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): 176 | self.t_max = t_max 177 | self.start_decay = rel_start_decay * t_max 178 | self.start_b = start_b 179 | self.end_b = end_b 180 | 181 | def __call__(self, t): 182 | """ 183 | Cosine annealing scheduler for temperature b. 184 | :param t: the current time step 185 | :return: scheduled temperature 186 | """ 187 | if t < self.start_decay: 188 | return self.start_b 189 | else: 190 | rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) 191 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) 192 | -------------------------------------------------------------------------------- /PTQ/quant/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from quant.quant_layer import QuantModule, Union 4 | from quant.quant_model import QuantModel 5 | from quant.quant_block import BaseQuantBlock 6 | 7 | 8 | def save_inp_oup_data(model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], cali_data: torch.Tensor, 9 | asym: bool = False, act_quant: bool = False, batch_size: int = 32, keep_gpu: bool = True): 10 | """ 11 | Save input data and output data of a particular layer/block over calibration dataset. 12 | 13 | :param model: QuantModel 14 | :param layer: QuantModule or QuantBlock 15 | :param cali_data: calibration data set 16 | :param asym: if Ture, save quantized input and full precision output 17 | :param act_quant: use activation quantization 18 | :param batch_size: mini-batch size for calibration 19 | :param keep_gpu: put saved data on GPU for faster optimization 20 | :return: input and output data 21 | """ 22 | device = next(model.parameters()).device 23 | get_inp_out = GetLayerInpOut(model, layer, device=device, asym=asym, act_quant=act_quant) 24 | cached_batches = [] 25 | torch.cuda.empty_cache() 26 | 27 | for i in range(int(cali_data.size(0) / batch_size)): 28 | cur_inp, cur_out = get_inp_out(cali_data[i * batch_size:(i + 1) * batch_size]) 29 | cached_batches.append((cur_inp.cpu(), cur_out.cpu())) 30 | 31 | cached_inps = torch.cat([x[0] for x in cached_batches]) 32 | cached_outs = torch.cat([x[1] for x in cached_batches]) 33 | torch.cuda.empty_cache() 34 | if keep_gpu: 35 | cached_inps = cached_inps.to(device) 36 | cached_outs = cached_outs.to(device) 37 | return cached_inps, cached_outs 38 | 39 | 40 | def save_grad_data(model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], cali_data: torch.Tensor, 41 | damping: float = 1., act_quant: bool = False, batch_size: int = 32, 42 | keep_gpu: bool = True): 43 | """ 44 | Save gradient data of a particular layer/block over calibration dataset. 45 | 46 | :param model: QuantModel 47 | :param layer: QuantModule or QuantBlock 48 | :param cali_data: calibration data set 49 | :param damping: damping the second-order gradient by adding some constant in the FIM diagonal 50 | :param act_quant: use activation quantization 51 | :param batch_size: mini-batch size for calibration 52 | :param keep_gpu: put saved data on GPU for faster optimization 53 | :return: gradient data 54 | """ 55 | device = next(model.parameters()).device 56 | get_grad = GetLayerGrad(model, layer, device, act_quant=act_quant) 57 | cached_batches = [] 58 | torch.cuda.empty_cache() 59 | 60 | for i in range(int(cali_data.size(0) / batch_size)): 61 | cur_grad = get_grad(cali_data[i * batch_size:(i + 1) * batch_size]) 62 | cached_batches.append(cur_grad.cpu()) 63 | 64 | cached_grads = torch.cat([x for x in cached_batches]) 65 | cached_grads = cached_grads.abs() + 1.0 66 | # scaling to make sure its mean is 1 67 | # cached_grads = cached_grads * torch.sqrt(cached_grads.numel() / cached_grads.pow(2).sum()) 68 | torch.cuda.empty_cache() 69 | if keep_gpu: 70 | cached_grads = cached_grads.to(device) 71 | return cached_grads 72 | 73 | 74 | class StopForwardException(Exception): 75 | """ 76 | Used to throw and catch an exception to stop traversing the graph 77 | """ 78 | pass 79 | 80 | 81 | class DataSaverHook: 82 | """ 83 | Forward hook that stores the input and output of a block 84 | """ 85 | def __init__(self, store_input=False, store_output=False, stop_forward=False): 86 | self.store_input = store_input 87 | self.store_output = store_output 88 | self.stop_forward = stop_forward 89 | 90 | self.input_store = None 91 | self.output_store = None 92 | 93 | def __call__(self, module, input_batch, output_batch): 94 | if self.store_input: 95 | self.input_store = input_batch 96 | if self.store_output: 97 | self.output_store = output_batch 98 | if self.stop_forward: 99 | raise StopForwardException 100 | 101 | 102 | class GetLayerInpOut: 103 | def __init__(self, model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], 104 | device: torch.device, asym: bool = False, act_quant: bool = False): 105 | self.model = model 106 | self.layer = layer 107 | self.asym = asym 108 | self.device = device 109 | self.act_quant = act_quant 110 | self.data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=True) 111 | 112 | def __call__(self, model_input): 113 | self.model.eval() 114 | self.model.set_quant_state(False, False) 115 | 116 | handle = self.layer.register_forward_hook(self.data_saver) 117 | with torch.no_grad(): 118 | try: 119 | _ = self.model(model_input.to(self.device)) 120 | except StopForwardException: 121 | pass 122 | 123 | if self.asym: 124 | # Recalculate input with network quantized 125 | self.data_saver.store_output = False 126 | self.model.set_quant_state(weight_quant=True, act_quant=self.act_quant) 127 | try: 128 | _ = self.model(model_input.to(self.device)) 129 | except StopForwardException: 130 | pass 131 | self.data_saver.store_output = True 132 | 133 | handle.remove() 134 | 135 | self.model.set_quant_state(False, False) 136 | self.layer.set_quant_state(True, self.act_quant) 137 | self.model.train() 138 | 139 | return self.data_saver.input_store[0].detach(), self.data_saver.output_store.detach() 140 | 141 | 142 | class GradSaverHook: 143 | def __init__(self, store_grad=True): 144 | self.store_grad = store_grad 145 | self.stop_backward = False 146 | self.grad_out = None 147 | 148 | def __call__(self, module, grad_input, grad_output): 149 | if self.store_grad: 150 | self.grad_out = grad_output[0] 151 | if self.stop_backward: 152 | raise StopForwardException 153 | 154 | 155 | class GetLayerGrad: 156 | def __init__(self, model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], 157 | device: torch.device, act_quant: bool = False): 158 | self.model = model 159 | self.layer = layer 160 | self.device = device 161 | self.act_quant = act_quant 162 | self.data_saver = GradSaverHook(True) 163 | 164 | def __call__(self, model_input): 165 | """ 166 | Compute the gradients of block output, note that we compute the 167 | gradient by calculating the KL loss between fp model and quant model 168 | 169 | :param model_input: calibration data samples 170 | :return: gradients 171 | """ 172 | self.model.eval() 173 | 174 | handle = self.layer.register_backward_hook(self.data_saver) 175 | with torch.enable_grad(): 176 | try: 177 | self.model.zero_grad() 178 | inputs = model_input.to(self.device) 179 | self.model.set_quant_state(False, False) 180 | out_fp = self.model(inputs) 181 | quantize_model_till(self.model, self.layer, self.act_quant) 182 | out_q = self.model(inputs) 183 | loss = F.kl_div(F.log_softmax(out_q, dim=1), F.softmax(out_fp, dim=1), reduction='batchmean') 184 | loss.backward() 185 | except StopForwardException: 186 | pass 187 | 188 | handle.remove() 189 | self.model.set_quant_state(False, False) 190 | self.layer.set_quant_state(True, self.act_quant) 191 | self.model.train() 192 | return self.data_saver.grad_out.data 193 | 194 | 195 | def quantize_model_till(model: QuantModule, layer: Union[QuantModule, BaseQuantBlock], act_quant: bool = False): 196 | """ 197 | We assumes modules are correctly ordered, holds for all models considered 198 | :param model: quantized_model 199 | :param layer: a block or a single layer. 200 | """ 201 | model.set_quant_state(False, False) 202 | for name, module in model.named_modules(): 203 | if isinstance(module, (QuantModule, BaseQuantBlock)): 204 | module.set_quant_state(True, act_quant) 205 | if module == layer: 206 | break 207 | -------------------------------------------------------------------------------- /PTQ/quant/fold_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | class StraightThrough(nn.Module): 7 | def __int__(self): 8 | super().__init__() 9 | 10 | def forward(self, input): 11 | return input 12 | 13 | 14 | def _fold_bn(conv_module, bn_module): 15 | w = conv_module.weight.data 16 | y_mean = bn_module.running_mean 17 | y_var = bn_module.running_var 18 | safe_std = torch.sqrt(y_var + bn_module.eps) 19 | w_view = (conv_module.out_channels, 1, 1, 1) 20 | if bn_module.affine: 21 | weight = w * (bn_module.weight / safe_std).view(w_view) 22 | beta = bn_module.bias - bn_module.weight * y_mean / safe_std 23 | if conv_module.bias is not None: 24 | bias = bn_module.weight * conv_module.bias / safe_std + beta 25 | else: 26 | bias = beta 27 | else: 28 | weight = w / safe_std.view(w_view) 29 | beta = -y_mean / safe_std 30 | if conv_module.bias is not None: 31 | bias = conv_module.bias / safe_std + beta 32 | else: 33 | bias = beta 34 | return weight, bias 35 | 36 | 37 | def fold_bn_into_conv(conv_module, bn_module): 38 | w, b = _fold_bn(conv_module, bn_module) 39 | if conv_module.bias is None: 40 | conv_module.bias = nn.Parameter(b) 41 | else: 42 | conv_module.bias.data = b 43 | conv_module.weight.data = w 44 | # set bn running stats 45 | bn_module.running_mean = bn_module.bias.data 46 | bn_module.running_var = bn_module.weight.data ** 2 47 | 48 | 49 | def reset_bn(module: nn.BatchNorm2d): 50 | if module.track_running_stats: 51 | module.running_mean.zero_() 52 | module.running_var.fill_(1-module.eps) 53 | # we do not reset numer of tracked batches here 54 | # self.num_batches_tracked.zero_() 55 | if module.affine: 56 | init.ones_(module.weight) 57 | init.zeros_(module.bias) 58 | 59 | 60 | def is_bn(m): 61 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 62 | 63 | 64 | def is_absorbing(m): 65 | return (isinstance(m, nn.Conv2d)) or isinstance(m, nn.Linear) 66 | 67 | 68 | def search_fold_and_remove_bn(model): 69 | model.eval() 70 | prev = None 71 | for n, m in model.named_children(): 72 | if is_bn(m) and is_absorbing(prev): 73 | fold_bn_into_conv(prev, m) 74 | # set the bn module to straight through 75 | setattr(model, n, StraightThrough()) 76 | elif is_absorbing(m): 77 | prev = m 78 | else: 79 | prev = search_fold_and_remove_bn(m) 80 | return prev 81 | 82 | 83 | def search_fold_and_reset_bn(model): 84 | model.eval() 85 | prev = None 86 | for n, m in model.named_children(): 87 | if is_bn(m) and is_absorbing(prev): 88 | fold_bn_into_conv(prev, m) 89 | # reset_bn(m) 90 | else: 91 | search_fold_and_reset_bn(m) 92 | prev = m 93 | 94 | -------------------------------------------------------------------------------- /PTQ/quant/layer_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from quant.quant_layer import QuantModule, StraightThrough, lp_loss 3 | from quant.quant_model import QuantModel 4 | from quant.block_recon import LinearTempDecay 5 | from quant.adaptive_rounding import AdaRoundQuantizer 6 | from quant.data_utils import save_grad_data, save_inp_oup_data 7 | 8 | 9 | def layer_reconstruction(model: QuantModel, layer: QuantModule, cali_data: torch.Tensor, 10 | batch_size: int = 32, iters: int = 20000, weight: float = 0.001, opt_mode: str = 'mse', 11 | asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2), 12 | warmup: float = 0.0, act_quant: bool = False, lr: float = 4e-5, p: float = 2.0): 13 | """ 14 | Block reconstruction to optimize the output from each layer. 15 | 16 | :param model: QuantModel 17 | :param layer: QuantModule that needs to be optimized 18 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound 19 | :param batch_size: mini-batch size for reconstruction 20 | :param iters: optimization iterations for reconstruction, 21 | :param weight: the weight of rounding regularization term 22 | :param opt_mode: optimization mode 23 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output 24 | :param include_act_func: optimize the output after activation function 25 | :param b_range: temperature range 26 | :param warmup: proportion of iterations that no scheduling for temperature 27 | :param act_quant: use activation quantization or not. 28 | :param lr: learning rate for act delta learning 29 | :param p: L_p norm minimization 30 | """ 31 | 32 | model.set_quant_state(False, False) 33 | layer.set_quant_state(True, act_quant) 34 | round_mode = 'learned_hard_sigmoid' 35 | 36 | if not include_act_func: 37 | org_act_func = layer.activation_function 38 | layer.activation_function = StraightThrough() 39 | 40 | if not act_quant: 41 | # Replace weight quantizer to AdaRoundQuantizer 42 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode, 43 | weight_tensor=layer.org_weight.data) 44 | layer.weight_quantizer.soft_targets = True 45 | 46 | # Set up optimizer 47 | opt_params = [layer.weight_quantizer.alpha] 48 | optimizer = torch.optim.Adam(opt_params) 49 | scheduler = None 50 | else: 51 | # Use UniformAffineQuantizer to learn delta 52 | opt_params = [layer.act_quantizer.delta] 53 | optimizer = torch.optim.Adam(opt_params, lr=lr) 54 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.) 55 | 56 | loss_mode = 'none' if act_quant else 'relaxation' 57 | rec_loss = opt_mode 58 | 59 | loss_func = LossFunction(layer, round_loss=loss_mode, weight=weight, 60 | max_count=iters, rec_loss=rec_loss, b_range=b_range, 61 | decay_start=0, warmup=warmup, p=p) 62 | 63 | # Save data before optimizing the rounding 64 | cached_inps, cached_outs = save_inp_oup_data(model, layer, cali_data, asym, act_quant, batch_size) 65 | if opt_mode != 'mse': 66 | cached_grads = save_grad_data(model, layer, cali_data, act_quant, batch_size=batch_size) 67 | else: 68 | cached_grads = None 69 | device = 'cuda' 70 | for i in range(iters): 71 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 72 | cur_inp = cached_inps[idx] 73 | cur_out = cached_outs[idx] 74 | cur_grad = cached_grads[idx] if opt_mode != 'mse' else None 75 | 76 | optimizer.zero_grad() 77 | out_quant = layer(cur_inp) 78 | 79 | err = loss_func(out_quant, cur_out, cur_grad) 80 | 81 | err.backward(retain_graph=True) 82 | optimizer.step() 83 | if scheduler: 84 | scheduler.step() 85 | 86 | torch.cuda.empty_cache() 87 | 88 | # Finish optimization, use hard rounding. 89 | layer.weight_quantizer.soft_targets = False 90 | 91 | # Reset original activation function 92 | if not include_act_func: 93 | layer.activation_function = org_act_func 94 | 95 | 96 | class LossFunction: 97 | def __init__(self, 98 | layer: QuantModule, 99 | round_loss: str = 'relaxation', 100 | weight: float = 1., 101 | rec_loss: str = 'mse', 102 | max_count: int = 2000, 103 | b_range: tuple = (10, 2), 104 | decay_start: float = 0.0, 105 | warmup: float = 0.0, 106 | p: float = 2.): 107 | 108 | self.layer = layer 109 | self.round_loss = round_loss 110 | self.weight = weight 111 | self.rec_loss = rec_loss 112 | self.loss_start = max_count * warmup 113 | self.p = p 114 | 115 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start, 116 | start_b=b_range[0], end_b=b_range[1]) 117 | self.count = 0 118 | 119 | def __call__(self, pred, tgt, grad=None): 120 | """ 121 | Compute the total loss for adaptive rounding: 122 | rec_loss is the quadratic output reconstruction loss, round_loss is 123 | a regularization term to optimize the rounding policy 124 | 125 | :param pred: output from quantized model 126 | :param tgt: output from FP model 127 | :param grad: gradients to compute fisher information 128 | :return: total loss function 129 | """ 130 | self.count += 1 131 | if self.rec_loss == 'mse': 132 | rec_loss = lp_loss(pred, tgt, p=self.p) 133 | elif self.rec_loss == 'fisher_diag': 134 | rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() 135 | elif self.rec_loss == 'fisher_full': 136 | a = (pred - tgt).abs() 137 | grad = grad.abs() 138 | batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) 139 | rec_loss = (batch_dotprod * a * grad).mean() / 100 140 | else: 141 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) 142 | 143 | b = self.temp_decay(self.count) 144 | if self.count < self.loss_start or self.round_loss == 'none': 145 | b = round_loss = 0 146 | elif self.round_loss == 'relaxation': 147 | round_loss = 0 148 | round_vals = self.layer.weight_quantizer.get_soft_targets() 149 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() 150 | else: 151 | raise NotImplementedError 152 | 153 | total_loss = rec_loss + round_loss 154 | if self.count % 500 == 0: 155 | print('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format( 156 | float(total_loss), float(rec_loss), float(round_loss), b, self.count)) 157 | return total_loss 158 | 159 | -------------------------------------------------------------------------------- /PTQ/quant/quant_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from quant.quant_layer import QuantModule, UniformAffineQuantizer, StraightThrough 6 | from models.resnet import BasicBlock, Bottleneck 7 | from models.mobilenetv2 import InvertedResidual 8 | 9 | 10 | class BaseQuantBlock(nn.Module): 11 | """ 12 | Base implementation of block structures for all networks. 13 | Due to the branch architecture, we have to perform activation function 14 | and quantization after the elemental-wise add operation, therefore, we 15 | put this part in this class. 16 | """ 17 | def __init__(self, act_quant_params: dict = {}): 18 | super().__init__() 19 | self.use_weight_quant = False 20 | self.use_act_quant = False 21 | # initialize quantizer 22 | 23 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 24 | self.activation_function = StraightThrough() 25 | 26 | self.ignore_reconstruction = False 27 | 28 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 29 | # setting weight quantization here does not affect actual forward pass 30 | self.use_weight_quant = weight_quant 31 | self.use_act_quant = act_quant 32 | for m in self.modules(): 33 | if isinstance(m, QuantModule): 34 | m.set_quant_state(weight_quant, act_quant) 35 | 36 | 37 | class QuantBasicBlock(BaseQuantBlock): 38 | """ 39 | Implementation of Quantized BasicBlock used in ResNet-18 and ResNet-34. 40 | """ 41 | def __init__(self, basic_block: BasicBlock, weight_quant_params: dict = {}, act_quant_params: dict = {}): 42 | super().__init__(act_quant_params) 43 | self.conv1 = QuantModule(basic_block.conv1, weight_quant_params, act_quant_params) 44 | self.conv1.activation_function = basic_block.relu1 45 | self.conv2 = QuantModule(basic_block.conv2, weight_quant_params, act_quant_params, disable_act_quant=True) 46 | 47 | # modify the activation function to ReLU 48 | self.activation_function = basic_block.relu2 49 | 50 | if basic_block.downsample is None: 51 | self.downsample = None 52 | else: 53 | self.downsample = QuantModule(basic_block.downsample[0], weight_quant_params, act_quant_params, 54 | disable_act_quant=True) 55 | # copying all attributes in original block 56 | self.stride = basic_block.stride 57 | 58 | def forward(self, x): 59 | residual = x if self.downsample is None else self.downsample(x) 60 | out = self.conv1(x) 61 | out = self.conv2(out) 62 | out += residual 63 | out = self.activation_function(out) 64 | if self.use_act_quant: 65 | out = self.act_quantizer(out) 66 | return out 67 | 68 | 69 | class QuantBottleneck(BaseQuantBlock): 70 | """ 71 | Implementation of Quantized Bottleneck Block used in ResNet-50, -101 and -152. 72 | """ 73 | 74 | def __init__(self, bottleneck: Bottleneck, weight_quant_params: dict = {}, act_quant_params: dict = {}): 75 | super().__init__(act_quant_params) 76 | self.conv1 = QuantModule(bottleneck.conv1, weight_quant_params, act_quant_params) 77 | self.conv1.activation_function = bottleneck.relu1 78 | self.conv2 = QuantModule(bottleneck.conv2, weight_quant_params, act_quant_params) 79 | self.conv2.activation_function = bottleneck.relu2 80 | self.conv3 = QuantModule(bottleneck.conv3, weight_quant_params, act_quant_params, disable_act_quant=True) 81 | 82 | # modify the activation function to ReLU 83 | self.activation_function = bottleneck.relu3 84 | 85 | if bottleneck.downsample is None: 86 | self.downsample = None 87 | else: 88 | self.downsample = QuantModule(bottleneck.downsample[0], weight_quant_params, act_quant_params, 89 | disable_act_quant=True) 90 | # copying all attributes in original block 91 | self.stride = bottleneck.stride 92 | 93 | def forward(self, x): 94 | residual = x if self.downsample is None else self.downsample(x) 95 | out = self.conv1(x) 96 | out = self.conv2(out) 97 | out = self.conv3(out) 98 | out += residual 99 | out = self.activation_function(out) 100 | if self.use_act_quant: 101 | out = self.act_quantizer(out) 102 | return out 103 | 104 | 105 | class QuantInvertedResidual(BaseQuantBlock): 106 | """ 107 | Implementation of Quantized Inverted Residual Block used in MobileNetV2. 108 | Inverted Residual does not have activation function. 109 | """ 110 | 111 | def __init__(self, inv_res: InvertedResidual, weight_quant_params: dict = {}, act_quant_params: dict = {}): 112 | super().__init__(act_quant_params) 113 | 114 | self.use_res_connect = inv_res.use_res_connect 115 | self.expand_ratio = inv_res.expand_ratio 116 | if self.expand_ratio == 1: 117 | self.conv = nn.Sequential( 118 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), 119 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params, disable_act_quant=True), 120 | ) 121 | self.conv[0].activation_function = nn.ReLU6() 122 | else: 123 | self.conv = nn.Sequential( 124 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), 125 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params), 126 | QuantModule(inv_res.conv[6], weight_quant_params, act_quant_params, disable_act_quant=True), 127 | ) 128 | self.conv[0].activation_function = nn.ReLU6() 129 | self.conv[1].activation_function = nn.ReLU6() 130 | 131 | def forward(self, x): 132 | if self.use_res_connect: 133 | out = x + self.conv(x) 134 | else: 135 | out = self.conv(x) 136 | out = self.activation_function(out) 137 | if self.use_act_quant: 138 | out = self.act_quantizer(out) 139 | return out 140 | 141 | 142 | specials = { 143 | BasicBlock: QuantBasicBlock, 144 | Bottleneck: QuantBottleneck, 145 | InvertedResidual: QuantInvertedResidual, 146 | } 147 | -------------------------------------------------------------------------------- /PTQ/quant/quant_layer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Union 6 | 7 | 8 | class StraightThrough(nn.Module): 9 | def __init__(self, channel_num: int = 1): 10 | super().__init__() 11 | 12 | def forward(self, input): 13 | return input 14 | 15 | 16 | def round_ste(x: torch.Tensor): 17 | """ 18 | Implement Straight-Through Estimator for rounding operation. 19 | """ 20 | return (x.round() - x).detach() + x 21 | 22 | 23 | def lp_loss(pred, tgt, p=2.0, reduction='none'): 24 | """ 25 | loss function measured in L_p Norm 26 | """ 27 | if reduction == 'none': 28 | return (pred-tgt).abs().pow(p).sum(1).mean() 29 | else: 30 | return (pred-tgt).abs().pow(p).mean() 31 | 32 | 33 | class UniformAffineQuantizer(nn.Module): 34 | """ 35 | PyTorch Function that can be used for asymmetric quantization (also called uniform affine 36 | quantization). Quantizes its argument in the forward pass, passes the gradient 'straight 37 | through' on the backward pass, ignoring the quantization that occurred. 38 | Based on https://arxiv.org/abs/1806.08342. 39 | 40 | :param n_bits: number of bit for quantization 41 | :param symmetric: if True, the zero_point should always be 0 42 | :param channel_wise: if True, compute scale and zero_point in each channel 43 | :param scale_method: determines the quantization scale and zero point 44 | """ 45 | def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False, scale_method: str = 'max', 46 | leaf_param: bool = False): 47 | super(UniformAffineQuantizer, self).__init__() 48 | self.sym = symmetric 49 | assert 2 <= n_bits <= 8, 'bitwidth not supported' 50 | self.n_bits = n_bits 51 | self.n_levels = 2 ** self.n_bits 52 | self.delta = None 53 | self.zero_point = None 54 | self.inited = False 55 | self.leaf_param = leaf_param 56 | self.channel_wise = channel_wise 57 | self.scale_method = scale_method 58 | 59 | def forward(self, x: torch.Tensor): 60 | 61 | if self.inited is False: 62 | if self.leaf_param: 63 | delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise) 64 | self.delta = torch.nn.Parameter(delta) 65 | # self.zero_point = torch.nn.Parameter(self.zero_point) 66 | else: 67 | self.delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise) 68 | self.inited = True 69 | 70 | # start quantization 71 | x_int = round_ste(x / self.delta) + self.zero_point 72 | x_quant = torch.clamp(x_int, 0, self.n_levels - 1) 73 | x_dequant = (x_quant - self.zero_point) * self.delta 74 | return x_dequant 75 | 76 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False): 77 | delta, zero_point = None, None 78 | if channel_wise: 79 | x_clone = x.clone().detach() 80 | n_channels = x_clone.shape[0] 81 | if len(x.shape) == 4: 82 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] 83 | else: 84 | x_max = x_clone.abs().max(dim=-1)[0] 85 | delta = x_max.clone() 86 | zero_point = x_max.clone() 87 | # determine the scale and zero point channel-by-channel 88 | for c in range(n_channels): 89 | delta[c], zero_point[c] = self.init_quantization_scale(x_clone[c], channel_wise=False) 90 | if len(x.shape) == 4: 91 | delta = delta.view(-1, 1, 1, 1) 92 | zero_point = zero_point.view(-1, 1, 1, 1) 93 | else: 94 | delta = delta.view(-1, 1) 95 | zero_point = zero_point.view(-1, 1) 96 | else: 97 | if 'max' in self.scale_method: 98 | x_min = min(x.min().item(), 0) 99 | x_max = max(x.max().item(), 0) 100 | if 'scale' in self.scale_method: 101 | x_min = x_min * (self.n_bits + 2) / 8 102 | x_max = x_max * (self.n_bits + 2) / 8 103 | 104 | x_absmax = max(abs(x_min), x_max) 105 | if self.sym: 106 | x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 107 | 108 | delta = float(x_max - x_min) / (self.n_levels - 1) 109 | if delta < 1e-8: 110 | warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max)) 111 | delta = 1e-8 112 | 113 | zero_point = round(-x_min / delta) 114 | delta = torch.tensor(delta).type_as(x) 115 | 116 | elif self.scale_method == 'mse': 117 | x_max = x.max() 118 | x_min = x.min() 119 | best_score = 1e+10 120 | for i in range(80): 121 | new_max = x_max * (1.0 - (i * 0.01)) 122 | new_min = x_min * (1.0 - (i * 0.01)) 123 | x_q = self.quantize(x, new_max, new_min) 124 | # L_p norm minimization as described in LAPQ 125 | # https://arxiv.org/abs/1911.07190 126 | score = lp_loss(x, x_q, p=2.4, reduction='all') 127 | if score < best_score: 128 | best_score = score 129 | delta = (new_max - new_min) / (2 ** self.n_bits - 1) 130 | zero_point = (- new_min / delta).round() 131 | else: 132 | raise NotImplementedError 133 | 134 | return delta, zero_point 135 | 136 | def quantize(self, x, max, min): 137 | delta = (max - min) / (2 ** self.n_bits - 1) 138 | zero_point = (- min / delta).round() 139 | # we assume weight quantization is always signed 140 | x_int = torch.round(x / delta) 141 | x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1) 142 | x_float_q = (x_quant - zero_point) * delta 143 | return x_float_q 144 | 145 | def bitwidth_refactor(self, refactored_bit: int): 146 | assert 2 <= refactored_bit <= 8, 'bitwidth not supported' 147 | self.n_bits = refactored_bit 148 | self.n_levels = 2 ** self.n_bits 149 | 150 | def extra_repr(self): 151 | s = 'bit={n_bits}, scale_method={scale_method}, symmetric={sym}, channel_wise={channel_wise},' \ 152 | ' leaf_param={leaf_param}' 153 | return s.format(**self.__dict__) 154 | 155 | 156 | class QuantModule(nn.Module): 157 | """ 158 | Quantized Module that can perform quantized convolution or normal convolution. 159 | To activate quantization, please use set_quant_state function. 160 | """ 161 | def __init__(self, org_module: Union[nn.Conv2d, nn.Linear], weight_quant_params: dict = {}, 162 | act_quant_params: dict = {}, disable_act_quant: bool = False, se_module=None): 163 | super(QuantModule, self).__init__() 164 | if isinstance(org_module, nn.Conv2d): 165 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding, 166 | dilation=org_module.dilation, groups=org_module.groups) 167 | self.fwd_func = F.conv2d 168 | else: 169 | self.fwd_kwargs = dict() 170 | self.fwd_func = F.linear 171 | self.weight = org_module.weight 172 | self.org_weight = org_module.weight.data.clone() 173 | if org_module.bias is not None: 174 | self.bias = org_module.bias 175 | self.org_bias = org_module.bias.data.clone() 176 | else: 177 | self.bias = None 178 | self.org_bias = None 179 | # de-activate the quantized forward default 180 | self.use_weight_quant = False 181 | self.use_act_quant = False 182 | self.use_bias_corr = False 183 | self.vcorr_weight = False 184 | self.bcorr_weight = False 185 | self.disable_act_quant = disable_act_quant 186 | # initialize quantizer 187 | self.weight_quantizer = UniformAffineQuantizer(**weight_quant_params) 188 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 189 | 190 | self.activation_function = StraightThrough() 191 | self.ignore_reconstruction = False 192 | 193 | self.se_module = se_module 194 | self.extra_repr = org_module.extra_repr 195 | 196 | def forward(self, input: torch.Tensor): 197 | if self.use_weight_quant: 198 | weight = self.weight_quantizer(self.weight) 199 | bias = self.bias 200 | else: 201 | weight = self.org_weight 202 | bias = self.org_bias 203 | 204 | # if self.use_bias_corr and self.use_weight_quant: 205 | # bias_q = weight.view(weight.shape[0], -1).mean(-1) 206 | # bias_q = bias_q.view(bias_q.numel(), 1, 1, 1) if len(weight.shape) == 4 else bias_q.view(bias_q.numel(), 207 | # 1) 208 | # bias_orig = self.org_weight.view(self.org_weight.shape[0], -1).mean(-1) 209 | # bias_orig = bias_orig.view(bias_orig.numel(), 1, 1, 1) if len(self.org_weight.shape) == 4 else bias_orig.view( 210 | # bias_orig.numel(), 1) 211 | # 212 | # if self.vcorr_weight: 213 | # eps = torch.tensor([1e-8]).to(weight.device) 214 | # var_corr = self.org_weight.view(self.org_weight.shape[0], -1).std(dim=-1) / \ 215 | # (weight.view(weight.shape[0], -1).std(dim=-1) + eps) 216 | # var_corr = (var_corr.view(var_corr.numel(), 1, 1, 1) if len(weight.shape) == 4 else var_corr.view( 217 | # var_corr.numel(), 1)) 218 | # 219 | # # Correct variance 220 | # weight = (weight - bias_q) * var_corr + bias_q 221 | # 222 | # if self.bcorr_weight: 223 | # # Correct mean 224 | # weight = weight - bias_q + bias_orig 225 | 226 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) 227 | # disable act quantization is designed for convolution before elemental-wise operation, 228 | # in that case, we apply activation function and quantization after ele-wise op. 229 | if self.se_module is not None: 230 | out = self.se_module(out) 231 | out = self.activation_function(out) 232 | if self.disable_act_quant: 233 | return out 234 | if self.use_act_quant: 235 | out = self.act_quantizer(out) 236 | return out 237 | 238 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 239 | self.use_weight_quant = weight_quant 240 | self.use_act_quant = act_quant 241 | 242 | def set_bias_state(self, use_bias_corr: bool = False, vcorr_weight: bool = False, bcorr_weight: bool = False): 243 | self.use_bias_corr = use_bias_corr 244 | self.vcorr_weight = vcorr_weight 245 | self.bcorr_weight = bcorr_weight 246 | -------------------------------------------------------------------------------- /PTQ/quant/quant_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from quant.quant_block import specials, BaseQuantBlock 3 | from quant.quant_layer import QuantModule, StraightThrough 4 | from quant.fold_bn import search_fold_and_remove_bn 5 | import sys 6 | 7 | 8 | class QuantModel(nn.Module): 9 | 10 | def __init__(self, model: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}): 11 | super().__init__() 12 | search_fold_and_remove_bn(model) 13 | self.model = model 14 | self.quant_module_refactor(self.model, weight_quant_params, act_quant_params) 15 | 16 | def quant_module_refactor(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}): 17 | """ 18 | Recursively replace the normal conv2d and Linear layer to QuantModule 19 | :param module: nn.Module with nn.Conv2d or nn.Linear in its children 20 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer 21 | :param act_quant_params: quantization parameters like n_bits for activation quantizer 22 | """ 23 | prev_quantmodule = None 24 | for name, child_module in module.named_children(): 25 | if type(child_module) in specials: 26 | setattr(module, name, specials[type(child_module)](child_module, weight_quant_params, act_quant_params)) 27 | 28 | elif isinstance(child_module, (nn.Conv2d, nn.Linear)): 29 | setattr(module, name, QuantModule(child_module, weight_quant_params, act_quant_params)) 30 | prev_quantmodule = getattr(module, name) 31 | 32 | elif isinstance(child_module, (nn.ReLU, nn.ReLU6)): 33 | if prev_quantmodule is not None: 34 | prev_quantmodule.activation_function = child_module 35 | setattr(module, name, StraightThrough()) 36 | else: 37 | continue 38 | 39 | elif isinstance(child_module, StraightThrough): 40 | continue 41 | 42 | else: 43 | self.quant_module_refactor(child_module, weight_quant_params, act_quant_params) 44 | 45 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 46 | for m in self.model.modules(): 47 | if isinstance(m, (QuantModule, BaseQuantBlock)): 48 | m.set_quant_state(weight_quant, act_quant) 49 | 50 | def set_bias_state(self, use_bias_corr: bool = False, vcorr_weight: bool = False, bcorr_weight: bool = False): #新增 51 | for m in self.model.modules(): 52 | if isinstance(m, QuantModule): 53 | m.set_bias_state(use_bias_corr, vcorr_weight, bcorr_weight) 54 | 55 | def forward(self, input): 56 | return self.model(input) 57 | 58 | def set_first_last_layer_to_8bit(self): 59 | module_list = [] 60 | for m in self.model.modules(): 61 | if isinstance(m, QuantModule): 62 | module_list += [m] 63 | 64 | module_list[0].weight_quantizer.bitwidth_refactor(8) 65 | module_list[0].act_quantizer.bitwidth_refactor(8) 66 | module_list[-1].weight_quantizer.bitwidth_refactor(8) 67 | module_list[-2].act_quantizer.bitwidth_refactor(8) 68 | # ignore reconstruction of the first layer 69 | module_list[0].ignore_reconstruction = True 70 | 71 | def disable_network_output_quantization(self): 72 | module_list = [] 73 | for m in self.model.modules(): 74 | if isinstance(m, QuantModule): 75 | module_list += [m] 76 | module_list[-1].disable_act_quant = True 77 | 78 | def set_mixed_precision(self, bit_cfg): 79 | bit_cfgs = [] 80 | if len(bit_cfg) == 8: 81 | for i in range(len(bit_cfg)): 82 | for j in range(2): 83 | bit_cfgs.append(bit_cfg[i]) 84 | if i in [2, 4, 6]: 85 | bit_cfgs.append(bit_cfg[i]) 86 | elif len(bit_cfg) == 20: 87 | for i in range(len(bit_cfg)): 88 | if i == 0 or i == 18 or i == 19: 89 | bit_cfgs.append(bit_cfg[i]) 90 | elif i == 1: 91 | for j in range(2): 92 | bit_cfgs.append(bit_cfg[i]) 93 | else: 94 | for j in range(3): 95 | bit_cfgs.append(bit_cfg[i]) 96 | elif len(bit_cfg) == 53: 97 | bit_cfgs = bit_cfg 98 | else: 99 | for i in range(len(bit_cfg)): 100 | bit_cfgs.append(bit_cfg[i]) 101 | if i in [6, 10, 14]: 102 | bit_cfgs.append(bit_cfg[i]) 103 | module_list = [] 104 | for m in self.model.modules(): 105 | if isinstance(m, QuantModule): 106 | module_list += [m] 107 | # print(module_list) 108 | # print(bit_cfgs) 109 | # sys.exit(1) 110 | 111 | 112 | # for i in range(len(module_list)-2): 113 | # module_list[i+1].weight_quantizer.bitwidth_refactor(bit_cfgs[i]) 114 | for i in range(len(bit_cfgs)): 115 | module_list[i].weight_quantizer.bitwidth_refactor(bit_cfgs[i]) 116 | # module_list[i].act_quantizer.bitwidth_refactor(4) 117 | # for m in self.model.modules(): 118 | # if isinstance(m, (QuantModule, BaseQuantBlock)): 119 | # print(m) 120 | # sys.exit(1) 121 | 122 | -------------------------------------------------------------------------------- /PTQ/run_scripts/train_mobilenetv2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## mobilenetv2 1.5Mb 4 | #python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3]" 5 | 6 | # mobilenetv2 1.3Mb 7 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 2]" 8 | 9 | ## mobilenetv2 1.1Mb 10 | #python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 4, 3, 3, 4, 3, 3, 4, 2, 2, 2]" 11 | # 12 | ## mobilenetv2 0.9Mb 13 | #python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 3, 3, 4, 3, 3, 4, 3, 3, 4, 3, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 2]" 14 | 15 | #"""CDP""" 16 | ## mobilenetv2 1.5Mb cdp 17 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3]" --gpu '0' 18 | 19 | # mobilenetv2 1.3Mb cdp 20 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2]" --gpu '1' 21 | 22 | ## mobilenetv2 1.1Mb cdp 23 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 4, 2, 4, 4, 2, 4, 2, 2, 2]" --gpu '3' 24 | # 25 | ## mobilenetv2 0.9Mb cdp 26 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 4, 2, 4, 4, 2, 4, 4, 2, 4, 4, 2, 4, 2, 2, 4, 3, 2, 4, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2]" --gpu '3' 27 | 28 | #"""CDP加上beta后""" 29 | ## mobilenetv2 1.5Mb cdp 修改前 前后不变 30 | python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3]" --gpu '0' 31 | 32 | # mobilenetv2 1.3Mb cdp 修改前 33 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2]" --gpu '1' 34 | 35 | ## mobilenetv2 1.1Mb cdp 修改前 36 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 4, 2, 3, 4, 2, 2, 2]" --gpu '1' 37 | # 38 | ## mobilenetv2 0.9Mb cdp 修改前 39 | # python main_imagenet.py --data_path /home/data/imagenet/ --arch mobilenetv2 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --weight 0.1 --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 4, 3, 4, 4, 3, 4, 4, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 2]" --gpu '3' 40 | 41 | -------------------------------------------------------------------------------- /PTQ/run_scripts/train_resnet18.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## resnet-18 3.0Mb 4 | #python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 3, 4, 4, 4, 4, 4, 3, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2]" 5 | # 6 | ## resnet-18 3.5Mb 7 | #python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 3, 3, 4, 3, 2, 2, 2, 3]" 8 | # 9 | ## resnet-18 4.0Mb 10 | #python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 4, 3, 2, 3, 3, 3]" 11 | 12 | # resnet-18 4.5Mb 13 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 3, 3, 4, 4, 3, 3, 3, 3]" 14 | # 15 | ## resnet-18 5.0Mb 16 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 3, 3, 3, 4, 4, 4, 3, 4, 3, 4, 3, 4, 3, 3, 4, 4, 4]" 17 | # 18 | ## resnet-18 5.5Mb 19 | #python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 4]" 20 | 21 | # CDP 22 | 23 | ## resnet-18 3.0Mb cdp 24 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 3, 4, 4, 3, 4, 2, 2, 2, 2, 2, 2, 2, 2]" --gpu '0' 25 | # 26 | ## resnet-18 3.5Mb cdp 27 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 4]" --gpu '0' 28 | # 29 | ## resnet-18 4.0Mb cdp 30 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 2, 2, 3, 3]" --gpu '1' 31 | # 32 | ## resnet-18 4.5Mb cdp 33 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2, 3, 4]" --gpu '1' 34 | # 35 | ## resnet-18 5.0Mb cdp 36 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 4, 4]" --gpu '3' 37 | # 38 | ### resnet-18 5.5Mb cdp 39 | python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4]" --gpu '3' 40 | 41 | # CDP修改后 beta 3.3 42 | 43 | ## resnet-18 3.0Mb cdp 44 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2, 2, 2, 2, 2, 2, 2, 2]" --gpu '0' 45 | # 46 | ## resnet-18 3.5Mb cdp 47 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4, 2, 2, 2, 2, 4]" --gpu '0' 48 | # 49 | ## resnet-18 4.0Mb cdp 50 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 4]" --gpu '1' 51 | # 52 | ## resnet-18 4.5Mb cdp 53 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2, 3, 4]" --gpu '1' 54 | # 55 | ## resnet-18 5.0Mb cdp 56 | # python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 4, 4]" --gpu '3' 57 | # 58 | ### resnet-18 5.5Mb cdp 59 | python main_imagenet.py --data_path /home/data/imagenet --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4]" --gpu '3' -------------------------------------------------------------------------------- /QAT/run_scripts/train_resnet18.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python quant_train.py \ 3 | -a resnet18 \ 4 | --epochs 90 \ 5 | --lr 0.0001 \ 6 | --batch_size 128 \ 7 | --data /home/data/imagenet/ \ 8 | --save_path "saved_quant_model/train_resnet18_${RANDOM}" \ 9 | --act_range_momentum=0.99 \ 10 | --wd 1e-4 \ 11 | --data_percentage 1 \ 12 | --pretrained \ 13 | --fix_BN \ 14 | --checkpoint_iter -1 \ 15 | --gpu_id '0,1' \ 16 | --quant_scheme modelsize_6.7_a6_75B -------------------------------------------------------------------------------- /QAT/run_scripts/train_resnet18_a8_cdp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python quant_train.py \ 3 | -a resnet18 \ 4 | --epochs 90 \ 5 | --lr 0.0001 \ 6 | --batch_size 1024 \ 7 | --data /home/data/imagenet/ \ 8 | --save_path "saved_quant_model/train_resnet18_${RANDOM}/" \ 9 | --act_range_momentum=0.99 \ 10 | --wd 1e-4 \ 11 | --data_percentage 1 \ 12 | --pretrained \ 13 | --fix_BN \ 14 | --checkpoint_iter -1 \ 15 | --gpu_id '1' \ 16 | --quant_scheme cdp_modelsize_6.7_a8_84B -------------------------------------------------------------------------------- /QAT/run_scripts/train_resnet18_cdp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python quant_train.py \ 3 | -a resnet18 \ 4 | --epochs 90 \ 5 | --lr 0.0001 \ 6 | --batch_size 1024 \ 7 | --data /home/data/imagenet/ \ 8 | --save_path "saved_quant_model/train_resnet18_${RANDOM}/" \ 9 | --act_range_momentum=0.99 \ 10 | --wd 1e-4 \ 11 | --data_percentage 1 \ 12 | --pretrained \ 13 | --fix_BN \ 14 | --checkpoint_iter -1 \ 15 | --gpu_id '0' \ 16 | --quant_scheme cdp_modelsize_6.7_a6_63B -------------------------------------------------------------------------------- /QAT/run_scripts/train_resnet50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python quant_train.py \ 3 | -a resnet50 \ 4 | --epochs 90 \ 5 | --lr 0.0001 \ 6 | --batch_size 128 \ 7 | --data /home/data/imagenet/ \ 8 | --save_path "saved_quant_model/train_resnet50_${RANDOM}" \ 9 | --act_range_momentum=0.99 \ 10 | --wd 1e-4 \ 11 | --data_percentage 1 \ 12 | --pretrained \ 13 | --fix_BN \ 14 | --checkpoint_iter -1 \ 15 | --gpu_id '0,1' \ 16 | --quant_scheme modelsize_16.0_a5_141BOP 17 | -------------------------------------------------------------------------------- /QAT/run_scripts/train_resnet50_cdp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python quant_train.py \ 3 | -a resnet50 \ 4 | --epochs 90 \ 5 | --lr 0.0001 \ 6 | --batch_size 1024 \ 7 | --data /home/data/imagenet/ \ 8 | --save_path "saved_quant_model/train_resnet50_${RANDOM}" \ 9 | --act_range_momentum=0.99 \ 10 | --wd 1e-4 \ 11 | --data_percentage 1 \ 12 | --pretrained \ 13 | --fix_BN \ 14 | --checkpoint_iter -1 \ 15 | --gpu_id '0,1' \ 16 | --quant_scheme cdp_modelsize_15.9_a5_143BOP -------------------------------------------------------------------------------- /QAT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import * 2 | from .models.q_mobilenetv2 import * 3 | from .models.q_inceptionv3 import * 4 | from .models.q_resnet import * -------------------------------------------------------------------------------- /QAT/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torchvision import datasets, transforms 3 | import torch 4 | 5 | 6 | class UniformDataset(Dataset): 7 | """ 8 | get random uniform samples with mean 0 and variance 1 9 | """ 10 | def __init__(self, length, size, transform): 11 | self.length = length 12 | self.transform = transform 13 | self.size = size 14 | 15 | def __len__(self): 16 | return self.length 17 | 18 | def __getitem__(self, idx): 19 | # var[U(-128, 127)] = (127 - (-128))**2 / 12 = 5418.75 20 | sample = (torch.randint(high=255, size=self.size).float() - 127.5) / 5418.75 21 | return sample 22 | 23 | 24 | def getRandomData(dataset='imagenet', batch_size=512, for_inception=False): 25 | """ 26 | get random sample dataloader 27 | dataset: name of the dataset 28 | batch_size: the batch size of random data 29 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224 30 | """ 31 | if dataset == 'cifar10': 32 | size = (3, 32, 32) 33 | num_data = 10000 34 | elif dataset == 'imagenet': 35 | num_data = 10000 36 | if not for_inception: 37 | size = (3, 224, 224) 38 | else: 39 | size = (3, 299, 299) 40 | else: 41 | raise NotImplementedError 42 | dataset = UniformDataset(length=num_data, size=size, transform=None) 43 | data_loader = DataLoader(dataset, 44 | batch_size=batch_size, 45 | shuffle=False, 46 | num_workers=32) 47 | return data_loader 48 | 49 | 50 | def getTestData(dataset='imagenet', 51 | batch_size=1024, 52 | path='data/imagenet', 53 | for_inception=False): 54 | """ 55 | Get dataloader of testset 56 | dataset: name of the dataset 57 | batch_size: the batch size of random data 58 | path: the path to the data 59 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224 60 | """ 61 | if dataset == 'imagenet': 62 | input_size = 299 if for_inception else 224 63 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 64 | std=[0.229, 0.224, 0.225]) 65 | 66 | test_dataset = datasets.ImageFolder( 67 | path + 'val', 68 | transforms.Compose([ 69 | transforms.Resize(int(input_size / 0.875)), 70 | transforms.CenterCrop(input_size), 71 | transforms.ToTensor(), 72 | normalize, 73 | ])) 74 | 75 | test_loader = DataLoader(test_dataset, 76 | batch_size=batch_size, 77 | shuffle=False, 78 | num_workers=32) 79 | 80 | return test_loader 81 | elif dataset == 'cifar10': 82 | data_dir = '/rscratch/yaohuic/data/' 83 | normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 84 | std=(0.2023, 0.1994, 0.2010)) 85 | transform_test = transforms.Compose([transforms.ToTensor(), normalize]) 86 | 87 | test_dataset = datasets.CIFAR10(root=data_dir, 88 | train=False, 89 | transform=transform_test) 90 | test_loader = DataLoader(test_dataset, 91 | batch_size=batch_size, 92 | shuffle=False, 93 | num_workers=32) 94 | return test_loader 95 | 96 | 97 | def getTrainData(dataset='imagenet', 98 | batch_size=512, 99 | path='data/imagenet', 100 | for_inception=False, 101 | data_percentage=0.1): 102 | """ 103 | Get dataloader of training 104 | dataset: name of the dataset 105 | batch_size: the batch size of random data 106 | path: the path to the data 107 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224 108 | """ 109 | if dataset == 'imagenet': 110 | input_size = 299 if for_inception else 224 111 | traindir = path + 'train' 112 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 113 | std=[0.229, 0.224, 0.225]) 114 | 115 | train_dataset = datasets.ImageFolder( 116 | traindir, 117 | transforms.Compose([ 118 | transforms.RandomResizedCrop(input_size), 119 | transforms.RandomHorizontalFlip(), 120 | transforms.ToTensor(), 121 | normalize, 122 | ])) 123 | 124 | dataset_length = int(len(train_dataset) * data_percentage) 125 | partial_train_dataset, _ = torch.utils.data.random_split(train_dataset, [dataset_length, len(train_dataset)-dataset_length]) 126 | 127 | train_loader = torch.utils.data.DataLoader( 128 | partial_train_dataset, batch_size=batch_size, shuffle=True, 129 | num_workers=32, pin_memory=True) 130 | 131 | return train_loader 132 | -------------------------------------------------------------------------------- /QAT/utils/models/q_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quantized MobileNetV2 for ImageNet-1K, implemented in PyTorch. 3 | Original paper: 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381. 4 | """ 5 | 6 | import os 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | from ..quantization_utils.quant_modules import * 10 | 11 | 12 | class Q_LinearBottleneck(nn.Module): 13 | def __init__(self, 14 | model, 15 | in_channels, 16 | out_channels, 17 | stride, 18 | expansion, 19 | remove_exp_conv): 20 | """ 21 | So-called 'Linear Bottleneck' layer. It is used as a quantized MobileNetV2 unit. 22 | Parameters: 23 | ---------- 24 | model : nn.Module 25 | The pretrained floating-point couterpart of this module with the same structure. 26 | in_channels : int 27 | Number of input channels. 28 | out_channels : int 29 | Number of output channels. 30 | stride : int or tuple/list of 2 int 31 | Strides of the second convolution layer. 32 | expansion : bool 33 | Whether do expansion of channels. 34 | remove_exp_conv : bool 35 | Whether to remove expansion convolution. 36 | """ 37 | super(Q_LinearBottleneck, self).__init__() 38 | self.residual = (in_channels == out_channels) and (stride == 1) 39 | mid_channels = in_channels * 6 if expansion else in_channels 40 | self.use_exp_conv = (expansion or (not remove_exp_conv)) 41 | self.activatition_func = nn.ReLU6() 42 | 43 | self.quant_act = QuantAct() 44 | 45 | if self.use_exp_conv: 46 | self.conv1 = QuantBnConv2d() 47 | self.conv1.set_param(model.conv1.conv, model.conv1.bn) 48 | self.quant_act1 = QuantAct() 49 | 50 | self.conv2 = QuantBnConv2d() 51 | self.conv2.set_param(model.conv2.conv, model.conv2.bn) 52 | self.quant_act2 = QuantAct() 53 | 54 | self.conv3 = QuantBnConv2d() 55 | self.conv3.set_param(model.conv3.conv, model.conv3.bn) 56 | 57 | self.quant_act_int32 = QuantAct() 58 | 59 | def forward(self, x, scaling_factor_int32=None): 60 | if self.residual: 61 | identity = x 62 | 63 | x, act_scaling_factor = self.quant_act(x, scaling_factor_int32, None, None, None, None) 64 | 65 | if self.use_exp_conv: 66 | x, weight_scaling_factor = self.conv1(x, act_scaling_factor) 67 | x = self.activatition_func(x) 68 | x, self.act_scaling_factor = self.quant_act1(x, act_scaling_factor, weight_scaling_factor, None, None) 69 | 70 | x, weight_scaling_factor = self.conv2(x, act_scaling_factor) 71 | x = self.activatition_func(x) 72 | x, act_scaling_factor = self.quant_act2(x, act_scaling_factor, weight_scaling_factor, None, None) 73 | 74 | # note that, there is no activation for the last conv 75 | x, weight_scaling_factor = self.conv3(x, act_scaling_factor) 76 | else: 77 | x, weight_scaling_factor = self.conv2(x, act_scaling_factor) 78 | x = self.activatition_func(x) 79 | x, act_scaling_factor = self.quant_act2(x, act_scaling_factor, weight_scaling_factor, None, None) 80 | 81 | # note that, there is no activation for the last conv 82 | x, weight_scaling_factor = self.conv3(x, act_scaling_factor) 83 | 84 | if self.residual: 85 | x = x + identity 86 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, identity, scaling_factor_int32, None) 87 | else: 88 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, None, None, None) 89 | 90 | return x, act_scaling_factor 91 | 92 | 93 | class Q_MobileNetV2(nn.Module): 94 | """ 95 | Quantized MobileNetV2 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381. 96 | Parameters: 97 | ---------- 98 | model : nn.Module 99 | The pretrained floating-point MobileNetV2. 100 | channels : list of list of int 101 | Number of output channels for each unit. 102 | init_block_channels : int 103 | Number of output channels for the initial unit. 104 | final_block_channels : int 105 | Number of output channels for the final block of the feature extractor. 106 | remove_exp_conv : bool 107 | Whether to remove expansion convolution. 108 | in_channels : int, default 3 109 | Number of input channels. 110 | in_size : tuple of two ints, default (224, 224) 111 | Spatial size of the expected input image. 112 | num_classes : int, default 1000 113 | Number of classification classes. 114 | """ 115 | def __init__(self, 116 | model, 117 | channels, 118 | init_block_channels, 119 | final_block_channels, 120 | remove_exp_conv, 121 | in_channels=3, 122 | in_size=(224, 224), 123 | num_classes=1000): 124 | super(Q_MobileNetV2, self).__init__() 125 | self.in_size = in_size 126 | self.num_classes = num_classes 127 | self.channels = channels 128 | self.activatition_func = nn.ReLU6() 129 | 130 | # add input quantization 131 | self.quant_input = QuantAct() 132 | 133 | # change the inital block 134 | self.add_module("init_block", QuantBnConv2d()) 135 | 136 | self.init_block.set_param(model.features.init_block.conv, model.features.init_block.bn) 137 | 138 | self.quant_act_int32 = QuantAct() 139 | 140 | self.features = nn.Sequential() 141 | # change the middle blocks 142 | in_channels = init_block_channels 143 | for i, channels_per_stage in enumerate(channels): 144 | stage = nn.Sequential() 145 | cur_stage = getattr(model.features, f'stage{i+1}') 146 | for j, out_channels in enumerate(channels_per_stage): 147 | cur_unit = getattr(cur_stage, f'unit{j+1}') 148 | 149 | stride = 2 if (j == 0) and (i != 0) else 1 150 | expansion = (i != 0) or (j != 0) 151 | 152 | stage.add_module("unit{}".format(j + 1), Q_LinearBottleneck( 153 | cur_unit, 154 | in_channels=in_channels, 155 | out_channels=out_channels, 156 | stride=stride, 157 | expansion=expansion, 158 | remove_exp_conv=remove_exp_conv, 159 | )) 160 | 161 | in_channels = out_channels 162 | self.features.add_module("stage{}".format(i + 1), stage) 163 | 164 | # change the final block 165 | self.quant_act_before_final_block = QuantAct() 166 | self.features.add_module("final_block", QuantBnConv2d()) 167 | 168 | self.features.final_block.set_param(model.features.final_block.conv, model.features.final_block.bn) 169 | self.quant_act_int32_final = QuantAct() 170 | 171 | in_channels = final_block_channels 172 | 173 | self.features.add_module("final_pool", QuantAveragePool2d()) 174 | self.features.final_pool.set_param(model.features.final_pool) 175 | self.quant_act_output = QuantAct() 176 | 177 | self.output = QuantConv2d() 178 | self.output.set_param(model.output) 179 | 180 | def forward(self, x): 181 | # quantize input 182 | x, act_scaling_factor = self.quant_input(x) 183 | 184 | # the init block 185 | x, weight_scaling_factor = self.init_block(x, act_scaling_factor) 186 | x = self.activatition_func(x) 187 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, None, None) 188 | 189 | # the feature block 190 | for i, channels_per_stage in enumerate(self.channels): 191 | cur_stage = getattr(self.features, f'stage{i+1}') 192 | for j, out_channels in enumerate(channels_per_stage): 193 | cur_unit = getattr(cur_stage, f'unit{j+1}') 194 | 195 | x, act_scaling_factor = cur_unit(x, act_scaling_factor) 196 | x, act_scaling_factor = self.quant_act_before_final_block(x, act_scaling_factor, None, None, None, None) 197 | x, weight_scaling_factor = self.features.final_block(x, act_scaling_factor) 198 | x = self.activatition_func(x) 199 | x, act_scaling_factor = self.quant_act_int32_final(x, act_scaling_factor, weight_scaling_factor, None, None, None) 200 | 201 | # the final pooling 202 | x = self.features.final_pool(x, act_scaling_factor) 203 | 204 | # the output 205 | x, act_scaling_factor = self.quant_act_output(x, act_scaling_factor, None, None, None, None) 206 | x, act_scaling_factor = self.output(x, act_scaling_factor) 207 | 208 | x = x.view(x.size(0), -1) 209 | 210 | return x 211 | 212 | 213 | def q_get_mobilenetv2(model, width_scale, remove_exp_conv=False): 214 | """ 215 | Create quantized MobileNetV2 model with specific parameters. 216 | Parameters: 217 | ---------- 218 | model : nn.Module 219 | The pretrained floating-point MobileNetV2. 220 | width_scale : float 221 | Scale factor for width of layers. 222 | remove_exp_conv : bool, default False 223 | Whether to remove expansion convolution. 224 | """ 225 | 226 | init_block_channels = 32 227 | final_block_channels = 1280 228 | layers = [1, 2, 3, 4, 3, 3, 1] 229 | downsample = [0, 1, 1, 1, 0, 1, 0] 230 | channels_per_layers = [16, 24, 32, 64, 96, 160, 320] 231 | 232 | from functools import reduce 233 | channels = reduce( 234 | lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], 235 | zip(channels_per_layers, layers, downsample), 236 | [[]]) 237 | 238 | if width_scale != 1.0: 239 | channels = [[int(cij * width_scale) for cij in ci] for ci in channels] 240 | init_block_channels = int(init_block_channels * width_scale) 241 | if width_scale > 1.0: 242 | final_block_channels = int(final_block_channels * width_scale) 243 | 244 | net = Q_MobileNetV2( 245 | model, 246 | channels=channels, 247 | init_block_channels=init_block_channels, 248 | final_block_channels=final_block_channels, 249 | remove_exp_conv=remove_exp_conv) 250 | 251 | return net 252 | 253 | 254 | def q_mobilenetv2_w1(model): 255 | """ 256 | Quantized 1.0 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 257 | https://arxiv.org/abs/1801.04381. 258 | Parameters: 259 | model : nn.Module 260 | The pretrained floating-point MobileNetV2. 261 | """ 262 | return q_get_mobilenetv2(model, width_scale=1.0) 263 | -------------------------------------------------------------------------------- /QAT/utils/models/q_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quantized ResNet for ImageNet-1K, implemented in PyTorch. 3 | Original paper: 'Deep Residual Learning for Image Recognition,' https://arxiv.org/abs/1512.03385. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import copy 9 | from ..quantization_utils.quant_modules import * 10 | from pytorchcv.models.common import ConvBlock 11 | from pytorchcv.models.shufflenetv2 import ShuffleUnit, ShuffleInitBlock 12 | import time 13 | import logging 14 | 15 | 16 | class Q_ResNet18(nn.Module): 17 | """ 18 | Quantized ResNet50 model from 'Deep Residual Learning for Image Recognition,' https://arxiv.org/abs/1512.03385. 19 | """ 20 | def __init__(self, model): 21 | super().__init__() 22 | features = getattr(model, 'features') 23 | init_block = getattr(features, 'init_block') 24 | 25 | self.quant_input = QuantAct() 26 | 27 | self.quant_init_block_convbn = QuantBnConv2d() 28 | self.quant_init_block_convbn.set_param(init_block.conv.conv, init_block.conv.bn) 29 | 30 | self.quant_act_int32 = QuantAct() 31 | 32 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 33 | self.act = nn.ReLU() 34 | 35 | self.channel = [2, 2, 2, 2] 36 | 37 | for stage_num in range(0, 4): 38 | stage = getattr(features, "stage{}".format(stage_num + 1)) 39 | for unit_num in range(0, self.channel[stage_num]): 40 | unit = getattr(stage, "unit{}".format(unit_num + 1)) 41 | quant_unit = Q_ResBlockBn() 42 | quant_unit.set_param(unit) 43 | setattr(self, f"stage{stage_num + 1}.unit{unit_num + 1}", quant_unit) 44 | 45 | self.final_pool = QuantAveragePool2d(kernel_size=7, stride=1) 46 | 47 | self.quant_act_output = QuantAct() 48 | 49 | output = getattr(model, 'output') 50 | self.quant_output = QuantLinear() 51 | self.quant_output.set_param(output) 52 | 53 | def forward(self, x): 54 | x, act_scaling_factor = self.quant_input(x) 55 | 56 | x, weight_scaling_factor = self.quant_init_block_convbn(x, act_scaling_factor) 57 | 58 | x = self.pool(x) 59 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor) 60 | 61 | x = self.act(x) 62 | 63 | for stage_num in range(0, 4): 64 | for unit_num in range(0, self.channel[stage_num]): 65 | tmp_func = getattr(self, f"stage{stage_num+1}.unit{unit_num+1}") 66 | x, act_scaling_factor = tmp_func(x, act_scaling_factor) 67 | 68 | x = self.final_pool(x, act_scaling_factor) 69 | 70 | x, act_scaling_factor = self.quant_act_output(x, act_scaling_factor) 71 | x = x.view(x.size(0), -1) 72 | x = self.quant_output(x, act_scaling_factor) 73 | 74 | return x 75 | 76 | 77 | class Q_ResNet50(nn.Module): 78 | """ 79 | Quantized ResNet50 model from 'Deep Residual Learning for Image Recognition,' https://arxiv.org/abs/1512.03385. 80 | """ 81 | def __init__(self, model): 82 | super().__init__() 83 | 84 | features = getattr(model, 'features') 85 | 86 | init_block = getattr(features, 'init_block') 87 | self.quant_input = QuantAct() 88 | self.quant_init_convbn = QuantBnConv2d() 89 | self.quant_init_convbn.set_param(init_block.conv.conv, init_block.conv.bn) 90 | 91 | self.quant_act_int32 = QuantAct() 92 | 93 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 94 | self.act = nn.ReLU() 95 | 96 | self.channel = [3, 4, 6, 3] 97 | 98 | for stage_num in range(0, 4): 99 | stage = getattr(features, "stage{}".format(stage_num + 1)) 100 | for unit_num in range(0, self.channel[stage_num]): 101 | unit = getattr(stage, "unit{}".format(unit_num + 1)) 102 | quant_unit = Q_ResUnitBn() 103 | quant_unit.set_param(unit) 104 | setattr(self, f"stage{stage_num + 1}.unit{unit_num + 1}", quant_unit) 105 | 106 | self.final_pool = QuantAveragePool2d(kernel_size=7, stride=1) 107 | 108 | self.quant_act_output = QuantAct() 109 | 110 | output = getattr(model, 'output') 111 | self.quant_output = QuantLinear() 112 | self.quant_output.set_param(output) 113 | 114 | def forward(self, x): 115 | x, act_scaling_factor = self.quant_input(x) 116 | 117 | x, weight_scaling_factor = self.quant_init_convbn(x, act_scaling_factor) 118 | 119 | x = self.pool(x) 120 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor) 121 | 122 | x = self.act(x) 123 | 124 | for stage_num in range(0, 4): 125 | for unit_num in range(0, self.channel[stage_num]): 126 | tmp_func = getattr(self, f"stage{stage_num+1}.unit{unit_num+1}") 127 | x, act_scaling_factor = tmp_func(x, act_scaling_factor) 128 | 129 | x = self.final_pool(x, act_scaling_factor) 130 | 131 | x, act_scaling_factor = self.quant_act_output(x, act_scaling_factor) 132 | x = x.view(x.size(0), -1) 133 | x = self.quant_output(x, act_scaling_factor) 134 | 135 | return x 136 | 137 | 138 | class Q_ResNet101(nn.Module): 139 | """ 140 | Quantized ResNet101 model from 'Deep Residual Learning for Image Recognition,' https://arxiv.org/abs/1512.03385. 141 | """ 142 | def __init__(self, model): 143 | super().__init__() 144 | 145 | features = getattr(model, 'features') 146 | 147 | init_block = getattr(features, 'init_block') 148 | self.quant_input = QuantAct() 149 | self.quant_init_convbn = QuantBnConv2d() 150 | self.quant_init_convbn.set_param(init_block.conv.conv, init_block.conv.bn) 151 | 152 | self.quant_act_int32 = QuantAct() 153 | 154 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 155 | self.act = nn.ReLU() 156 | 157 | self.channel = [3, 4, 23, 3] 158 | 159 | for stage_num in range(0, 4): 160 | stage = getattr(features, "stage{}".format(stage_num + 1)) 161 | for unit_num in range(0, self.channel[stage_num]): 162 | unit = getattr(stage, "unit{}".format(unit_num + 1)) 163 | quant_unit = Q_ResUnitBn() 164 | quant_unit.set_param(unit) 165 | setattr(self, f"stage{stage_num + 1}.unit{unit_num + 1}", quant_unit) 166 | 167 | self.final_pool = QuantAveragePool2d(kernel_size=7, stride=1) 168 | 169 | self.quant_act_output = QuantAct() 170 | 171 | output = getattr(model, 'output') 172 | self.quant_output = QuantLinear() 173 | self.quant_output.set_param(output) 174 | 175 | def forward(self, x): 176 | x, act_scaling_factor = self.quant_input(x) 177 | 178 | x, weight_scaling_factor = self.quant_init_convbn(x, act_scaling_factor) 179 | 180 | x = self.pool(x) 181 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, None, None) 182 | 183 | x = self.act(x) 184 | 185 | for stage_num in range(0, 4): 186 | for unit_num in range(0, self.channel[stage_num]): 187 | tmp_func = getattr(self, f"stage{stage_num+1}.unit{unit_num+1}") 188 | x, act_scaling_factor = tmp_func(x, act_scaling_factor) 189 | 190 | x = self.final_pool(x, act_scaling_factor) 191 | 192 | x, act_scaling_factor = self.quant_act_output(x, act_scaling_factor) 193 | x = x.view(x.size(0), -1) 194 | x = self.quant_output(x, act_scaling_factor) 195 | 196 | return x 197 | 198 | 199 | class Q_ResUnitBn(nn.Module): 200 | """ 201 | Quantized ResNet unit with residual path. 202 | """ 203 | def __init__(self): 204 | super(Q_ResUnitBn, self).__init__() 205 | 206 | def set_param(self, unit): 207 | self.resize_identity = unit.resize_identity 208 | 209 | self.quant_act = QuantAct() 210 | 211 | convbn1 = unit.body.conv1 212 | self.quant_convbn1 = QuantBnConv2d() 213 | self.quant_convbn1.set_param(convbn1.conv, convbn1.bn) 214 | self.quant_act1 = QuantAct() 215 | 216 | convbn2 = unit.body.conv2 217 | self.quant_convbn2 = QuantBnConv2d() 218 | self.quant_convbn2.set_param(convbn2.conv, convbn2.bn) 219 | self.quant_act2 = QuantAct() 220 | 221 | convbn3 = unit.body.conv3 222 | self.quant_convbn3 = QuantBnConv2d() 223 | self.quant_convbn3.set_param(convbn3.conv, convbn3.bn) 224 | 225 | if self.resize_identity: 226 | self.quant_identity_convbn = QuantBnConv2d() 227 | self.quant_identity_convbn.set_param(unit.identity_conv.conv, unit.identity_conv.bn) 228 | 229 | self.quant_act_int32 = QuantAct() 230 | 231 | def forward(self, x, scaling_factor_int32=None): 232 | # forward using the quantized modules 233 | if self.resize_identity: 234 | x, act_scaling_factor = self.quant_act(x, scaling_factor_int32) 235 | identity_act_scaling_factor = act_scaling_factor.clone() 236 | identity, identity_weight_scaling_factor = self.quant_identity_convbn(x, act_scaling_factor) 237 | else: 238 | identity = x 239 | x, act_scaling_factor = self.quant_act(x, scaling_factor_int32) 240 | 241 | x, weight_scaling_factor = self.quant_convbn1(x, act_scaling_factor) 242 | x = nn.ReLU()(x) 243 | x, act_scaling_factor = self.quant_act1(x, act_scaling_factor, weight_scaling_factor) 244 | 245 | x, weight_scaling_factor = self.quant_convbn2(x, act_scaling_factor) 246 | x = nn.ReLU()(x) 247 | x, act_scaling_factor = self.quant_act2(x, act_scaling_factor, weight_scaling_factor) 248 | 249 | x, weight_scaling_factor = self.quant_convbn3(x, act_scaling_factor) 250 | 251 | x = x + identity 252 | 253 | if self.resize_identity: 254 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, identity, identity_act_scaling_factor, identity_weight_scaling_factor) 255 | else: 256 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, identity, scaling_factor_int32, None) 257 | 258 | x = nn.ReLU()(x) 259 | 260 | return x, act_scaling_factor 261 | 262 | 263 | class Q_ResBlockBn(nn.Module): 264 | """ 265 | Quantized ResNet block with residual path. 266 | """ 267 | def __init__(self): 268 | super(Q_ResBlockBn, self).__init__() 269 | 270 | def set_param(self, unit): 271 | self.resize_identity = unit.resize_identity 272 | 273 | self.quant_act = QuantAct() 274 | 275 | convbn1 = unit.body.conv1 276 | self.quant_convbn1 = QuantBnConv2d() 277 | self.quant_convbn1.set_param(convbn1.conv, convbn1.bn) 278 | 279 | self.quant_act1 = QuantAct() 280 | 281 | convbn2 = unit.body.conv2 282 | self.quant_convbn2 = QuantBnConv2d() 283 | self.quant_convbn2.set_param(convbn2.conv, convbn2.bn) 284 | 285 | if self.resize_identity: 286 | self.quant_identity_convbn = QuantBnConv2d() 287 | self.quant_identity_convbn.set_param(unit.identity_conv.conv, unit.identity_conv.bn) 288 | 289 | self.quant_act_int32 = QuantAct() 290 | 291 | def forward(self, x, scaling_factor_int32=None): 292 | # forward using the quantized modules 293 | if self.resize_identity: 294 | x, act_scaling_factor = self.quant_act(x, scaling_factor_int32) 295 | identity_act_scaling_factor = act_scaling_factor.clone() 296 | identity, identity_weight_scaling_factor = self.quant_identity_convbn(x, act_scaling_factor) 297 | else: 298 | identity = x 299 | x, act_scaling_factor = self.quant_act(x, scaling_factor_int32) 300 | 301 | x, weight_scaling_factor = self.quant_convbn1(x, act_scaling_factor) 302 | x = nn.ReLU()(x) 303 | x, act_scaling_factor = self.quant_act1(x, act_scaling_factor, weight_scaling_factor) 304 | 305 | x, weight_scaling_factor = self.quant_convbn2(x, act_scaling_factor) 306 | 307 | x = x + identity 308 | 309 | if self.resize_identity: 310 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, identity, identity_act_scaling_factor, identity_weight_scaling_factor) 311 | else: 312 | x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, identity, scaling_factor_int32, None) 313 | 314 | x = nn.ReLU()(x) 315 | 316 | return x, act_scaling_factor 317 | 318 | 319 | def q_resnet18(model): 320 | net = Q_ResNet18(model) 321 | return net 322 | 323 | 324 | def q_resnet50(model): 325 | net = Q_ResNet50(model) 326 | return net 327 | 328 | 329 | def q_resnet101(model): 330 | net = Q_ResNet101(model) 331 | return net 332 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSMPQ 2 |

3 | 4 | 5 | 6 |

7 | The official implementation of CSMPQ: Class Separability Based Mixed-Precision Quantization. 8 | 9 | # Requirements 10 | 11 | * [DALI](https://github.com/NVIDIA/DALI) (for accelerating data processing) 12 | * [Apex](https://github.com/NVIDIA/apex) (for distributed running) 13 | * other requirements, running requirements.txt 14 | 15 | ```python 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | 20 | 21 | # Running 22 | 23 | 24 | 25 | 26 | **Bit Configuration** 27 | 28 | 29 | 30 | ```python 31 | #!/usr/bin/env bash 32 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract_cdp.py \ 33 | --model "resnet18" \ 34 | --path "/Path/to/Base_model" \ # pretrained base model 35 | --dataset "imagenet" \ 36 | --save_path '/Path/to/Dataset/' \ # Dataset path 37 | --beta 10.0 \ # Hyper-parameter for bit difference 38 | --model_size 6.7 \ # Target model size 39 | --quant_type "QAT" # Post-Training Quantization(PTQ) or Quantization-Aware Training(QAT) 40 | ``` 41 | 42 | or 43 | 44 | ```python 45 | bash ./mixed_bit/run_scripts/QAT/quant_resnet18.sh 46 | ``` 47 | 48 | 49 | 50 | **QAT** 51 | 52 | Because of random seed, bit configuration obtained through feature extraction may have a little difference from ours. Our bit configurations are given in bit_config.py. Our quantized models and logs are also given in [this](https://drive.google.com/drive/folders/1q0wtmWNdqPZuZqnSCQLScYFNIYzXKebg?usp=sharing) link. 53 | 54 | ```python 55 | #!/usr/bin/env bash 56 | python quant_train.py \ 57 | -a resnet18 \ 58 | --epochs 90 \ 59 | --lr 0.0001 \ 60 | --batch_size 128 \ 61 | --data /Path/to/Dataset/ \ 62 | --save_path /Path/to/Save_quant_model/ \ 63 | --act_range_momentum=0.99 \ 64 | --wd 1e-4 \ 65 | --data_percentage 1 \ 66 | --pretrained \ 67 | --fix_BN \ 68 | --checkpoint_iter -1 \ 69 | --quant_scheme modelsize_6.7_a6_75B 70 | ``` 71 | 72 | or 73 | 74 | ```python 75 | bash ./QAT/run_scripts/train_resnet18.sh 76 | ``` 77 | 78 | 79 | 80 | **PTQ** 81 | 82 | For the post-training quantization, we only require a few GPU hours to get the quantization model. So we set the random seed. You can directly get the same accuracy in the paper by running codes as follows: 83 | 84 | ```python 85 | python main_imagenet.py --data_path /Path/to/Dataset/ --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 8 --act_quant --test_before_calibration --bit_cfg "[4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 3, 3, 4, 4, 3, 3, 3, 3]" 86 | ``` 87 | 88 | or 89 | 90 | ```python 91 | bash ./PTQ/run_scripts/train_resnet18.sh 92 | ``` 93 | 94 | 95 | 96 | ## Related Works 97 | 98 | - [BRECQ: Pushing the Limit of Post-Training Quantization by Block Reconstruction (ICLR 2021)](https://arxiv.org/abs/2102.05426) 99 | -------------------------------------------------------------------------------- /mixed_bit/ORM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def gram_linear(x): 6 | """Compute Gram (kernel) matrix for a linear kernel. 7 | 8 | Args: 9 | x: A num_examples x num_features matrix of features. 10 | 11 | Returns: 12 | A num_examples x num_examples Gram matrix of examples. 13 | """ 14 | return x.dot(x.T) 15 | 16 | 17 | def center_gram(gram, unbiased=False): 18 | """Center a symmetric Gram matrix. 19 | 20 | This is equvialent to centering the (possibly infinite-dimensional) features 21 | induced by the kernel before computing the Gram matrix. 22 | 23 | Args: 24 | gram: A num_examples x num_examples symmetric matrix. 25 | unbiased: Whether to adjust the Gram matrix in order to compute an unbiased 26 | estimate of HSIC. Note that this estimator may be negative. 27 | 28 | Returns: 29 | A symmetric matrix with centered columns and rows. 30 | """ 31 | if not np.allclose(gram, gram.T): 32 | raise ValueError('Input must be a symmetric matrix.') 33 | gram = gram.copy() 34 | 35 | if unbiased: 36 | # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M. 37 | # L. (2014). Partial distance correlation with methods for dissimilarities. 38 | # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically 39 | # stable than the alternative from Song et al. (2007). 40 | n = gram.shape[0] 41 | np.fill_diagonal(gram, 0) 42 | means = np.sum(gram, 0, dtype=np.float64) / (n - 2) 43 | means -= np.sum(means) / (2 * (n - 1)) 44 | gram -= means[:, None] 45 | gram -= means[None, :] 46 | np.fill_diagonal(gram, 0) 47 | else: 48 | means = np.mean(gram, 0, dtype=np.float64) 49 | means -= np.mean(means) / 2 50 | gram -= means[:, None] 51 | gram -= means[None, :] 52 | 53 | return gram 54 | 55 | 56 | def orm(gram_x, gram_y, debiased=False): 57 | """Compute ORM. 58 | 59 | Args: 60 | gram_x: A num_examples x num_examples Gram matrix. 61 | gram_y: A num_examples x num_examples Gram matrix. 62 | debiased: Use unbiased estimator of HSIC. CKA may still be biased. 63 | 64 | Returns: 65 | The value of ORM between X and Y. 66 | """ 67 | gram_x = center_gram(gram_x, unbiased=debiased) 68 | gram_y = center_gram(gram_y, unbiased=debiased) 69 | 70 | # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or 71 | # n*(n-3) (unbiased variant), but this cancels for CKA. 72 | scaled_hsic = gram_x.ravel().dot(gram_y.ravel()) 73 | 74 | normalization_x = np.linalg.norm(gram_x) 75 | normalization_y = np.linalg.norm(gram_y) 76 | return scaled_hsic / (normalization_x * normalization_y) 77 | 78 | 79 | def _debiased_dot_product_similarity_helper( 80 | xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, 81 | n): 82 | """Helper for computing debiased dot product similarity (i.e. linear HSIC).""" 83 | # This formula can be derived by manipulating the unbiased estimator from 84 | # Song et al. (2007). 85 | return ( 86 | xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y) 87 | + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2))) 88 | 89 | 90 | # def feature_space_orm(features_x, features_y, debiased=False): 91 | # """Compute ORM with a linear kernel, in feature space. 92 | # 93 | # This is typically faster than computing the Gram matrix when there are fewer 94 | # features than examples. 95 | # 96 | # Args: 97 | # features_x: A num_examples x num_features matrix of features. 98 | # features_y: A num_examples x num_features matrix of features. 99 | # debiased: Use unbiased estimator of dot product similarity. ORM may still be 100 | # biased. Note that this estimator may be negative. 101 | # 102 | # Returns: 103 | # The value of ORM between X and Y. 104 | # """ 105 | # features_x = features_x - torch.mean(features_x, 0, keepdim=True) 106 | # features_y = features_y - torch.mean(features_y, 0, keepdim=True) 107 | # 108 | # a = torch.mm(features_x.t(), features_y) 109 | # b = torch.mm(features_x.t(), features_x) 110 | # c = torch.mm(features_y.t(), features_y) 111 | # dot_product_similarity = torch.linalg.norm(a) ** 2 112 | # normalization_x = torch.linalg.norm(b) 113 | # normalization_y = torch.linalg.norm(c) 114 | # 115 | # if debiased: 116 | # n = features_x.shape[0] 117 | # # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array. 118 | # sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x) 119 | # sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y) 120 | # squared_norm_x = np.sum(sum_squared_rows_x) 121 | # squared_norm_y = np.sum(sum_squared_rows_y) 122 | # 123 | # dot_product_similarity = _debiased_dot_product_similarity_helper( 124 | # dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y, 125 | # squared_norm_x, squared_norm_y, n) 126 | # normalization_x = np.sqrt(_debiased_dot_product_similarity_helper( 127 | # normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x, 128 | # squared_norm_x, squared_norm_x, n)) 129 | # normalization_y = np.sqrt(_debiased_dot_product_similarity_helper( 130 | # normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y, 131 | # squared_norm_y, squared_norm_y, n)) 132 | # 133 | # return dot_product_similarity / (normalization_x * normalization_y) 134 | 135 | 136 | def feature_space_orm(features_x, features_y, debiased=False): 137 | """Compute ORM with a linear kernel, in feature space. 138 | 139 | This is typically faster than computing the Gram matrix when there are fewer 140 | features than examples. 141 | 142 | Args: 143 | features_x: A num_examples x num_features matrix of features. 144 | features_y: A num_examples x num_features matrix of features. 145 | debiased: Use unbiased estimator of dot product similarity. ORM may still be 146 | biased. Note that this estimator may be negative. 147 | 148 | Returns: 149 | The value of ORM between X and Y. 150 | """ 151 | features_x = features_x - np.mean(features_x, 0, keepdims=True) 152 | features_y = features_y - np.mean(features_y, 0, keepdims=True) 153 | 154 | dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2 155 | normalization_x = np.linalg.norm(features_x.T.dot(features_x)) 156 | normalization_y = np.linalg.norm(features_y.T.dot(features_y)) 157 | 158 | if debiased: 159 | n = features_x.shape[0] 160 | # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array. 161 | sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x) 162 | sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y) 163 | squared_norm_x = np.sum(sum_squared_rows_x) 164 | squared_norm_y = np.sum(sum_squared_rows_y) 165 | 166 | dot_product_similarity = _debiased_dot_product_similarity_helper( 167 | dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y, 168 | squared_norm_x, squared_norm_y, n) 169 | normalization_x = np.sqrt(_debiased_dot_product_similarity_helper( 170 | normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x, 171 | squared_norm_x, squared_norm_x, n)) 172 | normalization_y = np.sqrt(_debiased_dot_product_similarity_helper( 173 | normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y, 174 | squared_norm_y, squared_norm_y, n)) 175 | 176 | return dot_product_similarity / (normalization_x * normalization_y) 177 | -------------------------------------------------------------------------------- /mixed_bit/cdp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import random 5 | import argparse 6 | import functools 7 | import numpy as np 8 | from datetime import datetime 9 | import pdb 10 | import torch 11 | import torchvision 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torchvision import datasets, transforms, models 15 | from nni.algorithms.compression.pytorch.pruning \ 16 | import L1FilterPruner, L1FilterPrunerMasker 17 | 18 | # class TFIDFMasker(L1FilterPrunerMasker): 19 | # def __init__(self, model, pruner, threshold, tf_idf_map, preserve_round=1, dependency_aware=False): 20 | # super().__init__(model, pruner, preserve_round, dependency_aware) 21 | # self.threshold=threshold 22 | # self.tf_idf_map=tf_idf_map 23 | 24 | # def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): 25 | # # get the l1-norm sum for each filter 26 | # w_tf_idf_structured = self.get_tf_idf_mask(wrapper, wrapper_idx) 27 | 28 | # mask_weight = torch.gt(w_tf_idf_structured, self.threshold)[ 29 | # :, None, None, None].expand_as(weight).type_as(weight) 30 | # mask_bias = torch.gt(w_tf_idf_structured, self.threshold).type_as( 31 | # weight).detach() if base_mask['bias_mask'] is not None else None 32 | 33 | # return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias} 34 | 35 | # def get_tf_idf_mask(self, wrapper, wrapper_idx): 36 | # name = wrapper.name 37 | # if wrapper.name.split('.')[-1] == 'module': 38 | # name = wrapper.name[0:-7] 39 | # #print(name) 40 | # w_tf_idf_structured = self.tf_idf_map[name] 41 | # return w_tf_idf_structured 42 | 43 | 44 | # class TFIDFPruner(L1FilterPruner): 45 | # def __init__(self, model, config_list, cdp_config:dict, pruning_algorithm='l1', optimizer=None, **algo_kwargs): 46 | # super().__init__(model, config_list, optimizer) 47 | # self.set_wrappers_attribute("if_calculated", False) 48 | # self.masker = TFIDFMasker(model, self, threshold=cdp_config["threshold"], tf_idf_map=cdp_config["map"], **algo_kwargs) 49 | # def update_masker(self,model,threshold,mapper): 50 | # self.masker = TFIDFMasker(model, self, threshold=threshold, tf_idf_map=mapper) 51 | 52 | def feature_preprocess(feature): 53 | for i in range(len(feature)): #([10000, 64, 16, 16]) 54 | # feature[i] = F.relu(feature[i]) #已经在relu后了,所以注释掉 55 | # pdb.set_trace() 56 | if len(feature[i].size()) == 4: #卷积层输出 57 | feature[i] = F.relu(feature[i]) #新增for mobilenet 58 | feature[i] = F.avg_pool2d(feature[i], feature[i].size()[3]) #([10000, 64, 1, 1]) 59 | else: 60 | feature[i] = F.relu(feature[i]) 61 | feature[i].view([feature[i].size()[0], feature[i].size()[1], 1, 1]) 62 | # fc层输出本来就是二维的,不用变换 63 | feature[i] = feature[i].view(feature[i].size()[0], -1) #([10000, 64]) 64 | feature[i] = feature[i].transpose(0, 1) #([64, 10000]) 65 | return feature 66 | 67 | # def acculumate_feature(model, loader, stop:int): 68 | # model=model.cuda() 69 | # features = {} 70 | 71 | # def hook_func(m, x, y, name, feature_iit): 72 | # #print(name, y.shape) 73 | # f = F.relu(y) 74 | # #f = y 75 | # feature = F.avg_pool2d(f, f.size()[3]) 76 | # feature = feature.view(f.size()[0], -1) 77 | # feature = feature.transpose(0, 1) 78 | # if name not in feature_iit: 79 | # feature_iit[name] = feature.cpu() 80 | # else: 81 | # feature_iit[name] = torch.cat([feature_iit[name], feature.cpu()], 1) 82 | 83 | # hook=functools.partial(hook_func, feature_iit=features) 84 | 85 | # handler_list=[] 86 | # for name, m in model.named_modules(): 87 | # if isinstance(m, nn.Conv2d): 88 | # #if not isinstance(m, nn.Linear): 89 | # handler = m.register_forward_hook(functools.partial(hook, name=name)) 90 | # handler_list.append(handler) 91 | # for batch_idx, (inputs, targets) in enumerate(loader): 92 | # if batch_idx % (stop//10) == 0: 93 | # print(batch_idx) 94 | # if batch_idx >= stop: 95 | # break 96 | # model.eval() 97 | # with torch.no_grad(): 98 | # model(inputs.cuda()) 99 | 100 | # [ k.remove() for k in handler_list] 101 | # return features 102 | 103 | def calc_tf_idf(feature:dict, coe:int, tf_idf_map:dict): # feature = [c, n] ([64, 10000]) name = 'conv_bn.conv' 104 | # calc tf 105 | # pdb.set_trace() 106 | balance_coe = np.log((feature.shape[0]/coe)*np.e) if coe else 1.0 #文中公式(8),float型标量 107 | # calc idf 108 | sample_quant = float(feature.shape[1]) #10000 109 | sample_mean = feature.mean(dim=1).view(feature.shape[0], 1) #得到每个通道的均值 ([64, 1]) 110 | sample_inverse = (feature >= sample_mean).sum(dim=1).type(torch.FloatTensor) #([64]),文章的Si*,表示每个通道中,比sample_mean大的个数有多少个 111 | 112 | # calc tf mean 113 | feature_sum = feature.sum(dim=0) #([10000]),每个通道累加 114 | tf = (feature / feature_sum) * balance_coe #文中公式(8) ([64, 10000]) 115 | tf_mean = (tf * ((feature) >= sample_mean)).sum(dim=1) # Sa ([64]) 116 | tf_mean_new = tf_mean / (((feature) >= sample_mean).sum(dim=1) + 1e-11 ) # ([64]) 117 | 118 | idf = torch.log(sample_quant / (sample_inverse + 1.0)) #文中公式(7) ([64]) 119 | idf = idf.cuda() #新增,idf必须和tf_mean都在gpu上 120 | 121 | # pdb.set_trace() 122 | importance = tf_mean_new * idf #文中公式(10) ([64]) 每个输出通道对应一个importance值 123 | importance = importance.mean().item() #新增 124 | tf_idf_map.append(importance) 125 | 126 | def calculate_cdp(features:dict, coe:int): 127 | tf_idf_map = [] 128 | for feature in features: 129 | # pdb.set_trace() 130 | calc_tf_idf(feature, coe=coe, tf_idf_map=tf_idf_map) 131 | return tf_idf_map 132 | 133 | 134 | -------------------------------------------------------------------------------- /mixed_bit/data_providers/__init__.py: -------------------------------------------------------------------------------- 1 | class DataProvider: 2 | VALID_SEED = 0 # random seed for the validation set 3 | 4 | @staticmethod 5 | def name(): 6 | """ Return name of the dataset """ 7 | raise NotImplementedError 8 | 9 | @property 10 | def data_shape(self): 11 | """ Return shape as python list of one data entry """ 12 | raise NotImplementedError 13 | 14 | @property 15 | def n_classes(self): 16 | """ Return `int` of num classes """ 17 | raise NotImplementedError 18 | 19 | @property 20 | def save_path(self): 21 | """ local path to save the data """ 22 | raise NotImplementedError 23 | 24 | @property 25 | def data_url(self): 26 | """ link to download the data """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /mixed_bit/data_providers/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import torch.utils.data 5 | 6 | import utils 7 | from data_providers import DataProvider 8 | 9 | 10 | def make_imagenet_subset(path2subset, n_sub_classes, path2imagenet='/userhome/memory_data/imagenet'): 11 | imagenet_train_folder = os.path.join(path2imagenet, 'train') 12 | imagenet_val_folder = os.path.join(path2imagenet, 'val') 13 | 14 | subfolders = sorted([f.path for f in os.scandir(imagenet_train_folder) if f.is_dir()]) 15 | # np.random.seed(DataProvider.VALID_SEED) 16 | np.random.shuffle(subfolders) 17 | 18 | chosen_train_folders = subfolders[:n_sub_classes] 19 | class_name_list = [] 20 | for train_folder in chosen_train_folders: 21 | class_name = train_folder.split('/')[-1] 22 | class_name_list.append(class_name) 23 | 24 | print('=> Start building subset%d' % n_sub_classes) 25 | for cls_name in class_name_list: 26 | src_train_folder = os.path.join(imagenet_train_folder, cls_name) 27 | target_train_folder = os.path.join(path2subset, 'train/%s' % cls_name) 28 | shutil.copytree(src_train_folder, target_train_folder) 29 | print('Train: %s -> %s' % (src_train_folder, target_train_folder)) 30 | 31 | src_val_folder = os.path.join(imagenet_val_folder, cls_name) 32 | target_val_folder = os.path.join(path2subset, 'val/%s' % cls_name) 33 | shutil.copytree(src_val_folder, target_val_folder) 34 | print('Val: %s -> %s' % (src_val_folder, target_val_folder)) 35 | print('=> Finish building subset%d' % n_sub_classes) 36 | 37 | 38 | class ImagenetDataProvider(DataProvider): 39 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, 40 | n_worker=24, manual_seed = 12, load_type='dali', local_rank=0, world_size=1, **kwargs): 41 | 42 | self._save_path = save_path 43 | self.valid = None 44 | if valid_size is not None: 45 | pass 46 | else: 47 | self.train = utils.get_imagenet_iter(data_type='train', image_dir=self.train_path, 48 | batch_size=train_batch_size, num_threads=n_worker, 49 | device_id=local_rank, manual_seed=manual_seed, 50 | num_gpus=torch.cuda.device_count(), crop=self.image_size, 51 | val_size=self.image_size, world_size=world_size, local_rank=local_rank) 52 | self.test = utils.get_imagenet_iter(data_type='val', image_dir=self.valid_path, manual_seed=manual_seed, 53 | batch_size=test_batch_size, num_threads=n_worker, device_id=local_rank, 54 | num_gpus=torch.cuda.device_count(), crop=self.image_size, 55 | val_size=256, world_size=world_size, local_rank=local_rank) 56 | if self.valid is None: 57 | self.valid = self.test 58 | 59 | @staticmethod 60 | def name(): 61 | return 'imagenet' 62 | 63 | @property 64 | def data_shape(self): 65 | return 3, self.image_size, self.image_size # C, H, W 66 | 67 | @property 68 | def n_classes(self): 69 | return 1000 70 | 71 | @property 72 | def save_path(self): 73 | if self._save_path is None: 74 | self._save_path = '/userhome/data/imagenet' 75 | return self._save_path 76 | 77 | @property 78 | def data_url(self): 79 | raise ValueError('unable to download ImageNet') 80 | 81 | @property 82 | def train_path(self): 83 | return os.path.join(self.save_path, 'train') 84 | 85 | @property 86 | def valid_path(self): 87 | return os.path.join(self._save_path, 'val') 88 | 89 | @property 90 | def resize_value(self): 91 | return 256 92 | 93 | @property 94 | def image_size(self): 95 | return 224 96 | -------------------------------------------------------------------------------- /mixed_bit/feature_extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import math 5 | from scipy import optimize 6 | import random 7 | import ORM 8 | import torch, time 9 | import torch.nn as nn 10 | from run_manager import RunManager 11 | from models import ResNet_ImageNet, MobileNetV2, TrainRunConfig 12 | from pulp import * 13 | from utils.pytorch_utils import DFS_bit 14 | import pulp 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | """ model config """ 19 | parser.add_argument('--path', type=str) 20 | parser.add_argument('--model', type=str, default="vgg", 21 | choices=['resnet50', 'mobilenetv2', 'mobilenet', 'resnet18']) 22 | parser.add_argument('--cfg', type=str, default="None") 23 | parser.add_argument('--manual_seed', default=0, type=int) 24 | parser.add_argument("--model_size", default=0, type=float) 25 | parser.add_argument("--beta", default=1, type=float) 26 | parser.add_argument('--quant_type', type=str, default='QAT', 27 | choices=['QAT', 'PTQ']) 28 | 29 | """ dataset config """ 30 | parser.add_argument('--dataset', type=str, default='cifar10', 31 | choices=['cifar10', 'imagenet']) 32 | parser.add_argument('--save_path', type=str, default='/Path/to/Dataset') 33 | 34 | """ runtime config """ 35 | parser.add_argument('--gpu', help='gpu available', default='0') 36 | parser.add_argument('--train_batch_size', type=int, default=32) 37 | parser.add_argument('--n_worker', type=int, default=24) 38 | parser.add_argument("--local_rank", default=0, type=int) 39 | 40 | if __name__ == '__main__': 41 | args = parser.parse_args() 42 | 43 | # cpu_num = 1 44 | # os.environ['OMP_NUM_THREADS'] = str(cpu_num) 45 | # os.environ['OPENBLAS_NUM_THREADS'] = str(cpu_num) 46 | # os.environ['MKL_NUM_THREADS'] = str(cpu_num) 47 | # os.environ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num) 48 | # os.environ['NUMEXPR_NUM_THREADS'] = str(cpu_num) 49 | # torch.set_num_threads(cpu_num) 50 | 51 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 52 | torch.cuda.set_device(0) 53 | 54 | random.seed(args.manual_seed) 55 | torch.manual_seed(args.manual_seed) 56 | torch.cuda.manual_seed_all(args.manual_seed) 57 | np.random.seed(args.manual_seed) 58 | # distributed setting 59 | torch.distributed.init_process_group(backend='nccl', 60 | init_method='env://') 61 | args.world_size = torch.distributed.get_world_size() 62 | 63 | # prepare run config 64 | run_config_path = '%s/run.config' % args.path 65 | 66 | run_config = TrainRunConfig( 67 | **args.__dict__ 68 | ) 69 | if args.local_rank == 0: 70 | print('Run config:') 71 | for k, v in args.__dict__.items(): 72 | print('\t%s: %s' % (k, v)) 73 | 74 | if args.model == "resnet18": 75 | assert args.dataset == 'imagenet', 'resnet18 only supports imagenet dataset' 76 | net = ResNet_ImageNet( 77 | depth=18, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 78 | elif args.model == "resnet50": 79 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset' 80 | net = ResNet_ImageNet( 81 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 82 | elif args.model == "mobilenetv2": 83 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset' 84 | net = MobileNetV2( 85 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 86 | 87 | 88 | # build run manager 89 | run_manager = RunManager(args.path, net, run_config) 90 | 91 | # load checkpoints 92 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path 93 | assert os.path.isfile(best_model_path), 'wrong path' 94 | if torch.cuda.is_available(): 95 | checkpoint = torch.load(best_model_path) 96 | else: 97 | checkpoint = torch.load(best_model_path, map_location='cpu') 98 | if 'state_dict' in checkpoint: 99 | checkpoint = checkpoint['state_dict'] 100 | run_manager.net.load_state_dict(checkpoint) 101 | output_dict = {} 102 | 103 | # feature extract 104 | # start = time.time() 105 | data_loader = run_manager.run_config.train_loader 106 | data = next(iter(data_loader)) 107 | data = data[0] 108 | n = data.size()[0] 109 | 110 | with torch.no_grad(): 111 | feature = net.feature_extract(data, args.quant_type) 112 | 113 | for i in range(len(feature)): 114 | feature[i] = feature[i].view(n, -1) 115 | feature[i] = feature[i].data.cpu().numpy() 116 | 117 | orthogonal_matrix = np.zeros((len(feature), len(feature))) 118 | 119 | for i in range(len(feature)): 120 | for j in range(len(feature)): 121 | with torch.no_grad(): 122 | orthogonal_matrix[i][j] = ORM.orm(ORM.gram_linear(feature[i]), ORM.gram_linear(feature[j])) 123 | 124 | def sum_list(a, j): 125 | b = 0 126 | for i in range(len(a)): 127 | if i != j: 128 | b += a[i] 129 | return b 130 | 131 | theta = [] 132 | gamma = [] 133 | flops = [] 134 | 135 | for i in range(len(feature)): 136 | gamma.append( sum_list(orthogonal_matrix[i], i) ) 137 | 138 | # e^-x 139 | for i in range(len(feature)): 140 | theta.append( 1 * math.exp(-1* args.beta *gamma[i]) ) 141 | theta = np.array(theta) 142 | theta = np.negative(theta) 143 | 144 | length = len(feature) 145 | # layerwise 146 | params, first_last_size = net.cfg2params_perlayer(net.cfg, length, args.quant_type) 147 | FLOPs, first_last_flops = net.cfg2flops_layerwise(net.cfg, length, args.quant_type) 148 | params = [i/(1024*1024) for i in params] 149 | first_last_size = first_last_size/(1024*1024) 150 | 151 | 152 | # Objective function 153 | def func(x, sign=1.0): 154 | """ Objective function """ 155 | global theta,length 156 | sum_fuc =[] 157 | for i in range(length): 158 | temp = 0. 159 | for j in range(i,length): 160 | temp += theta[j] 161 | sum_fuc.append( x[i] * (sign * temp / (length-i)) ) 162 | 163 | return sum(sum_fuc) 164 | 165 | # Derivative function of objective function 166 | def func_deriv(x, sign=1.0): 167 | """ Derivative of objective function """ 168 | global theta, length 169 | diff = [] 170 | for i in range(length): 171 | temp1 = 0. 172 | for j in range(i, length): 173 | temp1 += theta[j] 174 | diff.append(sign * temp1 / (length - i)) 175 | 176 | return np.array(diff) 177 | 178 | # Constraint function 179 | def constrain_func(x): 180 | """ constrain function """ 181 | global params, length 182 | a = [] 183 | for i in range(length): 184 | a.append(x[i] * params[i]) 185 | return np.array([args.model_size - first_last_size - sum(a)]) 186 | 187 | bnds = [] # bit search space: (0.25,0.5) for PTQ and (0.5,1.0) for QAT , 0.25表示2bit,0.5表示4bit,1.0表示8bit 188 | if args.quant_type == 'PTQ': #PTQ对所有层搜索空间都是2~4bit 189 | for i in range(length): 190 | bnds.append((0.25, 0.5)) 191 | else: 192 | for i in range(length): #QAT对首尾两层限制为8bit,对其他层搜索空间为4~8bit 193 | bnds.append((0.5, 1.0)) 194 | 195 | bnds = tuple(bnds) 196 | cons = ({'type': 'ineq', 197 | 'fun': constrain_func} 198 | ) 199 | 200 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons) 201 | 202 | if args.model == "resnet18": #resnet18要特殊处理一下 203 | prun_bitcfg, _ = DFS_bit(result.x[::-1] * 8, [params[length - i - 1] for i in range(length)]) 204 | prun_bitcfg = [prun_bitcfg[length - i - 1] for i in range(length)] 205 | else: 206 | prun_bitcfg = np.around(result.x * 8) #其他网络直接使用result.x * 8取整得到bitconfig 207 | # end = time.time() 208 | # print("Use", end - start, "seconds. ") 209 | 210 | 211 | optimize_cfg = [] 212 | if type(prun_bitcfg[0]) != int: 213 | for i in range(len(prun_bitcfg)): 214 | b = list(prun_bitcfg)[i].tolist() 215 | optimize_cfg.append(int(b)) 216 | else: 217 | optimize_cfg =prun_bitcfg 218 | # print(result.x) 219 | print(optimize_cfg) 220 | print("Quantization model is", np.sum(np.array(optimize_cfg) * np.array(params) / 8) + first_last_size, "Mb") 221 | print("Original model is", np.sum(np.array(params)) / 8 * 32 + first_last_size / 8 * 32 , "Mb") 222 | print('Quantization model BOPs is', 223 | (first_last_flops * 8*8 + sum([FLOPs[i] * optimize_cfg[i] *5 for i in range(length)])) / 1e9) 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /mixed_bit/feature_extract_cdp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import math 5 | from scipy import optimize 6 | import random 7 | import ORM 8 | import torch, time 9 | import torch.nn as nn 10 | from run_manager import RunManager 11 | from models import ResNet_ImageNet, MobileNetV2, TrainRunConfig 12 | from pulp import * 13 | from utils.pytorch_utils import DFS_bit 14 | import pulp 15 | from cdp import calculate_cdp, feature_preprocess 16 | # get_threshold_by_flops, get_threshold_by_sparsity 17 | # TFIDFPruner,acculumate_feature 18 | import pdb 19 | import copy 20 | parser = argparse.ArgumentParser() 21 | 22 | """ model config """ 23 | parser.add_argument('--path', type=str) 24 | parser.add_argument('--model', type=str, default="resnet18", 25 | choices=['resnet50', 'mobilenetv2', 'mobilenet', 'resnet18']) 26 | parser.add_argument('--cfg', type=str, default="None") 27 | parser.add_argument('--manual_seed', default=0, type=int) 28 | parser.add_argument("--model_size", default=0, type=float) 29 | parser.add_argument("--beta", default=1, type=float) 30 | parser.add_argument('--quant_type', type=str, default='QAT', 31 | choices=['QAT', 'PTQ']) 32 | 33 | """ dataset config """ 34 | parser.add_argument('--dataset', type=str, default='imagenet', 35 | choices=['cifar10', 'imagenet']) 36 | parser.add_argument('--save_path', type=str, default='/home/data/imagenet') 37 | 38 | """ runtime config """ 39 | parser.add_argument('--gpu', help='gpu available', default='0') 40 | parser.add_argument('--train_batch_size', type=int, default=32) 41 | parser.add_argument('--n_worker', type=int, default=24) 42 | parser.add_argument("--local_rank", default=0, type=int) 43 | 44 | """ cdp config """ 45 | parser.add_argument('--coe', type=int, help='whether to use balance coefficient') 46 | 47 | if __name__ == '__main__': 48 | args = parser.parse_args() 49 | 50 | # cpu_num = 1 51 | # os.environ['OMP_NUM_THREADS'] = str(cpu_num) 52 | # os.environ['OPENBLAS_NUM_THREADS'] = str(cpu_num) 53 | # os.environ['MKL_NUM_THREADS'] = str(cpu_num) 54 | # os.environ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num) 55 | # os.environ['NUMEXPR_NUM_THREADS'] = str(cpu_num) 56 | # torch.set_num_threads(cpu_num) 57 | 58 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 59 | torch.cuda.set_device(1) 60 | 61 | 62 | random.seed(args.manual_seed) 63 | torch.manual_seed(args.manual_seed) 64 | torch.cuda.manual_seed_all(args.manual_seed) 65 | np.random.seed(args.manual_seed) 66 | # distributed setting 67 | torch.distributed.init_process_group(backend='nccl', 68 | init_method='env://') 69 | args.world_size = torch.distributed.get_world_size() 70 | 71 | # prepare run config 72 | run_config_path = '%s/run.config' % args.path 73 | 74 | run_config = TrainRunConfig( 75 | **args.__dict__ 76 | ) 77 | if args.local_rank == 0: 78 | print('Run config:') 79 | for k, v in args.__dict__.items(): 80 | print('\t%s: %s' % (k, v)) 81 | 82 | if args.model == "resnet18": 83 | assert args.dataset == 'imagenet', 'resnet18 only supports imagenet dataset' 84 | net = ResNet_ImageNet( 85 | depth=18, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 86 | elif args.model == "resnet50": 87 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset' 88 | net = ResNet_ImageNet( 89 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 90 | elif args.model == "mobilenetv2": 91 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset' 92 | net = MobileNetV2( 93 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 94 | 95 | 96 | # build run manager 97 | run_manager = RunManager(args.path, net, run_config) 98 | 99 | # load checkpoints 100 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path 101 | assert os.path.isfile(best_model_path), 'wrong path' 102 | if torch.cuda.is_available(): 103 | checkpoint = torch.load(best_model_path) 104 | else: 105 | checkpoint = torch.load(best_model_path, map_location='cpu') 106 | if 'state_dict' in checkpoint: 107 | checkpoint = checkpoint['state_dict'] 108 | run_manager.net.load_state_dict(checkpoint) 109 | output_dict = {} 110 | 111 | # feature extract 112 | # start = time.time() 113 | data_loader = run_manager.run_config.train_loader 114 | data = next(iter(data_loader)) 115 | data = data[0] 116 | n = data.size()[0] 117 | print(net) 118 | with torch.no_grad(): 119 | feature = net.feature_extract(data, args.quant_type) 120 | 121 | # for i in range(len(feature)): 122 | # feature[i] = feature[i].view(n, -1) 123 | # feature[i] = feature[i].data.cpu().numpy() 124 | # pdb.set_trace() 125 | feature_iit = feature_preprocess(feature) 126 | tf_idf_map = calculate_cdp(feature_iit,args.coe) 127 | # pdb.set_trace() 128 | # threshold = get_threshold_by_sparsity(tf_idf_map,sparsity) 129 | 130 | # orthogonal_matrix = np.zeros((len(feature), len(feature))) 131 | 132 | # for i in range(len(feature)): 133 | # for j in range(len(feature)): 134 | # with torch.no_grad(): 135 | # orthogonal_matrix[i][j] = ORM.orm(ORM.gram_linear(feature[i]), ORM.gram_linear(feature[j])) 136 | 137 | # def sum_list(a, j): 138 | # b = 0 139 | # for i in range(len(a)): 140 | # if i != j: 141 | # b += a[i] 142 | # return b 143 | 144 | theta = [] 145 | gamma = [] 146 | flops = [] 147 | 148 | # for i in range(len(feature)): 149 | # gamma.append( sum_list(orthogonal_matrix[i], i) ) 150 | 151 | # e^-x 152 | # for i in range(len(feature)): 153 | # theta.append( 1 * math.exp(-1* args.beta *gamma[i]) ) 154 | # theta = np.array(theta) 155 | # theta = np.negative(theta) 156 | gamma = np.array(tf_idf_map) 157 | 158 | if args.quant_type == 'QAT': 159 | # for i in range(len(gamma)): 160 | # theta.append( 1 * math.exp(-1* args.beta *gamma[i]) ) 161 | theta = copy.deepcopy(gamma) 162 | elif args.quant_type == 'PTQ': 163 | for i in range(len(gamma)): 164 | theta.append( 1 * math.exp(-1* args.beta *gamma[i]) ) 165 | # theta = copy.deepcopy(gamma) 166 | theta = np.array(theta) 167 | # for x in gamma: 168 | # print('%.4f '%x) 169 | 170 | # theta = np.array(tf_idf_map) 171 | for x in theta: 172 | print('%.4f '%x) 173 | theta = np.negative(theta) 174 | 175 | # theta = array([0.01322103, 0.01361509, 0.01278478, 0.01267193, 0.00762108, 0.00755147, 0.00914501, 0.00717208, 0.00505306, 0.00479445, 176 | # 0.00673114, 0.00456354, 0.00451991, 0.00461229, 0.00489852, 0.00488713]) 177 | 178 | length = len(feature) 179 | # layerwise 180 | params, first_last_size = net.cfg2params_perlayer(net.cfg, length, args.quant_type) 181 | FLOPs, first_last_flops = net.cfg2flops_layerwise(net.cfg, length, args.quant_type) 182 | params = [i/(1024*1024) for i in params] 183 | first_last_size = first_last_size/(1024*1024) 184 | 185 | 186 | # Objective function 187 | # def func(x, sign=1.0): 188 | # """ Objective function """ 189 | # global theta,length 190 | # sum_fuc =[] 191 | # for i in range(length): 192 | # temp = 0. 193 | # for j in range(i,length): 194 | # temp += theta[j] 195 | # sum_fuc.append( x[i] * (sign * temp / (length-i)) ) 196 | 197 | # return sum(sum_fuc) 198 | # 修改后 199 | def func(x, sign=1.0): 200 | """ Objective function """ 201 | global theta,length 202 | sum_fuc =[] 203 | for i in range(length): 204 | sum_fuc.append( x[i] * (sign * theta[i]) ) 205 | return sum(sum_fuc) 206 | 207 | # Derivative function of objective function 208 | # def func_deriv(x, sign=1.0): 209 | # """ Derivative of objective function """ 210 | # global theta, length 211 | # diff = [] 212 | # for i in range(length): 213 | # temp1 = 0. 214 | # for j in range(i, length): 215 | # temp1 += theta[j] 216 | # diff.append(sign * temp1 / (length - i)) 217 | 218 | # return np.array(diff) 219 | # 修改后 220 | def func_deriv(x, sign=1.0): 221 | """ Derivative of objective function """ 222 | global theta, length 223 | diff = [] 224 | for i in range(length): 225 | diff.append(sign * theta[i]) 226 | 227 | return np.array(diff) 228 | 229 | 230 | # Constraint function 231 | def constrain_func(x): 232 | """ constrain function """ 233 | global params, length 234 | a = [] 235 | for i in range(length): 236 | a.append(x[i] * params[i]) 237 | return np.array([args.model_size - first_last_size - sum(a)]) 238 | 239 | bnds = [] # bit search space: (0.25,0.5) for PTQ and (0.5,1.0) for QAT , 0.25表示2bit,0.5表示4bit,1.0表示8bit 240 | if args.quant_type == 'PTQ': #PTQ对所有层搜索空间都是2~4bit 241 | for i in range(length): 242 | bnds.append((0.25, 0.5)) 243 | else: 244 | for i in range(length): #QAT对首尾两层限制为8bit,对其他层搜索空间为4~8bit 245 | bnds.append((0.5, 1.0)) 246 | 247 | bnds = tuple(bnds) 248 | cons = ({'type': 'ineq', 249 | 'fun': constrain_func} 250 | ) 251 | 252 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons) 253 | # pdb.set_trace() 254 | 255 | if args.model == "resnet18": #resnet18要特殊处理一下 256 | prun_bitcfg, _ = DFS_bit(result.x[::-1] * 8, [params[length - i - 1] for i in range(length)]) 257 | prun_bitcfg = [prun_bitcfg[length - i - 1] for i in range(length)] 258 | else: 259 | prun_bitcfg = np.around(result.x * 8) #其他网络直接使用result.x * 8取整得到bitconfig 260 | # end = time.time() 261 | # print("Use", end - start, "seconds. ") 262 | 263 | 264 | optimize_cfg = [] 265 | if type(prun_bitcfg[0]) != int: 266 | for i in range(len(prun_bitcfg)): 267 | b = list(prun_bitcfg)[i].tolist() 268 | optimize_cfg.append(int(b)) 269 | else: 270 | optimize_cfg =prun_bitcfg 271 | # print(result.x) 272 | # pdb.set_trace() 273 | print(optimize_cfg) 274 | print("Quantization model is", np.sum(np.array(optimize_cfg) * np.array(params) / 8) + first_last_size, "Mb") 275 | print("Original model is", np.sum(np.array(params)) / 8 * 32 + first_last_size / 8 * 32 , "Mb") 276 | print('Quantization model BOPs is', 277 | (first_last_flops * 8*8 + sum([FLOPs[i] * optimize_cfg[i] *5 for i in range(length)])) / 1e9) 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | -------------------------------------------------------------------------------- /mixed_bit/models/__init__.py: -------------------------------------------------------------------------------- 1 | from run_manager import RunConfig 2 | 3 | from .base_models import * 4 | from .resnet_imagenet import * 5 | from .mobilenet_imagenet import * 6 | 7 | 8 | class TrainRunConfig(RunConfig): 9 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 10 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 11 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 12 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 13 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, 14 | warm_epoch=5, save_path=None, base_path=None, **kwargs): 15 | super(TrainRunConfig, self).__init__( 16 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 17 | dataset, train_batch_size, test_batch_size, 18 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 19 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 20 | warm_epoch 21 | ) 22 | 23 | self.n_worker = n_worker 24 | self.save_path = save_path 25 | self.base_path = base_path 26 | 27 | print(kwargs.keys()) 28 | 29 | @property 30 | def data_config(self): 31 | return { 32 | 'train_batch_size': self.train_batch_size, 33 | 'test_batch_size': self.test_batch_size, 34 | 'n_worker': self.n_worker, 35 | 'local_rank': self.local_rank, 36 | 'world_size': self.world_size, 37 | 'save_path': self.save_path, 38 | } 39 | 40 | class SearchRunConfig(RunConfig): 41 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 42 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 43 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 44 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 45 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, warm_epoch=5, save_path=None, 46 | search_epoch=10, target_flops=1000, n_remove=2, n_best=3, n_populations=8, n_generations=25, div=8, **kwargs): 47 | super(SearchRunConfig, self).__init__( 48 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 49 | dataset, train_batch_size, test_batch_size, 50 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 51 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 52 | warm_epoch 53 | ) 54 | 55 | self.div = div 56 | self.n_worker = n_worker 57 | self.save_path = save_path 58 | self.search_epoch = search_epoch 59 | self.target_flops = target_flops 60 | self.n_remove = n_remove 61 | self.n_best = n_best 62 | self.n_populations = n_populations 63 | self.n_generations = n_generations 64 | 65 | print(kwargs.keys()) 66 | 67 | @property 68 | def data_config(self): 69 | return { 70 | 'train_batch_size': self.train_batch_size, 71 | 'test_batch_size': self.test_batch_size, 72 | 'n_worker': self.n_worker, 73 | 'local_rank': self.local_rank, 74 | 'world_size': self.world_size, 75 | 'save_path': self.save_path 76 | } -------------------------------------------------------------------------------- /mixed_bit/models/base_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from utils import count_parameters 5 | 6 | class MyNetwork(nn.Module): 7 | def forward(self, x): 8 | raise NotImplementedError 9 | 10 | def feature_extract(self, x): 11 | raise NotImplementedError 12 | 13 | @property 14 | def config(self): # should include name/cfg/cfg_base/dataset 15 | raise NotImplementedError 16 | 17 | def cfg2params(self, cfg): 18 | raise NotImplementedError 19 | 20 | def cfg2flops(self, cfg): 21 | raise NotImplementedError 22 | 23 | def set_bn_param(self, momentum, eps): 24 | for m in self.modules(): 25 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 26 | m.momentum = momentum 27 | m.eps = eps 28 | return 29 | 30 | def get_bn_param(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 33 | return { 34 | 'momentum': m.momentum, 35 | 'eps': m.eps, 36 | } 37 | return None 38 | 39 | def init_model(self, model_init, init_div_groups=False): 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | if model_init == 'he_fout': 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | if init_div_groups: 45 | n /= m.groups 46 | m.weight.data.normal_(0, math.sqrt(2. / n)) 47 | elif model_init == 'he_fin': 48 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 49 | if init_div_groups: 50 | n /= m.groups 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | elif model_init == 'xavier_normal': 53 | nn.init.xavier_normal_(m.weight.data) 54 | elif model_init == 'xavier_uniform': 55 | nn.init.xavier_uniform_(m.weight.data) 56 | else: 57 | raise NotImplementedError 58 | elif isinstance(m, nn.BatchNorm2d): 59 | m.weight.data.fill_(1) 60 | m.bias.data.zero_() 61 | elif isinstance(m, nn.Linear): 62 | stdv = 1. / math.sqrt(m.weight.size(1)) 63 | m.weight.data.uniform_(-stdv, stdv) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.BatchNorm1d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | def get_parameters(self, keys=None, mode='include'): 71 | if keys is None: 72 | for name, param in self.named_parameters(): 73 | yield param 74 | elif mode == 'include': 75 | for name, param in self.named_parameters(): 76 | flag = False 77 | for key in keys: 78 | if key in name: 79 | flag = True 80 | break 81 | if flag: 82 | yield param 83 | elif mode == 'exclude': 84 | for name, param in self.named_parameters(): 85 | flag = True 86 | for key in keys: 87 | if key in name: 88 | flag = False 89 | break 90 | if flag: 91 | yield param 92 | else: 93 | raise ValueError('do not support: %s' % mode) 94 | 95 | def weight_parameters(self): 96 | return self.get_parameters() 97 | -------------------------------------------------------------------------------- /mixed_bit/models/mobilenet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.base_models import * 3 | from collections import OrderedDict 4 | import pdb 5 | 6 | def conv_bn(inp, oup, stride): 7 | return nn.Sequential(OrderedDict([('conv', nn.Conv2d(inp, oup, 3, stride, 1, bias=False)), 8 | ('bn', nn.BatchNorm2d(oup)), 9 | ('relu', nn.ReLU6(inplace=True))])) 10 | 11 | 12 | def conv_1x1_bn(inp, oup): 13 | return nn.Sequential(OrderedDict([('conv', nn.Conv2d(inp, oup, 1, 1, 0, bias=False)), 14 | ('bn', nn.BatchNorm2d(oup)), 15 | ('relu', nn.ReLU6(inplace=True))])) 16 | 17 | def conv_dw(inp, oup, stride): 18 | conv1 = nn.Sequential(OrderedDict([('conv', nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False)), 19 | ('bn', nn.BatchNorm2d(inp)), 20 | ('relu', nn.ReLU(inplace=True))])) 21 | conv2 = nn.Sequential(OrderedDict([('conv', nn.Conv2d(inp, oup, 1, 1, 0, bias=False)), 22 | ('bn', nn.BatchNorm2d(oup)), 23 | ('relu', nn.ReLU(inplace=True))])) 24 | return nn.Sequential(conv1, conv2) 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | assert stride in [1, 2] 31 | 32 | hidden_dim = round(inp * expand_ratio) 33 | self.use_res_connect = self.stride == 1 and inp == oup 34 | 35 | if expand_ratio == 1: 36 | dw = nn.Sequential(OrderedDict([('conv', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)), 37 | ('bn', nn.BatchNorm2d(hidden_dim)), 38 | ('relu', nn.ReLU6(inplace=True))])) 39 | pw = nn.Sequential(OrderedDict([('conv', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)), 40 | ('bn', nn.BatchNorm2d(oup))])) 41 | self.conv = nn.Sequential(dw, pw) 42 | else: 43 | pw = nn.Sequential(OrderedDict([('conv', nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False)), 44 | ('bn', nn.BatchNorm2d(hidden_dim)), 45 | ('relu', nn.ReLU6(inplace=True))])) 46 | dw = nn.Sequential(OrderedDict([('conv', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)), 47 | ('bn', nn.BatchNorm2d(hidden_dim)), 48 | ('relu', nn.ReLU6(inplace=True))])) 49 | pwl = nn.Sequential(OrderedDict([('conv', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)), 50 | ('bn', nn.BatchNorm2d(oup))])) 51 | self.conv = nn.Sequential(pw, dw, pwl) 52 | 53 | def forward(self, x): 54 | if self.use_res_connect: 55 | return x + self.conv(x) 56 | else: 57 | return self.conv(x) 58 | 59 | class MobileNetV2(MyNetwork): 60 | def __init__(self, cfg=None, num_classes=1000, dropout=0.2): 61 | super(MobileNetV2, self).__init__() 62 | block = InvertedResidual 63 | if cfg==None: 64 | cfg = [32, 16, 24, 32, 64, 96, 160, 320] 65 | input_channel = cfg[0] 66 | interverted_residual_setting = [ 67 | # t, c, n, s 68 | [1, cfg[1], 1, 1], 69 | [6, cfg[2], 2, 2], 70 | [6, cfg[3], 3, 2], 71 | [6, cfg[4], 4, 2], 72 | [6, cfg[5], 3, 1], 73 | [6, cfg[6], 3, 2], 74 | [6, cfg[7], 1, 1], 75 | ] 76 | 77 | # building first layer 78 | # input_channel = int(input_channel * width_mult) 79 | self.cfg = cfg 80 | self.cfgs_base = [32, 16, 24, 32, 64, 96, 160, 320] 81 | self.dropout = dropout 82 | self.last_channel = 1280 83 | self.num_classes = num_classes 84 | self.features = [conv_bn(3, input_channel, 2)] 85 | self.interverted_residual_setting = interverted_residual_setting 86 | # building inverted residual blocks 87 | for t, c, n, s in interverted_residual_setting: 88 | output_channel = c 89 | for i in range(n): 90 | if i == 0: 91 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 92 | else: 93 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 94 | input_channel = output_channel 95 | # building last several layers 96 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 97 | # make it nn.Sequential 98 | self.features = nn.Sequential(*self.features) 99 | 100 | # building classifier 101 | self.classifier = nn.Sequential( 102 | # nn.Dropout(self.dropout), 103 | nn.Linear(self.last_channel, num_classes), 104 | ) 105 | 106 | def cfg2params_perlayer(self, cfg, length, quant_type = 'PTQ'): 107 | params = [0. for j in range(length)] 108 | first_last_size = 0. 109 | count = 0 110 | params[count] += (3 * 3 * 3 * cfg[0] + 2 * cfg[0]) # first layer 111 | input_channel = cfg[0] 112 | count += 1 113 | for t, c, n, s in self.interverted_residual_setting: 114 | output_channel = c 115 | for i in range(n): 116 | hidden_dim = round(input_channel * t) 117 | if i == 0: 118 | # self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 119 | if t==1: 120 | params[count] += (3 * 3 * hidden_dim + 2 * hidden_dim) 121 | params[count+1] += (1 * 1 * hidden_dim * output_channel + 2 * output_channel) 122 | count += 2 123 | else: 124 | params[count] += (1 * 1 * input_channel * hidden_dim + 2 * hidden_dim) 125 | params[count+1] += (3 * 3 * hidden_dim + 2 * hidden_dim) 126 | params[count+2] += (1 * 1 * hidden_dim * output_channel + 2 * output_channel) 127 | count += 3 128 | else: 129 | # self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 130 | if t==1: 131 | params[count] += (3 * 3 * hidden_dim + 2 * hidden_dim) 132 | params[count+1] += (1 * 1 * hidden_dim * output_channel + 2 * output_channel) 133 | count += 2 134 | else: 135 | params[count] += (1 * 1 * input_channel * hidden_dim + 2 * hidden_dim) 136 | params[count+1] += (3 * 3 * hidden_dim + 2 * hidden_dim) 137 | params[count+2] += (1 * 1 * hidden_dim * output_channel + 2 * output_channel) 138 | count += 3 139 | input_channel = output_channel 140 | params[count] += (1 * 1 * input_channel * self.last_channel + 2 * self.last_channel) # final 1x1 conv 141 | count += 1 142 | params[count] += ((self.last_channel + 1) * self.num_classes) # fc layer 143 | return params, first_last_size 144 | 145 | def cfg2flops_layerwise(self, cfg, length, quant_type = 'PTQ'): # to simplify, only count convolution flops 146 | interverted_residual_setting = [ 147 | # t, c, n, s 148 | [1, cfg[1], 1, 1], 149 | [6, cfg[2], 2, 2], 150 | [6, cfg[3], 3, 2], 151 | [6, cfg[4], 4, 2], 152 | [6, cfg[5], 3, 1], 153 | [6, cfg[6], 3, 2], 154 | [6, cfg[7], 1, 1], 155 | ] 156 | size = 224 157 | flops = [0 for j in range(length)] 158 | count = 0 159 | first_last_flops = 0. 160 | size = size//2 161 | flops[count] += (3 * 3 * 3 * cfg[0] + 0 * cfg[0]) * size * size # first layer 162 | count += 1 163 | input_channel = cfg[0] 164 | for t, c, n, s in interverted_residual_setting: 165 | output_channel = c 166 | for i in range(n): 167 | hidden_dim = round(input_channel * t) 168 | if i == 0: 169 | if s==2: 170 | size = size//2 171 | # self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 172 | if t==1: 173 | flops[count] += (3 * 3 * hidden_dim + 0 * hidden_dim) * size * size 174 | flops[count+1] += (1 * 1 * hidden_dim * output_channel + 0 * output_channel) * size * size 175 | count += 2 176 | else: 177 | size = size * s 178 | flops[count] += (1 * 1 * input_channel * hidden_dim + 0 * hidden_dim) * size * size 179 | size = size // s 180 | flops[count+1] += (3 * 3 * hidden_dim + 0 * hidden_dim) * size * size 181 | flops[count+2] += (1 * 1 * hidden_dim * output_channel + 0 * output_channel) * size * size 182 | count += 3 183 | else: 184 | # self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 185 | if t==1: 186 | flops[count] += (3 * 3 * hidden_dim + 0 * hidden_dim) * size * size 187 | flops[count+1] += (1 * 1 * hidden_dim * output_channel + 0 * output_channel) * size * size 188 | count += 2 189 | else: 190 | flops[count] += (1 * 1 * input_channel * hidden_dim + 0 * hidden_dim) * size * size 191 | flops[count+1] += (3 * 3 * hidden_dim + 0 * hidden_dim) * size * size 192 | flops[count+2] += (1 * 1 * hidden_dim * output_channel + 0 * output_channel) * size * size 193 | count += 3 194 | input_channel = output_channel 195 | flops[count] += (1 * 1 * input_channel * self.last_channel + 0 * self.last_channel) * size * size # final 1x1 conv 196 | count += 1 197 | flops[count] += ((2 * self.last_channel - 1) * self.num_classes) # fc layer 198 | return flops, first_last_flops 199 | 200 | def forward(self, x): 201 | x = self.features(x) 202 | x = x.mean(3).mean(2) 203 | x = self.classifier(x) 204 | return x 205 | 206 | def feature_extract(self, x, quant_type = 'PTQ'): 207 | # layerwise 208 | tensor = [] 209 | for _layer in self.features: 210 | if type(_layer) is not InvertedResidual: 211 | x = _layer(x) 212 | tensor.append(x) 213 | elif len(_layer.conv) == 2: 214 | tensor.append(_layer.conv[0](x)) 215 | tensor.append(_layer.conv(x)) 216 | x = _layer(x) 217 | else: 218 | tensor.append(_layer.conv[0](x)) 219 | tensor.append(_layer.conv[1](_layer.conv[0](x))) 220 | tensor.append(_layer(x)) 221 | x = _layer(x) 222 | x = x.mean(3).mean(2) 223 | x = self.classifier(x) 224 | tensor.append(x) 225 | # pdb.set_trace() 226 | return tensor 227 | 228 | @property 229 | def config(self): 230 | return { 231 | 'name': self.__class__.__name__, 232 | 'cfg': self.cfg, 233 | 'cfg_base': self.cfgs_base, 234 | 'dataset': 'ImageNet', 235 | } 236 | -------------------------------------------------------------------------------- /mixed_bit/run_scripts/PTQ/quant_mobilenetv2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract_cdp.py \ 3 | --model "mobilenetv2" \ 4 | --path "Exp_base/mobilenetv2_base" \ 5 | --dataset "imagenet" \ 6 | --save_path '/home/data/imagenet/' \ 7 | --beta 3.3 \ 8 | --model_size 1.5 \ 9 | --quant_type "PTQ" -------------------------------------------------------------------------------- /mixed_bit/run_scripts/PTQ/quant_mobilenetv2_2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract_cdp.py \ 3 | --model "mobilenetv2" \ 4 | --path "Exp_base/mobilenetv2_base" \ 5 | --dataset "imagenet" \ 6 | --save_path '/home/data/imagenet/' \ 7 | --beta 0.000000001 \ 8 | --model_size 0.9 \ 9 | --quant_type "PTQ" -------------------------------------------------------------------------------- /mixed_bit/run_scripts/PTQ/quant_resnet18.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract_cdp.py \ 3 | --model "resnet18" \ 4 | --path "Exp_base/resnet18_base" \ 5 | --dataset "imagenet" \ 6 | --save_path '/home/data/imagenet' \ 7 | --beta 1.0 \ 8 | --model_size 5.5 \ 9 | --quant_type "PTQ" 10 | -------------------------------------------------------------------------------- /mixed_bit/run_scripts/QAT/quant_resnet18.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract_cdp.py \ 3 | --model "resnet18" \ 4 | --path "Exp_base/resnet18_base" \ 5 | --dataset "imagenet" \ 6 | --save_path '/home/data/imagenet' \ 7 | --beta 10.0 \ 8 | --model_size 6.7 \ 9 | --quant_type "QAT" -------------------------------------------------------------------------------- /mixed_bit/run_scripts/QAT/quant_resnet50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract_cdp.py \ 3 | --model "resnet50" \ 4 | --path "Exp_base/resnet50_base" \ 5 | --dataset "imagenet" \ 6 | --save_path '/home/data/imagenet' \ 7 | --beta 3.3 \ 8 | --model_size 16.0 \ 9 | --quant_type "QAT" -------------------------------------------------------------------------------- /mixed_bit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_utils import * 2 | from .get_data_iter import * 3 | -------------------------------------------------------------------------------- /mixed_bit/utils/get_data_iter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import pickle 5 | import numpy as np 6 | import nvidia.dali.ops as ops 7 | import nvidia.dali.types as types 8 | from sklearn.utils import shuffle 9 | from nvidia.dali.pipeline import Pipeline 10 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIterator 11 | 12 | IMAGENET_IMAGES_NUM_TRAIN = 1281167 13 | IMAGENET_IMAGES_NUM_TEST = 50000 14 | CIFAR_IMAGES_NUM_TRAIN = 50000 15 | CIFAR_IMAGES_NUM_TEST = 10000 16 | 17 | 18 | class Cutout(object): 19 | def __init__(self, length): 20 | self.length = length 21 | 22 | def __call__(self, img): 23 | h, w = img.size(1), img.size(2) 24 | mask = np.ones((h, w), np.float32) 25 | y = np.random.randint(h) 26 | x = np.random.randint(w) 27 | 28 | y1 = np.clip(y - self.length // 2, 0, h) 29 | y2 = np.clip(y + self.length // 2, 0, h) 30 | x1 = np.clip(x - self.length // 2, 0, w) 31 | x2 = np.clip(x + self.length // 2, 0, w) 32 | 33 | mask[y1: y2, x1: x2] = 0. 34 | mask = torch.from_numpy(mask) 35 | mask = mask.expand_as(img) 36 | img *= mask 37 | return img 38 | 39 | 40 | def cutout_func(img, length=16): 41 | h, w = img.size(1), img.size(2) 42 | mask = np.ones((h, w), np.float32) 43 | y = np.random.randint(h) 44 | x = np.random.randint(w) 45 | 46 | y1 = np.clip(y - length // 2, 0, h) 47 | y2 = np.clip(y + length // 2, 0, h) 48 | x1 = np.clip(x - length // 2, 0, w) 49 | x2 = np.clip(x + length // 2, 0, w) 50 | 51 | mask[y1: y2, x1: x2] = 0. 52 | # mask = torch.from_numpy(mask) 53 | mask = mask.reshape(img.shape) 54 | img *= mask 55 | return img 56 | 57 | 58 | def cutout_batch(img, length=16): 59 | h, w = img.size(2), img.size(3) 60 | masks = [] 61 | for i in range(img.size(0)): 62 | mask = np.ones((h, w), np.float32) 63 | y = np.random.randint(h) 64 | x = np.random.randint(w) 65 | 66 | y1 = np.clip(y - length // 2, 0, h) 67 | y2 = np.clip(y + length // 2, 0, h) 68 | x1 = np.clip(x - length // 2, 0, w) 69 | x2 = np.clip(x + length // 2, 0, w) 70 | 71 | mask[y1: y2, x1: x2] = 0. 72 | mask = torch.from_numpy(mask) 73 | mask = mask.expand_as(img[0]).unsqueeze(0) 74 | masks.append(mask) 75 | masks = torch.cat(masks).cuda() 76 | img *= masks 77 | return img 78 | 79 | 80 | class DALIDataloader(DALIGenericIterator): 81 | def __init__(self, pipeline, size, batch_size, output_map=["data", "label"], auto_reset=True, onehot_label=False, dataset='imagenet'): 82 | self._size_all = size 83 | self.batch_size = batch_size 84 | self.onehot_label = onehot_label 85 | self.output_map = output_map 86 | if dataset != 'cifar10': 87 | super().__init__(pipelines=pipeline, reader_name="Reader", 88 | fill_last_batch=False, output_map=output_map) 89 | else: 90 | super().__init__(pipelines=pipeline, size=size, auto_reset=auto_reset, 91 | output_map=output_map, fill_last_batch=True, last_batch_padded=False) 92 | 93 | def __next__(self): 94 | if self._first_batch is not None: 95 | batch = self._first_batch 96 | self._first_batch = None 97 | return [batch[0]['data'], batch[0]['label'].squeeze()] 98 | data = super().__next__()[0] 99 | if self.onehot_label: 100 | return [data[self.output_map[0]], data[self.output_map[1]].squeeze().long()] 101 | else: 102 | return [data[self.output_map[0]], data[self.output_map[1]]] 103 | 104 | def __len__(self): 105 | if self._size_all % self.batch_size == 0: 106 | return self._size_all//self.batch_size 107 | else: 108 | return self._size_all//self.batch_size+1 109 | 110 | 111 | class HybridTrainPipe(Pipeline): 112 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, manual_seed, dali_cpu=False, local_rank=0, world_size=1): 113 | super(HybridTrainPipe, self).__init__(batch_size, 114 | num_threads, device_id, seed=manual_seed) 115 | self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, 116 | num_shards=world_size, random_shuffle=True, pad_last_batch=True) 117 | # let user decide which pipeline works him bets for RN version he runs 118 | if dali_cpu: 119 | dali_device = "cpu" 120 | self.decode = ops.HostDecoderRandomCrop(device=dali_device, output_type=types.RGB, 121 | random_aspect_ratio=[ 122 | 0.8, 1.25], 123 | random_area=[0.1, 1.0], 124 | num_attempts=100) 125 | else: 126 | dali_device = "gpu" 127 | # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet 128 | # without additional reallocations 129 | self.decode = ops.ImageDecoderRandomCrop(device="mixed", output_type=types.RGB, 130 | device_memory_padding=211025920, host_memory_padding=140544512, 131 | random_aspect_ratio=[ 132 | 0.8, 1.25], 133 | random_area=[0.1, 1.0], 134 | num_attempts=100) 135 | self.res = ops.Resize(device=dali_device, resize_x=crop, 136 | resize_y=crop, interp_type=types.INTERP_TRIANGULAR) 137 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 138 | dtype=types.FLOAT, 139 | output_layout=types.NCHW, 140 | crop=(crop, crop), 141 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 142 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 143 | self.coin = ops.CoinFlip(probability=0.5) 144 | print('DALI "{0}" variant'.format(dali_device)) 145 | 146 | def define_graph(self): 147 | rng = self.coin() 148 | self.jpegs, self.labels = self.input(name="Reader") 149 | images = self.decode(self.jpegs) 150 | images = self.res(images) 151 | output = self.cmnp(images.gpu(), mirror=rng) 152 | return [output, self.labels] 153 | 154 | 155 | class HybridValPipe(Pipeline): 156 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, manual_seed, local_rank=0, world_size=1): 157 | super(HybridValPipe, self).__init__(batch_size, 158 | num_threads, device_id, seed=manual_seed) 159 | self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, 160 | random_shuffle=True) 161 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) 162 | self.res = ops.Resize(device="gpu", resize_shorter=size, 163 | interp_type=types.INTERP_TRIANGULAR) 164 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 165 | dtype=types.FLOAT, 166 | output_layout=types.NCHW, 167 | crop=(crop, crop), 168 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 169 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 170 | 171 | def define_graph(self): 172 | self.jpegs, self.labels = self.input(name="Reader") 173 | images = self.decode(self.jpegs) 174 | images = self.res(images) 175 | output = self.cmnp(images) 176 | return [output, self.labels] 177 | 178 | 179 | def get_imagenet_iter(data_type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, manual_seed, val_size=256, world_size=1, 180 | local_rank=0): 181 | if data_type == 'train': 182 | pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, manual_seed=manual_seed, 183 | data_dir=image_dir, crop=crop, world_size=world_size, local_rank=local_rank) 184 | pip_train.build() 185 | dali_iter_train = DALIDataloader( 186 | pipeline=pip_train, size=IMAGENET_IMAGES_NUM_TRAIN // world_size, batch_size=batch_size, onehot_label=True) 187 | return dali_iter_train 188 | elif data_type == 'val': 189 | pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, manual_seed=manual_seed, 190 | data_dir=image_dir, crop=crop, size=val_size, world_size=world_size, local_rank=local_rank) 191 | pip_val.build() 192 | dali_iter_val = DALIDataloader( 193 | pipeline=pip_val, size=IMAGENET_IMAGES_NUM_TEST // world_size, batch_size=batch_size, onehot_label=True) 194 | return dali_iter_val 195 | 196 | 197 | class HybridTrainPipe_CIFAR(Pipeline): 198 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, manual_seed, dali_cpu=False, local_rank=0, world_size=1, 199 | cutout=0): 200 | super(HybridTrainPipe_CIFAR, self).__init__( 201 | batch_size, num_threads, device_id, seed=manual_seed) 202 | self.iterator = iter(CIFAR_INPUT_ITER( 203 | batch_size, 'train', root=data_dir)) 204 | dali_device = "gpu" 205 | self.input = ops.ExternalSource() 206 | self.input_label = ops.ExternalSource() 207 | self.pad = ops.Paste(device=dali_device, ratio=1.25, fill_value=0) 208 | self.uniform = ops.Uniform(range=(0., 1.)) 209 | self.crop = ops.Crop(device=dali_device, crop_h=32, crop_w=32) 210 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 211 | output_layout=types.NCHW, 212 | mean=[125.31, 122.95, 113.87], 213 | std=[63.0, 62.09, 66.70] 214 | ) 215 | self.coin = ops.CoinFlip(probability=0.5) 216 | self.flip = ops.Flip(device="gpu") 217 | 218 | def iter_setup(self): 219 | (images, labels) = self.iterator.next() 220 | self.feed_input(self.jpegs, images) 221 | self.feed_input(self.labels, labels) 222 | 223 | def define_graph(self): 224 | rng = self.coin() 225 | self.jpegs = self.input(name="Reader") 226 | self.labels = self.input_label() 227 | output = self.jpegs 228 | output = self.pad(output.gpu()) 229 | output = self.crop(output, crop_pos_x=self.uniform(), 230 | crop_pos_y=self.uniform()) 231 | output = self.flip(output, horizontal=rng) 232 | output = self.cmnp(output) 233 | return [output, self.labels] 234 | 235 | 236 | class HybridValPipe_CIFAR(Pipeline): 237 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, manual_seed, local_rank=0, world_size=1): 238 | super(HybridValPipe_CIFAR, self).__init__( 239 | batch_size, num_threads, device_id, seed=manual_seed) 240 | self.iterator = iter(CIFAR_INPUT_ITER( 241 | batch_size, 'val', root=data_dir)) 242 | self.input = ops.ExternalSource() 243 | self.input_label = ops.ExternalSource() 244 | self.pad = ops.Paste(device="gpu", ratio=1., fill_value=0) 245 | self.uniform = ops.Uniform(range=(0., 1.)) 246 | self.crop = ops.Crop(device="gpu", crop_h=32, crop_w=32) 247 | self.coin = ops.CoinFlip(probability=0.5) 248 | self.flip = ops.Flip(device="gpu") 249 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 250 | output_layout=types.NCHW, 251 | mean=[125.31, 122.95, 113.87], 252 | std=[63.0, 62.09, 66.70] 253 | ) 254 | 255 | def iter_setup(self): 256 | (images, labels) = self.iterator.next() 257 | self.feed_input(self.jpegs, images) # can only in HWC order 258 | self.feed_input(self.labels, labels) 259 | 260 | def define_graph(self): 261 | self.jpegs = self.input(name="Reader") 262 | self.labels = self.input_label() 263 | # rng = self.coin() 264 | output = self.jpegs 265 | output = self.pad(output.gpu()) 266 | output = self.cmnp(output.gpu()) 267 | return [output, self.labels] 268 | 269 | 270 | class CIFAR_INPUT_ITER(): 271 | base_folder = 'cifar-10-batches-py' 272 | train_list = [ 273 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 274 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 275 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 276 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 277 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 278 | ] 279 | 280 | test_list = [ 281 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 282 | ] 283 | 284 | def __init__(self, batch_size, data_type='train', root='/userhome/data/cifar10'): 285 | self.root = root 286 | self.batch_size = batch_size 287 | self.train = (data_type == 'train') 288 | if self.train: 289 | downloaded_list = self.train_list 290 | else: 291 | downloaded_list = self.test_list 292 | 293 | self.data = [] 294 | self.targets = [] 295 | for file_name, checksum in downloaded_list: 296 | file_path = os.path.join(self.root, self.base_folder, file_name) 297 | with open(file_path, 'rb') as f: 298 | if sys.version_info[0] == 2: 299 | entry = pickle.load(f) 300 | else: 301 | entry = pickle.load(f, encoding='latin1') 302 | self.data.append(entry['data']) 303 | if 'labels' in entry: 304 | self.targets.extend(entry['labels']) 305 | else: 306 | self.targets.extend(entry['fine_labels']) 307 | 308 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 309 | self.targets = np.vstack(self.targets) 310 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 311 | np.save("cifar.npy", self.data) 312 | self.data = np.load('cifar.npy') # to serialize, increase locality 313 | 314 | def __iter__(self): 315 | self.i = 0 316 | self.n = len(self.data) 317 | return self 318 | 319 | def __next__(self): 320 | batch = [] 321 | labels = [] 322 | for _ in range(self.batch_size): 323 | if self.train and self.i % self.n == 0: 324 | self.data, self.targets = shuffle( 325 | self.data, self.targets, random_state=0) 326 | img, label = self.data[self.i], self.targets[self.i] 327 | batch.append(img) 328 | labels.append(label) 329 | self.i = (self.i + 1) % self.n 330 | return (batch, labels) 331 | 332 | next = __next__ 333 | 334 | 335 | def get_cifar_iter(data_type, image_dir, batch_size, num_threads, manual_seed, local_rank=0, world_size=1, val_size=32, cutout=0): 336 | if data_type == 'train': 337 | pip_train = HybridTrainPipe_CIFAR(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, 338 | data_dir=image_dir, 339 | crop=32, world_size=world_size, local_rank=local_rank, cutout=cutout, manual_seed=manual_seed) 340 | pip_train.build() 341 | dali_iter_train = DALIDataloader(pipeline=pip_train, size=CIFAR_IMAGES_NUM_TRAIN // world_size, batch_size=batch_size, onehot_label=True, dataset='cifar10') 342 | return dali_iter_train 343 | 344 | elif data_type == 'val': 345 | pip_val = HybridValPipe_CIFAR(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, 346 | data_dir=image_dir, 347 | crop=32, size=val_size, world_size=world_size, local_rank=local_rank, manual_seed=manual_seed) 348 | pip_val.build() 349 | dali_iter_val = DALIDataloader(pipeline=pip_val, size=CIFAR_IMAGES_NUM_TEST // world_size, batch_size=batch_size, onehot_label=True, dataset='cifar10') 350 | return dali_iter_val 351 | -------------------------------------------------------------------------------- /mixed_bit/utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch,math 2 | import numpy as np 3 | import torch.nn as nn 4 | import sys 5 | 6 | def _make_divisible(v, divisor=8, min_value=None): 7 | """ 8 | This function is taken from the original tf repo. 9 | It ensures that all layers have a channel number that is divisible by 8 10 | It can be seen here: 11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 12 | :param v: 13 | :param divisor: 14 | :param min_value: 15 | :return: 16 | """ 17 | if min_value is None: 18 | min_value = divisor 19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 20 | # Make sure that round down does not go down by more than 10%. 21 | if new_v < 0.9 * v: 22 | new_v += divisor 23 | return new_v 24 | 25 | def build_activation(act_func, inplace=True): 26 | if act_func == 'relu': 27 | return nn.ReLU(inplace=inplace) 28 | elif act_func == 'relu6': 29 | return nn.ReLU6(inplace=inplace) 30 | elif act_func == 'tanh': 31 | return nn.Tanh() 32 | elif act_func == 'sigmoid': 33 | return nn.Sigmoid() 34 | elif act_func is None: 35 | return None 36 | else: 37 | raise ValueError('do not support: %s' % act_func) 38 | 39 | def raw2cfg(model, raw_ratios, flops, p=False, div=8): 40 | left = 0 41 | right = 50 42 | scale = 0 43 | cfg = None 44 | current_flops = 0 45 | base_channels = model.config['cfg_base'] 46 | cnt = 0 47 | while (True): 48 | cnt += 1 49 | scale = (left + right) / 2 50 | scaled_ratios = raw_ratios * scale 51 | for i in range(len(scaled_ratios)): 52 | scaled_ratios[i] = max(0.1, scaled_ratios[i]) 53 | scaled_ratios[i] = min(1, scaled_ratios[i]) 54 | cfg = (base_channels * scaled_ratios).astype(int).tolist() 55 | for i in range(len(cfg)): 56 | cfg[i] = _make_divisible(cfg[i], div) # 8 divisible channels 57 | current_flops = model.cfg2flops(cfg) 58 | if cnt > 20: 59 | break 60 | if abs(current_flops - flops) / flops < 0.01: 61 | break 62 | if p: 63 | print(str(current_flops)+'---'+str(flops)+'---left: '+str(left)+'---right: '+str(right)+'---cfg: '+str(cfg)) 64 | if current_flops < flops: 65 | left = scale 66 | elif current_flops > flops: 67 | right = scale 68 | else: 69 | break 70 | return cfg 71 | 72 | def weight2mask(weight, keep_c): # simple L1 pruning 73 | weight_copy = weight.abs().clone() 74 | L1_norm = torch.sum(weight_copy, dim=(1,2,3)) 75 | arg_max = torch.argsort(L1_norm, descending=True) 76 | arg_max_rev = arg_max[:keep_c].tolist() 77 | mask = np.zeros(weight.shape[0]) 78 | mask[arg_max_rev] = 1 79 | return mask 80 | 81 | def get_unpruned_weights(model, model_origin): 82 | masks = [] 83 | for [m0, m1] in zip(model_origin.named_modules(), model.named_modules()): 84 | if isinstance(m0[1], nn.Conv2d): 85 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 86 | flag = False 87 | if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]: 88 | assert len(masks)>0, "masks is empty!" 89 | if m0[0].endswith('downsample.conv'): 90 | if model.config['depth']>=50: 91 | mask = masks[-4] 92 | else: 93 | mask = masks[-3] 94 | else: 95 | mask = masks[-1] 96 | idx = np.squeeze(np.argwhere(mask)) 97 | if idx.size == 1: 98 | idx = np.resize(idx, (1,)) 99 | w = m0[1].weight.data[:, idx.tolist(), :, :].clone() 100 | flag = True 101 | if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]: 102 | masks.append(None) 103 | if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]: 104 | if m0[0].endswith('downsample.conv'): 105 | mask = masks[-1] 106 | else: 107 | if flag: 108 | mask = weight2mask(w.clone(), m1[1].weight.data.shape[0]) 109 | else: 110 | mask = weight2mask(m0[1].weight.data, m1[1].weight.data.shape[0]) 111 | idx = np.squeeze(np.argwhere(mask)) 112 | if idx.size == 1: 113 | idx = np.resize(idx, (1,)) 114 | if flag: 115 | w = w[idx.tolist(), :, :, :].clone() 116 | else: 117 | w = m0[1].weight.data[idx.tolist(), :, :, :].clone() 118 | m1[1].weight.data = w.clone() 119 | masks.append(mask) 120 | continue 121 | else: 122 | m1[1].weight.data = m0[1].weight.data.clone() 123 | masks.append(None) 124 | elif isinstance(m0[1], nn.BatchNorm2d): 125 | assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here." 126 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 127 | mask = masks[-1] 128 | idx = np.squeeze(np.argwhere(mask)) 129 | if idx.size == 1: 130 | idx = np.resize(idx, (1,)) 131 | m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone() 132 | m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone() 133 | m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone() 134 | m1[1].running_var = m0[1].running_var[idx.tolist()].clone() 135 | continue 136 | m1[1].weight.data = m0[1].weight.data.clone() 137 | m1[1].bias.data = m0[1].bias.data.clone() 138 | m1[1].running_mean = m0[1].running_mean.clone() 139 | m1[1].running_var = m0[1].running_var.clone() 140 | 141 | # noinspection PyUnresolvedReferences 142 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1): 143 | logsoftmax = nn.LogSoftmax(dim=1) 144 | n_classes = pred.size(1) 145 | # convert to one-hot 146 | target = torch.unsqueeze(target, 1) 147 | soft_target = torch.zeros_like(pred) 148 | soft_target.scatter_(1, target, 1) 149 | # label smoothing 150 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 151 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 152 | 153 | 154 | def count_parameters(model): 155 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 156 | return total_params 157 | 158 | 159 | def detach_variable(inputs): 160 | if isinstance(inputs, tuple): 161 | return tuple([detach_variable(x) for x in inputs]) 162 | else: 163 | x = inputs.detach() 164 | x.requires_grad = inputs.requires_grad 165 | return x 166 | 167 | 168 | def accuracy(output, target, topk=(1,)): 169 | """ Computes the precision@k for the specified values of k """ 170 | maxk = max(topk) 171 | batch_size = target.size(0) 172 | 173 | _, pred = output.topk(maxk, 1, True, True) 174 | pred = pred.t() 175 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 176 | 177 | res = [] 178 | for k in topk: 179 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 180 | res.append(correct_k.mul_(100.0 / batch_size)) 181 | return res 182 | 183 | 184 | class AverageMeter(object): 185 | """ 186 | Computes and stores the average and current value 187 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 188 | """ 189 | 190 | def __init__(self): 191 | self.val = 0 192 | self.avg = 0 193 | self.sum = 0 194 | self.count = 0 195 | 196 | def reset(self): 197 | self.val = 0 198 | self.avg = 0 199 | self.sum = 0 200 | self.count = 0 201 | 202 | def update(self, val, n=1): 203 | self.val = val 204 | self.sum += val * n 205 | self.count += n 206 | self.avg = self.sum / self.count 207 | 208 | 209 | 210 | def DFS_bit(value, weight): 211 | # value = [1.2, 3.8, 4.3, 7.9] 212 | # weight = [100, 10, 50, 30] 213 | thresh = sum(a * b for a, b in zip(value, weight)) 214 | best = thresh 215 | ans = None 216 | 217 | def dfs(index, way, cur_value, cur_queue): 218 | nonlocal best, thresh, ans 219 | if index > len(value) - 1: 220 | return 221 | if way == "ceil": 222 | v = math.ceil(value[index]) 223 | # elif way == "ceil+1": 224 | # v = math.ceil(value[index])+1 225 | elif way == "floor": 226 | v = math.floor(value[index]) 227 | # elif way == "floor-1": 228 | # v = math.floor(value[index])-1 229 | 230 | cur_value += v * weight[index] 231 | if cur_value > thresh: 232 | return 233 | 234 | cur_queue.append(v) 235 | if index == len(value) - 1: 236 | if abs(cur_value - thresh) < best: 237 | # print("find a solution:") 238 | # print(cur_queue, abs(cur_value - thresh)) 239 | ans = cur_queue.copy() 240 | best = abs(cur_value - thresh) 241 | 242 | # dfs(index + 1, "ceil+1", cur_value, cur_queue) 243 | dfs(index + 1, "ceil", cur_value, cur_queue) 244 | dfs(index + 1, "floor", cur_value, cur_queue) 245 | # dfs(index + 1, "floor-1", cur_value, cur_queue) 246 | cur_queue.pop() 247 | 248 | # print(f"the thresh of our problem is {thresh}") 249 | # dfs(0, "ceil+1", 0, []) 250 | dfs(0, "ceil", 0, []) 251 | dfs(0, "floor", 0, []) 252 | # dfs(0, "floor-1", 0, []) 253 | # print("-"*100) 254 | # print("the result is:") 255 | # print(ans) 256 | # print(best) 257 | return ans, best -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | Pillow==8.0.1 3 | tensorboardX==2.1 4 | scipy==1.5.3 5 | scikit-learn==0.23.2 6 | torchvision 7 | pytorchcv 8 | torchprofile 9 | PuLP --------------------------------------------------------------------------------