├── URSABench ├── hyperopt │ └── __init__.py ├── trtprof │ ├── __init__.py │ ├── readme.md │ ├── batch_torch2onnx.sh │ ├── batch_onnx2trt.sh │ ├── dataset.py │ ├── pred.bash │ ├── make_table.py │ ├── to_onnx.py │ ├── utils.py │ ├── prof.py │ └── run_prediction.py ├── hyperparams │ ├── ResNet50CIFAR10 │ │ ├── sgld_hyperparams.json │ │ ├── sghmc_hyperparams.json │ │ ├── mc_dropout_hyperparams.json │ │ ├── csgld_hyperparams.json │ │ ├── csghmc_hyperparams.json │ │ ├── swag_hyperparams.json │ │ └── pca_ess_hyperparams.json │ ├── ResNet50CIFAR100 │ │ ├── sgld_hyperparams.json │ │ ├── sghmc_hyperparams.json │ │ ├── mc_dropout_hyperparams.json │ │ ├── csgld_hyperparams.json │ │ ├── swag_hyperparams.json │ │ ├── csghmc_hyperparams.json │ │ └── pca_ess_hyperparams.json │ ├── ResNet50ImageNet │ │ ├── sgld_hyperparams.json │ │ ├── sghmc_hyperparams.json │ │ ├── mc_dropout_hyperparams.json │ │ ├── csgld_hyperparams.json │ │ ├── csghmc_hyperparams.json │ │ ├── swag_hyperparams.json │ │ └── pca_ess_hyperparams.json │ ├── WideResNet28x10CIFAR10 │ │ ├── sgld_hyperparams.json │ │ ├── sghmc_hyperparams.json │ │ ├── mc_dropout_hyperparams.json │ │ ├── csgld_hyperparams.json │ │ ├── swag_hyperparams.json │ │ ├── csghmc_hyperparams.json │ │ └── pca_ess_hyperparams.json │ ├── WideResNet28x10CIFAR100 │ │ ├── sgld_hyperparams.json │ │ ├── sghmc_hyperparams.json │ │ ├── mc_dropout_hyperparams.json │ │ ├── csgld_hyperparams.json │ │ ├── swag_hyperparams.json │ │ ├── csghmc_hyperparams.json │ │ └── pca_ess_hyperparams.json │ └── MLP200MNIST │ │ ├── SGD_BO.json │ │ ├── SGLD_BO.json │ │ ├── HMC_BO.json │ │ ├── SGHMC_BO.json │ │ ├── cSGLD_BO.json │ │ ├── MCdropout_BO.json │ │ ├── cSGHMC_BO.json │ │ └── PCASubspaceSampler_BO.json ├── models │ ├── __init__.py │ ├── mlp.py │ ├── resnet.py │ ├── imagenet_resnet.py │ ├── wideresnet.py │ └── preresnet.py ├── tasks │ ├── __init__.py │ ├── task_base.py │ ├── ood_detection_distilled.py │ ├── decision_making.py │ └── ood_detection.py ├── __init__.py ├── inference │ ├── __init__.py │ ├── projection_model.py │ ├── sgld.py │ ├── csgld.py │ ├── inference_base.py │ ├── optim_sghmc.py │ ├── hmc.py │ ├── sgd.py │ ├── sghmc.py │ ├── vi_dropout.py │ ├── csghmc.py │ ├── subspaces.py │ ├── pca_subspace.py │ ├── swag.py │ └── swa.py ├── run_seq_hypOpt.py ├── run_par_hypOpt.py ├── time_script.py └── experiment.py ├── setup.py ├── LICENSE ├── .gitignore └── README.md /URSABench/hyperopt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /URSABench/trtprof/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR10/sgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR100/sgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50ImageNet/sgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.001, "prior_std": 0.1, "burn_in_epochs": 5, "num_samples": 6} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR10/sgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR100/sgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/SGD_BO.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "epochs": 50, "momentum": 0.9, "weight_decay": 0.00030405964935198426, "num_samples": 30} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/SGLD_BO.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.09999999403953552, "prior_std": 0.1664159595966339, "burn_in_epochs": 50, "num_samples": 100} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR10/sghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "alpha": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR100/sghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "alpha": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50ImageNet/sghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.01, "prior_std": 0.1, "alpha": 0.5, "burn_in_epochs": 5, "num_samples": 6} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR10/sghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "alpha": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR100/sghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "prior_std": 0.5, "alpha": 0.5, "burn_in_epochs": 100, "num_samples": 50} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/HMC_BO.json: -------------------------------------------------------------------------------- 1 | {"step_size": 0.00020897435024380684, "L": 40, "tau": 100.0, "mass": 0.1919177919626236, "num_samples": 300, "burn": 200} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR10/mc_dropout_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "epochs": 50, "dropout": 0.2, "momentum": 0.9, "weight_decay": 0.0001, "num_samples": 50} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR100/mc_dropout_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "epochs": 50, "dropout": 0.2, "momentum": 0.9, "weight_decay": 0.0001, "num_samples": 50} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50ImageNet/mc_dropout_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.01, "epochs": 0, "dropout": 0.2, "momentum": 0.9, "weight_decay": 0.0001, "num_samples": 5} 2 | -------------------------------------------------------------------------------- /URSABench/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagenet_resnet import * 2 | from .mlp import * 3 | from .preresnet import * 4 | from .resnet import * 5 | from .wideresnet import * 6 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR10/mc_dropout_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "epochs": 50, "dropout": 0.2, "momentum": 0.9, "weight_decay": 0.0001, "num_samples": 50} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR10/csgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50,"burn_in_epochs": 2, "num_cycles": 17} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR100/csgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50,"burn_in_epochs": 0, "num_cycles": 17} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50ImageNet/csgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.01, "prior_std": 0.1, "num_samples_per_cycle": 2, "cycle_length": 7,"burn_in_epochs": 1, "num_cycles": 3} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR100/mc_dropout_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.1, "epochs": 50, "dropout": 0.2, "momentum": 0.9, "weight_decay": 0.0001, "num_samples": 50} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/SGHMC_BO.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.03134895861148834, "prior_std": 0.14046818017959595, "alpha": 0.10199674218893051, "burn_in_epochs": 50, "num_samples": 100} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR10/csgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50,"burn_in_epochs": 0, "num_cycles": 17} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR100/csgld_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50,"burn_in_epochs": 0, "num_cycles": 17} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/cSGLD_BO.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.05866509675979614, "prior_std": 0.17959332466125488, "num_samples_per_cycle": 3, "cycle_length": 21, "burn_in_epochs": 1, "num_cycles": 10} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR10/csghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50, "burn_in_epochs": 2, "num_cycles": 17, "alpha": 0.5} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR10/swag_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_init": 0.01, "swag_lr": 0.01, "swag_wd": 0.0001, "momentum": 0.9, "burn_in_epochs": 1, "num_samples": 50, "num_iterates": 140} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR100/swag_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_init": 0.01, "swag_lr": 0.01, "swag_wd": 0.0001, "momentum": 0.9, "burn_in_epochs": 1, "num_samples": 50, "num_iterates": 140} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50ImageNet/csghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.01, "prior_std": 0.1, "num_samples_per_cycle": 2, "cycle_length": 7, "burn_in_epochs": 1, "num_cycles": 3, "alpha": 0.5} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50ImageNet/swag_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_init": 0.005, "swag_lr": 0.005, "swag_wd": 0.0001, "momentum": 0.9, "burn_in_epochs": 1, "num_samples": 6, "num_iterates": 20} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR100/csghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50, "burn_in_epochs": 2, "num_cycles": 17, "alpha": 0.5} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR10/swag_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_init": 0.05, "swag_lr": 0.05, "swag_wd": 0.0005, "momentum": 0.9, "burn_in_epochs": 1, "num_samples": 50, "num_iterates": 10} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR100/swag_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_init": 0.05, "swag_lr": 0.05, "swag_wd": 0.0005, "momentum": 0.9, "burn_in_epochs": 1, "num_samples": 50, "num_iterates": 140} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .prediction import * 2 | from .ood_detection import * 3 | from .decision_making import * 4 | from .ood_detection_distilled import * 5 | from .prediction_distilled import * 6 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR10/csghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50, "burn_in_epochs": 0, "num_cycles": 17, "alpha": 0.5} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR100/csghmc_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.5, "prior_std": 0.5, "num_samples_per_cycle": 3, "cycle_length": 50, "burn_in_epochs": 0, "num_cycles": 17, "alpha": 0.5} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/MCdropout_BO.json: -------------------------------------------------------------------------------- 1 | {"lr": 0.08323124796152115, "epochs": 0, "dropout": 0.2, "momentum": 0.9040795087814331, "lengthscale": 0.0016902670031413436, "num_samples": 100, "weight_decay": 0.0003} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/cSGHMC_BO.json: -------------------------------------------------------------------------------- 1 | {"lr_0": 0.06825362145900726, "prior_std": 0.33042997121810913, "num_samples_per_cycle": 4, "cycle_length": 22, "burn_in_epochs": 1, "num_cycles": 10, "alpha": 0.21256764233112335} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR10/pca_ess_hyperparams.json: -------------------------------------------------------------------------------- 1 | {'swag_lr': 0.01,'swag_wd': 0.0001,'lr_init': 0.1, 'num_samples': 50, 'swag_momentum': 0.9, 'swag_burn_in_epochs':160, 'num_swag_iterates':140, 'rank': 20, 'max_rank': 20, 'temperature': 5000, 'prior_std': 2.0} 2 | -------------------------------------------------------------------------------- /URSABench/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.0.dev1' 2 | 3 | from .util import set_random_seed 4 | from .hyperopt.hyper_optimization import GridSearch, BayesOpt 5 | from .inference import HMC, SGLD, SGHMC 6 | from .models import * 7 | from .tasks import * 8 | from .datasets import * 9 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50CIFAR100/pca_ess_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"swag_lr": 0.05, "swag_wd": 0.0001,"lr_init": 0.1, "num_samples": 50, "swag_momentum": 0.9, "swag_burn_in_epochs":160, "num_swag_iterates":140, "rank": 20, "max_rank": 20, "temperature": 5000, "prior_std": 2.0} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/ResNet50ImageNet/pca_ess_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"swag_lr": 0.005, "swag_wd": 0.0001,"lr_init": 0.005, "num_samples": 6, "swag_momentum": 0.9, "swag_burn_in_epochs": 1, "num_swag_iterates":30, "rank": 20, "max_rank": 20, "temperature": 10000, "prior_std": 2.0} 2 | 3 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR10/pca_ess_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"swag_lr": 0.05,"swag_wd": 0.0005,"lr_init": 0.1, "num_samples": 50, "swag_momentum": 0.9, "swag_burn_in_epochs":160, "num_swag_iterates":140, "rank": 20, "max_rank": 20, "temperature": 5000, "prior_std": 2.0} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/WideResNet28x10CIFAR100/pca_ess_hyperparams.json: -------------------------------------------------------------------------------- 1 | {"swag_lr": 0.05,"swag_wd": 0.0005,"lr_init": 0.1, "num_samples": 50, "swag_momentum": 0.9, "swag_burn_in_epochs":160, "num_swag_iterates":140, "rank": 20, "max_rank": 20, "temperature": 5000, "prior_std": 2.0} 2 | -------------------------------------------------------------------------------- /URSABench/hyperparams/MLP200MNIST/PCASubspaceSampler_BO.json: -------------------------------------------------------------------------------- 1 | {"lr_init": 0.04393735155463219, "swag_lr": 0.0017009282018989325, "swag_wd": 0.00039999998989515007, "swag_momentum": 0.5357643365859985, "swag_burn_in_epochs": 50, "num_samples": 30, "num_swag_iterates": 50, "rank": 20, "max_rank": 20, "temperature": 5000, "prior_std": 2.0} -------------------------------------------------------------------------------- /URSABench/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .csghmc import * 2 | from .csgld import * 3 | from .hmc import * 4 | from .sghmc import * 5 | from .sgld import * 6 | from .swa import * 7 | from .swag import * 8 | from .pca_subspace import * 9 | from .projection_model import * 10 | from .vi_dropout import * 11 | from .sgd import * 12 | -------------------------------------------------------------------------------- /URSABench/trtprof/readme.md: -------------------------------------------------------------------------------- 1 | # Profiling on Jetson Platform 2 | 3 | 1. Save PyTorch models to checkpoints to state dict files (`pt` or `pth` files). 4 | 2. Export `pt` or `pth` files to ONNX format: 5 | ```bash 6 | bash batch_torch2onnx.sh 7 | ``` 8 | 3. Export ONNX models to TensorRT engines: 9 | ```bash 10 | bash batch_onnx2trt.sh 11 | ``` 12 | 4. Profile TensorRT engines: 13 | ```bash 14 | bash pred.bash 15 | ``` 16 | -------------------------------------------------------------------------------- /URSABench/trtprof/batch_torch2onnx.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Convert a folder of PyTorch models (pt, pth) to ONNX format. 4 | 5 | INPUT_DIR="/data/ResNet50_ImageNet" 6 | NUM_CLASSES=1000 7 | MODEL_CLASSES="ResNet_ImageNet" 8 | 9 | for input_file in $INPUT_DIR/*.pt; do 10 | onnx_file="${input_file%.pt}.onnx" 11 | if [ -e $onnx_file ] 12 | then 13 | echo "$onnx_file exists: Skip" 14 | else 15 | echo "$onnx_file doesn't exist: Exporting..." 16 | python3 to_onnx.py $input_file -s 1 3 224 224 --model_class $MODEL_CLASSES || exit 1 17 | fi 18 | done 19 | -------------------------------------------------------------------------------- /URSABench/tasks/task_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class _Task: 5 | def __init__(self, data_loader=None, num_classes=None, device=torch.device('cpu')): 6 | self.data_loader = data_loader 7 | self.num_classes = num_classes 8 | self.device = device 9 | 10 | def reset(self): 11 | raise NotImplementedError 12 | 13 | def update_statistics(self, model, output_performance=False): 14 | raise NotImplementedError 15 | 16 | def ensemble_update_statistics(self, model_list, output_performance=False): 17 | raise NotImplementedError 18 | 19 | def get_performance_metrics(self): 20 | raise NotImplementedError 21 | -------------------------------------------------------------------------------- /URSABench/trtprof/batch_onnx2trt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # MODEL_SUFFIX="trt32" 4 | # TRT_FLAGS="--explicitBatch" 5 | MODEL_SUFFIX="trt" 6 | TRT_FLAGS="--explicitBatch --fp16" 7 | 8 | # Convert an ONNX file to TensorRT engine 9 | do_convert() { 10 | echo "Converting $1" 11 | time trtexec --onnx="$1" --saveEngine="${1%.onnx}.$MODEL_SUFFIX" $TRT_FLAGS 12 | } 13 | 14 | # Convert a folder of ONNX models to TensorRT engines; Skip if output file already exists. 15 | convert_folder() { 16 | echo "Coverting models in folder $1" 17 | for onnx_file in $1/*.onnx; do 18 | trt_file="${onnx_file%.onnx}.$MODEL_SUFFIX" 19 | if [ -e $trt_file ] 20 | then 21 | echo "$trt_file exists: Skip" 22 | else 23 | echo "$trt_file doesn't exist: Exporting..." 24 | do_convert $onnx_file 25 | fi 26 | done 27 | } 28 | 29 | # do_convert /data/model_dense.onnx 30 | convert_folder "/data/ResNet50_ImageNet" 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from setuptools import setup, find_packages 4 | PACKAGE_NAME = 'URSABench' 5 | MINIMUM_PYTHON_VERSION = 3, 5 6 | 7 | 8 | def check_python_version(): 9 | """Exit when the Python version is too low.""" 10 | if sys.version_info < MINIMUM_PYTHON_VERSION: 11 | sys.exit("Python {}.{}+ is required.".format(*MINIMUM_PYTHON_VERSION)) 12 | 13 | 14 | def read_package_variable(key): 15 | """Read the value of a variable from the package without importing.""" 16 | module_path = os.path.join(PACKAGE_NAME, '__init__.py') 17 | with open(module_path) as module: 18 | for line in module: 19 | parts = line.strip().split(' ') 20 | if parts and parts[0] == key: 21 | return parts[-1].strip("'") 22 | assert 0, "'{0}' not found in '{1}'".format(key, module_path) 23 | 24 | 25 | check_python_version() 26 | setup( 27 | name='URSABench', 28 | version=read_package_variable('__version__'), 29 | description='A PyTorch-based benchmark library for MCMC', 30 | author='Adam D. Cobb, Meet Vadera, Ben Marlin, Brian Jalaian', 31 | author_email='cobb.derek.adam@gmail.com, mvadera@cs.umass.edu', 32 | packages=find_packages(), 33 | install_requires=['torch>=1.4.0', 'numpy'], 34 | url='https://github.com/reml-lab/URSABench', 35 | classifiers=['Development Status :: 1 - Planning', 'License :: OSI Approved :: MIT License', 'Programming Language :: Python :: 3.5'], 36 | license='MIT', 37 | keywords='pytorch MCMC BNN', 38 | ) 39 | -------------------------------------------------------------------------------- /URSABench/inference/projection_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from URSABench.util import unflatten_like 4 | 5 | 6 | class SubspaceModel(torch.nn.Module): 7 | def __init__(self, mean, cov_factor): 8 | super(SubspaceModel, self).__init__() 9 | self.rank = cov_factor.size(0) 10 | self.register_buffer('mean', mean) 11 | self.register_buffer('cov_factor', cov_factor) 12 | 13 | def forward(self, t): 14 | return self.mean + self.cov_factor.t() @ t 15 | 16 | 17 | class ProjectedModel(torch.nn.Module): 18 | def __init__(self, proj_params, model, projection=None, mean=None, subspace=None): 19 | super(ProjectedModel, self).__init__() 20 | self.model = model 21 | 22 | if subspace is None: 23 | self.subspace = SubspaceModel(mean, projection) 24 | else: 25 | self.subspace = subspace 26 | 27 | if mean is None and subspace is None: 28 | raise NotImplementedError('Must enter either subspace or mean') 29 | 30 | self.proj_params = proj_params 31 | 32 | def update_params(self, vec, model): 33 | vec_list = unflatten_like(likeTensorList=list(model.parameters()), vector=vec.view(1, -1)) 34 | for param, v in zip(model.parameters(), vec_list): 35 | param.detach_() 36 | param.mul_(0.0).add_(v) 37 | 38 | def forward(self, *args, **kwargs): 39 | y = self.subspace(self.proj_params) 40 | 41 | self.update_params(y, self.model) 42 | return self.model(*args, **kwargs) 43 | -------------------------------------------------------------------------------- /URSABench/trtprof/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage import io 4 | from skimage.transform import resize 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms import Normalize 7 | 8 | 9 | class MLPDataset(Dataset): 10 | def __init__(self, n_samples, n_feats): 11 | super(MLPDataset, self).__init__() 12 | self.n_samples = n_samples 13 | self.n_feats = n_feats 14 | self.labels = np.random.randint(1, size=self.n_samples) 15 | self.feats = np.random.randn(self.n_samples, self.n_feats).astype("f") 16 | 17 | def __len__(self): 18 | return self.n_samples 19 | 20 | def __getitem__(self, index): 21 | return self.feats[index], self.labels[index] 22 | 23 | 24 | class DummyDataset(Dataset): 25 | def __init__(self, img_hw, n_samples, dtype): 26 | self.n_samples = n_samples 27 | self.dtype = dtype 28 | 29 | url = "https://images.dog.ceo/breeds/retriever-golden/n02099601_3004.jpg" 30 | self.img = self.transform(resize(io.imread(url), (img_hw, img_hw))) 31 | self.img_labels = np.random.randint(0, 2, size=self.n_samples) 32 | 33 | def __len__(self): 34 | return self.n_samples 35 | 36 | def transform(self, img): 37 | norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 38 | result = norm(torch.from_numpy(img).transpose(0, 2).transpose(1, 2)) 39 | return np.array(result, dtype=self.dtype) 40 | 41 | def __getitem__(self, idx): 42 | label = self.img_labels[idx] 43 | return self.img, label 44 | -------------------------------------------------------------------------------- /URSABench/inference/sgld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from URSABench.util import reset_model 4 | from . import optimSGHMC 5 | from .sghmc import SGHMC 6 | 7 | 8 | class SGLD(SGHMC): 9 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 10 | device=torch.device('cpu')): 11 | ''' 12 | :param hyperparameters: Hyperparameters include {'lr', 'prior_std', 'num_samples'} 13 | :param model: Pytorch model to run SGLD on. 14 | :param train_loader: DataLoader for train data 15 | :param model_loss: Loss function to use for the model. (e.g.: 'multi_class_linear_output') 16 | :param device: Device on which model is present (e.g.: torch.device('cpu')) 17 | ''' 18 | if hyperparameters == None: 19 | # Initialise as some default values 20 | hyperparameters = {'lr': 0.001,'prior_std': 10, 'num_samples': 2, 'alpha': 0.1, 'burn_in_epochs':10} 21 | 22 | hyperparameters['alpha'] = 1. 23 | super(SGLD, self).__init__(hyperparameters, model, train_loader, model_loss, device) 24 | 25 | def update_hyp(self, hyperparameters): 26 | self.lr = hyperparameters['lr'] 27 | self.prior_std = hyperparameters['prior_std'] 28 | self.num_samples = hyperparameters['num_samples'] 29 | self.alpha = 1. 30 | self.burn_in_epochs = hyperparameters['burn_in_epochs'] 31 | self.model = reset_model(self.model) 32 | self.optimizer = optimSGHMC(params=self.model.parameters(), lr=self.lr, momentum=1 - self.alpha, 33 | num_training_samples=self.dataset_size, weight_decay=1 / (self.prior_std ** 2)) 34 | self.burnt_in = False 35 | self.epochs_run = 0 36 | -------------------------------------------------------------------------------- /URSABench/inference/csgld.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from URSABench.util import reset_model 4 | from . import optimSGHMC 5 | from .csghmc import cSGHMC 6 | 7 | 8 | # TODO: Add docstrings for classes below. 9 | class cSGLD(cSGHMC): 10 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 11 | device=torch.device('cpu')): 12 | ''' 13 | :param hyperparameters: Hyperparameters include {'lr', 'prior_std', 'num_samples'} 14 | :param model: Pytorch model to run SGLD on. 15 | :param train_loader: DataLoader for train data 16 | :param model_loss: Loss function to use for the model. (e.g.: 'multi_class_linear_output') 17 | ''' 18 | if hyperparameters == None: 19 | # Initialise as some default values 20 | hyperparameters = {'lr_0': 0.001000, 'prior_std': 10.1000, 'num_samples_per_cycle': 5, 'cycle_length': 20, 'burn_in_epochs': 5, 'num_cycles': 10, 'alpha': 1.,} 21 | hyperparameters['alpha'] = 1. 22 | super(cSGLD, self).__init__(hyperparameters, model, train_loader, model_loss, device) 23 | 24 | def update_hyp(self, hyperparameters): 25 | self.lr_0 = hyperparameters['lr_0'] 26 | self.prior_std = hyperparameters['prior_std'] 27 | self.num_samples_per_cycle = hyperparameters['num_samples_per_cycle'] 28 | self.cycle_length = hyperparameters['cycle_length'] 29 | self.alpha = 1. 30 | self.burn_in_epochs = hyperparameters['burn_in_epochs'] 31 | self.num_cycles = hyperparameters['num_cycles'] 32 | self.model = reset_model(self.model) 33 | self.optimizer = optimSGHMC(params=self.model.parameters(), lr=self.lr_0, momentum=1 - self.alpha, 34 | num_training_samples=self.dataset_size, weight_decay=1 / (self.prior_std ** 2)) 35 | self.burnt_in = False 36 | self.epochs_run = 0 37 | -------------------------------------------------------------------------------- /URSABench/inference/inference_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | 5 | from URSABench.util import get_loss_criterion 6 | 7 | if 'hamiltorch' not in sys.modules: 8 | print('You have not imported the hamiltorch module,\nrun: pip install git+https://github.com/AdamCobb/hamiltorch') 9 | 10 | 11 | # TODO: Add docstrings for classes below. 12 | class _Inference: 13 | """ Base class of inference wrapper """ 14 | 15 | def __init__(self, hyperparameters, model=None, train_loader=None, device=torch.device('cpu'), 16 | model_loss='multi_class_linear_output'): 17 | """ 18 | Inputs: 19 | model: torch.nn.model (TODO Check this is flexible to other models) 20 | hyperparameters: list of hyperparameters in order expected by inference engine e.g. [[0.0], [2., 4.]] 21 | train_loader: torch.utils.data.DataLoader 22 | device: default 'cpu' 23 | """ 24 | 25 | self.model = model 26 | self.hyperparameters = hyperparameters 27 | self.train_loader = train_loader 28 | self.device = device 29 | self.loss_criterion = get_loss_criterion(loss=model_loss) 30 | 31 | def update_hyp(self, hyperparameters): 32 | """ Update hyperparameters """ 33 | raise NotImplementedError 34 | 35 | def sample_iterative(self): 36 | """ Sample in an online manner (return a single sample per call) """ 37 | raise NotImplementedError 38 | 39 | def sample(self): 40 | """ 41 | Sample multiple samples 42 | Output: Torch Tensor shape (No Samples, No Parameters) 43 | """ 44 | raise NotImplementedError 45 | 46 | def compute_val_loss(self, val_loader=None): 47 | with torch.no_grad(): 48 | num_val_samples = 0 49 | total_loss = 0. 50 | self.model.eval() 51 | for batch_idx, (batch_data, batch_labels) in enumerate(val_loader): 52 | batch_data_logits = self.model(batch_data.to(self.device)) 53 | batch_loss = self.loss_criterion(batch_data_logits, batch_labels.to(self.device)) 54 | num_val_samples += len(batch_data) 55 | total_loss += batch_loss.item() * len(batch_data) 56 | return total_loss / num_val_samples 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 The Laboratory for Robust and Efficient Machine Learning 4 | 5 | Parts of this software are based on the following repositories: 6 | - Stochastic Weight Averaging (SWA), https://github.com/timgaripov/swa, Copyright (c) 2018, Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson 7 | - A Simple Baseline for Bayesian Deep Learning, https://github.com/wjmaddox/swa_gaussian, Copyright (c) 2019, Wesley Maddox, Timur Garipov, Pavel Izmailov, Dmitry Vetrov, Andrew Gordon Wilson 8 | - Cyclical Stochastic Gradient MCMC for Bayesian Deep Learning, https://github.com/ruqizhang/csgmcmc, Copyright (c) 2019 Ruqi Zhang, Chunyuan Li, Jianyi Zhang, Changyou Chen and Andrew Gordon Wilson 9 | - Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs, https://github.com/timgaripov/dnn-mode-connectivity, Copyright (c) 2018 Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov and Andrew Gordon Wilson 10 | - PyTorch Ensembles, https://github.com/bayesgroup/pytorch-ensembles, Copyright (c) 2020 Samsung AI Center Moscow, Arsenii Ashukha, Alexander Lyzhov, Dmitry Molchanov, Dmitry Vetrov 11 | - PyTorch, https://github.com/pytorch/pytorch, Copyright (c) 2016-present, Facebook Inc 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | SOFTWARE. 30 | -------------------------------------------------------------------------------- /URSABench/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.transforms import transforms 5 | 6 | __all__ = ['MLP200MNIST', 'MLP_dropout', 'MLP400MNIST', 'MLP600MNIST'] 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, hidden_size, input_dim, num_classes): 10 | super(MLP, self).__init__() 11 | self.input_dim = input_dim 12 | self.hidden_size = hidden_size 13 | self.num_classes = num_classes 14 | self.fc1 = nn.Linear(input_dim, hidden_size) 15 | self.fc2 = nn.Linear(hidden_size, hidden_size) 16 | self.fc3 = nn.Linear(hidden_size, num_classes) 17 | 18 | def forward(self, x): 19 | x = x.view(-1, self.input_dim) 20 | x = self.fc1(x) 21 | x = self.fc2(F.relu(x)) 22 | x = self.fc3(F.relu(x)) 23 | return x 24 | 25 | class MLP_dropout(nn.Module): 26 | def __init__(self, hidden_size, input_dim, num_classes, dropout=0.2): 27 | super(MLP_dropout, self).__init__() 28 | self.input_dim= input_dim 29 | self.hidden_size = hidden_size 30 | self.num_classes = num_classes 31 | self.fc1 = nn.Linear(input_dim, hidden_size) 32 | self.fc2 = nn.Linear(hidden_size, hidden_size) 33 | self.fc3 = nn.Linear(hidden_size, num_classes) 34 | self.dropout = dropout 35 | 36 | def forward(self, x): 37 | x = x.view(-1, self.input_dim) 38 | x = self.fc1(x) 39 | x = self.fc2(F.relu(F.dropout(x, p=self.dropout))) 40 | x = self.fc3(F.relu(F.dropout(x, p=self.dropout))) 41 | return x 42 | 43 | class Base: 44 | base = MLP 45 | args = list() 46 | kwargs = dict() 47 | transform_train = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.1307,), (0.3081,)) 50 | ]) 51 | 52 | transform_test = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.1307,), (0.3081,)) 55 | ]) 56 | 57 | class Base_dropout(Base): 58 | base = MLP_dropout 59 | 60 | class MLP200MNIST(Base): 61 | kwargs = {'hidden_size': 200, 'input_dim':784} 62 | 63 | class MLP200MNIST_dropout(Base_dropout): 64 | kwargs = {'hidden_size': 200, 'input_dim':784, 'dropout': 0.2} 65 | 66 | class MLP400MNIST(Base): 67 | kwargs = {'hidden_size': 400, 'input_dim':784} 68 | 69 | class MLP600MNIST(Base): 70 | kwargs = {'hidden_size': 600, 'input_dim':784} 71 | -------------------------------------------------------------------------------- /URSABench/trtprof/pred.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # There's a memory leak somewhere in the TensorRT runtime or the PyCUDA library 4 | # or our code. As a workaround, we keep running the `run_prediction.py` script 5 | # until it exits successfully (returns 0 exit code). 6 | max_retry=5 7 | counter=0 8 | 9 | run() { 10 | echo "$@" 11 | python3 run_prediction.py "$@" 12 | } 13 | 14 | run_to_finish() { 15 | run "$@" 16 | local last_run_exit_code=$? 17 | while [ $last_run_exit_code -ne 0 ]; do 18 | 19 | if [ $last_run_exit_code -eq 4 ]; then 20 | # exit code 4 means finish one batch of models, not an error 21 | printf "\nLast exit code: %d (%s)\n\n" $last_run_exit_code "finished one ensemble but not all" 22 | elif [ $last_run_exit_code -eq 3 ]; then 23 | # Sometimes, somehow the symbolic links of the dynamic libs in 24 | # `/usr/lib/aarch64-linux-gnu` are messed up / modified by someone after 25 | # running the program. I haven't figured out how that happens and who 26 | # does it. When that happens, an OSError exception will be raised 27 | # because certain libs cannot be found. We will catch this exception in 28 | # python and exit with code 3. As a workaround, when we see the exit 29 | # code 3, a potential symbolic link issue, we run `ldconfig` to 30 | # reconcile the links. 31 | printf "\nLast exit code: %d (%s)\n\n" $last_run_exit_code "potential symbolic link issue" 32 | printf "Symbolic links broken. Trying to fix that with ldconfig.\n" 33 | ldconfig 34 | printf "Retry %s\n" $counter 35 | elif [ $last_run_exit_code -eq 137 ]; then 36 | printf "\nLast exit code: %d (%s)\n\n" $last_run_exit_code "potential OOM" 37 | printf "Retry %s\n" $counter 38 | else 39 | printf "\nLast exit code: %d (%s)\n\n" $last_run_exit_code "" 40 | printf "Retry %s\n" $counter 41 | fi 42 | 43 | run "$@" 44 | local last_run_exit_code=$? 45 | 46 | [[ $last_run_exit_code -eq 0 ]] && exit 0 47 | 48 | [[ $last_run_exit_code -ne 4 ]] && ((counter++)) 49 | if [ $counter -eq $max_retry ]; then 50 | printf "Max retry reached. Stop trying. Last exit code: %d" $last_run_exit_code 51 | exit 0 52 | fi 53 | done 54 | } 55 | 56 | # 57 | # Latency profiling 58 | # 59 | run_to_finish /data/ResNet50_ImageNet trt latency ensemble 6 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # npy files 132 | *.npy 133 | # pytorch models 134 | *.pt 135 | # pytorch environment 136 | /URSABenchenv/ 137 | # log files 138 | *.log 139 | *.out 140 | *.err 141 | -------------------------------------------------------------------------------- /URSABench/trtprof/make_table.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import pathlib 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | def get_results_files(): 10 | return glob.glob("/data/*latency.ensemble*.json") 11 | 12 | 13 | def parse_file_name(file_name): 14 | model_dataset, precision, _, ensemble, _ = tuple( 15 | pathlib.Path(file_name).name.split(".") 16 | ) 17 | model, dataset = tuple(model_dataset.split("_")) 18 | if precision == "trt": 19 | precision = "FP32+FP16" 20 | else: 21 | precision = "FP32" 22 | ensemble = int(ensemble.lstrip("ensemble")) 23 | return model, dataset, precision, ensemble 24 | 25 | 26 | def get_average_latency(file_name, n_ensemble): 27 | with open(file_name) as f: 28 | results = json.load(f) 29 | results = {k: v for k, v in results.items() if len(k.split(",")) == n_ensemble} 30 | lat_mean = np.mean([v["latency_mean"] for v in results.values()]) 31 | std_mean = np.mean([v["latency_std"] for v in results.values()]) 32 | return lat_mean, std_mean 33 | 34 | 35 | if __name__ == "__main__": 36 | dataset_keyword = "ImageNet" 37 | #dataset_keyword = "MNIST" 38 | #dataset_keyword = "CIFAR" 39 | files = get_results_files() 40 | files = [x for x in files if dataset_keyword in x] 41 | results = [] 42 | for file_name in files: 43 | model, dataset, precision, ensemble = parse_file_name(file_name) 44 | lat_mean, std_mean = get_average_latency(file_name, ensemble) 45 | results.append( 46 | { 47 | "dataset": dataset, 48 | "model": model, 49 | "precision": precision, 50 | "1_latency_mean": lat_mean, 51 | "2_latency_std": std_mean, 52 | "3_ensemble": ensemble, 53 | } 54 | ) 55 | results = pd.DataFrame(results) 56 | # only keep the largest ensemble 57 | indices = results.groupby(["dataset", "model", "precision"], sort=True)[ 58 | "3_ensemble" 59 | ].idxmax() 60 | results = results.loc[indices] 61 | 62 | results["1_latency_mean"] = results["1_latency_mean"].apply(lambda x: f"{x:.4f}s") 63 | results["2_latency_std"] = results["2_latency_std"].apply( 64 | lambda x: f"$\\pm$ {x:.4f}s" 65 | ) 66 | results["3_ensemble"] = results["3_ensemble"].apply(lambda x: f"({x} models)") 67 | 68 | # format for latex table 69 | results = pd.melt(results, id_vars=["dataset", "model", "precision"]) 70 | results = results.pivot( 71 | index=["precision", "variable"], 72 | columns=["dataset", "model"], 73 | values="value", 74 | ) 75 | results.index = results.index.droplevel(1) 76 | results.to_latex( 77 | f"/data/result_latency_{dataset_keyword}.tex", 78 | caption="Caption here.", 79 | # index=False, 80 | escape=False, 81 | column_format="rccccc", 82 | multicolumn_format="c", 83 | ) 84 | -------------------------------------------------------------------------------- /URSABench/inference/optim_sghmc.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | 7 | class optimSGHMC(Optimizer): 8 | 9 | def __init__(self, params, lr=required, momentum=0, dampening=0, 10 | weight_decay=0, num_training_samples=None, nesterov=False): 11 | if lr is not required and lr < 0.0: 12 | raise ValueError("Invalid learning rate: {}".format(lr)) 13 | if momentum < 0.0: 14 | raise ValueError("Invalid momentum value: {}".format(momentum)) 15 | if weight_decay < 0.0: 16 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 17 | 18 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 19 | weight_decay=weight_decay, nesterov=nesterov, num_training_samples=num_training_samples) 20 | if nesterov and (momentum <= 0 or dampening != 0): 21 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 22 | super(optimSGHMC, self).__init__(params, defaults) 23 | 24 | def __setstate__(self, state): 25 | super(optimSGHMC, self).__setstate__(state) 26 | for group in self.param_groups: 27 | group.setdefault('nesterov', False) 28 | 29 | @torch.no_grad() 30 | def step(self, add_langevin_noise=True, closure=None): 31 | loss = None 32 | if closure is not None: 33 | with torch.enable_grad(): 34 | loss = closure() 35 | 36 | for group in self.param_groups: 37 | weight_decay = group['weight_decay'] 38 | momentum = group['momentum'] 39 | dampening = group['dampening'] 40 | nesterov = group['nesterov'] 41 | num_training_samples = group['num_training_samples'] 42 | 43 | for p in group['params']: 44 | if p.grad is None: 45 | continue 46 | d_p = p.grad 47 | if weight_decay != 0: 48 | d_p = d_p.add(p, alpha=weight_decay / num_training_samples) 49 | if momentum != 0: 50 | param_state = self.state[p] 51 | if 'momentum_buffer' not in param_state: 52 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 53 | buf.mul_(momentum).add_(d_p, alpha=-group['lr']) 54 | else: 55 | buf = param_state['momentum_buffer'] 56 | buf.mul_(momentum).add_(d_p, alpha=-group['lr']) 57 | if nesterov: 58 | d_p = d_p.add(buf, alpha=momentum) 59 | else: 60 | d_p = buf 61 | else: 62 | d_p = d_p.mul(-group['lr']) 63 | if add_langevin_noise: 64 | d_p = d_p.add(torch.randn_like(d_p) * math.sqrt(2 * (1 - momentum) * group['lr']) / (num_training_samples)) 65 | p.add_(d_p) 66 | if momentum != 0: 67 | param_state['momentum_buffer'] = d_p 68 | return loss 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # URSABench 2 | This repository contains the PyTorch implementation for the paper "URSABench: A System for Comprehensive Benchmarking of Bayesian Deep Neural Network Models and Inference methods" by [Meet P. Vadera](https://meetvadera.github.io), [Jinyang Li](https://scholar.google.com/citations?hl=en&user=VbeL3UUAAAAJ), [Adam D. Cobb](https://adamcobb.github.io/), [Brian Jalaian](https://brianjalaian.netlify.app/), [Tarek Abdelzaher](http://abdelzaher.cs.illinois.edu/) and [Benjamin M. Marlin](https://people.cs.umass.edu/~marlin). 3 | 4 | This paper will be presented at [MLSys '22](https://mlsys.org/). An initial version of this paper was presented at ICML '20 Workshop on [Uncertainty and Robustness in Deep Learning](https://sites.google.com/view/udlworkshop2020/home). 5 | 6 | ## Folder Structure & Usage 7 | 8 | URSABench (`URSABench/`) consists of the following main components: 9 | 10 | * `inference/`: This folder consists of all the approximate inference techniques. The `inference/inference_base.py` lays out the basic functions that are used for all inference methods. 11 | 12 | * `hyperopt/`: This is the hyperparameter optimization module. There are three main hyperparameter optimization classes included: `RandomSearch`, `GridSearch`, and `BayesOpt`. 13 | 14 | * `models/`: This contains some pre-defined model architectures. 15 | 16 | * `tasks/`: This module contains evaluation tasks. Files with the suffix `_distilled.py` contain tasks for the distilled models. 17 | 18 | * `trtprof/`: This module contains the code for run-time latency profiling using ONNX and NVIDIA TensorRT optimization. Note that the TensorRT optimizations are done on NVIDIA Jetson devices directly. 19 | 20 | We provide a notebook under `examples/` to illustrate how to use URSABench over a standard PyTorch model. This notebook is also available on Google Colab: https://colab.research.google.com/drive/174Urpg2nAc8C4LgBsynt8oiNwjvxJ9Yh?usp=sharing. 21 | 22 | ## Code references: 23 | 24 | * Model implementations: 25 | - PreResNet: https://github.com/bearpaw/pytorch-classification 26 | - WideResNet: https://github.com/meliketoy/wide-resnet.pytorch 27 | 28 | * The included inference schemes have been adapted from the following repos: 29 | - SWA https://github.com/timgaripov/swa/ 30 | - SWAG https://github.com/wjmaddox/swa_gaussian 31 | 32 | * For HMC, we use https://github.com/AdamCobb/hamiltorch. 33 | * Some metrics incorporate code from https://github.com/bayesgroup/pytorch-ensembles 34 | 35 | Please cite our work if you find this approach useful in your research: 36 | ```bibtex 37 | @inproceedings{MLSYS2022_3ef81541, 38 | author = {Vadera, Meet P. and Li, Jinyang and Cobb, Adam and Jalaian, Brian and Abdelzaher, Tarek and Marlin, Benjamin}, 39 | booktitle = {Proceedings of Machine Learning and Systems}, 40 | editor = {D. Marculescu and Y. Chi and C. Wu}, 41 | pages = {217--237}, 42 | title = {URSABench: A System for Comprehensive Benchmarking of Bayesian Deep Neural Network Models and Inference methods}, 43 | url = {https://proceedings.mlsys.org/paper/2022/file/3ef815416f775098fe977004015c6193-Paper.pdf}, 44 | volume = {4}, 45 | year = {2022}, 46 | bdsk-url-1 = {https://proceedings.mlsys.org/paper/2022/file/3ef815416f775098fe977004015c6193-Paper.pdf}} 47 | ``` 48 | 49 | ## Acknowledgements 50 | 51 | Research reported in this paper was sponsored in part by the CCDC Army Research Laboratory under Cooperative Agreement W911NF-17-2-0196 (ARL IoBT CRA). The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the Army Research Laboratory or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation herein. 52 | -------------------------------------------------------------------------------- /URSABench/inference/hmc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | 5 | from URSABench.util import reset_model, convert_sample_to_net 6 | from .inference_base import _Inference 7 | 8 | if 'hamiltorch' not in sys.modules: 9 | print('You have not imported the hamiltorch module,\nrun: pip install git+https://github.com/AdamCobb/hamiltorch') 10 | import hamiltorch 11 | 12 | 13 | # def make_model(sample, model): 14 | # fmodel = hamiltorch.util.make_functional(model) 15 | # params_unflattened = hamiltorch.util.unflatten(model, sample) 16 | # return lambda x : fmodel(x,params=params_unflattened) 17 | 18 | 19 | # TODO: Refactor docstrings for class below. 20 | class HMC(_Inference): 21 | """ Basic class for HMC using hamiltorch """ 22 | 23 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 24 | device=torch.device('cpu')): 25 | super(HMC, self).__init__(hyperparameters, model, train_loader, device) 26 | """ 27 | Inputs: 28 | hyperparameters: ['step_size', 'num_samples', 'L', 'prior_precision'] 29 | model_loss: Specific to output of model, default linear output (classification) 30 | """ 31 | if hyperparameters == None: 32 | # Initialise as some default values 33 | hyperparameters = {'step_size': 0.001,'num_samples': 10, 'L': 1, 'tau': 0.1, 'burn': -1, 'mass': 1.0} 34 | 35 | self.step_size = hyperparameters['step_size'] 36 | self.num_samples = hyperparameters['num_samples'] 37 | self.L = hyperparameters['L'] 38 | self.tau = hyperparameters['tau'] 39 | self.burn = hyperparameters['burn'] 40 | self.mass = hyperparameters['mass'] 41 | 42 | self.model_loss = model_loss 43 | 44 | x_train = []; 45 | y_train = [] 46 | for batch_idx, (data, target) in enumerate(train_loader): 47 | x_train.append(data.clone().to(self.device)) 48 | y_train.append(target.clone().to(self.device)) 49 | self.x = torch.cat(x_train) 50 | self.y = torch.cat(y_train) 51 | 52 | def update_hyp(self, hyperparameters): 53 | # TODO: Check the hyperparameters ar the right type 54 | self.step_size = hyperparameters['step_size'] 55 | self.num_samples = hyperparameters['num_samples'] 56 | self.L = hyperparameters['L'] 57 | self.tau = hyperparameters['tau'] 58 | self.mass = hyperparameters['mass'] 59 | self.burn = hyperparameters['burn'] 60 | self.model = reset_model(self.model) 61 | 62 | def sample(self, debug = False): 63 | if issubclass(self.model.__class__, torch.nn.Module): 64 | tau_list = [] 65 | for w in self.model.parameters(): 66 | tau_list.append(self.tau) 67 | tau_list = torch.tensor(tau_list).to(self.device) 68 | tau_out = 1. # For Regression make this a hyperparameter 69 | params_init = hamiltorch.util.flatten(self.model).to(self.device).clone() 70 | inv_mass = (torch.ones(params_init.shape) / self.mass).to(self.device) 71 | samples = hamiltorch.sample_model(self.model, self.x, self.y, params_init=params_init, 72 | model_loss=self.model_loss, num_samples=self.num_samples, 73 | burn=-1, inv_mass=inv_mass, step_size=self.step_size, 74 | num_steps_per_sample=self.L, tau_out=tau_out, tau_list=tau_list, 75 | debug=debug) 76 | model_list = [] 77 | # Do not return initial sample 78 | if len(samples) != self.L * self.num_samples + 1: 79 | print('Warning, thinning of sampling not aligned as reject occured in first sample.') 80 | for sample in samples[self.burn*self.L::self.L]: 81 | model_list.append(convert_sample_to_net(sample, self.model)) 82 | else: 83 | raise NotImplementedError 84 | 85 | return model_list#torch.stack(samples) 86 | -------------------------------------------------------------------------------- /URSABench/trtprof/to_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load a PyTorch model from `pth` file and export it to ONNX format. 3 | 4 | To convert ONNX to TensorRT: 5 | 6 | $ /usr/src/tensorrt/bin/trtexec --onnx=/data/rn50_ensemble_2.onnx --saveEngine=/data/rn50_ensemble_2.trt --explicitBatch 7 | 8 | where `/data/rn50_ensemble_2.onnx` is the input model and `/data/rn50_ensemble_2.trt` is the output path. 9 | """ 10 | 11 | import argparse 12 | import pathlib 13 | import time 14 | from collections import OrderedDict 15 | 16 | import torch 17 | import torchvision 18 | from URSABench import models 19 | 20 | 21 | def remove_prefix(text, prefix): 22 | if text.startswith(prefix): 23 | return text[len(prefix) :] 24 | else: 25 | return text 26 | 27 | 28 | def run(pickle_file, input_shape, num_classes, model_class): 29 | 30 | output_dir = pickle_file.parent 31 | model_name = pickle_file.stem 32 | onnx_filename = output_dir / f"{model_name}.onnx" 33 | 34 | print(f"Loading {args.pickle_file} ... ", end="", flush=True) 35 | t0 = time.perf_counter() 36 | if model_class == "ResNet_ImageNet": 37 | model = torchvision.models.resnet50().to("cuda") 38 | # model = torch.nn.DataParallel(model) 39 | checkpoint = torch.load(pickle_file, map_location="cuda") 40 | checkpoint = OrderedDict( 41 | [(remove_prefix(k, "module."), v) for k, v in checkpoint.items()] 42 | ) 43 | model.load_state_dict(checkpoint) 44 | elif model_class: 45 | model = getattr(models, model_class) 46 | model = model.base(*model.args, num_classes=num_classes, **model.kwargs).to( 47 | "cuda" 48 | ) 49 | checkpoint = torch.load(pickle_file, map_location="cuda") 50 | model.load_state_dict(checkpoint) 51 | else: 52 | model = torch.load(pickle_file, map_location=torch.device("cuda")) 53 | t1 = time.perf_counter() 54 | print(f"done in {t1-t0:.2f}s") 55 | 56 | dummy_input = torch.randn(*input_shape).cuda() 57 | 58 | print("Exporting to ONNX...", end="", flush=True) 59 | t0 = time.perf_counter() 60 | with torch.no_grad(): 61 | model.eval() 62 | torch.onnx.export( 63 | model, 64 | dummy_input, 65 | onnx_filename, 66 | # store the trained parameter weights inside the model file 67 | export_params=True, 68 | # the ONNX version to export the model to 69 | opset_version=11, 70 | # whether to execute constant folding for optimization 71 | do_constant_folding=True, 72 | # the model's input names 73 | input_names=["input"], 74 | # the model's output names 75 | output_names=["output"], 76 | verbose=False, 77 | # dynamic_axes={ 78 | # # variable length axes 79 | # "input": {0: "batch_size"}, 80 | # "output": {0: "batch_size"}, 81 | # }, 82 | ) 83 | t1 = time.perf_counter() 84 | print(f"done in {t1-t0:.2f}s") 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser(description="pth -> onnx") 89 | parser.add_argument( 90 | "pickle_file", 91 | type=pathlib.Path, 92 | help="Path to PyTorch pickle file.", 93 | ) 94 | parser.add_argument( 95 | "-s", 96 | "--input_shape", 97 | type=int, 98 | nargs="+", 99 | help="Shape of input. The first int is batch size. For conv2d, NCHW.", 100 | ) 101 | parser.add_argument( 102 | "--model_class", 103 | nargs="?", 104 | type=str, 105 | help="Class of the model", 106 | ) 107 | parser.add_argument( 108 | "--num_classes", type=int, help="Number of classification classes." 109 | ) 110 | parser.set_defaults(model_class=None, num_classes=10) 111 | args = parser.parse_args() 112 | print(args) 113 | run(args.pickle_file, args.input_shape, args.num_classes, args.model_class) 114 | -------------------------------------------------------------------------------- /URSABench/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | from torchvision.transforms import transforms 7 | 8 | __all__ = ['ResNet20', 'ResNet32', 'ResNet44', 'ResNet56', 'ResNet110', 'ResNet1202'] 9 | 10 | 11 | def _weights_init(m): 12 | classname = m.__class__.__name__ 13 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight) 15 | 16 | 17 | class LambdaLayer(nn.Module): 18 | def __init__(self, lambd): 19 | super(LambdaLayer, self).__init__() 20 | self.lambd = lambd 21 | 22 | def forward(self, x): 23 | return self.lambd(x) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, in_planes, planes, stride=1, option='A'): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | 36 | self.shortcut = nn.Sequential() 37 | if stride != 1 or in_planes != planes: 38 | if option == 'A': 39 | """ 40 | For CIFAR10 ResNet paper uses option A. 41 | """ 42 | self.shortcut = LambdaLayer(lambda x: 43 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 44 | 0)) 45 | elif option == 'B': 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 48 | nn.BatchNorm2d(self.expansion * planes) 49 | ) 50 | 51 | def forward(self, x): 52 | out = F.relu(self.bn1(self.conv1(x))) 53 | out = self.bn2(self.conv2(out)) 54 | out += self.shortcut(x) 55 | out = F.relu(out) 56 | return out 57 | 58 | 59 | class ResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=10): 61 | super(ResNet, self).__init__() 62 | self.in_planes = 16 63 | 64 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(16) 66 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 67 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 68 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 69 | self.linear = nn.Linear(64, num_classes) 70 | 71 | self.apply(_weights_init) 72 | 73 | def _make_layer(self, block, planes, num_blocks, stride): 74 | strides = [stride] + [1] * (num_blocks - 1) 75 | layers = [] 76 | for stride in strides: 77 | layers.append(block(self.in_planes, planes, stride)) 78 | self.in_planes = planes * block.expansion 79 | 80 | return nn.Sequential(*layers) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = self.layer1(out) 85 | out = self.layer2(out) 86 | out = self.layer3(out) 87 | out = F.avg_pool2d(out, out.size()[3]) 88 | out = out.view(out.size(0), -1) 89 | out = self.linear(out) 90 | return out 91 | 92 | 93 | class Base: 94 | base = ResNet 95 | args = list() 96 | kwargs = dict() 97 | transform_train = transforms.Compose([ 98 | transforms.RandomHorizontalFlip(), 99 | transforms.Resize(32), 100 | transforms.RandomCrop(32, padding=4), 101 | transforms.ToTensor(), 102 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 103 | ]) 104 | 105 | transform_test = transforms.Compose([ 106 | transforms.Resize(32), 107 | transforms.ToTensor(), 108 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 109 | ]) 110 | 111 | 112 | class ResNet20(Base): 113 | kwargs = {'block': BasicBlock, 'num_blocks': [3, 3, 3]} 114 | 115 | 116 | class ResNet32(Base): 117 | kwargs = {'block': BasicBlock, 'num_blocks': [5, 5, 5]} 118 | 119 | 120 | class ResNet44(Base): 121 | kwargs = {'block': BasicBlock, 'num_blocks': [7, 7, 7]} 122 | 123 | 124 | class ResNet56(Base): 125 | kwargs = {'block': BasicBlock, 'num_blocks': [9, 9, 9]} 126 | 127 | 128 | class ResNet110(Base): 129 | kwargs = {'block': BasicBlock, 'num_blocks': [18, 18, 18]} 130 | 131 | 132 | class ResNet1202(Base): 133 | kwargs = {'block': BasicBlock, 'num_blocks': [200, 200, 200]} 134 | -------------------------------------------------------------------------------- /URSABench/trtprof/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definitions to load `pth` files. 3 | """ 4 | 5 | import logging 6 | 7 | import torch 8 | from torch import nn 9 | from URSABench import models 10 | from URSABench.models.wideresnet import WideResNet_dropout 11 | 12 | logger = logging.getLogger("URSABench") 13 | logger.setLevel(logging.DEBUG) 14 | if not logger.handlers: 15 | ch = logging.StreamHandler() 16 | ch.setLevel(logging.DEBUG) 17 | formatter = logging.Formatter("%(levelname)s %(pathname)s:%(lineno)d] %(message)s") 18 | ch.setFormatter(formatter) 19 | logger.addHandler(ch) 20 | 21 | 22 | class MLP(torch.nn.Module): 23 | def __init__(self, input_size, hidden_size): 24 | super(MLP, self).__init__() 25 | self.input_size = input_size 26 | self.hidden_size = hidden_size 27 | self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size) 28 | self.relu = torch.nn.ReLU() 29 | self.fc2 = torch.nn.Linear(self.hidden_size, 1) 30 | self.sigmoid = torch.nn.Sigmoid() 31 | 32 | def forward(self, x): 33 | hidden = self.fc1(x) 34 | relu = self.relu(hidden) 35 | output = self.fc2(relu) 36 | output = self.sigmoid(output) 37 | return output 38 | 39 | 40 | class MLPEnsemble(torch.nn.Module): 41 | def __init__(self, input_size, hidden_size): 42 | super(MLPEnsemble, self).__init__() 43 | self.input_size = input_size 44 | self.hidden_size = hidden_size 45 | self.fc1_list = nn.ModuleList( 46 | [nn.Linear(self.input_size, self.hidden_size) for i in range(30)] 47 | ) 48 | # self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size) 49 | self.relu_list = nn.ModuleList([torch.nn.ReLU() for i in range(30)]) 50 | self.fc2_list = nn.ModuleList( 51 | [nn.Linear(self.hidden_size, 1) for i in range(30)] 52 | ) 53 | self.sigmoid_list = nn.ModuleList([torch.nn.Sigmoid() for i in range(30)]) 54 | 55 | def forward(self, x): 56 | hidden_list = [fc1(x) for fc1 in self.fc1_list] 57 | relu_list = [relu(hidden) for relu, hidden in zip(self.relu_list, hidden_list)] 58 | output_list = [fc2(relu) for fc2, relu in zip(self.fc2_list, relu_list)] 59 | output_list = [ 60 | sigmoid(output) for sigmoid, output in zip(self.sigmoid_list, output_list) 61 | ] 62 | return torch.stack(output_list).mean(dim=0) 63 | 64 | 65 | class MLPEnsemble2(torch.nn.Module): 66 | def __init__(self, input_size, hidden_size): 67 | super(MLPEnsemble2, self).__init__() 68 | self.module_list = nn.ModuleList( 69 | [MLP(input_size, hidden_size) for i in range(30)] 70 | ) 71 | 72 | def forward(self, x): 73 | for mlp in self.module_list: 74 | for params in mlp.parameters(): 75 | params.requires_grad = False 76 | output_list = [mlp(x) for mlp in self.module_list] 77 | return torch.stack(output_list).mean(dim=0) 78 | 79 | 80 | wrn_cfg = getattr(models, "WideResNet28x10") 81 | 82 | 83 | class WRNEnsemble2(torch.nn.Module): 84 | def __init__(self): 85 | super(WRNEnsemble2, self).__init__() 86 | self.module_list = [ 87 | wrn_cfg.base(*wrn_cfg.args, **wrn_cfg.kwargs, num_classes=10) 88 | for i in range(2) 89 | ] 90 | 91 | def forward(self, x): 92 | for m in self.module_list: 93 | for params in m.parameters(): 94 | params.requires_grad = False 95 | output_list = [model(x) for model in self.module_list] 96 | return torch.stack(output_list).mean(dim=0) 97 | 98 | 99 | rn50_cfg = getattr(models, "INResNet50") 100 | 101 | 102 | class ResNet50Ensemble2(torch.nn.Module): 103 | def __init__(self): 104 | super(ResNet50Ensemble2, self).__init__() 105 | self.module_list = [ 106 | rn50_cfg.base(*rn50_cfg.args, **rn50_cfg.kwargs, num_classes=10) 107 | for i in range(3) 108 | ] 109 | 110 | def forward(self, x): 111 | for m in self.module_list: 112 | for params in m.parameters(): 113 | params.requires_grad = False 114 | output_list = [model(x) for model in self.module_list] 115 | return torch.stack(output_list).mean(dim=0) 116 | 117 | 118 | class WRNDEnsemble2(torch.nn.Module): 119 | def __init__(self): 120 | super(WRNDEnsemble2, self).__init__() 121 | self.module_list = [WideResNet_dropout(num_classes=10) for i in range(3)] 122 | 123 | def forward(self, x): 124 | for m in self.module_list: 125 | for params in m.parameters(): 126 | params.requires_grad = False 127 | output_list = [model(x) for model in self.module_list] 128 | return torch.stack(output_list).mean(dim=0) 129 | -------------------------------------------------------------------------------- /URSABench/run_seq_hypOpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import warnings 4 | 5 | import torch 6 | 7 | from URSABench import models, inference, tasks, datasets, util, hyperOptimization 8 | 9 | warnings.filterwarnings('ignore') 10 | 11 | parser = argparse.ArgumentParser(description='Run sequential hyperparameter optimisation') 12 | 13 | 14 | parser.add_argument("--domain", required=True, metavar='PATH', 15 | help="Path to json file containing domain of hyperparams", type=lambda x: util.json_open_from_file(parser, x)) 16 | parser.add_argument('--hyper_opt', type=str, default='BayesOpt', help='Hyperparameter Optimisation Scheme (default: BayesOpt)') 17 | parser.add_argument('--verbose', type=int, default=1, help='Whether to print each iteration (default: 1)') 18 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)') 19 | parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH', 20 | help='path to datasets location (default: None)') 21 | parser.add_argument('--save_path', type=str, default=None, required=True, metavar='PATH', 22 | help='path to file to store results (default: None)') 23 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)') 24 | parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)') 25 | parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL', 26 | help='model name (default: None)') 27 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 28 | parser.add_argument('--inference_method', type=str, default='HMC', help='Inference Method (default: HMC)') 29 | parser.add_argument('--task', type=str, default='Prediction', help='Downstream task to evaluate (default: Prediction)') 30 | parser.add_argument('--validation', type = float, default= 0.2, help='Proportation of training used as validation (default: Prediction)') 31 | parser.add_argument('--split_classes', type=int, default=None) 32 | parser.add_argument('--device_num', type=int, default=0, help='Device number to select (default: 0)') 33 | # BayesOpt Specific 34 | parser.add_argument('--N_evaluations', type=int, default='10', help='Number of evaluations for BayesOpt (default: 10)') 35 | parser.add_argument('--init_evaluations', type=int, default='10', help='Number of randomly drawn evaluations for initialiation of BayesOpt (default: 10)') 36 | parser.add_argument('--time_limit', type=float, default='Inf', help='Time limit in seconds before stopping BayesOpt (default: Inf)') 37 | 38 | args = parser.parse_args() 39 | util.set_random_seed(args.seed) 40 | if torch.cuda.is_available(): 41 | args.device = torch.device('cuda') 42 | torch.cuda.set_device(args.device_num) 43 | else: 44 | args.device = torch.device('cpu') 45 | model_cfg = getattr(models, args.model) 46 | loaders, num_classes = datasets.loaders( 47 | args.dataset, 48 | args.data_path, 49 | args.batch_size, 50 | args.num_workers, 51 | transform_train=model_cfg.transform_train, 52 | transform_test=model_cfg.transform_test, 53 | shuffle_train=True, 54 | use_validation=True, 55 | val_size=args.validation, 56 | split_classes=args.split_classes 57 | ) 58 | train_loader = loaders['train'] 59 | test_loader = loaders['test'] 60 | # loaders['train'].dataset.data = loaders['train'].dataset.data[:300] 61 | # loaders['train'].dataset.targets = loaders['train'].dataset.targets[:300] 62 | 63 | num_classes = int(num_classes) 64 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs).to(args.device) 65 | inference_method = getattr(inference, args.inference_method) 66 | inference_object = inference_method(hyperparameters=None, model=model, train_loader=train_loader, 67 | device=args.device) 68 | task_method = getattr(tasks, args.task) 69 | task_data_loader = {'in_distribution_test': test_loader} 70 | task_object = task_method(dataloader=task_data_loader, num_classes=num_classes, device=args.device, metric_list=['ll']) 71 | 72 | hyper_opt_class = getattr(hyperOptimization, args.hyper_opt) 73 | 74 | if args.hyper_opt == 'BayesOpt': 75 | hyper_opt = hyper_opt_class(task_object, args.domain, inference_object, time_limit = args.time_limit, init_evaluations=args.init_evaluations, N_evaluations=args.N_evaluations, iterative_mode=False, seed = args.seed) 76 | best_hyp, max_obj, hyp_list, best_Y = hyper_opt.run(verbose=args.verbose, return_all = True, initialisation='RandomSearch', save_path = args.save_path) 77 | 78 | best_hyp = util.make_dic_json_format(best_hyp) 79 | 80 | with open(args.save_path + '.json', 'w') as fout: 81 | json.dump(best_hyp, fout) 82 | 83 | results = {'best_hyp':best_hyp, 'best_Y': best_Y, 'time': hyper_opt.time, 'args': args} 84 | 85 | torch.save(results, args.save_path + '_args.npy') 86 | -------------------------------------------------------------------------------- /URSABench/run_par_hypOpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import warnings 4 | 5 | import torch 6 | 7 | from URSABench import models, inference, tasks, datasets, util, hyperOptimization 8 | 9 | warnings.filterwarnings('ignore') 10 | 11 | parser = argparse.ArgumentParser(description='Run parallel hyperparameter optimisation') 12 | 13 | 14 | parser.add_argument("--domain", required=True, metavar='PATH', 15 | help="Path to json file containing domain of hyperparams", type=lambda x: util.json_open_from_file(parser, x)) 16 | parser.add_argument('--hyper_opt', type=str, default='BayesOpt', help='Hyperparameter Optimisation Scheme (default: BayesOpt)') 17 | parser.add_argument('--verbose', type=int, default=1, help='Whether to print each iteration (default: 1)') 18 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)') 19 | parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH', 20 | help='path to datasets location (default: None)') 21 | parser.add_argument('--save_path', type=str, default=None, required=True, metavar='PATH', 22 | help='path to file to store results (default: None)') 23 | parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)') 24 | parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL', 25 | help='model name (default: None)') 26 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)') 27 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 28 | parser.add_argument('--inference_method', type=str, default='HMC', help='Inference Method (default: HMC)') 29 | parser.add_argument('--task', type=str, default='Prediction', help='Downstream task to evaluate (default: Prediction)') 30 | parser.add_argument('--validation', type = float, default= 0.2, help='Proportation of training used as validation (default: Prediction)') 31 | parser.add_argument('--split_classes', type=int, default=None) 32 | parser.add_argument('--device_num', type=int, default=0, help='Device number to select (default: 0)') 33 | # BayesOpt Specific 34 | parser.add_argument('--N_evaluations', type=int, default='10', help='Number of evaluations for BayesOpt (default: 10)') 35 | parser.add_argument('--init_evaluations', type=int, default='10', help='Number of randomly drawn evaluations for initialiation of BayesOpt (default: 10)') 36 | parser.add_argument('--time_limit', type=float, default='Inf', help='Time limit in seconds before stopping BayesOpt (default: Inf)') 37 | 38 | args = parser.parse_args() 39 | util.set_random_seed(args.seed) 40 | if torch.cuda.is_available(): 41 | args.device = torch.device('cuda') 42 | torch.cuda.set_device(args.device_num) 43 | else: 44 | args.device = torch.device('cpu') 45 | model_cfg = getattr(models, args.model) 46 | loaders, num_classes = datasets.loaders( 47 | args.dataset, 48 | args.data_path, 49 | args.batch_size, 50 | args.num_workers, 51 | transform_train=model_cfg.transform_train, 52 | transform_test=model_cfg.transform_test, 53 | shuffle_train=True, 54 | use_validation=True, 55 | val_size=args.validation, 56 | split_classes=args.split_classes 57 | ) 58 | train_loader = loaders['train'] 59 | test_loader = loaders['test'] 60 | num_classes = int(num_classes) 61 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs).to(args.device) 62 | inference_method = getattr(inference, args.inference_method) 63 | inference_object = inference_method(hyperparameters=None, model=model, train_loader=train_loader, 64 | device=args.device) 65 | task_method = getattr(tasks, args.task) 66 | task_data_loader = {'in_distribution_test': test_loader} 67 | task_object = task_method(dataloader=task_data_loader, num_classes=num_classes, device=args.device, metric_list=['ll']) 68 | 69 | hyper_opt_class = getattr(hyperOptimization, args.hyper_opt) 70 | 71 | if args.hyper_opt == 'RandomSearch': 72 | hyper_opt = hyper_opt_class(task_object, args.domain, inference_object, N_evaluations=args.N_evaluations, 73 | iterative_mode=False, seed=args.seed) 74 | command_list = hyper_opt.run_parallel(args.dataset, args.data_path, args.model, args.validation, 75 | args.inference_method, args.task, args.verbose) 76 | 77 | # import pdb; pdb.set_trace() 78 | for command in command_list: 79 | subprocess.run(command) 80 | print(command) 81 | # command = ' '.join(args) 82 | # queue_name = np.random.choice(['1080ti-short', '2080ti-short'], p=[.5, .5]) 83 | # print(queue_name) 84 | # print(command) 85 | # sbatch(command, job_name=job_name, stdout=stdout, stderr=stderr, mem='32G', cpus_per_task=1, queue=queue_name, gres='gpu:1', time='0-04:00', exclude='node172') 86 | 87 | 88 | 89 | 90 | # best_hyp = util.make_dic_json_format(best_hyp) 91 | # 92 | # with open(args.save_path + '.json', 'w') as fout: 93 | # json.dump(best_hyp, fout) 94 | # 95 | # results = {'best_hyp':best_hyp, 'best_Y': best_Y, 'time': hyper_opt.time, 'args': args} 96 | # 97 | # torch.save(results, args.save_path + '_args.npy') 98 | -------------------------------------------------------------------------------- /URSABench/inference/sgd.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import wandb 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | 7 | from URSABench.util import get_loss_criterion, reset_model 8 | from . import optimSGHMC 9 | from .inference_base import _Inference 10 | 11 | 12 | # import pyvarinf 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | 19 | class SGD(_Inference): 20 | 21 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 22 | device=torch.device('cpu')): 23 | ''' 24 | :param hyperparameters: Hyperparameters include {'lr', 'prior_std', 'num_samples'} 25 | :param model: Pytorch model to run SGLD on. 26 | :param train_loader: DataLoader for train data 27 | :param model_loss: Loss function to use for the model. (e.g.: 'multi_class_linear_output') 28 | :param device: Device on which model is present (e.g.: torch.device('cpu')) 29 | ''' 30 | if hyperparameters == None: 31 | # Initialise as some default values 32 | hyperparameters = {'lr': 0.1, 'epochs':10, 'momentum': 0.9, 'weight_decay': 0.001} 33 | 34 | super(SGD, self).__init__(hyperparameters, model, train_loader, device) 35 | self.lr = hyperparameters['lr'] 36 | self.num_samples = 1 37 | self.burn_in_epochs = hyperparameters['epochs'] 38 | self.momentum = hyperparameters['momentum'] 39 | self.model = model.to(device) 40 | 41 | self.train_loader = train_loader 42 | self.device = device 43 | self.dataset_size = len(train_loader.dataset) 44 | self.weight_decay = hyperparameters['weight_decay'] 45 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, 46 | weight_decay=self.weight_decay) 47 | self.loss_criterion = get_loss_criterion(loss=model_loss) 48 | self.burnt_in = False 49 | self.epochs_run = 0 50 | self.lr_final = self.lr / 100. 51 | self.optimizer_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max= 52 | self.burn_in_epochs + self.num_samples, eta_min=self.lr_final) 53 | 54 | def update_hyp(self, hyperparameters): 55 | self.lr = hyperparameters['lr'] 56 | self.num_samples = 1 57 | self.epochs = hyperparameters['epochs'] 58 | self.momentum = hyperparameters['momentum'] 59 | self.weight_decay = hyperparameters['weight_decay'] 60 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, 61 | weight_decay=self.weight_decay) 62 | self.model = reset_model(self.model).to(self.device) 63 | self.burnt_in = False 64 | self.epochs_run = 0 65 | self.lr_final = self.lr / 2 66 | self.optimizer_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max= 67 | self.burn_in_epochs + self.num_samples, eta_min=self.lr_final) 68 | 69 | def sample_iterative(self, val_loader=None, debug_val_loss=False, wandb_debug=False): 70 | if issubclass(self.model.__class__, torch.nn.Module): 71 | if self.burnt_in is False: 72 | epochs = self.burn_in_epochs + 1 73 | self.burnt_in = True 74 | else: 75 | epochs = 0 76 | for epoch in range(epochs): 77 | self.model.train() 78 | total_epoch_train_loss = 0. 79 | for batch_idx, (batch_data, batch_labels) in enumerate(self.train_loader): 80 | batch_data = batch_data.to(self.device) 81 | batch_labels = batch_labels.to(self.device) 82 | batch_data_logits = self.model(batch_data) 83 | self.optimizer.zero_grad() 84 | loss = self.loss_criterion(batch_data_logits, batch_labels) 85 | loss.backward() 86 | total_epoch_train_loss += loss.item() * len(batch_data) 87 | self.optimizer.step() 88 | if debug_val_loss: 89 | avg_val_loss = self.compute_val_loss(val_loader) 90 | avg_train_loss = total_epoch_train_loss / self.dataset_size 91 | metrics = { 92 | 'train_loss': avg_train_loss, 93 | 'val_loss': avg_val_loss 94 | } 95 | print(metrics) 96 | if wandb_debug: 97 | wandb.log(metrics) 98 | self.optimizer_scheduler.step() 99 | return self.model 100 | else: 101 | raise NotImplementedError 102 | 103 | def sample(self, num_samples=None, val_loader=None, debug_val_loss=False, wandb_debug=False): 104 | output_list = [] 105 | if num_samples is None: 106 | num_samples = self.num_samples 107 | if issubclass(self.model.__class__, torch.nn.Module): 108 | for i in range(num_samples): 109 | output_list.append(self.sample_iterative(val_loader=val_loader, debug_val_loss=debug_val_loss, 110 | wandb_debug=wandb_debug)) 111 | return output_list 112 | else: 113 | raise NotImplementedError 114 | -------------------------------------------------------------------------------- /URSABench/time_script.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import warnings 5 | import time 6 | 7 | import torch 8 | 9 | from URSABench import models, inference, tasks, datasets, util 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)') 16 | parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH', 17 | help='path to datasets location (default: None)') 18 | parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)') 19 | parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL', 20 | help='model name (default: None)') 21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 22 | # parser.add_argument('--inference_method', type=str, default='HMC', help='Inference Method (default: HMC)') 23 | # parser.add_argument('--hyperparams', type=str, default=None, help='Hyperparameters in JSON format (default:None)') 24 | parser.add_argument('--hyperparams_path', default=None, help="Path to json file containing hyperparams", 25 | type=str) 26 | # parser.add_argument('--task', type=str, default='Prediction', help='Downstream task to evaluate (default: Prediction)') 27 | # parser.add_argument('--split_classes', type=int, default=None) 28 | parser.add_argument('--validation', type=float, default=0.2, 29 | help='Proportation of training used as validation (default: Prediction)') 30 | parser.add_argument('--use_val', dest = 'use_val', action='store_true', help='use val dataset instead of test (default: False)') 31 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)') 32 | parser.add_argument('--save_path', type=str, default=None, required=True, metavar='PATH', 33 | help='path to file to store results (default: None)') 34 | parser.add_argument('--device_num', type=int, default=0, help='Device number to select (default: 0)') 35 | 36 | 37 | args = parser.parse_args() 38 | util.set_random_seed(args.seed) 39 | if torch.cuda.is_available(): 40 | args.device = torch.device('cuda') 41 | torch.cuda.set_device(args.device_num) 42 | else: 43 | args.device = torch.device('cpu') 44 | 45 | # import pdb; pdb.set_trace() 46 | # if args.hyperparams is None: 47 | # hyperparams = args.hyperparams_path 48 | # else: 49 | # hyperparams = json.loads(args.hyperparams) 50 | model_cfg = getattr(models, args.model) 51 | loaders, num_classes = datasets.loaders( 52 | args.dataset, 53 | args.data_path, 54 | args.batch_size, 55 | args.num_workers, 56 | transform_train=model_cfg.transform_train, 57 | transform_test=model_cfg.transform_test, 58 | shuffle_train=True, 59 | use_validation=False, 60 | val_size=args.validation 61 | ) 62 | 63 | # loaders['train'].dataset.data = loaders['train'].dataset.data[:300] 64 | # loaders['train'].dataset.targets = loaders['train'].dataset.targets[:300] 65 | 66 | train_loader = loaders['train'] 67 | test_loader = loaders['test'] 68 | num_classes = int(num_classes) 69 | 70 | inference_method_list = ['HMC', 'SGLD', 'SGHMC', 'cSGLD', 'cSGHMC', 'SWAG', 'PCA', 'MCdropout', 'SGD', 'PCASubspaceSampler'] 71 | 72 | timer_dic = {} 73 | S = 3 74 | T = 10 75 | 76 | for inference_method in inference_method_list: 77 | hyperparams = util.json_open_from_file(parser, args.hyperparams_path +inference_method+'_BO.json') 78 | 79 | print(inference_method) 80 | print('Time for ' + str(S)+ ' sample.') 81 | 82 | if inference_method == 'HMC': 83 | hyperparams['burn'] = -1 84 | # hyperparams['L'] = 1 85 | if inference_method == 'SWAG': 86 | hyperparams['burn_in_epochs'] = 1 87 | if inference_method == 'PCASubspaceSampler': 88 | hyperparams['swag_burn_in_epochs'] 89 | if inference_method == 'SGHMC' or inference_method == 'SGLD': 90 | hyperparams['burn_in_epochs'] = 0 91 | if inference_method == 'cSGHMC'or inference_method == 'cSGLD': 92 | hyperparams['burn_in_epochs'] = 0 93 | hyperparams['num_cycles'] = 1 94 | hyperparams['num_samples_per_cycle'] = S 95 | if inference_method == 'MCdropout' or inference_method == 'SGD': 96 | hyperparams['epochs'] = 0 97 | 98 | hyperparams['num_samples'] = S# 30 99 | 100 | inference_scheme = getattr(inference, inference_method) 101 | 102 | t_tensor = torch.zeros(T) 103 | for t in range(T): 104 | print('Trial: ',t) 105 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs).to(args.device) 106 | inference_object = inference_scheme(hyperparameters=hyperparams, model=model, train_loader=train_loader, 107 | device=args.device) 108 | 109 | silent_inference_method = util.silent(inference_object.sample) 110 | 111 | t0 = time.perf_counter() 112 | # inference_object.sample(val_loader=test_loader, debug_val_loss=True) 113 | model_ensemble = silent_inference_method() 114 | t1 = time.perf_counter() 115 | t_tensor[t] = t1 - t0 116 | 117 | timer_dic[inference_method + '_mean'] = t_tensor.mean() 118 | timer_dic[inference_method + '_std'] = t_tensor.std() 119 | 120 | print('Time: ', t_tensor.mean(), ' +- ', t_tensor.std()) 121 | 122 | timer_dic = util.make_dic_json_format(timer_dic) 123 | 124 | with open(args.save_path + '.json', 'w') as fout: 125 | json.dump(timer_dic, fout) 126 | 127 | # task_method = getattr(tasks, args.task) 128 | # task_data_loader = {'in_distribution_test': test_loader} 129 | # metric_list = 'ALL' 130 | # 131 | # t0 = time.perf_counter() 132 | # 133 | # silent_inference = util.silent(self.inference.sample) 134 | # samples = silent_inference() 135 | # 136 | # t1 = time.perf_counter() 137 | # self.time.append(t1-t0) 138 | -------------------------------------------------------------------------------- /URSABench/inference/sghmc.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import wandb 5 | from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR 6 | 7 | from URSABench.util import get_loss_criterion, reset_model 8 | from . import optimSGHMC 9 | from .inference_base import _Inference 10 | 11 | 12 | class SGHMC(_Inference): 13 | 14 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 15 | device=torch.device('cpu')): 16 | ''' 17 | :param hyperparameters: Hyperparameters include {'lr', 'prior_std', 'num_samples'} 18 | :param model: Pytorch model to run SGLD on. 19 | :param train_loader: DataLoader for train data 20 | :param model_loss: Loss function to use for the model. (e.g.: 'multi_class_linear_output') 21 | :param device: Device on which model is present (e.g.: torch.device('cpu')) 22 | ''' 23 | if hyperparameters == None: 24 | # Initialise as some default values 25 | hyperparameters = {'lr': 0.001,'prior_std': 10, 'num_samples': 2, 'alpha': 0.1, 'burn_in_epochs':10} 26 | 27 | super(SGHMC, self).__init__(hyperparameters, model, train_loader, device) 28 | self.lr = hyperparameters['lr'] 29 | self.prior_std = hyperparameters['prior_std'] 30 | self.num_samples = hyperparameters['num_samples'] 31 | self.alpha = hyperparameters['alpha'] 32 | self.burn_in_epochs = hyperparameters['burn_in_epochs'] 33 | self.model_loss = model_loss 34 | self.model = model 35 | self.train_loader = train_loader 36 | self.device = device 37 | self.dataset_size = len(train_loader.dataset) 38 | self.optimizer = optimSGHMC(params=self.model.parameters(), lr=self.lr, momentum=1 - self.alpha, 39 | num_training_samples=self.dataset_size, weight_decay=1 / (self.prior_std ** 2)) 40 | self.loss_criterion = get_loss_criterion(loss=model_loss) 41 | self.burnt_in = False 42 | self.epochs_run = 0 43 | self.lr_final = self.lr / 2 44 | self.optimizer_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max= 45 | (self.burn_in_epochs + self.num_samples),) 46 | # self.optimizer_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr*5, 47 | # steps_per_epoch=len(self.train_loader), 48 | # epochs=self.burn_in_epochs + self.num_samples, cycle_momentum=False) 49 | 50 | def update_hyp(self, hyperparameters): 51 | self.lr = hyperparameters['lr'] 52 | self.prior_std = hyperparameters['prior_std'] 53 | self.num_samples = hyperparameters['num_samples'] 54 | self.alpha = hyperparameters['alpha'] 55 | self.burn_in_epochs = hyperparameters['burn_in_epochs'] 56 | self.model = reset_model(self.model) 57 | self.burnt_in = False 58 | self.epochs_run = 0 59 | self.optimizer = optimSGHMC(params=self.model.parameters(), lr=self.lr, momentum=1 - self.alpha, 60 | num_training_samples=self.dataset_size, weight_decay=1 / (self.prior_std ** 2)) 61 | self.lr_final = self.lr / 2 62 | self.optimizer_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max= 63 | self.burn_in_epochs + self.num_samples, eta_min=self.lr_final) 64 | 65 | def sample_iterative(self, val_loader=None, debug_val_loss=False, wandb_debug=False): 66 | if issubclass(self.model.__class__, torch.nn.Module): 67 | if self.burnt_in is False: 68 | epochs = self.burn_in_epochs + 1 69 | self.burnt_in = True 70 | else: 71 | epochs = 1 72 | for epoch in range(epochs): 73 | self.model.train() 74 | total_epoch_train_loss = 0. 75 | for batch_idx, (batch_data, batch_labels) in enumerate(self.train_loader): 76 | batch_data = batch_data.to(self.device) 77 | batch_labels = batch_labels.to(self.device) 78 | batch_data_logits = self.model(batch_data) 79 | self.optimizer.zero_grad() 80 | loss = self.loss_criterion(batch_data_logits, batch_labels) 81 | loss.backward() 82 | total_epoch_train_loss += loss.item() * len(batch_data) 83 | if epoch > 0.8 * epochs or self.burnt_in: 84 | self.optimizer.step(add_langevin_noise=True) 85 | else: 86 | self.optimizer.step(add_langevin_noise=False) 87 | self.optimizer_scheduler.step() 88 | if debug_val_loss: 89 | avg_val_loss = self.compute_val_loss(val_loader) 90 | avg_train_loss = total_epoch_train_loss / self.dataset_size 91 | metrics = { 92 | 'train_loss': avg_train_loss, 93 | 'val_loss': avg_val_loss, 94 | 'lr': self.optimizer_scheduler.get_lr() 95 | } 96 | print(metrics) 97 | if wandb_debug: 98 | wandb.log(metrics) 99 | output_model = deepcopy(self.model.cpu()) 100 | self.model.to(self.device) 101 | return output_model 102 | else: 103 | raise NotImplementedError 104 | 105 | def sample(self, num_samples=None, val_loader=None, debug_val_loss=False, wandb_debug=False): 106 | output_list = [] 107 | if num_samples is None: 108 | num_samples = self.num_samples 109 | if issubclass(self.model.__class__, torch.nn.Module): 110 | for i in range(num_samples): 111 | output_list.append(self.sample_iterative(val_loader=val_loader, debug_val_loss=debug_val_loss, 112 | wandb_debug=wandb_debug)) 113 | return output_list 114 | else: 115 | raise NotImplementedError 116 | -------------------------------------------------------------------------------- /URSABench/inference/vi_dropout.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | # import pyvarinf 4 | import torch 5 | import wandb 6 | from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR 7 | 8 | from URSABench import models 9 | from URSABench.util import get_loss_criterion, reset_model 10 | from .inference_base import _Inference 11 | 12 | 13 | def change_to_dropout_model(model, dropout): 14 | signature = inspect.signature(model.__init__) 15 | kwargs = {} 16 | for key in signature.parameters.keys(): 17 | kwargs[key] = getattr(model, key) 18 | 19 | name = model.__class__.__name__ + '_dropout' 20 | model_cfg = getattr(models, name) 21 | model = model_cfg(dropout = 0.2, **kwargs) 22 | return model 23 | 24 | 25 | class MCdropout(_Inference): 26 | 27 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 28 | device=torch.device('cpu')): 29 | ''' 30 | :param hyperparameters: Hyperparameters include {'lr', 'prior_std', 'num_samples'} 31 | :param model: Pytorch model to run SGLD on. 32 | :param train_loader: DataLoader for train data 33 | :param model_loss: Loss function to use for the model. (e.g.: 'multi_class_linear_output') 34 | :param device: Device on which model is present (e.g.: torch.device('cpu')) 35 | ''' 36 | if hyperparameters == None: 37 | # Initialise as some default values 38 | hyperparameters = {'lr': 0.1, 'epochs':10, 'dropout': 0.2, 'lengthscale': 0.01, 'num_samples': 10, 'momentum': 0.9, 'weight_decay': 0} 39 | 40 | super(MCdropout, self).__init__(hyperparameters, model, train_loader, device) 41 | self.lr = hyperparameters['lr'] 42 | self.num_samples = hyperparameters['num_samples'] 43 | self.burn_in_epochs = hyperparameters['epochs'] 44 | self.dropout = hyperparameters['dropout'] 45 | self.momentum = hyperparameters['momentum'] 46 | self.model = change_to_dropout_model(model, self.dropout).to(device) 47 | 48 | self.train_loader = train_loader 49 | self.device = device 50 | self.dataset_size = len(train_loader.dataset) 51 | 52 | if hyperparameters['weight_decay'] != 0: 53 | self.weight_decay = hyperparameters['weight_decay'] 54 | else: 55 | self.weight_decay = hyperparameters['lengthscale'] ** 2 * (1 - self.dropout) / (2. * self.dataset_size) 56 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, 57 | weight_decay=self.weight_decay) 58 | self.loss_criterion = get_loss_criterion(loss=model_loss) 59 | self.burnt_in = False 60 | self.epochs_run = 0 61 | self.lr_final = self.lr / 100. 62 | # self.optimizer_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max= 63 | # self.burn_in_epochs + self.num_samples, eta_min=self.lr_final) 64 | self.optimizer_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr * 5, 65 | steps_per_epoch=len(self.train_loader), 66 | epochs=self.burn_in_epochs + self.num_samples) 67 | 68 | def update_hyp(self, hyperparameters): 69 | self.lr = hyperparameters['lr'] 70 | self.num_samples = hyperparameters['num_samples'] 71 | self.epochs = hyperparameters['epochs'] 72 | self.dropout = hyperparameters['dropout'] 73 | self.momentum = hyperparameters['momentum'] 74 | if hyperparameters['weight_decay'] != 0: 75 | self.weight_decay = hyperparameters['weight_decay'] 76 | else: 77 | self.weight_decay = hyperparameters['lengthscale'] ** 2 * (1 - self.dropout) / (2. * self.dataset_size) 78 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, 79 | weight_decay=self.weight_decay) 80 | self.model = reset_model(self.model).to(self.device) 81 | self.burnt_in = False 82 | self.epochs_run = 0 83 | self.lr_final = self.lr / 2 84 | self.optimizer_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max= 85 | self.burn_in_epochs + self.num_samples, eta_min=self.lr_final) 86 | 87 | def sample_iterative(self, val_loader=None, debug_val_loss=False, wandb_debug=False): 88 | if issubclass(self.model.__class__, torch.nn.Module): 89 | if self.burnt_in is False: 90 | epochs = self.burn_in_epochs + 1 91 | self.burnt_in = True 92 | else: 93 | epochs = 1 94 | for epoch in range(epochs): 95 | self.model.train() 96 | total_epoch_train_loss = 0. 97 | for batch_idx, (batch_data, batch_labels) in enumerate(self.train_loader): 98 | batch_data = batch_data.to(self.device) 99 | batch_labels = batch_labels.to(self.device) 100 | batch_data_logits = self.model(batch_data) 101 | self.optimizer.zero_grad() 102 | loss = self.loss_criterion(batch_data_logits, batch_labels) 103 | loss.backward() 104 | total_epoch_train_loss += loss.item() * len(batch_data) 105 | self.optimizer.step() 106 | self.optimizer_scheduler.step() 107 | if debug_val_loss: 108 | avg_val_loss = self.compute_val_loss(val_loader) 109 | avg_train_loss = total_epoch_train_loss / self.dataset_size 110 | metrics = { 111 | 'train_loss': avg_train_loss, 112 | 'val_loss': avg_val_loss 113 | } 114 | print(metrics) 115 | if wandb_debug: 116 | wandb.log(metrics) 117 | return self.model 118 | else: 119 | raise NotImplementedError 120 | 121 | def sample(self, num_samples=None, val_loader=None, debug_val_loss=False, wandb_debug=False): 122 | output_list = [] 123 | if num_samples is None: 124 | num_samples = self.num_samples 125 | if issubclass(self.model.__class__, torch.nn.Module): 126 | for i in range(num_samples): 127 | output_list.append(self.sample_iterative(val_loader=val_loader, debug_val_loss=debug_val_loss, 128 | wandb_debug=wandb_debug)) 129 | return output_list 130 | else: 131 | raise NotImplementedError 132 | -------------------------------------------------------------------------------- /URSABench/inference/csghmc.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import torch 5 | import wandb 6 | 7 | from URSABench.util import get_loss_criterion, reset_model 8 | from .inference_base import _Inference 9 | from .optim_sghmc import optimSGHMC 10 | 11 | 12 | # TODO: Add docstrings for classes below. 13 | class cSGHMC(_Inference): 14 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 15 | device=torch.device('cpu')): 16 | 17 | if hyperparameters == None: 18 | # Initialise as some default values 19 | hyperparameters = {'lr_0': 0.001000, 'prior_std': 10.1000, 'num_samples_per_cycle': 5, 'cycle_length': 20, 'burn_in_epochs': 5, 'num_cycles': 10, 'alpha': 1.,} 20 | 21 | super(cSGHMC, self).__init__(hyperparameters, model, train_loader, device) 22 | self.lr_0 = hyperparameters['lr_0'] 23 | self.prior_std = hyperparameters['prior_std'] 24 | self.num_samples_per_cycle = hyperparameters['num_samples_per_cycle'] 25 | self.cycle_length = hyperparameters['cycle_length'] 26 | self.alpha = hyperparameters['alpha'] 27 | self.burn_in_epochs = hyperparameters['burn_in_epochs'] 28 | self.num_cycles = hyperparameters['num_cycles'] 29 | self.batch_size = train_loader.batch_size 30 | self.num_batch = len(train_loader.dataset) / self.batch_size + 1 31 | self.num_batch = max(1, self.num_batch) 32 | self.model_loss = model_loss 33 | self.model = model 34 | self.train_loader = train_loader 35 | self.device = device 36 | self.dataset_size = len(train_loader.dataset) 37 | self.optimizer = optimSGHMC(params=self.model.parameters(), lr=self.lr_0, momentum=1 - self.alpha, 38 | num_training_samples=self.dataset_size, weight_decay=1 / (self.prior_std ** 2)) 39 | self.loss_criterion = get_loss_criterion(loss=model_loss) 40 | self.burnt_in = False 41 | self.epochs_run = 0 42 | self.total_epochs = self.cycle_length * self.num_cycles 43 | self.dataloader_batch_size = self.train_loader.batch_size 44 | self.total_iterations = self.total_epochs * self.num_batch 45 | 46 | assert ((self.cycle_length - self.burn_in_epochs - self.num_samples_per_cycle) > 0) 47 | 48 | def update_hyp(self, hyperparameters): 49 | self.lr_0 = hyperparameters['lr_0'] 50 | self.prior_std = hyperparameters['prior_std'] 51 | self.num_samples_per_cycle = hyperparameters['num_samples_per_cycle'] 52 | self.cycle_length = hyperparameters['cycle_length'] 53 | self.alpha = hyperparameters['alpha'] 54 | self.burn_in_epochs = hyperparameters['burn_in_epochs'] 55 | self.num_cycles = hyperparameters['num_cycles'] 56 | self.model = reset_model(self.model) 57 | self.optimizer = optimSGHMC(params=self.model.parameters(), lr=self.lr_0, momentum=1 - self.alpha, 58 | num_training_samples=self.dataset_size, weight_decay=1 / (self.prior_std ** 2)) 59 | self.burnt_in = False 60 | self.epochs_run = 0 61 | 62 | assert ((self.cycle_length - self.burn_in_epochs - self.num_samples_per_cycle) > 0) 63 | 64 | def _adjust_learning_rate(self, optimizer, epoch, batch_idx): 65 | rcounter = epoch * self.num_batch + batch_idx 66 | cos_inner = np.pi * (rcounter % (self.total_iterations // self.num_cycles)) 67 | cos_inner /= self.total_iterations // self.num_cycles 68 | cos_out = np.cos(cos_inner) + 1 69 | lr = 0.5 * cos_out * self.lr_0 70 | for param_group in optimizer.param_groups: 71 | param_group['lr'] = lr 72 | return lr 73 | 74 | def sample_iterative(self, val_loader=None, debug_val_loss=False, wandb_debug=False): 75 | if issubclass(self.model.__class__, torch.nn.Module): 76 | sample_collected = False 77 | while sample_collected is False: 78 | self.model.train() 79 | total_epoch_train_loss = 0. 80 | for batch_idx, (batch_data, batch_labels) in enumerate(self.train_loader): 81 | self.lr = self._adjust_learning_rate(self.optimizer, self.epochs_run, batch_idx, ) 82 | batch_data = batch_data.to(self.device) 83 | batch_labels = batch_labels.to(self.device) 84 | batch_data_logits = self.model(batch_data) 85 | self.optimizer.zero_grad() 86 | loss = self.loss_criterion(batch_data_logits, batch_labels) 87 | loss.backward() 88 | total_epoch_train_loss += loss.item() * len(batch_data) 89 | if (self.epochs_run % self.cycle_length) + 1 > (self.cycle_length - self.burn_in_epochs 90 | - self.num_samples_per_cycle): 91 | self.optimizer.step(add_langevin_noise=True) 92 | else: 93 | self.optimizer.step(add_langevin_noise=False) 94 | self.epochs_run += 1 95 | print('Epoch: ', self.epochs_run, ' lr: ', self.lr) 96 | if debug_val_loss: 97 | avg_val_loss = self.compute_val_loss(val_loader) 98 | avg_train_loss = total_epoch_train_loss / self.dataset_size 99 | metrics = { 100 | 'train_loss': avg_train_loss, 101 | 'val_loss': avg_val_loss 102 | } 103 | print(metrics) 104 | if wandb_debug: 105 | wandb.log(metrics) 106 | if ((self.epochs_run - 1) % self.cycle_length) >= (self.cycle_length - self.num_samples_per_cycle): 107 | sample_collected = True 108 | # print('Epoch: ', self.epochs_run, ' lr: ', self.lr) 109 | output_model = deepcopy(self.model.cpu()) 110 | self.model.to(self.device) 111 | return output_model 112 | 113 | 114 | else: 115 | raise NotImplementedError 116 | 117 | def sample(self, num_samples=None, val_loader=None, debug_val_loss=False, wandb_debug=False): 118 | output_list = [] 119 | if num_samples is None: 120 | num_samples = self.num_samples_per_cycle * self.num_cycles 121 | if issubclass(self.model.__class__, torch.nn.Module): 122 | for i in range(num_samples): 123 | output_list.append(self.sample_iterative(val_loader=val_loader, debug_val_loss=debug_val_loss, 124 | wandb_debug=wandb_debug)) 125 | return output_list 126 | else: 127 | raise NotImplementedError 128 | -------------------------------------------------------------------------------- /URSABench/models/imagenet_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.nn.init as init 4 | from torchvision.transforms import transforms 5 | 6 | __all__ = ['INResNet18', 'INResNet34', 'INResNet50', 'INResNet101', 7 | 'INResNet152', 'ResNet_dropout'] 8 | 9 | 10 | def _weights_init(m): 11 | classname = m.__class__.__name__ 12 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion * planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion * planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion * planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion * planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, layers, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | self.block = block 74 | self.layers = layers 75 | self.num_classes = num_classes 76 | 77 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(64) 79 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 80 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 81 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 82 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 83 | self.linear = nn.Linear(512 * block.expansion, num_classes) 84 | 85 | self.apply(_weights_init) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1] * (num_blocks - 1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.avg_pool2d(out, 4) 102 | out = out.view(out.size(0), -1) 103 | out = self.linear(out) 104 | return out 105 | 106 | class ResNet_dropout(nn.Module): 107 | def __init__(self, block, layers, num_classes=10, dropout=0.2): 108 | super(ResNet_dropout, self).__init__() 109 | self.in_planes = 64 110 | self.block = block 111 | self.layers = layers 112 | self.num_classes = num_classes 113 | self.dropout = dropout 114 | 115 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 116 | self.bn1 = nn.BatchNorm2d(64) 117 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 121 | self.linear = nn.Linear(512 * block.expansion, num_classes) 122 | 123 | self.apply(_weights_init) 124 | 125 | def _make_layer(self, block, planes, num_blocks, stride): 126 | strides = [stride] + [1] * (num_blocks - 1) 127 | layers = [] 128 | for stride in strides: 129 | layers.append(block(self.in_planes, planes, stride)) 130 | self.in_planes = planes * block.expansion 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | out = F.relu(self.bn1(self.conv1(x))) 135 | out = self.layer1(out) 136 | out = self.layer2(out) 137 | out = self.layer3(out) 138 | out = self.layer4(out) 139 | out = F.avg_pool2d(out, 4) 140 | out = out.view(out.size(0), -1) 141 | out = self.linear(F.dropout(out, p=self.dropout)) 142 | return out 143 | 144 | class Base: 145 | base = ResNet 146 | args = list() 147 | kwargs = dict() 148 | transform_train = transforms.Compose([ 149 | transforms.RandomHorizontalFlip(), 150 | transforms.Resize(32), 151 | transforms.RandomCrop(32, padding=4), 152 | transforms.ToTensor(), 153 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 154 | ]) 155 | 156 | transform_test = transforms.Compose([ 157 | transforms.Resize(32), 158 | transforms.ToTensor(), 159 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 160 | ]) 161 | 162 | 163 | class INResNet18(Base): 164 | kwargs = {'block': BasicBlock, 'layers': [2, 2, 2, 2]} 165 | 166 | 167 | class INResNet34(Base): 168 | kwargs = {'block': BasicBlock, 'layers': [3, 4, 6, 3]} 169 | 170 | 171 | class INResNet50(Base): 172 | kwargs = {'block': Bottleneck, 'layers': [3, 4, 6, 3]} 173 | 174 | 175 | class INResNet101(Base): 176 | kwargs = {'block': Bottleneck, 'layers': [3, 4, 23, 3]} 177 | 178 | 179 | class INResNet152(Base): 180 | kwargs = {'block': Bottleneck, 'layers': [3, 8, 36, 3]} 181 | -------------------------------------------------------------------------------- /URSABench/tasks/ood_detection_distilled.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from sklearn.metrics import roc_auc_score 5 | 6 | from .task_base import _Task 7 | from ..util import central_smoothing, compute_predictive_entropy 8 | 9 | __all__ = ['OODDetectionDistilled'] 10 | # TODO: Add docstrings. 11 | class OODDetectionDistilled(_Task): 12 | def __init__(self, data_loader=None, num_classes=None, device=torch.device('cpu')): 13 | super(OODDetectionDistilled, self).__init__(data_loader, num_classes, device) 14 | self.in_distribution_loader = data_loader['in_distribution_test'] 15 | self.out_distribution_loader = data_loader['out_distribution_test'] 16 | self.num_classes = num_classes 17 | self.device = device 18 | self.in_distribution_ensemble_proba = torch.zeros(len(self.in_distribution_loader.dataset), num_classes) 19 | self.out_distribution_ensemble_proba = torch.zeros(len(self.out_distribution_loader.dataset), num_classes) 20 | self.in_distribution_data_uncertainty = torch.zeros(len(self.in_distribution_loader.dataset)) 21 | self.out_distribution_data_uncertainty = torch.zeros(len(self.out_distribution_loader.dataset)) 22 | self.in_distribution_total_uncertainty = None 23 | self.out_distribution_total_uncertainty = None 24 | self.in_distribution_model_uncertainty = None 25 | self.out_distribution_model_uncertainty = None 26 | self.num_samples_collected = 0 27 | 28 | def reset(self): 29 | self.in_distribution_ensemble_proba = torch.zeros(len(self.in_distribution_loader.dataset), self.num_classes) 30 | self.out_distribution_ensemble_proba = torch.zeros(len(self.out_distribution_loader.dataset), self.num_classes) 31 | self.in_distribution_data_uncertainty = torch.zeros(len(self.in_distribution_loader.dataset)) 32 | self.out_distribution_data_uncertainty = torch.zeros(len(self.out_distribution_loader.dataset)) 33 | self.in_distribution_total_uncertainty = None 34 | self.out_distribution_total_uncertainty = None 35 | self.in_distribution_model_uncertainty = None 36 | self.out_distribution_model_uncertainty = None 37 | self.num_samples_collected = 0 38 | 39 | def update_statistics(self, models, output_performance=True): 40 | if isinstance(models, list): 41 | if all(issubclass(model.__class__, torch.nn.Module) for model in models): 42 | self.num_samples_collected += 1 43 | else: 44 | raise NotImplementedError 45 | 46 | with torch.no_grad(): 47 | start_idx = 0 48 | for batch_idx, (batch_data, batch_labels) in enumerate(self.in_distribution_loader): 49 | end_idx = start_idx + len(batch_data) 50 | batch_data = batch_data.to(self.device) 51 | if isinstance(models, list): 52 | prediction_model = models[0] 53 | expected_data_uncertainty_model = models[1] 54 | # prediction_model.to(self.device) 55 | # expected_data_uncertainty_model.to(self.device) 56 | prediction_model.eval() 57 | expected_data_uncertainty_model.eval() 58 | batch_logits = prediction_model(batch_data) 59 | smoothened_proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 60 | entropy = expected_data_uncertainty_model(batch_data).exp().cpu() 61 | self.in_distribution_ensemble_proba[start_idx: end_idx] += smoothened_proba 62 | self.in_distribution_data_uncertainty[start_idx: end_idx] += entropy.squeeze() 63 | # model.to('cpu') 64 | else: 65 | raise Exception("Need exactly two models here") 66 | start_idx = end_idx 67 | 68 | start_idx = 0 69 | for batch_idx, (batch_data, batch_labels) in enumerate(self.out_distribution_loader): 70 | end_idx = start_idx + len(batch_data) 71 | batch_data = batch_data.to(self.device) 72 | if isinstance(models, list): 73 | 74 | prediction_model = models[0] 75 | expected_data_uncertainty_model = models[1] 76 | # prediction_model.to(self.device) 77 | # expected_data_uncertainty_model.to(self.device) 78 | prediction_model.eval() 79 | expected_data_uncertainty_model.eval() 80 | batch_logits = prediction_model(batch_data) 81 | smoothened_proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 82 | entropy = expected_data_uncertainty_model(batch_data).exp().cpu() 83 | smoothened_proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 84 | self.out_distribution_ensemble_proba[start_idx: end_idx] += smoothened_proba 85 | self.out_distribution_data_uncertainty[start_idx: end_idx] += compute_predictive_entropy( 86 | smoothened_proba) 87 | else: 88 | raise Exception("Need exactly two models here") 89 | start_idx = end_idx 90 | if output_performance: 91 | return self.get_performance_metrics() 92 | 93 | def get_performance_metrics(self): 94 | self.in_distribution_total_uncertainty = compute_predictive_entropy( 95 | self.in_distribution_ensemble_proba / self.num_samples_collected 96 | ) 97 | self.out_distribution_total_uncertainty = compute_predictive_entropy( 98 | self.out_distribution_ensemble_proba / self.num_samples_collected 99 | ) 100 | self.in_distribution_model_uncertainty = self.in_distribution_total_uncertainty - \ 101 | self.in_distribution_data_uncertainty / self.num_samples_collected 102 | self.out_distribution_model_uncertainty = self.out_distribution_total_uncertainty - \ 103 | self.out_distribution_data_uncertainty / self.num_samples_collected 104 | label_array = np.concatenate([np.ones(len(self.out_distribution_loader.dataset)), 105 | np.zeros(len(self.in_distribution_loader.dataset))]) 106 | total_uncertainty_array = np.concatenate([self.out_distribution_total_uncertainty.numpy(), 107 | self.in_distribution_total_uncertainty.numpy()]) 108 | model_uncertainty_array = np.concatenate([self.out_distribution_model_uncertainty.numpy(), 109 | self.in_distribution_model_uncertainty.numpy()]) 110 | total_uncertainty_auroc_score = roc_auc_score(label_array, total_uncertainty_array) 111 | model_uncertainty_auroc_score = roc_auc_score(label_array, model_uncertainty_array) 112 | 113 | return { 114 | 'total_uncertainty_auroc': total_uncertainty_auroc_score, 115 | 'model_uncertainty_auroc': model_uncertainty_auroc_score 116 | } 117 | -------------------------------------------------------------------------------- /URSABench/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | WideResNet model definition 3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py 4 | """ 5 | 6 | import math 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.init as init 11 | import torchvision.transforms as transforms 12 | 13 | __all__ = ['WideResNet28x10', 'WideResNet_dropout'] 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 18 | 19 | 20 | def conv_init(m): 21 | classname = m.__class__.__name__ 22 | if classname.find('Conv') != -1: 23 | init.xavier_uniform(m.weight, gain=math.sqrt(2)) 24 | init.constant(m.bias, 0) 25 | elif classname.find('BatchNorm') != -1: 26 | init.constant(m.weight, 1) 27 | init.constant(m.bias, 0) 28 | 29 | 30 | class WideBasic(nn.Module): 31 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 32 | super(WideBasic, self).__init__() 33 | self.bn1 = nn.BatchNorm2d(in_planes) 34 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 35 | self.dropout = nn.Dropout(p=dropout_rate) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or in_planes != planes: 41 | self.shortcut = nn.Sequential( 42 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 43 | ) 44 | 45 | def forward(self, x): 46 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 47 | out = self.conv2(F.relu(self.bn2(out))) 48 | out += self.shortcut(x) 49 | 50 | return out 51 | 52 | class WideBasic_dropout(nn.Module): 53 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 54 | super(WideBasic_dropout, self).__init__() 55 | self.dropout = dropout_rate 56 | 57 | self.bn1 = nn.BatchNorm2d(in_planes) 58 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 59 | # self.dropout = nn.Dropout(p=dropout_rate) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 62 | 63 | self.shortcut = nn.Sequential() 64 | if stride != 1 or in_planes != planes: 65 | self.shortcut = nn.Sequential( 66 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 67 | ) 68 | 69 | def forward(self, x): 70 | out = self.conv1(F.relu(self.bn1(x))) 71 | out = F.dropout(out, p=self.dropout) 72 | out = self.conv2(F.relu(self.bn2(out))) 73 | out += self.shortcut(x) 74 | 75 | return out 76 | 77 | 78 | class WideResNet(nn.Module): 79 | def __init__(self, num_classes=10, depth=28, widen_factor=10, dropout_rate=0.): 80 | super(WideResNet, self).__init__() 81 | self.in_planes = 16 82 | self.num_classes = num_classes 83 | self.depth = depth 84 | self.widen_factor = widen_factor 85 | self.dropout_rate = dropout_rate 86 | 87 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 88 | n = (depth - 4) / 6 89 | k = widen_factor 90 | 91 | nstages = [16, 16 * k, 32 * k, 64 * k] 92 | 93 | self.conv1 = conv3x3(3, nstages[0]) 94 | self.layer1 = self._wide_layer(WideBasic, nstages[1], n, dropout_rate, stride=1) 95 | self.layer2 = self._wide_layer(WideBasic, nstages[2], n, dropout_rate, stride=2) 96 | self.layer3 = self._wide_layer(WideBasic, nstages[3], n, dropout_rate, stride=2) 97 | self.bn1 = nn.BatchNorm2d(nstages[3], momentum=0.9) 98 | self.linear = nn.Linear(nstages[3], num_classes) 99 | 100 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 101 | strides = [stride] + [1] * int(num_blocks - 1) 102 | layers = [] 103 | 104 | for stride in strides: 105 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 106 | self.in_planes = planes 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | out = self.conv1(x) 112 | out = self.layer1(out) 113 | out = self.layer2(out) 114 | out = self.layer3(out) 115 | out = F.relu(self.bn1(out)) 116 | out = F.avg_pool2d(out, 8) 117 | out = out.view(out.size(0), -1) 118 | out = self.linear(out) 119 | 120 | return out 121 | 122 | class WideResNet_dropout(nn.Module): 123 | def __init__(self, num_classes=10, depth=28, widen_factor=10, dropout_rate=0., dropout = 0.1): 124 | super(WideResNet_dropout, self).__init__() 125 | self.in_planes = 16 126 | self.num_classes = num_classes 127 | self.depth = depth 128 | self.widen_factor = widen_factor 129 | self.dropout_rate = dropout_rate 130 | self.dropout = dropout 131 | 132 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 133 | n = (depth - 4) / 6 134 | k = widen_factor 135 | 136 | nstages = [16, 16 * k, 32 * k, 64 * k] 137 | 138 | self.conv1 = conv3x3(3, nstages[0]) 139 | self.layer1 = self._wide_layer(WideBasic_dropout, nstages[1], n, dropout, stride=1) 140 | self.layer2 = self._wide_layer(WideBasic_dropout, nstages[2], n, dropout, stride=2) 141 | self.layer3 = self._wide_layer(WideBasic_dropout, nstages[3], n, dropout, stride=2) 142 | self.bn1 = nn.BatchNorm2d(nstages[3], momentum=0.9) 143 | self.linear = nn.Linear(nstages[3], num_classes) 144 | 145 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 146 | strides = [stride] + [1] * int(num_blocks - 1) 147 | layers = [] 148 | 149 | for stride in strides: 150 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 151 | self.in_planes = planes 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def forward(self, x): 156 | out = self.conv1(x) 157 | out = self.layer1(out) 158 | out = self.layer2(out) 159 | out = self.layer3(out) 160 | out = F.relu(self.bn1(out)) 161 | out = F.avg_pool2d(out, 8) 162 | out = out.view(out.size(0), -1) 163 | out = self.linear(F.dropout(out, p=self.dropout)) 164 | 165 | return out 166 | 167 | class WideResNet28x10: 168 | base = WideResNet 169 | args = list() 170 | kwargs = {'depth': 28, 'widen_factor': 10} 171 | transform_train = transforms.Compose([ 172 | transforms.Resize(32), 173 | transforms.RandomCrop(32, padding=4), 174 | transforms.RandomHorizontalFlip(), 175 | transforms.ToTensor(), 176 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 177 | ]) 178 | transform_test = transforms.Compose([ 179 | transforms.Resize(32), 180 | transforms.ToTensor(), 181 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 182 | ]) 183 | -------------------------------------------------------------------------------- /URSABench/tasks/decision_making.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision 4 | 5 | from ..util import central_smoothing 6 | 7 | from .task_base import _Task 8 | 9 | __all__ = ['Decision'] 10 | 11 | 12 | def MNIST_cost(num_classes): 13 | I =torch.eye(num_classes) 14 | digits = [3,7] 15 | C = I.clone() 16 | C[torch.where(I==0)] = 0.1 17 | C[digits] = 100.0 # Select more important rows with high cost in error of decision 18 | C[torch.where(I==1)] = 0 19 | return C 20 | 21 | def CIFAR10_cost(num_classes): 22 | I =torch.eye(num_classes) 23 | digits = [0,1,8,9] # Plane, automobile, ship, truck 24 | C = I.clone() 25 | C[torch.where(I==0)] = 0.1 26 | C[digits] = 1.0 # Select more important rows with high cost in error of decision 27 | C[torch.where(I==1)] = 0 28 | return C 29 | 30 | coarse_label = ['apple', # id 0 31 | 'aquarium_fish', 'baby', 'bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 32 | 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 33 | 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 34 | 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 35 | 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 36 | 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'] 37 | 38 | def CIFAR100_cost(num_classes): 39 | labels = ['tank', 'rocket', 'pickup_truck'] 40 | digits = [] 41 | for i, l in enumerate(coarse_label): 42 | for c in labels: 43 | if c == l: 44 | digits.append(i) 45 | I = torch.eye(num_classes) 46 | C = I.clone() 47 | C[torch.where(I==0)] = 0.1 48 | C[digits] = 1.0 # Select more important rows with high cost in error of decision 49 | C[torch.where(I==1)] = 0 50 | return C 51 | # print(np.array2string(L.numpy())) 52 | 53 | # def decision(y_pred=None, cost_mat=None): 54 | # """ 55 | # torch function to calculate cost from the cost matrix: 56 | # Inputs: 57 | # y_pred: predicted values (S,N,D) (samples) 58 | # loss_mat: matrix of loss values for selecting outputs (D,D) 59 | # """ 60 | # S,N,D = y_pred.shape 61 | # A = torch.matmul(y_pred,cost_mat).mean(0) #Integration over samples 62 | # D = A.argmin(1) 63 | # return D 64 | 65 | def decision_cost(D, y_true, cost_mat=None): 66 | """ 67 | torch function to calculate cost from the cost matrix: 68 | Inputs: 69 | y_true: true values (N,) 70 | D: Decisions torch tensor of integers (N,) 71 | loss_mat: matrix of loss values for selecting outputs (D,D) 72 | """ 73 | return cost_mat[y_true,D].sum() 74 | 75 | 76 | class Decision(_Task): 77 | def __init__(self, dataloader, num_classes, device): 78 | super(Decision, self).__init__(dataloader, num_classes, device) 79 | self.data_loader = dataloader['decision_data_test'] 80 | self.num_classes = num_classes 81 | self.device = device 82 | self.num_samples_collected = 0 83 | self.ensemble_proba = torch.zeros(len(self.data_loader.dataset), num_classes) 84 | self.risk = torch.zeros(len(self.data_loader.dataset), self.num_classes) 85 | self.targets = list() 86 | for batch_idx, (batch_data, batch_labels) in enumerate(self.data_loader): 87 | self.targets.append(batch_labels) 88 | self.targets = torch.cat(self.targets) 89 | 90 | if self.data_loader.dataset.__class__ is torchvision.datasets.mnist.MNIST: 91 | self.cost_mat = MNIST_cost(self.num_classes) 92 | elif self.data_loader.dataset.__class__ is torchvision.datasets.cifar.CIFAR10: 93 | self.cost_mat = CIFAR10_cost(self.num_classes) 94 | elif self.data_loader.dataset.__class__ is torchvision.datasets.cifar.CIFAR100: 95 | self.cost_mat = CIFAR100_cost(self.num_classes) 96 | else: 97 | raise NotImplementedError 98 | 99 | def reset(self): 100 | self.num_samples_collected = 0 101 | self.ensemble_proba = torch.zeros(len(self.data_loader.dataset), self.num_classes) 102 | self.risk = torch.zeros(len(self.data_loader.dataset), self.num_classes) 103 | 104 | def update_statistics(self, models, output_performance=True, smoothing = True): 105 | if isinstance(models, list): 106 | if all(issubclass(model.__class__, torch.nn.Module) for model in models): 107 | num_models = len(models) 108 | self.num_samples_collected += num_models 109 | else: 110 | raise NotImplementedError 111 | else: 112 | if issubclass(models.__class__, torch.nn.Module): 113 | self.num_samples_collected += 1 114 | else: 115 | raise NotImplementedError 116 | 117 | with torch.no_grad(): 118 | start_idx = 0 119 | for batch_idx, (batch_data, batch_labels) in enumerate(self.data_loader): 120 | end_idx = start_idx + len(batch_data) 121 | batch_data = batch_data.to(self.device) 122 | if isinstance(models, list): 123 | for model_idx, model in enumerate(models): 124 | model.to(self.device) 125 | model.eval() 126 | batch_logits = model(batch_data) 127 | proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 128 | self.ensemble_proba[start_idx: end_idx] += proba 129 | self.risk[start_idx: end_idx] += torch.matmul(proba,self.cost_mat) 130 | model.to('cpu') 131 | else: 132 | ## Here models indicates a single model. 133 | models.to(self.device) 134 | models.eval() 135 | batch_logits = models(batch_data) 136 | proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 137 | self.ensemble_proba[start_idx: end_idx] += proba 138 | self.risk[start_idx: end_idx] += torch.matmul(proba,self.cost_mat) 139 | models.to('cpu') 140 | start_idx = end_idx 141 | if output_performance: 142 | return self.get_performance_metrics(output_performance, smoothing) 143 | 144 | def get_performance_metrics(self, output_performance=False, smoothing = True): 145 | output_dict = {} 146 | D = (self.risk / self.num_samples_collected).argmin(1) 147 | cost = decision_cost(D,self.targets,self.cost_mat) 148 | 149 | output_dict['True_Cost'] = cost 150 | output_dict['Decision'] = D 151 | output_dict['Pred_cost'] = self.risk 152 | return output_dict 153 | -------------------------------------------------------------------------------- /URSABench/inference/subspaces.py: -------------------------------------------------------------------------------- 1 | """ 2 | subspace classes 3 | CovarianceSpace: covariance subspace 4 | PCASpace: PCA subspace 5 | FreqDirSpace: Frequent Directions Space 6 | """ 7 | 8 | import abc 9 | 10 | import numpy as np 11 | import torch 12 | from sklearn.decomposition import TruncatedSVD 13 | from sklearn.decomposition.pca import _assess_dimension_ 14 | from sklearn.utils.extmath import randomized_svd 15 | 16 | 17 | class Subspace(torch.nn.Module, metaclass=abc.ABCMeta): 18 | subclasses = {} 19 | 20 | @classmethod 21 | def register_subclass(cls, subspace_type): 22 | def decorator(subclass): 23 | cls.subclasses[subspace_type] = subclass 24 | return subclass 25 | 26 | return decorator 27 | 28 | @classmethod 29 | def create(cls, subspace_type, **kwargs): 30 | if subspace_type not in cls.subclasses: 31 | raise ValueError('Bad subspaces type {}'.format(subspace_type)) 32 | return cls.subclasses[subspace_type](**kwargs) 33 | 34 | def __init__(self): 35 | super(Subspace, self).__init__() 36 | 37 | @abc.abstractmethod 38 | def collect_vector(self, vector): 39 | pass 40 | 41 | @abc.abstractmethod 42 | def get_space(self): 43 | pass 44 | 45 | 46 | @Subspace.register_subclass('random') 47 | class RandomSpace(Subspace): 48 | def __init__(self, num_parameters, rank=20, method='dense'): 49 | assert method in ['dense', 'fastfood'] 50 | 51 | super(RandomSpace, self).__init__() 52 | 53 | self.num_parameters = num_parameters 54 | self.rank = rank 55 | self.method = method 56 | 57 | if method == 'dense': 58 | self.subspace = torch.randn(rank, num_parameters) 59 | 60 | if method == 'fastfood': 61 | raise NotImplementedError("FastFood transform hasn't been implemented yet") 62 | 63 | # random subspace is independent of data 64 | def collect_vector(self, vector): 65 | pass 66 | 67 | def get_space(self): 68 | return self.subspace 69 | 70 | 71 | @Subspace.register_subclass('covariance') 72 | class CovarianceSpace(Subspace): 73 | 74 | def __init__(self, num_parameters, max_rank=20): 75 | super(CovarianceSpace, self).__init__() 76 | 77 | self.num_parameters = num_parameters 78 | 79 | self.register_buffer('rank', torch.zeros(1, dtype=torch.long)) 80 | self.register_buffer('cov_mat_sqrt', 81 | torch.empty(0, self.num_parameters, dtype=torch.float32)) 82 | 83 | self.max_rank = max_rank 84 | 85 | def collect_vector(self, vector): 86 | if self.rank.item() + 1 > self.max_rank: 87 | self.cov_mat_sqrt = self.cov_mat_sqrt[1:, :] 88 | self.cov_mat_sqrt = torch.cat((self.cov_mat_sqrt, vector.view(1, -1)), dim=0) 89 | self.rank = torch.min(self.rank + 1, torch.as_tensor(self.max_rank)).view(-1) 90 | 91 | def get_space(self): 92 | return self.cov_mat_sqrt.clone() / (self.cov_mat_sqrt.size(0) - 1) ** 0.5 93 | 94 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 95 | missing_keys, unexpected_keys, error_msgs): 96 | rank = state_dict[prefix + 'rank'].item() 97 | self.cov_mat_sqrt = self.cov_mat_sqrt.new_empty((rank, self.cov_mat_sqrt.size()[1])) 98 | super(CovarianceSpace, self)._load_from_state_dict(state_dict, prefix, local_metadata, 99 | strict, missing_keys, unexpected_keys, 100 | error_msgs) 101 | 102 | 103 | @Subspace.register_subclass('pca') 104 | class PCASpace(CovarianceSpace): 105 | 106 | def __init__(self, num_parameters, pca_rank=20, max_rank=20): 107 | super(PCASpace, self).__init__(num_parameters, max_rank=max_rank) 108 | 109 | # better phrasing for this condition? 110 | assert (pca_rank == 'mle' or isinstance(pca_rank, int)) 111 | if pca_rank != 'mle': 112 | assert 1 <= pca_rank <= max_rank 113 | 114 | self.pca_rank = pca_rank 115 | 116 | def get_space(self): 117 | 118 | cov_mat_sqrt_np = self.cov_mat_sqrt.clone().numpy() 119 | 120 | # perform PCA on DD' 121 | cov_mat_sqrt_np /= (max(1, self.rank.item() - 1)) ** 0.5 122 | 123 | if self.pca_rank == 'mle': 124 | pca_rank = self.rank.item() 125 | else: 126 | pca_rank = self.pca_rank 127 | 128 | pca_rank = max(1, min(pca_rank, self.rank.item())) 129 | pca_decomp = TruncatedSVD(n_components=pca_rank) 130 | pca_decomp.fit(cov_mat_sqrt_np) 131 | 132 | _, s, Vt = randomized_svd(cov_mat_sqrt_np, n_components=pca_rank, n_iter=5) 133 | 134 | # perform post-selection fitting 135 | if self.pca_rank == 'mle': 136 | eigs = s ** 2.0 137 | ll = np.zeros(len(eigs)) 138 | correction = np.zeros(len(eigs)) 139 | 140 | # compute minka's PCA marginal log likelihood and the correction term 141 | for rank in range(len(eigs)): 142 | # secondary correction term based on the rank of the matrix + degrees of freedom 143 | m = cov_mat_sqrt_np.shape[1] * rank - rank * (rank + 1) / 2. 144 | correction[rank] = 0.5 * m * np.log(cov_mat_sqrt_np.shape[0]) 145 | ll[rank] = _assess_dimension_(spectrum=eigs, 146 | rank=rank, 147 | n_features=min(cov_mat_sqrt_np.shape), 148 | n_samples=max(cov_mat_sqrt_np.shape)) 149 | 150 | self.ll = ll 151 | self.corrected_ll = ll - correction 152 | self.pca_rank = np.nanargmax(self.corrected_ll) 153 | print('PCA Rank is: ', self.pca_rank) 154 | return torch.FloatTensor(s[:self.pca_rank, None] * Vt[:self.pca_rank, :]) 155 | else: 156 | return torch.FloatTensor(s[:, None] * Vt) 157 | 158 | 159 | @Subspace.register_subclass('freq_dir') 160 | class FreqDirSpace(CovarianceSpace): 161 | def __init__(self, num_parameters, max_rank=20): 162 | super(FreqDirSpace, self).__init__(num_parameters, max_rank=max_rank) 163 | self.register_buffer('num_models', torch.zeros(1, dtype=torch.long)) 164 | self.delta = 0.0 165 | self.normalized = False 166 | 167 | def collect_vector(self, vector): 168 | if self.rank >= 2 * self.max_rank: 169 | sketch = self.cov_mat_sqrt.numpy() 170 | [_, s, Vt] = np.linalg.svd(sketch, full_matrices=False) 171 | if s.size >= self.max_rank: 172 | current_delta = s[self.max_rank - 1] ** 2 173 | self.delta += current_delta 174 | s = np.sqrt(s[:self.max_rank - 1] ** 2 - current_delta) 175 | self.cov_mat_sqrt = torch.from_numpy(s[:, None] * Vt[:s.size, :]) 176 | 177 | self.cov_mat_sqrt = torch.cat((self.cov_mat_sqrt, vector.view(1, -1)), dim=0) 178 | self.rank = torch.as_tensor(self.cov_mat_sqrt.size(0)) 179 | self.num_models.add_(1) 180 | self.normalized = False 181 | 182 | def get_space(self): 183 | if not self.normalized: 184 | sketch = self.cov_mat_sqrt.numpy() 185 | [_, s, Vt] = np.linalg.svd(sketch, full_matrices=False) 186 | self.cov_mat_sqrt = torch.from_numpy(s[:, None] * Vt) 187 | self.normalized = True 188 | curr_rank = min(self.rank.item(), self.max_rank) 189 | return self.cov_mat_sqrt[:curr_rank].clone() / max(1, self.num_models.item() - 1) ** 0.5 190 | -------------------------------------------------------------------------------- /URSABench/tasks/ood_detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from sklearn.metrics import roc_auc_score 5 | 6 | from .task_base import _Task 7 | from ..util import central_smoothing, compute_predictive_entropy 8 | 9 | __all__ = ['OODDetection'] 10 | # TODO: Add docstrings. 11 | class OODDetection(_Task): 12 | def __init__(self, data_loader=None, num_classes=None, device=torch.device('cpu')): 13 | super(OODDetection, self).__init__(data_loader, num_classes, device) 14 | self.in_distribution_loader = data_loader['in_distribution_test'] 15 | self.out_distribution_loader = data_loader['out_distribution_test'] 16 | self.num_classes = num_classes 17 | self.device = device 18 | self.in_distribution_ensemble_proba = torch.zeros(len(self.in_distribution_loader.dataset), num_classes) 19 | self.out_distribution_ensemble_proba = torch.zeros(len(self.out_distribution_loader.dataset), num_classes) 20 | self.in_distribution_data_uncertainty = torch.zeros(len(self.in_distribution_loader.dataset)) 21 | self.out_distribution_data_uncertainty = torch.zeros(len(self.out_distribution_loader.dataset)) 22 | self.in_distribution_total_uncertainty = None 23 | self.out_distribution_total_uncertainty = None 24 | self.in_distribution_model_uncertainty = None 25 | self.out_distribution_model_uncertainty = None 26 | self.num_samples_collected = 0 27 | 28 | def reset(self): 29 | self.in_distribution_ensemble_proba = torch.zeros(len(self.in_distribution_loader.dataset), self.num_classes) 30 | self.out_distribution_ensemble_proba = torch.zeros(len(self.out_distribution_loader.dataset), self.num_classes) 31 | self.in_distribution_data_uncertainty = torch.zeros(len(self.in_distribution_loader.dataset)) 32 | self.out_distribution_data_uncertainty = torch.zeros(len(self.out_distribution_loader.dataset)) 33 | self.in_distribution_total_uncertainty = None 34 | self.out_distribution_total_uncertainty = None 35 | self.in_distribution_model_uncertainty = None 36 | self.out_distribution_model_uncertainty = None 37 | self.num_samples_collected = 0 38 | 39 | def update_statistics(self, models, output_performance=True): 40 | if isinstance(models, list): 41 | if all(issubclass(model.__class__, torch.nn.Module) for model in models): 42 | num_models = len(models) 43 | self.num_samples_collected += num_models 44 | else: 45 | raise NotImplementedError 46 | else: 47 | if issubclass(models.__class__, torch.nn.Module): 48 | self.num_samples_collected += 1 49 | else: 50 | raise NotImplementedError 51 | 52 | with torch.no_grad(): 53 | start_idx = 0 54 | for batch_idx, (batch_data, batch_labels) in enumerate(self.in_distribution_loader): 55 | end_idx = start_idx + len(batch_data) 56 | batch_data = batch_data.to(self.device) 57 | if isinstance(models, list): 58 | for model_idx, model in enumerate(models): 59 | model.to(self.device) 60 | model.eval() 61 | batch_logits = model(batch_data) 62 | smoothened_proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 63 | self.in_distribution_ensemble_proba[start_idx: end_idx] += smoothened_proba 64 | self.in_distribution_data_uncertainty[start_idx: end_idx] += compute_predictive_entropy( 65 | smoothened_proba) 66 | model.to('cpu') 67 | else: 68 | ## Here models indicates a single model. 69 | models.to(self.device) 70 | models.eval() 71 | batch_logits = models(batch_data) 72 | smoothened_proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 73 | self.in_distribution_ensemble_proba[start_idx: end_idx] += smoothened_proba 74 | self.in_distribution_data_uncertainty[start_idx: end_idx] += compute_predictive_entropy( 75 | smoothened_proba) 76 | models.to('cpu') 77 | start_idx = end_idx 78 | 79 | start_idx = 0 80 | for batch_idx, (batch_data, batch_labels) in enumerate(self.out_distribution_loader): 81 | end_idx = start_idx + len(batch_data) 82 | batch_data = batch_data.to(self.device) 83 | if isinstance(models, list): 84 | for model_idx, model in enumerate(models): 85 | model.to(self.device) 86 | model.eval() 87 | batch_logits = model(batch_data) 88 | smoothened_proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 89 | self.out_distribution_ensemble_proba[start_idx: end_idx] += smoothened_proba 90 | self.out_distribution_data_uncertainty[start_idx: end_idx] += compute_predictive_entropy( 91 | smoothened_proba) 92 | model.to('cpu') 93 | else: 94 | ## Here models indicates a single model. 95 | models.to(self.device) 96 | models.eval() 97 | batch_logits = models(batch_data) 98 | smoothened_proba = central_smoothing(F.log_softmax(batch_logits, dim=-1).exp_().cpu()) 99 | self.out_distribution_ensemble_proba[start_idx: end_idx] += smoothened_proba 100 | self.out_distribution_data_uncertainty[start_idx: end_idx] += compute_predictive_entropy( 101 | smoothened_proba) 102 | models.to('cpu') 103 | start_idx = end_idx 104 | if output_performance: 105 | return self.get_performance_metrics() 106 | 107 | def get_performance_metrics(self): 108 | self.in_distribution_total_uncertainty = compute_predictive_entropy( 109 | self.in_distribution_ensemble_proba / self.num_samples_collected 110 | ) 111 | self.out_distribution_total_uncertainty = compute_predictive_entropy( 112 | self.out_distribution_ensemble_proba / self.num_samples_collected 113 | ) 114 | self.in_distribution_model_uncertainty = self.in_distribution_total_uncertainty - \ 115 | self.in_distribution_data_uncertainty / self.num_samples_collected 116 | self.out_distribution_model_uncertainty = self.out_distribution_total_uncertainty - \ 117 | self.out_distribution_data_uncertainty / self.num_samples_collected 118 | label_array = np.concatenate([np.ones(len(self.out_distribution_loader.dataset)), 119 | np.zeros(len(self.in_distribution_loader.dataset))]) 120 | total_uncertainty_array = np.concatenate([self.out_distribution_total_uncertainty.numpy(), 121 | self.in_distribution_total_uncertainty.numpy()]) 122 | model_uncertainty_array = np.concatenate([self.out_distribution_model_uncertainty.numpy(), 123 | self.in_distribution_model_uncertainty.numpy()]) 124 | total_uncertainty_auroc_score = roc_auc_score(label_array, total_uncertainty_array) 125 | model_uncertainty_auroc_score = roc_auc_score(label_array, model_uncertainty_array) 126 | 127 | return { 128 | 'total_uncertainty_auroc': total_uncertainty_auroc_score, 129 | 'model_uncertainty_auroc': model_uncertainty_auroc_score 130 | } 131 | -------------------------------------------------------------------------------- /URSABench/inference/pca_subspace.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import torch 5 | import wandb 6 | 7 | from URSABench.inference.inference_base import _Inference 8 | from URSABench.inference.projection_model import SubspaceModel 9 | from URSABench.inference.swa import SWA 10 | from URSABench.util import reset_model, log_pdf, cross_entropy, elliptical_slice, bn_update 11 | 12 | 13 | class PCASubspaceSampler(_Inference): 14 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 15 | device=torch.device('cpu')): 16 | super(PCASubspaceSampler, self).__init__(hyperparameters=hyperparameters, 17 | model=model, train_loader=train_loader, device=device) 18 | if hyperparameters == None: 19 | # Initialise as some default values 20 | self.hyperparameters = {'swag_lr': 0.001, 'swag_wd': 0.001, 'lr_init': 0.001, 'num_samples': 20, 21 | 'swag_momentum': 0.1, 'swag_burn_in_epochs': 100, 'num_swag_iterates': 50, 22 | 'rank': 20, 'max_rank': 20, 'temperature': 5000, 'prior_std': 2.} 23 | else: 24 | self.hyperparameters = hyperparameters 25 | 26 | self.rank = self.hyperparameters['rank'] 27 | self.max_rank = self.hyperparameters['max_rank'] 28 | self.model = model 29 | # self.sampler = hyperparameters['ess'] 30 | self.train_loader = train_loader 31 | self.device = device 32 | self.lr_init = self.hyperparameters['lr_init'] 33 | # import pdb; pdb.set_trace() 34 | self.swag_lr = self.hyperparameters['swag_lr'] 35 | self.swag_burn_in_epochs = self.hyperparameters['swag_burn_in_epochs'] 36 | self.num_samples = self.hyperparameters['num_samples'] 37 | self.num_swag_iterates = self.hyperparameters['num_swag_iterates'] 38 | self.swag_momentum = self.hyperparameters['swag_momentum'] 39 | self.lr_init = self.hyperparameters['lr_init'] 40 | self.swag_lr = self.hyperparameters['swag_lr'] 41 | self.swag_wd = self.hyperparameters['swag_wd'] 42 | self.prior_std = self.hyperparameters['prior_std'] 43 | self.temperature = self.hyperparameters['temperature'] 44 | self.model_loss_type = model_loss 45 | self.subspace_constructed = False 46 | self.current_theta = None 47 | swag_hyperparam_dict = { 48 | 'burn_in_epochs': self.swag_burn_in_epochs, 49 | # 'num_samples': self.num_swag_iterates, 50 | 'momentum': self.swag_momentum, 51 | 'lr_init': self.lr_init, 52 | 'swag_lr': self.swag_lr, 53 | 'swag_wd': self.swag_wd, 54 | 'num_iterates': self.num_swag_iterates, 55 | 'subspace_type': 'pca' 56 | } 57 | subspace_kwargs = { 58 | 'max_rank': self.max_rank, 59 | 'pca_rank': self.rank 60 | } 61 | self.swag_model = SWA(hyperparameters=swag_hyperparam_dict, model=self.model, 62 | train_loader=self.train_loader, model_loss=self.model_loss_type, 63 | device=self.device, **subspace_kwargs) 64 | self.weight_mean = None 65 | self.weight_covariance = None 66 | self.subspace = None 67 | 68 | def update_hyp(self, hyperparameters): 69 | self.rank = hyperparameters['rank'] 70 | self.max_rank = hyperparameters['max_rank'] 71 | # self.sampler = hyperparameters['ess'] 72 | self.lr_init = hyperparameters['lr_init'] 73 | self.swag_lr = hyperparameters['swag_lr'] 74 | self.swag_burn_in_epochs = hyperparameters['swag_burn_in_epochs'] 75 | self.num_samples = hyperparameters['num_samples'] 76 | self.num_swag_iterates = hyperparameters['num_swag_iterates'] 77 | self.swag_momentum = hyperparameters['swag_momentum'] 78 | self.lr_init = hyperparameters['lr_init'] 79 | self.swag_lr = hyperparameters['swag_lr'] 80 | self.swag_wd = hyperparameters['swag_wd'] 81 | # import pdb; pdb.set_trace() 82 | self.prior_std = hyperparameters['prior_std'] 83 | self.temperature = self.hyperparameters['temperature'] 84 | self.subspace_constructed = False 85 | self.current_theta = None 86 | swag_hyperparam_dict = { 87 | 'burn_in_epochs': self.swag_burn_in_epochs, 88 | 'num_iterates': self.num_swag_iterates, 89 | 'momentum': self.swag_momentum, 90 | 'lr_init': self.lr_init, 91 | 'swag_lr': self.swag_lr, 92 | 'swag_wd': self.swag_wd, 93 | 'subspace_type': 'pca' 94 | } 95 | subspace_kwargs = { 96 | 'max_rank': self.max_rank, 97 | 'pca_rank': self.rank 98 | } 99 | self.model = reset_model(self.model) 100 | self.swag_model.update_hyp(swag_hyperparam_dict, **subspace_kwargs) 101 | self.subspace_constructed = False 102 | self.weight_mean = None 103 | self.weight_covariance = None 104 | self.subspace = None 105 | 106 | def _oracle(self, theta, subspace): 107 | return log_pdf(theta, subspace, self.model, self.train_loader, cross_entropy, self.temperature, 108 | self.device) 109 | 110 | def sample_iterative(self, update_bn=True, val_loader=None, debug_val_loss=False, wandb_debug=False): 111 | if self.subspace_constructed is False: 112 | self.swag_model.sample(val_loader=val_loader, debug_val_loss=debug_val_loss, wandb_debug=wandb_debug) 113 | self.subspace_constructed = True 114 | if self.weight_mean is None or self.weight_covariance is None: 115 | self.weight_mean, _, self.weight_covariance = self.swag_model.get_space() 116 | if self.subspace is None: 117 | self.subspace = SubspaceModel(self.weight_mean, self.weight_covariance) 118 | if self.current_theta is None: 119 | self.current_theta = torch.zeros(self.rank) 120 | prior_sample = np.random.normal(loc=0.0, scale=self.prior_std, size=self.rank) 121 | theta, log_prob = elliptical_slice(initial_theta=self.current_theta.numpy().copy(), prior=prior_sample, 122 | lnpdf=self._oracle, subspace=self.subspace) 123 | self.current_theta = torch.FloatTensor(theta) 124 | weight_sample = self.subspace(self.current_theta) 125 | offset = 0 126 | for param in self.model.parameters(): 127 | param.data.copy_(weight_sample[offset:offset + param.numel()].view(param.size()).to(self.device)) 128 | offset += param.numel() 129 | if debug_val_loss: 130 | avg_val_loss = self.compute_val_loss(val_loader) 131 | # avg_train_loss = total_epoch_train_loss / self.dataset_size 132 | metrics = { 133 | # 'train_loss': avg_train_loss, 134 | 'val_loss': avg_val_loss 135 | } 136 | print(metrics) 137 | if wandb_debug: 138 | wandb.log(metrics) 139 | if update_bn: 140 | bn_update(self.train_loader, self.model) 141 | output_model = deepcopy(self.model.cpu()) 142 | self.model.to(self.device) 143 | return output_model 144 | 145 | def sample(self, num_samples=None, val_loader=None, debug_val_loss=False, wandb_debug=False): 146 | if num_samples is None: 147 | num_samples = self.num_samples 148 | output_model_list = [] 149 | for i in range(num_samples): 150 | if i == num_samples - 1: 151 | output_model_list.append(self.sample_iterative(update_bn=True, val_loader=val_loader, 152 | debug_val_loss=debug_val_loss, wandb_debug=wandb_debug)) 153 | else: 154 | output_model_list.append(self.sample_iterative(update_bn=False, val_loader=val_loader, 155 | debug_val_loss=debug_val_loss, wandb_debug=wandb_debug)) 156 | 157 | return output_model_list 158 | -------------------------------------------------------------------------------- /URSABench/models/preresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreResNet model definition 3 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py 4 | """ 5 | 6 | import math 7 | 8 | import torch.nn as nn 9 | import torchvision.transforms as transforms 10 | 11 | __all__ = ['PreResNet110', 'PreResNet56', 'PreResNet8', 'PreResNet83', 'PreResNet164'] 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(inplanes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.bn1(x) 36 | out = self.relu(out) 37 | out = self.conv1(out) 38 | 39 | out = self.bn2(out) 40 | out = self.relu(out) 41 | out = self.conv2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.bn1 = nn.BatchNorm2d(inplanes) 57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 60 | padding=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes) 62 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.bn1(x) 71 | out = self.relu(out) 72 | out = self.conv1(out) 73 | 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | out = self.conv2(out) 77 | 78 | out = self.bn3(out) 79 | out = self.relu(out) 80 | out = self.conv3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | 87 | return out 88 | 89 | 90 | class PreResNet(nn.Module): 91 | 92 | def __init__(self, num_classes=10, depth=110): 93 | super(PreResNet, self).__init__() 94 | if depth >= 44: 95 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 96 | n = (depth - 2) // 9 97 | block = Bottleneck 98 | else: 99 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 100 | n = (depth - 2) // 6 101 | block = BasicBlock 102 | 103 | self.inplanes = 16 104 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 105 | bias=False) 106 | self.layer1 = self._make_layer(block, 16, n) 107 | self.layer2 = self._make_layer(block, 32, n, stride=2) 108 | self.layer3 = self._make_layer(block, 64, n, stride=2) 109 | self.bn = nn.BatchNorm2d(64 * block.expansion) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.avgpool = nn.AvgPool2d(8) 112 | self.fc = nn.Linear(64 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | ) 129 | 130 | layers = list() 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | 141 | x = self.layer1(x) # 32x32 142 | x = self.layer2(x) # 16x16 143 | x = self.layer3(x) # 8x8 144 | x = self.bn(x) 145 | x = self.relu(x) 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | x = self.fc(x) 150 | 151 | return x 152 | 153 | 154 | class PreResNet164: 155 | base = PreResNet 156 | args = list() 157 | kwargs = {'depth': 164} 158 | transform_train = transforms.Compose([ 159 | transforms.Resize(32), 160 | transforms.RandomCrop(32, padding=4), 161 | transforms.RandomHorizontalFlip(), 162 | transforms.ToTensor(), 163 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 164 | ]) 165 | transform_test = transforms.Compose([ 166 | transforms.Resize(32), 167 | transforms.ToTensor(), 168 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 169 | ]) 170 | 171 | 172 | class PreResNet110: 173 | base = PreResNet 174 | args = list() 175 | kwargs = {'depth': 110} 176 | transform_train = transforms.Compose([ 177 | transforms.Resize(32), 178 | transforms.RandomCrop(32, padding=4), 179 | transforms.RandomHorizontalFlip(), 180 | transforms.ToTensor(), 181 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 182 | ]) 183 | transform_test = transforms.Compose([ 184 | transforms.Resize(32), 185 | transforms.ToTensor(), 186 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 187 | ]) 188 | 189 | 190 | class PreResNet83: 191 | base = PreResNet 192 | args = list() 193 | kwargs = {'depth': 83} 194 | transform_train = transforms.Compose([ 195 | transforms.RandomCrop(32, padding=4), 196 | transforms.RandomHorizontalFlip(), 197 | transforms.ToTensor(), 198 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 199 | ]) 200 | transform_test = transforms.Compose([ 201 | transforms.ToTensor(), 202 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 203 | ]) 204 | 205 | 206 | class PreResNet56: 207 | base = PreResNet 208 | args = list() 209 | kwargs = {'depth': 56} 210 | transform_train = transforms.Compose([ 211 | transforms.Resize(32), 212 | transforms.RandomCrop(32, padding=4), 213 | transforms.RandomHorizontalFlip(), 214 | transforms.ToTensor(), 215 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 216 | ]) 217 | transform_test = transforms.Compose([ 218 | transforms.Resize(32), 219 | transforms.ToTensor(), 220 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 221 | ]) 222 | 223 | 224 | class PreResNet8: 225 | base = PreResNet 226 | args = list() 227 | kwargs = {'depth': 8} 228 | transform_train = transforms.Compose([ 229 | transforms.Resize(32), 230 | transforms.RandomCrop(32, padding=4), 231 | transforms.RandomHorizontalFlip(), 232 | transforms.ToTensor(), 233 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 234 | ]) 235 | transform_test = transforms.Compose([ 236 | transforms.Resize(32), 237 | transforms.ToTensor(), 238 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 239 | ]) 240 | -------------------------------------------------------------------------------- /URSABench/inference/swag.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import wandb 5 | from torch.optim import SGD 6 | 7 | from URSABench.util import reset_model, bn_update, adjust_learning_rate 8 | from .subspaces import Subspace 9 | from .swa import SWA 10 | 11 | 12 | class SWAG(SWA): 13 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 14 | device=torch.device('cpu'), **subspace_kwargs): 15 | super(SWAG, self).__init__(hyperparameters, model=model, train_loader=train_loader, model_loss=model_loss, 16 | device=device, **subspace_kwargs) 17 | if hyperparameters == None: 18 | # Initialise as some default values 19 | hyperparameters = {'swag_lr': 0.001,'swag_wd': 0.001,'lr_init': 0.001, 'num_samples': 20, 'momentum': 0.1, 'burn_in_epochs':100, 'num_iterates':50} 20 | self.num_samples = hyperparameters['num_samples'] 21 | self.weight_variance = None 22 | 23 | def update_hyp(self, hyperparameters, **subspace_kwargs): 24 | self.weight_mean = torch.zeros(self.num_parameters) 25 | self.weight_variance = None 26 | self.sq_mean = torch.zeros(self.num_parameters) 27 | self.num_models_collected = torch.zeros(1, dtype=torch.long) 28 | self.burnt_in = False 29 | self.epochs_run = 0 30 | self.hyperparameters = hyperparameters 31 | self.burn_in_epochs = self.hyperparameters['burn_in_epochs'] 32 | self.num_iterates = self.hyperparameters['num_iterates'] 33 | self.num_samples = self.hyperparameters['num_samples'] 34 | self.momentum = self.hyperparameters['momentum'] 35 | self.lr_init = self.hyperparameters['lr_init'] 36 | self.swag_lr = self.hyperparameters['swag_lr'] 37 | self.swag_wd = self.hyperparameters['swag_wd'] 38 | self.model = reset_model(self.model) 39 | self.swag_model = reset_model(self.swag_model) 40 | self.optimizer = SGD(params=self.model.parameters(), lr=self.lr_init, momentum=self.momentum, 41 | weight_decay=self.swag_wd) 42 | if 'subspace_type' not in hyperparameters.keys(): 43 | self.subspace_type = 'pca' 44 | else: 45 | self.subspace_type = hyperparameters['subspace_type'] 46 | if subspace_kwargs is None: 47 | subspace_kwargs = dict() 48 | self.subspace = Subspace.create(self.subspace_type, num_parameters=self.num_parameters, 49 | **subspace_kwargs) 50 | 51 | def sample_iterative(self, update_bn=True, val_loader=None, debug_val_loss=False, wandb_debug=False, 52 | full_cov=False): 53 | if issubclass(self.model.__class__, torch.nn.Module): 54 | if self.burnt_in is False: 55 | epochs = self.burn_in_epochs + self.num_iterates 56 | for epoch in range(epochs): 57 | self.model.train() 58 | total_epoch_train_loss = 0. 59 | lr = self._schedule(self.epochs_run) 60 | adjust_learning_rate(self.optimizer, lr) 61 | for batch_idx, (batch_data, batch_labels) in enumerate(self.train_loader): 62 | batch_data = batch_data.to(self.device) 63 | batch_labels = batch_labels.to(self.device) 64 | batch_data_logits = self.model(batch_data) 65 | loss = self.loss_criterion(batch_data_logits, batch_labels) 66 | self.optimizer.zero_grad() 67 | loss.backward() 68 | total_epoch_train_loss += loss.item() * len(batch_data) 69 | self.optimizer.step() 70 | self.epochs_run += 1 71 | if debug_val_loss: 72 | avg_val_loss = self.compute_val_loss(val_loader) 73 | avg_train_loss = total_epoch_train_loss / self.dataset_size 74 | metrics = { 75 | 'train_loss': avg_train_loss, 76 | 'val_loss': avg_val_loss 77 | } 78 | print(metrics) 79 | if wandb_debug: 80 | wandb.log(metrics) 81 | if epoch >= self.burn_in_epochs: 82 | self._collect_model() 83 | self.burnt_in = True 84 | _, self.weight_variance = self._get_mean_and_variance() 85 | if full_cov is False: 86 | weight_sample = torch.normal(self.weight_mean, torch.sqrt(self.weight_variance)) 87 | else: 88 | var_sample = self.weight_variance.sqrt() * torch.randn_like(self.weight_variance, 89 | requires_grad=False) 90 | cov_sample = self.swag_model.subspace.cov_mat_sqrt.t().matmul( 91 | self.swag_model.subspace.cov_mat_sqrt.new_empty( 92 | (self.swag_model.subspace.cov_mat_sqrt.size(0),), requires_grad=False 93 | ).normal_() 94 | ) 95 | cov_sample /= (self.swag_model.subspace.max_rank - 1) ** 0.5 96 | rand_sample = var_sample + cov_sample 97 | weight_sample = self.weight_mean + rand_sample 98 | weight_sample = self.weight_mean 99 | offset = 0 100 | for param in self.swag_model.parameters(): 101 | param.data.copy_(weight_sample[offset:offset + param.numel()].view(param.size()).to(self.device)) 102 | offset += param.numel() 103 | else: 104 | assert (self.burnt_in is True) 105 | if full_cov is False: 106 | weight_sample = torch.normal(self.weight_mean, torch.sqrt(self.weight_variance)) 107 | else: 108 | var_sample = self.weight_variance.sqrt() * torch.randn_like(self.weight_variance, 109 | requires_grad=False) 110 | cov_sample = self.swag_model.subspace.cov_mat_sqrt.t().matmul( 111 | self.swag_model.subspace.cov_mat_sqrt.new_empty( 112 | (self.swag_model.subspace.cov_mat_sqrt.size(0),), requires_grad=False 113 | ).normal_() 114 | ) 115 | cov_sample /= (self.swag_model.subspace.max_rank - 1) ** 0.5 116 | rand_sample = var_sample + cov_sample 117 | weight_sample = self.weight_mean + rand_sample 118 | weight_sample = self.weight_mean 119 | offset = 0 120 | for param in self.swag_model.parameters(): 121 | param.data.copy_(weight_sample[offset:offset + param.numel()].view(param.size()).to(self.device)) 122 | offset += param.numel() 123 | if update_bn: 124 | bn_update(self.train_loader, self.swag_model) 125 | output_model = deepcopy(self.swag_model.cpu()) 126 | self.swag_model.to(self.device) 127 | return output_model 128 | else: 129 | raise NotImplementedError 130 | 131 | def sample(self, num_samples=None, val_loader=None, debug_val_loss=False, wandb_debug=False, full_cov=False): 132 | output_list = [] 133 | if num_samples is None: 134 | num_samples = self.num_samples 135 | if issubclass(self.model.__class__, torch.nn.Module): 136 | for i in range(num_samples): 137 | if i == num_samples - 1: 138 | output_list.append(self.sample_iterative(update_bn=True, val_loader=val_loader, 139 | debug_val_loss=debug_val_loss, wandb_debug=wandb_debug, 140 | full_cov=full_cov)) 141 | else: 142 | output_list.append(self.sample_iterative(update_bn=True, val_loader=val_loader, 143 | debug_val_loss=debug_val_loss, wandb_debug=wandb_debug, 144 | full_cov=full_cov)) 145 | return output_list 146 | else: 147 | raise NotImplementedError 148 | -------------------------------------------------------------------------------- /URSABench/trtprof/prof.py: -------------------------------------------------------------------------------- 1 | """ 2 | TensorRT 7.1.3 Python API: 3 | https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-713/api/python_api/index.html 4 | """ 5 | 6 | import argparse 7 | import pathlib 8 | import time 9 | from copy import deepcopy 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import pycuda.autoinit 15 | import pycuda.driver as cuda 16 | import tensorrt as trt 17 | import torch 18 | from torch.utils.data import DataLoader 19 | from URSABench.trtprof.dataset import DummyDataset 20 | from URSABench.trtprof.utils import * 21 | 22 | 23 | class TensorRTModel: 24 | """Create a TensorRT model for inference from the given `trt` file. 25 | 26 | See: https://github.com/NVIDIA/TensorRT/blob/eb8442dba3c9e85ffb77e0d870d2e29adcb0a4aa/quickstart/IntroNotebooks/onnx_helper.py""" 27 | 28 | def __init__(self, file: Path, num_classes: int, target_dtype=np.float16): 29 | 30 | self.target_dtype = target_dtype 31 | self.num_classes = num_classes 32 | self._load(file) 33 | self._allocate() 34 | # self.stream = None 35 | 36 | def _load(self, file: Path): 37 | logger.debug(f"Loading {file}") 38 | t0 = time.perf_counter() 39 | runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING)) 40 | with open(file, "rb") as f: 41 | self.engine = runtime.deserialize_cuda_engine(f.read()) 42 | self.context = self.engine.create_execution_context() 43 | t1 = time.perf_counter() 44 | logger.debug(f"{file} loaded in {t1-t0:.4f}s") 45 | 46 | def _allocate(self): 47 | bindings = [] 48 | 49 | # allocate for input 50 | input_size = trt.volume(self.engine.get_binding_shape("input")) 51 | input_size *= self.engine.max_batch_size 52 | input_dtype = trt.nptype(self.engine.get_binding_dtype("input")) 53 | input_host_mem = cuda.pagelocked_empty(input_size, input_dtype) 54 | input_device_mem = cuda.mem_alloc(input_host_mem.nbytes) 55 | bindings.append(int(input_device_mem)) 56 | 57 | # allocate for output 58 | output_size = trt.volume(self.engine.get_binding_shape("output")) 59 | output_size *= self.engine.max_batch_size 60 | output_dtype = trt.nptype(self.engine.get_binding_dtype("output")) 61 | output_host_mem = cuda.pagelocked_empty(output_size, output_dtype) 62 | output_device_mem = cuda.mem_alloc(output_host_mem.nbytes) 63 | bindings.append(int(output_device_mem)) 64 | 65 | self.input_host_mem = input_host_mem 66 | self.input_device_mem = input_device_mem 67 | self.output_host_mem = output_host_mem 68 | self.output_device_mem = output_device_mem 69 | self.bindings = bindings 70 | self.stream = cuda.Stream() 71 | 72 | def __call__(self, X: np.ndarray) -> np.ndarray: 73 | self.input_host_mem = X 74 | cuda.memcpy_htod_async(self.input_device_mem, self.input_host_mem, self.stream) 75 | self.context.execute_async_v2(self.bindings, self.stream.handle) 76 | cuda.memcpy_dtoh_async( 77 | self.output_host_mem, self.output_device_mem, self.stream 78 | ) 79 | self.stream.synchronize() 80 | return self.output_host_mem 81 | 82 | 83 | class TensorRTEnsemble(TensorRTModel): 84 | """Create an ensemble model by duplicating a TensorRTModel for several times.""" 85 | 86 | def __init__( 87 | self, 88 | file: Path, 89 | num_classes: int, 90 | target_dtype=np.float32, 91 | n_ensembles: int = 30, 92 | ): 93 | self.models = [ 94 | TensorRTModel(file, num_classes, target_dtype) for _ in range(n_ensembles) 95 | ] 96 | 97 | def __call__(self, X: np.ndarray) -> np.ndarray: 98 | 99 | output_list = [mlp(X) for mlp in self.models] 100 | output = np.stack(output_list).mean() 101 | return output 102 | 103 | 104 | class PyTorchModel: 105 | def __init__(self, file: Path): 106 | self.model = self._load(file) 107 | 108 | def _load(self, file): 109 | print(f"Loading {file} ... ", end="", flush=True) 110 | t0 = time.perf_counter() 111 | model = torch.load(file, map_location=torch.device("cuda")) 112 | t1 = time.perf_counter() 113 | print(f"\r{file} loaded in {t1-t0:.4f}s") 114 | model.eval() 115 | return model 116 | 117 | def __call__(self, X): 118 | X = X.to("cuda") 119 | with torch.no_grad(): 120 | pred = self.model(X) 121 | torch.cuda.synchronize() 122 | pred = pred.cpu() 123 | return pred 124 | 125 | 126 | class PyTorchEnsemble(PyTorchModel): 127 | def __init__(self, file: Path, n_ensemble: int = 30): 128 | m = super()._load(file) 129 | self.model = [deepcopy(m) for _ in range(n_ensemble)] 130 | 131 | def __call__(self, X): 132 | X = X.to("cuda") 133 | with torch.no_grad(): 134 | output_list = [mlp(X) for mlp in self.model] 135 | output = torch.stack(output_list).mean(dim=0) 136 | torch.cuda.synchronize() 137 | output = output.cpu() 138 | return output 139 | 140 | 141 | def warm_up(model, dataloader, n): 142 | print(f"Warming up ... ", end="", flush=True) 143 | for i, (batch_x, _) in enumerate(dataloader): 144 | if i > n: 145 | break 146 | if isinstance(model, TensorRTModel): 147 | batch_x = batch_x.numpy() 148 | model.input_to_gpu(batch_x) 149 | model(batch_x) 150 | print(f"done: {n} runs") 151 | 152 | 153 | def time_batch(model, dataloader, repetition=10): 154 | print(f"Profiling ... ", end="", flush=True) 155 | results = [] 156 | for i, (batch_x, _) in enumerate(dataloader): 157 | latency = [] 158 | if isinstance(model, TensorRTModel): 159 | batch_x = batch_x.numpy() 160 | model.input_to_gpu(batch_x) 161 | for j in range(repetition): 162 | print(f"\rProfiling ... batch {i}, rep {j}", end="", flush=True) 163 | t0 = time.perf_counter() 164 | pred = model(batch_x) 165 | t1 = time.perf_counter() 166 | latency.append(t1 - t0) 167 | results.append( 168 | {"batch": i, "latency": np.mean(latency), "latency_std": np.std(latency)} 169 | ) 170 | print(f"\rProfiling ... done: {i+1} batches.") 171 | return pd.DataFrame(results) 172 | 173 | 174 | def print_stats(df, model_name): 175 | print(f"{model_name}: {df.latency.mean():.4f} +/- {df.latency.std():.4f} s") 176 | 177 | 178 | def run(model_file, is_ensemble, target_dtype): 179 | 180 | batch_size = 1 181 | n_samples = 32 182 | n_classes = 10 183 | n_ensembles = 3 184 | x = DummyDataset(32, n_samples, np.float32) 185 | dataloader = DataLoader(x, batch_size=batch_size, shuffle=False) 186 | # x = MLPDataset(n_samples, n_feats) 187 | # dataloader = DataLoader(x, batch_size=batch_size, shuffle=False) 188 | 189 | if target_dtype == 32: 190 | target_dtype = np.float32 191 | elif target_dtype == 16: 192 | target_dtype = np.float16 193 | 194 | if model_file.suffix == ".pth" and not is_ensemble: 195 | model = PyTorchModel(model_file) 196 | elif model_file.suffix == ".pth" and is_ensemble: 197 | model = PyTorchEnsemble(model_file, n_ensemble=n_ensembles) 198 | elif model_file.suffix == ".trt" and not is_ensemble: 199 | model = TensorRTModel(model_file, n_classes, target_dtype=target_dtype) 200 | elif model_file.suffix == ".trt" and is_ensemble: 201 | model = TensorRTEnsemble( 202 | model_file, n_classes, n_ensembles=n_ensembles, target_dtype=target_dtype 203 | ) 204 | else: 205 | raise 206 | 207 | print(f"CPU ensemble is {is_ensemble}, so use {type(model).__name__}") 208 | 209 | warm_up(model, dataloader, 30) 210 | lat = time_batch(model, dataloader) 211 | if is_ensemble: 212 | print_name = model_file.name + ".cpu_ensemble" 213 | else: 214 | print_name = model_file.name 215 | print_stats(lat, print_name) 216 | 217 | 218 | if __name__ == "__main__": 219 | parser = argparse.ArgumentParser(description="Profiling") 220 | parser.add_argument( 221 | "model_file", 222 | type=pathlib.Path, 223 | help="Path to pth or trt file.", 224 | ) 225 | parser.add_argument("--cpu-ensemble", dest="is_ensemble", action="store_true") 226 | parser.add_argument( 227 | "--target-dtype", dest="target_dtype", type=int, action="store_true" 228 | ) 229 | parser.set_defaults(is_ensemble=False, target_dtype=32) 230 | args = parser.parse_args() 231 | run(args.model_file, args.is_ensemble, args.target_dtype) 232 | -------------------------------------------------------------------------------- /URSABench/inference/swa.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import wandb 5 | from torch.optim import SGD 6 | 7 | from URSABench.util import get_loss_criterion, reset_model, set_weights, flatten, \ 8 | bn_update, adjust_learning_rate 9 | from .inference_base import _Inference 10 | from .subspaces import Subspace 11 | 12 | 13 | class SWA(_Inference): 14 | def __init__(self, hyperparameters, model=None, train_loader=None, model_loss='multi_class_linear_output', 15 | device=torch.device('cpu'), **subspace_kwargs): 16 | super(SWA, self).__init__(hyperparameters, model=None, train_loader=None, device=torch.device('cpu')) 17 | if hyperparameters == None: 18 | # Initialise as some default values 19 | hyperparameters = {'swag_lr': 0.001,'swag_wd': 0.001,'lr_init': 0.001, 'num_samples': 20, 'momentum': 0.1, 'burn_in_epochs':100, 'num_iterates':50} 20 | 21 | self.hyperparameters = hyperparameters 22 | self.swag_model = deepcopy(model) 23 | self.model = model 24 | self.num_parameters = sum(param.numel() for param in self.swag_model.parameters()) 25 | self.weight_mean = torch.zeros(self.num_parameters) 26 | self.sq_mean = torch.zeros(self.num_parameters) 27 | self.num_models_collected = torch.zeros(1, dtype=torch.long) 28 | self.var_clamp = 1e-30 29 | self.device = device 30 | self.train_loader = train_loader 31 | self.loss_criterion = get_loss_criterion(loss=model_loss) 32 | self.dataset_size = len(train_loader.dataset) 33 | self.burnt_in = False 34 | self.epochs_run = 0 35 | self.burn_in_epochs = self.hyperparameters['burn_in_epochs'] 36 | self.num_iterates = self.hyperparameters['num_iterates'] 37 | self.momentum = self.hyperparameters['momentum'] 38 | self.lr_init = self.hyperparameters['lr_init'] 39 | self.swag_lr = self.hyperparameters['swag_lr'] 40 | self.swag_wd = self.hyperparameters['swag_wd'] 41 | self.optimizer = SGD(params=self.model.parameters(), lr=self.lr_init, momentum=self.momentum, 42 | weight_decay=self.swag_wd) 43 | if 'subspace_type' not in hyperparameters.keys(): 44 | self.subspace_type = 'pca' 45 | else: 46 | self.subspace_type = hyperparameters['subspace_type'] 47 | if subspace_kwargs is None: 48 | subspace_kwargs = dict() 49 | self.subspace = Subspace.create(self.subspace_type, num_parameters=self.num_parameters, 50 | **subspace_kwargs) 51 | self.cov_factor = None 52 | 53 | def update_hyp(self, hyperparameters, **subspace_kwargs): 54 | self.weight_mean = torch.zeros(self.num_parameters) 55 | self.sq_mean = torch.zeros(self.num_parameters) 56 | self.num_models_collected = torch.zeros(1, dtype=torch.long) 57 | self.burnt_in = False 58 | self.epochs_run = 0 59 | self.hyperparameters = hyperparameters 60 | self.burn_in_epochs = self.hyperparameters['burn_in_epochs'] 61 | self.num_iterates = self.hyperparameters['num_iterates'] 62 | self.momentum = self.hyperparameters['momentum'] 63 | self.lr_init = self.hyperparameters['lr_init'] 64 | self.swag_lr = self.hyperparameters['swag_lr'] 65 | self.swag_wd = self.hyperparameters['swag_wd'] 66 | self.model = reset_model(self.model) 67 | self.swag_model = reset_model(self.swag_model) 68 | self.optimizer = SGD(params=self.model.parameters(), lr=self.lr_init, momentum=self.momentum, 69 | weight_decay=self.swag_wd) 70 | if 'subspace_type' not in hyperparameters.keys(): 71 | self.subspace_type = 'pca' 72 | else: 73 | self.subspace_type = hyperparameters['subspace_type'] 74 | if subspace_kwargs is None: 75 | subspace_kwargs = dict() 76 | self.subspace = Subspace.create(self.subspace_type, num_parameters=self.num_parameters, 77 | **subspace_kwargs) 78 | 79 | def _collect_model(self): 80 | 81 | w = flatten([param.detach().cpu() for param in self.model.parameters()]) 82 | # first moment 83 | self.weight_mean.mul_(self.num_models_collected.item() / (self.num_models_collected.item() + 1.0)) 84 | self.weight_mean.add_(w / (self.num_models_collected.item() + 1.0)) 85 | 86 | # second moment 87 | self.sq_mean.mul_(self.num_models_collected.item() / (self.num_models_collected.item() + 1.0)) 88 | self.sq_mean.add_(w ** 2 / (self.num_models_collected.item() + 1.0)) 89 | deviation_vector = w - self.weight_mean 90 | self.subspace.collect_vector(deviation_vector) 91 | 92 | def _schedule(self, epoch): 93 | t = epoch / self.burn_in_epochs 94 | lr_ratio = self.swag_lr / self.lr_init 95 | if t <= 0.5: 96 | factor = 1.0 97 | elif t <= 0.9: 98 | factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 99 | else: 100 | factor = lr_ratio 101 | return self.lr_init * factor 102 | 103 | def _set_swa(self): 104 | set_weights(self.swag_model, self.weight_mean, self.device) 105 | 106 | def _get_mean_and_variance(self): 107 | variance = torch.clamp(self.sq_mean - self.weight_mean ** 2, self.var_clamp) 108 | return self.weight_mean, variance 109 | 110 | def fit(self): 111 | if self.cov_factor is not None: 112 | return 113 | self.cov_factor = self.subspace.get_space() 114 | 115 | def get_space(self, export_cov_factor=True): 116 | mean, variance = self._get_mean_and_variance() 117 | if not export_cov_factor: 118 | return mean.clone(), variance.clone() 119 | else: 120 | self.fit() 121 | return mean.clone(), variance.clone(), self.cov_factor.clone() 122 | 123 | def sample_iterative(self, update_bn_swa=True, val_loader=None, debug_val_loss=False, wandb_debug=False): 124 | if issubclass(self.model.__class__, torch.nn.Module): 125 | if self.burnt_in is False: 126 | epochs = self.burn_in_epochs + 1 127 | self.burnt_in = True 128 | else: 129 | epochs = 1 130 | self.num_models_collected += 1 131 | for epoch in range(epochs): 132 | self.model.train() 133 | lr = self._schedule(self.epochs_run) 134 | adjust_learning_rate(self.optimizer, lr) 135 | total_epoch_train_loss = 0. 136 | for batch_idx, (batch_data, batch_labels) in enumerate(self.train_loader): 137 | batch_data = batch_data.to(self.device) 138 | batch_labels = batch_labels.to(self.device) 139 | batch_data_logits = self.model(batch_data) 140 | loss = self.loss_criterion(batch_data_logits, batch_labels) 141 | self.optimizer.zero_grad() 142 | loss.backward() 143 | total_epoch_train_loss += loss.item() * len(batch_data) 144 | self.optimizer.step() 145 | self.epochs_run += 1 146 | if debug_val_loss: 147 | avg_val_loss = self.compute_val_loss(val_loader) 148 | avg_train_loss = total_epoch_train_loss / self.dataset_size 149 | metrics = { 150 | 'train_loss': avg_train_loss, 151 | 'val_loss': avg_val_loss 152 | } 153 | print(metrics) 154 | if wandb_debug: 155 | wandb.log(metrics) 156 | self._collect_model() 157 | if update_bn_swa: 158 | self._set_swa() 159 | bn_update(self.train_loader, self.swag_model) 160 | return self.swag_model 161 | else: 162 | raise NotImplementedError 163 | 164 | def sample(self, num_samples=None, val_loader=None, debug_val_loss=False, wandb_debug=False): 165 | output_list = [] 166 | if num_samples is None: 167 | num_samples = self.num_iterates 168 | if issubclass(self.model.__class__, torch.nn.Module): 169 | for i in range(num_samples): 170 | if i == num_samples - 1: 171 | output_list.append(self.sample_iterative(update_bn_swa=True, val_loader=val_loader, 172 | debug_val_loss=debug_val_loss, wandb_debug=wandb_debug)) 173 | else: 174 | output_list.append(self.sample_iterative(update_bn_swa=False, val_loader=val_loader, 175 | debug_val_loss=debug_val_loss, wandb_debug=wandb_debug)) 176 | return output_list 177 | else: 178 | raise NotImplementedError 179 | -------------------------------------------------------------------------------- /URSABench/trtprof/run_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import glob 4 | import json 5 | import pathlib 6 | import re 7 | import sys 8 | import time 9 | from collections import defaultdict 10 | 11 | # Caused by the symbolic link issue, e.g.: "OSError: libcublasLt.so.10: cannot 12 | # open shared object file: Too many levels of symbolic links" 13 | try: 14 | import URSABench 15 | except (OSError, ImportError): 16 | sys.exit(3) 17 | 18 | import numpy as np 19 | import psutil 20 | import torch 21 | from torch.utils.data import DataLoader 22 | from URSABench import datasets 23 | from URSABench import models 24 | from URSABench.trtprof.dataset import DummyDataset 25 | from URSABench.trtprof.dataset import MLPDataset 26 | from URSABench.trtprof.prof import TensorRTModel 27 | from URSABench.trtprof.utils import logger 28 | 29 | # experiment parameters 30 | # IMG_HW = 32 31 | IMG_HW = 224 32 | BATCH_SIZE = 1 33 | LATENCY_MODE_SAMPLE_SIZE = 100 34 | SHUFFLE = False 35 | DEVICE = torch.device("cuda") 36 | 37 | 38 | def load_results(json_file): 39 | try: 40 | with open(json_file, "r") as f: 41 | return json.load(f) 42 | except FileNotFoundError: 43 | return {} 44 | 45 | 46 | def dump_results(json_file, results): 47 | with open(json_file, "w") as f: 48 | json.dump(results, f, indent=4) 49 | 50 | 51 | def pytorch_from_stat_dict(checkpoint_file, model_cfg, n_classes): 52 | model = model_cfg.base( 53 | *model_cfg.args, num_classes=n_classes, **model_cfg.kwargs 54 | ).to("cuda") 55 | checkpoint = torch.load(checkpoint_file, map_location="cuda") 56 | model.load_state_dict(checkpoint) 57 | return model 58 | 59 | 60 | def load_model(model_file, model_suffix, model_cfg, num_classes): 61 | if model_suffix == "pt": 62 | model = pytorch_from_stat_dict(model_file, model_cfg, num_classes) 63 | elif model_suffix in ["trt", "trt32"]: 64 | model = TensorRTModel(model_file, num_classes) 65 | else: 66 | raise ValueError 67 | return model 68 | 69 | 70 | def get_latency(latencies, burn_in=10): 71 | by_batch = defaultdict(int) 72 | for x in latencies: 73 | if x["batch_idx"] > burn_in: 74 | by_batch[x["batch_idx"]] += x["latency"] 75 | lat = np.array(list(by_batch.values())) 76 | lat_mean = np.mean(lat) 77 | lat_avg = np.std(lat) 78 | return lat_mean, lat_avg 79 | 80 | 81 | def chunks(long_list, n): 82 | for i in range(0, len(long_list), n): 83 | yield long_list[i : i + n] 84 | 85 | 86 | def parse_model_number(fname): 87 | return int(fname.split(".")[0].split("_")[-1]) 88 | 89 | 90 | def batch_ensemble(file_list, n_ensemble): 91 | file_list = sorted(file_list, key=parse_model_number) 92 | ensembles = list(chunks(file_list, n_ensemble)) 93 | ensembles = [sorted(l, key=parse_model_number) for l in ensembles] 94 | ensembles = {",".join(l): l for l in ensembles} 95 | return ensembles 96 | 97 | 98 | def run( 99 | model_dir: pathlib.Path, 100 | model_suffix: str, 101 | profile_mode: str, 102 | output_suffix: str, 103 | n_ensemble: int, 104 | ): 105 | model_class_name, dataset_name = tuple(model_dir.name.split("_")) 106 | assert model_class_name in [ 107 | "WideResNet28x10", 108 | "INResNet50", 109 | "MLP200MNIST", 110 | "ResNet50", 111 | ] 112 | assert dataset_name in ["CIFAR10", "CIFAR100", "MNIST", "ImageNet"] 113 | assert model_suffix in ["pt", "trt", "trt32"] 114 | assert profile_mode in ["latency", "metrics"] 115 | 116 | dataset_path = f"/data/{dataset_name.lower()}" 117 | output_json_path = ( 118 | f"{model_dir}.{model_suffix}.{profile_mode}.{output_suffix}{n_ensemble}.json" 119 | ) 120 | 121 | logger.debug(f"model_class_name: {model_class_name}") 122 | logger.debug(f"model_suffix: {model_suffix}") 123 | logger.debug(f"profile_mode: {profile_mode}") 124 | logger.debug(f"dataset_path: {dataset_path}") 125 | logger.debug(f"output_json_path: {output_json_path}") 126 | 127 | # load cache 128 | results = load_results(output_json_path) 129 | 130 | # glob model files 131 | model_files = glob.glob(f"{model_dir}/*.{model_suffix}") 132 | ensemble_batches = batch_ensemble(model_files, n_ensemble) 133 | n_total_models = len(ensemble_batches) 134 | n_processed = len(results) 135 | logger.debug(f"{n_total_models} in total, {n_processed} already processed") 136 | # to_be_profiled = [x for x in model_files if x not in results] 137 | to_be_profiled = {k: v for k, v in ensemble_batches.items() if k not in results} 138 | 139 | if len(to_be_profiled) == 0: 140 | logger.info("All done.") 141 | sys.exit(0) 142 | 143 | # model config 144 | try: 145 | model_cfg = getattr(models, model_class_name) 146 | except AttributeError: 147 | model_cfg = None 148 | 149 | # setup dataloader 150 | if profile_mode == "latency" and dataset_name in [ 151 | "CIFAR10", 152 | "CIFAR100", 153 | "ImageNet", 154 | ]: 155 | dummy_dataset = DummyDataset(IMG_HW, LATENCY_MODE_SAMPLE_SIZE, np.float32) 156 | test_loader = DataLoader(dummy_dataset, batch_size=BATCH_SIZE) 157 | if dataset_name == "ImageNet": 158 | num_classes = 1000 159 | else: 160 | num_classes = int(re.sub("\D", "", dataset_name)) 161 | logger.debug("Using DummyDataset") 162 | elif profile_mode == "latency" and dataset_name in ["MNIST"]: 163 | dummy_dataset = MLPDataset(100, 10) 164 | test_loader = DataLoader(dummy_dataset, batch_size=BATCH_SIZE) 165 | num_classes = 10 166 | logger.debug("Using MLPDataset") 167 | else: 168 | URSABench.set_random_seed(0) 169 | loaders, num_classes = datasets.loaders( 170 | dataset_name, 171 | dataset_path, 172 | BATCH_SIZE, 173 | 0, 174 | transform_train=model_cfg.transform_train, 175 | transform_test=model_cfg.transform_test, 176 | shuffle_train=True, 177 | use_validation=False, 178 | ) 179 | test_loader = loaders["test"] 180 | logger.debug("Using real dataset") 181 | dataloader = {"in_distribution_test": test_loader} 182 | logger.debug( 183 | f"dataloader: {num_classes} classes, {test_loader.batch_size} batch_size" 184 | ) 185 | logger.debug( 186 | f"dataloader: {len(test_loader)} batches, {len(test_loader.dataset)} samples" 187 | ) 188 | 189 | # prediction instance 190 | if profile_mode == "latency": 191 | latency_mode = True 192 | metrics = ["ll"] 193 | else: 194 | latency_mode = False 195 | metrics = "ALL" 196 | 197 | t0 = time.perf_counter() 198 | predict = URSABench.tasks.Prediction( 199 | dataloader=dataloader, 200 | metric_list=metrics, 201 | num_classes=num_classes, 202 | device=DEVICE, 203 | latency_mode=latency_mode, 204 | ) 205 | t1 = time.perf_counter() 206 | logger.debug(f"Prediction instance created in {t1-t0:.4f}s.") 207 | 208 | # for i, model_file in enumerate(to_be_profiled): 209 | # for ensemble_name, ensemble_files in to_be_profiled.items(): 210 | 211 | # only run one batch of models a time; repeat in bash script which calls 212 | # this script to finish the whole task, so that avoids fluctuations. 213 | ensemble_name, ensemble_files = to_be_profiled.popitem() 214 | 215 | logger.debug(f"{ensemble_files}") 216 | 217 | t0 = time.perf_counter() 218 | model_list = [ 219 | load_model(model_file, model_suffix, model_cfg, num_classes) 220 | for model_file in ensemble_files 221 | ] 222 | t1 = time.perf_counter() 223 | model_load_time = t1 - t0 224 | logger.debug(f"{len(model_list)} models loaded in {t1-t0:.4f}s") 225 | 226 | t0 = time.perf_counter() 227 | predict.update_statistics(model_list, output_performance=False) 228 | t1 = time.perf_counter() 229 | logger.debug( 230 | f"Availabel mem: {psutil.virtual_memory().available / (1024 * 1024)} MB" 231 | ) 232 | 233 | val_dict = predict.get_performance_metrics() 234 | lat_mean, lat_std = get_latency(predict.latencies) 235 | val_dict["latency_mean"] = lat_mean 236 | val_dict["latency_std"] = lat_std 237 | # include model load time in the result 238 | val_dict["model_load_time"] = model_load_time 239 | val_dict["added_time"] = datetime.datetime.now().isoformat() 240 | results[ensemble_name] = val_dict 241 | 242 | dump_results(output_json_path, results) 243 | logger.info(f"Ensemble profiled in {t1-t0:.4f} s.") 244 | 245 | logger.info( 246 | f"{len(results)}/{n_total_models}: {n_total_models - len(results)} left." 247 | ) 248 | 249 | # only finished one batch 250 | sys.exit(4) 251 | 252 | 253 | if __name__ == "__main__": 254 | try: 255 | parser = argparse.ArgumentParser(description="Run models.Prediction") 256 | parser.add_argument( 257 | "model_dir", 258 | type=pathlib.Path, 259 | help="Path to the directory of the `pt` files or `trt` files.", 260 | ) 261 | parser.add_argument( 262 | "model_suffix", type=str, help="Type of models: `pt` or `trt`." 263 | ) 264 | parser.add_argument( 265 | "profile_mode", type=str, help="Profiling mode: `latency` or `metrics`." 266 | ) 267 | parser.add_argument("output_suffix", type=str, help="Suffix to output json") 268 | parser.add_argument( 269 | "n_ensemble", type=int, help="Number of ensemble component." 270 | ) 271 | args = parser.parse_args() 272 | run( 273 | args.model_dir, 274 | args.model_suffix, 275 | args.profile_mode, 276 | args.output_suffix, 277 | args.n_ensemble, 278 | ) 279 | except KeyboardInterrupt: 280 | sys.exit(0) 281 | -------------------------------------------------------------------------------- /URSABench/experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import warnings 5 | 6 | import torch 7 | 8 | from URSABench import models, inference, tasks, datasets, util 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)') 15 | parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH', 16 | help='path to datasets location (default: None)') 17 | parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)') 18 | parser.add_argument('--num_trials', type=int, default=1, help='number of repeats of each experiment (default: 1)') 19 | parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL', 20 | help='model name (default: None)') 21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 22 | parser.add_argument('--inference_method', type=str, default='HMC', help='Inference Method (default: HMC)') 23 | parser.add_argument('--hyperparams', type=str, default=None, help='Hyperparameters in JSON format (default:None)') 24 | parser.add_argument('--hyperparams_path', default=None, help="Path to json file containing hyperparams", 25 | type=lambda x: util.json_open_from_file(parser, x)) 26 | parser.add_argument('--task', type=str, default='Prediction', help='Downstream task to evaluate (default: Prediction)') 27 | parser.add_argument('--split_classes', type=int, default=None) 28 | parser.add_argument('--validation', type=float, default=0.2, 29 | help='Proportation of training used as validation (default: Prediction)') 30 | parser.add_argument('--use_val', dest = 'use_val', action='store_true', help='use val dataset instead of test (default: False)') 31 | parser.add_argument('--use_dm_imbalance', dest = 'use_dm_imbalance', action='store_true', help='use imbalance data set for DM (retrain) (default: False)') 32 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)') 33 | parser.add_argument('--save_path', type=str, default=None, required=True, metavar='PATH', 34 | help='path to file to store results (default: None)') 35 | parser.add_argument('--device_num', type=int, default=0, help='Device number to select (default: 0)') 36 | parser.add_argument('--pretrained_model_path', type=str, default=None, help='Path of the pretrained model') 37 | 38 | 39 | args = parser.parse_args() 40 | util.set_random_seed(args.seed) 41 | if torch.cuda.is_available(): 42 | args.device = torch.device('cuda') 43 | torch.cuda.set_device(args.device_num) 44 | else: 45 | args.device = torch.device('cpu') 46 | 47 | # import pdb; pdb.set_trace() 48 | if args.hyperparams is None: 49 | hyperparams = args.hyperparams_path 50 | else: 51 | hyperparams = json.loads(args.hyperparams) 52 | model_cfg = getattr(models, args.model) 53 | loaders, num_classes = datasets.loaders( 54 | args.dataset, 55 | args.data_path, 56 | args.batch_size, 57 | args.num_workers, 58 | transform_train=model_cfg.transform_train, 59 | transform_test=model_cfg.transform_test, 60 | shuffle_train=True, 61 | use_validation=args.use_val, 62 | val_size=args.validation, 63 | split_classes=args.split_classes 64 | ) 65 | train_loader = loaders['train'] 66 | test_loader = loaders['test'] 67 | # loaders['train'].dataset.data = loaders['train'].dataset.data[:30] 68 | # loaders['train'].dataset.targets = loaders['train'].dataset.targets[:30] 69 | 70 | num_classes = int(num_classes) 71 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs).to(args.device) 72 | if args.pretrained_model_path is not None: 73 | model.load_state_dict(torch.load(args.pretrained_model_path)) 74 | inference_method = getattr(inference, args.inference_method) 75 | inference_object = inference_method(hyperparameters=hyperparams, model=model, train_loader=train_loader, 76 | device=args.device) 77 | # model_ensemble = inference_object.sample() 78 | # 79 | # for model_idx, model in enumerate(model_ensemble): 80 | # torch.save(model.state_dict(), args.save_path + 'sghmc_sample_%d.pt' % model_idx) 81 | 82 | task_method = getattr(tasks, args.task) 83 | task_data_loader = {'in_distribution_test': test_loader} 84 | task_data_loader_dm = {'decision_data_test': test_loader} # For if we use no imbalance 85 | metric_list = 'ALL' 86 | 87 | # IF PART OF HYPOPT 88 | if args.task == 'Prediction' and args.use_val: 89 | task_object = task_method(dataloader=task_data_loader, num_classes=num_classes, device=args.device, 90 | metric_list=metric_list) 91 | task_object.update_statistics(models=model_ensemble, output_performance=False, smoothing=True) 92 | task_performance = task_object.get_performance_metrics() 93 | hyperparam_values = [hyperparams[key] for key in sorted(hyperparams.keys())] 94 | print(sorted(hyperparams.keys())) 95 | task_performance_values = [task_performance[key] for key in sorted(task_performance.keys())] 96 | print(sorted(task_performance.keys())) 97 | with open('results.csv', 'a+') as csvFile: 98 | writer = csv.writer(csvFile, dialect='excel') 99 | writer.writerow([ 100 | args.dataset, 101 | args.model, 102 | args.seed, 103 | args.inference_method, 104 | args.task, 105 | args.batch_size, 106 | *hyperparam_values, 107 | *task_performance_values 108 | ]) 109 | 110 | # IF PART OF TESTING 111 | S = args.num_trials # Number of trials (Random Seeds) 112 | OOD_loaders_list = [] 113 | 114 | if not args.use_val: 115 | if args.dataset == 'MNIST': 116 | # OOD 117 | data_name = ['FashionMNIST', 'KMNIST'] 118 | for d_name in data_name: 119 | loaders, _ = datasets.loaders( 120 | d_name, 121 | args.data_path + d_name, 122 | args.batch_size, 123 | args.num_workers, 124 | transform_train=model_cfg.transform_train, 125 | transform_test=model_cfg.transform_test, 126 | shuffle_train=True, 127 | use_validation=False, 128 | val_size=args.validation, 129 | split_classes=args.split_classes 130 | ) 131 | ood_d = {} 132 | ood_d['data'] = d_name 133 | ood_d['in_distribution_test'] = task_data_loader['in_distribution_test'] 134 | ood_d['out_distribution_test'] = loaders['test'] 135 | OOD_loaders_list.append(ood_d) 136 | 137 | elif args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100': 138 | # OOD 139 | 140 | data_name = ['STL10', 'SVHN'] 141 | for d_name in data_name: 142 | loaders, _ = datasets.loaders( 143 | d_name, 144 | args.data_path + d_name, 145 | args.batch_size, 146 | args.num_workers, 147 | transform_train=model_cfg.transform_train, 148 | transform_test=model_cfg.transform_test, 149 | shuffle_train=True, 150 | use_validation=False, 151 | val_size=args.validation, 152 | split_classes=args.split_classes 153 | ) 154 | ood_d = {} 155 | ood_d['data'] = d_name 156 | ood_d['in_distribution_test'] = task_data_loader['in_distribution_test'] 157 | ood_d['out_distribution_test'] = loaders['test'] 158 | OOD_loaders_list.append(ood_d) 159 | 160 | elif args.dataset == 'TIN': 161 | # TODO ADD TASK 162 | pass 163 | else: 164 | raise NotImplementedError 165 | 166 | results_dic = {} 167 | temp_dic = {} 168 | cost_list = [] 169 | for s in range(S): 170 | util.set_random_seed(s) 171 | print('Prediction: ',s) 172 | inference_object = inference_method(hyperparameters=hyperparams, model=model, train_loader=train_loader, 173 | device=args.device) 174 | model_ensemble = inference_object.sample() 175 | 176 | # Prediction 177 | task_object = task_method(dataloader=task_data_loader, num_classes=num_classes, device=args.device, metric_list=metric_list) 178 | task_object.update_statistics(models=model_ensemble, output_performance=False, smoothing=True) 179 | task_performance = task_object.get_performance_metrics() 180 | if not args.use_dm_imbalance and not args.dataset == 'TIN': 181 | print('Running DM task on balanced data: ',s) 182 | dec_object = tasks.Decision(dataloader=task_data_loader_dm, num_classes=num_classes, device=args.device) 183 | dec_object.update_statistics(models=model_ensemble, output_performance=False, smoothing=True) 184 | dec_result = dec_object.get_performance_metrics() 185 | cost_list.append(dec_result['True_Cost']) 186 | 187 | if args.dataset == 'TIN': 188 | pass 189 | else: 190 | print('OOD: ',s) 191 | for ood_data_loader in OOD_loaders_list: 192 | ood_object = tasks.OODDetection(data_loader=ood_data_loader, num_classes=num_classes, 193 | device=args.device) 194 | dic_ood = ood_object.update_statistics(model_ensemble, output_performance=True) 195 | 196 | for i, key in enumerate(dic_ood.keys()): 197 | if s == 0: 198 | temp_dic[key] = [dic_ood[key]] 199 | else: 200 | temp_dic[key].append(dic_ood[key]) 201 | if s == S-1: 202 | results_dic[key + '_'+ ood_data_loader['data'] +'_mean'] = torch.mean(torch.tensor(temp_dic[key]).float()) 203 | results_dic[key + '_'+ ood_data_loader['data'] +'_std'] = torch.std(torch.tensor(temp_dic[key]).float()) 204 | 205 | 206 | for i, key in enumerate(task_object.required_metric_list): 207 | if s == 0: 208 | temp_dic[key] = [task_performance[key]] 209 | else: 210 | temp_dic[key].append(task_performance[key]) 211 | if s == S-1: 212 | results_dic[key + '_mean'] = torch.mean(torch.tensor(temp_dic[key]).float()) 213 | results_dic[key + '_std'] = torch.std(torch.tensor(temp_dic[key]).float()) 214 | if not args.dataset == 'TIN': 215 | if args.use_dm_imbalance: 216 | # Decision Making: 217 | cost_list = [] 218 | for s in range(S): 219 | print('Decision Making SEED: ',s) 220 | util.set_random_seed(s) 221 | loaders, num_classes = datasets.loaders( 222 | args.dataset, 223 | args.data_path, 224 | args.batch_size, 225 | args.num_workers, 226 | transform_train=model_cfg.transform_train, 227 | transform_test=model_cfg.transform_test, 228 | shuffle_train=True, 229 | use_validation=False, 230 | val_size=args.validation, 231 | split_classes=args.split_classes, 232 | imbalance=True 233 | ) 234 | train_loader = loaders['train'] 235 | test_loader = loaders['test'] 236 | inference_method = getattr(inference, args.inference_method) 237 | inference_object = inference_method(hyperparameters=hyperparams, model=model, train_loader=train_loader, 238 | device=args.device) 239 | model_ensemble = inference_object.sample() 240 | task_data_loader = {'decision_data_test': test_loader} 241 | dec_object = tasks.Decision(dataloader=task_data_loader, num_classes=num_classes, device=args.device) 242 | dec_object.update_statistics(models=model_ensemble, output_performance=False, smoothing=True) 243 | dec_result = dec_object.get_performance_metrics() 244 | cost_list.append(dec_result['True_Cost']) 245 | 246 | results_dic['cost_mean'] = torch.mean(torch.tensor(cost_list)) 247 | results_dic['cost_std'] = torch.std(torch.tensor(cost_list)) 248 | 249 | hyperparam_values = [hyperparams[key] for key in sorted(hyperparams.keys())] 250 | task_performance_values = [results_dic[key] for key in sorted(results_dic.keys())] 251 | print(sorted(results_dic.keys())) 252 | with open(args.save_path + 'results.csv', 'a+') as csvFile: 253 | writer = csv.writer(csvFile, dialect='excel') 254 | writer.writerow([ 255 | args.dataset, 256 | args.model, 257 | args.seed, 258 | args.inference_method, 259 | args.task, 260 | args.batch_size, 261 | *hyperparam_values, 262 | *task_performance_values 263 | ]) 264 | 265 | print(results_dic) 266 | torch.save(results_dic, args.save_path + '_tests.npy') 267 | --------------------------------------------------------------------------------