├── models ├── __init__.py ├── mobilenet_tf │ └── download_pretrained_mobilenet.sh ├── mobilenet.py ├── linear_quantized_modules.py └── preload_mobilenet.py ├── quantization ├── __init__.py ├── quant_auto.py └── quantop.py ├── data.py ├── utils.py ├── README.md ├── preprocess.py └── main_binary.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenet import * 2 | 3 | -------------------------------------------------------------------------------- /quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantop import * 2 | from .quant_auto import * 3 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | 5 | _IMAGENET_MAIN_PATH = '' 6 | _DATASETS_MAIN_PATH = '' 7 | 8 | _dataset_path = { 9 | 'cifar10': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR10'), 10 | 'cifar100': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR100'), 11 | 'stl10': os.path.join(_DATASETS_MAIN_PATH, 'STL10'), 12 | 'mnist': os.path.join(_DATASETS_MAIN_PATH, 'MNIST'), 13 | 'imagenet': { 14 | 'train': os.path.join(_IMAGENET_MAIN_PATH, 'train'), 15 | 'val': os.path.join(_IMAGENET_MAIN_PATH, 'val') 16 | }, 17 | } 18 | 19 | 20 | def get_dataset(name, split='train', transform=None, 21 | target_transform=None, download=True): 22 | train = (split == 'train') 23 | if name == 'mnist': 24 | return datasets.MNIST( root=_dataset_path['mnist'], 25 | train=train, 26 | transform=transform, 27 | target_transform=target_transform, 28 | download=download) 29 | elif name == 'cifar10': 30 | return datasets.CIFAR10(root=_dataset_path['cifar10'], 31 | train=train, 32 | transform=transform, 33 | target_transform=target_transform, 34 | download=download) 35 | elif name == 'cifar100': 36 | return datasets.CIFAR100(root=_dataset_path['cifar100'], 37 | train=train, 38 | transform=transform, 39 | target_transform=target_transform, 40 | download=download) 41 | elif name == 'imagenet': 42 | path = _dataset_path[name][split] 43 | return datasets.ImageFolder(root=path, 44 | transform=transform, 45 | target_transform=target_transform) 46 | 47 | def get_num_classes(name): 48 | if name == 'mnist': 49 | return 10 50 | elif name == 'cifar10': 51 | return 10 52 | elif name == 'cifar100': 53 | return 100 54 | elif name == 'imagenet': 55 | return 1000 -------------------------------------------------------------------------------- /models/mobilenet_tf/download_pretrained_mobilenet.sh: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | ########### 3 | ########################################################################### 4 | 5 | 6 | # 128 7 | mkdir 128_1.0/ 8 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128.tgz 9 | tar -xvf mobilenet_v1_1.0_128.tgz 10 | mv mobilenet_v1_1.0_128.ckpt.* 128_1.0/ 11 | rm -rf mobilenet_v1_1.0_128* 12 | 13 | mkdir 128_0.75/ 14 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128.tgz 15 | tar -xvf mobilenet_v1_0.75_128.tgz 16 | mv mobilenet_v1_0.75_128.ckpt.* 128_0.75/ 17 | rm -rf mobilenet_v1_0.75_128* 18 | 19 | mkdir 128_0.5/ 20 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128.tgz 21 | tar -xvf mobilenet_v1_0.5_128.tgz 22 | mv mobilenet_v1_0.5_128.ckpt.* 128_0.5/ 23 | rm -rf mobilenet_v1_0.5_128* 24 | 25 | mkdir 128_0.25/ 26 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128.tgz 27 | tar -xvf mobilenet_v1_0.25_128.tgz 28 | mv mobilenet_v1_0.25_128.ckpt.* 128_0.25/ 29 | rm -rf mobilenet_v1_0.25_128* 30 | 31 | 32 | # 160 33 | mkdir 160_1.0/ 34 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160.tgz 35 | tar -xvf mobilenet_v1_1.0_160.tgz 36 | mv mobilenet_v1_1.0_160.ckpt.* 160_1.0/ 37 | rm -rf mobilenet_v1_1.0_160* 38 | 39 | mkdir 160_0.75/ 40 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160.tgz 41 | tar -xvf mobilenet_v1_0.75_160.tgz 42 | mv mobilenet_v1_0.75_160.ckpt.* 160_0.75/ 43 | rm -rf mobilenet_v1_0.75_160* 44 | 45 | mkdir 160_0.5/ 46 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160.tgz 47 | tar -xvf mobilenet_v1_0.5_160.tgz 48 | mv mobilenet_v1_0.5_160.ckpt.* 160_0.5/ 49 | rm -rf mobilenet_v1_0.5_160* 50 | 51 | mkdir 160_0.25/ 52 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160.tgz 53 | tar -xvf mobilenet_v1_0.25_160.tgz 54 | mv mobilenet_v1_0.25_160.ckpt.* 160_0.25/ 55 | rm -rf mobilenet_v1_0.25_160* 56 | 57 | # 192 58 | mkdir 192_1.0/ 59 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192.tgz 60 | tar -xvf mobilenet_v1_1.0_192.tgz 61 | mv mobilenet_v1_1.0_192.ckpt.* 192_1.0/ 62 | rm -rf mobilenet_v1_1.0_192* 63 | 64 | mkdir 192_0.75/ 65 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192.tgz 66 | tar -xvf mobilenet_v1_0.75_192.tgz 67 | mv mobilenet_v1_0.75_192.ckpt.* 192_0.75/ 68 | rm -rf mobilenet_v1_0.75_192* 69 | 70 | mkdir 192_0.5/ 71 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192.tgz 72 | tar -xvf mobilenet_v1_0.5_192.tgz 73 | mv mobilenet_v1_0.5_192.ckpt.* 192_0.5/ 74 | rm -rf mobilenet_v1_0.5_192* 75 | 76 | mkdir 192_0.25/ 77 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192.tgz 78 | tar -xvf mobilenet_v1_0.25_192.tgz 79 | mv mobilenet_v1_0.25_192.ckpt.* 192_0.25/ 80 | rm -rf mobilenet_v1_0.25_192* 81 | 82 | 83 | # 224 84 | mkdir 224_1.0/ 85 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz 86 | tar -xvf mobilenet_v1_1.0_224.tgz 87 | mv mobilenet_v1_1.0_224.ckpt.* 224_1.0/ 88 | rm -rf mobilenet_v1_1.0_224* 89 | 90 | mkdir 224_0.75/ 91 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224.tgz 92 | tar -xvf mobilenet_v1_0.75_224.tgz 93 | mv mobilenet_v1_0.75_224.ckpt.* 224_0.75/ 94 | rm -rf mobilenet_v1_0.75_224* 95 | 96 | mkdir 224_0.5/ 97 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224.tgz 98 | tar -xvf mobilenet_v1_0.5_224.tgz 99 | mv mobilenet_v1_0.5_224.ckpt.* 224_0.5/ 100 | rm -rf mobilenet_v1_0.5_224* 101 | 102 | mkdir 224_0.25/ 103 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224.tgz 104 | tar -xvf mobilenet_v1_0.25_224.tgz 105 | mv mobilenet_v1_0.25_224.ckpt.* 224_0.25/ 106 | rm -rf mobilenet_v1_0.25_224* 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging.config 4 | import shutil 5 | import pandas as pd 6 | from bokeh.io import output_file, save, show 7 | from bokeh.plotting import figure 8 | from bokeh.layouts import column 9 | #from bokeh.charts import Line, defaults 10 | # 11 | #defaults.width = 800 12 | #defaults.height = 400 13 | #defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save' 14 | 15 | 16 | def setup_logging(log_file='log.txt'): 17 | """Setup logging configuration 18 | """ 19 | logging.basicConfig(level=logging.DEBUG, 20 | format="%(asctime)s - %(levelname)s - %(message)s", 21 | datefmt="%Y-%m-%d %H:%M:%S", 22 | filename=log_file, 23 | filemode='w') 24 | console = logging.StreamHandler() 25 | console.setLevel(logging.INFO) 26 | formatter = logging.Formatter('%(message)s') 27 | console.setFormatter(formatter) 28 | logging.getLogger('').addHandler(console) 29 | 30 | 31 | class ResultsLog(object): 32 | 33 | def __init__(self, path='results.csv', plot_path=None): 34 | self.path = path 35 | self.plot_path = plot_path or (self.path + '.html') 36 | self.figures = [] 37 | self.results = pd.DataFrame() 38 | 39 | def add(self, **kwargs): 40 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 41 | if self.results is None: 42 | self.results = df 43 | else: 44 | self.results = self.results.append(df, ignore_index=True) 45 | 46 | def save(self, title='Training Results'): 47 | if len(self.figures) > 0: 48 | if os.path.isfile(self.plot_path): 49 | os.remove(self.plot_path) 50 | output_file(self.plot_path, title=title) 51 | plot = column(*self.figures) 52 | save(plot) 53 | self.figures = [] 54 | self.results.to_csv(self.path, index=False, index_label=False) 55 | 56 | def load(self, path=None): 57 | path = path or self.path 58 | if os.path.isfile(path): 59 | self.results = pd.read_csv(path) 60 | 61 | def show(self): 62 | if len(self.figures) > 0: 63 | plot = column(*self.figures) 64 | show(plot) 65 | 66 | #def plot(self, *kargs, **kwargs): 67 | # line = Line(data=self.results, *kargs, **kwargs) 68 | # self.figures.append(line) 69 | 70 | def image(self, *kargs, **kwargs): 71 | fig = figure() 72 | fig.image(*kargs, **kwargs) 73 | self.figures.append(fig) 74 | 75 | 76 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 77 | filename = os.path.join(path, filename) 78 | torch.save(state, filename) 79 | if is_best: 80 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 81 | if save_all: 82 | shutil.copyfile(filename, os.path.join( 83 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 84 | 85 | 86 | class AverageMeter(object): 87 | """Computes and stores the average and current value""" 88 | 89 | def __init__(self): 90 | self.reset() 91 | 92 | def reset(self): 93 | self.val = 0 94 | self.avg = 0 95 | self.sum = 0 96 | self.count = 0 97 | 98 | def update(self, val, n=1): 99 | self.val = val 100 | self.sum += val * n 101 | self.count += n 102 | self.avg = self.sum / self.count 103 | 104 | __optimizers = { 105 | 'SGD': torch.optim.SGD, 106 | 'ASGD': torch.optim.ASGD, 107 | 'Adam': torch.optim.Adam, 108 | 'Adamax': torch.optim.Adamax, 109 | 'Adagrad': torch.optim.Adagrad, 110 | 'Adadelta': torch.optim.Adadelta, 111 | 'Rprop': torch.optim.Rprop, 112 | 'RMSprop': torch.optim.RMSprop 113 | } 114 | 115 | 116 | def adjust_optimizer(optimizer, epoch, config): 117 | """Reconfigures the optimizer according to epoch and config dict""" 118 | def modify_optimizer(optimizer, setting): 119 | if 'optimizer' in setting: 120 | optimizer = __optimizers[setting['optimizer']]( 121 | optimizer.param_groups) 122 | logging.debug('OPTIMIZER - setting method = %s' % 123 | setting['optimizer']) 124 | for param_group in optimizer.param_groups: 125 | for key in param_group.keys(): 126 | if key in setting: 127 | logging.debug('OPTIMIZER - setting %s = %s' % 128 | (key, setting[key])) 129 | param_group[key] = setting[key] 130 | return optimizer 131 | 132 | if callable(config): 133 | optimizer = modify_optimizer(optimizer, config(epoch)) 134 | else: 135 | for e in range(epoch + 1): # run over all epochs - sticky setting 136 | if e in config: 137 | optimizer = modify_optimizer(optimizer, config[e]) 138 | 139 | return optimizer 140 | 141 | 142 | def accuracy(output, target, topk=(1,)): 143 | """Computes the precision@k for the specified values of k""" 144 | maxk = max(topk) 145 | batch_size = target.size(0) 146 | 147 | _, pred = output.float().topk(maxk, 1, True, True) 148 | pred = pred.t() 149 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 150 | 151 | res = [] 152 | for k in topk: 153 | correct_k = correct[:k].view(-1).float().sum(0) 154 | res.append(correct_k.mul_(100.0 / batch_size)) 155 | return res 156 | 157 | # kernel_img = model.features[0][0].kernel.data.clone() 158 | # kernel_img.add_(-kernel_img.min()) 159 | # kernel_img.mul_(255 / kernel_img.max()) 160 | # save_image(kernel_img, 'kernel%s.jpg' % epoch) 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training Mixed-Precision Quantized Neural Networks for microcontroller deployments 2 | 3 | ### Description 4 | This project targets quantization-aware training methodologies on Pytorch for microcontroller deployment of quantized neural networks. The featured mixed-precision quantization techniques aim at byte or sub-byte quantization, i.e. INT8, INT4, INT2. The generated network for deployment supports integer arithmetic only. Optionally, the selection of individual per-tensor bit precision is driven by the device memory constraints. 5 | 6 | ### Reference 7 | Please, cite this paper [arXiv](https://arxiv.org/abs/1905.13082) when using the code. 8 | ``` 9 | @article{rusci2019memory, 10 | title={Memory-Driven Mixed Low Precision Quantization For Enabling Deep Network Inference On Microcontrollers}, 11 | author={Rusci, Manuele and Capotondi, Alessandro and Benini, Luca}, 12 | journal={arXiv preprint arXiv:1905.13082}, 13 | year={2019} 14 | } 15 | 16 | ``` 17 | 18 | ### Questions 19 | For any question just drop me an [email](mailto:manuele.rusci@unibo.it). 20 | 21 | ## Getting Started 22 | 23 | ### Prerequisites 24 | - The code is tested with PyTorch 0.4.1 and Python 3.5 25 | - Tensorflow package is needed to load pretrained tensorflow model weights 26 | 27 | ### Setup 28 | Set the correct dataset paths inside `data.py` . As an example: 29 | ``` 30 | _IMAGENET_MAIN_PATH = '/home/user/ImagenetDataset/' 31 | _DATASETS_MAIN_PATH = './datasets/' 32 | ``` 33 | To download pretrained mobilenet weights: 34 | ``` 35 | $ cd models/mobilenet_tf/ 36 | $ source download_pretrained_mobilenet.sh 37 | ``` 38 | 39 | ### Quickstart 40 | For quantization-aware retraining of a 8-bit integer only mobilenet model type: 41 | ``` 42 | $ python3 main_binary.py -a mobilenet --mobilenet_width 1.0 --mobilenet_input 224 --save Imagenet/mobilenet_224_1.0_w8a8 --dataset imagenet --type_quant 'PerLayerAsymPACT' --weight_bits 8 --activ_bits 8 --activ_type learned --gpus 0,1,2,3 -j 8 --epochs 12 -b 128 --save_check --quantizer --batch_fold_delay 1 --batch_fold_type folding_weights 43 | ``` 44 | 45 | ### Quantization Options 46 | 47 | - *quantizer*: enables quantization when True 48 | - *type_quant*: type of weight uantization method to apply (see below) 49 | - *weight_bits*: number of bits for weights quantization 50 | - *activ_bits*: number of activation bits 51 | - *activ_type*: type of quantized activation layers 52 | - *batch_fold_delay*: number of epochs before freezing batch norm parameters 53 | - *batch_fold_type*: how to deal with folding of batch norm parameters (or any other scalar params). \[Supported: 'folding_weights' | 'ICN'\] 54 | - *quant_add_config*: optinal list of per-layer configuration, which overwrite previous settings on a per-layer basis 55 | - *mobilenet_width*: Mobilenet width multiplier ( default=1.0; supported \[ 0.25, 0.5, 0.75, 1.0 \] ) 56 | - *mobilenet_input*: Mobilenet resolution input size ( default=224; supported \[ 128, 160, 192, 224 \] ) 57 | - *mem_constraint*: Memory contraints of the target device. Must be provided as a string '\[ROM_SIZE,RAM_SIZE\]' 58 | - *mixed_prec_quant*: Mixed Per-Layer ('MixPL') or mixed per-channel ('MixPC') 59 | 60 | 61 | ### Reproducing paper results 62 | For any given mobilenet model, run the script with: 63 | 64 | - memory constraints 512kB of RAM and 2MB of FLASH *--mem_constraint \[2048000,512000\]* 65 | - mixed precision per-layer or per-channel *--mixed_prec_quant MixPL* (or MixPC) 66 | 67 | As an example: 68 | ``` 69 | $ python3 main_binary.py --model mobilenet --save Imagenet_ARM/mobilenet_128_0.75_quant_auto_tt --mobilenet_width 0.75 --mobilenet_input 128 --dataset imagenet -j 32 --epochs 10 -b 128 --save_check --gpus 0,1,2,3 --type_quant PerLayerAsymPACT --activ_type learned --quantizer --batch_fold_delay 1 --batch_fold_type folding_weights --mem_constraint [2048000,512000] --mixed_prec_quant MixPL 70 | ``` 71 | 72 | 73 | ## Quantization Strategy Guide 74 | 75 | ### Overview 76 | 77 | The quantization functions are located into `quantization/quantop.py`. The operator `QuantOp` wraps the full-precision model to handle weight quantization. As a usage example: 78 | ``` 79 | import quantization 80 | quantizer = quantization.QuantOp(model, type_quant, weight_bits, \ 81 | batch_fold_type=args.batch_fold_type, batch_fold_delay=batch_fold_delay, \ 82 | act_bits=activ_bits, add_config = quant_add_config ) 83 | ``` 84 | 85 | The operator *QuantOp* after wrapping a full-precision model: 86 | 87 | - generates the deployment integer-only graph **quantizer.deployment_model**, based on the full-precision graph *model*. 88 | - updates the quantized parameters of the deployment model based on the actual full-precision graph parameters **quantizer.generate_deployment_model()** 89 | - provides methods to support quantization-aware retraining of the full-precision model 90 | 91 | At training time, the quantizer works in combination with the optimizer: 92 | ``` 93 | # weight quantization before the forward pass 94 | quantizer.store_and_quantize() # copy the real-value weights and quantize the actual ones 95 | 96 | # forward pass 97 | output = model(input) # compute output 98 | loss = criterion(output, target) # compute loss 99 | 100 | if training: 101 | # backward pass 102 | optimizer.zero_grad() 103 | loss.backward() 104 | 105 | quantizer.restore_real_value() # restore real value parameters 106 | quantizer.backprop_quant_gradients() # compute gradients wrt to real-value weights 107 | 108 | optimizer.step() # update the values 109 | 110 | else: 111 | quantizer.restore_real_value() # restore real-value weights after forward pass 112 | ``` 113 | 114 | 115 | ### Weight Quantization 116 | Currently, the following quantization schemes are supported: 117 | 118 | - *PerLayerAsymPACT*: per-layer asymmetric quantization, quantization range is learned with PACT method 119 | - *PerChannelsAsymMinMax*: per-channel asymmetric quantization, quantization range is defined by min/max range of the weight-channel tensor 120 | - *PerLayerAsymMinMax*: per-layer asymmetric quantization, quantization range is defined by min/max range of the weight tensor (not fully tested) 121 | 122 | 123 | ### Activation Quantization 124 | At the present stage, the quantized activation layers must be part of the model definition itself. This is why the input model is already a fake-quantized model. See 'models/mobilenet.py' as an example. This part will be improved with automatic graph analysis and parsing, to turn a full-precision input model into a fake-quantized one. 125 | 126 | ### Limitations 127 | This project does not include any graph analysis tools. Hence, the graph parser (see \_\_init\_\_ of *QuantOp* operator) is specific for the tested model 'models/mobilenet.py', which already includes quantized activation layers. A rework of this part may be necessary to apply the implemented techniques on any other models. -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | __imagenet_pca = { 9 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 10 | 'eigvec': torch.Tensor([ 11 | [-0.5675, 0.7192, 0.4009], 12 | [-0.5808, -0.0045, -0.8140], 13 | [-0.5836, -0.6948, 0.4203], 14 | ]) 15 | } 16 | 17 | 18 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 19 | t_list = [ 20 | transforms.CenterCrop(input_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize(**normalize), 23 | ] 24 | if scale_size != input_size: 25 | t_list = [transforms.Scale(scale_size)] + t_list 26 | 27 | return transforms.Compose(t_list) 28 | 29 | 30 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 31 | t_list = [ 32 | transforms.RandomCrop(input_size), 33 | transforms.ToTensor(), 34 | transforms.Normalize(**normalize), 35 | ] 36 | if scale_size != input_size: 37 | t_list = [transforms.Scale(scale_size)] + t_list 38 | 39 | transforms.Compose(t_list) 40 | 41 | def scale_crop_nonorm(input_size): 42 | t_list = [ 43 | transforms.CenterCrop(input_size), 44 | transforms.ToTensor(), 45 | ] 46 | return transforms.Compose(t_list) 47 | 48 | 49 | def scale_random_nonorm( input_size ): 50 | t_list = [ 51 | transforms.RandomCrop(input_size), 52 | transforms.ToTensor(), 53 | ] 54 | return transforms.Compose(t_list) 55 | 56 | 57 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 58 | padding = int((scale_size - input_size) / 2) 59 | return transforms.Compose([ 60 | transforms.RandomCrop(input_size, padding=padding), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), 63 | transforms.Normalize(**normalize), 64 | ]) 65 | 66 | 67 | def inception_preproccess(input_size, normalize=__imagenet_stats): 68 | return transforms.Compose([ 69 | transforms.RandomSizedCrop(input_size), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | transforms.Normalize(**normalize) 73 | ]) 74 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 75 | return transforms.Compose([ 76 | transforms.RandomSizedCrop(input_size), 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | ColorJitter( 80 | brightness=0.4, 81 | contrast=0.4, 82 | saturation=0.4, 83 | ), 84 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 85 | transforms.Normalize(**normalize) 86 | ]) 87 | 88 | 89 | def get_transform(name='imagenet', input_size=None, 90 | scale_size=None, normalize=None, augment=True): 91 | normalize = normalize or __imagenet_stats 92 | if name == 'imagenet': 93 | scale_size = scale_size or 256 94 | input_size = input_size or 224 95 | if augment: 96 | return inception_preproccess(input_size, normalize=normalize) 97 | else: 98 | return scale_crop(input_size=input_size, 99 | scale_size=scale_size, normalize=normalize) 100 | elif 'cifar' in name: 101 | input_size = input_size or 32 102 | if augment: 103 | scale_size = scale_size or 40 104 | return pad_random_crop(input_size, scale_size=scale_size, 105 | normalize=normalize) 106 | else: 107 | scale_size = scale_size or 32 108 | return scale_crop(input_size=input_size, 109 | scale_size=scale_size, normalize=normalize) 110 | elif name == 'mnist': 111 | normalize = {'mean': [0.5], 'std': [0.5]} 112 | input_size = input_size or 28 113 | if augment: 114 | scale_size = scale_size or 32 115 | return pad_random_crop(input_size, scale_size=scale_size, 116 | normalize=normalize) 117 | else: 118 | scale_size = scale_size or 32 119 | return scale_crop(input_size=input_size, 120 | scale_size=scale_size, normalize=normalize) 121 | 122 | elif 'inria' in name: 123 | input_size = input_size or 64 124 | if augment: 125 | return scale_random_nonorm(input_size=input_size) 126 | else: 127 | return scale_crop_nonorm(input_size=input_size) 128 | 129 | 130 | class Lighting(object): 131 | """Lighting noise(AlexNet - style PCA - based noise)""" 132 | 133 | def __init__(self, alphastd, eigval, eigvec): 134 | self.alphastd = alphastd 135 | self.eigval = eigval 136 | self.eigvec = eigvec 137 | 138 | def __call__(self, img): 139 | if self.alphastd == 0: 140 | return img 141 | 142 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 143 | rgb = self.eigvec.type_as(img).clone()\ 144 | .mul(alpha.view(1, 3).expand(3, 3))\ 145 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 146 | .sum(1).squeeze() 147 | 148 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 149 | 150 | 151 | class Grayscale(object): 152 | 153 | def __call__(self, img): 154 | gs = img.clone() 155 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 156 | gs[1].copy_(gs[0]) 157 | gs[2].copy_(gs[0]) 158 | return gs 159 | 160 | 161 | class Saturation(object): 162 | 163 | def __init__(self, var): 164 | self.var = var 165 | 166 | def __call__(self, img): 167 | gs = Grayscale()(img) 168 | alpha = random.uniform(0, self.var) 169 | return img.lerp(gs, alpha) 170 | 171 | 172 | class Brightness(object): 173 | 174 | def __init__(self, var): 175 | self.var = var 176 | 177 | def __call__(self, img): 178 | gs = img.new().resize_as_(img).zero_() 179 | alpha = random.uniform(0, self.var) 180 | return img.lerp(gs, alpha) 181 | 182 | 183 | class Contrast(object): 184 | 185 | def __init__(self, var): 186 | self.var = var 187 | 188 | def __call__(self, img): 189 | gs = Grayscale()(img) 190 | gs.fill_(gs.mean()) 191 | alpha = random.uniform(0, self.var) 192 | return img.lerp(gs, alpha) 193 | 194 | 195 | class RandomOrder(object): 196 | """ Composes several transforms together in random order. 197 | """ 198 | 199 | def __init__(self, transforms): 200 | self.transforms = transforms 201 | 202 | def __call__(self, img): 203 | if self.transforms is None: 204 | return img 205 | order = torch.randperm(len(self.transforms)) 206 | for i in order: 207 | img = self.transforms[i](img) 208 | return img 209 | 210 | 211 | class ColorJitter(RandomOrder): 212 | 213 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 214 | self.transforms = [] 215 | if brightness != 0: 216 | self.transforms.append(Brightness(brightness)) 217 | if contrast != 0: 218 | self.transforms.append(Contrast(contrast)) 219 | if saturation != 0: 220 | self.transforms.append(Saturation(saturation)) 221 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # mobilenet.py 3 | # Manuele Rusci 4 | # 5 | # Copyright (C) 2019 University of Bologna 6 | # All rights reserved. 7 | # 8 | # This is an implementation of the quantized mobilenet built from 9 | # https://github.com/marvis/pytorch-mobilenet/blob/master/main.py 10 | # 11 | 12 | import torch.nn as nn 13 | import torch.utils.model_zoo as model_zoo 14 | import math 15 | import torchvision.transforms as transforms 16 | import torch.nn.functional as F 17 | 18 | from .linear_quantized_modules import ClippedLinearQuantization, LearnedClippedLinearQuantization, Conv2d_SAME 19 | from .preload_mobilenet import preload_mobilenet_tf 20 | 21 | 22 | ###### Full Precision Blocks ############# 23 | def conv_dw(inp, oup, stride, pad1=0, bias_ena=False): 24 | if pad1 == 1: 25 | return nn.Sequential( 26 | nn.ConstantPad2d(1, -1), 27 | nn.Conv2d(inp, inp, 3, stride, 0, groups=inp, bias=bias_ena), 28 | nn.BatchNorm2d(inp), 29 | nn.ReLU6(inplace=False) 30 | ) 31 | else: 32 | return nn.Sequential( 33 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=bias_ena), 34 | nn.BatchNorm2d(inp), 35 | nn.ReLU6(inplace=False) 36 | ) 37 | 38 | def conv_pw(inp, oup, stride,bias_ena=False): 39 | return nn.Sequential( 40 | nn.Conv2d(inp, oup, 1, 1, 0, bias=bias_ena), 41 | nn.BatchNorm2d(oup), 42 | nn.ReLU6(inplace=False) 43 | ) 44 | 45 | def conv_bn(inp, oup, stride): 46 | return nn.Sequential( 47 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 48 | nn.BatchNorm2d(oup), 49 | nn.ReLU6(inplace=False) 50 | ) 51 | 52 | ###### Quantized Blocks ############################## 53 | # this should be removed and automatized 54 | def get_quant_activ_layer(activ_type=None, act_bits=8): 55 | if activ_type == 'clipped': 56 | act_layer = ClippedLinearQuantization(num_bits=act_bits, clip_val=6) 57 | elif activ_type == 'learned': 58 | act_layer = LearnedClippedLinearQuantization(num_bits=act_bits, init_act_clip_val=6) 59 | else: 60 | act_layer = nn.ReLU6(inplace=True) 61 | 62 | return act_layer 63 | 64 | 65 | def conv_dw_quant(inp, oup, stride, activ_type=None, act_bits=8, bias_ena=False): 66 | return nn.Sequential( 67 | Conv2d_SAME(inp, inp, 3, stride, 1, groups=inp, bias=bias_ena), 68 | nn.BatchNorm2d(inp), 69 | get_quant_activ_layer(activ_type, act_bits) 70 | ) 71 | 72 | def conv_pw_quant(inp, oup, stride, activ_type=None, act_bits=8, bias_ena=False): 73 | return nn.Sequential( 74 | Conv2d_SAME(inp, oup, 1, 1, 0, bias=bias_ena), 75 | nn.BatchNorm2d(oup), 76 | get_quant_activ_layer(activ_type, act_bits) 77 | ) 78 | 79 | def conv_bn_quant(inp, oup, stride, activ_type=None, act_bits=8): 80 | return nn.Sequential( 81 | Conv2d_SAME(inp, oup, 3, stride, 1, bias=False), 82 | nn.BatchNorm2d(oup), 83 | get_quant_activ_layer(activ_type, act_bits) 84 | ) 85 | 86 | 87 | class mobilenet_real(nn.Module): 88 | def __init__(self, width_mult=1.0, input_dim = 224): 89 | super(mobilenet_real, self).__init__() 90 | print(width_mult, input_dim) 91 | 92 | if input_dim == 224: 93 | avg_size = 7 94 | crop_size = 256 95 | elif input_dim == 192: 96 | avg_size = 6 97 | crop_size = 220 98 | elif input_dim == 160: 99 | avg_size = 5 100 | crop_size = 180 101 | elif input_dim == 128: 102 | avg_size = 4 103 | crop_size = 146 104 | else: 105 | return -1 106 | self.width_mult = width_mult 107 | self.model = nn.Sequential( 108 | conv_bn( 3, int(width_mult* 32), 2), 109 | conv_dw( int(width_mult* 32), int(width_mult* 64), 1), 110 | conv_pw( int(width_mult* 32), int(width_mult* 64), 1), 111 | conv_dw( int(width_mult* 64), int(width_mult*128), 2), 112 | conv_pw( int(width_mult* 64), int(width_mult*128), 2), 113 | conv_dw( int(width_mult*128), int(width_mult*128), 1), 114 | conv_pw( int(width_mult*128), int(width_mult*128), 1), 115 | conv_dw( int(width_mult*128), int(width_mult*256), 2), 116 | conv_pw( int(width_mult*128), int(width_mult*256), 2), 117 | conv_dw( int(width_mult*256), int(width_mult*256), 1), 118 | conv_pw( int(width_mult*256), int(width_mult*256), 1), 119 | conv_dw( int(width_mult*256), int(width_mult*512), 2), 120 | conv_pw( int(width_mult*256), int(width_mult*512), 2), 121 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 122 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 123 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 124 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 125 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 126 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 127 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 128 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 129 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 130 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 131 | conv_dw( int(width_mult*512), int(width_mult*1024), 2), 132 | conv_pw( int(width_mult*512), int(width_mult*1024), 2), 133 | conv_dw( int(width_mult*1024), int(width_mult*1024), 1), 134 | conv_pw( int(width_mult*1024), int(width_mult*1024), 1), 135 | nn.AvgPool2d(avg_size), 136 | ) 137 | self.fc = nn.Linear( int(width_mult*1024), 1000) 138 | 139 | self.regime = { 140 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 141 | 'weight_decay': 1e-4, 'momentum': 0.9}, 142 | 10: {'lr': 5e-2}, 143 | 20: {'lr': 1e-2}, #, 'weight_decay': 0}, 144 | 30: {'lr': 5e-3}, 145 | 40: {'lr': 1e-3}, 146 | 50: {'lr': 5e-4}, 147 | 60: {'lr': 1e-4} 148 | } 149 | 150 | 151 | #prepocessing 152 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 153 | 154 | self.input_transform = { 155 | 'train': transforms.Compose([ 156 | transforms.Scale(crop_size), 157 | transforms.RandomCrop(input_dim), 158 | transforms.RandomHorizontalFlip(), 159 | transforms.ToTensor(), 160 | normalize 161 | ]), 162 | 'eval': transforms.Compose([ 163 | transforms.Scale(crop_size), 164 | transforms.CenterCrop(input_dim), 165 | transforms.ToTensor(), 166 | normalize 167 | ]) 168 | } 169 | 170 | def forward(self, x): 171 | x = self.model(x) 172 | x = x.view(-1, int(self.width_mult*1024)) 173 | x = self.fc(x) 174 | return x 175 | 176 | class mobilenet_quant_devel(nn.Module): 177 | def __init__(self, width_mult=1.0, input_dim = 224, activ_bits =None, weight_bits= None, activ_type='relu'): 178 | super(mobilenet_quant_devel, self).__init__() 179 | 180 | print('This is a quantized Mobilenet with alpha= ', width_mult, \ 181 | ' input_size: ', input_dim, \ 182 | ' activation bits: ', activ_bits, \ 183 | ' weight bits: ', weight_bits, \ 184 | ' activation type: ', activ_type ) 185 | 186 | if input_dim == 224: 187 | avg_size = 7 188 | crop_size = 256 189 | elif input_dim == 192: 190 | avg_size = 6 191 | crop_size = 220 192 | elif input_dim == 160: 193 | avg_size = 5 194 | crop_size = 180 195 | elif input_dim == 128: 196 | avg_size = 4 197 | crop_size = 146 198 | else: 199 | return -1 200 | 201 | 202 | self.width_mult = width_mult 203 | self.model = nn.Sequential( 204 | conv_bn_quant( 3, int(width_mult* 32), 2 , activ_type, activ_bits ), 205 | conv_dw_quant( int(width_mult* 32), int(width_mult* 64), 1 , activ_type, activ_bits ), 206 | conv_pw_quant( int(width_mult* 32), int(width_mult* 64), 1 , activ_type, activ_bits ), 207 | conv_dw_quant( int(width_mult* 64), int(width_mult*128), 2 , activ_type, activ_bits ), 208 | conv_pw_quant( int(width_mult* 64), int(width_mult*128), 2 , activ_type, activ_bits ), 209 | conv_dw_quant( int(width_mult*128), int(width_mult*128), 1 , activ_type, activ_bits ), 210 | conv_pw_quant( int(width_mult*128), int(width_mult*128), 1 , activ_type, activ_bits ), 211 | conv_dw_quant( int(width_mult*128), int(width_mult*256), 2 , activ_type, activ_bits ), 212 | conv_pw_quant( int(width_mult*128), int(width_mult*256), 2 , activ_type, activ_bits ), 213 | conv_dw_quant( int(width_mult*256), int(width_mult*256), 1 , activ_type, activ_bits ), 214 | conv_pw_quant( int(width_mult*256), int(width_mult*256), 1 , activ_type, activ_bits ), 215 | conv_dw_quant( int(width_mult*256), int(width_mult*512), 2 , activ_type, activ_bits ), 216 | conv_pw_quant( int(width_mult*256), int(width_mult*512), 2 , activ_type, activ_bits ), 217 | conv_dw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 218 | conv_pw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 219 | conv_dw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 220 | conv_pw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 221 | conv_dw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 222 | conv_pw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 223 | conv_dw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 224 | conv_pw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 225 | conv_dw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 226 | conv_pw_quant( int(width_mult*512), int(width_mult*512), 1 , activ_type, activ_bits ), 227 | conv_dw_quant( int(width_mult*512), int(width_mult*1024), 2, activ_type, activ_bits ), 228 | conv_pw_quant( int(width_mult*512), int(width_mult*1024), 2, activ_type, activ_bits ), 229 | conv_dw_quant( int(width_mult*1024), int(width_mult*1024),1, activ_type, activ_bits ), 230 | conv_pw_quant( int(width_mult*1024), int(width_mult*1024),1, activ_type, activ_bits ), 231 | nn.AvgPool2d(avg_size), 232 | ) 233 | self.fc = nn.Linear( int(width_mult*1024), 1000) 234 | 235 | for m in self.model.modules(): 236 | if 'BatchNorm' in m.__class__.__name__: 237 | m.eps = 0.001 238 | m.momentum = 0.003 239 | 240 | 241 | self.regime = { 242 | 0: {'optimizer': 'Adam', 'lr': 1e-4 }, 243 | 5: {'lr': 5e-5}, 244 | 8: {'lr': 1e-5 } 245 | } 246 | #prepocess as BNN 247 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 248 | 249 | self.input_transform = { 250 | 'train': transforms.Compose([ 251 | transforms.Scale(crop_size), 252 | transforms.RandomCrop(input_dim), 253 | transforms.RandomHorizontalFlip(), 254 | transforms.ToTensor(), 255 | normalize 256 | ]), 257 | 'eval': transforms.Compose([ 258 | transforms.Scale(crop_size), 259 | transforms.CenterCrop(input_dim), 260 | transforms.ToTensor(), 261 | normalize 262 | ]) 263 | } 264 | 265 | def forward(self, x): 266 | x = self.model(x) 267 | x = x.view(-1, int(self.width_mult*1024)) 268 | x = F.dropout(x, 0.2, self.training) 269 | x = self.fc(x) 270 | return x 271 | 272 | 273 | def mobilenet(type_quant= None, activ_bits =None, weight_bits= None, activ_type='hardtanh', width_mult=1.0, input_dim = 224,**kwargs): 274 | print(','.join('{0}={1!r}'.format(k,v) for k,v in kwargs.items())) 275 | 276 | print(activ_bits, weight_bits, type_quant) 277 | 278 | 279 | if type_quant == 'PerLayerAsymMinMax' or type_quant == 'PerLayerAsymPACT' or type_quant == 'PerChannelsAsymMinMax': 280 | model = mobilenet_quant_devel(width_mult, input_dim, activ_bits, weight_bits, activ_type) 281 | preload_mobilenet_tf( model , input_dim, width_mult) 282 | return model 283 | 284 | else: 285 | return mobilenet_real(width_mult, input_dim) 286 | 287 | -------------------------------------------------------------------------------- /models/linear_quantized_modules.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # 18 | # Modified by Manuele Rusci - manuele.rusci@unibo.it 19 | # Copyright (c) 2018 Università di Bologna 20 | # 21 | 22 | import torch 23 | import torch.nn as nn 24 | import math 25 | 26 | from torch.nn.parameter import Parameter 27 | from torch.nn.functional import pad 28 | from torch.nn.modules import Module 29 | from torch.nn.modules.utils import _single, _pair, _triple 30 | import torch.nn.functional as F 31 | 32 | ### 33 | # Utils 34 | ### 35 | 36 | def symmetric_linear_quantization_scale_factor(num_bits, saturation_val): 37 | # Leave one bit for sign 38 | n = 2 ** (num_bits - 1) - 1 39 | return n / saturation_val 40 | 41 | 42 | def asymmetric_linear_quantization_scale_factor(num_bits, saturation_min, saturation_max): 43 | n = 2 ** num_bits - 1 44 | return n / (saturation_max - saturation_min) 45 | 46 | 47 | def clamp(input, min, max, inplace=False): 48 | if inplace: 49 | input.clamp_(min, max) 50 | return input 51 | return torch.clamp(input, min, max) 52 | 53 | 54 | def linear_quantize(input, scale_factor, inplace=False): 55 | if inplace: 56 | input.mul_(scale_factor).floor_() 57 | return input 58 | return torch.floor(scale_factor * input) 59 | 60 | 61 | def linear_quantize_clamp(input, scale_factor, clamp_min, clamp_max, inplace=False): 62 | output = linear_quantize(input, scale_factor, inplace) 63 | return clamp(output, clamp_min, clamp_max, inplace) 64 | 65 | 66 | def linear_dequantize(input, scale_factor, inplace=False): 67 | if inplace: 68 | input.div_(scale_factor) 69 | return input 70 | return input / scale_factor 71 | 72 | 73 | def get_tensor_max_abs(tensor): 74 | return max(abs(tensor.max().item()), abs(tensor.min().item())) 75 | 76 | 77 | def get_tensor_avg_max_abs_across_batch(tensor): 78 | # Assume batch is at dim 0 79 | tv = tensor.view(tensor.size()[0], -1) 80 | avg_max = tv.max(dim=1)[0].mean().item() 81 | avg_min = tv.min(dim=1)[0].mean().item() 82 | return max(abs(avg_max), abs(avg_min)) 83 | 84 | 85 | def get_quantized_range(num_bits, signed=True): 86 | if signed: 87 | n = 2 ** (num_bits - 1) 88 | return -n, n - 1 89 | return 0, 2 ** num_bits - 1 90 | 91 | 92 | ### 93 | # Clipping-based linear quantization (e.g. DoReFa, WRPN) 94 | ### 95 | 96 | 97 | class LinearQuantizeSTE(torch.autograd.Function): 98 | @staticmethod 99 | def forward(ctx, input, scale_factor, dequantize, inplace): 100 | if inplace: 101 | ctx.mark_dirty(input) 102 | output = linear_quantize(input, scale_factor, inplace) 103 | if dequantize: 104 | output = linear_dequantize(output, scale_factor, inplace) 105 | return output 106 | 107 | @staticmethod 108 | def backward(ctx, grad_output): 109 | # Straight-through estimator 110 | return grad_output, None, None, None 111 | 112 | 113 | class LearnedClippedLinearQuantizeSTE(torch.autograd.Function): 114 | @staticmethod 115 | def forward(ctx, input, clip_val, num_bits, dequantize, inplace): 116 | ctx.save_for_backward(input, clip_val) 117 | if inplace: 118 | ctx.mark_dirty(input) 119 | scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val.data[0]) 120 | output = clamp(input, 0, clip_val.data[0], inplace) 121 | output = linear_quantize(output, scale_factor, inplace) 122 | if dequantize: 123 | output = linear_dequantize(output, scale_factor, inplace) 124 | return output 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | input, clip_val = ctx.saved_tensors 129 | grad_input = grad_output.clone() 130 | grad_input.masked_fill_(input.le(0), 0) 131 | grad_input.masked_fill_(input.ge(clip_val.data[0]), 0) 132 | 133 | grad_alpha = grad_output.clone() 134 | grad_alpha.masked_fill_(input.lt(clip_val.data[0]), 0) 135 | # grad_alpha[input.lt(clip_val.data[0])] = 0 136 | grad_alpha = grad_alpha.sum().expand_as(clip_val) 137 | 138 | # Straight-through estimator for the scale factor calculation 139 | return grad_input, grad_alpha, None, None, None 140 | 141 | 142 | ### 143 | # Layers 144 | ### 145 | 146 | # Missing online estimation of min/max -- se below: TODO 147 | class QuantMeasure(nn.Module): 148 | """docstring for QuantMeasure.""" 149 | 150 | def __init__(self, num_bits=8, momentum=0.1): 151 | super(QuantMeasure, self).__init__() 152 | self.register_buffer('running_min', torch.zeros(1)) 153 | self.register_buffer('running_max', torch.zeros(1)) 154 | self.momentum = momentum 155 | self.num_bits = num_bits 156 | 157 | def forward(self, input): 158 | if self.training: 159 | min_value = input.detach().view( 160 | input.size(0), -1).min(-1)[0].mean() 161 | max_value = input.detach().view( 162 | input.size(0), -1).max(-1)[0].mean() 163 | self.running_min.mul_(self.momentum).add_( 164 | min_value * (1 - self.momentum)) 165 | self.running_max.mul_(self.momentum).add_( 166 | max_value * (1 - self.momentum)) 167 | else: 168 | min_value = self.running_min 169 | max_value = self.running_max 170 | return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16) 171 | 172 | 173 | class ClippedLinearQuantization(nn.Module): 174 | def __init__(self, num_bits, clip_val, dequantize=True, inplace=False): 175 | super(ClippedLinearQuantization, self).__init__() 176 | self.num_bits = num_bits 177 | self.clip_val = clip_val 178 | self.scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val) 179 | self.dequantize = dequantize 180 | self.inplace = inplace 181 | 182 | def forward(self, input): 183 | input = clamp(input, 0, self.clip_val, self.inplace) 184 | input = LinearQuantizeSTE.apply(input, self.scale_factor, self.dequantize, self.inplace) 185 | return input 186 | 187 | def __repr__(self): 188 | inplace_str = ', inplace' if self.inplace else '' 189 | return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val, 190 | inplace_str) 191 | 192 | 193 | class LearnedClippedLinearQuantization(nn.Module): 194 | def __init__(self, num_bits, init_act_clip_val, dequantize=True, inplace=True): 195 | super(LearnedClippedLinearQuantization, self).__init__() 196 | self.num_bits = num_bits 197 | self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val])) 198 | self.dequantize = dequantize 199 | self.inplace = inplace 200 | 201 | def forward(self, input): 202 | input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits, self.dequantize, self.inplace) 203 | return input 204 | 205 | def __repr__(self): 206 | inplace_str = ', inplace' if self.inplace else '' 207 | return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val, 208 | inplace_str) 209 | 210 | 211 | ''' 212 | Quantized Layers for deployment purpose 213 | ''' 214 | class ScaledClippedLinearQuantization(nn.Module): 215 | def __init__(self, clip_val=1, M=1, inplace=True): 216 | super(ScaledClippedLinearQuantization, self).__init__() 217 | self.clip_val = clip_val 218 | self.M_ZERO = M 219 | self.N_ZERO = M 220 | self.inplace = inplace 221 | 222 | def forward(self, input): 223 | input = input.mul(self.M_ZERO).mul(2**self.N_ZERO).floor() 224 | input = clamp(input, 0, self.clip_val, self.inplace) 225 | return input 226 | 227 | def __repr__(self): 228 | inplace_str = ', inplace' if self.inplace else '' 229 | return '{0}(M0={1}, N0={2}, clip_val={3}{4})'.format(self.__class__.__name__, self.M_ZERO, self.N_ZERO, self.clip_val, 230 | inplace_str) 231 | 232 | 233 | class ScaledClippedLinearQuantizationChannel(nn.Module): 234 | def __init__(self, n_channels, clip_val=1, M=1, inplace=True): 235 | super(ScaledClippedLinearQuantizationChannel, self).__init__() 236 | self.clip_val = clip_val 237 | self.M_ZERO = torch.Tensor(n_channels).fill_(M) 238 | self.N_ZERO = torch.Tensor(n_channels).fill_(M) 239 | self.inplace = inplace 240 | 241 | def forward(self, input): 242 | if self.clip_val is not False: 243 | input = input.mul(self.M_ZERO.unsqueeze(0).unsqueeze(2).unsqueeze(2).expand(input.size())) 244 | input = input.mul(self.N_ZERO.unsqueeze(0).unsqueeze(2).unsqueeze(2).expand(input.size())).floor() #round 245 | input = clamp(input, 0, self.clip_val, self.inplace) 246 | else: 247 | input = input.mul(self.M_ZERO).mul(self.N_ZERO) 248 | return input 249 | 250 | def __repr__(self): 251 | inplace_str = ', inplace' if self.inplace else '' 252 | return '{0}(clip_val={1}{2})'.format(self.__class__.__name__, self.clip_val, inplace_str) 253 | 254 | 255 | 256 | # Tensorflow convolution with SAME padding 257 | class _ConvNd(Module): 258 | 259 | def __init__(self, in_channels, out_channels, kernel_size, stride, 260 | padding, dilation, transposed, output_padding, groups, bias): 261 | super(_ConvNd, self).__init__() 262 | if in_channels % groups != 0: 263 | raise ValueError('in_channels must be divisible by groups') 264 | if out_channels % groups != 0: 265 | raise ValueError('out_channels must be divisible by groups') 266 | self.in_channels = in_channels 267 | self.out_channels = out_channels 268 | self.kernel_size = kernel_size 269 | self.stride = stride 270 | self.padding = padding 271 | self.dilation = dilation 272 | self.transposed = transposed 273 | self.output_padding = output_padding 274 | self.groups = groups 275 | if transposed: 276 | self.weight = Parameter(torch.Tensor( 277 | in_channels, out_channels // groups, *kernel_size)) 278 | else: 279 | self.weight = Parameter(torch.Tensor( 280 | out_channels, in_channels // groups, *kernel_size)) 281 | if bias: 282 | self.bias = Parameter(torch.Tensor(out_channels)) 283 | else: 284 | self.register_parameter('bias', None) 285 | self.reset_parameters() 286 | 287 | def reset_parameters(self): 288 | n = self.in_channels 289 | for k in self.kernel_size: 290 | n *= k 291 | stdv = 1. / math.sqrt(n) 292 | self.weight.data.uniform_(-stdv, stdv) 293 | if self.bias is not None: 294 | self.bias.data.uniform_(-stdv, stdv) 295 | 296 | def __repr__(self): 297 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 298 | ', stride={stride}') 299 | if self.padding != (0,) * len(self.padding): 300 | s += ', padding={padding}' 301 | if self.dilation != (1,) * len(self.dilation): 302 | s += ', dilation={dilation}' 303 | if self.output_padding != (0,) * len(self.output_padding): 304 | s += ', output_padding={output_padding}' 305 | if self.groups != 1: 306 | s += ', groups={groups}' 307 | if self.bias is None: 308 | s += ', bias=False' 309 | s += ')' 310 | return s.format(name=self.__class__.__name__, **self.__dict__) 311 | 312 | 313 | class Conv2d_SAME(_ConvNd): 314 | 315 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 316 | padding=0, dilation=1, groups=1, bias=True): 317 | kernel_size = _pair(kernel_size) 318 | stride = _pair(stride) 319 | padding = _pair(padding) 320 | dilation = _pair(dilation) 321 | super(Conv2d_SAME, self).__init__( 322 | in_channels, out_channels, kernel_size, stride, padding, dilation, 323 | False, _pair(0), groups, bias) 324 | 325 | def forward(self, input): 326 | return conv2d_same_padding(input, self.weight, self.bias, self.stride, 327 | self.padding, self.dilation, self.groups) 328 | 329 | 330 | # custom con2d, because pytorch don't have "padding='same'" option. 331 | def conv2d_same_padding(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1): 332 | 333 | input_rows = input.size(2) 334 | filter_rows = weight.size(2) 335 | effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1 336 | out_rows = (input_rows + stride[0] - 1) // stride[0] 337 | padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows - 338 | input_rows) 339 | padding_rows = max(0, (out_rows - 1) * stride[0] + 340 | (filter_rows - 1) * dilation[0] + 1 - input_rows) 341 | rows_odd = (padding_rows % 2 != 0) 342 | padding_cols = max(0, (out_rows - 1) * stride[0] + 343 | (filter_rows - 1) * dilation[0] + 1 - input_rows) 344 | cols_odd = (padding_rows % 2 != 0) 345 | 346 | #print('cols_odd', int(cols_odd), 'rows_odd', int(rows_odd), 'padding',(padding_rows // 2, padding_cols // 2), 'dialtion', dilation) 347 | if rows_odd or cols_odd: 348 | input = pad(input, [0, int(cols_odd), 0, int(rows_odd)]) 349 | 350 | return F.conv2d(input, weight, bias, stride, 351 | padding=(padding_rows // 2, padding_cols // 2), 352 | dilation=dilation, groups=groups) -------------------------------------------------------------------------------- /models/preload_mobilenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # preload_mobilenet.py 3 | # Manuele Rusci 4 | # 5 | # Copyright (C) 2019 University of Bologna 6 | # All rights reserved. 7 | # 8 | # Adapted from: https://github.com/ruotianluo/pytorch-mobilenet-from-tf 9 | # 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.model_zoo as model_zoo 14 | import os 15 | import sys 16 | 17 | import tensorflow as tf 18 | from tensorflow.python import pywrap_tensorflow 19 | from collections import OrderedDict 20 | import re 21 | import torch 22 | import numpy as np 23 | 24 | dummy_replace = OrderedDict([\ 25 | ('features/Logits/Conv2d_1c_1x1/biases','fc.bias' ),\ 26 | ('features/Logits/Conv2d_1c_1x1/weights','fc.weight'),\ 27 | ('features/Conv2d_0/weights','model.0.0.weight' ), \ 28 | ('features/Conv2d_0/BatchNorm/beta','model.0.1.bias' ),\ 29 | ('features/Conv2d_0/BatchNorm/moving_mean','model.0.1.running_mean' ), \ 30 | ('features/Conv2d_0/BatchNorm/moving_variance','model.0.1.running_var' ), \ 31 | ('features/Conv2d_0/BatchNorm/gamma','model.0.1.weight' ),\ 32 | ('features/Conv2d_1_depthwise/depthwise_weights','model.1.0.weight' ),\ 33 | ('features/Conv2d_1_depthwise/BatchNorm/beta','model.1.1.bias' ),\ 34 | ('features/Conv2d_1_depthwise/BatchNorm/moving_mean','model.1.1.running_mean' ), \ 35 | ('features/Conv2d_1_depthwise/BatchNorm/moving_variance','model.1.1.running_var' ), \ 36 | ('features/Conv2d_1_depthwise/BatchNorm/gamma','model.1.1.weight' ), \ 37 | ('features/Conv2d_1_pointwise/weights','model.2.0.weight' ), \ 38 | ('features/Conv2d_1_pointwise/BatchNorm/beta','model.2.1.bias' ), \ 39 | ('features/Conv2d_1_pointwise/BatchNorm/moving_mean','model.2.1.running_mean' ), \ 40 | ('features/Conv2d_1_pointwise/BatchNorm/moving_variance','model.2.1.running_var' ), \ 41 | ('features/Conv2d_1_pointwise/BatchNorm/gamma','model.2.1.weight' ), \ 42 | ('features/Conv2d_2_depthwise/depthwise_weights','model.3.0.weight' ), \ 43 | ('features/Conv2d_2_depthwise/BatchNorm/beta','model.3.1.bias' ), \ 44 | ('features/Conv2d_2_depthwise/BatchNorm/moving_mean','model.3.1.running_mean' ), \ 45 | ('features/Conv2d_2_depthwise/BatchNorm/moving_variance','model.3.1.running_var' ), \ 46 | ('features/Conv2d_2_depthwise/BatchNorm/gamma','model.3.1.weight' ), \ 47 | ('features/Conv2d_2_pointwise/weights','model.4.0.weight' ), \ 48 | ('features/Conv2d_2_pointwise/BatchNorm/beta','model.4.1.bias' ), \ 49 | ('features/Conv2d_2_pointwise/BatchNorm/moving_mean','model.4.1.running_mean' ), \ 50 | ('features/Conv2d_2_pointwise/BatchNorm/moving_variance','model.4.1.running_var' ), \ 51 | ('features/Conv2d_2_pointwise/BatchNorm/gamma','model.4.1.weight' ), \ 52 | ('features/Conv2d_3_depthwise/depthwise_weights','model.5.0.weight' ), \ 53 | ('features/Conv2d_3_depthwise/BatchNorm/beta','model.5.1.bias' ), \ 54 | ('features/Conv2d_3_depthwise/BatchNorm/moving_mean','model.5.1.running_mean' ), \ 55 | ('features/Conv2d_3_depthwise/BatchNorm/moving_variance','model.5.1.running_var' ), \ 56 | ('features/Conv2d_3_depthwise/BatchNorm/gamma','model.5.1.weight' ), \ 57 | ('features/Conv2d_3_pointwise/weights','model.6.0.weight' ), \ 58 | ('features/Conv2d_3_pointwise/BatchNorm/beta','model.6.1.bias' ), \ 59 | ('features/Conv2d_3_pointwise/BatchNorm/moving_mean','model.6.1.running_mean' ), \ 60 | ('features/Conv2d_3_pointwise/BatchNorm/moving_variance','model.6.1.running_var' ), \ 61 | ('features/Conv2d_3_pointwise/BatchNorm/gamma','model.6.1.weight' ), \ 62 | ('features/Conv2d_4_depthwise/depthwise_weights','model.7.0.weight' ), \ 63 | ('features/Conv2d_4_depthwise/BatchNorm/beta','model.7.1.bias' ), \ 64 | ('features/Conv2d_4_depthwise/BatchNorm/moving_mean','model.7.1.running_mean' ), \ 65 | ('features/Conv2d_4_depthwise/BatchNorm/moving_variance','model.7.1.running_var' ), \ 66 | ('features/Conv2d_4_depthwise/BatchNorm/gamma','model.7.1.weight' ), \ 67 | ('features/Conv2d_4_pointwise/weights','model.8.0.weight' ), \ 68 | ('features/Conv2d_4_pointwise/BatchNorm/beta','model.8.1.bias' ), \ 69 | ('features/Conv2d_4_pointwise/BatchNorm/moving_mean','model.8.1.running_mean' ), \ 70 | ('features/Conv2d_4_pointwise/BatchNorm/moving_variance','model.8.1.running_var' ), \ 71 | ('features/Conv2d_4_pointwise/BatchNorm/gamma','model.8.1.weight' ), \ 72 | ('features/Conv2d_5_depthwise/depthwise_weights','model.9.0.weight' ), \ 73 | ('features/Conv2d_5_depthwise/BatchNorm/beta','model.9.1.bias' ), \ 74 | ('features/Conv2d_5_depthwise/BatchNorm/moving_mean','model.9.1.running_mean' ), \ 75 | ('features/Conv2d_5_depthwise/BatchNorm/moving_variance','model.9.1.running_var' ), \ 76 | ('features/Conv2d_5_depthwise/BatchNorm/gamma','model.9.1.weight' ), \ 77 | ('features/Conv2d_5_pointwise/weights','model.10.0.weight' ), \ 78 | ('features/Conv2d_5_pointwise/BatchNorm/beta','model.10.1.bias' ), \ 79 | ('features/Conv2d_5_pointwise/BatchNorm/moving_mean','model.10.1.running_mean' ), \ 80 | ('features/Conv2d_5_pointwise/BatchNorm/moving_variance','model.10.1.running_var' ), \ 81 | ('features/Conv2d_5_pointwise/BatchNorm/gamma','model.10.1.weight' ), \ 82 | ('features/Conv2d_6_depthwise/depthwise_weights','model.11.0.weight' ), \ 83 | ('features/Conv2d_6_depthwise/BatchNorm/beta','model.11.1.bias' ), \ 84 | ('features/Conv2d_6_depthwise/BatchNorm/moving_mean','model.11.1.running_mean' ), \ 85 | ('features/Conv2d_6_depthwise/BatchNorm/moving_variance','model.11.1.running_var' ), \ 86 | ('features/Conv2d_6_depthwise/BatchNorm/gamma','model.11.1.weight' ), \ 87 | ('features/Conv2d_6_pointwise/weights','model.12.0.weight' ), \ 88 | ('features/Conv2d_6_pointwise/BatchNorm/beta','model.12.1.bias' ), \ 89 | ('features/Conv2d_6_pointwise/BatchNorm/moving_mean','model.12.1.running_mean' ), \ 90 | ('features/Conv2d_6_pointwise/BatchNorm/moving_variance','model.12.1.running_var' ), \ 91 | ('features/Conv2d_6_pointwise/BatchNorm/gamma','model.12.1.weight' ), \ 92 | ('features/Conv2d_7_depthwise/depthwise_weights','model.13.0.weight' ), \ 93 | ('features/Conv2d_7_depthwise/BatchNorm/beta','model.13.1.bias' ), \ 94 | ('features/Conv2d_7_depthwise/BatchNorm/moving_mean','model.13.1.running_mean' ), \ 95 | ('features/Conv2d_7_depthwise/BatchNorm/moving_variance','model.13.1.running_var' ), \ 96 | ('features/Conv2d_7_depthwise/BatchNorm/gamma','model.13.1.weight' ), \ 97 | ('features/Conv2d_7_pointwise/weights','model.14.0.weight' ), \ 98 | ('features/Conv2d_7_pointwise/BatchNorm/beta','model.14.1.bias' ), \ 99 | ('features/Conv2d_7_pointwise/BatchNorm/moving_mean','model.14.1.running_mean' ), \ 100 | ('features/Conv2d_7_pointwise/BatchNorm/moving_variance','model.14.1.running_var' ), \ 101 | ('features/Conv2d_7_pointwise/BatchNorm/gamma','model.14.1.weight' ), \ 102 | ('features/Conv2d_8_depthwise/depthwise_weights','model.15.0.weight' ), \ 103 | ('features/Conv2d_8_depthwise/BatchNorm/beta','model.15.1.bias' ), \ 104 | ('features/Conv2d_8_depthwise/BatchNorm/moving_mean','model.15.1.running_mean' ), \ 105 | ('features/Conv2d_8_depthwise/BatchNorm/moving_variance','model.15.1.running_var' ), \ 106 | ('features/Conv2d_8_depthwise/BatchNorm/gamma','model.15.1.weight' ), \ 107 | ('features/Conv2d_8_pointwise/weights','model.16.0.weight' ), \ 108 | ('features/Conv2d_8_pointwise/BatchNorm/beta','model.16.1.bias' ), \ 109 | ('features/Conv2d_8_pointwise/BatchNorm/moving_mean','model.16.1.running_mean' ), \ 110 | ('features/Conv2d_8_pointwise/BatchNorm/moving_variance','model.16.1.running_var' ), \ 111 | ('features/Conv2d_8_pointwise/BatchNorm/gamma','model.16.1.weight' ), \ 112 | ('features/Conv2d_9_depthwise/depthwise_weights','model.17.0.weight' ), \ 113 | ('features/Conv2d_9_depthwise/BatchNorm/beta','model.17.1.bias' ), \ 114 | ('features/Conv2d_9_depthwise/BatchNorm/moving_mean','model.17.1.running_mean' ), \ 115 | ('features/Conv2d_9_depthwise/BatchNorm/moving_variance','model.17.1.running_var' ), \ 116 | ('features/Conv2d_9_depthwise/BatchNorm/gamma','model.17.1.weight' ), \ 117 | ('features/Conv2d_9_pointwise/weights','model.18.0.weight' ), \ 118 | ('features/Conv2d_9_pointwise/BatchNorm/beta','model.18.1.bias' ), \ 119 | ('features/Conv2d_9_pointwise/BatchNorm/moving_mean','model.18.1.running_mean' ), \ 120 | ('features/Conv2d_9_pointwise/BatchNorm/moving_variance','model.18.1.running_var' ), \ 121 | ('features/Conv2d_9_pointwise/BatchNorm/gamma','model.18.1.weight' ), \ 122 | ('features/Conv2d_10_depthwise/depthwise_weights','model.19.0.weight' ), \ 123 | ('features/Conv2d_10_depthwise/BatchNorm/beta','model.19.1.bias' ), \ 124 | ('features/Conv2d_10_depthwise/BatchNorm/moving_mean','model.19.1.running_mean' ), \ 125 | ('features/Conv2d_10_depthwise/BatchNorm/moving_variance','model.19.1.running_var' ), \ 126 | ('features/Conv2d_10_depthwise/BatchNorm/gamma','model.19.1.weight' ), \ 127 | ('features/Conv2d_10_pointwise/weights','model.20.0.weight' ), \ 128 | ('features/Conv2d_10_pointwise/BatchNorm/beta','model.20.1.bias' ), \ 129 | ('features/Conv2d_10_pointwise/BatchNorm/moving_mean','model.20.1.running_mean' ), \ 130 | ('features/Conv2d_10_pointwise/BatchNorm/moving_variance','model.20.1.running_var' ), \ 131 | ('features/Conv2d_10_pointwise/BatchNorm/gamma','model.20.1.weight' ), \ 132 | ('features/Conv2d_11_depthwise/depthwise_weights','model.21.0.weight' ), \ 133 | ('features/Conv2d_11_depthwise/BatchNorm/beta','model.21.1.bias' ), \ 134 | ('features/Conv2d_11_depthwise/BatchNorm/moving_mean','model.21.1.running_mean' ), \ 135 | ('features/Conv2d_11_depthwise/BatchNorm/moving_variance','model.21.1.running_var' ), \ 136 | ('features/Conv2d_11_depthwise/BatchNorm/gamma','model.21.1.weight' ), \ 137 | ('features/Conv2d_11_pointwise/weights','model.22.0.weight' ), \ 138 | ('features/Conv2d_11_pointwise/BatchNorm/beta','model.22.1.bias' ), \ 139 | ('features/Conv2d_11_pointwise/BatchNorm/moving_mean','model.22.1.running_mean' ), \ 140 | ('features/Conv2d_11_pointwise/BatchNorm/moving_variance','model.22.1.running_var' ), \ 141 | ('features/Conv2d_11_pointwise/BatchNorm/gamma','model.22.1.weight' ), \ 142 | ('features/Conv2d_12_depthwise/depthwise_weights','model.23.0.weight' ), \ 143 | ('features/Conv2d_12_depthwise/BatchNorm/beta','model.23.1.bias' ), \ 144 | ('features/Conv2d_12_depthwise/BatchNorm/moving_mean','model.23.1.running_mean' ), \ 145 | ('features/Conv2d_12_depthwise/BatchNorm/moving_variance','model.23.1.running_var' ), \ 146 | ('features/Conv2d_12_depthwise/BatchNorm/gamma','model.23.1.weight' ), \ 147 | ('features/Conv2d_12_pointwise/weights','model.24.0.weight' ), \ 148 | ('features/Conv2d_12_pointwise/BatchNorm/beta','model.24.1.bias' ), \ 149 | ('features/Conv2d_12_pointwise/BatchNorm/moving_mean','model.24.1.running_mean' ), \ 150 | ('features/Conv2d_12_pointwise/BatchNorm/moving_variance','model.24.1.running_var' ), \ 151 | ('features/Conv2d_12_pointwise/BatchNorm/gamma','model.24.1.weight' ), \ 152 | ('features/Conv2d_13_depthwise/depthwise_weights','model.25.0.weight' ), \ 153 | ('features/Conv2d_13_depthwise/BatchNorm/beta','model.25.1.bias' ), \ 154 | ('features/Conv2d_13_depthwise/BatchNorm/moving_mean','model.25.1.running_mean' ), \ 155 | ('features/Conv2d_13_depthwise/BatchNorm/moving_variance','model.25.1.running_var' ), \ 156 | ('features/Conv2d_13_depthwise/BatchNorm/gamma','model.25.1.weight' ), \ 157 | ('features/Conv2d_13_pointwise/weights','model.26.0.weight' ), \ 158 | ('features/Conv2d_13_pointwise/BatchNorm/beta','model.26.1.bias' ), \ 159 | ('features/Conv2d_13_pointwise/BatchNorm/moving_mean','model.26.1.running_mean' ), \ 160 | ('features/Conv2d_13_pointwise/BatchNorm/moving_variance','model.26.1.running_var' ), \ 161 | ('features/Conv2d_13_pointwise/BatchNorm/gamma','model.26.1.weight' ), \ 162 | ]) 163 | 164 | 165 | def path_pretrained_tf(input_size, depth_multiplier): 166 | folder_path = './models/mobilenet_tf/' 167 | 168 | if input_size == 224: 169 | if depth_multiplier == 1.0: 170 | tensorflow_model='224_1.0/mobilenet_v1_1.0_224.ckpt' 171 | elif depth_multiplier == 0.75: 172 | tensorflow_model='224_0.75/mobilenet_v1_0.75_224.ckpt' 173 | elif depth_multiplier == 0.5: 174 | tensorflow_model='224_0.5/mobilenet_v1_0.5_224.ckpt' 175 | elif depth_multiplier == 0.25: 176 | tensorflow_model='224_0.25/mobilenet_v1_0.25_224.ckpt' 177 | elif input_size == 192: 178 | if depth_multiplier == 1.0: 179 | tensorflow_model='192_1.0/mobilenet_v1_1.0_192.ckpt' 180 | elif depth_multiplier == 0.75: 181 | tensorflow_model='192_0.75/mobilenet_v1_0.75_192.ckpt' 182 | elif depth_multiplier == 0.5: 183 | tensorflow_model='192_0.5/mobilenet_v1_0.5_192.ckpt' 184 | elif depth_multiplier == 0.25: 185 | tensorflow_model='192_0.25/mobilenet_v1_0.25_192.ckpt' 186 | elif input_size == 160: 187 | if depth_multiplier == 1.0: 188 | tensorflow_model='160_1.0/mobilenet_v1_1.0_160.ckpt' 189 | elif depth_multiplier == 0.75: 190 | tensorflow_model='160_0.75/mobilenet_v1_0.75_160.ckpt' 191 | elif depth_multiplier == 0.5: 192 | tensorflow_model='160_0.5/mobilenet_v1_0.5_160.ckpt' 193 | elif depth_multiplier == 0.25: 194 | tensorflow_model='160_0.25/mobilenet_v1_0.25_160.ckpt' 195 | elif input_size == 128: 196 | if depth_multiplier == 1.0: 197 | tensorflow_model='128_1.0/mobilenet_v1_1.0_128.ckpt' 198 | elif depth_multiplier == 0.75: 199 | tensorflow_model='128_0.75/mobilenet_v1_0.75_128.ckpt' 200 | elif depth_multiplier == 0.5: 201 | tensorflow_model='128_0.5/mobilenet_v1_0.5_128.ckpt' 202 | elif depth_multiplier == 0.25: 203 | tensorflow_model='128_0.25/mobilenet_v1_0.25_128.ckpt' 204 | 205 | return folder_path+tensorflow_model 206 | 207 | 208 | 209 | def preload_mobilenet_tf(model, input_size, depth_multiplier): 210 | 211 | #Read TensorFlow Model 212 | tensorflow_model = path_pretrained_tf(input_size, depth_multiplier) 213 | reader = pywrap_tensorflow.NewCheckpointReader(tensorflow_model) 214 | var_to_shape_map = reader.get_variable_to_shape_map() 215 | var_dict = {k:reader.get_tensor(k) for k in var_to_shape_map.keys()} 216 | 217 | # Take PyTorch Model 218 | x = model.state_dict() 219 | 220 | # Filtering Tensorflow params from the model 221 | for k in list(var_dict.keys()): 222 | if var_dict[k].ndim == 4: 223 | if 'depthwise' in k: 224 | var_dict[k] = var_dict[k].transpose((2, 3, 0, 1)).copy(order='C') 225 | else: 226 | var_dict[k] = var_dict[k].transpose((3, 2, 0, 1)).copy(order='C') 227 | if var_dict[k].ndim == 2: 228 | var_dict[k] = var_dict[k].transpose((1, 0)).copy(order='C') 229 | 230 | for k in list(var_dict.keys()): 231 | if 'Momentum' in k or 'ExponentialMovingAverage' in k or 'RMSProp' in k or 'global_step' in k : 232 | del var_dict[k] 233 | 234 | for k in list(var_dict.keys()): 235 | if k.find('/') >= 0: 236 | var_dict['features'+k[k.find('/'):]] = var_dict[k] 237 | del var_dict[k] 238 | 239 | # Adapt from tensorflow to pytorch 240 | for a, b in dummy_replace.items(): 241 | for k in list(var_dict.keys()): 242 | if a in k: 243 | var_dict[k.replace(a,b)] = var_dict[k] 244 | del var_dict[k] 245 | 246 | # print('In var_dict but not in x_dict') 247 | # print(set(var_dict.keys()) - set(x.keys())) 248 | # print('In x_dict but not in var_dict') 249 | # print(set(x.keys()) - set(var_dict.keys())) 250 | for k in set(x.keys()) - set(var_dict.keys()): 251 | del x[k] 252 | 253 | assert len(set(x.keys()) - set(var_dict.keys())) == 0 254 | for k in set(var_dict.keys()) - set(x.keys()): 255 | del var_dict[k] 256 | 257 | for k in list(var_dict.keys()): 258 | if x[k].shape != var_dict[k].shape: 259 | print(k, 'Error') 260 | 261 | for k in list(var_dict.keys()): 262 | var_dict[k] = torch.from_numpy(var_dict[k]) 263 | 264 | # remove 1001-th class 265 | org_weight = var_dict['fc.weight'].clone() 266 | org_bias = var_dict['fc.bias'].clone() 267 | START = 1 268 | var_dict['fc.weight'] = org_weight.narrow(0,START,1000).squeeze() 269 | var_dict['fc.bias'] = org_bias.narrow(0,START,1000) 270 | 271 | # load model 272 | model.load_state_dict(var_dict, strict=False) 273 | 274 | 275 | print('Model Pretrained Loaded') -------------------------------------------------------------------------------- /quantization/quant_auto.py: -------------------------------------------------------------------------------- 1 | # 2 | # quant_auto.py 3 | # Manuele Rusci 4 | # 5 | # Copyright (C) 2019 University of Bologna 6 | # All rights reserved. 7 | # 8 | 9 | import torch, math 10 | import torch.nn as nn 11 | import numpy as np 12 | from scipy.linalg import lstsq 13 | from torch.autograd import Variable 14 | import copy 15 | 16 | import sys 17 | sys.path.append('../') 18 | from models.linear_quantized_modules import Conv2d_SAME 19 | 20 | 21 | def memory_driven_quant(model, x, ONLY_READ_MEM = 512*1024, READWRITE_MEM = 64*1024, cfg_quant='MixPL' ): 22 | 23 | #################################################################################### 24 | ###################### compute activation footprint ################################ 25 | #################################################################################### 26 | def compute_activation_footprint(param_list): 27 | i_params_mem = [] 28 | i_bits_mem = [] 29 | o_params_mem = [] 30 | o_bits_mem = [] 31 | 32 | for i,item in enumerate(param_list): 33 | o_size = activation_vector_o[i] 34 | i_size = activation_vector_i[i] 35 | o_params = 1 36 | for v in range(len(o_size)): 37 | o_params *= o_size[v] 38 | i_params = 1 39 | for v in range(len(i_size)): 40 | i_params *= i_size[v] 41 | 42 | i_bits = param_list[i]['act_i_bits'] 43 | o_bits = param_list[i]['act_o_bits'] 44 | 45 | i_params_mem.append(i_params) 46 | i_bits_mem.append(i_bits) 47 | o_params_mem.append(o_params) 48 | o_bits_mem.append(o_bits) 49 | 50 | #print('Input = ', i_params,'(bits = ', i_bits, ') | Output Params = ', o_params,'(bits = ', o_bits,')') 51 | 52 | return i_params_mem, i_bits_mem, o_params_mem, o_bits_mem 53 | 54 | def cut_activation_footprint(param_list, act_mem, MIN_ACT_BITS=2 ): 55 | 56 | i_params_mem, i_bits_mem, o_params_mem, o_bits_mem = compute_activation_footprint(param_list) 57 | LAST = len(o_bits_mem)-1 58 | 59 | errQuant = False 60 | 61 | completeQuant = False 62 | 63 | #for i in range(len(i_params_mem)): 64 | #start from layer 0 in forward direction 65 | i = 0 66 | forward = True 67 | okQuantNet = False 68 | 69 | print('Init: ', i_bits_mem, o_bits_mem) 70 | 71 | while completeQuant is False: 72 | 73 | okQuant = True # flag is set to False if one of the layer quantization results uncomplete during forward or backward 74 | 75 | print('*********** Layer ',i, 'Start ************** | completeQuant is ', completeQuant,'| okQuantNet is ',okQuantNet) 76 | tot_layer_mem = ((i_params_mem[i] * i_bits_mem[i]) + (o_params_mem[i] * o_bits_mem[i]) )/8 # to bytes 77 | n_iter = 1 78 | 79 | while tot_layer_mem > act_mem and okQuant == True: 80 | 81 | print('Iteration ',n_iter, 'with bits: ',i_bits_mem[i],o_bits_mem[i],'with forward as ', forward) 82 | 83 | i_mem = i_params_mem[i] * i_bits_mem[i] 84 | o_mem = o_params_mem[i] * o_bits_mem[i] 85 | 86 | if i == 0: # first layer 87 | 88 | if forward == True: 89 | if o_bits_mem[i] > MIN_ACT_BITS: #forward 90 | param_list[i]['act_o_bits'] = int(o_bits_mem[i] / 2) 91 | param_list[i+1]['act_i_bits'] = param_list[i]['act_o_bits'] 92 | else: #error 93 | errQuant = True 94 | print('No way to find a solution because of the first layer!') 95 | break 96 | else: 97 | 98 | if o_bits_mem[i] > MIN_ACT_BITS: #forward 99 | okQuant = False 100 | 101 | else: #error 102 | errQuant = True 103 | print('No way to find a solution because of the first layer!') 104 | break 105 | 106 | 107 | elif i == LAST: # last layer 108 | 109 | if forward == True: 110 | 111 | if i_bits_mem[i] > MIN_ACT_BITS: #backward 112 | okQuant = False 113 | 114 | else: #error 115 | errQuant = True 116 | print('No way to find a solution because of the last layer!') 117 | break 118 | else: 119 | 120 | if i_bits_mem[i] > MIN_ACT_BITS: #backward 121 | param_list[i-1]['act_o_bits'] = int(i_bits_mem[i] / 2) 122 | param_list[i]['act_i_bits'] = param_list[i-1]['act_o_bits'] 123 | 124 | else: #error 125 | errQuant = True 126 | print('No way to find a solution because of the last layer!') 127 | break 128 | else: 129 | 130 | if forward == True: 131 | 132 | print('We are in forward!') 133 | 134 | if i_bits_mem[i] == MIN_ACT_BITS and o_bits_mem[i] == MIN_ACT_BITS : #error 135 | errQuant = True 136 | print('No way to find a solution: layer ',i,'with i_bits,o_bits =',MIN_ACT_BITS) 137 | break 138 | 139 | elif i_bits_mem[i] > o_bits_mem[i] and i_bits_mem[i] > MIN_ACT_BITS: #backward 140 | okQuant = False 141 | 142 | 143 | elif i_bits_mem[i] < o_bits_mem[i] and o_bits_mem[i] > MIN_ACT_BITS: #forward 144 | param_list[i]['act_o_bits'] = int(o_bits_mem[i] / 2) 145 | param_list[i+1]['act_i_bits'] = param_list[i]['act_o_bits'] 146 | 147 | 148 | elif i_bits_mem[i] == o_bits_mem[i]: 149 | if i_mem > o_mem and i_bits_mem[i] > MIN_ACT_BITS : #backward 150 | okQuant = False 151 | 152 | elif o_bits_mem[i] > MIN_ACT_BITS: #forward 153 | param_list[i]['act_o_bits'] = int(o_bits_mem[i] / 2) 154 | param_list[i+1]['act_i_bits'] = param_list[i]['act_o_bits'] 155 | 156 | else: #error 157 | errQuant = True 158 | print('Corner case!') 159 | break 160 | 161 | elif o_bits_mem[i] > MIN_ACT_BITS : #forward 162 | param_list[i]['act_o_bits'] = int(o_bits_mem[i] / 2) 163 | param_list[i+1]['act_i_bits'] = param_list[i]['act_o_bits'] 164 | 165 | elif i_bits_mem[i] > MIN_ACT_BITS : #backward 166 | okQuant = False 167 | 168 | else: #error 169 | errQuant = True 170 | print('No way to find a solution!') 171 | break 172 | 173 | else: 174 | print('We are in backward!') 175 | 176 | if i_bits_mem[i] == MIN_ACT_BITS and o_bits_mem[i] == MIN_ACT_BITS : #error 177 | errQuant = True 178 | print('No way to find a solution: layer ',i,'with i_bits,o_bits =',MIN_ACT_BITS) 179 | break 180 | 181 | elif i_bits_mem[i] > o_bits_mem[i] and i_bits_mem[i] > MIN_ACT_BITS: #backward 182 | param_list[i-1]['act_o_bits'] = int(i_bits_mem[i] / 2) 183 | param_list[i]['act_i_bits'] = param_list[i-1]['act_o_bits'] 184 | 185 | 186 | elif i_bits_mem[i] < o_bits_mem[i] and o_bits_mem[i] > MIN_ACT_BITS: #forward 187 | okQuant = False 188 | 189 | 190 | elif i_bits_mem[i] == o_bits_mem[i]: 191 | if i_mem > o_mem and i_bits_mem[i] > MIN_ACT_BITS : #backward 192 | param_list[i-1]['act_o_bits'] = int(i_bits_mem[i] / 2) 193 | param_list[i]['act_i_bits'] = param_list[i-1]['act_o_bits'] 194 | 195 | elif o_bits_mem[i] > MIN_ACT_BITS: #forward 196 | okQuant = False 197 | 198 | else: #error 199 | errQuant = True 200 | print('Corner case!') 201 | break 202 | 203 | elif o_bits_mem[i] > MIN_ACT_BITS : #forward 204 | okQuant = False 205 | 206 | elif i_bits_mem[i] > MIN_ACT_BITS : #backward 207 | param_list[i-1]['act_o_bits'] = int(i_bits_mem[i] / 2) 208 | param_list[i]['act_i_bits'] = param_list[i-1]['act_o_bits'] 209 | 210 | else: #error 211 | errQuant = True 212 | print('No way to find a solution!') 213 | break 214 | 215 | 216 | i_params_mem, i_bits_mem, o_params_mem, o_bits_mem = compute_activation_footprint(param_list) 217 | tot_layer_mem = ((i_params_mem[i] * i_bits_mem[i]) + (o_params_mem[i] * o_bits_mem[i]) )/ 8 # to bytes 218 | n_iter +=1 219 | 220 | else: 221 | 222 | if okQuant == True: 223 | print('Layer ',i,' is OK - \t Total Mem:', tot_layer_mem, 'with bits: ',i_bits_mem[i], o_bits_mem[i],i_params_mem[i],o_params_mem[i] ) 224 | else: 225 | print('Layer ',i,' is NOK - \t') 226 | okQuantNet = False 227 | 228 | if forward == True: 229 | if i == LAST: 230 | if okQuantNet == True: 231 | completeQuant = True # stop iteration 232 | else: 233 | forward = False # continue in opposite direction 234 | okQuantNet = True # reset of okQuantNet 235 | else: 236 | i += 1 237 | else: 238 | if i == 0: 239 | if okQuantNet == True: 240 | completeQuant = True # stop iteration 241 | else: 242 | forward = True # continue in opposite direction 243 | okQuantNet = True # reset of okQuantNet 244 | else: 245 | i -= 1 246 | 247 | 248 | # final print 249 | i_params_mem, i_bits_mem, o_params_mem, o_bits_mem = compute_activation_footprint(param_list) 250 | print('Input: ',i_bits_mem,'\n Output:', o_bits_mem) 251 | if errQuant is True: 252 | print('No way for quantization') 253 | return -1 254 | 255 | 256 | #################################################################################### 257 | ###################### compute params footprint #################################### 258 | #################################################################################### 259 | def compute_params_footprint(param_list): 260 | LAST = len(param_list) 261 | weight_footprint = [0 for x in range(LAST)] 262 | weight_bits = [0 for x in range(LAST)] 263 | for i,item in enumerate(param_list): 264 | #number of bits 265 | w_bits = item['w_bits'] 266 | 267 | # size of convolutional parameters 268 | len_w = len(item['weight'].size()) 269 | size_param = 1 270 | for v in range(len_w): 271 | size_param *= item['weight'].size(v) 272 | w_footprint = (size_param * w_bits) / 8 #measure in bytes 273 | weight_footprint[i] = w_footprint 274 | weight_bits[i] = w_bits 275 | 276 | return weight_footprint,sum(weight_footprint), weight_bits 277 | 278 | 279 | def compute_bias_footprint(param_list, cfg_quant='MixPL'): 280 | BIAS_PerLayerAsymPACT_BITS = 32 281 | BIAS_CH_PerLayerAsymPACT_BITS = 16 282 | BIAS_CH_ZW_BITS = 8 283 | BIAS_M0_BITS = 32 284 | BIAS_N0_BITS = 8 285 | 286 | LAST = len(param_list) 287 | weight_footprint = [0 for x in range(LAST)] 288 | for i,item in enumerate(param_list): 289 | 290 | #number of bits 291 | w_bits = item['w_bits'] 292 | n_out_ch = item['weight'].size(0) 293 | a_o_bits = item['act_o_bits'] 294 | # size of extra parametes 295 | quant_type = item['quant_type'] 296 | fold_type = item['fold_type'] 297 | 298 | if cfg_quant == 'MixPL': 299 | if quant_type is None: 300 | quant_type = 'PerLayerAsymPACT' 301 | fold_type = 'folding_weights' 302 | elif w_bits < 8: 303 | quant_type = 'PerLayerAsymPACT' 304 | fold_type = 'ICN' 305 | elif a_o_bits < 8: 306 | fold_type = 'ICN' 307 | quant_type = 'PerLayerAsymPACT' 308 | else: 309 | if quant_type is None: 310 | quant_type = 'PerChannelsAsymMinMax' 311 | fold_type = 'ICN' 312 | elif w_bits < 8: 313 | quant_type = 'PerChannelsAsymMinMax' 314 | fold_type = 'ICN' 315 | elif a_o_bits < 8: 316 | fold_type = 'ICN' 317 | quant_type = 'PerChannelsAsymMinMax' 318 | 319 | item['quant_type'] = quant_type 320 | item['fold_type'] = fold_type 321 | 322 | # compute number of bias 323 | if i == LAST-1: 324 | if quant_type == 'PerLayerAsymPACT': 325 | bias_size = int(n_out_ch * BIAS_PerLayerAsymPACT_BITS / 8 ) # bias 326 | bias_size += int(5 * BIAS_PerLayerAsymPACT_BITS / 8 ) # Zi,Zo,Zw,M0,N0 per layer 327 | elif quant_type == 'PerChannelsAsymMinMax': 328 | bias_size = int(n_out_ch * BIAS_PerLayerAsymPACT_BITS / 8 ) # bias 329 | bias_size += int(n_out_ch * BIAS_CH_ZW_BITS / 8 ) # Zw per channel 330 | bias_size += int(1 * BIAS_PerLayerAsymPACT_BITS / 8 ) # Zi,Zo per layer 331 | bias_size += int(2*n_out_ch * BIAS_CH_ZW_BITS / 8 ) # M0,N0 last layer 332 | 333 | elif i < LAST-1: 334 | if quant_type == 'PerLayerAsymPACT': 335 | if fold_type == 'folding_weights': 336 | bias_size = int(n_out_ch * BIAS_PerLayerAsymPACT_BITS / 8 ) # bias 337 | bias_size += int(5 * BIAS_PerLayerAsymPACT_BITS / 8 ) # Zi,Zo,Zw,M0,N0 per layer 338 | elif fold_type == 'ICN': 339 | bias_size = int(n_out_ch * BIAS_M0_BITS / 8 ) # M0 per channel 340 | bias_size += int(n_out_ch * BIAS_N0_BITS / 8 ) # N0 per channel 341 | bias_size += int(n_out_ch * BIAS_PerLayerAsymPACT_BITS / 8 ) # bias 342 | bias_size += int(3 * BIAS_PerLayerAsymPACT_BITS / 8 ) # Zw,Zi,Zo per layer 343 | else: 344 | print('Error!') 345 | exit(0) 346 | 347 | elif quant_type == 'PerChannelsAsymMinMax': 348 | bias_size = int(n_out_ch * BIAS_M0_BITS / 8 ) # M0 per channel 349 | bias_size += int(n_out_ch * BIAS_N0_BITS / 8 ) # N0 per channel 350 | bias_size += int(n_out_ch * BIAS_PerLayerAsymPACT_BITS / 8 ) # bias 351 | bias_size += int(n_out_ch * BIAS_CH_ZW_BITS / 8 ) # Zw per channel 352 | bias_size += int(2 * BIAS_PerLayerAsymPACT_BITS / 8 ) # Zi,Zo per layer 353 | 354 | 355 | else: 356 | print('ERROR!!!') 357 | 358 | weight_footprint[i] = bias_size 359 | 360 | return weight_footprint 361 | 362 | def compute_footprint(param_list, cfg_quant='MixPL'): 363 | weight_footprint, _, _ = compute_params_footprint(param_list) 364 | bias_footprint = compute_bias_footprint(param_list, cfg_quant) 365 | ratio = np.array(bias_footprint) / np.array(weight_footprint) 366 | weight_footprint = sum(weight_footprint) 367 | bias_footprint = sum(bias_footprint) 368 | print('Total Weights: ',int(weight_footprint/1024),'kbytes | bias =' ,int(bias_footprint/1024),'kbytes ('\ 369 | ,100*bias_footprint/weight_footprint ,'%)') 370 | return bias_footprint+weight_footprint 371 | 372 | 373 | def compute_next_cut(param_list, MAX = 0.05, MIN_BIT=2): 374 | 375 | f,_,w_b = compute_params_footprint(param_list) 376 | weight = np.array(f) 377 | total = weight.sum() 378 | perc = weight / total 379 | arr = np.sort(perc)[::-1] 380 | ind = np.argsort(perc)[::-1] 381 | 382 | # remove w_bit = 1 from search 383 | rm = [] 384 | for i,item in enumerate(w_b): 385 | if item == MIN_BIT: 386 | for i_x,x in enumerate(ind): 387 | if x == i: 388 | rm.append(i_x) 389 | continue 390 | 391 | arr = np.delete(arr, rm) 392 | ind = np.delete(ind, rm) 393 | 394 | if len(arr) == 0: 395 | return -1 396 | 397 | #print(arr, ind) 398 | LAST = len(arr) 399 | 400 | cut_we = arr[0] 401 | cut_i = ind[0] 402 | print('most footprint on layer: ',cut_i, 'equal to', cut_we, 'ratio' ) 403 | thr = cut_we - MAX 404 | for i in range(1, LAST): 405 | if arr[i] > thr and ind[i] < cut_i: 406 | cut_i = ind[i] 407 | print(ind[i],arr[i]) 408 | 409 | return cut_i 410 | 411 | 412 | def weight_quantization(param_list, read_only_mem, cfg_quant='MixPL'): 413 | 414 | s = compute_footprint(param_list, cfg_quant) 415 | print('Total Footprint: ', s ) 416 | 417 | while (s > read_only_mem): 418 | 419 | print('***************************************************************') 420 | c = compute_next_cut(param_list) 421 | if c == -1: 422 | print('No way for weight quantization') 423 | return -1 424 | 425 | nb = int(param_list[c]['w_bits'] / 2) 426 | print('Layer to cut: ', c, 'to ', nb ,'bits') 427 | param_list[c]['w_bits'] = nb 428 | s = compute_footprint(param_list, cfg_quant) 429 | 430 | return(param_list) 431 | 432 | #################################################################################### 433 | ############ parse full precision network to map convolutional layers ############## 434 | #################################################################################### 435 | def print_size(model, input, output): 436 | global si, so 437 | si = input[0].size() 438 | so = output[0].size() 439 | 440 | activation_vector_i = [] 441 | activation_vector_o = [] 442 | param_list = [] 443 | model.eval() 444 | 445 | for i,module in enumerate(model.modules()): 446 | if type(module) in [ Conv2d_SAME, nn.Conv2d]: 447 | print('Conv2d (Same) =: ',i) 448 | hook = module.register_forward_hook(print_size) 449 | model(x) 450 | hook.remove() 451 | activation_vector_i.append(si) 452 | activation_vector_o.append(so) 453 | param_list.append({'act_o_bits':8, 'act_i_bits':8, 'w_bits':8 , 'weight': module.weight.data, 'quant_type': None, 'fold_type': None}) 454 | 455 | elif type(module) is nn.Linear: 456 | print('Linear =: ',i) 457 | hook = module.register_forward_hook(print_size) 458 | model(x) 459 | hook.remove() 460 | activation_vector_i.append(si) 461 | activation_vector_o.append(so) 462 | param_list.append({'act_o_bits':8, 'act_i_bits':8, 'w_bits':8 , 'weight': module.weight.data, 'quant_type': None, 'fold_type': None}) 463 | 464 | LAST = len(activation_vector_i)-1 465 | 466 | for i,item in enumerate(activation_vector_i): 467 | print('Size of layer ', i , ': ', item) 468 | 469 | 470 | #################################################################################### 471 | ###################### cut weight & activation bits ################################ 472 | #################################################################################### 473 | ret = cut_activation_footprint(param_list, READWRITE_MEM) 474 | if ret == -1: 475 | return -1 476 | ret = weight_quantization(param_list, ONLY_READ_MEM, cfg_quant = cfg_quant) 477 | if ret == -1: 478 | return -1 479 | 480 | 481 | add_config = [] 482 | LAST = len(param_list)-1 483 | for i,item in enumerate(param_list): 484 | print('***** Layer ', i , '**********') 485 | print('| i_bits : ',item['act_i_bits'], '\t| o_bits : ',item['act_o_bits'],'\t| w_bits' ,item['w_bits'],'\t| quant_type : ',item['quant_type'],'\t| fold_type : ',item['fold_type']) 486 | if i == LAST: 487 | add_config.append({'layer':i, 'w_bits': item['w_bits'], 'quant_type': item['quant_type'] }) 488 | else: 489 | add_config.append({'layer':i, 'w_bits': item['w_bits'], 'quant_type': item['quant_type'] , 'fold_type': item['fold_type'], 'a_bits': item['act_o_bits'] }) 490 | 491 | 492 | return add_config -------------------------------------------------------------------------------- /main_binary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import models 12 | from torch.autograd import Variable 13 | from data import get_dataset, get_num_classes 14 | from preprocess import get_transform 15 | from utils import * 16 | from datetime import datetime 17 | from ast import literal_eval 18 | import json 19 | from torchvision.utils import save_image 20 | import quantization 21 | from quantization.quant_auto import memory_driven_quant 22 | 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 28 | 29 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', 30 | help='results dir') 31 | parser.add_argument('--save', metavar='SAVE', default='', 32 | help='saved folder') 33 | parser.add_argument('--dataset', metavar='DATASET', default='cifar10', 34 | help='dataset name or folder') 35 | parser.add_argument('--model', '-a', metavar='MODEL', default='vgg_cifar10_binary', 36 | choices=model_names, 37 | help='model architecture: ' + 38 | ' | '.join(model_names) + 39 | ' (default: alexnet)') 40 | parser.add_argument('--input_size', type=int, default=None, 41 | help='image input size') 42 | parser.add_argument('--model_config', default='', 43 | help='additional architecture configuration') 44 | parser.add_argument('--type', default='torch.cuda.FloatTensor', 45 | help='type of tensor - e.g torch.cuda.HalfTensor') 46 | parser.add_argument('--gpus', default='0,1,2,3', 47 | help='gpus used for training - e.g 0,1,3') 48 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 49 | help='number of data loading workers (default: 8)') 50 | parser.add_argument('--epochs', default=150, type=int, metavar='N', 51 | help='number of total epochs to run') 52 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 53 | help='manual epoch number (useful on restarts)') 54 | parser.add_argument('-b', '--batch-size', default=256, type=int, 55 | metavar='N', help='mini-batch size (default: 256)') 56 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 57 | help='optimizer function used') 58 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 59 | metavar='LR', help='initial learning rate') 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum') 62 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 63 | metavar='W', help='weight decay (default: 1e-4)') 64 | parser.add_argument('--print-freq', '-p', default=100, type=int, 65 | metavar='N', help='print frequency (default: 10)') 66 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 67 | help='path to latest checkpoint (default: none)') 68 | parser.add_argument('-e', '--evaluate', action='store_true', 69 | help='run model on validation set') 70 | parser.add_argument('--save_check', action='store_true', 71 | help='saving the checkpoint') 72 | #binarization parameters 73 | parser.add_argument('--quantizer', action='store_true', 74 | help='using the quantizer flow') 75 | parser.add_argument('--type_quant', default=None, 76 | help='Type of binarization process') 77 | parser.add_argument('--weight_bits', default=1, 78 | help='Number of bits for the weights') 79 | parser.add_argument('--activ_bits', default=1, 80 | help='Number of bits for the activations') 81 | parser.add_argument('--activ_type', default='hardtanh', 82 | help='Type of the quantized activation layers') 83 | 84 | parser.add_argument('--batch_fold_delay', default=0, type=int, 85 | help='Apply folding of batch layers into convolutional') 86 | parser.add_argument('--batch_fold_type', default='folding_weights', type=str, 87 | help='Type of folding for batch norm layers: folding_weights | ICN') 88 | parser.add_argument('--quant_add_config', default='', type=str, 89 | help='Additional config of per-layer quantization') 90 | 91 | #mobilenet params 92 | parser.add_argument('--mobilenet_width', default=1.0, type=float, 93 | help='Mobilenet Width Muliplier') 94 | parser.add_argument('--mobilenet_input', default=224, type=int, 95 | help='Mobilenet input resolution ') 96 | 97 | #mixed-precision params 98 | parser.add_argument('--mem_constraint', default='', type=str, 99 | help='Memory constraints for automatic bitwidth quantization') 100 | parser.add_argument('--mixed_prec_quant', default='MixPL', type=str, 101 | help='Type of quantization for mixed-precision low bitwidth: MixPL | MixPC') 102 | 103 | 104 | def main(): 105 | global args, best_prec1 106 | best_prec1 = 0 107 | args = parser.parse_args() 108 | 109 | weight_bits = int(args.weight_bits) 110 | activ_bits = int(args.activ_bits) 111 | 112 | 113 | if args.save is '': 114 | args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 115 | save_path = os.path.join(args.results_dir, args.save) 116 | if not os.path.exists(save_path): 117 | os.makedirs(save_path) 118 | 119 | setup_logging(os.path.join(save_path, 'log.txt')) 120 | results_file = os.path.join(save_path, 'results.%s') 121 | results = ResultsLog(results_file % 'csv', results_file % 'html') 122 | 123 | logging.info("saving to %s", save_path) 124 | logging.debug("run arguments: %s", args) 125 | 126 | if 'cuda' in args.type: 127 | args.gpus = [int(i) for i in args.gpus.split(',')] 128 | print('Selected GPUs: ', args.gpus) 129 | torch.cuda.set_device(args.gpus[0]) 130 | cudnn.benchmark = True 131 | else: 132 | args.gpus = None 133 | 134 | # create model 135 | logging.info("creating model %s", args.model) 136 | model = models.__dict__[args.model] 137 | nClasses = get_num_classes(args.dataset) 138 | model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': nClasses, \ 139 | 'type_quant': args.type_quant, 'weight_bits': weight_bits, 'activ_bits': activ_bits,\ 140 | 'activ_type': args.activ_type, 'width_mult': float(args.mobilenet_width), 'input_dim': float(args.mobilenet_input) } 141 | 142 | if args.model_config is not '': 143 | model_config = dict(model_config, **literal_eval(args.model_config)) 144 | 145 | model = model(**model_config) 146 | logging.info("created model with configuration: %s", model_config) 147 | print(model) 148 | 149 | num_parameters = sum([l.nelement() for l in model.parameters()]) 150 | logging.info("number of parameters: %d", num_parameters) 151 | 152 | # Data loading code 153 | default_transform = { 154 | 'train': get_transform(args.dataset, 155 | input_size=args.input_size, augment=True), 156 | 'eval': get_transform(args.dataset, 157 | input_size=args.input_size, augment=False) 158 | } 159 | transform = getattr(model, 'input_transform', default_transform) 160 | regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer, 161 | 'lr': args.lr, 162 | 'momentum': args.momentum, 163 | 'weight_decay': args.weight_decay}}) 164 | print(transform) 165 | # define loss function (criterion) and optimizer 166 | criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)() 167 | criterion.type(args.type) 168 | 169 | 170 | val_data = get_dataset(args.dataset, 'val', transform['eval']) 171 | val_loader = torch.utils.data.DataLoader( 172 | val_data, 173 | batch_size=args.batch_size, shuffle=False, 174 | num_workers=args.workers, pin_memory=True) 175 | 176 | if args.quantizer: 177 | val_quant_loader = torch.utils.data.DataLoader( 178 | val_data, 179 | batch_size=32, shuffle=False, 180 | num_workers=args.workers, pin_memory=True) 181 | 182 | 183 | train_data = get_dataset(args.dataset, 'train', transform['train']) 184 | train_loader = torch.utils.data.DataLoader( 185 | train_data, 186 | batch_size=args.batch_size, shuffle=True, 187 | num_workers=args.workers, pin_memory=True) 188 | 189 | 190 | #define optimizer 191 | params_dict = dict(model.named_parameters()) 192 | params = [] 193 | for key, value in params_dict.items(): 194 | if 'clip_val' in key: 195 | params += [{'params':value,'weight_decay': 1e-4}] 196 | else: 197 | params += [{'params':value}] 198 | optimizer = torch.optim.SGD(params, lr=0.1) 199 | logging.info('training regime: %s', regime) 200 | 201 | #define quantizer 202 | if args.quantizer: 203 | if args.mem_constraint is not '': 204 | mem_contraints = json.loads(args.mem_constraint) 205 | print('This is the memory constraint:', mem_contraints ) 206 | if mem_contraints is not None: 207 | x_test = torch.Tensor(1,3,args.mobilenet_input,args.mobilenet_input) 208 | add_config = memory_driven_quant(model, x_test, mem_contraints[0], mem_contraints[1], args.mixed_prec_quant) 209 | if add_config == -1: 210 | print('The quantization process failed!') 211 | else: 212 | add_config = [] 213 | else: 214 | mem_constraint = None 215 | if args.quant_add_config is not '': 216 | add_config = json.loads(args.quant_add_config) 217 | 218 | else: 219 | add_config = [] 220 | 221 | quantizer = quantization.QuantOp(model, args.type_quant, weight_bits, \ 222 | batch_fold_type=args.batch_fold_type, batch_fold_delay=args.batch_fold_delay, act_bits=activ_bits, \ 223 | add_config = add_config ) 224 | quantizer.deployment_model.type(args.type) 225 | quantizer.add_params_to_optimizer(optimizer) 226 | 227 | else: 228 | quantizer = None 229 | 230 | #exit(0) 231 | 232 | 233 | #multi gpus 234 | if args.gpus and len(args.gpus) > 1: 235 | model = torch.nn.DataParallel(model).cuda() 236 | else: 237 | model.type(args.type) 238 | 239 | 240 | if args.resume: 241 | checkpoint_file = args.resume 242 | if os.path.isdir(checkpoint_file): 243 | checkpoint_file = os.path.join( 244 | checkpoint_file, 'model_best.pth.tar') 245 | if os.path.isfile(checkpoint_file): 246 | logging.info("loading checkpoint '%s'", args.resume) 247 | checkpoint_loaded = torch.load(checkpoint_file) 248 | checkpoint = checkpoint_loaded['state_dict'] 249 | model.load_state_dict(checkpoint, strict=False) 250 | print('Model pretrained') 251 | else: 252 | logging.error("no checkpoint found at '%s'", args.resume) 253 | 254 | if args.quantizer: 255 | quantizer.init_parameters() 256 | 257 | if args.evaluate: 258 | # evaluate on validation set 259 | 260 | if args.quantizer: 261 | # evaluate deployment model on validation set 262 | quantizer.generate_deployment_model() 263 | val_quant_loss, val_quant_prec1, val_quant_prec5 = validate( 264 | val_quant_loader, quantizer.deployment_model, criterion, 0, 'deployment' ) 265 | else: 266 | val_quant_loss, val_quant_prec1, val_quant_prec5 = 0, 0, 0 267 | 268 | val_loss, val_prec1, val_prec5 = validate( 269 | val_loader, model, criterion, 0, quantizer) 270 | 271 | logging.info('\n This is the results from evaluation only: ' 272 | 'Validation Prec@1 {val_prec1:.3f} \t' 273 | 'Validation Prec@5 {val_prec5:.3f} \t' 274 | 'Validation Quant Prec@1 {val_quant_prec1:.3f} \t' 275 | 'Validation Quant Prec@5 {val_quant_prec5:.3f} \n' 276 | .format(val_prec1=val_prec1, val_prec5=val_prec5, 277 | val_quant_prec1=val_quant_prec1, val_quant_prec5=val_quant_prec5)) 278 | exit(0) 279 | 280 | 281 | 282 | for epoch in range(args.start_epoch, args.epochs): 283 | optimizer = adjust_optimizer(optimizer, epoch, regime) 284 | 285 | # train for one epoch 286 | train_loss, train_prec1, train_prec5 = train( 287 | train_loader, model, criterion, epoch, optimizer, quantizer) 288 | 289 | # evaluate on validation set 290 | val_loss, val_prec1, val_prec5 = validate( 291 | val_loader, model, criterion, epoch, quantizer) 292 | 293 | if args.quantizer: 294 | # evaluate deployment model on validation set 295 | quantizer.generate_deployment_model() 296 | val_quant_loss, val_quant_prec1, val_quant_prec5 = validate( 297 | val_quant_loader, quantizer.deployment_model, criterion, epoch, 'deployment' ) 298 | else: 299 | val_quant_loss, val_quant_prec1, val_quant_prec5 = 0, 0, 0 300 | 301 | 302 | # remember best prec@1 and save checkpoint 303 | is_best = val_prec1 > best_prec1 304 | best_prec1 = max(val_prec1, best_prec1) 305 | 306 | #save_model 307 | if args.save_check: 308 | 309 | print('Saving Model!! Accuracy : ', best_prec1) 310 | save_checkpoint({ 311 | 'epoch': epoch + 1, 312 | 'model': args.model, 313 | 'config': model_config, 314 | 'state_dict': model.state_dict(), 315 | 'best_prec1': best_prec1, 316 | 'regime': regime , 317 | 'quantizer': quantizer, 318 | 'add_config': add_config, 319 | 'fold_type': args.batch_fold_type 320 | }, is_best, path=save_path) 321 | 322 | 323 | logging.info('\n Epoch: {0}\t' 324 | 'Training Loss {train_loss:.4f} \t' 325 | 'Training Prec@1 {train_prec1:.3f} \t' 326 | 'Training Prec@5 {train_prec5:.3f} \t' 327 | 'Validation Loss {val_loss:.4f} \t' 328 | 'Validation Prec@1 {val_prec1:.3f} \t' 329 | 'Validation Prec@5 {val_prec5:.3f} \t' 330 | 'Validation Quant Prec@1 {val_quant_prec1:.3f} \t' 331 | 'Validation Quant Prec@5 {val_quant_prec5:.3f} \n' 332 | .format(epoch + 1, train_loss=train_loss, val_loss=val_loss, 333 | train_prec1=train_prec1, val_prec1=val_prec1, 334 | train_prec5=train_prec5, val_prec5=val_prec5, 335 | val_quant_prec1=val_quant_prec1, val_quant_prec5=val_quant_prec5)) 336 | 337 | 338 | results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss, 339 | train_error1=100 - train_prec1, val_error1=100 - val_prec1, 340 | train_error5=100 - train_prec5, val_error5=100 - val_prec5, 341 | val_quant_error1=100 - val_quant_prec1, val_quant_error5=100 - val_quant_prec5) 342 | results.save() 343 | 344 | 345 | 346 | 347 | def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None, quantizer=None ): 348 | 349 | # if args.gpus and len(args.gpus) > 1: 350 | # model = torch.nn.DataParallel(model, args.gpus) 351 | 352 | batch_time = AverageMeter() 353 | data_time = AverageMeter() 354 | losses = AverageMeter() 355 | top1 = AverageMeter() 356 | top5 = AverageMeter() 357 | 358 | end = time.time() 359 | 360 | # apply transofrms at the begininng of each epoch 361 | 362 | print('Training: ',training ) 363 | 364 | if quantizer is not None and quantizer is not 'deployment': 365 | quantizer.freeze_BN_and_fold(epoch) 366 | 367 | # input quantization 368 | n_bits_inpt = 8 #retrieve from quantizer in future version 369 | max_inpt, min_inpt = 1, -1 #retrieve from quantizer in future version 370 | n = 2 ** n_bits_inpt - 1 371 | scale_factor = n / (max_inpt - min_inpt) 372 | 373 | 374 | for i, (inputs, target) in enumerate(data_loader): 375 | # measure data loading time 376 | data_time.update(time.time() - end) 377 | if args.gpus is not None: 378 | target = target.cuda(async=True) 379 | 380 | with torch.no_grad(): 381 | input_var = Variable(inputs.type(args.type)) 382 | target_var = Variable(target) 383 | 384 | # quantization before computing output 385 | if quantizer == 'deployment': 386 | input_var = input_var.clamp(min_inpt, max_inpt).mul(scale_factor).round() 387 | elif quantizer is not None: 388 | input_var = input_var.clamp(min_inpt, max_inpt).mul(scale_factor).round().div(scale_factor) 389 | quantizer.store_and_quantize(training=training ) 390 | 391 | # compute output 392 | output = model(input_var) 393 | 394 | loss = criterion(output, target_var) 395 | if type(output) is list: 396 | output = output[0] 397 | 398 | # measure accuracy and record loss 399 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 400 | losses.update(loss.data.item(), inputs.size(0)) 401 | top1.update(prec1.item(), inputs.size(0)) 402 | top5.update(prec5.item(), inputs.size(0)) 403 | 404 | if training: 405 | # compute gradient and do SGD step 406 | optimizer.zero_grad() 407 | loss.backward() 408 | 409 | # restore real value parameters before update 410 | if quantizer is not None: 411 | quantizer.backprop_quant_gradients() 412 | quantizer.restore_real_value() 413 | 414 | optimizer.step() 415 | 416 | elif quantizer is not None and quantizer is not 'deployment': 417 | quantizer.restore_real_value() 418 | 419 | 420 | # measure elapsed time 421 | batch_time.update(time.time() - end) 422 | end = time.time() 423 | 424 | if i % args.print_freq == 0: 425 | logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 426 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 427 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 428 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 429 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 430 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 431 | epoch, i, len(data_loader), 432 | phase='TRAINING' if training else 'EVALUATING', 433 | batch_time=batch_time, 434 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 435 | 436 | return losses.avg, top1.avg, top5.avg 437 | 438 | 439 | def train(data_loader, model, criterion, epoch, optimizer, quantizer): 440 | 441 | # switch to train mode 442 | model.train() 443 | return forward(data_loader, model, criterion, epoch, 444 | training=True, optimizer=optimizer, quantizer=quantizer ) 445 | 446 | 447 | def validate(data_loader, model, criterion, epoch, quantizer ): 448 | 449 | # switch to evaluate mode 450 | model.eval() 451 | return forward(data_loader, model, criterion, epoch, 452 | training=False, optimizer=None, quantizer=quantizer) 453 | 454 | 455 | if __name__ == '__main__': 456 | main() 457 | -------------------------------------------------------------------------------- /quantization/quantop.py: -------------------------------------------------------------------------------- 1 | # 2 | # quantop.py 3 | # Manuele Rusci 4 | # 5 | # Copyright (C) 2019 University of Bologna 6 | # All rights reserved. 7 | # 8 | 9 | import torch, math 10 | import torch.nn as nn 11 | import numpy as np 12 | from scipy.linalg import lstsq 13 | from torch.autograd import Variable 14 | import copy 15 | 16 | from .quant_auto import memory_driven_quant 17 | 18 | import sys 19 | sys.path.append('../') 20 | 21 | 22 | from models.linear_quantized_modules import ClippedLinearQuantization, LearnedClippedLinearQuantization, ScaledClippedLinearQuantization, Conv2d_SAME, ScaledClippedLinearQuantizationChannel 23 | 24 | 25 | 26 | ############################################################################ 27 | #### Utils 28 | ############################################################################ 29 | 30 | def has_children(module): 31 | try: 32 | next(module.children()) 33 | return True 34 | except StopIteration: 35 | return False 36 | 37 | 38 | def get_activ_stats(layer_int): 39 | a_bits = layer_int['act_o_bits'] 40 | n_levels_a = (2**a_bits)-1 41 | if type(layer_int['act']) == ClippedLinearQuantization: 42 | a_min = 0 43 | a_max = layer_int['act'].clip_val 44 | S_a = a_max / n_levels_a 45 | Z_a = 0 46 | elif type(layer_int['act']) == LearnedClippedLinearQuantization: 47 | a_min = 0 48 | a_max = layer_int['act'].clip_val.data.item() 49 | S_a = a_max / n_levels_a 50 | Z_a = 0 51 | return S_a, Z_a, n_levels_a, a_min, a_max 52 | 53 | 54 | def get_first_layer_stats(): 55 | # corresponds to the 8 bit quatized inputs of the imagenet dataset 56 | a_min, a_max = -1, 1 57 | n_levels_a = (2**8)-1 58 | S_a = (a_max - a_min) / n_levels_a 59 | Z_a = 0.0 60 | return S_a, Z_a, n_levels_a, a_min, a_max 61 | 62 | 63 | ############################################################################ 64 | #### Quantization and Binarization Library 65 | ############################################################################ 66 | 67 | 68 | ''' 69 | The QuantOp operator handles the quantization process of model's parameters. 70 | This is inspired from the code in https://github.com/jiecaoyu/XNOR-Net-PyTorch 71 | 72 | Implemented quantization approaches: 73 | - 'PerLayerAsymMinMax', apply per-layer asymmetric weight quantization w/ min-max range 74 | - 'PerLayerAsymPACT', apply per-layer asymmetric weight quantization w/ PACT strategy 75 | - 'PerChannelsAsymMinMax', apply per-channel asymmetric weight quantization w/ min-max range 76 | ''' 77 | 78 | _Availble_Quant_Type = ['PerLayerAsymMinMax', 'PerLayerAsymPACT', 'PerChannelsAsymMinMax', None] 79 | _Supported_Activ_Funct = [nn.ReLU6, nn.ReLU, ClippedLinearQuantization, LearnedClippedLinearQuantization] 80 | class QuantOp(): 81 | 82 | 83 | 84 | def __init__(self, 85 | model, 86 | quant_type = None, 87 | weight_bits= 8, 88 | bias_bits = 32, 89 | batch_fold_type = 'folding_weights', 90 | batch_fold_delay = 0, 91 | act_bits = 8, 92 | add_config = [] ): 93 | ''' 94 | Description: Wraps the full-precision model to apply weight quantization and build the deployments graph 95 | Arguments: 96 | - model: model to quantize 97 | - quant_type: 'None' default type of per-layer quantization for the weights 98 | - weight_bits: default number of bits for the weights 99 | - batch_fold_delay: number of epochs before folding batch normalization layers 100 | - act_bits: default number of bits for the activations 101 | ''' 102 | 103 | # generate quantized model 104 | self.deployment_model = copy.deepcopy(model) 105 | self.param_to_quantize = [] 106 | self.batch_fold = False 107 | self.batch_fold_delay = batch_fold_delay 108 | self.batch_fold_type = batch_fold_type 109 | 110 | print('Batch Folding: ', self.batch_fold, 'Batch Folding Delay: ', batch_fold_delay, 'Type', batch_fold_type ) 111 | 112 | #start 113 | last_layer = None 114 | modules_quant = [ ] 115 | 116 | print(add_config) 117 | 118 | 119 | 120 | for name_sub,submodel in model.named_children(): 121 | 122 | # map the sub-graph 123 | for name, module in submodel.named_modules(): 124 | if has_children(module) is False: 125 | 126 | layer_quant_descr = {} 127 | 128 | if type(module) in [nn.Conv2d, nn.Linear, Conv2d_SAME]: 129 | 130 | # saving per-node characteristics 131 | layer_quant_descr['name'] = name 132 | layer_quant_descr['w_bits'] = weight_bits 133 | layer_quant_descr['fold_type'] = self.batch_fold_type 134 | 135 | temp = getattr(submodel, 'quant_type', quant_type) 136 | if temp in _Availble_Quant_Type : 137 | layer_quant_descr['quant_type'] = temp 138 | else: 139 | print('Type of quantization not recognized') 140 | return -1 141 | layer_quant_descr['conv'] = module 142 | layer_quant_descr['weight'] = module.weight.data.clone() 143 | layer_quant_descr['w_clip'] = None 144 | 145 | # import per-layer config from external config file 146 | idx_layer = len(self.param_to_quantize) 147 | for item_dict in add_config: 148 | if item_dict['layer'] == idx_layer: 149 | if 'w_bits' in item_dict.keys(): 150 | layer_quant_descr['w_bits'] = item_dict['w_bits'] 151 | 152 | if 'fold_type' in item_dict.keys(): 153 | layer_quant_descr['fold_type'] = item_dict['fold_type'] 154 | 155 | if 'quant_type' in item_dict.keys(): 156 | layer_quant_descr['quant_type'] = item_dict['quant_type'] 157 | 158 | if module.bias is None: 159 | layer_quant_descr['bias'] = False 160 | else: 161 | layer_quant_descr['bias'] = module.bias.data.clone() 162 | layer_quant_descr['bias_bits'] = bias_bits 163 | layer_quant_descr['batch_norm'] = None 164 | layer_quant_descr['act'] = None 165 | layer_quant_descr['quant_act'] = None 166 | 167 | # append into deployment graph 168 | quant_layer = copy.deepcopy(module) 169 | modules_quant.append(quant_layer) 170 | layer_quant_descr['quant_conv'] = quant_layer 171 | 172 | #PACT needs extra parameters for learning quantization range 173 | if layer_quant_descr['quant_type'] == 'PerLayerAsymPACT': 174 | layer_quant_descr['w_max_thr'] = nn.Parameter( module.weight.data.max().cuda(),requires_grad=True ) 175 | layer_quant_descr['w_min_thr'] = nn.Parameter( module.weight.data.min().cuda(),requires_grad=True ) 176 | layer_quant_descr['w_max_thr'].add(1).sum().backward() #fake to create grad 177 | layer_quant_descr['w_min_thr'].add(1).sum().backward() #fake to create grad 178 | 179 | elif layer_quant_descr['quant_type'] == 'PerChannelsAsymMinMax': 180 | weight = module.weight.data 181 | if type(module) == nn.Linear: 182 | print('PerLayerAsymPACT chennel with ICN on last layer') 183 | out_ch = weight.size(0) 184 | quant_act = ScaledClippedLinearQuantizationChannel(out_ch, clip_val=False) 185 | layer_quant_descr['quant_act'] = quant_act 186 | modules_quant.append(quant_act) 187 | 188 | # append the latest node 189 | self.param_to_quantize.append(layer_quant_descr) 190 | last_layer = self.param_to_quantize[-1] 191 | 192 | elif (type(module) in [nn.BatchNorm2d]): 193 | if last_layer is not None: 194 | last_layer['batch_norm'] = module 195 | 196 | elif type(module) in [nn.AvgPool2d]: 197 | modules_quant.append(module) # temporary - this should me merged into or previous layer 198 | 199 | elif type(module) in _Supported_Activ_Funct: 200 | 201 | # check if a quantized features an activation function 202 | if last_layer is not None: 203 | 204 | last_layer['act'] = module 205 | 206 | # quantized activations 207 | if type(module) in [ClippedLinearQuantization, LearnedClippedLinearQuantization]: 208 | 209 | act_bits = module.num_bits 210 | #check if number of bits need to be changed 211 | for item_dict in add_config: 212 | if item_dict['layer'] == idx_layer: 213 | if 'a_bits' in item_dict.keys(): 214 | act_bits = item_dict['a_bits'] 215 | module.num_bits = act_bits 216 | 217 | if last_layer['fold_type'] == 'ICN': 218 | out_ch = last_layer['conv'].out_channels 219 | quant_act = ScaledClippedLinearQuantizationChannel(out_ch,clip_val=2**act_bits -1) 220 | 221 | else: 222 | quant_act = ScaledClippedLinearQuantization(clip_val=2**act_bits -1) 223 | 224 | last_layer['act_o_bits'] = act_bits 225 | 226 | # full-precision activations 227 | elif type(module) is nn.ReLU6: 228 | quant_act = nn.ReLU6(inplace=True) 229 | 230 | 231 | else: 232 | print('Supported activation layer but no method is here yet!') 233 | return 0 234 | 235 | 236 | last_layer['quant_act'] = quant_act 237 | 238 | modules_quant.append(quant_act) 239 | 240 | else: 241 | print(type(module), 'not supported yet!') 242 | 243 | else: 244 | #if type(module) is not nn.Sequential: 245 | last_layer = None 246 | 247 | self.deployment_model._modules[name_sub] = nn.Sequential(*modules_quant) 248 | modules_quant = [ ] 249 | 250 | 251 | self.n_quantize_layers = len(self.param_to_quantize) 252 | 253 | print('This is the quantized network: ') 254 | print(self.deployment_model) 255 | 256 | def _get_m0_nexp_vect (self, data_vec): 257 | dim = data_vec.size(0) 258 | M0 = torch.Tensor(dim).type_as(data_vec) 259 | N0 = torch.Tensor(dim).type_as(data_vec) 260 | for i in range(dim): 261 | M0[i],N0[i] = self._get_m0_nexp(data_vec[i]) 262 | 263 | return M0, N0 264 | 265 | def _get_m0_nexp(self, data_s): 266 | ''' Decompose data_s as data*2**n_exp ''' 267 | n_exp = 0 268 | data = abs(data_s) 269 | if data == 0.0: 270 | print('Data is 0') 271 | return data, n_exp 272 | 273 | while data >= 1 or data < 0.5: 274 | if data >= 1: 275 | data = data / 2 276 | n_exp += 1 277 | else: 278 | data = data * 2 279 | n_exp -= 1 280 | 281 | if data_s < 0: 282 | data = -data 283 | 284 | return data, n_exp 285 | 286 | 287 | def _linear_asymmetric_quant_args(self, data_tensor, n_bits ): 288 | ''' 289 | Quantize data_tensor in min/max quantization range 290 | and return quantized tensor 291 | ''' 292 | max_value = data_tensor.max() 293 | min_value = data_tensor.min() 294 | range_values = max_value - min_value 295 | 296 | # to prevent division by 0 impose a quantization range of 1 297 | if range_values == 0: 298 | range_values = 1 299 | 300 | n_steps = (2**n_bits)-1 301 | eps = range_values / n_steps 302 | data_quant = data_tensor.div(eps).round().mul(eps) 303 | 304 | return data_quant, min_value, max_value 305 | 306 | def _clipped_linear_asymmetric_quant_args(self, data_tensor, n_bits, min_value, max_value): 307 | ''' Applying clipping before linear quantization ''' 308 | data_tensor.clamp_(min_value, max_value) 309 | return self._linear_asymmetric_quant_args(data_tensor, n_bits) 310 | 311 | 312 | def _get_BN_scale_factors(self, batch_layer): 313 | 314 | eps = batch_layer.eps 315 | gamma_tensor = batch_layer.weight.data 316 | beta_tensor = batch_layer.bias.data 317 | mu_tensor = batch_layer.running_mean 318 | var_tensor = batch_layer.running_var 319 | 320 | # check for instabilities 321 | if var_tensor.le(0).sum() >0: 322 | print('erorr') 323 | return 324 | 325 | sigma_tensor = var_tensor.add(eps).sqrt() 326 | gamma_over_sigma = gamma_tensor / sigma_tensor 327 | 328 | return gamma_over_sigma, mu_tensor, beta_tensor 329 | 330 | 331 | 332 | def _batch_fold(self,conv_layer, batch_layer ): 333 | ''' 334 | Folding of batch normalization parameters into conv_layer params 335 | Return folded conv_layer tensors 336 | ''' 337 | weight_tensor = conv_layer.weight.data.clone() 338 | weight_tensor_size = weight_tensor.size() 339 | eps = batch_layer.eps 340 | gamma_tensor = batch_layer.weight.data 341 | beta_tensor = batch_layer.bias.data 342 | mu_tensor = batch_layer.running_mean 343 | var_tensor = batch_layer.running_var 344 | 345 | #assuming convolution bias == False 346 | bias_tensor = mu_tensor.clone() 347 | bias_tensor = bias_tensor.mul(-1) 348 | 349 | # check for instabilities 350 | if var_tensor.le(0).sum() >0: 351 | print('erorr') 352 | return 353 | 354 | sigma_tensor = var_tensor.add(eps).sqrt() 355 | gamma_over_sigma = gamma_tensor / sigma_tensor 356 | gamma_over_sigma_tensor = gamma_over_sigma.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(weight_tensor_size ) 357 | weight_tensor = weight_tensor * gamma_over_sigma_tensor 358 | bias_tensor = (bias_tensor*gamma_over_sigma)+beta_tensor 359 | 360 | return weight_tensor, bias_tensor 361 | 362 | def _batch_defold(self, weight_tensor, bias_tensor, batch_layer): 363 | ''' 364 | De-Foldingconv_layer params from batch-norm parameters 365 | This is needed for fake-quantized graph 366 | ''' 367 | weight_tensor_size = weight_tensor.size() 368 | eps = batch_layer.eps 369 | gamma_tensor = batch_layer.weight.data 370 | beta_tensor = batch_layer.bias.data 371 | mu_tensor = batch_layer.running_mean 372 | var_tensor = batch_layer.running_var 373 | 374 | # check for instabilities 375 | if var_tensor.le(0).sum() >0: 376 | print('erorr') 377 | return 378 | 379 | sigma_tensor = var_tensor.add(eps).sqrt() 380 | gamma_over_sigma = gamma_tensor / sigma_tensor 381 | gamma_over_sigma_tensor = gamma_over_sigma.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(weight_tensor_size ) 382 | 383 | weight_tensor = torch.where(gamma_over_sigma_tensor.eq(0.0), gamma_over_sigma_tensor, weight_tensor / gamma_over_sigma_tensor) 384 | if bias_tensor is not False: 385 | bias_tensor = torch.where(gamma_over_sigma.eq(0.0), bias_tensor, ((bias_tensor-beta_tensor)/ gamma_over_sigma)+mu_tensor) 386 | 387 | return weight_tensor, bias_tensor 388 | 389 | 390 | def add_params_to_optimizer(self, optimizer ): 391 | ''' 392 | Adding PACT parameters to the optimizer 393 | ''' 394 | for layer in self.param_to_quantize: 395 | quant_type = layer['quant_type'] 396 | w_bits = layer['w_bits'] 397 | conv_layer = layer['conv'] 398 | if quant_type in ['PerLayerAsymPACT']: 399 | optimizer.add_param_group({"params": layer['w_min_thr'], 'weight_decay': 5e-4}) 400 | optimizer.add_param_group({"params": layer['w_max_thr'], 'weight_decay': 5e-4}) 401 | 402 | 403 | def init_parameters(self ): 404 | ''' 405 | Init PACT parameters 406 | ''' 407 | for layer in self.param_to_quantize: 408 | quant_type = layer['quant_type'] 409 | w_bits = layer['w_bits'] 410 | conv_layer = layer['conv'] 411 | 412 | if quant_type == 'PerLayerAsymPACT': 413 | 414 | if layer['batch_norm'] is not None and layer['fold_type'] == 'folding_weights': 415 | batch_layer = layer['batch_norm'] 416 | weight_tensor, bias_tensor = self._batch_fold(conv_layer, batch_layer) 417 | layer['w_max_thr'].data = weight_tensor.max().cuda() 418 | layer['w_min_thr'].data = weight_tensor.min().cuda() 419 | 420 | ## PPQ 421 | elif layer['fold_type'] == 'ICN': 422 | weights = conv_layer.weight.data 423 | w_levels = 2**w_bits -1 424 | mult = (w_levels/2) 425 | mean_w = weights.mean() 426 | gamma= weights.add(-mean_w).abs().max().cuda().mul(2/w_levels) 427 | for i in range(10): 428 | gamma_zero = gamma 429 | 430 | mult = (w_levels/2) 431 | max_v = max(mean_w + gamma_zero*mult, 0) 432 | min_v = min(mean_w - gamma_zero*mult ,0) 433 | quant_weights = weights.clamp(min_v,max_v).div(gamma_zero).round() 434 | sum_quant = quant_weights.mul(quant_weights).sum() 435 | sum_proj = quant_weights.mul(weights).sum() 436 | gamma = sum_proj / sum_quant 437 | 438 | layer['w_max_thr'].data.fill_(max(mean_w + gamma_zero*mult, 0)) 439 | layer['w_min_thr'].data.fill_(min(mean_w - gamma_zero*mult ,0)) 440 | 441 | else: 442 | layer['w_max_thr'].data = conv_layer.weight.data.max().cuda() 443 | layer['w_min_thr'].data = conv_layer.weight.data.min().cuda() 444 | 445 | def freeze_BN_and_fold(self, epoch): 446 | ''' 447 | Freeze BN layer params after batch_fold_delay epochs 448 | When batch_fold is True, BN params are frozen and folding is applied if requested 449 | ''' 450 | 451 | old_batch_fold = self.batch_fold 452 | 453 | #enable batch-folding 454 | if epoch >= self.batch_fold_delay: 455 | self.batch_fold = True 456 | print("Frozen Batch Normalization Layers statistics!") 457 | 458 | change_fold_mode = self.batch_fold and (not old_batch_fold) 459 | print('Changing Folding Mode: ', change_fold_mode) 460 | 461 | 462 | for i_l,layer in enumerate(self.param_to_quantize): 463 | quant_type = layer['quant_type'] 464 | w_bits = layer['w_bits'] 465 | conv_layer = layer['conv'] 466 | 467 | # freeze BN parametes 468 | if quant_type in ['PerLayerAsymMinMax','PerLayerAsymPACT','PerChannelsAsymMinMax'] and layer['batch_norm'] is not None and self.batch_fold: 469 | 470 | batch_layer = layer['batch_norm'] 471 | batch_layer.eval() 472 | batch_layer.weight.requires_grad = False 473 | batch_layer.bias.requires_grad = False 474 | 475 | 476 | # manage min/max params when activate batch folding into weights 477 | if quant_type == 'PerLayerAsymPACT' and change_fold_mode and layer['fold_type'] == 'folding_weights' and layer['batch_norm'] is not None: 478 | if self.batch_fold and layer['batch_norm'] is not None: 479 | weight_tensor, bias_tensor = self._batch_fold(conv_layer, batch_layer) 480 | layer['w_max_thr'].data = weight_tensor.max().cuda() 481 | layer['w_min_thr'].data = weight_tensor.min().cuda() 482 | 483 | 484 | def _quant_PerLayerAsymMinMax(self, conv_layer, batch_norm_layer, batch_fold_type, w_bits, idx_layer, \ 485 | training=False, get_quantized=True): 486 | 487 | # 1- folding of batch normalization layers before quantization if enabled and for validation 488 | if (self.batch_fold ) and ( batch_norm_layer is not None) and (batch_fold_type == 'folding_weights'): 489 | #folding batch_norm into convolution and quantize 490 | batch_layer = batch_norm_layer 491 | weight_tensor, bias_tensor = self._batch_fold(conv_layer, batch_layer) 492 | 493 | else: # weights and bias as-it-is 494 | weight_tensor = conv_layer.weight.data.clone() 495 | if conv_layer.bias is None: 496 | bias_tensor = False 497 | else: 498 | bias_tensor = conv_layer.bias.data.clone() 499 | 500 | # return tensor data before quantization 501 | if not get_quantized: 502 | return weight_tensor, bias_tensor 503 | 504 | # 2- quantization 505 | weight_quant_tensor, w_min_value, w_max_value = self._linear_asymmetric_quant_args(weight_tensor, w_bits) 506 | 507 | if bias_tensor is not False: 508 | S_w = ( weight_tensor.max()-weight_tensor.min() ) / ((2**w_bits)-1) 509 | if idx_layer == 0 : 510 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_first_layer_stats() 511 | else : 512 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_activ_stats(self.param_to_quantize[idx_layer-1]) 513 | 514 | eps = S_a_i * S_w 515 | bias_quant_tensor = bias_tensor.div(eps).round().mul(eps) 516 | else: 517 | bias_quant_tensor = False 518 | 519 | 520 | # 3- batch norm defolding 521 | if (self.batch_fold ) and (batch_norm_layer is not None) and (batch_fold_type == 'folding_weights'): 522 | weight_quant_tensor, bias_quant_tensor = self._batch_defold(weight_quant_tensor, bias_quant_tensor, batch_layer) 523 | 524 | return weight_quant_tensor, bias_quant_tensor 525 | 526 | def _quant_PerLayerAsymPACT(self, conv_layer, batch_norm_layer, batch_fold_type, w_bits, idx_layer, \ 527 | w_min_ext, w_max_ext, training=False, get_quantized=True): 528 | 529 | # 1- folding of batch normalization layers before quantization if enabled and for validation 530 | if self.batch_fold and ( batch_norm_layer is not None) and (batch_fold_type == 'folding_weights'): 531 | #folding batch_norm into convolution and quantize 532 | batch_layer = batch_norm_layer 533 | weight_tensor, bias_tensor = self._batch_fold(conv_layer, batch_layer) 534 | else: # weights and bias as-it-is 535 | weight_tensor = conv_layer.weight.data.clone() 536 | if conv_layer.bias is None: 537 | bias_tensor = False 538 | else: 539 | bias_tensor = conv_layer.bias.data.clone() 540 | 541 | # return tensor data before quantization 542 | if self.batch_fold and (batch_fold_type == 'folding_weights') and ( batch_norm_layer is not None): 543 | weight_tensor = weight_tensor.clamp_(w_min_ext, w_max_ext) 544 | 545 | if not get_quantized: 546 | return weight_tensor, bias_tensor 547 | 548 | # 2- quantization 549 | weight_quant_tensor, w_min_value, w_max_value = self._linear_asymmetric_quant_args(weight_tensor, w_bits) 550 | 551 | range_w = w_max_value-w_min_value 552 | if range_w == 0: 553 | range_w = 1 554 | S_w = range_w / ((2**w_bits)-1) 555 | 556 | if bias_tensor is not False: 557 | 558 | if idx_layer == 0 : 559 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_first_layer_stats() 560 | else : 561 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_activ_stats(self.param_to_quantize[idx_layer-1]) 562 | eps = S_a_i * S_w 563 | bias_quant_tensor = bias_tensor.div(eps).round().mul(eps) 564 | else: 565 | bias_quant_tensor = False 566 | 567 | 568 | # 3- batch norm defolding 569 | if self.batch_fold and (batch_norm_layer is not None) and (batch_fold_type == 'folding_weights'): 570 | weight_quant_tensor, bias_quant_tensor = self._batch_defold(weight_quant_tensor, bias_quant_tensor, batch_layer) 571 | 572 | return weight_quant_tensor, bias_quant_tensor 573 | 574 | 575 | def _quant_PerChannelsAsymMinMax(self, conv_layer, batch_norm_layer, batch_fold_type, w_bits, idx_layer, \ 576 | w_min_ext, w_max_ext, training=False, get_quantized=True): 577 | 578 | # 0 -clamping value 579 | w_max,_ = conv_layer.weight.data.reshape(conv_layer.weight.data.size(0), -1).max(1) 580 | w_min,_ = conv_layer.weight.data.reshape(conv_layer.weight.data.size(0), -1).min(1) 581 | if len(conv_layer.weight.data.size()) == 4: 582 | w_min_value_mat = w_min.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(conv_layer.weight.data.size()) 583 | w_max_value_mat = w_max.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(conv_layer.weight.data.size()) 584 | elif len(conv_layer.weight.data.size()) == 2: 585 | w_min_value_mat = w_min.unsqueeze(1).expand(conv_layer.weight.data.size()) 586 | w_max_value_mat = w_max.unsqueeze(1).expand(conv_layer.weight.data.size()) 587 | else: 588 | return 1 589 | 590 | if batch_fold_type != 'folding_weights' : 591 | conv_layer.weight.data = torch.max(conv_layer.weight.data, w_min_value_mat) 592 | conv_layer.weight.data = torch.min(conv_layer.weight.data, w_max_value_mat) 593 | 594 | # 1- folding of batch normalization layers before quantization if enabled and for validation 595 | if (self.batch_fold ) and ( batch_norm_layer is not None) and (batch_fold_type == 'folding_weights'): 596 | #folding batch_norm into convolution and quantize 597 | batch_layer = batch_norm_layer 598 | weight_tensor, bias_tensor = self._batch_fold(conv_layer, batch_layer) 599 | 600 | w_max,_ = weight_tensor.reshape(weight_tensor.size(0), -1).max(1) 601 | w_min,_ = weight_tensor.reshape(weight_tensor.size(0), -1).min(1) 602 | if len(weight_tensor.size()) == 4: 603 | w_min_value_mat = w_min.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(weight_tensor.size()) 604 | w_max_value_mat = w_max.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(weight_tensor.size()) 605 | elif len(weight_tensor.size()) == 2: 606 | w_min_value_mat = w_min.unsqueeze(1).expand(weight_tensor.size()) 607 | w_max_value_mat = w_max.unsqueeze(1).expand(weight_tensor.size()) 608 | else: 609 | return 1 610 | else: # weights and bias as-it-is 611 | weight_tensor = conv_layer.weight.data.clone() 612 | if conv_layer.bias is None: 613 | bias_tensor = False 614 | else: 615 | bias_tensor = conv_layer.bias.data.clone() 616 | 617 | # return tensor data before quantization 618 | if self.batch_fold and (batch_fold_type == 'folding_weights') and batch_norm_layer is not None : 619 | for v in range(weight_tensor.size(0) ): 620 | weight_tensor[v] = weight_tensor[v].clamp(w_min_ext[v].item(), w_max_ext[v].item() ) 621 | 622 | if not get_quantized: 623 | return weight_tensor, bias_tensor 624 | 625 | # 2- quantization 626 | range_mat = w_max_value_mat - w_min_value_mat 627 | range_mat.masked_fill_(range_mat.eq(0), 1) 628 | n_steps = (2**w_bits)-1 629 | S_w = range_mat.div(n_steps) 630 | 631 | weight_quant_tensor = weight_tensor.div(S_w).round().clone() 632 | weight_quant_tensor = weight_quant_tensor.mul(S_w) 633 | 634 | 635 | if bias_tensor is not False: 636 | range_ext = w_max - w_min 637 | S_w = range_ext.masked_fill_(range_ext.eq(0), 1) / n_steps 638 | if idx_layer == 0 : 639 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_first_layer_stats() 640 | else : 641 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_activ_stats(self.param_to_quantize[idx_layer-1]) 642 | eps = S_a_i * S_w 643 | bias_quant_tensor = bias_tensor.div(eps).round().mul(eps) 644 | if eps.abs().min() == 0: 645 | print('Eps = 0 in layer ', idx_layer, '| bias is ', bias_quant_tensor) 646 | else: 647 | bias_quant_tensor = False 648 | 649 | 650 | # 3- batch norm defolding 651 | if (self.batch_fold ) and (batch_norm_layer is not None) and (batch_fold_type == 'folding_weights'): 652 | weight_quant_tensor, bias_quant_tensor = self._batch_defold(weight_quant_tensor, bias_quant_tensor, batch_layer) 653 | 654 | return weight_quant_tensor, bias_quant_tensor 655 | 656 | 657 | 658 | def generate_deployment_model(self): 659 | ''' 660 | Generate integer-only parameters from the fake-quantized model 661 | ''' 662 | 663 | for i_l,layer in enumerate(self.param_to_quantize): 664 | quant_type = layer['quant_type'] 665 | w_bits = layer['w_bits'] 666 | conv_layer = layer['conv'] 667 | bias_bits = layer['bias_bits'] 668 | 669 | if quant_type == 'PerLayerAsymMinMax': 670 | weight_tensor, bias_tensor = self._quant_PerLayerAsymMinMax( \ 671 | conv_layer, layer['batch_norm'], layer['fold_type'], w_bits, i_l, training=False, get_quantized=False) 672 | 673 | elif quant_type == 'PerLayerAsymPACT': 674 | 675 | weight_org = conv_layer.weight.data.clone() 676 | 677 | layer['w_min_thr'].data.clamp_(max= 0) 678 | layer['w_max_thr'].data.clamp_(min= 0) 679 | w_min_value, w_max_value = layer['w_min_thr'].data.item() , layer['w_max_thr'].data.item() 680 | 681 | if layer['fold_type'] != 'folding_weights' or layer['batch_norm'] is None: 682 | conv_layer.weight.data.clamp_(w_min_value, w_max_value) 683 | 684 | if layer['fold_type'] == 'ICN': 685 | weight_tensor = conv_layer.weight.data.clone() 686 | if conv_layer.bias is None: 687 | bias_tensor = False 688 | else: 689 | bias_tensor = conv_layer.bias.data.clone() 690 | else: 691 | weight_tensor, bias_tensor = self._quant_PerLayerAsymPACT( \ 692 | conv_layer, layer['batch_norm'], layer['fold_type'], w_bits, i_l, w_min_value, w_max_value, training=False, get_quantized=False) 693 | 694 | conv_layer.weight.data.copy_(weight_org) 695 | 696 | elif quant_type == 'PerChannelsAsymMinMax': 697 | weight_org = conv_layer.weight.data.clone() 698 | 699 | weight_tensor, bias_tensor = self._quant_PerChannelsAsymMinMax( \ 700 | conv_layer, layer['batch_norm'], layer['fold_type'], w_bits, i_l, None, None, training=False, get_quantized=False ) 701 | 702 | conv_layer.weight.data.copy_(weight_org) 703 | 704 | else: 705 | print('Not recognized quantization scheme!') 706 | exit() 707 | 708 | # compute S, Z parameters of input and outpt layers 709 | 710 | ################################### PARAMS BITS ############################################# 711 | BIAS_BITS = 32 #INT32 712 | M0_BITS = 32 713 | N0_BITS = 8 714 | BIAS_CH_BITS = 16 715 | M0_BITS_LAST = 8 716 | BIAS_FIXED_BITS = 0 717 | 718 | 719 | ################################### LAST layer ############################################# 720 | if layer['act'] is None: 721 | n_levels_w = (2**w_bits)-1 722 | 723 | if quant_type == 'PerChannelsAsymMinMax': # per channel quantization 724 | 725 | weight_quant_tensor = weight_tensor.clone() 726 | if bias_tensor is not False: 727 | bias_quant_tensor = bias_tensor 728 | else: 729 | bias_quant_tensor = False 730 | 731 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_activ_stats(self.param_to_quantize[i_l-1]) 732 | 733 | for v in range(weight_tensor.size(0)) : 734 | w_min_value,w_max_value= weight_tensor[v].min(), weight_tensor[v].max() 735 | range_w = w_max_value-w_min_value 736 | if range_w == 0: 737 | range_w = 1 738 | S_w = (range_w) / n_levels_w 739 | 740 | weight_quant_tensor[v] = weight_tensor[v].div(S_w).round() 741 | 742 | if bias_tensor is not False: 743 | bias_quant_tensor[v] = bias_tensor[v].div( S_w * S_a_i ).mul(2**BIAS_FIXED_BITS).round().clamp(-2**(BIAS_BITS-1), 2**(BIAS_BITS-1)-1 ).div(2**BIAS_FIXED_BITS) 744 | 745 | w_max_value , _ = weight_tensor.reshape(weight_tensor.size(0), -1).max(1) 746 | w_min_value , _ = weight_tensor.reshape(weight_tensor.size(0), -1).min(1) 747 | 748 | range_w = w_max_value-w_min_value 749 | range_w.masked_fill_(range_w.eq(0), 1) 750 | S_w = (range_w) / n_levels_w 751 | M0, n_exp = self._get_m0_nexp_vect(S_w) 752 | M0 = M0.mul(2**(M0_BITS_LAST-1)).round().clamp(-2**(M0_BITS_LAST-1),2**(M0_BITS_LAST-1)-1).div(2**(M0_BITS_LAST-1)) #assume Q.1.31 753 | n_exp = n_exp.clamp(-2**(N0_BITS-1),2**(N0_BITS-1)-1 ) 754 | layer['quant_act'].M_ZERO = M0 755 | layer['quant_act'].N_ZERO = 2**n_exp 756 | 757 | layer['quant_conv'].weight.data.copy_( weight_quant_tensor ) 758 | if bias_quant_tensor is not False: 759 | layer['quant_conv'].bias = nn.Parameter(bias_quant_tensor, requires_grad = False) 760 | 761 | 762 | else: 763 | 764 | w_min_value,w_max_value= weight_tensor.min(), weight_tensor.max() 765 | range_w = w_max_value-w_min_value 766 | if range_w == 0: 767 | range_w = 1 768 | S_w = range_w / n_levels_w 769 | 770 | 771 | weight_quant_tensor = weight_tensor.div(S_w).round() 772 | layer['quant_conv'].weight.data.copy_( weight_quant_tensor ) 773 | 774 | S_a_i, Z_a_i, nl_a_i, a_min, a_max= get_activ_stats(self.param_to_quantize[i_l-1]) 775 | if bias_tensor is not False: 776 | bias_tensor = bias_tensor.div(S_w*S_a_i) 777 | bias_quant_tensor = bias_tensor.round().clamp(-2**(BIAS_BITS-1),2**(BIAS_BITS-1)-1) 778 | 779 | layer['quant_conv'].bias = nn.Parameter(bias_quant_tensor, requires_grad = False) 780 | 781 | 782 | ################################### OTHERS layers ############################################# 783 | 784 | elif type(layer['act']) in [ClippedLinearQuantization, LearnedClippedLinearQuantization] : 785 | 786 | # get S_a parameters 787 | if i_l == 0 : 788 | #first layer 789 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_first_layer_stats() 790 | S_a_o, Z_a_o, nl_a_o, a_min, a_max= get_activ_stats(layer) 791 | else: 792 | S_a_i, Z_a_i, nl_a_i, a_min, a_max = get_activ_stats(self.param_to_quantize[i_l-1]) 793 | S_a_o, Z_a_o, nl_a_o, a_min, a_max = get_activ_stats(layer) 794 | 795 | n_levels_w = (2**w_bits)-1 796 | 797 | if quant_type == 'PerChannelsAsymMinMax': # per channel quantization 798 | 799 | if layer['fold_type'] == 'ICN': 800 | 801 | batch_layer = layer['batch_norm'] 802 | 803 | 804 | gamma_over_sigma, mu_tensor, beta_tensor = self._get_BN_scale_factors(batch_layer) 805 | if bias_tensor is not False: 806 | bias_tensor = bias_tensor.add(-mu_tensor) 807 | print('hereeee') 808 | else: 809 | bias_tensor = -mu_tensor.clone() 810 | 811 | w_max_value , _ = weight_tensor.reshape(weight_tensor.size(0), -1).max(1) 812 | w_min_value , _ = weight_tensor.reshape(weight_tensor.size(0), -1).min(1) 813 | range_w = w_max_value-w_min_value 814 | range_w.masked_fill_(range_w.eq(0), 1) 815 | S_w = (range_w) / n_levels_w 816 | bias_tensor = bias_tensor.add(beta_tensor/gamma_over_sigma).div(S_a_i*S_w).mul(2**BIAS_FIXED_BITS).round().clamp(-2**(BIAS_BITS-1),2**(BIAS_BITS-1)-1).div(2**BIAS_FIXED_BITS) #assume INT8 817 | layer['quant_conv'].bias = nn.Parameter(bias_tensor, requires_grad = False) 818 | 819 | 820 | weight_quant_tensor = weight_tensor.clone() 821 | for v in range(weight_tensor.size(0)) : 822 | w_min_value,w_max_value= weight_tensor[v].min(), weight_tensor[v].max() 823 | range_w = w_max_value-w_min_value 824 | if range_w == 0: 825 | range_w = 1 826 | 827 | S_w = range_w / n_levels_w 828 | weight_quant_tensor[v] = weight_tensor[v].div(S_w).round() 829 | gamma_over_sigma[v] = gamma_over_sigma[v].mul((S_a_i*S_w)/S_a_o) 830 | 831 | M0, n_exp = self._get_m0_nexp_vect(gamma_over_sigma) 832 | M0 = M0.mul(2**(M0_BITS-1)).round().clamp(-2**(M0_BITS-1),2**(M0_BITS-1)-1).div(2**(M0_BITS-1)) #assume Q.1.31 833 | n_exp = n_exp.clamp(-2**(N0_BITS-1),2**(N0_BITS-1)-1 ) 834 | layer['quant_act'].M_ZERO = M0 835 | layer['quant_act'].N_ZERO = 2**n_exp 836 | layer['quant_act'].clip_val = nl_a_o 837 | layer['quant_conv'].weight.data.copy_( weight_quant_tensor ) 838 | 839 | else: 840 | print('Not supported yet!!') 841 | return 1 842 | 843 | else: #### PerLayerAsymPACT ################ 844 | 845 | w_min_value,w_max_value= weight_tensor.min(), weight_tensor.max() 846 | range_w = w_max_value-w_min_value 847 | if range_w == 0: 848 | range_w = 1 849 | S_w = range_w / n_levels_w 850 | weight_quant_tensor = weight_tensor.div(S_w).round() 851 | layer['quant_conv'].weight.data.copy_( weight_quant_tensor ) 852 | 853 | 854 | if layer['fold_type'] == 'ICN': 855 | 856 | batch_layer = layer['batch_norm'] 857 | 858 | gamma_over_sigma, mu_tensor, beta_tensor = self._get_BN_scale_factors(batch_layer) 859 | if bias_tensor is not False: 860 | bias_tensor = bias_tensor.add(-mu_tensor) 861 | print('hereeee') 862 | else: 863 | bias_tensor = -mu_tensor 864 | 865 | 866 | bias_tensor = bias_tensor.add(beta_tensor/gamma_over_sigma).div(S_a_i*S_w) #assume INT8 867 | bias_quant_tensor = bias_tensor.round().clamp(-2**(BIAS_BITS-1),2**(BIAS_BITS-1)-1) 868 | 869 | layer['quant_conv'].bias = nn.Parameter(bias_quant_tensor, requires_grad = False) 870 | 871 | gamma_over_sigma = gamma_over_sigma.mul((S_a_i*S_w)/S_a_o) 872 | M0, n_exp = self._get_m0_nexp_vect(gamma_over_sigma ) 873 | M0 = M0.mul(2**(M0_BITS-1)).round().clamp(-2**(M0_BITS-1),2**(M0_BITS-1)-1).div(2**(M0_BITS-1)) #assume Q.1.31 874 | n_exp = n_exp.clamp(-2**(N0_BITS-1),2**(N0_BITS-1)-1 ) 875 | 876 | 877 | layer['quant_act'].M_ZERO = M0 878 | layer['quant_act'].N_ZERO = 2**n_exp 879 | layer['quant_act'].clip_val = nl_a_o 880 | 881 | 882 | else: #folding weights 883 | 884 | if bias_tensor is not False: 885 | bias_tensor = bias_tensor.div(S_w*S_a_i) 886 | bias_quant_tensor = bias_tensor.round().clamp(-2**(BIAS_BITS-1),2**(BIAS_BITS-1)-1) 887 | layer['quant_conv'].bias = nn.Parameter(bias_quant_tensor, requires_grad = False) 888 | 889 | else: 890 | layer['quant_conv'].bias = None 891 | 892 | Z_w = -weight_quant_tensor.min().item() 893 | 894 | M0, n_exp = self._get_m0_nexp((S_a_i*S_w)/S_a_o ) 895 | M0 = np.clip(np.round(M0*2**(M0_BITS-1)),-2**(M0_BITS-1),2**(M0_BITS-1)-1) / 2**(M0_BITS-1) 896 | n_exp = np.clip(n_exp,-2**(N0_BITS-1),2**(N0_BITS-1)-1 ) 897 | 898 | layer['quant_act'].M_ZERO = M0.tolist() 899 | layer['quant_act'].N_ZERO = n_exp.tolist() 900 | layer['quant_act'].clip_val = nl_a_o 901 | 902 | else: # relu for instance 903 | print('Error') 904 | 905 | layer['quant_conv'].weight.data.copy_( weight_quant_tensor ) 906 | layer['quant_conv'].bias = nn.Parameter(bias_tensor, requires_grad = False) 907 | 908 | 909 | def store_and_quantize(self, training=False): 910 | # store actual weights and replace with quantized versione before the forward pass 911 | 912 | for i_l,layer in enumerate(self.param_to_quantize): 913 | quant_type = layer['quant_type'] 914 | w_bits = layer['w_bits'] 915 | conv_layer = layer['conv'] 916 | 917 | 918 | if quant_type == 'PerLayerAsymMinMax': 919 | 920 | # 1. store data floating value 921 | layer['weight'] = conv_layer.weight.data.clone() 922 | if conv_layer.bias is None: 923 | layer['bias'] = False 924 | else: 925 | layer['bias'] = conv_layer.bias.data.clone() 926 | bias_bits = layer['bias_bits'] 927 | 928 | #2. quantization 929 | weight_quant_tensor, bias_quant_tensor = self._quant_PerLayerAsymMinMax( \ 930 | conv_layer, layer['batch_norm'], layer['fold_type'], w_bits, i_l, training=training ) 931 | 932 | # 3. push quantized values into fake-quantized network 933 | conv_layer.weight.data.copy_( weight_quant_tensor ) 934 | if bias_quant_tensor is not False: 935 | if conv_layer.bias is not None: 936 | conv_layer.bias.data.copy_( bias_quant_tensor ) 937 | else: 938 | conv_layer.bias = nn.Parameter(bias_quant_tensor, requires_grad = False) 939 | elif conv_layer.bias is not None: 940 | print('Error! something wrong with bias') 941 | return(-1) 942 | 943 | elif quant_type in ['PerLayerAsymPACT','PerChannelsAsymMinMax']: 944 | 945 | # 1. store data floating value 946 | layer['weight'] = conv_layer.weight.data.clone() 947 | if conv_layer.bias is None: 948 | layer['bias'] = False 949 | else: 950 | layer['bias'] = conv_layer.bias.data.clone() 951 | bias_bits = layer['bias_bits'] 952 | 953 | # 2. quantization 954 | if quant_type == 'PerLayerAsymPACT': 955 | 956 | layer['w_min_thr'].data.clamp_(max= 0) 957 | layer['w_max_thr'].data.clamp_(min= 0) 958 | w_min_value, w_max_value = layer['w_min_thr'].data.item() , layer['w_max_thr'].data.item() 959 | 960 | if layer['fold_type'] != 'folding_weights' or layer['batch_norm'] is None: 961 | conv_layer.weight.data.clamp_(w_min_value, w_max_value) 962 | 963 | weight_quant_tensor, bias_quant_tensor = self._quant_PerLayerAsymPACT( \ 964 | conv_layer, layer['batch_norm'], layer['fold_type'], w_bits, i_l, w_min_value, w_max_value, training=training ) 965 | 966 | elif quant_type == 'PerChannelsAsymMinMax': 967 | 968 | weight_quant_tensor, bias_quant_tensor = self._quant_PerChannelsAsymMinMax( \ 969 | conv_layer, layer['batch_norm'], layer['fold_type'], w_bits, i_l, None, None, training=training ) 970 | 971 | # 3. push quantized values into fake-quantized network 972 | conv_layer.weight.data.copy_( weight_quant_tensor ) 973 | 974 | if bias_quant_tensor is not False: 975 | if conv_layer.bias is not None: 976 | conv_layer.bias.data.copy_( bias_quant_tensor ) 977 | else: 978 | conv_layer.bias = nn.Parameter(bias_quant_tensor, requires_grad = False) 979 | elif conv_layer.bias is not None: 980 | print('Error something wrong with bias') 981 | return(-1) 982 | 983 | else: 984 | print('Not recognized quantization scheme!') 985 | exit() 986 | 987 | 988 | def restore_real_value(self): 989 | # restore real value into the fake quantized model before the weight update 990 | 991 | for layer in self.param_to_quantize: 992 | quant_type = layer['quant_type'] 993 | w_bits = layer['w_bits'] 994 | conv_layer = layer['conv'] 995 | 996 | 997 | if quant_type in ['PerLayerAsymMinMax', 'PerLayerAsymPACT', 'PerChannelsAsymMinMax']: 998 | 999 | # restore data 1000 | conv_layer.weight.data.copy_(layer['weight']) 1001 | if layer['w_clip'] is not None: 1002 | [ w_min, w_max ]= layer['w_clip'] 1003 | conv_layer.weight.data = conv_layer.weight.data.clamp(w_min, w_max) 1004 | 1005 | if layer['bias'] is not False: 1006 | conv_layer.bias.data.copy_(layer['bias']) 1007 | else: 1008 | conv_layer.bias = None 1009 | 1010 | 1011 | def backprop_quant_gradients(self): 1012 | # apply backward pass from quantized weights to real value weights 1013 | 1014 | 1015 | for i_l, layer in enumerate(self.param_to_quantize): 1016 | quant_type = layer['quant_type'] 1017 | w_bits = layer['w_bits'] 1018 | conv_layer = layer['conv'] 1019 | 1020 | if quant_type in ['PerLayerAsymMinMax', 'PerChannelsAsymMinMax']: 1021 | self.__ste_correction(conv_layer, clip=False) 1022 | 1023 | elif quant_type == 'PerLayerAsymPACT': 1024 | w_min_param, w_max_param = layer['w_min_thr'],layer['w_max_thr'] 1025 | w_min_value, w_max_value = w_min_param.data.item(), w_max_param.data.item() 1026 | 1027 | grad_a_max = conv_layer.weight.grad.data.clone() 1028 | grad_a_min = conv_layer.weight.grad.data.clone() 1029 | 1030 | weight_tensor = conv_layer.weight.data 1031 | dLdw = conv_layer.weight.grad.data 1032 | dLdw.masked_fill_(weight_tensor.le(w_min_value), 0) 1033 | dLdw.masked_fill_(weight_tensor.ge(w_max_value), 0) 1034 | 1035 | grad_a_max.masked_fill_(layer['weight'].lt(w_max_value), 0) 1036 | grad_a_max = grad_a_max.sum().expand_as(w_max_param.data) 1037 | w_max_param.grad.data = grad_a_max 1038 | 1039 | grad_a_min.masked_fill_(layer['weight'].gt(w_min_value), 0) 1040 | grad_a_min = grad_a_min.sum().expand_as(w_min_param.data) 1041 | w_min_param.grad.data = grad_a_min 1042 | 1043 | 1044 | def __ste_correction(self, conv_layer, clip=True): 1045 | # gradient correction according to straight trough estimator ( STE) 1046 | #update convolution weights 1047 | weight = conv_layer.weight.data 1048 | m = conv_layer.weight.grad.data 1049 | if clip: 1050 | m.masked_fill_(weight.lt(-1.0), 0) 1051 | m.masked_fill_(weight.gt(1.0), 0) 1052 | --------------------------------------------------------------------------------