├── .env.example ├── .gitignore ├── LICENSE ├── environment.yml ├── hopfield_boosting ├── __init__.py ├── __main__.py ├── config │ ├── __init__.py │ ├── auxiliary │ │ ├── cifar-10.yaml │ │ ├── cifar-100.yaml │ │ └── dataset │ │ │ └── cifar-10-imagenet.yaml │ ├── data │ │ └── hopfield_base.yaml │ ├── dataset │ │ ├── cifar-10-val.yaml │ │ ├── cifar-10.yaml │ │ ├── cifar-100-val.yaml │ │ └── cifar-100.yaml │ ├── early_stopping │ │ └── low_accuracy.yaml │ ├── energy │ │ └── border_energy.yaml │ ├── loss │ │ └── ce.yaml │ ├── model │ │ └── resnet_18_32x32.yaml │ ├── ood_eval │ │ ├── cifar-10.yaml │ │ ├── cifar-100.yaml │ │ └── data │ │ │ ├── cifar-10-eval.yaml │ │ │ └── cifar-10-test.yaml │ ├── optim │ │ └── sgd.yaml │ ├── paths │ │ └── dotenv.yaml │ ├── preprocess │ │ ├── cifar-10.yaml │ │ └── cifar-100.yaml │ ├── projection_head │ │ ├── identity.yaml │ │ ├── linear.yaml │ │ └── multilayer.yaml │ ├── resnet-18-cifar-10-aux-from-scratch.yaml │ ├── resnet-18-cifar-100-aux-from-scratch.yaml │ ├── scheduler │ │ └── cosine_annealing.yaml │ ├── trainer │ │ ├── energy.yaml │ │ ├── hopfield.yaml │ │ └── oe.yaml │ └── transform │ │ ├── cifar-10.yaml │ │ └── to-tensor.yaml ├── data │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── imagenet.py │ │ ├── kamnist.py │ │ ├── noise.py │ │ ├── random_noise.py │ │ └── svhn.py │ ├── setup.py │ └── softmax_sampler.py ├── download_data.py ├── early_stopping.py ├── encoder │ ├── __init__.py │ ├── densenet.py │ ├── identity.py │ ├── mlp.py │ └── resnet_18_32x32.py ├── energy.py ├── logger │ ├── __init__.py │ ├── base.py │ ├── console.py │ ├── file.py │ └── wandb.py ├── main.py ├── ood │ ├── __init__.py │ ├── evaluator.py │ ├── metrics.py │ └── test │ │ ├── test_detector.py │ │ └── test_metrics.py ├── trainer.py ├── transforms │ └── img_transforms.py ├── util.py └── utils │ ├── config_util.py │ ├── eval_util.py │ ├── model_util.py │ └── plot_util.py ├── images └── figure1.png ├── notebooks └── hopfield_boosting_demo_extended.ipynb ├── readme.md └── setup.py /.env.example: -------------------------------------------------------------------------------- 1 | # -------------------------- 2 | # Hopfield Boosting Training 3 | # -------------------------- 4 | 5 | # Downloaded data set root (downloaded data sets will be placed there) 6 | DOWNLOADED_PATH=downloaded_datasets 7 | 8 | # Project root (Model checkpoints will be stored here) 9 | PROJECT_ROOT=project_root 10 | 11 | # In-Distribution training data sets 12 | CIFAR10_ROOT=${DOWNLOADED_PATH} 13 | CIFAR100_ROOT=${DOWNLOADED_PATH} 14 | 15 | 16 | # AUX data set paths 17 | IMAGENET_ROOT=/path/to/dataset 18 | 19 | 20 | # OOD data set roots 21 | MNIST_ROOT=${DOWNLOADED_PATH} 22 | FASHION_MNIST_ROOT=${DOWNLOADED_PATH} 23 | 24 | 25 | SVHN_ROOT=${DOWNLOADED_PATH}/svhn 26 | PLACES_ROOT=${DOWNLOADED_PATH}/places365 27 | LSUN_CROP_ROOT=${DOWNLOADED_PATH}/LSUN 28 | LSUN_RESIZE_ROOT=${DOWNLOADED_PATH}/LSUN_resize 29 | ISUN_ROOT=${DOWNLOADED_PATH}/iSUN 30 | TEXTURES_ROOT=${DOWNLOADED_PATH}/dtd/images 31 | 32 | 33 | # --------------------- 34 | # Jupyter Notebook only 35 | # (you don't have to set those part for training) 36 | # --------------------- 37 | 38 | # Location where the pre-trained model is located (required for the Jupyter Notebook only) 39 | BASE_MODEL_PATH=trained_model 40 | 41 | # Optional data sets (required for the Jupyter Notebook only) 42 | RPC_ROOT=/path/to/dataset 43 | ICARTOONFACE_ROOT=/path/to/dataset 44 | SHAPES3D_ROOT=/path/to/dataset 45 | FOUR_SHAPES_ROOT=/path/to/dataset 46 | AFHQV2_ROOT=/path/to/dataset 47 | MOEIMOUTO_FACES_ROOT=/path/to/dataset 48 | IMAGENETO_ROOT=/path/to/dataset 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | .pytest_cache 4 | *.egg-info 5 | wandb 6 | saved_models 7 | datasets/ood_datasets 8 | runs 9 | outputs 10 | multirun 11 | *.pth 12 | build/ 13 | notebooks/.ipynb_checkpoints 14 | test_logs/ 15 | downloaded_datasets/ 16 | hopfield_boosting/config/paths/iml.yaml 17 | run_sweep.py 18 | .env 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Institute for Machine Learning, Johannes Kepler University Linz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hopfield-boosting 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.8 7 | - pytorch=2.2 8 | - torchvision 9 | - torchaudio 10 | - pytorch-cuda=12.1 11 | - pip 12 | - pip: 13 | - wandb 14 | - hydra-core 15 | - scipy 16 | - seaborn 17 | - matplotlib 18 | - numpy 19 | - scikit-learn 20 | - tqdm 21 | - python-dotenv -------------------------------------------------------------------------------- /hopfield_boosting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/hopfield-boosting/c5ae5b7e57d256a0a7848729052d04140d35dca8/hopfield_boosting/__init__.py -------------------------------------------------------------------------------- /hopfield_boosting/__main__.py: -------------------------------------------------------------------------------- 1 | from hopfield_boosting.main import main 2 | 3 | if __name__ == '__main__': 4 | main() 5 | -------------------------------------------------------------------------------- /hopfield_boosting/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/hopfield-boosting/c5ae5b7e57d256a0a7848729052d04140d35dca8/hopfield_boosting/config/__init__.py -------------------------------------------------------------------------------- /hopfield_boosting/config/auxiliary/cifar-10.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cifar-10-imagenet 3 | 4 | _target_: hopfield_boosting.data.DataSetup 5 | wrapper: ~ 6 | 7 | loader: 8 | _target_: torch.utils.data.DataLoader 9 | _partial_: true 10 | shuffle: true 11 | batch_size: ${aux_batch_size} 12 | -------------------------------------------------------------------------------- /hopfield_boosting/config/auxiliary/cifar-100.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - cifar-10 -------------------------------------------------------------------------------- /hopfield_boosting/config/auxiliary/dataset/cifar-10-imagenet.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.data.datasets.imagenet.ImageNet 2 | root: ${paths.imagenet} 3 | transform: 4 | _target_: torchvision.transforms.Compose 5 | transforms: 6 | - _target_: torchvision.transforms.RandomCrop 7 | _args_: 8 | - 32 9 | - ${transform} 10 | -------------------------------------------------------------------------------- /hopfield_boosting/config/data/hopfield_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - /dataset@aug.train.dataset: ${dataset} 4 | - /dataset@aug.val.dataset: ${dataset}-val 5 | - /dataset@no_aug.train.dataset: ${dataset} 6 | - /dataset@no_aug.val.dataset: ${dataset}-val 7 | - /transform@no_aug.train.dataset.transform: to-tensor 8 | - /transform@no_aug.val.dataset.transform: to-tensor 9 | 10 | aug: 11 | train: 12 | _target_: hopfield_boosting.data.DataSetup 13 | dataset: 14 | transform: ${transform} 15 | wrapper: ~ 16 | loader: 17 | _target_: torch.utils.data.DataLoader 18 | _partial_: true 19 | shuffle: true 20 | drop_last: true 21 | batch_size: ${train_batch_size} 22 | val: 23 | _target_: hopfield_boosting.data.DataSetup 24 | dataset: 25 | transform: ${transform} 26 | wrapper: ~ 27 | batch_sampler: ~ 28 | loader: 29 | _target_: torch.utils.data.DataLoader 30 | _partial_: true 31 | batch_size: ${val_batch_size} 32 | 33 | no_aug: ${data.aug} 34 | -------------------------------------------------------------------------------- /hopfield_boosting/config/dataset/cifar-10-val.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.datasets.CIFAR10 2 | root: ${paths.cifar-10} 3 | train: false 4 | download: true 5 | -------------------------------------------------------------------------------- /hopfield_boosting/config/dataset/cifar-10.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.datasets.CIFAR10 2 | root: ${paths.cifar-10} 3 | train: true 4 | download: true 5 | -------------------------------------------------------------------------------- /hopfield_boosting/config/dataset/cifar-100-val.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.datasets.CIFAR100 2 | root: ${paths.cifar-100} 3 | train: false 4 | download: true 5 | -------------------------------------------------------------------------------- /hopfield_boosting/config/dataset/cifar-100.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.datasets.CIFAR100 2 | root: ${paths.cifar-100} 3 | train: true 4 | download: true 5 | -------------------------------------------------------------------------------- /hopfield_boosting/config/early_stopping/low_accuracy.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.early_stopping.LowAccuracyEarlyStopping 2 | epoch: ??? 3 | accuracy: ??? 4 | -------------------------------------------------------------------------------- /hopfield_boosting/config/energy/border_energy.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.energy.BorderEnergy 2 | _partial_: true 3 | beta_a: ${beta} 4 | beta_b: ${beta} 5 | beta_border: ${beta} 6 | mask_diagonale: false 7 | -------------------------------------------------------------------------------- /hopfield_boosting/config/loss/ce.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.CrossEntropyLoss 2 | -------------------------------------------------------------------------------- /hopfield_boosting/config/model/resnet_18_32x32.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.encoder.CNNOODWrapper 2 | cnn: 3 | _target_: hopfield_boosting.encoder.resnet_18_32x32.ResNet18_32x32 4 | preprocess: ${preprocess} 5 | -------------------------------------------------------------------------------- /hopfield_boosting/config/ood_eval/cifar-10.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /transform@validation.in_dataset.dataset.transform: to-tensor 3 | - data@validation.out_datasets: cifar-10-eval 4 | - data@test.out_datasets: cifar-10-test 5 | 6 | validation: 7 | _target_: hopfield_boosting.ood.OODEvaluator 8 | in_dataset: 9 | _target_: hopfield_boosting.data.DataSetup 10 | dataset: 11 | _target_: torchvision.datasets.CIFAR10 12 | root: ${paths.cifar-10} 13 | train: false 14 | loader: 15 | _target_: torch.utils.data.DataLoader 16 | _partial_: true 17 | batch_size: ${val_batch_size} 18 | metrics: 19 | fpr95: 20 | _target_: hopfield_boosting.ood.FPR95OODMetric 21 | auroc: 22 | _target_: hopfield_boosting.ood.AUROCOODMetric 23 | auprs: 24 | _target_: hopfield_boosting.ood.AUPRSOODMetric 25 | logger: 26 | _target_: hopfield_boosting.logger.WandbLogger 27 | device: ${device} 28 | 29 | test: 30 | _target_: hopfield_boosting.ood.OODEvaluator 31 | in_dataset: ${ood_eval.validation.in_dataset} 32 | metrics: ${ood_eval.validation.metrics} 33 | logger: 34 | _target_: hopfield_boosting.logger.FileLogger 35 | path: test_logs/ 36 | device: ${device} 37 | -------------------------------------------------------------------------------- /hopfield_boosting/config/ood_eval/cifar-100.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /transform@validation.in_dataset.dataset.transform: to-tensor 3 | - data@validation.out_datasets: cifar-10-eval 4 | - data@test.out_datasets: cifar-10-test 5 | 6 | validation: 7 | _target_: hopfield_boosting.ood.OODEvaluator 8 | in_dataset: 9 | _target_: hopfield_boosting.data.DataSetup 10 | dataset: 11 | _target_: torchvision.datasets.CIFAR100 12 | root: ${paths.cifar-100} 13 | train: false 14 | loader: 15 | _target_: torch.utils.data.DataLoader 16 | _partial_: true 17 | batch_size: ${val_batch_size} 18 | metrics: 19 | fpr95: 20 | _target_: hopfield_boosting.ood.FPR95OODMetric 21 | auroc: 22 | _target_: hopfield_boosting.ood.AUROCOODMetric 23 | auprs: 24 | _target_: hopfield_boosting.ood.AUPRSOODMetric 25 | logger: 26 | _target_: hopfield_boosting.logger.WandbLogger 27 | device: ${device} 28 | 29 | test: 30 | _target_: hopfield_boosting.ood.OODEvaluator 31 | in_dataset: ${ood_eval.validation.in_dataset} 32 | metrics: ${ood_eval.validation.metrics} 33 | logger: 34 | _target_: hopfield_boosting.logger.FileLogger 35 | path: test_logs/ 36 | device: ${device} 37 | -------------------------------------------------------------------------------- /hopfield_boosting/config/ood_eval/data/cifar-10-eval.yaml: -------------------------------------------------------------------------------- 1 | imagenet: 2 | _target_: hopfield_boosting.data.DataSetup 3 | dataset: 4 | _target_: hopfield_boosting.data.datasets.ImageNet 5 | root: ${paths.imagenet} 6 | transform: 7 | _target_: torchvision.transforms.Compose 8 | transforms: 9 | - _target_: torchvision.transforms.Resize 10 | _args_: 11 | - 32 12 | - _target_: torchvision.transforms.CenterCrop 13 | _args_: 14 | - 32 15 | - _target_: torchvision.transforms.ToTensor 16 | loader: 17 | _target_: torch.utils.data.DataLoader 18 | _partial_: true 19 | batch_size: ${val_batch_size} 20 | num_workers: 2 21 | 22 | mnist: 23 | _target_: hopfield_boosting.data.DataSetup 24 | dataset: 25 | _target_: torchvision.datasets.MNIST 26 | root: ${paths.mnist} 27 | download: true 28 | transform: 29 | _target_: torchvision.transforms.Compose 30 | transforms: 31 | - _target_: torchvision.transforms.Resize 32 | _args_: 33 | - 32 34 | - _target_: torchvision.transforms.CenterCrop 35 | _args_: 36 | - 32 37 | - _target_: torchvision.transforms.Grayscale 38 | num_output_channels: 3 39 | - _target_: torchvision.transforms.ToTensor 40 | loader: ${ood_eval.validation.out_datasets.imagenet.loader} 41 | 42 | uniform_noise: 43 | _target_: hopfield_boosting.data.DataSetup 44 | dataset: 45 | _target_: hopfield_boosting.data.datasets.noise.UniformNoise 46 | len: 10000 47 | img_size: [3, 32, 32] 48 | loader: ${ood_eval.validation.out_datasets.imagenet.loader} 49 | 50 | gaussian_noise: 51 | _target_: hopfield_boosting.data.DataSetup 52 | dataset: 53 | _target_: hopfield_boosting.data.datasets.noise.GaussianNoise 54 | len: 10000 55 | img_size: [3, 32, 32] 56 | mean: [0.4914, 0.4822, 0.4465] 57 | std: [0.2023, 0.1994, 0.2010] 58 | loader: ${ood_eval.validation.out_datasets.imagenet.loader} 59 | -------------------------------------------------------------------------------- /hopfield_boosting/config/ood_eval/data/cifar-10-test.yaml: -------------------------------------------------------------------------------- 1 | svhn: 2 | _target_: hopfield_boosting.data.DataSetup 3 | dataset: 4 | _target_: hopfield_boosting.data.datasets.SVHN 5 | root: ${paths.svhn} 6 | split: test 7 | download: false 8 | transform: 9 | _target_: torchvision.transforms.ToTensor 10 | loader: 11 | _target_: torch.utils.data.DataLoader 12 | _partial_: true 13 | batch_size: ${val_batch_size} 14 | num_workers: 2 15 | 16 | dtd: 17 | _target_: hopfield_boosting.data.DataSetup 18 | dataset: 19 | _target_: torchvision.datasets.ImageFolder 20 | root: ${paths.textures} 21 | transform: 22 | _target_: torchvision.transforms.Compose 23 | transforms: 24 | - _target_: torchvision.transforms.Resize 25 | _args_: 26 | - 32 27 | - _target_: torchvision.transforms.CenterCrop 28 | _args_: 29 | - 32 30 | - _target_: torchvision.transforms.ToTensor 31 | loader: ${ood_eval.test.out_datasets.svhn.loader} 32 | 33 | places365: 34 | _target_: hopfield_boosting.data.DataSetup 35 | dataset: 36 | _target_: torch.utils.data.Subset 37 | dataset: 38 | _target_: torchvision.datasets.ImageFolder 39 | root: ${paths.places} 40 | transform: 41 | _target_: torchvision.transforms.Compose 42 | transforms: 43 | - _target_: torchvision.transforms.Resize 44 | _args_: 45 | - 32 46 | - _target_: torchvision.transforms.CenterCrop 47 | _args_: 48 | - 32 49 | - _target_: torchvision.transforms.ToTensor 50 | indices: 51 | _target_: numpy.random.choice 52 | a: 328500 53 | size: 10000 54 | replace: false 55 | loader: ${ood_eval.test.out_datasets.svhn.loader} 56 | 57 | lsun: 58 | _target_: hopfield_boosting.data.DataSetup 59 | dataset: 60 | _target_: torchvision.datasets.ImageFolder 61 | root: ${paths.lsun-crop} 62 | transform: 63 | _target_: torchvision.transforms.Compose 64 | transforms: 65 | - _target_: torchvision.transforms.Resize 66 | _args_: 67 | - 32 68 | - _target_: torchvision.transforms.CenterCrop 69 | _args_: 70 | - 32 71 | - _target_: torchvision.transforms.ToTensor 72 | loader: ${ood_eval.test.out_datasets.svhn.loader} 73 | 74 | lsun_resize: 75 | _target_: hopfield_boosting.data.DataSetup 76 | dataset: 77 | _target_: torchvision.datasets.ImageFolder 78 | root: ${paths.lsun-resize} 79 | transform: 80 | _target_: torchvision.transforms.Compose 81 | transforms: 82 | - _target_: torchvision.transforms.Resize 83 | _args_: 84 | - 32 85 | - _target_: torchvision.transforms.CenterCrop 86 | _args_: 87 | - 32 88 | - _target_: torchvision.transforms.ToTensor 89 | loader: ${ood_eval.test.out_datasets.svhn.loader} 90 | 91 | isun: 92 | _target_: hopfield_boosting.data.DataSetup 93 | dataset: 94 | _target_: torchvision.datasets.ImageFolder 95 | root: ${paths.isun} 96 | transform: 97 | _target_: torchvision.transforms.Compose 98 | transforms: 99 | - _target_: torchvision.transforms.Resize 100 | _args_: 101 | - 32 102 | - _target_: torchvision.transforms.CenterCrop 103 | _args_: 104 | - 32 105 | - _target_: torchvision.transforms.ToTensor 106 | loader: ${ood_eval.test.out_datasets.svhn.loader} 107 | -------------------------------------------------------------------------------- /hopfield_boosting/config/optim/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.SGD 2 | _partial_: true 3 | lr: ??? 4 | momentum: ??? 5 | nesterov: true 6 | -------------------------------------------------------------------------------- /hopfield_boosting/config/paths/dotenv.yaml: -------------------------------------------------------------------------------- 1 | project_root: ${oc.env:PROJECT_ROOT,/path/to/project_root} 2 | 3 | cifar-10: ${oc.env:CIFAR10_ROOT,/path/to/dataset} 4 | 5 | cifar-100: ${oc.env:CIFAR100_ROOT,/path/to/dataset} 6 | 7 | imagenet: ${oc.env:IMAGENET_ROOT,/path/to/dataset} 8 | 9 | svhn: ${oc.env:SVHN_ROOT,/path/to/dataset} 10 | 11 | lsun-crop: ${oc.env:LSUN_CROP_ROOT,/path/to/dataset} 12 | 13 | lsun-resize: ${oc.env:LSUN_RESIZE_ROOT,/path/to/dataset} 14 | 15 | isun: ${oc.env:ISUN_ROOT,/path/to/dataset} 16 | 17 | places: ${oc.env:PLACES_ROOT,/path/to/dataset} 18 | 19 | textures: ${oc.env:TEXTURES_ROOT,/path/to/dataset} 20 | 21 | mnist: ${oc.env:MNIST_ROOT,/path/to/dataset} 22 | 23 | fashion_mnist: ${oc.env:FASHION_MNIST_ROOT,/path/to/dataset} 24 | -------------------------------------------------------------------------------- /hopfield_boosting/config/preprocess/cifar-10.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Compose 2 | transforms: 3 | - _target_: torchvision.transforms.Normalize 4 | mean: [0.4914, 0.4822, 0.4465] 5 | std: [0.2023, 0.1994, 0.2010] 6 | -------------------------------------------------------------------------------- /hopfield_boosting/config/preprocess/cifar-100.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - cifar-10 3 | -------------------------------------------------------------------------------- /hopfield_boosting/config/projection_head/identity.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.Identity -------------------------------------------------------------------------------- /hopfield_boosting/config/projection_head/linear.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.Linear 2 | _args_: 3 | - 512 4 | - 512 5 | -------------------------------------------------------------------------------- /hopfield_boosting/config/projection_head/multilayer.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.encoder.mlp.MLP 2 | no_layers: 2 3 | in_dim: 512 4 | hidden_dim: 256 5 | out_dim: 128 -------------------------------------------------------------------------------- /hopfield_boosting/config/resnet-18-cifar-10-aux-from-scratch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - loss: ce 3 | - optim: sgd 4 | - data: hopfield_base 5 | - transform: cifar-10 6 | - ood_eval: ${dataset} 7 | - auxiliary: ${dataset} 8 | - preprocess: ${dataset} 9 | - dataset: cifar-10 10 | - scheduler: cosine_annealing 11 | - early_stopping: low_accuracy 12 | - energy: border_energy 13 | - model: resnet_18_32x32 14 | - trainer: hopfield 15 | - projection_head: multilayer 16 | - paths: dotenv 17 | - _self_ 18 | 19 | no_epochs: 100 20 | device: cuda 21 | train_batch_size: 128 22 | aux_batch_size: 2500 23 | val_batch_size: 1024 24 | image_size: 32 25 | force_deterministic: false 26 | num_classes: 10 27 | use_seed: false 28 | 29 | early_stopping: 30 | accuracy: 0.3 31 | epoch: 5 32 | 33 | beta: 4 34 | 35 | optim: 36 | lr: 0.1 37 | weight_decay: 0.0005 38 | momentum: 0.9 39 | 40 | hydra: 41 | job: 42 | chdir: false 43 | 44 | project_name: hopfield_boosting 45 | 46 | ood_weight: 0.5 47 | 48 | sampler: 49 | _target_: hopfield_boosting.data.softmax_sampler.SoftmaxBorderSampler 50 | num_tested_samples: 400_000 51 | out_batch_size: 128 52 | replacement: true 53 | device: ${device} 54 | 55 | do_sampling: true 56 | recursive_sampling: false 57 | -------------------------------------------------------------------------------- /hopfield_boosting/config/resnet-18-cifar-100-aux-from-scratch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - resnet-18-cifar-10-aux-from-scratch 3 | - override dataset: cifar-100 4 | 5 | project_name: hopfield_boosting-cifar-100 6 | num_classes: 100 7 | -------------------------------------------------------------------------------- /hopfield_boosting/config/scheduler/cosine_annealing.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 2 | eta_min: 1e-6 3 | T_max: please_set_manually 4 | -------------------------------------------------------------------------------- /hopfield_boosting/config/trainer/energy.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.trainer.EnergyTrainer 2 | model: ${model} 3 | classifier: 4 | _target_: torch.nn.Linear 5 | _args_: 6 | - 512 7 | - ${num_classes} 8 | criterion: ${loss} 9 | optim: ${optim} 10 | energy_fn: ~ 11 | ood_weight: 0.1 12 | m_in: -23 13 | m_out: -5 14 | -------------------------------------------------------------------------------- /hopfield_boosting/config/trainer/hopfield.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.trainer.HopfieldTrainer 2 | model: ${model} 3 | classifier: 4 | _target_: torch.nn.Linear 5 | _args_: 6 | - 512 7 | - ${num_classes} 8 | criterion: ${loss} 9 | optim: ${optim} 10 | energy_fn: ${energy} 11 | ood_weight: ${ood_weight} 12 | beta: ${beta} 13 | projection_head: ${projection_head} 14 | use_ood: true 15 | -------------------------------------------------------------------------------- /hopfield_boosting/config/trainer/oe.yaml: -------------------------------------------------------------------------------- 1 | _target_: hopfield_boosting.trainer.MSPTrainer 2 | model: ${model} 3 | classifier: 4 | _target_: torch.nn.Linear 5 | _args_: 6 | - 512 7 | - ${num_classes} 8 | criterion: ${loss} 9 | optim: ${optim} 10 | energy_fn: ~ 11 | ood_weight: 0.5 12 | -------------------------------------------------------------------------------- /hopfield_boosting/config/transform/cifar-10.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Compose 2 | transforms: 3 | - _target_: torchvision.transforms.RandomCrop 4 | size: 32 5 | padding: 4 6 | - _target_: torchvision.transforms.RandomHorizontalFlip 7 | - _target_: torchvision.transforms.ToTensor 8 | -------------------------------------------------------------------------------- /hopfield_boosting/config/transform/to-tensor.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.ToTensor 2 | -------------------------------------------------------------------------------- /hopfield_boosting/data/__init__.py: -------------------------------------------------------------------------------- 1 | from hopfield_boosting.data.setup import DataSetup 2 | from hopfield_boosting.data.softmax_sampler import SoftmaxBorderSampler -------------------------------------------------------------------------------- /hopfield_boosting/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from hopfield_boosting.data.datasets.imagenet import ImageNet 2 | from hopfield_boosting.data.datasets.svhn import SVHN 3 | -------------------------------------------------------------------------------- /hopfield_boosting/data/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | from torchvision.transforms import ToPILImage 7 | 8 | 9 | def unpickle(file): 10 | with open(file, 'rb') as fo: 11 | dict = pickle.load(fo) 12 | return dict 13 | 14 | class ImageNet(torch.utils.data.Dataset): 15 | 16 | def __init__(self, transform=None, img_size=64, root=None): 17 | self.root = root 18 | self.S = np.zeros(11, dtype=np.int32) 19 | self.img_size = img_size 20 | self.labels = [] 21 | for idx in range(1, 11): 22 | data_file = os.path.join(self.root, 'train_data_batch_{}'.format(idx)) 23 | d = unpickle(data_file) 24 | y = d['labels'] 25 | y = [i-1 for i in y] 26 | self.labels.extend(y) 27 | self.S[idx] = self.S[idx-1] + len(y) 28 | 29 | self.labels = np.array(self.labels) 30 | self.N = len(self.labels) 31 | self.curr_batch = -1 32 | 33 | self.transform = transform 34 | self.batch_cache = dict() 35 | 36 | def load_image_batch(self, batch_index): 37 | if batch_index in self.batch_cache: 38 | return self.batch_cache[batch_index] 39 | 40 | data_file = os.path.join(self.root, 'train_data_batch_{}'.format(batch_index)) 41 | d = unpickle(data_file) 42 | x = d['data'] 43 | 44 | img_size = self.img_size 45 | img_size2 = img_size * img_size 46 | x = np.dstack((x[:, :img_size2], x[:, img_size2:2*img_size2], x[:, 2*img_size2:])) 47 | x = x.reshape((x.shape[0], img_size, img_size, 3)) 48 | 49 | self.batch_cache[batch_index] = x 50 | return x 51 | 52 | def get_batch_index(self, index): 53 | j = 1 54 | while index >= self.S[j]: 55 | j += 1 56 | return j 57 | 58 | def load_image(self, index): 59 | batch_index = self.get_batch_index(index) 60 | if self.curr_batch != batch_index: 61 | batch = self.load_image_batch(batch_index) 62 | 63 | return batch[index-self.S[batch_index-1]] 64 | 65 | def __getitem__(self, index): 66 | img = ToPILImage()(self.load_image(index)) 67 | if self.transform is not None: 68 | img = self.transform(img) 69 | 70 | return img, index 71 | 72 | def __len__(self): 73 | return self.N -------------------------------------------------------------------------------- /hopfield_boosting/data/datasets/kamnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | 5 | class KaMNIST(torch.utils.data.Dataset): 6 | def __init__(self, root='../datasets/ood_datasets/kamnist/', transform=None): 7 | self.kamnist_x = np.load(f'{root}/X_KaMNIST_ICLR.npy') 8 | self.kamnist_y = np.load(f'{root}/y_KaMNIST_ICLR.npy') 9 | self.transform = transform 10 | 11 | def __getitem__(self, idx): 12 | return self.transform(torchvision.transforms.ToPILImage()(self.kamnist_x[idx])), self.kamnist_y[idx] 13 | 14 | def __len__(self): 15 | return len(self.kamnist_x) 16 | -------------------------------------------------------------------------------- /hopfield_boosting/data/datasets/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class GaussianNoise(torch.utils.data.Dataset): 4 | def __init__(self, len, img_size, mean, std): 5 | self.len = len 6 | self.img_size = list(img_size) 7 | self.mean = torch.tensor(mean).reshape(-1, 1, 1) 8 | self.std = torch.tensor(std).reshape(-1, 1, 1) 9 | 10 | def __len__(self): 11 | return self.len 12 | 13 | def __getitem__(self, idx): 14 | generator = torch.Generator() 15 | generator.manual_seed(idx) 16 | return torch.randn(self.img_size, generator=generator) * self.std + self.mean, torch.empty((1,)) 17 | 18 | 19 | class UniformNoise(torch.utils.data.Dataset): 20 | def __init__(self, len, img_size): 21 | self.len = len 22 | self.img_size = list(img_size) 23 | 24 | def __len__(self): 25 | return self.len 26 | 27 | def __getitem__(self, idx): 28 | generator = torch.Generator() 29 | generator.manual_seed(idx) 30 | return torch.rand(self.img_size, generator=generator), torch.empty((1,)) 31 | -------------------------------------------------------------------------------- /hopfield_boosting/data/datasets/random_noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | 5 | class RandomNoiseDataset(torch.utils.data.Dataset): 6 | def __init__(self, size=(32, 32), no_classes=10, transform=None, length=10000): 7 | self.no_classes = no_classes 8 | self.size = size 9 | self.transform = transform 10 | self.length = length 11 | 12 | def __getitem__(self, idx): 13 | if idx >= self.length: 14 | raise ValueError('Illegal Index') 15 | rand = torch.rand(1, *self.size) 16 | if self.transform: 17 | rand = self.transform(Image.fromarray(np.array(rand[0].detach()))) 18 | return rand, torch.randint(high=self.no_classes, size=[]) 19 | 20 | 21 | def __len__(self): 22 | return self.length 23 | -------------------------------------------------------------------------------- /hopfield_boosting/data/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | 7 | 8 | class SVHN(data.Dataset): 9 | url = "" 10 | filename = "" 11 | file_md5 = "" 12 | 13 | split_list = { 14 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 15 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 16 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 17 | "selected_test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 18 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 19 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"], 20 | 'train_and_extra': [ 21 | ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 22 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 23 | ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 24 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]]} 25 | 26 | def __init__(self, root, split='train', 27 | transform=None, target_transform=None, download=False): 28 | self.root = root 29 | self.transform = transform 30 | self.target_transform = target_transform 31 | self.split = split # training set or test set or extra set 32 | 33 | if self.split not in self.split_list: 34 | raise ValueError('Wrong split entered! Please use split="train" ' 35 | 'or split="extra" or split="test" ' 36 | 'or split="train_and_extra" ') 37 | 38 | if self.split == "train_and_extra": 39 | self.url = self.split_list[split][0][0] 40 | self.filename = self.split_list[split][0][1] 41 | self.file_md5 = self.split_list[split][0][2] 42 | else: 43 | self.url = self.split_list[split][0] 44 | self.filename = self.split_list[split][1] 45 | self.file_md5 = self.split_list[split][2] 46 | 47 | # import here rather than at top of file because this is 48 | # an optional dependency for torchvision 49 | import scipy.io as sio 50 | 51 | # reading(loading) mat file as array 52 | loaded_mat = sio.loadmat(os.path.join(root, self.filename)) 53 | 54 | if self.split == "test": 55 | self.data = loaded_mat['X'] 56 | self.targets = loaded_mat['y'] 57 | # Note label 10 == 0 so modulo operator required 58 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 59 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 60 | else: 61 | self.data = loaded_mat['X'] 62 | self.targets = loaded_mat['y'] 63 | 64 | if self.split == "train_and_extra": 65 | extra_filename = self.split_list[split][1][1] 66 | loaded_mat = sio.loadmat(os.path.join(root, extra_filename)) 67 | self.data = np.concatenate([self.data, 68 | loaded_mat['X']], axis=3) 69 | self.targets = np.vstack((self.targets, 70 | loaded_mat['y'])) 71 | # Note label 10 == 0 so modulo operator required 72 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 73 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 74 | 75 | def __getitem__(self, index): 76 | if self.split == "test": 77 | img, target = self.data[index], self.targets[index] 78 | else: 79 | img, target = self.data[index], self.targets[index] 80 | 81 | # doing this so that it is consistent with all other datasets 82 | # to return a PIL Image 83 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 84 | 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | 91 | return img, target.astype('long') 92 | 93 | def __len__(self): 94 | if self.split == "test": 95 | return len(self.data) 96 | else: 97 | return len(self.data) 98 | 99 | def _check_integrity(self): 100 | root = self.root 101 | if self.split == "train_and_extra": 102 | md5 = self.split_list[self.split][0][2] 103 | fpath = os.path.join(root, self.filename) 104 | train_integrity = check_integrity(fpath, md5) 105 | extra_filename = self.split_list[self.split][1][1] 106 | md5 = self.split_list[self.split][1][2] 107 | fpath = os.path.join(root, extra_filename) 108 | return check_integrity(fpath, md5) and train_integrity 109 | else: 110 | md5 = self.split_list[self.split][2] 111 | fpath = os.path.join(root, self.filename) 112 | return check_integrity(fpath, md5) 113 | 114 | def download(self): 115 | if self.split == "train_and_extra": 116 | md5 = self.split_list[self.split][0][2] 117 | download_url(self.url, self.root, self.filename, md5) 118 | extra_filename = self.split_list[self.split][1][1] 119 | md5 = self.split_list[self.split][1][2] 120 | download_url(self.url, self.root, extra_filename, md5) 121 | else: 122 | md5 = self.split_list[self.split][2] 123 | download_url(self.url, self.root, self.filename, md5) 124 | -------------------------------------------------------------------------------- /hopfield_boosting/data/setup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class DataSetup: 6 | def __init__(self, 7 | dataset=None, 8 | loader=None, 9 | batch_sampler=None, 10 | sampler=None, 11 | wrapper=None, 12 | split_sizes=None, 13 | splits=None): 14 | self.dataset = dataset 15 | 16 | if wrapper is not None: 17 | self.dataset = wrapper(self.dataset) 18 | 19 | self.splits = splits 20 | self.split_sizes = split_sizes 21 | 22 | if splits is not None: 23 | assert loader is None and batch_sampler is None and sampler is None 24 | assert np.allclose(sum(split_sizes.values()), 1.), 'Relative size of splits does not add up to 1.' 25 | 26 | split_lens = {key: int(value * len(dataset)) for key, value in list(split_sizes.items())[:-1]} 27 | split_lens[list(split_sizes.keys())[-1]] = len(dataset) - sum(split_lens.values()) 28 | 29 | generator = torch.Generator() 30 | generator.manual_seed(42) 31 | subsets = {key: subset for key, subset in zip(split_sizes.keys(), torch.utils.data.random_split(dataset, split_lens.values(), generator=generator))} 32 | 33 | for split_name in self.splits.keys(): 34 | subset = subsets[split_name] 35 | self.__dict__[split_name] = self.splits[split_name](dataset=subset) 36 | 37 | else: 38 | if sampler: 39 | self.sampler = sampler(self.dataset) 40 | else: 41 | self.sampler = None 42 | if batch_sampler: 43 | self.batch_sampler = batch_sampler(self.dataset) 44 | else: 45 | self.batch_sampler = None 46 | self.collate_fn = self.batch_sampler.collate_fn if self.batch_sampler else None 47 | self.loader = loader(dataset=self.dataset, sampler=self.sampler, batch_sampler=self.batch_sampler, collate_fn=self.collate_fn) 48 | -------------------------------------------------------------------------------- /hopfield_boosting/data/softmax_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SoftmaxBorderSampler: 5 | def __init__(self, num_batches, num_tested_samples, beta, out_batch_size, replacement, device): 6 | self.num_batches = num_batches 7 | self.num_tested_samples = num_tested_samples 8 | self.beta = beta 9 | self.out_batch_size = out_batch_size 10 | self.num_samples = self.num_batches * self.out_batch_size 11 | self.replacement = replacement 12 | self.device = device 13 | 14 | def sample_border_points(self, energy_fn, data): 15 | 16 | energies = [] 17 | idxs = [] 18 | 19 | num_tested_samples = 0 20 | 21 | with torch.no_grad(): 22 | 23 | for x, idx in data: 24 | if num_tested_samples >= self.num_tested_samples: 25 | break 26 | energy = energy_fn(x.to(self.device)) 27 | 28 | energies.append(energy.cpu()) 29 | idxs.append(idx) 30 | 31 | num_tested_samples += len(x) 32 | 33 | energies = torch.concat(energies, dim=0) 34 | idxs = torch.concat(idxs, dim=0) 35 | 36 | p = torch.softmax(-self.beta*energies, dim=-1) 37 | samples = torch.multinomial(p, self.num_samples, replacement=self.replacement) 38 | sample_idxs = idxs[samples] 39 | 40 | return torch.utils.data.DataLoader(torch.utils.data.Subset(data.dataset, sample_idxs), shuffle=True, batch_size=self.out_batch_size) 41 | -------------------------------------------------------------------------------- /hopfield_boosting/download_data.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | import os 4 | 5 | from dotenv import load_dotenv 6 | 7 | def select_svhn_data(source_path='test_32x32.mat', dest_path='selected_test_32x32.mat'): 8 | # Code adopted from POEM: https://github.com/deeplearning-wisc/poem/blob/main/select_svhn_data.py 9 | import scipy.io as sio 10 | import os 11 | import numpy as np 12 | 13 | loaded_mat = sio.loadmat(source_path) 14 | 15 | data = loaded_mat['X'] 16 | targets = loaded_mat['y'] 17 | 18 | data = np.transpose(data, (3, 0, 1, 2)) 19 | 20 | selected_data = [] 21 | selected_targets = [] 22 | count = np.zeros(11) 23 | 24 | for i, y in enumerate(targets): 25 | if count[y[0]] < 1000: 26 | selected_data.append(data[i]) 27 | selected_targets.append(y) 28 | count[y[0]] += 1 29 | 30 | selected_data = np.array(selected_data) 31 | selected_targets = np.array(selected_targets) 32 | 33 | selected_data = np.transpose(selected_data, (1, 2, 3, 0)) 34 | 35 | save_mat = {'X': selected_data, 'y': selected_targets} 36 | 37 | sio.savemat(dest_path, save_mat) 38 | 39 | 40 | def download_dataset(dataset_url, download_path, supress_stderr=False): 41 | if supress_stderr: 42 | subprocess.check_call(['wget', dataset_url, '-P', download_path], stderr=subprocess.DEVNULL) 43 | else: 44 | subprocess.check_call(['wget', dataset_url, '-P', download_path]) 45 | return Path(download_path) / dataset_url.split('/')[-1] 46 | 47 | 48 | def extract_tar(tar_path, target_dir, supress_output=False): 49 | os.makedirs(target_dir, exist_ok=True) 50 | if str(tar_path).endswith('.gz'): 51 | options = '-xvzf' 52 | else: 53 | options = '-xf' 54 | if supress_output: 55 | subprocess.check_call(['tar', options, tar_path, '-C', target_dir], stdout=subprocess.DEVNULL) 56 | else: 57 | subprocess.check_call(['tar', options, tar_path, '-C', target_dir]) 58 | 59 | def remove_file(path): 60 | subprocess.check_call(['rm', path]) 61 | 62 | 63 | def prepare_all_datasets(download_path='downloaded_datasets/', supress_output=False): 64 | os.makedirs(download_path, exist_ok=True) 65 | 66 | tar_dataset_urls = { 67 | 'TEXTURES': 'https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz', 68 | 'PLACES': 'http://data.csail.mit.edu/places/places365/test_256.tar', 69 | 'LSUN_CROP': 'https://www.dropbox.com/s/fhtsw1m3qxlwj6h/LSUN.tar.gz', 70 | 'LSUN_RESIZE': 'https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz', 71 | 'ISUN': 'https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz' 72 | } 73 | 74 | for dataset_name, dataset_url in tar_dataset_urls.items(): 75 | print(f'Downloading {dataset_name}') 76 | downloaded_path = download_dataset(dataset_url, download_path, supress_stderr=supress_output) 77 | if dataset_name == 'PLACES': 78 | extract_path = Path(download_path) / 'places365' 79 | os.makedirs(extract_path) 80 | else: 81 | extract_path = download_path 82 | print(f'Extracting {dataset_name}') 83 | extract_tar(downloaded_path, extract_path, supress_output=supress_output) 84 | remove_file(downloaded_path) 85 | 86 | 87 | svhn_url = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat' 88 | 89 | print('Downloading SVHN') 90 | svhn_test_file = download_dataset(svhn_url, Path(download_path) / 'svhn') 91 | print('Selecting SVHN') 92 | select_svhn_data(svhn_test_file, Path(download_path) / 'svhn' / 'selected_test_32x32.mat') 93 | 94 | 95 | if __name__ == '__main__': 96 | load_dotenv() 97 | download_path = os.getenv('DOWNLOADED_PATH', 'downloaded_datasets') 98 | prepare_all_datasets(download_path) 99 | -------------------------------------------------------------------------------- /hopfield_boosting/early_stopping.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class EarlyStopping(ABC): 4 | @abstractmethod 5 | def should_stop(log_dict) -> bool: 6 | pass 7 | 8 | 9 | class LowAccuracyEarlyStopping(ABC): 10 | def __init__(self, accuracy, epoch, key='classifier/acc_val', epoch_key='general/epoch'): 11 | self.accuracy = accuracy 12 | self.epoch = epoch 13 | self.key = key 14 | self.epoch_key = epoch_key 15 | 16 | def should_stop(self, log_dict): 17 | if self.epoch == log_dict[self.epoch_key]: 18 | return log_dict[self.key] < self.accuracy 19 | else: 20 | return False 21 | -------------------------------------------------------------------------------- /hopfield_boosting/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class CNNOODWrapper(nn.Module): 5 | def __init__(self, cnn, preprocess) -> None: 6 | super(CNNOODWrapper, self).__init__() 7 | self.preprocess = preprocess 8 | self.module = cnn 9 | 10 | def forward(self, x): 11 | assert x.shape[-2:] == (32, 32) # ensure CIFAR-10 size 12 | with torch.no_grad(): 13 | x = self.preprocess(x) 14 | x = self.module(x) 15 | return x 16 | -------------------------------------------------------------------------------- /hopfield_boosting/encoder/densenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | from functools import cached_property 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | def __init__(self, in_planes, out_planes, dropRate=0.0, batchnorm_track_running_stats=True): 18 | super(BasicBlock, self).__init__() 19 | self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=batchnorm_track_running_stats) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 22 | padding=1, bias=False) 23 | self.droprate = dropRate 24 | def forward(self, x): 25 | out = self.conv1(self.relu(self.bn1(x))) 26 | if self.droprate > 0: 27 | out = F.dropout(out, p=self.droprate, training=self.training) 28 | return torch.cat([x, out], 1) 29 | 30 | 31 | class BottleneckBlock(nn.Module): 32 | def __init__(self, in_planes, out_planes, dropRate=0.0, batchnorm_track_running_stats=True): 33 | super(BottleneckBlock, self).__init__() 34 | inter_planes = out_planes * 4 35 | self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=batchnorm_track_running_stats) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 38 | padding=0, bias=False) 39 | self.bn2 = nn.BatchNorm2d(inter_planes, track_running_stats=batchnorm_track_running_stats) 40 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 41 | padding=1, bias=False) 42 | self.droprate = dropRate 43 | def forward(self, x): 44 | out = self.conv1(self.relu(self.bn1(x))) 45 | if self.droprate > 0: 46 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 47 | out = self.conv2(self.relu(self.bn2(out))) 48 | if self.droprate > 0: 49 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 50 | return torch.cat([x, out], 1) 51 | 52 | 53 | class TransitionBlock(nn.Module): 54 | def __init__(self, in_planes, out_planes, dropRate=0.0, batchnorm_track_running_stats=True): 55 | super(TransitionBlock, self).__init__() 56 | self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=batchnorm_track_running_stats) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 59 | padding=0, bias=False) 60 | self.droprate = dropRate 61 | def forward(self, x): 62 | out = self.conv1(self.relu(self.bn1(x))) 63 | if self.droprate > 0: 64 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 65 | return F.avg_pool2d(out, 2) 66 | 67 | 68 | class DenseBlock(nn.Module): 69 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0, batchnorm_track_running_stats=True): 70 | super(DenseBlock, self).__init__() 71 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate, batchnorm_track_running_stats) 72 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate, batchnorm_track_running_stats): 73 | layers = [] 74 | for i in range(nb_layers): 75 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate, batchnorm_track_running_stats=batchnorm_track_running_stats)) 76 | return nn.Sequential(*layers) 77 | def forward(self, x): 78 | return self.layer(x) 79 | 80 | 81 | class DenseNet3(nn.Module): 82 | def __init__(self, depth, growth_rate=12, 83 | reduction=0.5, bottleneck=True, dropRate=0.0, normalizer = None, 84 | batchnorm_track_running_stats=True): 85 | super(DenseNet3, self).__init__() 86 | 87 | in_planes = 2 * growth_rate 88 | n = (depth - 4) / 3 89 | if bottleneck == True: 90 | n = int(n/2) 91 | block = BottleneckBlock 92 | else: 93 | block = BasicBlock 94 | # 1st conv before any dense block 95 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, 96 | padding=1, bias=False) 97 | # 1st block 98 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate, batchnorm_track_running_stats=batchnorm_track_running_stats) 99 | in_planes = int(in_planes+n*growth_rate) 100 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate, batchnorm_track_running_stats=batchnorm_track_running_stats) 101 | in_planes = int(math.floor(in_planes*reduction)) 102 | # 2nd block 103 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate, batchnorm_track_running_stats=batchnorm_track_running_stats) 104 | in_planes = int(in_planes+n*growth_rate) 105 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate, batchnorm_track_running_stats=batchnorm_track_running_stats) 106 | in_planes = int(math.floor(in_planes*reduction)) 107 | # 3rd block 108 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate, batchnorm_track_running_stats=batchnorm_track_running_stats) 109 | in_planes = int(in_planes+n*growth_rate) 110 | # global average pooling and classifier 111 | self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=batchnorm_track_running_stats) 112 | self.relu = nn.ReLU(inplace=True) 113 | 114 | self.blocks = [self.block1, self.block2, self.block3] 115 | 116 | self.in_planes = in_planes 117 | self.normalizer = normalizer 118 | self._repr_dim = None 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | m.weight.data.normal_(0, math.sqrt(2. / n)) 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | m.bias.data.zero_() 128 | 129 | 130 | @cached_property 131 | def repr_dim(self): 132 | return self.in_planes 133 | 134 | def forward(self, x): 135 | if self.normalizer is not None: 136 | x = self.normalizer(x) 137 | 138 | out = self.conv1(x) 139 | out = self.trans1(self.block1(out)) 140 | out = self.trans2(self.block2(out)) 141 | out = self.block3(out) 142 | out = self.relu(self.bn1(out)) 143 | out = F.avg_pool2d(out, 8) 144 | out = out.view(-1, self.in_planes) 145 | return out 146 | -------------------------------------------------------------------------------- /hopfield_boosting/encoder/identity.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class IdentityEncoder(nn.Module): 5 | def forward(self, x): 6 | return x 7 | -------------------------------------------------------------------------------- /hopfield_boosting/encoder/mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, no_layers, in_dim, out_dim, hidden_dim): 6 | super(MLP, self).__init__() 7 | self.no_layers = no_layers 8 | assert self.no_layers >= 2, 'MLP requires at least two layers' 9 | 10 | self.in_layer = nn.Sequential(nn.Linear(in_dim, hidden_dim), nn.ReLU()) 11 | self.out_layer = nn.Linear(hidden_dim, out_dim) 12 | self.hidden = nn.Sequential(*(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) for _ in range(no_layers - 2))) 13 | 14 | self.model = nn.Sequential(self.in_layer, self.hidden, self.out_layer) 15 | 16 | def forward(self, x): 17 | return self.model(x) 18 | -------------------------------------------------------------------------------- /hopfield_boosting/encoder/resnet_18_32x32.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | # Implementation from OpenOOD: https://github.com/Jingkang50/OpenOOD/blob/main/openood/networks/resnet18_32x32.py 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(in_planes, 13 | planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=1, 17 | bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, 20 | planes, 21 | kernel_size=3, 22 | stride=1, 23 | padding=1, 24 | bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion * planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, 31 | self.expansion * planes, 32 | kernel_size=1, 33 | stride=stride, 34 | bias=False), nn.BatchNorm2d(self.expansion * planes)) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(Bottleneck, self).__init__() 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, 52 | planes, 53 | kernel_size=3, 54 | stride=stride, 55 | padding=1, 56 | bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, 59 | self.expansion * planes, 60 | kernel_size=1, 61 | bias=False) 62 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 63 | 64 | self.shortcut = nn.Sequential() 65 | if stride != 1 or in_planes != self.expansion * planes: 66 | self.shortcut = nn.Sequential( 67 | nn.Conv2d(in_planes, 68 | self.expansion * planes, 69 | kernel_size=1, 70 | stride=stride, 71 | bias=False), nn.BatchNorm2d(self.expansion * planes)) 72 | 73 | def forward(self, x): 74 | out = F.relu(self.bn1(self.conv1(x))) 75 | out = F.relu(self.bn2(self.conv2(out))) 76 | out = self.bn3(self.conv3(out)) 77 | out += self.shortcut(x) 78 | out = F.relu(out) 79 | return out 80 | 81 | 82 | class ResNet18_32x32(nn.Module): 83 | def __init__(self, block=BasicBlock, num_blocks=None, num_classes=10): 84 | super(ResNet18_32x32, self).__init__() 85 | if num_blocks is None: 86 | num_blocks = [2, 2, 2, 2] 87 | self.in_planes = 64 88 | 89 | self.conv1 = nn.Conv2d(3, 90 | 64, 91 | kernel_size=3, 92 | stride=1, 93 | padding=1, 94 | bias=False) 95 | self.bn1 = nn.BatchNorm2d(64) 96 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 97 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 98 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 99 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 100 | # self.avgpool = nn.AvgPool2d(4) 101 | self.avgpool = nn.AdaptiveAvgPool2d(1) 102 | self.fc = nn.Linear(512 * block.expansion, num_classes) 103 | self.feature_size = 512 * block.expansion 104 | 105 | def _make_layer(self, block, planes, num_blocks, stride): 106 | strides = [stride] + [1] * (num_blocks - 1) 107 | layers = [] 108 | for stride in strides: 109 | layers.append(block(self.in_planes, planes, stride)) 110 | self.in_planes = planes * block.expansion 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | feature1 = F.relu(self.bn1(self.conv1(x))) 115 | feature2 = self.layer1(feature1) 116 | feature3 = self.layer2(feature2) 117 | feature4 = self.layer3(feature3) 118 | feature5 = self.layer4(feature4) 119 | feature5 = self.avgpool(feature5) 120 | feature = feature5.view(feature5.size(0), -1) 121 | #logits_cls = self.fc(feature) 122 | #feature_list = [feature1, feature2, feature3, feature4, feature5] 123 | return feature 124 | 125 | def forward_threshold(self, x, threshold): 126 | feature1 = F.relu(self.bn1(self.conv1(x))) 127 | feature2 = self.layer1(feature1) 128 | feature3 = self.layer2(feature2) 129 | feature4 = self.layer3(feature3) 130 | feature5 = self.layer4(feature4) 131 | feature5 = self.avgpool(feature5) 132 | feature = feature5.clip(max=threshold) 133 | feature = feature.view(feature.size(0), -1) 134 | logits_cls = self.fc(feature) 135 | 136 | return logits_cls 137 | 138 | def intermediate_forward(self, x, layer_index): 139 | out = F.relu(self.bn1(self.conv1(x))) 140 | 141 | out = self.layer1(out) 142 | if layer_index == 1: 143 | return out 144 | 145 | out = self.layer2(out) 146 | if layer_index == 2: 147 | return out 148 | 149 | out = self.layer3(out) 150 | if layer_index == 3: 151 | return out 152 | 153 | out = self.layer4(out) 154 | if layer_index == 4: 155 | return out 156 | 157 | raise ValueError 158 | 159 | def get_fc(self): 160 | fc = self.fc 161 | return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy() 162 | 163 | def get_fc_layer(self): 164 | return self.fc -------------------------------------------------------------------------------- /hopfield_boosting/energy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from hopfield_boosting.util import logmeanexp 5 | 6 | class Energy(nn.Module): 7 | def __init__(self, a, b, beta_a, beta_b, normalize=True): 8 | super(Energy, self).__init__() 9 | if normalize: 10 | a = nn.functional.normalize(a, dim=-1) 11 | b = nn.functional.normalize(b, dim=-1) 12 | self.register_buffer('a', a) 13 | self.register_buffer('b', b) 14 | self.normalize = normalize 15 | self.beta_a = beta_a 16 | self.beta_b = beta_b 17 | 18 | def forward(self, x, mask_diagonale=False): 19 | if self.normalize: 20 | x = nn.functional.normalize(x, dim=-1) 21 | attn = torch.einsum('...if,...jf->...ij', x, torch.concat([self.a, self.b], dim=0)) 22 | if mask_diagonale: 23 | assert attn.shape[-1] == attn.shape[-2] 24 | attn[..., range(attn.size(-1)), range(attn.size(-1))] = -torch.inf 25 | 26 | a_energy = -logmeanexp(self.beta_a, attn[...,:len(self.a)], dim=-1) 27 | b_energy = -logmeanexp(self.beta_b, attn[...,len(self.a):], dim=-1) 28 | 29 | return a_energy, b_energy, attn 30 | 31 | 32 | class BorderEnergy(Energy): 33 | def __init__(self, a, b, beta_a, beta_b, beta_border, normalize=True, mask_diagonale=True): 34 | super(BorderEnergy, self).__init__(a, b, beta_a, beta_b, normalize=normalize) 35 | self.beta_border = beta_border 36 | self.mask_diagonale = mask_diagonale 37 | 38 | def forward(self, x, return_dict=False): 39 | a_energy, b_energy, attn = super().forward(x, self.mask_diagonale) 40 | union_energy = -1/self.beta_border*torch.logaddexp(-self.beta_border*a_energy, -self.beta_border*b_energy) + 1/self.beta_border*torch.log(torch.tensor(2)) 41 | border_energy = a_energy + b_energy - 2*union_energy 42 | if return_dict: 43 | return border_energy, { 44 | 'attn': attn, 45 | 'a_energy': a_energy, 46 | 'b_energy': b_energy, 47 | 'union_energy': union_energy, 48 | } 49 | else: 50 | return border_energy 51 | 52 | 53 | class OneSidedEnergy(Energy): 54 | def __init__(self, a, b, beta_a, beta_b, normalize=True): 55 | super(OneSidedEnergy, self).__init__(a, b, beta_a, beta_b, normalize=normalize) 56 | 57 | def forward(self, x, mask_diagonale=False, return_dict=False): 58 | a_energy, b_energy, attn = super().forward(x, mask_diagonale) 59 | one_sided_energy = a_energy - b_energy 60 | if return_dict: 61 | return one_sided_energy, { 62 | 'attn': attn, 63 | 'a_energy': a_energy, 64 | 'b_energy': b_energy, 65 | } 66 | else: 67 | return one_sided_energy 68 | 69 | 70 | class HopfieldEnergy(Energy): 71 | def __init__(self, a, beta_a, normalize=True): 72 | super(HopfieldEnergy, self).__init__(a, a, beta_a, beta_a, normalize=normalize) 73 | 74 | def forward(self, x, mask_diagonale=False, return_dict=False): 75 | a_energy, _, attn = super().forward(x, mask_diagonale) 76 | maxnorm = torch.max(torch.functional.norm(self.a, dim=-1)) 77 | if self.normalize: 78 | torch.allclose(maxnorm, torch.tensor(1.)) 79 | if self.normalize: 80 | x = torch.nn.functional.normalize(x, dim=-1) 81 | energy = a_energy + 1/2 * torch.einsum('...f,...f->...', x, x) + 1/2 * maxnorm**2 82 | if return_dict: 83 | return energy, { 84 | 'attn': attn, 85 | 'a_energy': a_energy, 86 | } 87 | else: 88 | return energy 89 | -------------------------------------------------------------------------------- /hopfield_boosting/logger/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from hopfield_boosting.logger.base import Logger 3 | from hopfield_boosting.logger.console import ConsoleLogger 4 | from hopfield_boosting.logger.file import FileLogger 5 | from hopfield_boosting.logger.wandb import WandbLogger 6 | 7 | __all__ = [Logger, FileLogger, WandbLogger, ConsoleLogger] 8 | -------------------------------------------------------------------------------- /hopfield_boosting/logger/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Logger(ABC): 5 | @abstractmethod 6 | def log(self, log_dict: dict, epoch=None): 7 | pass 8 | -------------------------------------------------------------------------------- /hopfield_boosting/logger/console.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from hopfield_boosting.logger.base import Logger 4 | 5 | 6 | class ConsoleLogger(Logger): 7 | def log(self, log_dict: dict, epoch=None): 8 | print(json.dumps(log_dict)) 9 | -------------------------------------------------------------------------------- /hopfield_boosting/logger/file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import wandb 5 | 6 | from hopfield_boosting.logger.base import Logger 7 | 8 | 9 | class FileLogger(Logger): 10 | def __init__(self, path: str, mode='w'): 11 | os.makedirs(path, exist_ok=True) 12 | self.file_path = os.path.join(path, f'{wandb.run.id}.txt') 13 | assert mode in ('w', 'a'), 'Please choose either mode "w" or "a"' 14 | self.mode = mode 15 | 16 | def log(self, log_dict: dict, epoch=None): 17 | with open(self.file_path, self.mode) as f: 18 | f.write(f'{json.dumps(log_dict, indent=2)}\n') 19 | -------------------------------------------------------------------------------- /hopfield_boosting/logger/wandb.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from hopfield_boosting.logger.base import Logger 3 | 4 | 5 | class WandbLogger(Logger): 6 | def __init__(self, prefix='ood/'): 7 | self.prefix = prefix 8 | 9 | def log(self, log_dict: dict, epoch=None): 10 | if self.prefix is not None: 11 | log_dict = {f'{self.prefix}{key}': value for key, value in log_dict.items()} 12 | log_dict['general/epoch'] = epoch 13 | wandb.log(log_dict) 14 | -------------------------------------------------------------------------------- /hopfield_boosting/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from pathlib import Path 4 | from contextlib import contextmanager 5 | 6 | import wandb 7 | import hydra 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from tqdm import tqdm 13 | from hydra.utils import instantiate 14 | from omegaconf import DictConfig, OmegaConf 15 | from torch.optim.lr_scheduler import ReduceLROnPlateau 16 | from dotenv import load_dotenv 17 | 18 | import hopfield_boosting 19 | from hopfield_boosting.util import infer_loader 20 | 21 | 22 | @contextmanager 23 | def get_patterns(model, in_loader, aux_loader, device): 24 | with infer_loader(in_loader, model=model, device=device) as id, \ 25 | infer_loader(aux_loader, model=model, device=device, max_samples=len(in_loader)*in_loader.batch_size) as aux: 26 | yield id, aux 27 | 28 | 29 | def evaluate_id_classifier(classifier, val_loader, criterion, device, epoch): 30 | logits_val = [] 31 | ys_val = [] 32 | for x_val, y_val in val_loader: 33 | x_val, y_val = x_val.to(device), y_val.to(device) 34 | 35 | logit_val = classifier(x_val) 36 | 37 | logits_val.append(logit_val) 38 | ys_val.append(y_val) 39 | 40 | logits_val = torch.concat(logits_val, dim=0) 41 | ys_val = torch.concat(ys_val, dim=0) 42 | 43 | loss = criterion(logits_val, ys_val) 44 | acc = torch.mean((torch.argmax(logits_val, dim=-1) == ys_val).to(torch.float)) 45 | 46 | d = {'classifier/loss_val': loss, 47 | 'classifier/acc_val': acc, 48 | 'general/epoch': epoch} 49 | return d 50 | 51 | 52 | def evaluate(id, aux, trainer, data, device, ood_evaluators, epoch): 53 | with torch.no_grad(): 54 | log_dict = evaluate_id_classifier(nn.Sequential(trainer.model, trainer.classifier), data.no_aug.val.loader, trainer.criterion, device, epoch=epoch) 55 | if epoch % 25 == 24: 56 | ood_tester = trainer.ood_tester(emb_in=id, emb_out=aux) 57 | for odd_evaluator in ood_evaluators.values(): 58 | odd_evaluator.evaluate( 59 | ood_tester, 60 | epoch=epoch 61 | ) 62 | return log_dict 63 | 64 | 65 | def select_aux_data(id, aux, trainer, raw_aux_loader, sampler): 66 | if sampler: 67 | border_energy = trainer.energy_fn(a=id, b=aux) 68 | border_energy = nn.Sequential(trainer.model, trainer.projection_head, border_energy) 69 | aux_loader_selected = sampler.sample_border_points(energy_fn=border_energy, data=raw_aux_loader) 70 | else: 71 | aux_loader_selected = raw_aux_loader 72 | 73 | return aux_loader_selected 74 | 75 | 76 | def evaluate_and_select_aux_data(trainer, data, aux_loader, device, ood_evaluators, sampler, early_stopping, run_dir, epoch): 77 | trainer.eval() 78 | projection_model = nn.Sequential(trainer.model, trainer.projection_head) 79 | with get_patterns(projection_model, data.aug.train.loader, aux_loader, device) as (id, aux): 80 | log_dict = evaluate( 81 | id=id, 82 | aux=aux, 83 | trainer=trainer, 84 | data=data, 85 | device=device, 86 | ood_evaluators=ood_evaluators, 87 | epoch=epoch 88 | ) 89 | aux_loader_selected = select_aux_data( 90 | id=id, 91 | aux=aux, 92 | trainer=trainer, 93 | raw_aux_loader=aux_loader, 94 | sampler=sampler 95 | ) 96 | 97 | torch.save(trainer.model.state_dict(), f'{run_dir}/model.ckpt') 98 | torch.save(trainer.classifier.state_dict(), f'{run_dir}/classifier.ckpt') 99 | torch.save(id, f'{run_dir}/id.ckpt') 100 | torch.save(aux, f'{run_dir}/ood.ckpt') 101 | try: 102 | torch.save(trainer.projection_head.state_dict(), f'{run_dir}/projection_head.ckpt') 103 | except: 104 | print('Projection head could not be saved!') 105 | 106 | wandb.log(log_dict) 107 | if early_stopping.should_stop(log_dict): 108 | return 109 | trainer.train() 110 | return aux_loader_selected 111 | 112 | 113 | def torch_deterministic(): 114 | torch.use_deterministic_algorithms(True, warn_only=True) 115 | torch.backends.cudnn.deterministic = True 116 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 117 | 118 | 119 | @hydra.main(config_path='config', config_name='resnet-18-cifar-10-aux-from-scratch', version_base=None) 120 | def main(config: DictConfig): 121 | 122 | print(f'Location of hopfield_boosting library: {hopfield_boosting.__file__}') 123 | 124 | if config.use_seed: 125 | torch.cuda.random.manual_seed(0) 126 | torch.random.manual_seed(0) 127 | np.random.seed(0) 128 | random.seed(0) 129 | 130 | if config.force_deterministic: 131 | torch_deterministic() 132 | 133 | project_name = config.project_name 134 | load_dotenv() # does not override variables from the environment 135 | 136 | wandb.init(project=project_name, config={'hydra': OmegaConf.to_container( 137 | config, resolve=True, throw_on_missing=True)}, anonymous='allow') 138 | wandb.define_metric('epoch') 139 | 140 | print(OmegaConf.to_yaml(config)) 141 | 142 | paths = instantiate(config.paths) 143 | run_dir = Path(paths.project_root) / 'runs' / project_name / wandb.run.id 144 | os.makedirs(run_dir, exist_ok=True) 145 | OmegaConf.save(config, f'{run_dir}/cfg.yaml', resolve=True) 146 | 147 | device = config.device 148 | trainer = instantiate(config.trainer).to(device) 149 | model = trainer.model 150 | classifier = trainer.classifier 151 | projection_head = trainer.projection_head 152 | opt = trainer.optim 153 | early_stopping = instantiate(config.early_stopping) 154 | 155 | beta = config.beta 156 | 157 | data = instantiate(config.data) 158 | aux_data = instantiate(config.auxiliary) 159 | aux_loader = aux_data.loader 160 | 161 | sampler = instantiate(config.sampler, num_batches=len(data.aug.train.loader), beta=beta) 162 | ood_evaluators = instantiate(config.ood_eval) 163 | 164 | if config.get('scheduler'): 165 | if 'T_max' in config.scheduler.keys(): 166 | scheduler = instantiate(config.scheduler, optimizer=opt, T_max=len(data.aug.train.loader) * config.no_epochs) 167 | else: 168 | scheduler = instantiate(config.scheduler, optimizer=opt) 169 | else: 170 | scheduler = None 171 | 172 | wandb.watch([model, classifier, projection_head], log_freq=100, log='all') 173 | 174 | print('Initial Evaluation') 175 | aux_loader_selected = evaluate_and_select_aux_data( 176 | trainer=trainer, 177 | data=data, 178 | aux_loader=aux_loader, 179 | device=device, 180 | ood_evaluators=ood_evaluators, 181 | sampler=sampler, 182 | early_stopping=early_stopping, 183 | run_dir=run_dir, 184 | epoch=-1 185 | ) 186 | 187 | for epoch in tqdm(range(config.no_epochs)): 188 | 189 | in_losses = [] 190 | for i, ((xs_train, ys_train), (xs_ood, _)) in enumerate(zip(data.aug.train.loader, aux_loader_selected)): 191 | trainer.train() 192 | 193 | log_dict = { 194 | 'general/beta': beta, 195 | 'general/lr': opt.param_groups[0]['lr'], 196 | 'general/epoch': epoch 197 | } 198 | 199 | xs_train, ys_train = xs_train.to(device), ys_train.to(device) 200 | xs_ood = xs_ood.to(device) 201 | 202 | loss, info_dict = trainer.step(xs_train, ys_train, xs_ood) 203 | in_losses.append(info_dict['classifier/in_loss']) 204 | 205 | log_dict.update(info_dict) 206 | 207 | wandb.log(log_dict) 208 | 209 | if scheduler and isinstance(scheduler, CosineAnnealingLR): 210 | scheduler.step() 211 | 212 | if config.recursive_sampling: 213 | aux_loader_selected = aux_loader_selected 214 | else: 215 | aux_loader_selected = aux_loader 216 | 217 | aux_loader_selected = evaluate_and_select_aux_data( 218 | trainer=trainer, 219 | data=data, 220 | aux_loader=aux_loader_selected, 221 | device=device, 222 | ood_evaluators=ood_evaluators, 223 | sampler=sampler, 224 | early_stopping=early_stopping, 225 | run_dir=run_dir, 226 | epoch=epoch 227 | ) 228 | 229 | if scheduler: 230 | if isinstance(scheduler, ReduceLROnPlateau): 231 | scheduler.step(torch.mean(torch.tensor(in_losses))) 232 | elif not isinstance(scheduler, CosineAnnealingLR): 233 | scheduler.step() 234 | 235 | wandb.finish(0) 236 | 237 | if __name__ == '__main__': 238 | main() 239 | -------------------------------------------------------------------------------- /hopfield_boosting/ood/__init__.py: -------------------------------------------------------------------------------- 1 | from hopfield_boosting.ood.metrics import FPR95OODMetric, AUPRSOODMetric, AUROCOODMetric 2 | from hopfield_boosting.ood.evaluator import OODEvaluator -------------------------------------------------------------------------------- /hopfield_boosting/ood/evaluator.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict 3 | 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | class OODEvaluator: 12 | def __init__(self, in_dataset, out_datasets, metrics, logger, device=None): 13 | self.in_dataset = in_dataset.loader 14 | self.out_datasets = {key: ds.loader for key, ds in out_datasets.items()} 15 | self.metrics = metrics 16 | self.logger = logger 17 | self.device = device 18 | 19 | def evaluate(self, ood_tester, epoch=None, prefix=None): 20 | if prefix is not None: 21 | prefix = f'{prefix}_' 22 | else: 23 | prefix = '' 24 | results = {} 25 | with torch.no_grad(): 26 | in_scores = self.compute_scores_loader(self.in_dataset, ood_tester) 27 | for out_dataset_name, out_dataset in self.out_datasets.items(): 28 | out_scores = self.compute_scores_loader(out_dataset, ood_tester) 29 | for metric_name, metric in self.metrics.items(): 30 | metric_result = metric(in_scores, out_scores) 31 | results[f'{prefix}{metric_name}.{out_dataset_name}'] = metric_result 32 | 33 | if self.logger: 34 | self.logger.log(results, epoch=epoch) 35 | 36 | return results 37 | 38 | def compute_scores_loader(self, loader: DataLoader, score_fn, num_samples=10000) -> Dict[float, dict]: 39 | total_samples = 0 40 | scores = [] 41 | for xi, _ in loader: 42 | xi = xi.to(self.device) 43 | if total_samples >= num_samples: 44 | break 45 | total_samples += len(xi) 46 | score = score_fn(xi) 47 | scores.append(score) 48 | return torch.concat(scores, dim=0) 49 | 50 | -------------------------------------------------------------------------------- /hopfield_boosting/ood/metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | import sklearn.metrics 7 | 8 | 9 | class OODMetric(nn.Module, ABC): 10 | @abstractmethod 11 | def forward(self, in_scores, out_scores): 12 | pass 13 | 14 | 15 | class AUROCOODMetric(OODMetric): 16 | def forward(self, in_scores, out_scores): 17 | targets = torch.concat([torch.ones_like(in_scores, dtype=torch.int), torch.zeros_like(out_scores, dtype=torch.int)]) 18 | scores = torch.concat([in_scores, out_scores]) 19 | return float(sklearn.metrics.roc_auc_score(targets.cpu().detach().numpy(), scores.cpu().detach().numpy())) 20 | 21 | 22 | class AUPRSOODMetric(OODMetric): 23 | def forward(self, in_scores, out_scores): 24 | targets = torch.concat([torch.ones_like(in_scores, dtype=torch.int), torch.zeros_like(out_scores, dtype=torch.int)]) 25 | scores = torch.concat([in_scores, out_scores]) 26 | return float(sklearn.metrics.average_precision_score(targets.cpu().detach().numpy(), scores.cpu().detach().numpy())) 27 | 28 | 29 | class FPR95OODMetric(OODMetric): 30 | # source code adopted from: https://github.com/deeplearning-wisc/poem/blob/main/utils/anom_utils.py 31 | 32 | def stable_cumsum(self, arr, rtol=1e-05, atol=1e-08): 33 | """Use high precision for cumsum and check that final value matches sum 34 | Parameters 35 | ---------- 36 | arr : array-like 37 | To be cumulatively summed as flat 38 | rtol : float 39 | Relative tolerance, see ``np.allclose`` 40 | atol : float 41 | Absolute tolerance, see ``np.allclose`` 42 | """ 43 | out = np.cumsum(arr, dtype=np.float64) 44 | expected = np.sum(arr, dtype=np.float64) 45 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): 46 | raise RuntimeError('cumsum was found to be unstable: ' 47 | 'its last element does not correspond to sum') 48 | return out 49 | 50 | def fpr_and_fdr_at_recall(self, y_true, y_score, recall_level=0.95, pos_label=None): 51 | classes = np.unique(y_true) 52 | if (pos_label is None and 53 | not (np.array_equal(classes, [0, 1]) or 54 | np.array_equal(classes, [-1, 1]) or 55 | np.array_equal(classes, [0]) or 56 | np.array_equal(classes, [-1]) or 57 | np.array_equal(classes, [1]))): 58 | raise ValueError("Data is not binary and pos_label is not specified") 59 | elif pos_label is None: 60 | pos_label = 1. 61 | 62 | # make y_true a boolean vector 63 | y_true = (y_true == pos_label) 64 | 65 | # sort scores and corresponding truth values 66 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 67 | y_score = y_score[desc_score_indices] 68 | y_true = y_true[desc_score_indices] 69 | 70 | # y_score typically has many tied values. Here we extract 71 | # the indices associated with the distinct values. We also 72 | # concatenate a value for the end of the curve. 73 | distinct_value_indices = np.where(np.diff(y_score))[0] 74 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 75 | 76 | # accumulate the true positives with decreasing threshold 77 | tps = self.stable_cumsum(y_true)[threshold_idxs] 78 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 79 | 80 | thresholds = y_score[threshold_idxs] 81 | 82 | recall = tps / tps[-1] 83 | 84 | last_ind = tps.searchsorted(tps[-1]) 85 | sl = slice(last_ind, None, -1) # [last_ind::-1] 86 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] 87 | 88 | cutoff = np.argmin(np.abs(recall - recall_level)) 89 | return fps[cutoff] / (np.sum(np.logical_not(y_true))), fps[cutoff]/(fps[cutoff] + tps[cutoff]) 90 | 91 | def forward(self, in_scores, out_scores): 92 | targets = torch.concat([torch.ones_like(in_scores, dtype=torch.int), torch.zeros_like(out_scores, dtype=torch.int)]) 93 | scores = torch.concat([in_scores, out_scores]) 94 | fpr95, _ = self.fpr_and_fdr_at_recall(targets.cpu().detach().numpy(), scores.cpu().detach().numpy()) 95 | return fpr95 96 | -------------------------------------------------------------------------------- /hopfield_boosting/ood/test/test_detector.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from hopfield_boosting.energy import HopfieldClassifier 6 | from hopfield_boosting.classification_head import SimpleHopfieldClassificationHead 7 | from hopfield_boosting.ood import MSPOODDetector, EnergyOODDetector, MaxLogitOODDetector, ComposedOODDetector 8 | 9 | 10 | @pytest.fixture 11 | def in_data(): 12 | in_data = torch.rand([100, 100]) 13 | in_stored, in_state = in_data[:50], in_data[50:] 14 | in_stored_y = torch.concat([torch.zeros([25]), torch.ones([25])]).to(torch.int64) 15 | return in_stored, in_stored_y, in_state 16 | 17 | @pytest.fixture 18 | def out_data(): 19 | return -torch.rand([50, 100]) 20 | 21 | @pytest.fixture 22 | def classifier(in_data): 23 | with HopfieldClassifier(nn.Identity(), SimpleHopfieldClassificationHead(1.)).store_tensor(in_data[0], y=in_data[1]) as cls: 24 | yield cls 25 | 26 | 27 | @pytest.fixture 28 | def detectors(classifier): 29 | return [EnergyOODDetector(), MaxLogitOODDetector(), MSPOODDetector()] 30 | 31 | 32 | def test_detectors(detectors, in_data, out_data, classifier): 33 | _, _, in_data = in_data 34 | for detector in detectors: 35 | in_scores = next(iter(detector.compute_scores(in_data, classifier, beta=1.).values())) 36 | out_scores = next(iter(detector.compute_scores(out_data, classifier, beta=1.).values())) 37 | if not isinstance(detector, MSPOODDetector): 38 | # MSP unfortunately cant detect those outliers 39 | assert torch.min(in_scores) > torch.max(out_scores) 40 | 41 | 42 | def test_composed_detector(detectors, classifier, in_data, out_data): 43 | _, _, in_data = in_data 44 | detector = ComposedOODDetector(detectors) 45 | 46 | in_scores = detector.compute_scores(in_data, classifier, beta=1.) 47 | out_scores = detector.compute_scores(out_data, classifier, beta=1.) 48 | 49 | for k, in_score in in_scores.items(): 50 | out_score = out_scores[k] 51 | if not k == 'msp': 52 | assert torch.min(in_score) > torch.max(out_score) 53 | -------------------------------------------------------------------------------- /hopfield_boosting/ood/test/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | 5 | from hopfield_boosting.ood.metrics import AUROCOODMetric, AUPRSOODMetric, FPR95OODMetric 6 | 7 | #@pytest.fixture 8 | def dummy_ood_scores_1(): 9 | in_scores = torch.tensor([.21, .31, .41, .51]) 10 | out_scores = torch.tensor([.0, .1, .2, .3, .4]) 11 | 12 | return in_scores, out_scores 13 | 14 | def dummy_ood_scores_2(): 15 | in_scores = torch.tensor([1.1, 1.2, 1.3]) 16 | out_scores = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) 17 | return in_scores, out_scores 18 | 19 | @pytest.mark.parametrize('data_expected_result', [(dummy_ood_scores_1, 0.85), (dummy_ood_scores_2, 1.0)]) 20 | def test_auroc(data_expected_result): 21 | data, expected_result = data_expected_result 22 | in_scores, out_scores = data() 23 | auroc = AUROCOODMetric() 24 | assert np.allclose(auroc(in_scores, out_scores), expected_result) 25 | 26 | 27 | @pytest.mark.parametrize('data_expected_result', [(dummy_ood_scores_1, 0.8542), (dummy_ood_scores_2, 1.0)]) 28 | def test_auprs(data_expected_result): 29 | data, expected_result = data_expected_result 30 | in_scores, out_scores = data() 31 | auprs = AUPRSOODMetric() 32 | assert np.allclose(auprs(in_scores, out_scores), expected_result, atol=1e-4) 33 | 34 | @pytest.mark.parametrize('data_expected_result', [(dummy_ood_scores_1, 0.4), (dummy_ood_scores_2, 0.0)]) 35 | def test_fpr95(data_expected_result): 36 | data, expected_result = data_expected_result 37 | in_scores, out_scores = data() 38 | fpr95 = FPR95OODMetric() 39 | assert np.allclose(fpr95(in_scores, out_scores), expected_result) 40 | -------------------------------------------------------------------------------- /hopfield_boosting/trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from hopfield_boosting.energy import OneSidedEnergy 8 | from hopfield_boosting.util import Negative, FirstElement 9 | 10 | class Trainer(nn.Module, ABC): 11 | def __init__(self, model, classifier, criterion, optim, energy_fn, ood_weight, projection_head=None, beta=None, use_ood=True): 12 | super(Trainer, self).__init__() 13 | self.model = model 14 | self.classifier = classifier 15 | self.criterion = criterion 16 | self.optim = optim(params=nn.Sequential(model, classifier).parameters()) 17 | self.energy_fn = energy_fn 18 | self.ood_weight = ood_weight 19 | self.beta = beta 20 | self.use_ood = use_ood 21 | if projection_head is None: 22 | projection_head = nn.Identity() 23 | self.projection_head = projection_head 24 | 25 | def forward(self, x, y, x_aux): 26 | assert x_aux is not None 27 | len_xs_ood = len(x_aux) 28 | if self.use_ood: 29 | xs_train = torch.concat([x, x_aux], dim=0) 30 | else: 31 | xs_train = x 32 | 33 | emb = self.model(xs_train) 34 | emb_projected = self.projection_head(emb) 35 | logits = self.classifier(emb) 36 | 37 | if torch.any(torch.isnan(logits)): 38 | assert False 39 | 40 | logits_in = logits[:len(x)] 41 | logits_out = logits[len(x):] 42 | 43 | in_loss = self.criterion(logits_in, y) 44 | pred = torch.argmax(logits_in, dim=-1) 45 | acc = torch.mean((pred == y).to(torch.float)) 46 | loss = in_loss 47 | 48 | if self.use_ood: 49 | 50 | ood_loss, ood_dict = self.ood_loss(emb_projected[:-len_xs_ood], emb_projected[-len_xs_ood:], logits_in, logits_out) 51 | loss = loss + self.ood_weight * ood_loss 52 | 53 | else: 54 | ood_loss = torch.tensor(0.) 55 | ood_dict = {} 56 | 57 | info_dict = { 58 | 'classifier/acc': acc, 59 | 'classifier/in_loss': in_loss, 60 | 'classifier/in_logits': logits_in, 61 | 'general/loss': loss, 62 | 'out/ood_loss': ood_loss, 63 | } 64 | 65 | info_dict.update({f'energies/{k}': v for k, v in ood_dict.items() if not k == 'attn'}) 66 | 67 | info_dict = {k: v.detach() for k, v in info_dict.items()} 68 | 69 | return loss, info_dict 70 | 71 | def step(self, x, y, x_aux): 72 | loss, info_dict = self(x, y, x_aux) 73 | 74 | self.optim.zero_grad() 75 | loss.backward() 76 | 77 | self.optim.step() 78 | 79 | return loss, info_dict 80 | 81 | @abstractmethod 82 | def ood_loss(self, emb_in, emb_out, logits_in, logits_out): 83 | pass 84 | 85 | @abstractmethod 86 | def ood_tester(self, emb_in, emb_out) -> float: 87 | """ 88 | OOD Tester. The output should implement the following convention: High -> ID; Low -> OOD. 89 | """ 90 | pass 91 | 92 | 93 | class HopfieldTrainer(Trainer): 94 | def ood_loss(self, emb_in, emb_out, logits_in, logits_out): 95 | border_energy = self.energy_fn(a=emb_in, b=emb_out) 96 | energies, info_dict = border_energy(torch.concat([emb_in, emb_out], dim=0), return_dict=True) 97 | return -torch.mean(energies), {'border_energies': energies, **info_dict} 98 | 99 | def ood_tester(self, emb_in, emb_out): 100 | energy_one_sided = OneSidedEnergy(a=emb_in, b=emb_out, beta_a=self.beta, beta_b=self.beta, normalize=True) 101 | 102 | return nn.Sequential(self.model, self.projection_head, energy_one_sided, Negative()) 103 | 104 | 105 | class EnergyTrainer(Trainer): # implementation of https://arxiv.org/pdf/2010.03759.pdf 106 | def __init__(self, model, classifier, criterion, optim, m_in=-23, m_out=-5, energy_fn=None, ood_weight=0.1): # hyperparameters for CIFAR-10 taken from https://arxiv.org/pdf/2010.03759.pdf 107 | super().__init__(model, classifier, criterion, optim, energy_fn, ood_weight) 108 | self.m_in = m_in 109 | self.m_out = m_out 110 | 111 | def ood_loss(self, emb_in, emb_out, logits_in, logits_out): 112 | in_energy = -torch.logsumexp(logits_in, dim=-1) 113 | out_energy = -torch.logsumexp(logits_out, dim=-1) 114 | in_losses = torch.maximum(torch.tensor(0.), in_energy - self.m_in)**2 115 | out_losses = torch.maximum(torch.tensor(0.), self.m_out - out_energy)**2 116 | loss = torch.mean(in_losses) + torch.mean(out_losses) 117 | return loss, {'in_energy': in_energy, 'out_energy': out_energy, 'in_losses': in_losses, 'out_losses': out_losses} 118 | 119 | def ood_tester(self, emb_in, emb_out): 120 | class Energy(nn.Module): 121 | def forward(self, logits): 122 | return -torch.logsumexp(logits, dim=-1) 123 | return nn.Sequential(self.model, self.classifier, Energy(), Negative()) 124 | 125 | 126 | class MSPTrainer(Trainer): # implementation of https://arxiv.org/pdf/1812.04606.pdf 127 | def __init__(self, model, classifier, criterion, optim, energy_fn=None, ood_weight=0.5): 128 | super().__init__(model, classifier, criterion, optim, energy_fn, ood_weight) 129 | 130 | def ood_loss(self, emb_in, emb_out, logits_in, logits_out): 131 | return -(logits_out.mean(-1) - torch.logsumexp(logits_out, dim=-1)).mean(), {} 132 | 133 | def ood_tester(self, emb_in, emb_out): 134 | class MSP(nn.Module): 135 | def __init__(self) -> None: 136 | super(MSP, self).__init__() 137 | 138 | def forward(self, logits): 139 | return torch.max(torch.softmax(logits, dim=-1), dim=-1)[0] 140 | return nn.Sequential(self.model, self.classifier, MSP()) 141 | -------------------------------------------------------------------------------- /hopfield_boosting/transforms/img_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | class BWToRandColor(object): 7 | """ 8 | Transform to replace black and white pixels with random colors in an RGB image. 9 | 10 | Methods: 11 | __call__(self, img: Image.Image) -> Image.Image: Apply the transformation to the input image. 12 | """ 13 | def __call__(self, img: Image.Image) -> Image.Image: 14 | """ 15 | Apply the transformation to the input image. 16 | 17 | Parameters: 18 | img (PIL.Image.Image): Input grayscale image in "RGB" mode. 19 | 20 | Returns: 21 | PIL.Image.Image: Transformed image with replaced black and white pixels. 22 | """ 23 | if img.mode == "RGB": 24 | # Generate random colors for black and white replacement 25 | color1 = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 26 | color2 = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 27 | 28 | # Convert image to NumPy array 29 | img_array = np.array(img) 30 | 31 | # Create a mask for black pixels 32 | black_mask = (img_array[:, :, 0] == 0) & (img_array[:, :, 1] == 0) & (img_array[:, :, 2] == 0) 33 | 34 | # Create a mask for white pixels 35 | white_mask = (img_array[:, :, 0] == 255) & (img_array[:, :, 1] == 255) & (img_array[:, :, 2] == 255) 36 | 37 | # Replace black pixels with random color1 38 | img.paste(color1, None, mask=Image.fromarray(black_mask)) 39 | 40 | # Replace white pixels with random color2 41 | img.paste(color2, None, mask=Image.fromarray(white_mask)) 42 | 43 | # Convert the NumPy array back to a PIL image 44 | img = Image.fromarray(np.uint8(img)) 45 | 46 | return img 47 | 48 | 49 | class GrayToRandColor(object): 50 | """ 51 | Transform to interpolate colors based on grayscale intensity in an RGB image. 52 | 53 | Methods: 54 | __call__(self, img: Image.Image) -> Image.Image: Apply the transformation to the input image. 55 | """ 56 | def __call__(self, img: Image.Image) -> Image.Image: 57 | """ 58 | Apply the transformation to the input image. 59 | 60 | Parameters: 61 | img (PIL.Image.Image): Input grayscale image in "RGB" mode. 62 | 63 | Returns: 64 | PIL.Image.Image: Transformed image with colors interpolated based on grayscale intensity. 65 | """ 66 | if img.mode == "RGB": 67 | # Generate random colors for smooth fade 68 | color1 = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 69 | color2 = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 70 | 71 | # Convert image to NumPy array 72 | img_array = np.array(img) 73 | 74 | # Calculate grayscale intensity 75 | gray_intensity = np.mean(img_array, axis=-1) 76 | 77 | # Normalize the intensity to range [0, 1] 78 | normalized_intensity = gray_intensity / 255.0 79 | 80 | # Interpolate between color1 and color2 based on intensity 81 | interpolated_colors = ( 82 | (1 - normalized_intensity) * color1[0] + normalized_intensity * color2[0], 83 | (1 - normalized_intensity) * color1[1] + normalized_intensity * color2[1], 84 | (1 - normalized_intensity) * color1[2] + normalized_intensity * color2[2] 85 | ) 86 | 87 | # Replace pixels with interpolated colors 88 | img_array[:, :, 0] = interpolated_colors[0] 89 | img_array[:, :, 1] = interpolated_colors[1] 90 | img_array[:, :, 2] = interpolated_colors[2] 91 | 92 | # Convert the NumPy array back to a PIL image 93 | img = Image.fromarray(np.uint8(img_array)) 94 | 95 | return img -------------------------------------------------------------------------------- /hopfield_boosting/util.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from itertools import count 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | 8 | 9 | def logmeanexp(beta, tensor, dim, ignore_negative_inf=False, keepdim=False): 10 | n = torch.tensor(tensor.size(dim)) 11 | if ignore_negative_inf: 12 | no_neg_inf = torch.sum((torch.isinf(tensor) & (tensor < 0)).to(torch.int), dim=dim) 13 | n = n - no_neg_inf 14 | lse = 1/beta * torch.logsumexp(beta * tensor, dim=dim, keepdim=keepdim) 15 | return lse - 1 / beta * torch.log(n) 16 | 17 | 18 | @contextmanager 19 | def infer_loader(loader: torch.utils.data.DataLoader, model, device='cpu', max_samples=None): 20 | try: 21 | xs = [] 22 | 23 | if max_samples: 24 | r = range(int(np.ceil(max_samples / loader.batch_size))) 25 | else: 26 | r = count(0) 27 | 28 | with torch.no_grad(): 29 | for _, (x, y) in zip(r, loader): 30 | x = x.to(device) 31 | x = model(x) 32 | xs.append(x) 33 | 34 | x = torch.concat(xs, dim=0) 35 | 36 | yield x 37 | 38 | finally: 39 | del x 40 | 41 | class Negative(nn.Module): 42 | def forward(self, x): 43 | return -x 44 | 45 | class FirstElement(nn.Module): 46 | def forward(self, x): 47 | return x[0] -------------------------------------------------------------------------------- /hopfield_boosting/utils/config_util.py: -------------------------------------------------------------------------------- 1 | 2 | from omegaconf import OmegaConf 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | 7 | class ConstOmegaConfLoader: 8 | """ 9 | A class for loading an OmegaConf configuration from a file and setting it to read-only. 10 | 11 | Attributes: 12 | config (omegaconf.DictConfig): The loaded and read-only OmegaConf configuration. 13 | 14 | Methods: 15 | __init__(file_path: pathlib.Path): Initialize the loader with the path to the configuration file. 16 | load_config(): Load the configuration from the specified file and set it to read-only. 17 | """ 18 | 19 | def __init__(self, file_path: Union[Path, str]) -> None: 20 | """ 21 | Initialize the ConstOmegaConfLoader. 22 | 23 | Args: 24 | file_path (pathlib.Path): The path to the configuration file. 25 | """ 26 | self._file_path = file_path 27 | self._config = None 28 | 29 | @property 30 | def config(self) -> 'omegaconf.DictConfig': 31 | """ 32 | Getter for the loaded and read-only OmegaConf configuration. 33 | 34 | Returns: 35 | omegaconf.DictConfig: The loaded and read-only configuration. 36 | """ 37 | if self._config is None: 38 | self.load_config() 39 | return self._config 40 | 41 | def load_config(self) -> None: 42 | """ 43 | Load the configuration from the specified file and set it to read-only. 44 | """ 45 | self._config = OmegaConf.load(self._file_path) 46 | 47 | # Set the entire configuration as read-only 48 | OmegaConf.set_readonly(self._config, True) -------------------------------------------------------------------------------- /hopfield_boosting/utils/eval_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score 3 | from scipy.stats import wasserstein_distance 4 | 5 | 6 | 7 | def eval_auroc(a, b): 8 | labels = np.concatenate([np.zeros(len(a)), np.ones(len(b))], axis=0) 9 | score = np.concatenate([a, b], axis=0) 10 | 11 | auroc = roc_auc_score(labels, score) 12 | 13 | return auroc 14 | 15 | 16 | def eval_metrics_str(a, b): 17 | result_str = f"AUROC: {eval_auroc(a, b):.4f}\n" + f"Wasserstein distance: {wasserstein_distance(a, b):.4f}" 18 | 19 | return result_str 20 | 21 | 22 | def dump_eval_metrics(a, b): 23 | result_str = eval_metrics_str(a, b) 24 | print(result_str) -------------------------------------------------------------------------------- /hopfield_boosting/utils/model_util.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | from torch import nn 4 | import yaml 5 | from hydra.utils import instantiate 6 | from omegaconf import OmegaConf 7 | 8 | # helper function to load a model 9 | def model_file(model_path, epoch=99, name='model'): 10 | model_file = model_path / str(epoch) / (name + '.ckpt') 11 | 12 | return model_file 13 | 14 | 15 | def load_model(model_file, map_location='cpu'): 16 | model = torch.load(model_file, map_location=map_location) 17 | 18 | return model 19 | 20 | 21 | def create_model(model_path, model_config='cfg.yaml', device='cuda'): 22 | with open(model_path / model_config) as p: 23 | config = OmegaConf.create(yaml.safe_load(p)) 24 | 25 | resnet = instantiate(config['model']) 26 | projection_head = instantiate(config['projection_head']) 27 | classifier = nn.Linear(512, config.num_classes) 28 | 29 | return config, resnet.eval().to(device), projection_head.eval().to(device), classifier.eval().to(device) 30 | 31 | 32 | def load_model_weights(model_path, resnet, projection_head, classifier, epoch='', device='cuda'): 33 | resnet.load_state_dict(load_model(model_file(model_path, epoch=epoch))) 34 | projection_head.load_state_dict(load_model(model_file(model_path, name='projection_head', epoch=epoch))) 35 | classifier.load_state_dict(load_model(model_file(model_path, name='classifier', epoch=epoch))) 36 | 37 | return resnet.eval().to(device), projection_head.eval().to(device), classifier.eval().to(device) -------------------------------------------------------------------------------- /hopfield_boosting/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Tuple, Optional 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.colors import Colormap 6 | 7 | def plot_histograms(X: Union[np.ndarray, List[np.ndarray]], 8 | X_labels: Union[None, str, List[str]]=None, 9 | bins: Union[int, str, List[float]]='auto', 10 | density: bool=False, 11 | colormap: Union[str, Colormap, None]=None, 12 | colors: Union[None, List[str]]=None) -> plt.figure: 13 | """ 14 | Plot histograms for single or multiple datasets for visual comparison. 15 | 16 | Parameters: 17 | - X (Union[np.ndarray, List[np.ndarray]]): Single or list of 1D measurements to be turned into histograms. 18 | - X_labels (Union[None, str, List[str]], optional): Label or list of labels corresponding to each dataset in X. 19 | - bins (Union[int, str, List[float]], optional): Specification of histogram bins. Default is 'auto'. 20 | - density (bool, optional): If True, the histogram represents a probability density. Default is False. 21 | - colormap (Union[str, Colormap, None], optional): Name of the colormap or a colormap object itself. 22 | - colors (Union[None, List[str]], optional): List of colors corresponding to each dataset in X. 23 | 24 | Returns: 25 | - plt.figure: Matplotlib figure object containing the histogram plot. 26 | """ 27 | if (colormap is not None) and (colors is not None): 28 | raise ValueError("Both 'colormap' and 'colors' parameters cannot be provided simultaneously. " 29 | "Please choose either colormap or colors.") 30 | 31 | if not isinstance(X, list): 32 | X = [X] 33 | if isinstance(X_labels, str): 34 | X_labels = [X_labels] 35 | 36 | if X_labels is not None and len(X) != len(X_labels): 37 | raise ValueError("Length of X and X_labels should match") 38 | 39 | combined_data = np.concatenate(X) 40 | common_bins = np.histogram_bin_edges(combined_data, bins=bins) 41 | 42 | if colors is not None: 43 | if len(colors) != len(X): 44 | raise ValueError("Length of 'colors' should match the number of datasets in X") 45 | color_iter = iter(colors) 46 | 47 | if colormap is None: 48 | colormap = plt.rcParams['axes.prop_cycle'].by_key()['color'] 49 | elif isinstance(colormap, str): 50 | colormap = plt.get_cmap(colormap) 51 | 52 | fig, ax = plt.subplots() 53 | for i, data in enumerate(X): 54 | label = None if X_labels is None else X_labels[i] 55 | color = None 56 | 57 | if colors is not None: 58 | color = next(color_iter) 59 | else: 60 | if isinstance(colormap, list): 61 | color = colormap[i] 62 | else: 63 | color = colormap(i / len(X)) 64 | 65 | ax.hist(data, bins=common_bins, alpha=0.5, label=label, color=color, density=density) 66 | 67 | ylbl = "Relative Frequency" if density else "Frequency" 68 | ax.set_ylabel(ylbl) 69 | 70 | if X_labels is not None: 71 | ax.legend() 72 | 73 | return fig 74 | 75 | 76 | def create_subplots(n_subplots, rotate=False): 77 | """ 78 | Creates subplots in a grid layout. 79 | 80 | This function calculates the number of rows and columns based on the desired number of subplots. 81 | The subplots are then created using the `subplots` function from Matplotlib. 82 | 83 | Parameters: 84 | n_subplots (int): The total number of subplots. 85 | rotate (bool): Whether to rotate the layout by 90 degrees. 86 | 87 | Returns: 88 | matplotlib.figure.Figure: The generated figure object. 89 | numpy.ndarray: Flattened array of axes objects representing the subplots. 90 | 91 | """ 92 | square_len = np.sqrt(n_subplots) 93 | # we sometimes need an additional row depending on the rotation and the number of subplots 94 | row_appendix = int(bool(np.remainder(n_subplots,square_len))*rotate) 95 | 96 | nrows = int(square_len) + row_appendix 97 | ncols = int((n_subplots+nrows-1) // nrows) 98 | 99 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols) 100 | 101 | # Flatten the axes array for easier indexing 102 | axes = axes.flatten() 103 | 104 | # Remove axis and ticks for empty subplots 105 | for i in range(n_subplots, nrows * ncols): 106 | axes[i].axis('off') 107 | 108 | return fig, axes 109 | 110 | 111 | import matplotlib.pyplot as plt 112 | import matplotlib.patches as patches 113 | from matplotlib.collections import LineCollection 114 | 115 | 116 | def ax2ax(source_ax, target_ax): 117 | """ 118 | Reproduces the contents of one Matplotlib axis onto another axis. 119 | 120 | Parameters: 121 | source_ax (matplotlib.axes.Axes): The source axis from which the content will be copied. 122 | target_ax (matplotlib.axes.Axes): The target axis where the content will be reproduced. 123 | 124 | Returns: 125 | None 126 | """ 127 | # Reproduce line plots 128 | for line in source_ax.get_lines(): 129 | target_ax.plot(line.get_xdata(), 130 | line.get_ydata(), 131 | label=line.get_label(), 132 | color=line.get_color(), 133 | linestyle=line.get_linestyle(), 134 | linewidth=line.get_linewidth(), 135 | marker=line.get_marker(), 136 | markeredgecolor=line.get_markeredgecolor(), 137 | markeredgewidth=line.get_markeredgewidth(), 138 | markerfacecolor=line.get_markerfacecolor(), 139 | markersize=line.get_markersize(), 140 | ) 141 | 142 | # Reproduce rectangles (histogram bars) 143 | for artist in source_ax.__dict__['_children']: 144 | if isinstance(artist, patches.Rectangle): 145 | rect = artist 146 | # Retrieve properties of each rectangle and reproduce it on the target axis 147 | target_ax.add_patch(patches.Rectangle((rect.get_x(), rect.get_y()), 148 | rect.get_width(), 149 | rect.get_height(), 150 | edgecolor=rect.get_edgecolor(), 151 | facecolor=rect.get_facecolor(), 152 | linewidth=rect.get_linewidth(), 153 | linestyle=rect.get_linestyle() 154 | )) 155 | 156 | # Reproduce collections (e.g., LineCollection) 157 | for collection in source_ax.collections: 158 | if isinstance(collection, plt.collections.LineCollection): 159 | lc = plt.collections.LineCollection(segments=collection.get_segments(), 160 | label=collection.get_label(), 161 | color=collection.get_color(), 162 | linestyle=collection.get_linestyle(), 163 | linewidth=collection.get_linewidth(), 164 | ) 165 | target_ax.add_collection(lc) 166 | 167 | # Reproduce axis limits and aspect ratio 168 | target_ax.set_xlim(source_ax.get_xlim()) 169 | target_ax.set_ylim(source_ax.get_ylim()) 170 | target_ax.set_aspect(source_ax.get_aspect()) 171 | 172 | # Reproduce axis labels 173 | target_ax.set_xlabel(source_ax.get_xlabel()) 174 | target_ax.set_ylabel(source_ax.get_ylabel()) 175 | 176 | # Reproduce title 177 | target_ax.set_title(source_ax.get_title()) 178 | 179 | # Reproduce legend 180 | handles, labels = source_ax.get_legend_handles_labels() 181 | target_ax.legend(handles, labels) 182 | 183 | 184 | def plot_random_samples(dataset: torch.utils.data.Dataset, n_samples: int, 185 | grid_layout: Optional[bool] = True, title_label: Optional[bool] = True, 186 | frame: Optional[bool] = True) -> plt.Figure: 187 | """ 188 | Plot random samples from a PyTorch dataset. 189 | 190 | Parameters: 191 | dataset (torch.utils.data.Dataset): PyTorch dataset containing images and labels. 192 | n_samples (int): Number of samples to plot. 193 | grid_layout (bool, optional): If True, arrange the samples in a grid layout. Default is True. 194 | title_label (bool, optional): If True, display labels as titles for each sample. Default is True. 195 | frame (bool, optional): If True, display frames around the images. Default is True. 196 | 197 | Returns: 198 | matplotlib.figure.Figure: The generated matplotlib figure containing the plotted samples. 199 | """ 200 | # Randomly sample indices 201 | n_images = len(dataset) 202 | #sampled_indices = torch.randperm(n_images)[:n_samples] 203 | sampled_indices = np.random.permutation(n_images)[:n_samples] 204 | 205 | # Extract images and labels 206 | sampled_images = [dataset[i][0] for i in sampled_indices] 207 | if title_label: 208 | sampled_labels = [dataset[i][1] for i in sampled_indices] 209 | 210 | if grid_layout: 211 | # Create the layout of subplots 212 | fig, axes = create_subplots(n_samples, rotate=False) 213 | else: 214 | fig, axes = plt.subplots(nrows=1, ncols=n_samples) 215 | 216 | # Flatten the axes array for easier indexing 217 | axes = axes.flatten() 218 | 219 | for i in range(n_samples): 220 | ax = axes[i] 221 | ax.imshow(np.transpose(sampled_images[i], (1, 2, 0))) 222 | 223 | if title_label: 224 | ax.set_title(f"Label: {sampled_labels[i]}") 225 | 226 | ax.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False) 227 | 228 | if not frame: 229 | ax.axis("off") 230 | 231 | return fig 232 | 233 | 234 | def plot_random_samples_multi(datasets: Union[torch.utils.data.Dataset, List[torch.utils.data.Dataset]], 235 | n_samples: int, title_label: Optional[bool] = True, 236 | frame: Optional[bool] = True) -> plt.Figure: 237 | """ 238 | Plot random samples from one or more PyTorch datasets. 239 | 240 | Parameters: 241 | datasets (Union[torch.utils.data.Dataset, List[torch.utils.data.Dataset]]): Single or list of PyTorch datasets. 242 | n_samples (int): Number of samples to plot for each dataset. 243 | title_label (bool, optional): If True, display labels as titles for each sample. Default is True. 244 | frame (bool, optional): If True, display frames around the images. Default is True. 245 | 246 | Returns: 247 | matplotlib.figure.Figure: The generated matplotlib figure containing the plotted samples. 248 | """ 249 | if not isinstance(datasets, list): 250 | datasets = [datasets] 251 | 252 | n_datasets = len(datasets) 253 | 254 | fig, axes = plt.subplots(nrows=n_datasets, ncols=n_samples) 255 | 256 | if n_datasets < 2: 257 | axes = np.expand_dims(axes, axis=0) # Add an extra dimension for consistency 258 | 259 | for dataset_index, dataset in enumerate(datasets): 260 | # Randomly sample indices 261 | n_images = len(dataset) 262 | sampled_indices = torch.randperm(n_images)[:n_samples] 263 | 264 | # Extract images and labels 265 | sampled_images = [dataset[i][0] for i in sampled_indices] 266 | if title_label: 267 | sampled_labels = [dataset[i][1] for i in sampled_indices] 268 | 269 | for i in range(n_samples): 270 | ax = axes[dataset_index, i] 271 | ax.imshow(np.transpose(sampled_images[i], (1, 2, 0))) 272 | 273 | if title_label: 274 | ax.set_title(f"Label: {sampled_labels[i]}") 275 | 276 | ax.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False) 277 | 278 | if not frame: 279 | ax.axis("off") 280 | 281 | return fig -------------------------------------------------------------------------------- /images/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/hopfield-boosting/c5ae5b7e57d256a0a7848729052d04140d35dca8/images/figure1.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Energy-based Hopfield Boosting for Out-of-Distribution Detection 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2405.08766-b31b1b.svg)](https://arxiv.org/abs/2405.08766) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | 6 | This is the official implementation of "Energy-based Hopfield Boosting for Out-of-Distribution Detection". The paper is available [here](https://arxiv.org/abs/2405.08766). 7 | 8 | https://github.com/claushofmann/hopfield-classifier/assets/23155858/83021bdf-755b-4e79-8b52-4dc999cbc56f 9 | 10 | ## Installation 11 | 12 | - Hopfield Boosting works best with Anaconda ([download here](https://www.anaconda.com/download)). 13 | To install Hopfield Boosting and all dependencies, run the following commands: 14 | 15 | ``` 16 | conda env create -f environment.yml 17 | conda activate hopfield-boosting 18 | pip install -e . 19 | ``` 20 | 21 | ## Weights and Biases 22 | 23 | - Hopfield Boosting supports logging with Weights and Biases (W&B). By default, W&B will log all metrics in [anonymous mode](https://docs.wandb.ai/guides/app/features/anon). Note that runs logged in anonymous mode will be deleted after 7 days. To keep the logs, you need to [create a W&B account](https://docs.wandb.ai/quickstart). When done, login to your account using the command line. 24 | 25 | ## Data Sets 26 | To run, you need the following data sets. We follow the established benchmark, which is also used by e.g. [Lui et al. (2020)](https://arxiv.org/abs/2010.03759) and [Ming et al. (2022)](https://arxiv.org/abs/2206.13687). 27 | 28 | ### In-Distribution Data Sets 29 | 30 | * [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html): Automatically downloaded by PyTorch 31 | 32 | ### Auxiliary Outlier Data Set 33 | 34 | * [ImageNet-RC](https://patrykchrabaszcz.github.io/Imagenet32/): We use ImageNet64x64, which can be downloaded from the [ImageNet Website](http://image-net.org/download-images). 35 | 36 | 37 | ### Out-of-Distribution Validation Data Sets 38 | 39 | * **MNIST**: Automatically downloaded by PyTorch 40 | * **FashionMNIST**: Automatically downloaded by PyTorch 41 | 42 | ### Out-of-Distribution Test Data Sets 43 | 44 | The OOD test data is comprised of a selection of vision data sets: 45 | 46 | * **SVHN**: Street View House Numbers 47 | * **Places 365**: Scene recognition data set 48 | * **LSUN-Resize**: A resized version of the Large-scale Scene UNderstanding Challenge 49 | * **LSUN-Crop**: A cropped version of the Large-scale Scene UNderstanding Challenge 50 | * **iSUN**: Contains a large number of different scenes 51 | * **Textures**: A collection of textural images in the wild 52 | 53 | We have included a Python script that conveniently downloads all OOD Test data sets. To execute it, simply run 54 | 55 | ``` 56 | python -m hopfield_boosting.download_data 57 | ``` 58 | 59 | The downloaded data sets will be placed in the currently active directory under `downloaded_datasets/`. 60 | 61 | ## How to Run 62 | 63 | - Set the paths to the data sets: Copy the `.env.examples` file located in the root directory of the repository. 64 | Name the newly created file `.env`. 65 | Customize the new file to contain the paths to the data sets on your machine. 66 | You can also set a `project_root`, which is where Hopfield Boosting will store your model checkpoints. 67 | 68 | - To run Hopfield Boosting on CIFAR-10, run the command 69 | ``` 70 | python -m hopfield_boosting -cn resnet-18-cifar-10-aux-from-scratch 71 | ``` 72 | 73 | - For CIFAR-100, use the command 74 | ``` 75 | python -m hopfield_boosting -cn resnet-18-cifar-100-aux-from-scratch 76 | ``` 77 | 78 | - The performance on the OOD validation data sets will be logged to W&B; the performance on the OOD test sets will be logged to a file located in 79 | `test_logs` named according to the `run.id` from W&B. 80 | 81 | ## 📓 Demo Notebook 82 | 83 | We have provided a demo notebook [here](notebooks/hopfield_boosting_demo_extended.ipynb) where we demonstrate the capability of Hopfeild Boosting to detect OOD inputs. We provide a pre-trained model trained on CIFAR-10 for running the notebook, which is available for download [here](https://drive.google.com/file/d/1LK1VyjvQfA3qUG8LGue0IBOy0Sja2GJb/view?usp=sharing). 84 | 85 | To run, first set the paths to the data sets and the model in `hopfield_boosting_notebook_config.yaml`. The notebook uses additional data sets. You can find the link to download these data sets in the notebook itself. 86 | 87 | 88 | # 📚 Citation 89 | 90 | If you found this repository helpful, consider giving it a ⭐ and cite our paper: 91 | 92 | ``` 93 | @article{hofmann2024energybased, 94 | title={Energy-based Hopfield Boosting for Out-of-Distribution Detection}, 95 | author={Claus Hofmann and Simon Schmid and Bernhard Lehner and Daniel Klotz and Sepp Hochreiter}, 96 | year={2024}, 97 | journal={arXiv preprint arXiv:2405.08766} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='Hopfield Boosting', 6 | version='1.0', 7 | packages=find_packages(include=['hopfield_boosting', 'hopfield_boosting.*']) 8 | 9 | ) --------------------------------------------------------------------------------