├── .gitattributes ├── .gitignore ├── README.md ├── __init__.py ├── figures ├── inceptionV3_fig1.png ├── inceptionV3_fig2.png ├── opt_alpha_gaussian.png ├── opt_alpha_laplace.png ├── resnet101_fig1.png ├── resnet101_fig2.png ├── resnet18_fig1.png ├── resnet18_fig2.png ├── resnet50_fig1.png ├── resnet50_fig2.png ├── vgg16_bn_fig1.png ├── vgg16_bn_fig2.png ├── vgg16_fig1.png └── vgg16_fig2.png ├── inference-sim.py ├── kernels ├── build_all.sh ├── build_int_quantization.py ├── gemmlowp.cu └── int_quantization.cpp ├── mse_analysis.py ├── optimal_alpha.ipynb ├── pytorch_quantizer ├── __init__.py └── quantization │ ├── __init__.py │ ├── inference │ ├── __init__.py │ ├── inference_quantization_manager.py │ └── statistic_manager.py │ ├── qtypes │ ├── __init__.py │ └── int_quantizer.py │ └── quantization_manager.py └── utils ├── __init__.py ├── absorb_bn.py ├── attacher.py ├── dataset.py ├── dump_manager.py ├── log.py ├── meters.py ├── misc.py ├── monitor.py ├── optim.py └── preprocess.py /.gitattributes: -------------------------------------------------------------------------------- 1 | jupyter/* linguist-detectable=false 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.pyc 3 | __pycache__/ 4 | .pytest_cache 5 | *.tar 6 | venv/ 7 | venv3/ 8 | .env/ 9 | .idea/ 10 | logs/ 11 | results/ 12 | .ipynb_checkpoints/ 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACIQ: ANALYTICAL CLIPPING FOR INTEGER QUANTIZATION OF NEURAL NETWORKS 2 | This is complete example for applying Laplace and Gaussian clipping on activations of CNN. 3 | [ACIQ](https://openreview.net/pdf?id=B1x33sC9KQ) 4 | 5 | ## Dependencies 6 | - python3.x 7 | - [pytorch]() 8 | - [torchvision]() to load the datasets, perform image transforms 9 | - [pandas]() for logging to csv 10 | - [bokeh]() for training visualization 11 | 12 | ## Data 13 | - To run this code you need validation set from ILSVRC2012 data 14 | - Configure your dataset path by providing --data "PATH_TO_ILSVRC" or copy ILSVRC dir to ~/datasets/ILSVRC2012. 15 | - To get the ILSVRC2012 data, you should register on their site for access: 16 | 17 | ## Building cuda kernels for GEMMLOWP 18 | To improve performance GEMMLOWP quantization was implemented in cuda and requires to compile kernels. 19 | 20 | - Create virtual environment for python3 and activate: 21 | ``` 22 | virtualenv --system-site-packages -p python3 venv3 23 | . ./venv3/bin/activate 24 | ``` 25 | - build kernels 26 | ``` 27 | cd kernels 28 | ./build_all.sh 29 | ``` 30 | 31 | ## Prepare setup for Inference 32 | Low precision inference requires to find scale of low precision tensors ahead of time. In order to calculate scale we need to collect statistics of activations for specific topology and dataset. 33 | ### Collect statistics 34 | ``` 35 | python inference-sim -a resnet18 -b 512 --qtype int8 -sm collect 36 | ``` 37 | Statistics will be saved under ~/asiq_data/statistics folder. 38 | ### Run inference experiment 39 | Following command line will evaluate resnet18 with 4bit activations and Laplace clipping 40 | ``` 41 | python inference-sim -a resnet18 -b 512 --qtype int4 -sm use -th laplace 42 | ``` 43 | `* Prec@1 65.728 Prec@5 86.706` 44 | 45 | To evaluate non clipped version just omit -th or set "-th no" 46 | ``` 47 | python inference-sim -a resnet18 -b 512 --qtype int4 -sm use th no 48 | ``` 49 | `* Prec@1 53.206 Prec@5 76.860` 50 | 51 | Laplace clipping improves top1 accuracy by 12.5% w/o retraining. 52 | 53 | ## Solution for optimal clipping 54 | 55 | The best of our knowladge, differentiable equations presented in the paper doesn't have analytical solution. We solve those empirically using scipy library and find optimal alpha value for Gaussian and Laplace cases. 56 | We show linear dependency between optimal alpha and sigma for Gaussian case and optimal alpha and b for Laplace case. 57 | 58 | [optimal_alpha.ipynb](optimal_alpha.ipynb) 59 | 60 | Gaussian case, linear dependency 61 | ![Gaussian case](figures/opt_alpha_gaussian.png) 62 | 63 | ## Quantization with optimal clipping 64 | In order to quantize tensor to M bit with optimal clipping we use GEMMLOWP quantization with small modification. We replace dynamic range in scale computation by 2*alpha where alpha is optimal clipping value. 65 | 66 | Quantization code can be found here: 67 | [int_quantizer.py](pytorch_quantizer/quantization/qtypes/int_quantizer.py) 68 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/__init__.py -------------------------------------------------------------------------------- /figures/inceptionV3_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/inceptionV3_fig1.png -------------------------------------------------------------------------------- /figures/inceptionV3_fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/inceptionV3_fig2.png -------------------------------------------------------------------------------- /figures/opt_alpha_gaussian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/opt_alpha_gaussian.png -------------------------------------------------------------------------------- /figures/opt_alpha_laplace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/opt_alpha_laplace.png -------------------------------------------------------------------------------- /figures/resnet101_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/resnet101_fig1.png -------------------------------------------------------------------------------- /figures/resnet101_fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/resnet101_fig2.png -------------------------------------------------------------------------------- /figures/resnet18_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/resnet18_fig1.png -------------------------------------------------------------------------------- /figures/resnet18_fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/resnet18_fig2.png -------------------------------------------------------------------------------- /figures/resnet50_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/resnet50_fig1.png -------------------------------------------------------------------------------- /figures/resnet50_fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/resnet50_fig2.png -------------------------------------------------------------------------------- /figures/vgg16_bn_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/vgg16_bn_fig1.png -------------------------------------------------------------------------------- /figures/vgg16_bn_fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/vgg16_bn_fig2.png -------------------------------------------------------------------------------- /figures/vgg16_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/vgg16_fig1.png -------------------------------------------------------------------------------- /figures/vgg16_fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/figures/vgg16_fig2.png -------------------------------------------------------------------------------- /inference-sim.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | import random 6 | import shutil 7 | import time 8 | import collections 9 | import warnings 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from utils.meters import AverageMeter, accuracy 22 | from pytorch_quantizer.quantization.inference.inference_quantization_manager import QuantizationManagerInference as QM 23 | from utils.log import EvalLog 24 | from utils.absorb_bn import search_absorbe_bn 25 | from utils.dump_manager import DumpManager as DM 26 | from pathlib import Path 27 | 28 | 29 | torch.backends.cudnn.deterministic = True 30 | 31 | home = str(Path.home()) 32 | IMAGENET_FOR_INFERENCE = os.path.join(home, 'datasets/ILSVRC2012/') 33 | 34 | model_names = sorted(name for name in models.__dict__ 35 | if name.islower() and not name.startswith("__") 36 | and callable(models.__dict__[name])) 37 | model_names.append('shufflenet') 38 | 39 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 40 | parser.add_argument('--data', metavar='DIR', default=IMAGENET_FOR_INFERENCE, 41 | help='path to dataset') 42 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 43 | choices=model_names, 44 | help='model architecture: ' + 45 | ' | '.join(model_names) + 46 | ' (default: resnet18)') 47 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 48 | help='number of data loading workers (default: 4)') 49 | parser.add_argument('-b', '--batch-size', default=256, type=int, 50 | metavar='N', help='mini-batch size (default: 256)') 51 | parser.add_argument('--print-freq', '-p', default=10, type=int, 52 | metavar='N', help='print frequency (default: 10)') 53 | parser.add_argument('--seed', default=None, type=int, 54 | help='seed for initializing training. ') 55 | parser.add_argument('--device', default='cuda', 56 | help='device assignment ("cpu" or "cuda")') 57 | parser.add_argument('--device_ids', default=[0], type=int, nargs='+', 58 | help='device ids assignment (e.g 0 1 2 3') 59 | 60 | parser.add_argument('--qtype', default=None, help='data type: bfloat[N], int[N]') 61 | parser.add_argument('--stochastic', '-s', action='store_true', help='Stochastic rounding.', default=False) 62 | parser.add_argument('--hw_scale', '-hs', action='store_true', help='Force scale to be HW compatible', default=False) 63 | parser.add_argument('--preserve_zero', '-pz', action='store_true', help='Preserve zero during quantization', default=False) 64 | parser.add_argument('--eval_precision', '-ep', action='store_true', default=False, help='Evaluate different precisions, to csv.') 65 | parser.add_argument('--threshold', '-th', default='no', help='Threshold for integer quantization: [no, gaus, exp, laplace]') 66 | parser.add_argument('--stats_mode', '-sm', default='no', help='Specify if collect stats, use or not stats: [collect, use, no]') 67 | parser.add_argument('--stats_folder', '-sf', default=None, help='Specify directory of for statistics') 68 | parser.add_argument('--custom_test', '-ct', action='store_true', default=False, help='Perform some custom test.') 69 | parser.add_argument('--dump_dir', '-dd', default=None, help='Directory to dump tensors') 70 | args = parser.parse_args() 71 | 72 | if args.arch == 'resnet50': 73 | max_mse_order_id = ['linear0_activation', 'conv52_activation', 'conv49_activation', 'conv46_activation', 'conv43_activation', 'conv2_activation', 'conv25_activation', 'conv5_activation', 'conv1_activation', 'conv3_activation', 'conv9_activation', 'conv50_activation', 'conv12_activation', 'conv6_activation', 'conv13_activation', 'conv51_activation', 'conv44_activation', 'conv48_activation', 'conv22_activation', 'conv8_activation', 'conv41_activation', 'conv29_activation', 'conv26_activation', 'conv19_activation', 'conv47_activation', 'conv40_activation', 'conv32_activation', 'conv45_activation', 'conv38_activation', 'conv18_activation', 'conv35_activation', 'conv37_activation', 'conv21_activation', 'conv16_activation', 'conv34_activation', 'conv28_activation', 'conv4_activation', 'conv31_activation', 'conv11_activation', 'conv27_activation', 'conv15_activation', 'conv14_activation', 'conv42_activation', 'conv17_activation', 'conv20_activation', 'conv10_activation', 'conv24_activation', 'conv23_activation', 'conv30_activation', 'conv39_activation', 'conv7_activation', 'conv36_activation', 'conv33_activation'] 74 | if args.arch == 'resnet18': 75 | max_mse_order_id = ['linear0_activation', 'conv19_activation', 'conv4_activation', 'conv17_activation', 'conv1_activation', 'conv2_activation', 'conv3_activation', 'conv7_activation', 'conv12_activation', 'conv8_activation', 'conv6_activation', 'conv9_activation', 'conv11_activation', 'conv14_activation', 'conv13_activation', 'conv18_activation', 'conv16_activation', 'conv15_activation', 'conv5_activation', 'conv10_activation'] 76 | elif args.arch == 'vgg16': 77 | max_mse_order_id = ['conv7_activation', 'conv8_activation', 'conv6_activation', 'conv5_activation', 'conv9_activation', 'conv4_activation', 'conv10_activation', 'conv11_activation', 'conv3_activation', 'conv12_activation', 'linear0_activation', 'conv2_activation', 'linear2_activation', 'linear1_activation', 'conv1_activation'] 78 | elif args.arch == 'vgg16_bn': 79 | max_mse_order_id = ['linear2_activation', 'linear0_activation', 'linear1_activation', 'conv12_activation', 'conv1_activation', 'conv3_activation', 'conv2_activation', 'conv10_activation', 'conv11_activation', 'conv6_activation', 'conv4_activation', 'conv8_activation', 'conv5_activation', 'conv7_activation', 'conv9_activation'] 80 | elif args.arch == 'resnet101': 81 | max_mse_order_id = ['linear0_activation', 'conv103_activation', 'conv100_activation', 'conv97_activation', 'conv94_activation', 'conv2_activation', 'conv3_activation', 'conv25_activation', 'conv1_activation', 'conv102_activation', 'conv13_activation', 'conv95_activation', 'conv9_activation', 'conv99_activation', 'conv101_activation', 'conv22_activation', 'conv8_activation', 'conv26_activation', 'conv98_activation', 'conv12_activation', 'conv96_activation', 'conv19_activation', 'conv91_activation', 'conv21_activation', 'conv92_activation', 'conv88_activation', 'conv18_activation', 'conv85_activation', 'conv82_activation', 'conv86_activation', 'conv56_activation', 'conv59_activation', 'conv89_activation', 'conv67_activation', 'conv4_activation', 'conv27_activation', 'conv83_activation', 'conv14_activation', 'conv5_activation', 'conv11_activation', 'conv53_activation', 'conv16_activation', 'conv6_activation', 'conv62_activation', 'conv64_activation', 'conv77_activation', 'conv47_activation', 'conv50_activation', 'conv68_activation', 'conv79_activation', 'conv65_activation', 'conv80_activation', 'conv61_activation', 'conv73_activation', 'conv76_activation', 'conv55_activation', 'conv32_activation', 'conv58_activation', 'conv71_activation', 'conv46_activation', 'conv49_activation', 'conv70_activation', 'conv74_activation', 'conv15_activation', 'conv24_activation', 'conv44_activation', 'conv41_activation', 'conv43_activation', 'conv52_activation', 'conv40_activation', 'conv31_activation', 'conv93_activation', 'conv23_activation', 'conv38_activation', 'conv20_activation', 'conv17_activation', 'conv90_activation', 'conv87_activation', 'conv35_activation', 'conv37_activation', 'conv84_activation', 'conv81_activation', 'conv10_activation', 'conv78_activation', 'conv34_activation', 'conv60_activation', 'conv63_activation', 'conv69_activation', 'conv7_activation', 'conv29_activation', 'conv51_activation', 'conv54_activation', 'conv75_activation', 'conv66_activation', 'conv72_activation', 'conv48_activation', 'conv57_activation', 'conv28_activation', 'conv33_activation', 'conv45_activation', 'conv42_activation', 'conv39_activation', 'conv36_activation', 'conv30_activation'] 82 | elif args.arch == 'inception_v3': 83 | max_mse_order_id = ['conv5_activation', 'conv12_activation', 'conv1_activation', 'conv7_activation', 'conv4_activation', 'conv2_activation', 'conv14_activation', 'conv19_activation', 'conv10_activation', 'conv92_activation', 'conv21_activation', 'conv22_activation', 'conv9_activation', 'conv77_activation', 'conv16_activation', 'conv47_activation', 'conv48_activation', 'conv17_activation', 'conv58_activation', 'conv8_activation', 'conv55_activation', 'conv56_activation', 'conv40_activation', 'conv63_activation', 'conv15_activation', 'conv62_activation', 'conv84_activation', 'conv54_activation', 'conv57_activation', 'conv52_activation', 'conv65_activation', 'conv91_activation', 'conv76_activation', 'conv34_activation', 'conv51_activation', 'conv85_activation', 'conv53_activation', 'conv83_activation', 'conv35_activation', 'conv50_activation', 'conv46_activation', 'conv82_activation', 'conv61_activation', 'conv30_activation', 'conv37_activation', 'conv67_activation', 'conv75_activation', 'conv64_activation', 'conv29_activation', 'conv66_activation', 'conv44_activation', 'conv33_activation', 'conv43_activation', 'conv38_activation', 'conv45_activation', 'conv42_activation', 'conv23_activation', 'conv36_activation', 'conv60_activation', 'conv32_activation', 'conv41_activation', 'conv79_activation', 'conv6_activation', 'conv13_activation', 'conv78_activation', 'conv20_activation', 'conv73_activation', 'conv74_activation', 'conv80_activation', 'conv31_activation', 'conv27_activation', 'conv81_activation', 'conv88_activation', 'conv68_activation', 'conv28_activation', 'conv26_activation', 'conv89_activation', 'conv72_activation', 'conv93_activation', 'conv90_activation', 'conv94_activation', 'conv3_activation', 'conv24_activation', 'conv87_activation', 'conv18_activation', 'conv69_activation', 'conv59_activation', 'conv25_activation', 'conv49_activation', 'linear1_activation', 'conv39_activation', 'conv86_activation', 'conv11_activation', 'conv95_activation'] 84 | 85 | 86 | def main(): 87 | global args, best_prec1 88 | 89 | if args.seed is not None: 90 | random.seed(args.seed) 91 | torch.manual_seed(args.seed) 92 | cudnn.deterministic = True 93 | warnings.warn('You have chosen to seed training. ' 94 | 'This will turn on the CUDNN deterministic setting, ' 95 | 'which can slow down your training considerably! ' 96 | 'You may see unexpected behavior when restarting ' 97 | 'from checkpoints.') 98 | 99 | if 'cuda' in args.device and torch.cuda.is_available(): 100 | if args.seed is not None: 101 | torch.cuda.manual_seed_all(args.seed) 102 | torch.cuda.set_device(args.device_ids[0]) 103 | cudnn.benchmark = True 104 | else: 105 | args.device_ids = None 106 | 107 | # create model 108 | print("=> using pre-trained model '{}'".format(args.arch)) 109 | model = models.__dict__[args.arch](pretrained=True) 110 | model.to(args.device) 111 | 112 | if args.device_ids and len(args.device_ids) > 1 and args.arch != 'shufflenet': 113 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 114 | model.features = torch.nn.DataParallel(model.features, args.device_ids) 115 | else: 116 | model = torch.nn.DataParallel(model, args.device_ids) 117 | 118 | # BatchNorm folding 119 | if 'resnet' in args.arch or args.arch == 'vgg16_bn' or args.arch == 'inception_v3': 120 | print("Perform BN folding") 121 | search_absorbe_bn(model) 122 | QM().bn_folding = True 123 | 124 | # define loss function (criterion) and optimizer 125 | criterion = nn.CrossEntropyLoss() 126 | criterion.to(args.device) 127 | 128 | cudnn.benchmark = True 129 | 130 | # Data loading code 131 | valdir = os.path.join(args.data, 'val') 132 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 133 | std=[0.229, 0.224, 0.225]) 134 | 135 | resize = 256 if args.arch != 'inception_v3' else 299 136 | crop_size = 224 if args.arch != 'inception_v3' else 299 137 | 138 | val_loader = torch.utils.data.DataLoader( 139 | datasets.ImageFolder(valdir, transforms.Compose([ 140 | transforms.Resize(resize), 141 | transforms.CenterCrop(crop_size), 142 | transforms.ToTensor(), 143 | normalize, 144 | ])), 145 | batch_size=args.batch_size, shuffle=False, 146 | num_workers=args.workers, pin_memory=True) 147 | 148 | if args.eval_precision: 149 | elog = EvalLog(['dtype', 'val_prec1', 'val_prec5']) 150 | print("\nFloat32 no quantization") 151 | QM().disable() 152 | val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion) 153 | elog.log('fp32', val_prec1, val_prec5) 154 | logging.info('\nValidation Loss {val_loss:.4f} \t' 155 | 'Validation Prec@1 {val_prec1:.3f} \t' 156 | 'Validation Prec@5 {val_prec5:.3f} \n' 157 | .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5)) 158 | print("--------------------------------------------------------------------------") 159 | 160 | for q in [8, 7, 6, 5, 4]: 161 | args.qtype = 'int{}'.format(q) 162 | print("\nQuantize to %s" % args.qtype) 163 | QM().quantize = True 164 | QM().reload(args, qparams) 165 | val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion) 166 | elog.log(args.qtype, val_prec1, val_prec5) 167 | logging.info('\nValidation Loss {val_loss:.4f} \t' 168 | 'Validation Prec@1 {val_prec1:.3f} \t' 169 | 'Validation Prec@5 {val_prec5:.3f} \n' 170 | .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5)) 171 | print("--------------------------------------------------------------------------") 172 | print(elog) 173 | elog.save('results/precision/%s_%s_clipping.csv' % (args.arch, args.threshold)) 174 | elif args.custom_test: 175 | log_name = 'results/custom_test/%s_max_mse_%s_cliping_layer_selection.csv' % (args.arch, args.threshold) 176 | elog = EvalLog(['num_8bit_layers', 'indexes', 'val_prec1', 'val_prec5'], log_name, auto_save=True) 177 | for i in range(len(max_mse_order_id)+1): 178 | _8bit_layers = ['conv0_activation'] + max_mse_order_id[0:i] 179 | print("it: %d, 8 bit layers: %d" % (i, len(_8bit_layers))) 180 | QM().set_8bit_list(_8bit_layers) 181 | val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion) 182 | elog.log(i+1, str(_8bit_layers), val_prec1, val_prec5) 183 | print(elog) 184 | else: 185 | val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion) 186 | 187 | 188 | 189 | def validate(val_loader, model, criterion): 190 | batch_time = AverageMeter() 191 | losses = AverageMeter() 192 | top1 = AverageMeter() 193 | top5 = AverageMeter() 194 | 195 | # switch to evaluate mode 196 | model.eval() 197 | 198 | if args.dump_dir is not None: 199 | QM().disable() 200 | DM(args.dump_dir) 201 | 202 | with torch.no_grad(): 203 | end = time.time() 204 | for i, (input, target) in enumerate(val_loader): 205 | input = input.to(args.device) 206 | target = target.to(args.device) 207 | if args.dump_dir is not None and i == 5: 208 | with DM(args.dump_dir): 209 | DM().set_tag('batch%d'%i) 210 | # compute output 211 | output = model(input) 212 | break 213 | else: 214 | output = model(input) 215 | 216 | loss = criterion(output, target) 217 | 218 | # measure accuracy and record loss 219 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 220 | losses.update(loss.item(), input.size(0)) 221 | top1.update(float(prec1), input.size(0)) 222 | top5.update(float(prec5), input.size(0)) 223 | 224 | # measure elapsed time 225 | batch_time.update(time.time() - end) 226 | end = time.time() 227 | 228 | if i % args.print_freq == 0: 229 | print('Test: [{0}/{1}]\t' 230 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 231 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 232 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 233 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 234 | i, len(val_loader), batch_time=batch_time, loss=losses, 235 | top1=top1, top5=top5)) 236 | 237 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 238 | .format(top1=top1, top5=top5)) 239 | 240 | return losses.avg, top1.avg, top5.avg 241 | 242 | qparams = {'int': { 243 | 'threshold': args.threshold, 244 | 'true_zero': args.preserve_zero 245 | }} # TODO: add params for bfloat 246 | if __name__ == '__main__': 247 | with QM(args, qparams): 248 | main() 249 | -------------------------------------------------------------------------------- /kernels/build_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm -r build 3 | echo "**************************************************************" 4 | echo "Building int quantization kernels" 5 | echo "**************************************************************" 6 | python build_int_quantization.py install 7 | echo "Done" 8 | echo "**************************************************************" 9 | -------------------------------------------------------------------------------- /kernels/build_int_quantization.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | 4 | 5 | 6 | setup(name='int_quantization', 7 | ext_modules=[CUDAExtension('int_quantization', ['int_quantization.cpp', 8 | 'gemmlowp.cu' 9 | ])], 10 | cmdclass={'build_ext': BuildExtension}) 11 | 12 | # for installation execute: 13 | # > python build_int_quantization.py install 14 | # record list of all installed files: 15 | # > python build_int_quantization.py install --record files.txt -------------------------------------------------------------------------------- /kernels/gemmlowp.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | 9 | __global__ void GEMMLowpKernel(const float* in, const int N, float* out, 10 | float scale, float shift, long long qmax, const float* noise, bool enforce_true_zero) { 11 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { 12 | out[i] = in[i]; 13 | if (enforce_true_zero) 14 | out[i] = out[i] / scale - shift; 15 | else 16 | out[i] = (out[i] - shift) / scale; 17 | out[i] += noise[i]; 18 | out[i] = fminf(out[i], qmax); 19 | out[i] = fmaxf(out[i], 0.); 20 | out[i] = roundf(out[i]); 21 | if (enforce_true_zero) 22 | out[i] = (out[i] + shift) * scale; 23 | else 24 | out[i] = out[i] * scale + shift; 25 | } 26 | } 27 | 28 | #define block_count 32 29 | #define thread_per_block 1024 30 | // Wrapper for ATen 31 | at::Tensor float2gemmlowp(at::Tensor in, float range, float offset, int num_bits, bool int_exp, bool enforce_true_zero, at::Tensor noise) { 32 | if (range <= 0) 33 | return in; 34 | 35 | int N = in.numel(); 36 | auto out = at::zeros_like(in); 37 | long long qmax = (0x1l << num_bits) - 1; 38 | float scale = range / qmax; 39 | if (int_exp) 40 | scale = powf(2, int(ceilf(log2f(scale)))); 41 | float zero_point = roundf(offset / scale); 42 | float shift = enforce_true_zero ? zero_point : offset; 43 | switch (in.type().scalarType()) { 44 | case at::ScalarType::Float: 45 | GEMMLowpKernel<<>>(in.data(), N, out.data(), scale, shift, qmax, noise.data(), enforce_true_zero); 46 | break; 47 | default: 48 | out = in; 49 | } 50 | 51 | return out; 52 | } 53 | 54 | -------------------------------------------------------------------------------- /kernels/int_quantization.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | // CUDA declarations 5 | at::Tensor float2gemmlowp(at::Tensor in, float range, float offset, int num_bits, bool int_exp, 6 | bool enforce_true_zero, at::Tensor noise); 7 | 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("float2gemmlowp", &float2gemmlowp, "Convert float 32 to gemmlowp"); 11 | } -------------------------------------------------------------------------------- /mse_analysis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import math 4 | 5 | def uniform_midtread_quantizer(x, Q): 6 | xQ = np.round(x / Q) * Q 7 | return xQ 8 | 9 | 10 | def GaussianClippingAnalysis(Alpha, sigma,bitWidth): 11 | Analysis = [] 12 | for alpha in Alpha: 13 | clipping_mse = (sigma**2 + (alpha ** 2)) * (1 - math.erf(alpha / (sigma*np.sqrt(2.0)))) - np.sqrt(2.0/np.pi) * alpha * sigma*(np.e ** ((-1)*(0.5* (alpha ** 2))/sigma**2)) 14 | quant_mse = (alpha ** 2) / (3 * (2 ** (2 * bitWidth))) 15 | mse = clipping_mse + quant_mse 16 | Analysis.append(mse) 17 | return Analysis 18 | 19 | def GaussianClippingSimulation(Alpha, sigma,bitWidth): 20 | highPrecision = np.random.normal(0, sigma, size=100000) 21 | simulations = [] 22 | for alpha in Alpha: 23 | s = np.copy(highPrecision) 24 | Q = (2*alpha)/(2**bitWidth) 25 | # clipping 26 | s[s > alpha] = alpha 27 | s[s < -alpha] = -alpha 28 | # quabtization 29 | s = uniform_midtread_quantizer(s, Q) 30 | 31 | mse = ((s - highPrecision) ** 2).mean() 32 | simulations.append(mse) 33 | return simulations 34 | 35 | 36 | 37 | def LaplacianClippingAnalysis(Alpha, b,bitWidth): 38 | Analysis = [] 39 | for alpha in Alpha: 40 | mse = 2 * (b ** 2) * ((np.e) ** (-alpha / b)) + ((alpha ** 2) / (3 * (2 ** (2 * bitWidth)))) 41 | Analysis.append(mse) 42 | return Analysis 43 | 44 | def LaplacianClippingSimulation(Alpha, b, bitWidth): 45 | simulations = [] 46 | highPrecision = np.random.laplace(scale=b, size=100000, loc = 0) 47 | for alpha in Alpha: 48 | s = np.copy(highPrecision) 49 | Q = (2*alpha)/(2**bitWidth) 50 | 51 | #clipping 52 | s[s > alpha ] = alpha 53 | s[s < -alpha] = -alpha 54 | # quantization 55 | s = uniform_midtread_quantizer(s, Q) 56 | 57 | mse = ((s - highPrecision) ** 2).mean() 58 | simulations.append(mse) 59 | return simulations 60 | 61 | 62 | if __name__ == "__main__": 63 | Alpha = np.arange(0, 15, 0.1) 64 | 65 | #Experiment parameters 66 | bitWidth = 4 67 | sigma = 2 # standard deviation 68 | 69 | #Gauss 70 | simulation = GaussianClippingSimulation(Alpha,sigma,bitWidth) 71 | analysis = GaussianClippingAnalysis(Alpha, sigma, bitWidth) 72 | 73 | #Laplace 74 | # simulation = LaplacianClippingSimulation(Alpha, sigma, bitWidth) 75 | # analysis = LaplacianClippingAnalysis(Alpha, sigma, bitWidth) 76 | 77 | 78 | plt.plot(Alpha,simulation,'b', linewidth=3) 79 | plt.plot(Alpha,analysis,'r', linewidth=3) 80 | plt.legend(('simulation', 'analysis')); plt.ylabel('Mean Square Error', size=20) ; plt.xlabel('Clipping Value', size=20) 81 | plt.title('Bit Width='+ str(bitWidth), size=20) 82 | plt.show() -------------------------------------------------------------------------------- /pytorch_quantizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/pytorch_quantizer/__init__.py -------------------------------------------------------------------------------- /pytorch_quantizer/quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/pytorch_quantizer/quantization/__init__.py -------------------------------------------------------------------------------- /pytorch_quantizer/quantization/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/pytorch_quantizer/quantization/inference/__init__.py -------------------------------------------------------------------------------- /pytorch_quantizer/quantization/inference/inference_quantization_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_quantizer.quantization import qtypes 4 | from utils.misc import Singleton 5 | from utils import attacher 6 | from utils.monitor import Monitor 7 | from .statistic_manager import StatisticManager 8 | from pytorch_quantizer.quantization.quantization_manager import QuantizationManagerBase 9 | from enum import Enum 10 | from itertools import count 11 | import os 12 | import numpy as np 13 | from utils.dump_manager import DumpManager as DM 14 | 15 | 16 | class StatsMode(Enum): 17 | no_stats = 1 18 | collect_stats = 2 19 | use_stats = 3 20 | 21 | 22 | class Conv2dWithId(nn.Conv2d): 23 | _id = count(0) 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dWithId, self).__init__(in_channels, out_channels, kernel_size, stride, 27 | padding, dilation, groups, bias) 28 | self.id = next(self._id) 29 | # print('conv_%d' % self.id) 30 | 31 | # TODO: handle quantization of activations per model 32 | def forward(self, input): 33 | activation_id = 'conv%d_activation' % self.id 34 | if not QMI().enabled: 35 | out = super(Conv2dWithId, self).forward(input) 36 | else: 37 | if QMI().stats_mode is not StatsMode.collect_stats: 38 | self.weight.data = QMI().quantize_instant(self.weight, "weight") 39 | if self.bias is not None: 40 | self.bias.data = QMI().quantize_instant(self.bias, "bias") 41 | out = super(Conv2dWithId, self).forward(input) 42 | if QMI().stats_mode is StatsMode.collect_stats: 43 | QMI().stats_manager.save_tensor_stats(out, activation_id) 44 | elif QMI().stats_mode is StatsMode.use_stats: 45 | # Quantize using statistics 46 | out = QMI().quantize_instant(out, "activation", stat_id=activation_id) 47 | else: 48 | # No stats, quantize using actual values 49 | out = QMI().quantize_instant(out, "activation") 50 | 51 | return out 52 | 53 | 54 | class LinearWithId(nn.Linear): 55 | _id = count(0) 56 | def __init__(self, in_features, out_features, bias=True): 57 | super(LinearWithId, self).__init__(in_features, out_features, bias) 58 | self.id = next(self._id) 59 | 60 | def forward(self, input): 61 | if not QMI().enabled: 62 | return super(LinearWithId, self).forward(input) 63 | else: 64 | if QMI().stats_mode is not StatsMode.collect_stats: 65 | self.weight.data = QMI().quantize_instant(self.weight, "weight") 66 | if self.bias is not None: 67 | self.bias.data = QMI().quantize_instant(self.bias, "bias") 68 | out = super(LinearWithId, self).forward(input) 69 | activation_id = 'linear%d_activation' % self.id 70 | if QMI().stats_mode is StatsMode.collect_stats: 71 | QMI().stats_manager.save_tensor_stats(out, activation_id) 72 | elif QMI().stats_mode is StatsMode.use_stats: 73 | out = QMI().quantize_instant(out, "activation_linear", stat_id=activation_id) 74 | else: 75 | out = QMI().quantize_instant(out, "activation_linear") 76 | return out 77 | 78 | 79 | # TODO: batch norm folding 80 | class BatchNorm2dWithId(nn.BatchNorm2d): 81 | _id = count(0) 82 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 83 | track_running_stats=True): 84 | super(BatchNorm2dWithId, self).__init__(num_features, eps, momentum, affine, track_running_stats) 85 | self.id = next(self._id) 86 | # print('bn_%d' % self.id) 87 | 88 | def forward(self, input): 89 | if not QMI().enabled: 90 | return super(BatchNorm2dWithId, self).forward(input) 91 | else: 92 | if QMI().bn_folding: 93 | # Do regular BN if floding is set 94 | return super(BatchNorm2dWithId, self).forward(input) 95 | 96 | if QMI().stats_mode is not StatsMode.collect_stats: 97 | self.weight.data = QMI().quantize_instant(self.weight, "weight") 98 | if self.bias is not None: 99 | self.bias.data = QMI().quantize_instant(self.bias, "bias") 100 | 101 | out = super(BatchNorm2dWithId, self).forward(input) 102 | activation_id = 'bn%d_activation' % self.id 103 | if QMI().stats_mode is StatsMode.collect_stats: 104 | QMI().stats_manager.save_tensor_stats(out, activation_id) 105 | elif QMI().stats_mode is StatsMode.use_stats: 106 | # Quantize using statistics 107 | out = QMI().quantize_instant(out, "activation", stat_id=activation_id) 108 | else: 109 | # No stats, quantize using actual values 110 | out = QMI().quantize_instant(out, "activation") 111 | return out 112 | 113 | 114 | class QuantizationManagerInference(QuantizationManagerBase): 115 | def __init__(self, args, qparams): 116 | super(QuantizationManagerInference, self).__init__() 117 | self.quantize = args.qtype is not None 118 | if self.quantize: 119 | print("Quantize to %s" % args.qtype) 120 | self.op_manager = self.createTruncationManager(args, qparams) 121 | self.enabled = False 122 | self.bn_folding = False 123 | sf = args.stats_folder if args.stats_folder is not None else args.arch 124 | if args.stats_mode == 'collect': 125 | self.stats_mode = StatsMode.collect_stats 126 | self.stats_manager = StatisticManager(sf, load_stats=False) 127 | elif args.stats_mode == 'use': 128 | self.stats_mode = StatsMode.use_stats 129 | self.stats_manager = StatisticManager(sf, load_stats=True) 130 | else: 131 | self.stats_mode = StatsMode.no_stats 132 | self.stats_manager = None 133 | 134 | def __exit__(self, *args): 135 | if self.stats_manager is not None: 136 | self.stats_manager.__exit__() 137 | super(QuantizationManagerInference, self).__exit__(args) 138 | 139 | def createTruncationManager(self, args, qparams): 140 | op_manager = TruncationOpManagerInference(args, qparams) 141 | op_manager.set_8bit_list(['conv0_activation']) 142 | return op_manager 143 | 144 | def quantize_instant(self, tensor, tag="", stat_id=None): 145 | return self.op_manager.quantize_instant(tensor, tag, stat_id) 146 | 147 | def set_8bit_list(self, ignore_ids): 148 | self.op_manager.set_8bit_list(ignore_ids) 149 | 150 | 151 | # Alias 152 | QMI = QuantizationManagerInference 153 | 154 | 155 | class TruncationOpManagerInference: 156 | def __load_quantizer__(self, qtype, qparams): 157 | qtype_name = qtype.rstrip('1234567890') 158 | quant_params = qparams[qtype_name] if qtype_name in qparams else {} 159 | quantizer = qtypes.__dict__[qtype_name + "_quantizer"](qtype, quant_params) 160 | return quantizer, quant_params 161 | 162 | def __init__(self, args, qparams): 163 | self.verbose = False 164 | self.activation_quantizer = None 165 | 166 | self.origin_linear = nn.Linear 167 | self.origin_conv2d = nn.Conv2d 168 | self.origin_batch_norm = nn.BatchNorm2d 169 | 170 | if args.qtype is not None: 171 | self.quantize = True 172 | self.activation_quantizer, _ = self.__load_quantizer__(args.qtype, qparams) 173 | self.linear_layer_quantizer, _ = self.__load_quantizer__('int8', qparams) 174 | # self.weights_quantizer = self.activation_quantizer 175 | self.weights_quantizer, _ = self.__load_quantizer__('int8', qparams) 176 | self.quantizer_4bit, _ = self.__load_quantizer__('int4', qparams) 177 | self.quantizer_8bit, _ = self.__load_quantizer__('int8', qparams) 178 | 179 | def set_8bit_list(self, ignore_list): 180 | self.ignore_ids = ignore_list 181 | 182 | def enable(self): 183 | # self.quantize_matmul() 184 | self.quantize_linear() 185 | self.quantize_conv2d() 186 | self.quantize_batch_norm() 187 | 188 | def disable(self): 189 | nn.Linear = self.origin_linear 190 | nn.Conv2d = self.origin_conv2d 191 | nn.BatchNorm2d = self.origin_batch_norm 192 | 193 | # quantizes origin matmul 194 | def quantize_matmul(self): 195 | def quantized_matmul(tensor1, tensor2): 196 | tensor1_ = attacher.pytorch_attach(tensor1, self.activation_quantizer, None) 197 | tensor2_ = attacher.pytorch_attach(tensor2, self.activation_quantizer, None) 198 | res = self.origin_matmul(tensor1_, tensor2_) 199 | return attacher.pytorch_attach(res, self.activation_quantizer, None) 200 | 201 | torch.Tensor.matmul = quantized_matmul 202 | 203 | # quantizes origin linear 204 | def quantize_linear(self): 205 | nn.Linear = LinearWithId 206 | 207 | # quantizes origin conv2d 208 | def quantize_conv2d(self): 209 | nn.Conv2d = Conv2dWithId 210 | 211 | def quantize_batch_norm(self): 212 | nn.BatchNorm2d = BatchNorm2dWithId 213 | 214 | 215 | def quantize_tensor(self, tensor, fprop=True, bprop=True): 216 | fprop = self.activation_quantizer if fprop else None 217 | return attacher.pytorch_attach(tensor, fprop, None) 218 | 219 | def quantize_instant(self, tensor, tag="", stat_id=None): 220 | # ignore quantization of first and last layer 221 | ignore_cond = False 222 | if stat_id is not None: 223 | ignore_cond = np.array([l == stat_id for l in self.ignore_ids]).any() 224 | if ignore_cond: 225 | return self.quantizer_8bit(tensor, tag, stat_id) 226 | # Leave classifier layer in 8 bit 227 | elif tag == 'activation_linear' and tensor.shape[1] == 1000: 228 | return self.linear_layer_quantizer(tensor, tag, stat_id) 229 | elif tag == 'activation': 230 | return self.activation_quantizer(tensor, tag, stat_id) 231 | else: # weight, bias 232 | return self.weights_quantizer(tensor, tag, stat_id) 233 | -------------------------------------------------------------------------------- /pytorch_quantizer/quantization/inference/statistic_manager.py: -------------------------------------------------------------------------------- 1 | from utils.misc import Singleton 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | import shutil 6 | from utils.misc import sorted_nicely 7 | import torch 8 | from pathlib import Path 9 | 10 | class StatisticManager(metaclass=Singleton): 11 | def __init__(self, folder, load_stats, stats = ['max', 'min', 'std', 'mean', 'kurtosis', 'mean_abs', 'b']): 12 | self.name = folder 13 | home = str(Path.home()) 14 | self.folder = os.path.join(home, 'asiq_data/statistics', folder) 15 | if not load_stats: 16 | print("Saving statistics to %s" % self.folder) 17 | self.stats_names = stats 18 | self.stats = {} 19 | self.save_stats = not load_stats 20 | if load_stats: 21 | stats_file = os.path.join(self.folder, '%s_summary.csv' % self.name) 22 | assert os.path.exists(stats_file), "Statistics not found, please run with '-sm collect' first" 23 | self.stats_df = pd.read_csv(stats_file, index_col=0) 24 | else: 25 | self.stats_df = None 26 | pass 27 | 28 | def save_tensor_stats(self, tensor, id): 29 | stat_arr = [] 30 | # Calculate tensor stats 31 | for sn in self.stats_names: 32 | if sn == 'kurtosis': 33 | t = tensor.view(tensor.shape[0], -1) 34 | st = torch.mean(((t - t.mean(-1).unsqueeze(-1)) / t.std(-1).unsqueeze(-1))**4, dim=-1) - 3 35 | elif sn == 'mean_abs': 36 | t = tensor.view(tensor.shape[0], -1) 37 | st = torch.mean(t.abs(), dim=-1) 38 | # st = torch.mean((t - t.mean(-1).unsqueeze(-1)).abs(), dim=-1) 39 | elif sn == 'b': 40 | t = tensor.view(tensor.shape[0], -1) 41 | st = torch.mean(torch.abs(t - t.mean(-1).unsqueeze(-1)), dim=-1) 42 | else: 43 | # collect statistics for entire mini batch 44 | st = getattr(tensor.view(tensor.shape[0], -1), sn)(-1) 45 | if type(st) is tuple: 46 | st = st[0] 47 | stat_arr.append(st.cpu().numpy()) 48 | 49 | # Add to stats dictionary 50 | if id in self.stats: 51 | stat_arr = np.vstack(stat_arr).transpose() 52 | s = np.concatenate([self.stats[id], stat_arr]) 53 | self.stats[id] = s 54 | else: 55 | self.stats[id] = np.vstack(stat_arr).transpose() 56 | 57 | def get_tensor_stats(self, id, kind={'min':'mean', 'max':'mean', 'mean': 'mean','std':'mean', 'range':'mean', 'mean_abs':'mean', 'b':'mean'}): 58 | if self.stats_df is not None: 59 | # TODO: add different options for min/max 60 | min_ = self.stats_df.loc[id, '%s_min' % kind['min']] 61 | max_ = self.stats_df.loc[id, '%s_max' % kind['max']] 62 | mean_ = self.stats_df.loc[id, '%s_mean' % kind['mean']] 63 | std_ = self.stats_df.loc[id, '%s_std' % kind['std']] 64 | range_ = self.stats_df.loc[id, '%s_range' % kind['range']] 65 | mean_abs_ = self.stats_df.loc[id, '%s_mean_abs' % kind['mean_abs']] 66 | b_ = self.stats_df.loc[id, '%s_b' % kind['b']] 67 | return min_, max_, mean_, std_, range_, mean_abs_, b_ 68 | else: 69 | return None, None, None, None, None, None, None 70 | 71 | def __exit__(self, *args): 72 | if self.save_stats: 73 | # Save statistics 74 | if os.path.exists(self.folder): 75 | shutil.rmtree(self.folder) 76 | if not os.path.exists(self.folder): 77 | os.makedirs(self.folder) 78 | all_stats_df = {} 79 | for s_id in self.stats: 80 | path = os.path.join(self.folder, '%s.csv' % s_id) 81 | df = pd.DataFrame(columns=self.stats_names, data=self.stats[s_id]) 82 | df.to_csv(path, index=False) 83 | all_stats_df[s_id] = df 84 | self.__save_summry(all_stats_df) 85 | 86 | def __save_summry(self, all_stats_df): 87 | columns = [] 88 | c_names = ['max', 'min', 'mean', 'std', 'range', 'kurtosis', 'mean_abs', 'b'] 89 | for c in c_names: 90 | columns.append('min_%s' % c) 91 | columns.append('mean_%s' % c) 92 | columns.append('max_%s' % c) 93 | 94 | df_summary = pd.DataFrame(columns=columns) 95 | for s_id in sorted_nicely(all_stats_df.keys()): 96 | all_stats_df[s_id]['range'] = all_stats_df[s_id]['max'] - all_stats_df[s_id]['min'] 97 | for c in c_names: 98 | df_summary.loc[s_id, 'min_%s' % c] = all_stats_df[s_id][c].min() 99 | df_summary.loc[s_id, 'mean_%s' % c] = all_stats_df[s_id][c].mean() 100 | df_summary.loc[s_id, 'max_%s' % c] = all_stats_df[s_id][c].max() 101 | path = os.path.join(self.folder, '%s_summary.csv' % self.name) 102 | df_summary.to_csv(path, index=True) 103 | -------------------------------------------------------------------------------- /pytorch_quantizer/quantization/qtypes/__init__.py: -------------------------------------------------------------------------------- 1 | from .int_quantizer import int_quantizer 2 | -------------------------------------------------------------------------------- /pytorch_quantizer/quantization/qtypes/int_quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import numpy as np 4 | import int_quantization 5 | import math 6 | from utils.monitor import Monitor 7 | from pytorch_quantizer.quantization.inference.statistic_manager import StatisticManager as SM 8 | 9 | # Alpha coeficients for for gaussian clipping 10 | # [1.71063519 2.15159277 2.55913646 2.93620062 3.28691474 3.6151146 3.92403714] 11 | 12 | # Alpha coeficients for for laplace clipping 13 | # [2.83068299 3.89722946 5.02864014 6.20476633 7.41312622 8.64561995 9.89675982] 14 | 15 | 16 | count = 0 17 | class IntQuantizer(Function): 18 | def __init__(self, size, params): 19 | self.num_bits = size 20 | # TODO: expose as cmd line parameters 21 | self.stochastic = False 22 | self.int_exp = False 23 | self.enforce_true_zero = True #params['true_zero'] 24 | self.clipping = params['threshold'] 25 | self.alpha_gaus = {2: 1.71, 3: 2.15, 4: 2.55, 5: 2.93, 6: 3.28, 7: 3.61, 8: 3.92} 26 | self.alpha_laplace = {2: 2.83, 3: 3.89, 4: 5.03, 5: 6.2, 6: 7.41, 7: 8.64, 8: 9.89} 27 | self.gaussian_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5) 28 | 29 | def __call__(self, tensor, tag="", stat_id=None): 30 | if self.clipping == 'no' or tag != 'activation': 31 | min_, max_, mean_ = None, None, None 32 | if stat_id is not None: 33 | # kind = {'min': 'mean', 'max': 'mean', 'mean': 'mean', 'std': 'mean', 'range': 'mean', 'mean_abs': 'mean', 'b': 'mean'} 34 | kind = {'min': 'mean', 'max': 'mean', 'mean': 'mean', 'std': 'mean', 'range': 'mean', 'mean_abs': 'mean', 'b': 'mean'} 35 | # Hack: handle classifier layer differently 36 | if tag == 'activation_linear' and tensor.shape[1] == 1000: 37 | kind['min'] = 'min' 38 | kind['max'] = 'max' 39 | kind['range'] = 'max' 40 | min_, max_, mean_, _, _, _, _ = SM().get_tensor_stats(stat_id, kind) 41 | # print("use stats for %s, min %f, max %f" % (stat_id, min_, max_)) 42 | return self.gemmlowpQuantize(tensor, min_value=min_, max_value=max_, mean=mean_) 43 | else: 44 | return self.gemmlowpClippingQuantize(tensor, tag, stat_id=stat_id, clip_type=self.clipping) 45 | 46 | def get_alpha_laplace(self, tensor, stat_id=None): 47 | if stat_id is not None: 48 | kind = {'min': 'mean', 'max': 'mean', 'mean': 'mean', 'std': 'mean', 'range': 'mean', 'mean_abs': 'mean', 49 | 'b': 'mean'} 50 | _, _, _, _, _, _, b = SM().get_tensor_stats(stat_id, kind) 51 | else: 52 | b = torch.mean(torch.abs(tensor - tensor.mean())).cpu().numpy() 53 | return self.alpha_laplace[self.num_bits] * b 54 | 55 | def get_alpha_gaus(self, tensor, tag, stat_id=None): 56 | if tag == 'activation' and len(tensor.shape) == 4: 57 | N = tensor.shape[1]*tensor.shape[2]*tensor.shape[3] 58 | else: 59 | N = tensor.view(-1).size()[0] 60 | if stat_id is not None: 61 | kind = {'min': 'mean', 'max': 'mean', 'mean': 'mean', 'std': 'mean', 'range': 'mean', 'mean_abs': 'mean', 62 | 'b': 'mean'} 63 | min_value, max_value, _, std, _, _, _ = SM().get_tensor_stats(stat_id, kind) 64 | else: 65 | # TODO: Average over batch 66 | min_value = tensor.min() 67 | max_value = tensor.max() 68 | 69 | std = ((max_value - min_value) * self.gaussian_const) / ((2 * math.log(N)) ** 0.5) 70 | return self.alpha_gaus[self.num_bits] * std 71 | 72 | def alpha2DeltaOffset(self, alpha, max_value, min_value, mean): 73 | max_range = max_value - min_value 74 | if alpha <= 0 or alpha >= max_range / 2: 75 | delta = max_range 76 | else: 77 | delta = 2 * alpha 78 | min_value = max(min_value, mean - delta / 2) 79 | 80 | return delta, min_value 81 | 82 | # Python implementation of gemmlowp quantization. Use for quick poc and experiments 83 | def gemmlowpClippingQuantize(self, input, tag="", stat_id=None, clip_type='laplace'): 84 | if stat_id is not None: 85 | kind = {'min': 'mean', 'max': 'mean', 'mean': 'mean', 'std': 'mean', 'range': 'mean', 'mean_abs': 'mean', 86 | 'b': 'mean'} 87 | min_value, max_value, mean, std, range, mean_abs, b = SM().get_tensor_stats(stat_id, kind) 88 | else: 89 | min_value = input.min() 90 | max_value = input.max() 91 | mean = input.mean() 92 | 93 | max_range = max_value - min_value 94 | 95 | if clip_type == 'laplace': 96 | alpha = self.get_alpha_laplace(input, stat_id) # laplace clipping 97 | elif clip_type == 'gaus': 98 | alpha = self.get_alpha_gaus(input, tag, stat_id) # gaussian clipping 99 | elif clip_type == 'exp': 100 | alpha = self.get_alpha_exp(input, stat_id) # exponential clipping 101 | elif clip_type == 'mix': 102 | alpha_laplace = self.get_alpha_laplace(input, stat_id) # laplace clipping 103 | alpha_gause = self.get_alpha_gaus(input, tag, stat_id) # gaussian clipping 104 | mse_est_laplace = IntQuantizer.mse_laplace(b, alpha_laplace, self.num_bits) 105 | mse_est_gaus = IntQuantizer.mse_gaus(std, alpha_gause, self.num_bits) 106 | if mse_est_laplace < mse_est_gaus: 107 | alpha = alpha_laplace 108 | else: 109 | alpha = alpha_gause 110 | elif clip_type == 'test': 111 | mse_laplace, mse_laplace_est = self.__clip_and_mse_mesure(input, tag, stat_id, 'laplace', max_value, min_value, mean, std, b) 112 | mse_gaus, mse_gaus_est = self.__clip_and_mse_mesure(input, tag, stat_id, 'gaus', max_value, min_value, mean, std, b) 113 | mse_no_clip, _ = self.__clip_and_mse_mesure(input, tag, stat_id, 'no', max_value, min_value, mean, std, b) 114 | min_mse_id = np.argmin([mse_no_clip, mse_laplace, mse_gaus]) 115 | clippings = ['no clipping', 'laplace', 'gaussian'] 116 | 117 | print("%s - MSE no clipping: %f, laplace: %f, laplace est: %f, gaussian: %f, gaussian est: %f, min mse: %s, bits: %d, std: %f, b: %f" % 118 | (stat_id, mse_no_clip, mse_laplace, mse_laplace_est, mse_gaus, mse_gaus_est, clippings[min_mse_id], self.num_bits, std, b)) 119 | return input 120 | else: 121 | # no clipping 122 | alpha = max_range/2 123 | 124 | delta, min_value = self.alpha2DeltaOffset(alpha, max_value, min_value, mean) 125 | res = self.__gemmlowpQuantize__(input.contiguous(), delta, min_value) 126 | 127 | return res 128 | 129 | def gemmlowpQuantize(self, tensor, min_value=None, max_value=None, mean=None): 130 | # TODO: Average over batch 131 | if min_value is None: 132 | min_value = tensor.detach().min() 133 | if max_value is None: 134 | max_value = tensor.detach().max() 135 | 136 | range = max_value - min_value 137 | return self.__gemmlowpQuantize__(tensor, range, min_value) 138 | 139 | @staticmethod 140 | def mse_laplace(b, alpha, num_bits): 141 | return 2 * (b ** 2) * np.exp(-alpha / b) + ((alpha ** 2) / (3 * 2 ** (2 * num_bits))) 142 | 143 | @staticmethod 144 | def mse_gaus(sigma, alpha, num_bits): 145 | clipping_err = (sigma ** 2 + (alpha ** 2)) * (1 - math.erf(alpha / (sigma * np.sqrt(2.0)))) - \ 146 | np.sqrt(2.0 / np.pi) * alpha * sigma * (np.e ** ((-1) * (0.5 * (alpha ** 2)) / sigma ** 2)) 147 | quant_err = (alpha ** 2) / (3 * (2 ** (2 * num_bits))) 148 | return clipping_err + quant_err 149 | 150 | def __clip_and_mse_mesure(self, tensor, tag, stat_id, clip_type, max_value, min_value, mean, std, b): 151 | if clip_type == 'laplace': 152 | alpha = self.get_alpha_laplace(tensor, stat_id) # laplace clipping 153 | mse_est = IntQuantizer.mse_laplace(b, alpha, self.num_bits) 154 | elif clip_type == 'gaus': 155 | alpha = self.get_alpha_gaus(tensor, tag, stat_id) # gaussian clipping 156 | mse_est = IntQuantizer.mse_gaus(std, alpha, self.num_bits) 157 | else: # no clipping 158 | alpha = (max_value - min_value)/2 159 | mse_est = -1 160 | 161 | delta, min_value = self.alpha2DeltaOffset(alpha, max_value, min_value, mean) 162 | res = self.__gemmlowpQuantize__(tensor.contiguous(), delta, min_value) 163 | mse = torch.mean((tensor - res)**2) 164 | del res 165 | return mse, mse_est 166 | 167 | 168 | def __gemmlowpQuantize__(self, tensor, delta, offset): 169 | if self.stochastic: 170 | # Generate noise for stochastic rounding 171 | noise = tensor.new(tensor.shape).uniform_(-0.5, 0.5) 172 | else: 173 | noise = torch.cuda.FloatTensor(tensor.shape).fill_(0) 174 | 175 | # if enforce_true_zero and zero in range 176 | preserve_zero = self.enforce_true_zero and (offset + delta) > 0 and offset < 0 177 | return int_quantization.float2gemmlowp(tensor.contiguous(), delta, offset, self.num_bits, self.int_exp, preserve_zero, noise) 178 | 179 | def int_quantizer(qtype, quant_params): 180 | if len(qtype) > len('int'): 181 | size = int(qtype[len('int'):]) 182 | else: 183 | size = 32 184 | 185 | return IntQuantizer(size, quant_params) 186 | -------------------------------------------------------------------------------- /pytorch_quantizer/quantization/quantization_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import qtypes 3 | from utils.misc import Singleton 4 | from utils import attacher 5 | from utils.monitor import Monitor 6 | import abc 7 | 8 | INFERENCE_ONLY = False 9 | 10 | class QuantizationManagerBase(metaclass=Singleton): 11 | def __init__(self): 12 | pass 13 | 14 | def __enter__(self): 15 | self.enable() 16 | return self 17 | 18 | def __exit__(self, *args): 19 | self.disable() 20 | 21 | @abc.abstractclassmethod 22 | def createTruncationManager(self, args, qparams): 23 | return 24 | 25 | def enable(self): 26 | if self.quantize: 27 | self.enabled = True 28 | self.op_manager.enable() 29 | 30 | def disable(self): 31 | self.enabled = False 32 | self.op_manager.disable() 33 | 34 | def reload(self, args, qparams={}): 35 | self.disable() 36 | self.op_manager = self.createTruncationManager(args, qparams) 37 | self.enable() 38 | 39 | def reduce_logging_verbosity(self): 40 | self.op_manager.verbose = False 41 | 42 | def quantize_tensor(self, tensor, fprop=True, bprop=True, quantize_tensor=False): 43 | if self.enabled and quantize_tensor: 44 | return self.op_manager.quantize_tensor(tensor, fprop, bprop) 45 | else: 46 | return tensor 47 | 48 | def quantize_fprop(self, tensor): 49 | return self.quantize_tensor(tensor, fprop=True, bprop=False) 50 | 51 | def quantize_bprop(self, tensor): 52 | return self.quantize_tensor(tensor, fprop=False, bprop=True) 53 | 54 | def quantize_instant(self, tensor, tag="", quantize_tensor=False): 55 | if self.quantize and quantize_tensor: 56 | return self.op_manager.quantize_instant(tensor, tag) 57 | else: 58 | return tensor 59 | 60 | class QuantizationManager(QuantizationManagerBase): 61 | def __init__(self, args, qparams): 62 | super(QuantizationManager, self).__init__() 63 | self.inference_only = INFERENCE_ONLY 64 | self.dual_precision = False 65 | self.quantize_batchnorm = args.quantize_bn 66 | self.quantize = args.qtype_fprop is not None or args.qtype_bprop is not None or args.qtype is not None 67 | self.op_manager = TruncationOpManager(args, qparams, self.inference_only, self.dual_precision) 68 | self.enabled = False 69 | 70 | def createTruncationManager(self, args, qparams): 71 | return TruncationOpManager(args, qparams) 72 | 73 | 74 | class TruncationOpManager: 75 | def __load_quantizer__(self, qtype, qparams): 76 | qtype_name = qtype.rstrip('1234567890') 77 | quant_params = qparams[qtype_name] if qtype_name in qparams else {} 78 | quantizer = qtypes.__dict__[qtype_name + "_quantizer"](qtype, quant_params) 79 | return quantizer, quant_params 80 | 81 | def __init__(self, args, qparams, inference_only=False, dual_precision=False): 82 | self.inference_only = inference_only 83 | self.dual_precision = dual_precision 84 | self.verbose = False 85 | self.bprop_quantizer = self.fprop_quantizer = None 86 | 87 | self.origin_matmul = torch.Tensor.matmul 88 | self.origin_linear = torch.nn.functional.linear 89 | self.origin_conv2d = torch.nn.functional.conv2d 90 | 91 | if args.qtype_fprop is not None: 92 | self.quantize = True 93 | self.fprop_quantizer, self.fprop_qparams = self.__load_quantizer__(args.qtype_fprop, qparams) 94 | if args.qtype_bprop is not None: 95 | self.quantize = True 96 | self.bprop_quantizer, self.bprop_qparams = self.__load_quantizer__(args.qtype_bprop, qparams) 97 | 98 | if args.qtype_fprop is None and args.qtype_bprop is None and args.qtype is not None: 99 | self.quantize = True 100 | self.fprop_quantizer, self.fprop_qparams = self.__load_quantizer__(args.qtype, qparams) 101 | self.bprop_quantizer, self.bprop_qparams = self.__load_quantizer__(args.qtype, qparams) 102 | 103 | def enable(self): 104 | self.quantize_matmul() 105 | # self.quantize_linear() 106 | self.quantize_conv2d() 107 | 108 | def disable(self): 109 | torch.Tensor.matmul = self.origin_matmul 110 | torch.nn.functional.linear = self.origin_linear 111 | torch.nn.functional.conv2d = self.origin_conv2d 112 | 113 | # quantizes origin matmul 114 | def quantize_matmul(self): 115 | def quantized_matmul(tensor1, tensor2): 116 | assert False 117 | tensor1_ = attacher.pytorch_attach(tensor1, self.fprop_quantizer, self.bprop_quantizer) 118 | tensor2_ = attacher.pytorch_attach(tensor2, self.fprop_quantizer, self.bprop_quantizer) 119 | res = self.origin_matmul(tensor1_, tensor2_) 120 | return attacher.pytorch_attach(res, self.fprop_quantizer, self.bprop_quantizer) 121 | 122 | torch.Tensor.matmul = quantized_matmul 123 | 124 | # quantizes origin linear 125 | def quantize_linear(self): 126 | def quantized_linear(input, weight, bias=None): 127 | if self.inference_only: 128 | weight_ = self.quantize_instant(weight, "weight") 129 | res = self.origin_linear(input, weight_, bias) 130 | return self.quantize_instant(res, "activation_linear") 131 | elif self.dual_precision: 132 | return self.dual_prec_linear(input, weight, bias) 133 | else: 134 | input_ = attacher.pytorch_attach(input, self.fprop_quantizer, self.bprop_quantizer, tag='activation/in') 135 | weight_ = attacher.pytorch_attach(weight, self.fprop_quantizer, self.bprop_quantizer, tag='weight') 136 | if bias is not None: 137 | bias_ = attacher.pytorch_attach(bias, self.fprop_quantizer, self.bprop_quantizer, tag='bias') 138 | else: 139 | bias_ = bias 140 | 141 | res = self.origin_linear(input_, weight_, bias_) 142 | return attacher.pytorch_attach(res, self.fprop_quantizer, self.bprop_quantizer, tag='activation_linear') 143 | 144 | torch.nn.functional.linear = quantized_linear 145 | 146 | # quantizes origin conv2d 147 | def quantize_conv2d(self): 148 | def quantized_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 149 | if self.inference_only: 150 | weight_ = self.quantize_instant(weight, "weight") 151 | return self.origin_conv2d(input, weight_, bias, stride, padding, dilation, groups) 152 | elif self.dual_precision: 153 | return self.dual_prec_conv2d(input, weight, bias, stride, padding, dilation, groups) 154 | else: 155 | input_ = attacher.pytorch_attach(input, self.fprop_quantizer, self.bprop_quantizer, tag='activation/in') 156 | weight_ = attacher.pytorch_attach(weight, self.fprop_quantizer, self.bprop_quantizer, tag='weight') 157 | if bias is not None: 158 | bias_ = attacher.pytorch_attach(bias, self.fprop_quantizer, self.bprop_quantizer, tag='bias') 159 | else: 160 | bias_ = bias 161 | 162 | res = self.origin_conv2d(input_, weight_, bias_, stride, padding, dilation, groups) 163 | return attacher.pytorch_attach(res, self.fprop_quantizer, self.bprop_quantizer, tag='activation') 164 | 165 | torch.nn.functional.conv2d = quantized_conv2d 166 | 167 | def quantize_tensor(self, tensor, fprop=True, bprop=True): 168 | fprop = self.fprop_quantizer if fprop else None 169 | bprop = self.bprop_quantizer if bprop else None 170 | return attacher.pytorch_attach(tensor, fprop, bprop) 171 | 172 | def quantize_instant(self, tensor, tag=""): 173 | return self.fprop_quantizer(tensor, tag) 174 | 175 | def dual_prec_conv2d(self, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 176 | # fprop conv2d quantized by fprop_quantizer 177 | input_fprop = attacher.pytorch_attach(input, self.fprop_quantizer, None, tag='activation/in') 178 | weight_fprop = attacher.pytorch_attach(weight, self.fprop_quantizer, None, tag='weight') 179 | if bias is not None: 180 | bias_fprop = attacher.pytorch_attach(bias, self.fprop_quantizer, None, tag='bias') 181 | else: 182 | bias_fprop = bias 183 | conv_fprop = self.origin_conv2d(input_fprop, weight_fprop, bias_fprop, stride, padding, dilation, groups) 184 | conv_fprop = attacher.pytorch_attach(conv_fprop, self.fprop_quantizer, None, tag='activation') 185 | 186 | # bprop conv2d quantized by bprop_quantizer 187 | input_bprop = attacher.pytorch_attach(input, None, self.bprop_quantizer, tag='activation/in') 188 | weight_bprop = attacher.pytorch_attach(weight, None, self.bprop_quantizer, tag='weight') 189 | if bias is not None: 190 | bias_bprop = attacher.pytorch_attach(bias, None, self.bprop_quantizer, tag='bias') 191 | else: 192 | bias_bprop = bias 193 | conv_bprop = self.origin_conv2d(input_bprop, weight_bprop, bias_bprop, stride, padding, dilation, groups) 194 | conv_bprop = attacher.pytorch_attach(conv_bprop, None, self.bprop_quantizer, tag='activation') 195 | return conv_fprop.detach() + conv_bprop - conv_bprop.detach() 196 | 197 | def dual_prec_linear(self, input, weight, bias=None): 198 | # fprop linear quantized by fprop_quantizer 199 | input_fprop = attacher.pytorch_attach(input, self.fprop_quantizer, None, tag='activation/in') 200 | weight_fprop = attacher.pytorch_attach(weight, self.fprop_quantizer, None, tag='weight') 201 | if bias is not None: 202 | bias_fprop = attacher.pytorch_attach(bias, self.fprop_quantizer, None, tag='bias') 203 | else: 204 | bias_fprop = bias 205 | linear_fprop = self.origin_linear(input_fprop, weight_fprop, bias_fprop) 206 | linear_fprop = attacher.pytorch_attach(linear_fprop, self.fprop_quantizer, None, tag='activation_linear') 207 | 208 | # bprop linear quantized by bprop_quantizer 209 | input_bprop = attacher.pytorch_attach(input, None, self.bprop_quantizer, tag='activation/in') 210 | weight_bprop = attacher.pytorch_attach(weight, None, self.bprop_quantizer, tag='weight') 211 | if bias is not None: 212 | bias_bprop = attacher.pytorch_attach(bias, None, self.bprop_quantizer, tag='bias') 213 | else: 214 | bias_bprop = bias 215 | linear_bprop = self.origin_linear(input_bprop, weight_bprop, bias_bprop) 216 | linear_bprop = attacher.pytorch_attach(linear_bprop, None, self.bprop_quantizer, tag='activation_linear') 217 | return linear_fprop.detach() + linear_bprop - linear_bprop.detach() 218 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submission2019/AnalyticalScaleForIntegerQuantization/3246ee8cbfb747d7ef821c8cecc50283a73eaf92/utils/__init__.py -------------------------------------------------------------------------------- /utils/absorb_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def absorb_bn(module, bn_module): 6 | w = module.weight.data 7 | if module.bias is None: 8 | zeros = torch.Tensor(module.out_channels).zero_().type(w.type()) 9 | module.bias = nn.Parameter(zeros) 10 | b = module.bias.data 11 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) 12 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) 13 | b.add_(-bn_module.running_mean).mul_(invstd) 14 | 15 | if bn_module.affine: 16 | w.mul_(bn_module.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) 17 | b.mul_(bn_module.weight.data).add_(bn_module.bias.data) 18 | 19 | bn_module.register_buffer('running_mean', torch.zeros(module.out_channels).cuda()) 20 | bn_module.register_buffer('running_var', torch.ones(module.out_channels).cuda()) 21 | bn_module.register_parameter('weight', None) 22 | bn_module.register_parameter('bias', None) 23 | bn_module.affine = False 24 | 25 | 26 | def is_bn(m): 27 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 28 | 29 | 30 | def is_absorbing(m): 31 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) 32 | 33 | 34 | def search_absorbe_bn(model): 35 | prev = None 36 | for m in model.children(): 37 | if is_bn(m) and is_absorbing(prev): 38 | absorb_bn(prev, m) 39 | search_absorbe_bn(m) 40 | prev = m 41 | -------------------------------------------------------------------------------- /utils/attacher.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | # f is any callable object 4 | 5 | # attacher to forward 6 | class attach_to_forward_class(Function): 7 | @staticmethod 8 | def forward(ctx, tensor, f, tag): 9 | # print('forward') 10 | # we want that output will have different id from input 11 | ctx.tag = tag 12 | return 1*f(tensor, tag) 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | return grad_output, None, None 17 | 18 | 19 | # attacher to backward 20 | class attach_to_backward_class(Function): 21 | @staticmethod 22 | def forward(ctx, tensor, f, tag): 23 | ctx.f = f 24 | ctx.tag = tag 25 | return 1*tensor 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | # print('backward') 30 | f = ctx.f 31 | return f(grad_output, ctx.tag), None, None 32 | 33 | 34 | # attacher to backward 35 | class attach_to_forward_backward_class(Function): 36 | @staticmethod 37 | def forward(ctx, tensor, f, b, tag): 38 | # print('forward') 39 | ctx.f = f 40 | ctx.b = b 41 | ctx.tag = tag 42 | return f(tensor, tag) 43 | 44 | @staticmethod 45 | def backward(ctx, grad_output): 46 | # print('backward') 47 | return ctx.b(grad_output, ctx.tag), None, None, None 48 | 49 | 50 | # attacher to forward and backward 51 | def pytorch_attach(tensor, f=None, b=None, tag=''): 52 | if f is not None and b is not None: 53 | tensor = attach_to_forward_backward_class.apply(tensor, f, b, tag) 54 | elif f is not None: 55 | tensor = attach_to_forward_class.apply(tensor, f, tag) 56 | elif b is not None: 57 | tensor = attach_to_backward_class.apply(tensor, b, tag) 58 | return tensor 59 | 60 | 61 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from numpy.random import choice 4 | 5 | class RandomSamplerReplacment(torch.utils.data.sampler.Sampler): 6 | """Samples elements randomly, with replacement. 7 | Arguments: 8 | data_source (Dataset): dataset to sample from 9 | """ 10 | 11 | def __init__(self, data_source): 12 | self.num_samples = len(data_source) 13 | 14 | def __iter__(self): 15 | return iter(torch.from_numpy(choice(self.num_samples, self.num_samples, replace=True))) 16 | 17 | def __len__(self): 18 | return self.num_samples 19 | 20 | 21 | class LimitDataset(Dataset): 22 | 23 | def __init__(self, dset, max_len): 24 | self.dset = dset 25 | self.max_len = max_len 26 | 27 | def __len__(self): 28 | return min(len(self.dset), self.max_len) 29 | 30 | def __getitem__(self, index): 31 | return self.dset[index] 32 | 33 | class ByClassDataset(Dataset): 34 | 35 | def __init__(self, ds): 36 | self.dataset = ds 37 | self.idx_by_class = {} 38 | for idx, (_, c) in enumerate(ds): 39 | self.idx_by_class.setdefault(c, []) 40 | self.idx_by_class[c].append(idx) 41 | 42 | def __len__(self): 43 | return min([len(d) for d in self.idx_by_class.values()]) 44 | 45 | def __getitem__(self, idx): 46 | idx_per_class = [self.idx_by_class[c][idx] 47 | for c in range(len(self.idx_by_class))] 48 | labels = torch.LongTensor([self.dataset[i][1] 49 | for i in idx_per_class]) 50 | items = [self.dataset[i][0] for i in idx_per_class] 51 | if torch.is_tensor(items[0]): 52 | items = torch.stack(items) 53 | 54 | return (items, labels) 55 | 56 | 57 | class IdxDataset(Dataset): 58 | """docstring for IdxDataset.""" 59 | 60 | def __init__(self, dset): 61 | super(IdxDataset, self).__init__() 62 | self.dset = dset 63 | self.idxs = range(len(self.dset)) 64 | 65 | def __getitem__(self, idx): 66 | data, labels = self.dset[self.idxs[idx]] 67 | return (idx, data, labels) 68 | 69 | def __len__(self): 70 | return len(self.idxs) 71 | -------------------------------------------------------------------------------- /utils/dump_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | from utils.misc import Singleton 5 | import os 6 | import shutil 7 | import uuid 8 | 9 | class DumpManager(metaclass=Singleton): 10 | def __init__(self, dump_dir=None): 11 | if dump_dir is None: 12 | raise Exception('dump_dir must be provided') 13 | 14 | self.dump_dir = dump_dir 15 | if os.path.exists(dump_dir): 16 | shutil.rmtree(dump_dir) 17 | os.makedirs(dump_dir) 18 | self.enabled = False 19 | self.tag = '' 20 | 21 | def __enter__(self): 22 | self.enabled = True 23 | return self 24 | 25 | def __exit__(self, *args): 26 | self.enabled = False 27 | 28 | def set_tag(self, tag): 29 | self.tag = tag 30 | 31 | def dump(self, tensor, name): 32 | if self.enabled: 33 | f = os.path.join(self.dump_dir, name + '_' + self.tag) 34 | print("dumping: %s" % f) 35 | np.save(f, tensor.cpu().numpy()) 36 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from itertools import cycle 4 | import torch 5 | import logging.config 6 | from datetime import datetime 7 | import json 8 | 9 | import pandas as pd 10 | from bokeh.io import output_file, save, show 11 | from bokeh.plotting import figure 12 | from bokeh.layouts import column 13 | from bokeh.models import Div 14 | 15 | try: 16 | import hyperdash 17 | HYPERDASH_AVAILABLE = True 18 | except ImportError: 19 | HYPERDASH_AVAILABLE = False 20 | 21 | 22 | def export_args_namespace(args, filename): 23 | """ 24 | args: argparse.Namespace 25 | arguments to save 26 | filename: string 27 | filename to save at 28 | """ 29 | with open(filename, 'w') as fp: 30 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) 31 | 32 | class logfile_filter: 33 | def filter(self, record): 34 | return record.levelname == 'DEBUG' 35 | 36 | def setup_logging(log_file='log.txt', resume=False): 37 | """ 38 | Setup logging configuration 39 | """ 40 | if os.path.isfile(log_file) and resume: 41 | file_mode = 'a' 42 | else: 43 | file_mode = 'w' 44 | 45 | root_logger = logging.getLogger() 46 | if root_logger.handlers: 47 | root_logger.removeHandler(root_logger.handlers[0]) 48 | logging.basicConfig(level=logging.DEBUG, 49 | format="%(asctime)s - %(levelname)s - %(message)s", 50 | datefmt="%Y-%m-%d %H:%M:%S", 51 | filename=log_file, 52 | filemode=file_mode) 53 | console = logging.StreamHandler() 54 | console.setLevel(logging.INFO) 55 | formatter = logging.Formatter('%(message)s') 56 | console.setFormatter(formatter) 57 | logging.getLogger('').addHandler(console) 58 | 59 | handler = logging.FileHandler(os.path.join(os.path.dirname(log_file), "quantizer-debug.log"), "w") 60 | handler.setLevel(logging.DEBUG) 61 | formatter = logging.Formatter("%(message)s") 62 | handler.setFormatter(formatter) 63 | handler.addFilter(logfile_filter()) 64 | logging.getLogger('').addHandler(handler) 65 | 66 | 67 | class ResultsLog(object): 68 | 69 | supported_data_formats = ['csv', 'json'] 70 | 71 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'): 72 | """ 73 | Parameters 74 | ---------- 75 | path: string 76 | path to directory to save data files 77 | plot_path: string 78 | path to directory to save plot files 79 | title: string 80 | title of HTML file 81 | params: Namespace 82 | optionally save parameters for results 83 | resume: bool 84 | resume previous logging 85 | data_format: str('csv'|'json') 86 | which file format to use to save the data 87 | """ 88 | if data_format not in ResultsLog.supported_data_formats: 89 | raise ValueError('data_format must of the following: ' + 90 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats])) 91 | 92 | if data_format == 'json': 93 | self.data_path = '{}.json'.format(path) 94 | else: 95 | self.data_path = '{}.csv'.format(path) 96 | if params is not None: 97 | export_args_namespace(params, '{}.json'.format(path)) 98 | self.plot_path = '{}.html'.format(path) 99 | self.results = None 100 | self.clear() 101 | self.first_save = True 102 | if os.path.isfile(self.data_path): 103 | if resume: 104 | self.load(self.data_path) 105 | self.first_save = False 106 | else: 107 | os.remove(self.data_path) 108 | self.results = pd.DataFrame() 109 | else: 110 | self.results = pd.DataFrame() 111 | 112 | self.title = title 113 | self.data_format = data_format 114 | 115 | if HYPERDASH_AVAILABLE: 116 | name = self.title if title != '' else path 117 | self.hd_experiment = hyperdash.Experiment(name) 118 | if params is not None: 119 | for k, v in params._get_kwargs(): 120 | self.hd_experiment.param(k, v, log=False) 121 | 122 | def clear(self): 123 | self.figures = [] 124 | 125 | def add(self, **kwargs): 126 | """Add a new row to the dataframe 127 | example: 128 | resultsLog.add(epoch=epoch_num, train_loss=loss, 129 | test_loss=test_loss) 130 | """ 131 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 132 | self.results = self.results.append(df, ignore_index=True) 133 | if hasattr(self, 'hd_experiment'): 134 | for k, v in kwargs.items(): 135 | self.hd_experiment.metric(k, v, log=False) 136 | 137 | def smooth(self, column_name, window): 138 | """Select an entry to smooth over time""" 139 | # TODO: smooth only new data 140 | smoothed_column = self.results[column_name].rolling( 141 | window=window, center=False).mean() 142 | self.results[column_name + '_smoothed'] = smoothed_column 143 | 144 | def save(self, title=None): 145 | """save the json file. 146 | Parameters 147 | ---------- 148 | title: string 149 | title of the HTML file 150 | """ 151 | title = title or self.title 152 | if len(self.figures) > 0: 153 | if os.path.isfile(self.plot_path): 154 | os.remove(self.plot_path) 155 | if self.first_save: 156 | self.first_save = False 157 | logging.info('Plot file saved at: {}'.format( 158 | os.path.abspath(self.plot_path))) 159 | 160 | output_file(self.plot_path, title=title) 161 | plot = column( 162 | Div(text='

{}

'.format(title)), *self.figures) 163 | save(plot) 164 | self.clear() 165 | 166 | if self.data_format == 'json': 167 | self.results.to_json(self.data_path, orient='records', lines=True) 168 | else: 169 | self.results.to_csv(self.data_path, index=False, index_label=False) 170 | 171 | def load(self, path=None): 172 | """load the data file 173 | Parameters 174 | ---------- 175 | path: 176 | path to load the json|csv file from 177 | """ 178 | path = path or self.data_path 179 | if os.path.isfile(path): 180 | if self.data_format == 'json': 181 | self.results.read_json(path) 182 | else: 183 | self.results.read_csv(path) 184 | else: 185 | raise ValueError('{} isn''t a file'.format(path)) 186 | 187 | def show(self, title=None): 188 | title = title or self.title 189 | if len(self.figures) > 0: 190 | plot = column( 191 | Div(text='

{}

'.format(title)), *self.figures) 192 | show(plot) 193 | 194 | def plot(self, x, y, title=None, xlabel=None, ylabel=None, legend=None, 195 | width=800, height=400, line_width=2, 196 | colors=['red', 'green', 'blue', 'orange', 197 | 'black', 'purple', 'brown'], 198 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save'): 199 | """ 200 | add a new plot to the HTML file 201 | example: 202 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 203 | 'title='Loss', 'ylabel'='loss') 204 | """ 205 | if not isinstance(y, list): 206 | y = [y] 207 | xlabel = xlabel or x 208 | legend = legend or y 209 | assert len(legend) == len(y) 210 | f = figure(title=title, tools=tools, 211 | width=width, height=height, 212 | x_axis_label=xlabel or x, 213 | y_axis_label=ylabel or '') 214 | colors = cycle(colors) 215 | for i, yi in enumerate(y): 216 | f.line(self.results[x], self.results[yi], 217 | line_width=line_width, 218 | line_color=next(colors), legend=legend[i]) 219 | f.legend.click_policy = "hide" 220 | self.figures.append(f) 221 | 222 | def image(self, *kargs, **kwargs): 223 | fig = figure() 224 | fig.image(*kargs, **kwargs) 225 | self.figures.append(fig) 226 | 227 | def end(self): 228 | if hasattr(self, 'hd_experiment'): 229 | self.hd_experiment.end() 230 | 231 | 232 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 233 | filename = os.path.join(path, filename) 234 | torch.save(state, filename) 235 | if is_best: 236 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 237 | if save_all: 238 | shutil.copyfile(filename, os.path.join( 239 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 240 | 241 | class EvalLog: 242 | def __init__(self, headers, f_name=None, auto_save=False): 243 | if auto_save and f_name is None: 244 | raise Exception('auto_save option requires to specify file name') 245 | self.df = pd.DataFrame(columns=headers) 246 | self.file_name = f_name 247 | self.auto_save = auto_save 248 | 249 | def log(self, *kargs): 250 | v = {} 251 | for i, arg in enumerate(kargs): 252 | v[self.df.columns[i]] = arg 253 | self.df.loc[len(self.df)] = ([arg for arg in kargs]) 254 | if self.auto_save: 255 | self.df.to_csv(self.file_name, index=False) 256 | 257 | def save(self, fpath): 258 | if not self.auto_save: 259 | self.df.to_csv(fpath, index=False) 260 | 261 | def __str__(self): 262 | return self.df.__str__() -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class OnlineMeter(object): 24 | """Computes and stores the average and variance/std values of tensor""" 25 | 26 | def __init__(self): 27 | self.mean = torch.FloatTensor(1).fill_(-1) 28 | self.M2 = torch.FloatTensor(1).zero_() 29 | self.count = 0. 30 | self.needs_init = True 31 | 32 | def reset(self, x): 33 | self.mean = x.new(x.size()).zero_() 34 | self.M2 = x.new(x.size()).zero_() 35 | self.count = 0. 36 | self.needs_init = False 37 | 38 | def update(self, x): 39 | self.val = x 40 | if self.needs_init: 41 | self.reset(x) 42 | self.count += 1 43 | delta = x - self.mean 44 | self.mean.add_(delta / self.count) 45 | delta2 = x - self.mean 46 | self.M2.add_(delta * delta2) 47 | 48 | @property 49 | def var(self): 50 | if self.count < 2: 51 | return self.M2.clone().zero_() 52 | return self.M2 / (self.count - 1) 53 | 54 | @property 55 | def std(self): 56 | return self.var().sqrt() 57 | 58 | 59 | def accuracy(output, target, topk=(1,)): 60 | """Computes the precision@k for the specified values of k""" 61 | maxk = max(topk) 62 | batch_size = target.size(0) 63 | 64 | _, pred = output.topk(maxk, 1, True, True) 65 | pred = pred.t().type_as(target) 66 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 67 | 68 | res = [] 69 | for k in topk: 70 | correct_k = correct[:k].view(-1).float().sum(0) 71 | res.append(correct_k.mul_(100.0 / batch_size)) 72 | return res 73 | 74 | 75 | class AccuracyMeter(object): 76 | """Computes and stores the average and current topk accuracy""" 77 | 78 | def __init__(self, topk=(1,)): 79 | self.topk = topk 80 | self.reset() 81 | 82 | def reset(self): 83 | self._meters = {} 84 | for k in self.topk: 85 | self._meters[k] = AverageMeter() 86 | 87 | def update(self, output, target): 88 | n = target.nelement() 89 | acc_vals = accuracy(output, target, self.topk) 90 | for i, k in enumerate(self.topk): 91 | self._meters[k].update(acc_vals[i]) 92 | 93 | @property 94 | def val(self): 95 | return {n: meter.val for (n, meter) in self._meters.items()} 96 | 97 | @property 98 | def avg(self): 99 | return {n: meter.avg for (n, meter) in self._meters.items()} 100 | 101 | @property 102 | def avg_error(self): 103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()} 104 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | torch_dtypes = { 6 | 'float': torch.float, 7 | 'float32': torch.float32, 8 | 'float64': torch.float64, 9 | 'double': torch.double, 10 | 'float16': torch.float16, 11 | 'half': torch.half, 12 | 'uint8': torch.uint8, 13 | 'int8': torch.int8, 14 | 'int16': torch.int16, 15 | 'short': torch.short, 16 | 'int32': torch.int32, 17 | 'int': torch.int, 18 | 'int64': torch.int64, 19 | 'long': torch.long 20 | } 21 | 22 | 23 | def onehot(indexes, N=None, ignore_index=None): 24 | """ 25 | Creates a one-representation of indexes with N possible entries 26 | if N is not specified, it will suit the maximum index appearing. 27 | indexes is a long-tensor of indexes 28 | ignore_index will be zero in onehot representation 29 | """ 30 | if N is None: 31 | N = indexes.max() + 1 32 | sz = list(indexes.size()) 33 | output = indexes.new().byte().resize_(*sz, N).zero_() 34 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 35 | if ignore_index is not None and ignore_index >= 0: 36 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 37 | return output 38 | 39 | 40 | def set_global_seeds(i): 41 | try: 42 | import torch 43 | except ImportError: 44 | pass 45 | else: 46 | torch.manual_seed(i) 47 | if torch.cuda.is_available(): 48 | torch.cuda.manual_seed_all(i) 49 | np.random.seed(i) 50 | random.seed(i) 51 | 52 | # The following is for monitoring 53 | class Singleton(type): 54 | _instances = {} 55 | 56 | def __call__(cls, *args, **kwargs): 57 | if cls not in cls._instances: 58 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 59 | return cls._instances[cls] 60 | 61 | 62 | import re 63 | 64 | 65 | def sorted_nicely(l): 66 | """ Sorts the given iterable in the way that is expected. 67 | 68 | Required arguments: 69 | l -- The iterable to be sorted. 70 | 71 | """ 72 | convert = lambda text: int(text) if text.isdigit() else text 73 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] 74 | return sorted(l, key=alphanum_key) 75 | -------------------------------------------------------------------------------- /utils/monitor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | from utils.misc import Singleton 5 | import os 6 | import shutil 7 | import uuid 8 | 9 | 10 | def patch_call(instance, func): 11 | class _(type(instance)): 12 | def __call__(self, *arg, **kwarg): 13 | return func(*arg, **kwarg) 14 | instance.__class__ = _ 15 | 16 | 17 | 18 | class Monitor(metaclass=Singleton): 19 | def __init__(self, dump_dir=None): 20 | if dump_dir is None: 21 | raise Exception('dump_dir must be provided') 22 | 23 | self.dump_dir = dump_dir 24 | if os.path.exists(dump_dir): 25 | shutil.rmtree(dump_dir) 26 | os.makedirs(dump_dir) 27 | 28 | self.observed_tensors = dict() 29 | self.observed_operations = dict() 30 | 31 | def register_tensor(self, tensor, key, retain_grad=False): 32 | if retain_grad: 33 | tensor.retain_grad() 34 | self.observed_tensors[key] = tensor 35 | 36 | def dump_tensors(self, epoch, step): 37 | grad_keys = [] 38 | for key in self.observed_tensors.keys(): 39 | tensor = self.observed_tensors[key] 40 | if tensor.grad is not None: 41 | grad_keys.append(key) 42 | for key in grad_keys: 43 | self.observed_tensors[key + '_grad'] = self.observed_tensors[key].grad 44 | self.observed_tensors[key] = self.observed_tensors[key].detach() 45 | for key in self.observed_tensors.keys(): 46 | self.observed_tensors[key] = self.observed_tensors[key].cpu() 47 | fname = 'epoch_' + str(epoch) + '_step_' + str(step) + '.pt' 48 | torch.save(self.observed_tensors, os.path.join(self.dump_dir, fname)) 49 | self.observed_tensors.clear() 50 | 51 | def clear_tensors(self): 52 | self.observed_tensors.clear() 53 | 54 | def register_operation(self, operation, key): 55 | self.observed_operations[key] = operation 56 | 57 | def dump_operations(self, epoch, step): 58 | for op_key in self.observed_operations.keys(): 59 | grad_keys = [] 60 | operation = self.observed_operations[op_key] 61 | for key in operation.keys(): 62 | if isinstance(operation[key], torch.Tensor): 63 | tensor = operation[key] 64 | if tensor.grad is not None: 65 | grad_keys.append(key) 66 | for key in grad_keys: 67 | operation[key + '_grad'] = operation[key].grad 68 | operation[key] = operation[key].detach() 69 | for key in operation.keys(): 70 | if isinstance(operation[key], torch.Tensor): 71 | operation[key] = operation[key].cpu() 72 | 73 | fname = 'epoch_' + str(epoch) + '_step_' + str(step) + '.pt' 74 | torch.save(self.observed_operations, os.path.join(self.dump_dir, fname)) 75 | self.observed_operations.clear() 76 | 77 | def clear_operations(self): 78 | self.observed_operations.clear() 79 | 80 | def register_Conv2d(self, Conv2d, retain_grad=False): 81 | Conv2d_dict = self.observed_operations[id(Conv2d)] = dict() 82 | Conv2d_dict['in_channels'] = Conv2d.in_channels 83 | Conv2d_dict['out_channels'] = Conv2d.out_channels 84 | Conv2d_dict['kernel_size'] = Conv2d.kernel_size 85 | Conv2d_dict['stride'] = Conv2d.stride 86 | Conv2d_dict['padding'] = Conv2d.padding 87 | Conv2d_dict['dilation'] = Conv2d.dilation 88 | Conv2d_dict['groups'] = Conv2d.groups 89 | if Conv2d.bias: 90 | Conv2d_dict['bias'] = Conv2d.bias 91 | Conv2d_dict['weight'] = Conv2d.weight 92 | __call__ = Conv2d.__call__ 93 | 94 | def call_warpper(input): 95 | Conv2d_dict['input'] = input 96 | if retain_grad: 97 | input.retain_grad() 98 | 99 | output = __call__(input) 100 | 101 | Conv2d_dict['output'] = output 102 | if retain_grad: 103 | output.retain_grad() 104 | 105 | return output 106 | 107 | patch_call(Conv2d, call_warpper) 108 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging.config 3 | from copy import deepcopy 4 | from six import string_types 5 | 6 | 7 | def eval_func(f, x): 8 | if isinstance(f, string_types): 9 | f = eval(f) 10 | return f(x) 11 | 12 | 13 | class OptimRegime(object): 14 | """ 15 | Reconfigures the optimizer according to setting list. 16 | Exposes optimizer methods - state, step, zero_grad, add_param_group 17 | 18 | Examples for regime: 19 | 20 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 21 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 22 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 23 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 24 | ]" 25 | 2) 26 | "[{'step_lambda': 27 | "lambda t: { 28 | 'optimizer': 'Adam', 29 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 30 | 'betas': (0.9, 0.98), 'eps':1e-9} 31 | }]" 32 | """ 33 | 34 | def __init__(self, params, regime): 35 | self.optimizer = torch.optim.SGD(params, lr=0) 36 | self.regime = regime 37 | self.current_regime_phase = None 38 | self.setting = {} 39 | 40 | def update(self, epoch, train_steps): 41 | """adjusts optimizer according to current epoch or steps and training regime. 42 | """ 43 | if self.regime is None: 44 | return 45 | update_optimizer = False 46 | if self.current_regime_phase is None: 47 | update_optimizer = True 48 | setting = {} 49 | # Find the first entry where the epoch is smallest than current 50 | for regime_phase, regime_setting in enumerate(self.regime): 51 | start_epoch = regime_setting.get('epoch', 0) 52 | start_step = regime_setting.get('step', 0) 53 | if epoch >= start_epoch or train_steps >= start_step: 54 | self.current_regime_phase = regime_phase 55 | break 56 | if len(self.regime) > self.current_regime_phase + 1: 57 | next_phase = self.current_regime_phase + 1 58 | # Any more regime steps? 59 | start_epoch = self.regime[next_phase].get('epoch', float('inf')) 60 | start_step = self.regime[next_phase].get('step', float('inf')) 61 | if epoch >= start_epoch or train_steps >= start_step: 62 | self.current_regime_phase = next_phase 63 | update_optimizer = True 64 | 65 | setting = deepcopy(self.regime[self.current_regime_phase]) 66 | 67 | if 'lr_decay_rate' in setting and 'lr' in setting: 68 | decay_steps = setting.get('lr_decay_steps', 100) 69 | if train_steps % decay_steps == 0: 70 | decay_rate = setting['lr_decay_rate'] 71 | setting['lr'] *= decay_rate ** (train_steps / decay_steps) 72 | update_optimizer = True 73 | elif 'step_lambda' in setting: 74 | setting.update(eval_func(setting['step_lambda'], train_steps)) 75 | update_optimizer = True 76 | elif 'epoch_lambda' in setting: 77 | setting.update(eval_func(setting['epoch_lambda'], epoch)) 78 | update_optimizer = True 79 | 80 | if update_optimizer: 81 | self.adjust(setting) 82 | 83 | def adjust(self, setting): 84 | """adjusts optimizer according to a setting dict. 85 | e.g: setting={optimizer': 'Adam', 'lr': 5e-4} 86 | """ 87 | if 'optimizer' in setting: 88 | optim_method = torch.optim.__dict__[setting['optimizer']] 89 | if not isinstance(self.optimizer, optim_method): 90 | self.optimizer = optim_method(self.optimizer.param_groups) 91 | logging.debug('OPTIMIZER - setting method = %s' % 92 | setting['optimizer']) 93 | for param_group in self.optimizer.param_groups: 94 | for key in param_group.keys(): 95 | if key in setting: 96 | new_val = setting[key] 97 | if new_val != param_group[key]: 98 | logging.debug('OPTIMIZER - setting %s = %s' % 99 | (key, setting[key])) 100 | param_group[key] = setting[key] 101 | self.setting = deepcopy(setting) 102 | 103 | def __getstate__(self): 104 | return { 105 | 'optimizer_state': self.optimizer.__getstate__(), 106 | 'regime': self.regime, 107 | } 108 | 109 | def __setstate__(self, state): 110 | self.regime = state.get('regime') 111 | self.optimizer.__setstate__(state.get('optimizer_state')) 112 | 113 | def state_dict(self): 114 | """Returns the state of the optimizer as a :class:`dict`. 115 | """ 116 | return { 117 | 'optimizer_state': self.optimizer.state_dict(), 118 | 'regime': self.regime, 119 | } 120 | 121 | def load_state_dict(self, state_dict): 122 | """Loads the optimizer state. 123 | 124 | Arguments: 125 | state_dict (dict): optimizer state. Should be an object returned 126 | from a call to :meth:`state_dict`. 127 | """ 128 | # deepcopy, to be consistent with module API 129 | optimizer_state_dict = state_dict['optimizer_state'] 130 | 131 | self.__setstate__({'optimizer_state': optimizer_state_dict, 132 | 'regime': state_dict['regime']}) 133 | 134 | def zero_grad(self): 135 | """Clears the gradients of all optimized :class:`Variable` s.""" 136 | self.optimizer.zero_grad() 137 | 138 | def step(self, closure=None): 139 | """Performs a single optimization step (parameter update). 140 | 141 | Arguments: 142 | closure (callable): A closure that reevaluates the model and 143 | returns the loss. Optional for most optimizers. 144 | """ 145 | self.optimizer.step(closure) 146 | 147 | def add_param_group(self, param_group): 148 | """Add a param group to the :class:`Optimizer` s `param_groups`. 149 | 150 | This can be useful when fine tuning a pre-trained network as frozen layers can be made 151 | trainable and added to the :class:`Optimizer` as training progresses. 152 | 153 | Arguments: 154 | param_group (dict): Specifies what Variables should be optimized along with group 155 | specific optimization options. 156 | """ 157 | self.optimizer.add_param_group(param_group) 158 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | __imagenet_pca = { 9 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 10 | 'eigvec': torch.Tensor([ 11 | [-0.5675, 0.7192, 0.4009], 12 | [-0.5808, -0.0045, -0.8140], 13 | [-0.5836, -0.6948, 0.4203], 14 | ]) 15 | } 16 | 17 | 18 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 19 | t_list = [ 20 | transforms.CenterCrop(input_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize(**normalize), 23 | ] 24 | if scale_size != input_size: 25 | t_list = [transforms.Resize(scale_size)] + t_list 26 | 27 | return transforms.Compose(t_list) 28 | 29 | 30 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 31 | t_list = [ 32 | transforms.RandomCrop(input_size), 33 | transforms.ToTensor(), 34 | transforms.Normalize(**normalize), 35 | ] 36 | if scale_size != input_size: 37 | t_list = [transforms.Resize(scale_size)] + t_list 38 | 39 | transforms.Compose(t_list) 40 | 41 | 42 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 43 | padding = int((scale_size - input_size) / 2) 44 | return transforms.Compose([ 45 | transforms.RandomCrop(input_size, padding=padding), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(**normalize), 49 | ]) 50 | 51 | 52 | def inception_preproccess(input_size, normalize=__imagenet_stats): 53 | return transforms.Compose([ 54 | transforms.RandomResizedCrop(input_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize(**normalize) 58 | ]) 59 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 60 | return transforms.Compose([ 61 | transforms.RandomResizedCrop(input_size), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | ColorJitter( 65 | brightness=0.4, 66 | contrast=0.4, 67 | saturation=0.4, 68 | ), 69 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 70 | transforms.Normalize(**normalize) 71 | ]) 72 | 73 | 74 | def get_transform(name='imagenet', input_size=None, 75 | scale_size=None, normalize=None, augment=True): 76 | normalize = normalize or __imagenet_stats 77 | if name == 'imagenet': 78 | scale_size = scale_size or 256 79 | input_size = input_size or 224 80 | if augment: 81 | return inception_preproccess(input_size, normalize=normalize) 82 | else: 83 | return scale_crop(input_size=input_size, 84 | scale_size=scale_size, normalize=normalize) 85 | elif 'cifar' in name: 86 | input_size = input_size or 32 87 | if augment: 88 | scale_size = scale_size or 40 89 | return pad_random_crop(input_size, scale_size=scale_size, 90 | normalize=normalize) 91 | else: 92 | scale_size = scale_size or 32 93 | return scale_crop(input_size=input_size, 94 | scale_size=scale_size, normalize=normalize) 95 | elif name == 'mnist': 96 | normalize = {'mean': [0.5], 'std': [0.5]} 97 | input_size = input_size or 28 98 | if augment: 99 | scale_size = scale_size or 32 100 | return pad_random_crop(input_size, scale_size=scale_size, 101 | normalize=normalize) 102 | else: 103 | scale_size = scale_size or 32 104 | return scale_crop(input_size=input_size, 105 | scale_size=scale_size, normalize=normalize) 106 | 107 | 108 | class Lighting(object): 109 | """Lighting noise(AlexNet - style PCA - based noise)""" 110 | 111 | def __init__(self, alphastd, eigval, eigvec): 112 | self.alphastd = alphastd 113 | self.eigval = eigval 114 | self.eigvec = eigvec 115 | 116 | def __call__(self, img): 117 | if self.alphastd == 0: 118 | return img 119 | 120 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 121 | rgb = self.eigvec.type_as(img).clone()\ 122 | .mul(alpha.view(1, 3).expand(3, 3))\ 123 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 124 | .sum(1).squeeze() 125 | 126 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 127 | 128 | 129 | class Grayscale(object): 130 | 131 | def __call__(self, img): 132 | gs = img.clone() 133 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 134 | gs[1].copy_(gs[0]) 135 | gs[2].copy_(gs[0]) 136 | return gs 137 | 138 | 139 | class Saturation(object): 140 | 141 | def __init__(self, var): 142 | self.var = var 143 | 144 | def __call__(self, img): 145 | gs = Grayscale()(img) 146 | alpha = random.uniform(0, self.var) 147 | return img.lerp(gs, alpha) 148 | 149 | 150 | class Brightness(object): 151 | 152 | def __init__(self, var): 153 | self.var = var 154 | 155 | def __call__(self, img): 156 | gs = img.new().resize_as_(img).zero_() 157 | alpha = random.uniform(0, self.var) 158 | return img.lerp(gs, alpha) 159 | 160 | 161 | class Contrast(object): 162 | 163 | def __init__(self, var): 164 | self.var = var 165 | 166 | def __call__(self, img): 167 | gs = Grayscale()(img) 168 | gs.fill_(gs.mean()) 169 | alpha = random.uniform(0, self.var) 170 | return img.lerp(gs, alpha) 171 | 172 | 173 | class RandomOrder(object): 174 | """ Composes several transforms together in random order. 175 | """ 176 | 177 | def __init__(self, transforms): 178 | self.transforms = transforms 179 | 180 | def __call__(self, img): 181 | if self.transforms is None: 182 | return img 183 | order = torch.randperm(len(self.transforms)) 184 | for i in order: 185 | img = self.transforms[i](img) 186 | return img 187 | 188 | 189 | class ColorJitter(RandomOrder): 190 | 191 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 192 | self.transforms = [] 193 | if brightness != 0: 194 | self.transforms.append(Brightness(brightness)) 195 | if contrast != 0: 196 | self.transforms.append(Contrast(contrast)) 197 | if saturation != 0: 198 | self.transforms.append(Saturation(saturation)) 199 | --------------------------------------------------------------------------------