├── .gitignore
├── .pre-commit-config.yaml
├── README.md
├── configs
├── config.yaml
├── dataset
│ ├── cifar10.yaml
│ ├── cifar10_embedded.yaml
│ ├── fashionmnist.yaml
│ ├── fashionmnist_embedded.yaml
│ ├── gaussian2d.yaml
│ ├── genomics.yaml
│ ├── genomics_embedded.yaml
│ ├── segment.yaml
│ └── sensorless.yaml
├── grid
│ ├── seed-bottleneck.yaml
│ ├── seed-clf_weight.yaml
│ ├── seed-pyxce.yaml
│ └── seed.yaml
├── model
│ ├── fc_ce_baseline.yaml
│ ├── fc_mcmc.yaml
│ ├── fc_mcmc_supervised.yaml
│ ├── fc_norm_flow.yaml
│ ├── fc_ssm.yaml
│ ├── fc_ssm_supervised.yaml
│ ├── fc_vera.yaml
│ ├── fc_vera_supervised.yaml
│ ├── genomics_autoregressive_density.yaml
│ ├── genomics_ce_baseline.yaml
│ ├── genomics_mcmc.yaml
│ ├── genomics_mcmc_supervised.yaml
│ ├── genomics_vera.yaml
│ ├── genomics_vera_supervised.yaml
│ ├── img_autoencoder.yaml
│ ├── img_ce_baseline.yaml
│ ├── img_glow.yaml
│ ├── img_mcmc.yaml
│ ├── img_mcmc_supervised.yaml
│ ├── img_real_nvp.yaml
│ ├── img_ssm.yaml
│ ├── img_ssm_supervised.yaml
│ ├── img_vera.yaml
│ ├── img_vera_posteriornet.yaml
│ ├── img_vera_priornet.yaml
│ └── img_vera_supervised.yaml
└── seml_config.yaml
├── density_histograms.png
├── req.txt
└── uncertainty_est
├── __init__.py
├── archs
├── __init__.py
├── arch_factory.py
├── fc.py
├── flows.py
├── glow
│ ├── __init__.py
│ ├── act_norm.py
│ ├── coupling.py
│ ├── glow.py
│ └── inv_conv.py
├── real_nvp
│ ├── __init__.py
│ ├── coupling_layer.py
│ ├── real_nvp.py
│ ├── resnet.py
│ └── util.py
├── resnet.py
└── wrn.py
├── data
├── __init__.py
├── dataloaders.py
└── datasets.py
├── evaluate.py
├── models
├── __init__.py
├── ce_baseline.py
├── ebm
│ ├── __init__.py
│ ├── conditional_nce.py
│ ├── discrete_mcmc.py
│ ├── flow_contrastive_estimation.py
│ ├── hdge.py
│ ├── mcmc.py
│ ├── nce.py
│ ├── ssm.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── utils.py
│ │ └── vera_utils.py
│ ├── vera.py
│ ├── vera_posteriornet.py
│ └── vera_priornet.py
├── energy_finetuning.py
├── normalizing_flow
│ ├── __init__.py
│ ├── approx_flow.py
│ ├── image_flows.py
│ ├── iresnet.py
│ └── norm_flow.py
└── ood_detection_model.py
├── train.py
└── utils
├── dirichlet.py
├── metrics.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | __pycache__
3 |
4 | logs
5 | grid_search
6 | slurm
7 | /_cache_datalsuntestlmdb
8 | notebooks
9 | tmp
10 | temp
11 | thesis_logs
12 | ebm_investigation_logs
13 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 | - repo: https://github.com/ambv/black
9 | rev: 20.8b1
10 | hooks:
11 | - id: black
12 | language_version: python3.6
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # On Out-of-distribution Detection with Energy-based Models
2 |
3 | This repository contains the code for the experiments conducted in the paper
4 |
5 | > [On Out-of-distribution Detection with Energy-based Models](https://arxiv.org/abs/2107.08785) \
6 | Sven Elflein, Bertrand Charpentier, Daniel Zügner, Stephan Günnemann \
7 | ICML 2021, Workshop on Uncertainty & Robustness in Deep Learning.
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## Setup
16 |
17 | ```
18 | conda create --name env --file req.txt
19 | conda activate env
20 | pip install git+https://github.com/selflein/nn_uncertainty_eval
21 | ```
22 |
23 | ### Datasets
24 | The image datasets should download automatically. For "Sensorless Drive" and "Segment" pre-processed .csv files can be downloaded from the [PostNet repo](https://github.com/sharpenb/Posterior-Network#training--evaluation) under "Training & Evaluation".
25 |
26 | ## Training & Evaluation
27 |
28 | In order to train a model use the respective combination of configurations for dataset and model, e.g.,
29 |
30 | ```
31 | python uncertainty_est/train.py fixed.output_folder=./path/to/output/folder dataset=sensorless model=fc_mcmc
32 | ```
33 |
34 | to train a EBM with MCMC on the Sensorless dataset. See `configs/model` for all model configurations.
35 |
36 | In order to evaluate models use
37 |
38 | ```
39 | python uncertainty_est/evaluate.py --checkpoint-dir ./path/to/directory/with/models --output-folder ./path/to/output/folder
40 | ```
41 |
42 | This script generates CSVs with the respective OOD metrics.
43 |
44 | ## Cite
45 |
46 | If you find our work helpful, please consider citing our paper in your own work.
47 |
48 | ```
49 | @misc{elflein2021outofdistribution,
50 | title={On Out-of-distribution Detection with Energy-based Models},
51 | author={Sven Elflein and Bertrand Charpentier and Daniel Zügner and Stephan Günnemann},
52 | year={2021},
53 | eprint={2107.08785},
54 | archivePrefix={arXiv},
55 | primaryClass={cs.LG}
56 | }
57 | ```
58 |
59 | ## Acknowledgements
60 |
61 | * RealNVP from https://github.com/chrischute/real-nvp
62 | * Glow from https://github.com/chrischute/glow
63 | * JEM from https://github.com/wgrathwohl/JEM
64 | * VERA from https://github.com/wgrathwohl/VERA
65 | * SSM from https://github.com/ermongroup/sliced_score_matching
66 | * WideResNet from https://github.com/meliketoy/wide-resnet.pytorch
67 |
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: .
4 | output_subdir:
5 |
6 | seml:
7 | executable: uncertainty_est/train.py
8 | output_dir: slurm
9 | project_root_dir: ..
10 |
11 | slurm:
12 | experiments_per_job: 1
13 | sbatch_options:
14 | gres: gpu:1 # num GPUs
15 | mem: 16G # memory
16 | cpus-per-task: 4 # num cores
17 | time: 0-05:00 # max time, D-HH:MM
18 |
19 | fixed:
20 | trainer_config:
21 | gpus: 1
22 | benchmark: True
23 | log_dir: .
24 | output_folder:
25 | ood_dataset:
26 | seed: 1
27 |
28 | model_config:
29 | data_shape: ${fixed.data_shape}
30 |
31 | defaults:
32 | - hydra/job_logging: stdout
33 | - model: fc_mcmc_supervised
34 | - dataset: segment
35 | - model/updates: ${defaults.0.model}_${defaults.1.dataset}
36 | optional: True
37 | - grid: seed
38 |
--------------------------------------------------------------------------------
/configs/dataset/cifar10.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 1
5 | sbatch_options:
6 | time: 0-30:00
7 |
8 | fixed:
9 | trainer_config:
10 | max_epochs: 50
11 |
12 | checkpoint_config:
13 | monitor: val/ood
14 | mode: max
15 |
16 | dataset: cifar10
17 | num_classes: 10
18 | batch_size: 32
19 | data_shape: [32, 32, 3]
20 |
21 | test_ood_datasets:
22 | - lsun
23 | - textures
24 | - cifar100
25 | - svhn
26 | - celeb-a
27 | - uniform_noise
28 | - gaussian_noise
29 | - constant
30 | - svhn_unscaled
31 |
32 | model_config:
33 | ood_val_datasets:
34 | - celeb-a
35 | - cifar100
36 |
--------------------------------------------------------------------------------
/configs/dataset/cifar10_embedded.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 5
5 | sbatch_options:
6 | time: 0-05:00
7 |
8 | fixed:
9 | dataset: cifar10_embedded
10 | num_classes: 10
11 | batch_size: 512
12 | data_shape: [640]
13 |
14 | checkpoint_config:
15 | monitor: val/ood
16 | mode: max
17 |
18 | test_ood_datasets:
19 | - lsun_embedded
20 | - textures_embedded
21 | - cifar100_embedded
22 | - svhn_embedded
23 | - celeb-a_embedded
24 | - cifar10_uniform_noise_embedded
25 | - cifar10_gaussian_noise_embedded
26 | - cifar10_constant_embedded
27 | - svhn_unscaled_embedded
28 |
29 | model_config:
30 | ood_val_datasets:
31 | - celeb-a_embedded
32 | - cifar100_embedded
33 |
--------------------------------------------------------------------------------
/configs/dataset/fashionmnist.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 1
5 | sbatch_options:
6 | time: 0-10:00
7 |
8 | fixed:
9 | dataset: fashionmnist
10 | num_classes: 10
11 | batch_size: 64
12 | data_shape: [32, 32, 1]
13 |
14 | trainer_config:
15 | max_epochs: 50
16 |
17 | checkpoint_config:
18 | monitor: val/ood
19 | mode: max
20 | save_last: True
21 |
22 | test_ood_datasets:
23 | - mnist
24 | - notmnist
25 | - kmnist
26 | - gaussian_noise
27 | - uniform_noise
28 | - constant
29 | - kmnist_unscaled
30 |
31 | model_config:
32 | ood_val_datasets:
33 | - mnist
34 | - kmnist
35 |
--------------------------------------------------------------------------------
/configs/dataset/fashionmnist_embedded.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 5
5 |
6 | fixed:
7 | dataset: fashionmnist_embedded
8 | num_classes: 10
9 | batch_size: 128
10 | data_shape: [640]
11 |
12 | checkpoint_config:
13 | monitor: val/ood
14 | mode: max
15 |
16 | test_ood_datasets:
17 | - mnist_embedded
18 | - notmnist_embedded
19 | - kmnist_embedded
20 | - fashionmnist_gaussian_noise_embedded
21 | - fashionmnist_uniform_noise_embedded
22 | - fashionmnist_constant_embedded
23 |
24 | model_config:
25 | ood_val_datasets:
26 | - mnist_embedded
27 | - kmnist_embedded
28 |
--------------------------------------------------------------------------------
/configs/dataset/gaussian2d.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 5
5 | sbatch_options:
6 | time: 0-05:00
7 |
8 | fixed:
9 | dataset: Gaussian2D
10 | num_classes: 3
11 | batch_size: 512
12 | data_shape:
13 | - 2
14 |
15 | model_config:
16 | is_toy_dataset: True
17 |
18 | test_ood_datasets:
19 | - AnomalousGaussian2D
20 |
--------------------------------------------------------------------------------
/configs/dataset/genomics.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 1
5 | sbatch_options:
6 | time: 0-30:00
7 | mem: 32G
8 |
9 | fixed:
10 | dataset: genomics
11 | data_shape: [250]
12 | num_cat: 4
13 | num_classes: 10
14 |
15 | test_ood_datasets:
16 | - genomics-ood
17 | - genomics-noise
18 |
19 | model_config:
20 | ood_val_datasets:
21 | - genomics-ood
22 |
23 | checkpoint_config:
24 | monitor: val/ood
25 | mode: max
26 | save_last: True
27 |
28 | trainer_config:
29 | max_steps: 100_000
30 | max_epochs: 100_000
31 | limit_val_batches: 1000
32 | val_check_interval: 10_000
33 | limit_test_batches: 10_000
34 |
--------------------------------------------------------------------------------
/configs/dataset/genomics_embedded.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 5
5 | sbatch_options:
6 | time: 0-10:00
7 | mem: 32G
8 |
9 | fixed:
10 | dataset: genomics_embedded
11 | batch_size: 512
12 | data_shape:
13 | - 128
14 | num_classes: 10
15 |
16 | checkpoint_config:
17 | monitor: val/ood
18 | mode: max
19 |
20 | test_ood_datasets:
21 | - genomics-ood_embedded
22 | - genomics-noise_embedded
23 |
24 | trainer_config:
25 | max_epochs: 50_000
26 | max_steps: 50_000
27 |
28 | model_config:
29 | ood_val_datasets:
30 | - genomics-ood_embedded
31 |
--------------------------------------------------------------------------------
/configs/dataset/segment.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 5
5 | sbatch_options:
6 | time: 0-05:00
7 |
8 | fixed:
9 | dataset: segment
10 | num_classes: 6
11 | batch_size: 512
12 | data_shape:
13 | - 18
14 |
15 | checkpoint_config:
16 | monitor: val/ood
17 | mode: max
18 |
19 | test_ood_datasets:
20 | - segment-ood
21 | - uniform_noise
22 | - gaussian_noise
23 | - constant
24 |
25 | trainer_config:
26 | max_epochs: 10_000
27 | max_steps: 10_000
28 | check_val_every_n_epoch: 10
29 |
30 | model_config:
31 | ood_val_datasets:
32 | - segment-ood
33 |
--------------------------------------------------------------------------------
/configs/dataset/sensorless.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | slurm:
4 | experiments_per_job: 5
5 | sbatch_options:
6 | time: 0-05:00
7 |
8 | fixed:
9 | dataset: sensorless
10 | num_classes: 9
11 | batch_size: 512
12 | data_shape: [48]
13 |
14 | test_ood_datasets:
15 | - sensorless-ood
16 | - uniform_noise
17 | - gaussian_noise
18 | - constant
19 |
20 | checkpoint_config:
21 | monitor: val/ood
22 | mode: max
23 | save_last: True
24 |
25 | trainer_config:
26 | max_epochs: 10_000
27 | max_steps: 10_000
28 | check_val_every_n_epoch: 10
29 |
30 | model_config:
31 | ood_val_datasets:
32 | - sensorless-ood
33 |
--------------------------------------------------------------------------------
/configs/grid/seed-bottleneck.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | seed:
4 | type: choice
5 | options:
6 | - 2525
7 | - 2123
8 | - 293993
9 | - 2324234
10 | - 6566634
11 |
12 | model_config.arch_config.bottleneck_channels_factor:
13 | type: choice
14 | options:
15 | - 0.05
16 | - 0.1
17 | - 0.2
18 |
--------------------------------------------------------------------------------
/configs/grid/seed-clf_weight.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | seed:
4 | type: choice
5 | options:
6 | - 2525
7 | - 2123
8 | - 293993
9 | - 2324234
10 | - 6566634
11 |
12 | model_config.clf_weight:
13 | type: loguniform
14 | min: 0.01
15 | max: 100
16 | num: 5
17 |
--------------------------------------------------------------------------------
/configs/grid/seed-pyxce.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | seed:
4 | type: choice
5 | options:
6 | - 2525
7 | - 2123
8 | - 293993
9 | - 2324234
10 | - 6566634
11 |
12 | model_config.pyxce:
13 | type: loguniform
14 | min: 0.01
15 | max: 100
16 | num: 5
17 |
--------------------------------------------------------------------------------
/configs/grid/seed.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | seed:
4 | type: choice
5 | options:
6 | - 2525
7 | - 2123
8 | - 293993
9 | - 2324234
10 | - 6566634
11 |
--------------------------------------------------------------------------------
/configs/model/fc_ce_baseline.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 10_000
5 | max_steps: 10_000
6 | check_val_every_n_epoch: 100
7 |
8 | checkpoint_config:
9 | monitor: val/acc
10 | mode: max
11 |
12 | earlystop_config:
13 | monitor: val/acc
14 | mode: max
15 | patience: 10
16 |
17 | ood_dataset:
18 |
19 | model_name: CEBaseline
20 | model_config:
21 | arch_name: fc
22 | arch_config:
23 | inp_dim: ${fixed.data_shape.0}
24 | num_classes: ${fixed.num_classes}
25 | hidden_dims: [100, 100, 100, 100, 100]
26 | learning_rate: 0.001
27 | momentum: 0.9
28 | weight_decay: 0.0005
29 |
--------------------------------------------------------------------------------
/configs/model/fc_mcmc.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 10_000
5 | max_steps: 10_000
6 |
7 | checkpoint_config:
8 | save_last: True
9 |
10 | earlystop_config:
11 |
12 | ood_dataset: ${fixed.dataset}
13 |
14 | model_name: JEM
15 | model_config:
16 | arch_name: fc
17 | arch_config:
18 | inp_dim: ${fixed.data_shape.0}
19 | num_classes: 1
20 | hidden_dims: [100, 100, 100, 100, 100]
21 |
22 | learning_rate: 0.001
23 | momentum: 0.9
24 | weight_decay: 0.0
25 | smoothing: 0.0
26 | sgld_lr: 1.0
27 | sgld_std: 0.01
28 | sgld_steps: 100
29 | pyxce: 0.0
30 | pxsgld: 1.0
31 | pxysgld: 0.0
32 | buffer_size: 9000
33 | reinit_freq: 0.05
34 | data_shape: ${fixed.data_shape}
35 | sgld_batch_size: ${fixed.batch_size}
36 | class_cond_p_x_sample: False
37 | n_classes: 1
38 | warmup_steps: 2500
39 | entropy_reg_weight: 0.0
40 | lr_step_size: 1000
41 |
--------------------------------------------------------------------------------
/configs/model/fc_mcmc_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 10_000
5 | max_steps: 10_000
6 |
7 | earlystop_config:
8 |
9 | ood_dataset: ${fixed.dataset}
10 |
11 | model_name: JEM
12 | model_config:
13 | arch_name: fc
14 | arch_config:
15 | inp_dim: ${fixed.data_shape.0}
16 | num_classes: ${fixed.num_classes}
17 | hidden_dims: [100, 100, 100, 100, 100]
18 |
19 | learning_rate: 0.001
20 | momentum: 0.9
21 | weight_decay: 0.0
22 | smoothing: 0.0
23 | sgld_lr: 1.0
24 | sgld_std: 0.01
25 | sgld_steps: 100
26 | pyxce: 1.0
27 | pxsgld: 1.0
28 | pxysgld: 0.0
29 | buffer_size: 9000
30 | reinit_freq: 0.05
31 | data_shape: ${fixed.data_shape}
32 | sgld_batch_size: ${fixed.batch_size}
33 | class_cond_p_x_sample: True
34 | n_classes: ${fixed.num_classes}
35 | warmup_steps: 2500
36 | entropy_reg_weight: 0.0
37 | lr_step_size: 1000
38 |
--------------------------------------------------------------------------------
/configs/model/fc_norm_flow.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | checkpoint_config:
4 | monitor: val/loss
5 | mode: min
6 |
7 | earlystop_config:
8 | monitor: val/loss
9 | mode: min
10 | patience: 10
11 |
12 | ood_dataset:
13 |
14 | model_name: NormalizingFlow
15 | model_config:
16 | arch_name: normalizing_flow
17 | arch_config:
18 | flow_type: radial_flow
19 | dim: ${fixed.data_shape.0}
20 | flow_length: 20
21 | learning_rate: 0.001
22 | momentum: 0.9
23 | weight_decay: 0.0
24 |
--------------------------------------------------------------------------------
/configs/model/fc_ssm.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 10_000
5 | max_steps: 10_000
6 |
7 | checkpoint_config:
8 | save_last: True
9 |
10 | earlystop_config:
11 |
12 | ood_dataset: ${fixed.dataset}
13 |
14 | model_name: SSM
15 | model_config:
16 | arch_name: fc
17 | arch_config:
18 | inp_dim: ${fixed.data_shape.0}
19 | num_classes: 1
20 | hidden_dims: [100, 100, 100, 100, 100]
21 |
22 | learning_rate: 0.001
23 | momentum: 0.9
24 | weight_decay: 0.0
25 | clf_weight: 0.0
26 | n_classes: 1
27 | n_particles: 1
28 | warmup_steps: 2500
29 | lr_step_size: 1000
30 |
--------------------------------------------------------------------------------
/configs/model/fc_ssm_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 10_000
5 | max_steps: 10_000
6 |
7 | earlystop_config:
8 |
9 | ood_dataset: ${fixed.dataset}
10 |
11 | model_name: SSM
12 | model_config:
13 | arch_name: fc
14 | arch_config:
15 | inp_dim: ${fixed.data_shape.0}
16 | num_classes: ${fixed.num_classes}
17 | hidden_dims: [100, 100, 100, 100, 100]
18 |
19 | learning_rate: 0.001
20 | momentum: 0.9
21 | weight_decay: 0.0
22 | clf_weight: 1.0
23 | n_classes: ${fixed.num_classes}
24 | n_particles: 1
25 | warmup_steps: 2500
26 | lr_step_size: 1000
27 |
--------------------------------------------------------------------------------
/configs/model/fc_vera.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_steps: 10_000
5 | max_epochs: 10_000
6 |
7 | checkpoint_config:
8 | save_last: True
9 |
10 | earlystop_config:
11 |
12 | ood_dataset: ${fixed.dataset}
13 |
14 | model_name: VERA
15 | model_config:
16 | arch_name: fc
17 | arch_config:
18 | inp_dim: ${fixed.data_shape.0}
19 | num_classes: 1
20 | hidden_dims: [100, 100, 100, 100, 100]
21 | batch_norm: False
22 | bias: True
23 | slope: 0.2
24 |
25 | learning_rate: 0.00003
26 | beta1: 0.0
27 | beta2: 0.9
28 | weight_decay: 0.0
29 | n_classes: 1
30 | uncond: False
31 | gen_learning_rate: 0.00006
32 | ebm_iters: 1
33 | generator_iters: 1
34 | entropy_weight: 0.0001
35 |
36 | generator_type: vera
37 | generator_arch_name: fc
38 | generator_config:
39 | noise_dim: 16
40 | post_lr: 0.00003
41 | init_post_logsigma: 0.1
42 | generator_arch_config:
43 | inp_dim: ${fixed.model_config.generator_config.noise_dim}
44 | num_classes: ${fixed.data_shape.0}
45 | hidden_dims: ${fixed.model_config.arch_config.hidden_dims}
46 | batch_norm: True
47 | bias: False
48 |
49 | min_sigma: 0.01
50 | max_sigma: 0.3
51 | p_control: 0.0
52 | n_control: 0.0
53 | pg_control: 0.1
54 | clf_ent_weight: 0.0
55 | ebm_type: p_x
56 | clf_weight: 0.0
57 | warmup_steps: 2500
58 | no_g_batch_norm: False
59 | batch_size: ${fixed.batch_size}
60 | lr_decay: 0.3
61 | lr_decay_epochs: [3000, 4000]
62 |
--------------------------------------------------------------------------------
/configs/model/fc_vera_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_steps: 10_000
5 | max_epochs: 10_000
6 |
7 | earlystop_config:
8 |
9 | ood_dataset: ${fixed.dataset}
10 |
11 | model_name: VERA
12 | model_config:
13 | arch_name: fc
14 | arch_config:
15 | inp_dim: ${fixed.data_shape.0}
16 | num_classes: ${fixed.num_classes}
17 | hidden_dims: [100, 100, 100, 100, 100]
18 | batch_norm: False
19 | slope: 0.2
20 |
21 | learning_rate: 0.001
22 | beta1: 0.0
23 | beta2: 0.9
24 | weight_decay: 0.0
25 | n_classes: ${fixed.num_classes}
26 | uncond: False
27 | gen_learning_rate: 0.001
28 | ebm_iters: 1
29 | generator_iters: 1
30 | entropy_weight: 0.0001
31 |
32 | generator_type: vera
33 | generator_arch_name: fc
34 | generator_config:
35 | noise_dim: 16
36 | post_lr: 0.00003
37 | init_post_logsigma: 0.1
38 | generator_arch_config:
39 | inp_dim: ${fixed.model_config.generator_config.noise_dim}
40 | num_classes: ${fixed.data_shape.0}
41 | hidden_dims: ${fixed.model_config.arch_config.hidden_dims}
42 | activation: relu
43 | batch_norm: True
44 | bias: False
45 |
46 | min_sigma: 0.01
47 | max_sigma: 0.3
48 | p_control: 0.0
49 | n_control: 0.0
50 | pg_control: 0.1
51 | clf_ent_weight: 0.0
52 | ebm_type: jem
53 | clf_weight: 1.0
54 | warmup_steps: 2500
55 | no_g_batch_norm: False
56 | batch_size: ${fixed.batch_size}
57 | lr_decay: 0.3
58 | lr_decay_epochs: [3000, 4000]
59 |
--------------------------------------------------------------------------------
/configs/model/genomics_autoregressive_density.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | checkpoint_config:
4 | monitor: val/loss
5 | mode: min
6 |
7 | earlystop_config:
8 | monitor: val/loss
9 | mode: min
10 | patience: 10
11 |
12 | batch_size: 512
13 |
14 | model_name: NormalizingFlow
15 | model_config:
16 | arch_name: seq_generative_model
17 | arch_config:
18 | input_size: ${fixed.num_cat}
19 | hidden_size: 128
20 | num_classes: ${fixed.num_classes}
21 | learning_rate: 0.005
22 | momentum: 0.9
23 | weight_decay: 0.0005
24 |
--------------------------------------------------------------------------------
/configs/model/genomics_ce_baseline.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | batch_size: 512
4 |
5 | earlystop_config:
6 | monitor: val/acc
7 | mode: max
8 | patience: 10
9 |
10 | model_name: CEBaseline
11 | model_config:
12 | arch_name: seq_classifier
13 | arch_config:
14 | in_channels: ${fixed.num_cat}
15 | num_filters: 128
16 | fc_hidden_size: 1000
17 | num_classes: 10
18 | kernel_size: 20
19 | learning_rate: 0.0001
20 | momentum: 0.9
21 | weight_decay: 0.0005
22 |
--------------------------------------------------------------------------------
/configs/model/genomics_mcmc.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | batch_size: 512
4 |
5 | checkpoint_config:
6 |
7 | earlystop_config:
8 |
9 | ood_dataset: ${fixed.dataset}
10 |
11 | model_name: DiscreteMCMC
12 | model_config:
13 | num_cat: ${fixed.num_cat}
14 |
15 | arch_name: seq_classifier
16 | arch_config:
17 | in_channels: ${fixed.num_cat}
18 | num_filters: 128
19 | fc_hidden_size: 1000
20 | num_classes: 1
21 | kernel_size: 20
22 |
23 | learning_rate: 0.0001
24 | momentum: 0.9
25 | weight_decay: 0.0
26 | smoothing: 0.0
27 | sgld_steps: 40
28 | pyxce: 0.0
29 | pxsgld: 1.0
30 | pxysgld: 0.0
31 | buffer_size: 9999
32 | reinit_freq: 0.0
33 | data_shape: ${fixed.data_shape}
34 | sgld_batch_size: ${fixed.batch_size}
35 | class_cond_p_x_sample: False
36 | n_classes: 1
37 | warmup_steps: 2500
38 | entropy_reg_weight: 0.0
39 | lr_step_size: 1000
40 |
--------------------------------------------------------------------------------
/configs/model/genomics_mcmc_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | batch_size: 512
4 |
5 | earlystop_config:
6 |
7 | ood_dataset: ${fixed.dataset}
8 |
9 | model_name: DiscreteMCMC
10 | model_config:
11 | num_cat: ${fixed.num_cat}
12 |
13 | arch_name: seq_classifier
14 | arch_config:
15 | in_channels: ${fixed.num_cat}
16 | num_filters: 128
17 | fc_hidden_size: 1000
18 | num_classes: ${fixed.num_classes}
19 | kernel_size: 20
20 |
21 | learning_rate: 0.0001
22 | momentum: 0.9
23 | weight_decay: 0.0
24 | smoothing: 0.0
25 | sgld_steps: 40
26 | pyxce: 1.0
27 | pxsgld: 1.0
28 | pxysgld: 0.0
29 | buffer_size: 9999
30 | reinit_freq: 0.0
31 | data_shape: ${fixed.data_shape}
32 | sgld_batch_size: 512
33 | class_cond_p_x_sample: True
34 | n_classes: ${fixed.num_classes}
35 | warmup_steps: 2500
36 | entropy_reg_weight: 0.0001
37 | lr_step_size: 1000
38 |
--------------------------------------------------------------------------------
/configs/model/genomics_vera.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | batch_size: 256
4 |
5 | checkpoint_config:
6 |
7 | earlystop_config:
8 |
9 | ood_dataset: ${fixed.dataset}
10 |
11 | model_name: VERA
12 | model_config:
13 | arch_name: seq_classifier
14 | arch_config:
15 | in_channels: ${fixed.num_cat}
16 | num_filters: 128
17 | fc_hidden_size: 1000
18 | num_classes: 1
19 | kernel_size: 20
20 |
21 | learning_rate: 0.00003
22 | beta1: 0.0
23 | beta2: 0.9
24 | weight_decay: 0.0
25 | n_classes: 1
26 | uncond: False
27 | gen_learning_rate: 0.00006
28 | ebm_iters: 1
29 | generator_iters: 1
30 | entropy_weight: 0.0001
31 |
32 | generator_type: vera_discrete
33 | generator_arch_name: seq_generator
34 | generator_config:
35 | noise_dim: 128
36 | post_lr: 0.00003
37 | init_post_logsigma: 0.1
38 | generator_arch_config:
39 | inp_dim: ${fixed.model_config.generator_config.noise_dim}
40 | num_classes: ${fixed.num_cat}
41 | seq_length: ${fixed.data_shape.0}
42 |
43 | min_sigma: 0.01
44 | max_sigma: 0.3
45 | p_control: 0.0
46 | n_control: 0.0
47 | pg_control: 0.1
48 | clf_ent_weight: 0.0
49 | ebm_type: p_x
50 | clf_weight: 0.0
51 | warmup_steps: 2500
52 | no_g_batch_norm: False
53 | batch_size: ${fixed.batch_size}
54 | lr_decay: 0.3
55 | lr_decay_epochs: [150, 180]
56 |
--------------------------------------------------------------------------------
/configs/model/genomics_vera_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | batch_size: 256
4 |
5 | checkpoint_config:
6 |
7 | earlystop_config:
8 |
9 | ood_dataset: ${fixed.dataset}
10 |
11 | model_name: VERA
12 | model_config:
13 | arch_name: seq_classifier
14 | arch_config:
15 | in_channels: ${fixed.num_cat}
16 | num_filters: 128
17 | fc_hidden_size: 1000
18 | num_classes: ${fixed.num_classes}
19 | kernel_size: 20
20 |
21 | learning_rate: 0.00003
22 | beta1: 0.0
23 | beta2: 0.9
24 | weight_decay: 0.0
25 | n_classes: ${fixed.num_classes}
26 | uncond: False
27 | gen_learning_rate: 0.00006
28 | ebm_iters: 1
29 | generator_iters: 1
30 | entropy_weight: 0.0001
31 |
32 | generator_type: vera_discrete
33 | generator_arch_name: seq_generator
34 | generator_config:
35 | noise_dim: 128
36 | post_lr: 0.00003
37 | init_post_logsigma: 0.1
38 | generator_arch_config:
39 | inp_dim: ${fixed.model_config.generator_config.noise_dim}
40 | num_classes: ${fixed.num_cat}
41 | seq_length: ${fixed.data_shape.0}
42 |
43 | min_sigma: 0.01
44 | max_sigma: 0.3
45 | p_control: 0.0
46 | n_control: 0.0
47 | pg_control: 0.1
48 | clf_ent_weight: 0.0
49 | ebm_type: jem
50 | clf_weight: 1.0
51 | warmup_steps: 2500
52 | no_g_batch_norm: False
53 | batch_size: ${fixed.batch_size}
54 | lr_decay: 0.3
55 | lr_decay_epochs: [150, 180]
56 |
--------------------------------------------------------------------------------
/configs/model/img_autoencoder.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 50
5 |
6 | checkpoint_config:
7 | monitor: val/loss
8 | mode: min
9 |
10 | earlystop_config:
11 | monitor: val/loss
12 | mode: min
13 | patience: 10
14 |
15 | model_name: Autoencoder
16 | model_config:
17 | arch_name: wrn
18 | arch_config:
19 | depth: 10
20 | num_classes: 32
21 | widen_factor: 2
22 | input_channels: ${fixed.data_shape.2}
23 | decoder_arch_name: resnetgenerator
24 | decoder_arch_config:
25 | unit_interval: False
26 | feats: 32
27 | out_channels: ${fixed.data_shape.2}
28 | learning_rate: 0.0001
29 | momentum: 0.9
30 | weight_decay: 0.0005
31 |
--------------------------------------------------------------------------------
/configs/model/img_ce_baseline.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 50
5 |
6 | checkpoint_config:
7 | monitor: val/acc
8 | mode: max
9 |
10 | earlystop_config:
11 | monitor: val/acc
12 | mode: max
13 | patience: 10
14 |
15 | model_name: CEBaseline
16 | model_config:
17 | arch_name: wrn
18 | arch_config:
19 | depth: 10
20 | num_classes: ${fixed.num_classes}
21 | widen_factor: 2
22 | input_channels: ${fixed.data_shape.2}
23 | learning_rate: 0.0001
24 | momentum: 0.9
25 | weight_decay: 0.0005
26 |
--------------------------------------------------------------------------------
/configs/model/img_glow.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | earlystop_config:
4 | monitor: val/loss
5 | mode: min
6 | patience: 10
7 |
8 | normalize: False
9 | ood_dataset:
10 |
11 | model_name: Glow
12 | model_config:
13 | in_channels: ${fixed.data_shape.2}
14 | num_channels: 512
15 | num_levels: 3
16 | num_steps: 32
17 | learning_rate: 0.001
18 | momentum: 0.9
19 | weight_decay: 0.0
20 |
--------------------------------------------------------------------------------
/configs/model/img_mcmc.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 20
5 |
6 | checkpoint_config:
7 | save_last: True
8 |
9 | earlystop_config:
10 |
11 | ood_dataset: ${fixed.dataset}
12 | sigma: 0.1
13 |
14 | model_name: JEM
15 | model_config:
16 | n_classes: 1
17 | arch_name: wrn
18 | arch_config:
19 | depth: 10
20 | num_classes: 1
21 | widen_factor: 2
22 | input_channels: ${fixed.data_shape.2}
23 | strides: [1, 2, 2]
24 | learning_rate: 0.0001
25 | momentum: 0.9
26 | weight_decay: 0.0
27 | smoothing: 0.0
28 | sgld_lr: 1.
29 | sgld_std: 0.01
30 | sgld_steps: 100
31 | pyxce: 0.0
32 | pxsgld: 1.0
33 | pxysgld: 0.0
34 | buffer_size: 10000
35 | reinit_freq: 0.05
36 | data_shape: ${fixed.data_shape}
37 | sgld_batch_size: ${fixed.batch_size}
38 | class_cond_p_x_sample: False
39 | entropy_reg_weight: 0.0001
40 |
--------------------------------------------------------------------------------
/configs/model/img_mcmc_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 20
5 |
6 | checkpoint_config:
7 |
8 | earlystop_config:
9 |
10 | ood_dataset: ${fixed.dataset}
11 | sigma: 0.1
12 |
13 | model_name: JEM
14 | model_config:
15 | n_classes: ${fixed.num_classes}
16 | arch_name: wrn
17 | arch_config:
18 | depth: 10
19 | num_classes: ${fixed.num_classes}
20 | widen_factor: 2
21 | input_channels: ${fixed.data_shape.2}
22 | strides: [1, 2, 2]
23 | learning_rate: 0.0001
24 | momentum: 0.9
25 | weight_decay: 0.0
26 | smoothing: 0.0
27 | sgld_lr: 1.
28 | sgld_std: 0.01
29 | sgld_steps: 100
30 | pyxce: 1.0
31 | pxsgld: 1.0
32 | pxysgld: 0.0
33 | buffer_size: 10000
34 | reinit_freq: 0.05
35 | data_shape: ${fixed.data_shape}
36 | sgld_batch_size: ${fixed.batch_size}
37 | class_cond_p_x_sample: True
38 | entropy_reg_weight: 0.0001
39 |
--------------------------------------------------------------------------------
/configs/model/img_real_nvp.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | checkpoint_config:
4 | monitor: val/loss
5 | mode: min
6 | save_last: True
7 |
8 | earlystop_config:
9 | monitor: val/loss
10 | mode: min
11 | patience: 10
12 |
13 | normalize: False
14 | ood_dataset:
15 |
16 | model_name: RealNVP
17 | model_config:
18 | num_scales: 2
19 | in_channels: ${fixed.data_shape.2}
20 | mid_channels: 32
21 | num_blocks: 4
22 | num_classes: 10
23 | learning_rate: 0.0001
24 | momentum: 0.9
25 | weight_decay: 0.0
26 |
--------------------------------------------------------------------------------
/configs/model/img_ssm.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 50
5 |
6 | checkpoint_config:
7 | save_last: True
8 |
9 | earlystop_config:
10 |
11 | ood_dataset: ${fixed.dataset}
12 | sigma: 0.1
13 |
14 | model_name: SSM
15 | model_config:
16 | arch_name: wrn
17 | arch_config:
18 | depth: 10
19 | num_classes: 1
20 | widen_factor: 2
21 | input_channels: ${fixed.data_shape.2}
22 | strides: [1, 2, 2]
23 |
24 | learning_rate: 0.0001
25 | momentum: 0.9
26 | weight_decay: 0.0
27 | clf_weight: 0.0
28 | n_classes: 1
29 | n_particles: 1
30 | warmup_steps: 2500
31 | lr_step_size: 1000
32 |
--------------------------------------------------------------------------------
/configs/model/img_ssm_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 50
5 |
6 | checkpoint_config:
7 |
8 | earlystop_config:
9 |
10 | ood_dataset: ${fixed.dataset}
11 | sigma: 0.1
12 |
13 | model_name: SSM
14 | model_config:
15 | arch_name: wrn
16 | arch_config:
17 | depth: 10
18 | num_classes: ${fixed.num_classes}
19 | widen_factor: 2
20 | input_channels: ${fixed.data_shape.2}
21 | strides: [1, 2, 2]
22 |
23 | learning_rate: 0.0001
24 | momentum: 0.9
25 | weight_decay: 0.0
26 | clf_weight: 1.0
27 | n_classes: ${fixed.num_classes}
28 | n_particles: 1
29 | warmup_steps: 2500
30 | lr_step_size: 1000
31 |
--------------------------------------------------------------------------------
/configs/model/img_vera.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | checkpoint_config:
4 | save_last: True
5 |
6 | earlystop_config:
7 |
8 | ood_dataset: ${fixed.dataset}
9 |
10 | model_name: VERA
11 | model_config:
12 | n_classes: 1
13 | arch_name: wrn
14 | arch_config:
15 | depth: 10
16 | num_classes: 1
17 | widen_factor: 2
18 | input_channels: ${fixed.data_shape.2}
19 | strides: [1, 2, 2]
20 | learning_rate: 0.00003
21 | beta1: 0.0
22 | beta2: 0.9
23 | weight_decay: 0.0
24 | gen_learning_rate: 0.00006
25 | ebm_iters: 1
26 | generator_iters: 1
27 | entropy_weight: 0.0001
28 | generator_type: vera
29 | generator_arch_name: resnetgenerator
30 | generator_arch_config:
31 | unit_interval: False
32 | feats: 128
33 | out_channels: ${fixed.data_shape.2}
34 | generator_config:
35 | noise_dim: 128
36 | post_lr: 0.00003
37 | init_post_logsigma: 0.1
38 | min_sigma: 0.01
39 | max_sigma: 0.3
40 | p_control: 0.0
41 | n_control: 0.0
42 | pg_control: 0.1
43 | clf_ent_weight: 0.0001
44 | ebm_type: p_x
45 | clf_weight: 0.0
46 | warmup_steps: 2500
47 | no_g_batch_norm: False
48 | batch_size: ${fixed.batch_size}
49 | lr_decay: 0.3
50 | lr_decay_epochs: [15, 18]
51 |
--------------------------------------------------------------------------------
/configs/model/img_vera_posteriornet.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 40
5 | terminate_on_nan: True
6 |
7 | checkpoint_config:
8 | monitor: val/acc
9 | mode: max
10 | save_last: True
11 |
12 | earlystop_config:
13 | monitor: val/acc
14 | mode: max
15 | patience: 10
16 |
17 | ood_dataset: ${fixed.dataset}
18 | sigma: 0.1
19 |
20 | model_name: VERAPosteriorNet
21 | model_config:
22 | arch_name: wrn
23 | arch_config:
24 | depth: 4
25 | num_classes: ${fixed.num_classes}
26 | widen_factor: 10
27 | input_channels: ${fixed.data_shape.2}
28 | dropout: 0.3
29 | norm: group
30 |
31 | learning_rate: 0.00003
32 | beta1: 0.0
33 | beta2: 0.9
34 | weight_decay: 0.1
35 | n_classes: 10
36 | uncond: False
37 | gen_learning_rate: 0.00006
38 | ebm_iters: 1
39 | generator_iters: 1
40 | entropy_weight: 0.0001
41 |
42 | generator_type: vera
43 | generator_arch_name: resnetgenerator
44 | generator_arch_config:
45 | unit_interval: False
46 | feats: 128
47 | out_channels: ${fixed.data_shape.2}
48 |
49 | generator_config:
50 | noise_dim: 128
51 | post_lr: 0.00003
52 | init_post_logsigma: 0.1
53 |
54 | min_sigma: 0.01
55 | max_sigma: 0.3
56 | p_control: 1.0
57 | n_control: 1.0
58 | pg_control: 0.1
59 | clf_ent_weight: 0.1
60 | ebm_type: jem
61 | clf_weight: 100.0
62 | warmup_steps: 2500
63 | no_g_batch_norm: False
64 | batch_size: ${fixed.batch_size}
65 | lr_decay: 0.3
66 | lr_decay_epochs: [30, 35]
67 | vis_every: -1
68 | alpha_fix: True
69 | entropy_reg: 0.0001
70 |
--------------------------------------------------------------------------------
/configs/model/img_vera_priornet.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | trainer_config:
4 | max_epochs: 50
5 | terminate_on_nan: True
6 |
7 | checkpoint_config:
8 | monitor: val/acc
9 | mode: max
10 | save_last: True
11 |
12 | earlystop_config:
13 | monitor: val/acc
14 | mode: max
15 | patience: 10
16 |
17 | ood_dataset: ${fixed.dataset}
18 |
19 | model_name: VERAPriorNet
20 | model_config:
21 | arch_name: wrn
22 | arch_config:
23 | depth: 10
24 | num_classes: ${fixed.num_classes}
25 | widen_factor: 4
26 | input_channels: ${fixed.data_shape.2}
27 | dropout: 0.3
28 | norm: group
29 |
30 | learning_rate: 0.00003
31 | beta1: 0.0
32 | beta2: 0.9
33 | weight_decay: 0.1
34 | n_classes: 10
35 | uncond: False
36 | gen_learning_rate: 0.00006
37 | ebm_iters: 1
38 | generator_iters: 1
39 | entropy_weight: 0.0001
40 |
41 | generator_type: vera
42 | generator_arch_name: resnetgenerator
43 | generator_arch_config:
44 | unit_interval: False
45 | feats: 128
46 | out_channels: ${fixed.data_shape.2}
47 |
48 | generator_config:
49 | noise_dim: 128
50 | post_lr: 0.00003
51 | init_post_logsigma: 0.1
52 |
53 | min_sigma: 0.01
54 | max_sigma: 0.3
55 | p_control: 1.0
56 | n_control: 1.0
57 | pg_control: 1.0
58 | clf_ent_weight: 0.1
59 | ebm_type: jem
60 | clf_weight: 100.0
61 | warmup_steps: 2500
62 | no_g_batch_norm: False
63 | batch_size: 32
64 | lr_decay: 0.3
65 | lr_decay_epochs: [40, 45]
66 | vis_every: -1
67 | alpha_fix: True
68 | concentration: 1.0
69 | target_concentration:
70 | entropy_reg: 0.0001
71 | reverse_kl: True
72 | w_neg_sample_loss: 0.0
73 |
--------------------------------------------------------------------------------
/configs/model/img_vera_supervised.yaml:
--------------------------------------------------------------------------------
1 | # @package fixed
2 |
3 | checkpoint_config:
4 |
5 | earlystop_config:
6 |
7 | ood_dataset: ${fixed.dataset}
8 |
9 | model_name: VERA
10 | model_config:
11 | n_classes: ${fixed.num_classes}
12 | arch_name: wrn
13 | arch_config:
14 | depth: 10
15 | num_classes: ${fixed.num_classes}
16 | widen_factor: 2
17 | input_channels: ${fixed.data_shape.2}
18 | strides: [1, 2, 2]
19 | learning_rate: 0.00003
20 | beta1: 0.0
21 | beta2: 0.9
22 | weight_decay: 0.0
23 | gen_learning_rate: 0.00006
24 | ebm_iters: 1
25 | generator_iters: 1
26 | entropy_weight: 0.0001
27 | generator_type: vera
28 | generator_arch_name: resnetgenerator
29 | generator_arch_config:
30 | unit_interval: False
31 | feats: 128
32 | out_channels: ${fixed.data_shape.2}
33 | generator_config:
34 | noise_dim: 128
35 | post_lr: 0.00003
36 | init_post_logsigma: 0.1
37 | min_sigma: 0.01
38 | max_sigma: 0.3
39 | p_control: 0.0
40 | n_control: 0.0
41 | pg_control: 0.1
42 | clf_ent_weight: 0.0001
43 | ebm_type: jem
44 | clf_weight: 100.0
45 | warmup_steps: 2500
46 | no_g_batch_norm: False
47 | batch_size: ${fixed.batch_size}
48 | lr_decay: 0.3
49 | lr_decay_epochs: [15, 18]
50 |
--------------------------------------------------------------------------------
/configs/seml_config.yaml:
--------------------------------------------------------------------------------
1 | seml:
2 | executable: uncertainty_est/train.py
3 | name: vera_dim_reduction
4 | output_dir: slurm
5 | project_root_dir: ..
6 |
7 | slurm:
8 | experiments_per_job: 1
9 | sbatch_options:
10 | gres: gpu:1 # num GPUs
11 | mem: 8G # memory
12 | cpus-per-task: 2 # num cores
13 | time: 0-20:00 # max time, D-HH:MM
14 | mail-type: FAIL
15 |
16 | ###### BEGIN PARAMETER CONFIGURATION ######
17 |
18 | fixed:
19 | trainer_config:
20 | max_epochs: 40
21 | gpus: 1
22 | benchmark: True
23 | limit_val_batches: 0
24 |
25 | checkpoint_config:
26 | monitor:
27 | save_last: True
28 |
29 | earlystop_config:
30 |
31 | test_ood_datasets:
32 | - lsun
33 | - svhn
34 | - svhn_unscaled
35 | - gaussian_noise
36 | - uniform_noise
37 |
38 | log_dir: grid_search/vera_dim_reduction
39 | dataset: &dataset cifar10
40 | # Use second dataset
41 | ood_dataset: *dataset
42 | seed: 1
43 | batch_size: &batch_size 32
44 | data_shape:
45 | - 32
46 | - 32
47 | - &n_channels 3
48 |
49 | model_name: VERA
50 | model_config:
51 | n_classes: &n_classes 1
52 |
53 | arch_name: wrn
54 | arch_config:
55 | depth: 28
56 | num_classes: *n_classes
57 | widen_factor: 10
58 | input_channels: *n_channels
59 | # strides: [1, 2, 2]
60 | learning_rate: 0.00003
61 | beta1: 0.0
62 | beta2: 0.9
63 | weight_decay: 0.0
64 | gen_learning_rate: 0.00006
65 | ebm_iters: 1
66 | generator_iters: 1
67 | entropy_weight: 0.0001
68 | generator_type: vera
69 | generator_arch_name: resnetgenerator
70 | generator_arch_config:
71 | unit_interval: False
72 | feats: 128
73 | out_channels: *n_channels
74 | generator_config:
75 | noise_dim: 128
76 | post_lr: 0.00003
77 | init_post_logsigma: 0.1
78 | min_sigma: 0.01
79 | max_sigma: 0.3
80 | p_control: 0.0
81 | n_control: 0.0
82 | pg_control: 0.1
83 | clf_ent_weight: 0.0
84 | ebm_type: p_x
85 | clf_weight: 0.0
86 | warmup_steps: 2500
87 | no_g_batch_norm: False
88 | batch_size: *batch_size
89 | lr_decay: 0.3
90 | lr_decay_epochs: [20, 30]
91 |
92 | grid:
93 | model_config:
94 | type: parameter_collection
95 | params:
96 | arch_config:
97 | type: parameter_collection
98 | params:
99 | strides:
100 | type: choice
101 | options:
102 | - [1, 2, 2]
103 | - [1, 1, 2]
104 | - [1, 1, 1]
105 |
--------------------------------------------------------------------------------
/density_histograms.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/density_histograms.png
--------------------------------------------------------------------------------
/req.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | absl-py=0.13.0=pypi_0
6 | aiohttp=3.7.4.post0=pypi_0
7 | antlr4-python3-runtime=4.8=pypi_0
8 | async-timeout=3.0.1=pypi_0
9 | attrs=21.2.0=pypi_0
10 | blas=1.0=mkl
11 | ca-certificates=2021.5.25=h06a4308_1
12 | cachetools=4.2.2=pypi_0
13 | certifi=2021.5.30=py38h06a4308_0
14 | chardet=4.0.0=pypi_0
15 | cudatoolkit=10.1.243=h6bb024c_0
16 | cycler=0.10.0=py38_0
17 | dbus=1.13.18=hb2f20db_0
18 | expat=2.4.1=h2531618_2
19 | fontconfig=2.13.1=h6c09931_0
20 | freetype=2.10.4=h5ab3b9f_0
21 | fsspec=2021.6.1=pypi_0
22 | future=0.18.2=pypi_0
23 | glib=2.68.2=h36276a3_0
24 | google-auth=1.32.1=pypi_0
25 | google-auth-oauthlib=0.4.4=pypi_0
26 | grpcio=1.38.1=pypi_0
27 | gst-plugins-base=1.14.0=h8213a91_2
28 | gstreamer=1.14.0=h28cd5cc_2
29 | hydra-core=1.0.0=pypi_0
30 | icu=58.2=he6710b0_3
31 | idna=2.10=pypi_0
32 | importlib-resources=5.2.0=pypi_0
33 | intel-openmp=2021.2.0=h06a4308_610
34 | joblib=1.0.1=pypi_0
35 | jpeg=9b=h024ee3a_2
36 | kiwisolver=1.3.1=py38h2531618_0
37 | lcms2=2.12=h3be6417_0
38 | ld_impl_linux-64=2.33.1=h53a641e_7
39 | libffi=3.3=he6710b0_2
40 | libgcc-ng=9.1.0=hdf63c60_0
41 | libgfortran-ng=7.5.0=ha8ba4b0_17
42 | libgfortran4=7.5.0=ha8ba4b0_17
43 | libpng=1.6.37=hbc83047_0
44 | libstdcxx-ng=9.1.0=hdf63c60_0
45 | libtiff=4.2.0=h85742a9_0
46 | libuuid=1.0.3=h1bed415_2
47 | libuv=1.40.0=h7b6447c_0
48 | libwebp-base=1.2.0=h27cfd23_0
49 | libxcb=1.14=h7b6447c_0
50 | libxml2=2.9.10=hb55368b_3
51 | lz4-c=1.9.3=h2531618_0
52 | markdown=3.3.4=pypi_0
53 | matplotlib=3.3.4=py38h06a4308_0
54 | matplotlib-base=3.3.4=py38h62a2d02_0
55 | mkl=2021.2.0=h06a4308_296
56 | mkl-service=2.3.0=py38h27cfd23_1
57 | mkl_fft=1.3.0=py38h42c9631_2
58 | mkl_random=1.2.1=py38ha9443f7_2
59 | multidict=5.1.0=pypi_0
60 | ncurses=6.2=he6710b0_1
61 | ninja=1.10.2=hff7bd54_1
62 | numpy=1.20.2=py38h2d18471_0
63 | numpy-base=1.20.2=py38hfae3a4d_0
64 | oauthlib=3.1.1=pypi_0
65 | olefile=0.46=py_0
66 | omegaconf=2.1.0=pypi_0
67 | openssl=1.1.1k=h27cfd23_0
68 | opt-einsum=3.3.0=pypi_0
69 | packaging=21.0=pypi_0
70 | pandas=1.2.5=py38h295c915_0
71 | pcre=8.45=h295c915_0
72 | pillow=8.2.0=py38he98fc37_0
73 | pip=21.0.1=py38h06a4308_0
74 | protobuf=3.17.3=pypi_0
75 | pyasn1=0.4.8=pypi_0
76 | pyasn1-modules=0.2.8=pypi_0
77 | pydeprecate=0.3.0=pypi_0
78 | pyparsing=2.4.7=pyhd3eb1b0_0
79 | pyqt=5.9.2=py38h05f1152_4
80 | pyro-api=0.1.2=pypi_0
81 | pyro-ppl=1.5.2=pypi_0
82 | python=3.8.8=hdb3f193_5
83 | python-dateutil=2.8.1=pyhd3eb1b0_0
84 | pytorch=1.7.1=py3.8_cuda10.1.243_cudnn7.6.3_0
85 | pytorch-lightning=1.3.8=pypi_0
86 | pytz=2021.1=pyhd3eb1b0_0
87 | pyyaml=5.4.1=py38h27cfd23_1
88 | qt=5.9.7=h5867ecd_1
89 | readline=8.1=h27cfd23_0
90 | requests=2.25.1=pypi_0
91 | requests-oauthlib=1.3.0=pypi_0
92 | rsa=4.7.2=pypi_0
93 | scikit-learn=0.24.2=pypi_0
94 | scipy=1.6.2=py38had2a1c9_1
95 | setuptools=52.0.0=py38h06a4308_0
96 | sip=4.19.13=py38he6710b0_0
97 | six=1.16.0=pyhd3eb1b0_0
98 | sqlite=3.35.4=hdfb4753_0
99 | tensorboard=2.4.1=pypi_0
100 | tensorboard-plugin-wit=1.8.0=pypi_0
101 | tfrecord=1.11=pypi_0
102 | threadpoolctl=2.1.0=pypi_0
103 | tk=8.6.10=hbc83047_0
104 | torchmetrics=0.4.0=pypi_0
105 | torchvision=0.8.2=py38_cu101
106 | tornado=6.1=py38h27cfd23_0
107 | tqdm=4.61.1=pypi_0
108 | typing_extensions=3.10.0.0=pyh06a4308_0
109 | uncertainty-eval=0.0.1=pypi_0
110 | urllib3=1.26.6=pypi_0
111 | werkzeug=2.0.1=pypi_0
112 | wheel=0.36.2=pyhd3eb1b0_0
113 | xz=5.2.5=h7b6447c_0
114 | yaml=0.2.5=h7b6447c_0
115 | yarl=1.6.3=pypi_0
116 | zipp=3.5.0=pypi_0
117 | zlib=1.2.11=h7b6447c_3
118 | zstd=1.4.9=haebb681_0
119 |
--------------------------------------------------------------------------------
/uncertainty_est/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/archs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/archs/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/archs/arch_factory.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import vgg16
2 |
3 | from uncertainty_est.archs.wrn import WideResNet
4 | from uncertainty_est.archs.fc import SynthModel
5 | from uncertainty_est.archs.resnet import ResNetGenerator
6 | from uncertainty_est.archs.flows import NormalizingFlowDensity
7 |
8 |
9 | def get_arch(name, config_dict: dict):
10 | if name == "wrn":
11 | return WideResNet(**config_dict)
12 | elif name == "vgg16":
13 | return vgg16(**config_dict)
14 | elif name == "fc":
15 | return SynthModel(**config_dict)
16 | elif name == "resnetgenerator":
17 | return ResNetGenerator(**config_dict)
18 | elif name == "normalizing_flow":
19 | return NormalizingFlowDensity(**config_dict)
20 | else:
21 | raise ValueError(f'Architecture "{name}" not implemented!')
22 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/fc.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def make_mlp(
5 | dim_list, activation="relu", batch_norm=False, dropout=0, bias=True, slope=1e-2
6 | ):
7 | layers = []
8 | if len(dim_list) > 2:
9 | for dim_in, dim_out in zip(dim_list[:-2], dim_list[1:-1]):
10 | layers.append(nn.Linear(dim_in, dim_out, bias=bias))
11 |
12 | if batch_norm:
13 | layers.append(nn.BatchNorm1d(dim_out, affine=True))
14 |
15 | if activation == "relu":
16 | layers.append(nn.ReLU())
17 | elif activation == "leaky_relu":
18 | layers.append(nn.LeakyReLU(slope, inplace=True))
19 | elif activation == "elu":
20 | layers.append(nn.ELU(inplace=True))
21 | else:
22 | raise NotImplementedError(f"Activation '{activation}' not implemented!")
23 |
24 | if dropout > 0:
25 | layers.append(nn.Dropout(p=dropout))
26 | layers.append(nn.Linear(dim_list[-2], dim_list[-1], bias=bias))
27 | model = nn.Sequential(*layers)
28 | return model
29 |
30 |
31 | class SynthModel(nn.Module):
32 | def __init__(
33 | self,
34 | inp_dim,
35 | num_classes,
36 | hidden_dims=[
37 | 50,
38 | 50,
39 | ],
40 | activation="leaky_relu",
41 | batch_norm=False,
42 | dropout=0.0,
43 | **kwargs,
44 | ):
45 | super().__init__()
46 | self.net = make_mlp(
47 | [
48 | inp_dim,
49 | ]
50 | + hidden_dims
51 | + [
52 | num_classes,
53 | ],
54 | activation=activation,
55 | batch_norm=batch_norm,
56 | dropout=dropout,
57 | **kwargs,
58 | )
59 |
60 | def forward(self, inp):
61 | return self.net(inp)
62 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/flows.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/sharpenb/Posterior-Network/blob/main/src/posterior_networks/NormalizingFlowDensity.py """
2 | import math
3 |
4 | import torch
5 | from torch import nn
6 | import torch.distributions as tdist
7 | from pyro.distributions.transforms import ELUTransform, LeakyReLUTransform
8 | from pyro.distributions.torch_transform import TransformModule
9 | from pyro.distributions.transforms import (
10 | Planar,
11 | Radial,
12 | affine_autoregressive,
13 | affine_coupling,
14 | permute,
15 | )
16 |
17 |
18 | class OrthogonalTransform(nn.Module):
19 | def __init__(self, dim):
20 | super().__init__()
21 | self.dim = dim
22 | self.transform = nn.Linear(dim, dim, bias=False)
23 |
24 | def forward(self, x):
25 | return self.transform(x)
26 |
27 | @staticmethod
28 | def log_abs_det_jacobian(z, z_next):
29 | return torch.zeros(z.shape[0], device=z.device)
30 |
31 | def compute_weight_penalty(self):
32 | sq_weight = torch.mm(self.transform.weight.T, self.transform.weight)
33 | identity = torch.eye(self.dim).to(sq_weight)
34 | penalty = torch.linalg.norm(identity - sq_weight, ord="fro")
35 |
36 | return penalty
37 |
38 |
39 | class ReparameterizedTransform(nn.Module):
40 | def __init__(self, dim):
41 | super().__init__()
42 | self.dim = dim
43 | self.bias = nn.Parameter(torch.zeros(dim))
44 |
45 | self.u_mat = nn.Parameter(torch.Tensor(dim, dim))
46 | self.v_mat = nn.Parameter(torch.Tensor(dim, dim))
47 | self.sigma = nn.Parameter(torch.ones(dim))
48 | nn.init.kaiming_uniform_(self.u_mat, a=math.sqrt(5))
49 | nn.init.kaiming_uniform_(self.v_mat, a=math.sqrt(5))
50 |
51 | def forward(self, x):
52 | weight = self.u_mat @ torch.diag_embed(self.sigma) @ self.v_mat.T
53 | lin_out = torch.nn.functional.linear(x, weight, self.bias)
54 | return lin_out
55 |
56 | def log_abs_det_jacobian(self, z, z_next):
57 | ladj = torch.sum(torch.log(self.sigma.abs()))
58 | ladj = torch.empty(z.size(0), device=z.device).fill_(ladj)
59 | return ladj
60 |
61 | def compute_weight_penalty(self):
62 | return self._weight_penalty(self.u_mat) + self._weight_penalty(self.v_mat)
63 |
64 | @staticmethod
65 | def _weight_penalty(weight):
66 | sq_weight = torch.mm(weight.T, weight)
67 | identity = torch.eye(weight.size(0)).to(sq_weight)
68 | penalty = torch.linalg.norm(identity - sq_weight, ord="fro")
69 | return penalty
70 |
71 |
72 | @torch.jit.script
73 | def sequential_mult(V, X):
74 | for row in range(V.shape[0] - 1, -1, -1):
75 | X = X - 2 * V[row : row + 1, :].t() @ (V[row : row + 1, :] @ X)
76 | return X
77 |
78 |
79 | @torch.jit.script
80 | def sequential_inv_mult(V, X):
81 | for row in range(V.shape[0]):
82 | X = X - 2 * V[row : row + 1, :].t() @ (V[row : row + 1, :] @ X)
83 | return X
84 |
85 |
86 | class Orthogonal(nn.Module):
87 | def __init__(self, d, m=28, strategy="sequential"):
88 | super(Orthogonal, self).__init__()
89 | self.d = d
90 | self.strategy = strategy
91 | self.U = torch.nn.Parameter(torch.zeros((d, d)).normal_(0, 0.05))
92 |
93 | if strategy == "fast":
94 | assert d % m == 0, (
95 | "The CUDA implementation assumes m=%i divides d=%i which, for current parameters, is not true. "
96 | % (d, m)
97 | )
98 | HouseProd.m = m
99 |
100 | if not strategy in ["fast", "sequential"]:
101 | raise NotImplementedError(
102 | "The only implemented strategies are 'fast' and 'sequential'. "
103 | )
104 |
105 | def forward(self, X):
106 | if self.strategy == "fast":
107 | X = HouseProd.apply(X, self.U)
108 | elif self.strategy == "sequential":
109 | X = sequential_mult(self.U, X.t()).t()
110 | else:
111 | raise NotImplementedError(
112 | "The only implemented strategies are 'fast' and 'sequential'. "
113 | )
114 | return X
115 |
116 | def inverse(self, X):
117 | if self.strategy == "fast":
118 | X = HouseProd.apply(X, torch.flip(self.U, dims=[0]))
119 | elif self.strategy == "sequential":
120 | X = sequential_inv_mult(self.U, X.t()).t()
121 | else:
122 | raise NotImplementedError(
123 | "The only implemented strategies are 'fast' and 'sequential'. "
124 | )
125 | return X
126 |
127 | def lgdet(self, X):
128 | return 0
129 |
130 |
131 | class MatExpOrthogonal(nn.Module):
132 | def __init__(self, d):
133 | super().__init__()
134 | self.param = nn.Parameter(torch.randn(d, d))
135 | self.cached = False
136 | self._cache = None
137 |
138 | @property
139 | def w(self):
140 | if self.cached:
141 | return self._cache
142 |
143 | # Compute skew symetric matrix
144 | aux = self.param.triu(diagonal=1)
145 | aux = aux - aux.t()
146 |
147 | w = torch.matrix_exp(aux)
148 | self.cached = True
149 | self._cache = w
150 | return w
151 |
152 | def forward(self, x):
153 | return x @ self.w
154 |
155 | def backward(*args, **kwargs):
156 | self.cached = False
157 | return super().backward(*args, **kwargs)
158 |
159 |
160 | class LinearSVD(torch.nn.Module):
161 | def __init__(self, d, orthogonal_param="householder", **kwargs):
162 | super(LinearSVD, self).__init__()
163 | self.d = d
164 |
165 | if orthogonal_param == "mat_exp":
166 | self.U = MatExpOrthogonal(d, **kwargs)
167 | self.V = MatExpOrthogonal(d, **kwargs)
168 | elif orthogonal_param == "householder":
169 | self.U = Orthogonal(d, **kwargs)
170 | self.V = Orthogonal(d, **kwargs)
171 | else:
172 | raise NotImplementedError(
173 | f"Orthogonal parameterization {orthogonal_param} not implemented."
174 | )
175 | self.D = nn.Parameter(torch.empty(d).uniform_(0.99, 1.01))
176 | self.bias = nn.Parameter(torch.zeros(d))
177 |
178 | def forward(self, X):
179 | X = self.U(X)
180 | X = X * self.D
181 | X = self.V(X)
182 | return X + self.bias
183 |
184 | def log_abs_det_jacobian(self, z, z_next):
185 | ladj = torch.log(torch.prod(self.D).abs())
186 | return torch.empty(z.size(0), device=z.device).fill_(ladj)
187 |
188 |
189 | class NonLinearity(nn.Module):
190 | def __init__(self, type_="elu"):
191 | super().__init__()
192 | nonlins = {
193 | "elu": ELUTransform,
194 | "leaky_relu": LeakyReLUTransform,
195 | }
196 | self.transform = nonlins[type_]()
197 |
198 | def forward(self, x):
199 | return self.transform(x)
200 |
201 | def log_abs_det_jacobian(self, z, z_next):
202 | return self.transform.log_abs_det_jacobian(z, z_next).sum(1)
203 |
204 | def compute_weight_penalty(self, *args, **kwargs):
205 | return 0.0
206 |
207 |
208 | class NormalizingFlowDensity(nn.Module):
209 | def __init__(self, dim, flow_length, flow_type="planar_flow", **kwargs):
210 | super(NormalizingFlowDensity, self).__init__()
211 | self.dim = dim
212 | self.flow_length = flow_length
213 | self.flow_type = flow_type
214 |
215 | self.mean = nn.Parameter(torch.zeros(self.dim), requires_grad=False)
216 | self.cov = nn.Parameter(torch.eye(self.dim), requires_grad=False)
217 |
218 | if self.flow_type == "radial_flow":
219 | self.transforms = nn.ModuleList([Radial(dim) for _ in range(flow_length)])
220 | elif self.flow_type == "iaf_flow":
221 | self.transforms = nn.ModuleList(
222 | [
223 | affine_autoregressive(dim, hidden_dims=[128, 128], **kwargs)
224 | for _ in range(flow_length)
225 | ]
226 | )
227 | elif self.flow_type == "planar_flow":
228 | self.transforms = nn.ModuleList([Planar(dim) for _ in range(flow_length)])
229 | elif self.flow_type == "affine_coupling":
230 | self.transforms = []
231 | for i in range(flow_length):
232 | coupling = affine_coupling(dim, hidden_dims=[128, 128], **kwargs)
233 | self.transforms.append(coupling)
234 | self.add_module(f"coupling_{i}", coupling)
235 |
236 | permutation = nn.Parameter(torch.randperm(dim), requires_grad=False)
237 | self.register_parameter(f"permutation_{i}", permutation)
238 | permute_layer = permute(dim, permutation=permutation)
239 | self.transforms.append(permute_layer)
240 |
241 | elif self.flow_type == "orthogonal_flow":
242 | self.transforms = nn.ModuleList(
243 | [OrthogonalTransform(dim) for _ in range(flow_length)]
244 | )
245 | elif self.flow_type == "reparameterized_flow":
246 | self.transforms = nn.ModuleList()
247 | for i in range(flow_length):
248 | self.transforms.append(ReparameterizedTransform(dim))
249 | if i != (flow_length - 1):
250 | self.transforms.append(NonLinearity("elu"))
251 | elif self.flow_type == "svd":
252 | self.transforms = nn.ModuleList()
253 | for i in range(flow_length):
254 | self.transforms.append(LinearSVD(dim))
255 | if i != (flow_length - 1):
256 | self.transforms.append(NonLinearity("elu"))
257 | elif self.flow_type == "svd_mat_exp":
258 | self.transforms = nn.ModuleList()
259 | for i in range(flow_length):
260 | self.transforms.append(LinearSVD(dim, orthogonal_param="mat_exp"))
261 | if i != (flow_length - 1):
262 | self.transforms.append(NonLinearity("elu"))
263 | else:
264 | raise NotImplementedError
265 |
266 | def forward(self, z):
267 | sum_log_jacobians = 0
268 | for transform in self.transforms:
269 | z_next = transform(z)
270 | sum_log_jacobians = sum_log_jacobians + transform.log_abs_det_jacobian(
271 | z, z_next
272 | )
273 | z = z_next
274 |
275 | return z, sum_log_jacobians
276 |
277 | def log_prob(self, x):
278 | z, sum_log_jacobians = self.forward(x)
279 | log_prob_z = tdist.MultivariateNormal(self.mean, self.cov).log_prob(z)
280 | log_prob_x = log_prob_z + sum_log_jacobians # [batch_size]
281 | return log_prob_x
282 |
283 | def sample(self, num):
284 | dist = tdist.MultivariateNormal(self.mean, self.cov)
285 | z = dist.sample([num])
286 | base_log_prob = dist.log_prob(z)
287 | sum_log_jacobians = 0.0
288 | for transform in self.transforms[::-1]:
289 | z_next = transform.inv(z)
290 | sum_log_jacobians = sum_log_jacobians - transform.log_abs_det_jacobian(
291 | z, z_next
292 | )
293 | z = z_next
294 | return z, base_log_prob + sum_log_jacobians
295 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/glow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/archs/glow/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/archs/glow/act_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def mean_dim(tensor, dim=None, keepdims=False):
6 | """Take the mean along multiple dimensions.
7 | Args:
8 | tensor (torch.Tensor): Tensor of values to average.
9 | dim (list): List of dimensions along which to take the mean.
10 | keepdims (bool): Keep dimensions rather than squeezing.
11 | Returns:
12 | mean (torch.Tensor): New tensor of mean value(s).
13 | """
14 | if dim is None:
15 | return tensor.mean()
16 | else:
17 | if isinstance(dim, int):
18 | dim = [dim]
19 | dim = sorted(dim)
20 | for d in dim:
21 | tensor = tensor.mean(dim=d, keepdim=True)
22 | if not keepdims:
23 | for i, d in enumerate(dim):
24 | tensor.squeeze_(d - i)
25 | return tensor
26 |
27 |
28 | class ActNorm(nn.Module):
29 | """Activation normalization for 2D inputs.
30 |
31 | The bias and scale get initialized using the mean and variance of the
32 | first mini-batch. After the init, bias and scale are trainable parameters.
33 |
34 | Adapted from:
35 | > https://github.com/openai/glow
36 | """
37 |
38 | def __init__(self, num_features, scale=1.0, return_ldj=False):
39 | super(ActNorm, self).__init__()
40 | self.register_buffer("is_initialized", torch.zeros(1))
41 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
42 | self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1))
43 |
44 | self.num_features = num_features
45 | self.scale = float(scale)
46 | self.eps = 1e-6
47 | self.return_ldj = return_ldj
48 |
49 | def initialize_parameters(self, x):
50 | if not self.training:
51 | return
52 |
53 | with torch.no_grad():
54 | bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True)
55 | v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True)
56 | logs = (self.scale / (v.sqrt() + self.eps)).log()
57 | self.bias.data.copy_(bias.data)
58 | self.logs.data.copy_(logs.data)
59 | self.is_initialized += 1.0
60 |
61 | def _center(self, x, reverse=False):
62 | if reverse:
63 | return x - self.bias
64 | else:
65 | return x + self.bias
66 |
67 | def _scale(self, x, sldj, reverse=False):
68 | logs = self.logs
69 | if reverse:
70 | x = x * logs.mul(-1).exp()
71 | else:
72 | x = x * logs.exp()
73 |
74 | if sldj is not None:
75 | ldj = logs.sum() * x.size(2) * x.size(3)
76 | if reverse:
77 | sldj = sldj - ldj
78 | else:
79 | sldj = sldj + ldj
80 |
81 | return x, sldj
82 |
83 | def forward(self, x, ldj=None, reverse=False):
84 | if not self.is_initialized:
85 | self.initialize_parameters(x)
86 |
87 | if reverse:
88 | x, ldj = self._scale(x, ldj, reverse)
89 | x = self._center(x, reverse)
90 | else:
91 | x = self._center(x, reverse)
92 | x, ldj = self._scale(x, ldj, reverse)
93 |
94 | if self.return_ldj:
95 | return x, ldj
96 |
97 | return x
98 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/glow/coupling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .act_norm import ActNorm
6 |
7 |
8 | class Coupling(nn.Module):
9 | """Affine coupling layer originally used in Real NVP and described by Glow.
10 |
11 | Note: The official Glow implementation (https://github.com/openai/glow)
12 | uses a different affine coupling formulation than described in the paper.
13 | This implementation follows the paper and Real NVP.
14 |
15 | Args:
16 | in_channels (int): Number of channels in the input.
17 | mid_channels (int): Number of channels in the intermediate activation
18 | in NN.
19 | """
20 |
21 | def __init__(self, in_channels, mid_channels):
22 | super(Coupling, self).__init__()
23 | self.nn = NN(in_channels, mid_channels, 2 * in_channels)
24 | self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))
25 |
26 | def forward(self, x, ldj, reverse=False):
27 | x_change, x_id = x.chunk(2, dim=1)
28 |
29 | st = self.nn(x_id)
30 | s, t = st[:, 0::2, ...], st[:, 1::2, ...]
31 | s = self.scale * torch.tanh(s)
32 |
33 | # Scale and translate
34 | if reverse:
35 | x_change = x_change * s.mul(-1).exp() - t
36 | ldj = ldj - s.flatten(1).sum(-1)
37 | else:
38 | x_change = (x_change + t) * s.exp()
39 | ldj = ldj + s.flatten(1).sum(-1)
40 |
41 | x = torch.cat((x_change, x_id), dim=1)
42 |
43 | return x, ldj
44 |
45 |
46 | class NN(nn.Module):
47 | """Small convolutional network used to compute scale and translate factors.
48 |
49 | Args:
50 | in_channels (int): Number of channels in the input.
51 | mid_channels (int): Number of channels in the hidden activations.
52 | out_channels (int): Number of channels in the output.
53 | use_act_norm (bool): Use activation norm rather than batch norm.
54 | """
55 |
56 | def __init__(self, in_channels, mid_channels, out_channels, use_act_norm=False):
57 | super(NN, self).__init__()
58 | norm_fn = ActNorm if use_act_norm else nn.BatchNorm2d
59 |
60 | self.in_norm = norm_fn(in_channels)
61 | self.in_conv = nn.Conv2d(
62 | in_channels, mid_channels, kernel_size=3, padding=1, bias=False
63 | )
64 | nn.init.normal_(self.in_conv.weight, 0.0, 0.05)
65 |
66 | self.mid_norm = norm_fn(mid_channels)
67 | self.mid_conv = nn.Conv2d(
68 | mid_channels, mid_channels, kernel_size=1, padding=0, bias=False
69 | )
70 | nn.init.normal_(self.mid_conv.weight, 0.0, 0.05)
71 |
72 | self.out_norm = norm_fn(mid_channels)
73 | self.out_conv = nn.Conv2d(
74 | mid_channels, out_channels, kernel_size=3, padding=1, bias=True
75 | )
76 | nn.init.zeros_(self.out_conv.weight)
77 | nn.init.zeros_(self.out_conv.bias)
78 |
79 | def forward(self, x):
80 | x = self.in_norm(x)
81 | x = F.relu(x)
82 | x = self.in_conv(x)
83 |
84 | x = self.mid_norm(x)
85 | x = F.relu(x)
86 | x = self.mid_conv(x)
87 |
88 | x = self.out_norm(x)
89 | x = F.relu(x)
90 | x = self.out_conv(x)
91 |
92 | return x
93 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/glow/glow.py:
--------------------------------------------------------------------------------
1 | """ From https://github.com/chrischute/glow """
2 |
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .act_norm import ActNorm
9 | from .coupling import Coupling
10 | from .inv_conv import InvConv
11 |
12 |
13 | class Glow(nn.Module):
14 | """Glow Model
15 |
16 | Based on the paper:
17 | "Glow: Generative Flow with Invertible 1x1 Convolutions"
18 | by Diederik P. Kingma, Prafulla Dhariwal
19 | (https://arxiv.org/abs/1807.03039).
20 |
21 | Args:
22 | in_channels (int): Number of cannels of input.
23 | num_channels (int): Number of channels in middle convolution of each
24 | step of flow.
25 | num_levels (int): Number of levels in the entire model.
26 | num_steps (int): Number of steps of flow for each level.
27 | """
28 |
29 | def __init__(self, in_channels, num_channels, num_levels, num_steps, k=256):
30 | super(Glow, self).__init__()
31 | self.k = k
32 |
33 | # Use bounds to rescale images before converting to logits, not learned
34 | self.register_buffer("bounds", torch.tensor([0.9], dtype=torch.float32))
35 | self.flows = _Glow(
36 | in_channels=4 * in_channels, # RGB image after squeeze
37 | mid_channels=num_channels,
38 | num_levels=num_levels,
39 | num_steps=num_steps,
40 | )
41 |
42 | def forward(self, x, reverse=False):
43 | if reverse:
44 | sldj = torch.zeros(x.size(0), device=x.device)
45 | else:
46 | # Expect inputs in [0, 1]
47 | if x.min() < 0:
48 | x = (x + 1) / 2.0
49 | elif x.max() > 1:
50 | print("Warning: Input not properly normalized!")
51 |
52 | # De-quantize and convert to logits
53 | x, sldj = self._pre_process(x)
54 |
55 | x = squeeze(x)
56 | x, sldj = self.flows(x, sldj, reverse)
57 | x = squeeze(x, reverse=True)
58 |
59 | return x, sldj
60 |
61 | def _pre_process(self, x):
62 | """Dequantize the input image `x` and convert to logits.
63 |
64 | See Also:
65 | - Dequantization: https://arxiv.org/abs/1511.01844, Section 3.1
66 | - Modeling logits: https://arxiv.org/abs/1605.08803, Section 4.1
67 |
68 | Args:
69 | x (torch.Tensor): Input image.
70 |
71 | Returns:
72 | y (torch.Tensor): Dequantized logits of `x`.
73 | """
74 | y = (x * 255.0 + torch.rand_like(x)) / 256.0
75 | y = (2 * y - 1) * self.bounds
76 | y = (y + 1) / 2
77 | y = y.log() - (1.0 - y).log()
78 |
79 | # Save log-determinant of Jacobian of initial transform
80 | ldj = (
81 | F.softplus(y)
82 | + F.softplus(-y)
83 | - F.softplus((1.0 - self.bounds).log() - self.bounds.log())
84 | )
85 | sldj = ldj.flatten(1).sum(-1)
86 |
87 | return y, sldj
88 |
89 | def log_prob(self, x):
90 | z, sldj = self.forward(x, reverse=False)
91 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
92 | prior_ll = prior_ll.reshape(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod(
93 | z.size()[1:]
94 | )
95 | ll = prior_ll + sldj
96 | return ll
97 |
98 |
99 | class _Glow(nn.Module):
100 | """Recursive constructor for a Glow model. Each call creates a single level.
101 |
102 | Args:
103 | in_channels (int): Number of channels in the input.
104 | mid_channels (int): Number of channels in hidden layers of each step.
105 | num_levels (int): Number of levels to construct. Counter for recursion.
106 | num_steps (int): Number of steps of flow for each level.
107 | """
108 |
109 | def __init__(self, in_channels, mid_channels, num_levels, num_steps):
110 | super(_Glow, self).__init__()
111 | self.steps = nn.ModuleList(
112 | [
113 | _FlowStep(in_channels=in_channels, mid_channels=mid_channels)
114 | for _ in range(num_steps)
115 | ]
116 | )
117 |
118 | if num_levels > 1:
119 | self.next = _Glow(
120 | in_channels=2 * in_channels,
121 | mid_channels=mid_channels,
122 | num_levels=num_levels - 1,
123 | num_steps=num_steps,
124 | )
125 | else:
126 | self.next = None
127 |
128 | def forward(self, x, sldj, reverse=False):
129 | if not reverse:
130 | for step in self.steps:
131 | x, sldj = step(x, sldj, reverse)
132 |
133 | if self.next is not None:
134 | x = squeeze(x)
135 | x, x_split = x.chunk(2, dim=1)
136 | x, sldj = self.next(x, sldj, reverse)
137 | x = torch.cat((x, x_split), dim=1)
138 | x = squeeze(x, reverse=True)
139 |
140 | if reverse:
141 | for step in reversed(self.steps):
142 | x, sldj = step(x, sldj, reverse)
143 |
144 | return x, sldj
145 |
146 |
147 | class _FlowStep(nn.Module):
148 | def __init__(self, in_channels, mid_channels):
149 | super(_FlowStep, self).__init__()
150 |
151 | # Activation normalization, invertible 1x1 convolution, affine coupling
152 | self.norm = ActNorm(in_channels, return_ldj=True)
153 | self.conv = InvConv(in_channels)
154 | self.coup = Coupling(in_channels // 2, mid_channels)
155 |
156 | def forward(self, x, sldj=None, reverse=False):
157 | if reverse:
158 | x, sldj = self.coup(x, sldj, reverse)
159 | x, sldj = self.conv(x, sldj, reverse)
160 | x, sldj = self.norm(x, sldj, reverse)
161 | else:
162 | x, sldj = self.norm(x, sldj, reverse)
163 | x, sldj = self.conv(x, sldj, reverse)
164 | x, sldj = self.coup(x, sldj, reverse)
165 |
166 | return x, sldj
167 |
168 |
169 | def squeeze(x, reverse=False):
170 | """Trade spatial extent for channels. In forward direction, convert each
171 | 1x4x4 volume of input into a 4x1x1 volume of output.
172 |
173 | Args:
174 | x (torch.Tensor): Input to squeeze or unsqueeze.
175 | reverse (bool): Reverse the operation, i.e., unsqueeze.
176 |
177 | Returns:
178 | x (torch.Tensor): Squeezed or unsqueezed tensor.
179 | """
180 | b, c, h, w = x.size()
181 | if reverse:
182 | # Unsqueeze
183 | x = x.view(b, c // 4, 2, 2, h, w)
184 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
185 | x = x.view(b, c // 4, h * 2, w * 2)
186 | else:
187 | # Squeeze
188 | x = x.view(b, c, h // 2, 2, w // 2, 2)
189 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
190 | x = x.view(b, c * 2 * 2, h // 2, w // 2)
191 |
192 | return x
193 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/glow/inv_conv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class InvConv(nn.Module):
8 | """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow
9 | (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version.
10 |
11 | Args:
12 | num_channels (int): Number of channels in the input and output.
13 | """
14 |
15 | def __init__(self, num_channels):
16 | super(InvConv, self).__init__()
17 | self.num_channels = num_channels
18 |
19 | # Initialize with a random orthogonal matrix
20 | w_init = np.random.randn(num_channels, num_channels)
21 | w_init = np.linalg.qr(w_init)[0].astype(np.float32)
22 | self.weight = nn.Parameter(torch.from_numpy(w_init))
23 |
24 | def forward(self, x, sldj, reverse=False):
25 | ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
26 |
27 | if reverse:
28 | weight = torch.inverse(self.weight.double()).float()
29 | sldj = sldj - ldj
30 | else:
31 | weight = self.weight
32 | sldj = sldj + ldj
33 |
34 | weight = weight.view(self.num_channels, self.num_channels, 1, 1)
35 | z = F.conv2d(x, weight)
36 |
37 | return z, sldj
38 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/real_nvp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/archs/real_nvp/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/archs/real_nvp/coupling_layer.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .resnet import ResNet
7 | from .util import checkerboard_mask
8 |
9 |
10 | class MaskType(IntEnum):
11 | CHECKERBOARD = 0
12 | CHANNEL_WISE = 1
13 |
14 |
15 | class CouplingLayer(nn.Module):
16 | """Coupling layer in RealNVP.
17 |
18 | Args:
19 | in_channels (int): Number of channels in the input.
20 | mid_channels (int): Number of channels in the `s` and `t` network.
21 | num_blocks (int): Number of residual blocks in the `s` and `t` network.
22 | mask_type (MaskType): One of `MaskType.CHECKERBOARD` or `MaskType.CHANNEL_WISE`.
23 | reverse_mask (bool): Whether to reverse the mask. Useful for alternating masks.
24 | """
25 |
26 | def __init__(self, in_channels, mid_channels, num_blocks, mask_type, reverse_mask):
27 | super(CouplingLayer, self).__init__()
28 |
29 | # Save mask info
30 | self.mask_type = mask_type
31 | self.reverse_mask = reverse_mask
32 |
33 | # Build scale and translate network
34 | if self.mask_type == MaskType.CHANNEL_WISE:
35 | in_channels //= 2
36 | self.st_net = ResNet(
37 | in_channels,
38 | mid_channels,
39 | 2 * in_channels,
40 | num_blocks=num_blocks,
41 | kernel_size=3,
42 | padding=1,
43 | double_after_norm=(self.mask_type == MaskType.CHECKERBOARD),
44 | )
45 |
46 | # Learnable scale for s
47 | self.rescale = nn.utils.weight_norm(Rescale(in_channels))
48 |
49 | def forward(self, x, sldj=None, reverse=True):
50 | if self.mask_type == MaskType.CHECKERBOARD:
51 | # Checkerboard mask
52 | b = checkerboard_mask(
53 | x.size(2), x.size(3), self.reverse_mask, device=x.device
54 | )
55 | x_b = x * b
56 | st = self.st_net(x_b)
57 | s, t = st.chunk(2, dim=1)
58 | s = self.rescale(torch.tanh(s))
59 | s = s * (1 - b)
60 | t = t * (1 - b)
61 |
62 | # Scale and translate
63 | if reverse:
64 | inv_exp_s = s.mul(-1).exp()
65 | if torch.isnan(inv_exp_s).any():
66 | raise RuntimeError("Scale factor has NaN entries")
67 | x = x * inv_exp_s - t
68 | else:
69 | exp_s = s.exp()
70 | if torch.isnan(exp_s).any():
71 | raise RuntimeError("Scale factor has NaN entries")
72 | x = (x + t) * exp_s
73 |
74 | # Add log-determinant of the Jacobian
75 | sldj += s.reshape(s.size(0), -1).sum(-1)
76 | else:
77 | # Channel-wise mask
78 | if self.reverse_mask:
79 | x_id, x_change = x.chunk(2, dim=1)
80 | else:
81 | x_change, x_id = x.chunk(2, dim=1)
82 |
83 | st = self.st_net(x_id)
84 | s, t = st.chunk(2, dim=1)
85 | s = self.rescale(torch.tanh(s))
86 |
87 | # Scale and translate
88 | if reverse:
89 | inv_exp_s = s.mul(-1).exp()
90 | if torch.isnan(inv_exp_s).any():
91 | raise RuntimeError("Scale factor has NaN entries")
92 | x_change = x_change * inv_exp_s - t
93 | else:
94 | exp_s = s.exp()
95 | if torch.isnan(exp_s).any():
96 | raise RuntimeError("Scale factor has NaN entries")
97 | x_change = (x_change + t) * exp_s
98 |
99 | # Add log-determinant of the Jacobian
100 | sldj += s.reshape(s.size(0), -1).sum(-1)
101 |
102 | if self.reverse_mask:
103 | x = torch.cat((x_id, x_change), dim=1)
104 | else:
105 | x = torch.cat((x_change, x_id), dim=1)
106 |
107 | return x, sldj
108 |
109 |
110 | class Rescale(nn.Module):
111 | """Per-channel rescaling. Need a proper `nn.Module` so we can wrap it
112 | with `torch.nn.utils.weight_norm`.
113 |
114 | Args:
115 | num_channels (int): Number of channels in the input.
116 | """
117 |
118 | def __init__(self, num_channels):
119 | super(Rescale, self).__init__()
120 | self.weight = nn.Parameter(torch.ones(num_channels, 1, 1))
121 |
122 | def forward(self, x):
123 | x = self.weight * x
124 | return x
125 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/real_nvp/real_nvp.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/chrischute/real-nvp """
2 |
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .coupling_layer import CouplingLayer, MaskType
9 | from .util import squeeze_2x2
10 |
11 |
12 | class RealNVPLoss(nn.Module):
13 | """Get the NLL loss for a RealNVP model.
14 | Args:
15 | k (int or float): Number of discrete values in each input dimension.
16 | E.g., `k` is 256 for natural images.
17 | See Also:
18 | Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803
19 | """
20 |
21 | def __init__(self, k=256):
22 | super(RealNVPLoss, self).__init__()
23 | self.k = k
24 |
25 | def forward(self, z, sldj):
26 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
27 | prior_ll = prior_ll.view(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod(
28 | z.size()[1:]
29 | )
30 | ll = prior_ll + sldj
31 | nll = -ll.mean()
32 |
33 | return nll
34 |
35 |
36 | class RealNVP(nn.Module):
37 | """RealNVP Model
38 |
39 | Based on the paper:
40 | "Density estimation using Real NVP"
41 | by Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio
42 | (https://arxiv.org/abs/1605.08803).
43 |
44 | Args:
45 | num_scales (int): Number of scales in the RealNVP model.
46 | in_channels (int): Number of channels in the input.
47 | mid_channels (int): Number of channels in the intermediate layers.
48 | num_blocks (int): Number of residual blocks in the s and t network of
49 | `Coupling` layers.
50 | """
51 |
52 | def __init__(
53 | self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, k=256
54 | ):
55 | super(RealNVP, self).__init__()
56 | self.k = k
57 | # Register data_constraint to pre-process images, not learnable
58 | self.register_buffer(
59 | "data_constraint", torch.tensor([0.9], dtype=torch.float32)
60 | )
61 |
62 | self.flows = _RealNVP(0, num_scales, in_channels, mid_channels, num_blocks)
63 |
64 | def forward(self, x, reverse=False, **kwargs):
65 | sldj = None
66 | if not reverse:
67 | # Expect inputs in [0, 1]
68 | if x.min() < 0:
69 | x = (x + 1) / 2.0
70 | elif x.max() > 1:
71 | print("Warning: Input not properly normalized!")
72 |
73 | # De-quantize and convert to logits
74 | x, sldj = self._pre_process(x)
75 |
76 | x, sldj = self.flows(x, sldj, reverse)
77 |
78 | return x, sldj
79 |
80 | def _pre_process(self, x):
81 | """Dequantize the input image `x` and convert to logits.
82 |
83 | Args:
84 | x (torch.Tensor): Input image.
85 |
86 | Returns:
87 | y (torch.Tensor): Dequantized logits of `x`.
88 |
89 | See Also:
90 | - Dequantization: https://arxiv.org/abs/1511.01844, Section 3.1
91 | - Modeling logits: https://arxiv.org/abs/1605.08803, Section 4.1
92 | """
93 | y = (x * 255.0 + torch.rand_like(x)) / 256.0
94 | y = (2 * y - 1) * self.data_constraint
95 | y = (y + 1) / 2
96 | y = y.log() - (1.0 - y).log()
97 |
98 | # Save log-determinant of Jacobian of initial transform
99 | ldj = (
100 | F.softplus(y)
101 | + F.softplus(-y)
102 | - F.softplus(
103 | (1.0 - self.data_constraint).log() - self.data_constraint.log()
104 | )
105 | )
106 | sldj = ldj.view(ldj.size(0), -1).sum(-1)
107 |
108 | return y, sldj
109 |
110 | def log_prob(self, x):
111 | z, sldj = self.forward(x, reverse=False)
112 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
113 | prior_ll = prior_ll.reshape(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod(
114 | z.size()[1:]
115 | )
116 | ll = prior_ll + sldj
117 | return ll
118 |
119 | def inverse(self, z, **kwargs):
120 | return self.forward(z, reverse=True)
121 |
122 |
123 | class _RealNVP(nn.Module):
124 | """Recursive builder for a `RealNVP` model.
125 |
126 | Each `_RealNVPBuilder` corresponds to a single scale in `RealNVP`,
127 | and the constructor is recursively called to build a full `RealNVP` model.
128 |
129 | Args:
130 | scale_idx (int): Index of current scale.
131 | num_scales (int): Number of scales in the RealNVP model.
132 | in_channels (int): Number of channels in the input.
133 | mid_channels (int): Number of channels in the intermediate layers.
134 | num_blocks (int): Number of residual blocks in the s and t network of
135 | `Coupling` layers.
136 | """
137 |
138 | def __init__(self, scale_idx, num_scales, in_channels, mid_channels, num_blocks):
139 | super(_RealNVP, self).__init__()
140 |
141 | self.is_last_block = scale_idx == num_scales - 1
142 |
143 | self.in_couplings = nn.ModuleList(
144 | [
145 | CouplingLayer(
146 | in_channels,
147 | mid_channels,
148 | num_blocks,
149 | MaskType.CHECKERBOARD,
150 | reverse_mask=False,
151 | ),
152 | CouplingLayer(
153 | in_channels,
154 | mid_channels,
155 | num_blocks,
156 | MaskType.CHECKERBOARD,
157 | reverse_mask=True,
158 | ),
159 | CouplingLayer(
160 | in_channels,
161 | mid_channels,
162 | num_blocks,
163 | MaskType.CHECKERBOARD,
164 | reverse_mask=False,
165 | ),
166 | ]
167 | )
168 |
169 | if self.is_last_block:
170 | self.in_couplings.append(
171 | CouplingLayer(
172 | in_channels,
173 | mid_channels,
174 | num_blocks,
175 | MaskType.CHECKERBOARD,
176 | reverse_mask=True,
177 | )
178 | )
179 | else:
180 | self.out_couplings = nn.ModuleList(
181 | [
182 | CouplingLayer(
183 | 4 * in_channels,
184 | 2 * mid_channels,
185 | num_blocks,
186 | MaskType.CHANNEL_WISE,
187 | reverse_mask=False,
188 | ),
189 | CouplingLayer(
190 | 4 * in_channels,
191 | 2 * mid_channels,
192 | num_blocks,
193 | MaskType.CHANNEL_WISE,
194 | reverse_mask=True,
195 | ),
196 | CouplingLayer(
197 | 4 * in_channels,
198 | 2 * mid_channels,
199 | num_blocks,
200 | MaskType.CHANNEL_WISE,
201 | reverse_mask=False,
202 | ),
203 | ]
204 | )
205 | self.next_block = _RealNVP(
206 | scale_idx + 1, num_scales, 2 * in_channels, 2 * mid_channels, num_blocks
207 | )
208 |
209 | def forward(self, x, sldj, reverse=False):
210 | if reverse:
211 | if not self.is_last_block:
212 | # Re-squeeze -> split -> next block
213 | x = squeeze_2x2(x, reverse=False, alt_order=True)
214 | x, x_split = x.chunk(2, dim=1)
215 | x, sldj = self.next_block(x, sldj, reverse)
216 | x = torch.cat((x, x_split), dim=1)
217 | x = squeeze_2x2(x, reverse=True, alt_order=True)
218 |
219 | # Squeeze -> 3x coupling (channel-wise)
220 | x = squeeze_2x2(x, reverse=False)
221 | for coupling in reversed(self.out_couplings):
222 | x, sldj = coupling(x, sldj, reverse)
223 | x = squeeze_2x2(x, reverse=True)
224 |
225 | for coupling in reversed(self.in_couplings):
226 | x, sldj = coupling(x, sldj, reverse)
227 | else:
228 | for coupling in self.in_couplings:
229 | x, sldj = coupling(x, sldj, reverse)
230 |
231 | if not self.is_last_block:
232 | # Squeeze -> 3x coupling (channel-wise)
233 | x = squeeze_2x2(x, reverse=False)
234 | for coupling in self.out_couplings:
235 | x, sldj = coupling(x, sldj, reverse)
236 | x = squeeze_2x2(x, reverse=True)
237 |
238 | # Re-squeeze -> split -> next block
239 | x = squeeze_2x2(x, reverse=False, alt_order=True)
240 | x, x_split = x.chunk(2, dim=1)
241 | x, sldj = self.next_block(x, sldj, reverse)
242 | x = torch.cat((x, x_split), dim=1)
243 | x = squeeze_2x2(x, reverse=True, alt_order=True)
244 |
245 | return x, sldj
246 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/real_nvp/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class WNConv2d(nn.Module):
7 | """Weight-normalized 2d convolution.
8 | Args:
9 | in_channels (int): Number of channels in the input.
10 | out_channels (int): Number of channels in the output.
11 | kernel_size (int): Side length of each convolutional kernel.
12 | padding (int): Padding to add on edges of input.
13 | bias (bool): Use bias in the convolution operation.
14 | """
15 |
16 | def __init__(self, in_channels, out_channels, kernel_size, padding, bias=True):
17 | super(WNConv2d, self).__init__()
18 | self.conv = nn.utils.weight_norm(
19 | nn.Conv2d(
20 | in_channels, out_channels, kernel_size, padding=padding, bias=bias
21 | )
22 | )
23 |
24 | def forward(self, x):
25 | x = self.conv(x)
26 |
27 | return x
28 |
29 |
30 | class ResidualBlock(nn.Module):
31 | """ResNet basic block with weight norm."""
32 |
33 | def __init__(self, in_channels, out_channels):
34 | super(ResidualBlock, self).__init__()
35 |
36 | self.in_norm = nn.BatchNorm2d(in_channels)
37 | self.in_conv = WNConv2d(
38 | in_channels, out_channels, kernel_size=3, padding=1, bias=False
39 | )
40 |
41 | self.out_norm = nn.BatchNorm2d(out_channels)
42 | self.out_conv = WNConv2d(
43 | out_channels, out_channels, kernel_size=3, padding=1, bias=True
44 | )
45 |
46 | def forward(self, x):
47 | skip = x
48 |
49 | x = self.in_norm(x)
50 | x = F.relu(x)
51 | x = self.in_conv(x)
52 |
53 | x = self.out_norm(x)
54 | x = F.relu(x)
55 | x = self.out_conv(x)
56 |
57 | x = x + skip
58 |
59 | return x
60 |
61 |
62 | class ResNet(nn.Module):
63 | """ResNet for scale and translate factors in Real NVP.
64 |
65 | Args:
66 | in_channels (int): Number of channels in the input.
67 | mid_channels (int): Number of channels in the intermediate layers.
68 | out_channels (int): Number of channels in the output.
69 | num_blocks (int): Number of residual blocks in the network.
70 | kernel_size (int): Side length of each filter in convolutional layers.
71 | padding (int): Padding for convolutional layers.
72 | double_after_norm (bool): Double input after input BatchNorm.
73 | """
74 |
75 | def __init__(
76 | self,
77 | in_channels,
78 | mid_channels,
79 | out_channels,
80 | num_blocks,
81 | kernel_size,
82 | padding,
83 | double_after_norm,
84 | ):
85 | super(ResNet, self).__init__()
86 | self.in_norm = nn.BatchNorm2d(in_channels)
87 | self.double_after_norm = double_after_norm
88 | self.in_conv = WNConv2d(
89 | 2 * in_channels, mid_channels, kernel_size, padding, bias=True
90 | )
91 | self.in_skip = WNConv2d(
92 | mid_channels, mid_channels, kernel_size=1, padding=0, bias=True
93 | )
94 |
95 | self.blocks = nn.ModuleList(
96 | [ResidualBlock(mid_channels, mid_channels) for _ in range(num_blocks)]
97 | )
98 | self.skips = nn.ModuleList(
99 | [
100 | WNConv2d(
101 | mid_channels, mid_channels, kernel_size=1, padding=0, bias=True
102 | )
103 | for _ in range(num_blocks)
104 | ]
105 | )
106 |
107 | self.out_norm = nn.BatchNorm2d(mid_channels)
108 | self.out_conv = WNConv2d(
109 | mid_channels, out_channels, kernel_size=1, padding=0, bias=True
110 | )
111 |
112 | def forward(self, x):
113 | x = self.in_norm(x)
114 | if self.double_after_norm:
115 | x *= 2.0
116 | x = torch.cat((x, -x), dim=1)
117 | x = F.relu(x)
118 | x = self.in_conv(x)
119 | x_skip = self.in_skip(x)
120 |
121 | for block, skip in zip(self.blocks, self.skips):
122 | x = block(x)
123 | x_skip += skip(x)
124 |
125 | x = self.out_norm(x_skip)
126 | x = F.relu(x)
127 | x = self.out_conv(x)
128 |
129 | return x
130 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/real_nvp/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def squeeze_2x2(x, reverse=False, alt_order=False):
6 | """For each spatial position, a sub-volume of shape `1x1x(N^2 * C)`,
7 | reshape into a sub-volume of shape `NxNxC`, where `N = block_size`.
8 |
9 | Adapted from:
10 | https://github.com/tensorflow/models/blob/master/research/real_nvp/real_nvp_utils.py
11 |
12 | See Also:
13 | - TensorFlow nn.depth_to_space: https://www.tensorflow.org/api_docs/python/tf/nn/depth_to_space
14 | - Figure 3 of RealNVP paper: https://arxiv.org/abs/1605.08803
15 |
16 | Args:
17 | x (torch.Tensor): Input tensor of shape (B, C, H, W).
18 | reverse (bool): Whether to do a reverse squeeze (unsqueeze).
19 | alt_order (bool): Whether to use alternate ordering.
20 | """
21 | block_size = 2
22 | if alt_order:
23 | n, c, h, w = x.size()
24 |
25 | if reverse:
26 | if c % 4 != 0:
27 | raise ValueError(
28 | "Number of channels must be divisible by 4, got {}.".format(c)
29 | )
30 | c //= 4
31 | else:
32 | if h % 2 != 0:
33 | raise ValueError("Height must be divisible by 2, got {}.".format(h))
34 | if w % 2 != 0:
35 | raise ValueError("Width must be divisible by 4, got {}.".format(w))
36 | # Defines permutation of input channels (shape is (4, 1, 2, 2)).
37 | squeeze_matrix = torch.tensor(
38 | [
39 | [[[1.0, 0.0], [0.0, 0.0]]],
40 | [[[0.0, 0.0], [0.0, 1.0]]],
41 | [[[0.0, 1.0], [0.0, 0.0]]],
42 | [[[0.0, 0.0], [1.0, 0.0]]],
43 | ],
44 | dtype=x.dtype,
45 | device=x.device,
46 | )
47 | perm_weight = torch.zeros((4 * c, c, 2, 2), dtype=x.dtype, device=x.device)
48 | for c_idx in range(c):
49 | slice_0 = slice(c_idx * 4, (c_idx + 1) * 4)
50 | slice_1 = slice(c_idx, c_idx + 1)
51 | perm_weight[slice_0, slice_1, :, :] = squeeze_matrix
52 | shuffle_channels = torch.tensor(
53 | [c_idx * 4 for c_idx in range(c)]
54 | + [c_idx * 4 + 1 for c_idx in range(c)]
55 | + [c_idx * 4 + 2 for c_idx in range(c)]
56 | + [c_idx * 4 + 3 for c_idx in range(c)]
57 | )
58 | perm_weight = perm_weight[shuffle_channels, :, :, :]
59 |
60 | if reverse:
61 | x = F.conv_transpose2d(x, perm_weight, stride=2)
62 | else:
63 | x = F.conv2d(x, perm_weight, stride=2)
64 | else:
65 | b, c, h, w = x.size()
66 | x = x.permute(0, 2, 3, 1)
67 |
68 | if reverse:
69 | if c % 4 != 0:
70 | raise ValueError(
71 | "Number of channels {} is not divisible by 4".format(c)
72 | )
73 | x = x.view(b, h, w, c // 4, 2, 2)
74 | x = x.permute(0, 1, 4, 2, 5, 3)
75 | x = x.contiguous().view(b, 2 * h, 2 * w, c // 4)
76 | else:
77 | if h % 2 != 0 or w % 2 != 0:
78 | raise ValueError(
79 | "Expected even spatial dims HxW, got {}x{}".format(h, w)
80 | )
81 | x = x.view(b, h // 2, 2, w // 2, 2, c)
82 | x = x.permute(0, 1, 3, 5, 2, 4)
83 | x = x.contiguous().view(b, h // 2, w // 2, c * 4)
84 |
85 | x = x.permute(0, 3, 1, 2)
86 |
87 | return x
88 |
89 |
90 | def checkerboard_mask(
91 | height, width, reverse=False, dtype=torch.float32, device=None, requires_grad=False
92 | ):
93 | """Get a checkerboard mask, such that no two entries adjacent entries
94 | have the same value. In non-reversed mask, top-left entry is 0.
95 |
96 | Args:
97 | height (int): Number of rows in the mask.
98 | width (int): Number of columns in the mask.
99 | reverse (bool): If True, reverse the mask (i.e., make top-left entry 1).
100 | Useful for alternating masks in RealNVP.
101 | dtype (torch.dtype): Data type of the tensor.
102 | device (torch.device): Device on which to construct the tensor.
103 | requires_grad (bool): Whether the tensor requires gradient.
104 |
105 |
106 | Returns:
107 | mask (torch.tensor): Checkerboard mask of shape (1, 1, height, width).
108 | """
109 | checkerboard = [[((i % 2) + j) % 2 for j in range(width)] for i in range(height)]
110 | mask = torch.tensor(
111 | checkerboard, dtype=dtype, device=device, requires_grad=requires_grad
112 | )
113 |
114 | if reverse:
115 | mask = 1 - mask
116 |
117 | # Reshape to (1, 1, height, width) for broadcasting with tensors of shape (B, C, H, W)
118 | mask = mask.view(1, 1, height, width)
119 |
120 | return mask
121 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import init as nninit
4 |
5 |
6 | class GeneratorBlock(nn.Module):
7 | """ResNet-style block for the generator model."""
8 |
9 | def __init__(self, in_chans, out_chans, upsample=False):
10 | super().__init__()
11 |
12 | self.upsample = upsample
13 |
14 | if in_chans != out_chans:
15 | self.shortcut_conv = nn.Conv2d(in_chans, out_chans, kernel_size=1)
16 | else:
17 | self.shortcut_conv = None
18 | self.bn1 = nn.BatchNorm2d(in_chans)
19 | self.conv1 = nn.Conv2d(in_chans, in_chans, kernel_size=3, padding=1)
20 | self.bn2 = nn.BatchNorm2d(in_chans)
21 | self.conv2 = nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1)
22 | self.relu = nn.ReLU()
23 |
24 | def forward(self, *inputs):
25 | x = inputs[0]
26 |
27 | if self.upsample:
28 | shortcut = nn.functional.interpolate(x, scale_factor=2, mode="nearest")
29 | else:
30 | shortcut = x
31 |
32 | if self.shortcut_conv is not None:
33 | shortcut = self.shortcut_conv(shortcut)
34 |
35 | x = self.bn1(x)
36 | x = self.relu(x)
37 | if self.upsample:
38 | x = nn.functional.interpolate(x, scale_factor=2, mode="nearest")
39 | x = self.conv1(x)
40 | x = self.bn2(x)
41 | x = self.relu(x)
42 | x = self.conv2(x)
43 |
44 | return x + shortcut
45 |
46 |
47 | class ResNetGenerator(nn.Module):
48 | """The generator model."""
49 |
50 | def __init__(self, unit_interval, feats=128, out_channels=3):
51 | super().__init__()
52 |
53 | self.input_linear = nn.Linear(feats, 4 * 4 * feats)
54 | self.block1 = GeneratorBlock(feats, feats, upsample=True)
55 | self.block2 = GeneratorBlock(feats, feats, upsample=True)
56 | self.block3 = GeneratorBlock(feats, feats, upsample=True)
57 | self.output_bn = nn.BatchNorm2d(feats)
58 | self.output_conv = nn.Conv2d(feats, out_channels, kernel_size=3, padding=1)
59 | self.relu = nn.ReLU()
60 | self.feats = feats
61 |
62 | # Apply Xavier initialisation to the weights
63 | relu_gain = nninit.calculate_gain("relu")
64 | for module in self.modules():
65 | if isinstance(module, (nn.Conv2d, nn.Linear)):
66 | gain = relu_gain if module != self.input_linear else 1.0
67 | nninit.xavier_uniform_(module.weight.data, gain=gain)
68 | module.bias.data.zero_()
69 |
70 | if unit_interval == True:
71 | self.final_act = nn.functional.sigmoid
72 | elif unit_interval == False:
73 | self.final_act = torch.tanh
74 | else:
75 | self.final_act = nn.Identity()
76 |
77 | self.last_output = None
78 |
79 | def forward(self, *inputs):
80 | x = inputs[0]
81 |
82 | x = self.input_linear(x)
83 | x = x.view(-1, self.feats, 4, 4)
84 | x = self.block1(x)
85 | x = self.block2(x)
86 | x = self.block3(x)
87 | x = self.output_bn(x)
88 | x = self.relu(x)
89 | x = self.output_conv(x)
90 | x = self.final_act(x)
91 |
92 | self.last_output = x
93 |
94 | return x
95 |
--------------------------------------------------------------------------------
/uncertainty_est/archs/wrn.py:
--------------------------------------------------------------------------------
1 | """
2 | WideResnet architecture adapted from https://github.com/meliketoy/wide-resnet.pytorch
3 | """
4 |
5 | import numpy as np
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.nn.init as init
9 |
10 |
11 | def conv3x3(in_planes, out_planes, stride=1):
12 | """
13 | Convolution with 3x3 kernels.
14 | """
15 | return nn.Conv2d(
16 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True
17 | )
18 |
19 |
20 | def conv_init(m):
21 | """
22 | Initializing convolution layers.
23 | """
24 | classname = m.__class__.__name__
25 | if classname.find("Conv") != -1:
26 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
27 | init.constant_(m.bias, 0)
28 | elif classname.find("BatchNorm") != -1:
29 | init.constant_(m.weight, 1)
30 | init.constant_(m.bias, 0)
31 |
32 |
33 | class Identity(nn.Module):
34 | """
35 | Identity norm as a stand in for no BN.
36 | """
37 |
38 | def __init__(self, *args, **kwargs):
39 | super().__init__()
40 |
41 | def forward(self, x):
42 | """
43 | Forward pass of model.
44 | """
45 | return x
46 |
47 |
48 | class wide_basic(nn.Module):
49 | """
50 | One block in the Wide resnet.
51 | """
52 |
53 | def __init__(
54 | self,
55 | in_planes,
56 | planes,
57 | dropout_rate,
58 | stride=1,
59 | norm=None,
60 | leak=0.2,
61 | first=False,
62 | ):
63 | super(wide_basic, self).__init__()
64 | self.lrelu = nn.LeakyReLU(leak)
65 | self.bn1 = get_norm(in_planes, norm)
66 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
67 | self.dropout = Identity() if dropout_rate == 0.0 else nn.Dropout(p=dropout_rate)
68 | self.bn2 = get_norm(planes, norm)
69 | self.conv2 = nn.Conv2d(
70 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True
71 | )
72 |
73 | self.first = first
74 |
75 | self.shortcut = nn.Sequential()
76 | if stride != 1 or in_planes != planes:
77 | self.shortcut = nn.Sequential(
78 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
79 | )
80 |
81 | def forward(self, x):
82 | """
83 | Forward pass of block.
84 | """
85 | if (
86 | self.first
87 | ): # if it's the first block, don't apply the first batchnorm to the data
88 | out = self.dropout(self.conv1(self.lrelu(x)))
89 | else:
90 | out = self.dropout(self.conv1(self.lrelu(self.bn1(x))))
91 | out = self.conv2(self.lrelu(self.bn2(out)))
92 | out += self.shortcut(x)
93 |
94 | return out
95 |
96 |
97 | def get_norm(n_filters, norm):
98 | """
99 | Get batchnorm or other.
100 | """
101 | if norm is None:
102 | return Identity()
103 | elif norm == "batch":
104 | return nn.BatchNorm2d(n_filters)
105 | elif norm == "instance":
106 | return nn.InstanceNorm2d(n_filters, affine=True)
107 | elif norm == "layer":
108 | return nn.GroupNorm(1, n_filters)
109 | elif norm == "group":
110 | return nn.GroupNorm(32, n_filters)
111 |
112 |
113 | class WideResNet(nn.Module):
114 | """
115 | Wide resnet model.
116 | """
117 |
118 | def __init__(
119 | self,
120 | depth,
121 | widen_factor,
122 | num_classes=10,
123 | input_channels=3,
124 | sum_pool=False,
125 | norm=None,
126 | leak=0.2,
127 | dropout=0.0,
128 | strides=(1, 2, 2),
129 | bottleneck_dim=None,
130 | bottleneck_channels_factor=None,
131 | ):
132 | super(WideResNet, self).__init__()
133 | self.leak = leak
134 | self.in_planes = 16
135 | self.sum_pool = sum_pool
136 | self.norm = norm
137 | self.lrelu = nn.LeakyReLU(leak)
138 | self.bottleneck_dim = bottleneck_dim
139 | self.bottleneck_channels_factor = bottleneck_channels_factor
140 |
141 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4"
142 | n = (depth - 4) // 6
143 | k = widen_factor
144 |
145 | print("| Wide-Resnet %dx%d" % (depth, k))
146 | nStages = [16, 16 * k, 32 * k, 64 * k]
147 |
148 | self.conv1 = conv3x3(input_channels, nStages[0])
149 | self.layer1 = self._wide_layer(
150 | wide_basic, nStages[1], n, dropout, stride=strides[0], first=True
151 | )
152 | self.layer2 = self._wide_layer(
153 | wide_basic, nStages[2], n, dropout, stride=strides[1]
154 | )
155 | self.layer3 = self._wide_layer(
156 | wide_basic, nStages[3], n, dropout, stride=strides[2]
157 | )
158 | self.bn1 = get_norm(nStages[3], self.norm)
159 | self.last_dim = nStages[3]
160 | self.linear = nn.Linear(nStages[3], num_classes)
161 |
162 | if self.bottleneck_dim is not None:
163 | self.bottleneck = nn.Sequential(
164 | nn.Linear(nStages[3], nStages[3] // 2),
165 | nn.ReLU(True),
166 | nn.Linear(nStages[3] // 2, self.bottleneck_dim),
167 | nn.ReLU(True),
168 | nn.Linear(self.bottleneck_dim, nStages[3] // 2),
169 | nn.ReLU(True),
170 | nn.Linear(nStages[3] // 2, nStages[3]),
171 | )
172 |
173 | self.apply(conv_init)
174 |
175 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride, first=False):
176 | strides = [stride] + [1] * (num_blocks - 1)
177 | layers = []
178 |
179 | for i, stride in enumerate(strides):
180 | if first and i == 0: # first block of first layer has no BN
181 | layers.append(
182 | block(
183 | self.in_planes,
184 | planes,
185 | dropout_rate,
186 | stride,
187 | norm=self.norm,
188 | first=True,
189 | )
190 | )
191 | else:
192 | layers.append(
193 | block(self.in_planes, planes, dropout_rate, stride, norm=self.norm)
194 | )
195 | self.in_planes = planes
196 |
197 | if self.bottleneck_channels_factor is not None:
198 | bottleneck_channels = int(self.in_planes * self.bottleneck_channels_factor)
199 | layers.extend(
200 | [
201 | nn.Conv2d(self.in_planes, bottleneck_channels, kernel_size=1),
202 | nn.Conv2d(bottleneck_channels, self.in_planes, kernel_size=1),
203 | ]
204 | )
205 |
206 | return nn.Sequential(*layers)
207 |
208 | def encode(self, x, vx=None):
209 | out = self.conv1(x)
210 | out = self.layer1(out)
211 | out = self.layer2(out)
212 | out = self.layer3(out)
213 | out = self.lrelu(self.bn1(out))
214 | if self.sum_pool:
215 | out = out.view(out.size(0), out.size(1), -1).sum(2)
216 | else:
217 | out = F.avg_pool2d(out, out.shape[2:])
218 | out = out.view(out.size(0), -1)
219 |
220 | if self.bottleneck_dim is not None:
221 | out = self.bottleneck(out)
222 | return out
223 |
224 | def forward(self, x, vx=None):
225 | """
226 | Forward pass. TODO: purpose of vx?
227 | """
228 | out = self.encode(x, vx)
229 |
230 | return self.linear(out)
231 |
232 |
233 | if __name__ == "__main__":
234 | import torch
235 |
236 | for strides in [(1, 2, 2), (1, 1, 2), (1, 1, 1)]:
237 | wrn = WideResNet(28, 10, strides=strides, bottleneck_channels_factor=0.1)
238 | print(wrn)
239 | out = wrn(torch.zeros(1, 3, 32, 32))
240 | print(strides, out.shape)
241 |
--------------------------------------------------------------------------------
/uncertainty_est/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/data/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/data/dataloaders.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from torch.utils.data import DataLoader
5 | import PIL.Image as InterpolationMode
6 | from torchvision import transforms as tvt
7 | from uncertainty_eval.datasets import get_dataset as ue_get_dataset
8 |
9 | from uncertainty_est.data.datasets import ConcatDataset, ConcatIterableDataset
10 |
11 | DATA_ROOT = Path("../data")
12 |
13 |
14 | def get_dataset(dataset, data_shape=None, length=10_000, split_seed=1):
15 | try:
16 | ds_class = ue_get_dataset(dataset)
17 |
18 | if data_shape is None:
19 | data_shape = ds_class.data_shape
20 |
21 | if dataset == "gaussian_noise":
22 | m = 127.5 if len(data_shape) == 3 else 0.0
23 | s = 60.0 if len(data_shape) == 3 else 1.0
24 | mean = torch.empty(*data_shape).fill_(m)
25 | std = torch.empty(*data_shape).fill_(s)
26 | ds = ds_class(DATA_ROOT, length=length, mean=mean, std=std)
27 | elif dataset == "uniform_noise":
28 | l = 0.0 if len(data_shape) == 3 else -5.0
29 | h = 255.0 if len(data_shape) == 3 else 5.0
30 | low = torch.empty(*data_shape).fill_(l)
31 | high = torch.empty(*data_shape).fill_(h)
32 | ds = ds_class(DATA_ROOT, length=length, low=low, high=high)
33 | elif dataset == "constant":
34 | low = 0.0 if len(data_shape) == 3 else -5.0
35 | high = 255.0 if len(data_shape) == 3 else 5.0
36 | ds = ds_class(
37 | DATA_ROOT, length=length, low=low, high=high, shape=data_shape
38 | )
39 | else:
40 | ds = ds_class(DATA_ROOT)
41 | except KeyError as e:
42 | raise ValueError(f'Dataset "{dataset}" not supported') from e
43 | return ds, data_shape
44 |
45 |
46 | def get_dataloader(
47 | dataset,
48 | split,
49 | batch_size=32,
50 | data_shape=None,
51 | ood_dataset=None,
52 | sigma=0.0,
53 | num_workers=0,
54 | drop_last=None,
55 | shuffle=None,
56 | mutation_rate=0.0,
57 | split_seed=1,
58 | normalize=True,
59 | extra_train_transforms=[],
60 | extra_test_transforms=[],
61 | ):
62 | train_transform = []
63 | test_transform = []
64 |
65 | unscaled = False
66 | try:
67 | ds, data_shape = get_dataset(dataset, data_shape, split_seed=split_seed)
68 | except ValueError as e:
69 | if "_unscaled" in dataset:
70 | dataset = dataset.replace("_unscaled", "")
71 | unscaled = True
72 | ds, data_shape = get_dataset(dataset, data_shape, split_seed=split_seed)
73 | else:
74 | raise e
75 |
76 | if len(data_shape) == 3:
77 | img_size = data_shape[1]
78 | train_transform.extend(
79 | [
80 | tvt.Resize(img_size, InterpolationMode.LANCZOS),
81 | tvt.CenterCrop(img_size),
82 | tvt.Pad(4, padding_mode="reflect"),
83 | tvt.RandomRotation(15, resample=InterpolationMode.BILINEAR),
84 | tvt.RandomHorizontalFlip(),
85 | tvt.RandomCrop(img_size),
86 | ]
87 | )
88 |
89 | test_transform.extend(
90 | [
91 | tvt.Resize(img_size, InterpolationMode.LANCZOS),
92 | tvt.CenterCrop(img_size),
93 | ]
94 | )
95 |
96 | if unscaled:
97 | scale_transform = [tvt.ToTensor(), tvt.Lambda(lambda x: x * 255.0)]
98 | else:
99 | scale_transform = [tvt.ToTensor()]
100 | if normalize:
101 | scale_transform.append(
102 | tvt.Normalize((0.5,) * data_shape[2], (0.5,) * data_shape[2])
103 | )
104 |
105 | test_transform.extend(scale_transform)
106 | train_transform.extend(scale_transform)
107 |
108 | test_transform.extend(extra_test_transforms)
109 | train_transform.extend(extra_train_transforms)
110 |
111 | if sigma > 0.0:
112 | noise_transform = lambda x: x + sigma * torch.randn_like(x)
113 | train_transform.append(noise_transform)
114 | test_transform.append(noise_transform)
115 |
116 | if mutation_rate > 0.0:
117 | if len(data_shape) == 3:
118 | mutation_data_shape = (data_shape[2], data_shape[0], data_shape[1])
119 | else:
120 | mutation_data_shape = data_shape
121 |
122 | def mutation_transform(x):
123 | mask = torch.bernoulli(
124 | torch.empty(mutation_data_shape).fill_(mutation_rate)
125 | )
126 | replace = torch.empty(mutation_data_shape).uniform_(-1, 1) * (1 - mask)
127 | return x * mask + replace
128 |
129 | train_transform.append(mutation_transform)
130 |
131 | train_transform = tvt.Compose(train_transform)
132 | test_transform = tvt.Compose(test_transform)
133 |
134 | if split == "train":
135 | ds = ds.train(train_transform)
136 | elif split == "val":
137 | ds = ds.val(test_transform)
138 | else:
139 | ds = ds.test(test_transform)
140 |
141 | setattr(ds, "data_shape", data_shape)
142 |
143 | if isinstance(ds, torch.utils.data.IterableDataset):
144 | shuffle = False
145 | else:
146 | shuffle = split == "train" if shuffle is None else shuffle
147 |
148 | if ood_dataset is not None:
149 | if isinstance(ds, torch.utils.data.IterableDataset):
150 | ood_ds, _ = get_dataset(ood_dataset, data_shape)
151 | ds = ConcatIterableDataset(ds, ood_ds.train(train_transform))
152 | else:
153 | ood_ds, _ = get_dataset(ood_dataset, data_shape, length=len(ds))
154 |
155 | ood_train = ood_ds.train(train_transform)
156 | ds = ConcatDataset(ds, ood_train)
157 |
158 | dataloader = DataLoader(
159 | ds,
160 | batch_size=batch_size,
161 | pin_memory=True,
162 | num_workers=num_workers,
163 | shuffle=shuffle,
164 | drop_last=split == "train" if drop_last is None else drop_last,
165 | )
166 | return dataloader
167 |
--------------------------------------------------------------------------------
/uncertainty_est/data/datasets.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from torch.utils.data import Dataset, IterableDataset
4 |
5 |
6 | class ConcatDataset(Dataset):
7 | def __init__(self, *datasets):
8 | super().__init__()
9 | self.datasets: List[Dataset] = datasets
10 |
11 | def __len__(self):
12 | return min([len(ds) for ds in self.datasets])
13 |
14 | def __getitem__(self, idx):
15 | return [ds[idx] for ds in self.datasets]
16 |
17 |
18 | class ConcatIterableDataset(IterableDataset):
19 | def __init__(self, *datasets):
20 | super().__init__()
21 | self.datasets: List[Dataset] = datasets
22 |
23 | def __iter__(self):
24 | return zip(*self.datasets)
25 |
--------------------------------------------------------------------------------
/uncertainty_est/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from numpy.lib.arraysetops import isin
5 |
6 | sys.path.insert(0, os.getcwd())
7 |
8 | import logging
9 | from copy import copy
10 | from pathlib import Path
11 | from argparse import ArgumentParser
12 |
13 | import numpy as np
14 | import pandas as pd
15 |
16 | from uncertainty_est.data.dataloaders import get_dataloader
17 | from uncertainty_est.models import load_checkpoint, load_model
18 |
19 |
20 | parser = ArgumentParser()
21 | parser.add_argument("--checkpoint", type=str, action="append", default=[])
22 | parser.add_argument("--dataset", type=str)
23 | parser.add_argument("--ood_dataset", type=str, action="append")
24 | parser.add_argument("--eval-classification", action="store_true")
25 | parser.add_argument("--output-folder", type=str)
26 | parser.add_argument("--name", type=str, default="out")
27 | parser.add_argument("--max-eval", type=int, default=10_000)
28 | parser.add_argument("--checkpoint-dir", type=str)
29 | parser.add_argument("--config-entries", type=str, action="append", default=[])
30 |
31 |
32 | logger = logging.getLogger()
33 | logger.setLevel(logging.INFO)
34 |
35 |
36 | def eval_model(
37 | model,
38 | dataset,
39 | ood_datasets,
40 | eval_classification=False,
41 | model_name="",
42 | batch_size=128,
43 | max_items=-1,
44 | data_shape=None,
45 | **kwargs,
46 | ):
47 | id_test_loader = get_dataloader(
48 | dataset, "test", batch_size=batch_size, data_shape=data_shape, **kwargs
49 | )
50 |
51 | clf_accum = []
52 | if eval_classification:
53 | clf_results = model.eval_classifier(id_test_loader, max_items)
54 | for k, v in clf_results.items():
55 | logger.info(f"{k}: {v:.02f}")
56 | clf_accum.append((model_name, model.__class__.__name__, dataset, k, v))
57 |
58 | test_ood_dataloaders = []
59 | for test_ood_dataset in ood_datasets:
60 | loader = get_dataloader(
61 | test_ood_dataset,
62 | "test",
63 | data_shape=id_test_loader.dataset.data_shape,
64 | batch_size=batch_size,
65 | **kwargs,
66 | )
67 | test_ood_dataloaders.append((test_ood_dataset, loader))
68 |
69 | ood_results = model.eval_ood(id_test_loader, test_ood_dataloaders)
70 |
71 | accum = []
72 | for k, v in ood_results.items():
73 | logger.info(f"{k}: {v:.02f}")
74 | accum.append((model_name, model.__class__.__name__, dataset, *k, v))
75 |
76 | return accum, clf_accum
77 |
78 |
79 | if __name__ == "__main__":
80 | args = parser.parse_args()
81 | base_args = copy(args)
82 |
83 | ood_tbl_rows = []
84 | clf_tbl_rows = []
85 | for checkpoint in args.checkpoint:
86 | # Reset args to original
87 | args = copy(base_args)
88 | checkpoint_path = Path(checkpoint)
89 | model_name = checkpoint_path.parent.stem
90 |
91 | model, config = load_checkpoint(checkpoint_path, strict=False)
92 | model.eval()
93 | model.cuda()
94 |
95 | if not args.ood_dataset:
96 | args.ood_dataset = config["test_ood_datasets"]
97 |
98 | if not args.dataset:
99 | args.dataset = config["dataset"]
100 |
101 | ood_rows, clf_rows = eval_model(
102 | model,
103 | args.dataset,
104 | args.ood_dataset,
105 | args.eval_classification,
106 | model_name=model_name,
107 | batch_size=128,
108 | max_items=args.max_eval,
109 | normalize=config["normalize"] if "normalize" in config else True,
110 | data_shape=config["data_shape"],
111 | )
112 | ood_tbl_rows.extend(ood_rows)
113 | clf_tbl_rows.append(clf_rows)
114 |
115 | if args.checkpoint_dir:
116 | checkpoint_dir = Path(args.checkpoint_dir)
117 | for model_dir in checkpoint_dir.glob("**/version_*"):
118 | # Reset args to original
119 | args = copy(base_args)
120 | try:
121 | model, config = load_model(model_dir, last=False, strict=False)
122 | except Exception as e:
123 | logger.info(str(e))
124 | continue
125 | model.eval()
126 | model.cuda()
127 |
128 | if not args.ood_dataset:
129 | args.ood_dataset = config["test_ood_datasets"]
130 |
131 | if not args.dataset:
132 | args.dataset = config["dataset"]
133 |
134 | ood_rows, clf_rows = eval_model(
135 | model,
136 | args.dataset,
137 | args.ood_dataset,
138 | args.eval_classification,
139 | model_name=model_dir.parent.stem,
140 | batch_size=128,
141 | max_items=args.max_eval,
142 | normalize=config["normalize"] if "normalize" in config else True,
143 | data_shape=config["data_shape"],
144 | )
145 |
146 | extra_cols = []
147 | for e in args.config_entries:
148 | out = config
149 | for key in e.split("."):
150 | out = out.get(key, np.nan)
151 | if out == np.nan:
152 | break
153 |
154 | if isinstance(out, dict):
155 | raise ValueError("Error getting config entry")
156 | extra_cols.append(out)
157 | ood_tbl_rows.extend([[*row, *extra_cols] for row in ood_rows])
158 | clf_tbl_rows.extend([[*row, *extra_cols] for row in clf_rows])
159 |
160 | extra_row_names = [s.split(".")[-1] for s in args.config_entries]
161 | if args.output_folder:
162 | output_folder = Path(args.output_folder)
163 | ood_df = pd.DataFrame(
164 | ood_tbl_rows,
165 | columns=(
166 | "Model",
167 | "Model Type",
168 | "ID dataset",
169 | "OOD dataset",
170 | "Score",
171 | "Metric",
172 | "Value",
173 | *extra_row_names,
174 | ),
175 | )
176 | ood_df.to_csv(output_folder / f"ood-{args.name}.csv", index=False)
177 |
178 | if args.eval_classification:
179 | clf_df = pd.DataFrame(
180 | clf_tbl_rows,
181 | columns=(
182 | "Model",
183 | "Model Type",
184 | "ID dataset",
185 | "Metric",
186 | "Value",
187 | *extra_row_names,
188 | ),
189 | )
190 | clf_df.to_csv(output_folder / f"clf-{args.name}.csv", index=False)
191 |
--------------------------------------------------------------------------------
/uncertainty_est/models/__init__.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from pathlib import Path
3 |
4 | from uncertainty_est.models.ebm.mcmc import MCMC
5 | from uncertainty_est.models.ebm.discrete_mcmc import DiscreteMCMC
6 | from uncertainty_est.models.ce_baseline import CEBaseline
7 | from uncertainty_est.models.energy_finetuning import EnergyFinetune
8 | from uncertainty_est.models.ebm.vera import VERA
9 | from uncertainty_est.models.normalizing_flow.norm_flow import NormalizingFlow
10 | from uncertainty_est.models.ebm.conditional_nce import PerSampleNCE
11 | from .normalizing_flow.image_flows import RealNVPModel, GlowModel
12 | from .ebm.nce import NoiseContrastiveEstimation
13 | from .ebm.flow_contrastive_estimation import FlowContrastiveEstimation
14 | from .ebm.ssm import SSM
15 |
16 |
17 | MODELS = {
18 | "JEM": MCMC,
19 | "DiscreteMCMC": DiscreteMCMC,
20 | "CEBaseline": CEBaseline,
21 | "EnergyOOD": EnergyFinetune,
22 | "VERA": VERA,
23 | "NormalizingFlow": NormalizingFlow,
24 | "PerSampleNCE": PerSampleNCE,
25 | "RealNVP": RealNVPModel,
26 | "Glow": GlowModel,
27 | "NCE": NoiseContrastiveEstimation,
28 | "FlowCE": FlowContrastiveEstimation,
29 | "SSM": SSM,
30 | }
31 |
32 |
33 | def load_model(model_folder: Path, last=False, *args, **kwargs):
34 | ckpts = [file for file in model_folder.iterdir() if file.suffix == ".ckpt"]
35 | last_ckpt = [ckpt for ckpt in ckpts if ckpt.stem == "last"]
36 | best_ckpt = sorted([ckpt for ckpt in ckpts if ckpt.stem != "last"])
37 |
38 | if last:
39 | ckpts = best_ckpt + last_ckpt
40 | else:
41 | ckpts = last_ckpt + best_ckpt
42 | assert len(ckpts) > 0
43 |
44 | checkpoint_path = ckpts[-1]
45 | return load_checkpoint(checkpoint_path, *args, **kwargs)
46 |
47 |
48 | def load_checkpoint(checkpoint_path: Path, *args, **kwargs):
49 | with (checkpoint_path.parent / "config.yaml").open("r") as f:
50 | config = yaml.load(f, Loader=yaml.FullLoader)
51 |
52 | model = MODELS[config["model_name"]].load_from_checkpoint(
53 | checkpoint_path, *args, **kwargs
54 | )
55 | return model, config
56 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ce_baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from uncertainty_est.archs.arch_factory import get_arch
5 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
6 |
7 |
8 | class CEBaseline(OODDetectionModel):
9 | def __init__(
10 | self, arch_name, arch_config, learning_rate, momentum, weight_decay, **kwargs
11 | ):
12 | super().__init__(**kwargs)
13 | self.__dict__.update(locals())
14 | self.save_hyperparameters()
15 |
16 | self.backbone = get_arch(arch_name, arch_config)
17 |
18 | def forward(self, x):
19 | return self.backbone(x)
20 |
21 | def training_step(self, batch, batch_idx):
22 | x, y = batch
23 | y_hat = self(x)
24 |
25 | loss = F.cross_entropy(y_hat, y)
26 | self.log("train/loss", loss)
27 | return loss
28 |
29 | def validation_step(self, batch, batch_idx):
30 | x, y = batch
31 | y_hat = self(x)
32 |
33 | loss = F.cross_entropy(y_hat, y)
34 | self.log("val/loss", loss)
35 |
36 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
37 | self.log("val/acc", acc)
38 |
39 | def test_step(self, batch, batch_idx):
40 | x, y = batch
41 | y_hat = self(x)
42 |
43 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
44 | self.log("test/acc", acc)
45 |
46 | def configure_optimizers(self):
47 | optim = torch.optim.AdamW(
48 | self.parameters(),
49 | betas=(self.momentum, 0.999),
50 | lr=self.learning_rate,
51 | weight_decay=self.weight_decay,
52 | )
53 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5)
54 | return [optim], [scheduler]
55 |
56 | def classify(self, x):
57 | return torch.softmax(self.backbone(x), -1)
58 |
59 | def get_ood_scores(self, x):
60 | logits = self(x).cpu()
61 | dir_uncert = dirichlet_prior_network_uncertainty(logits)
62 | dir_uncert["p(x)"] = logits.logsumexp(1)
63 | dir_uncert["max p(y|x)"] = logits.softmax(1).max(1)[0]
64 | return dir_uncert
65 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/models/ebm/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/conditional_nce.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributions
3 |
4 | from uncertainty_est.archs.arch_factory import get_arch
5 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
6 |
7 |
8 | class PerSampleNCE(OODDetectionModel):
9 | def __init__(
10 | self,
11 | arch_name,
12 | arch_config,
13 | learning_rate,
14 | momentum,
15 | weight_decay,
16 | noise_sigma=0.01,
17 | p_control_weight=0.0,
18 | ):
19 | super().__init__()
20 | self.__dict__.update(locals())
21 | self.save_hyperparameters()
22 |
23 | self.model = get_arch(arch_name, arch_config)
24 |
25 | def forward(self, x):
26 | return self.model(x)
27 |
28 | def training_step(self, batch, batch_idx):
29 | x, _ = batch
30 |
31 | sample_shape = x.shape[1:]
32 | sample_dim = sample_shape.numel()
33 | noise_dist = distributions.multivariate_normal.MultivariateNormal(
34 | torch.zeros(sample_dim).to(self.device),
35 | torch.eye(sample_dim).to(self.device) * self.noise_sigma,
36 | )
37 | noise = noise_dist.sample(x.size()[:1])
38 |
39 | # Implements Eq. 9 in "Conditional Noise-Contrastive Estimation of Unnormalised Models"
40 | # Uses symmetry of noise distribution meaning p(u1|u2) = p(u2|u1) to simplify
41 | # Sets k = 1
42 | x_noisy = x + noise.reshape_as(x)
43 | log_p_model = self.model(torch.cat((x, x_noisy))).squeeze()
44 | log_p_x = log_p_model[: len(x)]
45 | log_p_x_noisy = log_p_model[len(x) :]
46 |
47 | loss = torch.log(1 + (-(log_p_x - log_p_x_noisy)).exp()).mean()
48 |
49 | p_control = log_p_model.abs().mean()
50 | loss += self.p_control_weight * p_control
51 |
52 | self.log("train/log_p_magnitude", log_p_x.mean(), prog_bar=True)
53 | self.log("train/log_p_noisy_magnitude", log_p_x_noisy.mean(), prog_bar=True)
54 | self.log("train/loss", loss)
55 | return loss
56 |
57 | def validation_step(self, batch, batch_idx):
58 | return
59 |
60 | def test_step(self, batch, batch_idx):
61 | self.to(torch.float32)
62 | x, y = batch
63 | y_hat = self.model(x)
64 |
65 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
66 | self.log("test_acc", acc)
67 | return y_hat
68 |
69 | def configure_optimizers(self):
70 | optim = torch.optim.AdamW(
71 | self.parameters(),
72 | betas=(self.momentum, 0.999),
73 | lr=self.learning_rate,
74 | weight_decay=self.weight_decay,
75 | )
76 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5)
77 | return [optim], [scheduler]
78 |
79 | def get_ood_scores(self, x):
80 | return {"p(x)": self.model(x)}
81 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/discrete_mcmc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.distributions import Categorical
4 |
5 | from uncertainty_est.models.ebm.mcmc import MCMC
6 |
7 |
8 | def init_random(buffer_size, data_shape, dim):
9 | buffer = torch.FloatTensor(buffer_size, *data_shape).random_(0, dim)
10 | buffer = F.one_hot(buffer.long(), num_classes=dim).float()
11 | return buffer
12 |
13 |
14 | class DiscreteMCMC(MCMC):
15 | def __init__(
16 | self,
17 | arch_name,
18 | arch_config,
19 | learning_rate,
20 | momentum,
21 | weight_decay,
22 | buffer_size,
23 | n_classes,
24 | data_shape,
25 | smoothing,
26 | pyxce,
27 | pxsgld,
28 | pxysgld,
29 | class_cond_p_x_sample,
30 | sgld_batch_size,
31 | reinit_freq,
32 | num_cat,
33 | sgld_steps=20,
34 | entropy_reg_weight=0.0,
35 | warmup_steps=0,
36 | lr_step_size=50,
37 | **kwargs
38 | ):
39 | kwargs.pop("sgld_lr")
40 | kwargs.pop("sgld_std")
41 | self.num_cat = num_cat
42 | super().__init__(
43 | arch_name=arch_name,
44 | arch_config=arch_config,
45 | learning_rate=learning_rate,
46 | momentum=momentum,
47 | weight_decay=weight_decay,
48 | buffer_size=buffer_size,
49 | n_classes=n_classes,
50 | data_shape=data_shape,
51 | smoothing=smoothing,
52 | pyxce=pyxce,
53 | pxsgld=pxsgld,
54 | pxysgld=pxysgld,
55 | class_cond_p_x_sample=class_cond_p_x_sample,
56 | sgld_batch_size=sgld_batch_size,
57 | sgld_lr=0.0,
58 | sgld_std=0.0,
59 | reinit_freq=reinit_freq,
60 | sgld_steps=sgld_steps,
61 | entropy_reg_weight=entropy_reg_weight,
62 | warmup_steps=warmup_steps,
63 | lr_step_size=lr_step_size,
64 | **kwargs
65 | )
66 | self.save_hyperparameters()
67 |
68 | def _init_buffer(self):
69 | self.replay_buffer = init_random(
70 | self.buffer_size, self.data_shape, self.num_cat
71 | ).cpu()
72 |
73 | def sample_p_0(self, replay_buffer, bs, y=None):
74 | if len(replay_buffer) == 0:
75 | return init_random(bs, self.data_shape, self.num_cat), []
76 |
77 | buffer_size = (
78 | len(replay_buffer) if y is None else len(replay_buffer) // self.n_classes
79 | )
80 | inds = torch.randint(0, buffer_size, (bs,))
81 | # if cond, convert inds to class conditional inds
82 | if y is not None:
83 | inds = y.cpu() * buffer_size + inds
84 |
85 | buffer_samples = replay_buffer[inds].to(self.device)
86 | if self.reinit_freq > 0.0:
87 | random_samples = init_random(bs, self.data_shape, self.num_cat).to(
88 | self.device
89 | )
90 | choose_random = (torch.rand(bs) < self.reinit_freq).to(buffer_samples)[
91 | (...,) + (None,) * (len(self.data_shape) + 1)
92 | ]
93 | samples = (
94 | choose_random * random_samples + (1 - choose_random) * buffer_samples
95 | )
96 | else:
97 | samples = buffer_samples
98 | return samples.to(self.device), inds
99 |
100 | def sample_q(self, replay_buffer, y=None, n_steps=20, contrast=False):
101 | self.model.eval()
102 | bs = self.sgld_batch_size if y is None else y.size(0)
103 |
104 | # generate initial samples and buffer inds of those samples (if buffer is used)
105 | init_sample, buffer_inds = self.sample_p_0(replay_buffer, bs=bs, y=y)
106 | x_k = torch.autograd.Variable(init_sample, requires_grad=True)
107 |
108 | # Gradient with Gibbs "http://arxiv.org/abs/2102.04509"
109 | for _ in range(n_steps):
110 | energy = self.model(x_k, y=y)
111 | f_prime = torch.autograd.grad(energy.sum(), [x_k], retain_graph=True)[0]
112 |
113 | d = f_prime - (x_k * f_prime).sum(-1, keepdim=True)
114 | q_i_given_x = Categorical(logits=(d / 2.0).flatten(start_dim=1))
115 | i = q_i_given_x.sample()
116 | prob_i_given_x = q_i_given_x.log_prob(i).exp()
117 |
118 | # Flip sampled dimension
119 | x_q_idx = x_k.argmax(-1)
120 | x_q_idx.flatten(start_dim=1)[torch.arange(len(x_k)), i // self.num_cat] = (
121 | i % self.num_cat
122 | )
123 | x_q = F.one_hot(x_q_idx, self.num_cat).float()
124 | x_q.requires_grad_()
125 |
126 | energy_q = self.model(x_q, y=y)
127 | f_prime = torch.autograd.grad(energy_q.sum(), [x_q], retain_graph=True)[0]
128 | d = f_prime - (x_q * f_prime).sum(-1, keepdim=True)
129 | q_i_given_x = Categorical(logits=(d / 2.0).flatten(start_dim=1))
130 | prob_i_given_x_q = q_i_given_x.log_prob(i).exp()
131 |
132 | # Update samples dependig on Metropolis-Hastings Probability
133 | keep_prob = torch.exp(energy_q - energy) * (
134 | prob_i_given_x_q / prob_i_given_x
135 | )
136 | keep_prob[keep_prob > 1.0] = 1.0
137 | keep = torch.rand(len(x_k)).to(self.device) < keep_prob
138 |
139 | x_k = x_k.detach()
140 | x_k[keep] = x_q[keep]
141 |
142 | self.model.train()
143 | final_samples = x_k.detach()
144 |
145 | # update replay buffer
146 | if len(replay_buffer) > 0:
147 | replay_buffer[buffer_inds] = final_samples.cpu()
148 | return final_samples
149 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/flow_contrastive_estimation.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import matplotlib.pyplot as plt
6 |
7 | from uncertainty_est.utils.utils import to_np
8 | from uncertainty_est.archs.arch_factory import get_arch
9 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
10 |
11 |
12 | class FlowContrastiveEstimation(OODDetectionModel):
13 | """Implementation of Noise Contrastive Estimation
14 | http://proceedings.mlr.press/v9/gutmann10a.html
15 | """
16 |
17 | def __init__(
18 | self,
19 | arch_name,
20 | arch_config,
21 | flow_arch_name,
22 | flow_arch_config,
23 | learning_rate,
24 | momentum,
25 | weight_decay,
26 | flow_learning_rate,
27 | rho=0.5,
28 | is_toy_dataset=False,
29 | toy_dataset_dim=2,
30 | ):
31 | super().__init__()
32 | self.automatic_optimization = False
33 | self.__dict__.update(locals())
34 | self.save_hyperparameters()
35 |
36 | self.model = get_arch(arch_name, arch_config)
37 | self.noise_dist = get_arch(flow_arch_name, flow_arch_config)
38 |
39 | def forward(self, x):
40 | return self.model(x)
41 |
42 | def training_step(self, batch, batch_idx, optimizer_idx):
43 | optim_ebm, optim_flow = self.optimizers()
44 | x, _ = batch
45 |
46 | noise, neg_f_log_p = self.noise_dist.sample(len(x))
47 | neg_e_log_p = self.model(noise.detach()).logsumexp(-1)
48 |
49 | target = torch.zeros_like(neg_e_log_p).long().to(self.device)
50 | neg_pred = torch.stack(
51 | (neg_f_log_p + math.log(1 - self.rho), neg_e_log_p + math.log(self.rho)), -1
52 | )
53 | neg_loss = F.cross_entropy(neg_pred, target)
54 |
55 | pos_f_log_p = self.noise_dist.log_prob(x)
56 | pos_e_log_p = self.model(x).logsumexp(-1)
57 |
58 | target = torch.ones_like(pos_e_log_p).long().to(self.device)
59 | pos_pred = torch.stack(
60 | (pos_f_log_p + math.log(1 - self.rho), pos_e_log_p + math.log(self.rho)), -1
61 | )
62 | pos_loss = F.cross_entropy(pos_pred, target)
63 |
64 | loss = self.rho * pos_loss + (1 - self.rho) * neg_loss
65 |
66 | pos_acc = (pos_e_log_p > pos_f_log_p).float().mean()
67 | neg_acc = (neg_f_log_p > neg_e_log_p).float().mean()
68 |
69 | self.log("train/pos_acc", pos_acc, prog_bar=True)
70 | self.log("train/neg_acc", neg_acc, prog_bar=True)
71 |
72 | if pos_acc < 0.55 or neg_acc < 0.55:
73 | optim_ebm.zero_grad()
74 | self.manual_backward(loss)
75 | optim_ebm.step()
76 | self.log("train/loss", loss, prog_bar=True)
77 | else:
78 | optim_flow.zero_grad()
79 | self.manual_backward(-pos_loss)
80 | optim_flow.step()
81 | self.log("train/flow_loss", -loss, prog_bar=True)
82 |
83 | def validation_step(self, batch, batch_idx):
84 | return
85 |
86 | def validation_epoch_end(self, outputs):
87 | if self.is_toy_dataset and self.toy_dataset_dim == 2:
88 | interp = torch.linspace(-4, 4, 500)
89 | x, y = torch.meshgrid(interp, interp)
90 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1).to(self.device)
91 | p_xy = torch.exp(self(data))
92 | px = to_np(p_xy.sum(1))
93 | flow_px = to_np(self.noise_dist.log_prob(data).exp())
94 |
95 | x, y = to_np(x), to_np(y)
96 | for i in range(p_xy.shape[1]):
97 | fig, ax = plt.subplots()
98 | mesh = ax.pcolormesh(x, y, to_np(p_xy[:, i]).reshape(*x.shape))
99 | fig.colorbar(mesh)
100 | self.logger.experiment.add_figure(
101 | f"dist/p(x,y={i})", fig, self.current_epoch
102 | )
103 | plt.close()
104 |
105 | fig, ax = plt.subplots()
106 | mesh = ax.pcolormesh(x, y, px.reshape(*x.shape))
107 | fig.colorbar(mesh)
108 | self.logger.experiment.add_figure("dist/p(x)", fig, self.current_epoch)
109 | plt.close()
110 |
111 | fig, ax = plt.subplots()
112 | samples = to_np(self.noise_dist.sample(1000)[0])
113 | mesh = ax.scatter(samples[:, 0], samples[:, 1])
114 | self.logger.experiment.add_figure(
115 | "dist/flow_samples", fig, self.current_epoch
116 | )
117 | plt.close()
118 |
119 | fig, ax = plt.subplots()
120 | mesh = ax.pcolormesh(x, y, flow_px.reshape(*x.shape))
121 | fig.colorbar(mesh)
122 | self.logger.experiment.add_figure("dist/Flow p(x)", fig, self.current_epoch)
123 | plt.close()
124 |
125 | def test_step(self, batch, batch_idx):
126 | self.to(torch.float32)
127 | x, y = batch
128 | y_hat = self.model(x)
129 |
130 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
131 | self.log("test_acc", acc)
132 | return y_hat
133 |
134 | def configure_optimizers(self):
135 | optim = torch.optim.AdamW(
136 | self.model.parameters(),
137 | betas=(self.momentum, 0.999),
138 | lr=self.learning_rate,
139 | weight_decay=self.weight_decay,
140 | )
141 | optim_flow = torch.optim.AdamW(
142 | self.noise_dist.parameters(),
143 | betas=(self.momentum, 0.999),
144 | lr=self.flow_learning_rate,
145 | weight_decay=self.weight_decay,
146 | )
147 | return [optim, optim_flow]
148 |
149 | def get_ood_scores(self, x):
150 | return {"p(x)": self.model(x)}
151 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/hdge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import torch.nn.functional as F
4 |
5 | from uncertainty_est.utils.utils import to_np
6 | from uncertainty_est.archs.arch_factory import get_arch
7 | from uncertainty_est.models.ebm.utils.model import HDGE
8 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
9 | from uncertainty_est.models.ebm.utils.utils import (
10 | KHotCrossEntropyLoss,
11 | smooth_one_hot,
12 | )
13 |
14 |
15 | class HDGEModel(OODDetectionModel):
16 | def __init__(
17 | self,
18 | arch_name,
19 | arch_config,
20 | learning_rate,
21 | momentum,
22 | weight_decay,
23 | pyxce,
24 | pxcontrast,
25 | pxycontrast,
26 | smoothing,
27 | n_classes,
28 | contrast_k,
29 | contrast_t,
30 | warmup_steps=-1,
31 | ):
32 | super().__init__()
33 | self.__dict__.update(locals())
34 | self.save_hyperparameters()
35 |
36 | arch = get_arch(arch_name, arch_config)
37 | self.model = HDGE(arch, n_classes, contrast_k, contrast_t)
38 |
39 | def forward(self, x):
40 | return self.model(x)
41 |
42 | def compute_losses(self, x_lab, dist, logits=None, evaluation=False):
43 | l_pyxce, l_pxcontrast, l_pxycontrast = 0.0, 0.0, 0.0
44 | # log p(y|x) cross entropy loss
45 | if self.pyxce > 0:
46 | if logits is None:
47 | logits = self.model.classify(x_lab)
48 | l_pyxce = KHotCrossEntropyLoss()(logits, dist)
49 | l_pyxce *= self.pyxce
50 |
51 | # log p(x) using contrastive learning
52 | if self.pxcontrast > 0:
53 | # ones like dist to use all indexes
54 | ones_dist = torch.ones_like(dist).to(self.device)
55 | output, target, _, _ = self.model.joint(
56 | img=x_lab, dist=ones_dist, evaluation=evaluation
57 | )
58 | l_pxcontrast = F.cross_entropy(output, target)
59 | l_pxcontrast *= self.pxycontrast
60 |
61 | # log p(x|y) using contrastive learning
62 | if self.pxycontrast > 0:
63 | output, target, _, _ = self.model.joint(
64 | img=x_lab, dist=dist, evaluation=evaluation
65 | )
66 | l_pxycontrast = F.cross_entropy(output, target)
67 | l_pxycontrast *= self.pxycontrast
68 |
69 | return l_pyxce, l_pxcontrast, l_pxycontrast
70 |
71 | def training_step(self, batch, batch_idx):
72 | x_lab, y_lab = batch
73 | dist = smooth_one_hot(y_lab, self.n_classes, self.smoothing)
74 |
75 | loss = sum(self.compute_losses(x_lab, dist))
76 | return loss
77 |
78 | def validation_step(self, batch, batch_idx):
79 | x, y = batch
80 | logits = self.model.classify(x)
81 | dist = smooth_one_hot(y, self.n_classes, self.smoothing)
82 |
83 | self.log(
84 | "val/loss",
85 | sum(self.compute_losses(x, dist, logits=logits, evaluation=True)),
86 | )
87 |
88 | acc = (y == logits.argmax(1)).float().mean(0).item()
89 | self.log("val_acc", acc)
90 |
91 | def validation_epoch_end(self, training_step_outputs):
92 | if self.vis_every > 0 and self.current_epoch % self.vis_every == 0:
93 | interp = torch.linspace(-10, 10, 500)
94 | x, y = torch.meshgrid(interp, interp)
95 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1)
96 |
97 | log_px = to_np(self.model(data.to(self.device)))
98 |
99 | fig, ax = plt.subplots()
100 | ax.set_title(f"log p(x)")
101 | mesh = ax.pcolormesh(to_np(x), to_np(y), log_px.reshape(*x.shape))
102 | fig.colorbar(mesh)
103 | self.logger.experiment.add_figure("log p(x)", fig, self.current_epoch)
104 | plt.close()
105 |
106 | def test_step(self, batch, batch_idx):
107 | x, y = batch
108 | y_hat = self(x)
109 |
110 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
111 | self.log("test_acc", acc)
112 |
113 | def configure_optimizers(self):
114 | optim = torch.optim.AdamW(
115 | self.parameters(),
116 | betas=(self.momentum, 0.999),
117 | lr=self.learning_rate,
118 | weight_decay=self.weight_decay,
119 | )
120 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5)
121 | return [optim], [scheduler]
122 |
123 | def classify(self, x):
124 | return torch.softmax(self.model.classify(x), -1)
125 |
126 | def get_ood_scores(self, x):
127 | return {"p(x)": self.model(x)}
128 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/mcmc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from matplotlib import pyplot as plt
3 |
4 | from uncertainty_est.utils.utils import to_np
5 | from uncertainty_est.models.ebm.utils.model import JEM
6 | from uncertainty_est.archs.arch_factory import get_arch
7 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
8 | from uncertainty_est.models.ebm.utils.utils import (
9 | KHotCrossEntropyLoss,
10 | smooth_one_hot,
11 | )
12 |
13 |
14 | def init_random(buffer_size, data_shape):
15 | return torch.FloatTensor(buffer_size, *data_shape).uniform_(-1, 1)
16 |
17 |
18 | class MCMC(OODDetectionModel):
19 | def __init__(
20 | self,
21 | arch_name,
22 | arch_config,
23 | learning_rate,
24 | momentum,
25 | weight_decay,
26 | buffer_size,
27 | n_classes,
28 | data_shape,
29 | smoothing,
30 | pyxce,
31 | pxsgld,
32 | pxysgld,
33 | class_cond_p_x_sample,
34 | sgld_batch_size,
35 | sgld_lr,
36 | sgld_std,
37 | reinit_freq,
38 | sgld_steps=20,
39 | entropy_reg_weight=0.0,
40 | warmup_steps=2500,
41 | lr_step_size=50,
42 | is_toy_dataset=False,
43 | **kwargs
44 | ):
45 | super().__init__(**kwargs)
46 | self.__dict__.update(locals())
47 | self.save_hyperparameters()
48 |
49 | if len(data_shape) == 3:
50 | self.sample_shape = [data_shape[-1], data_shape[0], data_shape[1]]
51 | else:
52 | self.sample_shape = data_shape
53 |
54 | if class_cond_p_x_sample:
55 | assert n_classes > 1
56 |
57 | arch = get_arch(arch_name, arch_config)
58 | self.model = JEM(arch)
59 |
60 | self.buffer_size = self.buffer_size - (self.buffer_size % self.n_classes)
61 | self._init_buffer()
62 |
63 | def forward(self, x):
64 | return self.model(x)
65 |
66 | def _init_buffer(self):
67 | self.replay_buffer = init_random(self.buffer_size, self.sample_shape).cpu()
68 |
69 | def training_step(self, batch, batch_idx):
70 | (x_lab, y_lab), (x_p_d, _) = batch
71 | if self.n_classes > 1:
72 | dist = smooth_one_hot(y_lab, self.n_classes, self.smoothing)
73 | else:
74 | dist = y_lab[None, :]
75 |
76 | l_pyxce, l_pxsgld, l_pxysgld = 0.0, 0.0, 0.0
77 | # log p(y|x) cross entropy loss
78 | if self.pyxce > 0:
79 | logits = self.model.classify(x_lab)
80 | l_pyxce = KHotCrossEntropyLoss()(logits, dist)
81 | l_pyxce *= self.pyxce
82 |
83 | l_pyxce += (
84 | self.entropy_reg_weight
85 | * -torch.distributions.Categorical(logits=logits).entropy().mean()
86 | )
87 | self.log("train/clf_loss", l_pyxce)
88 |
89 | # log p(x) using sgld
90 | if self.pxsgld > 0:
91 | if self.class_cond_p_x_sample:
92 | y_q = torch.randint(0, self.n_classes, (self.sgld_batch_size,)).to(
93 | self.device
94 | )
95 | x_q = self.sample_q(self.replay_buffer, y=y_q)
96 | else:
97 | x_q = self.sample_q(
98 | self.replay_buffer, n_steps=self.sgld_steps
99 | ) # sample from log-sumexp
100 |
101 | fp = self.model(x_p_d)
102 | fq = self.model(x_q)
103 | l_pxsgld = -(fp.mean() - fq.mean()) + (fp ** 2).mean() + (fq ** 2).mean()
104 | l_pxsgld *= self.pxsgld
105 |
106 | # log p(x|y) using sgld
107 | if self.pxysgld > 0:
108 | x_q_lab = self.sample_q(self.replay_buffer, y=y_lab)
109 | fp, fq = self.model(x_lab).mean(), self.model(x_q_lab).mean()
110 | l_pxysgld = -(fp - fq)
111 | l_pxysgld *= self.pxysgld
112 |
113 | loss = l_pxysgld + l_pxsgld + l_pyxce
114 | self.log("train/loss", loss)
115 | return loss
116 |
117 | def validation_step(self, batch, batch_idx):
118 | (x_lab, y_lab), (_, _) = batch
119 |
120 | if self.n_classes < 2:
121 | return
122 |
123 | _, logits = self.model(x_lab, return_logits=True)
124 | acc = (y_lab == logits.argmax(1)).float().mean(0).item()
125 | self.log("val/acc", acc)
126 |
127 | def validation_epoch_end(self, outputs):
128 | if self.is_toy_dataset:
129 | interp = torch.linspace(-4, 4, 500)
130 | x, y = torch.meshgrid(interp, interp)
131 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1).to(self.device)
132 | px = to_np(torch.exp(self(data)))
133 |
134 | fig, ax = plt.subplots()
135 | mesh = ax.pcolormesh(x, y, px.reshape(*x.shape))
136 | fig.colorbar(mesh)
137 | self.logger.experiment.add_figure("dist/p(x)", fig, self.current_epoch)
138 | plt.close()
139 | super().validation_epoch_end(outputs)
140 |
141 | def test_step(self, batch, batch_idx):
142 | x, y = batch
143 | y_hat = self.model.classify(x)
144 |
145 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
146 | self.log("test/acc", acc)
147 |
148 | return y_hat
149 |
150 | def configure_optimizers(self):
151 | optim = torch.optim.AdamW(
152 | self.parameters(),
153 | betas=(self.momentum, 0.999),
154 | lr=self.learning_rate,
155 | weight_decay=self.weight_decay,
156 | )
157 | scheduler = torch.optim.lr_scheduler.StepLR(
158 | optim, step_size=self.lr_step_size, gamma=0.5
159 | )
160 | return [optim], [scheduler]
161 |
162 | def classify(self, x):
163 | return torch.softmax(self.model.classify(x), -1)
164 |
165 | def get_ood_scores(self, x):
166 | return {"p(x)": self.model(x)}
167 |
168 | def sample_p_0(self, replay_buffer, bs, y=None):
169 | if len(replay_buffer) == 0:
170 | return init_random(bs, self.sample_shape), []
171 |
172 | buffer_size = (
173 | len(replay_buffer) if y is None else len(replay_buffer) // self.n_classes
174 | )
175 | inds = torch.randint(0, buffer_size, (bs,))
176 | # if cond, convert inds to class conditional inds
177 | if y is not None:
178 | inds = y.cpu() * buffer_size + inds
179 |
180 | buffer_samples = replay_buffer[inds].to(self.device)
181 | random_samples = init_random(bs, self.sample_shape).to(self.device)
182 | choose_random = (torch.rand(bs) < self.reinit_freq).to(buffer_samples)[
183 | (...,) + (None,) * len(self.sample_shape)
184 | ]
185 | samples = choose_random * random_samples + (1 - choose_random) * buffer_samples
186 | return samples.to(self.device), inds
187 |
188 | def sample_q(self, replay_buffer, y=None, n_steps=20, contrast=False):
189 | self.model.eval()
190 | bs = self.sgld_batch_size if y is None else y.size(0)
191 |
192 | # generate initial samples and buffer inds of those samples (if buffer is used)
193 | init_sample, buffer_inds = self.sample_p_0(replay_buffer, bs=bs, y=y)
194 | x_k = torch.autograd.Variable(init_sample, requires_grad=True)
195 |
196 | # sgld
197 | for _ in range(n_steps):
198 | if not contrast:
199 | energy = self.model(x_k, y=y).sum(0)
200 | else:
201 | if y is not None:
202 | dist = smooth_one_hot(y, self.n_classes, self.smoothing)
203 | else:
204 | dist = torch.ones((bs, self.n_classes)).to(self.device)
205 | output, target, _, _ = self.model.joint(
206 | img=x_k, dist=dist, evaluation=True
207 | )
208 | energy = -1.0 * F.cross_entropy(output, target)
209 | f_prime = torch.autograd.grad(energy, [x_k], retain_graph=True)[0]
210 | x_k.data += self.sgld_lr * f_prime + self.sgld_std * torch.randn_like(x_k)
211 | self.model.train()
212 | final_samples = x_k.detach()
213 |
214 | # update replay buffer
215 | if len(replay_buffer) > 0:
216 | replay_buffer[buffer_inds] = final_samples.cpu()
217 | return final_samples
218 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/nce.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import torch.nn.functional as F
4 | from torch import distributions
5 | import matplotlib.pyplot as plt
6 |
7 | from uncertainty_est.archs.arch_factory import get_arch
8 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
9 | from uncertainty_est.utils.utils import (
10 | to_np,
11 | estimate_normalizing_constant,
12 | sum_except_batch,
13 | )
14 |
15 |
16 | class NoiseContrastiveEstimation(OODDetectionModel):
17 | """Implementation of Noise Contrastive Estimation http://proceedings.mlr.press/v9/gutmann10a.html"""
18 |
19 | def __init__(
20 | self,
21 | arch_name,
22 | arch_config,
23 | learning_rate,
24 | momentum,
25 | weight_decay,
26 | noise_distribution="uniform",
27 | noise_distribution_kwargs={"low": 0, "high": 1},
28 | **kwargs,
29 | ):
30 | super().__init__()
31 | self.__dict__.update(locals())
32 | self.save_hyperparameters()
33 |
34 | self.model = get_arch(arch_name, arch_config)
35 |
36 | if noise_distribution == "uniform":
37 | noise_dist = distributions.Uniform
38 | if noise_distribution == "gaussian":
39 | noise_dist = distributions.Normal
40 | else:
41 | raise NotImplementedError(
42 | f"Requested noise distribution {noise_distribution} not implemented."
43 | )
44 |
45 | self.dist_parameters = torch.nn.ParameterDict(
46 | {
47 | k: torch.nn.Parameter(torch.tensor(v).float(), requires_grad=False)
48 | for k, v in noise_distribution_kwargs.items()
49 | }
50 | )
51 | self.noise_dist = noise_dist(**self.dist_parameters)
52 |
53 | def forward(self, x):
54 | return self.model(x)
55 |
56 | def compute_ebm_loss(self, batch, return_outputs=False):
57 | x, _ = batch
58 | noise = self.noise_dist.sample(x.shape).to(self.device)
59 | inp = torch.cat((x, noise))
60 |
61 | logits = self.model(inp)
62 | log_p_model = logits.logsumexp(-1)
63 | log_p_noise = sum_except_batch(self.noise_dist.log_prob(inp))
64 |
65 | loss = F.binary_cross_entropy_with_logits(
66 | log_p_model - log_p_noise,
67 | torch.cat((torch.ones(len(x)), torch.zeros(len(x)))).to(self.device),
68 | )
69 | if return_outputs:
70 | return loss, logits[: len(x)]
71 | return loss
72 |
73 | def training_step(self, batch, batch_idx):
74 | loss = self.compute_ebm_loss(batch)
75 |
76 | self.log("train/loss", loss)
77 | return loss
78 |
79 | def validation_step(self, batch, batch_idx):
80 | return
81 |
82 | def test_step(self, batch, batch_idx):
83 | self.to(torch.float32)
84 | x, y = batch
85 | y_hat = self.model(x)
86 |
87 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
88 | self.log("test_acc", acc)
89 | return y_hat
90 |
91 | def configure_optimizers(self):
92 | optim = torch.optim.AdamW(
93 | self.parameters(),
94 | betas=(self.momentum, 0.999),
95 | lr=self.learning_rate,
96 | weight_decay=self.weight_decay,
97 | )
98 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.5)
99 | return [optim], [scheduler]
100 |
101 | def get_ood_scores(self, x):
102 | return {"p(x)": self.model(x)}
103 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/ssm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import autograd
3 | import torch.nn.functional as F
4 |
5 | from uncertainty_est.utils.utils import to_np
6 | from uncertainty_est.models.ebm.utils.model import JEM
7 | from uncertainty_est.archs.arch_factory import get_arch
8 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
9 |
10 |
11 | class SSM(OODDetectionModel):
12 | def __init__(
13 | self,
14 | arch_name,
15 | arch_config,
16 | learning_rate,
17 | momentum,
18 | weight_decay,
19 | n_classes,
20 | clf_weight,
21 | noise_type="radermacher",
22 | n_particles=1,
23 | warmup_steps=2500,
24 | lr_step_size=50,
25 | is_toy_dataset=False,
26 | **kwargs
27 | ):
28 | super().__init__(**kwargs)
29 | self.__dict__.update(locals())
30 | self.save_hyperparameters()
31 |
32 | arch = get_arch(arch_name, arch_config)
33 | self.model = JEM(arch)
34 |
35 | def forward(self, x):
36 | return self.model(x)
37 |
38 | def training_step(self, batch, batch_idx):
39 | (x_lab, y_lab), (x_p_d, _) = batch
40 | dup_samples = (
41 | x_p_d.unsqueeze(0)
42 | .expand(self.n_particles, *x_p_d.shape)
43 | .contiguous()
44 | .view(-1, *x_p_d.shape[1:])
45 | )
46 | dup_samples.requires_grad_(True)
47 |
48 | vectors = torch.randn_like(dup_samples)
49 | if self.noise_type == "radermacher":
50 | vectors = vectors.sign()
51 | elif self.noise_type == "gaussian":
52 | pass
53 | else:
54 | raise ValueError("Noise type not implemented")
55 |
56 | logp = self.model(dup_samples).sum()
57 |
58 | grad1 = autograd.grad(logp, dup_samples, create_graph=True)[0]
59 | loss1 = torch.sum(grad1 * grad1, dim=-1) / 2.0
60 | gradv = torch.sum(grad1 * vectors)
61 |
62 | grad2 = autograd.grad(gradv, dup_samples, create_graph=True)[0]
63 | loss2 = torch.sum(vectors * grad2, dim=-1)
64 |
65 | loss1 = loss1.view(self.n_particles, -1).mean(dim=0)
66 | loss2 = loss2.view(self.n_particles, -1).mean(dim=0)
67 |
68 | ssm_loss = (loss1 + loss2).mean()
69 | self.log("train/ssm_loss", ssm_loss)
70 |
71 | clf_loss = 0.0
72 | if self.clf_weight > 0.0:
73 | _, logits = self.model(x_lab, return_logits=True)
74 | clf_loss = self.clf_weight * F.cross_entropy(logits, y_lab)
75 | self.log("train/clf_loss", clf_loss)
76 | return ssm_loss + clf_loss
77 |
78 | def validation_step(self, batch, batch_idx):
79 | (x_lab, y_lab), (_, _) = batch
80 |
81 | if self.n_classes < 2:
82 | return
83 |
84 | _, logits = self.model(x_lab, return_logits=True)
85 | acc = (y_lab == logits.argmax(1)).float().mean(0).item()
86 | self.log("val/acc", acc)
87 |
88 | def test_step(self, batch, batch_idx):
89 | x, y = batch
90 | y_hat = self.model.classify(x)
91 |
92 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
93 | self.log("test.acc", acc)
94 |
95 | return y_hat
96 |
97 | def configure_optimizers(self):
98 | optim = torch.optim.AdamW(
99 | self.parameters(),
100 | betas=(self.momentum, 0.999),
101 | lr=self.learning_rate,
102 | weight_decay=self.weight_decay,
103 | )
104 | scheduler = torch.optim.lr_scheduler.StepLR(
105 | optim, step_size=self.lr_step_size, gamma=0.5
106 | )
107 | return [optim], [scheduler]
108 |
109 | def classify(self, x):
110 | return torch.softmax(self.model.classify(x), -1)
111 |
112 | def get_ood_scores(self, x):
113 | return {"p(x)": self.model(x)}
114 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/models/ebm/utils/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/utils/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class JEM(nn.Module):
6 | def __init__(self, model):
7 | super().__init__()
8 | self.f = model
9 |
10 | def forward(self, x, return_logits=False, y=None):
11 | logits = self.classify(x)
12 |
13 | if y is not None:
14 | return logits[torch.arange(len(x)), y]
15 |
16 | if return_logits:
17 | return logits.logsumexp(1), logits
18 | else:
19 | return logits.logsumexp(1)
20 |
21 | def classify(self, x):
22 | return self.f(x)
23 |
24 |
25 | class HDGE(JEM):
26 | def __init__(self, model, n_classes, contrast_k, contrast_t):
27 | super(HDGE, self).__init__(model)
28 |
29 | self.K = contrast_k
30 | self.T = contrast_t
31 | self.dim = n_classes
32 |
33 | # create the queue
34 | init_logit = torch.randn(n_classes, contrast_k)
35 | self.register_buffer("queue_logit", init_logit)
36 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
37 |
38 | @torch.no_grad()
39 | def _dequeue_and_enqueue(self, logits):
40 | # gather logits before updating queue
41 | batch_size = logits.shape[0]
42 |
43 | ptr = int(self.queue_ptr)
44 | assert self.K % batch_size == 0 # for simplicity
45 |
46 | # replace the logits at ptr (dequeue and enqueue)
47 | self.queue_logit[:, ptr : ptr + batch_size] = logits.T
48 |
49 | ptr = (ptr + batch_size) % self.K # move pointer
50 |
51 | self.queue_ptr[0] = ptr
52 |
53 | def joint(self, img, dist=None, evaluation=False):
54 | f_logit = self.class_output(self.f(img)) # queries: NxC
55 | ce_logit = f_logit # cross-entropy loss logits
56 | prob = nn.functional.normalize(f_logit, dim=1)
57 | # positive logits: Nx1
58 | l_pos = dist * prob # NxC
59 | l_pos = torch.logsumexp(l_pos, dim=1, keepdim=True) # Nx1
60 | # negative logits: NxK
61 | buffer = nn.functional.normalize(self.queue_logit.clone().detach(), dim=0)
62 | l_neg = torch.einsum("nc,ck->nck", [dist, buffer]) # NxCxK
63 | l_neg = torch.logsumexp(l_neg, dim=1) # NxK
64 |
65 | # logits: Nx(1+K)
66 | logits = torch.cat([l_pos, l_neg], dim=1)
67 |
68 | # apply temperature
69 | logits /= self.T
70 |
71 | # labels: positive key indicators
72 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
73 |
74 | # dequeue and enqueue
75 | if not evaluation:
76 | self._dequeue_and_enqueue(f_logit)
77 |
78 | return logits, labels, ce_logit, l_neg.size(1)
79 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class KHotCrossEntropyLoss(nn.Module):
6 | def __init__(self, dim=-1):
7 | super(KHotCrossEntropyLoss, self).__init__()
8 | self.dim = dim
9 |
10 | def forward(self, pred, target):
11 | pred = pred.log_softmax(dim=self.dim)
12 | return torch.mean(torch.sum(-target * pred, dim=self.dim))
13 |
14 |
15 | def smooth_one_hot(labels, classes, smoothing=0.0):
16 | """
17 | if smoothing == 0, it's one-hot method
18 | if 0 < smoothing < 1, it's smooth method
19 | """
20 | assert 0 <= smoothing < 1
21 | label_shape = torch.Size((labels.size(0), classes))
22 | with torch.no_grad():
23 | dist = torch.empty(size=label_shape, device=labels.device)
24 | dist.fill_(smoothing / (classes - 1))
25 | dist.scatter_(1, labels.data.unsqueeze(-1), 1.0 - smoothing)
26 | return dist
27 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/vera.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributions
3 |
4 | from uncertainty_est.models.ebm.utils.model import JEM
5 | from uncertainty_est.archs.arch_factory import get_arch
6 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
7 | from uncertainty_est.models.ebm.utils.vera_utils import (
8 | VERADiscreteGenerator,
9 | VERAGenerator,
10 | VERAHMCGenerator,
11 | set_bn_to_eval,
12 | set_bn_to_train,
13 | )
14 |
15 |
16 | class VERA(OODDetectionModel):
17 | def __init__(
18 | self,
19 | arch_name,
20 | arch_config,
21 | learning_rate,
22 | beta1,
23 | beta2,
24 | weight_decay,
25 | n_classes,
26 | gen_learning_rate,
27 | ebm_iters,
28 | generator_iters,
29 | entropy_weight,
30 | generator_type,
31 | generator_arch_name,
32 | generator_arch_config,
33 | generator_config,
34 | min_sigma,
35 | max_sigma,
36 | p_control,
37 | n_control,
38 | pg_control,
39 | clf_ent_weight,
40 | ebm_type,
41 | clf_weight,
42 | warmup_steps,
43 | no_g_batch_norm,
44 | batch_size,
45 | lr_decay,
46 | lr_decay_epochs,
47 | **kwargs,
48 | ):
49 | super().__init__(**kwargs)
50 | self.__dict__.update(locals())
51 | self.save_hyperparameters()
52 | self.automatic_optimization = False
53 |
54 | arch = get_arch(arch_name, arch_config)
55 | self.model = JEM(arch)
56 |
57 | g = get_arch(generator_arch_name, generator_arch_config)
58 | if generator_type == "verahmc":
59 | self.generator = VERAHMCGenerator(g, **generator_config)
60 | elif generator_type == "vera":
61 | self.generator = VERAGenerator(g, **generator_config)
62 | elif generator_type == "vera_discrete":
63 | self.generator = VERADiscreteGenerator(g, **generator_config)
64 | else:
65 | raise NotImplementedError(f"Generator '{generator_type}' not implemented!")
66 |
67 | self.predict_mode = "density"
68 |
69 | def forward(self, x):
70 | if self.predict_mode == "density":
71 | return self.model(x)
72 | elif self.predict_mode in ("logits", "probs"):
73 | _, logits = self.model(x, return_logits=True)
74 | if self.predict_mode == "logits":
75 | return logits
76 | else:
77 | return torch.softmax(logits, -1)
78 |
79 | def training_step(self, batch, batch_idx, optimizer_idx):
80 | opt_e, opt_g = self.optimizers()
81 | (x_l, y_l), (x_d, _) = batch
82 |
83 | x_l.requires_grad_()
84 | x_d.requires_grad_()
85 |
86 | # sample from q(x, h)
87 | x_g, h_g = self.generator.sample(x_l.size(0), requires_grad=True)
88 |
89 | # ebm (contrastive divergence) objective
90 | if batch_idx % self.ebm_iters == 0:
91 | ebm_loss = self.ebm_step(x_d, x_l, x_g, y_l)
92 |
93 | self.log("train/ebm_loss", ebm_loss, prog_bar=True)
94 |
95 | opt_e.zero_grad()
96 | self.manual_backward(ebm_loss, opt_e)
97 | opt_e.step()
98 |
99 | # gen obj
100 | if batch_idx % self.generator_iters == 0:
101 | gen_loss = self.generator_step(x_g, h_g)
102 |
103 | self.log("train/gen_loss", gen_loss, prog_bar=True)
104 |
105 | opt_g.zero_grad()
106 | self.manual_backward(gen_loss, opt_g)
107 | opt_g.step()
108 |
109 | # clamp sigma to (.01, max_sigma) for generators
110 | if self.generator_type in ["verahmc", "vera"]:
111 | self.generator.clamp_sigma(self.max_sigma, sigma_min=self.min_sigma)
112 |
113 | def ebm_step(self, x_d, x_l, x_g, y_l):
114 | x_g_detach = x_g.detach().requires_grad_()
115 |
116 | if self.no_g_batch_norm:
117 | self.model.apply(set_bn_to_eval)
118 | lg_detach, lg_logits = self.model(x_g_detach, return_logits=True)
119 | self.model.apply(set_bn_to_train)
120 | else:
121 | lg_detach, lg_logits = self.model(x_g_detach, return_logits=True)
122 |
123 | unsup_ent = torch.tensor(0.0)
124 | if self.ebm_type == "ssl":
125 | ld, unsup_logits = self.model(x_d, return_logits=True)
126 | _, ld_logits = self.model(x_l, return_logits=True)
127 | unsup_ent = distributions.Categorical(logits=unsup_logits).entropy()
128 | elif self.ebm_type == "jem":
129 | ld, ld_logits = self.model(x_l, return_logits=True)
130 | self.log("train/acc", (ld_logits.argmax(1) == y_l).float().mean(0))
131 | elif self.ebm_type == "p_x":
132 | ld, ld_logits = self.model(x_l).squeeze(), torch.tensor(0.0).to(self.device)
133 | else:
134 | raise NotImplementedError(f"EBM type '{self.ebm_type}' not implemented!")
135 |
136 | logp_obj = ld.mean() - lg_detach.mean()
137 | e_loss = (
138 | -logp_obj
139 | + self.p_control * (ld ** 2).mean()
140 | + self.n_control * (lg_detach ** 2).mean()
141 | + self.clf_ent_weight * unsup_ent.mean()
142 | )
143 |
144 | if self.pg_control > 0:
145 | grad_ld = (
146 | torch.autograd.grad(ld.mean(), x_l, create_graph=True)[0]
147 | .flatten(start_dim=1)
148 | .norm(2, 1)
149 | )
150 | e_loss += self.pg_control * (grad_ld ** 2.0 / 2.0).mean()
151 |
152 | self.log("train/e_loss", e_loss.item())
153 |
154 | if self.clf_weight > 0:
155 | clf_loss = self.clf_weight * self.classifier_loss(ld_logits, y_l, lg_logits)
156 | self.log("train/clf_loss", clf_loss)
157 | e_loss += clf_loss
158 |
159 | return e_loss
160 |
161 | def classifier_loss(self, ld_logits, y_l, lg_logits):
162 | return torch.nn.CrossEntropyLoss()(ld_logits, y_l)
163 |
164 | def generator_step(self, x_g, h_g):
165 | lg = self.model(x_g).squeeze()
166 | grad = torch.autograd.grad(lg.sum(), x_g, retain_graph=True)[0]
167 | ebm_gn = grad.norm(2, 1).mean()
168 |
169 | if self.entropy_weight != 0.0:
170 | entropy_obj, ent_gn = self.generator.entropy_obj(x_g, h_g)
171 |
172 | logq_obj = lg.mean() + self.entropy_weight * entropy_obj
173 | return -logq_obj
174 |
175 | def validation_step(self, batch, batch_idx):
176 | (x_l, y_l), _ = batch
177 | ld, ld_logits = self.model(x_l, return_logits=True)
178 |
179 | self.log("val/loss", -ld.mean())
180 |
181 | # Performing density estimation only
182 | if ld_logits.shape[1] < 2:
183 | return
184 |
185 | acc = (y_l == ld_logits.argmax(1)).float().mean(0)
186 | self.log("val/acc", acc)
187 | return ld_logits
188 |
189 | def test_step(self, batch, batch_idx):
190 | x, y = batch
191 | _, y_hat = self.model(x, return_logits=True)
192 |
193 | if self.n_classes < 2:
194 | return
195 |
196 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
197 | self.log("acc", acc)
198 |
199 | return y_hat
200 |
201 | def configure_optimizers(self):
202 | optim = torch.optim.AdamW(
203 | self.model.parameters(),
204 | betas=(self.beta1, self.beta2),
205 | lr=self.learning_rate,
206 | weight_decay=self.weight_decay,
207 | )
208 | gen_optim = torch.optim.AdamW(
209 | self.generator.parameters(),
210 | betas=(self.beta1, self.beta2),
211 | lr=self.gen_learning_rate,
212 | weight_decay=self.weight_decay,
213 | )
214 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
215 | optim, gamma=self.lr_decay, milestones=self.lr_decay_epochs
216 | )
217 | gen_scheduler = torch.optim.lr_scheduler.MultiStepLR(
218 | gen_optim, gamma=self.lr_decay, milestones=self.lr_decay_epochs
219 | )
220 | return [optim, gen_optim], [scheduler, gen_scheduler]
221 |
222 | def classify(self, x):
223 | return self.model.classify(x).softmax(-1)
224 |
225 | def get_ood_scores(self, x):
226 | return {"p(x)": self.model(x)}
227 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/vera_posteriornet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.distributions import Dirichlet
4 |
5 | from uncertainty_est.models.ebm.vera import VERA
6 | from uncertainty_est.models.priornet.dpn_losses import dirichlet_kl_divergence
7 | from uncertainty_est.models.priornet.uncertainties import (
8 | dirichlet_prior_network_uncertainty,
9 | )
10 |
11 |
12 | class VERAPosteriorNet(VERA):
13 | def __init__(
14 | self,
15 | arch_name,
16 | arch_config,
17 | learning_rate,
18 | beta1,
19 | beta2,
20 | weight_decay,
21 | n_classes,
22 | gen_learning_rate,
23 | ebm_iters,
24 | generator_iters,
25 | entropy_weight,
26 | generator_type,
27 | generator_arch_name,
28 | generator_arch_config,
29 | generator_config,
30 | min_sigma,
31 | max_sigma,
32 | p_control,
33 | n_control,
34 | pg_control,
35 | clf_ent_weight,
36 | ebm_type,
37 | clf_weight,
38 | warmup_steps,
39 | no_g_batch_norm,
40 | batch_size,
41 | lr_decay,
42 | lr_decay_epochs,
43 | alpha_fix=True,
44 | entropy_reg=0.0,
45 | **kwargs,
46 | ):
47 | if n_control is None:
48 | n_control = p_control
49 |
50 | super().__init__(
51 | arch_name,
52 | arch_config,
53 | learning_rate,
54 | beta1,
55 | beta2,
56 | weight_decay,
57 | n_classes,
58 | gen_learning_rate,
59 | ebm_iters,
60 | generator_iters,
61 | entropy_weight,
62 | generator_type,
63 | generator_arch_name,
64 | generator_arch_config,
65 | generator_config,
66 | min_sigma,
67 | max_sigma,
68 | p_control,
69 | n_control,
70 | pg_control,
71 | clf_ent_weight,
72 | ebm_type,
73 | clf_weight,
74 | warmup_steps,
75 | no_g_batch_norm,
76 | batch_size,
77 | lr_decay,
78 | lr_decay_epochs,
79 | sample_term=0.0,
80 | **kwargs,
81 | )
82 | self.__dict__.update(locals())
83 | self.save_hyperparameters()
84 |
85 | def classifier_loss(self, ld_logits, y_l, lg_logits):
86 | alpha = torch.exp(ld_logits) # / self.p_y.unsqueeze(0).to(self.device)
87 | # Multiply by class counts for Bayesian update
88 |
89 | if self.alpha_fix:
90 | alpha = alpha + 1
91 |
92 | soft_output = F.one_hot(y_l, self.n_classes)
93 | alpha_0 = alpha.sum(1).unsqueeze(-1).repeat(1, self.n_classes)
94 | UCE_loss = torch.mean(
95 | soft_output * (torch.digamma(alpha_0) - torch.digamma(alpha))
96 | )
97 | UCE_loss = UCE_loss + self.clf_ent_weight * -Dirichlet(alpha).entropy().mean()
98 |
99 | import pdb
100 |
101 | pdb.set_trace()
102 | lg_alpha = torch.exp(lg_logits)
103 | if self.alpha_fix:
104 | lg_alpha = lg_alpha + 1
105 | sample_loss = self.sample_term * -Dirichlet(lg_alpha).entropy().mean()
106 |
107 | return UCE_loss + sample_loss
108 |
109 | def validation_epoch_end(self, outputs):
110 | super().validation_epoch_end(outputs)
111 | alphas = torch.exp(outputs[0]).reshape(-1) + 1 if self.alpha_fix else 0
112 | self.logger.experiment.add_histogram("alphas", alphas, self.current_epoch)
113 |
114 | def get_ood_scores(self, x):
115 | px, logits = self.model(x, return_logits=True)
116 | uncert = {}
117 | uncert["p(x)"] = px
118 | dirichlet_uncerts = dirichlet_prior_network_uncertainty(
119 | logits.cpu().numpy(), alpha_correction=self.alpha_fix
120 | )
121 | uncert = {**uncert, **dirichlet_uncerts}
122 | return uncert
123 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ebm/vera_priornet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from uncertainty_est.models.ebm.vera import VERA
4 | from uncertainty_est.models.priornet.dpn_losses import UnfixedDirichletKLLoss
5 | from uncertainty_est.models.priornet.uncertainties import (
6 | dirichlet_prior_network_uncertainty,
7 | )
8 |
9 |
10 | class VERAPriorNet(VERA):
11 | def __init__(
12 | self,
13 | arch_name,
14 | arch_config,
15 | learning_rate,
16 | beta1,
17 | beta2,
18 | weight_decay,
19 | n_classes,
20 | gen_learning_rate,
21 | ebm_iters,
22 | generator_iters,
23 | entropy_weight,
24 | generator_type,
25 | generator_arch_name,
26 | generator_arch_config,
27 | generator_config,
28 | min_sigma,
29 | max_sigma,
30 | p_control,
31 | n_control,
32 | pg_control,
33 | clf_ent_weight,
34 | ebm_type,
35 | clf_weight,
36 | warmup_steps,
37 | no_g_batch_norm,
38 | batch_size,
39 | lr_decay,
40 | lr_decay_epochs,
41 | alpha_fix=True,
42 | concentration=1.0,
43 | target_concentration=None,
44 | entropy_reg=0.0,
45 | reverse_kl=True,
46 | w_neg_sample_loss=1.0,
47 | **kwargs,
48 | ):
49 | super().__init__(
50 | arch_name,
51 | arch_config,
52 | learning_rate,
53 | beta1,
54 | beta2,
55 | weight_decay,
56 | n_classes,
57 | gen_learning_rate,
58 | ebm_iters,
59 | generator_iters,
60 | entropy_weight,
61 | generator_type,
62 | generator_arch_name,
63 | generator_arch_config,
64 | generator_config,
65 | min_sigma,
66 | max_sigma,
67 | p_control,
68 | n_control,
69 | pg_control,
70 | clf_ent_weight,
71 | ebm_type,
72 | clf_weight,
73 | warmup_steps,
74 | no_g_batch_norm,
75 | batch_size,
76 | lr_decay,
77 | lr_decay_epochs,
78 | **kwargs,
79 | )
80 | self.__dict__.update(locals())
81 | self.save_hyperparameters()
82 |
83 | self.clf_loss = UnfixedDirichletKLLoss(
84 | concentration, target_concentration, entropy_reg, reverse_kl, alpha_fix
85 | )
86 |
87 | def classifier_loss(self, ld_logits, y_l, lg_logits):
88 | loss = self.clf_loss(ld_logits, y_l)
89 |
90 | loss_ood = 0.0
91 | if self.w_neg_sample_loss > 0:
92 | loss_ood = self.w_neg_sample_loss * self.clf_loss(lg_logits)
93 | self.log("train/clf_loss", loss + loss_ood)
94 | return loss
95 |
96 | def validation_epoch_end(self, outputs):
97 | super().validation_epoch_end(outputs)
98 | alphas = torch.exp(outputs[0]).reshape(-1) + self.concentration
99 | self.logger.experiment.add_histogram("alphas", alphas, self.current_epoch)
100 |
101 | def get_ood_scores(self, x):
102 | px, logits = self.model(x, return_logits=True)
103 | uncert = {}
104 | uncert["p(x)"] = px
105 | dirichlet_uncerts = dirichlet_prior_network_uncertainty(
106 | logits.cpu().numpy(), alpha_correction=self.alpha_fix
107 | )
108 | uncert = {**uncert, **dirichlet_uncerts}
109 | return uncert
110 |
--------------------------------------------------------------------------------
/uncertainty_est/models/energy_finetuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 | from uncertainty_est.models.ce_baseline import CEBaseline
6 |
7 |
8 | class EnergyFinetune(CEBaseline):
9 | def __init__(
10 | self,
11 | arch_name,
12 | arch_config,
13 | learning_rate,
14 | momentum,
15 | weight_decay,
16 | score,
17 | m_in,
18 | m_out,
19 | checkpoint,
20 | max_steps,
21 | **kwargs
22 | ):
23 | super().__init__(
24 | arch_name, arch_config, learning_rate, momentum, weight_decay, **kwargs
25 | )
26 | self.__dict__.update(locals())
27 | self.save_hyperparameters()
28 | self.load_state_dict(torch.load(checkpoint)["state_dict"])
29 |
30 | def forward(self, x):
31 | return self.backbone(x)
32 |
33 | def training_step(self, batch, batch_idx):
34 | (x, y), (x_ood, _) = batch
35 |
36 | y_hat = self(torch.cat((x, x_ood)))
37 | y_hat_ood = y_hat[len(x) :]
38 | y_hat = y_hat[: len(x)]
39 | loss = F.cross_entropy(y_hat, y)
40 | self.log("train_ce_loss", loss, prog_bar=True)
41 |
42 | # cross-entropy from softmax distribution to uniform distribution
43 | if self.score == "energy":
44 | Ec_out = -torch.logsumexp(y_hat_ood, dim=1)
45 | Ec_in = -torch.logsumexp(y_hat, dim=1)
46 | margin_loss = 0.1 * (
47 | (F.relu(Ec_in - self.m_in) ** 2).mean()
48 | + (F.relu(self.m_out - Ec_out) ** 2).mean()
49 | )
50 | self.log("train_margin_loss", margin_loss, prog_bar=True)
51 | loss += margin_loss
52 | elif self.score == "OE":
53 | loss += (
54 | 0.5 * -(y_hat_ood.mean(1) - torch.logsumexp(y_hat_ood, dim=1)).mean()
55 | )
56 |
57 | self.log("train_loss", loss)
58 | return loss
59 |
60 | def validation_step(self, batch, batch_idx):
61 | x, y = batch
62 | y_hat = self.backbone(x)
63 |
64 | loss = F.cross_entropy(y_hat, y)
65 | self.log("val/loss", loss)
66 |
67 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
68 | self.log("val_acc", acc)
69 |
70 | def test_step(self, batch, batch_idx):
71 | x, y = batch
72 | y_hat = self.backbone(x)
73 |
74 | acc = (y == y_hat.argmax(1)).float().mean(0).item()
75 | self.log("test_acc", acc)
76 |
77 | def configure_optimizers(self):
78 | optim = torch.optim.AdamW(
79 | self.parameters(),
80 | betas=(self.momentum, 0.999),
81 | lr=self.learning_rate,
82 | weight_decay=self.weight_decay,
83 | )
84 |
85 | def cosine_annealing(step, total_steps, lr_max, lr_min):
86 | return lr_min + (lr_max - lr_min) * 0.5 * (
87 | 1 + np.cos(step / total_steps * np.pi)
88 | )
89 |
90 | scheduler = torch.optim.lr_scheduler.LambdaLR(
91 | optim,
92 | lr_lambda=lambda step: cosine_annealing(
93 | step,
94 | self.max_steps,
95 | 1, # since lr_lambda computes multiplicative factor
96 | 1e-6 / self.learning_rate,
97 | ),
98 | )
99 | return [optim], [scheduler]
100 |
101 | def ood_detect(self, loader):
102 | _, logits = self.get_gt_preds(loader)
103 |
104 | uncert = {}
105 | uncert["Energy"] = torch.logsumexp(logits, 1)
106 | return uncert
107 |
--------------------------------------------------------------------------------
/uncertainty_est/models/normalizing_flow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selflein/EBM-OOD-Detection/bbd0243cd2d33cf3e20b865229fc040611a8870b/uncertainty_est/models/normalizing_flow/__init__.py
--------------------------------------------------------------------------------
/uncertainty_est/models/normalizing_flow/approx_flow.py:
--------------------------------------------------------------------------------
1 | from uncertainty_est.models.normalizing_flow.norm_flow import NormalizingFlow
2 |
3 |
4 | class ApproxNormalizingFlow(NormalizingFlow):
5 | def __init__(
6 | self,
7 | density_type,
8 | latent_dim,
9 | n_density,
10 | learning_rate,
11 | momentum,
12 | weight_decay,
13 | weight_penalty_weight=1.0,
14 | ):
15 | super().__init__(
16 | density_type,
17 | latent_dim,
18 | n_density,
19 | learning_rate,
20 | momentum,
21 | weight_decay,
22 | )
23 | assert density_type in ("orthogonal_flow", "reparameterized_flow")
24 | self.__dict__.update(locals())
25 | self.save_hyperparameters()
26 |
27 | def training_step(self, batch, batch_idx):
28 | x, _ = batch
29 | log_p = self.density_estimation.log_prob(x)
30 |
31 | loss = -log_p.mean()
32 | self.log("train/ml_loss", loss, on_epoch=True)
33 |
34 | weight_penalty = 0.0
35 | transforms = self.density_estimation.transforms
36 | for t in transforms:
37 | weight_penalty += t.compute_weight_penalty()
38 | weight_penalty /= len(transforms)
39 | self.log("train/weight_penalty", weight_penalty, on_epoch=True)
40 | loss += self.weight_penalty_weight * weight_penalty
41 |
42 | return loss
43 |
--------------------------------------------------------------------------------
/uncertainty_est/models/normalizing_flow/image_flows.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from uncertainty_est.archs.glow.glow import Glow
6 | from uncertainty_est.archs.real_nvp.real_nvp import RealNVP
7 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
8 |
9 |
10 | class AffineCouplingModel(OODDetectionModel):
11 | def __init__(self, learning_rate, momentum, weight_decay, num_classes=1, **kwargs):
12 | super().__init__(**kwargs)
13 | self.__dict__.update(locals())
14 | self.conditional_densities = []
15 |
16 | def forward(self, x):
17 | log_p_xy = []
18 | for cd in self.conditional_densities:
19 | log_p_xy.append(cd.log_prob(x))
20 | log_p_xy = torch.stack(log_p_xy, 1)
21 |
22 | return log_p_xy
23 |
24 | def training_step(self, batch, batch_idx):
25 | x, y = batch
26 | log_p_xy = self(x)
27 | log_p_x = torch.logsumexp(log_p_xy, dim=1)
28 |
29 | if self.num_classes > 1:
30 | loss = F.cross_entropy(log_p_xy, y)
31 | self.log("train/clf_loss", loss)
32 | else:
33 | loss = -log_p_x.mean()
34 | self.log("train/loss", loss)
35 |
36 | return loss
37 |
38 | def validation_step(self, batch, batch_idx):
39 | x, y = batch
40 | log_p_xy = self(x)
41 | log_p_x = torch.logsumexp(log_p_xy, dim=1)
42 |
43 | loss = -log_p_x.mean()
44 | self.log("val/loss", loss)
45 |
46 | acc = (y == log_p_xy.argmax(1)).float().mean(0).item()
47 | self.log("val/acc", acc)
48 |
49 | def test_step(self, batch, batch_idx):
50 | x, y = batch
51 | log_p_xy = self(x)
52 | log_p_x = torch.logsumexp(log_p_xy, dim=1)
53 | self.log("log_likelihood", log_p_x.mean())
54 |
55 | acc = (y == log_p_xy.argmax(1)).float().mean(0).item()
56 | self.log("acc", acc)
57 |
58 | def configure_optimizers(self):
59 | optim = torch.optim.AdamW(
60 | self.parameters(),
61 | betas=(self.momentum, 0.999),
62 | lr=self.learning_rate,
63 | weight_decay=self.weight_decay,
64 | )
65 | return optim
66 |
67 | def classify(self, x):
68 | log_p_xy = self(x)
69 | return log_p_xy.softmax(-1)
70 |
71 | def get_ood_scores(self, x):
72 | log_p_xy = self(x)
73 | log_p_x = torch.logsumexp(log_p_xy, dim=1)
74 | return {"p(x)": log_p_x}
75 |
76 |
77 | class RealNVPModel(AffineCouplingModel):
78 | def __init__(
79 | self,
80 | num_scales,
81 | in_channels,
82 | mid_channels,
83 | num_blocks,
84 | learning_rate,
85 | momentum,
86 | weight_decay,
87 | num_classes=1,
88 | **kwargs
89 | ):
90 | super().__init__(learning_rate, momentum, weight_decay, num_classes=1, **kwargs)
91 | self.__dict__.update(locals())
92 | self.save_hyperparameters()
93 |
94 | self.conditional_densities = nn.ModuleList()
95 | for _ in range(num_classes):
96 | self.conditional_densities.append(
97 | RealNVP(num_scales, in_channels, mid_channels, num_blocks)
98 | )
99 |
100 |
101 | class GlowModel(AffineCouplingModel):
102 | def __init__(
103 | self,
104 | in_channels,
105 | num_channels,
106 | num_levels,
107 | num_steps,
108 | learning_rate,
109 | momentum,
110 | weight_decay,
111 | num_classes=1,
112 | **kwargs
113 | ):
114 | super().__init__(learning_rate, momentum, weight_decay, num_classes=1, **kwargs)
115 | self.__dict__.update(locals())
116 | self.save_hyperparameters()
117 |
118 | self.conditional_densities = nn.ModuleList()
119 | for _ in range(num_classes):
120 | self.conditional_densities.append(
121 | Glow(in_channels, num_channels, num_levels, num_steps)
122 | )
123 |
--------------------------------------------------------------------------------
/uncertainty_est/models/normalizing_flow/iresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from uncertainty_est.archs.arch_factory import get_arch
4 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
5 |
6 |
7 | class IResNetFlow(OODDetectionModel):
8 | def __init__(
9 | self,
10 | arch_name,
11 | arch_config,
12 | learning_rate,
13 | momentum,
14 | weight_decay,
15 | warmup_steps=0,
16 | ):
17 | super().__init__()
18 | self.__dict__.update(locals())
19 | self.save_hyperparameters()
20 | assert arch_name in ("iresnet_fc", "iresnet_conv")
21 |
22 | self.model = get_arch(arch_name, arch_config)
23 |
24 | def forward(self, x):
25 | return self.model(x)
26 |
27 | def training_step(self, batch, batch_idx):
28 | x, _ = batch
29 | x.requires_grad_()
30 | log_p = self.model.log_prob(x)
31 |
32 | loss = -log_p.mean()
33 | self.log("train/loss", loss)
34 | return loss
35 |
36 | def validation_step(self, batch, batch_idx):
37 | x, _ = batch
38 | with torch.enable_grad():
39 | x.requires_grad_()
40 | log_p = self.model.log_prob(x)
41 |
42 | sigmas = []
43 | for k, v in self.model.state_dict().items():
44 | if "_sigma" in k:
45 | sigmas.append(v.item())
46 | sigmas = torch.tensor(sigmas)
47 | self.log("val/sigma_mean", sigmas.mean().item())
48 |
49 | loss = -log_p.mean()
50 | self.log("val/loss", loss)
51 |
52 | def test_step(self, batch, batch_idx):
53 | x, _ = batch
54 | with torch.enable_grad():
55 | x.requires_grad_()
56 | log_p = self.model.log_prob(x)
57 | self.log("log_likelihood", log_p.mean())
58 |
59 | def configure_optimizers(self):
60 | optim = torch.optim.AdamW(
61 | self.parameters(),
62 | betas=(self.momentum, 0.999),
63 | lr=self.learning_rate,
64 | weight_decay=self.weight_decay,
65 | )
66 | return optim
67 |
68 | def get_ood_scores(self, x):
69 | with torch.enable_grad():
70 | x.requires_grad_()
71 | log_p = self.model.log_prob(x).detach()
72 | return {"p(x)": log_p}
73 |
--------------------------------------------------------------------------------
/uncertainty_est/models/normalizing_flow/norm_flow.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from uncertainty_est.archs.arch_factory import get_arch
4 | from uncertainty_est.models.ood_detection_model import OODDetectionModel
5 |
6 |
7 | class NormalizingFlow(OODDetectionModel):
8 | def __init__(
9 | self,
10 | arch_name,
11 | arch_config,
12 | learning_rate,
13 | momentum,
14 | weight_decay,
15 | **kwargs,
16 | ):
17 | super().__init__(**kwargs)
18 | self.__dict__.update(locals())
19 | self.save_hyperparameters()
20 |
21 | self.density_estimation = get_arch(arch_name, arch_config)
22 |
23 | def forward(self, x):
24 | return self.density_estimation.log_prob(x)
25 |
26 | def training_step(self, batch, batch_idx):
27 | x, _ = batch
28 | log_p = self.density_estimation.log_prob(x)
29 |
30 | loss = -log_p.mean()
31 | self.log("train/loss", loss)
32 | return loss
33 |
34 | def validation_step(self, batch, batch_idx):
35 | x, _ = batch
36 | log_p = self.density_estimation.log_prob(x)
37 |
38 | loss = -log_p.mean()
39 | self.log("val/loss", loss)
40 |
41 | def test_step(self, batch, batch_idx):
42 | x, _ = batch
43 | log_p = self.density_estimation.log_prob(x)
44 | self.log("log_likelihood", log_p.mean())
45 |
46 | def configure_optimizers(self):
47 | optim = torch.optim.AdamW(
48 | self.parameters(),
49 | betas=(self.momentum, 0.999),
50 | lr=self.learning_rate,
51 | weight_decay=self.weight_decay,
52 | )
53 | return optim
54 |
55 | def get_ood_scores(self, x):
56 | return {"p(x)": self.density_estimation.log_prob(x)}
57 |
--------------------------------------------------------------------------------
/uncertainty_est/models/ood_detection_model.py:
--------------------------------------------------------------------------------
1 | from itertools import islice
2 | from os import path
3 | from typing import Any, Dict
4 | from collections import defaultdict
5 |
6 | import torch
7 | import numpy as np
8 | from tqdm import tqdm
9 | import pytorch_lightning as pl
10 | import matplotlib.pyplot as plt
11 | from sklearn.metrics import roc_auc_score, average_precision_score
12 | from uncertainty_eval.vis import plot_score_hist
13 | from uncertainty_eval.metrics.brier import brier_score, brier_decomposition
14 | from uncertainty_eval.metrics.calibration_error import classification_calibration
15 | from uncertainty_eval.vis import draw_reliability_graph, plot_score_hist
16 |
17 | from uncertainty_est.utils.utils import to_np
18 | from uncertainty_est.utils.metrics import accuracy
19 | from uncertainty_est.data.dataloaders import get_dataloader
20 |
21 |
22 | class OODDetectionModel(pl.LightningModule):
23 | def __init__(
24 | self, ood_val_datasets=None, is_toy_dataset=False, data_shape=None, **kwargs
25 | ):
26 | super().__init__()
27 | self.ood_val_datasets = ood_val_datasets
28 | self.is_toy_dataset = is_toy_dataset
29 | self.data_shape = data_shape
30 | self.test_ood_dataloaders = []
31 |
32 | def eval_ood(
33 | self, id_loader, ood_loaders: Dict[str, Any], num=10_000
34 | ) -> Dict[str, float]:
35 | self.eval()
36 |
37 | if num > 0:
38 | max_batches = (num // id_loader.batch_size) + 1
39 | else:
40 | assert num == -1
41 | max_batches = None
42 |
43 | ood_metrics = {}
44 |
45 | # Compute ID OOD scores
46 | id_scores_dict = self.ood_detect(islice(id_loader, max_batches))
47 |
48 | # Compute OOD detection metrics
49 | for dataset_name, loader in ood_loaders:
50 | try:
51 | ood_scores_dict = self.ood_detect(islice(loader, max_batches))
52 | except Exception as e:
53 | print(e)
54 | continue
55 |
56 | for score_name, id_scores in id_scores_dict.items():
57 | try:
58 | ood_scores = ood_scores_dict[score_name]
59 |
60 | length = min(len(ood_scores), len(id_scores))
61 | ood = ood_scores[:length]
62 | id_ = id_scores[:length]
63 |
64 | if self.logger is not None and self.logger.log_dir is not None:
65 | ax = plot_score_hist(
66 | id_,
67 | ood,
68 | title="",
69 | )
70 | ax.figure.savefig(
71 | path.join(
72 | self.logger.log_dir, f"{dataset_name}_{score_name}.png"
73 | )
74 | )
75 | plt.close()
76 |
77 | preds = np.concatenate([ood, id_])
78 |
79 | labels = np.concatenate([np.zeros_like(ood), np.ones_like(id_)])
80 | ood_metrics[(dataset_name, score_name, "AUROC")] = (
81 | roc_auc_score(labels, preds) * 100.0
82 | )
83 | ood_metrics[(dataset_name, score_name, "AUPR")] = (
84 | average_precision_score(labels, preds) * 100.0
85 | )
86 |
87 | labels = np.concatenate([np.ones_like(ood), np.zeros_like(id_)])
88 | ood_metrics[(dataset_name, score_name, "AUROC")] = (
89 | roc_auc_score(labels, -preds) * 100.0
90 | )
91 | ood_metrics[(dataset_name, score_name, "AUPR")] = (
92 | average_precision_score(labels, -preds) * 100.0
93 | )
94 | except Exception as e:
95 | print(e)
96 | return ood_metrics
97 |
98 | def eval_classifier(self, loader, num=10_000):
99 | self.eval()
100 |
101 | if num > 0:
102 | max_batches = (num // loader.batch_size) + 1
103 | else:
104 | assert num == -1
105 | max_batches = None
106 |
107 | try:
108 | y, probs = self.get_gt_preds(islice(loader, max_batches))
109 | y, probs = y[:num], probs[:num]
110 | except NotImplementedError:
111 | print("Model does not support classification.")
112 | return {}
113 |
114 | try:
115 | # Compute accuracy
116 | acc = accuracy(y, probs)
117 |
118 | # Compute calibration
119 | y_np, probs_np = to_np(y), to_np(probs)
120 | ece, mce = classification_calibration(y_np, probs_np)
121 | brier = brier_score(y_np, probs_np)
122 | uncertainty, resolution, reliability = brier_decomposition(y_np, probs_np)
123 |
124 | if self.logger is not None and self.logger.log_dir is not None:
125 | fig, ax = plt.subplots(figsize=(10, 10))
126 | draw_reliability_graph(y_np, probs_np, 10, ax=ax)
127 | fig.savefig(self.logger.log_dir / "calibration.png", dpi=200)
128 | except:
129 | return {}
130 |
131 | return {
132 | "Accuracy": acc * 100,
133 | "ECE": ece * 100.0,
134 | "MCE": mce * 100.0,
135 | "Brier": brier * 100,
136 | "Brier uncertainty": uncertainty * 100,
137 | "Brier resolution": resolution * 100,
138 | "Brier reliability": reliability * 100,
139 | "Brier (via decomposition)": (reliability - resolution + uncertainty) * 100,
140 | }
141 |
142 | def setup(self, mode):
143 | if mode == "fit" and self.ood_val_datasets:
144 | batch_size = self.val_dataloader.dataloader.batch_size
145 |
146 | self.ood_val_loaders = []
147 | for ood_ds_name in self.ood_val_datasets:
148 | ood_loader = get_dataloader(
149 | ood_ds_name,
150 | "val",
151 | batch_size=batch_size,
152 | data_shape=self.data_shape,
153 | )
154 | self.ood_val_loaders.append((ood_ds_name, ood_loader))
155 |
156 | def validation_epoch_end(self, outputs):
157 | if self.is_toy_dataset:
158 | interp = torch.linspace(-4, 4, 500)
159 | x, y = torch.meshgrid(interp, interp)
160 | data = torch.stack((x.reshape(-1), y.reshape(-1)), 1).to(self.device)
161 | px = to_np(torch.exp(self(data)))
162 |
163 | fig, ax = plt.subplots()
164 | mesh = ax.pcolormesh(x, y, px.reshape(*x.shape))
165 | fig.colorbar(mesh)
166 | self.logger.experiment.add_figure("dist/p(x)", fig, self.current_epoch)
167 | plt.close()
168 |
169 | if hasattr(self, "ood_val_loaders"):
170 | ood_metrics = self.eval_ood(
171 | self.val_dataloader.dataloader, self.ood_val_loaders
172 | )
173 | self.logger.experiment.add_scalars(
174 | "val/all_ood",
175 | {", ".join(k): v for k, v in ood_metrics.items()},
176 | self.trainer.global_step,
177 | )
178 |
179 | avg_over_dataset_results = defaultdict(list)
180 | for k, v in ood_metrics.items():
181 | avg_over_dataset_results[", ".join(k[1:])].append(v)
182 |
183 | k, v = next(
184 | iter(
185 | {k: np.mean(v) for k, v in avg_over_dataset_results.items()}.items()
186 | )
187 | )
188 | self.log(f"val/ood", v)
189 |
190 | def optimizer_step(
191 | self,
192 | epoch: int = None,
193 | batch_idx: int = None,
194 | optimizer=None,
195 | optimizer_idx: int = None,
196 | optimizer_closure=None,
197 | on_tpu: bool = None,
198 | using_native_amp: bool = None,
199 | using_lbfgs: bool = None,
200 | **kwargs,
201 | ):
202 | # learning rate warm-up
203 | if (
204 | optimizer is not None
205 | and hasattr(self, "warmup_steps")
206 | and self.trainer.global_step < self.warmup_steps
207 | ):
208 | lr_scale = min(
209 | 1.0, float(self.trainer.global_step + 1) / float(self.warmup_steps)
210 | )
211 | for pg in optimizer.param_groups:
212 | pg["lr"] = lr_scale * self.hparams.learning_rate
213 |
214 | optimizer.step(closure=optimizer_closure)
215 |
216 | def ood_detect(self, loader):
217 | self.eval()
218 | torch.set_grad_enabled(False)
219 |
220 | scores = defaultdict(list)
221 | for x, _ in tqdm(loader, miniters=100):
222 | if not isinstance(x, torch.Tensor):
223 | x, _ = x
224 | x = x.to(self.device)
225 | out = self.get_ood_scores(x)
226 | for k, v in out.items():
227 | scores[k].append(to_np(v))
228 |
229 | scores = {k: np.concatenate(v) for k, v in scores.items()}
230 | return scores
231 |
232 | def get_gt_preds(self, loader):
233 | self.eval()
234 | torch.set_grad_enabled(False)
235 |
236 | gt, preds = [], []
237 | for x, y in tqdm(loader, miniters=100):
238 | x = x.to(self.device)
239 | y_hat = self.classify(x).cpu()
240 | gt.append(y)
241 | preds.append(y_hat)
242 | return torch.cat(gt), torch.cat(preds)
243 |
244 | def get_ood_scores(self, x) -> Dict[str, torch.tensor]:
245 | raise NotImplementedError
246 |
247 | def classify(self, x) -> torch.tensor:
248 | raise NotImplementedError
249 |
--------------------------------------------------------------------------------
/uncertainty_est/train.py:
--------------------------------------------------------------------------------
1 | # Fix https://github.com/pytorch/pytorch/issues/37377
2 | import numpy as _
3 |
4 | import os
5 | import sys
6 | from uuid import uuid4
7 |
8 | from torch.functional import norm
9 |
10 | sys.path.insert(0, os.getcwd())
11 |
12 | from pathlib import Path
13 | from datetime import datetime
14 |
15 | import yaml
16 | import hydra
17 | import pytorch_lightning as pl
18 | from omegaconf import DictConfig, OmegaConf
19 |
20 | from uncertainty_est.models import MODELS
21 | from uncertainty_est.data.dataloaders import get_dataloader
22 |
23 |
24 | @hydra.main(config_path="../configs", config_name="config")
25 | def run(cfg: DictConfig) -> None:
26 | cfg = OmegaConf.to_container(cfg.fixed, resolve=True)
27 | run_ex(**cfg, _run=cfg)
28 |
29 |
30 | def run_ex(
31 | trainer_config,
32 | model_name,
33 | model_config,
34 | dataset,
35 | batch_size,
36 | ood_dataset,
37 | earlystop_config,
38 | checkpoint_config,
39 | data_shape,
40 | seed,
41 | _run,
42 | num_classes=1,
43 | sigma=0.0,
44 | output_folder=None,
45 | log_dir=None,
46 | num_workers=4,
47 | test_ood_datasets=[],
48 | mutation_rate=0.0,
49 | num_cat=1,
50 | normalize=True,
51 | **kwargs,
52 | ):
53 | pl.seed_everything(seed)
54 | assert num_classes > 0
55 |
56 | model = MODELS[model_name](**model_config)
57 |
58 | train_loader = get_dataloader(
59 | dataset,
60 | "train",
61 | batch_size,
62 | data_shape=data_shape,
63 | ood_dataset=ood_dataset,
64 | sigma=sigma,
65 | num_workers=num_workers,
66 | mutation_rate=mutation_rate,
67 | normalize=normalize,
68 | )
69 | val_loader = get_dataloader(
70 | dataset,
71 | "val",
72 | batch_size,
73 | data_shape=data_shape,
74 | sigma=sigma,
75 | ood_dataset=ood_dataset,
76 | num_workers=num_workers,
77 | normalize=normalize,
78 | )
79 | test_loader = get_dataloader(
80 | dataset,
81 | "test",
82 | batch_size,
83 | data_shape=data_shape,
84 | sigma=sigma,
85 | ood_dataset=None,
86 | num_workers=num_workers,
87 | normalize=normalize,
88 | )
89 |
90 | if log_dir == None:
91 | out_path = Path("logs") / model_name / dataset
92 | else:
93 | out_path = Path(log_dir)
94 |
95 | output_folder = (
96 | f'{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}_{uuid4()}'
97 | if output_folder is None
98 | else output_folder
99 | )
100 |
101 | # Circumvent issue when starting multiple versions with the same name
102 | trys = 10
103 | for i in range(trys):
104 | try:
105 | logger = pl.loggers.TensorBoardLogger(
106 | out_path, name=output_folder, default_hp_metric=False
107 | )
108 | out_dir = Path(logger.log_dir)
109 | out_dir.mkdir(exist_ok=False, parents=True)
110 | break
111 | except FileExistsError as e:
112 | if i == (trys - 1):
113 | raise ValueError("Could not create log folder") from e
114 | print("Failed to create unique log folder. Trying again.")
115 |
116 | with (out_dir / "config.yaml").open("w") as f:
117 | f.write(yaml.dump(_run))
118 |
119 | callbacks = []
120 | callbacks.append(pl.callbacks.ModelCheckpoint(dirpath=out_dir, **checkpoint_config))
121 |
122 | if earlystop_config is not None:
123 | es_callback = pl.callbacks.EarlyStopping(**earlystop_config)
124 | callbacks.append(es_callback)
125 |
126 | trainer = pl.Trainer(
127 | **trainer_config,
128 | logger=logger,
129 | callbacks=callbacks,
130 | progress_bar_refresh_rate=100,
131 | )
132 | trainer.fit(model, train_loader, val_loader)
133 |
134 | try:
135 | _ = trainer.test(test_dataloaders=test_loader)
136 | except:
137 | _ = trainer.test(test_dataloaders=test_loader, ckpt_path=None)
138 |
139 | test_ood_dataloaders = []
140 | for test_ood_dataset in test_ood_datasets:
141 | loader = get_dataloader(
142 | test_ood_dataset,
143 | "test",
144 | batch_size,
145 | data_shape=data_shape,
146 | sigma=sigma,
147 | ood_dataset=None,
148 | num_workers=num_workers,
149 | normalize=normalize,
150 | )
151 | test_ood_dataloaders.append((test_ood_dataset, loader))
152 | ood_results = model.eval_ood(test_loader, test_ood_dataloaders)
153 | ood_results = {", ".join(k): v for k, v in ood_results.items()}
154 |
155 | clf_results = model.eval_classifier(test_loader)
156 |
157 | results = {**ood_results, **clf_results}
158 |
159 | logger.log_hyperparams(model.hparams, results)
160 | return results
161 |
162 |
163 | if __name__ == "__main__":
164 | run()
165 |
--------------------------------------------------------------------------------
/uncertainty_est/utils/dirichlet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import gammaln, digamma
3 |
4 |
5 | def dirichlet_prior_network_uncertainty(logits, epsilon=1e-10, alpha_correction=True):
6 | """
7 |
8 | :param logits:
9 | :param epsilon:
10 | :return:
11 | """
12 |
13 | logits = np.asarray(logits, dtype=np.float64)
14 | alphas = np.exp(logits)
15 | if alpha_correction:
16 | alphas = alphas + 1
17 | alpha0 = np.sum(alphas, axis=1, keepdims=True)
18 | probs = alphas / alpha0
19 |
20 | conf = np.max(probs, axis=1)
21 |
22 | entropy_of_exp = -np.sum(probs * np.log(probs + epsilon), axis=1)
23 | expected_entropy = -np.sum(
24 | (alphas / alpha0) * (digamma(alphas + 1) - digamma(alpha0 + 1.0)), axis=1
25 | )
26 | mutual_info = entropy_of_exp - expected_entropy
27 |
28 | epkl = np.squeeze((alphas.shape[1] - 1.0) / alpha0)
29 |
30 | dentropy = (
31 | np.sum(
32 | gammaln(alphas) - (alphas - 1.0) * (digamma(alphas) - digamma(alpha0)),
33 | axis=1,
34 | keepdims=True,
35 | )
36 | - gammaln(alpha0)
37 | )
38 |
39 | uncertainty = {
40 | "confidence": 1 - conf,
41 | "entropy_of_expected": entropy_of_exp,
42 | "expected_entropy": expected_entropy,
43 | "mutual_information": mutual_info,
44 | "EPKL": epkl,
45 | "differential_entropy": np.squeeze(dentropy),
46 | }
47 |
48 | return uncertainty
49 |
--------------------------------------------------------------------------------
/uncertainty_est/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(y: torch.Tensor, y_hat: torch.Tensor):
5 | return (y == y_hat.argmax(dim=1)).float().mean(0).item()
6 |
--------------------------------------------------------------------------------
/uncertainty_est/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from scipy.integrate import trapz
4 |
5 |
6 | def to_np(tensor: torch.Tensor):
7 | return tensor.detach().cpu().numpy()
8 |
9 |
10 | def eval_func_on_grid(
11 | density_func,
12 | interval=(-10, 10),
13 | num_samples=200,
14 | device="cpu",
15 | dimensions=2,
16 | batch_size=10_000,
17 | dtype=torch.float32,
18 | ):
19 | interp = torch.linspace(*interval, num_samples)
20 | grid_coords = torch.meshgrid(*[interp for _ in range(dimensions)])
21 | grid = torch.stack([coords.reshape(-1) for coords in grid_coords], 1).to(dtype)
22 |
23 | vals = []
24 | for samples in tqdm(torch.split(grid, batch_size)):
25 | vals.append(density_func(samples.to(device)).cpu())
26 | vals = torch.cat(vals)
27 | return grid_coords, vals
28 |
29 |
30 | def estimate_normalizing_constant(
31 | density_func,
32 | interval=(-10, 10),
33 | num_samples=200,
34 | device="cpu",
35 | dimensions=2,
36 | batch_size=10_000,
37 | dtype=torch.float32,
38 | ):
39 | """
40 | Numerically integrate a funtion in the specified interval.
41 | """
42 | with torch.no_grad():
43 | _, p_x = eval_func_on_grid(
44 | density_func, interval, num_samples, device, dimensions, batch_size, dtype
45 | )
46 |
47 | dx = (abs(interval[0]) + abs(interval[1])) / num_samples
48 | # Integrate one dimension after another
49 | grid_vals = to_np(p_x).reshape(*[num_samples for _ in range(dimensions)])
50 | for _ in range(dimensions):
51 | grid_vals = trapz(grid_vals, dx=dx, axis=-1)
52 |
53 | return torch.tensor(grid_vals)
54 |
55 |
56 | def sum_except_batch(x, num_batch_dims=1):
57 | """Sums all elements of `x` except for the first `num_batch_dims` dimensions."""
58 | if x.ndimension() == 1:
59 | return x
60 | reduce_dims = list(range(num_batch_dims, x.ndimension()))
61 | return torch.sum(x, dim=reduce_dims)
62 |
63 |
64 | def split_leading_dim(x, shape):
65 | """Reshapes the leading dim of `x` to have the given shape."""
66 | new_shape = torch.Size(shape) + x.shape[1:]
67 | return torch.reshape(x, new_shape)
68 |
69 |
70 | if __name__ == "__main__":
71 | dims = 2
72 | samples = 100
73 | print(
74 | estimate_normalizing_constant(
75 | lambda x: torch.empty(x.shape[0]).fill_(1 / (samples ** dims)),
76 | num_samples=samples,
77 | dimensions=dims,
78 | )
79 | )
80 |
--------------------------------------------------------------------------------