├── .gitignore ├── LICENSE ├── README.md ├── experiments ├── __init__.py ├── average_results.py ├── presentation │ ├── __init__.py │ ├── plot_candlesticks.py │ ├── plot_continous.py │ ├── plot_locally.sh │ └── plot_settings.py ├── run_all_float.sh ├── run_all_quant.sh ├── scripts │ ├── pointwise │ │ ├── __init__.py │ │ ├── float │ │ │ ├── __init__.py │ │ │ ├── pointwise_cifar.py │ │ │ ├── pointwise_mnist.py │ │ │ └── pointwise_regression.py │ │ └── quantised │ │ │ ├── __init__.py │ │ │ └── train │ │ │ ├── __init__.py │ │ │ ├── pointwise_cifar.py │ │ │ ├── pointwise_mnist.py │ │ │ └── pointwise_regression.py │ └── stochastic │ │ ├── __init__.py │ │ ├── bbb │ │ ├── float │ │ │ ├── __init__.py │ │ │ ├── bbb_cifar.py │ │ │ ├── bbb_mnist.py │ │ │ └── bbb_regression.py │ │ └── quantised │ │ │ └── train │ │ │ ├── __init__.py │ │ │ ├── bbb_cifar.py │ │ │ ├── bbb_mnist.py │ │ │ └── bbb_regression.py │ │ ├── mcdropout │ │ ├── float │ │ │ ├── mcdropout_cifar.py │ │ │ ├── mcdropout_mnist.py │ │ │ └── mcdropout_regression.py │ │ └── quantised │ │ │ └── train │ │ │ ├── __init__.py │ │ │ ├── mcdropout_cifar.py │ │ │ ├── mcdropout_mnist.py │ │ │ └── mcdropout_regression.py │ │ └── sgld │ │ ├── __init__.py │ │ ├── float │ │ ├── __init__.py │ │ ├── sgld_cifar.py │ │ ├── sgld_mnist.py │ │ └── sgld_regression.py │ │ └── quantised │ │ └── train │ │ ├── __init__.py │ │ ├── sgld_cifar.py │ │ ├── sgld_mnist.py │ │ └── sgld_regression.py └── utils.py ├── poster.pdf ├── requirements.txt ├── src ├── __init__.py ├── data.py ├── losses.py ├── metrics.py ├── models │ ├── __init__.py │ ├── pointwise │ │ ├── __init__.py │ │ └── models_p.py │ └── stochastic │ │ ├── __init__.py │ │ ├── bbb │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── linear.py │ │ ├── models_bbb.py │ │ ├── quantized │ │ │ ├── __init__.py │ │ │ ├── conv_q.py │ │ │ ├── conv_qat.py │ │ │ ├── linear_q.py │ │ │ └── linear_qat.py │ │ └── utils_bbb.py │ │ ├── mcdropout │ │ ├── __init__.py │ │ ├── dropout.py │ │ └── models_mc.py │ │ └── sgld │ │ ├── __init__.py │ │ ├── models_sgld.py │ │ └── utils_sgld.py ├── quant_utils.py ├── trainer.py └── utils.py └── tests ├── __init__.py ├── plot_datasets.py └── plot_distortions.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | experiments/scripts/data/** 131 | **not_q** 132 | **q-** 133 | **qat-** 134 | experiments/data/** 135 | **/summary/** 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Martin Ferianc 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/average_results.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import numpy as np 4 | import logging 5 | 6 | sys.path.append("../") 7 | sys.path.append("../../") 8 | sys.path.append("../../../") 9 | sys.path.append("../../../../") 10 | sys.path.append("../../../../../") 11 | sys.path.append("../../../../../../") 12 | sys.path.append("../../../../../../../") 13 | sys.path.append("../../../../../../../../") 14 | 15 | import src.utils as utils 16 | 17 | parser = argparse.ArgumentParser("average_results") 18 | 19 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 20 | parser.add_argument('--result_paths', nargs='+', default='EXP', help='experiment name') 21 | parser.add_argument('--label', type=str, default='', help='default experiment category ') 22 | 23 | parser.add_argument('--seed', type=int, default=1, help='random seed') 24 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 25 | 26 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 27 | 28 | 29 | def get_dict_path(dictionary, path=[]): 30 | for key, value in dictionary.items(): 31 | if type(value) is dict: 32 | return get_dict_path(dictionary[key], path+[key]) 33 | return path+[key] 34 | return path 35 | 36 | 37 | def get_dict_value(dictionary, path=[], delete =True): 38 | if len(path)==1: 39 | val = dictionary[path[0]] 40 | if delete: 41 | dictionary.pop(path[0]) 42 | return val 43 | else: 44 | return get_dict_value(dictionary[path[0]], path[1:]) 45 | 46 | def set_dict_value(dictionary, value, path=[]): 47 | if len(path)==1: 48 | dictionary[path[0]] = value 49 | else: 50 | if not path[0] in dictionary: 51 | dictionary[path[0]] = {} 52 | set_dict_value(dictionary[path[0]], value, path[1:]) 53 | 54 | def main(): 55 | args = parser.parse_args() 56 | 57 | args, _ = utils.parse_args(args,args.label) 58 | logging.info('# Beginning analysis #') 59 | 60 | final_results = utils.load_pickle(args.save+"/results.pickle") 61 | 62 | logging.info('## Loading of result pickles for the experiment ##') 63 | 64 | results = [] 65 | if len(args.result_paths)==1: 66 | args.result_paths = args.result_paths[0].split(" ") 67 | for result_path in args.result_paths: 68 | result = utils.load_pickle(result_path+"/results.pickle") 69 | logging.info('### Loading result: {} ###'.format(result)) 70 | 71 | results.append(result) 72 | 73 | assert len(results)>1 74 | 75 | final_results['dataset'] = results[0]['dataset'] 76 | final_results['model'] = results[0]['model'] 77 | 78 | traversing_result = results[0] 79 | while len(get_dict_path(traversing_result))!=0: 80 | path = get_dict_path(traversing_result) 81 | values = [] 82 | mean = None 83 | std = None 84 | for result in results: 85 | val = get_dict_value(result, path) 86 | if not isinstance(val, dict): 87 | values.append(val) 88 | 89 | if len(values) == 0 or type(values[0]) == str: 90 | continue 91 | 92 | if type(values[0]) == tuple: 93 | _values = [] 94 | for i in range(len(values)): 95 | try: 96 | count, val = values[i] 97 | _values.append(val) 98 | except: 99 | _values.append(values[i]) 100 | values = _values 101 | 102 | values = np.array(values) 103 | mean = np.nanmean(values) 104 | std = np.nanstd(values) 105 | set_dict_value(final_results, (mean, std), path) 106 | 107 | logging.info('## Results: {} ##'.format(final_results)) 108 | utils.save_pickle(final_results, args.save+"/results.pickle", True) 109 | 110 | 111 | 112 | logging.info('# Finished #') 113 | 114 | 115 | if __name__ == '__main__': 116 | main() -------------------------------------------------------------------------------- /experiments/presentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/presentation/__init__.py -------------------------------------------------------------------------------- /experiments/presentation/plot_continous.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import numpy as np 4 | import logging 5 | 6 | sys.path.append("../") 7 | sys.path.append("../../") 8 | sys.path.append("../../../") 9 | sys.path.append("../../../../") 10 | sys.path.append("../../../../../") 11 | sys.path.append("../../../../../../") 12 | 13 | from experiments.presentation.plot_settings import PLT as plt 14 | from matplotlib.ticker import MaxNLocator 15 | import src.utils as utils 16 | from experiments.utils import METRICS_UNITS, RELEVANT_COMBINATIONS, REGRESSION_DATASETS 17 | 18 | parser = argparse.ArgumentParser("compare_ood_results") 19 | 20 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 21 | 22 | parser.add_argument('--pointwise_paths', nargs='+', 23 | default=[], help='experiment name') 24 | parser.add_argument('--mcd_paths', nargs='+', 25 | default=[], help='experiment name') 26 | parser.add_argument('--bbb_paths', nargs='+', 27 | default=[], help='experiment name') 28 | parser.add_argument('--sgld_paths', nargs='+', 29 | default=[], help='experiment name') 30 | parser.add_argument('--task', type=str, 31 | default='classification', help='experiment task') 32 | parser.add_argument('--label', type=str, default='', 33 | help='default experiment category ') 34 | 35 | parser.add_argument('--seed', type=int, default=1, help='random seed') 36 | parser.add_argument('--debug', action='store_true', 37 | help='whether we are currently debugging') 38 | parser.add_argument('--weight', action='store_true', 39 | default=False, help='random seed') 40 | 41 | parser.add_argument('--gpu', type=int, default=0, help='gpu device ids') 42 | parser.add_argument('--q', action='store_true', default=False, 43 | help='whether to do post training quantisation') 44 | parser.add_argument('--at', action='store_true', default=False, 45 | help='whether to do training aware quantisation') 46 | 47 | def main(): 48 | args = parser.parse_args() 49 | args, _ = utils.parse_args(args, args.label) 50 | labels = ['Pointwise', 'MCD', 'BBB', 'SGHMC'] 51 | QUANT= None 52 | if args.weight: 53 | QUANT = [32, 8, 7, 6, 5, 4, 3] 54 | else: 55 | QUANT = [32, 7, 6, 5, 4, 3] 56 | for i, combination in enumerate(RELEVANT_COMBINATIONS[args.task]): 57 | logging.info('## Loading of result pickles for the experiment ##') 58 | fig = plt.figure(figsize=(5,1.75)) 59 | plt.grid(True) 60 | for j, paths in enumerate([args.pointwise_paths, args.mcd_paths, args.bbb_paths, args.sgld_paths]): 61 | if len(paths)==0: 62 | continue 63 | data = [] 64 | for k, path in enumerate(paths): 65 | result = utils.load_pickle(path+"/results.pickle") 66 | logging.info('### Loading result: {} ###'.format(result)) 67 | if args.task=='classification': 68 | data.append(result[combination[0]][combination[1]]) 69 | elif combination[1]=='uci' and args.task=='regression': 70 | mean = [] 71 | for dataset, _ in REGRESSION_DATASETS[1:]: 72 | val = result[combination[0] 73 | ]['regression_'+dataset]['test'][0] 74 | if utils.isoutlier(val): 75 | continue 76 | mean.append(val) 77 | 78 | data.append([np.mean(mean), np.std(mean)]) 79 | if combination[0] == "nll": 80 | data[-1][0]*=-1 81 | elif combination[1]=='synthetic' and args.task=='regression': 82 | d = list(result[combination[0]]["regression_synthetic"]['test']) 83 | if utils.isoutlier(d[0]): 84 | continue 85 | data.append(d) 86 | if combination[0] == "nll": 87 | data[-1][0]*=-1 88 | 89 | positions = np.array([x for x in range(len(data))]) 90 | mean = [d[0] for d in data] 91 | stds = [d[1] for d in data] 92 | plt.plot(positions, mean, 93 | color="C"+str(j), alpha=0.7, label=labels[j]) 94 | plt.errorbar(positions, mean, yerr=stds, fmt='o' if j != 0 else 'v', capsize=10, 95 | color="C"+str(j), alpha=0.7) 96 | 97 | ax = fig.gca() 98 | ax.xaxis.set_major_locator(MaxNLocator(integer=False)) 99 | positions = np.array([k for k in range(len(QUANT))]) 100 | ax.spines['top'].set_visible(False) 101 | ax.spines['right'].set_visible(False) 102 | if args.weight: 103 | ticks = ['$Float_{32}$' if j == 32 else '$Q:A_7\\boldsymbol{W_'+str(j)+"}$" for j in QUANT] 104 | else: 105 | ticks = ['$Float_{32}$' if j == 32 else '$Q:\\boldsymbol{A_'+str(j)+"}W_8$" for j in QUANT] 106 | plt.tick_params(axis="x", which="both", bottom=False) 107 | plt.xticks(ticks=positions, labels=ticks) 108 | if combination[0] == "error" and "mnist" in args.label and not args.weight: 109 | plt.ylim(0, 10) 110 | elif combination[0] == "error" and "cifar" in args.label and not args.weight: 111 | plt.ylim(0, 22.5) 112 | 113 | ax.legend(loc='upper left') 114 | plt.xlabel('Bit-width \& Precision') 115 | plt.ylabel(METRICS_UNITS[combination[0] + 116 | "_regression" if args.task == "regression" else combination[0]]) 117 | plt.tight_layout() 118 | path = utils.check_path( 119 | args.save+'/{}_{}_{}.pdf'.format(combination[0], combination[1], "weight" if args.weight else "activation")) 120 | box = ax.get_position() 121 | ax.set_position([box.x0, box.y0 - box.height * 0.08, 122 | box.width, box.height * 0.92]) 123 | ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=4) 124 | plt.savefig(path) 125 | 126 | if __name__ == '__main__': 127 | main() 128 | 129 | -------------------------------------------------------------------------------- /experiments/presentation/plot_settings.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib as mlp 3 | import brewer2mpl 4 | 5 | LINESTYLES = ['solid', 'dashed', 'dotted', 'dashdot'] 6 | bmap = brewer2mpl.get_map('Set1', 'qualitative', 9) 7 | 8 | COLORS = bmap.mpl_colors 9 | 10 | params = { 11 | 'figure.dpi':150, 12 | 'axes.labelsize': 10, 13 | 'font.size': 8, 14 | 'legend.fontsize': 10, 15 | 'legend.fancybox': False, 16 | 'legend.framealpha': .5, 17 | 'legend.frameon':False, 18 | 'xtick.labelsize': 10, 19 | 'ytick.labelsize': 10, 20 | 'text.usetex': True, 21 | 'figure.figsize': [7.5, 5.5], 22 | 'mathtext.default':'regular', 23 | } 24 | 25 | mlp.rcParams.update(params) 26 | mlp.rcParams['text.latex.preamble']=r"\usepackage{amsmath}" 27 | 28 | MLP = mlp 29 | PLT = plt -------------------------------------------------------------------------------- /experiments/run_all_float.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SAMPLES=3 4 | ret_val="" 5 | GPU=$1 6 | function run_experiments { 7 | declare -a samples_list=() 8 | for i in $(seq 1 $SAMPLES); 9 | do 10 | local sample=$(python3 $1 --seed $i --gpu $GPU | grep Experiment | cut -d " " -f 4) 11 | echo $sample 12 | samples_list+=("../${sample}") 13 | done 14 | echo "${samples_list[*]}" 15 | if [ ! -d "./summary" ]; then 16 | mkdir summary 17 | fi 18 | cd ./summary/ 19 | result=$(python3 ../$2 --result_paths "${samples_list[*]}" --label $3 | grep Experiment | cut -d " " -f 4) 20 | ret_val=$result 21 | cd ../ 22 | } 23 | 24 | function run_q_experiments { 25 | local float_samples=(./../../float/default/*"$4"*) 26 | echo "${float_samples[*]}" 27 | declare -a samples_list=() 28 | for i in $(seq 1 $SAMPLES); 29 | do 30 | local index=${i}-1 31 | echo "${float_samples[$index]}" 32 | local sample=$(python3 $1 --seed $i --gpu $GPU --load ${float_samples[$index]} | grep Experiment | cut -d " " -f 4) 33 | echo $sample 34 | samples_list+=("../${sample}") 35 | done 36 | if [ ! -d "./summary" ]; then 37 | mkdir summary 38 | fi 39 | cd ./summary/ 40 | echo "${samples_list[*]}" 41 | result=$(python3 ../$2 --result_paths "${samples_list[*]}" --label $3 | grep Experiment | cut -d " " -f 4) 42 | ret_val=$result 43 | cd ../ 44 | } 45 | 46 | cd ./scripts/pointwise/float 47 | #run_experiments pointwise_regression.py ../../../average_results.py float_pointwise_regression 48 | #run_experiments pointwise_mnist.py ../../../average_results.py float_pointwise_mnist 49 | #run_experiments pointwise_cifar.py ../../../average_results.py float_pointwise_cifar 50 | cd ./../../../ 51 | 52 | cd ./scripts/stochastic/mcdropout/float 53 | #run_experiments mcdropout_regression.py ../../../../average_results.py float_mcdropout_regression 54 | #run_experiments mcdropout_mnist.py ../../../../average_results.py float_mcdropout_mnist 55 | #run_experiments mcdropout_cifar.py ../../../../average_results.py float_mcdropout_cifar 56 | cd ./../../../../ 57 | 58 | cd ./scripts/stochastic/sgld/float 59 | #run_experiments sgld_regression.py ../../../../average_results.py float_sgld_regression 60 | #run_experiments sgld_mnist.py ../../../../average_results.py float_sgld_mnist 61 | #run_experiments sgld_cifar.py ../../../../average_results.py float_sgld_cifar 62 | cd ./../../../../ 63 | 64 | cd ./scripts/stochastic/bbb/float 65 | #run_experiments bbb_regression.py ../../../../average_results.py float_bbb_regression 66 | #run_experiments bbb_mnist.py ../../../../average_results.py float_bbb_mnist 67 | #run_experiments bbb_cifar.py ../../../../average_results.py float_bbb_cifar 68 | cd ./../../../../ 69 | -------------------------------------------------------------------------------- /experiments/run_all_quant.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SAMPLES=3 4 | ACTIVATION_PRECISION=7 5 | WEIGHT_PRECISION=8 6 | GPU=$1 7 | function run_q_experiments { 8 | local float_samples=(./../../float/default/*"$4"*) 9 | echo "${float_samples[*]}" 10 | 11 | for w in $(seq 3 $WEIGHT_PRECISION); 12 | do 13 | echo "a_${ACTIVATION_PRECISION}_w_${w}" 14 | if [ ! -d "a_${ACTIVATION_PRECISION}_w_${w}" ]; then 15 | mkdir a_"${ACTIVATION_PRECISION}"_w_"${w}" 16 | fi 17 | cd ./a_"${ACTIVATION_PRECISION}"_w_"${w}"/ 18 | declare -a samples_list=() 19 | for i in $(seq 1 $SAMPLES); 20 | do 21 | local index=${i}-1 22 | echo "${float_samples[$index]}" 23 | local sample=$(python3 ../$1 --seed $i --gpu $GPU --load ../${float_samples[$index]} --weight_precision $w --activation_precision ${ACTIVATION_PRECISION} | grep Experiment | cut -d " " -f 4) 24 | echo $sample 25 | samples_list+=("../${sample}") 26 | done 27 | if [ ! -d "./summary" ]; then 28 | mkdir summary 29 | fi 30 | cd ./summary/ 31 | echo "${samples_list[*]}" 32 | result=$(python3 ../../$2 --result_paths "${samples_list[*]}" --label $3_${ACTIVATION_PRECISION}_${w} | grep Experiment | cut -d " " -f 4) 33 | ret_val=$result 34 | cd ../../ 35 | done 36 | 37 | for a in $(seq 3 $(($ACTIVATION_PRECISION-1))); 38 | do 39 | echo "a_${a}_w_${WEIGHT_PRECISION}" 40 | if [ ! -d "a_${a}_w_${WEIGHT_PRECISION}" ]; then 41 | mkdir a_"${a}"_w_"${WEIGHT_PRECISION}" 42 | fi 43 | cd ./a_"${a}"_w_"${WEIGHT_PRECISION}"/ 44 | declare -a samples_list=() 45 | for i in $(seq 1 $SAMPLES); 46 | do 47 | local index=${i}-1 48 | echo "${float_samples[$index]}" 49 | local sample=$(python3 ../$1 --seed $i --gpu $GPU --load ../${float_samples[$index]} --weight_precision ${WEIGHT_PRECISION} --activation_precision $a | grep Experiment | cut -d " " -f 4) 50 | echo $sample 51 | samples_list+=("../${sample}") 52 | done 53 | if [ ! -d "./summary" ]; then 54 | mkdir summary 55 | fi 56 | cd ./summary/ 57 | echo "${samples_list[*]}" 58 | result=$(python3 ../../$2 --result_paths "${samples_list[*]}" --label $3_${a}_${WEIGHT_PRECISION} | grep Experiment | cut -d " " -f 4) 59 | ret_val=$result 60 | cd ../../ 61 | done 62 | } 63 | 64 | cd ./scripts/pointwise/quantised/train 65 | #run_q_experiments pointwise_regression.py ../../../../average_results.py qat_pointwise_regression regression-regression 66 | #run_q_experiments pointwise_mnist.py ../../../../average_results.py qat_pointwise_mnist mnist 67 | #run_q_experiments pointwise_cifar.py ../../../../average_results.py qat_pointwise_cifar cifar 68 | cd ./../../../../ 69 | 70 | cd ./scripts/stochastic/mcdropout/quantised/train 71 | #run_q_experiments mcdropout_regression.py ../../../../../average_results.py qat_mcdropout_regression regression-regression 72 | #run_q_experiments mcdropout_mnist.py ../../../../../average_results.py qat_mcdropout_mnist mnist 73 | #run_q_experiments mcdropout_cifar.py ../../../../../average_results.py qat_mcdropout_cifar cifar 74 | cd ./../../../../../ 75 | 76 | cd ./scripts/stochastic/sgld/quantised/train 77 | #run_q_experiments sgld_regression.py ../../../../../average_results.py qat_sgld_regression regression-regression 78 | #run_q_experiments sgld_mnist.py ../../../../../average_results.py qat_sgld_mnist mnist 79 | #run_q_experiments sgld_cifar.py ../../../../../average_results.py qat_sgld_cifar cifar 80 | cd ./../../../../../ 81 | 82 | cd ./scripts/stochastic/bbb/quantised/train 83 | #run_q_experiments bbb_regression.py ../../../../../average_results.py qat_bbb_regression regression-regression 84 | #run_q_experiments bbb_mnist.py ../../../../../average_results.py qat_bbb_mnist mnist 85 | #run_q_experiments bbb_cifar.py ../../../../../average_results.py qat_bbb_cifar cifar 86 | -------------------------------------------------------------------------------- /experiments/scripts/pointwise/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/pointwise/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/pointwise/float/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/pointwise/float/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/pointwise/float/pointwise_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | 12 | from experiments.utils import evaluate_cifar_uncertainty 13 | from src.data import * 14 | from src.trainer import Trainer 15 | from src.models import ModelFactory 16 | from src.losses import LOSS_FACTORY 17 | import src.utils as utils 18 | 19 | parser = argparse.ArgumentParser("pointwise_cifar") 20 | 21 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 22 | parser.add_argument('--model', type=str, default='conv_resnet', help='the model that we want to train') 23 | 24 | parser.add_argument('--learning_rate', type=float, 25 | default=0.001, help='init learning rate') 26 | parser.add_argument('--loss_scaling', type=str, 27 | default='batch', help='smoothing factor') 28 | parser.add_argument('--weight_decay', type=float, 29 | default=0.00001, help='weight decay') 30 | 31 | parser.add_argument('--data', type=str, default='./../../data/', 32 | help='location of the data corpus') 33 | parser.add_argument('--dataset', type=str, default='cifar', 34 | help='dataset') 35 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 36 | 37 | 38 | parser.add_argument('--valid_portion', type=float, 39 | default=0.1, help='portion of training data') 40 | 41 | parser.add_argument('--epochs', type=int, default=300, 42 | help='num of training epochs') 43 | 44 | parser.add_argument('--input_size', nargs='+', 45 | default=[1, 3, 32, 32], help='input size') 46 | parser.add_argument('--output_size', type=int, 47 | default=10, help='output size') 48 | parser.add_argument('--samples', type=int, 49 | default=1, help='output size') 50 | 51 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 52 | parser.add_argument('--save_last', action='store_true', default=True, 53 | help='whether to just save the last model') 54 | 55 | parser.add_argument('--num_workers', type=int, 56 | default=16, help='number of workers') 57 | parser.add_argument('--seed', type=int, default=1, help='random seed') 58 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 59 | 60 | parser.add_argument('--report_freq', type=float, 61 | default=50, help='report frequency') 62 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 63 | 64 | 65 | parser.add_argument('--q', action='store_true', 66 | help='whether to do post training quantisation') 67 | parser.add_argument('--at', action='store_true', 68 | help='whether to do training aware quantisation') 69 | 70 | 71 | def main(): 72 | args = parser.parse_args() 73 | load = False 74 | if args.save!='EXP': 75 | load=True 76 | 77 | args, writer = utils.parse_args(args) 78 | 79 | logging.info('# Start Re-training #') 80 | 81 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 82 | 83 | model_temp = ModelFactory.get_model 84 | 85 | logging.info('## Downloading and preparing data ##') 86 | train_loader, valid_loader= get_train_loaders(args) 87 | 88 | if not load: 89 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 90 | 91 | logging.info('## Model created: ##') 92 | logging.info(model.__repr__()) 93 | 94 | logging.info('### Loading model to parallel GPUs ###') 95 | model = utils.model_to_gpus(model, args) 96 | 97 | logging.info('### Preparing schedulers and optimizers ###') 98 | optimizer = torch.optim.Adam( 99 | model.parameters(), 100 | args.learning_rate, 101 | weight_decay = args.weight_decay) 102 | 103 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 104 | optimizer, args.epochs) 105 | 106 | 107 | logging.info('## Beginning Training ##') 108 | 109 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 110 | 111 | best_error, train_time, val_time = train.train_loop( 112 | train_loader, valid_loader) 113 | 114 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 115 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 116 | 117 | logging.info('## Beginning Plotting ##') 118 | del model 119 | 120 | with torch.no_grad(): 121 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 122 | 123 | utils.load_model(model, args.save+"/weights.pt") 124 | logging.info('## Model re-created: ##') 125 | logging.info(model.__repr__()) 126 | model = utils.model_to_gpus(model, args) 127 | model.eval() 128 | 129 | evaluate_cifar_uncertainty(model, args) 130 | logging.info('# Finished #') 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /experiments/scripts/pointwise/float/pointwise_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | 12 | from experiments.utils import evaluate_mnist_uncertainty 13 | from src.data import * 14 | from src.trainer import Trainer 15 | from src.models import ModelFactory 16 | from src.losses import LOSS_FACTORY 17 | import src.utils as utils 18 | 19 | parser = argparse.ArgumentParser("pointwise_mnist") 20 | 21 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 22 | parser.add_argument('--model', type=str, default='conv_lenet', help='the model that we want to train') 23 | 24 | parser.add_argument('--learning_rate', type=float, 25 | default=0.001, help='init learning rate') 26 | parser.add_argument('--loss_scaling', type=str, 27 | default='batch', help='smoothing factor') 28 | parser.add_argument('--weight_decay', type=float, 29 | default=0.0001, help='weight decay') 30 | 31 | 32 | parser.add_argument('--data', type=str, default='./../../data/', 33 | help='location of the data corpus') 34 | parser.add_argument('--dataset', type=str, default='mnist', 35 | help='dataset') 36 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 37 | 38 | parser.add_argument('--valid_portion', type=float, 39 | default=0.1, help='portion of training data') 40 | 41 | parser.add_argument('--epochs', type=int, default=100, 42 | help='num of training epochs') 43 | 44 | parser.add_argument('--input_size', nargs='+', 45 | default=[1, 1, 28, 28], help='input size') 46 | parser.add_argument('--output_size', type=int, 47 | default=10, help='output size') 48 | parser.add_argument('--samples', type=int, 49 | default=1, help='output size') 50 | 51 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 52 | parser.add_argument('--save_last', action='store_true', default=True, 53 | help='whether to just save the last model') 54 | 55 | parser.add_argument('--num_workers', type=int, 56 | default=16, help='number of workers') 57 | parser.add_argument('--seed', type=int, default=1, help='random seed') 58 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 59 | 60 | parser.add_argument('--report_freq', type=float, 61 | default=50, help='report frequency') 62 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 63 | 64 | 65 | parser.add_argument('--q', action='store_true', 66 | help='whether to do post training quantisation') 67 | parser.add_argument('--at', action='store_true', 68 | help='whether to do training aware quantisation') 69 | 70 | 71 | def main(): 72 | args = parser.parse_args() 73 | load = False 74 | if args.save!='EXP': 75 | load=True 76 | 77 | args, writer = utils.parse_args(args) 78 | 79 | logging.info('# Start Re-training #') 80 | 81 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 82 | 83 | model_temp = ModelFactory.get_model 84 | logging.info('## Downloading and preparing data ##') 85 | train_loader, valid_loader= get_train_loaders(args) 86 | 87 | if not load: 88 | 89 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 90 | 91 | logging.info('## Model created: ##') 92 | logging.info(model.__repr__()) 93 | 94 | logging.info('### Loading model to parallel GPUs ###') 95 | 96 | model = utils.model_to_gpus(model, args) 97 | 98 | logging.info('### Preparing schedulers and optimizers ###') 99 | optimizer = torch.optim.Adam( 100 | model.parameters(), 101 | args.learning_rate, 102 | weight_decay = args.weight_decay) 103 | 104 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 105 | optimizer, args.epochs) 106 | 107 | logging.info('## Beginning Training ##') 108 | 109 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 110 | 111 | best_error, train_time, val_time = train.train_loop( 112 | train_loader, valid_loader) 113 | 114 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 115 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 116 | 117 | logging.info('## Beginning Plotting ##') 118 | del model 119 | 120 | with torch.no_grad(): 121 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 122 | 123 | utils.load_model(model, args.save+"/weights.pt") 124 | logging.info('## Model re-created: ##') 125 | logging.info(model.__repr__()) 126 | model = utils.model_to_gpus(model, args) 127 | model.eval() 128 | 129 | evaluate_mnist_uncertainty(model, args) 130 | logging.info('# Finished #') 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /experiments/scripts/pointwise/float/pointwise_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | 12 | from src.data import * 13 | from src.trainer import Trainer 14 | from src.models import ModelFactory 15 | from src.losses import LOSS_FACTORY 16 | import src.utils as utils 17 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 18 | 19 | parser = argparse.ArgumentParser("pointwise_regression") 20 | 21 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 22 | parser.add_argument('--model', type=str, default='linear', help='the model that we want to train') 23 | 24 | parser.add_argument('--learning_rate', type=float, 25 | default=0.001, help='init learning rate') 26 | 27 | parser.add_argument('--loss_scaling', type=str, 28 | default='batch', help='smoothing factor') 29 | parser.add_argument('--weight_decay', type=float, 30 | default=0.00005, help='weight decay') 31 | 32 | parser.add_argument('--data', type=str, default='./../../data/', 33 | help='location of the data corpus') 34 | parser.add_argument('--dataset', type=str, default='regression', 35 | help='dataset') 36 | parser.add_argument('--batch_size', type=int, default=1000, help='batch size') 37 | 38 | parser.add_argument('--valid_portion', type=float, 39 | default=0.2, help='portion of training data') 40 | 41 | parser.add_argument('--epochs', type=int, default=300, 42 | help='num of training epochs') 43 | 44 | parser.add_argument('--input_size', nargs='+', 45 | default=[1], help='input size') 46 | parser.add_argument('--output_size', type=int, 47 | default=1, help='output size') 48 | parser.add_argument('--samples', type=int, 49 | default=1, help='output size') 50 | 51 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 52 | parser.add_argument('--save_last', action='store_true', default=True, 53 | help='whether to just save the last model') 54 | 55 | parser.add_argument('--num_workers', type=int, 56 | default=0, help='number of workers') 57 | parser.add_argument('--seed', type=int, default=1, help='random seed') 58 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 59 | 60 | parser.add_argument('--report_freq', type=float, 61 | default=50, help='report frequency') 62 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 63 | 64 | parser.add_argument('--q', action='store_true',default=False, 65 | help='whether to do post training quantisation') 66 | parser.add_argument('--at', action='store_true',default=False, 67 | help='whether to do training aware quantisation') 68 | 69 | 70 | def main(): 71 | args = parser.parse_args() 72 | load = False 73 | if args.save!='EXP': 74 | load=True 75 | 76 | model_temp = ModelFactory.get_model 77 | 78 | args, writer = utils.parse_args(args) 79 | logging.info('# Start Re-training #') 80 | if not load: 81 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 82 | for j in range(n_folds): 83 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 84 | 85 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 86 | 87 | logging.info('## Downloading and preparing data ##') 88 | args.dataset = "regression_" + dataset 89 | 90 | train_loader, valid_loader= get_train_loaders(args, split=j) 91 | in_shape = next(iter(train_loader))[0].shape[1] 92 | args.input_size = [in_shape] 93 | 94 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 95 | 96 | logging.info('## Model created: ##') 97 | logging.info(model.__repr__()) 98 | 99 | logging.info('### Loading model to parallel GPUs ###') 100 | 101 | model = utils.model_to_gpus(model, args) 102 | 103 | logging.info('### Preparing schedulers and optimizers ###') 104 | optimizer = torch.optim.Adam( 105 | model.parameters(), 106 | args.learning_rate, 107 | weight_decay = args.weight_decay) 108 | 109 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 110 | optimizer, args.epochs) 111 | 112 | logging.info('## Beginning Training ##') 113 | 114 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 115 | 116 | best_error, train_time, val_time = train.train_loop( 117 | train_loader, valid_loader, special_info="_"+dataset+"_"+str(j)) 118 | 119 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 120 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 121 | 122 | 123 | del model 124 | 125 | with torch.no_grad(): 126 | logging.info('## Beginning Plotting ##') 127 | evaluate_regression_uncertainty(model_temp, args) 128 | 129 | logging.info('# Finished #') 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /experiments/scripts/pointwise/quantised/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/pointwise/quantised/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/pointwise/quantised/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/pointwise/quantised/train/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/pointwise/quantised/train/pointwise_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from experiments.utils import evaluate_cifar_uncertainty 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | import src.utils as utils 20 | import src.quant_utils as quant_utils 21 | 22 | parser = argparse.ArgumentParser("pointwise_cifar") 23 | 24 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 25 | parser.add_argument('--model', type=str, default='conv_resnet', help='the model that we want to train') 26 | 27 | parser.add_argument('--learning_rate', type=float, 28 | default=0.00001, help='init learning rate') 29 | parser.add_argument('--loss_scaling', type=str, 30 | default='batch', help='smoothing factor') 31 | parser.add_argument('--weight_decay', type=float, 32 | default=0.00001, help='weight decay') 33 | 34 | 35 | parser.add_argument('--data', type=str, default='./../../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='cifar', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=1024, help='batch size') 40 | 41 | parser.add_argument('--valid_portion', type=float, 42 | default=0.1, help='portion of training data') 43 | 44 | parser.add_argument('--epochs', type=int, default=10, 45 | help='num of training epochs') 46 | 47 | parser.add_argument('--input_size', nargs='+', 48 | default=[1, 3, 32, 32], help='input size') 49 | parser.add_argument('--output_size', type=int, 50 | default=10, help='output size') 51 | parser.add_argument('--samples', type=int, 52 | default=1, help='output size') 53 | 54 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 55 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 56 | 57 | parser.add_argument('--save_last', action='store_true', default=True, 58 | help='whether to just save the last model') 59 | 60 | parser.add_argument('--num_workers', type=int, 61 | default=16, help='number of workers') 62 | parser.add_argument('--seed', type=int, default=1, help='random seed') 63 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 64 | 65 | parser.add_argument('--report_freq', type=float, 66 | default=50, help='report frequency') 67 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 68 | 69 | parser.add_argument('--q', action='store_true', default=True, 70 | help='whether to do post training quantisation') 71 | parser.add_argument('--at', action='store_true', default=True, 72 | help='whether to do training aware quantisation') 73 | parser.add_argument('--activation_precision', type=int, default=7, 74 | help='how many bits to be used for the activations') 75 | parser.add_argument('--weight_precision', type=int, default=8, 76 | help='how many bits to be used for the weights') 77 | 78 | def main(): 79 | args = parser.parse_args() 80 | load = False 81 | if args.save!='EXP': 82 | load=True 83 | args, writer = utils.parse_args(args) 84 | 85 | logging.info('# Start Re-training #') 86 | 87 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 88 | 89 | model_temp = ModelFactory.get_model 90 | logging.info('## Downloading and preparing data ##') 91 | train_loader, valid_loader= get_train_loaders(args) 92 | 93 | if not load: 94 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 95 | utils.load_model(model, args.load+"/weights.pt") 96 | 97 | if args.at: 98 | logging.info('## Preparing model for quantization aware training ##') 99 | quant_utils.prepare_model(model, args) 100 | 101 | logging.info('## Model created: ##') 102 | logging.info(model.__repr__()) 103 | 104 | 105 | logging.info('### Loading model to parallel GPUs ###') 106 | 107 | model = utils.model_to_gpus(model, args) 108 | 109 | 110 | logging.info('### Preparing schedulers and optimizers ###') 111 | optimizer = torch.optim.SGD( 112 | model.parameters(), 113 | args.learning_rate, 114 | momentum=0.9, 115 | weight_decay = args.weight_decay) 116 | 117 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 118 | optimizer, args.epochs) 119 | 120 | logging.info('## Beginning Training ##') 121 | 122 | 123 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 124 | 125 | best_error, train_time, val_time = train.train_loop( 126 | train_loader, valid_loader) 127 | 128 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 129 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 130 | 131 | if args.q: 132 | quant_utils.postprocess_model(model, args) 133 | 134 | 135 | logging.info('## Beginning Plotting ##') 136 | del model 137 | 138 | with torch.no_grad(): 139 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 140 | if args.q: 141 | quant_utils.prepare_model(model, args) 142 | quant_utils.convert(model) 143 | 144 | utils.load_model(model, args.save+"/weights.pt") 145 | logging.info('## Model re-created: ##') 146 | logging.info(model.__repr__()) 147 | 148 | if not args.q: 149 | model = utils.model_to_gpus(model, args) 150 | 151 | model.eval() 152 | 153 | evaluate_cifar_uncertainty(model, args) 154 | logging.info('# Finished #') 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /experiments/scripts/pointwise/quantised/train/pointwise_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from experiments.utils import evaluate_mnist_uncertainty 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | import src.utils as utils 20 | import src.quant_utils as quant_utils 21 | 22 | parser = argparse.ArgumentParser("pointwise_mnist") 23 | 24 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 25 | parser.add_argument('--model', type=str, default='conv_lenet', help='the model that we want to train') 26 | 27 | parser.add_argument('--learning_rate', type=float, 28 | default=0.00001, help='init learning rate') 29 | parser.add_argument('--loss_scaling', type=str, 30 | default='batch', help='smoothing factor') 31 | parser.add_argument('--weight_decay', type=float, 32 | default=0.0001, help='weight decay') 33 | 34 | 35 | parser.add_argument('--data', type=str, default='./../../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='mnist', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 40 | 41 | parser.add_argument('--valid_portion', type=float, 42 | default=0.1, help='portion of training data') 43 | 44 | parser.add_argument('--epochs', type=int, default=10, 45 | help='num of training epochs') 46 | 47 | parser.add_argument('--input_size', nargs='+', 48 | default=[1, 1, 28, 28], help='input size') 49 | parser.add_argument('--output_size', type=int, 50 | default=10, help='output size') 51 | parser.add_argument('--samples', type=int, 52 | default=1, help='output size') 53 | 54 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 55 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 56 | 57 | parser.add_argument('--save_last', action='store_true', default=True, 58 | help='whether to just save the last model') 59 | 60 | parser.add_argument('--num_workers', type=int, 61 | default=16, help='number of workers') 62 | parser.add_argument('--seed', type=int, default=1, help='random seed') 63 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 64 | 65 | parser.add_argument('--report_freq', type=float, 66 | default=50, help='report frequency') 67 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 68 | 69 | parser.add_argument('--q', action='store_true', default=True, 70 | help='whether to do post training quantisation') 71 | parser.add_argument('--at', action='store_true', default=True, 72 | help='whether to do training aware quantisation') 73 | parser.add_argument('--activation_precision', type=int, default=7, 74 | help='how many bits to be used for the activations') 75 | parser.add_argument('--weight_precision', type=int, default=8, 76 | help='how many bits to be used for the weights') 77 | 78 | def main(): 79 | args = parser.parse_args() 80 | load = False 81 | if args.save!='EXP': 82 | load=True 83 | 84 | args, writer = utils.parse_args(args) 85 | 86 | logging.info('# Start Re-training #') 87 | 88 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 89 | 90 | model_temp = ModelFactory.get_model 91 | logging.info('## Downloading and preparing data ##') 92 | train_loader, valid_loader= get_train_loaders(args) 93 | 94 | if not load: 95 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 96 | utils.load_model(model, args.load+"/weights.pt") 97 | 98 | if args.at: 99 | logging.info('## Preparing model for quantization aware training ##') 100 | quant_utils.prepare_model(model, args) 101 | 102 | 103 | logging.info('## Model created: ##') 104 | logging.info(model.__repr__()) 105 | 106 | 107 | logging.info('### Loading model to parallel GPUs ###') 108 | 109 | model = utils.model_to_gpus(model, args) 110 | 111 | logging.info('### Preparing schedulers and optimizers ###') 112 | optimizer = torch.optim.SGD( 113 | model.parameters(), 114 | args.learning_rate, 115 | momentum=0.9, 116 | weight_decay = args.weight_decay) 117 | 118 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 119 | optimizer, args.epochs) 120 | 121 | logging.info('## Beginning Training ##') 122 | 123 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 124 | 125 | best_error, train_time, val_time = train.train_loop( 126 | train_loader, valid_loader) 127 | 128 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 129 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 130 | 131 | if args.q: 132 | quant_utils.postprocess_model(model, args) 133 | 134 | 135 | logging.info('## Beginning Plotting ##') 136 | del model 137 | 138 | with torch.no_grad(): 139 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 140 | if args.q: 141 | quant_utils.prepare_model(model, args) 142 | quant_utils.convert(model) 143 | 144 | utils.load_model(model, args.save+"/weights.pt") 145 | logging.info('## Model re-created: ##') 146 | logging.info(model.__repr__()) 147 | 148 | model.eval() 149 | 150 | evaluate_mnist_uncertainty(model, args) 151 | logging.info('# Finished #') 152 | 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /experiments/scripts/pointwise/quantised/train/pointwise_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | import src.utils as utils 20 | import src.quant_utils as quant_utils 21 | 22 | parser = argparse.ArgumentParser("pointwise_regression") 23 | 24 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 25 | parser.add_argument('--model', type=str, default='linear', help='the model that we want to train') 26 | 27 | parser.add_argument('--learning_rate', type=float, 28 | default=0.00001, help='init learning rate') 29 | parser.add_argument('--loss_scaling', type=str, 30 | default='batch', help='smoothing factor') 31 | parser.add_argument('--weight_decay', type=float, 32 | default=0.00005, help='weight decay') 33 | 34 | parser.add_argument('--data', type=str, default='./../../../../data/', 35 | help='location of the data corpus') 36 | parser.add_argument('--dataset', type=str, default='regression', 37 | help='dataset') 38 | parser.add_argument('--batch_size', type=int, default=1000, help='batch size') 39 | 40 | 41 | parser.add_argument('--valid_portion', type=float, 42 | default=0.2, help='portion of training data') 43 | 44 | parser.add_argument('--epochs', type=int, default=10, 45 | help='num of training epochs') 46 | 47 | parser.add_argument('--input_size', nargs='+', 48 | default=[1], help='input size') 49 | parser.add_argument('--output_size', type=int, 50 | default=1, help='output size') 51 | parser.add_argument('--samples', type=int, 52 | default=1, help='output size') 53 | 54 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 55 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 56 | 57 | parser.add_argument('--save_last', action='store_true', default=True, 58 | help='whether to just save the last model') 59 | 60 | parser.add_argument('--num_workers', type=int, 61 | default=0, help='number of workers') 62 | parser.add_argument('--seed', type=int, default=1, help='random seed') 63 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 64 | 65 | parser.add_argument('--report_freq', type=float, 66 | default=50, help='report frequency') 67 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 68 | 69 | parser.add_argument('--q', action='store_true', default=True, 70 | help='whether to do post training quantisation') 71 | parser.add_argument('--at', action='store_true', default=True, 72 | help='whether to do training aware quantisation') 73 | parser.add_argument('--activation_precision', type=int, default=7, 74 | help='how many bits to be used for the activations') 75 | parser.add_argument('--weight_precision', type=int, default=8, 76 | help='how many bits to be used for the weights') 77 | 78 | def main(): 79 | args = parser.parse_args() 80 | load = False 81 | if args.save!='EXP': 82 | load=True 83 | 84 | model_temp = ModelFactory.get_model 85 | 86 | args, writer = utils.parse_args(args) 87 | logging.info('# Start Re-training #') 88 | if not load: 89 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 90 | for j in range(n_folds): 91 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 92 | 93 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 94 | 95 | logging.info('## Downloading and preparing data ##') 96 | args.dataset = "regression_" + dataset 97 | 98 | train_loader, valid_loader = get_train_loaders(args, split=j) 99 | in_shape = next(iter(train_loader))[0].shape[1] 100 | args.input_size = [in_shape] 101 | 102 | model = model_temp(args.model, args.input_size, 103 | args.output_size, args.at, args) 104 | utils.load_model( 105 | model, args.load+"/weights_{}_{}.pt".format(dataset, j)) 106 | 107 | if args.at: 108 | logging.info('## Preparing model for quantization aware training ##') 109 | quant_utils.prepare_model(model, args) 110 | 111 | logging.info('## Model created: ##') 112 | logging.info(model.__repr__()) 113 | 114 | logging.info('### Loading model to parallel GPUs ###') 115 | 116 | model = utils.model_to_gpus(model, args) 117 | 118 | logging.info('### Preparing schedulers and optimizers ###') 119 | optimizer = torch.optim.SGD( 120 | model.parameters(), 121 | args.learning_rate, 122 | momentum=0.9, 123 | weight_decay=args.weight_decay) 124 | 125 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 126 | optimizer, args.epochs) 127 | 128 | logging.info('## Beginning Training ##') 129 | 130 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 131 | 132 | best_error, train_time, val_time = train.train_loop( 133 | train_loader, valid_loader, special_info="_"+dataset+"_"+str(j)) 134 | 135 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 136 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 137 | 138 | if args.q: 139 | quant_utils.postprocess_model( 140 | model, args, special_info="_{}_{}".format(dataset, j)) 141 | 142 | del model 143 | 144 | with torch.no_grad(): 145 | logging.info('## Beginning Plotting ##') 146 | evaluate_regression_uncertainty(model_temp, args) 147 | 148 | logging.info('# Finished #') 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/stochastic/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/float/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/stochastic/bbb/float/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/float/bbb_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | 13 | from experiments.utils import evaluate_cifar_uncertainty 14 | from src.data import * 15 | from src.trainer import Trainer 16 | from src.models import ModelFactory 17 | from src.losses import LOSS_FACTORY 18 | import src.utils as utils 19 | 20 | parser = argparse.ArgumentParser("cifar_classifier") 21 | 22 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 23 | parser.add_argument('--model', type=str, default='conv_resnet_bbb', help='the model that we want to train') 24 | 25 | parser.add_argument('--learning_rate', type=float, 26 | default=0.001, help='init learning rate') 27 | parser.add_argument('--loss_scaling', type=str, 28 | default='batch', help='smoothing factor') 29 | parser.add_argument('--weight_decay', type=float, 30 | default=0.0, help='weight decay') 31 | 32 | parser.add_argument('--data', type=str, default='./../../../data/', 33 | help='location of the data corpus') 34 | parser.add_argument('--dataset', type=str, default='cifar', 35 | help='dataset') 36 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 37 | 38 | parser.add_argument('--valid_portion', type=float, 39 | default=0.1, help='portion of training data') 40 | 41 | parser.add_argument('--gamma', type=float, 42 | default=0.01, help='portion of training data') 43 | parser.add_argument('--sigma_prior', type=float, 44 | default=.05, help='portion of training data') 45 | 46 | parser.add_argument('--epochs', type=int, default=300, 47 | help='num of training epochs') 48 | 49 | parser.add_argument('--input_size', nargs='+', 50 | default=[1, 3, 32, 32], help='input size') 51 | parser.add_argument('--output_size', type=int, 52 | default=10, help='output size') 53 | parser.add_argument('--samples', type=int, 54 | default=20, help='output size') 55 | 56 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 57 | parser.add_argument('--save_last', action='store_true', default=True, 58 | help='whether to just save the last model') 59 | 60 | parser.add_argument('--num_workers', type=int, 61 | default=16, help='number of workers') 62 | parser.add_argument('--seed', type=int, default=1, help='random seed') 63 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 64 | 65 | parser.add_argument('--report_freq', type=float, 66 | default=50, help='report frequency') 67 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 68 | 69 | parser.add_argument('--q', action='store_true', default=False, 70 | help='whether to do post training quantisation') 71 | parser.add_argument('--at', action='store_true', default=False, 72 | help='whether to do training aware quantisation') 73 | 74 | 75 | 76 | def main(): 77 | args = parser.parse_args() 78 | load = False 79 | if args.save!='EXP': 80 | load=True 81 | 82 | args, writer = utils.parse_args(args) 83 | 84 | logging.info('# Start Re-training #') 85 | 86 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 87 | 88 | model_temp = ModelFactory.get_model 89 | 90 | logging.info('## Downloading and preparing data ##') 91 | train_loader, valid_loader= get_train_loaders(args) 92 | 93 | if not load: 94 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 95 | 96 | 97 | logging.info('## Model created: ##') 98 | logging.info(model.__repr__()) 99 | 100 | logging.info('### Loading model to parallel GPUs ###') 101 | 102 | model = utils.model_to_gpus(model, args) 103 | 104 | logging.info('### Preparing schedulers and optimizers ###') 105 | optimizer = torch.optim.Adam( 106 | model.parameters(), 107 | args.learning_rate, 108 | weight_decay = args.weight_decay) 109 | 110 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 111 | optimizer, args.epochs) 112 | 113 | logging.info('## Beginning Training ##') 114 | 115 | train = Trainer(model, criterion, optimizer, scheduler, args, writer) 116 | 117 | best_error, train_time, val_time = train.train_loop( 118 | train_loader, valid_loader) 119 | 120 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 121 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 122 | 123 | logging.info('## Beginning Plotting ##') 124 | del model 125 | 126 | with torch.no_grad(): 127 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 128 | 129 | 130 | utils.load_model(model, args.save+"/weights.pt") 131 | 132 | logging.info('## Model re-created: ##') 133 | logging.info(model.__repr__()) 134 | model = utils.model_to_gpus(model, args) 135 | model.eval() 136 | 137 | evaluate_cifar_uncertainty(model, args) 138 | 139 | logging.info('# Finished #') 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/float/bbb_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | 13 | from experiments.utils import evaluate_mnist_uncertainty 14 | from src.data import * 15 | from src.trainer import Trainer 16 | from src.models import ModelFactory 17 | from src.losses import LOSS_FACTORY 18 | import src.utils as utils 19 | 20 | parser = argparse.ArgumentParser("mnist_classifier") 21 | 22 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 23 | parser.add_argument('--model', type=str, default='conv_lenet_bbb', help='the model that we want to train') 24 | 25 | parser.add_argument('--learning_rate', type=float, 26 | default=0.001, help='init learning rate') 27 | parser.add_argument('--loss_scaling', type=str, 28 | default='batch', help='smoothing factor') 29 | parser.add_argument('--weight_decay', type=float, 30 | default=0.0, help='weight decay') 31 | 32 | parser.add_argument('--data', type=str, default='./../../../data/', 33 | help='location of the data corpus') 34 | parser.add_argument('--dataset', type=str, default='mnist', 35 | help='dataset') 36 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 37 | 38 | parser.add_argument('--valid_portion', type=float, 39 | default=0.1, help='portion of training data') 40 | 41 | parser.add_argument('--gamma', type=float, 42 | default=.1, help='portion of training data') 43 | parser.add_argument('--sigma_prior', type=float, 44 | default=.1, help='portion of training data') 45 | 46 | parser.add_argument('--epochs', type=int, default=100, 47 | help='num of training epochs') 48 | 49 | parser.add_argument('--input_size', nargs='+', 50 | default=[1, 1, 28, 28], help='input size') 51 | parser.add_argument('--output_size', type=int, 52 | default=10, help='output size') 53 | parser.add_argument('--samples', type=int, 54 | default=20, help='output size') 55 | 56 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 57 | parser.add_argument('--save_last', action='store_true', default=True, 58 | help='whether to just save the last model') 59 | 60 | parser.add_argument('--num_workers', type=int, 61 | default=16, help='number of workers') 62 | parser.add_argument('--seed', type=int, default=1, help='random seed') 63 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 64 | 65 | parser.add_argument('--report_freq', type=float, 66 | default=50, help='report frequency') 67 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 68 | 69 | parser.add_argument('--q', action='store_true', 70 | help='whether to do post training quantisation') 71 | parser.add_argument('--at', action='store_true', 72 | help='whether to do training aware quantisation') 73 | 74 | 75 | def main(): 76 | args = parser.parse_args() 77 | load = False 78 | if args.save!='EXP': 79 | load=True 80 | 81 | args, writer = utils.parse_args(args) 82 | 83 | logging.info('# Start Re-training #') 84 | 85 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 86 | 87 | model_temp = ModelFactory.get_model 88 | 89 | logging.info('## Downloading and preparing data ##') 90 | train_loader, valid_loader= get_train_loaders(args) 91 | 92 | if not load: 93 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 94 | 95 | logging.info('## Model created: ##') 96 | logging.info(model.__repr__()) 97 | 98 | 99 | logging.info('### Loading model to parallel GPUs ###') 100 | 101 | model = utils.model_to_gpus(model, args) 102 | 103 | logging.info('### Preparing schedulers and optimizers ###') 104 | optimizer = torch.optim.Adam( 105 | model.parameters(), 106 | args.learning_rate, 107 | weight_decay = args.weight_decay) 108 | 109 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 110 | optimizer, args.epochs) 111 | logging.info('## Downloading and preparing data ##') 112 | train_loader, valid_loader= get_train_loaders(args) 113 | 114 | logging.info('## Beginning Training ##') 115 | 116 | train = Trainer(model, criterion, optimizer, scheduler, args, writer) 117 | 118 | best_error, train_time, val_time = train.train_loop( 119 | train_loader, valid_loader) 120 | 121 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 122 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 123 | 124 | logging.info('## Beginning Plotting ##') 125 | del model 126 | 127 | with torch.no_grad(): 128 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 129 | 130 | utils.load_model(model, args.save+"/weights.pt") 131 | 132 | logging.info('## Model re-created: ##') 133 | logging.info(model.__repr__()) 134 | 135 | model = utils.model_to_gpus(model, args) 136 | model.eval() 137 | evaluate_mnist_uncertainty(model, args) 138 | 139 | logging.info('# Finished #') 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/float/bbb_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | 13 | from src.data import * 14 | from src.trainer import Trainer 15 | from src.models import ModelFactory 16 | from src.losses import LOSS_FACTORY 17 | import src.utils as utils 18 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 19 | 20 | parser = argparse.ArgumentParser("pointwise_regression") 21 | 22 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 23 | parser.add_argument('--model', type=str, default='linear_bbb', help='the model that we want to train') 24 | 25 | parser.add_argument('--learning_rate', type=float, 26 | default=0.01, help='init learning rate') 27 | parser.add_argument('--loss_scaling', type=str, 28 | default='batch', help='smoothing factor') 29 | parser.add_argument('--weight_decay', type=float, 30 | default=0.0, help='weight decay') 31 | 32 | parser.add_argument('--data', type=str, default='./../../../data/', 33 | help='location of the data corpus') 34 | parser.add_argument('--dataset', type=str, default='regression', 35 | help='dataset') 36 | parser.add_argument('--batch_size', type=int, default=1000, help='batch size') 37 | 38 | parser.add_argument('--valid_portion', type=float, 39 | default=0.2, help='portion of training data') 40 | 41 | parser.add_argument('--gamma', type=float, 42 | default=1.0, help='portion of training data') 43 | parser.add_argument('--sigma_prior', type=float, 44 | default=1., help='portion of training data') 45 | 46 | 47 | parser.add_argument('--epochs', type=int, default=300, 48 | help='num of training epochs') 49 | 50 | parser.add_argument('--input_size', nargs='+', 51 | default=[1], help='input size') 52 | parser.add_argument('--output_size', type=int, 53 | default=1, help='output size') 54 | parser.add_argument('--samples', type=int, 55 | default=20, help='output size') 56 | 57 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 58 | parser.add_argument('--save_last', action='store_true', default=True, 59 | help='whether to just save the last model') 60 | 61 | parser.add_argument('--num_workers', type=int, 62 | default=0, help='number of workers') 63 | parser.add_argument('--seed', type=int, default=1, help='random seed') 64 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 65 | 66 | parser.add_argument('--report_freq', type=float, 67 | default=50, help='report frequency') 68 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 69 | 70 | parser.add_argument('--q', action='store_true', 71 | help='whether to do post training quantisation') 72 | parser.add_argument('--at', action='store_true', 73 | help='whether to do training aware quantisation') 74 | 75 | 76 | def main(): 77 | args = parser.parse_args() 78 | load = False 79 | if args.save != 'EXP': 80 | load = True 81 | 82 | model_temp = ModelFactory.get_model 83 | 84 | args, writer = utils.parse_args(args) 85 | logging.info('# Start Re-training #') 86 | if not load: 87 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 88 | for j in range(n_folds): 89 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 90 | 91 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 92 | 93 | logging.info('## Downloading and preparing data ##') 94 | args.dataset = "regression_" + dataset 95 | 96 | train_loader, valid_loader = get_train_loaders(args, split=j) 97 | in_shape = next(iter(train_loader))[0].shape[1] 98 | args.input_size = [in_shape] 99 | 100 | model = model_temp(args.model, args.input_size, 101 | args.output_size, args.at, args) 102 | 103 | logging.info('## Model created: ##') 104 | logging.info(model.__repr__()) 105 | 106 | logging.info('### Loading model to parallel GPUs ###') 107 | 108 | model = utils.model_to_gpus(model, args) 109 | 110 | logging.info('### Preparing schedulers and optimizers ###') 111 | optimizer = torch.optim.Adam( 112 | model.parameters(), 113 | args.learning_rate, 114 | weight_decay=args.weight_decay) 115 | 116 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 117 | optimizer, args.epochs) 118 | 119 | logging.info('## Beginning Training ##') 120 | 121 | train = Trainer(model, criterion, optimizer, scheduler, args, writer) 122 | 123 | best_error, train_time, val_time = train.train_loop( 124 | train_loader, valid_loader, special_info="_"+dataset+"_"+str(j)) 125 | 126 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 127 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 128 | 129 | logging.info('## Beginning Plotting ##') 130 | del model 131 | with torch.no_grad(): 132 | logging.info('## Beginning Plotting ##') 133 | evaluate_regression_uncertainty(model_temp, args) 134 | 135 | logging.info('# Finished #') 136 | 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/quantised/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/stochastic/bbb/quantised/train/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/quantised/train/bbb_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_cifar_uncertainty 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | parser = argparse.ArgumentParser("cifar_classifier") 23 | 24 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 25 | parser.add_argument('--model', type=str, default='conv_resnet_bbb', help='the model that we want to train') 26 | 27 | parser.add_argument('--learning_rate', type=float, 28 | default=0.00001, help='init learning rate') 29 | parser.add_argument('--loss_scaling', type=str, 30 | default='batch', help='smoothing factor') 31 | parser.add_argument('--weight_decay', type=float, 32 | default=0.0, help='weight decay') 33 | 34 | parser.add_argument('--data', type=str, default='./../../../../../data/', 35 | help='location of the data corpus') 36 | parser.add_argument('--dataset', type=str, default='cifar', 37 | help='dataset') 38 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 39 | parser.add_argument('--sigma_prior', type=float, 40 | default=.05, help='portion of training data') 41 | parser.add_argument('--gamma', type=float, 42 | default=.0, help='portion of training data') 43 | parser.add_argument('--valid_portion', type=float, 44 | default=0.1, help='portion of training data') 45 | 46 | 47 | parser.add_argument('--epochs', type=int, default=10, 48 | help='num of training epochs') 49 | 50 | parser.add_argument('--input_size', nargs='+', 51 | default=[1, 3, 32, 32], help='input size') 52 | parser.add_argument('--output_size', type=int, 53 | default=10, help='output size') 54 | parser.add_argument('--samples', type=int, 55 | default=20, help='output size') 56 | 57 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 58 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 59 | 60 | parser.add_argument('--save_last', action='store_true', default=True, 61 | help='whether to just save the last model') 62 | 63 | parser.add_argument('--num_workers', type=int, 64 | default=16, help='number of workers') 65 | parser.add_argument('--seed', type=int, default=1, help='random seed') 66 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 67 | 68 | parser.add_argument('--report_freq', type=float, 69 | default=50, help='report frequency') 70 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 71 | 72 | parser.add_argument('--q', action='store_true', default=True, 73 | help='whether to do post training quantisation') 74 | parser.add_argument('--at', action='store_true', default=True, 75 | help='whether to do training aware quantisation') 76 | parser.add_argument('--activation_precision', type=int, default=7, 77 | help='how many bits to be used for the activations') 78 | parser.add_argument('--weight_precision', type=int, default=8, 79 | help='how many bits to be used for the weights') 80 | 81 | 82 | def main(): 83 | args = parser.parse_args() 84 | load = False 85 | if args.save!='EXP': 86 | load=True 87 | 88 | args, writer = utils.parse_args(args) 89 | 90 | logging.info('# Start Re-training #') 91 | 92 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 93 | 94 | model_temp = ModelFactory.get_model 95 | 96 | logging.info('## Downloading and preparing data ##') 97 | train_loader, valid_loader= get_train_loaders(args) 98 | 99 | if not load: 100 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 101 | 102 | utils.load_model(model, args.load+"/weights.pt") 103 | 104 | if args.at: 105 | logging.info('## Preparing model for quantization aware training ##') 106 | quant_utils.prepare_model(model, args) 107 | 108 | logging.info('## Model created: ##') 109 | logging.info(model.__repr__()) 110 | 111 | logging.info('### Loading model to parallel GPUs ###') 112 | model = utils.model_to_gpus(model, args) 113 | 114 | optimizer = torch.optim.SGD( 115 | model.parameters(), 116 | args.learning_rate, 117 | momentum = 0.9, 118 | weight_decay = args.weight_decay) 119 | 120 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 121 | optimizer, args.epochs) 122 | 123 | logging.info('## Beginning Training ##') 124 | 125 | train = Trainer(model, criterion, optimizer, scheduler, args, writer) 126 | best_error, train_time, val_time = train.train_loop( 127 | train_loader, valid_loader) 128 | 129 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 130 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 131 | 132 | if args.q: 133 | quant_utils.postprocess_model(model, args) 134 | 135 | logging.info('## Beginning Plotting ##') 136 | del model 137 | 138 | with torch.no_grad(): 139 | model= model_temp(args.model, args.input_size, args.output_size, args.q, args) 140 | if args.q: 141 | quant_utils.prepare_model(model, args) 142 | quant_utils.convert(model) 143 | 144 | utils.load_model(model, args.save+"/weights.pt") 145 | 146 | logging.info('## Model re-created: ##') 147 | logging.info(model.__repr__()) 148 | 149 | if not args.q: 150 | model = utils.model_to_gpus(model, args) 151 | model.eval() 152 | 153 | evaluate_cifar_uncertainty(model, args) 154 | 155 | logging.info('# Finished #') 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/quantised/train/bbb_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | 6 | sys.path.append("../") 7 | sys.path.append("../../") 8 | sys.path.append("../../../") 9 | sys.path.append("../../../../") 10 | sys.path.append("../../../../../") 11 | sys.path.append("../../../../../../") 12 | sys.path.append("../../../../../../../") 13 | 14 | from experiments.utils import evaluate_mnist_uncertainty 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | import src.utils as utils 20 | import src.quant_utils as quant_utils 21 | 22 | parser = argparse.ArgumentParser("mnist_classifier") 23 | 24 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 25 | parser.add_argument('--model', type=str, default='conv_lenet_bbb', help='the model that we want to train') 26 | 27 | parser.add_argument('--learning_rate', type=float, 28 | default=0.00001, help='init learning rate') 29 | parser.add_argument('--loss_scaling', type=str, 30 | default='batch', help='smoothing factor') 31 | parser.add_argument('--weight_decay', type=float, 32 | default=0.0, help='weight decay') 33 | 34 | parser.add_argument('--data', type=str, default='./../../../../../data/', 35 | help='location of the data corpus') 36 | parser.add_argument('--dataset', type=str, default='mnist', 37 | help='dataset') 38 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 39 | parser.add_argument('--sigma_prior', type=float, 40 | default=.1, help='portion of training data') 41 | parser.add_argument('--gamma', type=float, 42 | default=.0, help='portion of training data') 43 | 44 | parser.add_argument('--valid_portion', type=float, 45 | default=0.1, help='portion of training data') 46 | 47 | parser.add_argument('--epochs', type=int, default=10, 48 | help='num of training epochs') 49 | 50 | parser.add_argument('--input_size', nargs='+', 51 | default=[1, 1, 28, 28], help='input size') 52 | parser.add_argument('--output_size', type=int, 53 | default=10, help='output size') 54 | parser.add_argument('--samples', type=int, 55 | default=20, help='output size') 56 | 57 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 58 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 59 | 60 | parser.add_argument('--save_last', action='store_true', default=True, 61 | help='whether to just save the last model') 62 | 63 | parser.add_argument('--num_workers', type=int, 64 | default=16, help='number of workers') 65 | parser.add_argument('--seed', type=int, default=1, help='random seed') 66 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 67 | 68 | parser.add_argument('--report_freq', type=float, 69 | default=50, help='report frequency') 70 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 71 | 72 | parser.add_argument('--q', action='store_true', default=True, 73 | help='whether to do post training quantisation') 74 | parser.add_argument('--at', action='store_true', default=True, 75 | help='whether to do training aware quantisation') 76 | parser.add_argument('--activation_precision', type=int, default=7, 77 | help='how many bits to be used for the activations') 78 | parser.add_argument('--weight_precision', type=int, default=8, 79 | help='how many bits to be used for the weights') 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | load = False 84 | if args.save!='EXP': 85 | load=True 86 | 87 | args, writer = utils.parse_args(args) 88 | 89 | logging.info('# Start Re-training #') 90 | 91 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 92 | 93 | model_temp = ModelFactory.get_model 94 | 95 | logging.info('## Downloading and preparing data ##') 96 | train_loader, valid_loader= get_train_loaders(args) 97 | 98 | if not load: 99 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 100 | utils.load_model(model, args.load+"/weights.pt") 101 | 102 | if args.at: 103 | logging.info('## Preparing model for quantization aware training ##') 104 | quant_utils.prepare_model(model, args) 105 | 106 | logging.info('## Model created: ##') 107 | logging.info(model.__repr__()) 108 | 109 | logging.info('### Loading model to parallel GPUs ###') 110 | 111 | model = utils.model_to_gpus(model, args) 112 | 113 | logging.info('### Preparing schedulers and optimizers ###') 114 | optimizer = torch.optim.SGD( 115 | model.parameters(), 116 | args.learning_rate, 117 | momentum=0.9, 118 | weight_decay = args.weight_decay) 119 | 120 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 121 | optimizer, args.epochs) 122 | 123 | logging.info('## Beginning Training ##') 124 | 125 | train = Trainer(model, criterion, optimizer, scheduler, args, writer) 126 | 127 | best_error, train_time, val_time = train.train_loop( 128 | train_loader, valid_loader) 129 | 130 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 131 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 132 | 133 | if args.q: 134 | quant_utils.postprocess_model(model, args) 135 | 136 | logging.info('## Beginning Plotting ##') 137 | del model 138 | 139 | with torch.no_grad(): 140 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 141 | if args.q: 142 | quant_utils.prepare_model(model, args) 143 | quant_utils.convert(model) 144 | 145 | 146 | utils.load_model(model, args.save+"/weights.pt") 147 | 148 | logging.info('## Model re-created: ##') 149 | logging.info(model.__repr__()) 150 | 151 | if not args.q: 152 | model = utils.model_to_gpus(model, args) 153 | 154 | model.eval() 155 | evaluate_mnist_uncertainty(model, args) 156 | 157 | logging.info('# Finished #') 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/bbb/quantised/train/bbb_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | 23 | parser = argparse.ArgumentParser("pointwise_regression") 24 | 25 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 26 | parser.add_argument('--model', type=str, default='linear_bbb', help='the model that we want to train') 27 | 28 | parser.add_argument('--learning_rate', type=float, 29 | default=0.00001, help='init learning rate') 30 | parser.add_argument('--loss_scaling', type=str, 31 | default='batch', help='smoothing factor') 32 | parser.add_argument('--weight_decay', type=float, 33 | default=0.0, help='weight decay') 34 | 35 | parser.add_argument('--data', type=str, default='./../../../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='regression', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=1000, help='batch size') 40 | parser.add_argument('--sigma_prior', type=float, 41 | default=1., help='portion of training data') 42 | parser.add_argument('--gamma', type=float, 43 | default=.0, help='portion of training data') 44 | parser.add_argument('--valid_portion', type=float, 45 | default=0.2, help='portion of training data') 46 | 47 | 48 | parser.add_argument('--epochs', type=int, default=10, 49 | help='num of training epochs') 50 | 51 | parser.add_argument('--input_size', nargs='+', 52 | default=[1], help='input size') 53 | parser.add_argument('--output_size', type=int, 54 | default=1, help='output size') 55 | parser.add_argument('--samples', type=int, 56 | default=20, help='output size') 57 | 58 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 59 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 60 | 61 | parser.add_argument('--save_last', action='store_true', default=True, 62 | help='whether to just save the last model') 63 | 64 | parser.add_argument('--num_workers', type=int, 65 | default=0, help='number of workers') 66 | parser.add_argument('--seed', type=int, default=1, help='random seed') 67 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 68 | 69 | parser.add_argument('--report_freq', type=float, 70 | default=50, help='report frequency') 71 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 72 | 73 | parser.add_argument('--q', action='store_true', default=True, 74 | help='whether to do post training quantisation') 75 | parser.add_argument('--at', action='store_true', default=True, 76 | help='whether to do training aware quantisation') 77 | parser.add_argument('--activation_precision', type=int, default=7, 78 | help='how many bits to be used for the activations') 79 | parser.add_argument('--weight_precision', type=int, default=8, 80 | help='how many bits to be used for the weights') 81 | 82 | def main(): 83 | args = parser.parse_args() 84 | load = False 85 | if args.save != 'EXP': 86 | load = True 87 | 88 | model_temp = ModelFactory.get_model 89 | 90 | args, writer = utils.parse_args(args) 91 | logging.info('# Start Re-training #') 92 | if not load: 93 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 94 | for j in range(n_folds): 95 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 96 | 97 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 98 | 99 | logging.info('## Downloading and preparing data ##') 100 | args.dataset = "regression_" + dataset 101 | 102 | train_loader, valid_loader = get_train_loaders(args, split=j) 103 | in_shape = next(iter(train_loader))[0].shape[1] 104 | args.input_size = [in_shape] 105 | 106 | model = model_temp(args.model, args.input_size, 107 | args.output_size, args.at, args) 108 | utils.load_model( 109 | model, args.load+"/weights_{}_{}.pt".format(dataset, j)) 110 | 111 | if args.at: 112 | logging.info('## Preparing model for quantization aware training ##') 113 | quant_utils.prepare_model(model, args) 114 | 115 | logging.info('## Model created: ##') 116 | logging.info(model.__repr__()) 117 | 118 | logging.info('### Loading model to parallel GPUs ###') 119 | 120 | model = utils.model_to_gpus(model, args) 121 | 122 | logging.info('### Preparing schedulers and optimizers ###') 123 | optimizer = torch.optim.SGD( 124 | model.parameters(), 125 | args.learning_rate, 126 | momentum=0.9, 127 | weight_decay=args.weight_decay) 128 | 129 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 130 | optimizer, args.epochs) 131 | 132 | logging.info('## Beginning Training ##') 133 | 134 | train = Trainer(model, criterion, optimizer, scheduler, args, writer) 135 | 136 | best_error, train_time, val_time = train.train_loop( 137 | train_loader, valid_loader, special_info="_"+dataset+"_"+str(j)) 138 | 139 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 140 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 141 | 142 | if args.q: 143 | quant_utils.postprocess_model( 144 | model, args, special_info="_{}_{}".format(dataset, j)) 145 | 146 | del model 147 | 148 | with torch.no_grad(): 149 | logging.info('## Beginning Plotting ##') 150 | evaluate_regression_uncertainty(model_temp, args) 151 | 152 | logging.info('# Finished #') 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/mcdropout/float/mcdropout_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from experiments.utils import evaluate_cifar_uncertainty 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | import src.utils as utils 20 | 21 | parser = argparse.ArgumentParser("cifar_classifier") 22 | 23 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 24 | parser.add_argument('--model', type=str, default='conv_resnet_mc', help='the model that we want to train') 25 | 26 | parser.add_argument('--learning_rate', type=float, 27 | default=0.005, help='init learning rate') 28 | parser.add_argument('--loss_scaling', type=str, 29 | default='batch', help='smoothing factor') 30 | parser.add_argument('--weight_decay', type=float, 31 | default=0.00001, help='weight decay') 32 | parser.add_argument('--p', type=float, 33 | default=0.15, help='dropout probability') 34 | 35 | parser.add_argument('--data', type=str, default='./../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='cifar', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 40 | 41 | 42 | parser.add_argument('--valid_portion', type=float, 43 | default=0.1, help='portion of training data') 44 | 45 | parser.add_argument('--epochs', type=int, default=300, 46 | help='num of training epochs') 47 | 48 | parser.add_argument('--input_size', nargs='+', 49 | default=[1, 3, 32, 32], help='input size') 50 | parser.add_argument('--output_size', type=int, 51 | default=10, help='output size') 52 | parser.add_argument('--samples', type=int, 53 | default=20, help='output size') 54 | 55 | parser.add_argument('--num_workers', type=int, 56 | default=16, help='number of workers') 57 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 58 | 59 | parser.add_argument('--save_last', action='store_true', default=True, 60 | help='whether to just save the last model') 61 | 62 | parser.add_argument('--seed', type=int, default=2, help='random seed') 63 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 64 | 65 | parser.add_argument('--report_freq', type=float, 66 | default=50, help='report frequency') 67 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 68 | 69 | 70 | parser.add_argument('--q', action='store_true', default=False, 71 | help='whether to do post training quantisation') 72 | parser.add_argument('--at', action='store_true', default=False, 73 | help='whether to do training aware quantisation') 74 | 75 | 76 | 77 | def main(): 78 | args = parser.parse_args() 79 | load = False 80 | if args.save!='EXP': 81 | load=True 82 | 83 | args, writer = utils.parse_args(args) 84 | 85 | logging.info('# Start Re-training #') 86 | 87 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 88 | 89 | model_temp = ModelFactory.get_model 90 | 91 | logging.info('## Downloading and preparing data ##') 92 | train_loader, valid_loader= get_train_loaders(args) 93 | 94 | if not load: 95 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 96 | 97 | 98 | logging.info('## Model created: ##') 99 | logging.info(model.__repr__()) 100 | 101 | 102 | logging.info('### Loading model to parallel GPUs ###') 103 | 104 | model = utils.model_to_gpus(model, args) 105 | 106 | logging.info('### Preparing schedulers and optimizers ###') 107 | optimizer = torch.optim.Adam( 108 | model.parameters(), 109 | args.learning_rate, 110 | weight_decay = args.weight_decay) 111 | 112 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 113 | optimizer, args.epochs) 114 | 115 | logging.info('## Beginning Training ##') 116 | 117 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 118 | 119 | best_error, train_time, val_time = train.train_loop( 120 | train_loader, valid_loader) 121 | 122 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 123 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 124 | 125 | logging.info('## Beginning Plotting ##') 126 | del model 127 | 128 | with torch.no_grad(): 129 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 130 | 131 | utils.load_model(model, args.save+"/weights.pt") 132 | 133 | logging.info('## Model re-created: ##') 134 | logging.info(model.__repr__()) 135 | model = utils.model_to_gpus(model, args) 136 | 137 | model.eval() 138 | evaluate_cifar_uncertainty(model, args) 139 | 140 | logging.info('# Finished #') 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/mcdropout/float/mcdropout_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from experiments.utils import evaluate_mnist_uncertainty 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | import src.utils as utils 20 | 21 | parser = argparse.ArgumentParser("mnist_classifier") 22 | 23 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 24 | parser.add_argument('--model', type=str, default='conv_lenet_mc', help='the model that we want to train') 25 | 26 | parser.add_argument('--learning_rate', type=float, 27 | default=0.001, help='init learning rate') 28 | parser.add_argument('--loss_scaling', type=str, 29 | default='batch', help='smoothing factor') 30 | parser.add_argument('--weight_decay', type=float, 31 | default=0.00001, help='weight decay') 32 | parser.add_argument('--p', type=float, 33 | default=0.2, help='dropout probability') 34 | 35 | parser.add_argument('--data', type=str, default='./../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='mnist', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 40 | 41 | parser.add_argument('--valid_portion', type=float, 42 | default=0.1, help='portion of training data') 43 | 44 | parser.add_argument('--epochs', type=int, default=100, 45 | help='num of training epochs') 46 | 47 | parser.add_argument('--input_size', nargs='+', 48 | default=[1, 1, 28, 28], help='input size') 49 | parser.add_argument('--output_size', type=int, 50 | default=10, help='output size') 51 | parser.add_argument('--samples', type=int, 52 | default=20, help='output size') 53 | 54 | parser.add_argument('--num_workers', type=int, 55 | default=16, help='number of workers') 56 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 57 | 58 | parser.add_argument('--save_last', action='store_true', default=True, 59 | help='whether to just save the last model') 60 | 61 | parser.add_argument('--seed', type=int, default=2, help='random seed') 62 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 63 | 64 | parser.add_argument('--report_freq', type=float, 65 | default=50, help='report frequency') 66 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 67 | 68 | parser.add_argument('--q', action='store_true', default=False, 69 | help='whether to do post training quantisation') 70 | parser.add_argument('--at', action='store_true', default=False, 71 | help='whether to do training aware quantisation') 72 | 73 | 74 | 75 | def main(): 76 | args = parser.parse_args() 77 | load = False 78 | if args.save!='EXP': 79 | load=True 80 | 81 | args, writer = utils.parse_args(args) 82 | 83 | logging.info('# Start Re-training #') 84 | 85 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 86 | 87 | model_temp = ModelFactory.get_model 88 | 89 | logging.info('## Downloading and preparing data ##') 90 | train_loader, valid_loader= get_train_loaders(args) 91 | 92 | if not load: 93 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 94 | 95 | logging.info('## Model created: ##') 96 | logging.info(model.__repr__()) 97 | 98 | logging.info('### Loading model to parallel GPUs ###') 99 | 100 | model = utils.model_to_gpus(model, args) 101 | 102 | logging.info('### Preparing schedulers and optimizers ###') 103 | optimizer = torch.optim.Adam( 104 | model.parameters(), 105 | args.learning_rate, 106 | weight_decay = args.weight_decay) 107 | 108 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 109 | optimizer, args.epochs) 110 | 111 | logging.info('## Beginning Training ##') 112 | 113 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 114 | 115 | best_error, train_time, val_time = train.train_loop( 116 | train_loader, valid_loader) 117 | 118 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 119 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 120 | 121 | logging.info('## Beginning Plotting ##') 122 | del model 123 | 124 | with torch.no_grad(): 125 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 126 | 127 | utils.load_model(model, args.save+"/weights.pt") 128 | 129 | logging.info('## Model re-created: ##') 130 | logging.info(model.__repr__()) 131 | model = utils.model_to_gpus(model, args) 132 | 133 | model.eval() 134 | evaluate_mnist_uncertainty(model, args) 135 | 136 | logging.info('# Finished #') 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/mcdropout/float/mcdropout_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from src.data import * 15 | from src.trainer import Trainer 16 | from src.models import ModelFactory 17 | from src.losses import LOSS_FACTORY 18 | import src.utils as utils 19 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 20 | 21 | 22 | parser = argparse.ArgumentParser("mcdropout_regression") 23 | 24 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 25 | parser.add_argument('--model', type=str, default='linear_mc', help='the model that we want to train') 26 | 27 | parser.add_argument('--learning_rate', type=float, 28 | default=0.001, help='init learning rate') 29 | parser.add_argument('--loss_scaling', type=str, 30 | default='batch', help='smoothing factor') 31 | parser.add_argument('--weight_decay', type=float, 32 | default=0.00005, help='weight decay') 33 | parser.add_argument('--p', type=float, 34 | default=0.2, help='dropout probability') 35 | 36 | parser.add_argument('--data', type=str, default='./../../../data/', 37 | help='location of the data corpus') 38 | parser.add_argument('--dataset', type=str, default='regression', 39 | help='dataset') 40 | parser.add_argument('--batch_size', type=int, default=1000, help='batch size') 41 | 42 | parser.add_argument('--valid_portion', type=float, 43 | default=0.2, help='portion of training data') 44 | 45 | parser.add_argument('--epochs', type=int, default=300, 46 | help='num of training epochs') 47 | 48 | parser.add_argument('--input_size', nargs='+', 49 | default=[1], help='input size') 50 | parser.add_argument('--output_size', type=int, 51 | default=1, help='output size') 52 | parser.add_argument('--samples', type=int, 53 | default=20, help='output size') 54 | 55 | parser.add_argument('--num_workers', type=int, 56 | default=0, help='number of workers') 57 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 58 | parser.add_argument('--save_last', action='store_true', default=True, 59 | help='whether to just save the last model') 60 | 61 | parser.add_argument('--seed', type=int, default=1, help='random seed') 62 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 63 | 64 | parser.add_argument('--report_freq', type=float, 65 | default=50, help='report frequency') 66 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 67 | 68 | parser.add_argument('--q', action='store_true', default=False, 69 | help='whether to do post training quantisation') 70 | parser.add_argument('--at', action='store_true', default=False, 71 | help='whether to do training aware quantisation') 72 | 73 | 74 | def main(): 75 | args = parser.parse_args() 76 | load = False 77 | if args.save != 'EXP': 78 | load = True 79 | 80 | args, writer = utils.parse_args(args) 81 | model_temp = ModelFactory.get_model 82 | 83 | logging.info('# Start Re-training #') 84 | if not load: 85 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 86 | for j in range(n_folds): 87 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 88 | 89 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 90 | 91 | logging.info('## Downloading and preparing data ##') 92 | args.dataset = "regression_" + dataset 93 | 94 | train_loader, valid_loader = get_train_loaders(args, split=j) 95 | in_shape = next(iter(train_loader))[0].shape[1] 96 | args.input_size = [in_shape] 97 | 98 | model = model_temp(args.model, args.input_size, 99 | args.output_size, args.at, args) 100 | 101 | logging.info('## Model created: ##') 102 | logging.info(model.__repr__()) 103 | 104 | logging.info('### Loading model to parallel GPUs ###') 105 | 106 | model = utils.model_to_gpus(model, args) 107 | 108 | logging.info('### Preparing schedulers and optimizers ###') 109 | optimizer = torch.optim.Adam( 110 | model.parameters(), 111 | args.learning_rate, 112 | weight_decay=args.weight_decay) 113 | 114 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 115 | optimizer, args.epochs) 116 | 117 | logging.info('## Beginning Training ##') 118 | 119 | train = Trainer(model, criterion, optimizer, scheduler, args, writer) 120 | 121 | best_error, train_time, val_time = train.train_loop( 122 | train_loader, valid_loader, special_info="_"+dataset+"_"+str(j)) 123 | 124 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 125 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 126 | 127 | del model 128 | with torch.no_grad(): 129 | logging.info('## Beginning Plotting ##') 130 | evaluate_regression_uncertainty(model_temp, args) 131 | 132 | logging.info('# Finished #') 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/mcdropout/quantised/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/stochastic/mcdropout/quantised/train/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/stochastic/mcdropout/quantised/train/mcdropout_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_cifar_uncertainty 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | 23 | parser = argparse.ArgumentParser("cifar_classifier") 24 | 25 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 26 | parser.add_argument('--model', type=str, default='conv_resnet_mc', help='the model that we want to train') 27 | 28 | parser.add_argument('--learning_rate', type=float, 29 | default=0.001, help='init learning rate') 30 | parser.add_argument('--loss_scaling', type=str, 31 | default='batch', help='smoothing factor') 32 | parser.add_argument('--weight_decay', type=float, 33 | default=0.00001, help='weight decay') 34 | parser.add_argument('--p', type=float, 35 | default=0.15, help='dropout probability') 36 | 37 | parser.add_argument('--data', type=str, default='./../../../../../data/', 38 | help='location of the data corpus') 39 | parser.add_argument('--dataset', type=str, default='cifar', 40 | help='dataset') 41 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 42 | 43 | parser.add_argument('--valid_portion', type=float, 44 | default=0.1, help='portion of training data') 45 | 46 | parser.add_argument('--epochs', type=int, default=10, 47 | help='num of training epochs') 48 | 49 | parser.add_argument('--input_size', nargs='+', 50 | default=[1, 3, 32, 32], help='input size') 51 | parser.add_argument('--output_size', type=int, 52 | default=10, help='output size') 53 | parser.add_argument('--samples', type=int, 54 | default=20, help='output size') 55 | 56 | parser.add_argument('--num_workers', type=int, 57 | default=16, help='number of workers') 58 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 59 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 60 | 61 | 62 | parser.add_argument('--save_last', action='store_true', 63 | help='whether to just save the last model', default=True) 64 | 65 | parser.add_argument('--seed', type=int, default=2, help='random seed') 66 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 67 | 68 | parser.add_argument('--report_freq', type=float, 69 | default=50, help='report frequency') 70 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 71 | 72 | parser.add_argument('--q', action='store_true', default=True, 73 | help='whether to do post training quantisation') 74 | parser.add_argument('--at', action='store_true', default=True, 75 | help='whether to do training aware quantisation') 76 | parser.add_argument('--activation_precision', type=int, default=7, 77 | help='how many bits to be used for the activations') 78 | parser.add_argument('--weight_precision', type=int, default=8, 79 | help='how many bits to be used for the weights') 80 | def main(): 81 | args = parser.parse_args() 82 | load = False 83 | if args.save!='EXP': 84 | load=True 85 | 86 | args, writer = utils.parse_args(args) 87 | 88 | logging.info('# Start Re-training #') 89 | 90 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 91 | 92 | model_temp = ModelFactory.get_model 93 | 94 | logging.info('## Downloading and preparing data ##') 95 | train_loader, valid_loader= get_train_loaders(args) 96 | 97 | if not load: 98 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 99 | utils.load_model(model, args.load+"/weights.pt") 100 | 101 | if args.at: 102 | logging.info('## Preparing model for quantization aware training ##') 103 | quant_utils.prepare_model(model, args) 104 | 105 | logging.info('## Model created: ##') 106 | logging.info(model.__repr__()) 107 | 108 | logging.info('### Loading model to parallel GPUs ###') 109 | 110 | model = utils.model_to_gpus(model, args) 111 | 112 | logging.info('### Preparing schedulers and optimizers ###') 113 | optimizer = torch.optim.SGD( 114 | model.parameters(), 115 | args.learning_rate, 116 | momentum = 0.9, 117 | weight_decay = args.weight_decay) 118 | 119 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 120 | optimizer, args.epochs) 121 | 122 | logging.info('## Beginning Training ##') 123 | 124 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 125 | 126 | best_error, train_time, val_time = train.train_loop( 127 | train_loader, valid_loader) 128 | 129 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 130 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 131 | 132 | if args.q: 133 | quant_utils.postprocess_model(model, args) 134 | 135 | logging.info('## Beginning Plotting ##') 136 | del model 137 | 138 | with torch.no_grad(): 139 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 140 | 141 | if args.q: 142 | quant_utils.prepare_model(model, args) 143 | quant_utils.convert(model) 144 | 145 | utils.load_model(model, args.save+"/weights.pt") 146 | 147 | logging.info('## Model re-created: ##') 148 | logging.info(model.__repr__()) 149 | 150 | if not args.q: 151 | model = utils.model_to_gpus(model, args) 152 | 153 | model.eval() 154 | evaluate_cifar_uncertainty(model, args) 155 | 156 | logging.info('# Finished #') 157 | 158 | 159 | if __name__ == '__main__': 160 | main() 161 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/mcdropout/quantised/train/mcdropout_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_mnist_uncertainty 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | 23 | parser = argparse.ArgumentParser("mnist_classifier") 24 | 25 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 26 | parser.add_argument('--model', type=str, default='conv_lenet_mc', help='the model that we want to train') 27 | 28 | parser.add_argument('--learning_rate', type=float, 29 | default=0.00001, help='init learning rate') 30 | parser.add_argument('--loss_scaling', type=str, 31 | default='batch', help='smoothing factor') 32 | parser.add_argument('--weight_decay', type=float, 33 | default=0.00001, help='weight decay') 34 | parser.add_argument('--p', type=float, 35 | default=0.2, help='dropout probability') 36 | 37 | parser.add_argument('--data', type=str, default='./../../../../../data/', 38 | help='location of the data corpus') 39 | parser.add_argument('--dataset', type=str, default='mnist', 40 | help='dataset') 41 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 42 | 43 | parser.add_argument('--valid_portion', type=float, 44 | default=0.1, help='portion of training data') 45 | 46 | parser.add_argument('--epochs', type=int, default=10, 47 | help='num of training epochs') 48 | 49 | parser.add_argument('--input_size', nargs='+', 50 | default=[1, 1, 28, 28], help='input size') 51 | parser.add_argument('--output_size', type=int, 52 | default=10, help='output size') 53 | parser.add_argument('--samples', type=int, 54 | default=20, help='output size') 55 | 56 | parser.add_argument('--num_workers', type=int, 57 | default=16, help='number of workers') 58 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 59 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 60 | 61 | parser.add_argument('--save_last', action='store_true', default=True, 62 | help='whether to just save the last model') 63 | 64 | parser.add_argument('--seed', type=int, default=2, help='random seed') 65 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 66 | 67 | parser.add_argument('--report_freq', type=float, 68 | default=50, help='report frequency') 69 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 70 | 71 | parser.add_argument('--q', action='store_true', default=True, 72 | help='whether to do post training quantisation') 73 | parser.add_argument('--at', action='store_true', default=True, 74 | help='whether to do training aware quantisation') 75 | parser.add_argument('--activation_precision', type=int, default=7, 76 | help='how many bits to be used for the activations') 77 | parser.add_argument('--weight_precision', type=int, default=8, 78 | help='how many bits to be used for the weights') 79 | 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | load = False 84 | if args.save!='EXP': 85 | load=True 86 | 87 | args, writer = utils.parse_args(args) 88 | 89 | logging.info('# Start Re-training #') 90 | 91 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 92 | 93 | model_temp = ModelFactory.get_model 94 | 95 | logging.info('## Downloading and preparing data ##') 96 | train_loader, valid_loader= get_train_loaders(args) 97 | 98 | if not load: 99 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args) 100 | utils.load_model(model, args.load+"/weights.pt") 101 | 102 | if args.at: 103 | logging.info('## Preparing model for quantization aware training ##') 104 | quant_utils.prepare_model(model, args) 105 | 106 | logging.info('## Model created: ##') 107 | logging.info(model.__repr__()) 108 | 109 | 110 | logging.info('### Loading model to parallel GPUs ###') 111 | 112 | model = utils.model_to_gpus(model, args) 113 | 114 | logging.info('### Preparing schedulers and optimizers ###') 115 | optimizer = torch.optim.SGD( 116 | model.parameters(), 117 | args.learning_rate, 118 | momentum = 0.9, 119 | weight_decay = args.weight_decay) 120 | 121 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 122 | optimizer, args.epochs) 123 | 124 | logging.info('## Beginning Training ##') 125 | 126 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 127 | 128 | best_error, train_time, val_time = train.train_loop( 129 | train_loader, valid_loader) 130 | 131 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 132 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 133 | 134 | if args.q: 135 | quant_utils.postprocess_model(model, args) 136 | 137 | logging.info('## Beginning Plotting ##') 138 | del model 139 | 140 | with torch.no_grad(): 141 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args) 142 | if args.q: 143 | quant_utils.prepare_model(model, args) 144 | quant_utils.convert(model) 145 | 146 | utils.load_model(model, args.save+"/weights.pt") 147 | 148 | logging.info('## Model re-created: ##') 149 | logging.info(model.__repr__()) 150 | 151 | if not args.q: 152 | model = utils.model_to_gpus(model, args) 153 | 154 | model.eval() 155 | evaluate_mnist_uncertainty(model, args) 156 | 157 | logging.info('# Finished #') 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/mcdropout/quantised/train/mcdropout_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | 23 | parser = argparse.ArgumentParser("mcdropout_regression") 24 | 25 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 26 | parser.add_argument('--model', type=str, default='linear_mc', help='the model that we want to train') 27 | 28 | parser.add_argument('--learning_rate', type=float, 29 | default=0.00001, help='init learning rate') 30 | parser.add_argument('--loss_scaling', type=str, 31 | default='batch', help='smoothing factor') 32 | parser.add_argument('--weight_decay', type=float, 33 | default=0.00005, help='weight decay') 34 | parser.add_argument('--p', type=float, 35 | default=0.2, help='dropout probability') 36 | 37 | parser.add_argument('--data', type=str, default='./../../../../../data/', 38 | help='location of the data corpus') 39 | parser.add_argument('--dataset', type=str, default='regression', 40 | help='dataset') 41 | parser.add_argument('--batch_size', type=int, default=1000, help='batch size') 42 | 43 | parser.add_argument('--valid_portion', type=float, 44 | default=0.2, help='portion of training data') 45 | 46 | parser.add_argument('--epochs', type=int, default=10, 47 | help='num of training epochs') 48 | 49 | parser.add_argument('--input_size', nargs='+', 50 | default=[1], help='input size') 51 | parser.add_argument('--output_size', type=int, 52 | default=1, help='output size') 53 | parser.add_argument('--samples', type=int, 54 | default=20, help='output size') 55 | 56 | parser.add_argument('--num_workers', type=int, 57 | default=0, help='number of workers') 58 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 59 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 60 | 61 | parser.add_argument('--save_last', action='store_true', default=True, 62 | help='whether to just save the last model') 63 | 64 | 65 | parser.add_argument('--seed', type=int, default=1, help='random seed') 66 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 67 | 68 | parser.add_argument('--report_freq', type=float, 69 | default=50, help='report frequency') 70 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 71 | 72 | parser.add_argument('--q', action='store_true', default=True, 73 | help='whether to do post training quantisation') 74 | parser.add_argument('--at', action='store_true', default=True, 75 | help='whether to do training aware quantisation') 76 | parser.add_argument('--activation_precision', type=int, default=7, 77 | help='how many bits to be used for the activations') 78 | parser.add_argument('--weight_precision', type=int, default=8, 79 | help='how many bits to be used for the weights') 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | load = False 84 | if args.save != 'EXP': 85 | load = True 86 | 87 | model_temp = ModelFactory.get_model 88 | 89 | args, writer = utils.parse_args(args) 90 | logging.info('# Start Re-training #') 91 | if not load: 92 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 93 | for j in range(n_folds): 94 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 95 | 96 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 97 | 98 | logging.info('## Downloading and preparing data ##') 99 | args.dataset = "regression_" + dataset 100 | 101 | train_loader, valid_loader = get_train_loaders(args, split=j) 102 | in_shape = next(iter(train_loader))[0].shape[1] 103 | args.input_size = [in_shape] 104 | 105 | model = model_temp(args.model, args.input_size, 106 | args.output_size, args.at, args) 107 | utils.load_model( 108 | model, args.load+"/weights_{}_{}.pt".format(dataset, j)) 109 | 110 | if args.at: 111 | logging.info('## Preparing model for quantization aware training ##') 112 | quant_utils.prepare_model(model, args) 113 | 114 | logging.info('## Model created: ##') 115 | logging.info(model.__repr__()) 116 | 117 | logging.info('### Loading model to parallel GPUs ###') 118 | 119 | model = utils.model_to_gpus(model, args) 120 | 121 | logging.info('### Preparing schedulers and optimizers ###') 122 | optimizer = torch.optim.SGD( 123 | model.parameters(), 124 | args.learning_rate, 125 | momentum=0.9, 126 | weight_decay=args.weight_decay) 127 | 128 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 129 | optimizer, args.epochs) 130 | 131 | logging.info('## Beginning Training ##') 132 | 133 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 134 | 135 | best_error, train_time, val_time = train.train_loop( 136 | train_loader, valid_loader, special_info="_"+dataset+"_"+str(j)) 137 | 138 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 139 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 140 | 141 | if args.q: 142 | quant_utils.postprocess_model( 143 | model, args, special_info="_{}_{}".format(dataset, j)) 144 | 145 | del model 146 | 147 | with torch.no_grad(): 148 | logging.info('## Beginning Plotting ##') 149 | evaluate_regression_uncertainty(model_temp, args) 150 | 151 | logging.info('# Finished #') 152 | 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/stochastic/sgld/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/float/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/stochastic/sgld/float/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/float/sgld_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from experiments.utils import evaluate_cifar_uncertainty 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | from src.models.stochastic.sgld.utils_sgld import SGLD 20 | import src.utils as utils 21 | 22 | 23 | parser = argparse.ArgumentParser("cifar_classifier") 24 | 25 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 26 | parser.add_argument('--model', type=str, default='conv_resnet_sgld', help='the model that we want to train') 27 | 28 | parser.add_argument('--learning_rate', type=float, 29 | default=0.01, help='init learning rate') 30 | parser.add_argument('--loss_scaling', type=str, 31 | default='whole', help='smoothing factor') 32 | parser.add_argument('--loss_multiplier', type=float, 33 | default=16, help='smoothing factor') 34 | 35 | 36 | parser.add_argument('--data', type=str, default='./../../../data/', 37 | help='location of the data corpus') 38 | parser.add_argument('--dataset', type=str, default='cifar', 39 | help='dataset') 40 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 41 | 42 | parser.add_argument('--valid_portion', type=float, 43 | default=0.1, help='portion of training data') 44 | 45 | parser.add_argument('--burnin_epochs', type=int, 46 | default=200, help='portion of training data') 47 | parser.add_argument('--resample_momentum_iterations', type=int, 48 | default=50, help='portion of training data') 49 | parser.add_argument('--resample_prior_iterations', type=int, 50 | default=25, help='portion of training data') 51 | 52 | parser.add_argument('--epochs', type=int, default=300, 53 | help='num of training epochs') 54 | 55 | parser.add_argument('--input_size', nargs='+', 56 | default=[1, 3, 32, 32], help='input size') 57 | parser.add_argument('--output_size', type=int, 58 | default=10, help='output size') 59 | parser.add_argument('--samples', type=int, 60 | default=20, help='output size') 61 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 62 | parser.add_argument('--save_last', action='store_true', default=True, 63 | help='whether to just save the last model') 64 | 65 | parser.add_argument('--num_workers', type=int, 66 | default=16, help='number of workers') 67 | parser.add_argument('--seed', type=int, default=1, help='random seed') 68 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 69 | 70 | parser.add_argument('--report_freq', type=float, 71 | default=50, help='report frequency') 72 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 73 | 74 | parser.add_argument('--q', action='store_true', default=False, 75 | help='whether to do post training quantisation') 76 | parser.add_argument('--at', action='store_true', default=False, 77 | help='whether to do training aware quantisation') 78 | 79 | 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | load = False 84 | if args.save!='EXP': 85 | load=True 86 | args, writer = utils.parse_args(args) 87 | 88 | logging.info('# Start Re-training #') 89 | 90 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 91 | 92 | model_temp = ModelFactory.get_model 93 | logging.info('## Downloading and preparing data ##') 94 | train_loader, valid_loader= get_train_loaders(args) 95 | 96 | if not load: 97 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args, True) 98 | 99 | logging.info('## Model created: ##') 100 | logging.info(model.__repr__()) 101 | 102 | logging.info('### Loading model to parallel GPUs ###') 103 | model = utils.model_to_gpus(model, args) 104 | 105 | logging.info('### Preparing schedulers and optimizers ###') 106 | optimizer = SGLD( 107 | model.parameters(), 108 | args.learning_rate) 109 | 110 | scheduler = None 111 | 112 | logging.info('## Beginning Training ##') 113 | 114 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 115 | 116 | best_error, train_time, val_time = train.train_loop( 117 | train_loader, valid_loader) 118 | 119 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 120 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 121 | 122 | logging.info('## Beginning Plotting ##') 123 | del model 124 | 125 | with torch.no_grad(): 126 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args, False) 127 | model.load_ensemble(args) 128 | 129 | logging.info('## Model re-created: ##') 130 | logging.info(model.__repr__()) 131 | 132 | model = utils.model_to_gpus(model, args) 133 | 134 | model.eval() 135 | evaluate_cifar_uncertainty(model, args) 136 | 137 | logging.info('# Finished #') 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/float/sgld_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from experiments.utils import evaluate_mnist_uncertainty 15 | from src.data import * 16 | from src.trainer import Trainer 17 | from src.models import ModelFactory 18 | from src.losses import LOSS_FACTORY 19 | from src.models.stochastic.sgld.utils_sgld import SGLD 20 | import src.utils as utils 21 | 22 | 23 | parser = argparse.ArgumentParser("mnist_classifier") 24 | 25 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 26 | parser.add_argument('--model', type=str, default='conv_lenet_sgld', help='the model that we want to train') 27 | 28 | parser.add_argument('--learning_rate', type=float, 29 | default=0.01, help='init learning rate') 30 | parser.add_argument('--loss_scaling', type=str, 31 | default='whole', help='smoothing factor') 32 | parser.add_argument('--loss_multiplier', type=float, 33 | default=1, help='smoothing factor') 34 | 35 | parser.add_argument('--data', type=str, default='./../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='mnist', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 40 | 41 | 42 | parser.add_argument('--valid_portion', type=float, 43 | default=0.1, help='portion of training data') 44 | 45 | parser.add_argument('--burnin_epochs', type=int, 46 | default=20, help='portion of training data') 47 | parser.add_argument('--resample_momentum_iterations', type=int, 48 | default=50, help='portion of training data') 49 | parser.add_argument('--resample_prior_iterations', type=int, 50 | default=15, help='portion of training data') 51 | parser.add_argument('--epochs', type=int, default=100, 52 | help='num of training epochs') 53 | 54 | parser.add_argument('--input_size', nargs='+', 55 | default=[1, 1, 28, 28], help='input size') 56 | parser.add_argument('--output_size', type=int, 57 | default=10, help='output size') 58 | parser.add_argument('--samples', type=int, 59 | default=20, help='output size') 60 | 61 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 62 | parser.add_argument('--save_last', action='store_true', default=True, 63 | help='whether to just save the last model') 64 | 65 | parser.add_argument('--num_workers', type=int, 66 | default=16, help='number of workers') 67 | parser.add_argument('--seed', type=int, default=1, help='random seed') 68 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 69 | 70 | parser.add_argument('--report_freq', type=float, 71 | default=50, help='report frequency') 72 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 73 | 74 | parser.add_argument('--q', action='store_true', default=False, 75 | help='whether to do post training quantisation') 76 | parser.add_argument('--at', action='store_true', default=False, 77 | help='whether to do training aware quantisation') 78 | 79 | 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | load = False 84 | if args.save!='EXP': 85 | load=True 86 | 87 | args, writer = utils.parse_args(args) 88 | 89 | logging.info('# Start Re-training #') 90 | 91 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 92 | 93 | model_temp = ModelFactory.get_model 94 | 95 | logging.info('## Downloading and preparing data ##') 96 | train_loader, valid_loader= get_train_loaders(args) 97 | 98 | if not load: 99 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args, True) 100 | 101 | logging.info('## Model created: ##') 102 | logging.info(model.__repr__()) 103 | 104 | logging.info('### Loading model to parallel GPUs ###') 105 | model = utils.model_to_gpus(model, args) 106 | 107 | logging.info('### Preparing schedulers and optimizers ###') 108 | optimizer = SGLD( 109 | model.parameters(), 110 | args.learning_rate) 111 | scheduler = None 112 | 113 | 114 | logging.info('## Beginning Training ##') 115 | 116 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 117 | 118 | best_error, train_time, val_time = train.train_loop( 119 | train_loader, valid_loader) 120 | 121 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 122 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 123 | 124 | logging.info('## Beginning Plotting ##') 125 | del model 126 | 127 | with torch.no_grad(): 128 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args, False) 129 | model.load_ensemble(args) 130 | 131 | logging.info('## Model re-created: ##') 132 | logging.info(model.__repr__()) 133 | 134 | model = utils.model_to_gpus(model, args) 135 | 136 | model.eval() 137 | evaluate_mnist_uncertainty(model, args) 138 | 139 | logging.info('# Finished #') 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/float/sgld_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | 14 | from src.data import * 15 | from src.trainer import Trainer 16 | from src.models import ModelFactory 17 | from src.losses import LOSS_FACTORY 18 | from src.models.stochastic.sgld.utils_sgld import SGLD 19 | import src.utils as utils 20 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 21 | 22 | parser = argparse.ArgumentParser("stochastic_sgld_regression") 23 | 24 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 25 | parser.add_argument('--model', type=str, default='linear_sgld', help='the model that we want to train') 26 | 27 | parser.add_argument('--learning_rate', type=float, 28 | default=0.01, help='init learning rate') 29 | parser.add_argument('--loss_scaling', type=str, 30 | default='whole', help='smoothing factor') 31 | parser.add_argument('--loss_multiplier', type=float, 32 | default=2, help='smoothing factor') 33 | 34 | parser.add_argument('--data', type=str, default='./../../../data/', 35 | help='location of the data corpus') 36 | parser.add_argument('--dataset', type=str, default='regression', 37 | help='dataset') 38 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 39 | 40 | parser.add_argument('--valid_portion', type=float, 41 | default=0.2, help='portion of training data') 42 | 43 | parser.add_argument('--burnin_epochs', type=int, 44 | default=200, help='portion of training data') 45 | parser.add_argument('--resample_momentum_iterations', type=int, 46 | default=10, help='portion of training data') 47 | parser.add_argument('--resample_prior_iterations', type=int, 48 | default=5, help='portion of training data') 49 | parser.add_argument('--epochs', type=int, default=300, 50 | help='num of training epochs') 51 | 52 | parser.add_argument('--input_size', nargs='+', 53 | default=[10], help='input size') 54 | parser.add_argument('--output_size', type=int, 55 | default=1, help='output size') 56 | parser.add_argument('--samples', type=int, 57 | default=20, help='output size') 58 | 59 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 60 | parser.add_argument('--save_last', action='store_true', default=True, 61 | help='whether to just save the last model') 62 | 63 | parser.add_argument('--num_workers', type=int, 64 | default=0, help='number of workers') 65 | parser.add_argument('--seed', type=int, default=1, help='random seed') 66 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 67 | 68 | parser.add_argument('--report_freq', type=float, 69 | default=50, help='report frequency') 70 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 71 | 72 | parser.add_argument('--q', action='store_true', default=False, 73 | help='whether to do post training quantisation') 74 | parser.add_argument('--at', action='store_true', default=False, 75 | help='whether to do training aware quantisation') 76 | 77 | 78 | def main(): 79 | args = parser.parse_args() 80 | load = False 81 | if args.save!='EXP': 82 | load=True 83 | args, writer = utils.parse_args(args) 84 | 85 | logging.info('# Start Re-training #') 86 | model_temp = ModelFactory.get_model 87 | 88 | if not load: 89 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 90 | for j in range(n_folds): 91 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 92 | 93 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 94 | 95 | logging.info('## Downloading and preparing data ##') 96 | args.dataset = "regression_" + dataset 97 | 98 | train_loader, valid_loader = get_train_loaders(args, split=j) 99 | in_shape = next(iter(train_loader))[0].shape[1] 100 | args.input_size = [in_shape] 101 | 102 | if dataset == "yacht": 103 | args.batch_size = 64 104 | 105 | model = model_temp(args.model, args.input_size, 106 | args.output_size, args.at, args, True) 107 | 108 | logging.info('## Model created: ##') 109 | logging.info(model.__repr__()) 110 | 111 | logging.info('### Loading model to parallel GPUs ###') 112 | 113 | model = utils.model_to_gpus(model, args) 114 | 115 | logging.info('### Preparing schedulers and optimizers ###') 116 | optimizer = SGLD( 117 | model.parameters(), 118 | args.learning_rate) 119 | 120 | scheduler = None 121 | 122 | logging.info('## Beginning Training ##') 123 | 124 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 125 | 126 | best_error, train_time, val_time = train.train_loop( 127 | train_loader, valid_loader, special_info="_"+dataset+"_"+str(j)) 128 | 129 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 130 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 131 | 132 | del model 133 | with torch.no_grad(): 134 | args.batch_size = 1000 135 | logging.info('## Beginning Plotting ##') 136 | evaluate_regression_uncertainty(model_temp, args) 137 | logging.info('# Finished #') 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/quantised/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/experiments/scripts/stochastic/sgld/quantised/train/__init__.py -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/quantised/train/sgld_cifar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_cifar_uncertainty 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | from src.utils import natural_keys 23 | import re 24 | 25 | parser = argparse.ArgumentParser("cifar_classifier") 26 | 27 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 28 | parser.add_argument('--model', type=str, default='conv_resnet_sgld', help='the model that we want to train') 29 | 30 | parser.add_argument('--learning_rate', type=float, 31 | default=0.00001, help='init learning rate') 32 | parser.add_argument('--loss_scaling', type=str, 33 | default='batch', help='smoothing factor') 34 | 35 | parser.add_argument('--data', type=str, default='./../../../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='cifar', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=1024, help='batch size') 40 | 41 | parser.add_argument('--valid_portion', type=float, 42 | default=0.1, help='portion of training data') 43 | 44 | parser.add_argument('--epochs', type=int, default=10, 45 | help='num of training epochs') 46 | 47 | parser.add_argument('--input_size', nargs='+', 48 | default=[1, 3, 32, 32], help='input size') 49 | parser.add_argument('--output_size', type=int, 50 | default=10, help='output size') 51 | parser.add_argument('--samples', type=int, 52 | default=20, help='output size') 53 | 54 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 55 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 56 | 57 | parser.add_argument('--save_last', action='store_true', default=True, 58 | help='whether to just save the last model') 59 | 60 | parser.add_argument('--num_workers', type=int, 61 | default=16, help='number of workers') 62 | parser.add_argument('--seed', type=int, default=1, help='random seed') 63 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 64 | 65 | parser.add_argument('--report_freq', type=float, 66 | default=50, help='report frequency') 67 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 68 | 69 | parser.add_argument('--q', action='store_true', default=True, 70 | help='whether to do post training quantisation') 71 | parser.add_argument('--at', action='store_true', default=True, 72 | help='whether to do training aware quantisation') 73 | parser.add_argument('--activation_precision', type=int, default=7, 74 | help='how many bits to be used for the activations') 75 | parser.add_argument('--weight_precision', type=int, default=8, 76 | help='how many bits to be used for the weights') 77 | 78 | 79 | def main(): 80 | args = parser.parse_args() 81 | load = False 82 | if args.save!='EXP': 83 | load=True 84 | 85 | args, writer = utils.parse_args(args) 86 | 87 | logging.info('# Start Re-training #') 88 | 89 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 90 | 91 | model_temp = ModelFactory.get_model 92 | logging.info('## Downloading and preparing data ##') 93 | train_loader, valid_loader= get_train_loaders(args) 94 | 95 | if not load: 96 | sample_names = [] 97 | for root, dirs, files in os.walk(args.load): 98 | for filename in files: 99 | if ".pt" in filename: 100 | sample_name = re.findall('weights_[0-9]*.pt', filename) 101 | if len(sample_name)>=1: 102 | sample_name = sample_name[0] 103 | sample_names.append(sample_name) 104 | sample_names.sort(key=natural_keys) 105 | sample_names = sample_names[-args.samples:] 106 | 107 | for i in range(args.samples): 108 | model= model_temp(args.model, args.input_size, args.output_size, args.q, args, True) 109 | utils.load_model(model, args.load+"/"+sample_names[i], replace=False) 110 | logging.info('### Loading model: {} ###'.format(args.load+"/"+sample_names[i])) 111 | 112 | if args.at: 113 | logging.info('## Preparing model for quantization aware training ##') 114 | quant_utils.prepare_model(model, args) 115 | 116 | logging.info('## Model created: ##') 117 | logging.info(model.__repr__()) 118 | 119 | logging.info('### Loading model to parallel GPUs ###') 120 | model = utils.model_to_gpus(model, args) 121 | 122 | logging.info('### Preparing schedulers and optimizers ###') 123 | optimizer = torch.optim.SGD( 124 | model.parameters(), 125 | args.learning_rate, 126 | momentum=0.9, 127 | weight_decay=0.0) 128 | 129 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 130 | optimizer, args.epochs) 131 | 132 | logging.info('## Beginning Training ##') 133 | 134 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 135 | 136 | sample_name = re.findall('[0-9]+', sample_names[i]) 137 | sample_name = "_"+sample_name[0] 138 | best_error, train_time, val_time = train.train_loop( 139 | train_loader, valid_loader, special_info=sample_name) 140 | 141 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 142 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 143 | 144 | logging.info('## Beginning Plotting ##') 145 | del model 146 | 147 | with torch.no_grad(): 148 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args, False) 149 | if args.q: 150 | quant_utils.prepare_model(model, args) 151 | quant_utils.convert(model) 152 | 153 | model.load_ensemble(args) 154 | 155 | logging.info('## Model re-created: ##') 156 | logging.info(model.__repr__()) 157 | 158 | if not args.q: 159 | model = utils.model_to_gpus(model, args) 160 | 161 | model.eval() 162 | evaluate_cifar_uncertainty(model, args) 163 | 164 | logging.info('# Finished #') 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/quantised/train/sgld_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_mnist_uncertainty 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | from src.utils import natural_keys 23 | import re 24 | 25 | parser = argparse.ArgumentParser("mnist_classifier") 26 | 27 | parser.add_argument('--task', type=str, default='classification', help='the main task; defines loss') 28 | parser.add_argument('--model', type=str, default='conv_lenet_sgld', help='the model that we want to train') 29 | 30 | parser.add_argument('--learning_rate', type=float, 31 | default=0.00001, help='init learning rate') 32 | parser.add_argument('--loss_scaling', type=str, 33 | default='batch', help='smoothing factor') 34 | parser.add_argument('--data', type=str, default='./../../../../../data/', 35 | help='location of the data corpus') 36 | parser.add_argument('--dataset', type=str, default='mnist', 37 | help='dataset') 38 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 39 | 40 | parser.add_argument('--valid_portion', type=float, 41 | default=0.1, help='portion of training data') 42 | 43 | parser.add_argument('--epochs', type=int, default=10, 44 | help='num of training epochs') 45 | 46 | parser.add_argument('--input_size', nargs='+', 47 | default=[1, 1, 28, 28], help='input size') 48 | parser.add_argument('--output_size', type=int, 49 | default=10, help='output size') 50 | parser.add_argument('--samples', type=int, 51 | default=20, help='output size') 52 | 53 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 54 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 55 | 56 | parser.add_argument('--save_last', action='store_true', default=True, 57 | help='whether to just save the last model') 58 | 59 | parser.add_argument('--num_workers', type=int, 60 | default=16, help='number of workers') 61 | parser.add_argument('--seed', type=int, default=1, help='random seed') 62 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 63 | 64 | parser.add_argument('--report_freq', type=float, 65 | default=50, help='report frequency') 66 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 67 | 68 | parser.add_argument('--q', action='store_true', default=True, 69 | help='whether to do post training quantisation') 70 | parser.add_argument('--at', action='store_true', default=True, 71 | help='whether to do training aware quantisation') 72 | parser.add_argument('--activation_precision', type=int, default=7, 73 | help='how many bits to be used for the activations') 74 | parser.add_argument('--weight_precision', type=int, default=8, 75 | help='how many bits to be used for the weights') 76 | 77 | 78 | def main(): 79 | args = parser.parse_args() 80 | load = False 81 | if args.save!='EXP': 82 | load=True 83 | 84 | args, writer = utils.parse_args(args) 85 | 86 | logging.info('# Start Re-training #') 87 | 88 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 89 | 90 | model_temp = ModelFactory.get_model 91 | logging.info('## Downloading and preparing data ##') 92 | train_loader, valid_loader= get_train_loaders(args) 93 | 94 | if not load: 95 | sample_names = [] 96 | for root, dirs, files in os.walk(args.load): 97 | for filename in files: 98 | if ".pt" in filename: 99 | sample_name = re.findall('weights_[0-9]*.pt', filename) 100 | if len(sample_name)>=1: 101 | sample_name = sample_name[0] 102 | sample_names.append(sample_name) 103 | sample_names.sort(key=natural_keys) 104 | sample_names = sample_names[-args.samples:] 105 | 106 | for i in range(args.samples): 107 | model= model_temp(args.model, args.input_size, args.output_size, args.at, args, True) 108 | utils.load_model(model, args.load+"/"+sample_names[i], replace=False) 109 | logging.info('### Loading model: {} ###'.format(args.load+"/"+sample_names[i])) 110 | 111 | if args.at: 112 | logging.info('## Preparing model for quantization aware training ##') 113 | quant_utils.prepare_model(model, args) 114 | 115 | logging.info('## Model created: ##') 116 | logging.info(model.__repr__()) 117 | 118 | logging.info('### Loading model to parallel GPUs ###') 119 | model = utils.model_to_gpus(model, args) 120 | 121 | logging.info('### Preparing schedulers and optimizers ###') 122 | optimizer = torch.optim.SGD( 123 | model.parameters(), 124 | args.learning_rate, 125 | momentum=0.9, 126 | weight_decay=0.0) 127 | 128 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 129 | optimizer, args.epochs) 130 | 131 | logging.info('## Beginning Training ##') 132 | 133 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 134 | sample_name = re.findall('[0-9]+', sample_names[i]) 135 | sample_name = "_"+sample_name[0] 136 | best_error, train_time, val_time = train.train_loop( 137 | train_loader, valid_loader, special_info=sample_name) 138 | 139 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 140 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 141 | 142 | logging.info('## Beginning Plotting ##') 143 | del model 144 | 145 | with torch.no_grad(): 146 | model = model_temp(args.model, args.input_size, args.output_size, args.q, args, False) 147 | if args.q: 148 | quant_utils.prepare_model(model, args) 149 | quant_utils.convert(model) 150 | 151 | model.load_ensemble(args) 152 | 153 | logging.info('## Model re-created: ##') 154 | logging.info(model.__repr__()) 155 | 156 | if not args.q: 157 | model = utils.model_to_gpus(model, args) 158 | 159 | model.eval() 160 | evaluate_mnist_uncertainty(model, args) 161 | 162 | logging.info('# Finished #') 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /experiments/scripts/stochastic/sgld/quantised/train/sgld_regression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | from datetime import timedelta 5 | import logging 6 | 7 | sys.path.append("../") 8 | sys.path.append("../../") 9 | sys.path.append("../../../") 10 | sys.path.append("../../../../") 11 | sys.path.append("../../../../../") 12 | sys.path.append("../../../../../../") 13 | sys.path.append("../../../../../../../") 14 | 15 | from experiments.utils import evaluate_regression_uncertainty, REGRESSION_DATASETS 16 | from src.data import * 17 | from src.trainer import Trainer 18 | from src.models import ModelFactory 19 | from src.losses import LOSS_FACTORY 20 | import src.utils as utils 21 | import src.quant_utils as quant_utils 22 | from src.utils import natural_keys 23 | import re 24 | 25 | parser = argparse.ArgumentParser("stochastic_sgld_regression") 26 | 27 | parser.add_argument('--task', type=str, default='regression', help='the main task; defines loss') 28 | parser.add_argument('--model', type=str, default='linear_sgld', help='the model that we want to train') 29 | 30 | parser.add_argument('--learning_rate', type=float, 31 | default=0.00001, help='init learning rate') 32 | parser.add_argument('--loss_scaling', type=str, 33 | default='batch', help='smoothing factor') 34 | 35 | parser.add_argument('--data', type=str, default='./../../../../../data/', 36 | help='location of the data corpus') 37 | parser.add_argument('--dataset', type=str, default='regression', 38 | help='dataset') 39 | parser.add_argument('--batch_size', type=int, default=1000, help='batch size') 40 | 41 | parser.add_argument('--valid_portion', type=float, 42 | default=0.2, help='portion of training data') 43 | parser.add_argument('--epochs', type=int, default=10, 44 | help='num of training epochs') 45 | 46 | parser.add_argument('--input_size', nargs='+', 47 | default=[1], help='input size') 48 | parser.add_argument('--output_size', type=int, 49 | default=1, help='output size') 50 | parser.add_argument('--samples', type=int, 51 | default=20, help='output size') 52 | 53 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 54 | parser.add_argument('--load', type=str, default='EXP', help='to load pre-trained model') 55 | 56 | parser.add_argument('--save_last', action='store_true', default=True, 57 | help='whether to just save the last model') 58 | 59 | parser.add_argument('--num_workers', type=int, 60 | default=0, help='number of workers') 61 | parser.add_argument('--seed', type=int, default=1, help='random seed') 62 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 63 | 64 | parser.add_argument('--report_freq', type=float, 65 | default=50, help='report frequency') 66 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 67 | 68 | parser.add_argument('--q', action='store_true', default=True, 69 | help='whether to do post training quantisation') 70 | parser.add_argument('--at', action='store_true', default=True, 71 | help='whether to do training aware quantisation') 72 | parser.add_argument('--activation_precision', type=int, default=7, 73 | help='how many bits to be used for the activations') 74 | parser.add_argument('--weight_precision', type=int, default=8, 75 | help='how many bits to be used for the weights') 76 | 77 | def main(): 78 | args = parser.parse_args() 79 | load = False 80 | if args.save!='EXP': 81 | load=True 82 | args, writer = utils.parse_args(args) 83 | model_temp = ModelFactory.get_model 84 | 85 | logging.info('# Start Re-training #') 86 | if not load: 87 | for i, (dataset, n_folds) in enumerate(REGRESSION_DATASETS): 88 | for j in range(n_folds): 89 | logging.info('## Dataset: {}, Split: {} ##'.format(dataset, j)) 90 | criterion = LOSS_FACTORY[args.task](args, args.loss_scaling) 91 | 92 | logging.info('## Downloading and preparing data ##') 93 | args.dataset = "regression_" + dataset 94 | 95 | train_loader, valid_loader = get_train_loaders(args, split=j) 96 | in_shape = next(iter(train_loader))[0].shape[1] 97 | args.input_size = [in_shape] 98 | 99 | sample_names = [] 100 | for root, dirs, files in os.walk(args.load): 101 | for filename in files: 102 | if ".pt" in filename: 103 | sample_name = re.findall('weights_{}_{}_[0-9]*.pt'.format(dataset, j), filename) 104 | if len(sample_name)>=1: 105 | sample_name = sample_name[0] 106 | sample_names.append(sample_name) 107 | sample_names.sort(key=natural_keys) 108 | sample_names = sample_names[-args.samples:] 109 | for s in range(args.samples): 110 | model= model_temp(args.model, args.input_size, args.output_size, args.q, args, True) 111 | utils.load_model(model, args.load+"/"+sample_names[s], replace=False) 112 | logging.info('### Loading model: {} ###'.format(args.load+"/"+sample_names[s])) 113 | 114 | if args.at: 115 | logging.info('## Preparing model for quantization aware training ##') 116 | quant_utils.prepare_model(model, args) 117 | 118 | logging.info('## Model created: ##') 119 | logging.info(model.__repr__()) 120 | 121 | logging.info('### Loading model to parallel GPUs ###') 122 | model = utils.model_to_gpus(model, args) 123 | logging.info('### Preparing schedulers and optimizers ###') 124 | optimizer = torch.optim.SGD( 125 | model.parameters(), 126 | args.learning_rate, 127 | momentum=0.9, 128 | weight_decay=0.0) 129 | 130 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 131 | optimizer, args.epochs) 132 | 133 | logging.info('## Beginning Training ##') 134 | 135 | train = Trainer(model, criterion, optimizer, scheduler, args, writer=writer) 136 | sample_name = re.findall('[0-9]+', sample_names[s]) 137 | sample_name = "_"+dataset+"_"+str(j)+"_"+sample_name[1] 138 | best_error, train_time, val_time = train.train_loop( 139 | train_loader, valid_loader, special_info=sample_name) 140 | 141 | logging.info('## Finished training, the best observed validation error: {}, total training time: {}, total validation time: {} ##'.format( 142 | best_error, timedelta(seconds=train_time), timedelta(seconds=val_time))) 143 | 144 | logging.info('## Beginning Plotting ##') 145 | del model 146 | with torch.no_grad(): 147 | logging.info('## Beginning Plotting ##') 148 | evaluate_regression_uncertainty(model_temp, args) 149 | 150 | logging.info('# Finished #') 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/poster.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | autopep8==1.5.4 3 | backcall==0.2.0 4 | brewer2mpl==1.4.1 5 | cachetools==4.1.1 6 | certifi==2020.11.8 7 | chardet==3.0.4 8 | cycler==0.10.0 9 | dataclasses==0.8 10 | decorator==4.4.2 11 | future==0.18.2 12 | google-auth==1.23.0 13 | google-auth-oauthlib==0.4.2 14 | grpcio==1.33.2 15 | idna==2.10 16 | importlib-metadata==3.1.0 17 | ipykernel==5.3.4 18 | ipython==7.16.1 19 | ipython-genutils==0.2.0 20 | jedi==0.17.2 21 | joblib==0.17.0 22 | jupyter-client==6.1.7 23 | jupyter-core==4.7.0 24 | kiwisolver==1.3.1 25 | Markdown==3.3.3 26 | matplotlib==3.3.3 27 | numpy==1.19.4 28 | oauthlib==3.1.0 29 | pandas==1.1.4 30 | parso==0.7.1 31 | pexpect==4.8.0 32 | pickleshare==0.7.5 33 | Pillow==8.0.1 34 | prompt-toolkit==3.0.8 35 | protobuf==3.14.0 36 | ptyprocess==0.6.0 37 | pyasn1==0.4.8 38 | pyasn1-modules==0.2.8 39 | pycodestyle==2.6.0 40 | Pygments==2.7.2 41 | pyparsing==2.4.7 42 | python-dateutil==2.8.1 43 | pytz==2020.4 44 | pyzmq==20.0.0 45 | requests==2.25.0 46 | requests-oauthlib==1.3.0 47 | rsa==4.6 48 | scikit-learn==0.23.2 49 | scipy==1.5.4 50 | six==1.15.0 51 | tensorboard==2.4.0 52 | tensorboard-plugin-wit==1.7.0 53 | threadpoolctl==2.1.0 54 | toml==0.10.2 55 | torch==1.7.0 56 | torchvision==0.8.1 57 | tornado==6.1 58 | tqdm==4.19.9 59 | traitlets==4.3.3 60 | typing-extensions==3.7.4.3 61 | uncertainty-metrics==0.0.81 62 | urllib3==1.26.2 63 | wcwidth==0.2.5 64 | Werkzeug==1.0.1 65 | xlrd==1.2.0 66 | zipp==3.4.0 67 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/src/__init__.py -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | 5 | LOSS_FACTORY = {'classification': lambda args, scaling: ClassificationLoss(args, scaling), 6 | 'regression': lambda args, scaling: RegressionLoss(args, scaling)} 7 | 8 | class Loss(nn.Module): 9 | def __init__(self, args, scaling): 10 | super(Loss, self).__init__() 11 | self.args = args 12 | self.scaling = scaling 13 | 14 | class ClassificationLoss(Loss): 15 | def __init__(self, args, scaling): 16 | super(ClassificationLoss, self).__init__(args, scaling) 17 | self.ce = F.nll_loss 18 | def forward(self, output, target, kl, gamma, n_batches, n_points): 19 | if self.scaling=='whole': 20 | ce = n_points*self.ce(torch.log(output+1e-8), target) * self.args.loss_multiplier 21 | kl = kl / n_batches 22 | elif self.scaling=='batch': 23 | ce = self.ce(torch.log(output+1e-8), target) 24 | kl = kl / (target.shape[0]*n_batches) 25 | else: 26 | raise NotImplementedError('Other scaling not implemented!') 27 | loss = ce + gamma * kl 28 | 29 | return loss, ce, kl 30 | 31 | class RegressionLoss(Loss): 32 | def __init__(self, args, scaling): 33 | super(RegressionLoss, self).__init__(args, scaling) 34 | 35 | def forward(self, output, target, kl, gamma, n_batches, n_points): 36 | mean = output[0] 37 | var = output[1] 38 | precision = 1/(var+1e-8) 39 | if self.scaling == 'whole': 40 | heteroscedastic_loss = n_points * \ 41 | torch.mean(torch.sum(precision * (target - mean)**2 + 42 | torch.log(var+1e-8), 1), 0) * self.args.loss_multiplier 43 | kl = kl / n_batches 44 | elif self.scaling == 'batch': 45 | heteroscedastic_loss = torch.mean( 46 | torch.sum(precision * (target - mean)**2 + torch.log(var+1e-8), 1), 0) 47 | kl = kl / (target.shape[0]*n_batches) 48 | else: 49 | raise NotImplementedError('Other scaling not implemented!') 50 | loss = heteroscedastic_loss + gamma*kl 51 | return loss, heteroscedastic_loss, kl 52 | 53 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from src.models.pointwise.models_p import LinearNetwork, ConvNetwork_LeNet, ConvNetwork_ResNet 2 | from src.models.stochastic.mcdropout.models_mc import LinearNetwork as LinearNetworkMC 3 | from src.models.stochastic.mcdropout.models_mc import ConvNetwork_LeNet as ConvNetwork_LeNetMC 4 | from src.models.stochastic.mcdropout.models_mc import ConvNetwork_ResNet as ConvNetwork_ResNetMC 5 | from src.models.stochastic.bbb.models_bbb import LinearNetwork as LinearNetworkBBB 6 | from src.models.stochastic.bbb.models_bbb import ConvNetwork_LeNet as ConvNetwork_LeNetBBB 7 | from src.models.stochastic.bbb.models_bbb import ConvNetwork_ResNet as ConvNetwork_ResNetBBB 8 | from src.models.stochastic.sgld.models_sgld import Network as NetworkSGLD 9 | 10 | 11 | class ModelFactory(): 12 | def __init__(self): 13 | pass 14 | 15 | @staticmethod 16 | def get_model(model, input_size, output_size, q, args, training_mode=True): 17 | net = None 18 | if model == "linear": 19 | net = LinearNetwork(input_size, output_size, q, args) 20 | elif model == "conv_lenet": 21 | net = ConvNetwork_LeNet(input_size, output_size, q, args) 22 | elif model == "conv_resnet": 23 | net = ConvNetwork_ResNet(input_size, output_size, q, args) 24 | elif model == "linear_mc": 25 | net = LinearNetworkMC(input_size, output_size, q, args) 26 | elif model == "conv_lenet_mc": 27 | net = ConvNetwork_LeNetMC(input_size, output_size, q, args) 28 | elif model == "conv_resnet_mc": 29 | net = ConvNetwork_ResNetMC(input_size, output_size, q, args) 30 | elif model == "linear_bbb": 31 | net = LinearNetworkBBB(input_size, output_size, q, args) 32 | elif model == "conv_lenet_bbb": 33 | net = ConvNetwork_LeNetBBB(input_size, output_size, q, args) 34 | elif model == "conv_resnet_bbb": 35 | net = ConvNetwork_ResNetBBB(input_size, output_size, q, args) 36 | elif "sgld" in model: 37 | net = NetworkSGLD(input_size, output_size, q, args, training_mode) 38 | else: 39 | raise NotImplementedError("Other models not implemented") 40 | 41 | return net -------------------------------------------------------------------------------- /src/models/pointwise/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/src/models/pointwise/__init__.py -------------------------------------------------------------------------------- /src/models/stochastic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/src/models/stochastic/__init__.py -------------------------------------------------------------------------------- /src/models/stochastic/bbb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/src/models/stochastic/bbb/__init__.py -------------------------------------------------------------------------------- /src/models/stochastic/bbb/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import ReLU 5 | from torch.autograd import Variable 6 | from src.models.stochastic.bbb.utils_bbb import kl_divergence, softplusinv 7 | import copy 8 | 9 | class Conv2d(nn.Conv2d): 10 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 11 | padding=0, dilation=1, groups=1, 12 | bias=False, padding_mode='zeros', sigma_prior=-2, args=None): 13 | 14 | super(Conv2d, self).__init__(in_channels, out_channels,kernel_size, stride, padding, dilation, groups, bias, padding_mode) 15 | self.weight.data.uniform_(-0.01, 0.01) 16 | self.std = torch.nn.Parameter( 17 | torch.zeros_like(self.weight).uniform_(-10, -10), requires_grad=True) 18 | self.std_prior = torch.nn.Parameter(torch.tensor((1,))*sigma_prior, requires_grad=False) 19 | self.add_weight = torch.nn.quantized.FloatFunctional() 20 | self.mul_noise = torch.nn.quantized.FloatFunctional() 21 | self.args = args 22 | 23 | def forward(self, X): 24 | if self.training: 25 | Z_mean = F.conv2d(X, self.weight, None, self.stride, self.padding, self.dilation, self.groups) 26 | Z_std = torch.sqrt(1e-8+F.conv2d(torch.pow(X, 2), torch.pow(F.softplus(self.std), 2), 27 | None, self.stride, self.padding, self.dilation, self.groups)) 28 | Z_noise = Variable(Z_mean.new( 29 | Z_mean.size()).normal_()) 30 | 31 | Z = Z_mean + Z_std * Z_noise 32 | Z = Z + self.bias if self.bias is not None else Z 33 | else: 34 | noise = Variable(self.weight.data.new( 35 | self.weight.data.size()).normal_()) 36 | 37 | std = self.mul_noise.mul(noise, F.softplus(self.std)) 38 | weight = self.add_weight.add(self.weight, std) 39 | Z = F.conv2d(X, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 40 | return Z 41 | 42 | 43 | def get_kl_divergence(self): 44 | kl = kl_divergence(self.weight, F.softplus(self.std), 45 | torch.zeros_like(self.weight).to(self.weight.device), 46 | (torch.ones_like(self.std)*self.std_prior).to(self.weight.device)) 47 | return kl 48 | 49 | class ConvBn2d(torch.nn.Sequential): 50 | def __init__(self, conv, bn): 51 | assert type(conv) == Conv2d and type(bn) == torch.nn.BatchNorm2d, \ 52 | 'Incorrect types for input modules{}{}'.format( 53 | type(conv), type(bn)) 54 | super(ConvBn2d, self).__init__(conv, bn) 55 | 56 | class ConvReLU2d(torch.nn.Sequential): 57 | def __init__(self, conv, relu): 58 | assert type(conv) == Conv2d and type(relu) == ReLU, \ 59 | 'Incorrect types for input modules{}{}'.format( 60 | type(conv), type(relu)) 61 | super(ConvReLU2d, self).__init__(conv, relu) 62 | 63 | class ConvBnReLU2d(torch.nn.Sequential): 64 | def __init__(self, conv, bn, relu): 65 | assert type(conv) == Conv2d and type(bn) == torch.nn.BatchNorm2d and \ 66 | type(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \ 67 | .format(type(conv), type(bn), type(relu)) 68 | super(ConvBnReLU2d, self).__init__(conv, bn, relu) 69 | 70 | def fuse_conv_bn_weights(conv_w, conv_b, conv_std, bn_rm, bn_rv, bn_eps, bn_w, bn_b): 71 | if conv_b is None: 72 | conv_b = bn_rm.new_zeros(bn_rm.shape) 73 | bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) 74 | 75 | c = (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 76 | conv_w = conv_w * c 77 | conv_std = softplusinv(F.softplus(conv_std) * c) 78 | conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b 79 | 80 | return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b), torch.nn.Parameter(conv_std) 81 | 82 | def fuse_conv_bn_eval(conv, bn): 83 | assert(not (conv.training or bn.training)), "Fusion only for eval!" 84 | fused_conv = copy.deepcopy(conv) 85 | 86 | fused_conv.weight, fused_conv.bias, fused_conv.std= \ 87 | fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, fused_conv.std, 88 | bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) 89 | 90 | return fused_conv 91 | 92 | def fuse_conv_bn(conv, bn): 93 | assert(conv.training == bn.training),\ 94 | "Conv and BN both must be in the same mode (train or eval)." 95 | 96 | if conv.training: 97 | assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' 98 | assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' 99 | assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' 100 | return ConvBn2d(conv, bn) 101 | else: 102 | return fuse_conv_bn_eval(conv, bn) 103 | 104 | def fuse_conv_bn_relu(conv, bn, relu): 105 | assert(conv.training == bn.training == relu.training),\ 106 | "Conv and BN both must be in the same mode (train or eval)." 107 | 108 | if conv.training: 109 | map_to_fused_module_train = { 110 | Conv2d: ConvBnReLU2d 111 | } 112 | assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' 113 | assert bn.affine, 'Only support fusing BatchNorm with affine set to True' 114 | assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True' 115 | fused_module = map_to_fused_module_train.get(type(conv)) 116 | if fused_module is not None: 117 | return fused_module(conv, bn, relu) 118 | else: 119 | raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu))) 120 | 121 | else: 122 | map_to_fused_module_eval = { 123 | Conv2d: ConvReLU2d, 124 | } 125 | fused_module = map_to_fused_module_eval[type(conv)] 126 | if fused_module is not None: 127 | return fused_module(fuse_conv_bn_eval(conv, bn), relu) 128 | else: 129 | raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) 130 | 131 | -------------------------------------------------------------------------------- /src/models/stochastic/bbb/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import ReLU 5 | from torch.autograd import Variable 6 | from src.models.stochastic.bbb.utils_bbb import kl_divergence 7 | 8 | class Linear(nn.Linear): 9 | def __init__(self, in_features, out_features, bias, sigma_prior=1.0, args = None): 10 | super(Linear, self).__init__(in_features, out_features, bias) 11 | self.std_prior = torch.nn.Parameter( 12 | torch.ones((1,))*sigma_prior, requires_grad=False) 13 | 14 | self.weight.data.uniform_(-0.01, 0.01) 15 | self.std = nn.Parameter(torch.zeros_like(self.weight).uniform_(-3, -3)) 16 | 17 | self.add_weight = torch.nn.quantized.FloatFunctional() 18 | self.mul_noise = torch.nn.quantized.FloatFunctional() 19 | self.args = args 20 | 21 | if self.bias is not None: 22 | self.bias.data.uniform_(-0.01, 0.01) 23 | 24 | def get_kl_divergence(self): 25 | return kl_divergence(self.weight, F.softplus(self.std), 26 | torch.zeros_like(self.weight).to( 27 | self.weight.device), 28 | (torch.ones_like(self.std)*self.std_prior).to(self.weight.device)) 29 | 30 | def forward(self, x): 31 | output = None 32 | if self.training: 33 | mean = torch.mm(x, self.weight.t()) 34 | std = torch.sqrt(1e-8+torch.mm(torch.pow(x, 2), 35 | torch.pow(F.softplus(self.std).t(), 2))) 36 | noise = Variable(mean.new( 37 | mean.size()).normal_()) 38 | 39 | bias = self.bias if self.bias is not None else 0.0 40 | output = mean + std * noise + bias 41 | 42 | else: 43 | std = F.softplus(self.std) 44 | noise = Variable(self.weight.data.new( 45 | self.weight.size()).normal_()) 46 | std = self.mul_noise.mul(noise, std) 47 | weight_sample = self.add_weight.add(self.weight, std) 48 | bias = self.bias if self.bias is not None else 0.0 49 | 50 | output = torch.mm(x, weight_sample.t()) + bias 51 | 52 | return output 53 | 54 | class LinearReLU(torch.nn.Sequential): 55 | def __init__(self, linear, relu): 56 | assert type(linear) == Linear and type(relu) == ReLU, \ 57 | 'Incorrect types for input modules{}{}'.format( 58 | type(linear), type(relu)) 59 | super(LinearReLU, self).__init__(linear, relu) 60 | -------------------------------------------------------------------------------- /src/models/stochastic/bbb/quantized/__init__.py: -------------------------------------------------------------------------------- 1 | NOISE_SCALE = float(0.02362204724) 2 | NOISE_ZERO_POINT = int(0) -------------------------------------------------------------------------------- /src/models/stochastic/bbb/quantized/linear_qat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | 5 | from src.models.stochastic.bbb.linear import Linear as LinearBBB 6 | from src.models.stochastic.bbb.linear import LinearReLU as LinearReLUBBB 7 | 8 | class Linear(LinearBBB): 9 | _FLOAT_MODULE = LinearBBB 10 | def __init__(self, in_features, out_features, bias=False, qconfig=None, args=None): 11 | super(Linear, self).__init__(in_features, out_features, bias, args=args) 12 | assert qconfig, 'qconfig must be provided for QAT module' 13 | self.qconfig = qconfig 14 | self.weight_fake_quant = qconfig.weight() 15 | self.activation_post_process = qconfig.activation() 16 | self.std_fake_quant = qconfig.weight() 17 | 18 | def _forward(self, X): 19 | output = None 20 | weight = self.weight_fake_quant(self.weight) 21 | std = self.std_fake_quant(F.softplus(self.std)) 22 | if self.training: 23 | mean = torch.mm(X, weight.t()) 24 | std = torch.sqrt(1e-8+torch.mm(torch.pow(X,2), torch.pow(std.t(), 2))) 25 | noise = Variable(mean.new( 26 | mean.size()).normal_()) 27 | 28 | bias = self.bias if self.bias is not None else 0.0 29 | output = mean + std * noise + bias 30 | else: 31 | noise = Variable(weight.new( 32 | weight.size()).normal_()) 33 | 34 | std = self.mul_noise.mul(noise, std) 35 | weight_sample = self.add_weight.add(weight, std) 36 | bias = self.bias if self.bias is not None else 0.0 37 | output = torch.mm(X, weight_sample.t()) + bias 38 | return output 39 | 40 | def forward(self, X): 41 | return self.activation_post_process(self._forward(X)) 42 | 43 | def _get_name(self): 44 | return 'QATLinear' 45 | 46 | @classmethod 47 | def from_float(cls, mod, qconfig=None): 48 | assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \ 49 | cls._FLOAT_MODULE.__name__ 50 | if not qconfig: 51 | assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' 52 | assert mod.qconfig, 'Input float module must have a valid qconfig' 53 | if type(mod) == LinearReLUBBB: 54 | mod = mod[0] 55 | 56 | activation_post_process = mod.activation_post_process 57 | 58 | qconfig = mod.qconfig 59 | qat_linear = cls(mod.in_features, mod.out_features, mod.bias is not None, qconfig) 60 | qat_linear.activation_post_process = activation_post_process 61 | qat_linear.std_prior = mod.std_prior 62 | qat_linear.weight = mod.weight 63 | qat_linear.std = mod.std 64 | qat_linear.add_weight = mod.add_weight 65 | qat_linear.add_weight.activation_post_process = qconfig.weight() 66 | qat_linear.mul_noise = mod.mul_noise 67 | qat_linear.mul_noise.activation_post_process = qconfig.weight() 68 | qat_linear.bias = mod.bias 69 | qat_linear.args = mod.args 70 | return qat_linear 71 | 72 | class LinearReLU(Linear): 73 | _FLOAT_MODULE = LinearReLUBBB 74 | def __init__(self, in_features, out_features, bias=False, 75 | qconfig=None, args=None): 76 | super(LinearReLU, self).__init__(in_features, out_features, bias = bias, qconfig=qconfig, args=args) 77 | 78 | def forward(self, input): 79 | return self.activation_post_process(F.relu(self._forward(input))) 80 | 81 | @classmethod 82 | def from_float(cls, mod, qconfig=None): 83 | return super(LinearReLU, cls).from_float(mod, qconfig) 84 | 85 | def _get_name(self): 86 | return 'QATLinearReLU' 87 | -------------------------------------------------------------------------------- /src/models/stochastic/bbb/utils_bbb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def kl_divergence(mu, sigma, mu_prior, sigma_prior): 4 | kl = 0.5 * (2 * torch.log(sigma_prior / sigma) - 1 + (sigma / sigma_prior).pow(2) + ((mu_prior - mu) / sigma_prior).pow(2)).sum() 5 | return kl 6 | 7 | def softplusinv(x): 8 | return torch.log(torch.exp(x)-1.) -------------------------------------------------------------------------------- /src/models/stochastic/mcdropout/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/src/models/stochastic/mcdropout/__init__.py -------------------------------------------------------------------------------- /src/models/stochastic/mcdropout/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.quantized import QFunctional 4 | 5 | 6 | class BernoulliDropout(nn.Module): 7 | def __init__(self, p=0.0): 8 | super(BernoulliDropout, self).__init__() 9 | self.p = torch.nn.Parameter(torch.ones((1,))*p, requires_grad=False) 10 | self.multiplier = torch.nn.Parameter(torch.ones((1,))/(1.0 - self.p), requires_grad=False) 11 | 12 | self.mul_mask = torch.nn.quantized.FloatFunctional() 13 | self.mul_scalar = torch.nn.quantized.FloatFunctional() 14 | 15 | def forward(self, x): 16 | if self.p<=0.0: 17 | return x 18 | mask_ = None 19 | if len(x.shape)<=2: 20 | if x.is_cuda: 21 | mask_ = torch.cuda.FloatTensor(x.shape).bernoulli_(1.-self.p) 22 | else: 23 | mask_ = torch.FloatTensor(x.shape).bernoulli_(1.-self.p) 24 | else: 25 | if x.is_cuda: 26 | mask_ = torch.cuda.FloatTensor(x.shape[:2]).bernoulli_( 27 | 1.-self.p) 28 | else: 29 | mask_ = torch.FloatTensor(x.shape[:2]).bernoulli_( 30 | 1.-self.p) 31 | if isinstance(self.mul_mask, QFunctional): 32 | scale = self.mul_mask.scale 33 | zero_point = self.mul_mask.zero_point 34 | mask_ = torch.quantize_per_tensor(mask_, scale, zero_point, dtype=torch.quint8) 35 | if len(x.shape) > 2: 36 | mask_ = mask_.view( 37 | mask_.shape[0], mask_.shape[1], 1, 1).expand(-1, -1, x.shape[2], x.shape[3]) 38 | x = self.mul_mask.mul(x, mask_) 39 | x = self.mul_scalar.mul_scalar(x, self.multiplier) 40 | return x 41 | 42 | def extra_repr(self): 43 | return 'p={}, quant={}'.format( 44 | self.p.item(), isinstance( 45 | self.mul_mask, QFunctional) 46 | ) 47 | -------------------------------------------------------------------------------- /src/models/stochastic/sgld/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/src/models/stochastic/sgld/__init__.py -------------------------------------------------------------------------------- /src/models/stochastic/sgld/utils_sgld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | from numpy.random import gamma 4 | 5 | class SGLD(Optimizer): 6 | def __init__(self, params, lr=1e-2, base_C=0.05, gauss_sig=0.1, alpha0=10, beta0=10): 7 | self.eps = 1e-6 8 | self.alpha0 = alpha0 9 | self.beta0 = beta0 10 | 11 | if gauss_sig == 0: 12 | self.weight_decay = 0 13 | else: 14 | self.weight_decay = 1 / (gauss_sig ** 2) 15 | 16 | if self.weight_decay <= 0.0: 17 | raise ValueError( 18 | "Invalid weight_decay value: {}".format(self.weight_decay)) 19 | if lr < 0.0: 20 | raise ValueError("Invalid learning rate: {}".format(lr)) 21 | if base_C < 0: 22 | raise ValueError("Invalid friction term: {}".format(base_C)) 23 | 24 | defaults = dict( 25 | lr=lr, 26 | base_C=base_C, 27 | ) 28 | super(SGLD, self).__init__(params, defaults) 29 | 30 | def step(self, burn_in=False, resample_momentum=False, resample_prior=False): 31 | loss = None 32 | 33 | # iterate over blocks -> the ones defined in defaults. We dont use groups. 34 | for group in self.param_groups: 35 | for p in group["params"]: # these are weight and bias matrices 36 | if p.grad is None: 37 | continue 38 | state = self.state[p] # define dict for each individual param 39 | if len(state) == 0: 40 | state["iteration"] = 0 41 | state["tau"] = torch.ones_like(p) 42 | state["g"] = torch.ones_like(p) 43 | state["V_hat"] = torch.ones_like(p) 44 | state["v_momentum"] = torch.zeros_like(p) 45 | state['weight_decay'] = self.weight_decay 46 | 47 | if resample_prior: 48 | alpha = self.alpha0 + p.data.nelement() / 2 49 | beta = self.beta0 + (p.data ** 2).sum().item() / 2 50 | gamma_sample = gamma( 51 | shape=alpha, scale=1 / (beta+self.eps), size=None) 52 | state['weight_decay'] = gamma_sample 53 | 54 | base_C, lr = group["base_C"], group["lr"] 55 | weight_decay = state["weight_decay"] 56 | tau, g, V_hat = state["tau"], state["g"], state["V_hat"] 57 | 58 | d_p = p.grad.data 59 | if weight_decay != 0: 60 | d_p.add_(p.data, alpha=weight_decay) 61 | 62 | if burn_in: 63 | tau.add_(-tau * (g ** 2) / ( 64 | V_hat + self.eps) + 1) 65 | tau_inv = 1. / (tau + self.eps) 66 | g.add_(-tau_inv * g + tau_inv * d_p) 67 | V_hat.add_(-tau_inv * V_hat + tau_inv * (d_p ** 2)) 68 | 69 | V_sqrt = torch.sqrt(V_hat) 70 | V_inv_sqrt = 1. / (V_sqrt + self.eps) 71 | 72 | if resample_momentum: 73 | state["v_momentum"] = torch.normal(mean=torch.zeros_like(d_p), 74 | std=torch.sqrt((lr ** 2) * V_inv_sqrt)) 75 | v_momentum = state["v_momentum"] 76 | 77 | noise_var = (2. * (lr ** 2) * V_inv_sqrt * base_C - (lr ** 4)) 78 | noise_std = torch.sqrt(torch.clamp(noise_var, min=1e-16)) 79 | 80 | noise_sample = torch.normal(mean=torch.zeros_like( 81 | d_p), std=torch.ones_like(d_p) * noise_std) 82 | 83 | v_momentum.add_(- (lr ** 2) * V_inv_sqrt * 84 | d_p - base_C * v_momentum + noise_sample) 85 | 86 | v_momentum[v_momentum != v_momentum] = 0 87 | v_momentum[v_momentum == float('inf')] = 0 88 | v_momentum[v_momentum == -float('inf') ] = 0 89 | 90 | p.data.add_(v_momentum) 91 | 92 | return loss 93 | -------------------------------------------------------------------------------- /src/quant_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.quantization.observer import MovingAverageMinMaxObserver 3 | from torch.quantization.fake_quantize import FakeQuantize 4 | from torch.quantization.quantization_mappings import * 5 | from torch.quantization.quantize import swap_module 6 | import src.utils as utils 7 | import copy 8 | import torch.nn.intrinsic as nni 9 | 10 | from src.models.stochastic.bbb.conv import Conv2d as Conv2dBBB 11 | from src.models.stochastic.bbb.conv import ConvReLU2d as ConvReLU2dBBB 12 | from src.models.stochastic.bbb.conv import ConvBn2d as ConvBn2dBBB 13 | from src.models.stochastic.bbb.conv import ConvBnReLU2d as ConvBnReLU2dBBB 14 | 15 | from src.models.stochastic.bbb.quantized.conv_q import Conv2d as Conv2dBBB_Q 16 | from src.models.stochastic.bbb.quantized.conv_q import ConvReLU2d as ConvReLU2dBBB_Q 17 | 18 | from src.models.stochastic.bbb.quantized.conv_qat import Conv2d as Conv2dBBB_QAT 19 | from src.models.stochastic.bbb.quantized.conv_qat import ConvReLU2d as ConvReLU2dBBB_QAT 20 | from src.models.stochastic.bbb.quantized.conv_qat import ConvBn2d as ConvBn2dBBB_QAT 21 | from src.models.stochastic.bbb.quantized.conv_qat import ConvBnReLU2d as ConvBnReLU2dBBB_QAT 22 | 23 | from src.models.stochastic.bbb.linear import Linear as LinearBBB 24 | from src.models.stochastic.bbb.linear import LinearReLU as LinearReLUBBB 25 | from src.models.stochastic.bbb.quantized.linear_q import Linear as LinearBBB_Q 26 | from src.models.stochastic.bbb.quantized.linear_q import LinearReLU as LinearReLUBBB_Q 27 | from src.models.stochastic.bbb.quantized.linear_qat import Linear as LinearBBB_QAT 28 | from src.models.stochastic.bbb.quantized.linear_qat import LinearReLU as LinearReLUBBB_QAT 29 | 30 | DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST = get_qconfig_propagation_list() 31 | DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST.add(LinearBBB) 32 | DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST.add(LinearReLUBBB) 33 | DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST.add(Conv2dBBB) 34 | DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST.add(ConvReLU2dBBB) 35 | DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST.add(ConvBn2dBBB) 36 | DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST.add(ConvBnReLU2dBBB) 37 | 38 | 39 | QAT_MODULE_MAPPINGS[LinearBBB] = LinearBBB_QAT 40 | QAT_MODULE_MAPPINGS[LinearReLUBBB] = LinearReLUBBB_QAT 41 | QAT_MODULE_MAPPINGS[Conv2dBBB] = Conv2dBBB_QAT 42 | QAT_MODULE_MAPPINGS[ConvReLU2dBBB] = ConvReLU2dBBB_QAT 43 | QAT_MODULE_MAPPINGS[ConvBn2dBBB] = ConvBn2dBBB_QAT 44 | QAT_MODULE_MAPPINGS[ConvBnReLU2dBBB] = ConvBnReLU2dBBB_QAT 45 | 46 | STATIC_QUANT_MODULE_MAPPINGS[LinearBBB] = LinearBBB_Q 47 | STATIC_QUANT_MODULE_MAPPINGS[LinearBBB_QAT] = LinearBBB_Q 48 | STATIC_QUANT_MODULE_MAPPINGS[LinearReLUBBB] = LinearReLUBBB_Q 49 | STATIC_QUANT_MODULE_MAPPINGS[LinearReLUBBB_QAT] = LinearReLUBBB_Q 50 | STATIC_QUANT_MODULE_MAPPINGS[Conv2dBBB] = Conv2dBBB_Q 51 | STATIC_QUANT_MODULE_MAPPINGS[Conv2dBBB_QAT] = Conv2dBBB_Q 52 | STATIC_QUANT_MODULE_MAPPINGS[ConvReLU2dBBB] = ConvReLU2dBBB_Q 53 | STATIC_QUANT_MODULE_MAPPINGS[ConvReLU2dBBB_QAT] = ConvReLU2dBBB_Q 54 | STATIC_QUANT_MODULE_MAPPINGS[ConvBn2dBBB_QAT] = Conv2dBBB_Q 55 | STATIC_QUANT_MODULE_MAPPINGS[ConvBnReLU2dBBB_QAT] = ConvReLU2dBBB_Q 56 | 57 | STATIC_QUANT_MODULE_MAPPINGS[nni.ConvBn2d] = torch.nn.quantized.Conv2d 58 | STATIC_QUANT_MODULE_MAPPINGS[torch.nn.intrinsic.qat.modules.conv_fused.ConvBn2d] = torch.nn.quantized.Conv2d 59 | STATIC_QUANT_MODULE_MAPPINGS[nni.ConvBnReLU2d] = torch.nn.intrinsic.quantized.ConvReLU2d 60 | STATIC_QUANT_MODULE_MAPPINGS[torch.nn.intrinsic.qat.modules.conv_fused.ConvBnReLU2d] = torch.nn.intrinsic.quantized.ConvReLU2d 61 | 62 | def convert(model, mapping=None, inplace=True): 63 | def _convert(module, mapping=None, inplace=True): 64 | if mapping is None: 65 | mapping = STATIC_QUANT_MODULE_MAPPINGS 66 | if not inplace: 67 | module = copy.deepcopy(module) 68 | reassign = {} 69 | SWAPPABLE_MODULES = (nni.ConvBn2d, 70 | nni.ConvBnReLU2d, 71 | torch.nn.intrinsic.qat.modules.conv_fused.ConvBnReLU2d, 72 | torch.nn.intrinsic.qat.modules.conv_fused.ConvBn2d, 73 | nni.LinearReLU, 74 | nni.BNReLU2d, 75 | nni.BNReLU3d, 76 | nni.ConvBn1d, 77 | nni.ConvReLU1d, 78 | nni.ConvBnReLU1d, 79 | nni.ConvReLU2d, 80 | nni.ConvReLU3d, 81 | LinearReLUBBB, 82 | ConvReLU2dBBB, 83 | ConvBn2dBBB, 84 | ConvBnReLU2dBBB) 85 | 86 | for name, mod in module.named_children(): 87 | if type(mod) not in SWAPPABLE_MODULES: 88 | _convert(mod, mapping, inplace=True) 89 | swap = swap_module(mod, mapping) 90 | reassign[name] = swap 91 | 92 | for key, value in reassign.items(): 93 | module._modules[key] = value 94 | 95 | return module 96 | if mapping is None: 97 | mapping = STATIC_QUANT_MODULE_MAPPINGS 98 | model = _convert(model, mapping=mapping, inplace=inplace) 99 | return model 100 | 101 | def postprocess_model(model, args, q=None, at=None, special_info=""): 102 | if q is None: 103 | q = args.q 104 | if at is None: 105 | at = args.at 106 | if q and at and 'sgld' not in args.model: 107 | model = model.cpu() 108 | utils.load_model(model, args.save+"/weights{}.pt".format(special_info)) 109 | convert(model) 110 | utils.save_model(model, args, special_info) 111 | 112 | def prepare_model(model, args, q=None, at=None): 113 | if q is None: 114 | q = args.q 115 | if at is None: 116 | at = args.at 117 | 118 | torch.backends.quantized.engine = 'fbgemm' 119 | 120 | assert 2 <= args.activation_precision and args.activation_precision <= 7 121 | assert 2 <= args.weight_precision and args.weight_precision <= 8 122 | 123 | activation_precision = utils.UINT_BOUNDS[args.activation_precision] 124 | weight_precision = utils.INT_BOUNDS[args.weight_precision] 125 | 126 | if hasattr(model, 'fuse_model'): 127 | model.fuse_model() 128 | 129 | model.qconfig = torch.quantization.QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, 130 | dtype=torch.quint8, 131 | quant_min=activation_precision[0], 132 | quant_max=activation_precision[1], 133 | qscheme=torch.per_tensor_affine), 134 | weight=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, 135 | quant_min=weight_precision[0], 136 | quant_max=weight_precision[1], 137 | dtype=torch.qint8, 138 | qscheme=torch.per_tensor_affine)) 139 | if not 'bbb' in args.model: 140 | torch.quantization.prepare_qat(model, inplace=True) 141 | else: 142 | torch.quantization.prepare( 143 | model, allow_list=DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST, inplace=True) 144 | torch.quantization.prepare( 145 | model, allow_list=DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST, inplace=True, observer_non_leaf_module_list=[LinearBBB, Conv2dBBB]) 146 | 147 | convert(model, mapping=QAT_MODULE_MAPPINGS) 148 | 149 | 150 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import src.utils as utils 4 | import time 5 | import logging 6 | from src.models.stochastic.sgld.utils_sgld import SGLD 7 | import numpy as np 8 | from src.metrics import ClassificationMetric, RegressionMetric 9 | 10 | class Trainer(): 11 | def __init__(self, model, criterion, optimizer, scheduler, args, writer=None): 12 | super().__init__() 13 | self.model = model 14 | self.criterion = criterion 15 | self.optimizer = optimizer 16 | self.scheduler = scheduler 17 | self.args = args 18 | 19 | self.train_step = 0 20 | self.train_time = 0.0 21 | self.val_step = 0 22 | self.val_time = 0.0 23 | 24 | self.grad_buff = [] 25 | self.max_grad = 1e20 26 | self.grad_std_mul = 30 27 | 28 | self.epoch = 0 29 | self.iteration = 0 30 | 31 | self.train_metrics = ClassificationMetric(output_size=self.args.output_size, writer=writer) if "classification" in self.args.task else RegressionMetric(output_size=self.args.output_size, writer=writer) 32 | self.valid_metrics = ClassificationMetric(output_size=self.args.output_size, writer=writer) if "classification" in self.args.task else RegressionMetric(output_size=self.args.output_size, writer=writer) 33 | self.writer = writer 34 | 35 | def train_loop(self, train_loader, valid_loader, special_info=""): 36 | best_error = float('inf') 37 | 38 | for epoch in range(self.args.epochs): 39 | if epoch >= 1 and self.scheduler is not None: 40 | self.scheduler.step() 41 | 42 | if self.scheduler is not None: 43 | lr = self.scheduler.get_last_lr()[0] 44 | else: 45 | lr = self.args.learning_rate 46 | 47 | if self.writer is not None: 48 | self.writer.add_scalar('Train/learning_rate', lr, epoch) 49 | 50 | logging.info( 51 | '### Epoch: [%d/%d], Learning rate: %e ###', self.args.epochs, 52 | epoch, lr) 53 | if hasattr(self.args, 'gamma'): 54 | logging.info( 55 | '### Epoch: [%d/%d], Gamma: %e ###', self.args.epochs, 56 | epoch, self.args.gamma) 57 | 58 | self.train_metrics.reset() 59 | self.train(train_loader, self.optimizer) 60 | 61 | logging.info("#### Train | %s ####", self.train_metrics.get_str()) 62 | 63 | self.train_metrics.scalar_logging("train", epoch) 64 | 65 | # validation 66 | if valid_loader is not None: 67 | self.valid_metrics.reset() 68 | self.infer(valid_loader, "Valid") 69 | logging.info("#### Valid | %s ####", self.valid_metrics.get_str()) 70 | val_error_metric = self.valid_metrics.get_key_metric() 71 | 72 | if self.args.save_last or val_error_metric <= best_error: 73 | # Avoid correlation between the samples 74 | _special_info = None 75 | if hasattr(self.args, 'burnin_epochs') and epoch>=self.args.burnin_epochs and epoch%2==0 and epoch>=self.args.epochs-self.args.samples*2: 76 | _special_info= special_info+"_"+str(epoch) 77 | if _special_info is None: 78 | _special_info = special_info 79 | utils.save_model(self.model, self.args, _special_info) 80 | best_error = val_error_metric 81 | logging.info( 82 | '### Epoch: [%d/%d], Saving model! Current best error: %f ###', self.args.epochs, 83 | epoch, best_error) 84 | self.epoch+=1 85 | return best_error, self.train_time, self.val_time 86 | 87 | def _step(self, input, target, optimizer, n_batches, n_points, train_timer): 88 | start = time.time() 89 | if next(self.model.parameters()).is_cuda: 90 | input = input.cuda() 91 | target = target.cuda() 92 | 93 | if optimizer is not None: 94 | optimizer.zero_grad() 95 | output = self.model(input) 96 | if hasattr(self.model, 'get_kl_divergence'): 97 | kl = self.model.get_kl_divergence() 98 | else: 99 | kl = torch.tensor([0.0]).view(1).to(input.device) 100 | obj, main_obj, kl = self.criterion( 101 | output, target, kl, self.args.gamma if hasattr(self.args, 'gamma') else 0., n_batches, n_points) 102 | 103 | if optimizer is not None and obj == obj: 104 | obj.backward() 105 | for p in self.model.parameters(): 106 | if p.grad is not None: 107 | p.grad[p.grad != p.grad] = 0 108 | if isinstance(optimizer, SGLD): 109 | if len(self.grad_buff) > 1000: 110 | self.max_grad = np.mean(self.grad_buff) + \ 111 | self.grad_std_mul * np.std(self.grad_buff) 112 | self.grad_buff.pop(0) 113 | # Clipping to prevent explosions 114 | self.grad_buff.append(torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), 115 | max_norm=self.max_grad, norm_type=2).item()) 116 | if self.grad_buff[-1] >= self.max_grad: 117 | self.grad_buff.pop() 118 | 119 | optimizer.step(burn_in=(self.epoch < self.args.burnin_epochs), 120 | resample_momentum=(self.iteration % self.args.resample_momentum_iterations == 0), 121 | resample_prior=(self.iteration % self.args.resample_prior_iterations == 0)) 122 | else: 123 | optimizer.step() 124 | self.iteration+=1 125 | 126 | 127 | if train_timer: 128 | self.train_metrics.update(output=output, target=target, obj=obj, kl=kl, main_obj=main_obj) 129 | self.train_time += time.time() - start 130 | else: 131 | self.valid_metrics.update(output=output, target=target, obj=obj, kl=kl, main_obj=main_obj) 132 | self.val_time += time.time() - start 133 | 134 | 135 | 136 | def train(self, loader, optimizer): 137 | self.model.train() 138 | 139 | for step, (input, target) in enumerate(loader): 140 | self._step(input, target, optimizer, len(loader), len(loader.dataset), True) 141 | 142 | 143 | if step % self.args.report_freq == 0: 144 | logging.info( 145 | "##### Train step: [%03d/%03d] | %s #####", 146 | len(loader), 147 | step, 148 | self.train_metrics.get_str(), 149 | ) 150 | self.train_step += 1 151 | if self.args.debug: 152 | break 153 | 154 | def infer(self, loader, dataset="Valid"): 155 | with torch.no_grad(): 156 | self.model.eval() 157 | 158 | for step, (input, target) in enumerate(loader): 159 | n = input.shape[0] 160 | self._step( 161 | input, target, None, len(loader), n * len(loader), False) 162 | 163 | if step % self.args.report_freq == 0: 164 | logging.info( 165 | "##### %s step: [%03d/%03d] | %s #####", 166 | dataset, 167 | len(loader), 168 | step, 169 | self.valid_metrics.get_str(), 170 | ) 171 | self.val_step += 1 172 | 173 | if self.args.debug: 174 | break -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import shutil 5 | import random 6 | import pickle 7 | import sys 8 | import time 9 | import glob 10 | import logging 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torch.backends.cudnn as cudnn 13 | import re 14 | import shutil 15 | import copy 16 | 17 | UINT_BOUNDS = {8: [0, 255], 7: [0, 127], 6: [0, 63], 5: [0, 31], 4: [0, 15], 3: [0, 7], 2: [0, 3]} 18 | INT_BOUNDS = {8: [-128, 127], 7: [-64, 63], 6: [-32, 31], 19 | 5: [-16, 15], 4: [-8, 7], 3: [-4, 3], 2: [-2, 1]} 20 | BRIGHTNESS_LEVELS = [(1.5, 1.5), (2., 2.), (2.5, 2.5), (3, 3), (3.5, 3.5)] 21 | ROTATION_LEVELS = [(15,15),(30,30),(45,45),(60,60),(75,75)] 22 | SHIFT_LEVELS = [0.1,0.2,0.3,0.4,0.5] 23 | 24 | def clamp_activation(x, args): 25 | if x.dtype == torch.quint8: 26 | _min = (UINT_BOUNDS[args.activation_precision][0]-x.q_zero_point())*x.q_scale() 27 | _max = (UINT_BOUNDS[args.activation_precision][1]-x.q_zero_point())*x.q_scale() 28 | x = torch.clamp(x, _min, _max) 29 | return x 30 | 31 | def clamp_weight(x, args): 32 | if x.dtype == torch.qint8: 33 | _min = (INT_BOUNDS[args.weight_precision][0]-x.q_zero_point())*x.q_scale() 34 | _max = (INT_BOUNDS[args.weight_precision][1]-x.q_zero_point())*x.q_scale() 35 | x = torch.clamp(x, _min, _max) 36 | return x 37 | 38 | 39 | class Flatten(torch.nn.Module): 40 | def __init__(self): 41 | super(Flatten, self).__init__() 42 | 43 | def forward(self, x): 44 | if len(x.shape)==1: 45 | return x.unsqueeze(dim=0) 46 | return x.reshape(x.size(0), -1) 47 | 48 | class Add(torch.nn.Module): 49 | def __init__(self): 50 | super(Add, self).__init__() 51 | self.add = torch.nn.quantized.FloatFunctional() 52 | 53 | def forward(self, x, y): 54 | return self.add.add(x,y) 55 | 56 | def atoi(text): 57 | return int(text) if text.isdigit() else text 58 | 59 | def natural_keys(text): 60 | return [atoi(c) for c in re.split(r'(-?\d+)', text)] 61 | 62 | def size_of_model(model): 63 | torch.save(model.state_dict(), "temp.p") 64 | size = os.path.getsize("temp.p")/1e6 65 | os.remove('temp.p') 66 | return size 67 | 68 | 69 | def save_model(model, args, special_info=""): 70 | _model = model 71 | if args.q and args.at and 'sgld' in args.model: 72 | from src.quant_utils import convert 73 | _model_copy = copy.deepcopy(model) 74 | _model = convert(_model_copy.cpu(), inplace=False) 75 | torch.save(_model.state_dict(), os.path.join(args.save, 'weights'+special_info+'.pt')) 76 | 77 | with open(os.path.join(args.save, 'args.pt'), 'wb') as handle: 78 | pickle.dump(args, handle, protocol=pickle.HIGHEST_PROTOCOL) 79 | 80 | def save_pickle(data, path, overwrite=False): 81 | path = check_path(path) if not overwrite else path 82 | with open(path, 'wb') as fp: 83 | pickle.dump(data, fp, protocol=pickle.HIGHEST_PROTOCOL) 84 | 85 | def isoutlier(val): 86 | return val == np.inf or val == -np.inf or val<-9e1 or val>9e1 or np.isnan(val) 87 | 88 | def load_pickle(path): 89 | file = open(path, 'rb') 90 | return pickle.load(file) 91 | 92 | def transfer_weights(quantized_model, model): 93 | state_dict = model.state_dict() 94 | model = model.to('cpu') 95 | quantized_model.load_state_dict(state_dict) 96 | 97 | def load_model(model, model_path, replace=True): 98 | state_dict = torch.load(model_path, map_location=torch.device('cpu')) 99 | model_dict = model.state_dict() 100 | pretrained_dict = {} 101 | for k,v in state_dict.items(): 102 | _k = k 103 | if replace: 104 | _k = k.replace('module.','').replace('main_net.','') 105 | pretrained_dict[_k] = v 106 | pretrained_dict = {k: v for k,v in pretrained_dict.items() if k in model_dict.keys()} 107 | model_dict.update(pretrained_dict) 108 | model.load_state_dict(model_dict) 109 | 110 | def create_exp_dir(path, scripts_to_save=None): 111 | path = check_path(path) 112 | os.mkdir(path) 113 | 114 | if scripts_to_save is not None: 115 | os.mkdir(os.path.join(path, 'scripts')) 116 | for script in scripts_to_save: 117 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 118 | shutil.copyfile(script, dst_file) 119 | 120 | def check_path(path): 121 | if os.path.exists(path): 122 | filename, file_extension = os.path.splitext(path) 123 | counter = 0 124 | while os.path.exists(filename+"_"+str(counter)+file_extension): 125 | counter+=1 126 | return filename+"_"+str(counter)+file_extension 127 | return path 128 | 129 | 130 | def model_to_gpus(model, args): 131 | if args.gpu!= -1: 132 | device = torch.device("cuda:"+str(args.gpu)) 133 | model = model.to(device) 134 | return model 135 | 136 | def check_quantized(x): 137 | return x.dtype == torch.qint8 or x.dtype == torch.quint8 138 | 139 | def parse_args(args, label=""): 140 | if label=="": 141 | q="not_q" 142 | if args.q: 143 | q="q" 144 | if args.at: 145 | q+="at" 146 | label = q 147 | loading_path = args.save 148 | dataset = args.dataset if hasattr(args, 'dataset') else "" 149 | task = args.task if hasattr(args, 'task') else "" 150 | new_path = '{}-{}-{}-{}'.format(label, dataset, task, time.strftime("%Y%m%d-%H%M%S")) 151 | 152 | create_exp_dir( 153 | new_path, scripts_to_save=glob.glob('*.py') + \ 154 | glob.glob('../../src/**/*.py', recursive=True) + \ 155 | glob.glob('../../../src/**/*.py', recursive=True) + \ 156 | glob.glob('../../../../src/**/*.py', recursive=True) + \ 157 | glob.glob('../../../../../src/**/*.py', recursive=True) + \ 158 | glob.glob('../../../experiments/*.py', recursive=True) + \ 159 | glob.glob('../../../../experiments/*.py', recursive=True) + \ 160 | glob.glob('../../../../../experiments/*.py', recursive=True)) 161 | args.save = new_path 162 | if loading_path!="EXP": 163 | for root, dirs, files in os.walk(loading_path): 164 | for filename in files: 165 | if ".pt" in filename: 166 | shutil.copy(os.path.join(loading_path, filename), os.path.join(new_path, filename)) 167 | 168 | log_format = '%(asctime)s %(message)s' 169 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 170 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 171 | log_path = os.path.join(args.save, 'log.log') 172 | log_path = check_path(log_path) 173 | 174 | fh = logging.FileHandler(log_path) 175 | fh.setFormatter(logging.Formatter(log_format)) 176 | logging.getLogger().addHandler(fh) 177 | 178 | print('Experiment dir : {}'.format(args.save)) 179 | 180 | writer = SummaryWriter( 181 | log_dir=args.save+"/",max_queue=5) 182 | if torch.cuda.is_available() and args.gpu!=-1: 183 | logging.info('## GPUs available = {} ##'.format(args.gpu)) 184 | torch.cuda.set_device(args.gpu) 185 | cudnn.benchmark = True 186 | cudnn.enabled = True 187 | torch.cuda.manual_seed(args.seed) 188 | else: 189 | logging.info('## No GPUs detected ##') 190 | 191 | random.seed(args.seed) 192 | np.random.seed(args.seed) 193 | torch.manual_seed(args.seed) 194 | logging.info("## Args = %s ##", args) 195 | 196 | path = os.path.join(args.save, 'results.pickle') 197 | path= check_path(path) 198 | results = {} 199 | results["dataset"] = args.dataset if hasattr(args, 'dataset') else "" 200 | results["model"] = args.model if hasattr(args, 'model') else "" 201 | results["error"] = {} 202 | results["nll"] = {} 203 | results["latency"] = {} 204 | results["ece"] = {} 205 | results["entropy"] = {} 206 | 207 | save_pickle(results, path, True) 208 | 209 | return args, writer 210 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/quantised-bayesian-nets/43cac356d4b086e3b3b987b860ce43507c79facf/tests/__init__.py -------------------------------------------------------------------------------- /tests/plot_datasets.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.gridspec as gridspec 3 | import sys 4 | import argparse 5 | import numpy as np 6 | import logging 7 | 8 | sys.path.append("../") 9 | sys.path.append("../../") 10 | sys.path.append("../../../") 11 | sys.path.append("../../../../") 12 | sys.path.append("../../../../../") 13 | 14 | from src.data import * 15 | import src.utils as utils 16 | from src.data import get_test_loader, CIFAR_MEAN, CIFAR_STD, MNIST_MEAN, MNIST_STD 17 | from experiments.presentation.plot_settings import PLT as plt 18 | 19 | 20 | parser = argparse.ArgumentParser("test_distortions") 21 | 22 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 23 | parser.add_argument('--data', type=str, 24 | default='./../experiments/data', help='experiment name') 25 | 26 | parser.add_argument('--label', type=str, default='dataset_sample_plots', 27 | help='default experiment category ') 28 | parser.add_argument('--dataset', type=str, default='', 29 | help='default dataset ') 30 | parser.add_argument('--batch_size', type=int, default=64, 31 | help='default batch size') 32 | parser.add_argument('--num_workers', type=int, 33 | default=1, help='default batch size') 34 | parser.add_argument('--valid_portion', type=float, 35 | default=0.1, help='portion of training data') 36 | parser.add_argument('--gpu', type=int, 37 | default=-1, help='portion of training data') 38 | parser.add_argument('--input_size', nargs='+', 39 | default=[1, 3, 32, 32], help='input size') 40 | parser.add_argument('--seed', type=int, 41 | default=1, help='input size') 42 | parser.add_argument('--q', type=bool, 43 | default=False, help='input size') 44 | 45 | INPUT_SIZES = {"mnist": (1, 1, 28, 28), "cifar": ( 46 | 1, 3, 32, 32)} 47 | 48 | 49 | def main(): 50 | args = parser.parse_args() 51 | logging.info('## Testing datasets ##') 52 | args, _ = utils.parse_args(args, args.label) 53 | datasets = ['mnist', 'cifar'] 54 | 55 | for i, dataset in enumerate(datasets): 56 | plt.figure() 57 | gs = gridspec.GridSpec(5, 4) 58 | gs.update(wspace=0, hspace=0) 59 | args.dataset = dataset 60 | train_loader, valid_loader = get_train_loaders(args) 61 | test_loader = get_test_loader(args) 62 | args.input_size = INPUT_SIZES[dataset] 63 | args.dataset = "random_"+dataset 64 | random_loader = get_test_loader(args) 65 | for j, loader in enumerate([train_loader, valid_loader, test_loader, random_loader]): 66 | input, _ = next(iter(loader)) 67 | input = input[:5] 68 | for k, image in enumerate(input): 69 | plt.subplot(gs[k, j]) 70 | if "mnist" in args.dataset: 71 | plt.imshow(image.squeeze().numpy(), cmap='gray') 72 | elif "cifar" in args.dataset: 73 | means = np.array(CIFAR_MEAN).reshape((3, 1, 1)) 74 | stds = np.array(CIFAR_STD).reshape((3, 1, 1)) 75 | image = (image.numpy()*stds)+means 76 | plt.imshow(np.transpose(image, (1, 2, 0))) 77 | plt.axis('off') 78 | plt.tight_layout() 79 | path = utils.check_path(args.save+'/{}.png'.format(dataset)) 80 | plt.savefig(path) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /tests/plot_distortions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import numpy as np 4 | import logging 5 | 6 | sys.path.append("../") 7 | sys.path.append("../../") 8 | sys.path.append("../../../") 9 | sys.path.append("../../../../") 10 | sys.path.append("../../../../../") 11 | 12 | 13 | from experiments.utils import DISTORTIONS, LEVELS 14 | import src.utils as utils 15 | from src.data import get_test_loader 16 | from experiments.presentation.plot_settings import PLT as plt 17 | import matplotlib.gridspec as gridspec 18 | from src.data import CIFAR_MEAN, CIFAR_STD 19 | 20 | parser = argparse.ArgumentParser("test_distortions") 21 | 22 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 23 | parser.add_argument('--data', type=str, default='./../experiments/data', help='experiment name') 24 | 25 | parser.add_argument('--label', type=str, default='test_distortions', help='default experiment category ') 26 | parser.add_argument('--dataset', type=str, default='mnist', help='default dataset ') 27 | parser.add_argument('--batch_size', type=int, default=64, help='default batch size') 28 | parser.add_argument('--num_workers', type=int, default=1, help='default batch size') 29 | 30 | 31 | parser.add_argument('--seed', type=int, default=1, help='random seed') 32 | parser.add_argument('--debug', action='store_true', help='whether we are currently debugging') 33 | 34 | parser.add_argument('--gpu', type=int, default = 0, help='gpu device ids') 35 | 36 | 37 | 38 | def main(): 39 | args = parser.parse_args() 40 | args, _ = utils.parse_args(args, args.label) 41 | logging.info('## Testing distortions ##') 42 | 43 | for distortion in DISTORTIONS: 44 | plt.figure(figsize=(3, 1)) 45 | gs = gridspec.GridSpec(1, 5) 46 | gs.update(wspace=0, hspace=0) 47 | for level in range(LEVELS): 48 | test_loader = get_test_loader(args, distortion=distortion, level=level) 49 | input, _ = next(iter(test_loader)) 50 | plt.subplot(gs[level]) 51 | if args.dataset == "mnist": 52 | image = input[0] 53 | plt.imshow(image.squeeze().numpy(), cmap='gray') 54 | elif args.dataset == "cifar": 55 | image = input[2] 56 | means = np.array(CIFAR_MEAN).reshape((3,1,1)) 57 | stds = np.array(CIFAR_STD).reshape((3,1,1)) 58 | image = (image.numpy()*stds)+means 59 | plt.imshow(np.transpose(image,(1,2,0))) 60 | 61 | plt.axis('off') 62 | plt.tight_layout() 63 | path = utils.check_path(args.save+'/{}.png'.format(distortion)) 64 | plt.savefig(path) 65 | 66 | if __name__ == '__main__': 67 | main() --------------------------------------------------------------------------------