├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── configs ├── config.yaml ├── dataset │ ├── cifar10.yaml │ ├── cifar10_embedded.yaml │ ├── fashionmnist.yaml │ ├── fashionmnist_embedded.yaml │ ├── gaussian2d.yaml │ ├── genomics.yaml │ ├── genomics_embedded.yaml │ ├── segment.yaml │ └── sensorless.yaml ├── grid │ ├── seed-bottleneck.yaml │ ├── seed-clf_weight.yaml │ ├── seed-pyxce.yaml │ └── seed.yaml ├── model │ ├── fc_ce_baseline.yaml │ ├── fc_mcmc.yaml │ ├── fc_mcmc_supervised.yaml │ ├── fc_norm_flow.yaml │ ├── fc_ssm.yaml │ ├── fc_ssm_supervised.yaml │ ├── fc_vera.yaml │ ├── fc_vera_supervised.yaml │ ├── genomics_autoregressive_density.yaml │ ├── genomics_ce_baseline.yaml │ ├── genomics_mcmc.yaml │ ├── genomics_mcmc_supervised.yaml │ ├── genomics_vera.yaml │ ├── genomics_vera_supervised.yaml │ ├── img_autoencoder.yaml │ ├── img_ce_baseline.yaml │ ├── img_glow.yaml │ ├── img_mcmc.yaml │ ├── img_mcmc_supervised.yaml │ ├── img_real_nvp.yaml │ ├── img_ssm.yaml │ ├── img_ssm_supervised.yaml │ ├── img_vera.yaml │ ├── img_vera_posteriornet.yaml │ ├── img_vera_priornet.yaml │ └── img_vera_supervised.yaml └── seml_config.yaml ├── density_histograms.png ├── req.txt └── uncertainty_est ├── __init__.py ├── archs ├── __init__.py ├── arch_factory.py ├── fc.py ├── flows.py ├── glow │ ├── __init__.py │ ├── act_norm.py │ ├── coupling.py │ ├── glow.py │ └── inv_conv.py ├── real_nvp │ ├── __init__.py │ ├── coupling_layer.py │ ├── real_nvp.py │ ├── resnet.py │ └── util.py ├── resnet.py └── wrn.py ├── data ├── __init__.py ├── dataloaders.py └── datasets.py ├── evaluate.py ├── models ├── __init__.py ├── ce_baseline.py ├── ebm │ ├── __init__.py │ ├── conditional_nce.py │ ├── discrete_mcmc.py │ ├── flow_contrastive_estimation.py │ ├── hdge.py │ ├── mcmc.py │ ├── nce.py │ ├── ssm.py │ ├── utils │ │ ├── __init__.py │ │ ├── model.py │ │ ├── utils.py │ │ └── vera_utils.py │ ├── vera.py │ ├── vera_posteriornet.py │ └── vera_priornet.py ├── energy_finetuning.py ├── normalizing_flow │ ├── __init__.py │ ├── approx_flow.py │ ├── image_flows.py │ ├── iresnet.py │ └── norm_flow.py └── ood_detection_model.py ├── train.py └── utils ├── dirichlet.py ├── metrics.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | 4 | logs 5 | grid_search 6 | slurm 7 | /_cache_datalsuntestlmdb 8 | notebooks 9 | tmp 10 | temp 11 | thesis_logs 12 | ebm_investigation_logs 13 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/ambv/black 9 | rev: 20.8b1 10 | hooks: 11 | - id: black 12 | language_version: python3.6 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On Out-of-distribution Detection with Energy-based Models 2 | 3 | This repository contains the code for the experiments conducted in the paper 4 | 5 | > [On Out-of-distribution Detection with Energy-based Models](https://arxiv.org/abs/2107.08785) \ 6 | Sven Elflein, Bertrand Charpentier, Daniel Zügner, Stephan Günnemann \ 7 | ICML 2021, Workshop on Uncertainty & Robustness in Deep Learning. 8 | 9 | 10 |

11 | 12 |

13 | 14 | 15 | ## Setup 16 | 17 | ``` 18 | conda create --name env --file req.txt 19 | conda activate env 20 | pip install git+https://github.com/selflein/nn_uncertainty_eval 21 | ``` 22 | 23 | ### Datasets 24 | The image datasets should download automatically. For "Sensorless Drive" and "Segment" pre-processed .csv files can be downloaded from the [PostNet repo](https://github.com/sharpenb/Posterior-Network#training--evaluation) under "Training & Evaluation". 25 | 26 | ## Training & Evaluation 27 | 28 | In order to train a model use the respective combination of configurations for dataset and model, e.g., 29 | 30 | ``` 31 | python uncertainty_est/train.py fixed.output_folder=./path/to/output/folder dataset=sensorless model=fc_mcmc 32 | ``` 33 | 34 | to train a EBM with MCMC on the Sensorless dataset. See `configs/model` for all model configurations. 35 | 36 | In order to evaluate models use 37 | 38 | ``` 39 | python uncertainty_est/evaluate.py --checkpoint-dir ./path/to/directory/with/models --output-folder ./path/to/output/folder 40 | ``` 41 | 42 | This script generates CSVs with the respective OOD metrics. 43 | 44 | ## Cite 45 | 46 | If you find our work helpful, please consider citing our paper in your own work. 47 | 48 | ``` 49 | @misc{elflein2021outofdistribution, 50 | title={On Out-of-distribution Detection with Energy-based Models}, 51 | author={Sven Elflein and Bertrand Charpentier and Daniel Zügner and Stephan Günnemann}, 52 | year={2021}, 53 | eprint={2107.08785}, 54 | archivePrefix={arXiv}, 55 | primaryClass={cs.LG} 56 | } 57 | ``` 58 | 59 | ## Acknowledgements 60 | 61 | * RealNVP from https://github.com/chrischute/real-nvp 62 | * Glow from https://github.com/chrischute/glow 63 | * JEM from https://github.com/wgrathwohl/JEM 64 | * VERA from https://github.com/wgrathwohl/VERA 65 | * SSM from https://github.com/ermongroup/sliced_score_matching 66 | * WideResNet from https://github.com/meliketoy/wide-resnet.pytorch 67 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: . 4 | output_subdir: 5 | 6 | seml: 7 | executable: uncertainty_est/train.py 8 | output_dir: slurm 9 | project_root_dir: .. 10 | 11 | slurm: 12 | experiments_per_job: 1 13 | sbatch_options: 14 | gres: gpu:1 # num GPUs 15 | mem: 16G # memory 16 | cpus-per-task: 4 # num cores 17 | time: 0-05:00 # max time, D-HH:MM 18 | 19 | fixed: 20 | trainer_config: 21 | gpus: 1 22 | benchmark: True 23 | log_dir: . 24 | output_folder: 25 | ood_dataset: 26 | seed: 1 27 | 28 | model_config: 29 | data_shape: ${fixed.data_shape} 30 | 31 | defaults: 32 | - hydra/job_logging: stdout 33 | - model: fc_mcmc_supervised 34 | - dataset: segment 35 | - model/updates: ${defaults.0.model}_${defaults.1.dataset} 36 | optional: True 37 | - grid: seed 38 | -------------------------------------------------------------------------------- /configs/dataset/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 1 5 | sbatch_options: 6 | time: 0-30:00 7 | 8 | fixed: 9 | trainer_config: 10 | max_epochs: 50 11 | 12 | checkpoint_config: 13 | monitor: val/ood 14 | mode: max 15 | 16 | dataset: cifar10 17 | num_classes: 10 18 | batch_size: 32 19 | data_shape: [32, 32, 3] 20 | 21 | test_ood_datasets: 22 | - lsun 23 | - textures 24 | - cifar100 25 | - svhn 26 | - celeb-a 27 | - uniform_noise 28 | - gaussian_noise 29 | - constant 30 | - svhn_unscaled 31 | 32 | model_config: 33 | ood_val_datasets: 34 | - celeb-a 35 | - cifar100 36 | -------------------------------------------------------------------------------- /configs/dataset/cifar10_embedded.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 5 5 | sbatch_options: 6 | time: 0-05:00 7 | 8 | fixed: 9 | dataset: cifar10_embedded 10 | num_classes: 10 11 | batch_size: 512 12 | data_shape: [640] 13 | 14 | checkpoint_config: 15 | monitor: val/ood 16 | mode: max 17 | 18 | test_ood_datasets: 19 | - lsun_embedded 20 | - textures_embedded 21 | - cifar100_embedded 22 | - svhn_embedded 23 | - celeb-a_embedded 24 | - cifar10_uniform_noise_embedded 25 | - cifar10_gaussian_noise_embedded 26 | - cifar10_constant_embedded 27 | - svhn_unscaled_embedded 28 | 29 | model_config: 30 | ood_val_datasets: 31 | - celeb-a_embedded 32 | - cifar100_embedded 33 | -------------------------------------------------------------------------------- /configs/dataset/fashionmnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 1 5 | sbatch_options: 6 | time: 0-10:00 7 | 8 | fixed: 9 | dataset: fashionmnist 10 | num_classes: 10 11 | batch_size: 64 12 | data_shape: [32, 32, 1] 13 | 14 | trainer_config: 15 | max_epochs: 50 16 | 17 | checkpoint_config: 18 | monitor: val/ood 19 | mode: max 20 | save_last: True 21 | 22 | test_ood_datasets: 23 | - mnist 24 | - notmnist 25 | - kmnist 26 | - gaussian_noise 27 | - uniform_noise 28 | - constant 29 | - kmnist_unscaled 30 | 31 | model_config: 32 | ood_val_datasets: 33 | - mnist 34 | - kmnist 35 | -------------------------------------------------------------------------------- /configs/dataset/fashionmnist_embedded.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 5 5 | 6 | fixed: 7 | dataset: fashionmnist_embedded 8 | num_classes: 10 9 | batch_size: 128 10 | data_shape: [640] 11 | 12 | checkpoint_config: 13 | monitor: val/ood 14 | mode: max 15 | 16 | test_ood_datasets: 17 | - mnist_embedded 18 | - notmnist_embedded 19 | - kmnist_embedded 20 | - fashionmnist_gaussian_noise_embedded 21 | - fashionmnist_uniform_noise_embedded 22 | - fashionmnist_constant_embedded 23 | 24 | model_config: 25 | ood_val_datasets: 26 | - mnist_embedded 27 | - kmnist_embedded 28 | -------------------------------------------------------------------------------- /configs/dataset/gaussian2d.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 5 5 | sbatch_options: 6 | time: 0-05:00 7 | 8 | fixed: 9 | dataset: Gaussian2D 10 | num_classes: 3 11 | batch_size: 512 12 | data_shape: 13 | - 2 14 | 15 | model_config: 16 | is_toy_dataset: True 17 | 18 | test_ood_datasets: 19 | - AnomalousGaussian2D 20 | -------------------------------------------------------------------------------- /configs/dataset/genomics.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 1 5 | sbatch_options: 6 | time: 0-30:00 7 | mem: 32G 8 | 9 | fixed: 10 | dataset: genomics 11 | data_shape: [250] 12 | num_cat: 4 13 | num_classes: 10 14 | 15 | test_ood_datasets: 16 | - genomics-ood 17 | - genomics-noise 18 | 19 | model_config: 20 | ood_val_datasets: 21 | - genomics-ood 22 | 23 | checkpoint_config: 24 | monitor: val/ood 25 | mode: max 26 | save_last: True 27 | 28 | trainer_config: 29 | max_steps: 100_000 30 | max_epochs: 100_000 31 | limit_val_batches: 1000 32 | val_check_interval: 10_000 33 | limit_test_batches: 10_000 34 | -------------------------------------------------------------------------------- /configs/dataset/genomics_embedded.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 5 5 | sbatch_options: 6 | time: 0-10:00 7 | mem: 32G 8 | 9 | fixed: 10 | dataset: genomics_embedded 11 | batch_size: 512 12 | data_shape: 13 | - 128 14 | num_classes: 10 15 | 16 | checkpoint_config: 17 | monitor: val/ood 18 | mode: max 19 | 20 | test_ood_datasets: 21 | - genomics-ood_embedded 22 | - genomics-noise_embedded 23 | 24 | trainer_config: 25 | max_epochs: 50_000 26 | max_steps: 50_000 27 | 28 | model_config: 29 | ood_val_datasets: 30 | - genomics-ood_embedded 31 | -------------------------------------------------------------------------------- /configs/dataset/segment.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 5 5 | sbatch_options: 6 | time: 0-05:00 7 | 8 | fixed: 9 | dataset: segment 10 | num_classes: 6 11 | batch_size: 512 12 | data_shape: 13 | - 18 14 | 15 | checkpoint_config: 16 | monitor: val/ood 17 | mode: max 18 | 19 | test_ood_datasets: 20 | - segment-ood 21 | - uniform_noise 22 | - gaussian_noise 23 | - constant 24 | 25 | trainer_config: 26 | max_epochs: 10_000 27 | max_steps: 10_000 28 | check_val_every_n_epoch: 10 29 | 30 | model_config: 31 | ood_val_datasets: 32 | - segment-ood 33 | -------------------------------------------------------------------------------- /configs/dataset/sensorless.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | slurm: 4 | experiments_per_job: 5 5 | sbatch_options: 6 | time: 0-05:00 7 | 8 | fixed: 9 | dataset: sensorless 10 | num_classes: 9 11 | batch_size: 512 12 | data_shape: [48] 13 | 14 | test_ood_datasets: 15 | - sensorless-ood 16 | - uniform_noise 17 | - gaussian_noise 18 | - constant 19 | 20 | checkpoint_config: 21 | monitor: val/ood 22 | mode: max 23 | save_last: True 24 | 25 | trainer_config: 26 | max_epochs: 10_000 27 | max_steps: 10_000 28 | check_val_every_n_epoch: 10 29 | 30 | model_config: 31 | ood_val_datasets: 32 | - sensorless-ood 33 | -------------------------------------------------------------------------------- /configs/grid/seed-bottleneck.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | seed: 4 | type: choice 5 | options: 6 | - 2525 7 | - 2123 8 | - 293993 9 | - 2324234 10 | - 6566634 11 | 12 | model_config.arch_config.bottleneck_channels_factor: 13 | type: choice 14 | options: 15 | - 0.05 16 | - 0.1 17 | - 0.2 18 | -------------------------------------------------------------------------------- /configs/grid/seed-clf_weight.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | seed: 4 | type: choice 5 | options: 6 | - 2525 7 | - 2123 8 | - 293993 9 | - 2324234 10 | - 6566634 11 | 12 | model_config.clf_weight: 13 | type: loguniform 14 | min: 0.01 15 | max: 100 16 | num: 5 17 | -------------------------------------------------------------------------------- /configs/grid/seed-pyxce.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | seed: 4 | type: choice 5 | options: 6 | - 2525 7 | - 2123 8 | - 293993 9 | - 2324234 10 | - 6566634 11 | 12 | model_config.pyxce: 13 | type: loguniform 14 | min: 0.01 15 | max: 100 16 | num: 5 17 | -------------------------------------------------------------------------------- /configs/grid/seed.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | seed: 4 | type: choice 5 | options: 6 | - 2525 7 | - 2123 8 | - 293993 9 | - 2324234 10 | - 6566634 11 | -------------------------------------------------------------------------------- /configs/model/fc_ce_baseline.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 10_000 5 | max_steps: 10_000 6 | check_val_every_n_epoch: 100 7 | 8 | checkpoint_config: 9 | monitor: val/acc 10 | mode: max 11 | 12 | earlystop_config: 13 | monitor: val/acc 14 | mode: max 15 | patience: 10 16 | 17 | ood_dataset: 18 | 19 | model_name: CEBaseline 20 | model_config: 21 | arch_name: fc 22 | arch_config: 23 | inp_dim: ${fixed.data_shape.0} 24 | num_classes: ${fixed.num_classes} 25 | hidden_dims: [100, 100, 100, 100, 100] 26 | learning_rate: 0.001 27 | momentum: 0.9 28 | weight_decay: 0.0005 29 | -------------------------------------------------------------------------------- /configs/model/fc_mcmc.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 10_000 5 | max_steps: 10_000 6 | 7 | checkpoint_config: 8 | save_last: True 9 | 10 | earlystop_config: 11 | 12 | ood_dataset: ${fixed.dataset} 13 | 14 | model_name: JEM 15 | model_config: 16 | arch_name: fc 17 | arch_config: 18 | inp_dim: ${fixed.data_shape.0} 19 | num_classes: 1 20 | hidden_dims: [100, 100, 100, 100, 100] 21 | 22 | learning_rate: 0.001 23 | momentum: 0.9 24 | weight_decay: 0.0 25 | smoothing: 0.0 26 | sgld_lr: 1.0 27 | sgld_std: 0.01 28 | sgld_steps: 100 29 | pyxce: 0.0 30 | pxsgld: 1.0 31 | pxysgld: 0.0 32 | buffer_size: 9000 33 | reinit_freq: 0.05 34 | data_shape: ${fixed.data_shape} 35 | sgld_batch_size: ${fixed.batch_size} 36 | class_cond_p_x_sample: False 37 | n_classes: 1 38 | warmup_steps: 2500 39 | entropy_reg_weight: 0.0 40 | lr_step_size: 1000 41 | -------------------------------------------------------------------------------- /configs/model/fc_mcmc_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 10_000 5 | max_steps: 10_000 6 | 7 | earlystop_config: 8 | 9 | ood_dataset: ${fixed.dataset} 10 | 11 | model_name: JEM 12 | model_config: 13 | arch_name: fc 14 | arch_config: 15 | inp_dim: ${fixed.data_shape.0} 16 | num_classes: ${fixed.num_classes} 17 | hidden_dims: [100, 100, 100, 100, 100] 18 | 19 | learning_rate: 0.001 20 | momentum: 0.9 21 | weight_decay: 0.0 22 | smoothing: 0.0 23 | sgld_lr: 1.0 24 | sgld_std: 0.01 25 | sgld_steps: 100 26 | pyxce: 1.0 27 | pxsgld: 1.0 28 | pxysgld: 0.0 29 | buffer_size: 9000 30 | reinit_freq: 0.05 31 | data_shape: ${fixed.data_shape} 32 | sgld_batch_size: ${fixed.batch_size} 33 | class_cond_p_x_sample: True 34 | n_classes: ${fixed.num_classes} 35 | warmup_steps: 2500 36 | entropy_reg_weight: 0.0 37 | lr_step_size: 1000 38 | -------------------------------------------------------------------------------- /configs/model/fc_norm_flow.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | checkpoint_config: 4 | monitor: val/loss 5 | mode: min 6 | 7 | earlystop_config: 8 | monitor: val/loss 9 | mode: min 10 | patience: 10 11 | 12 | ood_dataset: 13 | 14 | model_name: NormalizingFlow 15 | model_config: 16 | arch_name: normalizing_flow 17 | arch_config: 18 | flow_type: radial_flow 19 | dim: ${fixed.data_shape.0} 20 | flow_length: 20 21 | learning_rate: 0.001 22 | momentum: 0.9 23 | weight_decay: 0.0 24 | -------------------------------------------------------------------------------- /configs/model/fc_ssm.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 10_000 5 | max_steps: 10_000 6 | 7 | checkpoint_config: 8 | save_last: True 9 | 10 | earlystop_config: 11 | 12 | ood_dataset: ${fixed.dataset} 13 | 14 | model_name: SSM 15 | model_config: 16 | arch_name: fc 17 | arch_config: 18 | inp_dim: ${fixed.data_shape.0} 19 | num_classes: 1 20 | hidden_dims: [100, 100, 100, 100, 100] 21 | 22 | learning_rate: 0.001 23 | momentum: 0.9 24 | weight_decay: 0.0 25 | clf_weight: 0.0 26 | n_classes: 1 27 | n_particles: 1 28 | warmup_steps: 2500 29 | lr_step_size: 1000 30 | -------------------------------------------------------------------------------- /configs/model/fc_ssm_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 10_000 5 | max_steps: 10_000 6 | 7 | earlystop_config: 8 | 9 | ood_dataset: ${fixed.dataset} 10 | 11 | model_name: SSM 12 | model_config: 13 | arch_name: fc 14 | arch_config: 15 | inp_dim: ${fixed.data_shape.0} 16 | num_classes: ${fixed.num_classes} 17 | hidden_dims: [100, 100, 100, 100, 100] 18 | 19 | learning_rate: 0.001 20 | momentum: 0.9 21 | weight_decay: 0.0 22 | clf_weight: 1.0 23 | n_classes: ${fixed.num_classes} 24 | n_particles: 1 25 | warmup_steps: 2500 26 | lr_step_size: 1000 27 | -------------------------------------------------------------------------------- /configs/model/fc_vera.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_steps: 10_000 5 | max_epochs: 10_000 6 | 7 | checkpoint_config: 8 | save_last: True 9 | 10 | earlystop_config: 11 | 12 | ood_dataset: ${fixed.dataset} 13 | 14 | model_name: VERA 15 | model_config: 16 | arch_name: fc 17 | arch_config: 18 | inp_dim: ${fixed.data_shape.0} 19 | num_classes: 1 20 | hidden_dims: [100, 100, 100, 100, 100] 21 | batch_norm: False 22 | bias: True 23 | slope: 0.2 24 | 25 | learning_rate: 0.00003 26 | beta1: 0.0 27 | beta2: 0.9 28 | weight_decay: 0.0 29 | n_classes: 1 30 | uncond: False 31 | gen_learning_rate: 0.00006 32 | ebm_iters: 1 33 | generator_iters: 1 34 | entropy_weight: 0.0001 35 | 36 | generator_type: vera 37 | generator_arch_name: fc 38 | generator_config: 39 | noise_dim: 16 40 | post_lr: 0.00003 41 | init_post_logsigma: 0.1 42 | generator_arch_config: 43 | inp_dim: ${fixed.model_config.generator_config.noise_dim} 44 | num_classes: ${fixed.data_shape.0} 45 | hidden_dims: ${fixed.model_config.arch_config.hidden_dims} 46 | batch_norm: True 47 | bias: False 48 | 49 | min_sigma: 0.01 50 | max_sigma: 0.3 51 | p_control: 0.0 52 | n_control: 0.0 53 | pg_control: 0.1 54 | clf_ent_weight: 0.0 55 | ebm_type: p_x 56 | clf_weight: 0.0 57 | warmup_steps: 2500 58 | no_g_batch_norm: False 59 | batch_size: ${fixed.batch_size} 60 | lr_decay: 0.3 61 | lr_decay_epochs: [3000, 4000] 62 | -------------------------------------------------------------------------------- /configs/model/fc_vera_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_steps: 10_000 5 | max_epochs: 10_000 6 | 7 | earlystop_config: 8 | 9 | ood_dataset: ${fixed.dataset} 10 | 11 | model_name: VERA 12 | model_config: 13 | arch_name: fc 14 | arch_config: 15 | inp_dim: ${fixed.data_shape.0} 16 | num_classes: ${fixed.num_classes} 17 | hidden_dims: [100, 100, 100, 100, 100] 18 | batch_norm: False 19 | slope: 0.2 20 | 21 | learning_rate: 0.001 22 | beta1: 0.0 23 | beta2: 0.9 24 | weight_decay: 0.0 25 | n_classes: ${fixed.num_classes} 26 | uncond: False 27 | gen_learning_rate: 0.001 28 | ebm_iters: 1 29 | generator_iters: 1 30 | entropy_weight: 0.0001 31 | 32 | generator_type: vera 33 | generator_arch_name: fc 34 | generator_config: 35 | noise_dim: 16 36 | post_lr: 0.00003 37 | init_post_logsigma: 0.1 38 | generator_arch_config: 39 | inp_dim: ${fixed.model_config.generator_config.noise_dim} 40 | num_classes: ${fixed.data_shape.0} 41 | hidden_dims: ${fixed.model_config.arch_config.hidden_dims} 42 | activation: relu 43 | batch_norm: True 44 | bias: False 45 | 46 | min_sigma: 0.01 47 | max_sigma: 0.3 48 | p_control: 0.0 49 | n_control: 0.0 50 | pg_control: 0.1 51 | clf_ent_weight: 0.0 52 | ebm_type: jem 53 | clf_weight: 1.0 54 | warmup_steps: 2500 55 | no_g_batch_norm: False 56 | batch_size: ${fixed.batch_size} 57 | lr_decay: 0.3 58 | lr_decay_epochs: [3000, 4000] 59 | -------------------------------------------------------------------------------- /configs/model/genomics_autoregressive_density.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | checkpoint_config: 4 | monitor: val/loss 5 | mode: min 6 | 7 | earlystop_config: 8 | monitor: val/loss 9 | mode: min 10 | patience: 10 11 | 12 | batch_size: 512 13 | 14 | model_name: NormalizingFlow 15 | model_config: 16 | arch_name: seq_generative_model 17 | arch_config: 18 | input_size: ${fixed.num_cat} 19 | hidden_size: 128 20 | num_classes: ${fixed.num_classes} 21 | learning_rate: 0.005 22 | momentum: 0.9 23 | weight_decay: 0.0005 24 | -------------------------------------------------------------------------------- /configs/model/genomics_ce_baseline.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | batch_size: 512 4 | 5 | earlystop_config: 6 | monitor: val/acc 7 | mode: max 8 | patience: 10 9 | 10 | model_name: CEBaseline 11 | model_config: 12 | arch_name: seq_classifier 13 | arch_config: 14 | in_channels: ${fixed.num_cat} 15 | num_filters: 128 16 | fc_hidden_size: 1000 17 | num_classes: 10 18 | kernel_size: 20 19 | learning_rate: 0.0001 20 | momentum: 0.9 21 | weight_decay: 0.0005 22 | -------------------------------------------------------------------------------- /configs/model/genomics_mcmc.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | batch_size: 512 4 | 5 | checkpoint_config: 6 | 7 | earlystop_config: 8 | 9 | ood_dataset: ${fixed.dataset} 10 | 11 | model_name: DiscreteMCMC 12 | model_config: 13 | num_cat: ${fixed.num_cat} 14 | 15 | arch_name: seq_classifier 16 | arch_config: 17 | in_channels: ${fixed.num_cat} 18 | num_filters: 128 19 | fc_hidden_size: 1000 20 | num_classes: 1 21 | kernel_size: 20 22 | 23 | learning_rate: 0.0001 24 | momentum: 0.9 25 | weight_decay: 0.0 26 | smoothing: 0.0 27 | sgld_steps: 40 28 | pyxce: 0.0 29 | pxsgld: 1.0 30 | pxysgld: 0.0 31 | buffer_size: 9999 32 | reinit_freq: 0.0 33 | data_shape: ${fixed.data_shape} 34 | sgld_batch_size: ${fixed.batch_size} 35 | class_cond_p_x_sample: False 36 | n_classes: 1 37 | warmup_steps: 2500 38 | entropy_reg_weight: 0.0 39 | lr_step_size: 1000 40 | -------------------------------------------------------------------------------- /configs/model/genomics_mcmc_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | batch_size: 512 4 | 5 | earlystop_config: 6 | 7 | ood_dataset: ${fixed.dataset} 8 | 9 | model_name: DiscreteMCMC 10 | model_config: 11 | num_cat: ${fixed.num_cat} 12 | 13 | arch_name: seq_classifier 14 | arch_config: 15 | in_channels: ${fixed.num_cat} 16 | num_filters: 128 17 | fc_hidden_size: 1000 18 | num_classes: ${fixed.num_classes} 19 | kernel_size: 20 20 | 21 | learning_rate: 0.0001 22 | momentum: 0.9 23 | weight_decay: 0.0 24 | smoothing: 0.0 25 | sgld_steps: 40 26 | pyxce: 1.0 27 | pxsgld: 1.0 28 | pxysgld: 0.0 29 | buffer_size: 9999 30 | reinit_freq: 0.0 31 | data_shape: ${fixed.data_shape} 32 | sgld_batch_size: 512 33 | class_cond_p_x_sample: True 34 | n_classes: ${fixed.num_classes} 35 | warmup_steps: 2500 36 | entropy_reg_weight: 0.0001 37 | lr_step_size: 1000 38 | -------------------------------------------------------------------------------- /configs/model/genomics_vera.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | batch_size: 256 4 | 5 | checkpoint_config: 6 | 7 | earlystop_config: 8 | 9 | ood_dataset: ${fixed.dataset} 10 | 11 | model_name: VERA 12 | model_config: 13 | arch_name: seq_classifier 14 | arch_config: 15 | in_channels: ${fixed.num_cat} 16 | num_filters: 128 17 | fc_hidden_size: 1000 18 | num_classes: 1 19 | kernel_size: 20 20 | 21 | learning_rate: 0.00003 22 | beta1: 0.0 23 | beta2: 0.9 24 | weight_decay: 0.0 25 | n_classes: 1 26 | uncond: False 27 | gen_learning_rate: 0.00006 28 | ebm_iters: 1 29 | generator_iters: 1 30 | entropy_weight: 0.0001 31 | 32 | generator_type: vera_discrete 33 | generator_arch_name: seq_generator 34 | generator_config: 35 | noise_dim: 128 36 | post_lr: 0.00003 37 | init_post_logsigma: 0.1 38 | generator_arch_config: 39 | inp_dim: ${fixed.model_config.generator_config.noise_dim} 40 | num_classes: ${fixed.num_cat} 41 | seq_length: ${fixed.data_shape.0} 42 | 43 | min_sigma: 0.01 44 | max_sigma: 0.3 45 | p_control: 0.0 46 | n_control: 0.0 47 | pg_control: 0.1 48 | clf_ent_weight: 0.0 49 | ebm_type: p_x 50 | clf_weight: 0.0 51 | warmup_steps: 2500 52 | no_g_batch_norm: False 53 | batch_size: ${fixed.batch_size} 54 | lr_decay: 0.3 55 | lr_decay_epochs: [150, 180] 56 | -------------------------------------------------------------------------------- /configs/model/genomics_vera_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | batch_size: 256 4 | 5 | checkpoint_config: 6 | 7 | earlystop_config: 8 | 9 | ood_dataset: ${fixed.dataset} 10 | 11 | model_name: VERA 12 | model_config: 13 | arch_name: seq_classifier 14 | arch_config: 15 | in_channels: ${fixed.num_cat} 16 | num_filters: 128 17 | fc_hidden_size: 1000 18 | num_classes: ${fixed.num_classes} 19 | kernel_size: 20 20 | 21 | learning_rate: 0.00003 22 | beta1: 0.0 23 | beta2: 0.9 24 | weight_decay: 0.0 25 | n_classes: ${fixed.num_classes} 26 | uncond: False 27 | gen_learning_rate: 0.00006 28 | ebm_iters: 1 29 | generator_iters: 1 30 | entropy_weight: 0.0001 31 | 32 | generator_type: vera_discrete 33 | generator_arch_name: seq_generator 34 | generator_config: 35 | noise_dim: 128 36 | post_lr: 0.00003 37 | init_post_logsigma: 0.1 38 | generator_arch_config: 39 | inp_dim: ${fixed.model_config.generator_config.noise_dim} 40 | num_classes: ${fixed.num_cat} 41 | seq_length: ${fixed.data_shape.0} 42 | 43 | min_sigma: 0.01 44 | max_sigma: 0.3 45 | p_control: 0.0 46 | n_control: 0.0 47 | pg_control: 0.1 48 | clf_ent_weight: 0.0 49 | ebm_type: jem 50 | clf_weight: 1.0 51 | warmup_steps: 2500 52 | no_g_batch_norm: False 53 | batch_size: ${fixed.batch_size} 54 | lr_decay: 0.3 55 | lr_decay_epochs: [150, 180] 56 | -------------------------------------------------------------------------------- /configs/model/img_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 50 5 | 6 | checkpoint_config: 7 | monitor: val/loss 8 | mode: min 9 | 10 | earlystop_config: 11 | monitor: val/loss 12 | mode: min 13 | patience: 10 14 | 15 | model_name: Autoencoder 16 | model_config: 17 | arch_name: wrn 18 | arch_config: 19 | depth: 10 20 | num_classes: 32 21 | widen_factor: 2 22 | input_channels: ${fixed.data_shape.2} 23 | decoder_arch_name: resnetgenerator 24 | decoder_arch_config: 25 | unit_interval: False 26 | feats: 32 27 | out_channels: ${fixed.data_shape.2} 28 | learning_rate: 0.0001 29 | momentum: 0.9 30 | weight_decay: 0.0005 31 | -------------------------------------------------------------------------------- /configs/model/img_ce_baseline.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 50 5 | 6 | checkpoint_config: 7 | monitor: val/acc 8 | mode: max 9 | 10 | earlystop_config: 11 | monitor: val/acc 12 | mode: max 13 | patience: 10 14 | 15 | model_name: CEBaseline 16 | model_config: 17 | arch_name: wrn 18 | arch_config: 19 | depth: 10 20 | num_classes: ${fixed.num_classes} 21 | widen_factor: 2 22 | input_channels: ${fixed.data_shape.2} 23 | learning_rate: 0.0001 24 | momentum: 0.9 25 | weight_decay: 0.0005 26 | -------------------------------------------------------------------------------- /configs/model/img_glow.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | earlystop_config: 4 | monitor: val/loss 5 | mode: min 6 | patience: 10 7 | 8 | normalize: False 9 | ood_dataset: 10 | 11 | model_name: Glow 12 | model_config: 13 | in_channels: ${fixed.data_shape.2} 14 | num_channels: 512 15 | num_levels: 3 16 | num_steps: 32 17 | learning_rate: 0.001 18 | momentum: 0.9 19 | weight_decay: 0.0 20 | -------------------------------------------------------------------------------- /configs/model/img_mcmc.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 20 5 | 6 | checkpoint_config: 7 | save_last: True 8 | 9 | earlystop_config: 10 | 11 | ood_dataset: ${fixed.dataset} 12 | sigma: 0.1 13 | 14 | model_name: JEM 15 | model_config: 16 | n_classes: 1 17 | arch_name: wrn 18 | arch_config: 19 | depth: 10 20 | num_classes: 1 21 | widen_factor: 2 22 | input_channels: ${fixed.data_shape.2} 23 | strides: [1, 2, 2] 24 | learning_rate: 0.0001 25 | momentum: 0.9 26 | weight_decay: 0.0 27 | smoothing: 0.0 28 | sgld_lr: 1. 29 | sgld_std: 0.01 30 | sgld_steps: 100 31 | pyxce: 0.0 32 | pxsgld: 1.0 33 | pxysgld: 0.0 34 | buffer_size: 10000 35 | reinit_freq: 0.05 36 | data_shape: ${fixed.data_shape} 37 | sgld_batch_size: ${fixed.batch_size} 38 | class_cond_p_x_sample: False 39 | entropy_reg_weight: 0.0001 40 | -------------------------------------------------------------------------------- /configs/model/img_mcmc_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 20 5 | 6 | checkpoint_config: 7 | 8 | earlystop_config: 9 | 10 | ood_dataset: ${fixed.dataset} 11 | sigma: 0.1 12 | 13 | model_name: JEM 14 | model_config: 15 | n_classes: ${fixed.num_classes} 16 | arch_name: wrn 17 | arch_config: 18 | depth: 10 19 | num_classes: ${fixed.num_classes} 20 | widen_factor: 2 21 | input_channels: ${fixed.data_shape.2} 22 | strides: [1, 2, 2] 23 | learning_rate: 0.0001 24 | momentum: 0.9 25 | weight_decay: 0.0 26 | smoothing: 0.0 27 | sgld_lr: 1. 28 | sgld_std: 0.01 29 | sgld_steps: 100 30 | pyxce: 1.0 31 | pxsgld: 1.0 32 | pxysgld: 0.0 33 | buffer_size: 10000 34 | reinit_freq: 0.05 35 | data_shape: ${fixed.data_shape} 36 | sgld_batch_size: ${fixed.batch_size} 37 | class_cond_p_x_sample: True 38 | entropy_reg_weight: 0.0001 39 | -------------------------------------------------------------------------------- /configs/model/img_real_nvp.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | checkpoint_config: 4 | monitor: val/loss 5 | mode: min 6 | save_last: True 7 | 8 | earlystop_config: 9 | monitor: val/loss 10 | mode: min 11 | patience: 10 12 | 13 | normalize: False 14 | ood_dataset: 15 | 16 | model_name: RealNVP 17 | model_config: 18 | num_scales: 2 19 | in_channels: ${fixed.data_shape.2} 20 | mid_channels: 32 21 | num_blocks: 4 22 | num_classes: 10 23 | learning_rate: 0.0001 24 | momentum: 0.9 25 | weight_decay: 0.0 26 | -------------------------------------------------------------------------------- /configs/model/img_ssm.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 50 5 | 6 | checkpoint_config: 7 | save_last: True 8 | 9 | earlystop_config: 10 | 11 | ood_dataset: ${fixed.dataset} 12 | sigma: 0.1 13 | 14 | model_name: SSM 15 | model_config: 16 | arch_name: wrn 17 | arch_config: 18 | depth: 10 19 | num_classes: 1 20 | widen_factor: 2 21 | input_channels: ${fixed.data_shape.2} 22 | strides: [1, 2, 2] 23 | 24 | learning_rate: 0.0001 25 | momentum: 0.9 26 | weight_decay: 0.0 27 | clf_weight: 0.0 28 | n_classes: 1 29 | n_particles: 1 30 | warmup_steps: 2500 31 | lr_step_size: 1000 32 | -------------------------------------------------------------------------------- /configs/model/img_ssm_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 50 5 | 6 | checkpoint_config: 7 | 8 | earlystop_config: 9 | 10 | ood_dataset: ${fixed.dataset} 11 | sigma: 0.1 12 | 13 | model_name: SSM 14 | model_config: 15 | arch_name: wrn 16 | arch_config: 17 | depth: 10 18 | num_classes: ${fixed.num_classes} 19 | widen_factor: 2 20 | input_channels: ${fixed.data_shape.2} 21 | strides: [1, 2, 2] 22 | 23 | learning_rate: 0.0001 24 | momentum: 0.9 25 | weight_decay: 0.0 26 | clf_weight: 1.0 27 | n_classes: ${fixed.num_classes} 28 | n_particles: 1 29 | warmup_steps: 2500 30 | lr_step_size: 1000 31 | -------------------------------------------------------------------------------- /configs/model/img_vera.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | checkpoint_config: 4 | save_last: True 5 | 6 | earlystop_config: 7 | 8 | ood_dataset: ${fixed.dataset} 9 | 10 | model_name: VERA 11 | model_config: 12 | n_classes: 1 13 | arch_name: wrn 14 | arch_config: 15 | depth: 10 16 | num_classes: 1 17 | widen_factor: 2 18 | input_channels: ${fixed.data_shape.2} 19 | strides: [1, 2, 2] 20 | learning_rate: 0.00003 21 | beta1: 0.0 22 | beta2: 0.9 23 | weight_decay: 0.0 24 | gen_learning_rate: 0.00006 25 | ebm_iters: 1 26 | generator_iters: 1 27 | entropy_weight: 0.0001 28 | generator_type: vera 29 | generator_arch_name: resnetgenerator 30 | generator_arch_config: 31 | unit_interval: False 32 | feats: 128 33 | out_channels: ${fixed.data_shape.2} 34 | generator_config: 35 | noise_dim: 128 36 | post_lr: 0.00003 37 | init_post_logsigma: 0.1 38 | min_sigma: 0.01 39 | max_sigma: 0.3 40 | p_control: 0.0 41 | n_control: 0.0 42 | pg_control: 0.1 43 | clf_ent_weight: 0.0001 44 | ebm_type: p_x 45 | clf_weight: 0.0 46 | warmup_steps: 2500 47 | no_g_batch_norm: False 48 | batch_size: ${fixed.batch_size} 49 | lr_decay: 0.3 50 | lr_decay_epochs: [15, 18] 51 | -------------------------------------------------------------------------------- /configs/model/img_vera_posteriornet.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 40 5 | terminate_on_nan: True 6 | 7 | checkpoint_config: 8 | monitor: val/acc 9 | mode: max 10 | save_last: True 11 | 12 | earlystop_config: 13 | monitor: val/acc 14 | mode: max 15 | patience: 10 16 | 17 | ood_dataset: ${fixed.dataset} 18 | sigma: 0.1 19 | 20 | model_name: VERAPosteriorNet 21 | model_config: 22 | arch_name: wrn 23 | arch_config: 24 | depth: 4 25 | num_classes: ${fixed.num_classes} 26 | widen_factor: 10 27 | input_channels: ${fixed.data_shape.2} 28 | dropout: 0.3 29 | norm: group 30 | 31 | learning_rate: 0.00003 32 | beta1: 0.0 33 | beta2: 0.9 34 | weight_decay: 0.1 35 | n_classes: 10 36 | uncond: False 37 | gen_learning_rate: 0.00006 38 | ebm_iters: 1 39 | generator_iters: 1 40 | entropy_weight: 0.0001 41 | 42 | generator_type: vera 43 | generator_arch_name: resnetgenerator 44 | generator_arch_config: 45 | unit_interval: False 46 | feats: 128 47 | out_channels: ${fixed.data_shape.2} 48 | 49 | generator_config: 50 | noise_dim: 128 51 | post_lr: 0.00003 52 | init_post_logsigma: 0.1 53 | 54 | min_sigma: 0.01 55 | max_sigma: 0.3 56 | p_control: 1.0 57 | n_control: 1.0 58 | pg_control: 0.1 59 | clf_ent_weight: 0.1 60 | ebm_type: jem 61 | clf_weight: 100.0 62 | warmup_steps: 2500 63 | no_g_batch_norm: False 64 | batch_size: ${fixed.batch_size} 65 | lr_decay: 0.3 66 | lr_decay_epochs: [30, 35] 67 | vis_every: -1 68 | alpha_fix: True 69 | entropy_reg: 0.0001 70 | -------------------------------------------------------------------------------- /configs/model/img_vera_priornet.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | trainer_config: 4 | max_epochs: 50 5 | terminate_on_nan: True 6 | 7 | checkpoint_config: 8 | monitor: val/acc 9 | mode: max 10 | save_last: True 11 | 12 | earlystop_config: 13 | monitor: val/acc 14 | mode: max 15 | patience: 10 16 | 17 | ood_dataset: ${fixed.dataset} 18 | 19 | model_name: VERAPriorNet 20 | model_config: 21 | arch_name: wrn 22 | arch_config: 23 | depth: 10 24 | num_classes: ${fixed.num_classes} 25 | widen_factor: 4 26 | input_channels: ${fixed.data_shape.2} 27 | dropout: 0.3 28 | norm: group 29 | 30 | learning_rate: 0.00003 31 | beta1: 0.0 32 | beta2: 0.9 33 | weight_decay: 0.1 34 | n_classes: 10 35 | uncond: False 36 | gen_learning_rate: 0.00006 37 | ebm_iters: 1 38 | generator_iters: 1 39 | entropy_weight: 0.0001 40 | 41 | generator_type: vera 42 | generator_arch_name: resnetgenerator 43 | generator_arch_config: 44 | unit_interval: False 45 | feats: 128 46 | out_channels: ${fixed.data_shape.2} 47 | 48 | generator_config: 49 | noise_dim: 128 50 | post_lr: 0.00003 51 | init_post_logsigma: 0.1 52 | 53 | min_sigma: 0.01 54 | max_sigma: 0.3 55 | p_control: 1.0 56 | n_control: 1.0 57 | pg_control: 1.0 58 | clf_ent_weight: 0.1 59 | ebm_type: jem 60 | clf_weight: 100.0 61 | warmup_steps: 2500 62 | no_g_batch_norm: False 63 | batch_size: 32 64 | lr_decay: 0.3 65 | lr_decay_epochs: [40, 45] 66 | vis_every: -1 67 | alpha_fix: True 68 | concentration: 1.0 69 | target_concentration: 70 | entropy_reg: 0.0001 71 | reverse_kl: True 72 | w_neg_sample_loss: 0.0 73 | -------------------------------------------------------------------------------- /configs/model/img_vera_supervised.yaml: -------------------------------------------------------------------------------- 1 | # @package fixed 2 | 3 | checkpoint_config: 4 | 5 | earlystop_config: 6 | 7 | ood_dataset: ${fixed.dataset} 8 | 9 | model_name: VERA 10 | model_config: 11 | n_classes: ${fixed.num_classes} 12 | arch_name: wrn 13 | arch_config: 14 | depth: 10 15 | num_classes: ${fixed.num_classes} 16 | widen_factor: 2 17 | input_channels: ${fixed.data_shape.2} 18 | strides: [1, 2, 2] 19 | learning_rate: 0.00003 20 | beta1: 0.0 21 | beta2: 0.9 22 | weight_decay: 0.0 23 | gen_learning_rate: 0.00006 24 | ebm_iters: 1 25 | generator_iters: 1 26 | entropy_weight: 0.0001 27 | generator_type: vera 28 | generator_arch_name: resnetgenerator 29 | generator_arch_config: 30 | unit_interval: False 31 | feats: 128 32 | out_channels: ${fixed.data_shape.2} 33 | generator_config: 34 | noise_dim: 128 35 | post_lr: 0.00003 36 | init_post_logsigma: 0.1 37 | min_sigma: 0.01 38 | max_sigma: 0.3 39 | p_control: 0.0 40 | n_control: 0.0 41 | pg_control: 0.1 42 | clf_ent_weight: 0.0001 43 | ebm_type: jem 44 | clf_weight: 100.0 45 | warmup_steps: 2500 46 | no_g_batch_norm: False 47 | batch_size: ${fixed.batch_size} 48 | lr_decay: 0.3 49 | lr_decay_epochs: [15, 18] 50 | -------------------------------------------------------------------------------- /configs/seml_config.yaml: -------------------------------------------------------------------------------- 1 | seml: 2 | executable: uncertainty_est/train.py 3 | name: vera_dim_reduction 4 | output_dir: slurm 5 | project_root_dir: .. 6 | 7 | slurm: 8 | experiments_per_job: 1 9 | sbatch_options: 10 | gres: gpu:1 # num GPUs 11 | mem: 8G # memory 12 | cpus-per-task: 2 # num cores 13 | time: 0-20:00 # max time, D-HH:MM 14 | mail-type: FAIL 15 | 16 | ###### BEGIN PARAMETER CONFIGURATION ###### 17 | 18 | fixed: 19 | trainer_config: 20 | max_epochs: 40 21 | gpus: 1 22 | benchmark: True 23 | limit_val_batches: 0 24 | 25 | checkpoint_config: 26 | monitor: 27 | save_last: True 28 | 29 | earlystop_config: 30 | 31 | test_ood_datasets: 32 | - lsun 33 | - svhn 34 | - svhn_unscaled 35 | - gaussian_noise 36 | - uniform_noise 37 | 38 | log_dir: grid_search/vera_dim_reduction 39 | dataset: &dataset cifar10 40 | # Use second dataset 41 | ood_dataset: *dataset 42 | seed: 1 43 | batch_size: &batch_size 32 44 | data_shape: 45 | - 32 46 | - 32 47 | - &n_channels 3 48 | 49 | model_name: VERA 50 | model_config: 51 | n_classes: &n_classes 1 52 | 53 | arch_name: wrn 54 | arch_config: 55 | depth: 28 56 | num_classes: *n_classes 57 | widen_factor: 10 58 | input_channels: *n_channels 59 | # strides: [1, 2, 2] 60 | learning_rate: 0.00003 61 | beta1: 0.0 62 | beta2: 0.9 63 | weight_decay: 0.0 64 | gen_learning_rate: 0.00006 65 | ebm_iters: 1 66 | generator_iters: 1 67 | entropy_weight: 0.0001 68 | generator_type: vera 69 | generator_arch_name: resnetgenerator 70 | generator_arch_config: 71 | unit_interval: False 72 | feats: 128 73 | out_channels: *n_channels 74 | generator_config: 75 | noise_dim: 128 76 | post_lr: 0.00003 77 | init_post_logsigma: 0.1 78 | min_sigma: 0.01 79 | max_sigma: 0.3 80 | p_control: 0.0 81 | n_control: 0.0 82 | pg_control: 0.1 83 | clf_ent_weight: 0.0 84 | ebm_type: p_x 85 | clf_weight: 0.0 86 | warmup_steps: 2500 87 | no_g_batch_norm: False 88 | batch_size: *batch_size 89 | lr_decay: 0.3 90 | lr_decay_epochs: [20, 30] 91 | 92 | grid: 93 | model_config: 94 | type: parameter_collection 95 | params: 96 | arch_config: 97 | type: parameter_collection 98 | params: 99 | strides: 100 | type: choice 101 | options: 102 | - [1, 2, 2] 103 | - [1, 1, 2] 104 | - [1, 1, 1] 105 | -------------------------------------------------------------------------------- /density_histograms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/density_histograms.png -------------------------------------------------------------------------------- /req.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | absl-py=0.13.0=pypi_0 6 | aiohttp=3.7.4.post0=pypi_0 7 | antlr4-python3-runtime=4.8=pypi_0 8 | async-timeout=3.0.1=pypi_0 9 | attrs=21.2.0=pypi_0 10 | blas=1.0=mkl 11 | ca-certificates=2021.5.25=h06a4308_1 12 | cachetools=4.2.2=pypi_0 13 | certifi=2021.5.30=py38h06a4308_0 14 | chardet=4.0.0=pypi_0 15 | cudatoolkit=10.1.243=h6bb024c_0 16 | cycler=0.10.0=py38_0 17 | dbus=1.13.18=hb2f20db_0 18 | expat=2.4.1=h2531618_2 19 | fontconfig=2.13.1=h6c09931_0 20 | freetype=2.10.4=h5ab3b9f_0 21 | fsspec=2021.6.1=pypi_0 22 | future=0.18.2=pypi_0 23 | glib=2.68.2=h36276a3_0 24 | google-auth=1.32.1=pypi_0 25 | google-auth-oauthlib=0.4.4=pypi_0 26 | grpcio=1.38.1=pypi_0 27 | gst-plugins-base=1.14.0=h8213a91_2 28 | gstreamer=1.14.0=h28cd5cc_2 29 | hydra-core=1.0.0=pypi_0 30 | icu=58.2=he6710b0_3 31 | idna=2.10=pypi_0 32 | importlib-resources=5.2.0=pypi_0 33 | intel-openmp=2021.2.0=h06a4308_610 34 | joblib=1.0.1=pypi_0 35 | jpeg=9b=h024ee3a_2 36 | kiwisolver=1.3.1=py38h2531618_0 37 | lcms2=2.12=h3be6417_0 38 | ld_impl_linux-64=2.33.1=h53a641e_7 39 | libffi=3.3=he6710b0_2 40 | libgcc-ng=9.1.0=hdf63c60_0 41 | libgfortran-ng=7.5.0=ha8ba4b0_17 42 | libgfortran4=7.5.0=ha8ba4b0_17 43 | libpng=1.6.37=hbc83047_0 44 | libstdcxx-ng=9.1.0=hdf63c60_0 45 | libtiff=4.2.0=h85742a9_0 46 | libuuid=1.0.3=h1bed415_2 47 | libuv=1.40.0=h7b6447c_0 48 | libwebp-base=1.2.0=h27cfd23_0 49 | libxcb=1.14=h7b6447c_0 50 | libxml2=2.9.10=hb55368b_3 51 | lz4-c=1.9.3=h2531618_0 52 | markdown=3.3.4=pypi_0 53 | matplotlib=3.3.4=py38h06a4308_0 54 | matplotlib-base=3.3.4=py38h62a2d02_0 55 | mkl=2021.2.0=h06a4308_296 56 | mkl-service=2.3.0=py38h27cfd23_1 57 | mkl_fft=1.3.0=py38h42c9631_2 58 | mkl_random=1.2.1=py38ha9443f7_2 59 | multidict=5.1.0=pypi_0 60 | ncurses=6.2=he6710b0_1 61 | ninja=1.10.2=hff7bd54_1 62 | numpy=1.20.2=py38h2d18471_0 63 | numpy-base=1.20.2=py38hfae3a4d_0 64 | oauthlib=3.1.1=pypi_0 65 | olefile=0.46=py_0 66 | omegaconf=2.1.0=pypi_0 67 | openssl=1.1.1k=h27cfd23_0 68 | opt-einsum=3.3.0=pypi_0 69 | packaging=21.0=pypi_0 70 | pandas=1.2.5=py38h295c915_0 71 | pcre=8.45=h295c915_0 72 | pillow=8.2.0=py38he98fc37_0 73 | pip=21.0.1=py38h06a4308_0 74 | protobuf=3.17.3=pypi_0 75 | pyasn1=0.4.8=pypi_0 76 | pyasn1-modules=0.2.8=pypi_0 77 | pydeprecate=0.3.0=pypi_0 78 | pyparsing=2.4.7=pyhd3eb1b0_0 79 | pyqt=5.9.2=py38h05f1152_4 80 | pyro-api=0.1.2=pypi_0 81 | pyro-ppl=1.5.2=pypi_0 82 | python=3.8.8=hdb3f193_5 83 | python-dateutil=2.8.1=pyhd3eb1b0_0 84 | pytorch=1.7.1=py3.8_cuda10.1.243_cudnn7.6.3_0 85 | pytorch-lightning=1.3.8=pypi_0 86 | pytz=2021.1=pyhd3eb1b0_0 87 | pyyaml=5.4.1=py38h27cfd23_1 88 | qt=5.9.7=h5867ecd_1 89 | readline=8.1=h27cfd23_0 90 | requests=2.25.1=pypi_0 91 | requests-oauthlib=1.3.0=pypi_0 92 | rsa=4.7.2=pypi_0 93 | scikit-learn=0.24.2=pypi_0 94 | scipy=1.6.2=py38had2a1c9_1 95 | setuptools=52.0.0=py38h06a4308_0 96 | sip=4.19.13=py38he6710b0_0 97 | six=1.16.0=pyhd3eb1b0_0 98 | sqlite=3.35.4=hdfb4753_0 99 | tensorboard=2.4.1=pypi_0 100 | tensorboard-plugin-wit=1.8.0=pypi_0 101 | tfrecord=1.11=pypi_0 102 | threadpoolctl=2.1.0=pypi_0 103 | tk=8.6.10=hbc83047_0 104 | torchmetrics=0.4.0=pypi_0 105 | torchvision=0.8.2=py38_cu101 106 | tornado=6.1=py38h27cfd23_0 107 | tqdm=4.61.1=pypi_0 108 | typing_extensions=3.10.0.0=pyh06a4308_0 109 | uncertainty-eval=0.0.1=pypi_0 110 | urllib3=1.26.6=pypi_0 111 | werkzeug=2.0.1=pypi_0 112 | wheel=0.36.2=pyhd3eb1b0_0 113 | xz=5.2.5=h7b6447c_0 114 | yaml=0.2.5=h7b6447c_0 115 | yarl=1.6.3=pypi_0 116 | zipp=3.5.0=pypi_0 117 | zlib=1.2.11=h7b6447c_3 118 | zstd=1.4.9=haebb681_0 119 | -------------------------------------------------------------------------------- /uncertainty_est/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/archs/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/archs/arch_factory.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import vgg16 2 | 3 | from uncertainty_est.archs.wrn import WideResNet 4 | from uncertainty_est.archs.fc import SynthModel 5 | from uncertainty_est.archs.resnet import ResNetGenerator 6 | from uncertainty_est.archs.flows import NormalizingFlowDensity 7 | 8 | 9 | def get_arch(name, config_dict: dict): 10 | if name == "wrn": 11 | return WideResNet(**config_dict) 12 | elif name == "vgg16": 13 | return vgg16(**config_dict) 14 | elif name == "fc": 15 | return SynthModel(**config_dict) 16 | elif name == "resnetgenerator": 17 | return ResNetGenerator(**config_dict) 18 | elif name == "normalizing_flow": 19 | return NormalizingFlowDensity(**config_dict) 20 | else: 21 | raise ValueError(f'Architecture "{name}" not implemented!') 22 | -------------------------------------------------------------------------------- /uncertainty_est/archs/fc.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def make_mlp( 5 | dim_list, activation="relu", batch_norm=False, dropout=0, bias=True, slope=1e-2 6 | ): 7 | layers = [] 8 | if len(dim_list) > 2: 9 | for dim_in, dim_out in zip(dim_list[:-2], dim_list[1:-1]): 10 | layers.append(nn.Linear(dim_in, dim_out, bias=bias)) 11 | 12 | if batch_norm: 13 | layers.append(nn.BatchNorm1d(dim_out, affine=True)) 14 | 15 | if activation == "relu": 16 | layers.append(nn.ReLU()) 17 | elif activation == "leaky_relu": 18 | layers.append(nn.LeakyReLU(slope, inplace=True)) 19 | elif activation == "elu": 20 | layers.append(nn.ELU(inplace=True)) 21 | else: 22 | raise NotImplementedError(f"Activation '{activation}' not implemented!") 23 | 24 | if dropout > 0: 25 | layers.append(nn.Dropout(p=dropout)) 26 | layers.append(nn.Linear(dim_list[-2], dim_list[-1], bias=bias)) 27 | model = nn.Sequential(*layers) 28 | return model 29 | 30 | 31 | class SynthModel(nn.Module): 32 | def __init__( 33 | self, 34 | inp_dim, 35 | num_classes, 36 | hidden_dims=[ 37 | 50, 38 | 50, 39 | ], 40 | activation="leaky_relu", 41 | batch_norm=False, 42 | dropout=0.0, 43 | **kwargs, 44 | ): 45 | super().__init__() 46 | self.net = make_mlp( 47 | [ 48 | inp_dim, 49 | ] 50 | + hidden_dims 51 | + [ 52 | num_classes, 53 | ], 54 | activation=activation, 55 | batch_norm=batch_norm, 56 | dropout=dropout, 57 | **kwargs, 58 | ) 59 | 60 | def forward(self, inp): 61 | return self.net(inp) 62 | -------------------------------------------------------------------------------- /uncertainty_est/archs/flows.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/sharpenb/Posterior-Network/blob/main/src/posterior_networks/NormalizingFlowDensity.py """ 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | import torch.distributions as tdist 7 | from pyro.distributions.transforms import ELUTransform, LeakyReLUTransform 8 | from pyro.distributions.torch_transform import TransformModule 9 | from pyro.distributions.transforms import ( 10 | Planar, 11 | Radial, 12 | affine_autoregressive, 13 | affine_coupling, 14 | permute, 15 | ) 16 | 17 | 18 | class OrthogonalTransform(nn.Module): 19 | def __init__(self, dim): 20 | super().__init__() 21 | self.dim = dim 22 | self.transform = nn.Linear(dim, dim, bias=False) 23 | 24 | def forward(self, x): 25 | return self.transform(x) 26 | 27 | @staticmethod 28 | def log_abs_det_jacobian(z, z_next): 29 | return torch.zeros(z.shape[0], device=z.device) 30 | 31 | def compute_weight_penalty(self): 32 | sq_weight = torch.mm(self.transform.weight.T, self.transform.weight) 33 | identity = torch.eye(self.dim).to(sq_weight) 34 | penalty = torch.linalg.norm(identity - sq_weight, ord="fro") 35 | 36 | return penalty 37 | 38 | 39 | class ReparameterizedTransform(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.dim = dim 43 | self.bias = nn.Parameter(torch.zeros(dim)) 44 | 45 | self.u_mat = nn.Parameter(torch.Tensor(dim, dim)) 46 | self.v_mat = nn.Parameter(torch.Tensor(dim, dim)) 47 | self.sigma = nn.Parameter(torch.ones(dim)) 48 | nn.init.kaiming_uniform_(self.u_mat, a=math.sqrt(5)) 49 | nn.init.kaiming_uniform_(self.v_mat, a=math.sqrt(5)) 50 | 51 | def forward(self, x): 52 | weight = self.u_mat @ torch.diag_embed(self.sigma) @ self.v_mat.T 53 | lin_out = torch.nn.functional.linear(x, weight, self.bias) 54 | return lin_out 55 | 56 | def log_abs_det_jacobian(self, z, z_next): 57 | ladj = torch.sum(torch.log(self.sigma.abs())) 58 | ladj = torch.empty(z.size(0), device=z.device).fill_(ladj) 59 | return ladj 60 | 61 | def compute_weight_penalty(self): 62 | return self._weight_penalty(self.u_mat) + self._weight_penalty(self.v_mat) 63 | 64 | @staticmethod 65 | def _weight_penalty(weight): 66 | sq_weight = torch.mm(weight.T, weight) 67 | identity = torch.eye(weight.size(0)).to(sq_weight) 68 | penalty = torch.linalg.norm(identity - sq_weight, ord="fro") 69 | return penalty 70 | 71 | 72 | @torch.jit.script 73 | def sequential_mult(V, X): 74 | for row in range(V.shape[0] - 1, -1, -1): 75 | X = X - 2 * V[row : row + 1, :].t() @ (V[row : row + 1, :] @ X) 76 | return X 77 | 78 | 79 | @torch.jit.script 80 | def sequential_inv_mult(V, X): 81 | for row in range(V.shape[0]): 82 | X = X - 2 * V[row : row + 1, :].t() @ (V[row : row + 1, :] @ X) 83 | return X 84 | 85 | 86 | class Orthogonal(nn.Module): 87 | def __init__(self, d, m=28, strategy="sequential"): 88 | super(Orthogonal, self).__init__() 89 | self.d = d 90 | self.strategy = strategy 91 | self.U = torch.nn.Parameter(torch.zeros((d, d)).normal_(0, 0.05)) 92 | 93 | if strategy == "fast": 94 | assert d % m == 0, ( 95 | "The CUDA implementation assumes m=%i divides d=%i which, for current parameters, is not true. " 96 | % (d, m) 97 | ) 98 | HouseProd.m = m 99 | 100 | if not strategy in ["fast", "sequential"]: 101 | raise NotImplementedError( 102 | "The only implemented strategies are 'fast' and 'sequential'. " 103 | ) 104 | 105 | def forward(self, X): 106 | if self.strategy == "fast": 107 | X = HouseProd.apply(X, self.U) 108 | elif self.strategy == "sequential": 109 | X = sequential_mult(self.U, X.t()).t() 110 | else: 111 | raise NotImplementedError( 112 | "The only implemented strategies are 'fast' and 'sequential'. " 113 | ) 114 | return X 115 | 116 | def inverse(self, X): 117 | if self.strategy == "fast": 118 | X = HouseProd.apply(X, torch.flip(self.U, dims=[0])) 119 | elif self.strategy == "sequential": 120 | X = sequential_inv_mult(self.U, X.t()).t() 121 | else: 122 | raise NotImplementedError( 123 | "The only implemented strategies are 'fast' and 'sequential'. " 124 | ) 125 | return X 126 | 127 | def lgdet(self, X): 128 | return 0 129 | 130 | 131 | class MatExpOrthogonal(nn.Module): 132 | def __init__(self, d): 133 | super().__init__() 134 | self.param = nn.Parameter(torch.randn(d, d)) 135 | self.cached = False 136 | self._cache = None 137 | 138 | @property 139 | def w(self): 140 | if self.cached: 141 | return self._cache 142 | 143 | # Compute skew symetric matrix 144 | aux = self.param.triu(diagonal=1) 145 | aux = aux - aux.t() 146 | 147 | w = torch.matrix_exp(aux) 148 | self.cached = True 149 | self._cache = w 150 | return w 151 | 152 | def forward(self, x): 153 | return x @ self.w 154 | 155 | def backward(*args, **kwargs): 156 | self.cached = False 157 | return super().backward(*args, **kwargs) 158 | 159 | 160 | class LinearSVD(torch.nn.Module): 161 | def __init__(self, d, orthogonal_param="householder", **kwargs): 162 | super(LinearSVD, self).__init__() 163 | self.d = d 164 | 165 | if orthogonal_param == "mat_exp": 166 | self.U = MatExpOrthogonal(d, **kwargs) 167 | self.V = MatExpOrthogonal(d, **kwargs) 168 | elif orthogonal_param == "householder": 169 | self.U = Orthogonal(d, **kwargs) 170 | self.V = Orthogonal(d, **kwargs) 171 | else: 172 | raise NotImplementedError( 173 | f"Orthogonal parameterization {orthogonal_param} not implemented." 174 | ) 175 | self.D = nn.Parameter(torch.empty(d).uniform_(0.99, 1.01)) 176 | self.bias = nn.Parameter(torch.zeros(d)) 177 | 178 | def forward(self, X): 179 | X = self.U(X) 180 | X = X * self.D 181 | X = self.V(X) 182 | return X + self.bias 183 | 184 | def log_abs_det_jacobian(self, z, z_next): 185 | ladj = torch.log(torch.prod(self.D).abs()) 186 | return torch.empty(z.size(0), device=z.device).fill_(ladj) 187 | 188 | 189 | class NonLinearity(nn.Module): 190 | def __init__(self, type_="elu"): 191 | super().__init__() 192 | nonlins = { 193 | "elu": ELUTransform, 194 | "leaky_relu": LeakyReLUTransform, 195 | } 196 | self.transform = nonlins[type_]() 197 | 198 | def forward(self, x): 199 | return self.transform(x) 200 | 201 | def log_abs_det_jacobian(self, z, z_next): 202 | return self.transform.log_abs_det_jacobian(z, z_next).sum(1) 203 | 204 | def compute_weight_penalty(self, *args, **kwargs): 205 | return 0.0 206 | 207 | 208 | class NormalizingFlowDensity(nn.Module): 209 | def __init__(self, dim, flow_length, flow_type="planar_flow", **kwargs): 210 | super(NormalizingFlowDensity, self).__init__() 211 | self.dim = dim 212 | self.flow_length = flow_length 213 | self.flow_type = flow_type 214 | 215 | self.mean = nn.Parameter(torch.zeros(self.dim), requires_grad=False) 216 | self.cov = nn.Parameter(torch.eye(self.dim), requires_grad=False) 217 | 218 | if self.flow_type == "radial_flow": 219 | self.transforms = nn.ModuleList([Radial(dim) for _ in range(flow_length)]) 220 | elif self.flow_type == "iaf_flow": 221 | self.transforms = nn.ModuleList( 222 | [ 223 | affine_autoregressive(dim, hidden_dims=[128, 128], **kwargs) 224 | for _ in range(flow_length) 225 | ] 226 | ) 227 | elif self.flow_type == "planar_flow": 228 | self.transforms = nn.ModuleList([Planar(dim) for _ in range(flow_length)]) 229 | elif self.flow_type == "affine_coupling": 230 | self.transforms = [] 231 | for i in range(flow_length): 232 | coupling = affine_coupling(dim, hidden_dims=[128, 128], **kwargs) 233 | self.transforms.append(coupling) 234 | self.add_module(f"coupling_{i}", coupling) 235 | 236 | permutation = nn.Parameter(torch.randperm(dim), requires_grad=False) 237 | self.register_parameter(f"permutation_{i}", permutation) 238 | permute_layer = permute(dim, permutation=permutation) 239 | self.transforms.append(permute_layer) 240 | 241 | elif self.flow_type == "orthogonal_flow": 242 | self.transforms = nn.ModuleList( 243 | [OrthogonalTransform(dim) for _ in range(flow_length)] 244 | ) 245 | elif self.flow_type == "reparameterized_flow": 246 | self.transforms = nn.ModuleList() 247 | for i in range(flow_length): 248 | self.transforms.append(ReparameterizedTransform(dim)) 249 | if i != (flow_length - 1): 250 | self.transforms.append(NonLinearity("elu")) 251 | elif self.flow_type == "svd": 252 | self.transforms = nn.ModuleList() 253 | for i in range(flow_length): 254 | self.transforms.append(LinearSVD(dim)) 255 | if i != (flow_length - 1): 256 | self.transforms.append(NonLinearity("elu")) 257 | elif self.flow_type == "svd_mat_exp": 258 | self.transforms = nn.ModuleList() 259 | for i in range(flow_length): 260 | self.transforms.append(LinearSVD(dim, orthogonal_param="mat_exp")) 261 | if i != (flow_length - 1): 262 | self.transforms.append(NonLinearity("elu")) 263 | else: 264 | raise NotImplementedError 265 | 266 | def forward(self, z): 267 | sum_log_jacobians = 0 268 | for transform in self.transforms: 269 | z_next = transform(z) 270 | sum_log_jacobians = sum_log_jacobians + transform.log_abs_det_jacobian( 271 | z, z_next 272 | ) 273 | z = z_next 274 | 275 | return z, sum_log_jacobians 276 | 277 | def log_prob(self, x): 278 | z, sum_log_jacobians = self.forward(x) 279 | log_prob_z = tdist.MultivariateNormal(self.mean, self.cov).log_prob(z) 280 | log_prob_x = log_prob_z + sum_log_jacobians # [batch_size] 281 | return log_prob_x 282 | 283 | def sample(self, num): 284 | dist = tdist.MultivariateNormal(self.mean, self.cov) 285 | z = dist.sample([num]) 286 | base_log_prob = dist.log_prob(z) 287 | sum_log_jacobians = 0.0 288 | for transform in self.transforms[::-1]: 289 | z_next = transform.inv(z) 290 | sum_log_jacobians = sum_log_jacobians - transform.log_abs_det_jacobian( 291 | z, z_next 292 | ) 293 | z = z_next 294 | return z, base_log_prob + sum_log_jacobians 295 | -------------------------------------------------------------------------------- /uncertainty_est/archs/glow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/archs/glow/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/archs/glow/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def mean_dim(tensor, dim=None, keepdims=False): 6 | """Take the mean along multiple dimensions. 7 | Args: 8 | tensor (torch.Tensor): Tensor of values to average. 9 | dim (list): List of dimensions along which to take the mean. 10 | keepdims (bool): Keep dimensions rather than squeezing. 11 | Returns: 12 | mean (torch.Tensor): New tensor of mean value(s). 13 | """ 14 | if dim is None: 15 | return tensor.mean() 16 | else: 17 | if isinstance(dim, int): 18 | dim = [dim] 19 | dim = sorted(dim) 20 | for d in dim: 21 | tensor = tensor.mean(dim=d, keepdim=True) 22 | if not keepdims: 23 | for i, d in enumerate(dim): 24 | tensor.squeeze_(d - i) 25 | return tensor 26 | 27 | 28 | class ActNorm(nn.Module): 29 | """Activation normalization for 2D inputs. 30 | 31 | The bias and scale get initialized using the mean and variance of the 32 | first mini-batch. After the init, bias and scale are trainable parameters. 33 | 34 | Adapted from: 35 | > https://github.com/openai/glow 36 | """ 37 | 38 | def __init__(self, num_features, scale=1.0, return_ldj=False): 39 | super(ActNorm, self).__init__() 40 | self.register_buffer("is_initialized", torch.zeros(1)) 41 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 42 | self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 43 | 44 | self.num_features = num_features 45 | self.scale = float(scale) 46 | self.eps = 1e-6 47 | self.return_ldj = return_ldj 48 | 49 | def initialize_parameters(self, x): 50 | if not self.training: 51 | return 52 | 53 | with torch.no_grad(): 54 | bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True) 55 | v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True) 56 | logs = (self.scale / (v.sqrt() + self.eps)).log() 57 | self.bias.data.copy_(bias.data) 58 | self.logs.data.copy_(logs.data) 59 | self.is_initialized += 1.0 60 | 61 | def _center(self, x, reverse=False): 62 | if reverse: 63 | return x - self.bias 64 | else: 65 | return x + self.bias 66 | 67 | def _scale(self, x, sldj, reverse=False): 68 | logs = self.logs 69 | if reverse: 70 | x = x * logs.mul(-1).exp() 71 | else: 72 | x = x * logs.exp() 73 | 74 | if sldj is not None: 75 | ldj = logs.sum() * x.size(2) * x.size(3) 76 | if reverse: 77 | sldj = sldj - ldj 78 | else: 79 | sldj = sldj + ldj 80 | 81 | return x, sldj 82 | 83 | def forward(self, x, ldj=None, reverse=False): 84 | if not self.is_initialized: 85 | self.initialize_parameters(x) 86 | 87 | if reverse: 88 | x, ldj = self._scale(x, ldj, reverse) 89 | x = self._center(x, reverse) 90 | else: 91 | x = self._center(x, reverse) 92 | x, ldj = self._scale(x, ldj, reverse) 93 | 94 | if self.return_ldj: 95 | return x, ldj 96 | 97 | return x 98 | -------------------------------------------------------------------------------- /uncertainty_est/archs/glow/coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .act_norm import ActNorm 6 | 7 | 8 | class Coupling(nn.Module): 9 | """Affine coupling layer originally used in Real NVP and described by Glow. 10 | 11 | Note: The official Glow implementation (https://github.com/openai/glow) 12 | uses a different affine coupling formulation than described in the paper. 13 | This implementation follows the paper and Real NVP. 14 | 15 | Args: 16 | in_channels (int): Number of channels in the input. 17 | mid_channels (int): Number of channels in the intermediate activation 18 | in NN. 19 | """ 20 | 21 | def __init__(self, in_channels, mid_channels): 22 | super(Coupling, self).__init__() 23 | self.nn = NN(in_channels, mid_channels, 2 * in_channels) 24 | self.scale = nn.Parameter(torch.ones(in_channels, 1, 1)) 25 | 26 | def forward(self, x, ldj, reverse=False): 27 | x_change, x_id = x.chunk(2, dim=1) 28 | 29 | st = self.nn(x_id) 30 | s, t = st[:, 0::2, ...], st[:, 1::2, ...] 31 | s = self.scale * torch.tanh(s) 32 | 33 | # Scale and translate 34 | if reverse: 35 | x_change = x_change * s.mul(-1).exp() - t 36 | ldj = ldj - s.flatten(1).sum(-1) 37 | else: 38 | x_change = (x_change + t) * s.exp() 39 | ldj = ldj + s.flatten(1).sum(-1) 40 | 41 | x = torch.cat((x_change, x_id), dim=1) 42 | 43 | return x, ldj 44 | 45 | 46 | class NN(nn.Module): 47 | """Small convolutional network used to compute scale and translate factors. 48 | 49 | Args: 50 | in_channels (int): Number of channels in the input. 51 | mid_channels (int): Number of channels in the hidden activations. 52 | out_channels (int): Number of channels in the output. 53 | use_act_norm (bool): Use activation norm rather than batch norm. 54 | """ 55 | 56 | def __init__(self, in_channels, mid_channels, out_channels, use_act_norm=False): 57 | super(NN, self).__init__() 58 | norm_fn = ActNorm if use_act_norm else nn.BatchNorm2d 59 | 60 | self.in_norm = norm_fn(in_channels) 61 | self.in_conv = nn.Conv2d( 62 | in_channels, mid_channels, kernel_size=3, padding=1, bias=False 63 | ) 64 | nn.init.normal_(self.in_conv.weight, 0.0, 0.05) 65 | 66 | self.mid_norm = norm_fn(mid_channels) 67 | self.mid_conv = nn.Conv2d( 68 | mid_channels, mid_channels, kernel_size=1, padding=0, bias=False 69 | ) 70 | nn.init.normal_(self.mid_conv.weight, 0.0, 0.05) 71 | 72 | self.out_norm = norm_fn(mid_channels) 73 | self.out_conv = nn.Conv2d( 74 | mid_channels, out_channels, kernel_size=3, padding=1, bias=True 75 | ) 76 | nn.init.zeros_(self.out_conv.weight) 77 | nn.init.zeros_(self.out_conv.bias) 78 | 79 | def forward(self, x): 80 | x = self.in_norm(x) 81 | x = F.relu(x) 82 | x = self.in_conv(x) 83 | 84 | x = self.mid_norm(x) 85 | x = F.relu(x) 86 | x = self.mid_conv(x) 87 | 88 | x = self.out_norm(x) 89 | x = F.relu(x) 90 | x = self.out_conv(x) 91 | 92 | return x 93 | -------------------------------------------------------------------------------- /uncertainty_est/archs/glow/glow.py: -------------------------------------------------------------------------------- 1 | """ From https://github.com/chrischute/glow """ 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .act_norm import ActNorm 9 | from .coupling import Coupling 10 | from .inv_conv import InvConv 11 | 12 | 13 | class Glow(nn.Module): 14 | """Glow Model 15 | 16 | Based on the paper: 17 | "Glow: Generative Flow with Invertible 1x1 Convolutions" 18 | by Diederik P. Kingma, Prafulla Dhariwal 19 | (https://arxiv.org/abs/1807.03039). 20 | 21 | Args: 22 | in_channels (int): Number of cannels of input. 23 | num_channels (int): Number of channels in middle convolution of each 24 | step of flow. 25 | num_levels (int): Number of levels in the entire model. 26 | num_steps (int): Number of steps of flow for each level. 27 | """ 28 | 29 | def __init__(self, in_channels, num_channels, num_levels, num_steps, k=256): 30 | super(Glow, self).__init__() 31 | self.k = k 32 | 33 | # Use bounds to rescale images before converting to logits, not learned 34 | self.register_buffer("bounds", torch.tensor([0.9], dtype=torch.float32)) 35 | self.flows = _Glow( 36 | in_channels=4 * in_channels, # RGB image after squeeze 37 | mid_channels=num_channels, 38 | num_levels=num_levels, 39 | num_steps=num_steps, 40 | ) 41 | 42 | def forward(self, x, reverse=False): 43 | if reverse: 44 | sldj = torch.zeros(x.size(0), device=x.device) 45 | else: 46 | # Expect inputs in [0, 1] 47 | if x.min() < 0: 48 | x = (x + 1) / 2.0 49 | elif x.max() > 1: 50 | print("Warning: Input not properly normalized!") 51 | 52 | # De-quantize and convert to logits 53 | x, sldj = self._pre_process(x) 54 | 55 | x = squeeze(x) 56 | x, sldj = self.flows(x, sldj, reverse) 57 | x = squeeze(x, reverse=True) 58 | 59 | return x, sldj 60 | 61 | def _pre_process(self, x): 62 | """Dequantize the input image `x` and convert to logits. 63 | 64 | See Also: 65 | - Dequantization: https://arxiv.org/abs/1511.01844, Section 3.1 66 | - Modeling logits: https://arxiv.org/abs/1605.08803, Section 4.1 67 | 68 | Args: 69 | x (torch.Tensor): Input image. 70 | 71 | Returns: 72 | y (torch.Tensor): Dequantized logits of `x`. 73 | """ 74 | y = (x * 255.0 + torch.rand_like(x)) / 256.0 75 | y = (2 * y - 1) * self.bounds 76 | y = (y + 1) / 2 77 | y = y.log() - (1.0 - y).log() 78 | 79 | # Save log-determinant of Jacobian of initial transform 80 | ldj = ( 81 | F.softplus(y) 82 | + F.softplus(-y) 83 | - F.softplus((1.0 - self.bounds).log() - self.bounds.log()) 84 | ) 85 | sldj = ldj.flatten(1).sum(-1) 86 | 87 | return y, sldj 88 | 89 | def log_prob(self, x): 90 | z, sldj = self.forward(x, reverse=False) 91 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) 92 | prior_ll = prior_ll.reshape(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod( 93 | z.size()[1:] 94 | ) 95 | ll = prior_ll + sldj 96 | return ll 97 | 98 | 99 | class _Glow(nn.Module): 100 | """Recursive constructor for a Glow model. Each call creates a single level. 101 | 102 | Args: 103 | in_channels (int): Number of channels in the input. 104 | mid_channels (int): Number of channels in hidden layers of each step. 105 | num_levels (int): Number of levels to construct. Counter for recursion. 106 | num_steps (int): Number of steps of flow for each level. 107 | """ 108 | 109 | def __init__(self, in_channels, mid_channels, num_levels, num_steps): 110 | super(_Glow, self).__init__() 111 | self.steps = nn.ModuleList( 112 | [ 113 | _FlowStep(in_channels=in_channels, mid_channels=mid_channels) 114 | for _ in range(num_steps) 115 | ] 116 | ) 117 | 118 | if num_levels > 1: 119 | self.next = _Glow( 120 | in_channels=2 * in_channels, 121 | mid_channels=mid_channels, 122 | num_levels=num_levels - 1, 123 | num_steps=num_steps, 124 | ) 125 | else: 126 | self.next = None 127 | 128 | def forward(self, x, sldj, reverse=False): 129 | if not reverse: 130 | for step in self.steps: 131 | x, sldj = step(x, sldj, reverse) 132 | 133 | if self.next is not None: 134 | x = squeeze(x) 135 | x, x_split = x.chunk(2, dim=1) 136 | x, sldj = self.next(x, sldj, reverse) 137 | x = torch.cat((x, x_split), dim=1) 138 | x = squeeze(x, reverse=True) 139 | 140 | if reverse: 141 | for step in reversed(self.steps): 142 | x, sldj = step(x, sldj, reverse) 143 | 144 | return x, sldj 145 | 146 | 147 | class _FlowStep(nn.Module): 148 | def __init__(self, in_channels, mid_channels): 149 | super(_FlowStep, self).__init__() 150 | 151 | # Activation normalization, invertible 1x1 convolution, affine coupling 152 | self.norm = ActNorm(in_channels, return_ldj=True) 153 | self.conv = InvConv(in_channels) 154 | self.coup = Coupling(in_channels // 2, mid_channels) 155 | 156 | def forward(self, x, sldj=None, reverse=False): 157 | if reverse: 158 | x, sldj = self.coup(x, sldj, reverse) 159 | x, sldj = self.conv(x, sldj, reverse) 160 | x, sldj = self.norm(x, sldj, reverse) 161 | else: 162 | x, sldj = self.norm(x, sldj, reverse) 163 | x, sldj = self.conv(x, sldj, reverse) 164 | x, sldj = self.coup(x, sldj, reverse) 165 | 166 | return x, sldj 167 | 168 | 169 | def squeeze(x, reverse=False): 170 | """Trade spatial extent for channels. In forward direction, convert each 171 | 1x4x4 volume of input into a 4x1x1 volume of output. 172 | 173 | Args: 174 | x (torch.Tensor): Input to squeeze or unsqueeze. 175 | reverse (bool): Reverse the operation, i.e., unsqueeze. 176 | 177 | Returns: 178 | x (torch.Tensor): Squeezed or unsqueezed tensor. 179 | """ 180 | b, c, h, w = x.size() 181 | if reverse: 182 | # Unsqueeze 183 | x = x.view(b, c // 4, 2, 2, h, w) 184 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 185 | x = x.view(b, c // 4, h * 2, w * 2) 186 | else: 187 | # Squeeze 188 | x = x.view(b, c, h // 2, 2, w // 2, 2) 189 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 190 | x = x.view(b, c * 2 * 2, h // 2, w // 2) 191 | 192 | return x 193 | -------------------------------------------------------------------------------- /uncertainty_est/archs/glow/inv_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class InvConv(nn.Module): 8 | """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow 9 | (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version. 10 | 11 | Args: 12 | num_channels (int): Number of channels in the input and output. 13 | """ 14 | 15 | def __init__(self, num_channels): 16 | super(InvConv, self).__init__() 17 | self.num_channels = num_channels 18 | 19 | # Initialize with a random orthogonal matrix 20 | w_init = np.random.randn(num_channels, num_channels) 21 | w_init = np.linalg.qr(w_init)[0].astype(np.float32) 22 | self.weight = nn.Parameter(torch.from_numpy(w_init)) 23 | 24 | def forward(self, x, sldj, reverse=False): 25 | ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3) 26 | 27 | if reverse: 28 | weight = torch.inverse(self.weight.double()).float() 29 | sldj = sldj - ldj 30 | else: 31 | weight = self.weight 32 | sldj = sldj + ldj 33 | 34 | weight = weight.view(self.num_channels, self.num_channels, 1, 1) 35 | z = F.conv2d(x, weight) 36 | 37 | return z, sldj 38 | -------------------------------------------------------------------------------- /uncertainty_est/archs/real_nvp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/archs/real_nvp/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/archs/real_nvp/coupling_layer.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .resnet import ResNet 7 | from .util import checkerboard_mask 8 | 9 | 10 | class MaskType(IntEnum): 11 | CHECKERBOARD = 0 12 | CHANNEL_WISE = 1 13 | 14 | 15 | class CouplingLayer(nn.Module): 16 | """Coupling layer in RealNVP. 17 | 18 | Args: 19 | in_channels (int): Number of channels in the input. 20 | mid_channels (int): Number of channels in the `s` and `t` network. 21 | num_blocks (int): Number of residual blocks in the `s` and `t` network. 22 | mask_type (MaskType): One of `MaskType.CHECKERBOARD` or `MaskType.CHANNEL_WISE`. 23 | reverse_mask (bool): Whether to reverse the mask. Useful for alternating masks. 24 | """ 25 | 26 | def __init__(self, in_channels, mid_channels, num_blocks, mask_type, reverse_mask): 27 | super(CouplingLayer, self).__init__() 28 | 29 | # Save mask info 30 | self.mask_type = mask_type 31 | self.reverse_mask = reverse_mask 32 | 33 | # Build scale and translate network 34 | if self.mask_type == MaskType.CHANNEL_WISE: 35 | in_channels //= 2 36 | self.st_net = ResNet( 37 | in_channels, 38 | mid_channels, 39 | 2 * in_channels, 40 | num_blocks=num_blocks, 41 | kernel_size=3, 42 | padding=1, 43 | double_after_norm=(self.mask_type == MaskType.CHECKERBOARD), 44 | ) 45 | 46 | # Learnable scale for s 47 | self.rescale = nn.utils.weight_norm(Rescale(in_channels)) 48 | 49 | def forward(self, x, sldj=None, reverse=True): 50 | if self.mask_type == MaskType.CHECKERBOARD: 51 | # Checkerboard mask 52 | b = checkerboard_mask( 53 | x.size(2), x.size(3), self.reverse_mask, device=x.device 54 | ) 55 | x_b = x * b 56 | st = self.st_net(x_b) 57 | s, t = st.chunk(2, dim=1) 58 | s = self.rescale(torch.tanh(s)) 59 | s = s * (1 - b) 60 | t = t * (1 - b) 61 | 62 | # Scale and translate 63 | if reverse: 64 | inv_exp_s = s.mul(-1).exp() 65 | if torch.isnan(inv_exp_s).any(): 66 | raise RuntimeError("Scale factor has NaN entries") 67 | x = x * inv_exp_s - t 68 | else: 69 | exp_s = s.exp() 70 | if torch.isnan(exp_s).any(): 71 | raise RuntimeError("Scale factor has NaN entries") 72 | x = (x + t) * exp_s 73 | 74 | # Add log-determinant of the Jacobian 75 | sldj += s.reshape(s.size(0), -1).sum(-1) 76 | else: 77 | # Channel-wise mask 78 | if self.reverse_mask: 79 | x_id, x_change = x.chunk(2, dim=1) 80 | else: 81 | x_change, x_id = x.chunk(2, dim=1) 82 | 83 | st = self.st_net(x_id) 84 | s, t = st.chunk(2, dim=1) 85 | s = self.rescale(torch.tanh(s)) 86 | 87 | # Scale and translate 88 | if reverse: 89 | inv_exp_s = s.mul(-1).exp() 90 | if torch.isnan(inv_exp_s).any(): 91 | raise RuntimeError("Scale factor has NaN entries") 92 | x_change = x_change * inv_exp_s - t 93 | else: 94 | exp_s = s.exp() 95 | if torch.isnan(exp_s).any(): 96 | raise RuntimeError("Scale factor has NaN entries") 97 | x_change = (x_change + t) * exp_s 98 | 99 | # Add log-determinant of the Jacobian 100 | sldj += s.reshape(s.size(0), -1).sum(-1) 101 | 102 | if self.reverse_mask: 103 | x = torch.cat((x_id, x_change), dim=1) 104 | else: 105 | x = torch.cat((x_change, x_id), dim=1) 106 | 107 | return x, sldj 108 | 109 | 110 | class Rescale(nn.Module): 111 | """Per-channel rescaling. Need a proper `nn.Module` so we can wrap it 112 | with `torch.nn.utils.weight_norm`. 113 | 114 | Args: 115 | num_channels (int): Number of channels in the input. 116 | """ 117 | 118 | def __init__(self, num_channels): 119 | super(Rescale, self).__init__() 120 | self.weight = nn.Parameter(torch.ones(num_channels, 1, 1)) 121 | 122 | def forward(self, x): 123 | x = self.weight * x 124 | return x 125 | -------------------------------------------------------------------------------- /uncertainty_est/archs/real_nvp/real_nvp.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/chrischute/real-nvp """ 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .coupling_layer import CouplingLayer, MaskType 9 | from .util import squeeze_2x2 10 | 11 | 12 | class RealNVPLoss(nn.Module): 13 | """Get the NLL loss for a RealNVP model. 14 | Args: 15 | k (int or float): Number of discrete values in each input dimension. 16 | E.g., `k` is 256 for natural images. 17 | See Also: 18 | Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803 19 | """ 20 | 21 | def __init__(self, k=256): 22 | super(RealNVPLoss, self).__init__() 23 | self.k = k 24 | 25 | def forward(self, z, sldj): 26 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) 27 | prior_ll = prior_ll.view(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod( 28 | z.size()[1:] 29 | ) 30 | ll = prior_ll + sldj 31 | nll = -ll.mean() 32 | 33 | return nll 34 | 35 | 36 | class RealNVP(nn.Module): 37 | """RealNVP Model 38 | 39 | Based on the paper: 40 | "Density estimation using Real NVP" 41 | by Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio 42 | (https://arxiv.org/abs/1605.08803). 43 | 44 | Args: 45 | num_scales (int): Number of scales in the RealNVP model. 46 | in_channels (int): Number of channels in the input. 47 | mid_channels (int): Number of channels in the intermediate layers. 48 | num_blocks (int): Number of residual blocks in the s and t network of 49 | `Coupling` layers. 50 | """ 51 | 52 | def __init__( 53 | self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, k=256 54 | ): 55 | super(RealNVP, self).__init__() 56 | self.k = k 57 | # Register data_constraint to pre-process images, not learnable 58 | self.register_buffer( 59 | "data_constraint", torch.tensor([0.9], dtype=torch.float32) 60 | ) 61 | 62 | self.flows = _RealNVP(0, num_scales, in_channels, mid_channels, num_blocks) 63 | 64 | def forward(self, x, reverse=False, **kwargs): 65 | sldj = None 66 | if not reverse: 67 | # Expect inputs in [0, 1] 68 | if x.min() < 0: 69 | x = (x + 1) / 2.0 70 | elif x.max() > 1: 71 | print("Warning: Input not properly normalized!") 72 | 73 | # De-quantize and convert to logits 74 | x, sldj = self._pre_process(x) 75 | 76 | x, sldj = self.flows(x, sldj, reverse) 77 | 78 | return x, sldj 79 | 80 | def _pre_process(self, x): 81 | """Dequantize the input image `x` and convert to logits. 82 | 83 | Args: 84 | x (torch.Tensor): Input image. 85 | 86 | Returns: 87 | y (torch.Tensor): Dequantized logits of `x`. 88 | 89 | See Also: 90 | - Dequantization: https://arxiv.org/abs/1511.01844, Section 3.1 91 | - Modeling logits: https://arxiv.org/abs/1605.08803, Section 4.1 92 | """ 93 | y = (x * 255.0 + torch.rand_like(x)) / 256.0 94 | y = (2 * y - 1) * self.data_constraint 95 | y = (y + 1) / 2 96 | y = y.log() - (1.0 - y).log() 97 | 98 | # Save log-determinant of Jacobian of initial transform 99 | ldj = ( 100 | F.softplus(y) 101 | + F.softplus(-y) 102 | - F.softplus( 103 | (1.0 - self.data_constraint).log() - self.data_constraint.log() 104 | ) 105 | ) 106 | sldj = ldj.view(ldj.size(0), -1).sum(-1) 107 | 108 | return y, sldj 109 | 110 | def log_prob(self, x): 111 | z, sldj = self.forward(x, reverse=False) 112 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) 113 | prior_ll = prior_ll.reshape(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod( 114 | z.size()[1:] 115 | ) 116 | ll = prior_ll + sldj 117 | return ll 118 | 119 | def inverse(self, z, **kwargs): 120 | return self.forward(z, reverse=True) 121 | 122 | 123 | class _RealNVP(nn.Module): 124 | """Recursive builder for a `RealNVP` model. 125 | 126 | Each `_RealNVPBuilder` corresponds to a single scale in `RealNVP`, 127 | and the constructor is recursively called to build a full `RealNVP` model. 128 | 129 | Args: 130 | scale_idx (int): Index of current scale. 131 | num_scales (int): Number of scales in the RealNVP model. 132 | in_channels (int): Number of channels in the input. 133 | mid_channels (int): Number of channels in the intermediate layers. 134 | num_blocks (int): Number of residual blocks in the s and t network of 135 | `Coupling` layers. 136 | """ 137 | 138 | def __init__(self, scale_idx, num_scales, in_channels, mid_channels, num_blocks): 139 | super(_RealNVP, self).__init__() 140 | 141 | self.is_last_block = scale_idx == num_scales - 1 142 | 143 | self.in_couplings = nn.ModuleList( 144 | [ 145 | CouplingLayer( 146 | in_channels, 147 | mid_channels, 148 | num_blocks, 149 | MaskType.CHECKERBOARD, 150 | reverse_mask=False, 151 | ), 152 | CouplingLayer( 153 | in_channels, 154 | mid_channels, 155 | num_blocks, 156 | MaskType.CHECKERBOARD, 157 | reverse_mask=True, 158 | ), 159 | CouplingLayer( 160 | in_channels, 161 | mid_channels, 162 | num_blocks, 163 | MaskType.CHECKERBOARD, 164 | reverse_mask=False, 165 | ), 166 | ] 167 | ) 168 | 169 | if self.is_last_block: 170 | self.in_couplings.append( 171 | CouplingLayer( 172 | in_channels, 173 | mid_channels, 174 | num_blocks, 175 | MaskType.CHECKERBOARD, 176 | reverse_mask=True, 177 | ) 178 | ) 179 | else: 180 | self.out_couplings = nn.ModuleList( 181 | [ 182 | CouplingLayer( 183 | 4 * in_channels, 184 | 2 * mid_channels, 185 | num_blocks, 186 | MaskType.CHANNEL_WISE, 187 | reverse_mask=False, 188 | ), 189 | CouplingLayer( 190 | 4 * in_channels, 191 | 2 * mid_channels, 192 | num_blocks, 193 | MaskType.CHANNEL_WISE, 194 | reverse_mask=True, 195 | ), 196 | CouplingLayer( 197 | 4 * in_channels, 198 | 2 * mid_channels, 199 | num_blocks, 200 | MaskType.CHANNEL_WISE, 201 | reverse_mask=False, 202 | ), 203 | ] 204 | ) 205 | self.next_block = _RealNVP( 206 | scale_idx + 1, num_scales, 2 * in_channels, 2 * mid_channels, num_blocks 207 | ) 208 | 209 | def forward(self, x, sldj, reverse=False): 210 | if reverse: 211 | if not self.is_last_block: 212 | # Re-squeeze -> split -> next block 213 | x = squeeze_2x2(x, reverse=False, alt_order=True) 214 | x, x_split = x.chunk(2, dim=1) 215 | x, sldj = self.next_block(x, sldj, reverse) 216 | x = torch.cat((x, x_split), dim=1) 217 | x = squeeze_2x2(x, reverse=True, alt_order=True) 218 | 219 | # Squeeze -> 3x coupling (channel-wise) 220 | x = squeeze_2x2(x, reverse=False) 221 | for coupling in reversed(self.out_couplings): 222 | x, sldj = coupling(x, sldj, reverse) 223 | x = squeeze_2x2(x, reverse=True) 224 | 225 | for coupling in reversed(self.in_couplings): 226 | x, sldj = coupling(x, sldj, reverse) 227 | else: 228 | for coupling in self.in_couplings: 229 | x, sldj = coupling(x, sldj, reverse) 230 | 231 | if not self.is_last_block: 232 | # Squeeze -> 3x coupling (channel-wise) 233 | x = squeeze_2x2(x, reverse=False) 234 | for coupling in self.out_couplings: 235 | x, sldj = coupling(x, sldj, reverse) 236 | x = squeeze_2x2(x, reverse=True) 237 | 238 | # Re-squeeze -> split -> next block 239 | x = squeeze_2x2(x, reverse=False, alt_order=True) 240 | x, x_split = x.chunk(2, dim=1) 241 | x, sldj = self.next_block(x, sldj, reverse) 242 | x = torch.cat((x, x_split), dim=1) 243 | x = squeeze_2x2(x, reverse=True, alt_order=True) 244 | 245 | return x, sldj 246 | -------------------------------------------------------------------------------- /uncertainty_est/archs/real_nvp/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class WNConv2d(nn.Module): 7 | """Weight-normalized 2d convolution. 8 | Args: 9 | in_channels (int): Number of channels in the input. 10 | out_channels (int): Number of channels in the output. 11 | kernel_size (int): Side length of each convolutional kernel. 12 | padding (int): Padding to add on edges of input. 13 | bias (bool): Use bias in the convolution operation. 14 | """ 15 | 16 | def __init__(self, in_channels, out_channels, kernel_size, padding, bias=True): 17 | super(WNConv2d, self).__init__() 18 | self.conv = nn.utils.weight_norm( 19 | nn.Conv2d( 20 | in_channels, out_channels, kernel_size, padding=padding, bias=bias 21 | ) 22 | ) 23 | 24 | def forward(self, x): 25 | x = self.conv(x) 26 | 27 | return x 28 | 29 | 30 | class ResidualBlock(nn.Module): 31 | """ResNet basic block with weight norm.""" 32 | 33 | def __init__(self, in_channels, out_channels): 34 | super(ResidualBlock, self).__init__() 35 | 36 | self.in_norm = nn.BatchNorm2d(in_channels) 37 | self.in_conv = WNConv2d( 38 | in_channels, out_channels, kernel_size=3, padding=1, bias=False 39 | ) 40 | 41 | self.out_norm = nn.BatchNorm2d(out_channels) 42 | self.out_conv = WNConv2d( 43 | out_channels, out_channels, kernel_size=3, padding=1, bias=True 44 | ) 45 | 46 | def forward(self, x): 47 | skip = x 48 | 49 | x = self.in_norm(x) 50 | x = F.relu(x) 51 | x = self.in_conv(x) 52 | 53 | x = self.out_norm(x) 54 | x = F.relu(x) 55 | x = self.out_conv(x) 56 | 57 | x = x + skip 58 | 59 | return x 60 | 61 | 62 | class ResNet(nn.Module): 63 | """ResNet for scale and translate factors in Real NVP. 64 | 65 | Args: 66 | in_channels (int): Number of channels in the input. 67 | mid_channels (int): Number of channels in the intermediate layers. 68 | out_channels (int): Number of channels in the output. 69 | num_blocks (int): Number of residual blocks in the network. 70 | kernel_size (int): Side length of each filter in convolutional layers. 71 | padding (int): Padding for convolutional layers. 72 | double_after_norm (bool): Double input after input BatchNorm. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | in_channels, 78 | mid_channels, 79 | out_channels, 80 | num_blocks, 81 | kernel_size, 82 | padding, 83 | double_after_norm, 84 | ): 85 | super(ResNet, self).__init__() 86 | self.in_norm = nn.BatchNorm2d(in_channels) 87 | self.double_after_norm = double_after_norm 88 | self.in_conv = WNConv2d( 89 | 2 * in_channels, mid_channels, kernel_size, padding, bias=True 90 | ) 91 | self.in_skip = WNConv2d( 92 | mid_channels, mid_channels, kernel_size=1, padding=0, bias=True 93 | ) 94 | 95 | self.blocks = nn.ModuleList( 96 | [ResidualBlock(mid_channels, mid_channels) for _ in range(num_blocks)] 97 | ) 98 | self.skips = nn.ModuleList( 99 | [ 100 | WNConv2d( 101 | mid_channels, mid_channels, kernel_size=1, padding=0, bias=True 102 | ) 103 | for _ in range(num_blocks) 104 | ] 105 | ) 106 | 107 | self.out_norm = nn.BatchNorm2d(mid_channels) 108 | self.out_conv = WNConv2d( 109 | mid_channels, out_channels, kernel_size=1, padding=0, bias=True 110 | ) 111 | 112 | def forward(self, x): 113 | x = self.in_norm(x) 114 | if self.double_after_norm: 115 | x *= 2.0 116 | x = torch.cat((x, -x), dim=1) 117 | x = F.relu(x) 118 | x = self.in_conv(x) 119 | x_skip = self.in_skip(x) 120 | 121 | for block, skip in zip(self.blocks, self.skips): 122 | x = block(x) 123 | x_skip += skip(x) 124 | 125 | x = self.out_norm(x_skip) 126 | x = F.relu(x) 127 | x = self.out_conv(x) 128 | 129 | return x 130 | -------------------------------------------------------------------------------- /uncertainty_est/archs/real_nvp/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def squeeze_2x2(x, reverse=False, alt_order=False): 6 | """For each spatial position, a sub-volume of shape `1x1x(N^2 * C)`, 7 | reshape into a sub-volume of shape `NxNxC`, where `N = block_size`. 8 | 9 | Adapted from: 10 | https://github.com/tensorflow/models/blob/master/research/real_nvp/real_nvp_utils.py 11 | 12 | See Also: 13 | - TensorFlow nn.depth_to_space: https://www.tensorflow.org/api_docs/python/tf/nn/depth_to_space 14 | - Figure 3 of RealNVP paper: https://arxiv.org/abs/1605.08803 15 | 16 | Args: 17 | x (torch.Tensor): Input tensor of shape (B, C, H, W). 18 | reverse (bool): Whether to do a reverse squeeze (unsqueeze). 19 | alt_order (bool): Whether to use alternate ordering. 20 | """ 21 | block_size = 2 22 | if alt_order: 23 | n, c, h, w = x.size() 24 | 25 | if reverse: 26 | if c % 4 != 0: 27 | raise ValueError( 28 | "Number of channels must be divisible by 4, got {}.".format(c) 29 | ) 30 | c //= 4 31 | else: 32 | if h % 2 != 0: 33 | raise ValueError("Height must be divisible by 2, got {}.".format(h)) 34 | if w % 2 != 0: 35 | raise ValueError("Width must be divisible by 4, got {}.".format(w)) 36 | # Defines permutation of input channels (shape is (4, 1, 2, 2)). 37 | squeeze_matrix = torch.tensor( 38 | [ 39 | [[[1.0, 0.0], [0.0, 0.0]]], 40 | [[[0.0, 0.0], [0.0, 1.0]]], 41 | [[[0.0, 1.0], [0.0, 0.0]]], 42 | [[[0.0, 0.0], [1.0, 0.0]]], 43 | ], 44 | dtype=x.dtype, 45 | device=x.device, 46 | ) 47 | perm_weight = torch.zeros((4 * c, c, 2, 2), dtype=x.dtype, device=x.device) 48 | for c_idx in range(c): 49 | slice_0 = slice(c_idx * 4, (c_idx + 1) * 4) 50 | slice_1 = slice(c_idx, c_idx + 1) 51 | perm_weight[slice_0, slice_1, :, :] = squeeze_matrix 52 | shuffle_channels = torch.tensor( 53 | [c_idx * 4 for c_idx in range(c)] 54 | + [c_idx * 4 + 1 for c_idx in range(c)] 55 | + [c_idx * 4 + 2 for c_idx in range(c)] 56 | + [c_idx * 4 + 3 for c_idx in range(c)] 57 | ) 58 | perm_weight = perm_weight[shuffle_channels, :, :, :] 59 | 60 | if reverse: 61 | x = F.conv_transpose2d(x, perm_weight, stride=2) 62 | else: 63 | x = F.conv2d(x, perm_weight, stride=2) 64 | else: 65 | b, c, h, w = x.size() 66 | x = x.permute(0, 2, 3, 1) 67 | 68 | if reverse: 69 | if c % 4 != 0: 70 | raise ValueError( 71 | "Number of channels {} is not divisible by 4".format(c) 72 | ) 73 | x = x.view(b, h, w, c // 4, 2, 2) 74 | x = x.permute(0, 1, 4, 2, 5, 3) 75 | x = x.contiguous().view(b, 2 * h, 2 * w, c // 4) 76 | else: 77 | if h % 2 != 0 or w % 2 != 0: 78 | raise ValueError( 79 | "Expected even spatial dims HxW, got {}x{}".format(h, w) 80 | ) 81 | x = x.view(b, h // 2, 2, w // 2, 2, c) 82 | x = x.permute(0, 1, 3, 5, 2, 4) 83 | x = x.contiguous().view(b, h // 2, w // 2, c * 4) 84 | 85 | x = x.permute(0, 3, 1, 2) 86 | 87 | return x 88 | 89 | 90 | def checkerboard_mask( 91 | height, width, reverse=False, dtype=torch.float32, device=None, requires_grad=False 92 | ): 93 | """Get a checkerboard mask, such that no two entries adjacent entries 94 | have the same value. In non-reversed mask, top-left entry is 0. 95 | 96 | Args: 97 | height (int): Number of rows in the mask. 98 | width (int): Number of columns in the mask. 99 | reverse (bool): If True, reverse the mask (i.e., make top-left entry 1). 100 | Useful for alternating masks in RealNVP. 101 | dtype (torch.dtype): Data type of the tensor. 102 | device (torch.device): Device on which to construct the tensor. 103 | requires_grad (bool): Whether the tensor requires gradient. 104 | 105 | 106 | Returns: 107 | mask (torch.tensor): Checkerboard mask of shape (1, 1, height, width). 108 | """ 109 | checkerboard = [[((i % 2) + j) % 2 for j in range(width)] for i in range(height)] 110 | mask = torch.tensor( 111 | checkerboard, dtype=dtype, device=device, requires_grad=requires_grad 112 | ) 113 | 114 | if reverse: 115 | mask = 1 - mask 116 | 117 | # Reshape to (1, 1, height, width) for broadcasting with tensors of shape (B, C, H, W) 118 | mask = mask.view(1, 1, height, width) 119 | 120 | return mask 121 | -------------------------------------------------------------------------------- /uncertainty_est/archs/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import init as nninit 4 | 5 | 6 | class GeneratorBlock(nn.Module): 7 | """ResNet-style block for the generator model.""" 8 | 9 | def __init__(self, in_chans, out_chans, upsample=False): 10 | super().__init__() 11 | 12 | self.upsample = upsample 13 | 14 | if in_chans != out_chans: 15 | self.shortcut_conv = nn.Conv2d(in_chans, out_chans, kernel_size=1) 16 | else: 17 | self.shortcut_conv = None 18 | self.bn1 = nn.BatchNorm2d(in_chans) 19 | self.conv1 = nn.Conv2d(in_chans, in_chans, kernel_size=3, padding=1) 20 | self.bn2 = nn.BatchNorm2d(in_chans) 21 | self.conv2 = nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1) 22 | self.relu = nn.ReLU() 23 | 24 | def forward(self, *inputs): 25 | x = inputs[0] 26 | 27 | if self.upsample: 28 | shortcut = nn.functional.interpolate(x, scale_factor=2, mode="nearest") 29 | else: 30 | shortcut = x 31 | 32 | if self.shortcut_conv is not None: 33 | shortcut = self.shortcut_conv(shortcut) 34 | 35 | x = self.bn1(x) 36 | x = self.relu(x) 37 | if self.upsample: 38 | x = nn.functional.interpolate(x, scale_factor=2, mode="nearest") 39 | x = self.conv1(x) 40 | x = self.bn2(x) 41 | x = self.relu(x) 42 | x = self.conv2(x) 43 | 44 | return x + shortcut 45 | 46 | 47 | class ResNetGenerator(nn.Module): 48 | """The generator model.""" 49 | 50 | def __init__(self, unit_interval, feats=128, out_channels=3): 51 | super().__init__() 52 | 53 | self.input_linear = nn.Linear(feats, 4 * 4 * feats) 54 | self.block1 = GeneratorBlock(feats, feats, upsample=True) 55 | self.block2 = GeneratorBlock(feats, feats, upsample=True) 56 | self.block3 = GeneratorBlock(feats, feats, upsample=True) 57 | self.output_bn = nn.BatchNorm2d(feats) 58 | self.output_conv = nn.Conv2d(feats, out_channels, kernel_size=3, padding=1) 59 | self.relu = nn.ReLU() 60 | self.feats = feats 61 | 62 | # Apply Xavier initialisation to the weights 63 | relu_gain = nninit.calculate_gain("relu") 64 | for module in self.modules(): 65 | if isinstance(module, (nn.Conv2d, nn.Linear)): 66 | gain = relu_gain if module != self.input_linear else 1.0 67 | nninit.xavier_uniform_(module.weight.data, gain=gain) 68 | module.bias.data.zero_() 69 | 70 | if unit_interval == True: 71 | self.final_act = nn.functional.sigmoid 72 | elif unit_interval == False: 73 | self.final_act = torch.tanh 74 | else: 75 | self.final_act = nn.Identity() 76 | 77 | self.last_output = None 78 | 79 | def forward(self, *inputs): 80 | x = inputs[0] 81 | 82 | x = self.input_linear(x) 83 | x = x.view(-1, self.feats, 4, 4) 84 | x = self.block1(x) 85 | x = self.block2(x) 86 | x = self.block3(x) 87 | x = self.output_bn(x) 88 | x = self.relu(x) 89 | x = self.output_conv(x) 90 | x = self.final_act(x) 91 | 92 | self.last_output = x 93 | 94 | return x 95 | -------------------------------------------------------------------------------- /uncertainty_est/archs/wrn.py: -------------------------------------------------------------------------------- 1 | """ 2 | WideResnet architecture adapted from https://github.com/meliketoy/wide-resnet.pytorch 3 | """ 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | """ 13 | Convolution with 3x3 kernels. 14 | """ 15 | return nn.Conv2d( 16 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True 17 | ) 18 | 19 | 20 | def conv_init(m): 21 | """ 22 | Initializing convolution layers. 23 | """ 24 | classname = m.__class__.__name__ 25 | if classname.find("Conv") != -1: 26 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 27 | init.constant_(m.bias, 0) 28 | elif classname.find("BatchNorm") != -1: 29 | init.constant_(m.weight, 1) 30 | init.constant_(m.bias, 0) 31 | 32 | 33 | class Identity(nn.Module): 34 | """ 35 | Identity norm as a stand in for no BN. 36 | """ 37 | 38 | def __init__(self, *args, **kwargs): 39 | super().__init__() 40 | 41 | def forward(self, x): 42 | """ 43 | Forward pass of model. 44 | """ 45 | return x 46 | 47 | 48 | class wide_basic(nn.Module): 49 | """ 50 | One block in the Wide resnet. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | in_planes, 56 | planes, 57 | dropout_rate, 58 | stride=1, 59 | norm=None, 60 | leak=0.2, 61 | first=False, 62 | ): 63 | super(wide_basic, self).__init__() 64 | self.lrelu = nn.LeakyReLU(leak) 65 | self.bn1 = get_norm(in_planes, norm) 66 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 67 | self.dropout = Identity() if dropout_rate == 0.0 else nn.Dropout(p=dropout_rate) 68 | self.bn2 = get_norm(planes, norm) 69 | self.conv2 = nn.Conv2d( 70 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True 71 | ) 72 | 73 | self.first = first 74 | 75 | self.shortcut = nn.Sequential() 76 | if stride != 1 or in_planes != planes: 77 | self.shortcut = nn.Sequential( 78 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 79 | ) 80 | 81 | def forward(self, x): 82 | """ 83 | Forward pass of block. 84 | """ 85 | if ( 86 | self.first 87 | ): # if it's the first block, don't apply the first batchnorm to the data 88 | out = self.dropout(self.conv1(self.lrelu(x))) 89 | else: 90 | out = self.dropout(self.conv1(self.lrelu(self.bn1(x)))) 91 | out = self.conv2(self.lrelu(self.bn2(out))) 92 | out += self.shortcut(x) 93 | 94 | return out 95 | 96 | 97 | def get_norm(n_filters, norm): 98 | """ 99 | Get batchnorm or other. 100 | """ 101 | if norm is None: 102 | return Identity() 103 | elif norm == "batch": 104 | return nn.BatchNorm2d(n_filters) 105 | elif norm == "instance": 106 | return nn.InstanceNorm2d(n_filters, affine=True) 107 | elif norm == "layer": 108 | return nn.GroupNorm(1, n_filters) 109 | elif norm == "group": 110 | return nn.GroupNorm(32, n_filters) 111 | 112 | 113 | class WideResNet(nn.Module): 114 | """ 115 | Wide resnet model. 116 | """ 117 | 118 | def __init__( 119 | self, 120 | depth, 121 | widen_factor, 122 | num_classes=10, 123 | input_channels=3, 124 | sum_pool=False, 125 | norm=None, 126 | leak=0.2, 127 | dropout=0.0, 128 | strides=(1, 2, 2), 129 | bottleneck_dim=None, 130 | bottleneck_channels_factor=None, 131 | ): 132 | super(WideResNet, self).__init__() 133 | self.leak = leak 134 | self.in_planes = 16 135 | self.sum_pool = sum_pool 136 | self.norm = norm 137 | self.lrelu = nn.LeakyReLU(leak) 138 | self.bottleneck_dim = bottleneck_dim 139 | self.bottleneck_channels_factor = bottleneck_channels_factor 140 | 141 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4" 142 | n = (depth - 4) // 6 143 | k = widen_factor 144 | 145 | print("| Wide-Resnet %dx%d" % (depth, k)) 146 | nStages = [16, 16 * k, 32 * k, 64 * k] 147 | 148 | self.conv1 = conv3x3(input_channels, nStages[0]) 149 | self.layer1 = self._wide_layer( 150 | wide_basic, nStages[1], n, dropout, stride=strides[0], first=True 151 | ) 152 | self.layer2 = self._wide_layer( 153 | wide_basic, nStages[2], n, dropout, stride=strides[1] 154 | ) 155 | self.layer3 = self._wide_layer( 156 | wide_basic, nStages[3], n, dropout, stride=strides[2] 157 | ) 158 | self.bn1 = get_norm(nStages[3], self.norm) 159 | self.last_dim = nStages[3] 160 | self.linear = nn.Linear(nStages[3], num_classes) 161 | 162 | if self.bottleneck_dim is not None: 163 | self.bottleneck = nn.Sequential( 164 | nn.Linear(nStages[3], nStages[3] // 2), 165 | nn.ReLU(True), 166 | nn.Linear(nStages[3] // 2, self.bottleneck_dim), 167 | nn.ReLU(True), 168 | nn.Linear(self.bottleneck_dim, nStages[3] // 2), 169 | nn.ReLU(True), 170 | nn.Linear(nStages[3] // 2, nStages[3]), 171 | ) 172 | 173 | self.apply(conv_init) 174 | 175 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride, first=False): 176 | strides = [stride] + [1] * (num_blocks - 1) 177 | layers = [] 178 | 179 | for i, stride in enumerate(strides): 180 | if first and i == 0: # first block of first layer has no BN 181 | layers.append( 182 | block( 183 | self.in_planes, 184 | planes, 185 | dropout_rate, 186 | stride, 187 | norm=self.norm, 188 | first=True, 189 | ) 190 | ) 191 | else: 192 | layers.append( 193 | block(self.in_planes, planes, dropout_rate, stride, norm=self.norm) 194 | ) 195 | self.in_planes = planes 196 | 197 | if self.bottleneck_channels_factor is not None: 198 | bottleneck_channels = int(self.in_planes * self.bottleneck_channels_factor) 199 | layers.extend( 200 | [ 201 | nn.Conv2d(self.in_planes, bottleneck_channels, kernel_size=1), 202 | nn.Conv2d(bottleneck_channels, self.in_planes, kernel_size=1), 203 | ] 204 | ) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def encode(self, x, vx=None): 209 | out = self.conv1(x) 210 | out = self.layer1(out) 211 | out = self.layer2(out) 212 | out = self.layer3(out) 213 | out = self.lrelu(self.bn1(out)) 214 | if self.sum_pool: 215 | out = out.view(out.size(0), out.size(1), -1).sum(2) 216 | else: 217 | out = F.avg_pool2d(out, out.shape[2:]) 218 | out = out.view(out.size(0), -1) 219 | 220 | if self.bottleneck_dim is not None: 221 | out = self.bottleneck(out) 222 | return out 223 | 224 | def forward(self, x, vx=None): 225 | """ 226 | Forward pass. TODO: purpose of vx? 227 | """ 228 | out = self.encode(x, vx) 229 | 230 | return self.linear(out) 231 | 232 | 233 | if __name__ == "__main__": 234 | import torch 235 | 236 | for strides in [(1, 2, 2), (1, 1, 2), (1, 1, 1)]: 237 | wrn = WideResNet(28, 10, strides=strides, bottleneck_channels_factor=0.1) 238 | print(wrn) 239 | out = wrn(torch.zeros(1, 3, 32, 32)) 240 | print(strides, out.shape) 241 | -------------------------------------------------------------------------------- /uncertainty_est/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/data/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/data/dataloaders.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import PIL.Image as InterpolationMode 6 | from torchvision import transforms as tvt 7 | from uncertainty_eval.datasets import get_dataset as ue_get_dataset 8 | 9 | from uncertainty_est.data.datasets import ConcatDataset, ConcatIterableDataset 10 | 11 | DATA_ROOT = Path("../data") 12 | 13 | 14 | def get_dataset(dataset, data_shape=None, length=10_000, split_seed=1): 15 | try: 16 | ds_class = ue_get_dataset(dataset) 17 | 18 | if data_shape is None: 19 | data_shape = ds_class.data_shape 20 | 21 | if dataset == "gaussian_noise": 22 | m = 127.5 if len(data_shape) == 3 else 0.0 23 | s = 60.0 if len(data_shape) == 3 else 1.0 24 | mean = torch.empty(*data_shape).fill_(m) 25 | std = torch.empty(*data_shape).fill_(s) 26 | ds = ds_class(DATA_ROOT, length=length, mean=mean, std=std) 27 | elif dataset == "uniform_noise": 28 | l = 0.0 if len(data_shape) == 3 else -5.0 29 | h = 255.0 if len(data_shape) == 3 else 5.0 30 | low = torch.empty(*data_shape).fill_(l) 31 | high = torch.empty(*data_shape).fill_(h) 32 | ds = ds_class(DATA_ROOT, length=length, low=low, high=high) 33 | elif dataset == "constant": 34 | low = 0.0 if len(data_shape) == 3 else -5.0 35 | high = 255.0 if len(data_shape) == 3 else 5.0 36 | ds = ds_class( 37 | DATA_ROOT, length=length, low=low, high=high, shape=data_shape 38 | ) 39 | else: 40 | ds = ds_class(DATA_ROOT) 41 | except KeyError as e: 42 | raise ValueError(f'Dataset "{dataset}" not supported') from e 43 | return ds, data_shape 44 | 45 | 46 | def get_dataloader( 47 | dataset, 48 | split, 49 | batch_size=32, 50 | data_shape=None, 51 | ood_dataset=None, 52 | sigma=0.0, 53 | num_workers=0, 54 | drop_last=None, 55 | shuffle=None, 56 | mutation_rate=0.0, 57 | split_seed=1, 58 | normalize=True, 59 | extra_train_transforms=[], 60 | extra_test_transforms=[], 61 | ): 62 | train_transform = [] 63 | test_transform = [] 64 | 65 | unscaled = False 66 | try: 67 | ds, data_shape = get_dataset(dataset, data_shape, split_seed=split_seed) 68 | except ValueError as e: 69 | if "_unscaled" in dataset: 70 | dataset = dataset.replace("_unscaled", "") 71 | unscaled = True 72 | ds, data_shape = get_dataset(dataset, data_shape, split_seed=split_seed) 73 | else: 74 | raise e 75 | 76 | if len(data_shape) == 3: 77 | img_size = data_shape[1] 78 | train_transform.extend( 79 | [ 80 | tvt.Resize(img_size, InterpolationMode.LANCZOS), 81 | tvt.CenterCrop(img_size), 82 | tvt.Pad(4, padding_mode="reflect"), 83 | tvt.RandomRotation(15, resample=InterpolationMode.BILINEAR), 84 | tvt.RandomHorizontalFlip(), 85 | tvt.RandomCrop(img_size), 86 | ] 87 | ) 88 | 89 | test_transform.extend( 90 | [ 91 | tvt.Resize(img_size, InterpolationMode.LANCZOS), 92 | tvt.CenterCrop(img_size), 93 | ] 94 | ) 95 | 96 | if unscaled: 97 | scale_transform = [tvt.ToTensor(), tvt.Lambda(lambda x: x * 255.0)] 98 | else: 99 | scale_transform = [tvt.ToTensor()] 100 | if normalize: 101 | scale_transform.append( 102 | tvt.Normalize((0.5,) * data_shape[2], (0.5,) * data_shape[2]) 103 | ) 104 | 105 | test_transform.extend(scale_transform) 106 | train_transform.extend(scale_transform) 107 | 108 | test_transform.extend(extra_test_transforms) 109 | train_transform.extend(extra_train_transforms) 110 | 111 | if sigma > 0.0: 112 | noise_transform = lambda x: x + sigma * torch.randn_like(x) 113 | train_transform.append(noise_transform) 114 | test_transform.append(noise_transform) 115 | 116 | if mutation_rate > 0.0: 117 | if len(data_shape) == 3: 118 | mutation_data_shape = (data_shape[2], data_shape[0], data_shape[1]) 119 | else: 120 | mutation_data_shape = data_shape 121 | 122 | def mutation_transform(x): 123 | mask = torch.bernoulli( 124 | torch.empty(mutation_data_shape).fill_(mutation_rate) 125 | ) 126 | replace = torch.empty(mutation_data_shape).uniform_(-1, 1) * (1 - mask) 127 | return x * mask + replace 128 | 129 | train_transform.append(mutation_transform) 130 | 131 | train_transform = tvt.Compose(train_transform) 132 | test_transform = tvt.Compose(test_transform) 133 | 134 | if split == "train": 135 | ds = ds.train(train_transform) 136 | elif split == "val": 137 | ds = ds.val(test_transform) 138 | else: 139 | ds = ds.test(test_transform) 140 | 141 | setattr(ds, "data_shape", data_shape) 142 | 143 | if isinstance(ds, torch.utils.data.IterableDataset): 144 | shuffle = False 145 | else: 146 | shuffle = split == "train" if shuffle is None else shuffle 147 | 148 | if ood_dataset is not None: 149 | if isinstance(ds, torch.utils.data.IterableDataset): 150 | ood_ds, _ = get_dataset(ood_dataset, data_shape) 151 | ds = ConcatIterableDataset(ds, ood_ds.train(train_transform)) 152 | else: 153 | ood_ds, _ = get_dataset(ood_dataset, data_shape, length=len(ds)) 154 | 155 | ood_train = ood_ds.train(train_transform) 156 | ds = ConcatDataset(ds, ood_train) 157 | 158 | dataloader = DataLoader( 159 | ds, 160 | batch_size=batch_size, 161 | pin_memory=True, 162 | num_workers=num_workers, 163 | shuffle=shuffle, 164 | drop_last=split == "train" if drop_last is None else drop_last, 165 | ) 166 | return dataloader 167 | -------------------------------------------------------------------------------- /uncertainty_est/data/datasets.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from torch.utils.data import Dataset, IterableDataset 4 | 5 | 6 | class ConcatDataset(Dataset): 7 | def __init__(self, *datasets): 8 | super().__init__() 9 | self.datasets: List[Dataset] = datasets 10 | 11 | def __len__(self): 12 | return min([len(ds) for ds in self.datasets]) 13 | 14 | def __getitem__(self, idx): 15 | return [ds[idx] for ds in self.datasets] 16 | 17 | 18 | class ConcatIterableDataset(IterableDataset): 19 | def __init__(self, *datasets): 20 | super().__init__() 21 | self.datasets: List[Dataset] = datasets 22 | 23 | def __iter__(self): 24 | return zip(*self.datasets) 25 | -------------------------------------------------------------------------------- /uncertainty_est/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from numpy.lib.arraysetops import isin 5 | 6 | sys.path.insert(0, os.getcwd()) 7 | 8 | import logging 9 | from copy import copy 10 | from pathlib import Path 11 | from argparse import ArgumentParser 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | from uncertainty_est.data.dataloaders import get_dataloader 17 | from uncertainty_est.models import load_checkpoint, load_model 18 | 19 | 20 | parser = ArgumentParser() 21 | parser.add_argument("--checkpoint", type=str, action="append", default=[]) 22 | parser.add_argument("--dataset", type=str) 23 | parser.add_argument("--ood_dataset", type=str, action="append") 24 | parser.add_argument("--eval-classification", action="store_true") 25 | parser.add_argument("--output-folder", type=str) 26 | parser.add_argument("--name", type=str, default="out") 27 | parser.add_argument("--max-eval", type=int, default=10_000) 28 | parser.add_argument("--checkpoint-dir", type=str) 29 | parser.add_argument("--config-entries", type=str, action="append", default=[]) 30 | 31 | 32 | logger = logging.getLogger() 33 | logger.setLevel(logging.INFO) 34 | 35 | 36 | def eval_model( 37 | model, 38 | dataset, 39 | ood_datasets, 40 | eval_classification=False, 41 | model_name="", 42 | batch_size=128, 43 | max_items=-1, 44 | data_shape=None, 45 | **kwargs, 46 | ): 47 | id_test_loader = get_dataloader( 48 | dataset, "test", batch_size=batch_size, data_shape=data_shape, **kwargs 49 | ) 50 | 51 | clf_accum = [] 52 | if eval_classification: 53 | clf_results = model.eval_classifier(id_test_loader, max_items) 54 | for k, v in clf_results.items(): 55 | logger.info(f"{k}: {v:.02f}") 56 | clf_accum.append((model_name, model.__class__.__name__, dataset, k, v)) 57 | 58 | test_ood_dataloaders = [] 59 | for test_ood_dataset in ood_datasets: 60 | loader = get_dataloader( 61 | test_ood_dataset, 62 | "test", 63 | data_shape=id_test_loader.dataset.data_shape, 64 | batch_size=batch_size, 65 | **kwargs, 66 | ) 67 | test_ood_dataloaders.append((test_ood_dataset, loader)) 68 | 69 | ood_results = model.eval_ood(id_test_loader, test_ood_dataloaders) 70 | 71 | accum = [] 72 | for k, v in ood_results.items(): 73 | logger.info(f"{k}: {v:.02f}") 74 | accum.append((model_name, model.__class__.__name__, dataset, *k, v)) 75 | 76 | return accum, clf_accum 77 | 78 | 79 | if __name__ == "__main__": 80 | args = parser.parse_args() 81 | base_args = copy(args) 82 | 83 | ood_tbl_rows = [] 84 | clf_tbl_rows = [] 85 | for checkpoint in args.checkpoint: 86 | # Reset args to original 87 | args = copy(base_args) 88 | checkpoint_path = Path(checkpoint) 89 | model_name = checkpoint_path.parent.stem 90 | 91 | model, config = load_checkpoint(checkpoint_path, strict=False) 92 | model.eval() 93 | model.cuda() 94 | 95 | if not args.ood_dataset: 96 | args.ood_dataset = config["test_ood_datasets"] 97 | 98 | if not args.dataset: 99 | args.dataset = config["dataset"] 100 | 101 | ood_rows, clf_rows = eval_model( 102 | model, 103 | args.dataset, 104 | args.ood_dataset, 105 | args.eval_classification, 106 | model_name=model_name, 107 | batch_size=128, 108 | max_items=args.max_eval, 109 | normalize=config["normalize"] if "normalize" in config else True, 110 | data_shape=config["data_shape"], 111 | ) 112 | ood_tbl_rows.extend(ood_rows) 113 | clf_tbl_rows.append(clf_rows) 114 | 115 | if args.checkpoint_dir: 116 | checkpoint_dir = Path(args.checkpoint_dir) 117 | for model_dir in checkpoint_dir.glob("**/version_*"): 118 | # Reset args to original 119 | args = copy(base_args) 120 | try: 121 | model, config = load_model(model_dir, last=False, strict=False) 122 | except Exception as e: 123 | logger.info(str(e)) 124 | continue 125 | model.eval() 126 | model.cuda() 127 | 128 | if not args.ood_dataset: 129 | args.ood_dataset = config["test_ood_datasets"] 130 | 131 | if not args.dataset: 132 | args.dataset = config["dataset"] 133 | 134 | ood_rows, clf_rows = eval_model( 135 | model, 136 | args.dataset, 137 | args.ood_dataset, 138 | args.eval_classification, 139 | model_name=model_dir.parent.stem, 140 | batch_size=128, 141 | max_items=args.max_eval, 142 | normalize=config["normalize"] if "normalize" in config else True, 143 | data_shape=config["data_shape"], 144 | ) 145 | 146 | extra_cols = [] 147 | for e in args.config_entries: 148 | out = config 149 | for key in e.split("."): 150 | out = out.get(key, np.nan) 151 | if out == np.nan: 152 | break 153 | 154 | if isinstance(out, dict): 155 | raise ValueError("Error getting config entry") 156 | extra_cols.append(out) 157 | ood_tbl_rows.extend([[*row, *extra_cols] for row in ood_rows]) 158 | clf_tbl_rows.extend([[*row, *extra_cols] for row in clf_rows]) 159 | 160 | extra_row_names = [s.split(".")[-1] for s in args.config_entries] 161 | if args.output_folder: 162 | output_folder = Path(args.output_folder) 163 | ood_df = pd.DataFrame( 164 | ood_tbl_rows, 165 | columns=( 166 | "Model", 167 | "Model Type", 168 | "ID dataset", 169 | "OOD dataset", 170 | "Score", 171 | "Metric", 172 | "Value", 173 | *extra_row_names, 174 | ), 175 | ) 176 | ood_df.to_csv(output_folder / f"ood-{args.name}.csv", index=False) 177 | 178 | if args.eval_classification: 179 | clf_df = pd.DataFrame( 180 | clf_tbl_rows, 181 | columns=( 182 | "Model", 183 | "Model Type", 184 | "ID dataset", 185 | "Metric", 186 | "Value", 187 | *extra_row_names, 188 | ), 189 | ) 190 | clf_df.to_csv(output_folder / f"clf-{args.name}.csv", index=False) 191 | -------------------------------------------------------------------------------- /uncertainty_est/models/__init__.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pathlib import Path 3 | 4 | from uncertainty_est.models.ebm.mcmc import MCMC 5 | from uncertainty_est.models.ebm.discrete_mcmc import DiscreteMCMC 6 | from uncertainty_est.models.ce_baseline import CEBaseline 7 | from uncertainty_est.models.energy_finetuning import EnergyFinetune 8 | from uncertainty_est.models.ebm.vera import VERA 9 | from uncertainty_est.models.normalizing_flow.norm_flow import NormalizingFlow 10 | from uncertainty_est.models.ebm.conditional_nce import PerSampleNCE 11 | from .normalizing_flow.image_flows import RealNVPModel, GlowModel 12 | from .ebm.nce import NoiseContrastiveEstimation 13 | from .ebm.flow_contrastive_estimation import FlowContrastiveEstimation 14 | from .ebm.ssm import SSM 15 | 16 | 17 | MODELS = { 18 | "JEM": MCMC, 19 | "DiscreteMCMC": DiscreteMCMC, 20 | "CEBaseline": CEBaseline, 21 | "EnergyOOD": EnergyFinetune, 22 | "VERA": VERA, 23 | "NormalizingFlow": NormalizingFlow, 24 | "PerSampleNCE": PerSampleNCE, 25 | "RealNVP": RealNVPModel, 26 | "Glow": GlowModel, 27 | "NCE": NoiseContrastiveEstimation, 28 | "FlowCE": FlowContrastiveEstimation, 29 | "SSM": SSM, 30 | } 31 | 32 | 33 | def load_model(model_folder: Path, last=False, *args, **kwargs): 34 | ckpts = [file for file in model_folder.iterdir() if file.suffix == ".ckpt"] 35 | last_ckpt = [ckpt for ckpt in ckpts if ckpt.stem == "last"] 36 | best_ckpt = sorted([ckpt for ckpt in ckpts if ckpt.stem != "last"]) 37 | 38 | if last: 39 | ckpts = best_ckpt + last_ckpt 40 | else: 41 | ckpts = last_ckpt + best_ckpt 42 | assert len(ckpts) > 0 43 | 44 | checkpoint_path = ckpts[-1] 45 | return load_checkpoint(checkpoint_path, *args, **kwargs) 46 | 47 | 48 | def load_checkpoint(checkpoint_path: Path, *args, **kwargs): 49 | with (checkpoint_path.parent / "config.yaml").open("r") as f: 50 | config = yaml.load(f, Loader=yaml.FullLoader) 51 | 52 | model = MODELS[config["model_name"]].load_from_checkpoint( 53 | checkpoint_path, *args, **kwargs 54 | ) 55 | return model, config 56 | -------------------------------------------------------------------------------- /uncertainty_est/models/ce_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from uncertainty_est.archs.arch_factory import get_arch 5 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 6 | 7 | 8 | class CEBaseline(OODDetectionModel): 9 | def __init__( 10 | self, arch_name, arch_config, learning_rate, momentum, weight_decay, **kwargs 11 | ): 12 | super().__init__(**kwargs) 13 | self.__dict__.update(locals()) 14 | self.save_hyperparameters() 15 | 16 | self.backbone = get_arch(arch_name, arch_config) 17 | 18 | def forward(self, x): 19 | return self.backbone(x) 20 | 21 | def training_step(self, batch, batch_idx): 22 | x, y = batch 23 | y_hat = self(x) 24 | 25 | loss = F.cross_entropy(y_hat, y) 26 | self.log("train/loss", loss) 27 | return loss 28 | 29 | def validation_step(self, batch, batch_idx): 30 | x, y = batch 31 | y_hat = self(x) 32 | 33 | loss = F.cross_entropy(y_hat, y) 34 | self.log("val/loss", loss) 35 | 36 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 37 | self.log("val/acc", acc) 38 | 39 | def test_step(self, batch, batch_idx): 40 | x, y = batch 41 | y_hat = self(x) 42 | 43 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 44 | self.log("test/acc", acc) 45 | 46 | def configure_optimizers(self): 47 | optim = torch.optim.AdamW( 48 | self.parameters(), 49 | betas=(self.momentum, 0.999), 50 | lr=self.learning_rate, 51 | weight_decay=self.weight_decay, 52 | ) 53 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5) 54 | return [optim], [scheduler] 55 | 56 | def classify(self, x): 57 | return torch.softmax(self.backbone(x), -1) 58 | 59 | def get_ood_scores(self, x): 60 | logits = self(x).cpu() 61 | dir_uncert = dirichlet_prior_network_uncertainty(logits) 62 | dir_uncert["p(x)"] = logits.logsumexp(1) 63 | dir_uncert["max p(y|x)"] = logits.softmax(1).max(1)[0] 64 | return dir_uncert 65 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/models/ebm/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/conditional_nce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | 4 | from uncertainty_est.archs.arch_factory import get_arch 5 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 6 | 7 | 8 | class PerSampleNCE(OODDetectionModel): 9 | def __init__( 10 | self, 11 | arch_name, 12 | arch_config, 13 | learning_rate, 14 | momentum, 15 | weight_decay, 16 | noise_sigma=0.01, 17 | p_control_weight=0.0, 18 | ): 19 | super().__init__() 20 | self.__dict__.update(locals()) 21 | self.save_hyperparameters() 22 | 23 | self.model = get_arch(arch_name, arch_config) 24 | 25 | def forward(self, x): 26 | return self.model(x) 27 | 28 | def training_step(self, batch, batch_idx): 29 | x, _ = batch 30 | 31 | sample_shape = x.shape[1:] 32 | sample_dim = sample_shape.numel() 33 | noise_dist = distributions.multivariate_normal.MultivariateNormal( 34 | torch.zeros(sample_dim).to(self.device), 35 | torch.eye(sample_dim).to(self.device) * self.noise_sigma, 36 | ) 37 | noise = noise_dist.sample(x.size()[:1]) 38 | 39 | # Implements Eq. 9 in "Conditional Noise-Contrastive Estimation of Unnormalised Models" 40 | # Uses symmetry of noise distribution meaning p(u1|u2) = p(u2|u1) to simplify 41 | # Sets k = 1 42 | x_noisy = x + noise.reshape_as(x) 43 | log_p_model = self.model(torch.cat((x, x_noisy))).squeeze() 44 | log_p_x = log_p_model[: len(x)] 45 | log_p_x_noisy = log_p_model[len(x) :] 46 | 47 | loss = torch.log(1 + (-(log_p_x - log_p_x_noisy)).exp()).mean() 48 | 49 | p_control = log_p_model.abs().mean() 50 | loss += self.p_control_weight * p_control 51 | 52 | self.log("train/log_p_magnitude", log_p_x.mean(), prog_bar=True) 53 | self.log("train/log_p_noisy_magnitude", log_p_x_noisy.mean(), prog_bar=True) 54 | self.log("train/loss", loss) 55 | return loss 56 | 57 | def validation_step(self, batch, batch_idx): 58 | return 59 | 60 | def test_step(self, batch, batch_idx): 61 | self.to(torch.float32) 62 | x, y = batch 63 | y_hat = self.model(x) 64 | 65 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 66 | self.log("test_acc", acc) 67 | return y_hat 68 | 69 | def configure_optimizers(self): 70 | optim = torch.optim.AdamW( 71 | self.parameters(), 72 | betas=(self.momentum, 0.999), 73 | lr=self.learning_rate, 74 | weight_decay=self.weight_decay, 75 | ) 76 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5) 77 | return [optim], [scheduler] 78 | 79 | def get_ood_scores(self, x): 80 | return {"p(x)": self.model(x)} 81 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/discrete_mcmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.distributions import Categorical 4 | 5 | from uncertainty_est.models.ebm.mcmc import MCMC 6 | 7 | 8 | def init_random(buffer_size, data_shape, dim): 9 | buffer = torch.FloatTensor(buffer_size, *data_shape).random_(0, dim) 10 | buffer = F.one_hot(buffer.long(), num_classes=dim).float() 11 | return buffer 12 | 13 | 14 | class DiscreteMCMC(MCMC): 15 | def __init__( 16 | self, 17 | arch_name, 18 | arch_config, 19 | learning_rate, 20 | momentum, 21 | weight_decay, 22 | buffer_size, 23 | n_classes, 24 | data_shape, 25 | smoothing, 26 | pyxce, 27 | pxsgld, 28 | pxysgld, 29 | class_cond_p_x_sample, 30 | sgld_batch_size, 31 | reinit_freq, 32 | num_cat, 33 | sgld_steps=20, 34 | entropy_reg_weight=0.0, 35 | warmup_steps=0, 36 | lr_step_size=50, 37 | **kwargs 38 | ): 39 | kwargs.pop("sgld_lr") 40 | kwargs.pop("sgld_std") 41 | self.num_cat = num_cat 42 | super().__init__( 43 | arch_name=arch_name, 44 | arch_config=arch_config, 45 | learning_rate=learning_rate, 46 | momentum=momentum, 47 | weight_decay=weight_decay, 48 | buffer_size=buffer_size, 49 | n_classes=n_classes, 50 | data_shape=data_shape, 51 | smoothing=smoothing, 52 | pyxce=pyxce, 53 | pxsgld=pxsgld, 54 | pxysgld=pxysgld, 55 | class_cond_p_x_sample=class_cond_p_x_sample, 56 | sgld_batch_size=sgld_batch_size, 57 | sgld_lr=0.0, 58 | sgld_std=0.0, 59 | reinit_freq=reinit_freq, 60 | sgld_steps=sgld_steps, 61 | entropy_reg_weight=entropy_reg_weight, 62 | warmup_steps=warmup_steps, 63 | lr_step_size=lr_step_size, 64 | **kwargs 65 | ) 66 | self.save_hyperparameters() 67 | 68 | def _init_buffer(self): 69 | self.replay_buffer = init_random( 70 | self.buffer_size, self.data_shape, self.num_cat 71 | ).cpu() 72 | 73 | def sample_p_0(self, replay_buffer, bs, y=None): 74 | if len(replay_buffer) == 0: 75 | return init_random(bs, self.data_shape, self.num_cat), [] 76 | 77 | buffer_size = ( 78 | len(replay_buffer) if y is None else len(replay_buffer) // self.n_classes 79 | ) 80 | inds = torch.randint(0, buffer_size, (bs,)) 81 | # if cond, convert inds to class conditional inds 82 | if y is not None: 83 | inds = y.cpu() * buffer_size + inds 84 | 85 | buffer_samples = replay_buffer[inds].to(self.device) 86 | if self.reinit_freq > 0.0: 87 | random_samples = init_random(bs, self.data_shape, self.num_cat).to( 88 | self.device 89 | ) 90 | choose_random = (torch.rand(bs) < self.reinit_freq).to(buffer_samples)[ 91 | (...,) + (None,) * (len(self.data_shape) + 1) 92 | ] 93 | samples = ( 94 | choose_random * random_samples + (1 - choose_random) * buffer_samples 95 | ) 96 | else: 97 | samples = buffer_samples 98 | return samples.to(self.device), inds 99 | 100 | def sample_q(self, replay_buffer, y=None, n_steps=20, contrast=False): 101 | self.model.eval() 102 | bs = self.sgld_batch_size if y is None else y.size(0) 103 | 104 | # generate initial samples and buffer inds of those samples (if buffer is used) 105 | init_sample, buffer_inds = self.sample_p_0(replay_buffer, bs=bs, y=y) 106 | x_k = torch.autograd.Variable(init_sample, requires_grad=True) 107 | 108 | # Gradient with Gibbs "http://arxiv.org/abs/2102.04509" 109 | for _ in range(n_steps): 110 | energy = self.model(x_k, y=y) 111 | f_prime = torch.autograd.grad(energy.sum(), [x_k], retain_graph=True)[0] 112 | 113 | d = f_prime - (x_k * f_prime).sum(-1, keepdim=True) 114 | q_i_given_x = Categorical(logits=(d / 2.0).flatten(start_dim=1)) 115 | i = q_i_given_x.sample() 116 | prob_i_given_x = q_i_given_x.log_prob(i).exp() 117 | 118 | # Flip sampled dimension 119 | x_q_idx = x_k.argmax(-1) 120 | x_q_idx.flatten(start_dim=1)[torch.arange(len(x_k)), i // self.num_cat] = ( 121 | i % self.num_cat 122 | ) 123 | x_q = F.one_hot(x_q_idx, self.num_cat).float() 124 | x_q.requires_grad_() 125 | 126 | energy_q = self.model(x_q, y=y) 127 | f_prime = torch.autograd.grad(energy_q.sum(), [x_q], retain_graph=True)[0] 128 | d = f_prime - (x_q * f_prime).sum(-1, keepdim=True) 129 | q_i_given_x = Categorical(logits=(d / 2.0).flatten(start_dim=1)) 130 | prob_i_given_x_q = q_i_given_x.log_prob(i).exp() 131 | 132 | # Update samples dependig on Metropolis-Hastings Probability 133 | keep_prob = torch.exp(energy_q - energy) * ( 134 | prob_i_given_x_q / prob_i_given_x 135 | ) 136 | keep_prob[keep_prob > 1.0] = 1.0 137 | keep = torch.rand(len(x_k)).to(self.device) < keep_prob 138 | 139 | x_k = x_k.detach() 140 | x_k[keep] = x_q[keep] 141 | 142 | self.model.train() 143 | final_samples = x_k.detach() 144 | 145 | # update replay buffer 146 | if len(replay_buffer) > 0: 147 | replay_buffer[buffer_inds] = final_samples.cpu() 148 | return final_samples 149 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/flow_contrastive_estimation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | 7 | from uncertainty_est.utils.utils import to_np 8 | from uncertainty_est.archs.arch_factory import get_arch 9 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 10 | 11 | 12 | class FlowContrastiveEstimation(OODDetectionModel): 13 | """Implementation of Noise Contrastive Estimation 14 | http://proceedings.mlr.press/v9/gutmann10a.html 15 | """ 16 | 17 | def __init__( 18 | self, 19 | arch_name, 20 | arch_config, 21 | flow_arch_name, 22 | flow_arch_config, 23 | learning_rate, 24 | momentum, 25 | weight_decay, 26 | flow_learning_rate, 27 | rho=0.5, 28 | is_toy_dataset=False, 29 | toy_dataset_dim=2, 30 | ): 31 | super().__init__() 32 | self.automatic_optimization = False 33 | self.__dict__.update(locals()) 34 | self.save_hyperparameters() 35 | 36 | self.model = get_arch(arch_name, arch_config) 37 | self.noise_dist = get_arch(flow_arch_name, flow_arch_config) 38 | 39 | def forward(self, x): 40 | return self.model(x) 41 | 42 | def training_step(self, batch, batch_idx, optimizer_idx): 43 | optim_ebm, optim_flow = self.optimizers() 44 | x, _ = batch 45 | 46 | noise, neg_f_log_p = self.noise_dist.sample(len(x)) 47 | neg_e_log_p = self.model(noise.detach()).logsumexp(-1) 48 | 49 | target = torch.zeros_like(neg_e_log_p).long().to(self.device) 50 | neg_pred = torch.stack( 51 | (neg_f_log_p + math.log(1 - self.rho), neg_e_log_p + math.log(self.rho)), -1 52 | ) 53 | neg_loss = F.cross_entropy(neg_pred, target) 54 | 55 | pos_f_log_p = self.noise_dist.log_prob(x) 56 | pos_e_log_p = self.model(x).logsumexp(-1) 57 | 58 | target = torch.ones_like(pos_e_log_p).long().to(self.device) 59 | pos_pred = torch.stack( 60 | (pos_f_log_p + math.log(1 - self.rho), pos_e_log_p + math.log(self.rho)), -1 61 | ) 62 | pos_loss = F.cross_entropy(pos_pred, target) 63 | 64 | loss = self.rho * pos_loss + (1 - self.rho) * neg_loss 65 | 66 | pos_acc = (pos_e_log_p > pos_f_log_p).float().mean() 67 | neg_acc = (neg_f_log_p > neg_e_log_p).float().mean() 68 | 69 | self.log("train/pos_acc", pos_acc, prog_bar=True) 70 | self.log("train/neg_acc", neg_acc, prog_bar=True) 71 | 72 | if pos_acc < 0.55 or neg_acc < 0.55: 73 | optim_ebm.zero_grad() 74 | self.manual_backward(loss) 75 | optim_ebm.step() 76 | self.log("train/loss", loss, prog_bar=True) 77 | else: 78 | optim_flow.zero_grad() 79 | self.manual_backward(-pos_loss) 80 | optim_flow.step() 81 | self.log("train/flow_loss", -loss, prog_bar=True) 82 | 83 | def validation_step(self, batch, batch_idx): 84 | return 85 | 86 | def validation_epoch_end(self, outputs): 87 | if self.is_toy_dataset and self.toy_dataset_dim == 2: 88 | interp = torch.linspace(-4, 4, 500) 89 | x, y = torch.meshgrid(interp, interp) 90 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1).to(self.device) 91 | p_xy = torch.exp(self(data)) 92 | px = to_np(p_xy.sum(1)) 93 | flow_px = to_np(self.noise_dist.log_prob(data).exp()) 94 | 95 | x, y = to_np(x), to_np(y) 96 | for i in range(p_xy.shape[1]): 97 | fig, ax = plt.subplots() 98 | mesh = ax.pcolormesh(x, y, to_np(p_xy[:, i]).reshape(*x.shape)) 99 | fig.colorbar(mesh) 100 | self.logger.experiment.add_figure( 101 | f"dist/p(x,y={i})", fig, self.current_epoch 102 | ) 103 | plt.close() 104 | 105 | fig, ax = plt.subplots() 106 | mesh = ax.pcolormesh(x, y, px.reshape(*x.shape)) 107 | fig.colorbar(mesh) 108 | self.logger.experiment.add_figure("dist/p(x)", fig, self.current_epoch) 109 | plt.close() 110 | 111 | fig, ax = plt.subplots() 112 | samples = to_np(self.noise_dist.sample(1000)[0]) 113 | mesh = ax.scatter(samples[:, 0], samples[:, 1]) 114 | self.logger.experiment.add_figure( 115 | "dist/flow_samples", fig, self.current_epoch 116 | ) 117 | plt.close() 118 | 119 | fig, ax = plt.subplots() 120 | mesh = ax.pcolormesh(x, y, flow_px.reshape(*x.shape)) 121 | fig.colorbar(mesh) 122 | self.logger.experiment.add_figure("dist/Flow p(x)", fig, self.current_epoch) 123 | plt.close() 124 | 125 | def test_step(self, batch, batch_idx): 126 | self.to(torch.float32) 127 | x, y = batch 128 | y_hat = self.model(x) 129 | 130 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 131 | self.log("test_acc", acc) 132 | return y_hat 133 | 134 | def configure_optimizers(self): 135 | optim = torch.optim.AdamW( 136 | self.model.parameters(), 137 | betas=(self.momentum, 0.999), 138 | lr=self.learning_rate, 139 | weight_decay=self.weight_decay, 140 | ) 141 | optim_flow = torch.optim.AdamW( 142 | self.noise_dist.parameters(), 143 | betas=(self.momentum, 0.999), 144 | lr=self.flow_learning_rate, 145 | weight_decay=self.weight_decay, 146 | ) 147 | return [optim, optim_flow] 148 | 149 | def get_ood_scores(self, x): 150 | return {"p(x)": self.model(x)} 151 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/hdge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import torch.nn.functional as F 4 | 5 | from uncertainty_est.utils.utils import to_np 6 | from uncertainty_est.archs.arch_factory import get_arch 7 | from uncertainty_est.models.ebm.utils.model import HDGE 8 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 9 | from uncertainty_est.models.ebm.utils.utils import ( 10 | KHotCrossEntropyLoss, 11 | smooth_one_hot, 12 | ) 13 | 14 | 15 | class HDGEModel(OODDetectionModel): 16 | def __init__( 17 | self, 18 | arch_name, 19 | arch_config, 20 | learning_rate, 21 | momentum, 22 | weight_decay, 23 | pyxce, 24 | pxcontrast, 25 | pxycontrast, 26 | smoothing, 27 | n_classes, 28 | contrast_k, 29 | contrast_t, 30 | warmup_steps=-1, 31 | ): 32 | super().__init__() 33 | self.__dict__.update(locals()) 34 | self.save_hyperparameters() 35 | 36 | arch = get_arch(arch_name, arch_config) 37 | self.model = HDGE(arch, n_classes, contrast_k, contrast_t) 38 | 39 | def forward(self, x): 40 | return self.model(x) 41 | 42 | def compute_losses(self, x_lab, dist, logits=None, evaluation=False): 43 | l_pyxce, l_pxcontrast, l_pxycontrast = 0.0, 0.0, 0.0 44 | # log p(y|x) cross entropy loss 45 | if self.pyxce > 0: 46 | if logits is None: 47 | logits = self.model.classify(x_lab) 48 | l_pyxce = KHotCrossEntropyLoss()(logits, dist) 49 | l_pyxce *= self.pyxce 50 | 51 | # log p(x) using contrastive learning 52 | if self.pxcontrast > 0: 53 | # ones like dist to use all indexes 54 | ones_dist = torch.ones_like(dist).to(self.device) 55 | output, target, _, _ = self.model.joint( 56 | img=x_lab, dist=ones_dist, evaluation=evaluation 57 | ) 58 | l_pxcontrast = F.cross_entropy(output, target) 59 | l_pxcontrast *= self.pxycontrast 60 | 61 | # log p(x|y) using contrastive learning 62 | if self.pxycontrast > 0: 63 | output, target, _, _ = self.model.joint( 64 | img=x_lab, dist=dist, evaluation=evaluation 65 | ) 66 | l_pxycontrast = F.cross_entropy(output, target) 67 | l_pxycontrast *= self.pxycontrast 68 | 69 | return l_pyxce, l_pxcontrast, l_pxycontrast 70 | 71 | def training_step(self, batch, batch_idx): 72 | x_lab, y_lab = batch 73 | dist = smooth_one_hot(y_lab, self.n_classes, self.smoothing) 74 | 75 | loss = sum(self.compute_losses(x_lab, dist)) 76 | return loss 77 | 78 | def validation_step(self, batch, batch_idx): 79 | x, y = batch 80 | logits = self.model.classify(x) 81 | dist = smooth_one_hot(y, self.n_classes, self.smoothing) 82 | 83 | self.log( 84 | "val/loss", 85 | sum(self.compute_losses(x, dist, logits=logits, evaluation=True)), 86 | ) 87 | 88 | acc = (y == logits.argmax(1)).float().mean(0).item() 89 | self.log("val_acc", acc) 90 | 91 | def validation_epoch_end(self, training_step_outputs): 92 | if self.vis_every > 0 and self.current_epoch % self.vis_every == 0: 93 | interp = torch.linspace(-10, 10, 500) 94 | x, y = torch.meshgrid(interp, interp) 95 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1) 96 | 97 | log_px = to_np(self.model(data.to(self.device))) 98 | 99 | fig, ax = plt.subplots() 100 | ax.set_title(f"log p(x)") 101 | mesh = ax.pcolormesh(to_np(x), to_np(y), log_px.reshape(*x.shape)) 102 | fig.colorbar(mesh) 103 | self.logger.experiment.add_figure("log p(x)", fig, self.current_epoch) 104 | plt.close() 105 | 106 | def test_step(self, batch, batch_idx): 107 | x, y = batch 108 | y_hat = self(x) 109 | 110 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 111 | self.log("test_acc", acc) 112 | 113 | def configure_optimizers(self): 114 | optim = torch.optim.AdamW( 115 | self.parameters(), 116 | betas=(self.momentum, 0.999), 117 | lr=self.learning_rate, 118 | weight_decay=self.weight_decay, 119 | ) 120 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5) 121 | return [optim], [scheduler] 122 | 123 | def classify(self, x): 124 | return torch.softmax(self.model.classify(x), -1) 125 | 126 | def get_ood_scores(self, x): 127 | return {"p(x)": self.model(x)} 128 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/mcmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from matplotlib import pyplot as plt 3 | 4 | from uncertainty_est.utils.utils import to_np 5 | from uncertainty_est.models.ebm.utils.model import JEM 6 | from uncertainty_est.archs.arch_factory import get_arch 7 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 8 | from uncertainty_est.models.ebm.utils.utils import ( 9 | KHotCrossEntropyLoss, 10 | smooth_one_hot, 11 | ) 12 | 13 | 14 | def init_random(buffer_size, data_shape): 15 | return torch.FloatTensor(buffer_size, *data_shape).uniform_(-1, 1) 16 | 17 | 18 | class MCMC(OODDetectionModel): 19 | def __init__( 20 | self, 21 | arch_name, 22 | arch_config, 23 | learning_rate, 24 | momentum, 25 | weight_decay, 26 | buffer_size, 27 | n_classes, 28 | data_shape, 29 | smoothing, 30 | pyxce, 31 | pxsgld, 32 | pxysgld, 33 | class_cond_p_x_sample, 34 | sgld_batch_size, 35 | sgld_lr, 36 | sgld_std, 37 | reinit_freq, 38 | sgld_steps=20, 39 | entropy_reg_weight=0.0, 40 | warmup_steps=2500, 41 | lr_step_size=50, 42 | is_toy_dataset=False, 43 | **kwargs 44 | ): 45 | super().__init__(**kwargs) 46 | self.__dict__.update(locals()) 47 | self.save_hyperparameters() 48 | 49 | if len(data_shape) == 3: 50 | self.sample_shape = [data_shape[-1], data_shape[0], data_shape[1]] 51 | else: 52 | self.sample_shape = data_shape 53 | 54 | if class_cond_p_x_sample: 55 | assert n_classes > 1 56 | 57 | arch = get_arch(arch_name, arch_config) 58 | self.model = JEM(arch) 59 | 60 | self.buffer_size = self.buffer_size - (self.buffer_size % self.n_classes) 61 | self._init_buffer() 62 | 63 | def forward(self, x): 64 | return self.model(x) 65 | 66 | def _init_buffer(self): 67 | self.replay_buffer = init_random(self.buffer_size, self.sample_shape).cpu() 68 | 69 | def training_step(self, batch, batch_idx): 70 | (x_lab, y_lab), (x_p_d, _) = batch 71 | if self.n_classes > 1: 72 | dist = smooth_one_hot(y_lab, self.n_classes, self.smoothing) 73 | else: 74 | dist = y_lab[None, :] 75 | 76 | l_pyxce, l_pxsgld, l_pxysgld = 0.0, 0.0, 0.0 77 | # log p(y|x) cross entropy loss 78 | if self.pyxce > 0: 79 | logits = self.model.classify(x_lab) 80 | l_pyxce = KHotCrossEntropyLoss()(logits, dist) 81 | l_pyxce *= self.pyxce 82 | 83 | l_pyxce += ( 84 | self.entropy_reg_weight 85 | * -torch.distributions.Categorical(logits=logits).entropy().mean() 86 | ) 87 | self.log("train/clf_loss", l_pyxce) 88 | 89 | # log p(x) using sgld 90 | if self.pxsgld > 0: 91 | if self.class_cond_p_x_sample: 92 | y_q = torch.randint(0, self.n_classes, (self.sgld_batch_size,)).to( 93 | self.device 94 | ) 95 | x_q = self.sample_q(self.replay_buffer, y=y_q) 96 | else: 97 | x_q = self.sample_q( 98 | self.replay_buffer, n_steps=self.sgld_steps 99 | ) # sample from log-sumexp 100 | 101 | fp = self.model(x_p_d) 102 | fq = self.model(x_q) 103 | l_pxsgld = -(fp.mean() - fq.mean()) + (fp ** 2).mean() + (fq ** 2).mean() 104 | l_pxsgld *= self.pxsgld 105 | 106 | # log p(x|y) using sgld 107 | if self.pxysgld > 0: 108 | x_q_lab = self.sample_q(self.replay_buffer, y=y_lab) 109 | fp, fq = self.model(x_lab).mean(), self.model(x_q_lab).mean() 110 | l_pxysgld = -(fp - fq) 111 | l_pxysgld *= self.pxysgld 112 | 113 | loss = l_pxysgld + l_pxsgld + l_pyxce 114 | self.log("train/loss", loss) 115 | return loss 116 | 117 | def validation_step(self, batch, batch_idx): 118 | (x_lab, y_lab), (_, _) = batch 119 | 120 | if self.n_classes < 2: 121 | return 122 | 123 | _, logits = self.model(x_lab, return_logits=True) 124 | acc = (y_lab == logits.argmax(1)).float().mean(0).item() 125 | self.log("val/acc", acc) 126 | 127 | def validation_epoch_end(self, outputs): 128 | if self.is_toy_dataset: 129 | interp = torch.linspace(-4, 4, 500) 130 | x, y = torch.meshgrid(interp, interp) 131 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1).to(self.device) 132 | px = to_np(torch.exp(self(data))) 133 | 134 | fig, ax = plt.subplots() 135 | mesh = ax.pcolormesh(x, y, px.reshape(*x.shape)) 136 | fig.colorbar(mesh) 137 | self.logger.experiment.add_figure("dist/p(x)", fig, self.current_epoch) 138 | plt.close() 139 | super().validation_epoch_end(outputs) 140 | 141 | def test_step(self, batch, batch_idx): 142 | x, y = batch 143 | y_hat = self.model.classify(x) 144 | 145 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 146 | self.log("test/acc", acc) 147 | 148 | return y_hat 149 | 150 | def configure_optimizers(self): 151 | optim = torch.optim.AdamW( 152 | self.parameters(), 153 | betas=(self.momentum, 0.999), 154 | lr=self.learning_rate, 155 | weight_decay=self.weight_decay, 156 | ) 157 | scheduler = torch.optim.lr_scheduler.StepLR( 158 | optim, step_size=self.lr_step_size, gamma=0.5 159 | ) 160 | return [optim], [scheduler] 161 | 162 | def classify(self, x): 163 | return torch.softmax(self.model.classify(x), -1) 164 | 165 | def get_ood_scores(self, x): 166 | return {"p(x)": self.model(x)} 167 | 168 | def sample_p_0(self, replay_buffer, bs, y=None): 169 | if len(replay_buffer) == 0: 170 | return init_random(bs, self.sample_shape), [] 171 | 172 | buffer_size = ( 173 | len(replay_buffer) if y is None else len(replay_buffer) // self.n_classes 174 | ) 175 | inds = torch.randint(0, buffer_size, (bs,)) 176 | # if cond, convert inds to class conditional inds 177 | if y is not None: 178 | inds = y.cpu() * buffer_size + inds 179 | 180 | buffer_samples = replay_buffer[inds].to(self.device) 181 | random_samples = init_random(bs, self.sample_shape).to(self.device) 182 | choose_random = (torch.rand(bs) < self.reinit_freq).to(buffer_samples)[ 183 | (...,) + (None,) * len(self.sample_shape) 184 | ] 185 | samples = choose_random * random_samples + (1 - choose_random) * buffer_samples 186 | return samples.to(self.device), inds 187 | 188 | def sample_q(self, replay_buffer, y=None, n_steps=20, contrast=False): 189 | self.model.eval() 190 | bs = self.sgld_batch_size if y is None else y.size(0) 191 | 192 | # generate initial samples and buffer inds of those samples (if buffer is used) 193 | init_sample, buffer_inds = self.sample_p_0(replay_buffer, bs=bs, y=y) 194 | x_k = torch.autograd.Variable(init_sample, requires_grad=True) 195 | 196 | # sgld 197 | for _ in range(n_steps): 198 | if not contrast: 199 | energy = self.model(x_k, y=y).sum(0) 200 | else: 201 | if y is not None: 202 | dist = smooth_one_hot(y, self.n_classes, self.smoothing) 203 | else: 204 | dist = torch.ones((bs, self.n_classes)).to(self.device) 205 | output, target, _, _ = self.model.joint( 206 | img=x_k, dist=dist, evaluation=True 207 | ) 208 | energy = -1.0 * F.cross_entropy(output, target) 209 | f_prime = torch.autograd.grad(energy, [x_k], retain_graph=True)[0] 210 | x_k.data += self.sgld_lr * f_prime + self.sgld_std * torch.randn_like(x_k) 211 | self.model.train() 212 | final_samples = x_k.detach() 213 | 214 | # update replay buffer 215 | if len(replay_buffer) > 0: 216 | replay_buffer[buffer_inds] = final_samples.cpu() 217 | return final_samples 218 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/nce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import torch.nn.functional as F 4 | from torch import distributions 5 | import matplotlib.pyplot as plt 6 | 7 | from uncertainty_est.archs.arch_factory import get_arch 8 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 9 | from uncertainty_est.utils.utils import ( 10 | to_np, 11 | estimate_normalizing_constant, 12 | sum_except_batch, 13 | ) 14 | 15 | 16 | class NoiseContrastiveEstimation(OODDetectionModel): 17 | """Implementation of Noise Contrastive Estimation http://proceedings.mlr.press/v9/gutmann10a.html""" 18 | 19 | def __init__( 20 | self, 21 | arch_name, 22 | arch_config, 23 | learning_rate, 24 | momentum, 25 | weight_decay, 26 | noise_distribution="uniform", 27 | noise_distribution_kwargs={"low": 0, "high": 1}, 28 | **kwargs, 29 | ): 30 | super().__init__() 31 | self.__dict__.update(locals()) 32 | self.save_hyperparameters() 33 | 34 | self.model = get_arch(arch_name, arch_config) 35 | 36 | if noise_distribution == "uniform": 37 | noise_dist = distributions.Uniform 38 | if noise_distribution == "gaussian": 39 | noise_dist = distributions.Normal 40 | else: 41 | raise NotImplementedError( 42 | f"Requested noise distribution {noise_distribution} not implemented." 43 | ) 44 | 45 | self.dist_parameters = torch.nn.ParameterDict( 46 | { 47 | k: torch.nn.Parameter(torch.tensor(v).float(), requires_grad=False) 48 | for k, v in noise_distribution_kwargs.items() 49 | } 50 | ) 51 | self.noise_dist = noise_dist(**self.dist_parameters) 52 | 53 | def forward(self, x): 54 | return self.model(x) 55 | 56 | def compute_ebm_loss(self, batch, return_outputs=False): 57 | x, _ = batch 58 | noise = self.noise_dist.sample(x.shape).to(self.device) 59 | inp = torch.cat((x, noise)) 60 | 61 | logits = self.model(inp) 62 | log_p_model = logits.logsumexp(-1) 63 | log_p_noise = sum_except_batch(self.noise_dist.log_prob(inp)) 64 | 65 | loss = F.binary_cross_entropy_with_logits( 66 | log_p_model - log_p_noise, 67 | torch.cat((torch.ones(len(x)), torch.zeros(len(x)))).to(self.device), 68 | ) 69 | if return_outputs: 70 | return loss, logits[: len(x)] 71 | return loss 72 | 73 | def training_step(self, batch, batch_idx): 74 | loss = self.compute_ebm_loss(batch) 75 | 76 | self.log("train/loss", loss) 77 | return loss 78 | 79 | def validation_step(self, batch, batch_idx): 80 | return 81 | 82 | def test_step(self, batch, batch_idx): 83 | self.to(torch.float32) 84 | x, y = batch 85 | y_hat = self.model(x) 86 | 87 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 88 | self.log("test_acc", acc) 89 | return y_hat 90 | 91 | def configure_optimizers(self): 92 | optim = torch.optim.AdamW( 93 | self.parameters(), 94 | betas=(self.momentum, 0.999), 95 | lr=self.learning_rate, 96 | weight_decay=self.weight_decay, 97 | ) 98 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5) 99 | return [optim], [scheduler] 100 | 101 | def get_ood_scores(self, x): 102 | return {"p(x)": self.model(x)} 103 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/ssm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd 3 | import torch.nn.functional as F 4 | 5 | from uncertainty_est.utils.utils import to_np 6 | from uncertainty_est.models.ebm.utils.model import JEM 7 | from uncertainty_est.archs.arch_factory import get_arch 8 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 9 | 10 | 11 | class SSM(OODDetectionModel): 12 | def __init__( 13 | self, 14 | arch_name, 15 | arch_config, 16 | learning_rate, 17 | momentum, 18 | weight_decay, 19 | n_classes, 20 | clf_weight, 21 | noise_type="radermacher", 22 | n_particles=1, 23 | warmup_steps=2500, 24 | lr_step_size=50, 25 | is_toy_dataset=False, 26 | **kwargs 27 | ): 28 | super().__init__(**kwargs) 29 | self.__dict__.update(locals()) 30 | self.save_hyperparameters() 31 | 32 | arch = get_arch(arch_name, arch_config) 33 | self.model = JEM(arch) 34 | 35 | def forward(self, x): 36 | return self.model(x) 37 | 38 | def training_step(self, batch, batch_idx): 39 | (x_lab, y_lab), (x_p_d, _) = batch 40 | dup_samples = ( 41 | x_p_d.unsqueeze(0) 42 | .expand(self.n_particles, *x_p_d.shape) 43 | .contiguous() 44 | .view(-1, *x_p_d.shape[1:]) 45 | ) 46 | dup_samples.requires_grad_(True) 47 | 48 | vectors = torch.randn_like(dup_samples) 49 | if self.noise_type == "radermacher": 50 | vectors = vectors.sign() 51 | elif self.noise_type == "gaussian": 52 | pass 53 | else: 54 | raise ValueError("Noise type not implemented") 55 | 56 | logp = self.model(dup_samples).sum() 57 | 58 | grad1 = autograd.grad(logp, dup_samples, create_graph=True)[0] 59 | loss1 = torch.sum(grad1 * grad1, dim=-1) / 2.0 60 | gradv = torch.sum(grad1 * vectors) 61 | 62 | grad2 = autograd.grad(gradv, dup_samples, create_graph=True)[0] 63 | loss2 = torch.sum(vectors * grad2, dim=-1) 64 | 65 | loss1 = loss1.view(self.n_particles, -1).mean(dim=0) 66 | loss2 = loss2.view(self.n_particles, -1).mean(dim=0) 67 | 68 | ssm_loss = (loss1 + loss2).mean() 69 | self.log("train/ssm_loss", ssm_loss) 70 | 71 | clf_loss = 0.0 72 | if self.clf_weight > 0.0: 73 | _, logits = self.model(x_lab, return_logits=True) 74 | clf_loss = self.clf_weight * F.cross_entropy(logits, y_lab) 75 | self.log("train/clf_loss", clf_loss) 76 | return ssm_loss + clf_loss 77 | 78 | def validation_step(self, batch, batch_idx): 79 | (x_lab, y_lab), (_, _) = batch 80 | 81 | if self.n_classes < 2: 82 | return 83 | 84 | _, logits = self.model(x_lab, return_logits=True) 85 | acc = (y_lab == logits.argmax(1)).float().mean(0).item() 86 | self.log("val/acc", acc) 87 | 88 | def test_step(self, batch, batch_idx): 89 | x, y = batch 90 | y_hat = self.model.classify(x) 91 | 92 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 93 | self.log("test.acc", acc) 94 | 95 | return y_hat 96 | 97 | def configure_optimizers(self): 98 | optim = torch.optim.AdamW( 99 | self.parameters(), 100 | betas=(self.momentum, 0.999), 101 | lr=self.learning_rate, 102 | weight_decay=self.weight_decay, 103 | ) 104 | scheduler = torch.optim.lr_scheduler.StepLR( 105 | optim, step_size=self.lr_step_size, gamma=0.5 106 | ) 107 | return [optim], [scheduler] 108 | 109 | def classify(self, x): 110 | return torch.softmax(self.model.classify(x), -1) 111 | 112 | def get_ood_scores(self, x): 113 | return {"p(x)": self.model(x)} 114 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/models/ebm/utils/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class JEM(nn.Module): 6 | def __init__(self, model): 7 | super().__init__() 8 | self.f = model 9 | 10 | def forward(self, x, return_logits=False, y=None): 11 | logits = self.classify(x) 12 | 13 | if y is not None: 14 | return logits[torch.arange(len(x)), y] 15 | 16 | if return_logits: 17 | return logits.logsumexp(1), logits 18 | else: 19 | return logits.logsumexp(1) 20 | 21 | def classify(self, x): 22 | return self.f(x) 23 | 24 | 25 | class HDGE(JEM): 26 | def __init__(self, model, n_classes, contrast_k, contrast_t): 27 | super(HDGE, self).__init__(model) 28 | 29 | self.K = contrast_k 30 | self.T = contrast_t 31 | self.dim = n_classes 32 | 33 | # create the queue 34 | init_logit = torch.randn(n_classes, contrast_k) 35 | self.register_buffer("queue_logit", init_logit) 36 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 37 | 38 | @torch.no_grad() 39 | def _dequeue_and_enqueue(self, logits): 40 | # gather logits before updating queue 41 | batch_size = logits.shape[0] 42 | 43 | ptr = int(self.queue_ptr) 44 | assert self.K % batch_size == 0 # for simplicity 45 | 46 | # replace the logits at ptr (dequeue and enqueue) 47 | self.queue_logit[:, ptr : ptr + batch_size] = logits.T 48 | 49 | ptr = (ptr + batch_size) % self.K # move pointer 50 | 51 | self.queue_ptr[0] = ptr 52 | 53 | def joint(self, img, dist=None, evaluation=False): 54 | f_logit = self.class_output(self.f(img)) # queries: NxC 55 | ce_logit = f_logit # cross-entropy loss logits 56 | prob = nn.functional.normalize(f_logit, dim=1) 57 | # positive logits: Nx1 58 | l_pos = dist * prob # NxC 59 | l_pos = torch.logsumexp(l_pos, dim=1, keepdim=True) # Nx1 60 | # negative logits: NxK 61 | buffer = nn.functional.normalize(self.queue_logit.clone().detach(), dim=0) 62 | l_neg = torch.einsum("nc,ck->nck", [dist, buffer]) # NxCxK 63 | l_neg = torch.logsumexp(l_neg, dim=1) # NxK 64 | 65 | # logits: Nx(1+K) 66 | logits = torch.cat([l_pos, l_neg], dim=1) 67 | 68 | # apply temperature 69 | logits /= self.T 70 | 71 | # labels: positive key indicators 72 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 73 | 74 | # dequeue and enqueue 75 | if not evaluation: 76 | self._dequeue_and_enqueue(f_logit) 77 | 78 | return logits, labels, ce_logit, l_neg.size(1) 79 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class KHotCrossEntropyLoss(nn.Module): 6 | def __init__(self, dim=-1): 7 | super(KHotCrossEntropyLoss, self).__init__() 8 | self.dim = dim 9 | 10 | def forward(self, pred, target): 11 | pred = pred.log_softmax(dim=self.dim) 12 | return torch.mean(torch.sum(-target * pred, dim=self.dim)) 13 | 14 | 15 | def smooth_one_hot(labels, classes, smoothing=0.0): 16 | """ 17 | if smoothing == 0, it's one-hot method 18 | if 0 < smoothing < 1, it's smooth method 19 | """ 20 | assert 0 <= smoothing < 1 21 | label_shape = torch.Size((labels.size(0), classes)) 22 | with torch.no_grad(): 23 | dist = torch.empty(size=label_shape, device=labels.device) 24 | dist.fill_(smoothing / (classes - 1)) 25 | dist.scatter_(1, labels.data.unsqueeze(-1), 1.0 - smoothing) 26 | return dist 27 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/vera.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | 4 | from uncertainty_est.models.ebm.utils.model import JEM 5 | from uncertainty_est.archs.arch_factory import get_arch 6 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 7 | from uncertainty_est.models.ebm.utils.vera_utils import ( 8 | VERADiscreteGenerator, 9 | VERAGenerator, 10 | VERAHMCGenerator, 11 | set_bn_to_eval, 12 | set_bn_to_train, 13 | ) 14 | 15 | 16 | class VERA(OODDetectionModel): 17 | def __init__( 18 | self, 19 | arch_name, 20 | arch_config, 21 | learning_rate, 22 | beta1, 23 | beta2, 24 | weight_decay, 25 | n_classes, 26 | gen_learning_rate, 27 | ebm_iters, 28 | generator_iters, 29 | entropy_weight, 30 | generator_type, 31 | generator_arch_name, 32 | generator_arch_config, 33 | generator_config, 34 | min_sigma, 35 | max_sigma, 36 | p_control, 37 | n_control, 38 | pg_control, 39 | clf_ent_weight, 40 | ebm_type, 41 | clf_weight, 42 | warmup_steps, 43 | no_g_batch_norm, 44 | batch_size, 45 | lr_decay, 46 | lr_decay_epochs, 47 | **kwargs, 48 | ): 49 | super().__init__(**kwargs) 50 | self.__dict__.update(locals()) 51 | self.save_hyperparameters() 52 | self.automatic_optimization = False 53 | 54 | arch = get_arch(arch_name, arch_config) 55 | self.model = JEM(arch) 56 | 57 | g = get_arch(generator_arch_name, generator_arch_config) 58 | if generator_type == "verahmc": 59 | self.generator = VERAHMCGenerator(g, **generator_config) 60 | elif generator_type == "vera": 61 | self.generator = VERAGenerator(g, **generator_config) 62 | elif generator_type == "vera_discrete": 63 | self.generator = VERADiscreteGenerator(g, **generator_config) 64 | else: 65 | raise NotImplementedError(f"Generator '{generator_type}' not implemented!") 66 | 67 | self.predict_mode = "density" 68 | 69 | def forward(self, x): 70 | if self.predict_mode == "density": 71 | return self.model(x) 72 | elif self.predict_mode in ("logits", "probs"): 73 | _, logits = self.model(x, return_logits=True) 74 | if self.predict_mode == "logits": 75 | return logits 76 | else: 77 | return torch.softmax(logits, -1) 78 | 79 | def training_step(self, batch, batch_idx, optimizer_idx): 80 | opt_e, opt_g = self.optimizers() 81 | (x_l, y_l), (x_d, _) = batch 82 | 83 | x_l.requires_grad_() 84 | x_d.requires_grad_() 85 | 86 | # sample from q(x, h) 87 | x_g, h_g = self.generator.sample(x_l.size(0), requires_grad=True) 88 | 89 | # ebm (contrastive divergence) objective 90 | if batch_idx % self.ebm_iters == 0: 91 | ebm_loss = self.ebm_step(x_d, x_l, x_g, y_l) 92 | 93 | self.log("train/ebm_loss", ebm_loss, prog_bar=True) 94 | 95 | opt_e.zero_grad() 96 | self.manual_backward(ebm_loss, opt_e) 97 | opt_e.step() 98 | 99 | # gen obj 100 | if batch_idx % self.generator_iters == 0: 101 | gen_loss = self.generator_step(x_g, h_g) 102 | 103 | self.log("train/gen_loss", gen_loss, prog_bar=True) 104 | 105 | opt_g.zero_grad() 106 | self.manual_backward(gen_loss, opt_g) 107 | opt_g.step() 108 | 109 | # clamp sigma to (.01, max_sigma) for generators 110 | if self.generator_type in ["verahmc", "vera"]: 111 | self.generator.clamp_sigma(self.max_sigma, sigma_min=self.min_sigma) 112 | 113 | def ebm_step(self, x_d, x_l, x_g, y_l): 114 | x_g_detach = x_g.detach().requires_grad_() 115 | 116 | if self.no_g_batch_norm: 117 | self.model.apply(set_bn_to_eval) 118 | lg_detach, lg_logits = self.model(x_g_detach, return_logits=True) 119 | self.model.apply(set_bn_to_train) 120 | else: 121 | lg_detach, lg_logits = self.model(x_g_detach, return_logits=True) 122 | 123 | unsup_ent = torch.tensor(0.0) 124 | if self.ebm_type == "ssl": 125 | ld, unsup_logits = self.model(x_d, return_logits=True) 126 | _, ld_logits = self.model(x_l, return_logits=True) 127 | unsup_ent = distributions.Categorical(logits=unsup_logits).entropy() 128 | elif self.ebm_type == "jem": 129 | ld, ld_logits = self.model(x_l, return_logits=True) 130 | self.log("train/acc", (ld_logits.argmax(1) == y_l).float().mean(0)) 131 | elif self.ebm_type == "p_x": 132 | ld, ld_logits = self.model(x_l).squeeze(), torch.tensor(0.0).to(self.device) 133 | else: 134 | raise NotImplementedError(f"EBM type '{self.ebm_type}' not implemented!") 135 | 136 | logp_obj = ld.mean() - lg_detach.mean() 137 | e_loss = ( 138 | -logp_obj 139 | + self.p_control * (ld ** 2).mean() 140 | + self.n_control * (lg_detach ** 2).mean() 141 | + self.clf_ent_weight * unsup_ent.mean() 142 | ) 143 | 144 | if self.pg_control > 0: 145 | grad_ld = ( 146 | torch.autograd.grad(ld.mean(), x_l, create_graph=True)[0] 147 | .flatten(start_dim=1) 148 | .norm(2, 1) 149 | ) 150 | e_loss += self.pg_control * (grad_ld ** 2.0 / 2.0).mean() 151 | 152 | self.log("train/e_loss", e_loss.item()) 153 | 154 | if self.clf_weight > 0: 155 | clf_loss = self.clf_weight * self.classifier_loss(ld_logits, y_l, lg_logits) 156 | self.log("train/clf_loss", clf_loss) 157 | e_loss += clf_loss 158 | 159 | return e_loss 160 | 161 | def classifier_loss(self, ld_logits, y_l, lg_logits): 162 | return torch.nn.CrossEntropyLoss()(ld_logits, y_l) 163 | 164 | def generator_step(self, x_g, h_g): 165 | lg = self.model(x_g).squeeze() 166 | grad = torch.autograd.grad(lg.sum(), x_g, retain_graph=True)[0] 167 | ebm_gn = grad.norm(2, 1).mean() 168 | 169 | if self.entropy_weight != 0.0: 170 | entropy_obj, ent_gn = self.generator.entropy_obj(x_g, h_g) 171 | 172 | logq_obj = lg.mean() + self.entropy_weight * entropy_obj 173 | return -logq_obj 174 | 175 | def validation_step(self, batch, batch_idx): 176 | (x_l, y_l), _ = batch 177 | ld, ld_logits = self.model(x_l, return_logits=True) 178 | 179 | self.log("val/loss", -ld.mean()) 180 | 181 | # Performing density estimation only 182 | if ld_logits.shape[1] < 2: 183 | return 184 | 185 | acc = (y_l == ld_logits.argmax(1)).float().mean(0) 186 | self.log("val/acc", acc) 187 | return ld_logits 188 | 189 | def test_step(self, batch, batch_idx): 190 | x, y = batch 191 | _, y_hat = self.model(x, return_logits=True) 192 | 193 | if self.n_classes < 2: 194 | return 195 | 196 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 197 | self.log("acc", acc) 198 | 199 | return y_hat 200 | 201 | def configure_optimizers(self): 202 | optim = torch.optim.AdamW( 203 | self.model.parameters(), 204 | betas=(self.beta1, self.beta2), 205 | lr=self.learning_rate, 206 | weight_decay=self.weight_decay, 207 | ) 208 | gen_optim = torch.optim.AdamW( 209 | self.generator.parameters(), 210 | betas=(self.beta1, self.beta2), 211 | lr=self.gen_learning_rate, 212 | weight_decay=self.weight_decay, 213 | ) 214 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 215 | optim, gamma=self.lr_decay, milestones=self.lr_decay_epochs 216 | ) 217 | gen_scheduler = torch.optim.lr_scheduler.MultiStepLR( 218 | gen_optim, gamma=self.lr_decay, milestones=self.lr_decay_epochs 219 | ) 220 | return [optim, gen_optim], [scheduler, gen_scheduler] 221 | 222 | def classify(self, x): 223 | return self.model.classify(x).softmax(-1) 224 | 225 | def get_ood_scores(self, x): 226 | return {"p(x)": self.model(x)} 227 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/vera_posteriornet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.distributions import Dirichlet 4 | 5 | from uncertainty_est.models.ebm.vera import VERA 6 | from uncertainty_est.models.priornet.dpn_losses import dirichlet_kl_divergence 7 | from uncertainty_est.models.priornet.uncertainties import ( 8 | dirichlet_prior_network_uncertainty, 9 | ) 10 | 11 | 12 | class VERAPosteriorNet(VERA): 13 | def __init__( 14 | self, 15 | arch_name, 16 | arch_config, 17 | learning_rate, 18 | beta1, 19 | beta2, 20 | weight_decay, 21 | n_classes, 22 | gen_learning_rate, 23 | ebm_iters, 24 | generator_iters, 25 | entropy_weight, 26 | generator_type, 27 | generator_arch_name, 28 | generator_arch_config, 29 | generator_config, 30 | min_sigma, 31 | max_sigma, 32 | p_control, 33 | n_control, 34 | pg_control, 35 | clf_ent_weight, 36 | ebm_type, 37 | clf_weight, 38 | warmup_steps, 39 | no_g_batch_norm, 40 | batch_size, 41 | lr_decay, 42 | lr_decay_epochs, 43 | alpha_fix=True, 44 | entropy_reg=0.0, 45 | **kwargs, 46 | ): 47 | if n_control is None: 48 | n_control = p_control 49 | 50 | super().__init__( 51 | arch_name, 52 | arch_config, 53 | learning_rate, 54 | beta1, 55 | beta2, 56 | weight_decay, 57 | n_classes, 58 | gen_learning_rate, 59 | ebm_iters, 60 | generator_iters, 61 | entropy_weight, 62 | generator_type, 63 | generator_arch_name, 64 | generator_arch_config, 65 | generator_config, 66 | min_sigma, 67 | max_sigma, 68 | p_control, 69 | n_control, 70 | pg_control, 71 | clf_ent_weight, 72 | ebm_type, 73 | clf_weight, 74 | warmup_steps, 75 | no_g_batch_norm, 76 | batch_size, 77 | lr_decay, 78 | lr_decay_epochs, 79 | sample_term=0.0, 80 | **kwargs, 81 | ) 82 | self.__dict__.update(locals()) 83 | self.save_hyperparameters() 84 | 85 | def classifier_loss(self, ld_logits, y_l, lg_logits): 86 | alpha = torch.exp(ld_logits) # / self.p_y.unsqueeze(0).to(self.device) 87 | # Multiply by class counts for Bayesian update 88 | 89 | if self.alpha_fix: 90 | alpha = alpha + 1 91 | 92 | soft_output = F.one_hot(y_l, self.n_classes) 93 | alpha_0 = alpha.sum(1).unsqueeze(-1).repeat(1, self.n_classes) 94 | UCE_loss = torch.mean( 95 | soft_output * (torch.digamma(alpha_0) - torch.digamma(alpha)) 96 | ) 97 | UCE_loss = UCE_loss + self.clf_ent_weight * -Dirichlet(alpha).entropy().mean() 98 | 99 | import pdb 100 | 101 | pdb.set_trace() 102 | lg_alpha = torch.exp(lg_logits) 103 | if self.alpha_fix: 104 | lg_alpha = lg_alpha + 1 105 | sample_loss = self.sample_term * -Dirichlet(lg_alpha).entropy().mean() 106 | 107 | return UCE_loss + sample_loss 108 | 109 | def validation_epoch_end(self, outputs): 110 | super().validation_epoch_end(outputs) 111 | alphas = torch.exp(outputs[0]).reshape(-1) + 1 if self.alpha_fix else 0 112 | self.logger.experiment.add_histogram("alphas", alphas, self.current_epoch) 113 | 114 | def get_ood_scores(self, x): 115 | px, logits = self.model(x, return_logits=True) 116 | uncert = {} 117 | uncert["p(x)"] = px 118 | dirichlet_uncerts = dirichlet_prior_network_uncertainty( 119 | logits.cpu().numpy(), alpha_correction=self.alpha_fix 120 | ) 121 | uncert = {**uncert, **dirichlet_uncerts} 122 | return uncert 123 | -------------------------------------------------------------------------------- /uncertainty_est/models/ebm/vera_priornet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from uncertainty_est.models.ebm.vera import VERA 4 | from uncertainty_est.models.priornet.dpn_losses import UnfixedDirichletKLLoss 5 | from uncertainty_est.models.priornet.uncertainties import ( 6 | dirichlet_prior_network_uncertainty, 7 | ) 8 | 9 | 10 | class VERAPriorNet(VERA): 11 | def __init__( 12 | self, 13 | arch_name, 14 | arch_config, 15 | learning_rate, 16 | beta1, 17 | beta2, 18 | weight_decay, 19 | n_classes, 20 | gen_learning_rate, 21 | ebm_iters, 22 | generator_iters, 23 | entropy_weight, 24 | generator_type, 25 | generator_arch_name, 26 | generator_arch_config, 27 | generator_config, 28 | min_sigma, 29 | max_sigma, 30 | p_control, 31 | n_control, 32 | pg_control, 33 | clf_ent_weight, 34 | ebm_type, 35 | clf_weight, 36 | warmup_steps, 37 | no_g_batch_norm, 38 | batch_size, 39 | lr_decay, 40 | lr_decay_epochs, 41 | alpha_fix=True, 42 | concentration=1.0, 43 | target_concentration=None, 44 | entropy_reg=0.0, 45 | reverse_kl=True, 46 | w_neg_sample_loss=1.0, 47 | **kwargs, 48 | ): 49 | super().__init__( 50 | arch_name, 51 | arch_config, 52 | learning_rate, 53 | beta1, 54 | beta2, 55 | weight_decay, 56 | n_classes, 57 | gen_learning_rate, 58 | ebm_iters, 59 | generator_iters, 60 | entropy_weight, 61 | generator_type, 62 | generator_arch_name, 63 | generator_arch_config, 64 | generator_config, 65 | min_sigma, 66 | max_sigma, 67 | p_control, 68 | n_control, 69 | pg_control, 70 | clf_ent_weight, 71 | ebm_type, 72 | clf_weight, 73 | warmup_steps, 74 | no_g_batch_norm, 75 | batch_size, 76 | lr_decay, 77 | lr_decay_epochs, 78 | **kwargs, 79 | ) 80 | self.__dict__.update(locals()) 81 | self.save_hyperparameters() 82 | 83 | self.clf_loss = UnfixedDirichletKLLoss( 84 | concentration, target_concentration, entropy_reg, reverse_kl, alpha_fix 85 | ) 86 | 87 | def classifier_loss(self, ld_logits, y_l, lg_logits): 88 | loss = self.clf_loss(ld_logits, y_l) 89 | 90 | loss_ood = 0.0 91 | if self.w_neg_sample_loss > 0: 92 | loss_ood = self.w_neg_sample_loss * self.clf_loss(lg_logits) 93 | self.log("train/clf_loss", loss + loss_ood) 94 | return loss 95 | 96 | def validation_epoch_end(self, outputs): 97 | super().validation_epoch_end(outputs) 98 | alphas = torch.exp(outputs[0]).reshape(-1) + self.concentration 99 | self.logger.experiment.add_histogram("alphas", alphas, self.current_epoch) 100 | 101 | def get_ood_scores(self, x): 102 | px, logits = self.model(x, return_logits=True) 103 | uncert = {} 104 | uncert["p(x)"] = px 105 | dirichlet_uncerts = dirichlet_prior_network_uncertainty( 106 | logits.cpu().numpy(), alpha_correction=self.alpha_fix 107 | ) 108 | uncert = {**uncert, **dirichlet_uncerts} 109 | return uncert 110 | -------------------------------------------------------------------------------- /uncertainty_est/models/energy_finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | from uncertainty_est.models.ce_baseline import CEBaseline 6 | 7 | 8 | class EnergyFinetune(CEBaseline): 9 | def __init__( 10 | self, 11 | arch_name, 12 | arch_config, 13 | learning_rate, 14 | momentum, 15 | weight_decay, 16 | score, 17 | m_in, 18 | m_out, 19 | checkpoint, 20 | max_steps, 21 | **kwargs 22 | ): 23 | super().__init__( 24 | arch_name, arch_config, learning_rate, momentum, weight_decay, **kwargs 25 | ) 26 | self.__dict__.update(locals()) 27 | self.save_hyperparameters() 28 | self.load_state_dict(torch.load(checkpoint)["state_dict"]) 29 | 30 | def forward(self, x): 31 | return self.backbone(x) 32 | 33 | def training_step(self, batch, batch_idx): 34 | (x, y), (x_ood, _) = batch 35 | 36 | y_hat = self(torch.cat((x, x_ood))) 37 | y_hat_ood = y_hat[len(x) :] 38 | y_hat = y_hat[: len(x)] 39 | loss = F.cross_entropy(y_hat, y) 40 | self.log("train_ce_loss", loss, prog_bar=True) 41 | 42 | # cross-entropy from softmax distribution to uniform distribution 43 | if self.score == "energy": 44 | Ec_out = -torch.logsumexp(y_hat_ood, dim=1) 45 | Ec_in = -torch.logsumexp(y_hat, dim=1) 46 | margin_loss = 0.1 * ( 47 | (F.relu(Ec_in - self.m_in) ** 2).mean() 48 | + (F.relu(self.m_out - Ec_out) ** 2).mean() 49 | ) 50 | self.log("train_margin_loss", margin_loss, prog_bar=True) 51 | loss += margin_loss 52 | elif self.score == "OE": 53 | loss += ( 54 | 0.5 * -(y_hat_ood.mean(1) - torch.logsumexp(y_hat_ood, dim=1)).mean() 55 | ) 56 | 57 | self.log("train_loss", loss) 58 | return loss 59 | 60 | def validation_step(self, batch, batch_idx): 61 | x, y = batch 62 | y_hat = self.backbone(x) 63 | 64 | loss = F.cross_entropy(y_hat, y) 65 | self.log("val/loss", loss) 66 | 67 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 68 | self.log("val_acc", acc) 69 | 70 | def test_step(self, batch, batch_idx): 71 | x, y = batch 72 | y_hat = self.backbone(x) 73 | 74 | acc = (y == y_hat.argmax(1)).float().mean(0).item() 75 | self.log("test_acc", acc) 76 | 77 | def configure_optimizers(self): 78 | optim = torch.optim.AdamW( 79 | self.parameters(), 80 | betas=(self.momentum, 0.999), 81 | lr=self.learning_rate, 82 | weight_decay=self.weight_decay, 83 | ) 84 | 85 | def cosine_annealing(step, total_steps, lr_max, lr_min): 86 | return lr_min + (lr_max - lr_min) * 0.5 * ( 87 | 1 + np.cos(step / total_steps * np.pi) 88 | ) 89 | 90 | scheduler = torch.optim.lr_scheduler.LambdaLR( 91 | optim, 92 | lr_lambda=lambda step: cosine_annealing( 93 | step, 94 | self.max_steps, 95 | 1, # since lr_lambda computes multiplicative factor 96 | 1e-6 / self.learning_rate, 97 | ), 98 | ) 99 | return [optim], [scheduler] 100 | 101 | def ood_detect(self, loader): 102 | _, logits = self.get_gt_preds(loader) 103 | 104 | uncert = {} 105 | uncert["Energy"] = torch.logsumexp(logits, 1) 106 | return uncert 107 | -------------------------------------------------------------------------------- /uncertainty_est/models/normalizing_flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/models/normalizing_flow/__init__.py -------------------------------------------------------------------------------- /uncertainty_est/models/normalizing_flow/approx_flow.py: -------------------------------------------------------------------------------- 1 | from uncertainty_est.models.normalizing_flow.norm_flow import NormalizingFlow 2 | 3 | 4 | class ApproxNormalizingFlow(NormalizingFlow): 5 | def __init__( 6 | self, 7 | density_type, 8 | latent_dim, 9 | n_density, 10 | learning_rate, 11 | momentum, 12 | weight_decay, 13 | weight_penalty_weight=1.0, 14 | ): 15 | super().__init__( 16 | density_type, 17 | latent_dim, 18 | n_density, 19 | learning_rate, 20 | momentum, 21 | weight_decay, 22 | ) 23 | assert density_type in ("orthogonal_flow", "reparameterized_flow") 24 | self.__dict__.update(locals()) 25 | self.save_hyperparameters() 26 | 27 | def training_step(self, batch, batch_idx): 28 | x, _ = batch 29 | log_p = self.density_estimation.log_prob(x) 30 | 31 | loss = -log_p.mean() 32 | self.log("train/ml_loss", loss, on_epoch=True) 33 | 34 | weight_penalty = 0.0 35 | transforms = self.density_estimation.transforms 36 | for t in transforms: 37 | weight_penalty += t.compute_weight_penalty() 38 | weight_penalty /= len(transforms) 39 | self.log("train/weight_penalty", weight_penalty, on_epoch=True) 40 | loss += self.weight_penalty_weight * weight_penalty 41 | 42 | return loss 43 | -------------------------------------------------------------------------------- /uncertainty_est/models/normalizing_flow/image_flows.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from uncertainty_est.archs.glow.glow import Glow 6 | from uncertainty_est.archs.real_nvp.real_nvp import RealNVP 7 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 8 | 9 | 10 | class AffineCouplingModel(OODDetectionModel): 11 | def __init__(self, learning_rate, momentum, weight_decay, num_classes=1, **kwargs): 12 | super().__init__(**kwargs) 13 | self.__dict__.update(locals()) 14 | self.conditional_densities = [] 15 | 16 | def forward(self, x): 17 | log_p_xy = [] 18 | for cd in self.conditional_densities: 19 | log_p_xy.append(cd.log_prob(x)) 20 | log_p_xy = torch.stack(log_p_xy, 1) 21 | 22 | return log_p_xy 23 | 24 | def training_step(self, batch, batch_idx): 25 | x, y = batch 26 | log_p_xy = self(x) 27 | log_p_x = torch.logsumexp(log_p_xy, dim=1) 28 | 29 | if self.num_classes > 1: 30 | loss = F.cross_entropy(log_p_xy, y) 31 | self.log("train/clf_loss", loss) 32 | else: 33 | loss = -log_p_x.mean() 34 | self.log("train/loss", loss) 35 | 36 | return loss 37 | 38 | def validation_step(self, batch, batch_idx): 39 | x, y = batch 40 | log_p_xy = self(x) 41 | log_p_x = torch.logsumexp(log_p_xy, dim=1) 42 | 43 | loss = -log_p_x.mean() 44 | self.log("val/loss", loss) 45 | 46 | acc = (y == log_p_xy.argmax(1)).float().mean(0).item() 47 | self.log("val/acc", acc) 48 | 49 | def test_step(self, batch, batch_idx): 50 | x, y = batch 51 | log_p_xy = self(x) 52 | log_p_x = torch.logsumexp(log_p_xy, dim=1) 53 | self.log("log_likelihood", log_p_x.mean()) 54 | 55 | acc = (y == log_p_xy.argmax(1)).float().mean(0).item() 56 | self.log("acc", acc) 57 | 58 | def configure_optimizers(self): 59 | optim = torch.optim.AdamW( 60 | self.parameters(), 61 | betas=(self.momentum, 0.999), 62 | lr=self.learning_rate, 63 | weight_decay=self.weight_decay, 64 | ) 65 | return optim 66 | 67 | def classify(self, x): 68 | log_p_xy = self(x) 69 | return log_p_xy.softmax(-1) 70 | 71 | def get_ood_scores(self, x): 72 | log_p_xy = self(x) 73 | log_p_x = torch.logsumexp(log_p_xy, dim=1) 74 | return {"p(x)": log_p_x} 75 | 76 | 77 | class RealNVPModel(AffineCouplingModel): 78 | def __init__( 79 | self, 80 | num_scales, 81 | in_channels, 82 | mid_channels, 83 | num_blocks, 84 | learning_rate, 85 | momentum, 86 | weight_decay, 87 | num_classes=1, 88 | **kwargs 89 | ): 90 | super().__init__(learning_rate, momentum, weight_decay, num_classes=1, **kwargs) 91 | self.__dict__.update(locals()) 92 | self.save_hyperparameters() 93 | 94 | self.conditional_densities = nn.ModuleList() 95 | for _ in range(num_classes): 96 | self.conditional_densities.append( 97 | RealNVP(num_scales, in_channels, mid_channels, num_blocks) 98 | ) 99 | 100 | 101 | class GlowModel(AffineCouplingModel): 102 | def __init__( 103 | self, 104 | in_channels, 105 | num_channels, 106 | num_levels, 107 | num_steps, 108 | learning_rate, 109 | momentum, 110 | weight_decay, 111 | num_classes=1, 112 | **kwargs 113 | ): 114 | super().__init__(learning_rate, momentum, weight_decay, num_classes=1, **kwargs) 115 | self.__dict__.update(locals()) 116 | self.save_hyperparameters() 117 | 118 | self.conditional_densities = nn.ModuleList() 119 | for _ in range(num_classes): 120 | self.conditional_densities.append( 121 | Glow(in_channels, num_channels, num_levels, num_steps) 122 | ) 123 | -------------------------------------------------------------------------------- /uncertainty_est/models/normalizing_flow/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from uncertainty_est.archs.arch_factory import get_arch 4 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 5 | 6 | 7 | class IResNetFlow(OODDetectionModel): 8 | def __init__( 9 | self, 10 | arch_name, 11 | arch_config, 12 | learning_rate, 13 | momentum, 14 | weight_decay, 15 | warmup_steps=0, 16 | ): 17 | super().__init__() 18 | self.__dict__.update(locals()) 19 | self.save_hyperparameters() 20 | assert arch_name in ("iresnet_fc", "iresnet_conv") 21 | 22 | self.model = get_arch(arch_name, arch_config) 23 | 24 | def forward(self, x): 25 | return self.model(x) 26 | 27 | def training_step(self, batch, batch_idx): 28 | x, _ = batch 29 | x.requires_grad_() 30 | log_p = self.model.log_prob(x) 31 | 32 | loss = -log_p.mean() 33 | self.log("train/loss", loss) 34 | return loss 35 | 36 | def validation_step(self, batch, batch_idx): 37 | x, _ = batch 38 | with torch.enable_grad(): 39 | x.requires_grad_() 40 | log_p = self.model.log_prob(x) 41 | 42 | sigmas = [] 43 | for k, v in self.model.state_dict().items(): 44 | if "_sigma" in k: 45 | sigmas.append(v.item()) 46 | sigmas = torch.tensor(sigmas) 47 | self.log("val/sigma_mean", sigmas.mean().item()) 48 | 49 | loss = -log_p.mean() 50 | self.log("val/loss", loss) 51 | 52 | def test_step(self, batch, batch_idx): 53 | x, _ = batch 54 | with torch.enable_grad(): 55 | x.requires_grad_() 56 | log_p = self.model.log_prob(x) 57 | self.log("log_likelihood", log_p.mean()) 58 | 59 | def configure_optimizers(self): 60 | optim = torch.optim.AdamW( 61 | self.parameters(), 62 | betas=(self.momentum, 0.999), 63 | lr=self.learning_rate, 64 | weight_decay=self.weight_decay, 65 | ) 66 | return optim 67 | 68 | def get_ood_scores(self, x): 69 | with torch.enable_grad(): 70 | x.requires_grad_() 71 | log_p = self.model.log_prob(x).detach() 72 | return {"p(x)": log_p} 73 | -------------------------------------------------------------------------------- /uncertainty_est/models/normalizing_flow/norm_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from uncertainty_est.archs.arch_factory import get_arch 4 | from uncertainty_est.models.ood_detection_model import OODDetectionModel 5 | 6 | 7 | class NormalizingFlow(OODDetectionModel): 8 | def __init__( 9 | self, 10 | arch_name, 11 | arch_config, 12 | learning_rate, 13 | momentum, 14 | weight_decay, 15 | **kwargs, 16 | ): 17 | super().__init__(**kwargs) 18 | self.__dict__.update(locals()) 19 | self.save_hyperparameters() 20 | 21 | self.density_estimation = get_arch(arch_name, arch_config) 22 | 23 | def forward(self, x): 24 | return self.density_estimation.log_prob(x) 25 | 26 | def training_step(self, batch, batch_idx): 27 | x, _ = batch 28 | log_p = self.density_estimation.log_prob(x) 29 | 30 | loss = -log_p.mean() 31 | self.log("train/loss", loss) 32 | return loss 33 | 34 | def validation_step(self, batch, batch_idx): 35 | x, _ = batch 36 | log_p = self.density_estimation.log_prob(x) 37 | 38 | loss = -log_p.mean() 39 | self.log("val/loss", loss) 40 | 41 | def test_step(self, batch, batch_idx): 42 | x, _ = batch 43 | log_p = self.density_estimation.log_prob(x) 44 | self.log("log_likelihood", log_p.mean()) 45 | 46 | def configure_optimizers(self): 47 | optim = torch.optim.AdamW( 48 | self.parameters(), 49 | betas=(self.momentum, 0.999), 50 | lr=self.learning_rate, 51 | weight_decay=self.weight_decay, 52 | ) 53 | return optim 54 | 55 | def get_ood_scores(self, x): 56 | return {"p(x)": self.density_estimation.log_prob(x)} 57 | -------------------------------------------------------------------------------- /uncertainty_est/models/ood_detection_model.py: -------------------------------------------------------------------------------- 1 | from itertools import islice 2 | from os import path 3 | from typing import Any, Dict 4 | from collections import defaultdict 5 | 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | import pytorch_lightning as pl 10 | import matplotlib.pyplot as plt 11 | from sklearn.metrics import roc_auc_score, average_precision_score 12 | from uncertainty_eval.vis import plot_score_hist 13 | from uncertainty_eval.metrics.brier import brier_score, brier_decomposition 14 | from uncertainty_eval.metrics.calibration_error import classification_calibration 15 | from uncertainty_eval.vis import draw_reliability_graph, plot_score_hist 16 | 17 | from uncertainty_est.utils.utils import to_np 18 | from uncertainty_est.utils.metrics import accuracy 19 | from uncertainty_est.data.dataloaders import get_dataloader 20 | 21 | 22 | class OODDetectionModel(pl.LightningModule): 23 | def __init__( 24 | self, ood_val_datasets=None, is_toy_dataset=False, data_shape=None, **kwargs 25 | ): 26 | super().__init__() 27 | self.ood_val_datasets = ood_val_datasets 28 | self.is_toy_dataset = is_toy_dataset 29 | self.data_shape = data_shape 30 | self.test_ood_dataloaders = [] 31 | 32 | def eval_ood( 33 | self, id_loader, ood_loaders: Dict[str, Any], num=10_000 34 | ) -> Dict[str, float]: 35 | self.eval() 36 | 37 | if num > 0: 38 | max_batches = (num // id_loader.batch_size) + 1 39 | else: 40 | assert num == -1 41 | max_batches = None 42 | 43 | ood_metrics = {} 44 | 45 | # Compute ID OOD scores 46 | id_scores_dict = self.ood_detect(islice(id_loader, max_batches)) 47 | 48 | # Compute OOD detection metrics 49 | for dataset_name, loader in ood_loaders: 50 | try: 51 | ood_scores_dict = self.ood_detect(islice(loader, max_batches)) 52 | except Exception as e: 53 | print(e) 54 | continue 55 | 56 | for score_name, id_scores in id_scores_dict.items(): 57 | try: 58 | ood_scores = ood_scores_dict[score_name] 59 | 60 | length = min(len(ood_scores), len(id_scores)) 61 | ood = ood_scores[:length] 62 | id_ = id_scores[:length] 63 | 64 | if self.logger is not None and self.logger.log_dir is not None: 65 | ax = plot_score_hist( 66 | id_, 67 | ood, 68 | title="", 69 | ) 70 | ax.figure.savefig( 71 | path.join( 72 | self.logger.log_dir, f"{dataset_name}_{score_name}.png" 73 | ) 74 | ) 75 | plt.close() 76 | 77 | preds = np.concatenate([ood, id_]) 78 | 79 | labels = np.concatenate([np.zeros_like(ood), np.ones_like(id_)]) 80 | ood_metrics[(dataset_name, score_name, "AUROC")] = ( 81 | roc_auc_score(labels, preds) * 100.0 82 | ) 83 | ood_metrics[(dataset_name, score_name, "AUPR")] = ( 84 | average_precision_score(labels, preds) * 100.0 85 | ) 86 | 87 | labels = np.concatenate([np.ones_like(ood), np.zeros_like(id_)]) 88 | ood_metrics[(dataset_name, score_name, "AUROC")] = ( 89 | roc_auc_score(labels, -preds) * 100.0 90 | ) 91 | ood_metrics[(dataset_name, score_name, "AUPR")] = ( 92 | average_precision_score(labels, -preds) * 100.0 93 | ) 94 | except Exception as e: 95 | print(e) 96 | return ood_metrics 97 | 98 | def eval_classifier(self, loader, num=10_000): 99 | self.eval() 100 | 101 | if num > 0: 102 | max_batches = (num // loader.batch_size) + 1 103 | else: 104 | assert num == -1 105 | max_batches = None 106 | 107 | try: 108 | y, probs = self.get_gt_preds(islice(loader, max_batches)) 109 | y, probs = y[:num], probs[:num] 110 | except NotImplementedError: 111 | print("Model does not support classification.") 112 | return {} 113 | 114 | try: 115 | # Compute accuracy 116 | acc = accuracy(y, probs) 117 | 118 | # Compute calibration 119 | y_np, probs_np = to_np(y), to_np(probs) 120 | ece, mce = classification_calibration(y_np, probs_np) 121 | brier = brier_score(y_np, probs_np) 122 | uncertainty, resolution, reliability = brier_decomposition(y_np, probs_np) 123 | 124 | if self.logger is not None and self.logger.log_dir is not None: 125 | fig, ax = plt.subplots(figsize=(10, 10)) 126 | draw_reliability_graph(y_np, probs_np, 10, ax=ax) 127 | fig.savefig(self.logger.log_dir / "calibration.png", dpi=200) 128 | except: 129 | return {} 130 | 131 | return { 132 | "Accuracy": acc * 100, 133 | "ECE": ece * 100.0, 134 | "MCE": mce * 100.0, 135 | "Brier": brier * 100, 136 | "Brier uncertainty": uncertainty * 100, 137 | "Brier resolution": resolution * 100, 138 | "Brier reliability": reliability * 100, 139 | "Brier (via decomposition)": (reliability - resolution + uncertainty) * 100, 140 | } 141 | 142 | def setup(self, mode): 143 | if mode == "fit" and self.ood_val_datasets: 144 | batch_size = self.val_dataloader.dataloader.batch_size 145 | 146 | self.ood_val_loaders = [] 147 | for ood_ds_name in self.ood_val_datasets: 148 | ood_loader = get_dataloader( 149 | ood_ds_name, 150 | "val", 151 | batch_size=batch_size, 152 | data_shape=self.data_shape, 153 | ) 154 | self.ood_val_loaders.append((ood_ds_name, ood_loader)) 155 | 156 | def validation_epoch_end(self, outputs): 157 | if self.is_toy_dataset: 158 | interp = torch.linspace(-4, 4, 500) 159 | x, y = torch.meshgrid(interp, interp) 160 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1).to(self.device) 161 | px = to_np(torch.exp(self(data))) 162 | 163 | fig, ax = plt.subplots() 164 | mesh = ax.pcolormesh(x, y, px.reshape(*x.shape)) 165 | fig.colorbar(mesh) 166 | self.logger.experiment.add_figure("dist/p(x)", fig, self.current_epoch) 167 | plt.close() 168 | 169 | if hasattr(self, "ood_val_loaders"): 170 | ood_metrics = self.eval_ood( 171 | self.val_dataloader.dataloader, self.ood_val_loaders 172 | ) 173 | self.logger.experiment.add_scalars( 174 | "val/all_ood", 175 | {", ".join(k): v for k, v in ood_metrics.items()}, 176 | self.trainer.global_step, 177 | ) 178 | 179 | avg_over_dataset_results = defaultdict(list) 180 | for k, v in ood_metrics.items(): 181 | avg_over_dataset_results[", ".join(k[1:])].append(v) 182 | 183 | k, v = next( 184 | iter( 185 | {k: np.mean(v) for k, v in avg_over_dataset_results.items()}.items() 186 | ) 187 | ) 188 | self.log(f"val/ood", v) 189 | 190 | def optimizer_step( 191 | self, 192 | epoch: int = None, 193 | batch_idx: int = None, 194 | optimizer=None, 195 | optimizer_idx: int = None, 196 | optimizer_closure=None, 197 | on_tpu: bool = None, 198 | using_native_amp: bool = None, 199 | using_lbfgs: bool = None, 200 | **kwargs, 201 | ): 202 | # learning rate warm-up 203 | if ( 204 | optimizer is not None 205 | and hasattr(self, "warmup_steps") 206 | and self.trainer.global_step < self.warmup_steps 207 | ): 208 | lr_scale = min( 209 | 1.0, float(self.trainer.global_step + 1) / float(self.warmup_steps) 210 | ) 211 | for pg in optimizer.param_groups: 212 | pg["lr"] = lr_scale * self.hparams.learning_rate 213 | 214 | optimizer.step(closure=optimizer_closure) 215 | 216 | def ood_detect(self, loader): 217 | self.eval() 218 | torch.set_grad_enabled(False) 219 | 220 | scores = defaultdict(list) 221 | for x, _ in tqdm(loader, miniters=100): 222 | if not isinstance(x, torch.Tensor): 223 | x, _ = x 224 | x = x.to(self.device) 225 | out = self.get_ood_scores(x) 226 | for k, v in out.items(): 227 | scores[k].append(to_np(v)) 228 | 229 | scores = {k: np.concatenate(v) for k, v in scores.items()} 230 | return scores 231 | 232 | def get_gt_preds(self, loader): 233 | self.eval() 234 | torch.set_grad_enabled(False) 235 | 236 | gt, preds = [], [] 237 | for x, y in tqdm(loader, miniters=100): 238 | x = x.to(self.device) 239 | y_hat = self.classify(x).cpu() 240 | gt.append(y) 241 | preds.append(y_hat) 242 | return torch.cat(gt), torch.cat(preds) 243 | 244 | def get_ood_scores(self, x) -> Dict[str, torch.tensor]: 245 | raise NotImplementedError 246 | 247 | def classify(self, x) -> torch.tensor: 248 | raise NotImplementedError 249 | -------------------------------------------------------------------------------- /uncertainty_est/train.py: -------------------------------------------------------------------------------- 1 | # Fix https://github.com/pytorch/pytorch/issues/37377 2 | import numpy as _ 3 | 4 | import os 5 | import sys 6 | from uuid import uuid4 7 | 8 | from torch.functional import norm 9 | 10 | sys.path.insert(0, os.getcwd()) 11 | 12 | from pathlib import Path 13 | from datetime import datetime 14 | 15 | import yaml 16 | import hydra 17 | import pytorch_lightning as pl 18 | from omegaconf import DictConfig, OmegaConf 19 | 20 | from uncertainty_est.models import MODELS 21 | from uncertainty_est.data.dataloaders import get_dataloader 22 | 23 | 24 | @hydra.main(config_path="../configs", config_name="config") 25 | def run(cfg: DictConfig) -> None: 26 | cfg = OmegaConf.to_container(cfg.fixed, resolve=True) 27 | run_ex(**cfg, _run=cfg) 28 | 29 | 30 | def run_ex( 31 | trainer_config, 32 | model_name, 33 | model_config, 34 | dataset, 35 | batch_size, 36 | ood_dataset, 37 | earlystop_config, 38 | checkpoint_config, 39 | data_shape, 40 | seed, 41 | _run, 42 | num_classes=1, 43 | sigma=0.0, 44 | output_folder=None, 45 | log_dir=None, 46 | num_workers=4, 47 | test_ood_datasets=[], 48 | mutation_rate=0.0, 49 | num_cat=1, 50 | normalize=True, 51 | **kwargs, 52 | ): 53 | pl.seed_everything(seed) 54 | assert num_classes > 0 55 | 56 | model = MODELS[model_name](**model_config) 57 | 58 | train_loader = get_dataloader( 59 | dataset, 60 | "train", 61 | batch_size, 62 | data_shape=data_shape, 63 | ood_dataset=ood_dataset, 64 | sigma=sigma, 65 | num_workers=num_workers, 66 | mutation_rate=mutation_rate, 67 | normalize=normalize, 68 | ) 69 | val_loader = get_dataloader( 70 | dataset, 71 | "val", 72 | batch_size, 73 | data_shape=data_shape, 74 | sigma=sigma, 75 | ood_dataset=ood_dataset, 76 | num_workers=num_workers, 77 | normalize=normalize, 78 | ) 79 | test_loader = get_dataloader( 80 | dataset, 81 | "test", 82 | batch_size, 83 | data_shape=data_shape, 84 | sigma=sigma, 85 | ood_dataset=None, 86 | num_workers=num_workers, 87 | normalize=normalize, 88 | ) 89 | 90 | if log_dir == None: 91 | out_path = Path("logs") / model_name / dataset 92 | else: 93 | out_path = Path(log_dir) 94 | 95 | output_folder = ( 96 | f'{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}_{uuid4()}' 97 | if output_folder is None 98 | else output_folder 99 | ) 100 | 101 | # Circumvent issue when starting multiple versions with the same name 102 | trys = 10 103 | for i in range(trys): 104 | try: 105 | logger = pl.loggers.TensorBoardLogger( 106 | out_path, name=output_folder, default_hp_metric=False 107 | ) 108 | out_dir = Path(logger.log_dir) 109 | out_dir.mkdir(exist_ok=False, parents=True) 110 | break 111 | except FileExistsError as e: 112 | if i == (trys - 1): 113 | raise ValueError("Could not create log folder") from e 114 | print("Failed to create unique log folder. Trying again.") 115 | 116 | with (out_dir / "config.yaml").open("w") as f: 117 | f.write(yaml.dump(_run)) 118 | 119 | callbacks = [] 120 | callbacks.append(pl.callbacks.ModelCheckpoint(dirpath=out_dir, **checkpoint_config)) 121 | 122 | if earlystop_config is not None: 123 | es_callback = pl.callbacks.EarlyStopping(**earlystop_config) 124 | callbacks.append(es_callback) 125 | 126 | trainer = pl.Trainer( 127 | **trainer_config, 128 | logger=logger, 129 | callbacks=callbacks, 130 | progress_bar_refresh_rate=100, 131 | ) 132 | trainer.fit(model, train_loader, val_loader) 133 | 134 | try: 135 | _ = trainer.test(test_dataloaders=test_loader) 136 | except: 137 | _ = trainer.test(test_dataloaders=test_loader, ckpt_path=None) 138 | 139 | test_ood_dataloaders = [] 140 | for test_ood_dataset in test_ood_datasets: 141 | loader = get_dataloader( 142 | test_ood_dataset, 143 | "test", 144 | batch_size, 145 | data_shape=data_shape, 146 | sigma=sigma, 147 | ood_dataset=None, 148 | num_workers=num_workers, 149 | normalize=normalize, 150 | ) 151 | test_ood_dataloaders.append((test_ood_dataset, loader)) 152 | ood_results = model.eval_ood(test_loader, test_ood_dataloaders) 153 | ood_results = {", ".join(k): v for k, v in ood_results.items()} 154 | 155 | clf_results = model.eval_classifier(test_loader) 156 | 157 | results = {**ood_results, **clf_results} 158 | 159 | logger.log_hyperparams(model.hparams, results) 160 | return results 161 | 162 | 163 | if __name__ == "__main__": 164 | run() 165 | -------------------------------------------------------------------------------- /uncertainty_est/utils/dirichlet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import gammaln, digamma 3 | 4 | 5 | def dirichlet_prior_network_uncertainty(logits, epsilon=1e-10, alpha_correction=True): 6 | """ 7 | 8 | :param logits: 9 | :param epsilon: 10 | :return: 11 | """ 12 | 13 | logits = np.asarray(logits, dtype=np.float64) 14 | alphas = np.exp(logits) 15 | if alpha_correction: 16 | alphas = alphas + 1 17 | alpha0 = np.sum(alphas, axis=1, keepdims=True) 18 | probs = alphas / alpha0 19 | 20 | conf = np.max(probs, axis=1) 21 | 22 | entropy_of_exp = -np.sum(probs * np.log(probs + epsilon), axis=1) 23 | expected_entropy = -np.sum( 24 | (alphas / alpha0) * (digamma(alphas + 1) - digamma(alpha0 + 1.0)), axis=1 25 | ) 26 | mutual_info = entropy_of_exp - expected_entropy 27 | 28 | epkl = np.squeeze((alphas.shape[1] - 1.0) / alpha0) 29 | 30 | dentropy = ( 31 | np.sum( 32 | gammaln(alphas) - (alphas - 1.0) * (digamma(alphas) - digamma(alpha0)), 33 | axis=1, 34 | keepdims=True, 35 | ) 36 | - gammaln(alpha0) 37 | ) 38 | 39 | uncertainty = { 40 | "confidence": 1 - conf, 41 | "entropy_of_expected": entropy_of_exp, 42 | "expected_entropy": expected_entropy, 43 | "mutual_information": mutual_info, 44 | "EPKL": epkl, 45 | "differential_entropy": np.squeeze(dentropy), 46 | } 47 | 48 | return uncertainty 49 | -------------------------------------------------------------------------------- /uncertainty_est/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(y: torch.Tensor, y_hat: torch.Tensor): 5 | return (y == y_hat.argmax(dim=1)).float().mean(0).item() 6 | -------------------------------------------------------------------------------- /uncertainty_est/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from scipy.integrate import trapz 4 | 5 | 6 | def to_np(tensor: torch.Tensor): 7 | return tensor.detach().cpu().numpy() 8 | 9 | 10 | def eval_func_on_grid( 11 | density_func, 12 | interval=(-10, 10), 13 | num_samples=200, 14 | device="cpu", 15 | dimensions=2, 16 | batch_size=10_000, 17 | dtype=torch.float32, 18 | ): 19 | interp = torch.linspace(*interval, num_samples) 20 | grid_coords = torch.meshgrid(*[interp for _ in range(dimensions)]) 21 | grid = torch.stack([coords.reshape(-1) for coords in grid_coords], 1).to(dtype) 22 | 23 | vals = [] 24 | for samples in tqdm(torch.split(grid, batch_size)): 25 | vals.append(density_func(samples.to(device)).cpu()) 26 | vals = torch.cat(vals) 27 | return grid_coords, vals 28 | 29 | 30 | def estimate_normalizing_constant( 31 | density_func, 32 | interval=(-10, 10), 33 | num_samples=200, 34 | device="cpu", 35 | dimensions=2, 36 | batch_size=10_000, 37 | dtype=torch.float32, 38 | ): 39 | """ 40 | Numerically integrate a funtion in the specified interval. 41 | """ 42 | with torch.no_grad(): 43 | _, p_x = eval_func_on_grid( 44 | density_func, interval, num_samples, device, dimensions, batch_size, dtype 45 | ) 46 | 47 | dx = (abs(interval[0]) + abs(interval[1])) / num_samples 48 | # Integrate one dimension after another 49 | grid_vals = to_np(p_x).reshape(*[num_samples for _ in range(dimensions)]) 50 | for _ in range(dimensions): 51 | grid_vals = trapz(grid_vals, dx=dx, axis=-1) 52 | 53 | return torch.tensor(grid_vals) 54 | 55 | 56 | def sum_except_batch(x, num_batch_dims=1): 57 | """Sums all elements of `x` except for the first `num_batch_dims` dimensions.""" 58 | if x.ndimension() == 1: 59 | return x 60 | reduce_dims = list(range(num_batch_dims, x.ndimension())) 61 | return torch.sum(x, dim=reduce_dims) 62 | 63 | 64 | def split_leading_dim(x, shape): 65 | """Reshapes the leading dim of `x` to have the given shape.""" 66 | new_shape = torch.Size(shape) + x.shape[1:] 67 | return torch.reshape(x, new_shape) 68 | 69 | 70 | if __name__ == "__main__": 71 | dims = 2 72 | samples = 100 73 | print( 74 | estimate_normalizing_constant( 75 | lambda x: torch.empty(x.shape[0]).fill_(1 / (samples ** dims)), 76 | num_samples=samples, 77 | dimensions=dims, 78 | ) 79 | ) 80 | --------------------------------------------------------------------------------