├── requirements.txt ├── models ├── __init__.py ├── modules │ ├── checkpoint.py │ ├── se.py │ ├── birelu.py │ ├── batch_norm.py │ ├── fixup.py │ ├── fixed_proj.py │ ├── bwn.py │ ├── evolved_modules.py │ └── lp_norm.py ├── mobilenet_v2.py ├── mobilenet_v2_old.py └── resnet.py ├── scripts ├── advanced_pipeline.sh ├── light_pipeline.sh ├── bn_tuning.sh ├── bias_tuning.sh ├── adaquant.sh └── integer-programing.sh ├── utils ├── functions.py ├── correct_bias.py ├── convert_pytcv_model.py ├── LICENSE ├── mixup.py ├── quantize.py ├── cross_entropy.py ├── misc.py ├── meters.py ├── regime.py ├── model_manipulation.py ├── param_filter.py ├── layer_sensativity.py ├── dataset.py ├── absorb_bn.py ├── adaquant.py ├── log.py ├── optim.py └── regularization.py ├── BERT-base ├── scripts │ ├── clone_copy_build.sh │ ├── create_calib_data.py │ ├── bert-base-squad1.1-384.sh │ └── bert-base-adaquant.sh ├── README.md └── src │ └── layer_sensativity.py ├── ip_config_parser.py ├── create_calib_folder.py ├── LICENSE ├── README.md ├── preprocess.py ├── evaluate.py ├── data.py └── mpip_compression_pytorch_multi.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | bokeh 4 | pandas 5 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .mobilenet_v2 import * 3 | from .resnet_pytcv import * 4 | -------------------------------------------------------------------------------- /scripts/advanced_pipeline.sh: -------------------------------------------------------------------------------- 1 | sh scripts/adaquant.sh resnet resnet50 4 4 True 2 | sh scripts/adaquant.sh resnet resnet50 8 8 True 3 | sh scripts/integer-programing.sh resnet resnet50 4 4 8 8 50 loss True 4 | 5 | # Uncomment to run first configuration only 6 | #for cfg_idx in 0 7 | for cfg_idx in 0 1 2 3 4 5 6 7 8 9 10 11 8 | do 9 | # TODO: run bn and bias tuning in loop on all configurations 10 | sh scripts/bn_tuning.sh resnet resnet50 8 8 $cfg_idx 11 | sh scripts/bias_tuning.sh resnet resnet50 8 8 $cfg_idx 12 | done 13 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function 3 | 4 | class ScaleGrad(Function): 5 | 6 | @staticmethod 7 | def forward(ctx, input, scale): 8 | ctx.scale = scale 9 | return input 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | grad_input = ctx.scale * grad_output 14 | return grad_input, None 15 | 16 | 17 | def scale_grad(x, scale): 18 | return ScaleGrad().apply(x, scale) 19 | 20 | def negate_grad(x): 21 | return scale_grad(x, -1) 22 | -------------------------------------------------------------------------------- /BERT-base/scripts/clone_copy_build.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/huggingface/transformers 2 | cd transformers 3 | git checkout 5620033115e571013325d017bcca92991b0a4ace 4 | cd .. 5 | cp src/layer_sensativity.py transformers/ 6 | cp src/run_squad_adaquant.py transformers/examples/question-answering/ 7 | mv transformers/src/transformers/modeling_bert.py transformers/src/transformers/modeling_bert_fp32.py 8 | cp src/modeling_* transformers/src/transformers/ 9 | cp src/mse_optimization.py transformers/src/transformers/ 10 | cd transformers 11 | pip install . 12 | 13 | 14 | -------------------------------------------------------------------------------- /BERT-base/scripts/create_calib_data.py: -------------------------------------------------------------------------------- 1 | train_file_path='/media/drive/Datasets/squad/train-v1.1.json' 2 | 3 | with open (train_file_path, 'rt') as myfile: # Open file lorem.txt for reading text 4 | for txt_data in myfile: # For each line, read it to a string 5 | id=txt_data.find('"title"',101385+12) 6 | new_line= txt_data[:id-3]+txt_data[-20:] 7 | import pdb; pdb.set_trace() 8 | calib_path= train_file_path.replace('train','calib') 9 | calib_file = open(calib_path, "wt") 10 | n = calib_file.write(new_line) 11 | calib_file.close() -------------------------------------------------------------------------------- /scripts/light_pipeline.sh: -------------------------------------------------------------------------------- 1 | export model=${1:-"resnet"} 2 | export model_vis=${2:-"resnet50"} 3 | 4 | sh scripts/adaquant.sh $model $model_vis 4 4 false 5 | sh scripts/adaquant.sh $model $model_vis 8 8 false 6 | sh scripts/integer-programing.sh $model $model_vis 4 4 8 8 50 loss false 7 | 8 | # Uncomment to run first configuration only 9 | #for cfg_idx in 0 10 | for cfg_idx in 0 1 2 3 4 5 6 7 8 9 10 11 11 | do 12 | # TODO: run bn tuning in loop on all configurations 13 | echo "Running configuration $cfg_idx" 14 | sh scripts/bn_tuning.sh resnet resnet50 8 8 $cfg_idx 15 | done 16 | -------------------------------------------------------------------------------- /ip_config_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import csv 4 | import os 5 | 6 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 7 | parser.add_argument('--config-file', type=str, metavar='FILE', help='path to config file') 8 | parser.add_argument('--cfg-idx', default=0, type=int, help='Index of precision configuration to get') 9 | parser.add_argument('--column', type=str, help='path to config file') # [Configuration, state_dict_path] 10 | args = parser.parse_args() 11 | 12 | 13 | df = pd.read_csv(args.config_file, sep='\t') 14 | res = df[args.column][args.cfg_idx] 15 | print(res) 16 | -------------------------------------------------------------------------------- /create_calib_folder.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | 5 | basepath = '/home/Datasets/imagenet/train/' 6 | basepath_calib = '/home/Datasets/imagenet/calib/' 7 | 8 | directory = os.fsencode(basepath) 9 | os.mkdir(basepath_calib) 10 | for d in os.listdir(directory): 11 | dir_name = os.fsdecode(d) 12 | dir_path = os.path.join(basepath,dir_name) 13 | dir_copy_path = os.path.join(basepath_calib,dir_name) 14 | os.mkdir(dir_copy_path) 15 | for f in os.listdir(dir_path): 16 | file_path = os.path.join(dir_path,f) 17 | copy_file_path = os.path.join(dir_copy_path,f) 18 | shutil.copyfile(file_path, copy_file_path) 19 | break -------------------------------------------------------------------------------- /utils/correct_bias.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def correct_bias(path_bias_quant,path_bias_measure,path_model): 4 | model=torch.load(path_model) 5 | bias_quant=torch.load(path_bias_quant) 6 | bias_measure=torch.load(path_bias_measure) 7 | for key in bias_quant: 8 | if 'count' not in key: 9 | #import pdb; pdb.set_trace() 10 | diff_bias = (bias_quant[key]-bias_measure[key]).div(bias_measure[key+'.count']*10) 11 | model[key+'.bias'] -= diff_bias.to(model[key+'.bias']) 12 | torch.save(model,path_model+'_bias_correct') 13 | 14 | correct_bias('bias_mean_quant','bias_mean_measure','./results/resnet50/resnet.absorb_bn.measure_v2') -------------------------------------------------------------------------------- /models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 4 | 5 | 6 | class CheckpointModule(nn.Module): 7 | def __init__(self, module, num_segments=1): 8 | super(CheckpointModule, self).__init__() 9 | assert num_segments == 1 or isinstance(module, nn.Sequential) 10 | self.module = module 11 | self.num_segments = num_segments 12 | 13 | def forward(self, *inputs): 14 | if self.num_segments > 1: 15 | return checkpoint_sequential(self.module, self.num_segments, *inputs) 16 | else: 17 | return checkpoint(self.module, *inputs) 18 | -------------------------------------------------------------------------------- /BERT-base/scripts/bert-base-squad1.1-384.sh: -------------------------------------------------------------------------------- 1 | export SQUAD_DIR=/media/drive/Datasets/squad 2 | CUDA_VISIBLE_DEVICES=2 python ./examples/run_squad.py \ 3 | --model_type bert \ 4 | --model_name_or_path bert-base-uncased \ 5 | --do_train \ 6 | --do_eval \ 7 | --train_file $SQUAD_DIR/train-v1.1.json \ 8 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 9 | --learning_rate 3e-5 \ 10 | --num_train_epochs 2 \ 11 | --max_seq_length 384 \ 12 | --doc_stride 128 \ 13 | --output_dir ./examples/models/bert_base_uncased_finetuned_squad_base_384 \ 14 | --per_gpu_eval_batch_size=12 \ 15 | --per_gpu_train_batch_size=12 \ 16 | --do_lower_case \ 17 | --local_rank=-1 \ 18 | 19 | -------------------------------------------------------------------------------- /utils/convert_pytcv_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def convert_pytcv_model(model,model_pytcv): 4 | sd=model.state_dict() 5 | sd_pytcv=model_pytcv.state_dict() 6 | convert_dict={} 7 | for key,key_pytcv in zip(sd.keys(),sd_pytcv.keys()): 8 | clean_key='.'.join(key.split('.')[:-1]) 9 | clean_key_pytcv='.'.join(key_pytcv.split('.')[:-1]) 10 | convert_dict[clean_key]=clean_key_pytcv 11 | if sd[key].shape != sd_pytcv[key_pytcv].shape: 12 | print(key,sd[key].shape,key_pytcv,sd_pytcv[key_pytcv].shape) 13 | import pdb; pdb.set_trace() 14 | else: 15 | sd[key].copy_(sd_pytcv[key_pytcv]) 16 | return model 17 | -------------------------------------------------------------------------------- /models/modules/se.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SEBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels=None, ratio=16): 6 | super(SEBlock, self).__init__() 7 | self.in_channels = in_channels 8 | if out_channels is None: 9 | out_channels = in_channels 10 | self.ratio = ratio 11 | self.relu = nn.ReLU(True) 12 | self.global_pool = nn.AdaptiveAvgPool2d(1) 13 | self.transform = nn.Sequential( 14 | nn.Linear(in_channels, in_channels // ratio), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(in_channels // ratio, out_channels), 17 | nn.Sigmoid() 18 | ) 19 | 20 | def forward(self, x): 21 | x_avg = self.global_pool(x).view(x.size(0), -1) 22 | mask = self.transform(x_avg) 23 | return x * mask.view(x.size(0), -1, 1, 1) 24 | 25 | -------------------------------------------------------------------------------- /BERT-base/README.md: -------------------------------------------------------------------------------- 1 | # BERT-BASE over SQuAD1.1 2 | 3 | This repository adds files to huggingface/transformers repo as to enable adaquant optimization. 4 | The repo is tested on Python 3.6+, PyTorch 1.0.0+ (PyTorch 1.3.1+ for examples) and TensorFlow 2.0. 5 | As suggested by [🤗 Transformers](https://github.com/huggingface/transformers) you should install it in a virtual environment. 6 | 7 | In your virtual environment please use the follwing script to clone and copy the releven files: 8 | ```bash 9 | sh scripts/clone_copy_build.sh 10 | ``` 11 | To reproduce the results please make sure that you have [SQuAD-v1.1](https://rajpurkar.github.io/SQuAD-explorer/) dataset and pretrained base model. You can finetune FP32 BERT base on SQuAD using: 12 | ```bash 13 | sh ../scripts/bert-base-squad1.1-384.sh 14 | ``` 15 | Next create calibration dataset from training file and run AdaQuant 16 | ```bash 17 | python ../scripts/create_calib_data.py 18 | sh ../scripts/bert-base-adaquant.sh 19 | ``` 20 | 21 | -------------------------------------------------------------------------------- /scripts/bn_tuning.sh: -------------------------------------------------------------------------------- 1 | export datasets_dir=/media/drive/Datasets 2 | export model=${1:-"resnet"} 3 | export model_vis=${2:-"resnet50"} 4 | export nbits_weight=${3:-4} 5 | export nbits_act=${4:-4} 6 | export perC=True 7 | export num_sp_layers=-1 8 | export perC_suffix='' 9 | export adaquant=True 10 | if [ "$adaquant" = True ]; then 11 | export adaquant_suffix='.adaquant' 12 | fi 13 | export perC_suffix='' 14 | if [ "$perC" = True ]; then 15 | export perC_suffix='_perC' 16 | fi 17 | export workdir=${model_vis}_w$nbits_weight'a'$nbits_act$adaquant_suffix 18 | 19 | cfg_idx=${5:-0} 20 | prec_dict=$(python ip_config_parser.py --cfg-idx $cfg_idx --config-file results/$workdir/IP_${model_vis}_loss.txt --column Configuration) 21 | ckp=$(python ip_config_parser.py --cfg-idx $cfg_idx --config-file results/$workdir/IP_${model_vis}_loss.txt --column state_dict_path) 22 | 23 | # Run bn tuning 24 | python main.py -lpd "$prec_dict" --batch-norn-tuning --model $model -lfv $model_vis -b 200 --evaluate $ckp --model-config "{'batch_norm': False,'measure': False, 'perC': $perC}" --dataset imagenet_calib --datasets-dir $datasets_dir 25 | 26 | -------------------------------------------------------------------------------- /scripts/bias_tuning.sh: -------------------------------------------------------------------------------- 1 | export datasets_dir=/media/drive/Datasets 2 | export model=${1:-"resnet"} 3 | export model_vis=${2:-"resnet50"} 4 | export nbits_weight=${3:-4} 5 | export nbits_act=${4:-4} 6 | export perC=True 7 | export num_sp_layers=-1 8 | export perC_suffix='' 9 | export adaquant=True 10 | if [ "$adaquant" = True ]; then 11 | export adaquant_suffix='.adaquant' 12 | fi 13 | export perC_suffix='' 14 | if [ "$perC" = True ]; then 15 | export perC_suffix='_perC' 16 | fi 17 | export workdir=${model_vis}_w$nbits_weight'a'$nbits_act$adaquant_suffix 18 | 19 | export cmp_idx=${5:-0} 20 | prec_dict=$(python ip_config_parser.py --cfg-idx $cmp_idx --config-file results/$workdir/IP_${model_vis}_loss.txt) 21 | export ckp_name=resnet.absorb_bn.mixed-ip-results.comp_0.13_loss 22 | 23 | # Run bias tuning 24 | python main.py -lpd "$prec_dict" --bias-tuning --model $model -b 200 --evaluate results/$workdir/$ckp_name --model-config "{'batch_norm': False,'measure': False, 'perC': $perC}" --dataset imagenet_calib --datasets-dir $datasets_dir --save results/$workdir/bias_ft -b 50 --fine-tune --update_only_th --kld_loss --epochs 10 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/modules/birelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import InplaceFunction 3 | import torch.nn as nn 4 | 5 | 6 | class BiReLUFunction(InplaceFunction): 7 | 8 | @staticmethod 9 | def forward(ctx, input, inplace=False): 10 | if input.size(1) % 2 != 0: 11 | raise RuntimeError("dimension 1 of input must be multiple of 2, " 12 | "but got {}".format(input.size(1))) 13 | ctx.inplace = inplace 14 | 15 | if ctx.inplace: 16 | ctx.mark_dirty(input) 17 | output = input 18 | else: 19 | output = input.clone() 20 | 21 | pos, neg = output.chunk(2, dim=1) 22 | pos.clamp_(min=0) 23 | neg.clamp_(max=0) 24 | ctx.save_for_backward(output) 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | output, = ctx.saved_variables 30 | grad_input = grad_output.masked_fill(output.eq(0), 0) 31 | return grad_input, None 32 | 33 | 34 | def birelu(x, inplace=False): 35 | return BiReLUFunction().apply(x, inplace) 36 | 37 | 38 | class BiReLU(nn.Module): 39 | """docstring for BiReLU.""" 40 | 41 | def __init__(self, inplace=False): 42 | super(BiReLU, self).__init__() 43 | self.inplace = inplace 44 | 45 | def forward(self, inputs): 46 | return birelu(inputs, inplace=self.inplace) 47 | 48 | -------------------------------------------------------------------------------- /scripts/adaquant.sh: -------------------------------------------------------------------------------- 1 | export datasets_dir=/media/drive/Datasets 2 | export model=${1:-"resnet"} 3 | export model_vis=${2:-"resnet50"} 4 | export nbits_weight=${3:-4} 5 | export nbits_act=${4:-4} 6 | export adaquant_suffix='' 7 | if [ "$5" = True ]; then 8 | export adaquant_suffix='.adaquant' 9 | fi 10 | export workdir=${model_vis}_w$nbits_weight'a'$nbits_act$adaquant_suffix 11 | export perC=True 12 | export num_sp_layers=-1 13 | export perC_suffix='' 14 | if [ "$perC" = True ] ; then 15 | export perC_suffix='_perC' 16 | fi 17 | # download and absorb_bn resnet50 and 18 | python main.py --model $model --save $workdir -b 128 -lfv $model_vis --model-config "{'batch_norm': False}" 19 | 20 | # measure range and zero point on calibset 21 | python main.py --model $model --nbits_weight $nbits_weight --nbits_act $nbits_act --num-sp-layers $num_sp_layers --evaluate results/$workdir/$model.absorb_bn --model-config "{'batch_norm': False,'measure': True, 'perC': $perC}" -b 100 --rec --dataset imagenet_calib --datasets-dir $datasets_dir 22 | if [ "$5" = True ]; then 23 | # Run adaquant to minimize MSE of the output with respect to range, zero point and small perturations in parameters 24 | python main.py --optimize-weights --nbits_weight $nbits_weight --nbits_act $nbits_act --num-sp-layers $num_sp_layers --model $model -b 200 --evaluate results/$workdir/$model.absorb_bn.measure$perC_suffix --model-config "{'batch_norm': False,'measure': False, 'perC': $perC}" --dataset imagenet_calib --datasets-dir $datasets_dir --adaquant 25 | fi 26 | -------------------------------------------------------------------------------- /BERT-base/scripts/bert-base-adaquant.sh: -------------------------------------------------------------------------------- 1 | export SQUAD_DIR=/media/drive/Datasets/squad 2 | export FP_MODEL_DIR=/media/data/transformers/ 3 | export SQUAD_BERT_NBITS=8 4 | # measure 5 | CUDA_VISIBLE_DEVICES=0 python ./examples/question-answering/run_squad_adaquant.py \ 6 | --model_type bert \ 7 | --model_name_or_path $FP_MODEL_DIR/examples/models/bert_base_uncased_finetuned_squad_base_384 \ 8 | --do_eval \ 9 | --calib_file $SQUAD_DIR/calib-v1.1.json \ 10 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 11 | --learning_rate 3e-5 \ 12 | --max_seq_length 384 \ 13 | --doc_stride 128 \ 14 | --output_dir ./examples/models/bert-base-measure-perc-w$SQUAD_BERT_NBITS'a'$SQUAD_BERT_NBITS \ 15 | --per_gpu_eval_batch_size=12 \ 16 | --per_gpu_train_batch_size=12 \ 17 | --do_lower_case \ 18 | --local_rank=-1 \ 19 | --quant-config "{'quantize': True, 'measure': True,'num_bits': 4, 'num_bits_weight': 4, 'perC': True, 'cal_qparams': False}" \ 20 | 21 | # adaquant 22 | CUDA_VISIBLE_DEVICES=0 python ./examples/question-answering/run_squad_adaquant.py \ 23 | --model_type bert \ 24 | --model_name_or_path ./examples/models/bert-base-measure-perc-w$SQUAD_BERT_NBITS'a'$SQUAD_BERT_NBITS \ 25 | --do_eval \ 26 | --calib_file $SQUAD_DIR/calib-v1.1.json \ 27 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 28 | --learning_rate 3e-5 \ 29 | --max_seq_length 384 \ 30 | --doc_stride 128 \ 31 | --output_dir ./examples/models/bert-base-adaquant-perc-w4a4 \ 32 | --per_gpu_eval_batch_size=12 \ 33 | --per_gpu_train_batch_size=12 \ 34 | --do_lower_case \ 35 | --local_rank=-1 \ 36 | --optimize_weights \ 37 | --quant-config "{'quantize': True, 'measure': False,'num_bits': $SQUAD_BERT_NBITS, 'num_bits_weight': $SQUAD_BERT_NBITS, 'perC': True, 'cal_qparams': False}" \ 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Post Training Neural Quantization:Layer-wise Calibration and Integer Programming 2 | Most of the literature on neural network quantization requires some training of the quantized model (fine-tuning). However, this training is not always possible in real-world scenarios, as it requires the full dataset. Lately, post-training quantization methods have gained considerable attention, as they are simple to use and require only a small, unlabeled calibration set. Yet, they usually incur significant accuracy degradation when quantized below 8-bits. This paper seeks to address this problem by introducing two pipelines, advanced and light, where the former involves: (i) minimizing the quantization errors of each layer by optimizing its parameters over the calibration set; (ii) using integer programming to optimally allocate the desired bit-width for each layer while constraining accuracy degradation or model compression; and (iii) tuning the mixed-precision model statistics to correct biases introduced during quantization. While the light pipeline which invokes only (ii) and (iii) obtains surprisingly accurate results; the advanced pipeline yields state-of-the-art accuracy-compression ratios for both vision and text models. For instance, on ResNet50, we obtain less than 1\% accuracy degradation while compressing the model to 13\% of its original size. Our code is available in the supplementary material and would be open-sourced upon acceptance. 3 | ## Reproducing the results 4 | 5 | This repository is based on [convNet.pytorch](https://github.com/eladhoffer/convNet.pytorch) repo. please ensure that you are using pytorch 1.3+. 6 | To repreduce the results 7 | ```bash 8 | sh scripts/advanced_pipeline.sh 9 | sh scripts/light_pipeline.sh 10 | ``` 11 | To reproduce BERT-base results please follow the instruction in BERT-base folder 12 | -------------------------------------------------------------------------------- /scripts/integer-programing.sh: -------------------------------------------------------------------------------- 1 | export datasets_dir=/media/drive/Datasets 2 | export model=${1:-"resnet"} 3 | export model_vis=${2:-"resnet50"} 4 | export nbits_weight_m1=${3:-4} 5 | export nbits_act_m1=${4:-4} 6 | export nbits_weight_m2=${5:-8} 7 | export nbits_act_m2=${6:-8} 8 | export perC=True 9 | export adaquant=True 10 | export precisions="$nbits_weight_m2;$nbits_weight_m1" 11 | export min_compression=0.13 12 | export max_compression=0.25 13 | 14 | export adaquant_suffix='' 15 | export do_not_use_adaquant=--do_not_use_adaquant 16 | if [ "$9" = True ]; then 17 | export adaquant_suffix='.adaquant' 18 | export do_not_use_adaquant='' 19 | fi 20 | export perC_suffix='' 21 | if [ "$perC" = True ]; then 22 | export perC_suffix='_perC' 23 | fi 24 | export workdir_m1=${model_vis}_w$nbits_weight_m1'a'$nbits_act_m1$adaquant_suffix 25 | export workdir_m2=${model_vis}_w$nbits_weight_m2'a'$nbits_act_m2$adaquant_suffix 26 | 27 | export num_sp_layers=-1 28 | export depth=${7:-50} 29 | export loss=${8:-'loss'} 30 | export layer_by_layer=results/$workdir_m2/$model.absorb_bn.measure$perC_suffix$adaquant_suffix.per_layer_accuracy.A$nbits_weight_m1.W$nbits_weight_m1.csv 31 | 32 | #Extract per layer loss delta 33 | python main.py --model $model --evaluate results/$workdir_m2/$model.absorb_bn'.measure'$perC_suffix$adaquant_suffix --model-config "{'batch_norm': False,'measure': False, 'perC': $perC, 'depth': $depth}" -b 100 --dataset imagenet_calib --datasets-dir $datasets_dir --int8_opt_model_path results/$workdir_m2/$model.absorb_bn'.measure'$perC_suffix$adaquant_suffix --int4_opt_model_path results/$workdir_m1/$model.absorb_bn'.measure'$perC_suffix$adaquant_suffix --names-sp-layers '' --per-layer 34 | 35 | #Run IP algorithm to obtain best topology 36 | python mpip_compression_pytorch_multi.py --model $model --model_vis $model_vis --ip_method $loss --precisions $precisions --layer_by_layer_files $layer_by_layer --min_compression $min_compression --max_compression $max_compression $do_not_use_adaquant --datasets-dir $datasets_dir 37 | -------------------------------------------------------------------------------- /utils/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy.random import beta 4 | from .misc import onehot 5 | 6 | 7 | class MixUp(nn.Module): 8 | def __init__(self, batch_dim=0): 9 | super(MixUp, self).__init__() 10 | self.batch_dim = batch_dim 11 | self.reset() 12 | 13 | def reset(self): 14 | self.enabled = False 15 | self.mix_values = None 16 | self.mix_index = None 17 | 18 | def mix(self, x1, x2): 19 | if not torch.is_tensor(self.mix_values): # scalar 20 | return x2.lerp(x1, self.mix_values) 21 | else: 22 | view = [1] * int(x1.dim()) 23 | view[self.batch_dim] = -1 24 | mix_val = self.mix_values.to(device=x1.device).view(*view) 25 | return mix_val * x1 + (1.-mix_val) * x2 26 | 27 | def sample(self, alpha, batch_size, sample_batch=False): 28 | self.mix_index = torch.randperm(batch_size) 29 | if sample_batch: 30 | values = beta(alpha, alpha, size=batch_size) 31 | self.mix_values = torch.tensor(values, dtype=torch.float) 32 | else: 33 | self.mix_values = torch.tensor([beta(alpha, alpha)], 34 | dtype=torch.float) 35 | 36 | def mix_target(self, y, n_class): 37 | if not self.training or \ 38 | self.mix_values is None or\ 39 | self.mix_values is None: 40 | return y 41 | y = onehot(y, n_class).to(dtype=torch.float) 42 | idx = self.mix_index.to(device=y.device) 43 | y_mix = y.index_select(self.batch_dim, idx) 44 | return self.mix(y, y_mix) 45 | 46 | def forward(self, x): 47 | if not self.training or \ 48 | self.mix_values is None or\ 49 | self.mix_values is None: 50 | return x 51 | idx = self.mix_index.to(device=x.device) 52 | x_mix = x.index_select(self.batch_dim, idx) 53 | return self.mix(x, x_mix) 54 | -------------------------------------------------------------------------------- /utils/quantize.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point']) 6 | 7 | 8 | def quantize_tensor(x, num_bits=8): 9 | qmin = 0. 10 | qmax = 2.**num_bits - 1. 11 | min_val, max_val = x.min(), x.max() 12 | 13 | scale = (max_val - min_val) / (qmax - qmin) 14 | 15 | initial_zero_point = qmin - min_val / scale 16 | 17 | zero_point = 0 18 | if initial_zero_point < qmin: 19 | zero_point = qmin 20 | elif initial_zero_point > qmax: 21 | zero_point = qmax 22 | else: 23 | zero_point = initial_zero_point 24 | 25 | zero_point = int(zero_point) 26 | q_x = zero_point + x / scale 27 | q_x.clamp_(qmin, qmax).round_() 28 | q_x = q_x.round().byte() 29 | return QTensor(tensor=q_x, scale=scale, zero_point=zero_point) 30 | 31 | 32 | def dequantize_tensor(q_x): 33 | return q_x.scale * (q_x.tensor.float() - q_x.zero_point) 34 | 35 | 36 | def quantize_model(model): 37 | qparams = {} 38 | 39 | for n, p in model.state_dict().items(): 40 | qp = quantize_tensor(p) 41 | qparams[n + '.quantization.scale'] = torch.FloatTensor([qp.scale]) 42 | qparams[ 43 | n + '.quantization.zero_point'] = torch.ByteTensor([qp.zero_point]) 44 | p.copy_(qp.tensor) 45 | model.type('torch.ByteTensor') 46 | for n, p in qparams.items(): 47 | model.register_buffer(n, p) 48 | model.quantized = True 49 | 50 | 51 | def dequantize_model(model): 52 | model.float() 53 | params = model.state_dict() 54 | for n, p in params.items(): 55 | if 'quantization' not in n: 56 | qp = QTensor(tensor=p, 57 | scale=params[n + '.quantization.scale'][0], 58 | zero_point=params[n + '.quantization.zero_point'][0]) 59 | p.copy_(dequantize_tensor(qp)) 60 | model.register_buffer(n + '.quantization.scale', None) 61 | model.register_buffer(n + '.quantization.zero_point', None) 62 | model.quantized = None 63 | -------------------------------------------------------------------------------- /models/modules/batch_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import BatchNorm1d as _BatchNorm1d 4 | from torch.nn import BatchNorm2d as _BatchNorm2d 5 | from torch.nn import BatchNorm3d as _BatchNorm3d 6 | 7 | """ 8 | BatchNorm variants that can be disabled by removing all parameters and running stats 9 | """ 10 | 11 | 12 | def has_running_stats(m): 13 | return getattr(m, 'running_mean', None) is not None\ 14 | or getattr(m, 'running_var', None) is not None 15 | 16 | 17 | def has_parameters(m): 18 | return getattr(m, 'weight', None) is not None\ 19 | or getattr(m, 'bias', None) is not None 20 | 21 | 22 | class BatchNorm1d(_BatchNorm1d): 23 | def forward(self, inputs): 24 | if not (has_parameters(self) or has_running_stats(self)): 25 | return inputs 26 | return super(BatchNorm1d, self).forward(inputs) 27 | 28 | 29 | class BatchNorm2d(_BatchNorm2d): 30 | def forward(self, inputs): 31 | if not (has_parameters(self) or has_running_stats(self)): 32 | return inputs 33 | return super(BatchNorm2d, self).forward(inputs) 34 | 35 | 36 | class BatchNorm3d(_BatchNorm3d): 37 | def forward(self, inputs): 38 | if not (has_parameters(self) or has_running_stats(self)): 39 | return inputs 40 | return super(BatchNorm3d, self).forward(inputs) 41 | 42 | 43 | class MeanBatchNorm2d(nn.BatchNorm2d): 44 | """BatchNorm with mean-only normalization""" 45 | 46 | def __init__(self, num_features, momentum=0.1, bias=True): 47 | nn.Module.__init__(self) 48 | self.register_buffer('running_mean', torch.zeros(num_features)) 49 | self.momentum = momentum 50 | self.num_features = num_features 51 | if bias: 52 | self.bias = nn.Parameter(torch.zeros(num_features)) 53 | else: 54 | self.register_parameter('bias', None) 55 | 56 | def forward(self, x): 57 | if not (has_parameters(self) or has_running_stats(self)): 58 | return x 59 | if self.training: 60 | numel = x.size(0) * x.size(2) * x.size(3) 61 | mean = x.sum((0, 2, 3)) / numel 62 | with torch.no_grad(): 63 | self.running_mean.mul_(self.momentum)\ 64 | .add_(1 - self.momentum, mean) 65 | else: 66 | mean = self.running_mean 67 | if self.bias is not None: 68 | mean = mean - self.bias 69 | return x - mean.view(1, -1, 1, 1) 70 | 71 | def extra_repr(self): 72 | return '{num_features}, momentum={momentum}, bias={has_bias}'.format( 73 | has_bias=self.bias is not None, **self.__dict__) 74 | -------------------------------------------------------------------------------- /models/modules/fixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def _sum_tensor_scalar(tensor, scalar, expand_size): 6 | if scalar is not None: 7 | scalar = scalar.expand(expand_size).contiguous() 8 | else: 9 | return tensor 10 | if tensor is None: 11 | return scalar 12 | return tensor + scalar 13 | 14 | 15 | class ZIConv2d(nn.Conv2d): 16 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 17 | padding=0, dilation=1, groups=1, bias=False, 18 | multiplier=False, pre_bias=True, post_bias=True): 19 | super(ZIConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 20 | padding, dilation, groups, bias) 21 | if pre_bias: 22 | self.pre_bias = nn.Parameter(torch.tensor([0.])) 23 | else: 24 | self.register_parameter('pre_bias', None) 25 | if post_bias: 26 | self.post_bias = nn.Parameter(torch.tensor([0.])) 27 | else: 28 | self.register_parameter('post_bias', None) 29 | if multiplier: 30 | self.multiplier = nn.Parameter(torch.tensor([1.])) 31 | else: 32 | self.register_parameter('multiplier', None) 33 | 34 | def forward(self, x): 35 | if self.pre_bias is not None: 36 | x = x + self.pre_bias 37 | weight = self.weight if self.multiplier is None\ 38 | else self.weight * self.multiplier 39 | bias = _sum_tensor_scalar(self.bias, self.post_bias, self.out_channels) 40 | return nn.functional.conv2d(x, weight, bias, self.stride, 41 | self.padding, self.dilation, self.groups) 42 | 43 | 44 | class ZILinear(nn.Linear): 45 | def __init__(self, in_features, out_features, bias=False, 46 | multiplier=False, pre_bias=True, post_bias=True): 47 | super(ZILinear, self).__init__(in_features, out_features, bias) 48 | if pre_bias: 49 | self.pre_bias = nn.Parameter(torch.tensor([0.])) 50 | else: 51 | self.register_parameter('pre_bias', None) 52 | if post_bias: 53 | self.post_bias = nn.Parameter(torch.tensor([0.])) 54 | else: 55 | self.register_parameter('post_bias', None) 56 | if multiplier: 57 | self.multiplier = nn.Parameter(torch.tensor([1.])) 58 | else: 59 | self.register_parameter('multiplier', None) 60 | 61 | def forward(self, x): 62 | if self.pre_bias is not None: 63 | x = x + self.pre_bias 64 | weight = self.weight if self.multiplier is None\ 65 | else self.weight * self.multiplier 66 | bias = _sum_tensor_scalar(self.bias, self.post_bias, self.out_features) 67 | return nn.functional.linear(x, weight, bias) 68 | -------------------------------------------------------------------------------- /utils/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .misc import onehot 6 | 7 | 8 | def _is_long(x): 9 | if hasattr(x, 'data'): 10 | x = x.data 11 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor) 12 | 13 | 14 | def cross_entropy(logits, target, weight=None, ignore_index=-100, reduction='mean', 15 | smooth_eps=None, smooth_dist=None): 16 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567""" 17 | smooth_eps = smooth_eps or 0 18 | 19 | # ordinary log-liklihood - use cross_entropy from nn 20 | if _is_long(target) and smooth_eps == 0: 21 | return F.cross_entropy(logits, target, weight, ignore_index=ignore_index, reduction=reduction) 22 | 23 | masked_indices = None 24 | num_classes = logits.size(-1) 25 | 26 | if _is_long(target) and ignore_index >= 0: 27 | masked_indices = target.eq(ignore_index) 28 | 29 | if smooth_eps > 0 and smooth_dist is not None: 30 | if _is_long(target): 31 | target = onehot(target, num_classes).type_as(logits) 32 | if smooth_dist.dim() < target.dim(): 33 | smooth_dist = smooth_dist.unsqueeze(0) 34 | target.lerp_(smooth_dist, smooth_eps) 35 | 36 | # log-softmax of logits 37 | lsm = F.log_softmax(logits, dim=-1) 38 | 39 | if weight is not None: 40 | lsm = lsm * weight.unsqueeze(0) 41 | 42 | if _is_long(target): 43 | eps = smooth_eps / (num_classes - 1) 44 | nll = -lsm.gather(dim=-1, index=target.unsqueeze(-1)) 45 | loss = (1. - 2 * eps) * nll - eps * lsm.sum(-1) 46 | else: 47 | loss = -(target * lsm).sum(-1) 48 | 49 | if masked_indices is not None: 50 | loss.masked_fill_(masked_indices, 0) 51 | 52 | if reduction == 'sum': 53 | loss = loss.sum() 54 | elif reduction == 'mean': 55 | if masked_indices is None: 56 | loss = loss.mean() 57 | else: 58 | loss = loss.sum() / float(loss.size(0) - masked_indices.sum()) 59 | 60 | return loss 61 | 62 | 63 | class CrossEntropyLoss(nn.CrossEntropyLoss): 64 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing""" 65 | 66 | def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None): 67 | super(CrossEntropyLoss, self).__init__(weight=weight, 68 | ignore_index=ignore_index, reduction=reduction) 69 | self.smooth_eps = smooth_eps 70 | self.smooth_dist = smooth_dist 71 | 72 | def forward(self, input, target, smooth_dist=None): 73 | if smooth_dist is None: 74 | smooth_dist = self.smooth_dist 75 | return cross_entropy(input, target, self.weight, self.ignore_index, self.reduction, self.smooth_eps, smooth_dist) 76 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 6 | 7 | torch_dtypes = { 8 | 'float': torch.float, 9 | 'float32': torch.float32, 10 | 'float64': torch.float64, 11 | 'double': torch.double, 12 | 'float16': torch.float16, 13 | 'half': torch.half, 14 | 'uint8': torch.uint8, 15 | 'int8': torch.int8, 16 | 'int16': torch.int16, 17 | 'short': torch.short, 18 | 'int32': torch.int32, 19 | 'int': torch.int, 20 | 'int64': torch.int64, 21 | 'long': torch.long 22 | } 23 | 24 | 25 | def onehot(indexes, N=None, ignore_index=None): 26 | """ 27 | Creates a one-representation of indexes with N possible entries 28 | if N is not specified, it will suit the maximum index appearing. 29 | indexes is a long-tensor of indexes 30 | ignore_index will be zero in onehot representation 31 | """ 32 | if N is None: 33 | N = indexes.max() + 1 34 | sz = list(indexes.size()) 35 | output = indexes.new().byte().resize_(*sz, N).zero_() 36 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 37 | if ignore_index is not None and ignore_index >= 0: 38 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 39 | return output 40 | 41 | 42 | def set_global_seeds(i): 43 | try: 44 | import torch 45 | except ImportError: 46 | pass 47 | else: 48 | torch.manual_seed(i) 49 | if torch.cuda.is_available(): 50 | torch.cuda.manual_seed_all(i) 51 | np.random.seed(i) 52 | random.seed(i) 53 | 54 | 55 | class CheckpointModule(nn.Module): 56 | def __init__(self, module, num_segments=1): 57 | super(CheckpointModule, self).__init__() 58 | assert num_segments == 1 or isinstance(module, nn.Sequential) 59 | self.module = module 60 | self.num_segments = num_segments 61 | 62 | def forward(self, x): 63 | if self.num_segments > 1: 64 | return checkpoint_sequential(self.module, self.num_segments, x) 65 | else: 66 | return checkpoint(self.module, x) 67 | 68 | 69 | def normalize_module_name(layer_name): 70 | """Normalize a module's name. 71 | 72 | PyTorch let's you parallelize the computation of a model, by wrapping a model with a 73 | DataParallel module. Unfortunately, this changs the fully-qualified name of a module, 74 | even though the actual functionality of the module doesn't change. 75 | Many time, when we search for modules by name, we are indifferent to the DataParallel 76 | module and want to use the same module name whether the module is parallel or not. 77 | We call this module name normalization, and this is implemented here. 78 | """ 79 | modules = layer_name.split('.') 80 | try: 81 | idx = modules.index('module') 82 | except ValueError: 83 | return layer_name 84 | del modules[idx] 85 | return '.'.join(modules) 86 | -------------------------------------------------------------------------------- /models/modules/fixed_proj.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | from torch.autograd import Variable 5 | from scipy.linalg import hadamard 6 | 7 | class HadamardProj(nn.Module): 8 | 9 | def __init__(self, input_size, output_size, bias=True, fixed_weights=True, fixed_scale=None): 10 | super(HadamardProj, self).__init__() 11 | self.output_size = output_size 12 | self.input_size = input_size 13 | sz = 2 ** int(math.ceil(math.log(max(input_size, output_size), 2))) 14 | mat = torch.from_numpy(hadamard(sz)) 15 | if fixed_weights: 16 | self.proj = Variable(mat, requires_grad=False) 17 | else: 18 | self.proj = nn.Parameter(mat) 19 | 20 | init_scale = 1. / math.sqrt(self.output_size) 21 | 22 | if fixed_scale is not None: 23 | self.scale = Variable(torch.Tensor( 24 | [fixed_scale]), requires_grad=False) 25 | else: 26 | self.scale = nn.Parameter(torch.Tensor([init_scale])) 27 | 28 | if bias: 29 | self.bias = nn.Parameter(torch.Tensor( 30 | output_size).uniform_(-init_scale, init_scale)) 31 | else: 32 | self.register_parameter('bias', None) 33 | 34 | self.eps = 1e-8 35 | 36 | def forward(self, x): 37 | if not isinstance(self.scale, nn.Parameter): 38 | self.scale = self.scale.type_as(x) 39 | x = x / (x.norm(2, -1, keepdim=True) + self.eps) 40 | w = self.proj.type_as(x) 41 | 42 | out = -self.scale * \ 43 | nn.functional.linear(x, w[:self.output_size, :self.input_size]) 44 | if self.bias is not None: 45 | out = out + self.bias.view(1, -1) 46 | return out 47 | 48 | 49 | class Proj(nn.Module): 50 | 51 | def __init__(self, input_size, output_size, bias=True, init_scale=10): 52 | super(Proj, self).__init__() 53 | if init_scale is not None: 54 | self.weight = nn.Parameter(torch.Tensor(1).fill_(init_scale)) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(output_size).fill_(0)) 57 | self.proj = Variable(torch.Tensor( 58 | output_size, input_size), requires_grad=False) 59 | torch.manual_seed(123) 60 | nn.init.orthogonal(self.proj) 61 | 62 | def forward(self, x): 63 | w = self.proj.type_as(x) 64 | x = x / x.norm(2, -1, keepdim=True) 65 | out = nn.functional.linear(x, w) 66 | if hasattr(self, 'weight'): 67 | out = out * self.weight 68 | if hasattr(self, 'bias'): 69 | out = out + self.bias.view(1, -1) 70 | return out 71 | 72 | class LinearFixed(nn.Linear): 73 | 74 | def __init__(self, input_size, output_size, bias=True, init_scale=10): 75 | super(LinearFixed, self).__init__(input_size, output_size, bias) 76 | self.scale = nn.Parameter(torch.Tensor(1).fill_(init_scale)) 77 | 78 | def forward(self, x): 79 | w = self.weight / self.weight.norm(2, -1, keepdim=True) 80 | x = x / x.norm(2, -1, keepdim=True) 81 | out = nn.functional.linear(x, w, self.bias) 82 | return out 83 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class OnlineMeter(object): 24 | """Computes and stores the average and variance/std values of tensor""" 25 | 26 | def __init__(self): 27 | self.mean = torch.FloatTensor(1).fill_(-1) 28 | self.M2 = torch.FloatTensor(1).zero_() 29 | self.count = 0. 30 | self.needs_init = True 31 | 32 | def reset(self, x): 33 | self.mean = x.new(x.size()).zero_() 34 | self.M2 = x.new(x.size()).zero_() 35 | self.count = 0. 36 | self.needs_init = False 37 | 38 | def update(self, x): 39 | self.val = x 40 | if self.needs_init: 41 | self.reset(x) 42 | self.count += 1 43 | delta = x - self.mean 44 | self.mean.add_(delta / self.count) 45 | delta2 = x - self.mean 46 | self.M2.add_(delta * delta2) 47 | 48 | @property 49 | def var(self): 50 | if self.count < 2: 51 | return self.M2.clone().zero_() 52 | return self.M2 / (self.count - 1) 53 | 54 | @property 55 | def std(self): 56 | return self.var().sqrt() 57 | 58 | 59 | def accuracy(output, target, topk=(1,)): 60 | """Computes the precision@k for the specified values of k""" 61 | maxk = max(topk) 62 | batch_size = target.size(0) 63 | 64 | _, pred = output.topk(maxk, 1, True, True) 65 | pred = pred.t().type_as(target) 66 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 67 | 68 | res = [] 69 | for k in topk: 70 | correct_k = correct[:k].view(-1).float().sum(0) 71 | res.append(correct_k.mul_(100.0 / batch_size)) 72 | return res 73 | 74 | 75 | class AccuracyMeter(object): 76 | """Computes and stores the average and current topk accuracy""" 77 | 78 | def __init__(self, topk=(1,)): 79 | self.topk = topk 80 | self.reset() 81 | 82 | def reset(self): 83 | self._meters = {} 84 | for k in self.topk: 85 | self._meters[k] = AverageMeter() 86 | 87 | def update(self, output, target): 88 | n = target.nelement() 89 | acc_vals = accuracy(output, target, self.topk) 90 | for i, k in enumerate(self.topk): 91 | self._meters[k].update(acc_vals[i]) 92 | 93 | @property 94 | def val(self): 95 | return {n: meter.val for (n, meter) in self._meters.items()} 96 | 97 | @property 98 | def avg(self): 99 | return {n: meter.avg for (n, meter) in self._meters.items()} 100 | 101 | @property 102 | def avg_error(self): 103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()} 104 | -------------------------------------------------------------------------------- /utils/regime.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | from six import string_types 4 | 5 | 6 | def eval_func(f, x): 7 | if isinstance(f, string_types): 8 | f = eval(f) 9 | return f(x) 10 | 11 | 12 | class Regime(object): 13 | """ 14 | Examples for regime: 15 | 16 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 17 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 18 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 19 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 20 | ]" 21 | 2) 22 | "[{'step_lambda': 23 | "lambda t: { 24 | 'optimizer': 'Adam', 25 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 26 | 'betas': (0.9, 0.98), 'eps':1e-9} 27 | }]" 28 | """ 29 | 30 | def __init__(self, regime, defaults={}): 31 | self.regime = regime 32 | self.current_regime_phase = None 33 | self.setting = defaults 34 | 35 | def update(self, epoch=None, train_steps=None): 36 | """adjusts according to current epoch or steps and regime. 37 | """ 38 | if self.regime is None: 39 | return False 40 | epoch = -1 if epoch is None else epoch 41 | train_steps = -1 if train_steps is None else train_steps 42 | setting = deepcopy(self.setting) 43 | if self.current_regime_phase is None: 44 | # Find the first entry where the epoch is smallest than current 45 | for regime_phase, regime_setting in enumerate(self.regime): 46 | start_epoch = regime_setting.get('epoch', 0) 47 | start_step = regime_setting.get('step', 0) 48 | if epoch >= start_epoch or train_steps >= start_step: 49 | self.current_regime_phase = regime_phase 50 | break 51 | # each entry is updated from previous 52 | setting.update(regime_setting) 53 | if len(self.regime) > self.current_regime_phase + 1: 54 | next_phase = self.current_regime_phase + 1 55 | # Any more regime steps? 56 | start_epoch = self.regime[next_phase].get('epoch', float('inf')) 57 | start_step = self.regime[next_phase].get('step', float('inf')) 58 | if epoch >= start_epoch or train_steps >= start_step: 59 | self.current_regime_phase = next_phase 60 | setting.update(self.regime[self.current_regime_phase]) 61 | 62 | if 'lr_decay_rate' in setting and 'lr' in setting: 63 | decay_steps = setting.pop('lr_decay_steps', 100) 64 | if train_steps % decay_steps == 0: 65 | decay_rate = setting.pop('lr_decay_rate') 66 | setting['lr'] *= decay_rate ** (train_steps / decay_steps) 67 | elif 'step_lambda' in setting: 68 | setting.update(eval_func(setting.pop('step_lambda'), train_steps)) 69 | elif 'epoch_lambda' in setting: 70 | setting.update(eval_func(setting.pop('epoch_lambda'), epoch)) 71 | 72 | if 'execute' in setting: 73 | setting.pop('execute')() 74 | 75 | if 'execute_once' in setting: 76 | setting.pop('execute_once')() 77 | # remove from regime, so won't happen again 78 | self.regime[self.current_regime_phase].pop('execute_once', None) 79 | 80 | if setting == self.setting: 81 | return False 82 | else: 83 | self.setting = setting 84 | return True 85 | -------------------------------------------------------------------------------- /utils/model_manipulation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import torch.nn as nn 4 | from models.modules.quantize import QConv2d,QLinear 5 | 6 | def replace_layer(model,layer,exec_str_m,new_layer_str='=QConv2d(',num_bits=8,num_bits_weight=8,forced_bias=False): 7 | new_layer_str='=QConv2d(' if is_Conv(layer) else '=QLinear(' 8 | bias_str = 'True,' if forced_bias else exec_str_m+'.bias is not None,' 9 | #import pdb; pdb.set_trace() 10 | if 'Linear' in new_layer_str: 11 | exec_str = exec_str_m+new_layer_str + exec_str_m+'.in_features,'+exec_str_m+'.out_features,'+bias_str+'num_bits=num_bits,'+'num_bits_weight=num_bits_weight)' 12 | elif 'Conv' in new_layer_str: 13 | print(exec_str_m) 14 | exec_str = exec_str_m+new_layer_str+exec_str_m+'.in_channels,'+exec_str_m+'.out_channels,'+exec_str_m+'.kernel_size,'+exec_str_m+'.stride,'+exec_str_m+'.padding,'+exec_str_m+'.dilation,'+exec_str_m+'.groups,'+bias_str+'num_bits=num_bits,'+'num_bits_weight=num_bits_weight)' 15 | else: 16 | import pdb; pdb.set_trace() 17 | exec(exec_str) 18 | return model 19 | 20 | def is_Linear(m): 21 | return isinstance(m, nn.Linear) or isinstance(m, QLinear) 22 | 23 | def is_Conv(m): 24 | return isinstance(m, nn.Conv2d) or isinstance(m, QConv2d) 25 | 26 | def is_bn(m): 27 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 28 | 29 | def is_absorbing(m): 30 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, QConv2d) or isinstance(m, QLinear) 31 | 32 | def search_replace_conv_linear(model,name_model='',arr=[]): 33 | prev = None 34 | for i,m in enumerate(model.children()): 35 | modules_names=[key for key in model._modules.keys()] 36 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 37 | exec_str_m = 'model._modules[\'%s\']'%layer_name 38 | if is_Conv(m) or is_Linear(m): 39 | model=replace_layer(model,m,exec_str_m) 40 | #arr= search_replace_conv_linear(m,layer_name,arr) 41 | prev = m 42 | return model 43 | 44 | def search_delete_bn(model,name_model='',arr=[]): 45 | prev = None; prev_name = None; bn_absorbed_layers=[]; 46 | for i,m in enumerate(model.children()): 47 | modules_names=[key for key in model._modules.keys()] 48 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 49 | exec_str_m = 'model._modules[\'%s\']'%layer_name 50 | if is_bn(m) and is_absorbing(prev): 51 | if is_Conv(prev) or is_Linear(prev): 52 | if prev.bias is None and prev_name is not None: 53 | model = replace_layer(model,prev,prev_name,forced_bias=True) 54 | bn_absorbed_layers.append(layer_name) 55 | prev = m 56 | prev_name=exec_str_m 57 | deps_keys = [key for key in model.deps.keys()] 58 | for layer_name in bn_absorbed_layers: 59 | model._modules.pop(layer_name) 60 | pop_name=deps_keys[int(layer_name)] 61 | new_name = model.deps[deps_keys[int(layer_name)]][1][0] 62 | for key in model.deps: 63 | for id_name,name in enumerate(model.deps[key][1]): 64 | if name==pop_name: 65 | model.deps[key][1][id_name]=new_name 66 | #import pdb; pdb.set_trace() 67 | model.deps.pop(deps_keys[int(layer_name)]) 68 | #import pdb; pdb.set_trace() 69 | return model 70 | -------------------------------------------------------------------------------- /utils/param_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def is_not_bias(name): 6 | return not name.endswith('bias') 7 | 8 | 9 | def is_bn(module): 10 | return isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) 11 | 12 | 13 | def is_not_bn(module): 14 | return not is_bn(module) 15 | 16 | 17 | def filtered_parameter_info(model, module_fn=None, module_name_fn=None, parameter_name_fn=None, memo=None): 18 | if memo is None: 19 | memo = set() 20 | 21 | for module_name, module in model.named_modules(): 22 | if module_fn is not None and not module_fn(module): 23 | continue 24 | if module_name_fn is not None and not module_name_fn(module_name): 25 | continue 26 | for parameter_name, param in module.named_parameters(prefix=module_name, recurse=False): 27 | if parameter_name_fn is not None and not parameter_name_fn(parameter_name): 28 | continue 29 | if param not in memo: 30 | memo.add(param) 31 | yield {'named_module': (module_name, module), 'named_parameter': (parameter_name, param)} 32 | 33 | 34 | class FilterParameters(object): 35 | def __init__(self, source, module=None, module_name=None, parameter_name=None): 36 | if isinstance(source, FilterParameters): 37 | self._filtered_parameter_info = list(source.filter( 38 | module=module, 39 | module_name=module_name, 40 | parameter_name=parameter_name)) 41 | elif isinstance(source, torch.nn.Module): # source is a model 42 | self._filtered_parameter_info = list(filtered_parameter_info(source, 43 | module_fn=module, 44 | module_name_fn=module_name, 45 | parameter_name_fn=parameter_name)) 46 | 47 | def named_parameters(self): 48 | for p in self._filtered_parameter_info: 49 | yield p['named_parameter'] 50 | 51 | def parameters(self): 52 | for _, p in self.named_parameters(): 53 | yield p 54 | 55 | def filter(self, module=None, module_name=None, parameter_name=None): 56 | for p_info in self._filtered_parameter_info: 57 | if (module is None or module(p_info['named_module'][1]) 58 | and (module_name is None or module_name(p_info['named_module'][0])) 59 | and (parameter_name is None or parameter_name(p_info['named_parameter'][0]))): 60 | yield p_info 61 | 62 | def named_modules(self): 63 | for m in self._filtered_parameter_info: 64 | yield m['named_module'] 65 | 66 | def modules(self): 67 | for _, m in self.named_modules(): 68 | yield m 69 | 70 | def to(self, *kargs, **kwargs): 71 | for m in self.modules(): 72 | m.to(*kargs, **kwargs) 73 | 74 | 75 | class FilterModules(FilterParameters): 76 | pass 77 | 78 | if __name__ == '__main__': 79 | from torchvision.models import resnet50 80 | model = resnet50() 81 | filterd_params = FilterParameters(model, 82 | module=lambda m: isinstance( 83 | m, torch.nn.Linear), 84 | parameter_name=lambda n: 'bias' in n) 85 | -------------------------------------------------------------------------------- /BERT-base/src/layer_sensativity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import torch.nn as nn 4 | from transformers.modeling_quantize import calculate_qparams, quantize, QConv2d, QLinear, QMatmul, QEmbedding 5 | 6 | def search_replace_layer(model,all_names,num_bits_activation,num_bits_weight,name_model=''): 7 | for i,m in enumerate(model.children()): 8 | modules_names=[key for key in model._modules.keys()] 9 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 10 | m.name=layer_name 11 | if layer_name in all_names: 12 | if isinstance(all_names,dict): 13 | num_bits_activation,num_bits_weight = all_names[layer_name] 14 | print("Layer {}, precision switch from {}-bit to {}-bit weight, {}-bit activation.".format( 15 | layer_name, m.num_bits, num_bits_weight, num_bits_activation)) 16 | m.num_bits=num_bits_activation 17 | m.num_bits_weight = num_bits_weight 18 | if isinstance(m, QLinear): 19 | m.quantize_input.num_bits=num_bits_activation 20 | m.quantize_weight.num_bits=num_bits_weight 21 | if isinstance(m, QMatmul): 22 | m.quantize_input1.num_bits=num_bits_activation 23 | m.quantize_input2.num_bits=num_bits_activation 24 | if isinstance(m, QEmbedding): 25 | m.quantize_weight.num_bits=num_bits_weight 26 | search_replace_layer(m,all_names,num_bits_activation,num_bits_weight,layer_name) 27 | return model 28 | 29 | def is_q_module(m): 30 | return isinstance(m, QConv2d) or isinstance(m, QLinear) or isinstance(m, QMatmul) or isinstance(m, QEmbedding) 31 | 32 | def extract_all_quant_layers_names(model,q_names=[],name_model=''): 33 | for i,m in enumerate(model.children()): 34 | modules_names=[key for key in model._modules.keys()] 35 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 36 | m.name=layer_name 37 | if is_q_module(m): 38 | q_names.append(m.name) 39 | q_names = extract_all_quant_layers_names(m,q_names,layer_name) 40 | return q_names 41 | 42 | def check_quantized_model(model,fp_names=[],name_model=''): 43 | for i,m in enumerate(model.children()): 44 | modules_names=[key for key in model._modules.keys()] 45 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 46 | m.name=layer_name 47 | if (is_q_module(m) and m.measure) or not is_q_module: 48 | fp_names.append(m.name) 49 | print("Layer {}, if in FP32.".format(layer_name)) 50 | fp_names = check_quantized_model(m,fp_names,layer_name) 51 | return fp_names 52 | 53 | def extract_save_quant_state_dict(model,all_names,filename='int_state_dict.pth.tar'): 54 | 55 | state_dict=model.state_dict() 56 | 57 | for key in state_dict.keys(): 58 | #import pdb; pdb.set_trace() 59 | val=state_dict[key] 60 | if 'weight' in key: 61 | num_bits = 4 if key[:-7] in all_names else 8 62 | if num_bits==4: 63 | import pdb; pdb.set_trace() 64 | weight_qparams = calculate_qparams(val, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None) 65 | val_q=quantize(val, qparams=weight_qparams,dequantize=False) 66 | zero_point=(-weight_qparams[1]/weight_qparams[0]*(2**weight_qparams[2]-1)).round() 67 | val_q=val_q-zero_point 68 | print(val_q.eq(0).sum().float().div(val_q.numel())) 69 | if 'bias' in key: 70 | val_q = quantize(val, num_bits=num_bits*2,flatten_dims=(0, -1)) 71 | 72 | state_dict[key] = val_q 73 | torch.save(state_dict,filename) 74 | return state_dict 75 | -------------------------------------------------------------------------------- /models/modules/bwn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Weight Normalization from https://arxiv.org/abs/1602.07868 3 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 4 | """ 5 | import torch 6 | from torch.nn.parameter import Parameter 7 | from torch.autograd import Function 8 | import torch.nn as nn 9 | 10 | 11 | def _norm(x, dim, p=2): 12 | """Computes the norm over all dimensions except dim""" 13 | if p == -1: 14 | def func(x, dim): return x.max(dim=dim)[0] - x.min(dim=dim)[0] 15 | elif p == float('inf'): 16 | def func(x, dim): return x.max(dim=dim)[0] 17 | else: 18 | def func(x, dim): return torch.norm(x, dim=dim, p=p) 19 | if dim is None: 20 | return x.norm(p=p) 21 | elif dim == 0: 22 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 23 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 24 | elif dim == x.dim() - 1: 25 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 26 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 27 | else: 28 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 29 | 30 | 31 | def _mean(p, dim): 32 | """Computes the mean over all dimensions except dim""" 33 | if dim is None: 34 | return p.mean() 35 | elif dim == 0: 36 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 37 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 38 | elif dim == p.dim() - 1: 39 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 40 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 41 | else: 42 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 43 | 44 | 45 | class BoundedWeightNorm(object): 46 | 47 | def __init__(self, name, dim, p): 48 | self.name = name 49 | self.dim = dim 50 | 51 | def compute_weight(self, module): 52 | 53 | v = getattr(module, self.name + '_v') 54 | v.data.div_(_norm(v, self.dim)) 55 | init_norm = getattr(module, self.name + '_init_norm') 56 | return v * (init_norm / _norm(v, self.dim)) 57 | 58 | @staticmethod 59 | def apply(module, name, dim, p): 60 | fn = BoundedWeightNorm(name, dim, p) 61 | 62 | weight = getattr(module, name) 63 | 64 | # remove w from parameter list 65 | del module._parameters[name] 66 | module.register_buffer( 67 | name + '_init_norm', torch.Tensor([_norm(weight, dim, p=p).data.mean()])) 68 | module.register_parameter(name + '_v', Parameter(weight.data)) 69 | setattr(module, name, fn.compute_weight(module)) 70 | 71 | # recompute weight before every forward() 72 | module.register_forward_pre_hook(fn) 73 | return fn 74 | 75 | def remove(self, module): 76 | weight = self.compute_weight(module) 77 | delattr(module, self.name) 78 | del module._parameters[self.name + '_v'] 79 | module.register_parameter(self.name, Parameter(weight.data)) 80 | 81 | def __call__(self, module, inputs): 82 | setattr(module, self.name, self.compute_weight(module)) 83 | 84 | 85 | def weight_norm(module, name='weight', dim=0, p=2): 86 | r"""Applies weight normalization to a parameter in the given module. 87 | 88 | .. math:: 89 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 90 | 91 | Weight normalization is a reparameterization that decouples the magnitude 92 | of a weight tensor from its direction. This replaces the parameter specified 93 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude 94 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). 95 | Weight normalization is implemented via a hook that recomputes the weight 96 | tensor from the magnitude and direction before every :meth:`~Module.forward` 97 | call. 98 | 99 | By default, with `dim=0`, the norm is computed independently per output 100 | channel/plane. To compute a norm over the entire weight tensor, use 101 | `dim=None`. 102 | 103 | See https://arxiv.org/abs/1602.07868 104 | 105 | Args: 106 | module (nn.Module): containing module 107 | name (str, optional): name of weight parameter 108 | dim (int, optional): dimension over which to compute the norm 109 | 110 | Returns: 111 | The original module with the weight norm hook 112 | 113 | Example:: 114 | 115 | >>> m = weight_norm(nn.Linear(20, 40), name='weight') 116 | Linear (20 -> 40) 117 | >>> m.weight_g.size() 118 | torch.Size([40, 1]) 119 | >>> m.weight_v.size() 120 | torch.Size([40, 20]) 121 | 122 | """ 123 | BoundedWeightNorm.apply(module, name, dim, p) 124 | return module 125 | 126 | 127 | def remove_weight_norm(module, name='weight'): 128 | r"""Removes the weight normalization reparameterization from a module. 129 | 130 | Args: 131 | module (nn.Module): containing module 132 | name (str, optional): name of weight parameter 133 | 134 | Example: 135 | >>> m = weight_norm(nn.Linear(20, 40)) 136 | >>> remove_weight_norm(m) 137 | """ 138 | for k, hook in module._forward_pre_hooks.items(): 139 | if isinstance(hook, BoundedWeightNorm) and hook.name == name: 140 | hook.remove(module) 141 | del module._forward_pre_hooks[k] 142 | return module 143 | 144 | raise ValueError("weight_norm of '{}' not found in {}" 145 | .format(name, module)) 146 | -------------------------------------------------------------------------------- /utils/layer_sensativity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import torch.nn as nn 4 | from models.modules.quantize import calculate_qparams, quantize, QConv2d,QLinear 5 | 6 | 7 | def search_replace_layer(model,all_names,num_bits_activation,num_bits_weight,name_model=''): 8 | for i,m in enumerate(model.children()): 9 | modules_names=[key for key in model._modules.keys()] 10 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 11 | m.name=layer_name 12 | if layer_name in all_names: 13 | print("Layer {}, precision switch from {}-bit to {}-bit weight, {}-bit activation.".format( 14 | layer_name, m.num_bits, num_bits_weight, num_bits_activation)) 15 | m.num_bits=num_bits_activation 16 | m.num_bits_weight = num_bits_weight 17 | m.quantize_input.num_bits=num_bits_activation 18 | m.quantize_weight.num_bits=num_bits_weight 19 | search_replace_layer(m,all_names,num_bits_activation,num_bits_weight,layer_name) 20 | return model 21 | 22 | 23 | # "{'conv1': [8, 8], 'layer1.0.conv1': [8, 8], 'layer1.0.conv2': [4, 4], 'layer1.0.conv3': [4, 4], 'layer1.0.downsample.0': [8, 8], 'layer1.1.conv1': [4, 4], 'layer1.1.conv2': [4, 4], 'layer1.1.conv3': [4, 4], 'layer1.2.conv1': [4, 4], 'layer1.2.conv2': [4, 4], 'layer1.2.conv3': [4, 4], 'layer2.0.conv1': [4, 4], 'layer2.0.conv2': [2, 2], 'layer2.0.conv3': [4, 4], 'layer2.0.downsample.0': [4, 4], 'layer2.1.conv1': [2, 2], 'layer2.1.conv2': [4, 4], 'layer2.1.conv3': [4, 4], 'layer2.2.conv1': [2, 2], 'layer2.2.conv2': [2, 2], 'layer2.2.conv3': [4, 4], 'layer2.3.conv1': [2, 2], 'layer2.3.conv2': [2, 2], 'layer2.3.conv3': [4, 4], 'layer3.0.conv1': [4, 4], 'layer3.0.conv2': [2, 2], 'layer3.0.conv3': [2, 2], 'layer3.0.downsample.0': [2, 2], 'layer3.1.conv1': [2, 2], 'layer3.1.conv2': [2, 2], 'layer3.1.conv3': [2, 2], 'layer3.2.conv1': [2, 2], 'layer3.2.conv2': [2, 2], 'layer3.2.conv3': [2, 2], 'layer3.3.conv1': [2, 2], 'layer3.3.conv2': [2, 2], 'layer3.3.conv3': [2, 2], 'layer3.4.conv1': [2, 2], 'layer3.4.conv2': [2, 2], 'layer3.4.conv3': [2, 2], 'layer3.5.conv1': [2, 2], 'layer3.5.conv2': [2, 2], 'layer3.5.conv3': [2, 2], 'layer4.0.conv1': [2, 2], 'layer4.0.conv2': [2, 2], 'layer4.0.conv3': [2, 2], 'layer4.0.downsample.0': [2, 2], 'layer4.1.conv1': [2, 2], 'layer4.1.conv2': [2, 2], 'layer4.1.conv3': [2, 2], 'layer4.2.conv1': [2, 2], 'layer4.2.conv2': [2, 2], 'layer4.2.conv3': [2, 2], 'fc': [4, 4]}" 24 | def search_replace_layer_from_dict(model, layers_precision_dict, name_model=''): 25 | for i,m in enumerate(model.children()): 26 | modules_names=[key for key in model._modules.keys()] 27 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 28 | m.name=layer_name 29 | if layer_name in layers_precision_dict: 30 | new_prec = layers_precision_dict[layer_name] 31 | print("Layer {}, precision switch from {}-bit to {}-bit weight, {}-bit activation.".format( 32 | layer_name, m.num_bits, new_prec[0], new_prec[1])) 33 | m.num_bits=new_prec[0] 34 | m.num_bits_weight = new_prec[1] 35 | m.quantize_input.num_bits=new_prec[1] 36 | m.quantize_weight.num_bits=new_prec[0] 37 | search_replace_layer_from_dict(m,layers_precision_dict,layer_name) 38 | return model 39 | 40 | 41 | def is_q_module(m): 42 | return isinstance(m, QConv2d) or isinstance(m, QLinear) 43 | 44 | def extract_all_quant_layers_names(model,q_names=[],name_model=''): 45 | for i,m in enumerate(model.children()): 46 | modules_names=[key for key in model._modules.keys()] 47 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 48 | m.name=layer_name 49 | if is_q_module(m): 50 | q_names.append(m.name) 51 | print("Layer {}, if in FP32.".format(layer_name)) 52 | q_names = check_quantized_model(m,q_names,layer_name) 53 | return q_names 54 | 55 | def check_quantized_model(model,fp_names=[],name_model=''): 56 | for i,m in enumerate(model.children()): 57 | modules_names=[key for key in model._modules.keys()] 58 | layer_name=name_model+'.'+modules_names[i] if name_model !='' else name_model+modules_names[i] 59 | m.name=layer_name 60 | if (is_q_module(m) and m.measure) or not is_q_module: 61 | fp_names.append(m.name) 62 | print("Layer {}, if in FP32.".format(layer_name)) 63 | fp_names = check_quantized_model(m,fp_names,layer_name) 64 | return fp_names 65 | 66 | def extract_save_quant_state_dict(model,all_names,filename='int_state_dict.pth.tar'): 67 | 68 | state_dict=model.state_dict() 69 | 70 | for key in state_dict.keys(): 71 | #import pdb; pdb.set_trace() 72 | val=state_dict[key] 73 | if 'weight' in key: 74 | num_bits = 4 if key[:-7] in all_names else 8 75 | if num_bits==4: 76 | import pdb; pdb.set_trace() 77 | weight_qparams = calculate_qparams(val, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None) 78 | val_q=quantize(val, qparams=weight_qparams,dequantize=False) 79 | zero_point=(-weight_qparams[1]/weight_qparams[0]*(2**weight_qparams[2]-1)).round() 80 | val_q=val_q-zero_point 81 | print(val_q.eq(0).sum().float().div(val_q.numel())) 82 | if 'bias' in key: 83 | val_q = quantize(val, num_bits=num_bits*2,flatten_dims=(0, -1)) 84 | 85 | state_dict[key] = val_q 86 | torch.save(state_dict,filename) 87 | return state_dict 88 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import pickle 3 | import PIL 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.utils.data.sampler import Sampler, RandomSampler, BatchSampler, _int_classes 7 | from numpy.random import choice 8 | 9 | class RandomSamplerReplacment(torch.utils.data.sampler.Sampler): 10 | """Samples elements randomly, with replacement. 11 | Arguments: 12 | data_source (Dataset): dataset to sample from 13 | """ 14 | 15 | def __init__(self, data_source): 16 | self.num_samples = len(data_source) 17 | 18 | def __iter__(self): 19 | return iter(torch.from_numpy(choice(self.num_samples, self.num_samples, replace=True))) 20 | 21 | def __len__(self): 22 | return self.num_samples 23 | 24 | 25 | class LimitDataset(Dataset): 26 | 27 | def __init__(self, dset, max_len): 28 | self.dset = dset 29 | self.max_len = max_len 30 | 31 | def __len__(self): 32 | return min(len(self.dset), self.max_len) 33 | 34 | def __getitem__(self, index): 35 | return self.dset[index] 36 | 37 | class ByClassDataset(Dataset): 38 | 39 | def __init__(self, ds): 40 | self.dataset = ds 41 | self.idx_by_class = {} 42 | for idx, (_, c) in enumerate(ds): 43 | self.idx_by_class.setdefault(c, []) 44 | self.idx_by_class[c].append(idx) 45 | 46 | def __len__(self): 47 | return min([len(d) for d in self.idx_by_class.values()]) 48 | 49 | def __getitem__(self, idx): 50 | idx_per_class = [self.idx_by_class[c][idx] 51 | for c in range(len(self.idx_by_class))] 52 | labels = torch.LongTensor([self.dataset[i][1] 53 | for i in idx_per_class]) 54 | items = [self.dataset[i][0] for i in idx_per_class] 55 | if torch.is_tensor(items[0]): 56 | items = torch.stack(items) 57 | 58 | return (items, labels) 59 | 60 | 61 | class IdxDataset(Dataset): 62 | """docstring for IdxDataset.""" 63 | 64 | def __init__(self, dset): 65 | super(IdxDataset, self).__init__() 66 | self.dset = dset 67 | self.idxs = range(len(self.dset)) 68 | 69 | def __getitem__(self, idx): 70 | data, labels = self.dset[self.idxs[idx]] 71 | return (idx, data, labels) 72 | 73 | def __len__(self): 74 | return len(self.idxs) 75 | 76 | 77 | def image_loader(imagebytes): 78 | img = PIL.Image.open(BytesIO(imagebytes)) 79 | return img.convert('RGB') 80 | 81 | 82 | class IndexedFileDataset(Dataset): 83 | """ A dataset that consists of an indexed file (with sample offsets in 84 | another file). For example, a .tar that contains image files. 85 | The dataset does not extract the samples, but works with the indexed 86 | file directly. 87 | NOTE: The index file is assumed to be a pickled list of 3-tuples: 88 | (name, offset, size). 89 | """ 90 | def __init__(self, filename, index_filename=None, extract_target_fn=None, 91 | transform=None, target_transform=None, loader=image_loader): 92 | super(IndexedFileDataset, self).__init__() 93 | 94 | # Defaults 95 | if index_filename is None: 96 | index_filename = filename + '.index' 97 | if extract_target_fn is None: 98 | extract_target_fn = lambda *args: args 99 | 100 | # Read index 101 | with open(index_filename, 'rb') as index_fp: 102 | sample_list = pickle.load(index_fp) 103 | 104 | # Collect unique targets (sorted by name) 105 | targetset = set(extract_target_fn(target) for target, _, _ in sample_list) 106 | targetmap = {target: i for i, target in enumerate(sorted(targetset))} 107 | 108 | self.samples = [(targetmap[extract_target_fn(target)], offset, size) 109 | for target, offset, size in sample_list] 110 | self.filename = filename 111 | 112 | self.loader = loader 113 | self.transform = transform 114 | self.target_transform = target_transform 115 | 116 | def _get_sample(self, fp, idx): 117 | target, offset, size = self.samples[idx] 118 | fp.seek(offset) 119 | sample = self.loader(fp.read(size)) 120 | 121 | if self.transform is not None: 122 | sample = self.transform(sample) 123 | if self.target_transform is not None: 124 | target = self.target_transform(target) 125 | 126 | return sample, target 127 | 128 | def __getitem__(self, index): 129 | with open(self.filename, 'rb') as fp: 130 | # Handle slices 131 | if isinstance(index, slice): 132 | return [self._get_sample(fp, subidx) for subidx in 133 | range(index.start or 0, index.stop or len(self), 134 | index.step or 1)] 135 | 136 | return self._get_sample(fp, index) 137 | 138 | def __len__(self): 139 | return len(self.samples) 140 | 141 | 142 | class DuplicateBatchSampler(Sampler): 143 | def __init__(self, sampler, batch_size, duplicates, drop_last): 144 | if not isinstance(sampler, Sampler): 145 | raise ValueError("sampler should be an instance of " 146 | "torch.utils.data.Sampler, but got sampler={}" 147 | .format(sampler)) 148 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 149 | batch_size <= 0: 150 | raise ValueError("batch_size should be a positive integeral value, " 151 | "but got batch_size={}".format(batch_size)) 152 | if not isinstance(drop_last, bool): 153 | raise ValueError("drop_last should be a boolean value, but got " 154 | "drop_last={}".format(drop_last)) 155 | self.sampler = sampler 156 | self.batch_size = batch_size 157 | self.drop_last = drop_last 158 | self.duplicates = duplicates 159 | 160 | def __iter__(self): 161 | batch = [] 162 | for idx in self.sampler: 163 | batch.append(idx) 164 | if len(batch) == self.batch_size: 165 | yield batch * self.duplicates 166 | batch = [] 167 | if len(batch) > 0 and not self.drop_last: 168 | yield batch * self.duplicates 169 | 170 | def __len__(self): 171 | if self.drop_last: 172 | return len(self.sampler) // self.batch_size 173 | else: 174 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 175 | -------------------------------------------------------------------------------- /utils/absorb_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | # from efficientnet_pytorch.utils import Conv2dSamePadding 5 | 6 | def remove_bn_params(bn_module): 7 | bn_module.register_buffer('running_mean', None) 8 | bn_module.register_buffer('running_var', None) 9 | bn_module.register_parameter('weight', None) 10 | bn_module.register_parameter('bias', None) 11 | 12 | 13 | def init_bn_params(bn_module): 14 | bn_module.running_mean.fill_(0) 15 | bn_module.running_var.fill_(1) 16 | if bn_module.affine: 17 | bn_module.weight.fill_(1) 18 | bn_module.bias.fill_(0) 19 | 20 | 21 | def absorb_bn(module, bn_module, remove_bn=True, verbose=False): 22 | with torch.no_grad(): 23 | w = module.weight 24 | if module.bias is None: 25 | zeros = torch.zeros(module.out_channels, 26 | dtype=w.dtype, device=w.device) 27 | bias = nn.Parameter(zeros) 28 | module.register_parameter('bias', bias) 29 | b = module.bias 30 | 31 | if hasattr(bn_module, 'running_mean'): 32 | b.add_(-bn_module.running_mean) 33 | if hasattr(bn_module, 'running_var'): 34 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) 35 | w.mul_(invstd.view(w.size(0), 1, 1, 1)) 36 | b.mul_(invstd) 37 | if hasattr(module, 'quantize_weight'): 38 | module.quantize_weight.running_range.mul_(invstd.view(w.size(0), 1, 1, 1)) 39 | module.quantize_weight.running_zero_point.mul_(invstd.view(w.size(0), 1, 1, 1)) 40 | 41 | if hasattr(bn_module, 'weight'): 42 | w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 43 | b.mul_(bn_module.weight) 44 | module.register_parameter('gamma', nn.Parameter(bn_module.weight.data.clone())) 45 | if hasattr(module, 'quantize_weight'): 46 | module.quantize_weight.running_range.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 47 | module.quantize_weight.running_zero_point.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 48 | if hasattr(bn_module, 'bias'): 49 | b.add_(bn_module.bias) 50 | module.register_parameter('beta', nn.Parameter(bn_module.bias.data.clone())) 51 | 52 | if remove_bn: 53 | remove_bn_params(bn_module) 54 | else: 55 | init_bn_params(bn_module) 56 | 57 | if verbose: 58 | logging.info('BN module %s was asborbed into layer %s' % 59 | (bn_module, module)) 60 | 61 | 62 | def is_bn(m): 63 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 64 | 65 | 66 | def is_absorbing(m): 67 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, Conv2dSamePadding) 68 | 69 | 70 | def search_absorbe_bn(model, prev=None, remove_bn=True, verbose=False): 71 | with torch.no_grad(): 72 | for m in model.children(): 73 | if is_bn(m) and is_absorbing(prev): 74 | # print(prev,m) 75 | absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose) 76 | search_absorbe_bn(m, remove_bn=remove_bn, verbose=verbose) 77 | prev = m 78 | 79 | 80 | def absorb_fake_bn(module, bn_module, verbose=False): 81 | with torch.no_grad(): 82 | w = module.weight 83 | if module.bias is None: 84 | zeros = torch.zeros(module.out_channels, 85 | dtype=w.dtype, device=w.device) 86 | bias = nn.Parameter(zeros) 87 | module.register_parameter('bias', bias) 88 | 89 | if verbose: 90 | logging.info('BN module %s was asborbed into layer %s' % 91 | (bn_module, module)) 92 | 93 | 94 | def is_fake_bn(m): 95 | from models.resnet import Lambda 96 | return isinstance(m, Lambda) 97 | 98 | 99 | def search_absorbe_fake_bn(model, prev=None, remove_bn=True, verbose=False): 100 | with torch.no_grad(): 101 | for m in model.children(): 102 | if is_fake_bn(m) and is_absorbing(prev): 103 | # print(prev,m) 104 | absorb_fake_bn(prev, m, verbose=verbose) 105 | search_absorbe_fake_bn(m, remove_bn=remove_bn, verbose=verbose) 106 | prev = m 107 | 108 | 109 | def add_bn(module, bn_module, verbose=False): 110 | bn = nn.BatchNorm2d(module.out_channels) 111 | 112 | def bn_forward(bn, x): 113 | res = bn(x) 114 | return res 115 | 116 | bn_module.forward_orig = bn_module.forward 117 | bn_module.forward = lambda x: bn_forward(bn, x) 118 | bn.to(module.weight.device) 119 | 120 | bn.register_buffer('running_var', module.gamma**2) 121 | bn.register_buffer('running_mean', module.beta.clone()) 122 | bn.register_parameter('weight', nn.Parameter(torch.sqrt(bn.running_var + bn.eps))) 123 | bn.register_parameter('bias', nn.Parameter(bn.running_mean.clone())) 124 | 125 | bn_module.bn = bn 126 | 127 | 128 | def need_tuning(module): 129 | return hasattr(module, 'num_bits') #and module.groups == 1 130 | 131 | 132 | def search_add_bn(model, prev=None, remove_bn=True, verbose=False): 133 | with torch.no_grad(): 134 | for m in model.children(): 135 | if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev): 136 | # print(prev,m) 137 | add_bn(prev, m, verbose=verbose) 138 | search_add_bn(m, remove_bn=remove_bn, verbose=verbose) 139 | prev = m 140 | 141 | 142 | def search_absorbe_tuning_bn(model, prev=None, remove_bn=True, verbose=False): 143 | with torch.no_grad(): 144 | for m in model.children(): 145 | if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev): 146 | # print(prev,m) 147 | absorb_bn(prev, m.bn, remove_bn=remove_bn, verbose=verbose) 148 | m.forward = m.forward_orig 149 | m.bn = None 150 | search_absorbe_tuning_bn(m, remove_bn=remove_bn, verbose=verbose) 151 | prev = m 152 | 153 | 154 | def copy_bn_params(module, bn_module, remove_bn=True, verbose=False): 155 | with torch.no_grad(): 156 | if hasattr(bn_module, 'weight'): 157 | module.register_parameter('gamma', nn.Parameter(bn_module.weight.data.clone())) 158 | 159 | if hasattr(bn_module, 'bias'): 160 | module.register_parameter('beta', nn.Parameter(bn_module.bias.data.clone())) 161 | 162 | 163 | def search_copy_bn_params(model, prev=None, remove_bn=True, verbose=False): 164 | with torch.no_grad(): 165 | for m in model.children(): 166 | if is_bn(m) and is_absorbing(prev): 167 | # print(prev,m) 168 | copy_bn_params(prev, m, remove_bn=remove_bn, verbose=verbose) 169 | search_copy_bn_params(m, remove_bn=remove_bn, verbose=verbose) 170 | prev = m 171 | 172 | 173 | # def recalibrate_bn(module, bn_module, verbose=False): 174 | # bn = bn_module.bn 175 | # bn.register_parameter('weight', nn.Parameter(torch.sqrt(bn.running_var + bn.eps))) 176 | # bn.register_parameter('bias', nn.Parameter(bn.running_mean.clone())) 177 | # 178 | # 179 | # def search_bn_recalibrate(model, prev=None, remove_bn=True, verbose=False): 180 | # with torch.no_grad(): 181 | # for m in model.children(): 182 | # if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev): 183 | # recalibrate_bn(prev, m, verbose=verbose) 184 | # search_bn_recalibrate(m, remove_bn=remove_bn, verbose=verbose) 185 | # prev = m 186 | -------------------------------------------------------------------------------- /utils/adaquant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | import scipy.optimize as opt 7 | import math 8 | 9 | 10 | def optimize_qparams(layer, cached_inps, cached_outs, test_inp, test_out, batch_size=100): 11 | print("\nOptimize quantization params") 12 | w_range_orig = layer.quantize_weight.running_range.data.clone() 13 | w_zp_orig = layer.quantize_weight.running_zero_point.data.clone() 14 | inp_range_orig = layer.quantize_input.running_range.data.clone() 15 | inp_zp_orig = layer.quantize_input.running_zero_point.data.clone() 16 | 17 | def layer_err(p, inp, out): 18 | layer.quantize_weight.running_range.data = w_range_orig * p[0] 19 | layer.quantize_weight.running_zero_point.data = w_zp_orig + p[1] 20 | layer.quantize_input.running_range.data = inp_range_orig * p[2] 21 | layer.quantize_input.running_zero_point.data = inp_zp_orig + p[3] 22 | yq = layer(inp) 23 | return F.mse_loss(yq, out).item() 24 | 25 | init = np.array([1, 0, 1, 0]) 26 | results = [] 27 | for i in tqdm(range(int(cached_inps.size(0) / batch_size))): 28 | cur_inp = cached_inps[i * batch_size:(i + 1) * batch_size] 29 | cur_out = cached_outs[i * batch_size:(i + 1) * batch_size] 30 | 31 | # print("init:") 32 | # print(init) 33 | res = opt.minimize(lambda p: layer_err(p, cur_inp, cur_out), init, method=methods[0]) 34 | results.append(res.x) 35 | 36 | mean_res = np.array(results).mean(axis=0) 37 | print(mean_res) 38 | mse_before = layer_err(init, test_inp, test_out) 39 | mse_after = layer_err(mean_res, test_inp, test_out) 40 | return mse_before, mse_after 41 | 42 | 43 | def adaquant(layer, cached_inps, cached_outs, test_inp, test_out, lr1=1e-4, lr2=1e-2, iters=100, progress=True, batch_size=50): 44 | print("\nRun adaquant") 45 | mse_before = F.mse_loss(layer(test_inp), test_out) 46 | 47 | # lr_factor = 1e-2 48 | # Those hyperparameters tuned for 8 bit and checked on mobilenet_v2 and resnet50 49 | # Have to verify on other bit-width and other models 50 | lr_qpin = 1e-1#lr_factor * (test_inp.max() - test_inp.min()).item() # 1e-1 51 | lr_qpw = 1e-3#lr_factor * (layer.weight.max() - layer.weight.min()).item() # 1e-3 52 | lr_w = 1e-5#lr_factor * layer.weight.std().item() # 1e-5 53 | lr_b = 1e-3#lr_factor * layer.bias.std().item() # 1e-3 54 | 55 | opt_w = torch.optim.Adam([layer.weight], lr=lr_w) 56 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias = torch.optim.Adam([layer.bias], lr=lr_b) 57 | opt_qparams_in = torch.optim.Adam([layer.quantize_input.running_range, 58 | layer.quantize_input.running_zero_point], lr=lr_qpin) 59 | opt_qparams_w = torch.optim.Adam([layer.quantize_weight.running_range, 60 | layer.quantize_weight.running_zero_point], lr=lr_qpw) 61 | 62 | losses = [] 63 | for j in (tqdm(range(iters)) if progress else range(iters)): 64 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 65 | 66 | train_inp = cached_inps[idx]#.cuda() 67 | train_out = cached_outs[idx]#.cuda() 68 | 69 | qout = layer(train_inp) 70 | loss = F.mse_loss(qout, train_out) 71 | 72 | losses.append(loss.item()) 73 | opt_w.zero_grad() 74 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias.zero_grad() 75 | opt_qparams_in.zero_grad() 76 | opt_qparams_w.zero_grad() 77 | loss.backward() 78 | opt_w.step() 79 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias.step() 80 | opt_qparams_in.step() 81 | opt_qparams_w.step() 82 | 83 | # if len(losses) < 10: 84 | # total_loss = loss.item() 85 | # else: 86 | # total_loss = np.mean(losses[-10:]) 87 | # print("mse out: {}, pc mean loss: {}, total: {}".format(mse_out.item(), mean_loss.item(), total_loss)) 88 | 89 | mse_after = F.mse_loss(layer(test_inp), test_out) 90 | return mse_before.item(), mse_after.item() 91 | 92 | 93 | def optimize_layer(layer, in_out, optimize_weights=False): 94 | batch_size = 100 95 | 96 | # if layer.name == 'features.17.conv.0.0' or layer.name == 'features.17.conv.1.0': 97 | # dump("mobilenet_v2", layer, in_out) 98 | # return 0, 0, 0, 0, 0, 0 99 | 100 | cached_inps = torch.cat([x[0] for x in in_out]).to(layer.weight.device) 101 | cached_outs = torch.cat([x[1] for x in in_out]).to(layer.weight.device) 102 | 103 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 104 | 105 | test_inp = cached_inps[idx] 106 | test_out = cached_outs[idx] 107 | 108 | # mse_before, mse_after = optimize_qparams(layer, cached_inps, cached_outs, test_inp, test_out) 109 | # mse_before_opt = mse_before 110 | # print("MSE before qparams: {}".format(mse_before)) 111 | # print("MSE after qparams: {}".format(mse_after)) 112 | 113 | if optimize_weights: 114 | mse_before, mse_after = adaquant(layer, cached_inps, cached_outs, test_inp, test_out, iters=100, lr1=1e-5, lr2=1e-4) 115 | mse_before_opt = mse_before 116 | print("MSE before adaquant: {}".format(mse_before)) 117 | print("MSE after adaquant: {}".format(mse_after)) 118 | torch.cuda.empty_cache() 119 | else: 120 | mse_before, mse_after = optimize_qparams(layer, cached_inps, cached_outs, test_inp, test_out) 121 | mse_before_opt = mse_before 122 | print("MSE before qparams: {}".format(mse_before)) 123 | print("MSE after qparams: {}".format(mse_after)) 124 | 125 | mse_after_opt = mse_after 126 | 127 | with torch.no_grad(): 128 | N = test_out.numel() 129 | snr_before = (1/math.sqrt(N)) * math.sqrt(N * mse_before_opt) / torch.norm(test_out).item() 130 | snr_after = (1/math.sqrt(N)) * math.sqrt(N * mse_after_opt) / torch.norm(test_out).item() 131 | 132 | # optimize_rounding(layer, cached_inps, cached_outs, test_inp, test_out, iters=7000) 133 | # optimize_qparams(layer, cached_inps, cached_outs, test_inp, test_out) 134 | # optimize_rounding(layer, cached_inps, cached_outs, test_inp, test_out, iters=2000) 135 | # optimize_qparams(layer, cached_inps, cached_outs, test_inp, test_out) 136 | # optimize_rounding(layer, cached_inps, cached_outs, test_inp, test_out, iters=2000) 137 | # optimize_qparams(layer, test_inp, test_out) 138 | 139 | kurt_in = kurtosis(test_inp).item() 140 | kurt_w = kurtosis(layer.weight).item() 141 | 142 | del cached_inps 143 | del cached_outs 144 | torch.cuda.empty_cache() 145 | 146 | return mse_before_opt, mse_after_opt, snr_before, snr_after, kurt_in, kurt_w 147 | 148 | 149 | def kurtosis(x): 150 | var = torch.mean((x - x.mean())**2) 151 | return torch.mean((x - x.mean())**4 / var**2) 152 | 153 | 154 | def dump(model_name, layer, in_out): 155 | path = os.path.join("dump", model_name, layer.name) 156 | if os.path.exists(path): 157 | shutil.rmtree(path) 158 | os.makedirs(path) 159 | 160 | if hasattr(layer, 'groups'): 161 | f = open(os.path.join(path, "groups_{}".format(layer.groups)), 'x') 162 | f.close() 163 | 164 | cached_inps = torch.cat([x[0] for x in in_out]) 165 | cached_outs = torch.cat([x[1] for x in in_out]) 166 | torch.save(cached_inps, os.path.join(path, "input.pt")) 167 | torch.save(cached_outs, os.path.join(path, "output.pt")) 168 | torch.save(layer.weight, os.path.join(path, 'weight.pt')) 169 | if layer.bias is not None: 170 | torch.save(layer.bias, os.path.join(path, 'bias.pt')) 171 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms as transforms 4 | import random 5 | import PIL 6 | 7 | 8 | _IMAGENET_STATS = {'mean': [0.485, 0.456, 0.406], 9 | 'std': [0.229, 0.224, 0.225]} 10 | 11 | _IMAGENET_PCA = { 12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 13 | 'eigvec': torch.Tensor([ 14 | [-0.5675, 0.7192, 0.4009], 15 | [-0.5808, -0.0045, -0.8140], 16 | [-0.5836, -0.6948, 0.4203], 17 | ]) 18 | } 19 | 20 | 21 | def scale_crop(input_size, scale_size=None, num_crops=1, normalize=_IMAGENET_STATS): 22 | assert num_crops in [1, 5, 10], "num crops must be in {1,5,10}" 23 | convert_tensor = transforms.Compose([transforms.ToTensor(), 24 | transforms.Normalize(**normalize)]) 25 | if num_crops == 1: 26 | t_list = [ 27 | transforms.CenterCrop(input_size), 28 | convert_tensor 29 | ] 30 | else: 31 | if num_crops == 5: 32 | t_list = [transforms.FiveCrop(input_size)] 33 | elif num_crops == 10: 34 | t_list = [transforms.TenCrop(input_size)] 35 | # returns a 4D tensor 36 | t_list.append(transforms.Lambda(lambda crops: 37 | torch.stack([convert_tensor(crop) for crop in crops]))) 38 | 39 | if scale_size != input_size: 40 | t_list = [transforms.Resize(scale_size)] + t_list 41 | 42 | return transforms.Compose(t_list) 43 | 44 | 45 | def scale_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS): 46 | t_list = [ 47 | transforms.RandomCrop(input_size), 48 | transforms.ToTensor(), 49 | transforms.Normalize(**normalize), 50 | ] 51 | if scale_size != input_size: 52 | t_list = [transforms.Resize(scale_size)] + t_list 53 | 54 | transforms.Compose(t_list) 55 | 56 | 57 | def pad_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS): 58 | padding = int((scale_size - input_size) / 2) 59 | return transforms.Compose([ 60 | transforms.RandomCrop(input_size, padding=padding), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), 63 | transforms.Normalize(**normalize), 64 | ]) 65 | 66 | 67 | 68 | def inception_preproccess(input_size, normalize=_IMAGENET_STATS): 69 | return transforms.Compose([ 70 | transforms.RandomResizedCrop(input_size), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | transforms.Normalize(**normalize) 74 | ]) 75 | 76 | 77 | def inception_color_preproccess(input_size, normalize=_IMAGENET_STATS): 78 | return transforms.Compose([ 79 | transforms.RandomResizedCrop(input_size), 80 | transforms.RandomHorizontalFlip(), 81 | transforms.ColorJitter( 82 | brightness=0.4, 83 | contrast=0.4, 84 | saturation=0.4, 85 | ), 86 | transforms.ToTensor(), 87 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 88 | transforms.Normalize(**normalize) 89 | ]) 90 | 91 | 92 | def multi_transform(transform_fn, duplicates=1, dim=0): 93 | """preforms multiple transforms, useful to implement inference time augmentation or 94 | "batch augmentation" from https://openreview.net/forum?id=H1V4QhAqYQ¬eId=BylUSs_3Y7 95 | """ 96 | if duplicates > 1: 97 | return transforms.Lambda(lambda x: torch.stack([transform_fn(x) for _ in range(duplicates)], dim=dim)) 98 | else: 99 | return transform_fn 100 | 101 | 102 | def get_transform(transform_name='imagenet', input_size=None, scale_size=None, 103 | normalize=None, augment=True, cutout=None, autoaugment=False, 104 | duplicates=1, num_crops=1): 105 | normalize = normalize or _IMAGENET_STATS 106 | transform_fn = None 107 | 108 | if 'imagenet' in transform_name: # inception augmentation is default for imagenet 109 | scale_size = scale_size or 256 110 | input_size = input_size or 224 111 | if augment: 112 | transform_fn = inception_preproccess(input_size, 113 | normalize=normalize) 114 | else: 115 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 116 | num_crops=num_crops, normalize=normalize) 117 | elif 'cifar' in transform_name: # resnet augmentation is default for imagenet 118 | input_size = input_size or 32 119 | if augment: 120 | scale_size = scale_size or 40 121 | if autoaugment: 122 | transform_fn = cifar_autoaugment(input_size, scale_size=scale_size, 123 | normalize=normalize) 124 | else: 125 | transform_fn = pad_random_crop(input_size, scale_size=scale_size, 126 | normalize=normalize) 127 | else: 128 | scale_size = scale_size or 32 129 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 130 | num_crops=num_crops, normalize=normalize) 131 | elif transform_name == 'mnist': 132 | normalize = {'mean': [0.5], 'std': [0.5]} 133 | input_size = input_size or 28 134 | if augment: 135 | scale_size = scale_size or 32 136 | transform_fn = pad_random_crop(input_size, scale_size=scale_size, 137 | normalize=normalize) 138 | else: 139 | scale_size = scale_size or 32 140 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 141 | num_crops=num_crops, normalize=normalize) 142 | if cutout is not None: 143 | transform_fn.transforms.append(Cutout(**cutout)) 144 | return multi_transform(transform_fn, duplicates) 145 | 146 | 147 | class Lighting(object): 148 | """Lighting noise(AlexNet - style PCA - based noise)""" 149 | 150 | def __init__(self, alphastd, eigval, eigvec): 151 | self.alphastd = alphastd 152 | self.eigval = eigval 153 | self.eigvec = eigvec 154 | 155 | def __call__(self, img): 156 | if self.alphastd == 0: 157 | return img 158 | 159 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 160 | rgb = self.eigvec.type_as(img).clone()\ 161 | .mul(alpha.view(1, 3).expand(3, 3))\ 162 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 163 | .sum(1).squeeze() 164 | 165 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 166 | 167 | 168 | class Cutout(object): 169 | """ 170 | Randomly mask out one or more patches from an image. 171 | taken from https://github.com/uoguelph-mlrg/Cutout 172 | 173 | 174 | Args: 175 | holes (int): Number of patches to cut out of each image. 176 | length (int): The length (in pixels) of each square patch. 177 | """ 178 | 179 | def __init__(self, holes, length): 180 | self.holes = holes 181 | self.length = length 182 | 183 | def __call__(self, img): 184 | """ 185 | Args: 186 | img (Tensor): Tensor image of size (C, H, W). 187 | Returns: 188 | Tensor: Image with holes of dimension length x length cut out of it. 189 | """ 190 | h = img.size(1) 191 | w = img.size(2) 192 | 193 | mask = np.ones((h, w), np.float32) 194 | 195 | for n in range(self.holes): 196 | y = np.random.randint(h) 197 | x = np.random.randint(w) 198 | 199 | y1 = np.clip(y - self.length // 2, 0, h) 200 | y2 = np.clip(y + self.length // 2, 0, h) 201 | x1 = np.clip(x - self.length // 2, 0, w) 202 | x2 = np.clip(x + self.length // 2, 0, w) 203 | 204 | mask[y1: y2, x1: x2] = 0. 205 | 206 | mask = torch.from_numpy(mask) 207 | mask = mask.expand_as(img) 208 | img = img * mask 209 | 210 | return img 211 | -------------------------------------------------------------------------------- /models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.utils import _single, _pair, _triple 4 | import math 5 | import torch.nn.functional as F 6 | from torch.nn.modules.utils import _pair 7 | import torchvision.transforms as transforms 8 | from models.resnet import depBatchNorm2d 9 | from .modules.quantize import QConv2d, QLinear, RangeBN 10 | __all__ = ['mobilenet_v2'] 11 | 12 | 13 | class ConvBNReLU(nn.Sequential): 14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, 15 | batch_norm=True,num_bits=8, num_bits_weight=8, perC=True, measure=False, cal_qparams=False): 16 | padding = (kernel_size - 1) // 2 17 | super(ConvBNReLU, self).__init__( 18 | QConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False, 19 | num_bits=num_bits, num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams), 20 | depBatchNorm2d(batch_norm, out_planes), 21 | nn.ReLU6(inplace=False) 22 | ) 23 | 24 | 25 | def _make_divisible(v, divisor, min_value=None): 26 | """ 27 | This function is taken from the original tf repo. 28 | It ensures that all layers have a channel number that is divisible by 8 29 | It can be seen here: 30 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 31 | :param v: 32 | :param divisor: 33 | :param min_value: 34 | :return: 35 | """ 36 | if min_value is None: 37 | min_value = divisor 38 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 39 | # Make sure that round down does not go down by more than 10%. 40 | if new_v < 0.9 * v: 41 | new_v += divisor 42 | return new_v 43 | 44 | 45 | class InvertedResidual(nn.Module): 46 | def __init__(self, inp, oup, stride, expand_ratio, batch_norm=True,num_bits=8, num_bits_weight=8, 47 | perC=True, measure=False, cal_qparams=False): 48 | super(InvertedResidual, self).__init__() 49 | self.stride = stride 50 | assert stride in [1, 2] 51 | 52 | hidden_dim = int(round(inp * expand_ratio)) 53 | self.use_res_connect = self.stride == 1 and inp == oup 54 | 55 | layers = [] 56 | if expand_ratio != 1: 57 | # pw 58 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, batch_norm=batch_norm, num_bits=num_bits, 59 | num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams)) 60 | layers.extend([ 61 | # dw 62 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, batch_norm=batch_norm, num_bits=num_bits, 63 | num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams), 64 | # pw-linear 65 | QConv2d(hidden_dim, oup, 1, 1, 0, bias=False, num_bits=num_bits, 66 | num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams), 67 | depBatchNorm2d(batch_norm, oup), 68 | ]) 69 | self.conv = nn.Sequential(*layers) 70 | 71 | def forward(self, x): 72 | if self.use_res_connect: 73 | return x + self.conv(x) 74 | else: 75 | return self.conv(x) 76 | 77 | 78 | class MobileNetV2(nn.Module): 79 | def __init__(self, 80 | num_classes=1000, 81 | width_mult=1.0, 82 | inverted_residual_setting=None, 83 | round_nearest=8, 84 | block=None, batch_norm=True,num_bits=8, num_bits_weight=8, perC=True, measure=False, cal_qparams=False): 85 | """ 86 | MobileNet V2 main class 87 | Args: 88 | num_classes (int): Number of classes 89 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 90 | inverted_residual_setting: Network structure 91 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 92 | Set to 1 to turn off rounding 93 | block: Module specifying inverted residual building block for mobilenet 94 | """ 95 | super(MobileNetV2, self).__init__() 96 | 97 | if block is None: 98 | block = InvertedResidual 99 | input_channel = 32 100 | last_channel = 1280 101 | 102 | if inverted_residual_setting is None: 103 | inverted_residual_setting = [ 104 | # t, c, n, s 105 | [1, 16, 1, 1], 106 | [6, 24, 2, 2], 107 | [6, 32, 3, 2], 108 | [6, 64, 4, 2], 109 | [6, 96, 3, 1], 110 | [6, 160, 3, 2], 111 | [6, 320, 1, 1], 112 | ] 113 | 114 | # only check the first element, assuming user knows t,c,n,s are required 115 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 116 | raise ValueError("inverted_residual_setting should be non-empty " 117 | "or a 4-element list, got {}".format(inverted_residual_setting)) 118 | 119 | # building first layer 120 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 121 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 122 | features = [ConvBNReLU(3, input_channel, stride=2, batch_norm=batch_norm, num_bits=num_bits, 123 | num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams)] 124 | # building inverted residual blocks 125 | for t, c, n, s in inverted_residual_setting: 126 | output_channel = _make_divisible(c * width_mult, round_nearest) 127 | for i in range(n): 128 | stride = s if i == 0 else 1 129 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, batch_norm=batch_norm, num_bits=num_bits, 130 | num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams)) 131 | input_channel = output_channel 132 | # building last several layers 133 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, batch_norm=batch_norm, num_bits=num_bits, 134 | num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams)) 135 | # make it nn.Sequential 136 | self.features = nn.Sequential(*features) 137 | 138 | # building classifier 139 | self.classifier = nn.Sequential( 140 | nn.Dropout(0.2), 141 | QLinear(self.last_channel, num_classes, num_bits=num_bits, num_bits_weight=num_bits_weight, 142 | perC=perC, measure=measure, cal_qparams=cal_qparams), 143 | ) 144 | 145 | # weight initialization 146 | for m in self.modules(): 147 | if isinstance(m, nn.Conv2d): 148 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 149 | if m.bias is not None: 150 | nn.init.zeros_(m.bias) 151 | elif isinstance(m, nn.BatchNorm2d): 152 | nn.init.ones_(m.weight) 153 | nn.init.zeros_(m.bias) 154 | elif isinstance(m, nn.Linear): 155 | nn.init.normal_(m.weight, 0, 0.01) 156 | nn.init.zeros_(m.bias) 157 | 158 | def _forward_impl(self, x): 159 | # This exists since TorchScript doesn't support inheritance, so the superclass method 160 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 161 | x = self.features(x) 162 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 163 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 164 | x = self.classifier(x) 165 | return x 166 | 167 | def forward(self, x): 168 | return self._forward_impl(x) 169 | 170 | 171 | def mobilenet_v2(**config): 172 | r"""MobileNet v2 model architecture from the `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" 173 | `_ paper. 174 | """ 175 | dataset = config.pop('dataset', 'imagenet') 176 | assert dataset == 'imagenet' 177 | return MobileNetV2(**config) 178 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import models 12 | import torch.distributed as dist 13 | from data import DataRegime 14 | from utils.log import setup_logging, ResultsLog, save_checkpoint 15 | from utils.optim import OptimRegime 16 | from utils.cross_entropy import CrossEntropyLoss 17 | from utils.misc import torch_dtypes 18 | from utils.param_filter import FilterModules, is_bn 19 | from datetime import datetime 20 | from ast import literal_eval 21 | from trainer import Trainer 22 | 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Evaluation') 28 | parser.add_argument('evaluate', type=str, 29 | help='evaluate model FILE on validation set') 30 | parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results', 31 | help='results dir') 32 | parser.add_argument('--save', metavar='SAVE', default='', 33 | help='saved folder') 34 | parser.add_argument('--datasets-dir', metavar='DATASETS_DIR', default='~/Datasets', 35 | help='datasets dir') 36 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 37 | help='dataset name or folder') 38 | parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet', 39 | choices=model_names, 40 | help='model architecture: ' + 41 | ' | '.join(model_names) + 42 | ' (default: alexnet)') 43 | parser.add_argument('--input-size', type=int, default=None, 44 | help='image input size') 45 | parser.add_argument('--model-config', default='', 46 | help='additional architecture configuration') 47 | parser.add_argument('--dtype', default='float', 48 | help='type of tensor: ' + 49 | ' | '.join(torch_dtypes.keys()) + 50 | ' (default: float)') 51 | parser.add_argument('--device', default='cuda', 52 | help='device assignment ("cpu" or "cuda")') 53 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+', 54 | help='device ids assignment (e.g 0 1 2 3') 55 | parser.add_argument('--world-size', default=-1, type=int, 56 | help='number of distributed processes') 57 | parser.add_argument('--local_rank', default=-1, type=int, 58 | help='rank of distributed processes') 59 | parser.add_argument('--dist-init', default='env://', type=str, 60 | help='init used to set up distributed training') 61 | parser.add_argument('--dist-backend', default='nccl', type=str, 62 | help='distributed backend') 63 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 64 | help='number of data loading workers (default: 8)') 65 | parser.add_argument('-b', '--batch-size', default=256, type=int, 66 | metavar='N', help='mini-batch size (default: 256)') 67 | parser.add_argument('--label-smoothing', default=0, type=float, 68 | help='label smoothing coefficient - default 0') 69 | parser.add_argument('--mixup', default=None, type=float, 70 | help='mixup alpha coefficient - default None') 71 | parser.add_argument('--duplicates', default=1, type=int, 72 | help='number of augmentations over singel example') 73 | parser.add_argument('--chunk-batch', default=1, type=int, 74 | help='chunk batch size for multiple passes (training)') 75 | parser.add_argument('--augment', action='store_true', default=False, 76 | help='perform augmentations') 77 | parser.add_argument('--cutout', action='store_true', default=False, 78 | help='cutout augmentations') 79 | parser.add_argument('--autoaugment', action='store_true', default=False, 80 | help='use autoaugment policies') 81 | parser.add_argument('--avg-out', action='store_true', default=False, 82 | help='average outputs') 83 | parser.add_argument('--print-freq', '-p', default=10, type=int, 84 | metavar='N', help='print frequency (default: 10)') 85 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 86 | help='path to latest checkpoint (default: none)') 87 | 88 | parser.add_argument('--seed', default=123, type=int, 89 | help='random seed (default: 123)') 90 | 91 | 92 | def main(): 93 | args = parser.parse_args() 94 | main_worker(args) 95 | 96 | 97 | def main_worker(args): 98 | global best_prec1, dtype 99 | best_prec1 = 0 100 | dtype = torch_dtypes.get(args.dtype) 101 | torch.manual_seed(args.seed) 102 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 103 | if args.evaluate: 104 | args.results_dir = '/tmp' 105 | if args.save is '': 106 | args.save = time_stamp 107 | save_path = os.path.join(args.results_dir, args.save) 108 | 109 | args.distributed = args.local_rank >= 0 or args.world_size > 1 110 | 111 | if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0): 112 | os.makedirs(save_path) 113 | 114 | setup_logging(os.path.join(save_path, 'log.txt'), 115 | resume=args.resume is not '', 116 | dummy=args.distributed and args.local_rank > 0) 117 | 118 | results_path = os.path.join(save_path, 'results') 119 | results = ResultsLog( 120 | results_path, title='Training Results - %s' % args.save) 121 | 122 | if 'cuda' in args.device and torch.cuda.is_available(): 123 | torch.cuda.manual_seed_all(args.seed) 124 | torch.cuda.set_device(args.device_ids[0]) 125 | cudnn.benchmark = True 126 | else: 127 | args.device_ids = None 128 | 129 | if not os.path.isfile(args.evaluate): 130 | parser.error('invalid checkpoint: {}'.format(args.evaluate)) 131 | checkpoint = torch.load(args.evaluate, map_location="cpu") 132 | # Overrride configuration with checkpoint info 133 | args.model = checkpoint.get('model', args.model) 134 | args.model_config = checkpoint.get('config', args.model_config) 135 | 136 | logging.info("saving to %s", save_path) 137 | logging.debug("run arguments: %s", args) 138 | logging.info("creating model %s", args.model) 139 | 140 | # create model 141 | model = models.__dict__[args.model] 142 | model_config = {'dataset': args.dataset} 143 | 144 | if args.model_config is not '': 145 | model_config = dict(model_config, **literal_eval(args.model_config)) 146 | 147 | model = model(**model_config) 148 | logging.info("created model with configuration: %s", model_config) 149 | num_parameters = sum([l.nelement() for l in model.parameters()]) 150 | logging.info("number of parameters: %d", num_parameters) 151 | 152 | # load checkpoint 153 | model.load_state_dict(checkpoint['state_dict']) 154 | logging.info("loaded checkpoint '%s' (epoch %s)", 155 | args.evaluate, checkpoint['epoch']) 156 | 157 | # define loss function (criterion) and optimizer 158 | loss_params = {} 159 | if args.label_smoothing > 0: 160 | loss_params['smooth_eps'] = args.label_smoothing 161 | criterion = getattr(model, 'criterion', nn.NLLLoss)(**loss_params) 162 | criterion.to(args.device, dtype) 163 | model.to(args.device, dtype) 164 | 165 | # Batch-norm should always be done in float 166 | if 'half' in args.dtype: 167 | FilterModules(model, module=is_bn).to(dtype=torch.float) 168 | 169 | trainer = Trainer(model, criterion, 170 | device_ids=args.device_ids, device=args.device, dtype=dtype, 171 | mixup=args.mixup, print_freq=args.print_freq) 172 | 173 | # Evaluation Data loading code 174 | val_data = DataRegime(getattr(model, 'data_eval_regime', None), 175 | defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': args.augment, 176 | 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': False, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 177 | 'cutout': {'holes': 1, 'length': 16} if args.cutout else None, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False}) 178 | 179 | results = trainer.validate(val_data.get_loader(), 180 | duplicates=val_data.get('duplicates'), 181 | average_output=args.avg_out) 182 | logging.info(results) 183 | return results 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from itertools import cycle 4 | import torch 5 | import logging.config 6 | from datetime import datetime 7 | import json 8 | 9 | import pandas as pd 10 | from bokeh.io import output_file, save, show 11 | from bokeh.plotting import figure 12 | from bokeh.layouts import column 13 | from bokeh.models import Div 14 | 15 | try: 16 | import hyperdash 17 | HYPERDASH_AVAILABLE = True 18 | except ImportError: 19 | HYPERDASH_AVAILABLE = False 20 | 21 | 22 | def export_args_namespace(args, filename): 23 | """ 24 | args: argparse.Namespace 25 | arguments to save 26 | filename: string 27 | filename to save at 28 | """ 29 | with open(filename, 'w') as fp: 30 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) 31 | 32 | 33 | def setup_logging(log_file='log.txt', resume=False, dummy=False): 34 | """ 35 | Setup logging configuration 36 | """ 37 | if dummy: 38 | logging.getLogger('dummy') 39 | else: 40 | if os.path.isfile(log_file) and resume: 41 | file_mode = 'a' 42 | else: 43 | file_mode = 'w' 44 | 45 | root_logger = logging.getLogger() 46 | if root_logger.handlers: 47 | root_logger.removeHandler(root_logger.handlers[0]) 48 | logging.basicConfig(level=logging.DEBUG, 49 | format="%(asctime)s - %(levelname)s - %(message)s", 50 | datefmt="%Y-%m-%d %H:%M:%S", 51 | filename=log_file, 52 | filemode=file_mode) 53 | console = logging.StreamHandler() 54 | console.setLevel(logging.INFO) 55 | formatter = logging.Formatter('%(message)s') 56 | console.setFormatter(formatter) 57 | logging.getLogger('').addHandler(console) 58 | 59 | 60 | def plot_figure(data, x, y, title=None, xlabel=None, ylabel=None, legend=None, 61 | x_axis_type='linear', y_axis_type='linear', 62 | width=800, height=400, line_width=2, 63 | colors=['red', 'green', 'blue', 'orange', 64 | 'black', 'purple', 'brown'], 65 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save', 66 | append_figure=None): 67 | """ 68 | creates a new plot figures 69 | example: 70 | plot_figure(x='epoch', y=['train_loss', 'val_loss'], 71 | 'title='Loss', 'ylabel'='loss') 72 | """ 73 | if not isinstance(y, list): 74 | y = [y] 75 | xlabel = xlabel or x 76 | legend = legend or y 77 | assert len(legend) == len(y) 78 | if append_figure is not None: 79 | f = append_figure 80 | else: 81 | f = figure(title=title, tools=tools, 82 | width=width, height=height, 83 | x_axis_label=xlabel or x, 84 | y_axis_label=ylabel or '', 85 | x_axis_type=x_axis_type, 86 | y_axis_type=y_axis_type) 87 | colors = cycle(colors) 88 | for i, yi in enumerate(y): 89 | f.line(data[x], data[yi], 90 | line_width=line_width, 91 | line_color=next(colors), legend=legend[i]) 92 | f.legend.click_policy = "hide" 93 | return f 94 | 95 | 96 | class ResultsLog(object): 97 | 98 | supported_data_formats = ['csv', 'json'] 99 | 100 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'): 101 | """ 102 | Parameters 103 | ---------- 104 | path: string 105 | path to directory to save data files 106 | plot_path: string 107 | path to directory to save plot files 108 | title: string 109 | title of HTML file 110 | params: Namespace 111 | optionally save parameters for results 112 | resume: bool 113 | resume previous logging 114 | data_format: str('csv'|'json') 115 | which file format to use to save the data 116 | """ 117 | if data_format not in ResultsLog.supported_data_formats: 118 | raise ValueError('data_format must of the following: ' + 119 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats])) 120 | 121 | if data_format == 'json': 122 | self.data_path = '{}.json'.format(path) 123 | else: 124 | self.data_path = '{}.csv'.format(path) 125 | if params is not None: 126 | export_args_namespace(params, '{}.json'.format(path)) 127 | self.plot_path = '{}.html'.format(path) 128 | self.results = None 129 | self.clear() 130 | self.first_save = True 131 | if os.path.isfile(self.data_path): 132 | if resume: 133 | self.load(self.data_path) 134 | self.first_save = False 135 | else: 136 | os.remove(self.data_path) 137 | self.results = pd.DataFrame() 138 | else: 139 | self.results = pd.DataFrame() 140 | 141 | self.title = title 142 | self.data_format = data_format 143 | 144 | if HYPERDASH_AVAILABLE: 145 | name = self.title if title != '' else path 146 | self.hd_experiment = hyperdash.Experiment(name) 147 | if params is not None: 148 | for k, v in params._get_kwargs(): 149 | self.hd_experiment.param(k, v, log=False) 150 | 151 | def clear(self): 152 | self.figures = [] 153 | 154 | def add(self, **kwargs): 155 | """Add a new row to the dataframe 156 | example: 157 | resultsLog.add(epoch=epoch_num, train_loss=loss, 158 | test_loss=test_loss) 159 | """ 160 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 161 | self.results = self.results.append(df, ignore_index=True) 162 | if hasattr(self, 'hd_experiment'): 163 | for k, v in kwargs.items(): 164 | self.hd_experiment.metric(k, v, log=False) 165 | 166 | def smooth(self, column_name, window): 167 | """Select an entry to smooth over time""" 168 | # TODO: smooth only new data 169 | smoothed_column = self.results[column_name].rolling( 170 | window=window, center=False).mean() 171 | self.results[column_name + '_smoothed'] = smoothed_column 172 | 173 | def save(self, title=None): 174 | """save the json file. 175 | Parameters 176 | ---------- 177 | title: string 178 | title of the HTML file 179 | """ 180 | title = title or self.title 181 | if len(self.figures) > 0: 182 | if os.path.isfile(self.plot_path): 183 | os.remove(self.plot_path) 184 | if self.first_save: 185 | self.first_save = False 186 | logging.info('Plot file saved at: {}'.format( 187 | os.path.abspath(self.plot_path))) 188 | 189 | output_file(self.plot_path, title=title) 190 | plot = column( 191 | Div(text='

