├── .gitignore ├── README.md ├── __init__.py ├── clustering ├── __init__.py ├── kmeans.py └── pairwise.py ├── data ├── resnet18_l0l1_WNoneA2.pkl ├── resnet18_l0l1_WNoneA3.pkl ├── resnet18_l0l1_WNoneA4.pkl └── results │ └── cd_vs_powell.csv ├── experimets ├── ablation_cd.py ├── ablation_powell.py ├── ablation_powell_l3.5.py ├── ablation_powell_l3.py ├── ablation_powell_l4.py ├── ablation_separability.py ├── ablation_study.py ├── cd_vs_powell.py ├── cd_vs_powell_l1.py ├── cd_vs_powell_l3.py ├── lapq_init_study.py └── moments.py ├── fig ├── cd_vs_powell_32W2A.png ├── cd_vs_powell_3W32A.png ├── cd_vs_powell_4W32A.png ├── contour_2bit.png ├── contour_4bit.png ├── heatmap_legend.png ├── lapq_intuition.pdf ├── lpnorm_err_2bit.pdf ├── lpnorm_err_2bit.png ├── lpnorm_err_4bit.pdf ├── lpnorm_err_4bit.png ├── p_vs_lpnorm.pdf ├── p_vs_opt_clipping.pdf ├── quant_matmul_sep_2bit.pdf ├── quant_matmul_sep_4bit.pdf ├── res18_CalibrationVsAcc.pdf ├── res18_l14l15_contour.pdf ├── res18_l14l15_contour.png ├── res18_weights_thresholds.png ├── res50_p_vs_acc.pdf ├── res50_p_vs_acc.png ├── res50_res18_pnorms.pdf ├── res50_res18_pnorms.png ├── resnet18_A2_3d.png ├── resnet18_A2_3d_colorfull.pdf ├── resnet18_A2_3d_colorfull.png ├── resnet18_A2_3d_zoom.png ├── resnet18_act_2bit_hessian.pdf ├── resnet18_act_2bit_hessian.png ├── resnet18_act_2bit_hessian_norm.png ├── resnet18_act_4bit_hessian.pdf ├── resnet18_act_4bit_hessian.png ├── resnet18_act_4bit_hessian_norm.png ├── resnet18_l0l1_WNoneA2.pdf ├── resnet18_l0l1_WNoneA2.png ├── resnet18_l0l1_WNoneA2_3d.png ├── resnet18_l0l1_WNoneA2_topographic_zoom.png ├── resnet18_l0l1_WNoneA3.pdf ├── resnet18_l0l1_WNoneA3.png ├── resnet18_l0l1_WNoneA4.pdf ├── resnet18_l0l1_WNoneA4.png ├── resnet18_loss_vs_p.pdf ├── resnet18_quadratic_loss_vs_distance.pdf └── resnet50_loss_vs_p.pdf ├── jupyter ├── Ablation_CalibrationVsAcc.ipynb ├── anova_separability.ipynb ├── cvpr2020 │ ├── CDvsPowell.ipynb │ ├── hessian_matrix.ipynb │ └── scale_comparison.ipynb ├── error_separability.ipynb ├── lapq_intuition.ipynb ├── loss_mesh_grid_act.ipynb ├── loss_mesh_grid_weights.ipynb ├── loss_visualization.ipynb ├── loss_vs_p_quadratic.ipynb ├── lp_norm_vs_clipping.ipynb ├── lp_vs_acc.ipynb ├── mse_analysis.ipynb ├── mse_optimization_sgd.ipynb ├── ncf_data_proccessing.ipynb ├── quadratic_loss_opt.ipynb ├── quant_sep.ipynb ├── separability_of_quantization_act.ipynb └── separability_of_quantization_weight.ipynb ├── lapq ├── README.md ├── lapq_linear_parametrization.py ├── lapq_v2.py ├── layer_scale_optimization.py └── layer_scale_optimization_opt.py ├── models ├── ShuffleNet.py ├── __init__.py ├── inception.py └── resnet.py ├── quantization ├── __init__.py ├── analysis │ ├── __init__.py │ ├── loss_data_generation.py │ ├── loss_parametrization.py │ ├── loss_parametrization1.py │ ├── loss_vs_p.py │ └── separability_index.py ├── methods │ ├── __init__.py │ ├── clipped_uniform.py │ ├── non_uniform.py │ ├── stochastic.py │ └── uniform.py ├── posttraining │ ├── __init__.py │ ├── cnn_classifier.py │ ├── cnn_classifier_inference.py │ └── module_wrapper.py ├── qat │ ├── __init__.py │ ├── cnn_classifier_train.py │ └── module_wrapper.py └── quantizer.py └── utils ├── __init__.py ├── absorb_bn.py ├── data.py ├── entropy.py ├── experiments_log.py ├── log.py ├── mark_relu.py ├── meters.py ├── misc.py ├── mllog.py ├── model_naming.py ├── preprocess.py └── stats_trucker.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | .pytest_cache 4 | *.tar 5 | venv/ 6 | venv3/ 7 | .env/ 8 | .idea/ 9 | .idea/* 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nn-quantization-pytorch -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/__init__.py -------------------------------------------------------------------------------- /clustering/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/clustering/__init__.py -------------------------------------------------------------------------------- /clustering/kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from clustering.pairwise import pairwise_distance 4 | 5 | def forgy(X, n_clusters): 6 | _len = len(X) 7 | indices = np.random.choice(_len, n_clusters) 8 | initial_state = X[indices] 9 | return initial_state 10 | 11 | 12 | def lloyd1d(X, n_clusters, tol=1e-4, device=None, max_iter=100, init_state=None): 13 | if device is not None: 14 | X = X.to(device) 15 | 16 | if init_state is None: 17 | initial_state = forgy(X, n_clusters).flatten() 18 | else: 19 | initial_state = init_state.clone() 20 | 21 | iter = 0 22 | dis = X.new_empty((n_clusters, X.numel())) 23 | choice_cluster = X.new_empty(X.numel()).int() 24 | centers = torch.arange(n_clusters, device=X.device).view(-1, 1).int() 25 | initial_state_pre = initial_state.clone() 26 | # temp = X.new_empty((n_clusters, X.numel())) 27 | while iter < max_iter: 28 | iter += 1 29 | 30 | # Calculate pair wise distance 31 | dis[:, ] = X.view(1, -1) 32 | dis.sub_(initial_state.view(-1, 1)) 33 | dis.pow_(2) 34 | 35 | choice_cluster[:] = torch.argmin(dis, dim=0).int() 36 | 37 | initial_state_pre[:] = initial_state 38 | 39 | temp = X.view(1, -1) * (choice_cluster == centers).float() 40 | initial_state[:] = temp.sum(1) / (temp != 0).sum(1).float() 41 | 42 | # center_shift = torch.sum(torch.sqrt(torch.sum((initial_state - initial_state_pre) ** 2, dim=1))) 43 | center_shift = torch.sqrt(torch.sum((initial_state - initial_state_pre) ** 2)) 44 | 45 | if center_shift < tol: 46 | break 47 | 48 | return choice_cluster, initial_state 49 | -------------------------------------------------------------------------------- /clustering/pairwise.py: -------------------------------------------------------------------------------- 1 | r''' 2 | calculation of pairwise distance, and return condensed result, i.e. we omit the diagonal and duplicate entries and store everything in a one-dimensional array 3 | ''' 4 | import torch 5 | 6 | def pairwise_distance(data1, data2=None, device=-1): 7 | r''' 8 | using broadcast mechanism to calculate pairwise ecludian distance of data 9 | the input data is N*M matrix, where M is the dimension 10 | we first expand the N*M matrix into N*1*M matrix A and 1*N*M matrix B 11 | then a simple elementwise operation of A and B will handle the pairwise operation of points represented by data 12 | ''' 13 | if data2 is None: 14 | data2 = data1 15 | 16 | if device!=-1: 17 | data1, data2 = data1.cuda(device), data2.cuda(device) 18 | 19 | #N*1*M 20 | A = data1.unsqueeze(dim=1) 21 | 22 | #1*N*M 23 | B = data2.unsqueeze(dim=0) 24 | 25 | dis = (A-B)**2.0 26 | #return N*N matrix for pairwise distance 27 | dis = dis.sum(dim=-1).squeeze() 28 | return dis 29 | 30 | def group_pairwise(X, groups, device=0, fun=lambda r,c: pairwise_distance(r, c).cpu()): 31 | group_dict = {} 32 | for group_index_r, group_r in enumerate(groups): 33 | for group_index_c, group_c in enumerate(groups): 34 | R, C = X[group_r], X[group_c] 35 | if device!=-1: 36 | R = R.cuda(device) 37 | C = C.cuda(device) 38 | group_dict[(group_index_r, group_index_c)] = fun(R, C) 39 | return group_dict 40 | 41 | -------------------------------------------------------------------------------- /data/resnet18_l0l1_WNoneA2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/data/resnet18_l0l1_WNoneA2.pkl -------------------------------------------------------------------------------- /data/resnet18_l0l1_WNoneA3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/data/resnet18_l0l1_WNoneA3.pkl -------------------------------------------------------------------------------- /data/resnet18_l0l1_WNoneA4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/data/resnet18_l0l1_WNoneA4.pkl -------------------------------------------------------------------------------- /data/results/cd_vs_powell.csv: -------------------------------------------------------------------------------- 1 | it,Bits weights,Bits act,l1.CD,l1.Powell,l2.CD,l2.Powell,l3.CD,l3.Powell 2 | init,32,4,58.60%,58.60%,68.10%,68.10%,68.20%,68.60% 3 | iter 1,32,4,67.80%,67.80%,68.70%,68.10%,68.80%,68.20% 4 | iter 2,32,4,68.40%,68.60%,68.70%,68.20%,68.20%,68.20% 5 | init,32,3,38.50%,38.50%,63.30%,63.30%,65.70%,65.70% 6 | iter 1,32,3,64.30%,64.60%,66.40%,64.80%,65.70%,66.20% 7 | iter 2,32,3,66.20%,64.60%,65.40%,65.70%,65.00%,66.20% 8 | init,32,2,4.70%,4.70%,33.00%,33.00%,47.80%,47.80% 9 | iter 1,32,2,23.90%,23.90%,45.60%,47.80%,51.20%,49.00% 10 | iter 2,32,2,41.70%,41.70%,51.80%,51.00%,48.90%,49.10% 11 | init,4,32,0.50%,0.50%,48.70%,48.70%,57.20%,57.20% 12 | iter 1,4,32,34.20%,36.20%,50.00%,61.20%,56.20%,62.20% 13 | iter 2,4,32,54.80%,58.50%,59.70%,62.10%,53.30%,62.60% 14 | init,3,32,0.10%,0.10%,4.00%,4.00%,19.80%,19.80% 15 | iter 1,3,32,0.10%,0.10%,36.30%,36.20%,0.30%,37.90% 16 | iter 2,3,32,0.10%,0.10%,33.00%,42.30%,0.70%,40.70% 17 | init,4,4,0.30%,0.30%,43.60%,43.60%,55.40%,55.40% 18 | iter 1,4,4,0.10%,20.70%,54.30%,60.20%,56.30%,58.50% 19 | iter 2,4,4,0.10%,45.00%,55.30%,60.20%,57.40%,58.60% -------------------------------------------------------------------------------- /experimets/ablation_cd.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | {'model': 'resnet18', 'bs': 256, 'dev': [0]}, 11 | # {'model': 'resnet50', 'bs': 128, 'dev': [5]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'coord_2it' 16 | # qtypes = ['l2_norm'] 17 | qtypes = ['l2_norm', 'aciq_laplace', 'l3_norm', 'max_static'] 18 | 19 | for mset in models_set: 20 | for qt in qtypes: 21 | for bits in [2, 3, 4]: 22 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 23 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 24 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 25 | + ['-ba', str(bits)] + ['--qtype', qt] + "--min_method CD -maxi 2 --init_method dynamic".split(" ") 26 | ) 27 | 28 | for bits in [3, 4]: 29 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 30 | mset['bs'])] 31 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 32 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 33 | + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method CD -maxi 2 --init_method dynamic".split( 34 | " ") 35 | ) 36 | 37 | for bits in [4]: 38 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 39 | mset['bs'])] 40 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 41 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 42 | + ['-ba', str(bits)] + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method CD -maxi 2 --init_method dynamic".split( 43 | " ") 44 | ) 45 | -------------------------------------------------------------------------------- /experimets/ablation_powell.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | # {'model': 'resnet18', 'bs': 512, 'dev': [0]}, 11 | {'model': 'resnet50', 'bs': 128, 'dev': [0]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'powell_res50_l2' 16 | qtypes = ['l2_norm'] 17 | # qtypes = ['l2_norm', 'aciq_laplace', 'l3_norm'] 18 | maxiter = 1 19 | 20 | for mset in models_set: 21 | for qt in qtypes: 22 | for bits in [2, 3, 4]: 23 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 24 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 25 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 26 | + ['-ba', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512".split(" ") 27 | ) 28 | 29 | for bits in [3, 4]: 30 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 31 | mset['bs'])] 32 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 33 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 34 | + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512".split( 35 | " ") 36 | ) 37 | 38 | for bits in [4]: 39 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 40 | mset['bs'])] 41 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 42 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 43 | + ['-ba', str(bits)] + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512".split( 44 | " ") 45 | ) 46 | -------------------------------------------------------------------------------- /experimets/ablation_powell_l3.5.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | # {'model': 'resnet18', 'bs': 512, 'dev': [0]}, 11 | {'model': 'resnet50', 'bs': 128, 'dev': [0]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'powell_res50_l3.5' 16 | qtypes = ['lp_norm'] 17 | # qtypes = ['l2_norm', 'aciq_laplace', 'l3_norm'] 18 | maxiter = 1 19 | 20 | for mset in models_set: 21 | for qt in qtypes: 22 | for bits in [2, 3, 4]: 23 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 24 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 25 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 26 | + ['-ba', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512 -lp 3.5".split(" ") 27 | ) 28 | 29 | for bits in [3, 4]: 30 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 31 | mset['bs'])] 32 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 33 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 34 | + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512 -lp 3.5".split( 35 | " ") 36 | ) 37 | 38 | for bits in [4]: 39 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 40 | mset['bs'])] 41 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 42 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 43 | + ['-ba', str(bits)] + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512 -lp 3.5".split( 44 | " ") 45 | ) 46 | -------------------------------------------------------------------------------- /experimets/ablation_powell_l3.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | # {'model': 'resnet18', 'bs': 512, 'dev': [0]}, 11 | {'model': 'resnet50', 'bs': 128, 'dev': [0]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'powell_res50_l3' 16 | qtypes = ['l3_norm'] 17 | # qtypes = ['l2_norm', 'aciq_laplace', 'l3_norm'] 18 | maxiter = 1 19 | 20 | for mset in models_set: 21 | for qt in qtypes: 22 | for bits in [2, 3, 4]: 23 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 24 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 25 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 26 | + ['-ba', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512".split(" ") 27 | ) 28 | 29 | for bits in [3, 4]: 30 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 31 | mset['bs'])] 32 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 33 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 34 | + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512".split( 35 | " ") 36 | ) 37 | 38 | for bits in [4]: 39 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 40 | mset['bs'])] 41 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 42 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 43 | + ['-ba', str(bits)] + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512".split( 44 | " ") 45 | ) 46 | -------------------------------------------------------------------------------- /experimets/ablation_powell_l4.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | # {'model': 'resnet18', 'bs': 512, 'dev': [0]}, 11 | {'model': 'resnet50', 'bs': 128, 'dev': [0]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'powell_res50_l4' 16 | qtypes = ['lp_norm'] 17 | # qtypes = ['l2_norm', 'aciq_laplace', 'l3_norm'] 18 | maxiter = 1 19 | 20 | for mset in models_set: 21 | for qt in qtypes: 22 | for bits in [2, 3, 4]: 23 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 24 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 25 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 26 | + ['-ba', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512 -lp 4".split(" ") 27 | ) 28 | 29 | for bits in [3, 4]: 30 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 31 | mset['bs'])] 32 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 33 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 34 | + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512 -lp 4".split( 35 | " ") 36 | ) 37 | 38 | for bits in [4]: 39 | run(["python", "quantization/posttraining/layer_scale_optimization.py"] + ['-a', mset['model']] + ['-b', str( 40 | mset['bs'])] 41 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 42 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] + ['-maxi', str(maxiter)] 43 | + ['-ba', str(bits)] + ['-bw', str(bits)] + ['--qtype', qt] + "--min_method Powell --init_method dynamic -cs 512 -lp 4".split( 44 | " ") 45 | ) 46 | -------------------------------------------------------------------------------- /experimets/ablation_separability.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--act_or_weights', '-aow', help='Quantize activations or weights [a, w]', default='a') 8 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 9 | args = parser.parse_args() 10 | 11 | models_set = [ 12 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 13 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 14 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 15 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 16 | {'model': 'resnet18', 'bs': 64, 'dev': [0]}, 17 | # {'model': 'resnet50', 'bs': 128, 'dev': [5]}, 18 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 19 | ] 20 | 21 | layer_type = '-ba' if args.act_or_weights == 'a' else '-bw' 22 | for mset in models_set: 23 | for bits in [7, 8]: 24 | run(["python", "quantization/analysis/separability_index.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 25 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 26 | + "--pretrained --custom_resnet".split(" ") + ['-exp', args.experiment] 27 | + [layer_type, str(bits)] + "-i 7 -n 1000".split(" ") 28 | ) 29 | -------------------------------------------------------------------------------- /experimets/ablation_study.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | 6 | models_set = [ 7 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 9 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 10 | # {'model': 'alexnet', 'bs': 512, 'dev': [5]}, 11 | {'model': 'resnet18', 'bs': 512, 'dev': [5]}, 12 | # {'model': 'resnet50', 'bs': 128, 'dev': [5]}, 13 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 14 | ] 15 | 16 | mset = models_set[0] 17 | exp_name = 'res18_u_eval_lp' 18 | 19 | qtypes = ['lp_norm'] 20 | # qtypes = ['aciq_laplace', 'aciq_gaus', 'mse_direct', 'mse_uniform_prior', 'mse_direct_no_prior'] 21 | # qtypes = ['aciq_laplace', 'aciq_gaus', 'mse_direct', 'mse_decomp', 'mse_quant_est', 'max_static'] 22 | 23 | for p in np.linspace(2, 4, 21): 24 | for bits in [2]: 25 | for qt in qtypes: 26 | run(["python", "quantization/posttraining/cnn_classifier_inference.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 27 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 28 | + "--pretrained --custom_resnet -sh -q".split(" ") + ['-exp', exp_name] 29 | + ['--qtype', qt] + ['-ba', str(bits)] + ['-lp', str(p)] 30 | ) 31 | -------------------------------------------------------------------------------- /experimets/cd_vs_powell.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | {'model': 'resnet18', 'bs': 256, 'dev': [0]}, 11 | # {'model': 'resnet50', 'bs': 128, 'dev': [5]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'cd_vs_powell_res18_bcorr' 16 | 17 | for mset in models_set: 18 | # for bits in [2, 3, 4]: 19 | # run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 20 | # + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 21 | # + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 22 | # + ['-ba', str(bits)] + "--min_method Powell -maxi 2 -cs 512".split(" ") 23 | # ) 24 | 25 | 26 | for bits in [3, 4]: 27 | run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str( 28 | mset['bs'])] 29 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 30 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 31 | + ['-bw', str(bits)] + "--min_method Powell -maxi 2 -cs 512 -bcw".split( 32 | " ") 33 | ) 34 | 35 | for bits in [4]: 36 | run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str( 37 | mset['bs'])] 38 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 39 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 40 | + ['-ba', str(bits)] + ['-bw', str(bits)] + "--min_method Powell -maxi 2 -cs 512 -bcw".split( 41 | " ") 42 | ) 43 | -------------------------------------------------------------------------------- /experimets/cd_vs_powell_l1.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | # {'model': 'resnet18', 'bs': 256, 'dev': [0]}, 11 | {'model': 'resnet50', 'bs': 128, 'dev': [0]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'cd_vs_powell_res50_bcorr' 16 | 17 | for mset in models_set: 18 | # for bits in [2, 4]: 19 | # run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 20 | # + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 21 | # + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 22 | # + ['-ba', str(bits)] + "--min_method Powell -maxi 1 -cs 512".split(" ") 23 | # ) 24 | 25 | for bits in [4]: 26 | run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str( 27 | mset['bs'])] 28 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 29 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 30 | + ['-bw', str(bits)] + "--min_method Powell -maxi 1 -cs 512 -bcw".split( 31 | " ") 32 | ) 33 | 34 | for bits in [4]: 35 | run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str( 36 | mset['bs'])] 37 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 38 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 39 | + ['-ba', str(bits)] + ['-bw', str(bits)] + "--min_method Powell -maxi 1 -cs 512 -bcw".split( 40 | " ") 41 | ) 42 | -------------------------------------------------------------------------------- /experimets/cd_vs_powell_l3.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | {'model': 'mobilenet_v2', 'bs': 128, 'dev': [0]}, 10 | # {'model': 'resnet18', 'bs': 256, 'dev': [0]}, 11 | # {'model': 'resnet50', 'bs': 128, 'dev': [5]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'cd_vs_powell_mobilv2_bcorr' 16 | 17 | for mset in models_set: 18 | # for bits in [2, 4]: 19 | # run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 20 | # + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 21 | # + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 22 | # + ['-ba', str(bits)] + "--min_method Powell -maxi 1 -cs 512".split(" ") 23 | # ) 24 | 25 | for bits in [4]: 26 | run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str( 27 | mset['bs'])] 28 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 29 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 30 | + ['-bw', str(bits)] + "--min_method Powell -maxi 1 -cs 512 -bcw".split( 31 | " ") 32 | ) 33 | 34 | for bits in [4]: 35 | run(["python", "quantization/posttraining/layer_scale_optimization_opt.py"] + ['-a', mset['model']] + ['-b', str( 36 | mset['bs'])] 37 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 38 | + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 39 | + ['-ba', str(bits)] + ['-bw', str(bits)] + "--min_method Powell -maxi 1 -cs 512 -bcw".split( 40 | " ") 41 | ) 42 | -------------------------------------------------------------------------------- /experimets/lapq_init_study.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | proj_root_dir = os.path.join(os.path.dirname(__file__), os.pardir) 3 | sys.path.append(proj_root_dir) 4 | import argparse 5 | import torch 6 | import torchvision.models as models 7 | import scipy.optimize as opt 8 | from pathlib import Path 9 | import numpy as np 10 | import torch.nn as nn 11 | from itertools import count 12 | import torch.backends.cudnn as cudnn 13 | from quantization.quantizer import ModelQuantizer 14 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 15 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 16 | from utils.mllog import MLlogger 17 | from quantization.posttraining.cnn_classifier import CnnModel 18 | import pickle 19 | from tqdm import tqdm 20 | 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | home = str(Path.home()) 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 29 | choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet18)') 33 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 34 | help='dataset name') 35 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 36 | help='dataset folder') 37 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('-cb', '--cal-batch-size', default=None, type=int, help='Batch size for calibration') 45 | parser.add_argument('-cs', '--cal-set-size', default=None, type=int, help='Batch size for calibration') 46 | parser.add_argument('-p', '--print-freq', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 55 | parser.add_argument('--custom_inception', action='store_true', help='use custom inception implementation') 56 | parser.add_argument('--seed', default=0, type=int, 57 | help='seed for initializing training. ') 58 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 59 | help='GPU ids to use (e.g 0 1 2 3)') 60 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 61 | 62 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 63 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 64 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 65 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 66 | parser.add_argument('--qtype', default='max_static', help='Type of quantization method') 67 | parser.add_argument('-lp', type=float, help='p parameter of Lp norm', default=3.) 68 | 69 | parser.add_argument('--min_method', '-mm', help='Minimization method to use [Nelder-Mead, Powell, COBYLA]', default='Powell') 70 | parser.add_argument('--maxiter', '-maxi', type=int, help='Maximum number of iterations to minimize algo', default=None) 71 | parser.add_argument('--maxfev', '-maxf', type=int, help='Maximum number of function evaluations of minimize algo', default=None) 72 | 73 | parser.add_argument('--init_method', default='static', 74 | help='Scale initialization method [static, dynamic, random], default=static') 75 | parser.add_argument('-siv', type=float, help='Value for static initialization', default=1.) 76 | 77 | parser.add_argument('--dont_fix_np_seed', '-dfns', action='store_true', help='Do not fix np seed even if seed specified') 78 | parser.add_argument('--bcorr_w', '-bcw', action='store_true', help='Bias correction for weights', default=False) 79 | parser.add_argument('--tag', help='Tag for logging purposes', default='n/a') 80 | parser.add_argument('--bn_folding', '-bnf', action='store_true', help='Apply Batch Norm folding', default=False) 81 | 82 | 83 | # TODO: refactor this 84 | _eval_count = count(0) 85 | _min_loss = 1e6 86 | 87 | 88 | def evaluate_calibration_clipped(scales, model, mq): 89 | global _eval_count, _min_loss 90 | eval_count = next(_eval_count) 91 | 92 | mq.set_clipping(scales, model.device) 93 | loss = model.evaluate_calibration().item() 94 | 95 | if loss < _min_loss: 96 | _min_loss = loss 97 | 98 | print_freq = 20 99 | if eval_count % 20 == 0: 100 | print("func eval iteration: {}, minimum loss of last {} iterations: {:.4f}".format( 101 | eval_count, print_freq, _min_loss)) 102 | 103 | return loss 104 | 105 | 106 | def coord_descent(fun, init, args, **kwargs): 107 | maxiter = kwargs['maxiter'] 108 | x = init.copy() 109 | 110 | def coord_opt(alpha, scales, i): 111 | if alpha < 0: 112 | result = 1e6 113 | else: 114 | scales[i] = alpha 115 | result = fun(scales) 116 | 117 | return result 118 | 119 | nfev = 0 120 | for j in range(maxiter): 121 | for i in range(len(x)): 122 | print("Optimizing variable {}".format(i)) 123 | r = opt.minimize_scalar(lambda alpha: coord_opt(alpha, x, i)) 124 | nfev += r.nfev 125 | opt_alpha = r.x 126 | x[i] = opt_alpha 127 | 128 | if 'callback' in kwargs: 129 | kwargs['callback'](x) 130 | 131 | res = opt.OptimizeResult() 132 | res.x = x 133 | res.nit = maxiter 134 | res.nfev = nfev 135 | res.fun = np.array([r.fun]) 136 | res.success = True 137 | 138 | return res 139 | 140 | 141 | def main(args, ml_logger): 142 | # Fix the seed 143 | random.seed(args.seed) 144 | if not args.dont_fix_np_seed: 145 | np.random.seed(args.seed) 146 | torch.manual_seed(args.seed) 147 | torch.cuda.manual_seed_all(args.seed) 148 | cudnn.deterministic = True 149 | torch.backends.cudnn.benchmark = False 150 | 151 | if args.tag is not None: 152 | ml_logger.mlflow.log_param('tag', args.tag) 153 | 154 | enable_bcorr = False 155 | if args.bcorr_w: 156 | args.bcorr_w = False 157 | enable_bcorr = True 158 | 159 | if args.init_method == 'random': 160 | args.qtype = 'max_static' 161 | 162 | # create model 163 | # Always enable shuffling to avoid issues where we get bad results due to weak statistics 164 | custom_resnet = True 165 | custom_inception = True 166 | inf_model = CnnModel(args.arch, custom_resnet, custom_inception, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 167 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 168 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 169 | 170 | layers = [] 171 | # TODO: make it more generic 172 | if args.bit_weights is not None: 173 | layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 174 | if args.bit_act is not None: 175 | layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 176 | if args.bit_act is not None and 'mobilenet' in args.arch: 177 | layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 178 | 179 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 180 | nn.ReLU6: ActivationModuleWrapperPost, 181 | nn.Conv2d: ParameterModuleWrapperPost} 182 | 183 | mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory) 184 | init_loss = inf_model.evaluate_calibration() 185 | 186 | if args.init_method == 'random': 187 | clip = mq.get_clipping() 188 | for i, c in enumerate(clip.cpu()): 189 | clip[i] = np.random.uniform(0, c) 190 | print("Randomize initial clipping") 191 | print(clip) 192 | mq.set_clipping(clip, inf_model.device) 193 | init_loss = inf_model.evaluate_calibration() 194 | 195 | print("init loss: {:.4f}".format(init_loss.item())) 196 | ml_logger.log_metric('Init loss', init_loss.item(), step='auto') 197 | 198 | acc = inf_model.validate() 199 | ml_logger.log_metric('Acc init', acc, step='auto') 200 | 201 | init = mq.get_clipping() 202 | 203 | global _eval_count, _min_loss 204 | _min_loss = init_loss.item() 205 | 206 | # if enable_bcorr: 207 | # args.bcorr_w = True 208 | # inf_model = CnnModel(args.arch, custom_resnet, custom_inception, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 209 | # batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 210 | # cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 211 | # 212 | # mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory) 213 | 214 | # run optimizer 215 | min_options = {} 216 | if args.maxiter is not None: 217 | min_options['maxiter'] = args.maxiter 218 | if args.maxfev is not None: 219 | min_options['maxfev'] = args.maxfev 220 | 221 | _iter = count(0) 222 | 223 | def local_search_callback(x): 224 | it = next(_iter) 225 | mq.set_clipping(x, inf_model.device) 226 | loss = inf_model.evaluate_calibration() 227 | print("\n[{}]: Local search callback".format(it)) 228 | print("loss: {:.4f}\n".format(loss.item())) 229 | print(x) 230 | ml_logger.log_metric('Loss {}'.format(args.min_method), loss.item(), step='auto') 231 | 232 | # evaluate 233 | acc = inf_model.validate() 234 | ml_logger.log_metric('Acc {}'.format(args.min_method), acc, step='auto') 235 | 236 | args.min_method = "Powell" 237 | method = coord_descent if args.min_method == 'CD' else args.min_method 238 | res = opt.minimize(lambda scales: evaluate_calibration_clipped(scales, inf_model, mq), init.cpu().numpy(), 239 | method=method, options=min_options, callback=local_search_callback) 240 | 241 | print(res) 242 | 243 | scales = res.x 244 | mq.set_clipping(scales, inf_model.device) 245 | loss = inf_model.evaluate_calibration() 246 | ml_logger.log_metric('Loss {}'.format(args.min_method), loss.item(), step='auto') 247 | 248 | # evaluate 249 | acc = inf_model.validate() 250 | ml_logger.log_metric('Acc {}'.format(args.min_method), acc, step='auto') 251 | # data['powell'] = {'alpha': scales, 'loss': loss.item(), 'acc': acc} 252 | 253 | # save scales 254 | # f_name = "scales_{}_W{}A{}.pkl".format(args.arch, args.bit_weights, args.bit_act) 255 | # f = open(os.path.join(proj_root_dir, 'data', f_name), 'wb') 256 | # pickle.dump(data, f) 257 | # f.close() 258 | # print("Data saved to {}".format(f_name)) 259 | 260 | 261 | if __name__ == '__main__': 262 | args = parser.parse_args() 263 | if args.cal_batch_size is None: 264 | args.cal_batch_size = args.batch_size 265 | if args.cal_batch_size > args.batch_size: 266 | print("Changing cal_batch_size parameter from {} to {}".format(args.cal_batch_size, args.batch_size)) 267 | args.cal_batch_size = args.batch_size 268 | if args.cal_set_size is None: 269 | args.cal_set_size = args.batch_size 270 | 271 | with MLlogger(os.path.join(home, 'mxt-sim/mllog_runs'), args.experiment, args, 272 | name_args=[args.arch, args.dataset, "W{}A{}".format(args.bit_weights, args.bit_act)]) as ml_logger: 273 | main(args, ml_logger) 274 | -------------------------------------------------------------------------------- /experimets/moments.py: -------------------------------------------------------------------------------- 1 | from subprocess import run 2 | import mlflow 3 | import numpy as np 4 | 5 | models_set = [ 6 | # {'model': 'vgg16', 'bs': 128, 'dev': [5]}, 7 | # {'model': 'vgg16_bn', 'bs': 128, 'dev': [5]}, 8 | # {'model': 'inception_v3', 'bs': 256, 'dev': [5]}, 9 | # {'model': 'mobilenet_v2', 'bs': 128, 'dev': [5]}, 10 | {'model': 'resnet18', 'bs': 256, 'dev': [0]}, 11 | # {'model': 'resnet50', 'bs': 128, 'dev': [5]}, 12 | # {'model': 'resnet101', 'bs': 512, 'dev': [5]} 13 | ] 14 | 15 | exp_name = 'moments' 16 | qtypes = ['max_static'] 17 | # qtypes = ['l2_norm', 'aciq_laplace', 'l3_norm', 'max_static'] 18 | 19 | for mset in models_set: 20 | for qt in qtypes: 21 | for bits in [16, 8, 6, 5, 4]: 22 | run(["python", "quantization/qat/cnn_classifier_train.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 23 | + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 24 | + "--custom_resnet".split(" ") + ['-exp', exp_name] 25 | + ['-bw', str(bits)] + ['--qtype', qt] + "-q -e -ls".split(" ") + ['--resume', '/data/home/cvds_lab/mxt-sim/ckpt/res18_kurtosis/epoch_90_checkpoint.pth.tar'] 26 | ) 27 | 28 | # run(["python", "quantization/qat/cnn_classifier_train.py"] + ['-a', mset['model']] + ['-b', str(mset['bs'])] 29 | # + ['--dataset', 'imagenet'] + ['--gpu_ids'] + " ".join(map(str, mset['dev'])).split(" ") 30 | # + "--pretrained --custom_resnet".split(" ") + ['-exp', exp_name] 31 | # + ['-bw', str(bits)] + ['--qtype', qt] + "-q -e -ls".split(" ") 32 | # ) 33 | -------------------------------------------------------------------------------- /fig/cd_vs_powell_32W2A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/cd_vs_powell_32W2A.png -------------------------------------------------------------------------------- /fig/cd_vs_powell_3W32A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/cd_vs_powell_3W32A.png -------------------------------------------------------------------------------- /fig/cd_vs_powell_4W32A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/cd_vs_powell_4W32A.png -------------------------------------------------------------------------------- /fig/contour_2bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/contour_2bit.png -------------------------------------------------------------------------------- /fig/contour_4bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/contour_4bit.png -------------------------------------------------------------------------------- /fig/heatmap_legend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/heatmap_legend.png -------------------------------------------------------------------------------- /fig/lapq_intuition.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/lapq_intuition.pdf -------------------------------------------------------------------------------- /fig/lpnorm_err_2bit.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/lpnorm_err_2bit.pdf -------------------------------------------------------------------------------- /fig/lpnorm_err_2bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/lpnorm_err_2bit.png -------------------------------------------------------------------------------- /fig/lpnorm_err_4bit.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/lpnorm_err_4bit.pdf -------------------------------------------------------------------------------- /fig/lpnorm_err_4bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/lpnorm_err_4bit.png -------------------------------------------------------------------------------- /fig/p_vs_lpnorm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/p_vs_lpnorm.pdf -------------------------------------------------------------------------------- /fig/p_vs_opt_clipping.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/p_vs_opt_clipping.pdf -------------------------------------------------------------------------------- /fig/quant_matmul_sep_2bit.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/quant_matmul_sep_2bit.pdf -------------------------------------------------------------------------------- /fig/quant_matmul_sep_4bit.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/quant_matmul_sep_4bit.pdf -------------------------------------------------------------------------------- /fig/res18_CalibrationVsAcc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res18_CalibrationVsAcc.pdf -------------------------------------------------------------------------------- /fig/res18_l14l15_contour.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res18_l14l15_contour.pdf -------------------------------------------------------------------------------- /fig/res18_l14l15_contour.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res18_l14l15_contour.png -------------------------------------------------------------------------------- /fig/res18_weights_thresholds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res18_weights_thresholds.png -------------------------------------------------------------------------------- /fig/res50_p_vs_acc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res50_p_vs_acc.pdf -------------------------------------------------------------------------------- /fig/res50_p_vs_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res50_p_vs_acc.png -------------------------------------------------------------------------------- /fig/res50_res18_pnorms.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res50_res18_pnorms.pdf -------------------------------------------------------------------------------- /fig/res50_res18_pnorms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/res50_res18_pnorms.png -------------------------------------------------------------------------------- /fig/resnet18_A2_3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_A2_3d.png -------------------------------------------------------------------------------- /fig/resnet18_A2_3d_colorfull.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_A2_3d_colorfull.pdf -------------------------------------------------------------------------------- /fig/resnet18_A2_3d_colorfull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_A2_3d_colorfull.png -------------------------------------------------------------------------------- /fig/resnet18_A2_3d_zoom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_A2_3d_zoom.png -------------------------------------------------------------------------------- /fig/resnet18_act_2bit_hessian.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_act_2bit_hessian.pdf -------------------------------------------------------------------------------- /fig/resnet18_act_2bit_hessian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_act_2bit_hessian.png -------------------------------------------------------------------------------- /fig/resnet18_act_2bit_hessian_norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_act_2bit_hessian_norm.png -------------------------------------------------------------------------------- /fig/resnet18_act_4bit_hessian.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_act_4bit_hessian.pdf -------------------------------------------------------------------------------- /fig/resnet18_act_4bit_hessian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_act_4bit_hessian.png -------------------------------------------------------------------------------- /fig/resnet18_act_4bit_hessian_norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_act_4bit_hessian_norm.png -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA2.pdf -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA2.png -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA2_3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA2_3d.png -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA2_topographic_zoom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA2_topographic_zoom.png -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA3.pdf -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA3.png -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA4.pdf -------------------------------------------------------------------------------- /fig/resnet18_l0l1_WNoneA4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_l0l1_WNoneA4.png -------------------------------------------------------------------------------- /fig/resnet18_loss_vs_p.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_loss_vs_p.pdf -------------------------------------------------------------------------------- /fig/resnet18_quadratic_loss_vs_distance.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet18_quadratic_loss_vs_distance.pdf -------------------------------------------------------------------------------- /fig/resnet50_loss_vs_p.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/fig/resnet50_loss_vs_p.pdf -------------------------------------------------------------------------------- /jupyter/error_separability.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "\n" 12 | ] 13 | } 14 | ], 15 | "metadata": { 16 | "kernelspec": { 17 | "display_name": "Python 3", 18 | "language": "python", 19 | "name": "python3" 20 | }, 21 | "language_info": { 22 | "codemirror_mode": { 23 | "name": "ipython", 24 | "version": 2 25 | }, 26 | "file_extension": ".py", 27 | "mimetype": "text/x-python", 28 | "name": "python", 29 | "nbconvert_exporter": "python", 30 | "pygments_lexer": "ipython2", 31 | "version": "2.7.6" 32 | }, 33 | "pycharm": { 34 | "stem_cell": { 35 | "cell_type": "raw", 36 | "source": [], 37 | "metadata": { 38 | "collapsed": false 39 | } 40 | } 41 | } 42 | }, 43 | "nbformat": 4, 44 | "nbformat_minor": 0 45 | } 46 | -------------------------------------------------------------------------------- /jupyter/lapq_intuition.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "" 12 | ] 13 | } 14 | ], 15 | "metadata": { 16 | "language_info": { 17 | "codemirror_mode": { 18 | "name": "ipython", 19 | "version": 2 20 | }, 21 | "file_extension": ".py", 22 | "mimetype": "text/x-python", 23 | "name": "python", 24 | "nbconvert_exporter": "python", 25 | "pygments_lexer": "ipython2", 26 | "version": "2.7.6" 27 | } 28 | }, 29 | "nbformat": 4, 30 | "nbformat_minor": 0 31 | } 32 | -------------------------------------------------------------------------------- /lapq/README.md: -------------------------------------------------------------------------------- 1 | # Loss Aware Post-training Quantization 2 | 3 | ## Dependencies 4 | - python3.x 5 | - [pytorch]() 6 | - [torchvision]() to load the datasets, perform image transforms 7 | - [pandas]() for logging to csv 8 | - [bokeh]() for training visualization 9 | - [scikit-learn](https://scikit-learn.org) for kmeans clustering 10 | - [mlflow](https://mlflow.org/) for logging 11 | - [tqdm](https://tqdm.github.io/) for progress 12 | - [scipy](https://scipy.org/) for Powell and Brent 13 | 14 | 15 | ## HW requirements 16 | NVIDIA GPU / cuda support 17 | 18 | ## Data 19 | - To run this code you need validation set from ILSVRC2012 data 20 | - Configure your dataset path by providing --data "PATH_TO_ILSVRC" or copy ILSVRC dir to ~/datasets/ILSVRC2012. 21 | - To get the ILSVRC2012 data, you should register on their site for access: 22 | 23 | ## Prepare environment 24 | - Clone source code 25 | ``` 26 | git clone https://github.com/ynahshan/nn-quantization-pytorch.git 27 | cd cnn-quantization 28 | ``` 29 | - Create virtual environment for python3 and activate: 30 | ``` 31 | virtualenv --system-site-packages -p python3 venv3 32 | . ./venv3/bin/activate 33 | ``` 34 | - Install dependencies 35 | ``` 36 | pip install torch torchvision bokeh pandas sklearn mlflow tqdm scipy 37 | ``` 38 | 39 | ### Run experiments 40 | - To reproduce resnet18 experiment run: 41 | ``` 42 | cd nn-quantization-pytorch 43 | python lapq/lapq_v2.py -a resnet18 -b 512 --dataset imagenet --pretrained --custom_resnet --min_method Powell -maxi 2 -cs 512 -exp lapq_v2 -ba 4 -bw 4 -bcw 44 | ``` 45 | 46 | - To reproduce resnet50 experiment run: 47 | ``` 48 | cd nn-quantization-pytorch 49 | python lapq/lapq_v2.py -a resnet50 --dataset imagenet -b 128 --pretrained --custom_resnet -ba 4 -bw 4 --min_method Powell -maxi 1 -cs 512 -bcw 50 | ``` 51 | 52 | - To reproduce inception_v3 experiment run: 53 | ``` 54 | cd nn-quantization-pytorch 55 | python lapq/lapq_v2.py -a inception_v3 --dataset imagenet -b 128 --pretrained --custom_inception -ba 4 -bw 4 --min_method Powell -maxi 1 -cs 512 -bcw 56 | ``` 57 | 58 | To reproduce results for other models change model name after "-a". All other arguments are same as resnet50. 59 | -------------------------------------------------------------------------------- /lapq/layer_scale_optimization.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir)) 3 | import argparse 4 | import torch 5 | import torchvision.models as models 6 | import scipy.optimize as opt 7 | from pathlib import Path 8 | import numpy as np 9 | import torch.nn as nn 10 | from itertools import count 11 | import torch.backends.cudnn as cudnn 12 | from quantization.quantizer import ModelQuantizer 13 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 14 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 15 | from utils.mllog import MLlogger 16 | from quantization.posttraining.cnn_classifier import CnnModel 17 | 18 | 19 | model_names = sorted(name for name in models.__dict__ 20 | if name.islower() and not name.startswith("__") 21 | and callable(models.__dict__[name])) 22 | 23 | home = str(Path.home()) 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 26 | choices=model_names, 27 | help='model architecture: ' + 28 | ' | '.join(model_names) + 29 | ' (default: resnet18)') 30 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 31 | help='dataset name') 32 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 33 | help='dataset folder') 34 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('-b', '--batch-size', default=256, type=int, 37 | metavar='N', 38 | help='mini-batch size (default: 256), this is the total ' 39 | 'batch size of all GPUs on the current node when ' 40 | 'using Data Parallel or Distributed Data Parallel') 41 | parser.add_argument('-cb', '--cal-batch-size', default=None, type=int, help='Batch size for calibration') 42 | parser.add_argument('-cs', '--cal-set-size', default=None, type=int, help='Batch size for calibration') 43 | parser.add_argument('-p', '--print-freq', default=10, type=int, 44 | metavar='N', help='print frequency (default: 10)') 45 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 46 | help='path to latest checkpoint (default: none)') 47 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 48 | help='evaluate model on validation set') 49 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 50 | help='use pre-trained model') 51 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 52 | parser.add_argument('--custom_inception', action='store_true', help='use custom inception implementation') 53 | 54 | parser.add_argument('--seed', default=0, type=int, 55 | help='seed for initializing training. ') 56 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 57 | help='GPU ids to use (e.g 0 1 2 3)') 58 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 59 | 60 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 61 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 62 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 63 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 64 | parser.add_argument('--qtype', default='aciq_laplace', help='Type of quantization method') 65 | parser.add_argument('-lp', type=float, help='p parameter of Lp norm', default=3.) 66 | 67 | parser.add_argument('--min_method', '-mm', help='Minimization method to use [Nelder-Mead, Powell, COBYLA]', default='Powell') 68 | parser.add_argument('--maxiter', '-maxi', type=int, help='Maximum number of iterations to minimize algo', default=None) 69 | parser.add_argument('--maxfev', '-maxf', type=int, help='Maximum number of function evaluations of minimize algo', default=None) 70 | 71 | parser.add_argument('--init_method', default='static', 72 | help='Scale initialization method [static, dynamic, random], default=static') 73 | parser.add_argument('-siv', type=float, help='Value for static initialization', default=1.) 74 | 75 | parser.add_argument('--dont_fix_np_seed', '-dfns', action='store_true', help='Do not fix np seed even if seed specified') 76 | parser.add_argument('--bcorr_w', '-bcw', action='store_true', help='Bias correction for weights', default=False) 77 | 78 | 79 | # TODO: refactor this 80 | _eval_count = count(0) 81 | _min_loss = 1e6 82 | 83 | 84 | def evaluate_calibration_clipped(scales, model, mq): 85 | global _eval_count, _min_loss 86 | eval_count = next(_eval_count) 87 | 88 | mq.set_clipping(scales, model.device) 89 | loss = model.evaluate_calibration().item() 90 | 91 | if loss < _min_loss: 92 | _min_loss = loss 93 | 94 | print_freq = 20 95 | if eval_count % 20 == 0: 96 | print("func eval iteration: {}, minimum loss of last {} iterations: {:.4f}".format( 97 | eval_count, print_freq, _min_loss)) 98 | 99 | return loss 100 | 101 | 102 | def coord_descent(fun, init, args, **kwargs): 103 | maxiter = kwargs['maxiter'] 104 | x = init.copy() 105 | 106 | def coord_opt(alpha, scales, i): 107 | if alpha < 0: 108 | result = 1e6 109 | else: 110 | scales[i] = alpha 111 | result = fun(scales) 112 | 113 | return result 114 | 115 | nfev = 0 116 | for j in range(maxiter): 117 | for i in range(len(x)): 118 | print("Optimizing variable {}".format(i)) 119 | r = opt.minimize_scalar(lambda alpha: coord_opt(alpha, x, i)) 120 | nfev += r.nfev 121 | opt_alpha = r.x 122 | x[i] = opt_alpha 123 | 124 | if 'callback' in kwargs: 125 | kwargs['callback'](x) 126 | 127 | res = opt.OptimizeResult() 128 | res.x = x 129 | res.nit = maxiter 130 | res.nfev = nfev 131 | res.fun = np.array([r.fun]) 132 | res.success = True 133 | 134 | return res 135 | 136 | 137 | def main(args, ml_logger): 138 | # Fix the seed 139 | random.seed(args.seed) 140 | if not args.dont_fix_np_seed: 141 | np.random.seed(args.seed) 142 | torch.manual_seed(args.seed) 143 | torch.cuda.manual_seed_all(args.seed) 144 | cudnn.deterministic = True 145 | torch.backends.cudnn.benchmark = False 146 | 147 | # create model 148 | # Always enable shuffling to avoid issues where we get bad results due to weak statistics 149 | inf_model = CnnModel(args.arch, args.custom_resnet, args.custom_inception,args.pretrained, args.dataset, args.gpu_ids, args.datapath, 150 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 151 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 152 | 153 | layers = [] 154 | # TODO: make it more generic 155 | if args.bit_weights is not None: 156 | layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 157 | if args.bit_act is not None: 158 | layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 159 | if args.bit_act is not None and 'mobilenet' in args.arch: 160 | layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 161 | 162 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 163 | nn.ReLU6: ActivationModuleWrapperPost, 164 | nn.Conv2d: ParameterModuleWrapperPost} 165 | mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory) 166 | # mq.log_quantizer_state(ml_logger, -1) 167 | 168 | print("init_method: {}, qtype {}".format(args.init_method, args.qtype)) 169 | # initialize scales 170 | if args.init_method == 'dynamic': 171 | # evaluate to initialize dynamic clipping 172 | loss = inf_model.evaluate_calibration() 173 | print("Initial loss: {:.4f}".format(loss.item())) 174 | 175 | # get clipping values 176 | init = mq.get_clipping() 177 | else: 178 | if args.init_method == 'static': 179 | init = np.array([args.siv] * len(layers)) 180 | elif args.init_method == 'random': 181 | init = np.random.uniform(0.5, 1., size=len(layers)) # TODO: pass range by argument 182 | else: 183 | raise RuntimeError("Invalid argument init_method {}".format(args.init_method)) 184 | 185 | # set clip value to qwrappers 186 | mq.set_clipping(init, inf_model.device) 187 | print("scales initialization: {}".format(str(init))) 188 | 189 | # evaluate with clipping 190 | loss = inf_model.evaluate_calibration() 191 | print("Initial loss: {:.4f}".format(loss.item())) 192 | 193 | ml_logger.log_metric('Loss init'.format(args.min_method), loss.item(), step='auto') 194 | 195 | global _min_loss 196 | _min_loss = loss.item() 197 | 198 | # evaluate 199 | acc = inf_model.validate() 200 | ml_logger.log_metric('Acc init', acc, step='auto') 201 | 202 | # run optimizer 203 | min_options = {} 204 | if args.maxiter is not None: 205 | min_options['maxiter'] = args.maxiter 206 | if args.maxfev is not None: 207 | min_options['maxfev'] = args.maxfev 208 | 209 | _iter = count(0) 210 | 211 | def local_search_callback(x): 212 | it = next(_iter) 213 | mq.set_clipping(x, inf_model.device) 214 | loss = inf_model.evaluate_calibration() 215 | print("\n[{}]: Local search callback".format(it)) 216 | print("loss: {:.4f}\n".format(loss.item())) 217 | print(x) 218 | ml_logger.log_metric('Loss {}'.format(args.min_method), loss.item(), step='auto') 219 | 220 | # evaluate 221 | acc = inf_model.validate() 222 | ml_logger.log_metric('Acc {}'.format(args.min_method), acc, step='auto') 223 | 224 | method = coord_descent if args.min_method == 'CD' else args.min_method 225 | res = opt.minimize(lambda scales: evaluate_calibration_clipped(scales, inf_model, mq), init.cpu().numpy(), 226 | method=method, options=min_options, callback=local_search_callback) 227 | 228 | print(res) 229 | 230 | scales = res.x 231 | mq.set_clipping(scales, inf_model.device) 232 | loss = inf_model.evaluate_calibration() 233 | ml_logger.log_metric('Loss {}'.format(args.min_method), loss.item(), step='auto') 234 | 235 | # evaluate 236 | acc = inf_model.validate() 237 | ml_logger.log_metric('Acc {}'.format(args.min_method), acc, step='auto') 238 | # save scales 239 | 240 | 241 | if __name__ == '__main__': 242 | args = parser.parse_args() 243 | if args.cal_batch_size is None: 244 | args.cal_batch_size = args.batch_size 245 | if args.cal_batch_size > args.batch_size: 246 | print("Changing cal_batch_size parameter from {} to {}".format(args.cal_batch_size, args.batch_size)) 247 | args.cal_batch_size = args.batch_size 248 | if args.cal_set_size is None: 249 | args.cal_set_size = args.batch_size 250 | 251 | with MLlogger(os.path.join(home, 'mxt-sim/mllog_runs'), args.experiment, args, 252 | name_args=[args.arch, args.dataset, "W{}A{}".format(args.bit_weights, args.bit_act)]) as ml_logger: 253 | main(args, ml_logger) 254 | -------------------------------------------------------------------------------- /models/ShuffleNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | from torch.nn import init 7 | 8 | 9 | def conv3x3(in_channels, out_channels, stride=1, 10 | padding=1, bias=True, groups=1): 11 | """3x3 convolution with padding 12 | """ 13 | return nn.Conv2d( 14 | in_channels, 15 | out_channels, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=padding, 19 | bias=bias, 20 | groups=groups) 21 | 22 | 23 | def conv1x1(in_channels, out_channels, groups=1): 24 | """1x1 convolution with padding 25 | - Normal pointwise convolution When groups == 1 26 | - Grouped pointwise convolution when groups > 1 27 | """ 28 | return nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size=1, 32 | groups=groups, 33 | stride=1) 34 | 35 | 36 | def channel_shuffle(x, groups): 37 | batchsize, num_channels, height, width = x.data.size() 38 | 39 | channels_per_group = num_channels // groups 40 | 41 | # reshape 42 | x = x.view(batchsize, groups, 43 | channels_per_group, height, width) 44 | 45 | # transpose 46 | # - contiguous() required if transpose() is used before view(). 47 | # See https://github.com/pytorch/pytorch/issues/764 48 | x = torch.transpose(x, 1, 2).contiguous() 49 | 50 | # flatten 51 | x = x.view(batchsize, -1, height, width) 52 | 53 | return x 54 | 55 | 56 | class ShuffleUnit(nn.Module): 57 | def __init__(self, in_channels, out_channels, groups=3, 58 | grouped_conv=True, combine='add'): 59 | 60 | super(ShuffleUnit, self).__init__() 61 | 62 | self.in_channels = in_channels 63 | self.out_channels = out_channels 64 | self.grouped_conv = grouped_conv 65 | self.combine = combine 66 | self.groups = groups 67 | self.bottleneck_channels = self.out_channels // 4 68 | 69 | # define the type of ShuffleUnit 70 | if self.combine == 'add': 71 | # ShuffleUnit Figure 2b 72 | self.depthwise_stride = 1 73 | self._combine_func = self._add 74 | elif self.combine == 'concat': 75 | # ShuffleUnit Figure 2c 76 | self.depthwise_stride = 2 77 | self._combine_func = self._concat 78 | 79 | # ensure output of concat has the same channels as 80 | # original output channels. 81 | self.out_channels -= self.in_channels 82 | else: 83 | raise ValueError("Cannot combine tensors with \"{}\"" \ 84 | "Only \"add\" and \"concat\" are" \ 85 | "supported".format(self.combine)) 86 | 87 | # Use a 1x1 grouped or non-grouped convolution to reduce input channels 88 | # to bottleneck channels, as in a ResNet bottleneck module. 89 | # NOTE: Do not use group convolution for the first conv1x1 in Stage 2. 90 | self.first_1x1_groups = self.groups if grouped_conv else 1 91 | 92 | self.g_conv_1x1_compress = self._make_grouped_conv1x1( 93 | self.in_channels, 94 | self.bottleneck_channels, 95 | self.first_1x1_groups, 96 | batch_norm=True, 97 | relu=True 98 | ) 99 | 100 | # 3x3 depthwise convolution followed by batch normalization 101 | self.depthwise_conv3x3 = conv3x3( 102 | self.bottleneck_channels, self.bottleneck_channels, 103 | stride=self.depthwise_stride, groups=self.bottleneck_channels) 104 | self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels) 105 | 106 | # Use 1x1 grouped convolution to expand from 107 | # bottleneck_channels to out_channels 108 | self.g_conv_1x1_expand = self._make_grouped_conv1x1( 109 | self.bottleneck_channels, 110 | self.out_channels, 111 | self.groups, 112 | batch_norm=True, 113 | relu=False 114 | ) 115 | 116 | 117 | @staticmethod 118 | def _add(x, out): 119 | # residual connection 120 | return x + out 121 | 122 | 123 | @staticmethod 124 | def _concat(x, out): 125 | # concatenate along channel axis 126 | return torch.cat((x, out), 1) 127 | 128 | 129 | def _make_grouped_conv1x1(self, in_channels, out_channels, groups, 130 | batch_norm=True, relu=False): 131 | 132 | modules = OrderedDict() 133 | 134 | conv = conv1x1(in_channels, out_channels, groups=groups) 135 | modules['conv1x1'] = conv 136 | 137 | if batch_norm: 138 | modules['batch_norm'] = nn.BatchNorm2d(out_channels) 139 | if relu: 140 | modules['relu'] = nn.ReLU() 141 | if len(modules) > 1: 142 | return nn.Sequential(modules) 143 | else: 144 | return conv 145 | 146 | 147 | def forward(self, x): 148 | # save for combining later with output 149 | residual = x 150 | 151 | if self.combine == 'concat': 152 | residual = F.avg_pool2d(residual, kernel_size=3, 153 | stride=2, padding=1) 154 | 155 | out = self.g_conv_1x1_compress(x) 156 | out = channel_shuffle(out, self.groups) 157 | out = self.depthwise_conv3x3(out) 158 | out = self.bn_after_depthwise(out) 159 | out = self.g_conv_1x1_expand(out) 160 | 161 | out = self._combine_func(residual, out) 162 | return F.relu(out) 163 | 164 | 165 | class ShuffleNet(nn.Module): 166 | """ShuffleNet implementation. 167 | """ 168 | 169 | def __init__(self, groups=3, in_channels=3, num_classes=1000): 170 | """ShuffleNet constructor. 171 | 172 | Arguments: 173 | groups (int, optional): number of groups to be used in grouped 174 | 1x1 convolutions in each ShuffleUnit. Default is 3 for best 175 | performance according to original paper. 176 | in_channels (int, optional): number of channels in the input tensor. 177 | Default is 3 for RGB image inputs. 178 | num_classes (int, optional): number of classes to predict. Default 179 | is 1000 for ImageNet. 180 | 181 | """ 182 | super(ShuffleNet, self).__init__() 183 | 184 | self.groups = groups 185 | self.stage_repeats = [3, 7, 3] 186 | self.in_channels = in_channels 187 | self.num_classes = num_classes 188 | 189 | # index 0 is invalid and should never be called. 190 | # only used for indexing convenience. 191 | if groups == 1: 192 | self.stage_out_channels = [-1, 24, 144, 288, 567] 193 | elif groups == 2: 194 | self.stage_out_channels = [-1, 24, 200, 400, 800] 195 | elif groups == 3: 196 | self.stage_out_channels = [-1, 24, 240, 480, 960] 197 | elif groups == 4: 198 | self.stage_out_channels = [-1, 24, 272, 544, 1088] 199 | elif groups == 8: 200 | self.stage_out_channels = [-1, 24, 384, 768, 1536] 201 | else: 202 | raise ValueError( 203 | """{} groups is not supported for 204 | 1x1 Grouped Convolutions""".format(num_groups)) 205 | 206 | # Stage 1 always has 24 output channels 207 | self.conv1 = conv3x3(self.in_channels, 208 | self.stage_out_channels[1], # stage 1 209 | stride=2) 210 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 211 | 212 | # Stage 2 213 | self.stage2 = self._make_stage(2) 214 | # Stage 3 215 | self.stage3 = self._make_stage(3) 216 | # Stage 4 217 | self.stage4 = self._make_stage(4) 218 | 219 | # Global pooling: 220 | # Undefined as PyTorch's functional API can be used for on-the-fly 221 | # shape inference if input size is not ImageNet's 224x224 222 | 223 | # Fully-connected classification layer 224 | num_inputs = self.stage_out_channels[-1] 225 | self.fc = nn.Linear(num_inputs, self.num_classes) 226 | 227 | 228 | def _make_stage(self, stage): 229 | modules = OrderedDict() 230 | stage_name = "ShuffleUnit_Stage{}".format(stage) 231 | 232 | # First ShuffleUnit in the stage 233 | # 1. non-grouped 1x1 convolution (i.e. pointwise convolution) 234 | # is used in Stage 2. Group convolutions used everywhere else. 235 | grouped_conv = stage > 2 236 | 237 | # 2. concatenation unit is always used. 238 | first_module = ShuffleUnit( 239 | self.stage_out_channels[stage-1], 240 | self.stage_out_channels[stage], 241 | groups=self.groups, 242 | grouped_conv=grouped_conv, 243 | combine='concat' 244 | ) 245 | modules[stage_name+"_0"] = first_module 246 | 247 | # add more ShuffleUnits depending on pre-defined number of repeats 248 | for i in range(self.stage_repeats[stage-2]): 249 | name = stage_name + "_{}".format(i+1) 250 | module = ShuffleUnit( 251 | self.stage_out_channels[stage], 252 | self.stage_out_channels[stage], 253 | groups=self.groups, 254 | grouped_conv=True, 255 | combine='add' 256 | ) 257 | modules[name] = module 258 | 259 | return nn.Sequential(modules) 260 | 261 | 262 | def forward(self, x): 263 | x = self.conv1(x) 264 | x = self.maxpool(x) 265 | 266 | x = self.stage2(x) 267 | x = self.stage3(x) 268 | x = self.stage4(x) 269 | 270 | # global average pooling layer 271 | x = F.avg_pool2d(x, x.data.size()[-2:]) 272 | 273 | # flatten for input to fully-connected layer 274 | x = x.view(x.size(0), -1) 275 | x = self.fc(x) 276 | 277 | return F.log_softmax(x, dim=1) 278 | 279 | 280 | if __name__ == "__main__": 281 | """Testing 282 | """ 283 | model = ShuffleNet() 284 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .inception import * -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from torchvision.models.utils import load_state_dict_from_url 4 | from torchvision.models.resnet import model_urls 5 | 6 | __all__ = ['resnet'] 7 | 8 | 9 | def BN(num_features, eps=1e-5, momentum=0.1, affine=True): 10 | return nn.BatchNorm2d(num_features, eps, momentum, affine) 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | def init_model(model): 20 | for m in model.modules(): 21 | if isinstance(m, nn.Conv2d): 22 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(0, math.sqrt(2. / n)) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | m.weight.data.fill_(1) 26 | m.bias.data.zero_() 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = BN(planes) 36 | self.relu1 = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = BN(planes) 39 | self.relu2 = nn.ReLU(inplace=True) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu1(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu2(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BN(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = BN(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = BN(planes * 4) 74 | self.relu1 = nn.ReLU(inplace=True) 75 | self.relu2 = nn.ReLU(inplace=True) 76 | self.relu3 = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu1(out) 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu2(out) 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu3(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self): 103 | super(ResNet, self).__init__() 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | BN(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | x = self.relu(x) 126 | x = self.maxpool(x) 127 | 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | x = self.avgpool(x) 134 | x = x.view(x.size(0), -1) 135 | x = self.fc(x) 136 | 137 | return x 138 | 139 | 140 | class ResNet_imagenet(ResNet): 141 | 142 | def __init__(self, num_classes=1000, 143 | block=Bottleneck, layers=[3, 4, 23, 3]): 144 | super(ResNet_imagenet, self).__init__() 145 | self.inplanes = 64 146 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 147 | bias=False) 148 | self.bn1 = BN(64) 149 | self.relu = nn.ReLU(inplace=True) 150 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 151 | self.layer1 = self._make_layer(block, 64, layers[0]) 152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 155 | self.avgpool = nn.AvgPool2d(7) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | init_model(self) 159 | self.regime = [ 160 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 161 | 'weight_decay': 1e-4, 'momentum': 0.9}, 162 | {'epoch': 30, 'lr': 1e-2}, 163 | {'epoch': 60, 'lr': 1e-3, 'weight_decay': 0}, 164 | {'epoch': 90, 'lr': 1e-4} 165 | ] 166 | 167 | 168 | class ResNet_cifar10(ResNet): 169 | 170 | def __init__(self, num_classes=10, 171 | block=BasicBlock, depth=18): 172 | super(ResNet_cifar10, self).__init__() 173 | self.inplanes = 16 174 | n = int((depth - 2) / 6) 175 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, 176 | bias=False) 177 | self.bn1 = BN(16) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.maxpool = lambda x: x 180 | self.layer1 = self._make_layer(block, 16, n) 181 | self.layer2 = self._make_layer(block, 32, n, stride=2) 182 | self.layer3 = self._make_layer(block, 64, n, stride=2) 183 | self.layer4 = lambda x: x 184 | self.avgpool = nn.AvgPool2d(8) 185 | self.fc = nn.Linear(64, num_classes) 186 | 187 | init_model(self) 188 | self.regime = [ 189 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 190 | 'weight_decay': 1e-4, 'momentum': 0.9}, 191 | {'epoch': 81, 'lr': 1e-2}, 192 | {'epoch': 122, 'lr': 1e-3, 'weight_decay': 0}, 193 | {'epoch': 164, 'lr': 1e-4} 194 | ] 195 | 196 | 197 | def resnet(**kwargs): 198 | num_classes, depth, dataset = map( 199 | kwargs.get, ['num_classes', 'depth', 'dataset']) 200 | if dataset == 'imagenet': 201 | num_classes = num_classes or 1000 202 | depth = depth or 50 203 | if depth == 18: 204 | model = ResNet_imagenet(num_classes=num_classes, 205 | block=BasicBlock, layers=[2, 2, 2, 2]) 206 | if depth == 34: 207 | model = ResNet_imagenet(num_classes=num_classes, 208 | block=BasicBlock, layers=[3, 4, 6, 3]) 209 | if depth == 50: 210 | model = ResNet_imagenet(num_classes=num_classes, 211 | block=Bottleneck, layers=[3, 4, 6, 3]) 212 | if depth == 101: 213 | model = ResNet_imagenet(num_classes=num_classes, 214 | block=Bottleneck, layers=[3, 4, 23, 3]) 215 | if depth == 152: 216 | model = ResNet_imagenet(num_classes=num_classes, 217 | block=Bottleneck, layers=[3, 8, 36, 3]) 218 | 219 | elif dataset == 'cifar10': 220 | num_classes = num_classes or 10 221 | depth = depth or 56 222 | model = ResNet_cifar10(num_classes=num_classes, 223 | block=BasicBlock, depth=depth) 224 | 225 | if 'pretrained' in kwargs and kwargs['pretrained'] and dataset == 'imagenet': 226 | arch = kwargs['arch'] 227 | progress = kwargs['progress'] if 'progress' in kwargs else True 228 | state_dict = load_state_dict_from_url(model_urls[arch], 229 | progress=progress) 230 | model.load_state_dict(state_dict) 231 | 232 | return model 233 | -------------------------------------------------------------------------------- /quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/quantization/__init__.py -------------------------------------------------------------------------------- /quantization/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/quantization/analysis/__init__.py -------------------------------------------------------------------------------- /quantization/analysis/loss_data_generation.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | proj_root_dir = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) 3 | sys.path.append(proj_root_dir) 4 | import argparse 5 | import torch 6 | import torchvision.models as models 7 | import scipy.optimize as opt 8 | from pathlib import Path 9 | import numpy as np 10 | import torch.nn as nn 11 | from itertools import count 12 | import torch.backends.cudnn as cudnn 13 | from quantization.quantizer import ModelQuantizer 14 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 15 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 16 | from utils.mllog import MLlogger 17 | from quantization.posttraining.cnn_classifier import CnnModel 18 | from tqdm import tqdm 19 | import pickle 20 | 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | home = str(Path.home()) 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 29 | choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet18)') 33 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 34 | help='dataset name') 35 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 36 | help='dataset folder') 37 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('-cb', '--cal-batch-size', default=256, type=int, help='Batch size for calibration') 45 | parser.add_argument('-cs', '--cal-set-size', default=256, type=int, help='Batch size for calibration') 46 | parser.add_argument('-p', '--print-freq', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 55 | parser.add_argument('--seed', default=0, type=int, 56 | help='seed for initializing training. ') 57 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 58 | help='GPU ids to use (e.g 0 1 2 3)') 59 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 60 | 61 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 62 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 63 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 64 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 65 | parser.add_argument('--qtype', default='max_static', help='Type of quantization method') 66 | parser.add_argument('-lp', type=float, help='p parameter of Lp norm', default=2.) 67 | parser.add_argument('--dont_fix_np_seed', '-dfns', action='store_true', help='Do not fix np seed even if seed specified') 68 | 69 | parser.add_argument('--grid_resolution', '-gr', type=int, help='Number of intervals in the grid, one coordinate.', default=11) 70 | 71 | 72 | def main(args): 73 | # Fix the seed 74 | random.seed(args.seed) 75 | if not args.dont_fix_np_seed: 76 | np.random.seed(args.seed) 77 | torch.manual_seed(args.seed) 78 | torch.cuda.manual_seed_all(args.seed) 79 | cudnn.deterministic = True 80 | torch.backends.cudnn.benchmark = False 81 | 82 | args.qtype = 'max_static' 83 | # create model 84 | # Always enable shuffling to avoid issues where we get bad results due to weak statistics 85 | custom_resnet = True 86 | inf_model = CnnModel(args.arch, custom_resnet, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 87 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 88 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 89 | 90 | all_layers = [] 91 | if args.bit_weights is not None: 92 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 93 | if args.bit_act is not None: 94 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 95 | if args.bit_act is not None and 'mobilenet' in args.arch: 96 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 97 | 98 | id1 = 0 99 | id2 = 1 100 | layers = [all_layers[id1], all_layers[id2]] 101 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 102 | nn.ReLU6: ActivationModuleWrapperPost, 103 | nn.Conv2d: ParameterModuleWrapperPost} 104 | 105 | mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory) 106 | 107 | loss = inf_model.evaluate_calibration() 108 | print("loss: {:.4f}".format(loss.item())) 109 | max_point = mq.get_clipping() 110 | 111 | n = args.grid_resolution 112 | x = np.linspace(0.01, max_point[0].item(), n) 113 | y = np.linspace(0.01, max_point[1].item(), n) 114 | X, Y = np.meshgrid(x, y) 115 | Z = np.empty((n, n)) 116 | for i, x_ in enumerate(tqdm(x)): 117 | for j, y_ in enumerate(y): 118 | # set clip value to qwrappers 119 | scales = np.array([X[i, j], Y[i, j]]) 120 | mq.set_clipping(scales, inf_model.device) 121 | 122 | # evaluate with clipping 123 | loss = inf_model.evaluate_calibration() 124 | Z[i][j] = loss.item() 125 | 126 | max_point = np.concatenate([max_point.cpu().numpy(), loss.cpu().numpy()]) 127 | 128 | def eval_pnorm(p): 129 | args.qtype = 'lp_norm' 130 | args.lp = p 131 | # Fix the seed 132 | random.seed(args.seed) 133 | if not args.dont_fix_np_seed: 134 | np.random.seed(args.seed) 135 | torch.manual_seed(args.seed) 136 | torch.cuda.manual_seed_all(args.seed) 137 | cudnn.deterministic = True 138 | torch.backends.cudnn.benchmark = False 139 | inf_model = CnnModel(args.arch, custom_resnet, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 140 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 141 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 142 | 143 | mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory) 144 | loss = inf_model.evaluate_calibration() 145 | point = mq.get_clipping() 146 | point = np.concatenate([point.cpu().numpy(), loss.cpu().numpy()]) 147 | 148 | del inf_model 149 | del mq 150 | return point 151 | 152 | del inf_model 153 | del mq 154 | l1_point = eval_pnorm(1.) 155 | print("loss l1: {:.4f}".format(l1_point[2])) 156 | 157 | l1_5_point = eval_pnorm(1.5) 158 | print("loss l1.5: {:.4f}".format(l1_5_point[2])) 159 | 160 | l2_point = eval_pnorm(2.) 161 | print("loss l2: {:.4f}".format(l2_point[2])) 162 | 163 | l2_5_point = eval_pnorm(2.5) 164 | print("loss l2.5: {:.4f}".format(l2_5_point[2])) 165 | 166 | l3_point = eval_pnorm(3.) 167 | print("loss l3: {:.4f}".format(l3_point[2])) 168 | 169 | f_name = "{}_l{}l{}_W{}A{}.pkl".format(args.arch, id1, id2, args.bit_weights, args.bit_act) 170 | f = open(os.path.join(proj_root_dir, 'data', f_name), 'wb') 171 | data = {'X': X, 'Y': Y, 'Z': Z, 172 | 'max_point': max_point, 'l1_point': l1_point, 'l1.5_point': l1_5_point, 'l2_point': l2_point, 173 | 'l2.5_point': l2_5_point, 'l3_point': l3_point} 174 | pickle.dump(data, f) 175 | f.close() 176 | print("Data saved to {}".format(f_name)) 177 | 178 | 179 | if __name__ == '__main__': 180 | args = parser.parse_args() 181 | main(args) 182 | -------------------------------------------------------------------------------- /quantization/analysis/loss_parametrization.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | proj_root_dir = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) 3 | sys.path.append(proj_root_dir) 4 | import argparse 5 | import torch 6 | import torchvision.models as models 7 | import scipy.optimize as opt 8 | from pathlib import Path 9 | import numpy as np 10 | import torch.nn as nn 11 | from itertools import count 12 | import torch.backends.cudnn as cudnn 13 | from quantization.quantizer import ModelQuantizer 14 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 15 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 16 | from utils.mllog import MLlogger 17 | from quantization.posttraining.cnn_classifier import CnnModel 18 | from tqdm import tqdm 19 | import pickle 20 | 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | home = str(Path.home()) 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 29 | choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet18)') 33 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 34 | help='dataset name') 35 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 36 | help='dataset folder') 37 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('-cb', '--cal-batch-size', default=256, type=int, help='Batch size for calibration') 45 | parser.add_argument('-cs', '--cal-set-size', default=256, type=int, help='Batch size for calibration') 46 | parser.add_argument('-p', '--print-freq', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 55 | parser.add_argument('--seed', default=0, type=int, 56 | help='seed for initializing training. ') 57 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 58 | help='GPU ids to use (e.g 0 1 2 3)') 59 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 60 | 61 | parser.add_argument('--bn_folding', '-bnf', action='store_true', help='Apply Batch Norm folding', default=False) 62 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 63 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 64 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 65 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 66 | parser.add_argument('--qtype', default='max_static', help='Type of quantization method') 67 | parser.add_argument('-lp', type=float, help='p parameter of Lp norm', default=2.) 68 | parser.add_argument('--dont_fix_np_seed', '-dfns', action='store_true', help='Do not fix np seed even if seed specified') 69 | parser.add_argument('--bcorr_w', '-bcw', action='store_true', help='Bias correction for weights', default=False) 70 | parser.add_argument('--grid_resolution', '-gr', type=int, help='Number of intervals in the grid, one coordinate.', default=11) 71 | 72 | 73 | def main(args): 74 | # Fix the seed 75 | random.seed(args.seed) 76 | if not args.dont_fix_np_seed: 77 | np.random.seed(args.seed) 78 | torch.manual_seed(args.seed) 79 | torch.cuda.manual_seed_all(args.seed) 80 | cudnn.deterministic = True 81 | torch.backends.cudnn.benchmark = False 82 | 83 | args.qtype = 'max_static' 84 | # create model 85 | # Always enable shuffling to avoid issues where we get bad results due to weak statistics 86 | custom_resnet = True 87 | custom_inception = True 88 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 89 | nn.ReLU6: ActivationModuleWrapperPost, 90 | nn.Conv2d: ParameterModuleWrapperPost} 91 | 92 | def eval_pnorm(p): 93 | args.qtype = 'lp_norm' 94 | args.lp = p 95 | # Fix the seed 96 | random.seed(args.seed) 97 | if not args.dont_fix_np_seed: 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.cuda.manual_seed_all(args.seed) 101 | cudnn.deterministic = True 102 | torch.backends.cudnn.benchmark = False 103 | 104 | inf_model = CnnModel(args.arch, custom_resnet, custom_inception, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 105 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 106 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 107 | 108 | all_layers = [] 109 | if args.bit_weights is not None: 110 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 111 | if args.bit_act is not None: 112 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 113 | if args.bit_act is not None and 'mobilenet' in args.arch: 114 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 115 | 116 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 117 | loss = inf_model.evaluate_calibration() 118 | point = mq.get_clipping() 119 | 120 | del inf_model 121 | del mq 122 | 123 | return point, loss 124 | 125 | random.seed(args.seed) 126 | if not args.dont_fix_np_seed: 127 | np.random.seed(args.seed) 128 | torch.manual_seed(args.seed) 129 | torch.cuda.manual_seed_all(args.seed) 130 | cudnn.deterministic = True 131 | torch.backends.cudnn.benchmark = False 132 | 133 | inf_model = CnnModel(args.arch, custom_resnet, custom_inception, args.pretrained, args.dataset, args.gpu_ids, 134 | args.datapath, 135 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 136 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 137 | 138 | all_layers = [] 139 | if args.bit_weights is not None: 140 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 141 | if args.bit_act is not None: 142 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 143 | if args.bit_act is not None and 'mobilenet' in args.arch: 144 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 145 | 146 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 147 | 148 | start_point, start_loss = eval_pnorm(2) 149 | end_point, end_loss = eval_pnorm(4.5) 150 | k = 50 151 | step = (end_point - start_point) / k 152 | 153 | print("start") 154 | print(start_point) 155 | print("end") 156 | print(end_point) 157 | losses = [] 158 | points = [] 159 | for i in range(k+1): 160 | point = start_point + i * step 161 | mq.set_clipping(point, inf_model.device) 162 | loss = inf_model.evaluate_calibration() 163 | losses.append(loss.item()) 164 | points.append(point.cpu().numpy()) 165 | print("({}: loss) - {}".format(i, loss.item())) 166 | 167 | f_name = "{}_W{}A{}_loss_vs_clipping.pkl".format(args.arch, args.bit_weights, args.bit_act) 168 | f = open(os.path.join(proj_root_dir, 'data', f_name), 'wb') 169 | data = {'start': start_point.cpu().numpy(), 'end': end_point.cpu().numpy(), 'loss': losses, 'points': points} 170 | pickle.dump(data, f) 171 | f.close() 172 | print("Data saved to {}".format(f_name)) 173 | 174 | 175 | if __name__ == '__main__': 176 | args = parser.parse_args() 177 | main(args) 178 | -------------------------------------------------------------------------------- /quantization/analysis/loss_parametrization1.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | proj_root_dir = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) 3 | sys.path.append(proj_root_dir) 4 | import argparse 5 | import torch 6 | import torchvision.models as models 7 | import scipy.optimize as opt 8 | from pathlib import Path 9 | import numpy as np 10 | import torch.nn as nn 11 | from itertools import count 12 | import torch.backends.cudnn as cudnn 13 | from quantization.quantizer import ModelQuantizer 14 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 15 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 16 | from utils.mllog import MLlogger 17 | from quantization.posttraining.cnn_classifier import CnnModel 18 | from tqdm import tqdm 19 | import pickle 20 | 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | home = str(Path.home()) 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 29 | choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet18)') 33 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 34 | help='dataset name') 35 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 36 | help='dataset folder') 37 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('-cb', '--cal-batch-size', default=256, type=int, help='Batch size for calibration') 45 | parser.add_argument('-cs', '--cal-set-size', default=256, type=int, help='Batch size for calibration') 46 | parser.add_argument('-p', '--print-freq', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 55 | parser.add_argument('--seed', default=0, type=int, 56 | help='seed for initializing training. ') 57 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 58 | help='GPU ids to use (e.g 0 1 2 3)') 59 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 60 | 61 | parser.add_argument('--bn_folding', '-bnf', action='store_true', help='Apply Batch Norm folding', default=False) 62 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 63 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 64 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 65 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 66 | parser.add_argument('--qtype', default='max_static', help='Type of quantization method') 67 | parser.add_argument('-lp', type=float, help='p parameter of Lp norm', default=2.) 68 | parser.add_argument('--dont_fix_np_seed', '-dfns', action='store_true', help='Do not fix np seed even if seed specified') 69 | parser.add_argument('--bcorr_w', '-bcw', action='store_true', help='Bias correction for weights', default=False) 70 | parser.add_argument('--grid_resolution', '-gr', type=int, help='Number of intervals in the grid, one coordinate.', default=11) 71 | 72 | 73 | def main(args): 74 | # Fix the seed 75 | random.seed(args.seed) 76 | if not args.dont_fix_np_seed: 77 | np.random.seed(args.seed) 78 | torch.manual_seed(args.seed) 79 | torch.cuda.manual_seed_all(args.seed) 80 | cudnn.deterministic = True 81 | torch.backends.cudnn.benchmark = False 82 | 83 | args.qtype = 'max_static' 84 | # create model 85 | # Always enable shuffling to avoid issues where we get bad results due to weak statistics 86 | custom_resnet = True 87 | custom_inception = True 88 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 89 | nn.ReLU6: ActivationModuleWrapperPost, 90 | nn.Conv2d: ParameterModuleWrapperPost} 91 | 92 | def eval_pnorm(p): 93 | args.qtype = 'lp_norm' 94 | args.lp = p 95 | # Fix the seed 96 | random.seed(args.seed) 97 | if not args.dont_fix_np_seed: 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.cuda.manual_seed_all(args.seed) 101 | cudnn.deterministic = True 102 | torch.backends.cudnn.benchmark = False 103 | 104 | inf_model = CnnModel(args.arch, custom_resnet, custom_inception, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 105 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 106 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 107 | 108 | all_layers = [] 109 | if args.bit_weights is not None: 110 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 111 | if args.bit_act is not None: 112 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 113 | if args.bit_act is not None and 'mobilenet' in args.arch: 114 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 115 | 116 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 117 | loss = inf_model.evaluate_calibration() 118 | point = mq.get_clipping() 119 | 120 | del inf_model 121 | del mq 122 | 123 | return point, loss 124 | 125 | random.seed(args.seed) 126 | if not args.dont_fix_np_seed: 127 | np.random.seed(args.seed) 128 | torch.manual_seed(args.seed) 129 | torch.cuda.manual_seed_all(args.seed) 130 | cudnn.deterministic = True 131 | torch.backends.cudnn.benchmark = False 132 | 133 | inf_model = CnnModel(args.arch, custom_resnet, custom_inception, args.pretrained, args.dataset, args.gpu_ids, 134 | args.datapath, 135 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 136 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 137 | 138 | all_layers = [] 139 | if args.bit_weights is not None: 140 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 141 | if args.bit_act is not None: 142 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 143 | if args.bit_act is not None and 'mobilenet' in args.arch: 144 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 145 | 146 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 147 | 148 | p1 = torch.tensor([0.7677084 , 1.7640269 , 0.80914754, 2.044024 , 0.87229156, 149 | 1.2659631 , 0.78454655, 1.3018194 , 0.7894693 , 0.92967707, 150 | 0.5754433 , 0.9115604 , 0.5689196 , 1.2382566 , 0.601773 ]) 151 | p2 = torch.tensor([0.8135005 , 1.7248632 , 0.8009758 , 2.005755 , 0.83956134, 152 | 1.2431265 , 0.7720454 , 1.3013302 , 0.76733077, 0.96402454, 153 | 0.5914314 , 0.9579072 , 0.56543064, 1.2535284 , 0.6261679]) 154 | 155 | k = 50 156 | step = p1 - p2 157 | losses = [] 158 | points = [] 159 | for i in range(k+1): 160 | point = p1 + 0.4*i * step - 10*step 161 | mq.set_clipping(point, inf_model.device) 162 | loss = inf_model.evaluate_calibration() 163 | losses.append(loss.item()) 164 | points.append(point.cpu().numpy()) 165 | print("({}: loss) - {}".format(i, loss.item())) 166 | 167 | f_name = "{}_W{}A{}_loss_conjugate_dir.pkl".format(args.arch, args.bit_weights, args.bit_act) 168 | f = open(os.path.join(proj_root_dir, 'data', f_name), 'wb') 169 | data = {'start': p1.cpu().numpy(), 'loss': losses, 'points': points} 170 | pickle.dump(data, f) 171 | f.close() 172 | print("Data saved to {}".format(f_name)) 173 | 174 | 175 | if __name__ == '__main__': 176 | args = parser.parse_args() 177 | main(args) 178 | -------------------------------------------------------------------------------- /quantization/analysis/loss_vs_p.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | proj_root_dir = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) 3 | sys.path.append(proj_root_dir) 4 | import argparse 5 | import torch 6 | import torchvision.models as models 7 | import scipy.optimize as opt 8 | from pathlib import Path 9 | import numpy as np 10 | import torch.nn as nn 11 | from itertools import count 12 | import torch.backends.cudnn as cudnn 13 | from quantization.quantizer import ModelQuantizer 14 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 15 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 16 | from utils.mllog import MLlogger 17 | from quantization.posttraining.cnn_classifier import CnnModel 18 | from tqdm import tqdm 19 | import pickle 20 | 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | home = str(Path.home()) 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 29 | choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet18)') 33 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 34 | help='dataset name') 35 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 36 | help='dataset folder') 37 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('-cb', '--cal-batch-size', default=256, type=int, help='Batch size for calibration') 45 | parser.add_argument('-cs', '--cal-set-size', default=256, type=int, help='Batch size for calibration') 46 | parser.add_argument('-p', '--print-freq', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 55 | parser.add_argument('--seed', default=0, type=int, 56 | help='seed for initializing training. ') 57 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 58 | help='GPU ids to use (e.g 0 1 2 3)') 59 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 60 | 61 | parser.add_argument('--bn_folding', '-bnf', action='store_true', help='Apply Batch Norm folding', default=False) 62 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 63 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 64 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 65 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 66 | parser.add_argument('--qtype', default='max_static', help='Type of quantization method') 67 | parser.add_argument('-lp', type=float, help='p parameter of Lp norm', default=2.) 68 | parser.add_argument('--dont_fix_np_seed', '-dfns', action='store_true', help='Do not fix np seed even if seed specified') 69 | parser.add_argument('--bcorr_w', '-bcw', action='store_true', help='Bias correction for weights', default=False) 70 | parser.add_argument('--grid_resolution', '-gr', type=int, help='Number of intervals in the grid, one coordinate.', default=11) 71 | 72 | 73 | def main(args): 74 | # Fix the seed 75 | random.seed(args.seed) 76 | if not args.dont_fix_np_seed: 77 | np.random.seed(args.seed) 78 | torch.manual_seed(args.seed) 79 | torch.cuda.manual_seed_all(args.seed) 80 | cudnn.deterministic = True 81 | torch.backends.cudnn.benchmark = False 82 | 83 | args.qtype = 'max_static' 84 | # create model 85 | # Always enable shuffling to avoid issues where we get bad results due to weak statistics 86 | custom_resnet = True 87 | custom_inception = True 88 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 89 | nn.ReLU6: ActivationModuleWrapperPost, 90 | nn.Conv2d: ParameterModuleWrapperPost} 91 | 92 | def eval_pnorm(p): 93 | args.qtype = 'lp_norm' 94 | args.lp = p 95 | # Fix the seed 96 | random.seed(args.seed) 97 | if not args.dont_fix_np_seed: 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.cuda.manual_seed_all(args.seed) 101 | cudnn.deterministic = True 102 | torch.backends.cudnn.benchmark = False 103 | 104 | inf_model = CnnModel(args.arch, custom_resnet, custom_inception, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 105 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 106 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 107 | 108 | all_layers = [] 109 | if args.bit_weights is not None: 110 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 111 | if args.bit_act is not None: 112 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 113 | if args.bit_act is not None and 'mobilenet' in args.arch: 114 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 115 | 116 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 117 | loss = inf_model.evaluate_calibration() 118 | point = mq.get_clipping() 119 | 120 | # evaluate 121 | # acc = inf_model.validate() 122 | 123 | del inf_model 124 | del mq 125 | 126 | return point, loss 127 | 128 | ps = np.linspace(2, 4.5, 50) 129 | losses = [] 130 | points = [] 131 | for p in tqdm(ps): 132 | point, loss = eval_pnorm(p) 133 | points.append(point.cpu().numpy()) 134 | losses.append(loss.item()) 135 | print("(p, loss) - ({}, {})".format(p, loss.item())) 136 | 137 | f_name = "{}_W{}A{}_loss_vs_p_points.pkl".format(args.arch, args.bit_weights, args.bit_act) 138 | f = open(os.path.join(proj_root_dir, 'data', f_name), 'wb') 139 | data = {'p': ps, 'loss': losses, 'points': points} 140 | pickle.dump(data, f) 141 | f.close() 142 | print("Data saved to {}".format(f_name)) 143 | 144 | 145 | if __name__ == '__main__': 146 | args = parser.parse_args() 147 | main(args) 148 | -------------------------------------------------------------------------------- /quantization/analysis/separability_index.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 3 | import argparse 4 | import torch 5 | import torchvision.models as models 6 | import scipy.optimize as opt 7 | from pathlib import Path 8 | import numpy as np 9 | import torch.nn as nn 10 | from itertools import count 11 | import torch.backends.cudnn as cudnn 12 | from quantization.quantizer import ModelQuantizer 13 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 14 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 15 | from utils.mllog import MLlogger 16 | from quantization.posttraining.cnn_classifier import CnnModel 17 | from tqdm import tqdm 18 | 19 | 20 | model_names = sorted(name for name in models.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and callable(models.__dict__[name])) 23 | 24 | home = str(Path.home()) 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 27 | choices=model_names, 28 | help='model architecture: ' + 29 | ' | '.join(model_names) + 30 | ' (default: resnet18)') 31 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 32 | help='dataset name') 33 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 34 | help='dataset folder') 35 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('-b', '--batch-size', default=64, type=int, 38 | metavar='N', 39 | help='mini-batch size (default: 256), this is the total ' 40 | 'batch size of all GPUs on the current node when ' 41 | 'using Data Parallel or Distributed Data Parallel') 42 | parser.add_argument('-cb', '--cal-batch-size', default=64, type=int, help='Batch size for calibration') 43 | parser.add_argument('-cs', '--cal-set-size', default=64, type=int, help='Batch size for calibration') 44 | parser.add_argument('-p', '--print-freq', default=10, type=int, 45 | metavar='N', help='print frequency (default: 10)') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none)') 48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 49 | help='evaluate model on validation set') 50 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 51 | help='use pre-trained model') 52 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 53 | parser.add_argument('--seed', default=0, type=int, 54 | help='seed for initializing training. ') 55 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 56 | help='GPU ids to use (e.g 0 1 2 3)') 57 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 58 | 59 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 60 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 61 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 62 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 63 | parser.add_argument('--qtype', default='max_static', help='Type of quantization method') 64 | parser.add_argument('--dont_fix_np_seed', '-dfns', action='store_true', help='Do not fix np seed even if seed specified') 65 | 66 | parser.add_argument('--num_iter', '-i', type=int, help='Number of bits for activations', default=3) 67 | parser.add_argument('--num_points', '-n', type=int, help='Number of bits for activations', default=100) 68 | 69 | 70 | def separability_index(f, m, n, k=1, gpu=False, status_callback=None): 71 | g = None 72 | max_ = None 73 | gamma_ = [] 74 | T_ = [] 75 | for j in range(k): 76 | x, z = torch.tensor(np.random.uniform(0, 1, size=(n, m))).float(), torch.tensor( 77 | np.random.uniform(0, 1, size=(n, m))).float() 78 | if gpu: 79 | x = x.cuda() 80 | z = z.cuda() 81 | 82 | print("Calaculate f(x)") 83 | fx = f(x).double() 84 | if max_ is None: 85 | max_ = fx.max() 86 | else: 87 | max_ = torch.max(max_, fx.max()) 88 | print("Calaculate f(z)") 89 | fz = f(z).double() 90 | max_ = torch.max(max_, fz.max()) 91 | 92 | t1 = fx + (m - 1) * fz 93 | 94 | print("Calaculate f(xj,zj')") 95 | t2 = t1.new_zeros((n,), dtype=torch.float64) 96 | for i in range(m): 97 | y = z.clone() 98 | y[:, i] = x[:, i] 99 | fy = f(y).double() 100 | max_ = torch.max(max_, fy.max()) 101 | t2 += fy 102 | 103 | g_ = fx * (t1 - t2) 104 | if g is None: 105 | g = g_ 106 | else: 107 | g = torch.cat([g, g_], dim=0) 108 | 109 | gamma = g.mean() 110 | gamma_.append(gamma.cpu().item()) 111 | 112 | s = torch.sqrt(torch.sum((g - gamma) ** 2) / (g.numel() - 1)) 113 | T = np.sqrt(g.numel()) * gamma / max(s, 1e-2) 114 | print(s) 115 | T_.append(T.cpu().item()) 116 | 117 | if status_callback is not None: 118 | status_callback(j, gamma, T, max_) 119 | 120 | return gamma_, T_, max_ 121 | 122 | # assum x is matrix (n,m) in range [0,1] 123 | # n - number of sumples 124 | # m - number of variables 125 | def model_func(x, scales, inf_model, mq, a, b): 126 | loss = x.new_empty(x.shape[0]) 127 | for i in tqdm(range(x.shape[0])): 128 | # in general do transformation X: [0, 1] => [a, b] where [a, b] is region of interest 129 | # e.g. region around point that minimizes some metric 130 | # We can do simple linear transformation (x + alpha) / beta where 131 | alpha = a / (b - a) 132 | beta = 1 / (b - a) 133 | r = (x[i] + alpha) / beta 134 | r = torch.min(r, scales) 135 | r = torch.max(r, r.new_zeros(1)) 136 | mq.set_clipping(r, inf_model.device) 137 | 138 | # evaluate with clipping 139 | loss[i] = inf_model.evaluate_calibration() 140 | return loss 141 | 142 | 143 | def main(args, ml_logger): 144 | # Fix the seed 145 | random.seed(args.seed) 146 | if not args.dont_fix_np_seed: 147 | np.random.seed(args.seed) 148 | torch.manual_seed(args.seed) 149 | torch.cuda.manual_seed_all(args.seed) 150 | cudnn.deterministic = True 151 | torch.backends.cudnn.benchmark = False 152 | 153 | args.qtype = 'max_static' 154 | # create model 155 | # Always enable shuffling to avoid issues where we get bad results due to weak statistics 156 | inf_model = CnnModel(args.arch, args.custom_resnet, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 157 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 158 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 159 | 160 | all_layers = [] 161 | if args.bit_weights is not None: 162 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.Conv2d)][1:-1] 163 | if args.bit_act is not None: 164 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU)][1:-1] 165 | if args.bit_act is not None and 'mobilenet' in args.arch: 166 | all_layers += [n for n, m in inf_model.model.named_modules() if isinstance(m, nn.ReLU6)][1:-1] 167 | 168 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 169 | nn.ReLU6: ActivationModuleWrapperPost, 170 | nn.Conv2d: ParameterModuleWrapperPost} 171 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 172 | 173 | loss = inf_model.evaluate_calibration() 174 | print("loss: {:.4f}".format(loss.item())) 175 | ml_logger.log_metric('loss', loss.item(), step='auto') 176 | 177 | # get clipping values 178 | p_max = mq.get_clipping() 179 | # print(init) 180 | 181 | args.qtype = 'l2_norm' 182 | inf_model = CnnModel(args.arch, args.custom_resnet, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 183 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 184 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 185 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 186 | loss = inf_model.evaluate_calibration() 187 | print("loss l2: {:.4f}".format(loss.item())) 188 | p_l2 = mq.get_clipping() 189 | 190 | args.qtype = 'l3_norm' 191 | inf_model = CnnModel(args.arch, args.custom_resnet, args.pretrained, args.dataset, args.gpu_ids, args.datapath, 192 | batch_size=args.batch_size, shuffle=True, workers=args.workers, print_freq=args.print_freq, 193 | cal_batch_size=args.cal_batch_size, cal_set_size=args.cal_set_size, args=args) 194 | mq = ModelQuantizer(inf_model.model, args, all_layers, replacement_factory) 195 | loss = inf_model.evaluate_calibration() 196 | print("loss l2: {:.4f}".format(loss.item())) 197 | p_l3 = mq.get_clipping() 198 | 199 | # gamma_avg = 0 200 | # T_avg = 0 201 | num_iter = args.num_iter 202 | n = args.num_points 203 | 204 | def status_callback(i, gamma, T, f_max): 205 | T = T.item() 206 | gamma = gamma.item() 207 | f_max = f_max.item() 208 | 209 | print("gamma^2: {}, T: {}, max: {}".format(gamma, T, f_max)) 210 | ml_logger.log_metric('gamma', gamma, step='auto') 211 | ml_logger.log_metric('T', T, step='auto') 212 | ml_logger.log_metric('f_max', f_max, step='auto') 213 | T_norm = T / np.sqrt(i+1) 214 | ml_logger.log_metric('T_norm', T_norm, step='auto') 215 | gamma_norm = gamma / f_max**2 216 | ml_logger.log_metric('gamma_norm', gamma_norm, step='auto') 217 | 218 | gamma_, T_, f_max = separability_index(lambda x: model_func(x, p_max, inf_model, mq, p_l2, p_l3), len(p_max), n, num_iter, 219 | gpu=True, status_callback=status_callback) 220 | 221 | gamma_norm = np.mean(np.array(gamma_) / f_max.item()**2) 222 | T_norm = np.mean(np.array(T_) / np.sqrt(np.arange(1, num_iter + 1))) 223 | 224 | print("gamma^2 norm: {}, T norm: {}".format(gamma_norm, T_norm)) 225 | ml_logger.log_metric('gamma_tot', gamma_norm, step='auto') 226 | ml_logger.log_metric('T_tot', T_norm, step='auto') 227 | 228 | 229 | if __name__ == '__main__': 230 | args = parser.parse_args() 231 | with MLlogger(os.path.join(home, 'mxt-sim/mllog_runs'), args.experiment, args, 232 | name_args=[args.arch, args.dataset, "W{}A{}".format(args.bit_weights, args.bit_act)]) as ml_logger: 233 | main(args, ml_logger) 234 | -------------------------------------------------------------------------------- /quantization/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/quantization/methods/__init__.py -------------------------------------------------------------------------------- /quantization/methods/stochastic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Noise(object): 6 | def __init__(self, module, tensor, *args, **kwargs): 7 | self.dist = torch.distributions.uniform.Uniform(-0.5, 0.5) 8 | s = 1./2 # TODO: get from kwargs 9 | with torch.no_grad(): 10 | self.amp = (tensor.max() - tensor.min()) * s 11 | 12 | def __call__(self, tensor): 13 | noise = self.dist.sample(sample_shape=tensor.shape).to(tensor.device) * self.amp 14 | return tensor + noise 15 | 16 | def loggable_parameters(self): 17 | return [] 18 | -------------------------------------------------------------------------------- /quantization/methods/uniform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RoundSTE(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, input): 8 | output = torch.round(input) 9 | return output 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | return grad_output 14 | 15 | 16 | class QuantizationBase(object): 17 | def __init__(self, module, num_bits): 18 | self.module = module 19 | self.num_bits = num_bits 20 | self.num_bins = int(2 ** num_bits) 21 | self.opt_params = {} 22 | self.named_params = [] 23 | 24 | def register_buffer(self, name, value): 25 | if hasattr(self.module, name): 26 | delattr(self.module, name) 27 | self.module.register_buffer(name, value) 28 | setattr(self, name, getattr(self.module, name)) 29 | 30 | def register_parameter(self, name, value): 31 | if hasattr(self.module, name): 32 | delattr(self.module, name) 33 | self.module.register_parameter(name, nn.Parameter(value)) 34 | setattr(self, name, getattr(self.module, name)) 35 | 36 | self.named_params.append((name, getattr(self.module, name))) 37 | 38 | def __add_optim_params__(self, optim_type, dataset, params): 39 | learnable_params = [d for n, d in params if n in self.learned_parameters()] 40 | self.opt_params[optim_type + '_' + dataset] = learnable_params 41 | 42 | def optim_parameters(self): 43 | return self.opt_params 44 | 45 | def loggable_parameters(self): 46 | return self.named_parameters() 47 | 48 | def named_parameters(self): 49 | named_params = [(n, p) for n, p in self.named_params if n in self.learned_parameters()] 50 | return named_params 51 | 52 | @staticmethod 53 | def learned_parameters(): 54 | return [] 55 | 56 | 57 | class UniformQuantization(QuantizationBase): 58 | def __init__(self, module, num_bits, symmetric, uint=False, stochastic=False, tails=False): 59 | super(UniformQuantization, self).__init__(module, num_bits) 60 | if not symmetric and not uint: 61 | raise RuntimeError("Can't perform integer quantization on non symmetric distributions.") 62 | 63 | self.symmetric = symmetric 64 | self.uint = uint 65 | self.stochastic = stochastic 66 | self.tails = tails 67 | if uint: 68 | self.qmax = 2 ** self.num_bits - 1 69 | self.qmin = 0 70 | else: 71 | self.qmax = 2 ** (self.num_bits - 1) - 1 72 | self.qmin = -self.qmax - 1 73 | 74 | if tails: 75 | self.qmax -= 0.5 + 1e-6 76 | self.qmin -= 0.5 77 | 78 | def __quantize__(self, tensor, alpha): 79 | delta = (2 if self.symmetric else 1) * alpha / (self.num_bins - 1) 80 | delta = max(delta, 1e-8) 81 | 82 | # quantize 83 | if self.uint and self.symmetric: 84 | t_q = (tensor + alpha) / delta 85 | else: 86 | t_q = tensor / delta 87 | 88 | # stochastic rounding 89 | if self.stochastic and self.module.training: 90 | with torch.no_grad(): 91 | noise = t_q.new_empty(t_q.shape).uniform_(-0.5, 0.5) 92 | t_q += noise 93 | 94 | # clamp and round 95 | t_q = torch.clamp(t_q, self.qmin, self.qmax) 96 | t_q = RoundSTE.apply(t_q) 97 | assert torch.unique(t_q).shape[0] <= self.num_bins 98 | 99 | # uncomment to debug quantization 100 | # print(torch.unique(t_q)) 101 | 102 | # de-quantize 103 | if self.uint and self.symmetric: 104 | t_q = t_q * delta - alpha 105 | else: 106 | t_q = t_q * delta 107 | 108 | return t_q 109 | 110 | # def __distiller_quantize__(self, tensor, alpha): 111 | # # Leave one bit for sign 112 | # n = self.qmax 113 | # scale = n / alpha 114 | # t_q = torch.clamp(torch.round(tensor * scale), self.qmin, self.qmax) 115 | # t_q = t_q / scale 116 | # return t_q 117 | 118 | def __quantize_gemmlowp__(self, tensor, min_, max_): 119 | assert self.uint is True 120 | delta = (max_ - min_) / (self.num_bins - 1) 121 | delta = max(delta, 1e-8) 122 | 123 | # quantize 124 | t_q = (tensor - min_) / delta 125 | 126 | # stochastic rounding 127 | if self.stochastic and self.module.training: 128 | with torch.no_grad(): 129 | noise = t_q.new_empty(t_q.shape).uniform_(-0.5, 0.5) 130 | t_q += noise 131 | 132 | # clamp and round 133 | t_q = torch.clamp(t_q, self.qmin, self.qmax) 134 | t_q = RoundSTE.apply(t_q) 135 | assert torch.unique(t_q).shape[0] <= self.num_bins 136 | 137 | # uncomment to debug quantization 138 | # print(torch.unique(t_q)) 139 | 140 | # de-quantize 141 | t_q = t_q * delta + min_ 142 | 143 | return t_q 144 | 145 | def __for_repr__(self): 146 | return [('bits', self.num_bits), ('symmetric', self.symmetric), ('tails', self.tails)] 147 | 148 | def __repr__(self): 149 | s = '{} - ['.format(type(self).__name__) 150 | for name, value in self.__for_repr__(): 151 | s += '{}: {}, '.format(name, value) 152 | return s + ']' 153 | # return '{} - bits: {}, symmetric: {}'.format(type(self).__name__, self.num_bits, self.symmetric) 154 | 155 | 156 | class MaxAbsDynamicQuantization(UniformQuantization): 157 | def __init__(self, module, tensor, num_bits, symmetric, stochastic=False): 158 | super(MaxAbsDynamicQuantization, self).__init__(module, tensor, num_bits, symmetric) 159 | 160 | def __call__(self, tensor): 161 | alpha = tensor.abs().max() 162 | t_q = self.__quantize__(tensor, alpha) 163 | return t_q 164 | 165 | 166 | class MinMaxQuantization(UniformQuantization): 167 | def __init__(self, module, tensor, num_bits, symmetric, uint=False, stochastic=False, kwargs={}): 168 | super(MinMaxQuantization, self).__init__(module, num_bits, symmetric, uint, stochastic) 169 | 170 | with torch.no_grad(): 171 | self.register_buffer('min', tensor.new_tensor([tensor.min()])) 172 | self.register_buffer('max', tensor.new_tensor([tensor.max()])) 173 | 174 | def __call__(self, tensor): 175 | t_q = self.__quantize_gemmlowp__(tensor, min_=self.min, max_=self.max) 176 | return t_q 177 | -------------------------------------------------------------------------------- /quantization/posttraining/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/quantization/posttraining/__init__.py -------------------------------------------------------------------------------- /quantization/posttraining/cnn_classifier.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random 2 | import torch 3 | import torchvision.models as models 4 | import scipy.optimize as opt 5 | from pathlib import Path 6 | import numpy as np 7 | import torch.nn as nn 8 | from itertools import count 9 | from utils.data import get_dataset 10 | from utils.preprocess import get_transform 11 | from utils.meters import AverageMeter, ProgressMeter, accuracy 12 | from torch.utils.data import RandomSampler 13 | from models.resnet import resnet as custom_resnet 14 | from models.inception import inception_v3 as custom_inception 15 | from utils.misc import normalize_module_name, arch2depth 16 | 17 | 18 | class CnnModel(object): 19 | def __init__(self, arch, use_custom_resnet, use_custom_inception, pretrained, dataset, gpu_ids, datapath, batch_size, shuffle, workers, 20 | print_freq, cal_batch_size, cal_set_size, args): 21 | self.arch = arch 22 | self.use_custom_resnet = use_custom_resnet 23 | self.pretrained = pretrained 24 | self.dataset = dataset 25 | self.gpu_ids = gpu_ids 26 | self.datapath = datapath 27 | self.batch_size = batch_size 28 | self.shuffle = shuffle 29 | self.workers = workers 30 | self.print_freq = print_freq 31 | self.cal_batch_size = cal_batch_size 32 | self.cal_set_size = cal_set_size # TODO: pass it as cmd line argument 33 | 34 | # create model 35 | if 'resnet' in arch and use_custom_resnet: 36 | model = custom_resnet(arch=arch, pretrained=pretrained, depth=arch2depth(arch), 37 | dataset=dataset) 38 | elif 'inception_v3' in arch and use_custom_inception: 39 | model = custom_inception(pretrained=pretrained) 40 | else: 41 | print("=> using pre-trained model '{}'".format(arch)) 42 | model = models.__dict__[arch](pretrained=pretrained) 43 | 44 | self.device = torch.device('cuda:{}'.format(gpu_ids[0])) 45 | 46 | torch.cuda.set_device(gpu_ids[0]) 47 | model = model.to(self.device) 48 | 49 | # optionally resume from a checkpoint 50 | if args.resume: 51 | if os.path.isfile(args.resume): 52 | print("=> loading checkpoint '{}'".format(args.resume)) 53 | checkpoint = torch.load(args.resume, self.device) 54 | args.start_epoch = checkpoint['epoch'] 55 | checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} 56 | model.load_state_dict(checkpoint['state_dict'], strict=False) 57 | print("=> loaded checkpoint '{}' (epoch {})" 58 | .format(args.resume, checkpoint['epoch'])) 59 | else: 60 | print("=> no checkpoint found at '{}'".format(args.resume)) 61 | 62 | if len(gpu_ids) > 1: 63 | # DataParallel will divide and allocate batch_size to all available GPUs 64 | if arch.startswith('alexnet') or arch.startswith('vgg'): 65 | model.features = torch.nn.DataParallel(model.features, gpu_ids) 66 | else: 67 | model = torch.nn.DataParallel(model, gpu_ids) 68 | 69 | self.model = model 70 | 71 | if args.bn_folding: 72 | print("Applying batch-norm folding ahead of post-training quantization") 73 | from utils.absorb_bn import search_absorbe_bn 74 | search_absorbe_bn(model) 75 | 76 | # define loss function (criterion) and optimizer 77 | self.criterion = torch.nn.CrossEntropyLoss().to(self.device) 78 | 79 | val_data = get_dataset(dataset, 'val', get_transform(dataset, augment=False, scale_size=299 if 'inception' in arch else None, 80 | input_size=299 if 'inception' in arch else None), 81 | datasets_path=datapath) 82 | self.val_loader = torch.utils.data.DataLoader( 83 | val_data, 84 | batch_size=batch_size, shuffle=shuffle, 85 | num_workers=workers, pin_memory=True) 86 | 87 | self.cal_loader = torch.utils.data.DataLoader( 88 | val_data, 89 | batch_size=self.cal_batch_size, shuffle=shuffle, 90 | num_workers=workers, pin_memory=True) 91 | 92 | @staticmethod 93 | def __arch2depth__(arch): 94 | depth = None 95 | if 'resnet18' in arch: 96 | depth = 18 97 | elif 'resnet34' in arch: 98 | depth = 34 99 | elif 'resnet50' in arch: 100 | depth = 50 101 | elif 'resnet101' in arch: 102 | depth = 101 103 | 104 | return depth 105 | 106 | def evaluate_calibration(self): 107 | # switch to evaluate mode 108 | self.model.eval() 109 | 110 | with torch.no_grad(): 111 | if not hasattr(self, 'cal_set'): 112 | self.cal_set = [] 113 | # TODO: Workaround, refactor this later 114 | for i, (images, target) in enumerate(self.cal_loader): 115 | if i * self.cal_batch_size >= self.cal_set_size: 116 | break 117 | images = images.to(self.device, non_blocking=True) 118 | target = target.to(self.device, non_blocking=True) 119 | self.cal_set.append((images, target)) 120 | 121 | res = torch.tensor([0.]).to(self.device) 122 | for i in range(len(self.cal_set)): 123 | images, target = self.cal_set[i] 124 | # compute output 125 | output = self.model(images) 126 | loss = self.criterion(output, target) 127 | res += loss 128 | 129 | return res / len(self.cal_set) 130 | 131 | def validate(self): 132 | batch_time = AverageMeter('Time', ':6.3f') 133 | losses = AverageMeter('Loss', ':.4e') 134 | top1 = AverageMeter('Acc@1', ':6.2f') 135 | top5 = AverageMeter('Acc@5', ':6.2f') 136 | progress = ProgressMeter(len(self.val_loader), batch_time, losses, top1, top5, 137 | prefix='Test: ') 138 | 139 | # switch to evaluate mode 140 | self.model.eval() 141 | 142 | with torch.no_grad(): 143 | end = time.time() 144 | for i, (images, target) in enumerate(self.val_loader): 145 | images = images.to(self.device, non_blocking=True) 146 | target = target.to(self.device, non_blocking=True) 147 | 148 | # compute output 149 | output = self.model(images) 150 | loss = self.criterion(output, target) 151 | 152 | # measure accuracy and record loss 153 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 154 | losses.update(loss.item(), images.size(0)) 155 | top1.update(acc1.item(), images.size(0)) 156 | top5.update(acc5.item(), images.size(0)) 157 | 158 | # measure elapsed time 159 | batch_time.update(time.time() - end) 160 | end = time.time() 161 | 162 | if i % self.print_freq == 0: 163 | progress.print(i) 164 | 165 | # TODO: this should also be done with the ProgressMeter 166 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 167 | .format(top1=top1, top5=top5)) 168 | 169 | return top1.avg 170 | -------------------------------------------------------------------------------- /quantization/posttraining/cnn_classifier_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 5 | 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.models as models 18 | import numpy as np 19 | from utils.data import get_dataset 20 | from utils.preprocess import get_transform 21 | from quantization.quantizer import ModelQuantizer 22 | from pathlib import Path 23 | from utils.mllog import MLlogger 24 | from utils.meters import AverageMeter, ProgressMeter, accuracy 25 | from models.resnet import resnet as custom_resnet 26 | from models.inception import inception_v3 as custom_inception 27 | from quantization.posttraining.module_wrapper import ActivationModuleWrapperPost, ParameterModuleWrapperPost 28 | 29 | home = str(Path.home()) 30 | 31 | model_names = sorted(name for name in models.__dict__ 32 | if name.islower() and not name.startswith("__") 33 | and callable(models.__dict__[name])) 34 | 35 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 37 | choices=model_names, 38 | help='model architecture: ' + 39 | ' | '.join(model_names) + 40 | ' (default: resnet18)') 41 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 42 | help='dataset name') 43 | parser.add_argument('--datapath', metavar='DATAPATH', type=str, default=None, 44 | help='dataset folder') 45 | parser.add_argument('-j', '--workers', default=25, type=int, metavar='N', 46 | help='number of data loading workers (default: 4)') 47 | parser.add_argument('-b', '--batch-size', default=256, type=int, 48 | metavar='N', 49 | help='mini-batch size (default: 256), this is the total ' 50 | 'batch size of all GPUs on the current node when ' 51 | 'using Data Parallel or Distributed Data Parallel') 52 | parser.add_argument('-p', '--print-freq', default=10, type=int, 53 | metavar='N', help='print frequency (default: 10)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 59 | help='use pre-trained model') 60 | parser.add_argument('--custom_resnet', action='store_true', help='use custom resnet implementation') 61 | parser.add_argument('--custom_inception', action='store_true', help='use custom inception implementation') 62 | 63 | parser.add_argument('--seed', default=12345, type=int, 64 | help='seed for initializing training. ') 65 | parser.add_argument('--gpu_ids', default=[0], type=int, nargs='+', 66 | help='GPU ids to use (e.g 0 1 2 3)') 67 | parser.add_argument('--shuffle', '-sh', action='store_true', help='shuffle data') 68 | 69 | parser.add_argument('--quantize', '-q', action='store_true', help='Enable quantization', default=False) 70 | parser.add_argument('--experiment', '-exp', help='Name of the experiment', default='default') 71 | parser.add_argument('--bit_weights', '-bw', type=int, help='Number of bits for weights', default=None) 72 | parser.add_argument('--bit_act', '-ba', type=int, help='Number of bits for activations', default=None) 73 | parser.add_argument('--pre_relu', dest='pre_relu', action='store_true', help='use pre-ReLU quantization') 74 | parser.add_argument('--qtype', default='aciq_laplace', help='Type of quantization method') 75 | parser.add_argument('-lp', type=float, help='p parameter of Lp norm', default=3.) 76 | parser.add_argument('--bcorr_w', '-bcw', action='store_true', help='Bias correction for weights', default=False) 77 | 78 | best_acc1 = 0 79 | 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | args.post_relu = not args.pre_relu 84 | 85 | if args.seed is not None: 86 | random.seed(args.seed) 87 | np.random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | torch.cuda.manual_seed_all(args.seed) 90 | cudnn.deterministic = True 91 | torch.backends.cudnn.benchmark = False 92 | 93 | with MLlogger(os.path.join(home, 'mxt-sim/mllog_runs'), args.experiment, args, 94 | name_args=[args.arch, args.dataset, "W{}A{}".format(args.bit_weights, args.bit_act)]) as ml_logger: 95 | main_worker(args, ml_logger) 96 | 97 | 98 | def arch2depth(arch): 99 | depth = None 100 | if 'resnet18' in arch: 101 | depth = 18 102 | elif 'resnet34' in arch: 103 | depth = 34 104 | elif 'resnet50' in arch: 105 | depth = 50 106 | elif 'resnet101' in arch: 107 | depth = 101 108 | 109 | return depth 110 | 111 | 112 | def main_worker(args, ml_logger): 113 | global best_acc1 114 | 115 | if args.gpu_ids is not None: 116 | print("Use GPU: {} for training".format(args.gpu_ids)) 117 | 118 | # create model 119 | if 'resnet' in args.arch and args.custom_resnet: 120 | model = custom_resnet(arch=args.arch, pretrained=args.pretrained, depth=arch2depth(args.arch), dataset=args.dataset) 121 | elif 'inception_v3' in args.arch and args.custom_inception: 122 | model = custom_inception(pretrained=args.pretrained) 123 | 124 | elif args.pretrained: 125 | print("=> using pre-trained model '{}'".format(args.arch)) 126 | model = models.__dict__[args.arch](pretrained=True) 127 | else: 128 | print("=> creating model '{}'".format(args.arch)) 129 | model = models.__dict__[args.arch]() 130 | 131 | device = torch.device('cuda:{}'.format(args.gpu_ids[0])) 132 | cudnn.benchmark = True 133 | 134 | torch.cuda.set_device(args.gpu_ids[0]) 135 | model = model.to(device) 136 | 137 | # optionally resume from a checkpoint 138 | if args.resume: 139 | if os.path.isfile(args.resume): 140 | # mq = ModelQuantizer(model, args) 141 | print("=> loading checkpoint '{}'".format(args.resume)) 142 | checkpoint = torch.load(args.resume, device) 143 | args.start_epoch = checkpoint['epoch'] 144 | best_acc1 = checkpoint['best_acc1'] 145 | # best_acc1 may be from a checkpoint from a different GPU 146 | # best_acc1 = best_acc1.to(device) 147 | model.load_state_dict(checkpoint['state_dict']) 148 | # optimizer.load_state_dict(checkpoint['optimizer']) 149 | print("=> loaded checkpoint '{}' (epoch {})" 150 | .format(args.resume, checkpoint['epoch'])) 151 | else: 152 | print("=> no checkpoint found at '{}'".format(args.resume)) 153 | 154 | if len(args.gpu_ids) > 1: 155 | # DataParallel will divide and allocate batch_size to all available GPUs 156 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 157 | model.features = torch.nn.DataParallel(model.features, args.gpu_ids) 158 | else: 159 | model = torch.nn.DataParallel(model, args.gpu_ids) 160 | 161 | val_data = get_dataset(args.dataset, 'val', get_transform(args.dataset, augment=False, scale_size = 299 if 'inception' in args.arch else None, 162 | input_size = 299 if 'inception' in args.arch else None), datasets_path=args.datapath) 163 | val_loader = torch.utils.data.DataLoader( 164 | val_data, 165 | batch_size=args.batch_size, shuffle=args.shuffle, 166 | num_workers=args.workers, pin_memory=True) 167 | 168 | # define loss function (criterion) and optimizer 169 | criterion = nn.CrossEntropyLoss().to(device) 170 | if 'inception' in args.arch and args.custom_inception: 171 | first = 3 172 | last = -1 173 | else: 174 | first = 1 175 | last = -1 176 | if args.quantize: 177 | all_convs = [n for n, m in model.named_modules() if isinstance(m, nn.Conv2d)] 178 | all_relu = [n for n, m in model.named_modules() if isinstance(m, nn.ReLU)] 179 | all_relu6 = [n for n, m in model.named_modules() if isinstance(m, nn.ReLU6)] 180 | layers = all_relu[first:last] + all_relu6[first:last] + all_convs[first:last] 181 | replacement_factory = {nn.ReLU: ActivationModuleWrapperPost, 182 | nn.ReLU6: ActivationModuleWrapperPost, 183 | nn.Conv2d: ParameterModuleWrapperPost} 184 | mq = ModelQuantizer(model, args, layers, replacement_factory) 185 | mq.log_quantizer_state(ml_logger, -1) 186 | 187 | acc = validate(val_loader, model, criterion, args, device) 188 | ml_logger.log_metric('Val Acc1', acc, step='auto') 189 | 190 | 191 | def validate(val_loader, model, criterion, args, device): 192 | batch_time = AverageMeter('Time', ':6.3f') 193 | losses = AverageMeter('Loss', ':.4e') 194 | top1 = AverageMeter('Acc@1', ':6.2f') 195 | top5 = AverageMeter('Acc@5', ':6.2f') 196 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 197 | prefix='Test: ') 198 | 199 | # switch to evaluate mode 200 | model.eval() 201 | 202 | with torch.no_grad(): 203 | end = time.time() 204 | for i, (images, target) in enumerate(val_loader): 205 | images = images.to(device, non_blocking=True) 206 | target = target.to(device, non_blocking=True) 207 | 208 | # compute output 209 | output = model(images) 210 | loss = criterion(output, target) 211 | 212 | # measure accuracy and record loss 213 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 214 | losses.update(loss.item(), images.size(0)) 215 | top1.update(acc1.item(), images.size(0)) 216 | top5.update(acc5.item(), images.size(0)) 217 | 218 | # measure elapsed time 219 | batch_time.update(time.time() - end) 220 | end = time.time() 221 | 222 | if i % args.print_freq == 0: 223 | progress.print(i) 224 | 225 | # TODO: this should also be done with the ProgressMeter 226 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 227 | .format(top1=top1, top5=top5)) 228 | 229 | return top1.avg 230 | 231 | 232 | if __name__ == '__main__': 233 | main() 234 | -------------------------------------------------------------------------------- /quantization/posttraining/module_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn as nn 5 | 6 | from quantization.methods.clipped_uniform import AngDistanceQuantization, L3NormQuantization, L2NormQuantization, \ 7 | LpNormQuantization, L1NormQuantization 8 | from quantization.methods.clipped_uniform import MaxAbsStaticQuantization, AciqLaplaceQuantization, \ 9 | AciqGausQuantization, LogLikeQuantization 10 | from quantization.methods.clipped_uniform import MseNoPriorQuantization, MseUniformPriorQuantization 11 | from quantization.methods.non_uniform import KmeansQuantization 12 | 13 | quantization_mapping = {'max_static': MaxAbsStaticQuantization, 14 | 'aciq_laplace': AciqLaplaceQuantization, 15 | 'aciq_gaus': AciqGausQuantization, 16 | 'mse_uniform_prior': MseUniformPriorQuantization, 17 | 'mse_no_prior': MseNoPriorQuantization, 18 | 'ang_dis': AngDistanceQuantization, 19 | 'l3_norm': L3NormQuantization, 20 | 'l2_norm': L2NormQuantization, 21 | 'l1_norm': L1NormQuantization, 22 | 'lp_norm': LpNormQuantization, 23 | 'log_like': LogLikeQuantization 24 | } 25 | 26 | 27 | def is_positive(module): 28 | return isinstance(module, nn.ReLU) or isinstance(module, nn.ReLU6) 29 | 30 | 31 | class ActivationModuleWrapperPost(nn.Module): 32 | def __init__(self, name, wrapped_module, **kwargs): 33 | super(ActivationModuleWrapperPost, self).__init__() 34 | self.name = name 35 | self.wrapped_module = wrapped_module 36 | self.bits_out = kwargs['bits_out'] 37 | self.qtype = kwargs['qtype'] 38 | self.post_relu = True 39 | self.enabled = True 40 | self.active = True 41 | 42 | if self.bits_out is not None: 43 | self.out_quantization = self.out_quantization_default = None 44 | 45 | def __init_out_quantization__(tensor): 46 | self.out_quantization_default = quantization_mapping[self.qtype](self, tensor, self.bits_out, 47 | symmetric=(not is_positive(wrapped_module)), 48 | uint=True, kwargs=kwargs) 49 | self.out_quantization = self.out_quantization_default 50 | print("ActivationModuleWrapperPost - {} | {} | {}".format(self.name, str(self.out_quantization), str(tensor.device))) 51 | 52 | self.out_quantization_init_fn = __init_out_quantization__ 53 | 54 | def __enabled__(self): 55 | return self.enabled and self.active and self.bits_out is not None 56 | 57 | def forward(self, *input): 58 | # Uncomment to enable dump 59 | # torch.save(*input, os.path.join('dump', self.name + '_in' + '.pt')) 60 | 61 | if self.post_relu: 62 | out = self.wrapped_module(*input) 63 | 64 | # Quantize output 65 | if self.__enabled__(): 66 | self.verify_initialized(self.out_quantization, out, self.out_quantization_init_fn) 67 | out = self.out_quantization(out) 68 | else: 69 | # Quantize output 70 | if self.__enabled__(): 71 | self.verify_initialized(self.out_quantization, *input, self.out_quantization_init_fn) 72 | out = self.out_quantization(*input) 73 | else: 74 | out = self.wrapped_module(*input) 75 | 76 | # Uncomment to enable dump 77 | # torch.save(out, os.path.join('dump', self.name + '_out' + '.pt')) 78 | 79 | return out 80 | 81 | def get_quantization(self): 82 | return self.out_quantization 83 | 84 | def set_quantization(self, qtype, kwargs, verbose=False): 85 | self.out_quantization = qtype(self, self.bits_out, symmetric=(not is_positive(self.wrapped_module)), 86 | uint=True, kwargs=kwargs) 87 | if verbose: 88 | print("ActivationModuleWrapperPost - {} | {} | {}".format(self.name, str(self.out_quantization), 89 | str(kwargs['device']))) 90 | 91 | def set_quant_method(self, method=None): 92 | if self.bits_out is not None: 93 | if method == 'kmeans': 94 | self.out_quantization = KmeansQuantization(self.bits_out) 95 | else: 96 | self.out_quantization = self.out_quantization_default 97 | 98 | @staticmethod 99 | def verify_initialized(quantization_handle, tensor, init_fn): 100 | if quantization_handle is None: 101 | init_fn(tensor) 102 | 103 | def log_state(self, step, ml_logger): 104 | if self.__enabled__(): 105 | if self.out_quantization is not None: 106 | for n, p in self.out_quantization.named_parameters(): 107 | if p.numel() == 1: 108 | ml_logger.log_metric(self.name + '.' + n, p.item(), step='auto') 109 | else: 110 | for i, e in enumerate(p): 111 | ml_logger.log_metric(self.name + '.' + n + '.' + str(i), e.item(), step='auto') 112 | 113 | 114 | class ParameterModuleWrapperPost(nn.Module): 115 | def __init__(self, name, wrapped_module, **kwargs): 116 | super(ParameterModuleWrapperPost, self).__init__() 117 | self.name = name 118 | self.wrapped_module = wrapped_module 119 | self.forward_functor = kwargs['forward_functor'] 120 | self.bit_weights = kwargs['bits_weight'] 121 | self.bits_out = kwargs['bits_out'] 122 | self.qtype = kwargs['qtype'] 123 | self.enabled = True 124 | self.active = True 125 | self.centroids_hist = {} 126 | self.log_weights_hist = False 127 | self.log_weights_mse = False 128 | self.log_clustering = False 129 | self.bn = kwargs['bn'] if 'bn' in kwargs else None 130 | self.dynamic_weight_quantization = True 131 | self.bcorr_w = kwargs['bcorr_w'] 132 | 133 | setattr(self, 'weight', wrapped_module.weight) 134 | delattr(wrapped_module, 'weight') 135 | if hasattr(wrapped_module, 'bias'): 136 | setattr(self, 'bias', wrapped_module.bias) 137 | delattr(wrapped_module, 'bias') 138 | 139 | if self.bit_weights is not None: 140 | self.weight_quantization_default = quantization_mapping[self.qtype](self, self.weight, self.bit_weights, 141 | symmetric=True, uint=True, kwargs=kwargs) 142 | self.weight_quantization = self.weight_quantization_default 143 | if not self.dynamic_weight_quantization: 144 | self.weight_q = self.weight_quantization(self.weight) 145 | self.weight_mse = torch.mean((self.weight_q - self.weight)**2).item() 146 | print("ParameterModuleWrapperPost - {} | {} | {}".format(self.name, str(self.weight_quantization), 147 | str(self.weight.device))) 148 | 149 | def __enabled__(self): 150 | return self.enabled and self.active and self.bit_weights is not None 151 | 152 | def bias_corr(self, x, xq): 153 | bias_q = xq.view(xq.shape[0], -1).mean(-1) 154 | bias_orig = x.view(x.shape[0], -1).mean(-1) 155 | bcorr = bias_q - bias_orig 156 | 157 | return xq - bcorr.view(bcorr.numel(), 1, 1, 1) if len(x.shape) == 4 else xq - bcorr.view(bcorr.numel(), 1) 158 | 159 | def forward(self, *input): 160 | w = self.weight 161 | if self.__enabled__(): 162 | # Quantize weights 163 | if self.dynamic_weight_quantization: 164 | w = self.weight_quantization(self.weight) 165 | 166 | if self.bcorr_w: 167 | w = self.bias_corr(self.weight, w) 168 | else: 169 | w = self.weight_q 170 | 171 | out = self.forward_functor(*input, weight=w, bias=(self.bias if hasattr(self, 'bias') else None)) 172 | 173 | return out 174 | 175 | def get_quantization(self): 176 | return self.weight_quantization 177 | 178 | def set_quantization(self, qtype, kwargs, verbose=False): 179 | self.weight_quantization = qtype(self, self.bit_weights, symmetric=True, uint=True, kwargs=kwargs) 180 | if verbose: 181 | print("ParameterModuleWrapperPost - {} | {} | {}".format(self.name, str(self.weight_quantization), 182 | str(kwargs['device']))) 183 | 184 | def set_quant_method(self, method=None): 185 | if self.bit_weights is not None: 186 | if method is None: 187 | self.weight_quantization = self.weight_quantization_default 188 | elif method == 'kmeans': 189 | self.weight_quantization = KmeansQuantization(self.bit_weights) 190 | else: 191 | self.weight_quantization = self.weight_quantization_default 192 | 193 | # TODO: make it more generic 194 | def set_quant_mode(self, mode=None): 195 | if self.bit_weights is not None: 196 | if mode is not None: 197 | self.soft = self.weight_quantization.soft_quant 198 | self.hard = self.weight_quantization.hard_quant 199 | if mode is None: 200 | self.weight_quantization.soft_quant = self.soft 201 | self.weight_quantization.hard_quant = self.hard 202 | elif mode == 'soft': 203 | self.weight_quantization.soft_quant = True 204 | self.weight_quantization.hard_quant = False 205 | elif mode == 'hard': 206 | self.weight_quantization.soft_quant = False 207 | self.weight_quantization.hard_quant = True 208 | 209 | def log_state(self, step, ml_logger): 210 | if self.__enabled__(): 211 | if self.weight_quantization is not None: 212 | for n, p in self.weight_quantization.loggable_parameters(): 213 | if p.numel() == 1: 214 | ml_logger.log_metric(self.name + '.' + n, p.item(), step='auto') 215 | else: 216 | for i, e in enumerate(p): 217 | ml_logger.log_metric(self.name + '.' + n + '.' + str(i), e.item(), step='auto') 218 | 219 | if self.log_weights_hist: 220 | ml_logger.tf_logger.add_histogram(self.name + '.weight', self.weight.cpu().flatten(), step='auto') 221 | 222 | if self.log_weights_mse: 223 | ml_logger.log_metric(self.name + '.mse_q', self.weight_mse, step='auto') 224 | -------------------------------------------------------------------------------- /quantization/qat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/quantization/qat/__init__.py -------------------------------------------------------------------------------- /quantization/quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from itertools import count 5 | from quantization.methods.clipped_uniform import LearnedStepSizeQuantization 6 | from quantization.methods.non_uniform import LearnableDifferentiableQuantization, LearnedCentroidsQuantization 7 | from quantization.methods.clipped_uniform import FixedClipValueQuantization 8 | # from utils.absorb_bn import is_absorbing, is_bn 9 | 10 | 11 | class Conv2dFunctor: 12 | def __init__(self, conv2d): 13 | self.conv2d = conv2d 14 | 15 | def __call__(self, *input, weight, bias): 16 | res = torch.nn.functional.conv2d(*input, weight, bias, self.conv2d.stride, self.conv2d.padding, 17 | self.conv2d.dilation, self.conv2d.groups) 18 | return res 19 | 20 | 21 | class LinearFunctor: 22 | def __init__(self, linear): 23 | self.linear = linear 24 | 25 | def __call__(self, *input, weight, bias): 26 | res = torch.nn.functional.linear(*input, weight, bias) 27 | return res 28 | 29 | 30 | class EmbeddingFunctor: 31 | def __init__(self, embedding): 32 | self.embedding = embedding 33 | 34 | def __call__(self, *input, weight, bias=None): 35 | res = torch.nn.functional.embedding( 36 | *input, weight, self.embedding.padding_idx, self.embedding.max_norm, 37 | self.embedding.norm_type, self.embedding.scale_grad_by_freq, self.embedding.sparse) 38 | return res 39 | 40 | 41 | # class QuantizationScheduler(object): 42 | # _iter_counter = count(0) 43 | # 44 | # def __init__(self, model, optimizer, grad_rate, enable=True): 45 | # self.quantizations = [] 46 | # self.optimizer = optimizer 47 | # self.grad_rate = grad_rate 48 | # self.scheduling_enabled = enable 49 | # 50 | # model.register_forward_hook(lambda m, inp, out: self.step(m)) 51 | # 52 | # def register_module_quantization(self, qwrapper): 53 | # self.quantizations.append(qwrapper) 54 | # if len(self.quantizations) == 1 or not self.scheduling_enabled: 55 | # qwrapper.enabled = True 56 | # else: 57 | # qwrapper.enabled = False 58 | # 59 | # def step(self, model): 60 | # if model.training: 61 | # step = next(QuantizationScheduler._iter_counter) 62 | # 63 | # if self.scheduling_enabled and model.training: 64 | # if step % self.grad_rate == 0: 65 | # i = int(step / self.grad_rate) 66 | # if i < len(self.quantizations): 67 | # self.quantizations[i].enabled = True 68 | 69 | 70 | class OptimizerBridge(object): 71 | def __init__(self, optimizer, settings={'algo': 'SGD', 'dataset': 'imagenet'}): 72 | self.optimizer = optimizer 73 | self.settings = settings 74 | 75 | def add_quantization_params(self, all_quant_params): 76 | key = self.settings['algo'] + '_' + self.settings['dataset'] 77 | if key in all_quant_params: 78 | quant_params = all_quant_params[key] 79 | for group in quant_params: 80 | self.optimizer.add_param_group(group) 81 | 82 | 83 | class ModelQuantizer: 84 | def __init__(self, model, args, quantizable_layers, replacement_factory, optimizer_bridge=None): 85 | self.model = model 86 | self.args = args 87 | self.bit_weights = args.bit_weights 88 | self.bit_act = args.bit_act 89 | self.post_relu = True 90 | self.functor_map = {nn.Conv2d: Conv2dFunctor, nn.Linear: LinearFunctor, nn.Embedding: EmbeddingFunctor} 91 | self.replacement_factory = replacement_factory 92 | 93 | self.optimizer_bridge = optimizer_bridge 94 | 95 | self.quantization_wrappers = [] 96 | self.quantizable_modules = [] 97 | self.quantizable_layers = quantizable_layers 98 | self._pre_process_container(model) 99 | self._create_quantization_wrappers() 100 | 101 | # TODO: hack, make it generic 102 | self.quantization_params = LearnedStepSizeQuantization.learned_parameters() 103 | 104 | def load_state_dict(self, state_dict): 105 | for name, qwrapper in self.quantization_wrappers: 106 | qwrapper.load_state_dict(state_dict) 107 | 108 | def freeze(self): 109 | for n, p in self.model.named_parameters(): 110 | # TODO: hack, make it more robust 111 | if not np.any([qp in n for qp in self.quantization_params]): 112 | p.requires_grad = False 113 | 114 | # for n, p in self.model.named_parameters(): 115 | # if not ('conv' in n or 'downsample.0' in n): 116 | # p.requires_grad = False 117 | 118 | # for n, p in self.model.named_parameters(): 119 | # if not ('bn' in n or 'downsample.1' in n): 120 | # p.requires_grad = False 121 | 122 | # for n, m in self.model.named_modules(): 123 | # if isinstance(m, nn.BatchNorm2d): 124 | # m.momentum = 0 125 | 126 | @staticmethod 127 | def has_children(module): 128 | try: 129 | next(module.children()) 130 | return True 131 | except StopIteration: 132 | return False 133 | 134 | def _create_quantization_wrappers(self): 135 | for qm in self.quantizable_modules: 136 | # replace module by it's wrapper 137 | fn = self.functor_map[type(qm.module)](qm.module) if type(qm.module) in self.functor_map else None 138 | args = {"bits_out": self.bit_act, "bits_weight": self.bit_weights, "forward_functor": fn, 139 | "post_relu": self.post_relu, "optim_bridge": self.optimizer_bridge} 140 | args.update(vars(self.args)) 141 | if hasattr(qm, 'bn'): 142 | args['bn'] = qm.bn 143 | module_wrapper = self.replacement_factory[type(qm.module)](qm.full_name, qm.module, 144 | **args) 145 | setattr(qm.container, qm.name, module_wrapper) 146 | self.quantization_wrappers.append((qm.full_name, module_wrapper)) 147 | 148 | def _pre_process_container(self, container, prefix=''): 149 | prev, prev_name = None, None 150 | for name, module in container.named_children(): 151 | # if is_bn(module) and is_absorbing(prev) and prev_name in self.quantizable_layers: 152 | # # Pass BN module to prev module quantization wrapper for BN folding/unfolding 153 | # self.quantizable_modules[-1].bn = module 154 | 155 | full_name = prefix + name 156 | if full_name in self.quantizable_layers: 157 | self.quantizable_modules.append( 158 | type('', (object,), {'name': name, 'full_name': full_name, 'module': module, 'container': container})() 159 | ) 160 | 161 | if self.has_children(module): 162 | # For container we call recursively 163 | self._pre_process_container(module, full_name + '.') 164 | 165 | prev = module 166 | prev_name = full_name 167 | 168 | def log_quantizer_state(self, ml_logger, step): 169 | if self.bit_weights is not None or self.bit_act is not None: 170 | with torch.no_grad(): 171 | for name, qwrapper in self.quantization_wrappers: 172 | qwrapper.log_state(step, ml_logger) 173 | 174 | def get_qwrappers(self): 175 | return [qwrapper for (name, qwrapper) in self.quantization_wrappers if qwrapper.__enabled__()] 176 | 177 | def set_clipping(self, clipping, device): # TODO: handle device internally somehow 178 | qwrappers = self.get_qwrappers() 179 | for i, qwrapper in enumerate(qwrappers): 180 | qwrapper.set_quantization(FixedClipValueQuantization, 181 | {'clip_value': clipping[i], 'device': device}) 182 | 183 | def get_clipping(self): 184 | clipping = [] 185 | qwrappers = self.get_qwrappers() 186 | for i, qwrapper in enumerate(qwrappers): 187 | q = qwrapper.get_quantization() 188 | clip_value = getattr(q, 'alpha') 189 | clipping.append(clip_value.item()) 190 | 191 | return qwrappers[0].get_quantization().alpha.new_tensor(clipping) 192 | 193 | class QuantMethod: 194 | def __init__(self, quantization_wrappers, method): 195 | self.quantization_wrappers = quantization_wrappers 196 | self.method = method 197 | 198 | def __enter__(self): 199 | for n, qw in self.quantization_wrappers: 200 | qw.set_quant_method(self.method) 201 | 202 | def __exit__(self, exc_type, exc_val, exc_tb): 203 | for n, qw in self.quantization_wrappers: 204 | qw.set_quant_method() 205 | 206 | class QuantMode: 207 | def __init__(self, quantization_wrappers, mode): 208 | self.quantization_wrappers = quantization_wrappers 209 | self.mode = mode 210 | 211 | def __enter__(self): 212 | for n, qw in self.quantization_wrappers: 213 | qw.set_quant_mode(self.mode) 214 | 215 | def __exit__(self, exc_type, exc_val, exc_tb): 216 | for n, qw in self.quantization_wrappers: 217 | qw.set_quant_mode() 218 | 219 | class DisableQuantizer: 220 | def __init__(self, quantization_wrappers): 221 | self.quantization_wrappers = quantization_wrappers 222 | 223 | def __enter__(self): 224 | for n, qw in self.quantization_wrappers: 225 | qw.active = False 226 | 227 | def __exit__(self, exc_type, exc_val, exc_tb): 228 | for n, qw in self.quantization_wrappers: 229 | qw.active = True 230 | 231 | def quantization_method(self, method): 232 | return ModelQuantizer.QuantMethod(self.quantization_wrappers, method) 233 | 234 | def quantization_mode(self, mode): 235 | return ModelQuantizer.QuantMode(self.quantization_wrappers, mode) 236 | 237 | def disable(self): 238 | return ModelQuantizer.DisableQuantizer(self.quantization_wrappers) 239 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynahshan/nn-quantization-pytorch/12cd3bf4b5430128483220a2521998c7a7ae9bd1/utils/__init__.py -------------------------------------------------------------------------------- /utils/absorb_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def absorb_bn(module, bn_module): 6 | w = module.weight.data 7 | if module.bias is None: 8 | zeros = torch.Tensor(module.out_channels).zero_().type(w.type()) 9 | module.bias = nn.Parameter(zeros) 10 | b = module.bias.data 11 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) 12 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) 13 | b.add_(-bn_module.running_mean).mul_(invstd) 14 | 15 | if bn_module.affine: 16 | w.mul_(bn_module.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) 17 | b.mul_(bn_module.weight.data).add_(bn_module.bias.data) 18 | 19 | bn_module.register_buffer('running_mean', torch.zeros(module.out_channels).cuda()) 20 | bn_module.register_buffer('running_var', torch.ones(module.out_channels).cuda()) 21 | bn_module.register_parameter('weight', None) 22 | bn_module.register_parameter('bias', None) 23 | bn_module.affine = False 24 | 25 | 26 | def is_bn(m): 27 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 28 | 29 | 30 | def is_absorbing(m): 31 | return (isinstance(m, nn.Conv2d) and m.groups == 1) or isinstance(m, nn.Linear) 32 | 33 | 34 | def search_absorbe_bn(model): 35 | prev = None 36 | for m in model.children(): 37 | if is_bn(m) and is_absorbing(prev): 38 | m.absorbed = True 39 | absorb_bn(prev, m) 40 | search_absorbe_bn(m) 41 | prev = m 42 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets as datasets 3 | 4 | __DATASETS_DEFAULT_PATH = '/tmp/Datasets/' 5 | 6 | from pathlib import Path 7 | home = str(Path.home()) 8 | __IMAGENET_DEFAULT_PATH = '/home/cvds_lab/datasets/ILSVRC2012/' 9 | 10 | def get_dataset(name, split='train', transform=None, 11 | target_transform=None, download=True, 12 | datasets_path=None): 13 | train = (split == 'train') 14 | root = os.path.join(datasets_path if datasets_path is not None else __DATASETS_DEFAULT_PATH, name) 15 | if name == 'cifar10': 16 | return datasets.CIFAR10(root=root, 17 | train=train, 18 | transform=transform, 19 | target_transform=target_transform, 20 | download=download) 21 | elif name == 'cifar100': 22 | return datasets.CIFAR100(root=root, 23 | train=train, 24 | transform=transform, 25 | target_transform=target_transform, 26 | download=download) 27 | elif name == 'mnist': 28 | return datasets.MNIST(root=root, 29 | train=train, 30 | transform=transform, 31 | target_transform=target_transform, 32 | download=download) 33 | elif name == 'stl10': 34 | return datasets.STL10(root=root, 35 | split=split, 36 | transform=transform, 37 | target_transform=target_transform, 38 | download=download) 39 | elif name == 'imagenet': 40 | if datasets_path is None: 41 | datasets_path = __IMAGENET_DEFAULT_PATH 42 | if train: 43 | root = os.path.join(datasets_path, 'train') 44 | else: 45 | root = os.path.join(datasets_path, 'val') 46 | return datasets.ImageFolder(root=root, 47 | transform=transform, 48 | target_transform=target_transform) 49 | -------------------------------------------------------------------------------- /utils/entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def shannon_entropy(t, handle_negative=False): 5 | # workaround for out of memory issue 6 | torch.cuda.empty_cache() 7 | 8 | pk = torch.unique(t.flatten(), return_counts=True)[1] 9 | 10 | probs = pk.float() / pk.sum() 11 | probs[probs == 0] = 1 12 | entropy = -probs * torch.log2(probs) 13 | res = entropy.sum() 14 | 15 | return res 16 | -------------------------------------------------------------------------------- /utils/experiments_log.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | class ExperimentsLogger(object): 5 | pass 6 | 7 | def log_metric(self, name, value): 8 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from itertools import cycle 4 | import torch 5 | import logging.config 6 | from datetime import datetime 7 | import json 8 | 9 | import pandas as pd 10 | from bokeh.io import output_file, save, show 11 | from bokeh.plotting import figure 12 | from bokeh.layouts import column 13 | from bokeh.models import Div 14 | 15 | try: 16 | import hyperdash 17 | HYPERDASH_AVAILABLE = True 18 | except ImportError: 19 | HYPERDASH_AVAILABLE = False 20 | 21 | 22 | def export_args_namespace(args, filename): 23 | """ 24 | args: argparse.Namespace 25 | arguments to save 26 | filename: string 27 | filename to save at 28 | """ 29 | with open(filename, 'w') as fp: 30 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) 31 | 32 | class logfile_filter: 33 | def filter(self, record): 34 | return record.levelname == 'DEBUG' 35 | 36 | def setup_logging(log_file='log.txt', resume=False): 37 | """ 38 | Setup logging configuration 39 | """ 40 | if os.path.isfile(log_file) and resume: 41 | file_mode = 'a' 42 | else: 43 | file_mode = 'w' 44 | 45 | root_logger = logging.getLogger() 46 | if root_logger.handlers: 47 | root_logger.removeHandler(root_logger.handlers[0]) 48 | logging.basicConfig(level=logging.DEBUG, 49 | format="%(asctime)s - %(levelname)s - %(message)s", 50 | datefmt="%Y-%m-%d %H:%M:%S", 51 | filename=log_file, 52 | filemode=file_mode) 53 | console = logging.StreamHandler() 54 | console.setLevel(logging.INFO) 55 | formatter = logging.Formatter('%(message)s') 56 | console.setFormatter(formatter) 57 | logging.getLogger('').addHandler(console) 58 | 59 | handler = logging.FileHandler(os.path.join(os.path.dirname(log_file), "quantizer-debug.log"), "w") 60 | handler.setLevel(logging.DEBUG) 61 | formatter = logging.Formatter("%(message)s") 62 | handler.setFormatter(formatter) 63 | handler.addFilter(logfile_filter()) 64 | logging.getLogger('').addHandler(handler) 65 | 66 | 67 | class ResultsLog(object): 68 | 69 | supported_data_formats = ['csv', 'json'] 70 | 71 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'): 72 | """ 73 | Parameters 74 | ---------- 75 | path: string 76 | path to directory to save data files 77 | plot_path: string 78 | path to directory to save plot files 79 | title: string 80 | title of HTML file 81 | params: Namespace 82 | optionally save parameters for results 83 | resume: bool 84 | resume previous logging 85 | data_format: str('csv'|'json') 86 | which file format to use to save the data 87 | """ 88 | if data_format not in ResultsLog.supported_data_formats: 89 | raise ValueError('data_format must of the following: ' + 90 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats])) 91 | 92 | if data_format == 'json': 93 | self.data_path = '{}.json'.format(path) 94 | else: 95 | self.data_path = '{}.csv'.format(path) 96 | if params is not None: 97 | export_args_namespace(params, '{}.json'.format(path)) 98 | self.plot_path = '{}.html'.format(path) 99 | self.results = None 100 | self.clear() 101 | self.first_save = True 102 | if os.path.isfile(self.data_path): 103 | if resume: 104 | self.load(self.data_path) 105 | self.first_save = False 106 | else: 107 | os.remove(self.data_path) 108 | self.results = pd.DataFrame() 109 | else: 110 | self.results = pd.DataFrame() 111 | 112 | self.title = title 113 | self.data_format = data_format 114 | 115 | if HYPERDASH_AVAILABLE: 116 | name = self.title if title != '' else path 117 | self.hd_experiment = hyperdash.Experiment(name) 118 | if params is not None: 119 | for k, v in params._get_kwargs(): 120 | self.hd_experiment.param(k, v, log=False) 121 | 122 | def clear(self): 123 | self.figures = [] 124 | 125 | def add(self, **kwargs): 126 | """Add a new row to the dataframe 127 | example: 128 | resultsLog.add(epoch=epoch_num, train_loss=loss, 129 | test_loss=test_loss) 130 | """ 131 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 132 | self.results = self.results.append(df, ignore_index=True) 133 | if hasattr(self, 'hd_experiment'): 134 | for k, v in kwargs.items(): 135 | self.hd_experiment.metric(k, v, log=False) 136 | 137 | def smooth(self, column_name, window): 138 | """Select an entry to smooth over time""" 139 | # TODO: smooth only new data 140 | smoothed_column = self.results[column_name].rolling( 141 | window=window, center=False).mean() 142 | self.results[column_name + '_smoothed'] = smoothed_column 143 | 144 | def save(self, title=None): 145 | """save the json file. 146 | Parameters 147 | ---------- 148 | title: string 149 | title of the HTML file 150 | """ 151 | title = title or self.title 152 | if len(self.figures) > 0: 153 | if os.path.isfile(self.plot_path): 154 | os.remove(self.plot_path) 155 | if self.first_save: 156 | self.first_save = False 157 | logging.info('Plot file saved at: {}'.format( 158 | os.path.abspath(self.plot_path))) 159 | 160 | output_file(self.plot_path, title=title) 161 | plot = column( 162 | Div(text='

{}

'.format(title)), *self.figures) 163 | save(plot) 164 | self.clear() 165 | 166 | if self.data_format == 'json': 167 | self.results.to_json(self.data_path, orient='records', lines=True) 168 | else: 169 | self.results.to_csv(self.data_path, index=False, index_label=False) 170 | 171 | def load(self, path=None): 172 | """load the data file 173 | Parameters 174 | ---------- 175 | path: 176 | path to load the json|csv file from 177 | """ 178 | path = path or self.data_path 179 | if os.path.isfile(path): 180 | if self.data_format == 'json': 181 | self.results.read_json(path) 182 | else: 183 | self.results.read_csv(path) 184 | else: 185 | raise ValueError('{} isn''t a file'.format(path)) 186 | 187 | def show(self, title=None): 188 | title = title or self.title 189 | if len(self.figures) > 0: 190 | plot = column( 191 | Div(text='

{}

'.format(title)), *self.figures) 192 | show(plot) 193 | 194 | def plot(self, x, y, title=None, xlabel=None, ylabel=None, legend=None, 195 | width=800, height=400, line_width=2, 196 | colors=['red', 'green', 'blue', 'orange', 197 | 'black', 'purple', 'brown'], 198 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save'): 199 | """ 200 | add a new plot to the HTML file 201 | example: 202 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 203 | 'title='Loss', 'ylabel'='loss') 204 | """ 205 | if not isinstance(y, list): 206 | y = [y] 207 | xlabel = xlabel or x 208 | legend = legend or y 209 | assert len(legend) == len(y) 210 | f = figure(title=title, tools=tools, 211 | width=width, height=height, 212 | x_axis_label=xlabel or x, 213 | y_axis_label=ylabel or '') 214 | colors = cycle(colors) 215 | for i, yi in enumerate(y): 216 | f.line(self.results[x], self.results[yi], 217 | line_width=line_width, 218 | line_color=next(colors), legend=legend[i]) 219 | f.legend.click_policy = "hide" 220 | self.figures.append(f) 221 | 222 | def image(self, *kargs, **kwargs): 223 | fig = figure() 224 | fig.image(*kargs, **kwargs) 225 | self.figures.append(fig) 226 | 227 | def end(self): 228 | if hasattr(self, 'hd_experiment'): 229 | self.hd_experiment.end() 230 | 231 | 232 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 233 | filename = os.path.join(path, filename) 234 | torch.save(state, filename) 235 | if is_best: 236 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 237 | if save_all: 238 | shutil.copyfile(filename, os.path.join( 239 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 240 | 241 | class EvalLog: 242 | def __init__(self, headers, f_name=None, auto_save=False): 243 | if auto_save and f_name is None: 244 | raise Exception('auto_save option requires to specify file name') 245 | 246 | dir_name = os.path.dirname(f_name) 247 | if not os.path.exists(dir_name): 248 | os.makedirs(dir_name) 249 | self.df = pd.DataFrame(columns=headers) 250 | self.file_name = f_name 251 | self.auto_save = auto_save 252 | 253 | def log(self, *kargs): 254 | v = {} 255 | for i, arg in enumerate(kargs): 256 | v[self.df.columns[i]] = arg 257 | self.df.loc[len(self.df)] = ([arg for arg in kargs]) 258 | if self.auto_save: 259 | self.df.to_csv(self.file_name, index=False) 260 | 261 | def save(self, fpath): 262 | if not self.auto_save: 263 | self.df.to_csv(fpath, index=False) 264 | 265 | def __str__(self): 266 | return self.df.__str__() -------------------------------------------------------------------------------- /utils/mark_relu.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import Bottleneck, BasicBlock 2 | from torch.nn.parallel.data_parallel import DataParallel 3 | 4 | 5 | def mark_bottlenetck_before_relu(model): 6 | for m in model.children(): 7 | if isinstance(m, Bottleneck): 8 | m.conv1.before_relu = True 9 | m.bn1.before_relu = True 10 | m.conv2.before_relu = True 11 | m.bn2.before_relu = True 12 | else: 13 | mark_bottlenetck_before_relu(m) 14 | 15 | 16 | def mark_basicblock_before_relu(model): 17 | for m in model.children(): 18 | if isinstance(m, BasicBlock): 19 | m.conv1.before_relu = True 20 | m.bn1.before_relu = True 21 | else: 22 | mark_basicblock_before_relu(m) 23 | 24 | 25 | def resnet_mark_before_relu(model): 26 | if isinstance(model, DataParallel): 27 | model.module.conv1.before_relu = True 28 | else: 29 | model.conv1.before_relu = True 30 | 31 | mark_bottlenetck_before_relu(model) 32 | mark_basicblock_before_relu(model) 33 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ProgressMeter(object): 5 | def __init__(self, num_batches, *meters, prefix=""): 6 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 7 | self.meters = meters 8 | self.prefix = prefix 9 | 10 | def print(self, batch): 11 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 12 | entries += [str(meter) for meter in self.meters] 13 | print('\t'.join(entries)) 14 | 15 | def _get_batch_fmtstr(self, num_batches): 16 | num_digits = len(str(num_batches // 1)) 17 | fmt = '{:' + str(num_digits) + 'd}' 18 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 19 | 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value""" 23 | def __init__(self, name='', fmt=':f'): 24 | self.name = name 25 | self.fmt = fmt 26 | self.reset() 27 | 28 | def reset(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | self.avg = self.sum / self.count 39 | 40 | def __str__(self): 41 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 42 | return fmtstr.format(**self.__dict__) 43 | 44 | 45 | class OnlineMeter(object): 46 | """Computes and stores the average and variance/std values of tensor""" 47 | 48 | def __init__(self): 49 | self.mean = torch.FloatTensor(1).fill_(-1) 50 | self.M2 = torch.FloatTensor(1).zero_() 51 | self.count = 0. 52 | self.needs_init = True 53 | 54 | def reset(self, x): 55 | self.mean = x.new(x.size()).zero_() 56 | self.M2 = x.new(x.size()).zero_() 57 | self.count = 0. 58 | self.needs_init = False 59 | 60 | def update(self, x): 61 | self.val = x 62 | if self.needs_init: 63 | self.reset(x) 64 | self.count += 1 65 | delta = x - self.mean 66 | self.mean.add_(delta / self.count) 67 | delta2 = x - self.mean 68 | self.M2.add_(delta * delta2) 69 | 70 | @property 71 | def var(self): 72 | if self.count < 2: 73 | return self.M2.clone().zero_() 74 | return self.M2 / (self.count - 1) 75 | 76 | @property 77 | def std(self): 78 | return self.var().sqrt() 79 | 80 | 81 | def accuracy(output, target, topk=(1,)): 82 | """Computes the accuracy over the k top predictions for the specified values of k""" 83 | with torch.no_grad(): 84 | maxk = max(topk) 85 | batch_size = target.size(0) 86 | 87 | _, pred = output.topk(maxk, 1, True, True) 88 | pred = pred.t() 89 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 90 | 91 | res = [] 92 | for k in topk: 93 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 94 | res.append(correct_k.mul_(100.0 / batch_size)) 95 | return res 96 | 97 | 98 | class AccuracyMeter(object): 99 | """Computes and stores the average and current topk accuracy""" 100 | 101 | def __init__(self, topk=(1,)): 102 | self.topk = topk 103 | self.reset() 104 | 105 | def reset(self): 106 | self._meters = {} 107 | for k in self.topk: 108 | self._meters[k] = AverageMeter() 109 | 110 | def update(self, output, target): 111 | n = target.nelement() 112 | acc_vals = accuracy(output, target, self.topk) 113 | for i, k in enumerate(self.topk): 114 | self._meters[k].update(acc_vals[i]) 115 | 116 | @property 117 | def val(self): 118 | return {n: meter.val for (n, meter) in self._meters.items()} 119 | 120 | @property 121 | def avg(self): 122 | return {n: meter.avg for (n, meter) in self._meters.items()} 123 | 124 | @property 125 | def avg_error(self): 126 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()} 127 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def arch2depth(arch): 7 | depth = None 8 | if 'resnet18' in arch: 9 | depth = 18 10 | elif 'resnet34' in arch: 11 | depth = 34 12 | elif 'resnet50' in arch: 13 | depth = 50 14 | elif 'resnet101' in arch: 15 | depth = 101 16 | 17 | return depth 18 | 19 | 20 | torch_dtypes = { 21 | 'float': torch.float, 22 | 'float32': torch.float32, 23 | 'float64': torch.float64, 24 | 'double': torch.double, 25 | 'float16': torch.float16, 26 | 'half': torch.half, 27 | 'uint8': torch.uint8, 28 | 'int8': torch.int8, 29 | 'int16': torch.int16, 30 | 'short': torch.short, 31 | 'int32': torch.int32, 32 | 'int': torch.int, 33 | 'int64': torch.int64, 34 | 'long': torch.long 35 | } 36 | 37 | 38 | def normalize_module_name(layer_name): 39 | """Normalize a module's name. 40 | 41 | PyTorch let's you parallelize the computation of a model, by wrapping a model with a 42 | DataParallel module. Unfortunately, this changs the fully-qualified name of a module, 43 | even though the actual functionality of the module doesn't change. 44 | Many time, when we search for modules by name, we are indifferent to the DataParallel 45 | module and want to use the same module name whether the module is parallel or not. 46 | We call this module name normalization, and this is implemented here. 47 | """ 48 | modules = layer_name.split('.') 49 | try: 50 | idx = modules.index('module') 51 | except ValueError: 52 | return layer_name 53 | del modules[idx] 54 | return '.'.join(modules) 55 | 56 | 57 | def expand_shape(base_shape, target_shape): 58 | d = len(target_shape) - len(base_shape) 59 | for i in range(d): 60 | base_shape += torch.Size([1]) 61 | return base_shape 62 | 63 | 64 | def cos_sim(x, y, dims=[-1]): 65 | dot = x*y 66 | for d in dims: 67 | dot = torch.sum(dot, dim=d) 68 | norm_x = x**2 69 | for d in dims: 70 | norm_x = torch.sqrt(torch.sum(norm_x, dim=d)) 71 | norm_y = y ** 2 72 | for d in dims: 73 | norm_y = torch.sqrt(torch.sum(norm_y, dim=d)) 74 | 75 | return dot / (norm_x * norm_y) 76 | 77 | 78 | def onehot(indexes, N=None, ignore_index=None): 79 | """ 80 | Creates a one-representation of indexes with N possible entries 81 | if N is not specified, it will suit the maximum index appearing. 82 | indexes is a long-tensor of indexes 83 | ignore_index will be zero in onehot representation 84 | """ 85 | if N is None: 86 | N = indexes.max() + 1 87 | sz = list(indexes.size()) 88 | output = indexes.new().byte().resize_(*sz, N).zero_() 89 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 90 | if ignore_index is not None and ignore_index >= 0: 91 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 92 | return output 93 | 94 | 95 | def set_global_seeds(i): 96 | try: 97 | import torch 98 | except ImportError: 99 | pass 100 | else: 101 | torch.manual_seed(i) 102 | if torch.cuda.is_available(): 103 | torch.cuda.manual_seed_all(i) 104 | np.random.seed(i) 105 | random.seed(i) 106 | 107 | # The following is for monitoring 108 | class Singleton(type): 109 | _instances = {} 110 | 111 | def __call__(cls, *args, **kwargs): 112 | if cls not in cls._instances: 113 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 114 | return cls._instances[cls] 115 | 116 | 117 | import re 118 | 119 | 120 | def sorted_nicely(l): 121 | """ Sorts the given iterable in the way that is expected. 122 | 123 | Required arguments: 124 | l -- The iterable to be sorted. 125 | 126 | """ 127 | convert = lambda text: int(text) if text.isdigit() else text 128 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] 129 | return sorted(l, key=alphanum_key) 130 | -------------------------------------------------------------------------------- /utils/mllog.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mlflow 3 | from tensorboardX import SummaryWriter 4 | from itertools import count 5 | from utils.meters import AverageMeter 6 | 7 | 8 | class MLlogger: 9 | def __init__(self, log_dir, experiment_name, args=None, name_args=[]): 10 | self.log_dir = log_dir 11 | self.args = vars(args) 12 | self.name_args = name_args 13 | 14 | mlflow.set_tracking_uri(log_dir) 15 | mlflow.set_experiment(experiment_name) 16 | 17 | self.auto_steps = {} 18 | self.metters = {} 19 | 20 | def __enter__(self): 21 | self.mlflow = mlflow 22 | 23 | name = '_'.join(self.name_args) if len(self.name_args) > 0 else 'run1' 24 | self.run = mlflow.start_run(run_name=name) 25 | self.run_loc = os.path.join(self.log_dir, self.run.info.experiment_id, self.run.info.run_uuid) 26 | # Save tensorboard events to artifats directory 27 | self.tf_logger = SummaryWriter(os.path.join(self.run_loc, 'artifacts', "events")) 28 | 29 | self.mlflow.set_tag('Tensor board', 'tensorboard --logdir={} --port={} --samples_per_plugin images=0'.format(self.mlflow.get_artifact_uri(), 9999)) 30 | 31 | for key, value in self.args.items(): 32 | self.mlflow.log_param(key, value) 33 | 34 | return self 35 | 36 | def __exit__(self, exc_type, exc_val, exc_tb): 37 | self.mlflow.end_run() 38 | 39 | def log_metric(self, key, value, step=None, log_to_tfboard=False, meterId=None, weight=1.): 40 | if meterId not in self.metters: 41 | self.metters[meterId] = AverageMeter() 42 | 43 | if step is not None and type(step) is str and step == 'auto': 44 | if key not in self.auto_steps: 45 | self.auto_steps[key] = count(0) 46 | step = next(self.auto_steps[key]) 47 | self.mlflow.log_metric(key, value, step) 48 | else: 49 | self.mlflow.log_metric(key, value, step=step) 50 | if log_to_tfboard: 51 | self.tf_logger.add_scalar(key, value, step) 52 | 53 | if meterId is not None: 54 | self.metters[meterId].update(value, weight) 55 | self.mlflow.log_metric(meterId, self.metters[meterId].avg, step) 56 | -------------------------------------------------------------------------------- /utils/model_naming.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel.data_parallel import DataParallel 2 | 3 | 4 | def module_type_to_string(m): 5 | return (str(type(m)).replace('>', '').replace('\'', '').split('.')[-1]).replace('WithId', '') 6 | 7 | 8 | def set_node_name_recurcive(parent, parent_name, ldict=None): 9 | has_children = False 10 | for m in parent.named_children(): 11 | has_children = True 12 | t = module_type_to_string(m[1]) 13 | m_name = parent_name + '/' + t + '[' + m[0] + ']' 14 | set_node_name_recurcive(m[1], m_name, ldict=ldict) 15 | 16 | if not has_children: 17 | parent.internal_name = parent_name 18 | if ldict is not None: 19 | ldict[parent_name] = parent 20 | # print(parent_name) 21 | 22 | 23 | def set_node_names(model, format='tensorboard', create_ldict=None): 24 | # Currently only tensorboard format supported 25 | ldict = {} if create_ldict else None 26 | set_node_name_recurcive(model, module_type_to_string(model), ldict=ldict) 27 | if create_ldict: 28 | return ldict 29 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | __imagenet_pca = { 9 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 10 | 'eigvec': torch.Tensor([ 11 | [-0.5675, 0.7192, 0.4009], 12 | [-0.5808, -0.0045, -0.8140], 13 | [-0.5836, -0.6948, 0.4203], 14 | ]) 15 | } 16 | 17 | 18 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 19 | t_list = [ 20 | transforms.CenterCrop(input_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize(**normalize), 23 | ] 24 | # if scale_size != input_size: 25 | t_list = [transforms.Resize(scale_size)] + t_list 26 | 27 | return transforms.Compose(t_list) 28 | 29 | 30 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 31 | t_list = [ 32 | transforms.RandomCrop(input_size), 33 | transforms.ToTensor(), 34 | transforms.Normalize(**normalize), 35 | ] 36 | if scale_size != input_size: 37 | t_list = [transforms.Resize(scale_size)] + t_list 38 | 39 | transforms.Compose(t_list) 40 | 41 | 42 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 43 | padding = int((scale_size - input_size) / 2) 44 | return transforms.Compose([ 45 | transforms.RandomCrop(input_size, padding=padding), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(**normalize), 49 | ]) 50 | 51 | 52 | def inception_preproccess(input_size, normalize=__imagenet_stats): 53 | return transforms.Compose([ 54 | transforms.RandomResizedCrop(input_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize(**normalize) 58 | ]) 59 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 60 | return transforms.Compose([ 61 | transforms.RandomResizedCrop(input_size), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | ColorJitter( 65 | brightness=0.4, 66 | contrast=0.4, 67 | saturation=0.4, 68 | ), 69 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 70 | transforms.Normalize(**normalize) 71 | ]) 72 | 73 | 74 | def get_transform(name='imagenet', input_size=None, 75 | scale_size=None, normalize=None, augment=True): 76 | normalize = normalize or __imagenet_stats 77 | if name == 'imagenet': 78 | scale_size = scale_size or 256 79 | input_size = input_size or 224 80 | if augment: 81 | return inception_preproccess(input_size, normalize=normalize) 82 | else: 83 | return scale_crop(input_size=input_size, 84 | scale_size=scale_size, normalize=normalize) 85 | elif 'cifar' in name: 86 | input_size = input_size or 32 87 | if augment: 88 | scale_size = scale_size or 40 89 | return pad_random_crop(input_size, scale_size=scale_size, 90 | normalize=normalize) 91 | else: 92 | scale_size = scale_size or 32 93 | return scale_crop(input_size=input_size, 94 | scale_size=scale_size, normalize=normalize) 95 | elif name == 'mnist': 96 | normalize = {'mean': [0.5], 'std': [0.5]} 97 | input_size = input_size or 28 98 | if augment: 99 | scale_size = scale_size or 32 100 | return pad_random_crop(input_size, scale_size=scale_size, 101 | normalize=normalize) 102 | else: 103 | scale_size = scale_size or 32 104 | return scale_crop(input_size=input_size, 105 | scale_size=scale_size, normalize=normalize) 106 | 107 | 108 | class Lighting(object): 109 | """Lighting noise(AlexNet - style PCA - based noise)""" 110 | 111 | def __init__(self, alphastd, eigval, eigvec): 112 | self.alphastd = alphastd 113 | self.eigval = eigval 114 | self.eigvec = eigvec 115 | 116 | def __call__(self, img): 117 | if self.alphastd == 0: 118 | return img 119 | 120 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 121 | rgb = self.eigvec.type_as(img).clone()\ 122 | .mul(alpha.view(1, 3).expand(3, 3))\ 123 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 124 | .sum(1).squeeze() 125 | 126 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 127 | 128 | 129 | class Grayscale(object): 130 | 131 | def __call__(self, img): 132 | gs = img.clone() 133 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 134 | gs[1].copy_(gs[0]) 135 | gs[2].copy_(gs[0]) 136 | return gs 137 | 138 | 139 | class Saturation(object): 140 | 141 | def __init__(self, var): 142 | self.var = var 143 | 144 | def __call__(self, img): 145 | gs = Grayscale()(img) 146 | alpha = random.uniform(0, self.var) 147 | return img.lerp(gs, alpha) 148 | 149 | 150 | class Brightness(object): 151 | 152 | def __init__(self, var): 153 | self.var = var 154 | 155 | def __call__(self, img): 156 | gs = img.new().resize_as_(img).zero_() 157 | alpha = random.uniform(0, self.var) 158 | return img.lerp(gs, alpha) 159 | 160 | 161 | class Contrast(object): 162 | 163 | def __init__(self, var): 164 | self.var = var 165 | 166 | def __call__(self, img): 167 | gs = Grayscale()(img) 168 | gs.fill_(gs.mean()) 169 | alpha = random.uniform(0, self.var) 170 | return img.lerp(gs, alpha) 171 | 172 | 173 | class RandomOrder(object): 174 | """ Composes several transforms together in random order. 175 | """ 176 | 177 | def __init__(self, transforms): 178 | self.transforms = transforms 179 | 180 | def __call__(self, img): 181 | if self.transforms is None: 182 | return img 183 | order = torch.randperm(len(self.transforms)) 184 | for i in order: 185 | img = self.transforms[i](img) 186 | return img 187 | 188 | 189 | class ColorJitter(RandomOrder): 190 | 191 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 192 | self.transforms = [] 193 | if brightness != 0: 194 | self.transforms.append(Brightness(brightness)) 195 | if contrast != 0: 196 | self.transforms.append(Contrast(contrast)) 197 | if saturation != 0: 198 | self.transforms.append(Saturation(saturation)) 199 | -------------------------------------------------------------------------------- /utils/stats_trucker.py: -------------------------------------------------------------------------------- 1 | from utils.misc import Singleton 2 | import numpy as np 3 | import pandas as pd 4 | import signal 5 | import atexit 6 | import os 7 | import shutil 8 | from utils.misc import sorted_nicely 9 | import torch 10 | from pathlib import Path 11 | import pickle 12 | home = str(Path.home()) 13 | 14 | 15 | def list_mean(l): 16 | if len(l) == 0: 17 | return None 18 | 19 | res = l[0].clone() 20 | for t in l[1:]: 21 | res += t 22 | return res / len(l) 23 | 24 | 25 | def exit(trucker): 26 | trucker.__exit__() 27 | 28 | 29 | class StatsTrucker(metaclass=Singleton): 30 | def __init__(self, sufix): 31 | self.folder = 'mxt-sim/stats' 32 | self.fname = 'stats_{}.pkl'.format(sufix) 33 | self.mode = 'mean' 34 | self.stats = {} 35 | self.exited = False 36 | 37 | signal.signal(signal.SIGINT, exit) 38 | signal.signal(signal.SIGTERM, exit) 39 | atexit.register(exit, self) 40 | 41 | def add(self, stat_name, id, value): 42 | if stat_name not in self.stats: 43 | self.stats[stat_name] = {} 44 | if id not in self.stats[stat_name]: 45 | self.stats[stat_name][id] = [value] 46 | else: 47 | self.stats[stat_name][id].append(value) 48 | 49 | def get_stats(self): 50 | stats = {} 51 | for s in self.stats: 52 | stats[s] = [list_mean(self.stats[s][k]).mean().item() for k in self.stats[s]] 53 | 54 | return stats 55 | 56 | def __exit__(self, *args): 57 | if self.exited: 58 | return 59 | print("Saving stats.") 60 | # Save measures 61 | location = os.path.join(home, self.folder) 62 | # if os.path.exists(location): 63 | # shutil.rmtree(location) 64 | if not os.path.exists(location): 65 | os.makedirs(location) 66 | f = open(os.path.join(location, self.fname), 'wb') 67 | pickle.dump(self.stats, f) 68 | f.close() 69 | self.exited = True 70 | --------------------------------------------------------------------------------