{}

'.format(title)), *self.figures) 192 | save(plot) 193 | self.clear() 194 | 195 | if self.data_format == 'json': 196 | self.results.to_json(self.data_path, orient='records', lines=True) 197 | else: 198 | self.results.to_csv(self.data_path, index=False, index_label=False) 199 | 200 | def load(self, path=None): 201 | """load the data file 202 | Parameters 203 | ---------- 204 | path: 205 | path to load the json|csv file from 206 | """ 207 | path = path or self.data_path 208 | if os.path.isfile(path): 209 | if self.data_format == 'json': 210 | self.results.read_json(path) 211 | else: 212 | self.results.read_csv(path) 213 | else: 214 | raise ValueError('{} isn''t a file'.format(path)) 215 | 216 | def show(self, title=None): 217 | title = title or self.title 218 | if len(self.figures) > 0: 219 | plot = column( 220 | Div(text='

{}

'.format(title)), *self.figures) 221 | show(plot) 222 | 223 | def plot(self, *kargs, **kwargs): 224 | """ 225 | add a new plot to the HTML file 226 | example: 227 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 228 | 'title='Loss', 'ylabel'='loss') 229 | """ 230 | f = plot_figure(self.results, *kargs, **kwargs) 231 | self.figures.append(f) 232 | 233 | def image(self, *kargs, **kwargs): 234 | fig = figure() 235 | fig.image(*kargs, **kwargs) 236 | self.figures.append(fig) 237 | 238 | def end(self): 239 | if hasattr(self, 'hd_experiment'): 240 | self.hd_experiment.end() 241 | 242 | 243 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 244 | filename = os.path.join(path, filename) 245 | torch.save(state, filename) 246 | if is_best: 247 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 248 | if save_all: 249 | shutil.copyfile(filename, os.path.join( 250 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 251 | -------------------------------------------------------------------------------- /models/mobilenet_v2_old.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.utils import _single, _pair, _triple 4 | import math 5 | import torch.nn.functional as F 6 | from torch.nn.modules.utils import _pair 7 | import torchvision.transforms as transforms 8 | from .modules.quantize import QConv2d, QLinear, RangeBN 9 | __all__ = ['mobilenet_v2'] 10 | 11 | 12 | def nearby_int(n): 13 | return int(round(n)) 14 | 15 | 16 | def init_model(model): 17 | for m in model.modules(): 18 | if isinstance(m, QConv2d): 19 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 20 | m.weight.data.normal_(0, math.sqrt(2. / n)) 21 | elif isinstance(m, nn.BatchNorm2d): 22 | m.weight.data.fill_(1) 23 | m.bias.data.zero_() 24 | 25 | 26 | def weight_decay_config(value=1e-4, log=False): 27 | def regularize_layer(m): 28 | non_depthwise_conv = isinstance(m, QConv2d) \ 29 | and m.groups != m.in_channels 30 | return isinstance(m, nn.Linear) or non_depthwise_conv 31 | 32 | return {'name': 'WeightDecay', 33 | 'value': value, 34 | 'log': log, 35 | 'filter': {'parameter_name': lambda n: not n.endswith('bias'), 36 | 'module': regularize_layer} 37 | } 38 | 39 | 40 | class ExpandedConv2d(nn.Module): 41 | 42 | def __init__(self, in_channels, out_channels, expansion=1, kernel_size=3, 43 | stride=1, padding=1, residual_block=None, batch_norm=True, 44 | num_bits=8, num_bits_weight=8, measure=False, perC=True, cal_qparams=False): 45 | expanded = in_channels * expansion 46 | super(ExpandedConv2d, self).__init__() 47 | self.add_res = stride == 1 and in_channels == out_channels 48 | self.residual_block = residual_block 49 | if expanded == in_channels: 50 | block = [] 51 | else: 52 | if batch_norm: 53 | block = [ 54 | QConv2d(in_channels, expanded, 1, bias=False, 55 | num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, perC=perC, cal_qparams=cal_qparams), 56 | nn.BatchNorm2d(expanded), 57 | nn.ReLU6(inplace=True), 58 | ] 59 | else: 60 | block = [ 61 | QConv2d(in_channels, expanded, 1, bias=True, 62 | num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, perC=perC, cal_qparams=cal_qparams), 63 | nn.ReLU6(inplace=True) 64 | ] 65 | 66 | if batch_norm: 67 | block += [ 68 | QConv2d(expanded, expanded, kernel_size, 69 | stride=stride, padding=padding, groups=expanded, bias=False, 70 | num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, perC=perC, cal_qparams=cal_qparams), 71 | nn.BatchNorm2d(expanded), 72 | nn.ReLU6(inplace=True), 73 | QConv2d(expanded, out_channels, 1, bias=False, 74 | num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, perC=perC, cal_qparams=cal_qparams), 75 | nn.BatchNorm2d(out_channels) 76 | ] 77 | else: 78 | block += [ 79 | QConv2d(expanded, expanded, kernel_size, 80 | stride=stride, padding=padding, groups=expanded, bias=True, 81 | num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, perC=perC, cal_qparams=cal_qparams), 82 | nn.ReLU6(inplace=True), 83 | QConv2d(expanded, out_channels, 1, bias=True, 84 | num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, perC=perC, cal_qparams=cal_qparams) 85 | ] 86 | 87 | 88 | self.block = nn.Sequential(*block) 89 | 90 | def forward(self, x): 91 | out = self.block(x) 92 | if self.add_res: 93 | if self.residual_block is not None: 94 | x = self.residual_block(x) 95 | out += x 96 | return out 97 | 98 | 99 | def conv(in_channels, out_channels, kernel=3, stride=1, padding=1,batch_norm=True,num_bits=8, num_bits_weight=8, measure=False, perC=True, cal_qparams=False): 100 | if batch_norm: 101 | return nn.Sequential( 102 | QConv2d(in_channels, out_channels, kernel, 103 | stride, padding, bias=False, 104 | num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, perC=perC, cal_qparams=cal_qparams), 105 | nn.BatchNorm2d(out_channels), 106 | nn.ReLU6(inplace=True) 107 | ) 108 | else: 109 | return nn.Sequential( 110 | QConv2d(in_channels, out_channels, kernel, 111 | stride, padding, bias=True, 112 | num_bits=8, num_bits_weight=8, measure=False, perC=perC, cal_qparams=cal_qparams), 113 | nn.ReLU6(inplace=True) 114 | ) 115 | 116 | 117 | class MobileNet_v2(nn.Module): 118 | 119 | def __init__(self, width=1., regime=None, num_classes=1000, scale_lr=1, batch_norm=True,num_bits=8, num_bits_weight=8, perC=True, measure=False, cal_qparams=False): 120 | super(MobileNet_v2, self).__init__() 121 | in_channels = nearby_int(width * 32) 122 | 123 | layers_config = [ 124 | dict(expansion=1, stride=1, out_channels=nearby_int(width * 16)), 125 | dict(expansion=6, stride=2, out_channels=nearby_int(width * 24)), 126 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 24)), 127 | dict(expansion=6, stride=2, out_channels=nearby_int(width * 32)), 128 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 32)), 129 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 32)), 130 | dict(expansion=6, stride=2, out_channels=nearby_int(width * 64)), 131 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 64)), 132 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 64)), 133 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 64)), 134 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 96)), 135 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 96)), 136 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 96)), 137 | dict(expansion=6, stride=2, out_channels=nearby_int(width * 160)), 138 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 160)), 139 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 160)), 140 | dict(expansion=6, stride=1, out_channels=nearby_int(width * 320)), 141 | ] 142 | 143 | self.features = nn.Sequential() 144 | self.features.add_module('conv0', conv(3, in_channels, 145 | kernel=3, stride=2, padding=1, batch_norm=batch_norm,num_bits=8, num_bits_weight=8, perC=True, measure=False, cal_qparams=False)) 146 | 147 | for i, layer in enumerate(layers_config): 148 | layer['in_channels'] = in_channels 149 | layer['batch_norm'] = batch_norm 150 | layer['num_bits'] = num_bits 151 | layer['num_bits_weight'] = num_bits_weight 152 | layer['perC'] = perC 153 | layer['measure'] = measure 154 | layer['cal_qparams'] = cal_qparams 155 | in_channels = layer['out_channels'] 156 | self.features.add_module( 157 | 'bottleneck' + str(i), ExpandedConv2d(**layer)) 158 | 159 | out_channels = nearby_int(width * 1280) 160 | self.features.add_module('conv1', conv(in_channels, out_channels, 161 | kernel=1, stride=1, padding=0, batch_norm=batch_norm,num_bits=num_bits, num_bits_weight=num_bits_weight, perC=perC, measure=measure, cal_qparams=cal_qparams)) 162 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 163 | self.classifier = nn.Sequential( 164 | nn.Dropout(0.2, True), 165 | nn.Linear(out_channels, num_classes) 166 | ) 167 | init_model(self) 168 | 169 | if regime == 'small': 170 | scale_lr *= 4 171 | self.data_regime = [ 172 | {'epoch': 0, 'input_size': 128, 'batch_size': 512}, 173 | {'epoch': 80, 'input_size': 224, 'batch_size': 128}, 174 | ] 175 | self.data_eval_regime = [ 176 | {'epoch': 0, 'input_size': 128, 177 | 'scale_size': 160, 'batch_size': 1024}, 178 | {'epoch': 80, 'input_size': 224, 'batch_size': 512}, 179 | ] 180 | 181 | self.regime = [ 182 | {'epoch': 0, 'optimizer': 'SGD', 'momentum': 0.9, 'lr': scale_lr * 1e-1, 183 | 'regularizer': weight_decay_config(1e-4)}, 184 | {'epoch': 30, 'lr': scale_lr * 1e-2}, 185 | {'epoch': 60, 'lr': scale_lr * 1e-3}, 186 | {'epoch': 80, 'lr': scale_lr * 1e-4} 187 | ] 188 | 189 | def forward(self, x): 190 | x = self.features(x) 191 | x = self.avg_pool(x) 192 | x = x.view(x.size(0), -1) 193 | x = self.classifier(x) 194 | return x 195 | 196 | 197 | def mobilenet_v2(**config): 198 | r"""MobileNet v2 model architecture from the `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" 199 | `_ paper. 200 | """ 201 | dataset = config.pop('dataset', 'imagenet') 202 | assert dataset == 'imagenet' 203 | return MobileNet_v2(**config) 204 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging.config 3 | from copy import deepcopy 4 | from six import string_types 5 | from .regime import Regime 6 | from .param_filter import FilterParameters 7 | from . import regularization 8 | import torch.nn as nn 9 | 10 | _OPTIMIZERS = {name: func for name, func in torch.optim.__dict__.items()} 11 | 12 | try: 13 | from adabound import AdaBound 14 | _OPTIMIZERS['AdaBound'] = AdaBound 15 | except ImportError: 16 | pass 17 | 18 | 19 | def copy_params(param_target, param_src): 20 | with torch.no_grad(): 21 | for p_src, p_target in zip(param_src, param_target): 22 | p_target.copy_(p_src) 23 | 24 | 25 | def copy_params_grad(param_target, param_src): 26 | for p_src, p_target in zip(param_src, param_target): 27 | if p_target.grad is None: 28 | p_target.backward(p_src.grad.to(dtype=p_target.dtype)) 29 | else: 30 | p_target.grad.detach().copy_(p_src.grad) 31 | 32 | 33 | class ModuleFloatShadow(nn.Module): 34 | def __init__(self, module): 35 | super(ModuleFloatShadow, self).__init__() 36 | self.original_module = module 37 | self.float_module = deepcopy(module) 38 | self.float_module.to(dtype=torch.float) 39 | 40 | def parameters(self, *kargs, **kwargs): 41 | return self.float_module.parameters(*kargs, **kwargs) 42 | 43 | def named_parameters(self, *kargs, **kwargs): 44 | return self.float_module.named_parameters(*kargs, **kwargs) 45 | 46 | def modules(self, *kargs, **kwargs): 47 | return self.float_module.modules(*kargs, **kwargs) 48 | 49 | def named_modules(self, *kargs, **kwargs): 50 | return self.float_module.named_modules(*kargs, **kwargs) 51 | 52 | def original_parameters(self, *kargs, **kwargs): 53 | return self.original_module.parameters(*kargs, **kwargs) 54 | 55 | def original_named_parameters(self, *kargs, **kwargs): 56 | return self.original_module.named_parameters(*kargs, **kwargs) 57 | 58 | def original_modules(self, *kargs, **kwargs): 59 | return self.original_module.modules(*kargs, **kwargs) 60 | 61 | def original_named_modules(self, *kargs, **kwargs): 62 | return self.original_module.named_modules(*kargs, **kwargs) 63 | 64 | 65 | class OptimRegime(Regime): 66 | """ 67 | Reconfigures the optimizer according to setting list. 68 | Exposes optimizer methods - state, step, zero_grad, add_param_group 69 | 70 | Examples for regime: 71 | 72 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 73 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 74 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 75 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 76 | ]" 77 | 2) 78 | "[{'step_lambda': 79 | "lambda t: { 80 | 'optimizer': 'Adam', 81 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 82 | 'betas': (0.9, 0.98), 'eps':1e-9} 83 | }]" 84 | """ 85 | 86 | def __init__(self, model, regime, defaults={}, filter=None, use_float_copy=False): 87 | super(OptimRegime, self).__init__(regime, defaults) 88 | if filter is not None: 89 | model = FilterParameters(model, **filter) 90 | if use_float_copy: 91 | model = ModuleFloatShadow(model) 92 | self._original_parameters = list(model.original_parameters()) 93 | 94 | self.parameters = list(model.parameters()) 95 | self.optimizer = torch.optim.SGD(self.parameters, lr=0) 96 | self.regularizer = regularization.Regularizer(model) 97 | self.use_float_copy = use_float_copy 98 | 99 | def update(self, epoch=None, train_steps=None): 100 | """adjusts optimizer according to current epoch or steps and training regime. 101 | """ 102 | if super(OptimRegime, self).update(epoch, train_steps): 103 | self.adjust(self.setting) 104 | return True 105 | else: 106 | return False 107 | 108 | def adjust(self, setting): 109 | """adjusts optimizer according to a setting dict. 110 | e.g: setting={optimizer': 'Adam', 'lr': 5e-4} 111 | """ 112 | if 'optimizer' in setting: 113 | optim_method = _OPTIMIZERS[setting['optimizer']] 114 | if not isinstance(self.optimizer, optim_method): 115 | self.optimizer = optim_method(self.optimizer.param_groups) 116 | logging.debug('OPTIMIZER - setting method = %s' % 117 | setting['optimizer']) 118 | for param_group in self.optimizer.param_groups: 119 | for key in param_group.keys(): 120 | if key in setting: 121 | new_val = setting[key] 122 | if new_val != param_group[key]: 123 | logging.debug('OPTIMIZER - setting %s = %s' % 124 | (key, setting[key])) 125 | param_group[key] = setting[key] 126 | # fix for AdaBound 127 | if key == 'lr' and hasattr(self.optimizer, 'base_lrs'): 128 | self.optimizer.base_lrs = list( 129 | map(lambda group: group['lr'], self.optimizer.param_groups)) 130 | 131 | if 'regularizer' in setting: 132 | reg_list = deepcopy(setting['regularizer']) 133 | if not (isinstance(reg_list, list) or isinstance(reg_list, tuple)): 134 | reg_list = (reg_list,) 135 | regularizers = [] 136 | for reg in reg_list: 137 | if isinstance(reg, dict): 138 | logging.debug('OPTIMIZER - Regularization - %s' % reg) 139 | name = reg.pop('name') 140 | regularizers.append((regularization.__dict__[name], reg)) 141 | elif isinstance(reg, regularization.Regularizer): 142 | regularizers.append(reg) 143 | else: # callable on model 144 | regularizers.append(reg(self.regularizer._model)) 145 | self.regularizer = regularization.RegularizerList(self.regularizer._model, 146 | regularizers) 147 | 148 | def __getstate__(self): 149 | return { 150 | 'optimizer_state': self.optimizer.__getstate__(), 151 | 'regime': self.regime, 152 | } 153 | 154 | def __setstate__(self, state): 155 | self.regime = state.get('regime') 156 | self.optimizer.__setstate__(state.get('optimizer_state')) 157 | 158 | def state_dict(self): 159 | """Returns the state of the optimizer as a :class:`dict`. 160 | """ 161 | return { 162 | 'optimizer_state': self.optimizer.state_dict(), 163 | 'regime': self.regime, 164 | } 165 | 166 | def load_state_dict(self, state_dict): 167 | """Loads the optimizer state. 168 | 169 | Arguments: 170 | state_dict (dict): optimizer state. Should be an object returned 171 | from a call to :meth:`state_dict`. 172 | """ 173 | # deepcopy, to be consistent with module API 174 | optimizer_state_dict = state_dict['optimizer_state'] 175 | 176 | self.__setstate__({'optimizer_state': optimizer_state_dict, 177 | 'regime': state_dict['regime']}) 178 | 179 | def zero_grad(self): 180 | """Clears the gradients of all optimized :class:`Variable` s.""" 181 | self.optimizer.zero_grad() 182 | if self.use_float_copy: 183 | for p in self._original_parameters: 184 | if p.grad is not None: 185 | p.grad.detach().zero_() 186 | 187 | def step(self, closure=None): 188 | """Performs a single optimization step (parameter update). 189 | 190 | Arguments: 191 | closure (callable): A closure that reevaluates the model and 192 | returns the loss. Optional for most optimizers. 193 | """ 194 | if self.use_float_copy: 195 | copy_params_grad(self.parameters, self._original_parameters) 196 | self.regularizer.pre_step() 197 | self.optimizer.step(closure) 198 | self.regularizer.post_step() 199 | if self.use_float_copy: 200 | copy_params(self._original_parameters, self.parameters) 201 | 202 | def pre_forward(self): 203 | """ allows modification pre-forward pass - e.g for regularization 204 | """ 205 | self.regularizer.pre_forward() 206 | 207 | def pre_backward(self): 208 | """ allows modification post-forward pass and pre-backward - e.g for regularization 209 | """ 210 | self.regularizer.pre_backward() 211 | 212 | 213 | class MultiOptimRegime(OptimRegime): 214 | 215 | def __init__(self, *optim_regime_list): 216 | self.optim_regime_list = [] 217 | for optim_regime in optim_regime_list: 218 | assert isinstance(optim_regime, OptimRegime) 219 | self.optim_regime_list.append(optim_regime) 220 | 221 | def update(self, epoch=None, train_steps=None): 222 | """adjusts optimizer according to current epoch or steps and training regime. 223 | """ 224 | updated = False 225 | for i, optim in enumerate(self.optim_regime_list): 226 | current_updated = optim.update(epoch, train_steps) 227 | if current_updated: 228 | logging.debug('OPTIMIZER #%s was updated' % i) 229 | updated = updated or current_updated 230 | return updated 231 | 232 | def zero_grad(self): 233 | """Clears the gradients of all optimized :class:`Variable` s.""" 234 | for optim in self.optim_regime_list: 235 | optim.zero_grad() 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step (parameter update). 239 | 240 | Arguments: 241 | closure (callable): A closure that reevaluates the model and 242 | returns the loss. Optional for most optimizers. 243 | """ 244 | for optim in self.optim_regime_list: 245 | optim.step(closure) 246 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from torch.utils.data.distributed import DistributedSampler 5 | from torch.utils.data.sampler import RandomSampler 6 | from torch.utils.data import Subset 7 | from torch._utils import _accumulate 8 | from utils.regime import Regime 9 | from utils.dataset import IndexedFileDataset 10 | from preprocess import get_transform 11 | from itertools import chain 12 | from copy import deepcopy 13 | import warnings 14 | import numpy as np 15 | from PIL import Image 16 | 17 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 18 | 19 | 20 | def get_dataset(name, split='train', transform=None, 21 | target_transform=None, download=True, datasets_path='~/Datasets'): 22 | train = (split == 'train') 23 | root = os.path.join(os.path.expanduser(datasets_path), name) 24 | if name == 'cifar10': 25 | return datasets.CIFAR10(root=root, 26 | train=train, 27 | transform=transform, 28 | target_transform=target_transform, 29 | download=download) 30 | elif name == 'cifar100': 31 | return datasets.CIFAR100(root=root, 32 | train=train, 33 | transform=transform, 34 | target_transform=target_transform, 35 | download=download) 36 | elif name == 'mnist': 37 | return datasets.MNIST(root=root, 38 | train=train, 39 | transform=transform, 40 | target_transform=target_transform, 41 | download=download) 42 | elif name == 'stl10': 43 | return datasets.STL10(root=root, 44 | split=split, 45 | transform=transform, 46 | target_transform=target_transform, 47 | download=download) 48 | elif name == 'imagenet': 49 | if train: 50 | root = os.path.join(root, 'train') 51 | else: 52 | root = os.path.join(root, 'val') 53 | return datasets.ImageFolder(root=root, 54 | transform=transform, 55 | target_transform=target_transform) 56 | elif name == 'imagenet_calib': 57 | if train: 58 | root = os.path.join(root.replace('imagenet_calib','imagenet'), 'calib') 59 | else: 60 | root = os.path.join(root, 'val') 61 | return datasets.ImageFolder(root=root, 62 | transform=transform, 63 | target_transform=target_transform) 64 | elif name == 'imagenet_calib_10K': 65 | if train: 66 | root = os.path.join(root.replace('imagenet_calib_10K','imagenet'), 'calib_10K') 67 | else: 68 | root = os.path.join(root, 'val') 69 | return datasets.ImageFolder(root=root, 70 | transform=transform, 71 | target_transform=target_transform) 72 | elif name == 'imagenet_tar': 73 | if train: 74 | root = os.path.join(root, 'imagenet_train.tar') 75 | else: 76 | root = os.path.join(root, 'imagenet_validation.tar') 77 | return IndexedFileDataset(root, extract_target_fn=( 78 | lambda fname: fname.split('/')[0]), 79 | transform=transform, 80 | target_transform=target_transform) 81 | 82 | 83 | _DATA_ARGS = {'name', 'split', 'transform', 84 | 'target_transform', 'download', 'datasets_path'} 85 | _DATALOADER_ARGS = {'batch_size', 'shuffle', 'sampler', 'batch_sampler', 86 | 'num_workers', 'collate_fn', 'pin_memory', 'drop_last', 87 | 'timeout', 'worker_init_fn'} 88 | _TRANSFORM_ARGS = {'transform_name', 'input_size', 'scale_size', 'normalize', 'augment', 89 | 'cutout', 'duplicates', 'num_crops', 'autoaugment'} 90 | _OTHER_ARGS = {'distributed'} 91 | 92 | 93 | #class ImageNetCalib(datasets.ImageFolder): 94 | # """Small calibration dataset taken from training.""" 95 | # 96 | # def __init__(self, root,transform=None, target_transform=None): 97 | # """ 98 | # Args: 99 | # csv_file (string): Path to the csv file with annotations. 100 | # root_dir (string): Directory with all the images. 101 | # transform (callable, optional): Optional transform to be applied 102 | # on a sample. 103 | # """ 104 | # self.samples,self.target = torch.load(root) 105 | # self.samples = Image.fromarray(np.uint8(self.samples.permute(0,2,3,1).contiguous().numpy())) 106 | # self.root = root 107 | # self.transform = transform 108 | # self.target_transform = target_transform 109 | # 110 | # 111 | # def __len__(self): 112 | # return len(self.target) 113 | # 114 | # def __getitem__(self, idx): 115 | # samples = self.samples[idx] 116 | # target = self.target[idx] 117 | # if self.transform is not None: 118 | # #import pdb; pdb.set_trace() 119 | # print(samples.shape) 120 | # samples = self.transform(samples) 121 | # if self.target_transform is not None: 122 | # target = self.target_transform(target) 123 | # return samples,target #,idx 124 | 125 | 126 | class DataRegime(object): 127 | def __init__(self, regime, defaults={}): 128 | self.regime = Regime(regime, deepcopy(defaults)) 129 | self.epoch = 0 130 | self.steps = None 131 | self.get_loader(True) 132 | 133 | def get_setting(self): 134 | setting = self.regime.setting 135 | loader_setting = {k: v for k, 136 | v in setting.items() if k in _DATALOADER_ARGS} 137 | data_setting = {k: v for k, v in setting.items() if k in _DATA_ARGS} 138 | transform_setting = { 139 | k: v for k, v in setting.items() if k in _TRANSFORM_ARGS} 140 | other_setting = {k: v for k, v in setting.items() if k in _OTHER_ARGS} 141 | transform_setting.setdefault('transform_name', data_setting['name']) 142 | return {'data': data_setting, 'loader': loader_setting, 143 | 'transform': transform_setting, 'other': other_setting} 144 | 145 | def get(self, key, default=None): 146 | return self.regime.setting.get(key, default) 147 | 148 | def get_loader(self, force_update=False, override_settings=None, subset_indices=None): 149 | if force_update or self.regime.update(self.epoch, self.steps): 150 | setting = self.get_setting() 151 | if override_settings is not None: 152 | setting.update(override_settings) 153 | self._transform = get_transform(**setting['transform']) 154 | setting['data'].setdefault('transform', self._transform) 155 | self._data = get_dataset(**setting['data']) 156 | if subset_indices is not None: 157 | self._data = Subset(self._data, subset_indices) 158 | if setting['other'].get('distributed', False): 159 | setting['loader']['sampler'] = DistributedSampler(self._data) 160 | setting['loader']['shuffle'] = None 161 | # pin-memory currently broken for distributed 162 | setting['loader']['pin_memory'] = False 163 | self._sampler = setting['loader'].get('sampler', None) 164 | self._loader = torch.utils.data.DataLoader( 165 | self._data, **setting['loader']) 166 | return self._loader 167 | 168 | def set_epoch(self, epoch): 169 | self.epoch = epoch 170 | if self._sampler is not None and hasattr(self._sampler, 'set_epoch'): 171 | self._sampler.set_epoch(epoch) 172 | 173 | def __len__(self): 174 | return len(self._data) 175 | 176 | 177 | class SampledDataLoader(object): 178 | def __init__(self, dl_list): 179 | self.dl_list = dl_list 180 | self.epoch = 0 181 | 182 | def generate_order(self): 183 | 184 | order = [[idx]*len(dl) for idx, dl in enumerate(self.dl_list)] 185 | order = list(chain(*order)) 186 | g = torch.Generator() 187 | g.manual_seed(self.epoch) 188 | return torch.tensor(order)[torch.randperm(len(order), generator=g)].tolist() 189 | 190 | def __len__(self): 191 | return sum([len(dl) for dl in self.dl_list]) 192 | 193 | def __iter__(self): 194 | order = self.generate_order() 195 | 196 | iterators = [iter(dl) for dl in self.dl_list] 197 | for idx in order: 198 | yield next(iterators[idx]) 199 | return 200 | 201 | 202 | class SampledDataRegime(DataRegime): 203 | def __init__(self, data_regime_list, probs, split_data=True): 204 | self.probs = probs 205 | self.data_regime_list = data_regime_list 206 | self.split_data = split_data 207 | 208 | def get_setting(self): 209 | return [data_regime.get_setting() for data_regime in self.data_regime_list] 210 | 211 | def get(self, key, default=None): 212 | return [data_regime.get(key, default) for data_regime in self.data_regime_list] 213 | 214 | def get_loader(self, force_update=False): 215 | settings = self.get_setting() 216 | if self.split_data: 217 | dset_sizes = [len(get_dataset(**s['data'])) for s in settings] 218 | assert len(set(dset_sizes)) == 1, \ 219 | "all datasets should be same size" 220 | dset_size = dset_sizes[0] 221 | lengths = [int(prob * dset_size) for prob in self.probs] 222 | lengths[-1] = dset_size - sum(lengths[:-1]) 223 | indices = torch.randperm(dset_size).tolist() 224 | indices_split = [indices[offset - length:offset] 225 | for offset, length in zip(_accumulate(lengths), lengths)] 226 | loaders = [data_regime.get_loader(force_update=True, subset_indices=indices_split[i]) 227 | for i, data_regime in enumerate(self.data_regime_list)] 228 | else: 229 | loaders = [data_regime.get_loader( 230 | force_update=force_update) for data_regime in self.data_regime_list] 231 | self._loader = SampledDataLoader(loaders) 232 | self._loader.epoch = self.epoch 233 | 234 | return self._loader 235 | 236 | def set_epoch(self, epoch): 237 | self.epoch = epoch 238 | if hasattr(self, '_loader'): 239 | self._loader.epoch = epoch 240 | for data_regime in self.data_regime_list: 241 | if data_regime._sampler is not None and hasattr(data_regime._sampler, 'set_epoch'): 242 | data_regime._sampler.set_epoch(epoch) 243 | 244 | def __len__(self): 245 | return sum([len(data_regime._data) 246 | for data_regime in self.data_regime_list]) 247 | 248 | 249 | if __name__ == '__main__': 250 | reg1 = DataRegime(None, {'name': 'imagenet', 'batch_size': 16}) 251 | reg2 = DataRegime(None, {'name': 'imagenet', 'batch_size': 32}) 252 | reg1.set_epoch(0) 253 | reg2.set_epoch(0) 254 | mreg = SampledDataRegime([reg1, reg2]) 255 | 256 | for x, _ in mreg.get_loader(): 257 | print(x.shape) 258 | -------------------------------------------------------------------------------- /models/modules/evolved_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from https://github.com/quark0/darts 3 | """ 4 | from collections import namedtuple 5 | import torch 6 | import torch.nn as nn 7 | 8 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 9 | 10 | OPS = { 11 | 'avg_pool_3x3': lambda channels, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 12 | 'max_pool_3x3': lambda channels, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 13 | 'skip_connect': lambda channels, stride, affine: Identity() if stride == 1 else FactorizedReduce(channels, channels, affine=affine), 14 | 'sep_conv_3x3': lambda channels, stride, affine: SepConv(channels, channels, 3, stride, 1, affine=affine), 15 | 'sep_conv_5x5': lambda channels, stride, affine: SepConv(channels, channels, 5, stride, 2, affine=affine), 16 | 'sep_conv_7x7': lambda channels, stride, affine: SepConv(channels, channels, 7, stride, 3, affine=affine), 17 | 'dil_conv_3x3': lambda channels, stride, affine: DilConv(channels, channels, 3, stride, 2, 2, affine=affine), 18 | 'dil_conv_5x5': lambda channels, stride, affine: DilConv(channels, channels, 5, stride, 4, 2, affine=affine), 19 | 'conv_7x1_1x7': lambda channels, stride, affine: nn.Sequential( 20 | nn.ReLU(inplace=False), 21 | nn.Conv2d(channels, channels, (1, 7), stride=(1, stride), 22 | padding=(0, 3), bias=False), 23 | nn.Conv2d(channels, channels, (7, 1), stride=(stride, 1), 24 | padding=(3, 0), bias=False), 25 | nn.BatchNorm2d(channels, affine=affine) 26 | ), 27 | } 28 | 29 | 30 | # genotypes 31 | GENOTYPES = dict( 32 | NASNet=Genotype( 33 | normal=[ 34 | ('sep_conv_5x5', 1), 35 | ('sep_conv_3x3', 0), 36 | ('sep_conv_5x5', 0), 37 | ('sep_conv_3x3', 0), 38 | ('avg_pool_3x3', 1), 39 | ('skip_connect', 0), 40 | ('avg_pool_3x3', 0), 41 | ('avg_pool_3x3', 0), 42 | ('sep_conv_3x3', 1), 43 | ('skip_connect', 1), 44 | ], 45 | normal_concat=[2, 3, 4, 5, 6], 46 | reduce=[ 47 | ('sep_conv_5x5', 1), 48 | ('sep_conv_7x7', 0), 49 | ('max_pool_3x3', 1), 50 | ('sep_conv_7x7', 0), 51 | ('avg_pool_3x3', 1), 52 | ('sep_conv_5x5', 0), 53 | ('skip_connect', 3), 54 | ('avg_pool_3x3', 2), 55 | ('sep_conv_3x3', 2), 56 | ('max_pool_3x3', 1), 57 | ], 58 | reduce_concat=[4, 5, 6], 59 | ), 60 | 61 | AmoebaNet=Genotype( 62 | normal=[ 63 | ('avg_pool_3x3', 0), 64 | ('max_pool_3x3', 1), 65 | ('sep_conv_3x3', 0), 66 | ('sep_conv_5x5', 2), 67 | ('sep_conv_3x3', 0), 68 | ('avg_pool_3x3', 3), 69 | ('sep_conv_3x3', 1), 70 | ('skip_connect', 1), 71 | ('skip_connect', 0), 72 | ('avg_pool_3x3', 1), 73 | ], 74 | normal_concat=[4, 5, 6], 75 | reduce=[ 76 | ('avg_pool_3x3', 0), 77 | ('sep_conv_3x3', 1), 78 | ('max_pool_3x3', 0), 79 | ('sep_conv_7x7', 2), 80 | ('sep_conv_7x7', 0), 81 | ('avg_pool_3x3', 1), 82 | ('max_pool_3x3', 0), 83 | ('max_pool_3x3', 1), 84 | ('conv_7x1_1x7', 0), 85 | ('sep_conv_3x3', 5), 86 | ], 87 | reduce_concat=[3, 4, 6] 88 | ), 89 | 90 | DARTS_V1=Genotype( 91 | normal=[ 92 | ('sep_conv_3x3', 1), 93 | ('sep_conv_3x3', 0), 94 | ('skip_connect', 0), 95 | ('sep_conv_3x3', 1), 96 | ('skip_connect', 0), 97 | ('sep_conv_3x3', 1), 98 | ('sep_conv_3x3', 0), 99 | ('skip_connect', 2)], 100 | normal_concat=[2, 3, 4, 5], 101 | reduce=[('max_pool_3x3', 0), 102 | ('max_pool_3x3', 1), 103 | ('skip_connect', 2), 104 | ('max_pool_3x3', 0), 105 | ('max_pool_3x3', 0), 106 | ('skip_connect', 2), 107 | ('skip_connect', 2), 108 | ('avg_pool_3x3', 0)], 109 | reduce_concat=[2, 3, 4, 5]), 110 | DARTS=Genotype(normal=[('sep_conv_3x3', 0), 111 | ('sep_conv_3x3', 1), 112 | ('sep_conv_3x3', 0), 113 | ('sep_conv_3x3', 1), 114 | ('sep_conv_3x3', 1), 115 | ('skip_connect', 0), 116 | ('skip_connect', 0), 117 | ('dil_conv_3x3', 2)], 118 | normal_concat=[2, 3, 4, 5], 119 | reduce=[('max_pool_3x3', 0), 120 | ('max_pool_3x3', 1), 121 | ('skip_connect', 2), 122 | ('max_pool_3x3', 1), 123 | ('max_pool_3x3', 0), 124 | ('skip_connect', 2), 125 | ('skip_connect', 2), 126 | ('max_pool_3x3', 1)], 127 | reduce_concat=[2, 3, 4, 5]), 128 | ) 129 | 130 | 131 | class ReLUConvBN(nn.Module): 132 | 133 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 134 | super(ReLUConvBN, self).__init__() 135 | self.op = nn.Sequential( 136 | nn.ReLU(inplace=False), 137 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, 138 | padding=padding, bias=False), 139 | nn.BatchNorm2d(C_out, affine=affine) 140 | ) 141 | 142 | def forward(self, x): 143 | return self.op(x) 144 | 145 | 146 | class DilConv(nn.Module): 147 | 148 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 149 | super(DilConv, self).__init__() 150 | self.op = nn.Sequential( 151 | nn.ReLU(inplace=False), 152 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, 153 | padding=padding, dilation=dilation, groups=C_in, bias=False), 154 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 155 | nn.BatchNorm2d(C_out, affine=affine), 156 | ) 157 | 158 | def forward(self, x): 159 | return self.op(x) 160 | 161 | 162 | class SepConv(nn.Module): 163 | 164 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 165 | super(SepConv, self).__init__() 166 | self.op = nn.Sequential( 167 | nn.ReLU(inplace=False), 168 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, 169 | padding=padding, groups=C_in, bias=False), 170 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 171 | nn.BatchNorm2d(C_in, affine=affine), 172 | nn.ReLU(inplace=False), 173 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, 174 | padding=padding, groups=C_in, bias=False), 175 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 176 | nn.BatchNorm2d(C_out, affine=affine), 177 | ) 178 | 179 | def forward(self, x): 180 | return self.op(x) 181 | 182 | 183 | class Identity(nn.Module): 184 | 185 | def __init__(self): 186 | super(Identity, self).__init__() 187 | 188 | def forward(self, x): 189 | return x 190 | 191 | 192 | class FactorizedReduce(nn.Module): 193 | 194 | def __init__(self, C_in, C_out, affine=True): 195 | super(FactorizedReduce, self).__init__() 196 | assert C_out % 2 == 0 197 | self.relu = nn.ReLU(inplace=False) 198 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, 199 | stride=2, padding=0, bias=False) 200 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, 201 | stride=2, padding=0, bias=False) 202 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 203 | 204 | def forward(self, x): 205 | x = self.relu(x) 206 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) 207 | out = self.bn(out) 208 | return out 209 | 210 | 211 | def drop_path(x, drop_prob): 212 | if drop_prob > 0.: 213 | keep_prob = 1.-drop_prob 214 | mask = x.new(x.size(0), 1, 1, 1).bernoulli_(keep_prob) 215 | x.div_(keep_prob) 216 | x.mul_(mask) 217 | return x 218 | 219 | 220 | class Cell(nn.Module): 221 | 222 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): 223 | super(Cell, self).__init__() 224 | if reduction_prev: 225 | self.preprocess0 = FactorizedReduce(C_prev_prev, C) 226 | else: 227 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) 228 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) 229 | 230 | if reduction: 231 | op_names, indices = zip(*genotype.reduce) 232 | concat = genotype.reduce_concat 233 | else: 234 | op_names, indices = zip(*genotype.normal) 235 | concat = genotype.normal_concat 236 | self._compile(C, op_names, indices, concat, reduction) 237 | 238 | def _compile(self, C, op_names, indices, concat, reduction): 239 | assert len(op_names) == len(indices) 240 | self._steps = len(op_names) // 2 241 | self._concat = concat 242 | self.multiplier = len(concat) 243 | 244 | self._ops = nn.ModuleList() 245 | for name, index in zip(op_names, indices): 246 | stride = 2 if reduction and index < 2 else 1 247 | op = OPS[name](C, stride, True) 248 | self._ops += [op] 249 | self._indices = indices 250 | 251 | def forward(self, s0, s1, drop_prob): 252 | s0 = self.preprocess0(s0) 253 | s1 = self.preprocess1(s1) 254 | 255 | states = [s0, s1] 256 | for i in range(self._steps): 257 | h1 = states[self._indices[2*i]] 258 | h2 = states[self._indices[2*i+1]] 259 | op1 = self._ops[2*i] 260 | op2 = self._ops[2*i+1] 261 | h1 = op1(h1) 262 | h2 = op2(h2) 263 | if self.training and drop_prob > 0.: 264 | if not isinstance(op1, Identity): 265 | h1 = drop_path(h1, drop_prob) 266 | if not isinstance(op2, Identity): 267 | h2 = drop_path(h2, drop_prob) 268 | s = h1 + h2 269 | states += [s] 270 | return torch.cat([states[i] for i in self._concat], dim=1) 271 | 272 | 273 | class NasNetCell(Cell): 274 | def __init__(self, *kargs, **kwargs): 275 | super(NasNetCell, self).__init__(GENOTYPES['NASNet'], *kargs, **kwargs) 276 | 277 | 278 | class AmoebaNetCell(Cell): 279 | def __init__(self, *kargs, **kwargs): 280 | super(AmoebaNetCell, self).__init__( 281 | GENOTYPES['AmoebaNet'], *kargs, **kwargs) 282 | 283 | 284 | class DARTSCell(Cell): 285 | def __init__(self, *kargs, **kwargs): 286 | super(DARTSCell, self).__init__(GENOTYPES['DARTS'], *kargs, **kwargs) 287 | -------------------------------------------------------------------------------- /utils/regularization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .param_filter import FilterParameters, is_not_bn, is_not_bias 3 | from .absorb_bn import search_absorbe_bn 4 | from torch.nn.utils import clip_grad_norm_ 5 | import logging 6 | 7 | 8 | def sparsity(p): 9 | return float(p.eq(0).sum()) / p.nelement() 10 | 11 | 12 | def _norm_exclude_dim(x, dim=0, keepdim=False): 13 | dims = tuple(set(range(x.dim())) - set([dim])) 14 | return x.pow(2).sum(dims, keepdim=keepdim).sqrt() 15 | 16 | 17 | def _renorm(x, dim=0, inplace=False, eps=1e-12): 18 | if not inplace: 19 | x = x.clone() 20 | return x.div_(_norm_exclude_dim(x, dim, keepdim=True)) 21 | 22 | 23 | def _norm(x, dim, p=2): 24 | """Computes the norm over all dimensions except dim""" 25 | if p == -1: 26 | def func(x, dim): return x.max(dim=dim)[0] - x.min(dim=dim)[0] 27 | elif p == float('inf'): 28 | def func(x, dim): return x.max(dim=dim)[0] 29 | else: 30 | def func(x, dim): return torch.norm(x, dim=dim, p=p) 31 | if dim is None: 32 | return x.norm(p=p) 33 | elif dim == 0: 34 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 35 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 36 | elif dim == x.dim() - 1: 37 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 38 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 39 | else: 40 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 41 | 42 | 43 | class Regularizer(object): 44 | def __init__(self, model, value=1e-3, filter={}, log=False): 45 | self._model = model 46 | self._named_parameters = list( 47 | FilterParameters(model, **filter).named_parameters()) 48 | self.value = value 49 | self.log = log 50 | if self.log: 51 | logging.debug('Applying regularization to parameters: %s', 52 | [n for n, _ in self._named_parameters]) 53 | 54 | def named_parameters(self): 55 | for n, p in self._named_parameters: 56 | yield n, p 57 | 58 | def parameters(self): 59 | for _, p in self.named_parameters(): 60 | yield p 61 | 62 | def _pre_parameter_step(self, parameter): 63 | pass 64 | 65 | def _post_parameter_step(self, parameter): 66 | pass 67 | 68 | def pre_step(self): 69 | pass 70 | 71 | def post_step(self): 72 | pass 73 | 74 | def pre_forward(self): 75 | pass 76 | 77 | def pre_backward(self): 78 | pass 79 | 80 | 81 | class RegularizerList(Regularizer): 82 | def __init__(self, model, regularization_list): 83 | """each item must of of format (RegClass, **kwargs) or instance of Regularizer""" 84 | super(RegularizerList, self).__init__(model) 85 | self.regularization_list = [] 86 | for regularizer in regularization_list: 87 | if not isinstance(regularizer, Regularizer): 88 | reg, reg_params = regularizer 89 | regularizer = reg(model=model, **reg_params) 90 | self.regularization_list.append(regularizer) 91 | 92 | def pre_step(self): 93 | for reg in self.regularization_list: 94 | reg.pre_step() 95 | 96 | def post_step(self): 97 | for reg in self.regularization_list: 98 | reg.post_step() 99 | 100 | def pre_forward(self): 101 | for reg in self.regularization_list: 102 | reg.pre_forward() 103 | 104 | def pre_backward(self): 105 | for reg in self.regularization_list: 106 | reg.pre_backward() 107 | 108 | 109 | class L2Regularization(Regularizer): 110 | def __init__(self, model, value=1e-3, 111 | filter={'parameter_name': is_not_bias, 112 | 'module': is_not_bn}, 113 | pre_op=True, post_op=False, **kwargs): 114 | super(L2Regularization, self).__init__( 115 | model, value, filter=filter, **kwargs) 116 | self.pre_op = pre_op 117 | self.post_op = post_op 118 | 119 | def pre_step(self): 120 | if self.pre_op: 121 | with torch.no_grad(): 122 | for _, p in self._named_parameters: 123 | p.grad.add_(self.value, p) 124 | if self.log: 125 | logging.debug('L2 penalty of %s was applied pre optimization step', 126 | self.value) 127 | 128 | def post_step(self): 129 | if self.post_op: 130 | with torch.no_grad(): 131 | for _, p in self._named_parameters: 132 | p.add_(-self.value, p) 133 | if self.log: 134 | logging.debug('L2 penalty of %s was applied post optimization step', 135 | self.value) 136 | 137 | 138 | class WeightDecay(L2Regularization): 139 | def __init__(self, *kargs, **kwargs): 140 | super(WeightDecay, self).__init__(*kargs, **kwargs) 141 | 142 | 143 | class GradClip(Regularizer): 144 | def __init__(self, *kargs, **kwargs): 145 | super(GradClip, self).__init__(*kargs, **kwargs) 146 | 147 | def pre_step(self): 148 | if self.value > 0: 149 | with torch.no_grad(): 150 | grad = clip_grad_norm_(self.parameters(), self.value) 151 | if self.log: 152 | logging.debug('Gradient value was clipped from %s to %s', 153 | grad, self.value) 154 | 155 | 156 | class L1Regularization(Regularizer): 157 | def __init__(self, model, value=1e-3, 158 | filter={'parameter_name': is_not_bias, 159 | 'module': is_not_bn}, 160 | pre_op=False, post_op=True, report_sparsity=False, **kwargs): 161 | super(L1Regularization, self).__init__( 162 | model, value, filter=filter, **kwargs) 163 | self.pre_op = pre_op 164 | self.post_op = post_op 165 | self.report_sparsity = report_sparsity 166 | 167 | def pre_step(self): 168 | if self.pre_op: 169 | with torch.no_grad(): 170 | for n, p in self._named_parameters: 171 | p.grad.add_(self.value, p.sign()) 172 | if self.report_sparsity: 173 | logging.debug('Sparsity for %s is %s', n, sparsity(p)) 174 | if self.log: 175 | logging.debug('L1 penalty of %s was applied pre optimization step', 176 | self.value) 177 | 178 | def post_step(self): 179 | if self.post_op: 180 | with torch.no_grad(): 181 | for n, p in self._named_parameters: 182 | p.copy_(torch.nn.functional.softshrink(p, self.value)) 183 | if self.report_sparsity: 184 | logging.debug('Sparsity for %s is %s', n, sparsity(p)) 185 | if self.log: 186 | logging.debug('L1 penalty of %s was applied post optimization step', 187 | self.value) 188 | 189 | 190 | class BoundedWeightNorm(Regularizer): 191 | def __init__(self, model, 192 | filter={'parameter_name': is_not_bias, 193 | 'module': is_not_bn}, 194 | dim=0, p=2, **kwargs): 195 | super(BoundedWeightNorm, self).__init__( 196 | model, 0, filter=filter, **kwargs) 197 | self.dim = dim 198 | self.init_norms = None 199 | self.p = p 200 | 201 | def _gather_init_norm(self): 202 | self.init_norms = {} 203 | with torch.no_grad(): 204 | for n, p in self._named_parameters: 205 | self.init_norms[n] = _norm( 206 | p, self.dim, p=self.p).detach().mean() 207 | 208 | def pre_forward(self): 209 | if self.init_norms is None: 210 | self._gather_init_norm() 211 | with torch.no_grad(): 212 | for n, p in self._named_parameters: 213 | init_norm = self.init_norms[n] 214 | new_norm = _norm(p, self.dim, p=self.p) 215 | p.mul_(init_norm / new_norm) 216 | 217 | def pre_step(self): 218 | for n, p in self._named_parameters: 219 | init_norm = self.init_norms[n] 220 | norm = _norm(p, self.dim, p=self.p) 221 | curr_grad = p.grad.data.clone() 222 | p.grad.data.zero_() 223 | p_normed = p * (init_norm / norm) 224 | p_normed.backward(curr_grad) 225 | 226 | 227 | class LARS(Regularizer): 228 | """Large Batch Training of Convolutional Networks - https://arxiv.org/abs/1708.03888 229 | """ 230 | 231 | def __init__(self, model, value=0.01, weight_decay=0, dim=None, p=2, min_scale=None, max_scale=None, 232 | filter={'parameter_name': is_not_bias, 233 | 'module': is_not_bn}, 234 | **kwargs): 235 | super(LARS, self).__init__(model, value, filter=filter, **kwargs) 236 | self.weight_decay = weight_decay 237 | self.dim = dim 238 | self.p = p 239 | self.min_scale = min_scale 240 | self.max_scale = max_scale 241 | 242 | def pre_step(self): 243 | with torch.no_grad(): 244 | for _, param in self._named_parameters: 245 | param.grad.add_(self.weight_decay, param) 246 | if self.dim is not None: 247 | norm = _norm(param, dim=self.dim, p=self.p) 248 | grad_norm = _norm(param.grad, dim=self.dim, p=self.p) 249 | else: 250 | norm = param.norm(p=self.p) 251 | grad_norm = param.grad.norm(p=self.p) 252 | scale = self.value * norm/grad_norm 253 | if self.min_scale is not None or self.max_scale is not None: 254 | scale.clamp_(min=self.min_scale, max=self.max_scale) 255 | param.grad.mul_(scale) 256 | 257 | 258 | class DropConnect(Regularizer): 259 | def __init__(self, model, value=0, 260 | filter={'parameter_name': is_not_bias, 261 | 'module': is_not_bn}, 262 | shakeshake=False, **kwargs): 263 | super(DropConnect, self).__init__( 264 | model, value=value, filter=filter, **kwargs) 265 | self.shakeshake = shakeshake 266 | 267 | def _drop_parameters(self): 268 | self.parameter_copy = {} 269 | with torch.no_grad(): 270 | for n, p in self._named_parameters: 271 | self.parameter_copy[n] = p.clone() 272 | torch.nn.functional.dropout(p, self.value, 273 | training=True, inplace=True) 274 | 275 | def _reassign_parameters(self): 276 | with torch.no_grad(): 277 | for n, p in self._named_parameters: 278 | p.copy_(self.parameter_copy.pop(n)) 279 | 280 | def pre_forward(self): 281 | self._drop_parameters() 282 | 283 | def pre_backward(self): 284 | if self.shakeshake: 285 | self._reassign_parameters() 286 | 287 | def pre_step(self): 288 | if not self.shakeshake: 289 | self._reassign_parameters() 290 | 291 | 292 | class AbsorbBN(Regularizer): 293 | def __init__(self, model, remove_bn=False): 294 | self._model = model 295 | if not remove_bn: 296 | for m in model.modules(): 297 | if isinstance(m, torch.nn.BatchNorm2d): 298 | m.momentum = 1 299 | self.remove_bn = remove_bn 300 | self._removed = False 301 | 302 | def pre_forward(self): 303 | if self._removed: 304 | return 305 | search_absorbe_bn(self._model, remove_bn=self.remove_bn, verbose=False) 306 | self._removed = self.remove_bn 307 | -------------------------------------------------------------------------------- /models/modules/lp_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parameter import Parameter 3 | from torch.autograd import Variable, Function 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | 8 | def _norm(x, dim, p=2): 9 | """Computes the norm over all dimensions except dim""" 10 | if p == -1: 11 | func = lambda x, dim: x.max(dim=dim)[0] - x.min(dim=dim)[0] 12 | elif p == float('inf'): 13 | func = lambda x, dim: x.max(dim=dim)[0] 14 | else: 15 | func = lambda x, dim: torch.norm(x, dim=dim, p=p) 16 | if dim is None: 17 | return x.norm(p=p) 18 | elif dim == 0: 19 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 20 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 21 | elif dim == x.dim() - 1: 22 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 23 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 24 | else: 25 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 26 | 27 | 28 | def _mean(p, dim): 29 | """Computes the mean over all dimensions except dim""" 30 | if dim is None: 31 | return p.mean() 32 | elif dim == 0: 33 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 34 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 35 | elif dim == p.dim() - 1: 36 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 37 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 38 | else: 39 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 40 | 41 | 42 | def _std(p, dim): 43 | """Computes the mean over all dimensions except dim""" 44 | if dim is None: 45 | return p.std() 46 | elif dim == 0: 47 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 48 | return p.contiguous().view(p.size(0), -1).std(dim=1).view(*output_size) 49 | elif dim == p.dim() - 1: 50 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 51 | return p.contiguous().view(-1, p.size(-1)).std(dim=0).view(*output_size) 52 | else: 53 | return _std(p.transpose(0, dim), 0).transpose(0, dim) 54 | 55 | # L2 56 | 57 | 58 | class LpBatchNorm2d(nn.Module): 59 | # This is L2 Baseline 60 | 61 | def __init__(self, num_features, dim=1, p=2, momentum=0.1, bias=True, eps=1e-5, noise=False): 62 | super(LpBatchNorm2d, self).__init__() 63 | self.register_buffer('running_mean', torch.zeros(num_features)) 64 | self.register_buffer('running_var', torch.zeros(num_features)) 65 | self.momentum = momentum 66 | self.dim = dim 67 | self.noise = noise 68 | self.p = p 69 | self.eps = eps 70 | self.bias = Parameter(torch.Tensor(num_features)) 71 | self.weight = Parameter(torch.Tensor(num_features)) 72 | 73 | def forward(self, x): 74 | p = self.p 75 | if self.training: 76 | mean = x.view(x.size(0), x.size(self.dim), -1).mean(-1).mean(0) 77 | y = x.transpose(0, 1) 78 | z = y.contiguous() 79 | t = z.view(z.size(0), -1) 80 | Var = (torch.abs((t.transpose(1, 0) - mean))**p).mean(0) 81 | 82 | scale = (Var + self.eps)**(-1 / p) 83 | 84 | self.running_mean.mul_(self.momentum).add_( 85 | mean.data * (1 - self.momentum)) 86 | 87 | self.running_var.mul_(self.momentum).add_( 88 | scale.data * (1 - self.momentum)) 89 | else: 90 | mean = torch.autograd.Variable(self.running_mean) 91 | scale = torch.autograd.Variable(self.running_var) 92 | 93 | out = (x - mean.view(1, mean.size(0), 1, 1)) * \ 94 | scale.view(1, scale.size(0), 1, 1) 95 | 96 | if self.noise and self.training: 97 | std = 0.1 * _std(x, self.dim).data 98 | ones = torch.ones_like(x.data) 99 | std_noise = Variable(torch.normal(ones, ones) * std) 100 | out = out * std_noise 101 | 102 | if self.weight is not None: 103 | out = out * self.weight.view(1, self.weight.size(0), 1, 1) 104 | 105 | if self.bias is not None: 106 | out = out + self.bias.view(1, self.bias.size(0), 1, 1) 107 | return out 108 | 109 | 110 | class TopkBatchNorm2d(nn.Module): 111 | # this is normalized L_inf 112 | 113 | def __init__(self, num_features, k=10, dim=1, momentum=0.1, bias=True, eps=1e-5, noise=False): 114 | super(TopkBatchNorm2d, self).__init__() 115 | self.register_buffer('running_mean', torch.zeros(num_features)) 116 | self.register_buffer('running_var', torch.zeros(num_features)) 117 | 118 | self.momentum = momentum 119 | self.dim = dim 120 | self.noise = noise 121 | self.k = k 122 | self.eps = eps 123 | self.bias = Parameter(torch.Tensor(num_features)) 124 | self.weight = Parameter(torch.Tensor(num_features)) 125 | 126 | def forward(self, x): 127 | if self.training: 128 | mean = x.view(x.size(0), x.size(self.dim), -1).mean(-1).mean(0) 129 | y = x.transpose(0, 1) 130 | z = y.contiguous() 131 | t = z.view(z.size(0), -1) 132 | A = torch.abs(t.transpose(1, 0) - mean) 133 | 134 | const = 0.5 * (1 + (np.pi * np.log(4)) ** 0.5) / \ 135 | ((2 * np.log(A.size(0))) ** 0.5) 136 | 137 | MeanTOPK = (torch.topk(A, self.k, dim=0)[0].mean(0)) * const 138 | scale = 1 / (MeanTOPK + self.eps) 139 | 140 | self.running_mean.mul_(self.momentum).add_( 141 | mean.data * (1 - self.momentum)) 142 | 143 | self.running_var.mul_(self.momentum).add_( 144 | scale.data * (1 - self.momentum)) 145 | else: 146 | mean = torch.autograd.Variable(self.running_mean) 147 | scale = torch.autograd.Variable(self.running_var) 148 | 149 | out = (x - mean.view(1, mean.size(0), 1, 1)) * \ 150 | scale.view(1, scale.size(0), 1, 1) 151 | 152 | if self.noise and self.training: 153 | std = 0.1 * _std(x, self.dim).data 154 | ones = torch.ones_like(x.data) 155 | std_noise = Variable(torch.normal(ones, ones) * std) 156 | out = out * std_noise 157 | 158 | if self.weight is not None: 159 | out = out * self.weight.view(1, self.weight.size(0), 1, 1) 160 | 161 | if self.bias is not None: 162 | out = out + self.bias.view(1, self.bias.size(0), 1, 1) 163 | return out 164 | 165 | # Top10 166 | 167 | 168 | class GhostTopkBatchNorm2d(nn.Module): 169 | # This is normalized Top10 batch norm 170 | 171 | def __init__(self, num_features, k=10, dim=1, momentum=0.1, bias=True, eps=1e-5, beta=0.75, noise=False): 172 | super(GhostTopkBatchNorm2d, self).__init__() 173 | self.register_buffer('running_mean', torch.zeros(num_features)) 174 | self.register_buffer('running_var', torch.zeros(num_features)) 175 | 176 | self.momentum = momentum 177 | self.dim = dim 178 | self.register_buffer('meanTOPK', torch.zeros(num_features)) 179 | self.noise = noise 180 | self.k = k 181 | self.beta = 0.75 182 | self.eps = eps 183 | self.bias = Parameter(torch.Tensor(num_features)) 184 | self.weight = Parameter(torch.Tensor(num_features)) 185 | 186 | def forward(self, x): 187 | # p=5 188 | if self.training: 189 | mean = x.view(x.size(0), x.size(self.dim), -1).mean(-1).mean(0) 190 | y = x.transpose(0, 1) 191 | z = y.contiguous() 192 | t = z.view(z.size(0), -1) 193 | A = torch.abs(t.transpose(1, 0) - mean) 194 | beta = 0.75 195 | 196 | MeanTOPK = torch.topk(A, self.k, dim=0)[0].mean(0) 197 | meanTOPK = beta * \ 198 | torch.autograd.variable.Variable( 199 | self.biasTOPK) + (1 - beta) * MeanTOPK 200 | 201 | const = 0.5 * (1 + (np.pi * np.log(4)) ** 0.5) / \ 202 | ((2 * np.log(A.size(0))) ** 0.5) 203 | meanTOPK = meanTOPK * const 204 | 205 | # print(self.biasTOPK) 206 | self.biasTOPK.copy_(meanTOPK.data) 207 | # self.biasTOPK = MeanTOPK.data 208 | scale = 1 / (meanTOPK + self.eps) 209 | 210 | self.running_mean.mul_(self.momentum).add_( 211 | mean.data * (1 - self.momentum)) 212 | 213 | self.running_var.mul_(self.momentum).add_( 214 | scale.data * (1 - self.momentum)) 215 | else: 216 | mean = torch.autograd.Variable(self.running_mean) 217 | scale = torch.autograd.Variable(self.running_var) 218 | 219 | out = (x - mean.view(1, mean.size(0), 1, 1)) * \ 220 | scale.view(1, scale.size(0), 1, 1) 221 | # out = (x - mean.view(1, mean.size(0), 1, 1)) * final_scale.view(1, scale.size(0), 1, 1) 222 | 223 | if self.noise and self.training: 224 | std = 0.1 * _std(x, self.dim).data 225 | ones = torch.ones_like(x.data) 226 | std_noise = Variable(torch.normal(ones, ones) * std) 227 | out = out * std_noise 228 | 229 | if self.weight is not None: 230 | out = out * self.weight.view(1, self.weight.size(0), 1, 1) 231 | 232 | if self.bias is not None: 233 | out = out + self.bias.view(1, self.bias.size(0), 1, 1) 234 | return out 235 | 236 | 237 | # L1 238 | class L1BatchNorm2d(nn.Module): 239 | # This is normalized L1 Batch norm; note the normalization term (np.pi / 2) ** 0.5) when multiplying by Var: 240 | # scale = ((Var * (np.pi / 2) ** 0.5) + self.eps) ** (-1) 241 | 242 | """docstring for L1BatchNorm2d.""" 243 | 244 | def __init__(self, num_features, dim=1, momentum=0.1, bias=True, normalized=True, eps=1e-5, noise=False): 245 | super(L1BatchNorm2d, self).__init__() 246 | self.register_buffer('running_mean', torch.zeros(num_features)) 247 | self.register_buffer('running_var', torch.zeros(num_features)) 248 | self.momentum = momentum 249 | self.dim = dim 250 | self.noise = noise 251 | self.bias = Parameter(torch.Tensor(num_features)) 252 | self.weight = Parameter(torch.Tensor(num_features)) 253 | self.eps = eps 254 | if normalized: 255 | self.weight_fix = (np.pi / 2) ** 0.5 256 | else: 257 | self.weight_fix = 1 258 | 259 | def forward(self, x): 260 | p = 1 261 | if self.training: 262 | mean = x.view(x.size(0), x.size(self.dim), -1).mean(-1).mean(0) 263 | y = x.transpose(0, 1) 264 | z = y.contiguous() 265 | t = z.view(z.size(0), -1) 266 | Var = (torch.abs((t.transpose(1, 0) - mean))).mean(0) 267 | scale = (Var * self.weight_fix + self.eps) ** (-1) 268 | self.running_mean.mul_(self.momentum).add_( 269 | mean.data * (1 - self.momentum)) 270 | 271 | self.running_var.mul_(self.momentum).add_( 272 | scale.data * (1 - self.momentum)) 273 | else: 274 | mean = torch.autograd.Variable(self.running_mean) 275 | scale = torch.autograd.Variable(self.running_var) 276 | 277 | out = (x - mean.view(1, mean.size(0), 1, 1)) * \ 278 | scale.view(1, scale.size(0), 1, 1) 279 | 280 | if self.noise and self.training: 281 | std = 0.1 * _std(x, self.dim).data 282 | ones = torch.ones_like(x.data) 283 | std_noise = Variable(torch.normal(ones, ones) * std) 284 | out = out * std_noise 285 | 286 | if self.weight is not None: 287 | out = out * self.weight.view(1, self.weight.size(0), 1, 1) 288 | 289 | if self.bias is not None: 290 | out = out + self.bias.view(1, self.bias.size(0), 1, 1) 291 | return out 292 | -------------------------------------------------------------------------------- /mpip_compression_pytorch_multi.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pulp import * 3 | import numpy as np 4 | import argparse 5 | from main import main_with_args as main_per_layer 6 | 7 | Debug = False 8 | 9 | def mpip_compression(files=None, replace_precisions=None, Degradation=None, noise=None, method='acc', base_precision=8): 10 | 11 | data = {} 12 | for f, prec in zip(files, replace_precisions): 13 | data[prec] = pd.read_csv(f) 14 | 15 | if Degradation is None: 16 | Degradation = 0.18 17 | 18 | bops=False 19 | metric = 'MACs' if bops else 'Parameters Size [Elements]' 20 | 21 | if method=='acc': 22 | acc=True 23 | elif method=='loss': 24 | acc=False 25 | measurement = 'accuracy' if acc else 'loss' 26 | 27 | po = 2 if bops else 1 28 | prob = LpProblem('BitAllocationProblem',LpMinimize) 29 | Combinations={}; accLoss={}; memorySaved={}; Indicators={}; S={}; DeltaL={} 30 | 31 | num_layers = len(data[replace_precisions[0]]['base precision']) - 1 32 | 33 | base_accuracy = data[replace_precisions[0]][measurement][0] 34 | total_mac=0 35 | for l in range(1,num_layers+1): 36 | layer = data[replace_precisions[0]]['replaced layer'][l] 37 | total_mac+= int(data[replace_precisions[0]][metric][l]) 38 | base_performance = int(data[replace_precisions[0]][metric][l]) * (base_precision ** po) 39 | acc_layer = {} 40 | performance = {} 41 | Combinations[layer] = [] 42 | accLoss[layer] = {} 43 | memorySaved[layer] = {} 44 | for prec in replace_precisions: 45 | acc_layer[prec] = data[prec][measurement][l] 46 | performance[prec] = int(data[prec][metric][l]) * (prec ** po) 47 | Combinations[layer].append(layer + '_{}W_{}A'.format(prec, prec)) 48 | if acc: 49 | accLoss[layer][layer + '_{}W_{}A'.format(prec, prec)] = max(base_accuracy - acc_layer[prec], 1e-6) 50 | else: 51 | accLoss[layer][layer + '_{}W_{}A'.format(prec, prec)] = max(acc_layer[prec] - base_accuracy, 1e-6) 52 | if noise is not None: 53 | accLoss[layer][layer + '_{}W_{}A'.format(prec, prec)] += noise * np.random.normal() * accLoss[layer][layer + '_{}W_{}A'.format(prec, prec)] 54 | memorySaved[layer][layer + '_{}W_{}A'.format(prec, prec)] = base_performance - performance[prec] 55 | Combinations[layer].append(layer + '_{}W_{}A'.format(base_precision, base_precision)) 56 | accLoss[layer][layer + '_{}W_{}A'.format(base_precision, base_precision)] = 0 57 | memorySaved[layer][layer + '_{}W_{}A'.format(base_precision, base_precision)] = 0 58 | Indicators[layer] = LpVariable.dicts("indicator"+layer,Combinations[layer],0,1,LpInteger) 59 | S[layer] =LpVariable("S"+layer, 0) 60 | DeltaL[layer] =LpVariable("DeltaL"+layer, 0) 61 | 62 | prob += lpSum([S[layer] for layer in S.keys()]) # Objective (minimize acc loss) 63 | 64 | total_performance=total_mac*base_precision**po 65 | 66 | for l in range(1,num_layers+1): # range(1,3):# 67 | layer = data[replace_precisions[0]]['replaced layer'][l] 68 | prob += lpSum([Indicators[layer][i] * accLoss[layer][i] for i in Combinations[layer]]) == S[layer] # Accuracy loss per layer 69 | prob += lpSum([Indicators[layer][i] for i in Combinations[layer]]) == 1 # Constraint of only one indicator==1 70 | prob += lpSum([Indicators[layer][i] * memorySaved[layer][i] for i in Combinations[layer]]) == DeltaL[layer] # Acc loss per layer 71 | 72 | prob += lpSum([DeltaL[layer] for layer in DeltaL.keys()]) >= total_performance*(1- Degradation*(32/base_precision)) # Total acc loss constraint 73 | 74 | prob.solve() 75 | LpStatus[prob.status] 76 | 77 | print('optimal solution for total degradation D = ' + str(Degradation)+':') 78 | if Debug: 79 | for v in prob.variables(): 80 | print(v.name, "=", v.varValue) 81 | 82 | print(value(prob.objective)) 83 | 84 | if (prob.status==-1): 85 | print('Infeasable') 86 | 87 | expected_acc_deg = sum([S[layer].varValue for layer in S.keys()]) 88 | reduced_performance=sum([DeltaL[layer].varValue for layer in DeltaL.keys()]) 89 | 90 | sol = {} 91 | memory_reduced = 0 92 | acc_deg = 0 93 | policy = [] 94 | all_precisions = replace_precisions + [base_precision] 95 | total_params = {} 96 | for prec in all_precisions: 97 | total_params[prec] = 0 98 | for l in range(1, num_layers + 1): 99 | layer = data[replace_precisions[0]]['replaced layer'][l] 100 | for prec in all_precisions: 101 | if Indicators[layer][layer + '_{}W_{}A'.format(prec, prec)].varValue: 102 | policy.append(prec) 103 | sol[layer] = [prec, prec] 104 | memory_reduced += memorySaved[layer][layer + '_{}W_{}A'.format(prec, prec)] 105 | acc_deg += accLoss[layer][layer + '_{}W_{}A'.format(prec, prec)] 106 | total_params[prec] += int(data[replace_precisions[0]][metric][l]) 107 | 108 | print('Final Solution: ', sol) 109 | print('Policy: ', policy) 110 | print('Achieved compression: ', (total_performance - memory_reduced) / (total_performance * (32/base_precision))) 111 | if acc: 112 | expected_acc = base_accuracy - acc_deg 113 | else: 114 | expected_acc = base_accuracy + acc_deg 115 | print('Expected acc: ', expected_acc) 116 | for prec in all_precisions: 117 | print('Params % in int {} = {}'.format(prec, total_params[prec] / total_mac)) 118 | 119 | return sol, expected_acc, (total_performance - reduced_performance) / (total_performance * (32/base_precision)), policy 120 | 121 | 122 | def get_args(): 123 | parser = argparse.ArgumentParser(description='PyTorch Reinforcement Learning') 124 | 125 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+', 126 | help='device ids assignment (e.g 0 1 2 3') 127 | parser.add_argument('--ip_method', type=str, default='loss', help='IP optimization target, loss / acc') 128 | parser.add_argument('--model', type=str, default='resnet', help='model to use') 129 | parser.add_argument('--model_vis', type=str, default='resnet50', help='torchvision model name') 130 | parser.add_argument('--num_exp', default=1, type=int, help='number of experiments per compression level') 131 | parser.add_argument('--sigma', default=None, type=float, help='sigma noise to add to measurements') 132 | parser.add_argument('--layer_by_layer_files', type=str, default='./results/resnet50_w8a8_adaquant/resnet.absorb_bn.measure.adaquant.per_layer_accuracy.csv', help='layer degradation csv file') 133 | parser.add_argument('--datasets-dir', type=str, default='/media/drive/Datasets', help='dataset dir') 134 | parser.add_argument('--precisions', type=str, default='8;4', help='precisions, base first, separated by ;') 135 | parser.add_argument('--max_compression', type=float, default='0.25', help='max compression to test') 136 | parser.add_argument('--min_compression', type=float, default='0.13', help='min compression to test') 137 | parser.add_argument('--suffix', type=str, default='', help='suffix to add to all outputs') 138 | parser.add_argument('--do_not_use_adaquant', action='store_true', default=False, 139 | help='use non optimized model') 140 | parser.add_argument('--eval_on_train', action='store_true', default=False, 141 | help='evaluate on calibration data') 142 | 143 | args = parser.parse_args() 144 | return args 145 | 146 | 147 | args = get_args() 148 | 149 | compressions = np.arange(args.min_compression, args.max_compression, 0.01) 150 | sigma = args.sigma 151 | num_exp = args.num_exp 152 | ip_method = args.ip_method 153 | files = args.layer_by_layer_files.split(';') 154 | precisions = [int(i) for i in args.precisions.split(';')] 155 | replace_precisions = precisions[1:] 156 | datasets_dir = args.datasets_dir 157 | model = args.model 158 | model_vis = args.model_vis 159 | if args.do_not_use_adaquant: 160 | workdirs = [os.path.join('results', model_vis + '_w{}a{}'.format(i, i)) for i in precisions] 161 | else: 162 | workdirs = [os.path.join('results', model_vis + '_w{}a{}.adaquant'.format(i, i)) for i in precisions] 163 | eval_dir = os.path.join(workdirs[0], model + '.absorb_bn') 164 | 165 | perC=True 166 | num_sp_layers=0 167 | model_config = {'batch_norm': False,'measure': False, 'perC': perC} 168 | if model_vis=='resnet18': 169 | model_config['depth'] = 18 170 | 171 | output_fname = os.path.join(workdirs[0], 'IP_{}_{}{}.txt'.format(model_vis, ip_method, args.suffix)) 172 | 173 | eval_dict = {'model': model, 174 | 'evaluate': eval_dir, 175 | 'dataset': 'imagenet_calib', 176 | 'datasets_dir': datasets_dir, 177 | 'b': 100, 178 | 'model_config': model_config, 179 | 'mixed_builder': True, 180 | 'device_ids': args.device_ids, 181 | 'precisions': precisions} 182 | 183 | if args.do_not_use_adaquant: 184 | eval_dict['opt_model_paths'] = [os.path.join(dd, model + '.absorb_bn.measure_perC') for dd in workdirs] 185 | else: 186 | eval_dict['opt_model_paths'] = [os.path.join(dd, model + '.absorb_bn.measure_perC.adaquant') for dd in workdirs] 187 | 188 | if args.eval_on_train: 189 | eval_dict['eval_on_train'] = True 190 | 191 | solutions = [] 192 | expected_accuracies = [] 193 | state_dict_path=[] 194 | actual_compressions = [] 195 | actual_accuracies = [] 196 | actual_losses = [] 197 | policies = [] 198 | completed = 0 199 | start_from = 0 200 | for Deg in compressions: 201 | if completed < start_from: 202 | completed += 1 203 | solutions.append('') 204 | state_dict_path.append('') 205 | policies.append([]) 206 | expected_accuracies.append(0) 207 | actual_compressions.append(0) 208 | actual_accuracies.append(0) 209 | actual_losses.append(0) 210 | continue 211 | attempted_policies = {} 212 | valid_exp = 0 213 | while valid_exp < num_exp: 214 | if Debug: 215 | import pdb; pdb.set_trace() 216 | sol, expect_acc, comp, policy = mpip_compression(files=files, replace_precisions=replace_precisions, Degradation=Deg, noise=sigma, method=ip_method) 217 | if str(policy) in attempted_policies.keys(): 218 | continue 219 | valid_exp += 1 220 | 221 | eval_dict['names_sp_layers'] = sol 222 | eval_dict['suffix'] = 'comp_{}_{}{}'.format( "{:.2f}".format(Deg), ip_method, args.suffix) 223 | acc, loss = main_per_layer(**eval_dict) 224 | # acc = 0.11; loss = 0.9 225 | # import pdb; pdb.set_trace() 226 | 227 | attempted_policies[str(policy)] = acc 228 | 229 | solutions.append(sol) 230 | policies.append(policy.copy()) 231 | expected_accuracies.append(expect_acc) 232 | actual_compressions.append(comp) 233 | actual_accuracies.append(acc) 234 | actual_losses.append(loss) 235 | state_dict_path.append(eval_dict['evaluate']+'.mixed-ip-results.'+eval_dict['suffix']) 236 | completed += 1 237 | c = 0 238 | for d in compressions: 239 | for exp in range(num_exp): 240 | if c >= completed: 241 | break 242 | print('Compression thr {}, experiment {},state_dict_path {}, compression {}, expected {} {}, actual acc {}, actual loss {}'.format("{:.2f}".format(d), exp, state_dict_path[c], actual_compressions[c], 243 | ip_method, expected_accuracies[c], 244 | actual_accuracies[c], actual_losses[c])) 245 | print('Policy: {}'.format(policies[c])) 246 | print('Configuration = {}'.format(solutions[c])) 247 | c += 1 248 | 249 | 250 | with open(output_fname, 'w') as pid: 251 | line = 'Compression thr\tExperiment\tstate_dict_path\tActual compression\tExpected {}\tActual Accuracy\tActual loss\tPolicy\tConfiguration\n'.format(ip_method) 252 | pid.write(line) 253 | c = 0 254 | for Deg in compressions: 255 | for exp in range(num_exp): 256 | print('Compression thr {}, experiment {}, state_dict_path {}, actual_compression {}, expected {} {}, actual acc {}, actual loss {}'.format(Deg, exp,state_dict_path[c], actual_compressions[c], ip_method, expected_accuracies[c], actual_accuracies[c], actual_losses[c])) 257 | print('Policy: {}'.format(policies[c])) 258 | line = '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( "{:.2f}".format(Deg), exp, state_dict_path[c], actual_compressions[c], expected_accuracies[c], actual_accuracies[c], actual_losses[c], policies[c], solutions[c]) 259 | pid.write(line) 260 | c += 1 261 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import math 5 | from .modules.se import SEBlock 6 | from .modules.quantize import QConv2d,QLinear,RangeBN 7 | #from .modules.quantize import QConv2d_o as QConv2d 8 | #from .modules.quantize import QLinear_o as QLinear 9 | #from .modules.quantize import RangeBN 10 | __all__ = ['resnet', 'resnet_se'] 11 | 12 | class Lambda(nn.Module): 13 | def __init__(self): 14 | super(Lambda, self).__init__() 15 | 16 | def forward(self,x): 17 | return x 18 | 19 | def depBatchNorm2d(exists, *kargs, **kwargs): 20 | if exists: 21 | return nn.BatchNorm2d(*kargs, **kwargs) 22 | else: 23 | return Lambda() 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, bias=False,num_bits=8,num_bits_weight=8,measure=False, cal_qparams=False): 27 | "3x3 convolution with padding" 28 | return QConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, groups=groups, bias=bias,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams) 30 | 31 | 32 | def init_model(model): 33 | for m in model.modules(): 34 | if isinstance(m, QConv2d): 35 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 36 | m.weight.data.normal_(0, math.sqrt(2. / n)) 37 | elif isinstance(m, nn.BatchNorm2d): 38 | m.weight.data.fill_(1) 39 | m.bias.data.zero_() 40 | for m in model.modules(): 41 | if isinstance(m, Bottleneck): 42 | nn.init.constant_(m.bn3.weight, 0) 43 | elif isinstance(m, BasicBlock): 44 | nn.init.constant_(m.bn2.weight, 0) 45 | 46 | model.fc.weight.data.normal_(0, 0.01) 47 | model.fc.bias.data.zero_() 48 | 49 | 50 | class BasicBlock(nn.Module): 51 | 52 | def __init__(self, inplanes, planes, stride=1, expansion=1, downsample=None, groups=1, residual_block=None,batch_norm=True,measure=False,num_bits=8,num_bits_weight=8, cal_qparams=False): 53 | super(BasicBlock, self).__init__() 54 | self.conv1 = conv3x3(inplanes, planes, stride, groups=groups,bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams) 55 | self.bn1 = depBatchNorm2d(batch_norm,planes) 56 | self.relu1 = nn.ReLU(inplace=False) 57 | self.conv2 = conv3x3(planes, expansion * planes, groups=groups,bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams) 58 | self.bn2 = depBatchNorm2d(batch_norm, expansion * planes) 59 | self.relu2 = nn.ReLU(inplace=False) 60 | self.downsample = downsample 61 | self.residual_block = residual_block 62 | self.stride = stride 63 | self.expansion = expansion 64 | 65 | def forward(self, x): 66 | residual = x 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | #import pdb; pdb.set_trace() 70 | out = self.relu1(out) 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(residual) 76 | 77 | if self.residual_block is not None: 78 | residual = self.residual_block(residual) 79 | out += residual 80 | out = self.relu2(out) 81 | return out 82 | 83 | 84 | class Bottleneck(nn.Module): 85 | 86 | def __init__(self, inplanes, planes, stride=1, expansion=4, downsample=None, groups=1, residual_block=None,batch_norm=True,measure=False,num_bits=8,num_bits_weight=8, cal_qparams=False): 87 | super(Bottleneck, self).__init__() 88 | self.conv1 = QConv2d( 89 | inplanes, planes, kernel_size=1, bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams) 90 | self.bn1 = depBatchNorm2d(batch_norm, planes) 91 | self.conv2 = conv3x3(planes, planes, stride=stride, groups=groups,bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams) 92 | self.bn2 = depBatchNorm2d(batch_norm, planes) 93 | self.conv3 = QConv2d( 94 | planes, planes * expansion, kernel_size=1, bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams) 95 | self.bn3 = depBatchNorm2d(batch_norm, planes * expansion) 96 | self.relu1 = nn.ReLU(inplace=False) 97 | self.relu2 = nn.ReLU(inplace=False) 98 | self.relu3 = nn.ReLU(inplace=False) 99 | self.downsample = downsample 100 | self.residual_block = residual_block 101 | self.stride = stride 102 | self.expansion = expansion 103 | 104 | def forward(self, x): 105 | residual = x 106 | out = self.conv1(x) 107 | out = self.bn1(out) 108 | out = self.relu1(out) 109 | out = self.conv2(out) 110 | out = self.bn2(out) 111 | out = self.relu2(out) 112 | 113 | out = self.conv3(out) 114 | out = self.bn3(out) 115 | 116 | if self.downsample is not None: 117 | residual = self.downsample(residual) 118 | 119 | if self.residual_block is not None: 120 | residual = self.residual_block(residual) 121 | 122 | out += residual 123 | out = self.relu3(out) 124 | 125 | return out 126 | 127 | 128 | class ResNet(nn.Module): 129 | 130 | def __init__(self): 131 | super(ResNet, self).__init__() 132 | 133 | def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, batch_norm=True, num_bits=8, num_bits_weight=8, perC=True,measure=False, cal_qparams=False): 134 | downsample = None 135 | out_planes = planes * expansion 136 | if stride != 1 or self.inplanes != out_planes: 137 | downsample = nn.Sequential( 138 | QConv2d(self.inplanes, out_planes, 139 | kernel_size=1, stride=stride, bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams), 140 | depBatchNorm2d(batch_norm,planes * expansion), 141 | ) 142 | if residual_block is not None: 143 | residual_block = residual_block(out_planes) 144 | 145 | layers = [] 146 | layers.append(block(self.inplanes, planes, stride, expansion=expansion, 147 | downsample=downsample, groups=groups, residual_block=residual_block,batch_norm=batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams)) 148 | self.inplanes = planes * expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes, expansion=expansion, groups=groups, 151 | residual_block=residual_block,batch_norm=batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,measure=measure, cal_qparams=cal_qparams)) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def features(self, x): 156 | x = self.conv1(x) 157 | x = self.bn1(x) 158 | x = self.relu(x) 159 | x = self.maxpool(x) 160 | 161 | x = self.layer1(x) 162 | x = self.layer2(x) 163 | x = self.layer3(x) 164 | x = self.layer4(x) 165 | 166 | x = self.avgpool(x) 167 | return x.view(x.size(0), -1) 168 | 169 | def forward(self, x): 170 | x = self.features(x) 171 | x = self.fc(x) 172 | return x 173 | 174 | @staticmethod 175 | def regularization_pre_step(model, weight_decay=1e-4): 176 | with torch.no_grad(): 177 | for m in model.modules(): 178 | if isinstance(m, QConv2d) or isinstance(m, nn.Linear): 179 | if m.weight.grad is not None: 180 | m.weight.grad.add_(weight_decay * m.weight) 181 | return 0 182 | 183 | 184 | class ResNet_imagenet(ResNet): 185 | 186 | def __init__(self, num_classes=1000, inplanes=64, 187 | block=Bottleneck, residual_block=None, layers=[3, 4, 23, 3], 188 | width=[64, 128, 256, 512], expansion=4, groups=[1, 1, 1, 1], 189 | regime='normal', scale_lr=1,batch_norm=True,num_bits=8,num_bits_weight=8, perC=True, measure=False, cal_qparams=False): 190 | super(ResNet_imagenet, self).__init__() 191 | self.inplanes = inplanes 192 | self.conv1 = QConv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 193 | bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight, perC=perC ,measure=measure, cal_qparams=cal_qparams) 194 | self.bn1 = depBatchNorm2d(batch_norm,self.inplanes) 195 | self.relu = nn.ReLU(inplace=False) 196 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 197 | for i in range(len(layers)): 198 | #if i==2 or i==1: 199 | # print(i) 200 | # num_bits = 4 201 | # num_bits_weight = 4 202 | setattr(self, 'layer%s' % str(i + 1), 203 | self._make_layer(block=block, planes=width[i], blocks=layers[i], expansion=expansion, 204 | stride=1 if i == 0 else 2, residual_block=residual_block, groups=groups[i],batch_norm=batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,perC=perC, measure=measure, cal_qparams=cal_qparams)) 205 | 206 | self.avgpool = nn.AdaptiveAvgPool2d(1) 207 | self.fc = QLinear(width[-1] * expansion, num_classes,num_bits_weight=num_bits_weight,perC=perC,measure=measure, cal_qparams=cal_qparams) 208 | if batch_norm: 209 | init_model(self) 210 | 211 | def ramp_up_lr(lr0, lrT, T): 212 | rate = (lrT - lr0) / T 213 | return "lambda t: {'lr': %s + t * %s}" % (lr0, rate) 214 | if regime == 'normal': 215 | self.regime = [ 216 | {'epoch': 0, 'optimizer': 'SGD', 'momentum': 0.9, 217 | 'step_lambda': ramp_up_lr(0.1, 0.1 * scale_lr, 5004 * 5 / scale_lr)}, 218 | {'epoch': 5, 'lr': scale_lr * 1e-1}, 219 | {'epoch': 30, 'lr': scale_lr * 1e-2}, 220 | {'epoch': 60, 'lr': scale_lr * 1e-3}, 221 | {'epoch': 80, 'lr': scale_lr * 1e-4} 222 | ] 223 | elif regime == 'fast': 224 | self.regime = [ 225 | {'epoch': 0, 'optimizer': 'SGD', 'momentum': 0.9, 226 | 'step_lambda': ramp_up_lr(0.1, 0.1 * 4 * scale_lr, 5004 * 4 / (4 * scale_lr))}, 227 | {'epoch': 4, 'lr': 4 * scale_lr * 1e-1}, 228 | {'epoch': 18, 'lr': scale_lr * 1e-1}, 229 | {'epoch': 21, 'lr': scale_lr * 1e-2}, 230 | {'epoch': 35, 'lr': scale_lr * 1e-3}, 231 | {'epoch': 43, 'lr': scale_lr * 1e-4}, 232 | ] 233 | self.data_regime = [ 234 | {'epoch': 0, 'input_size': 128, 'batch_size': 256}, 235 | {'epoch': 18, 'input_size': 224, 'batch_size': 64}, 236 | {'epoch': 41, 'input_size': 288, 'batch_size': 32}, 237 | ] 238 | elif regime == 'small': 239 | scale_lr *= 4 240 | self.regime = [ 241 | {'epoch': 0, 'optimizer': 'SGD', 242 | 'momentum': 0.9, 'lr': scale_lr * 1e-1}, 243 | {'epoch': 30, 'lr': scale_lr * 1e-2}, 244 | {'epoch': 60, 'lr': scale_lr * 1e-3}, 245 | {'epoch': 80, 'lr': scale_lr * 1e-4} 246 | ] 247 | self.data_regime = [ 248 | {'epoch': 0, 'input_size': 128, 'batch_size': 256}, 249 | {'epoch': 80, 'input_size': 224, 'batch_size': 64}, 250 | ] 251 | self.data_eval_regime = [ 252 | {'epoch': 0, 'input_size': 128, 'batch_size': 1024}, 253 | {'epoch': 80, 'input_size': 224, 'batch_size': 512}, 254 | ] 255 | 256 | 257 | class ResNet_cifar(ResNet): 258 | 259 | def __init__(self, num_classes=10, inplanes=16, 260 | block=BasicBlock, depth=18, width=[16, 32, 64], 261 | groups=[1, 1, 1], residual_block=None,batch_norm=True, num_bits=8, num_bits_weight=8, perC=True, measure=False, cal_qparams=False): 262 | super(ResNet_cifar, self).__init__() 263 | #inplanes=4 264 | #width=[4, 8, 16] 265 | self.inplanes = inplanes 266 | n = int((depth - 2) / 6) 267 | self.conv1 = QConv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 268 | bias=not batch_norm,num_bits=num_bits,num_bits_weight=num_bits_weight,perC=perC, measure=measure) 269 | self.bn1 = depBatchNorm2d(batch_norm,self.inplanes) 270 | self.relu = nn.ReLU(inplace=False) 271 | self.maxpool = lambda x: x 272 | self.layer1 = self._make_layer(block, width[0], n, groups=groups[ 273 | 0], residual_block=residual_block,batch_norm=batch_norm, num_bits=num_bits,num_bits_weight=num_bits_weight,perC=perC, measure=measure, cal_qparams=cal_qparams) 274 | self.layer2 = self._make_layer( 275 | block, width[1], n, stride=2, groups=groups[1], residual_block=residual_block,batch_norm=batch_norm, num_bits=num_bits,num_bits_weight=num_bits_weight,perC=perC, measure=measure, cal_qparams=cal_qparams) 276 | self.layer3 = self._make_layer( 277 | block, width[2], n, stride=2, groups=groups[2], residual_block=residual_block,batch_norm=batch_norm, num_bits=num_bits,num_bits_weight=num_bits_weight,perC=perC, measure=measure, cal_qparams=cal_qparams) 278 | self.layer4 = lambda x: x 279 | self.avgpool = nn.AvgPool2d(8) 280 | self.fc = nn.Linear(width[-1], num_classes) 281 | if batch_norm: 282 | init_model(self) 283 | self.regime = [ 284 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 285 | 'weight_decay': 0, 'momentum': 0.9}, 286 | {'epoch': 81, 'lr': 1e-2}, 287 | {'epoch': 122, 'lr': 1e-3, 'weight_decay': 0}, 288 | {'epoch': 164, 'lr': 1e-4} 289 | ] 290 | 291 | 292 | def resnet(**config): 293 | dataset = config.pop('dataset', 'imagenet') 294 | 295 | bn_norm = config.pop('bn_norm', None) 296 | if bn_norm is not None: 297 | from .modules.lp_norm import L1BatchNorm2d, TopkBatchNorm2d 298 | if bn_norm == 'L1': 299 | torch.nn.BatchNorm2d = L1BatchNorm2d 300 | if bn_norm == 'TopK': 301 | torch.nn.BatchNorm2d = TopkBatchNorm2d 302 | 303 | if dataset == 'imagenet': 304 | config.setdefault('num_classes', 1000) 305 | depth = config.pop('depth', 50) 306 | if depth == 18: 307 | config.update(dict(block=BasicBlock, 308 | layers=[2, 2, 2, 2], 309 | expansion=1)) 310 | if depth == 34: 311 | config.update(dict(block=BasicBlock, 312 | layers=[3, 4, 6, 3], 313 | expansion=1)) 314 | if depth == 50: 315 | config.update(dict(block=Bottleneck, layers=[3, 4, 6, 3])) 316 | if depth == 101: 317 | config.update(dict(block=Bottleneck, layers=[3, 4, 23, 3])) 318 | if depth == 152: 319 | config.update(dict(block=Bottleneck, layers=[3, 8, 36, 3])) 320 | 321 | return ResNet_imagenet(**config) 322 | 323 | elif dataset == 'cifar10': 324 | config.setdefault('num_classes', 10) 325 | config.setdefault('depth', 44) 326 | return ResNet_cifar(block=BasicBlock, **config) 327 | 328 | elif dataset == 'cifar100': 329 | config.setdefault('num_classes', 100) 330 | config.setdefault('depth', 44) 331 | return ResNet_cifar(block=BasicBlock, **config) 332 | 333 | 334 | def resnet_se(**config): 335 | config['residual_block'] = SEBlock 336 | return resnet(**config) 337 | 338 | --------------------------------------------------------------------------------