├── .gitignore ├── 2d_density_estimation ├── ae_1_8gaussians.pth ├── ae_3_8gaussians.pth ├── draw_2d_density_estimation_figure.ipynb ├── draw_nae_training_methods.ipynb ├── nae_3_8gaussians.pth ├── nae_8gaussians.pth ├── nae_cd_3_8gaussians.pth ├── nae_pcd_3_8gaussians.pth ├── nae_training_methods.ipynb ├── vae_1_8gaussians.pth └── vae_3_8gaussians.pth ├── LICENSE ├── README.md ├── assets ├── celeba64samples.png ├── cifar10samples.png ├── fig_mnist_recon_with_box_v2.png ├── main.png ├── mnistsamples.png └── vanilla_ae_mnistsamples.png ├── augmentations ├── __init__.py └── augmentations.py ├── configs ├── celeba64_ood_nae │ └── z64gr_h32g8.yml ├── cifar_ood_nae │ └── z32gn.yml ├── fmnist_ood_nae │ └── z32.yml ├── mnist_ho_nae │ └── l2_z32.yml └── mnist_ood_nae │ └── z32.yml ├── conftest.py ├── datasets └── .gitkeep ├── evaluate_ood.py ├── loaders ├── __init__.py ├── chimera_dataset.py ├── leaveout_dataset.py ├── modified_dataset.py └── synthetic.py ├── metrics.py ├── models ├── __init__.py ├── ae.py ├── energybased.py ├── igebm.py ├── langevin.py ├── mcmc.py ├── mmd.py ├── modules.py ├── modules_sngan.py ├── nae.py ├── spectral_norm.py └── utils.py ├── optimizers.py ├── pretrained └── .gitkeep ├── sample.py ├── tests ├── test_load_pretrained.py ├── test_loader.py ├── test_modules.py └── test_nae.py ├── train.py ├── trainers ├── __init__.py ├── base.py ├── logger.py └── nae.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Add 132 | datasets/ 133 | results/ 134 | pretrained/ 135 | -------------------------------------------------------------------------------- /2d_density_estimation/ae_1_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/ae_1_8gaussians.pth -------------------------------------------------------------------------------- /2d_density_estimation/ae_3_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/ae_3_8gaussians.pth -------------------------------------------------------------------------------- /2d_density_estimation/nae_3_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_3_8gaussians.pth -------------------------------------------------------------------------------- /2d_density_estimation/nae_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_8gaussians.pth -------------------------------------------------------------------------------- /2d_density_estimation/nae_cd_3_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_cd_3_8gaussians.pth -------------------------------------------------------------------------------- /2d_density_estimation/nae_pcd_3_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_pcd_3_8gaussians.pth -------------------------------------------------------------------------------- /2d_density_estimation/vae_1_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/vae_1_8gaussians.pth -------------------------------------------------------------------------------- /2d_density_estimation/vae_3_8gaussians.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/vae_3_8gaussians.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sangwoong Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autoencoding Under Normalization Constraints 2 | 3 | The official repository for <Autoencoding Under Normalization Constraints> (Yoon, Noh and Park, ICML 2021) and **normalized autoencoders**. 4 | 5 | > The paper proposes Normalized Autoencoder (NAE), which is a novel energy-based model where the energy function is the reconstruction error. NAE effectively remedies outlier reconstruction, a pathological phenomenon limiting the performance of an autoencoder as an outlier detector. 6 | 7 | Paper: https://arxiv.org/abs/2105.05735 8 | 5-min video: https://www.youtube.com/watch?v=ra6usGKnPGk 9 | 10 | 11 | ![MNIST-figure](assets/fig_mnist_recon_with_box_v2.png) 12 | 13 | ## News 14 | 15 | * 2023-06-15 : Add config file for FashionMNIST 16 | * 2022-06-13 : Refactoring of NAE class. 17 | 18 | 19 | ## Set-up 20 | 21 | ### Environment 22 | 23 | I encourage you to use conda to set up a virtual environment. However, other methods should work without problems. 24 | ``` 25 | conda create -n nae python=3.7 26 | ``` 27 | 28 | The main dependencies of the repository is as follows: 29 | 30 | - python 3.7.2 31 | - numpy 32 | - pillow 33 | - pytorch 1.7.1 34 | - CUDA 10.1 35 | - scikit-learn 0.24.2 36 | - tensorboard 2.5.0 37 | - pytest 6.2.3 38 | 39 | 40 | ### Datasets 41 | 42 | All datasets are stored in `datasets/` directory. 43 | 44 | - MNIST, CIFAR-10, SVHN, Omniglot : Retrieved using `torchvision.dataset`. 45 | - Noise, Constant, ConstantGray : [Dropbox link](https://www.dropbox.com/sh/u41ewgwujuvqvpm/AABM6YbklJFAruczJPhBWNwZa?dl=0) 46 | - [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), [ImageNet 32x32](http://image-net.org/small/download.php): Retrieved from their official site. I am afraid that website for ImageNet 32x32 is not available as of June 24, 2021. I will temporarily upload the data to the above Dropbox link. 47 | 48 | When set up, the dataset directory should look like as follows. 49 | 50 | ``` 51 | datasets 52 | ├── CelebA 53 | │   ├── Anno 54 | │   ├── Eval 55 | │   └── Img 56 | ├── cifar-10-batches-py 57 | ├── const_img_gray.npy 58 | ├── const_img.npy 59 | ├── FashionMNIST 60 | ├── ImageNet32 61 | │   ├── train_32x32 62 | │   └── valid_32x32 63 | ├── MNIST 64 | ├── noise_img.npy 65 | ├── omniglot-py 66 | │   ├── images_background 67 | │   └── images_evaluation 68 | ├── test_32x32.mat 69 | └── train_32x32.mat 70 | 71 | ``` 72 | 73 | ### Pre-trained Models 74 | 75 | Pre-trained models are stored under `pretrained/`. The pre-trained models are provided through the [Dropbox link](https://www.dropbox.com/sh/u41ewgwujuvqvpm/AABM6YbklJFAruczJPhBWNwZa?dl=0). 76 | 77 | If the pretrained models are prepared successfully, the directory structure should look like the following. 78 | 79 | ``` 80 | pretrained 81 | ├── celeba64_ood_nae 82 | │   └── z64gr_h32g8 83 | ├── cifar_ood_nae 84 | │   └── z32gn 85 | └── mnist_ood_nae 86 | └── z32 87 | ``` 88 | 89 | ## Unittesting 90 | 91 | PyTest is used for unittesting. 92 | 93 | ``` 94 | pytest tests 95 | ``` 96 | 97 | The code should pass all tests after the preparation of pre-trained models and datasets. 98 | 99 | ## Execution 100 | 101 | ### OOD Detection Evaluation 102 | 103 | ``` 104 | python evaluate_ood.py --ood ConstantGray_OOD,FashionMNIST_OOD,SVHN_OOD,CelebA_OOD,Noise_OOD --resultdir pretrained/cifar_ood_nae/z32gn/ --ckpt nae_9.pkl --config z32gn.yml --device 0 --dataset CIFAR10_OOD 105 | ``` 106 | 107 |
108 | Expected Results 109 | 110 | ``` 111 | OOD Detection Results in AUC 112 | ConstantGray_OOD:0.9632 113 | FashionMNIST_OOD:0.8193 114 | SVHN_OOD:0.9196 115 | CelebA_OOD:0.8873 116 | Noise_OOD:1.0000 117 | ``` 118 |
119 | 120 | 121 | ### Training 122 | 123 | Use `train.py` to train NAE. 124 | * `--config` option specifies a path to a configuration yaml file. 125 | * `--logdir` specifies a directory where results files will be written. 126 | * `--run` specifies an id for each run, i.e., an experiment. 127 | 128 | **Training on MNIST** 129 | ``` 130 | python train.py --config configs/mnist_ood_nae/z32.yml --logdir results/mnist_ood_nae/ --run run --device 0 131 | ``` 132 | 133 | **Training on MNIST digits 0 to 8 for the hold-out digit detection task** 134 | ``` 135 | python train.py --config configs/mnist_ho_nae/l2_z32.yml --logdir results/mnist_ho_nae --run run --device 0 136 | ``` 137 | 138 | **Training on CIFAR-10** 139 | ``` 140 | python train.py --config configs/cifar_ood_nae/z32gn.yml --logdir results/cifar_ood_nae/ --run run --device 0 141 | ``` 142 | 143 | **Training on CelebA 64x64** 144 | ``` 145 | python train.py --config configs/celeba64_ood_nae/z64gr_h32g8.yml --logdir results/celeba64_ood_nae/z64gr_h32g8.yml --run run --device 0 146 | ``` 147 | 148 | **Training on FashionMNIST** 149 | ``` 150 | python train.py --config configs/fmnist_ood_nae/z32.yml --device 0 --logdir results/fmnist_ood_nae --run run --device 0 151 | ``` 152 | 153 | 154 | ### Sampling 155 | 156 | Use `sample.py` to generate sample images form NAE. Samples are saved as `.npy` file containing an `(n_sample, img_h, img_w, channels)` array. 157 | Note that the quality of generated images is not supposed to match that of state-of-the-art generative models. Improving the sample quality is one of the important future research direction. 158 | 159 | **Sampling for MNIST** 160 | 161 | ``` 162 | python sample.py pretrained/mnist_ood_nae/z32/ z32.yml nae_20.pkl --zstep 200 --x_shape 28 --batch_size 64 --n_sample 64 --x_channel 1 --device 0 163 | ``` 164 | 165 | ![mnistsamples](assets/mnistsamples.png) 166 | 167 | The white square is an artifact of NAE, possibly occurring due to the distortion of the encoder and the decoder. 168 | 169 | The result is comparable to the samples from a vanilla autoencoder generated with the same procedure. 170 | 171 | ![vanillamnistsamples](assets/vanilla_ae_mnistsamples.png) 172 | 173 | 174 | **Sampling for CIFAR-10** 175 | ``` 176 | python sample.py pretrained/cifar_ood_nae/z32gn/ z32gn.yml nae_8.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0 177 | ``` 178 | 179 | ![cifar10samples](assets/cifar10samples.png) 180 | 181 | **Sampling for CelebA 64x64** 182 | ``` 183 | python sample.py pretrained/celeba64_ood_nae/z64gr_h32g8/ z64gr_h32g8.yml nae_3.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0 --x_shape 64 184 | ``` 185 | 186 | 187 | ![celeba64samples](assets/celeba64samples.png) 188 | 189 | 190 | 191 | ## Citation 192 | 193 | 194 | ``` 195 | @InProceedings{pmlr-v139-yoon21c, 196 | title = {Autoencoding Under Normalization Constraints}, 197 | author = {Yoon, Sangwoong and Noh, Yung-Kyun and Park, Frank}, 198 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 199 | pages = {12087--12097}, 200 | year = {2021}, 201 | editor = {Meila, Marina and Zhang, Tong}, 202 | volume = {139}, 203 | series = {Proceedings of Machine Learning Research}, 204 | month = {18--24 Jul}, 205 | publisher = {PMLR}, 206 | pdf = {http://proceedings.mlr.press/v139/yoon21c/yoon21c.pdf}, 207 | url = {https://proceedings.mlr.press/v139/yoon21c.html} 208 | } 209 | 210 | ``` 211 | 212 | -------------------------------------------------------------------------------- /assets/celeba64samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/celeba64samples.png -------------------------------------------------------------------------------- /assets/cifar10samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/cifar10samples.png -------------------------------------------------------------------------------- /assets/fig_mnist_recon_with_box_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/fig_mnist_recon_with_box_v2.png -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/main.png -------------------------------------------------------------------------------- /assets/mnistsamples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/mnistsamples.png -------------------------------------------------------------------------------- /assets/vanilla_ae_mnistsamples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/vanilla_ae_mnistsamples.png -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from torchvision.transforms import ( 3 | RandomHorizontalFlip, 4 | RandomRotation, 5 | RandomResizedCrop, 6 | ColorJitter, 7 | RandomGrayscale, 8 | RandomChoice, 9 | Compose, 10 | Normalize, 11 | ToTensor 12 | ) 13 | from augmentations.augmentations import ( 14 | RandomRotate90, 15 | ColorJitterSimCLR, 16 | GaussianDequantize, 17 | UniformDequantize, 18 | ToGray, 19 | ) 20 | 21 | logger = logging.getLogger("ptsemseg") 22 | 23 | 24 | key2aug = { 25 | 'hflip': RandomHorizontalFlip, 26 | 'rotate': RandomRotate90, 27 | 'rcrop': RandomResizedCrop, 28 | 'cjitter': ColorJitterSimCLR, 29 | 'rgray': RandomGrayscale, 30 | 'GaussianDequantize': GaussianDequantize, 31 | 'UniformDequantize': UniformDequantize, 32 | 'togray': ToGray, 33 | 'normalize': Normalize, 34 | 'totensor': ToTensor 35 | } 36 | 37 | 38 | def get_composed_augmentations(aug_dict): 39 | if aug_dict is None: 40 | # print("Using No Augmentations") 41 | return None 42 | 43 | augmentations = [] 44 | for aug_key, aug_param in aug_dict.items(): 45 | augmentations.append(key2aug[aug_key](**aug_param)) 46 | # print("Using {} aug with params {}".format(aug_key, aug_param)) 47 | return Compose(augmentations) 48 | -------------------------------------------------------------------------------- /augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms.functional as tf 7 | from torchvision.transforms import ( 8 | RandomApply, 9 | ColorJitter, 10 | ) 11 | 12 | from PIL import Image, ImageOps 13 | 14 | 15 | class RandomRotate90: 16 | def __init__(self): 17 | self.t = RandomChoice([RandomRotation(0), RandomRotation(90), 18 | RandomRotation(180), RandomRotation(270)]) 19 | def __call__(self, img): 20 | return self.t(img) 21 | 22 | 23 | class ColorJitterSimCLR: 24 | def __init__(self, jitter_d=0.5, jitter_p=0.8): 25 | self.t = RandomApply([ColorJitter(0.8 * jitter_d, 0.8 * jitter_d, 0.8 * jitter_d, 0.2 * jitter_d)], 26 | p=jitter_p) 27 | 28 | def __call__(self, img): 29 | return self.t(img) 30 | 31 | 32 | class Invert(object): 33 | """Transform that inverts pixel intensities""" 34 | def __init__(self): 35 | pass 36 | 37 | def __call__(self, img): 38 | return 1 - img 39 | 40 | 41 | class ToGray: 42 | """ Transform color image to gray scale""" 43 | def __init__(self): 44 | pass 45 | 46 | def __call__(self, img): 47 | img = torch.sqrt(torch.sum(img ** 2, dim=0, keepdim=True)) 48 | return torch.cat([img, img, img], dim=0).detach() 49 | 50 | 51 | class Fragment: 52 | def __init__(self, mode): 53 | self.mode = mode 54 | 55 | def __call__(self, sample): 56 | '''sample: numpy array for image (after ToTensor)''' 57 | if self.mode == 'horizontal': 58 | half = (sample.shape[2]) // 2 59 | if np.random.rand() > 0.5: # uppder half 60 | sample[:,:,:half] = 0 61 | else: 62 | sample[:,:,half:] = 0 63 | 64 | elif self.mode == 'vertical': 65 | half = (sample.shape[1]) // 2 66 | if np.random.rand() > 0.5: # uppder half 67 | sample[:,:half,:] = 0 68 | else: 69 | sample[:,half:,:] = 0 70 | elif self.mode == '1/4': 71 | choice = (np.random.rand() * 4) % 4 72 | half1 = (sample.shape[1]) // 2 73 | half2 = (sample.shape[2]) // 2 74 | img = torch.zeros_like(sample) 75 | if choice == 0: 76 | img[:,:half1,:half2] = sample[:,:half1,:half2] 77 | elif choice == 1: 78 | img[:,half1:,:half2] = sample[:,half1:,:half2] 79 | elif choice == 2: 80 | img[:,half1:,half2:] = sample[:,half1:,half2:] 81 | else: 82 | img[:,:half1,half2:] = sample[:,:half1,half2:] 83 | sample = img 84 | else: 85 | raise ValueError 86 | return sample 87 | 88 | 89 | class UniformDequantize: 90 | def __init__(self): 91 | pass 92 | 93 | def __call__(self, img): 94 | img = img / 256. * 255. + torch.rand_like(img) / 256. 95 | return img 96 | 97 | 98 | class GaussianDequantize: 99 | def __init__(self): 100 | pass 101 | 102 | def __call__(self, img): 103 | img = img + torch.randn_like(img) * 0.01 104 | return img 105 | 106 | -------------------------------------------------------------------------------- /configs/celeba64_ood_nae/z64gr_h32g8.yml: -------------------------------------------------------------------------------- 1 | trainer: nae 2 | logger: nae 3 | model: 4 | arch: nae 5 | encoder: 6 | arch: conv64 7 | nh: 32 8 | out_activation: linear 9 | activation: relu 10 | use_bn: false 11 | num_groups: 8 12 | decoder: 13 | arch: deconv64 14 | nh: 32 15 | use_bn: false 16 | num_groups: 8 17 | out_activation: sigmoid 18 | nae: 19 | spherical: True 20 | gamma: 1 21 | sampling: on_manifold 22 | 23 | z_step: 20 24 | z_stepsize: 1 25 | z_noise_std: 0.02 26 | z_noise_anneal: Null 27 | z_clip_langevin_grad: Null 28 | x_step: 40 29 | x_stepsize: 20 30 | x_noise_std: 0.02 31 | x_noise_anneal: 1 32 | x_clip_langevin_grad: 0.01 33 | 34 | buffer_size: 10000 35 | replay_ratio: 0.95 36 | replay: True 37 | 38 | x_bound: 39 | - 0 40 | - 1 41 | z_bound: Null 42 | l2_norm_reg: Null 43 | l2_norm_reg_en: Null 44 | temperature: 1. 45 | temperature_trainable: False 46 | 47 | x_dim: 3 48 | z_dim: 64 49 | data: 50 | indist_train: 51 | dataset: CelebA_OOD 52 | path: datasets 53 | batch_size: 128 54 | n_workers: 8 55 | split: training 56 | shuffle: True 57 | size: 64 58 | augmentations: 59 | hflip: 60 | p: 0.5 61 | dequant: 62 | UniformDequantize: {} 63 | indist_val: 64 | dataset: CelebA_OOD 65 | path: datasets 66 | batch_size: 128 67 | n_workers: 8 68 | split: validation 69 | size: 64 70 | ood_val: 71 | dataset: CelebA_OOD 72 | channel: 3 73 | path: datasets 74 | batch_size: 128 75 | split: validation 76 | n_workers: 8 77 | size: 64 78 | dequant: 79 | togray: {} 80 | ood_target: 81 | dataset: ConstantGray_OOD 82 | size: 64 83 | path: datasets 84 | batch_size: 128 85 | n_workers: 8 86 | split: validation 87 | training: 88 | # load_ae: /opt/home3/swyoon/energy-based-autoencoder/src/results/celeba_ood_nae/big/run/model_best.pkl 89 | ae_epoch: 30 90 | nae_epoch: 15 91 | lr_schedule: null 92 | # name: on_plateau 93 | # factor: 0.5 94 | # patience: 3 95 | # min_lr: 1.0e-4 96 | resume: null 97 | # file: /a/b/c.pkl 98 | # optimizer: False 99 | print_interval: 100 100 | val_interval: 200 101 | save_interval: 2000 102 | ae_lr: 1.0e-4 103 | nae_lr: 1.0e-5 104 | temperature_lr: 1.0e-3 105 | fix_D: false 106 | -------------------------------------------------------------------------------- /configs/cifar_ood_nae/z32gn.yml: -------------------------------------------------------------------------------- 1 | data: 2 | indist_train: 3 | augmentations: 4 | hflip: 5 | p: 0.5 6 | batch_size: 128 7 | dataset: CIFAR10_OOD 8 | dequant: 9 | UniformDequantize: {} 10 | n_workers: 8 11 | path: datasets 12 | shuffle: true 13 | split: training 14 | indist_val: 15 | batch_size: 128 16 | dataset: CIFAR10_OOD 17 | n_workers: 8 18 | path: datasets 19 | split: validation 20 | ood_target: 21 | batch_size: 128 22 | dataset: Constant_OOD 23 | n_workers: 8 24 | path: datasets 25 | size: 32 26 | split: validation 27 | ood_val: 28 | batch_size: 128 29 | channel: 3 30 | dataset: SVHN_OOD 31 | n_workers: 8 32 | path: datasets 33 | split: validation 34 | device: cuda:3 35 | logger: nae 36 | model: 37 | arch: nae 38 | decoder: 39 | arch: sngan_generator_gn 40 | hidden_dim: 128 41 | num_groups: 2 42 | out_activation: sigmoid 43 | encoder: 44 | arch: IGEBMEncoder 45 | keepdim: true 46 | use_spectral_norm: false 47 | nae: 48 | buffer_size: 10000 49 | gamma: 1 50 | l2_norm_reg: null 51 | l2_norm_reg_en: null 52 | replay: true 53 | replay_ratio: 0.95 54 | sampling: on_manifold 55 | spherical: true 56 | temperature: 1.0 57 | temperature_trainable: false 58 | x_bound: 59 | - 0 60 | - 1 61 | x_clip_langevin_grad: 0.01 62 | x_noise_anneal: 1 63 | x_noise_std: 0.02 64 | x_step: 40 65 | x_stepsize: 10 66 | z_bound: null 67 | z_clip_langevin_grad: null 68 | z_noise_anneal: null 69 | z_noise_std: 0.02 70 | z_step: 20 71 | z_stepsize: 1 72 | x_dim: 3 73 | z_dim: 32 74 | trainer: nae 75 | training: 76 | ae_epoch: 180 77 | ae_lr: 0.0001 78 | fix_D: false 79 | lr_schedule: null 80 | nae_epoch: 10 81 | nae_lr: 1.0e-05 82 | print_interval: 352 83 | resume: null 84 | save_interval: 2000 85 | temperature_lr: 0.001 86 | val_interval: 352 87 | -------------------------------------------------------------------------------- /configs/fmnist_ood_nae/z32.yml: -------------------------------------------------------------------------------- 1 | trainer: nae 2 | logger: nae 3 | model: 4 | arch: nae 5 | encoder: 6 | arch: conv2fc 7 | nh: 8 8 | nh_mlp: 1024 9 | out_activation: linear 10 | decoder: 11 | arch: deconv2 12 | nh: 8 13 | out_activation: sigmoid 14 | nae: 15 | spherical: True 16 | gamma: 1 17 | sampling: on_manifold 18 | 19 | z_step: 10 20 | z_stepsize: 0.2 21 | z_noise_std: 0.05 22 | z_noise_anneal: Null 23 | z_clip_langevin_grad: Null 24 | x_step: 50 25 | x_stepsize: 10 26 | x_noise_std: 0.05 27 | x_noise_anneal: 1 28 | x_clip_langevin_grad: 0.01 29 | 30 | buffer_size: 10000 31 | replay_ratio: 0.95 32 | replay: True 33 | 34 | x_bound: 35 | - 0 36 | - 1 37 | z_bound: Null 38 | l2_norm_reg: Null 39 | l2_norm_reg_en: 1e-4 40 | temperature: 1. 41 | temperature_trainable: False 42 | 43 | x_dim: 1 44 | z_dim: 32 45 | data: 46 | indist_train: 47 | dataset: FashionMNIST_OOD 48 | path: datasets 49 | shuffle: True 50 | batch_size: 128 51 | n_workers: 8 52 | split: training 53 | size: 28 54 | dequant: 55 | UniformDequantize: {} 56 | indist_val: 57 | dataset: FashionMNIST_OOD 58 | path: datasets 59 | batch_size: 128 60 | n_workers: 8 61 | split: validation 62 | size: 28 63 | ood_val: 64 | dataset: Constant_OOD 65 | size: 28 66 | channel: 1 67 | path: datasets 68 | split: validation 69 | n_workers: 4 70 | batch_size: 128 71 | ood_target: 72 | dataset: MNIST_OOD 73 | size: 28 74 | channel: 1 75 | path: datasets 76 | split: validation 77 | n_workers: 4 78 | batch_size: 128 79 | training: 80 | ae_epoch: 100 81 | nae_epoch: 50 82 | save_interval: 2000 83 | val_interval: 1000 84 | print_interval: 100 85 | ae_lr: 1e-4 86 | nae_lr: 1e-5 87 | fix_D: false 88 | -------------------------------------------------------------------------------- /configs/mnist_ho_nae/l2_z32.yml: -------------------------------------------------------------------------------- 1 | trainer: nae_v2 2 | logger: base 3 | model: 4 | arch: nae_l2 5 | sampling: omi 6 | encoder: 7 | arch: conv2fc 8 | nh: 8 9 | nh_mlp: 1024 10 | out_activation: spherical 11 | decoder: 12 | arch: deconv2 13 | nh: 8 14 | out_activation: sigmoid 15 | nae: 16 | gamma: 1 17 | l2_norm_reg_de: Null 18 | l2_norm_reg_en: 0.0001 19 | T: 1. 20 | sampler_z: 21 | sampler: langevin 22 | n_step: 10 23 | stepsize: 0.2 24 | noise_std: 0.05 25 | noise_anneal: Null 26 | clip_langevin_grad: Null 27 | buffer_size: 10000 28 | replay_ratio: 0.95 29 | mh: False 30 | bound: spherical 31 | initial_dist: uniform_sphere 32 | sampler_x: 33 | sampler: langevin 34 | n_step: 50 35 | stepsize: 10 36 | noise_std: 0.05 37 | noise_anneal: 1 38 | clip_langevin_grad: 0.01 39 | mh: False 40 | buffer_size: 0 41 | bound: [0, 1] 42 | x_dim: 1 43 | z_dim: 32 44 | data: 45 | indist_train: 46 | dataset: MNISTLeaveOut 47 | path: datasets 48 | batch_size: 128 49 | n_workers: 8 50 | shuffle: True 51 | split: training 52 | out_class: 53 | - 9 54 | indist_val: 55 | dataset: MNISTLeaveOut 56 | path: datasets 57 | batch_size: 128 58 | n_workers: 8 59 | split: validation 60 | out_class: 61 | - 9 62 | ood_val: 63 | dataset: Constant_OOD 64 | size: 28 65 | channel: 1 66 | path: datasets 67 | batch_size: 128 68 | split: validation 69 | n_workers: 8 70 | ood_target: 71 | dataset: MNISTLeaveOut 72 | path: datasets 73 | batch_size: 128 74 | n_workers: 8 75 | split: validation 76 | out_class: 77 | - 0 78 | - 1 79 | - 2 80 | - 3 81 | - 4 82 | - 5 83 | - 6 84 | - 7 85 | - 8 86 | training: 87 | load_ae: Null 88 | ae_epoch: 100 89 | nae_epoch: 50 90 | print_interval: 500 91 | print_interval_nae: 100 92 | val_interval: 352 93 | save_interval: 2000 94 | ae_lr: 1.0e-4 95 | nae_lr: 1.0e-5 96 | nae_opt: all 97 | -------------------------------------------------------------------------------- /configs/mnist_ood_nae/z32.yml: -------------------------------------------------------------------------------- 1 | trainer: nae 2 | logger: nae 3 | model: 4 | arch: nae 5 | encoder: 6 | arch: conv2fc 7 | nh: 8 8 | nh_mlp: 1024 9 | out_activation: linear 10 | decoder: 11 | arch: deconv2 12 | nh: 8 13 | out_activation: sigmoid 14 | nae: 15 | spherical: True 16 | gamma: 1 17 | sampling: on_manifold 18 | 19 | z_step: 10 20 | z_stepsize: 0.2 21 | z_noise_std: 0.05 22 | z_noise_anneal: Null 23 | z_clip_langevin_grad: Null 24 | x_step: 50 25 | x_stepsize: 10 26 | x_noise_std: 0.05 27 | x_noise_anneal: 1 28 | x_clip_langevin_grad: 0.01 29 | 30 | buffer_size: 10000 31 | replay_ratio: 0.95 32 | replay: True 33 | 34 | x_bound: 35 | - 0 36 | - 1 37 | z_bound: Null 38 | l2_norm_reg: Null 39 | l2_norm_reg_en: 1e-4 40 | temperature: 1. 41 | temperature_trainable: False 42 | 43 | x_dim: 1 44 | z_dim: 32 45 | data: 46 | indist_train: 47 | dataset: MNIST_OOD 48 | path: datasets 49 | shuffle: True 50 | batch_size: 128 51 | n_workers: 8 52 | split: training 53 | size: 28 54 | dequant: 55 | UniformDequantize: {} 56 | indist_val: 57 | dataset: MNIST_OOD 58 | path: datasets 59 | batch_size: 128 60 | n_workers: 8 61 | split: validation 62 | size: 28 63 | ood_val: 64 | dataset: Constant_OOD 65 | size: 28 66 | channel: 1 67 | path: datasets 68 | split: validation 69 | n_workers: 4 70 | batch_size: 128 71 | ood_target: 72 | dataset: Noise_OOD 73 | size: 28 74 | channel: 1 75 | path: datasets 76 | split: validation 77 | n_workers: 4 78 | batch_size: 128 79 | training: 80 | ae_epoch: 100 81 | nae_epoch: 50 82 | save_interval: 2000 83 | val_interval: 1000 84 | print_interval: 100 85 | ae_lr: 1e-4 86 | nae_lr: 1e-5 87 | fix_D: false 88 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/conftest.py -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/datasets/.gitkeep -------------------------------------------------------------------------------- /evaluate_ood.py: -------------------------------------------------------------------------------- 1 | """ 2 | evaluate OOD detection performance through AUROC score 3 | 4 | Example: 5 | python evaluate_cifar_ood.py --dataset FashionMNIST_OOD \ 6 | --ood MNIST_OOD,ConstantGray_OOD \ 7 | --resultdir results/fmnist_ood_vqvae/Z7K512/e300 \ 8 | --ckpt model_epoch_280.pkl \ 9 | --config Z7K512.yml \ 10 | --device 1 11 | """ 12 | import os 13 | import yaml 14 | import argparse 15 | import copy 16 | import torch 17 | import numpy as np 18 | from torch.utils import data 19 | from models import get_model, load_pretrained 20 | from loaders import get_dataloader 21 | 22 | from utils import roc_btw_arr, batch_run, search_params_intp, parse_unknown_args, parse_nested_args 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--resultdir', type=str, help='result dir. results/... or pretrained/...') 27 | parser.add_argument('--config', type=str, help='config file name') 28 | parser.add_argument('--ckpt', type=str, help='checkpoint file name to load. default', default=None) 29 | parser.add_argument('--ood', type=str, help='list of OOD datasets, separated by comma') 30 | parser.add_argument('--device', type=str, help='device') 31 | parser.add_argument('--dataset', type=str, choices=['MNIST_OOD', 'CIFAR10_OOD', 'ImageNet32', 'FashionMNIST_OOD', 32 | 'FashionMNISTpad_OOD'], 33 | default='MNIST', help='inlier dataset dataset') 34 | parser.add_argument('--aug', type=str, help='pre-defiend data augmentation', choices=[None, 'CIFAR10', 'CIFAR10-OE']) 35 | parser.add_argument('--method', type=str, choices=[None, 'likelihood_regret', 'input_complexity', 'outlier_exposure']) 36 | args, unknown = parser.parse_known_args() 37 | d_cmd_cfg = parse_unknown_args(unknown) 38 | d_cmd_cfg = parse_nested_args(d_cmd_cfg) 39 | print(d_cmd_cfg) 40 | 41 | 42 | # load config file 43 | cfg = yaml.load(open(os.path.join(args.resultdir, args.config)), Loader=yaml.FullLoader) 44 | result_dir = args.resultdir 45 | if args.ckpt is not None: 46 | ckpt_file = os.path.join(result_dir, args.ckpt) 47 | else: 48 | raise ValueError(f'ckpt file not specified') 49 | 50 | print(f'loading from {ckpt_file}') 51 | l_ood = [s.strip() for s in args.ood.split(',')] 52 | device = f'cuda:{args.device}' 53 | 54 | print(f'loading from : {ckpt_file}') 55 | 56 | 57 | def evaluate(m, in_dl, out_dl, device): 58 | """computes OOD detection score""" 59 | in_pred = batch_run(m, in_dl, device, method='predict') 60 | out_pred = batch_run(m, out_dl, device, method='predict') 61 | auc = roc_btw_arr(out_pred, in_pred) 62 | return auc 63 | 64 | 65 | # load dataset 66 | print('ood datasets') 67 | print(l_ood) 68 | if args.dataset in {'MNIST_OOD', 'FashionMNIST_OOD'}: 69 | size = 28 70 | channel = 1 71 | else: 72 | size = 32 73 | channel = 3 74 | data_dict = {'path': 'datasets', 75 | 'size': size, 76 | 'channel': channel, 77 | 'batch_size': 64, 78 | 'n_workers': 4, 79 | 'split': 'evaluation', 80 | 'path': 'datasets'} 81 | 82 | 83 | data_dict_ = copy.copy(data_dict) 84 | data_dict_['dataset'] = args.dataset 85 | in_dl = get_dataloader(data_dict_) 86 | 87 | l_ood_dl = [] 88 | for ood_name in l_ood: 89 | data_dict_ = copy.copy(data_dict) 90 | data_dict_['dataset'] = ood_name 91 | dl = get_dataloader(data_dict_) 92 | l_ood_dl.append(dl) 93 | 94 | model = get_model(cfg).to(device) 95 | ckpt_data = torch.load(ckpt_file) 96 | if 'model_state' in ckpt_data: 97 | model.load_state_dict(ckpt_data['model_state']) 98 | else: 99 | model.load_state_dict(torch.load(ckpt_file)) 100 | 101 | model.eval() 102 | model.to(device) 103 | 104 | in_pred = batch_run(model, in_dl, device=device, no_grad=False) 105 | 106 | l_ood_pred = [] 107 | for dl in l_ood_dl: 108 | out_pred = batch_run(model, dl, device=device, no_grad=False) 109 | l_ood_pred.append(out_pred) 110 | 111 | l_ood_auc = [] 112 | for pred in l_ood_pred: 113 | l_ood_auc.append(roc_btw_arr(pred, in_pred)) 114 | 115 | print('OOD Detection Results in AUC') 116 | for ds, auc in zip(l_ood, l_ood_auc): 117 | print(f'{ds}:{auc:.4f}') 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import json 4 | import torch 5 | from torch.utils import data 6 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop,\ 7 | Pad, Normalize 8 | 9 | from loaders.leaveout_dataset import MNISTLeaveOut, CIFAR10LeaveOut 10 | from loaders.modified_dataset import Gray2RGB, MNIST_OOD, FashionMNIST_OOD, \ 11 | CIFAR10_OOD, SVHN_OOD, Constant_OOD, \ 12 | Noise_OOD, CIFAR100_OOD, CelebA_OOD, \ 13 | NotMNIST, ConstantGray_OOD, ImageNet32 14 | from loaders.chimera_dataset import Chimera 15 | from torchvision.datasets import FashionMNIST, Omniglot 16 | from augmentations import get_composed_augmentations 17 | from augmentations.augmentations import ToGray, Invert, Fragment 18 | 19 | 20 | OOD_SIZE = 32 # common image size for OOD detection experiments 21 | 22 | 23 | def get_dataloader(data_dict, mode=None, mode_dict=None, data_aug=None): 24 | """constructs DataLoader 25 | data_dict: data part of cfg 26 | 27 | mode: deprecated argument 28 | mode_dict: deprecated argument 29 | data_aug: deprecated argument 30 | 31 | Example data_dict 32 | dataset: FashionMNISTpad_OOD 33 | path: datasets 34 | shuffle: True 35 | batch_size: 128 36 | n_workers: 8 37 | split: training 38 | dequant: 39 | UniformDequantize: {} 40 | """ 41 | 42 | # dataset loading 43 | aug = get_composed_augmentations(data_dict.get('augmentations', None)) 44 | dequant = get_composed_augmentations(data_dict.get('dequant', None)) 45 | dataset = get_dataset(data_dict, split_type=None, data_aug=aug, dequant=dequant) 46 | 47 | # dataloader loading 48 | loader = data.DataLoader( 49 | dataset, 50 | batch_size=data_dict["batch_size"], 51 | num_workers=data_dict["n_workers"], 52 | shuffle=data_dict.get('shuffle', False), 53 | pin_memory=False, 54 | ) 55 | 56 | return loader 57 | 58 | 59 | def get_dataset(data_dict, split_type=None, data_aug=None, dequant=None): 60 | """ 61 | split_type: deprecated argument 62 | """ 63 | do_concat = any([k.startswith('concat') for k in data_dict.keys()]) 64 | if do_concat: 65 | if data_aug is not None: 66 | return data.ConcatDataset([get_dataset(d, data_aug=data_aug) for k, d in data_dict.items() if k.startswith('concat')]) 67 | elif dequant is not None: 68 | return data.ConcatDataset([get_dataset(d, dequant=dequant) for k, d in data_dict.items() if k.startswith('concat')]) 69 | else: return data.ConcatDataset([get_dataset(d) for k, d in data_dict.items() if k.startswith('concat')]) 70 | name = data_dict["dataset"] 71 | split_type = data_dict['split'] 72 | data_path = data_dict["path"][split_type] if split_type in data_dict["path"] else data_dict["path"] 73 | 74 | # default tranform behavior. 75 | original_data_aug = data_aug 76 | if data_aug is not None: 77 | #data_aug = Compose([data_aug, ToTensor()]) 78 | data_aug = Compose([ToTensor(), data_aug]) 79 | else: 80 | data_aug = ToTensor() 81 | 82 | if dequant is not None: # dequantization should be applied last 83 | data_aug = Compose([data_aug, dequant]) 84 | 85 | 86 | # datasets 87 | if name == 'MNISTLeaveOut': 88 | l_out_class = data_dict['out_class'] 89 | dataset = MNISTLeaveOut(data_path, l_out_class=l_out_class, split=split_type, download=True, 90 | transform=data_aug) 91 | elif name == 'MNISTLeaveOutFragment': 92 | l_out_class = data_dict['out_class'] 93 | fragment = data_dict['fragment'] 94 | dataset = MNISTLeaveOut(data_path, l_out_class=l_out_class, split=split_type, download=True, 95 | transform=Compose([ToTensor(), 96 | Fragment(fragment)])) 97 | elif name == 'MNIST_OOD': 98 | size = data_dict.get('size', 28) 99 | if size == 28: 100 | l_transform = [ToTensor()] 101 | else: 102 | l_transform = [Gray2RGB(), Resize(OOD_SIZE), ToTensor()] 103 | dataset = MNIST_OOD(data_path, split=split_type, download=True, 104 | transform=Compose(l_transform)) 105 | dataset.img_size = (size, size) 106 | 107 | elif name == 'MNISTpad_OOD': 108 | dataset = MNIST_OOD(data_path, split=split_type, download=True, 109 | transform=Compose([Gray2RGB(), 110 | Pad(2), 111 | ToTensor()])) 112 | dataset.img_size = (OOD_SIZE, OOD_SIZE) 113 | 114 | elif name == 'FashionMNIST_OOD': 115 | size = data_dict.get('size', 28) 116 | if size == 28: 117 | l_transform = [ToTensor()] 118 | else: 119 | l_transform = [Gray2RGB(), Resize(OOD_SIZE), ToTensor()] 120 | 121 | dataset = FashionMNIST_OOD(data_path, split=split_type, download=True, 122 | transform=Compose(l_transform)) 123 | dataset.img_size = (size, size) 124 | 125 | elif name == 'FashionMNISTpad_OOD': 126 | dataset = FashionMNIST_OOD(data_path, split=split_type, download=True, 127 | transform=Compose([Gray2RGB(), 128 | Pad(2), 129 | ToTensor()])) 130 | dataset.img_size = (OOD_SIZE, OOD_SIZE) 131 | 132 | elif name == 'HalfMNIST': 133 | mnist = MNIST_OOD(data_path, split=split_type, download=True, 134 | transform=ToTensor()) 135 | dataset = Chimera(mnist, mode='horizontal_blank') 136 | elif name == 'ChimeraMNIST': 137 | mnist = MNIST_OOD(data_path, split=split_type, download=True, 138 | transform=ToTensor()) 139 | dataset = Chimera(mnist, mode='horizontal') 140 | elif name == 'CIFAR10_OOD': 141 | dataset = CIFAR10_OOD(data_path, split=split_type, download=True, 142 | transform=data_aug) 143 | dataset.img_size = (OOD_SIZE, OOD_SIZE) 144 | 145 | elif name == 'CIFAR10LeaveOut': 146 | l_out_class = data_dict['out_class'] 147 | seed = data_dict.get('seed', 1) 148 | dataset = CIFAR10LeaveOut(data_path, l_out_class=l_out_class, split=split_type, download=True, 149 | transform=data_aug, seed=seed) 150 | 151 | elif name == 'CIFAR10_GRAY': 152 | dataset = CIFAR10_OOD(data_path, split=split_type, download=True, 153 | transform=Compose([ToTensor(), 154 | ToGray()])) 155 | dataset.img_size = (OOD_SIZE, OOD_SIZE) 156 | 157 | 158 | elif name == 'CIFAR100_OOD': 159 | dataset = CIFAR100_OOD(data_path, split=split_type, download=True, 160 | transform=ToTensor()) 161 | dataset.img_size = (OOD_SIZE, OOD_SIZE) 162 | 163 | elif name == 'SVHN_OOD': 164 | dataset = SVHN_OOD(data_path, split=split_type, download=True, 165 | transform=data_aug) 166 | dataset.img_size = (OOD_SIZE, OOD_SIZE) 167 | 168 | elif name == 'Constant_OOD': 169 | size = data_dict.get('size', OOD_SIZE) 170 | channel = data_dict.get('channel', 3) 171 | dataset = Constant_OOD(data_path, split=split_type, size=(size, size), 172 | channel=channel, 173 | transform=ToTensor()) 174 | 175 | elif name == 'ConstantGray_OOD': 176 | size = data_dict.get('size', OOD_SIZE) 177 | channel = data_dict.get('channel', 3) 178 | dataset = ConstantGray_OOD(data_path, split=split_type, size=(size, size), 179 | channel=channel, 180 | transform=ToTensor()) 181 | 182 | elif name == 'Noise_OOD': 183 | channel = data_dict.get('channel', 3) 184 | size = data_dict.get('size', OOD_SIZE) 185 | dataset = Noise_OOD(data_path, split=split_type, 186 | transform=ToTensor(), channel=channel, size=(size, size)) 187 | 188 | elif name == 'CelebA_OOD': 189 | size = data_dict.get('size', OOD_SIZE) 190 | l_aug = [] 191 | l_aug.append(CenterCrop(140)) 192 | l_aug.append(Resize(size)) 193 | if original_data_aug is not None: 194 | l_aug.append(original_data_aug) 195 | l_aug.append(ToTensor()) 196 | if dequant is not None: 197 | l_aug.append(dequant) 198 | data_aug = Compose(l_aug) 199 | dataset = CelebA_OOD(data_path, split=split_type, 200 | transform=data_aug) 201 | dataset.img_size = (OOD_SIZE, OOD_SIZE) 202 | 203 | elif name == 'FashionMNIST': # normal FashionMNIS 204 | dataset = FashionMNIST_OOD(data_path, split=split_type, download=True, 205 | transform=ToTensor()) 206 | dataset.img_size = (28, 28) 207 | elif name == 'MNIST': # normal MNIST 208 | dataset = MNIST_OOD(data_path, split=split_type, download=True, 209 | transform=ToTensor()) 210 | dataset.img_size = (28, 28) 211 | elif name == 'NotMNIST': 212 | dataset = NotMNIST(data_path, split=split_type, transform=ToTensor()) 213 | dataset.img_size = (28, 28) 214 | elif name == 'Omniglot': 215 | size = data_dict.get('size', OOD_SIZE) 216 | invert = data_dict.get('invert', True) # invert pixel intensity: x -> 1 - x 217 | if split_type == 'training': 218 | background = True 219 | else: 220 | background = False 221 | 222 | if invert: 223 | tr = Compose([Resize(size), ToTensor(), Invert()]) 224 | else: 225 | tr = Compose([Resize(size), ToTensor()]) 226 | 227 | dataset = Omniglot(data_path, background=background, download=False, 228 | transform=tr) 229 | elif name == 'ImageNet32': 230 | train_split_ratio = data_dict.get('train_split_ratio', 0.8) 231 | seed = data_dict.get('seed', 1) 232 | dataset = ImageNet32(data_path, split=split_type, transform=ToTensor(), seed=seed, 233 | train_split_ratio=train_split_ratio) 234 | else: 235 | n_classes = data_dict["n_classes"] 236 | split = data_dict['split'][split_type] 237 | 238 | param_dict = copy.deepcopy(data_dict) 239 | param_dict.pop("dataset") 240 | param_dict.pop("path") 241 | param_dict.pop("n_classes") 242 | param_dict.pop("split") 243 | param_dict.update({"split_type": split_type}) 244 | 245 | 246 | dataset_instance = _get_dataset_instance(name) 247 | dataset = dataset_instance( 248 | data_path, 249 | n_classes, 250 | split=split, 251 | augmentations=data_aug, 252 | is_transform=True, 253 | **param_dict, 254 | ) 255 | 256 | return dataset 257 | 258 | 259 | def _get_dataset_instance(name): 260 | """get_loader 261 | 262 | :param name: 263 | """ 264 | return { 265 | "basic": basic_dataset, 266 | "inmemory": InMemoryDataset, 267 | }[name] 268 | 269 | 270 | 271 | def np_to_loader(l_tensors, batch_size, num_workers, load_all=False, shuffle=False): 272 | '''Convert a list of numpy arrays to a torch.DataLoader''' 273 | if load_all: 274 | dataset = data.TensorDataset(*[torch.Tensor(X).cuda() for X in l_tensors]) 275 | num_workers = 0 276 | return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle) 277 | else: 278 | dataset = data.TensorDataset(*[torch.Tensor(X) for X in l_tensors]) 279 | return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=False) 280 | -------------------------------------------------------------------------------- /loaders/chimera_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class Chimera(Dataset): 5 | def __init__(self, ds, mode='horizontal'): 6 | self.ds = ds 7 | self.mode = mode 8 | 9 | def __len__(self): 10 | return len(self.ds) 11 | 12 | def __getitem__(self, i): 13 | img, lbl = self.ds[i] 14 | if self.mode == 'horizontal': 15 | half = (img.shape[2]) // 2 16 | img2, lbl2 = self.select_diff_digit(lbl, i) 17 | 18 | img[:, :half, :] = img2[:, :half, :] 19 | 20 | img[:, half-3:half+3, :] = 0 21 | 22 | return img, str(f'{lbl},{lbl2}') 23 | 24 | elif self.mode == 'horizontal_blank': 25 | half = (img.shape[2]) // 2 26 | if i % 2 == 0: 27 | img[:, :half, :] = 0. 28 | else: 29 | img[:, half:, :] = 0. 30 | return img, lbl 31 | else: 32 | raise ValueError 33 | 34 | def select_diff_digit(self, digit, i): 35 | new_digit = digit 36 | idx = i 37 | while new_digit == digit: 38 | idx += 1 39 | idx = idx % len(self.ds) 40 | # idx = np.random.randint(len(self.ds)) 41 | img , lbl = self.ds[idx] 42 | new_digit = lbl 43 | return img, lbl 44 | -------------------------------------------------------------------------------- /loaders/leaveout_dataset.py: -------------------------------------------------------------------------------- 1 | '''Dataset Class''' 2 | import os 3 | import sys 4 | import pickle 5 | import numpy as np 6 | import torch 7 | from torchvision.datasets.mnist import MNIST 8 | from torchvision.datasets.cifar import CIFAR10 9 | from utils import get_shuffled_idx 10 | 11 | 12 | class MNISTLeaveOut(MNIST): 13 | """ 14 | MNIST Dataset with some digits excluded. 15 | 16 | targets will be 1 for excluded digits (outlier) and 0 for included digits. 17 | 18 | See also the original MNIST class: 19 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/mnist.html#MNIST 20 | """ 21 | img_size = (28, 28) 22 | 23 | def __init__(self, root, l_out_class, split='training', transform=None, target_transform=None, 24 | download=False): 25 | """ 26 | l_out_class : a list of ints. these clases are excluded in training 27 | """ 28 | super(MNISTLeaveOut, self).__init__(root, transform=transform, 29 | target_transform=target_transform, download=download) 30 | if split == 'training' or split == 'validation': 31 | self.train = True # training set or test set 32 | else: 33 | self.train = False 34 | self.split = split 35 | self.l_out_class = list(l_out_class) 36 | for c in l_out_class: 37 | assert c in set(list(range(10))) 38 | set_out_class = set(l_out_class) 39 | 40 | if download: 41 | self.download() 42 | 43 | if not self._check_exists(): 44 | raise RuntimeError('Dataset not found.' + 45 | ' You can use download=True to download it') 46 | 47 | if self.train: 48 | data_file = self.training_file 49 | else: 50 | data_file = self.test_file 51 | 52 | data, targets = torch.load(os.path.join(self.processed_folder, data_file)) 53 | 54 | if split == 'training': 55 | data = data[:50000] 56 | targets = targets[:50000] 57 | elif split == 'validation': 58 | data = data[50000:] 59 | targets = targets[50000:] 60 | 61 | out_idx = torch.zeros(len(data), dtype=torch.bool) # pytorch 1.2 62 | for c in l_out_class: 63 | out_idx = out_idx | (targets == c) 64 | 65 | self.data = data[~out_idx] 66 | self.digits = targets[~out_idx] 67 | self.targets = self.digits 68 | 69 | @property 70 | def raw_folder(self): 71 | return os.path.join(self.root, 'MNIST', 'raw') 72 | 73 | @property 74 | def processed_folder(self): 75 | return os.path.join(self.root, 'MNIST', 'processed') 76 | 77 | 78 | class CIFAR10LeaveOut(CIFAR10): 79 | def __init__(self, root, l_out_class, split='training', transform=None, target_transform=None, 80 | download=True, seed=1): 81 | 82 | super(CIFAR10LeaveOut, self).__init__(root, transform=transform, 83 | target_transform=target_transform, download=True) 84 | assert split in ('training', 'validation', 'evaluation') 85 | 86 | if split == 'training' or split == 'validation': 87 | self.train = True 88 | shuffle_idx = get_shuffled_idx(50000, seed) 89 | else: 90 | self.train = False 91 | self.split = split 92 | 93 | if download: 94 | self.download() 95 | 96 | if not self._check_integrity(): 97 | raise RuntimeError('Dataset not found or corrupted.' + 98 | ' You can use download=True to download it') 99 | 100 | if self.train: 101 | downloaded_list = self.train_list 102 | else: 103 | downloaded_list = self.test_list 104 | 105 | self.data = [] 106 | self.targets = [] 107 | 108 | # now load the picked numpy arrays 109 | for file_name, checksum in downloaded_list: 110 | file_path = os.path.join(self.root, self.base_folder, file_name) 111 | with open(file_path, 'rb') as f: 112 | if sys.version_info[0] == 2: 113 | entry = pickle.load(f) 114 | else: 115 | entry = pickle.load(f, encoding='latin1') 116 | self.data.append(entry['data']) 117 | if 'labels' in entry: 118 | self.targets.extend(entry['labels']) 119 | else: 120 | self.targets.extend(entry['fine_labels']) 121 | 122 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 123 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 124 | self.targets = np.array(self.targets) 125 | 126 | if split == 'training': 127 | self.data = self.data[shuffle_idx][:45000] 128 | self.targets = self.targets[shuffle_idx][:45000] 129 | elif split == 'validation': 130 | self.data = self.data[shuffle_idx][45000:] 131 | self.targets = self.targets[shuffle_idx][45000:] 132 | 133 | out_idx = torch.zeros(len(self.data), dtype=torch.bool) 134 | 135 | for c in l_out_class: 136 | out_idx = out_idx | (self.targets == c) 137 | out_idx = out_idx.bool() 138 | 139 | self.data = self.data[~out_idx] 140 | self.targets = self.targets[~out_idx] 141 | self._load_meta() 142 | -------------------------------------------------------------------------------- /loaders/modified_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified_dataset.py 3 | =================== 4 | Inherited and modified pytorch datasets 5 | """ 6 | import os 7 | import re 8 | import sys 9 | import numpy as np 10 | from scipy.io import loadmat 11 | import pickle 12 | from PIL import Image 13 | import torch 14 | from torch.utils.data import Dataset 15 | from torchvision.datasets.mnist import MNIST, FashionMNIST 16 | from torchvision.datasets.cifar import CIFAR10 17 | from torchvision.datasets.svhn import SVHN 18 | from torchvision.datasets.utils import verify_str_arg 19 | from utils import get_shuffled_idx 20 | 21 | 22 | class Gray2RGB: 23 | """change grayscale PIL image to RGB format. 24 | channel values are copied""" 25 | def __call__(self, x): 26 | return x.convert('RGB') 27 | 28 | 29 | class MNIST_OOD(MNIST): 30 | """ 31 | See also the original MNIST class: 32 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/mnist.html#MNIST 33 | """ 34 | def __init__(self, root, split='training', transform=None, target_transform=None, 35 | download=False, seed=1): 36 | super(MNIST_OOD, self).__init__(root, transform=transform, 37 | target_transform=target_transform, download=download) 38 | assert split in ('training', 'validation', 'evaluation') 39 | if split == 'training' or split == 'validation': 40 | self.train = True 41 | shuffle_idx = get_shuffled_idx(60000, seed) 42 | else: 43 | self.train = False 44 | self.split = split 45 | 46 | if download: 47 | self.download() 48 | 49 | if not self._check_exists(): 50 | raise RuntimeError('Dataset not found.' + 51 | ' You can use download=True to download it') 52 | 53 | if self.train: 54 | data_file = self.training_file 55 | else: 56 | data_file = self.test_file 57 | 58 | data, targets = self.data, self.targets 59 | 60 | if split == 'training': 61 | self.data = data[shuffle_idx][:54000] 62 | self.targets = targets[shuffle_idx][:54000] 63 | elif split == 'validation': 64 | self.data = data[shuffle_idx][54000:] 65 | self.targets = targets[shuffle_idx][54000:] 66 | elif split == 'evaluation': 67 | self.data = data 68 | self.targets = targets 69 | 70 | 71 | class FashionMNIST_OOD(FashionMNIST): 72 | def __init__(self, root, split='training', transform=None, target_transform=None, 73 | download=False, seed=1): 74 | super(FashionMNIST_OOD, self).__init__(root, transform=transform, 75 | target_transform=target_transform, download=download) 76 | assert split in ('training', 'validation', 'evaluation') 77 | if split == 'training' or split == 'validation': 78 | self.train = True 79 | shuffle_idx = get_shuffled_idx(60000, seed) 80 | else: 81 | self.train = False 82 | self.split = split 83 | 84 | if download: 85 | self.download() 86 | 87 | if not self._check_exists(): 88 | raise RuntimeError('Dataset not found.' + 89 | ' You can use download=True to download it') 90 | 91 | if self.train: 92 | data_file = self.training_file 93 | else: 94 | data_file = self.test_file 95 | 96 | data, targets = self.data, self.targets 97 | 98 | if split == 'training': 99 | self.data = data[shuffle_idx][:54000] 100 | self.targets = targets[shuffle_idx][:54000] 101 | elif split == 'validation': 102 | self.data = data[shuffle_idx][54000:] 103 | self.targets = targets[shuffle_idx][54000:] 104 | elif split == 'evaluation': 105 | self.data = data 106 | self.targets = targets 107 | 108 | 109 | class CIFAR10_OOD(CIFAR10): 110 | def __init__(self, root, split='training', transform=None, target_transform=None, 111 | download=True): 112 | 113 | super(CIFAR10_OOD, self).__init__(root, transform=transform, 114 | target_transform=target_transform, download=True) 115 | assert split in ('training', 'validation', 'evaluation', 'training_full') 116 | 117 | if split == 'training' or split == 'validation' or split == 'training_full': 118 | self.train = True 119 | shuffle_idx = np.load(os.path.join(root, 'cifar10_trainval_idx.npy')) 120 | else: 121 | self.train = False 122 | self.split = split 123 | 124 | if download: 125 | self.download() 126 | 127 | if not self._check_integrity(): 128 | raise RuntimeError('Dataset not found or corrupted.' + 129 | ' You can use download=True to download it') 130 | 131 | if self.train: 132 | downloaded_list = self.train_list 133 | else: 134 | downloaded_list = self.test_list 135 | 136 | self.data = [] 137 | self.targets = [] 138 | 139 | # now load the picked numpy arrays 140 | for file_name, checksum in downloaded_list: 141 | file_path = os.path.join(self.root, self.base_folder, file_name) 142 | with open(file_path, 'rb') as f: 143 | if sys.version_info[0] == 2: 144 | entry = pickle.load(f) 145 | else: 146 | entry = pickle.load(f, encoding='latin1') 147 | self.data.append(entry['data']) 148 | if 'labels' in entry: 149 | self.targets.extend(entry['labels']) 150 | else: 151 | self.targets.extend(entry['fine_labels']) 152 | 153 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 154 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 155 | self.targets = np.array(self.targets) 156 | 157 | if split == 'training': 158 | self.data = self.data[shuffle_idx][:45000] 159 | self.targets = self.targets[shuffle_idx][:45000] 160 | elif split == 'validation': 161 | self.data = self.data[shuffle_idx][45000:] 162 | self.targets = self.targets[shuffle_idx][45000:] 163 | elif split == 'training_full': 164 | pass 165 | 166 | self._load_meta() 167 | 168 | 169 | class CIFAR100_OOD(CIFAR10_OOD): 170 | base_folder = 'cifar-100-python' 171 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 172 | filename = "cifar-100-python.tar.gz" 173 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 174 | train_list = [ 175 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 176 | ] 177 | 178 | test_list = [ 179 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 180 | ] 181 | meta = { 182 | 'filename': 'meta', 183 | 'key': 'fine_label_names', 184 | 'md5': '7973b15100ade9c7d40fb424638fde48', 185 | } 186 | 187 | 188 | class SVHN_OOD(SVHN): 189 | 190 | def __init__(self, root, split='training', transform=None, target_transform=None, 191 | download=False): 192 | super(SVHN, self).__init__(root, transform=transform, 193 | target_transform=target_transform) 194 | assert split in ('training', 'validation', 'evaluation') 195 | if split == 'training' or split == 'validation': 196 | svhn_split = 'train' 197 | shuffle_idx = np.load(os.path.join(root, 'svhn_trainval_idx.npy')) 198 | else: 199 | svhn_split = 'test' 200 | 201 | # self.split = verify_str_arg(svhn_split, "split", tuple(self.split_list.keys())) 202 | self.split = svhn_split # special treatment 203 | self.url = self.split_list[svhn_split][0] 204 | self.filename = self.split_list[svhn_split][1] 205 | self.file_md5 = self.split_list[svhn_split][2] 206 | 207 | if download: 208 | self.download() 209 | 210 | if not self._check_integrity(): 211 | raise RuntimeError('Dataset not found or corrupted.' + 212 | ' You can use download=True to download it') 213 | 214 | # import here rather than at top of file because this is 215 | # an optional dependency for torchvision 216 | import scipy.io as sio 217 | 218 | # reading(loading) mat file as array 219 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 220 | 221 | self.data = loaded_mat['X'] 222 | # loading from the .mat file gives an np array of type np.uint8 223 | # converting to np.int64, so that we have a LongTensor after 224 | # the conversion from the numpy array 225 | # the squeeze is needed to obtain a 1D tensor 226 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 227 | 228 | # the svhn dataset assigns the class label "10" to the digit 0 229 | # this makes it inconsistent with several loss functions 230 | # which expect the class labels to be in the range [0, C-1] 231 | np.place(self.labels, self.labels == 10, 0) 232 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 233 | 234 | if split == 'training': 235 | self.data = self.data[shuffle_idx][:65930] 236 | self.labels = self.labels[shuffle_idx][:65930] 237 | elif split == 'validation': 238 | self.data = self.data[shuffle_idx][65930:] 239 | self.labels = self.labels[shuffle_idx][65930:] 240 | 241 | 242 | class Constant_OOD(Dataset): 243 | def __init__(self, root, split='training', size=(32, 32), transform=None, channel=3): 244 | super(Constant_OOD, self).__init__() 245 | assert split in ('training', 'validation', 'evaluation') 246 | self.transform = transform 247 | self.root = root 248 | self.img_size = size 249 | self.channel = channel 250 | self.const_vals = np.load(os.path.join(root, 'const_img.npy')) # (40,000, 3) array 251 | 252 | if split == 'training': 253 | self.const_vals = self.const_vals[:32000] 254 | elif split == 'validation': 255 | self.const_vals = self.const_vals[32000:36000] 256 | elif split == 'evaluation': 257 | self.const_vals = self.const_vals[36000:] 258 | 259 | def __getitem__(self, index): 260 | img = np.ones(self.img_size + (self.channel,), dtype=np.float32) * self.const_vals[index] / 255 # (H, W, C) 261 | img = img.astype(np.float32) 262 | if self.channel == 1: 263 | img = img[:, :, 0:1] 264 | 265 | if self.transform is not None: 266 | img = self.transform(img) 267 | return img, 0 268 | 269 | def __len__(self): 270 | return len(self.const_vals) 271 | 272 | 273 | class ConstantGray_OOD(Dataset): 274 | def __init__(self, root, split='training', size=(32, 32), transform=None, channel=3): 275 | super(ConstantGray_OOD, self).__init__() 276 | assert split in ('training', 'validation', 'evaluation') 277 | self.transform = transform 278 | self.root = root 279 | self.img_size = size 280 | self.channel = channel 281 | self.const_vals = np.load(os.path.join(root, 'const_img_gray.npy')) # (40,000,) array 282 | 283 | if split == 'training': 284 | self.const_vals = self.const_vals[:32000] 285 | elif split == 'validation': 286 | self.const_vals = self.const_vals[32000:36000] 287 | elif split == 'evaluation': 288 | self.const_vals = self.const_vals[36000:] 289 | 290 | def __getitem__(self, index): 291 | img = np.ones(self.img_size + (self.channel,), dtype=np.float32) * self.const_vals[index] / 255 # (H, W, C) 292 | img = img.astype(np.float32) 293 | 294 | if self.transform is not None: 295 | img = self.transform(img) 296 | return img, 0 297 | 298 | def __len__(self): 299 | return len(self.const_vals) 300 | 301 | 302 | class Noise_OOD(Dataset): 303 | def __init__(self, root, split='training', transform=None, channel=3, size=(32,32)): 304 | super(Noise_OOD, self).__init__() 305 | assert split in ('training', 'validation', 'evaluation') 306 | self.transform = transform 307 | self.root = root 308 | self.vals = np.load(os.path.join(root, 'noise_img.npy')) # (40000, 32, 32, 3) array 309 | self.channel = channel 310 | self.size = size 311 | 312 | if split == 'training': 313 | self.vals = self.vals[:32000] 314 | elif split == 'validation': 315 | self.vals = self.vals[32000:36000] 316 | elif split == 'evaluation': 317 | self.vals = self.vals[36000:] 318 | 319 | def __getitem__(self, index): 320 | img = self.vals[index] / 255 321 | 322 | img = img.astype(np.float32) 323 | if self.channel == 1: 324 | img = img[:, :, 0:1] 325 | if self.size != (32, 32): 326 | img = img[:self.size[0], :self.size[1], :] 327 | if self.transform is not None: 328 | img = self.transform(img) 329 | return img, 0 330 | 331 | def __len__(self): 332 | return len(self.vals) 333 | 334 | 335 | # 336 | # CelebA 337 | # 338 | IMAGE_EXTENSTOINS = [".png", ".jpg", ".jpeg", ".bmp"] 339 | ATTR_ANNO = "list_attr_celeba.txt" 340 | 341 | 342 | def _is_image(fname): 343 | _, ext = os.path.splitext(fname) 344 | return ext.lower() in IMAGE_EXTENSTOINS 345 | 346 | 347 | def _find_images_and_annotation(root_dir): 348 | images = {} 349 | attr = None 350 | assert os.path.exists(root_dir), "{} not exists".format(root_dir) 351 | img_dir = os.path.join(root_dir, 'Img/img_align_celeba') 352 | for fname in os.listdir(img_dir): 353 | if _is_image(fname): 354 | path = os.path.join(img_dir, fname) 355 | images[os.path.splitext(fname)[0]] = path 356 | 357 | attr = os.path.join(root_dir, 'Anno', ATTR_ANNO) 358 | assert attr is not None, "Failed to find `list_attr_celeba.txt`" 359 | 360 | # begin to parse all image 361 | final = [] 362 | with open(attr, "r") as fin: 363 | image_total = 0 364 | attrs = [] 365 | for i_line, line in enumerate(fin): 366 | line = line.strip() 367 | if i_line == 0: 368 | image_total = int(line) 369 | elif i_line == 1: 370 | attrs = line.split(" ") 371 | else: 372 | line = re.sub("[ ]+", " ", line) 373 | line = line.split(" ") 374 | fname = os.path.splitext(line[0])[0] 375 | onehot = [int(int(d) > 0) for d in line[1:]] 376 | assert len(onehot) == len(attrs), "{} only has {} attrs < {}".format( 377 | fname, len(onehot), len(attrs)) 378 | final.append({ 379 | "path": images[fname], 380 | "attr": onehot 381 | }) 382 | print("Find {} images, with {} attrs".format(len(final), len(attrs))) 383 | return final, attrs 384 | 385 | 386 | class CelebA_OOD(Dataset): 387 | def __init__(self, root_dir, split='training', transform=None, seed=1): 388 | """attributes are not implemented""" 389 | super().__init__() 390 | assert split in ('training', 'validation', 'evaluation') 391 | if split == 'training': 392 | setnum = 0 393 | elif split == 'validation': 394 | setnum = 1 395 | elif split == 'evaluation': 396 | setnum = 2 397 | else: 398 | raise ValueError(f'Unexpected split {split}') 399 | 400 | d_split = self.read_split_file(root_dir) 401 | self.data = d_split[setnum] 402 | self.transform = transform 403 | self.split = split 404 | self.root_dir = os.path.join(root_dir, 'CelebA', 'Img', 'img_align_celeba') 405 | 406 | 407 | def __getitem__(self, index): 408 | filename = self.data[index] 409 | path = os.path.join(self.root_dir, filename) 410 | image = Image.open(path).convert("RGB") 411 | if self.transform is not None: 412 | image = self.transform(image) 413 | return image, 0. 414 | 415 | def __len__(self): 416 | return len(self.data) 417 | 418 | def read_split_file(self, root_dir): 419 | split_path = os.path.join(root_dir, 'CelebA', 'Eval', 'list_eval_partition.txt') 420 | d_split = {0:[], 1:[], 2:[]} 421 | with open(split_path) as f: 422 | for line in f: 423 | fname, setnum = line.strip().split() 424 | d_split[int(setnum)].append(fname) 425 | return d_split 426 | 427 | 428 | class NotMNIST(Dataset): 429 | def __init__(self, root_dir, split='training', transform=None): 430 | super().__init__() 431 | self.transform = transform 432 | shuffle_idx = np.load(os.path.join(root_dir, 'notmnist_trainval_idx.npy')) 433 | datadict = loadmat(os.path.join(root_dir, 'NotMNIST/notMNIST_small.mat')) 434 | data = datadict['images'].transpose(2, 0, 1).astype('float32') 435 | data = data[shuffle_idx] 436 | targets = datadict['labels'].astype('float32') 437 | targets = targets[shuffle_idx] 438 | if split == 'training': 439 | self.data = data[:14979] 440 | self.targets = targets[:14979] 441 | elif split == 'validation': 442 | self.data = data[14979:16851] 443 | self.targets = targets[14979:16851] 444 | elif split == 'evaluation': 445 | self.data = data[16851:] 446 | self.targets = targets[16851:] 447 | else: 448 | raise ValueError 449 | 450 | def __getitem__(self, index): 451 | image = self.data[index] 452 | if self.transform is not None: 453 | image = self.transform(image) 454 | return image, self.targets[index] 455 | 456 | def __len__(self): 457 | return len(self.data) 458 | 459 | 460 | class ImageNet32(Dataset): 461 | def __init__(self, root_dir, split='training', transform=None, seed=1, train_split_ratio=0.8): 462 | """ 463 | split: 'training' - the whole train split (1281149) 464 | 'evaluation' - the whole val split (49999) 465 | 'train_train' - (train_split_ratio) portion of train split 466 | 'train_val' - (1 - train_split_ratio) portion of train split 467 | """ 468 | super().__init__() 469 | self.root_dir = os.path.join(root_dir, 'ImageNet32') 470 | self.split = split 471 | self.transform = transform 472 | self.shuffle_idx = get_shuffled_idx(1281149, seed) 473 | n_train = int(len(self.shuffle_idx) * train_split_ratio) 474 | 475 | if split == 'training': # whole train split 476 | self.imgdir = os.path.join(self.root_dir, 'train_32x32') 477 | self.l_img_file = sorted(os.listdir(self.imgdir)) 478 | elif split == 'evaluation': # whole val split 479 | self.imgdir = os.path.join(self.root_dir, 'valid_32x32') 480 | self.l_img_file = sorted(os.listdir(self.imgdir)) 481 | elif split == 'train_train': # 80 % of train split 482 | self.imgdir = os.path.join(self.root_dir, 'train_32x32') 483 | self.l_img_file = sorted(os.listdir(self.imgdir)) 484 | self.l_img_file = [self.l_img_file[i] for i in self.shuffle_idx[:n_train]] 485 | elif split == 'train_val': # 20 % of train split 486 | self.imgdir = os.path.join(self.root_dir, 'train_32x32') 487 | self.l_img_file = sorted(os.listdir(self.imgdir)) 488 | self.l_img_file = [self.l_img_file[i] for i in self.shuffle_idx[n_train:]] 489 | else: 490 | raise ValueError(f'{split}') 491 | 492 | def __getitem__(self, index): 493 | imgpath = os.path.join(self.imgdir, self.l_img_file[index]) 494 | im = Image.open(imgpath) 495 | if self.transform is not None: 496 | im = self.transform(im) 497 | return im, 0 498 | 499 | def __len__(self): 500 | return len(self.l_img_file) 501 | -------------------------------------------------------------------------------- /loaders/synthetic.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Synthetic distributions from https://github.com/nicola-decao/BNAF/ 3 | ''' 4 | import torch 5 | import numpy as np 6 | 7 | def sample2d(data, batch_size=200): 8 | rng = np.random.RandomState() 9 | 10 | if data == '8gaussians': 11 | scale = 4. 12 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 13 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 14 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 15 | centers = [(scale * x, scale * y) for x, y in centers] 16 | 17 | dataset = [] 18 | for i in range(batch_size): 19 | point = rng.randn(2) * 0.5 20 | idx = rng.randint(8) 21 | center = centers[idx] 22 | point[0] += center[0] 23 | point[1] += center[1] 24 | dataset.append(point) 25 | dataset = np.array(dataset, dtype='float32') 26 | dataset /= 1.414 27 | return dataset 28 | 29 | elif data == '2spirals': 30 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 31 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 32 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 33 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 34 | x += np.random.randn(*x.shape) * 0.1 35 | return x 36 | 37 | elif data == 'checkerboard': 38 | x1 = np.random.rand(batch_size) * 4 - 2 39 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 40 | x2 = x2_ + (np.floor(x1) % 2) 41 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 42 | 43 | else: 44 | raise RuntimeError 45 | 46 | 47 | 48 | def energy2d(data, z): 49 | 50 | if data == 't1': 51 | return U1(z) 52 | elif data == 't2': 53 | return U2(z) 54 | elif data == 't3': 55 | return U3(z) 56 | elif data == 't4': 57 | return U4(z) 58 | else: 59 | raise RuntimeError 60 | 61 | def w1(z): 62 | return torch.sin(2 * np.pi * z[:, 0] / 4) 63 | 64 | def w2(z): 65 | return 3 * torch.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2) 66 | 67 | def w3(z): 68 | return 3 * torch.sigmoid((z[:, 0] - 1) / 0.3) 69 | 70 | def U1(z): 71 | z_norm = torch.norm(z, 2, 1) 72 | add1 = 0.5 * ((z_norm - 2) / 0.4) ** 2 73 | add2 = - torch.log(torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2) +\ 74 | torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2) + 1e-9) 75 | 76 | return add1 + add2 77 | 78 | def U2(z): 79 | return 0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2 80 | 81 | def U3(z): 82 | in1 = torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.35) ** 2) 83 | in2 = torch.exp(-0.5 * ((z[:, 1] - w1(z) + w2(z)) / 0.35) ** 2) 84 | return -torch.log(in1 + in2 + 1e-9) 85 | 86 | def U4(z): 87 | in1 = torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2) 88 | in2 = torch.exp(-0.5 * ((z[:, 1] - w1(z) + w3(z)) / 0.35) ** 2) 89 | return -torch.log(in1 + in2 + 1e-9) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | from sklearn.metrics import roc_auc_score 6 | 7 | 8 | class runningScore(object): 9 | def __init__(self, n_classes): 10 | self.n_classes = n_classes 11 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 12 | 13 | def _fast_hist(self, label_true, label_pred, n_class): 14 | mask = (label_true >= 0) & (label_true < n_class) 15 | hist = np.bincount( 16 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 17 | ).reshape(n_class, n_class) 18 | return hist 19 | 20 | def update(self, label_trues, label_preds): 21 | for lt, lp in zip(label_trues, label_preds): 22 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 23 | 24 | def get_scores(self): 25 | """Returns accuracy score evaluation result. 26 | - overall accuracy 27 | - mean accuracy 28 | - mean IU 29 | - fwavacc 30 | """ 31 | hist = self.confusion_matrix 32 | acc = np.diag(hist).sum() / hist.sum() 33 | acc_cls = np.diag(hist) / hist.sum(axis=1) 34 | acc_cls = np.nanmean(acc_cls) 35 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 36 | mean_iu = np.nanmean(iu) 37 | freq = hist.sum(axis=1) / hist.sum() 38 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 39 | cls_iu = dict(zip(range(self.n_classes), iu)) 40 | 41 | return ( 42 | { 43 | "Overall Acc: \t": acc, 44 | "Mean Acc : \t": acc_cls, 45 | "FreqW Acc : \t": fwavacc, 46 | "Mean IoU : \t": mean_iu, 47 | }, 48 | cls_iu 49 | ) 50 | 51 | def reset(self): 52 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 53 | 54 | 55 | class runningScore_cls(object): 56 | def __init__(self, n_classes): 57 | self.n_classes = n_classes 58 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 59 | 60 | def _fast_hist(self, label_true, label_pred, n_class): 61 | mask = (label_true >= 0) & (label_true < n_class) 62 | hist = np.bincount( 63 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 64 | ).reshape(n_class, n_class) 65 | return hist 66 | 67 | def update(self, label_trues, label_preds): 68 | for lt, lp in zip(label_trues, label_preds): 69 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 70 | 71 | def get_scores(self): 72 | """Returns accuracy score evaluation result. 73 | - overall accuracy 74 | - mean accuracy 75 | - mean IU 76 | - fwavacc 77 | """ 78 | hist = self.confusion_matrix 79 | acc = np.diag(hist).sum() / hist.sum() 80 | acc_cls = np.diag(hist) / hist.sum(axis=1) 81 | acc_cls = np.nanmean(acc_cls) 82 | 83 | 84 | return {"Overall Acc : \t": acc, 85 | "Mean Acc : \t": acc_cls,} 86 | 87 | 88 | def reset(self): 89 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 90 | 91 | 92 | class AUC: 93 | def __init__(self): 94 | self.l_label = [] 95 | self.l_pred = [] 96 | 97 | def update(self, label_trues, label_preds): 98 | for lt, lp in zip(label_trues, label_preds): 99 | self.l_label.append(lt) 100 | self.l_pred.append(lp) 101 | 102 | def reset(self): 103 | self.l_label = [] 104 | self.l_pred = [] 105 | 106 | def get_scores(self): 107 | auc = roc_auc_score(self.l_label, self.l_pred) 108 | return {'AUROC: \t': auc} 109 | 110 | 111 | class averageMeter(object): 112 | """Computes and stores the average and current value""" 113 | 114 | def __init__(self): 115 | self.reset() 116 | 117 | def reset(self): 118 | self.val = 0 119 | self.avg = 0 120 | self.sum = 0 121 | self.count = 0 122 | 123 | def update(self, val, n=1): 124 | self.val = val 125 | self.sum += val * n 126 | self.count += n 127 | self.avg = self.sum / self.count 128 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | import copy 4 | import torch 5 | import torchvision.models as models 6 | import torchvision.transforms as transforms 7 | from augmentations import get_composed_augmentations 8 | 9 | from models.ae import ( 10 | AE, 11 | VAE, 12 | DAE, 13 | WAE, 14 | ) 15 | from models.nae import NAE, NAE_L2_OMI 16 | from models.mcmc import get_sampler 17 | from models.modules import ( 18 | DeConvNet2, 19 | FCNet, 20 | ConvNet2FC, 21 | ConvMLP, 22 | IGEBMEncoder, 23 | ConvNet64, 24 | DeConvNet64, 25 | ) 26 | from models.modules_sngan import Generator as SNGANGeneratorBN 27 | from models.modules_sngan import GeneratorNoBN as SNGANGeneratorNoBN 28 | from models.modules_sngan import GeneratorNoBN64 as SNGANGeneratorNoBN64 29 | from models.modules_sngan import GeneratorGN as SNGANGeneratorGN 30 | from models.energybased import EnergyBasedModel 31 | 32 | 33 | 34 | def get_net(in_dim, out_dim, **kwargs): 35 | nh = kwargs.get("nh", 8) 36 | out_activation = kwargs.get("out_activation", "linear") 37 | 38 | if kwargs["arch"] == "conv2fc": 39 | nh_mlp = kwargs["nh_mlp"] 40 | net = ConvNet2FC( 41 | in_chan=in_dim, 42 | out_chan=out_dim, 43 | nh=nh, 44 | nh_mlp=nh_mlp, 45 | out_activation=out_activation, 46 | ) 47 | 48 | elif kwargs["arch"] == "deconv2": 49 | net = DeConvNet2( 50 | in_chan=in_dim, out_chan=out_dim, nh=nh, out_activation=out_activation 51 | ) 52 | elif kwargs["arch"] == "conv64": 53 | num_groups = kwargs.get("num_groups", None) 54 | use_bn = kwargs.get("use_bn", False) 55 | net = ConvNet64( 56 | in_chan=in_dim, 57 | out_chan=out_dim, 58 | nh=nh, 59 | out_activation=out_activation, 60 | num_groups=num_groups, 61 | use_bn=use_bn, 62 | ) 63 | elif kwargs["arch"] == "deconv64": 64 | num_groups = kwargs.get("num_groups", None) 65 | use_bn = kwargs.get("use_bn", False) 66 | net = DeConvNet64( 67 | in_chan=in_dim, 68 | out_chan=out_dim, 69 | nh=nh, 70 | out_activation=out_activation, 71 | num_groups=num_groups, 72 | use_bn=use_bn, 73 | ) 74 | elif kwargs["arch"] == "fc": 75 | l_hidden = kwargs["l_hidden"] 76 | activation = kwargs["activation"] 77 | net = FCNet( 78 | in_dim=in_dim, 79 | out_dim=out_dim, 80 | l_hidden=l_hidden, 81 | activation=activation, 82 | out_activation=out_activation, 83 | ) 84 | elif kwargs["arch"] == "convmlp": 85 | l_hidden = kwargs["l_hidden"] 86 | activation = kwargs["activation"] 87 | net = ConvMLP( 88 | in_dim=in_dim, 89 | out_dim=out_dim, 90 | l_hidden=l_hidden, 91 | activation=activation, 92 | out_activation=out_activation, 93 | ) 94 | elif kwargs["arch"] == "IGEBMEncoder": 95 | use_spectral_norm = kwargs.get("use_spectral_norm", False) 96 | keepdim = kwargs.get("keepdim", True) 97 | net = IGEBMEncoder( 98 | in_chan=in_dim, 99 | out_chan=out_dim, 100 | n_class=None, 101 | use_spectral_norm=use_spectral_norm, 102 | keepdim=keepdim, 103 | ) 104 | elif kwargs["arch"] == "sngan_generator_bn": 105 | hidden_dim = kwargs.get("hidden_dim", 128) 106 | out_activation = kwargs["out_activation"] 107 | net = SNGANGeneratorBN( 108 | z_dim=in_dim, 109 | channels=out_dim, 110 | hidden_dim=hidden_dim, 111 | out_activation=out_activation, 112 | ) 113 | elif kwargs["arch"] == "sngan_generator_nobn": 114 | hidden_dim = kwargs.get("hidden_dim", 128) 115 | out_activation = kwargs["out_activation"] 116 | net = SNGANGeneratorNoBN( 117 | z_dim=in_dim, 118 | channels=out_dim, 119 | hidden_dim=hidden_dim, 120 | out_activation=out_activation, 121 | ) 122 | elif kwargs["arch"] == "sngan_generator_nobn64": 123 | hidden_dim = kwargs.get("hidden_dim", 128) 124 | out_activation = kwargs["out_activation"] 125 | net = SNGANGeneratorNoBN64( 126 | z_dim=in_dim, 127 | channels=out_dim, 128 | hidden_dim=hidden_dim, 129 | out_activation=out_activation, 130 | ) 131 | elif kwargs["arch"] == "sngan_generator_gn": 132 | hidden_dim = kwargs.get("hidden_dim", 128) 133 | out_activation = kwargs["out_activation"] 134 | num_groups = kwargs["num_groups"] 135 | net = SNGANGeneratorGN( 136 | z_dim=in_dim, 137 | channels=out_dim, 138 | hidden_dim=hidden_dim, 139 | out_activation=out_activation, 140 | num_groups=num_groups, 141 | ) 142 | 143 | return net 144 | 145 | 146 | def get_ae(**model_cfg): 147 | arch = model_cfg.pop('arch') 148 | x_dim = model_cfg.pop("x_dim") 149 | z_dim = model_cfg.pop("z_dim") 150 | enc_cfg = model_cfg.pop('encoder') 151 | dec_cfg = model_cfg.pop('decoder') 152 | 153 | if arch == "ae": 154 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **enc_cfg) 155 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg) 156 | ae = AE(encoder, decoder) 157 | elif arch == "dae": 158 | sig = model_cfg["sig"] 159 | noise_type = model_cfg["noise_type"] 160 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **enc_cfg) 161 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg) 162 | ae = DAE(encoder, decoder, sig=sig, noise_type=noise_type) 163 | elif arch == "wae": 164 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **enc_cfg) 165 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg) 166 | ae = WAE(encoder, decoder, **model_cfg) 167 | elif arch == "vae": 168 | sigma_trainable = model_cfg.get("sigma_trainable", False) 169 | encoder = get_net(in_dim=x_dim, out_dim=z_dim * 2, **enc_cfg) 170 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg) 171 | ae = VAE(encoder, decoder, **model_cfg) 172 | return ae 173 | 174 | 175 | 176 | def get_vae(**model_cfg): 177 | x_dim = model_cfg["x_dim"] 178 | z_dim = model_cfg["z_dim"] 179 | encoder_out_dim = z_dim * 2 180 | 181 | encoder = get_net(in_dim=x_dim, out_dim=encoder_out_dim, **model_cfg["encoder"]) 182 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg["decoder"]) 183 | n_sample = model_cfg.get("n_sample", 1) 184 | pred_method = model_cfg.get("pred_method", "recon") 185 | 186 | if model_cfg["arch"] == "vae": 187 | ae = VAE(encoder, decoder, n_sample=n_sample, pred_method=pred_method) 188 | return ae 189 | 190 | 191 | def get_ebm(**model_cfg): 192 | model_cfg = copy.deepcopy(model_cfg) 193 | if "arch" in model_cfg: 194 | model_cfg.pop("arch") 195 | in_dim = model_cfg["x_dim"] 196 | model_cfg.pop("x_dim") 197 | net = get_net(in_dim=in_dim, out_dim=1, **model_cfg["net"]) 198 | model_cfg.pop("net") 199 | return EnergyBasedModel(net, **model_cfg) 200 | 201 | 202 | def get_nae(**model_cfg): 203 | arch = model_cfg.pop("arch") 204 | x_dim = model_cfg["x_dim"] 205 | z_dim = model_cfg["z_dim"] 206 | 207 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **model_cfg["encoder"]) 208 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg["decoder"]) 209 | 210 | if arch == "nae": 211 | ae = NAE(encoder, decoder, **model_cfg["nae"]) 212 | else: 213 | raise ValueError(f"{arch}") 214 | return ae 215 | 216 | 217 | def get_nae_v2(**model_cfg): 218 | arch = model_cfg.pop('arch') 219 | sampling = model_cfg.pop('sampling') 220 | x_dim = model_cfg['x_dim'] 221 | z_dim = model_cfg['z_dim'] 222 | 223 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **model_cfg["encoder"]) 224 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg["decoder"]) 225 | if arch == 'nae_l2' and sampling == 'omi': 226 | sampler_z = get_sampler(**model_cfg['sampler_z']) 227 | sampler_x = get_sampler(**model_cfg['sampler_x']) 228 | nae = NAE_L2_OMI(encoder, decoder, sampler_z, sampler_x, **model_cfg['nae']) 229 | else: 230 | raise ValueError(f'Invalid sampling: {sampling}') 231 | return nae 232 | 233 | 234 | def get_model(cfg, *args, version=None, **kwargs): 235 | # cfg can be a whole config dictionary or a value of a key 'model' in the config dictionary (cfg['model']). 236 | if "model" in cfg: 237 | model_dict = cfg["model"] 238 | elif "arch" in cfg: 239 | model_dict = cfg 240 | else: 241 | raise ValueError(f"Invalid model configuration dictionary: {cfg}") 242 | name = model_dict["arch"] 243 | model = _get_model_instance(name) 244 | model = model(**model_dict) 245 | return model 246 | 247 | 248 | def _get_model_instance(name): 249 | try: 250 | return { 251 | "ae": get_ae, 252 | "nae": get_nae, 253 | "nae_l2": get_nae_v2, 254 | }[name] 255 | except: 256 | raise ("Model {} not available".format(name)) 257 | 258 | 259 | def load_pretrained(identifier, config_file, ckpt_file, root='pretrained', **kwargs): 260 | """ 261 | load pre-trained model. 262 | identifier: '/'. e.g. 'ae_mnist/z16' 263 | config_file: name of a config file. e.g. 'ae.yml' 264 | ckpt_file: name of a model checkpoint file. e.g. 'model_best.pth' 265 | root: path to pretrained directory 266 | """ 267 | config_path = os.path.join(root, identifier, config_file) 268 | ckpt_path = os.path.join(root, identifier, ckpt_file) 269 | cfg = OmegaConf.load(config_path) 270 | model_name = cfg['model']['arch'] 271 | 272 | model = get_model(cfg) 273 | ckpt = torch.load(ckpt_path, map_location='cpu') 274 | if 'model_state' in ckpt: 275 | ckpt = ckpt['model_state'] 276 | model.load_state_dict(ckpt) 277 | model.eval() 278 | return model, cfg 279 | 280 | -------------------------------------------------------------------------------- /models/ae.py: -------------------------------------------------------------------------------- 1 | """ 2 | ae.py 3 | ===== 4 | Autoencoders 5 | """ 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import optim 11 | from torchvision.utils import make_grid 12 | from models.modules import IsotropicGaussian, IsotropicLaplace 13 | 14 | 15 | class AE(nn.Module): 16 | """autoencoder""" 17 | def __init__(self, encoder, decoder): 18 | """ 19 | encoder, decoder : neural networks 20 | """ 21 | super(AE, self).__init__() 22 | self.encoder = encoder 23 | self.decoder = decoder 24 | self.own_optimizer = False 25 | 26 | def forward(self, x): 27 | z = self.encode(x) 28 | recon = self.decoder(z) 29 | return recon 30 | 31 | def encode(self, x): 32 | z = self.encoder(x) 33 | return z 34 | 35 | def predict(self, x): 36 | """one-class anomaly prediction""" 37 | recon = self(x) 38 | if hasattr(self.decoder, 'error'): 39 | predict = self.decoder.error(x, recon) 40 | else: 41 | predict = ((recon - x) ** 2).view(len(x), -1).mean(dim=1) 42 | return predict 43 | 44 | def predict_and_reconstruct(self, x): 45 | recon = self(x) 46 | if hasattr(self.decoder, 'error'): 47 | recon_err = self.decoder.error(x, recon) 48 | else: 49 | recon_err = ((recon - x) ** 2).view(len(x), -1).mean(dim=1) 50 | return recon_err, recon 51 | 52 | def validation_step(self, x, **kwargs): 53 | recon = self(x) 54 | if hasattr(self.decoder, 'error'): 55 | predict = self.decoder.error(x, recon) 56 | else: 57 | predict = ((recon - x) ** 2).view(len(x), -1).mean(dim=1) 58 | loss = predict.mean() 59 | 60 | if kwargs.get('show_image', True): 61 | x_img = make_grid(x.detach().cpu(), nrow=10, range=(0, 1)) 62 | recon_img = make_grid(recon.detach().cpu(), nrow=10, range=(0, 1)) 63 | else: 64 | x_img, recon_img = None, None 65 | return {'loss': loss.item(), 'predict': predict, 'reconstruction': recon, 66 | 'input@': x_img, 'recon@': recon_img} 67 | 68 | def train_step(self, x, optimizer, clip_grad=None, **kwargs): 69 | optimizer.zero_grad() 70 | recon_error = self.predict(x) 71 | loss = recon_error.mean() 72 | loss.backward() 73 | if clip_grad is not None: 74 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=clip_grad) 75 | optimizer.step() 76 | return {'loss': loss.item()} 77 | 78 | def reconstruct(self, x): 79 | return self(x) 80 | 81 | def sample(self, N, z_shape=None, device='cpu'): 82 | if z_shape is None: 83 | z_shape = self.encoder.out_shape 84 | 85 | rand_z = torch.rand(N, *z_shape).to(device) * 2 - 1 86 | sample_x = self.decoder(rand_z) 87 | return sample_x 88 | 89 | 90 | 91 | def clip_vector_norm(x, max_norm): 92 | norm = x.norm(dim=-1, keepdim=True) 93 | x = x * ((norm < max_norm).to(torch.float) + (norm > max_norm).to(torch.float) * max_norm/norm + 1e-6) 94 | return x 95 | 96 | 97 | class DAE(AE): 98 | """denoising autoencoder""" 99 | def __init__(self, encoder, decoder, sig=0.0, noise_type='gaussian'): 100 | super(DAE, self).__init__(encoder, decoder) 101 | self.sig = sig 102 | self.noise_type = noise_type 103 | 104 | def train_step(self, x, optimizer, y=None): 105 | optimizer.zero_grad() 106 | if self.noise_type == 'gaussian': 107 | noise = torch.randn(*x.shape, dtype=torch.float32) 108 | noise = noise.to(x.device) 109 | recon = self(x + self.sig * noise) 110 | elif self.noise_type == 'saltnpepper': 111 | x = self.salt_and_pepper(x) 112 | recon = self(x) 113 | else: 114 | raise ValueError(f'Invalid noise_type: {self.noise_type}') 115 | 116 | loss = torch.mean((recon - x) ** 2) 117 | loss.backward() 118 | optimizer.step() 119 | return {'loss': loss.item()} 120 | 121 | def salt_and_pepper(self, img): 122 | """salt and pepper noise for mnist""" 123 | # for salt and pepper noise, sig is probability of occurance of noise pixels. 124 | img = img.copy() 125 | prob = self.sig 126 | rnd = torch.random.rand(*img.shape).to(img.device) 127 | img[rnd < prob / 2] = 0. 128 | img[rnd > 1 - prob / 2] = 1. 129 | return img 130 | 131 | 132 | class VAE(AE): 133 | def __init__(self, encoder, decoder, n_sample=1, use_mean=False, pred_method='recon', sigma_trainable=False): 134 | super(VAE, self).__init__(encoder, IsotropicGaussian(decoder, sigma=1, sigma_trainable=sigma_trainable)) 135 | self.n_sample = n_sample # the number of samples to generate for anomaly detection 136 | self.use_mean = use_mean # if True, does not sample from posterior distribution 137 | self.pred_method = pred_method # which anomaly score to use 138 | self.z_shape = None 139 | 140 | def forward(self, x): 141 | z = self.encoder(x) 142 | z_sample = self.sample_latent(z) 143 | return self.decoder(z_sample) 144 | 145 | def sample_latent(self, z): 146 | half_chan = int(z.shape[1] / 2) 147 | mu, log_sig = z[:, :half_chan], z[:, half_chan:] 148 | if self.use_mean: 149 | return mu 150 | eps = torch.randn(*mu.shape, dtype=torch.float32) 151 | eps = eps.to(z.device) 152 | return mu + torch.exp(log_sig) * eps 153 | 154 | # def sample_marginal_latent(self, z_shape): 155 | # return torch.randn(z_shape) 156 | 157 | def kl_loss(self, z): 158 | """analytic (positive) KL divergence between gaussians 159 | KL(q(z|x) | p(z))""" 160 | half_chan = int(z.shape[1] / 2) 161 | mu, log_sig = z[:, :half_chan], z[:, half_chan:] 162 | mu_sq = mu ** 2 163 | sig_sq = torch.exp(log_sig) ** 2 164 | kl = mu_sq + sig_sq - torch.log(sig_sq) - 1 165 | # return 0.5 * torch.mean(kl.view(len(kl), -1), dim=1) 166 | return 0.5 * torch.sum(kl.view(len(kl), -1), dim=1) 167 | 168 | def train_step(self, x, optimizer, y=None, clip_grad=None): 169 | optimizer.zero_grad() 170 | z = self.encoder(x) 171 | z_sample = self.sample_latent(z) 172 | nll = - self.decoder.log_likelihood(x, z_sample) 173 | 174 | kl_loss = self.kl_loss(z) 175 | loss = nll + kl_loss 176 | loss = loss.mean() 177 | nll = nll.mean() 178 | 179 | loss.backward() 180 | optimizer.step() 181 | return {'loss': nll.item(), 'vae/kl_loss_': kl_loss.mean(), 'vae/sigma_': self.decoder.sigma.item()} 182 | 183 | def predict(self, x): 184 | """one-class anomaly prediction using the metric specified by self.anomaly_score""" 185 | if self.pred_method == 'recon': 186 | return self.reconstruction_probability(x) 187 | elif self.pred_method == 'lik': 188 | return - self.marginal_likelihood(x) # negative log likelihood 189 | else: 190 | raise ValueError(f'{self.pred_method} should be recon or lik') 191 | 192 | def validation_step(self, x, y=None, **kwargs): 193 | z = self.encoder(x) 194 | z_sample = self.sample_latent(z) 195 | recon = self.decoder(z_sample) 196 | loss = torch.mean((recon - x) ** 2) 197 | predict = - self.decoder.log_likelihood(x, z_sample) 198 | 199 | if kwargs.get('show_image', True): 200 | x_img = make_grid(x.detach().cpu(), nrow=10, range=(0, 1)) 201 | recon_img = make_grid(recon.detach().cpu(), nrow=10, range=(0, 1)) 202 | else: 203 | x_img, recon_img = None, None 204 | 205 | return {'loss': loss.item(), 'predict': predict, 'reconstruction': recon, 206 | 'input@': x_img, 'recon@': recon_img} 207 | 208 | def reconstruction_probability(self, x): 209 | l_score = [] 210 | z = self.encoder(x) 211 | for i in range(self.n_sample): 212 | z_sample = self.sample_latent(z) 213 | recon_loss = - self.decoder.log_likelihood(x, z_sample) 214 | score = recon_loss 215 | l_score.append(score) 216 | return torch.stack(l_score).mean(dim=0) 217 | 218 | def marginal_likelihood(self, x, n_sample=None): 219 | """marginal likelihood from importance sampling 220 | log P(X) = log \int P(X|Z) * P(Z)/Q(Z|X) * Q(Z|X) dZ""" 221 | if n_sample is None: 222 | n_sample = self.n_sample 223 | 224 | # check z shape 225 | with torch.no_grad(): 226 | z = self.encoder(x) 227 | 228 | l_score = [] 229 | for i in range(n_sample): 230 | z_sample = self.sample_latent(z) 231 | log_recon = self.decoder.log_likelihood(x, z_sample) 232 | log_prior = self.log_prior(z_sample) 233 | log_posterior = self.log_posterior(z, z_sample) 234 | l_score.append(log_recon + log_prior - log_posterior) 235 | score = torch.stack(l_score) 236 | logN = torch.log(torch.tensor(n_sample, dtype=torch.float, device=x.device)) 237 | return torch.logsumexp(score, dim=0) - logN 238 | 239 | def marginal_likelihood_naive(self, x, n_sample=None): 240 | if n_sample is None: 241 | n_sample = self.n_sample 242 | 243 | # check z shape 244 | z_dummy = self.encoder(x[[0]]) 245 | z = torch.zeros(len(x), *list(z_dummy.shape[1:]), dtype=torch.float).to(x.device) 246 | 247 | l_score = [] 248 | for i in range(n_sample): 249 | z_sample = self.sample_latent(z) 250 | recon_loss = - self.decoder.log_likelihood(x, z_sample) 251 | score = recon_loss 252 | l_score.append(score) 253 | score = torch.stack(l_score) 254 | return - torch.logsumexp(-score, dim=0) 255 | 256 | def elbo(self, x): 257 | l_score = [] 258 | z = self.encoder(x) 259 | for i in range(self.n_sample): 260 | z_sample = self.sample_latent(z) 261 | recon_loss = - self.decoder.log_likelihood(x, z_sample) 262 | kl_loss = self.kl_loss(z) 263 | score = recon_loss + kl_loss 264 | l_score.append(score) 265 | return torch.stack(l_score).mean(dim=0) 266 | 267 | def log_posterior(self, z, z_sample): 268 | half_chan = int(z.shape[1] / 2) 269 | mu, log_sig = z[:, :half_chan], z[:, half_chan:] 270 | 271 | log_p = torch.distributions.Normal(mu, torch.exp(log_sig)).log_prob(z_sample) 272 | log_p = log_p.view(len(z), -1).sum(-1) 273 | return log_p 274 | 275 | def log_prior(self, z_sample): 276 | log_p = torch.distributions.Normal(torch.zeros_like(z_sample), torch.ones_like(z_sample)).log_prob(z_sample) 277 | log_p = log_p.view(len(z_sample), -1).sum(-1) 278 | return log_p 279 | 280 | def posterior_entropy(self, z): 281 | half_chan = int(z.shape[1] / 2) 282 | mu, log_sig = z[:, :half_chan], z[:, half_chan:] 283 | D = mu.shape[1] 284 | pi = torch.tensor(np.pi, dtype=torch.float32).to(z.device) 285 | term1 = D / 2 286 | term2 = D / 2 * torch.log(2 * pi) 287 | term3 = log_sig.view(len(log_sig), -1).sum(dim=-1) 288 | return term1 + term2 + term3 289 | 290 | def _set_z_shape(self, x): 291 | if self.z_shape is not None: 292 | return 293 | with torch.no_grad(): 294 | dummy_z = self.encode(x[[0]]) 295 | dummy_z = self.sample_latent(dummy_z) 296 | z_shape = dummy_z.shape 297 | self.z_shape = z_shape[1:] 298 | 299 | def sample_z(self, n_sample, device): 300 | z_shape = (n_sample,) + self.z_shape 301 | return torch.randn(z_shape, device=device, dtype=torch.float) 302 | 303 | def sample(self, n_sample, device): 304 | z = self.sample_z(n_sample, device) 305 | return {'sample_x': self.decoder.sample(z)} 306 | 307 | 308 | class WAE(AE): 309 | """Wassertstein Autoencoder with MMD loss""" 310 | def __init__(self, encoder, decoder, reg=1., bandwidth='median', prior='gaussian'): 311 | super().__init__(encoder, decoder) 312 | if not isinstance(bandwidth, str): 313 | bandwidth = float(bandwidth) 314 | self.bandwidth = bandwidth 315 | self.reg = reg # coefficient for MMD loss 316 | self.prior = prior 317 | 318 | def train_step(self, x, optimizer, y=None, **kwargs): 319 | optimizer.zero_grad() 320 | # forward step 321 | z = self.encoder(x) 322 | recon = self.decoder(z) 323 | # recon_loss = torch.mean(self.decoder.square_error(x, recon)) 324 | recon_loss = torch.mean((x - recon) ** 2) 325 | 326 | # MMD step 327 | z_prior = self.sample_prior(z) 328 | mmd_loss = self.mmd(z_prior, z) 329 | 330 | loss = recon_loss + mmd_loss * self.reg 331 | loss.backward() 332 | optimizer.step() 333 | return {'loss': loss.item(), 'recon_loss': recon_loss, 'mmd_loss': mmd_loss.item()} 334 | 335 | def sample_prior(self, z): 336 | if self.prior == 'gaussian': 337 | return torch.randn_like(z) 338 | elif self.prior == 'uniform_tanh': 339 | return torch.rand_like(z) * 2 - 1 340 | else: 341 | raise ValueError(f'invalid prior {self.prior}') 342 | 343 | def mmd(self, X1, X2): 344 | if len(X1.shape) == 4: 345 | X1 = X1.view(len(X1), -1) 346 | if len(X2.shape) == 4: 347 | X2 = X2.view(len(X2), -1) 348 | 349 | N1 = len(X1) 350 | X1_sq = X1.pow(2).sum(1).unsqueeze(0) 351 | X1_cr = torch.mm(X1, X1.t()) 352 | X1_dist = X1_sq + X1_sq.t() - 2 * X1_cr 353 | 354 | N2 = len(X2) 355 | X2_sq = X2.pow(2).sum(1).unsqueeze(0) 356 | X2_cr = torch.mm(X2, X2.t()) 357 | X2_dist = X2_sq + X2_sq.t() - 2 * X2_cr 358 | 359 | X12 = torch.mm(X1, X2.t()) 360 | X12_dist = X1_sq.t() + X2_sq - 2 * X12 361 | 362 | # median heuristic to select bandwidth 363 | if self.bandwidth == 'median': 364 | X1_triu = X1_dist[torch.triu(torch.ones_like(X1_dist), diagonal=1) == 1] 365 | bandwidth1 = torch.median(X1_triu) 366 | X2_triu = X2_dist[torch.triu(torch.ones_like(X2_dist), diagonal=1) == 1] 367 | bandwidth2 = torch.median(X2_triu) 368 | bandwidth_sq = ((bandwidth1 + bandwidth2) / 2).detach() 369 | else: 370 | bandwidth_sq = (self.bandwidth ** 2) 371 | 372 | C = - 0.5 / bandwidth_sq 373 | K11 = torch.exp(C * X1_dist) 374 | K22 = torch.exp(C * X2_dist) 375 | K12 = torch.exp(C * X12_dist) 376 | K11 = (1 - torch.eye(N1).to(X1.device)) * K11 377 | K22 = (1 - torch.eye(N2).to(X1.device)) * K22 378 | mmd = K11.sum() / N1 / (N1 - 1) + K22.sum() / N2 / (N2 - 1) - 2 * K12.mean() 379 | return mmd 380 | 381 | 382 | 383 | -------------------------------------------------------------------------------- /models/energybased.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.langevin import sample_langevin 7 | 8 | 9 | class EnergyBasedModel(nn.Module): 10 | def __init__(self, net, alpha=1, step_size=10, sample_step=60, 11 | noise_std=0.005, buffer_size=10000, replay_ratio=0.95, 12 | langevin_clip_grad=0.01, clip_x=(0, 1)): 13 | super().__init__() 14 | self.net = net 15 | self.alpha = alpha 16 | self.step_size = step_size 17 | self.sample_step = sample_step 18 | self.noise_std = noise_std 19 | self.buffer_size = buffer_size 20 | self.buffer = SampleBuffer(max_samples=buffer_size, replay_ratio=replay_ratio) 21 | self.replay_ratio = replay_ratio 22 | self.replay = True if self.replay_ratio > 0 else False 23 | self.langevin_clip_grad = langevin_clip_grad 24 | self.clip_x = clip_x 25 | 26 | self.own_optimizer = False 27 | 28 | def forward(self, x): 29 | return self.net(x).view(-1) 30 | 31 | def predict(self, x): 32 | return self(x) 33 | 34 | def validation_step(self, x, y=None): 35 | with torch.no_grad(): 36 | pos_e = self(x) 37 | 38 | return {'loss': pos_e.mean(), 39 | 'predict': pos_e, 40 | } 41 | 42 | def train_step(self, x, optimizer, clip_grad=None, y=None): 43 | neg_x = self.sample(shape=x.shape, device=x.device, replay=self.replay) 44 | optimizer.zero_grad() 45 | pos_e = self(x) 46 | neg_e = self(neg_x) 47 | 48 | ebm_loss = pos_e.mean() - neg_e.mean() 49 | reg_loss = (pos_e ** 2).mean() + (neg_e ** 2).mean() 50 | weight_norm = sum([(w ** 2).sum() for w in self.net.parameters()]) 51 | loss = ebm_loss + self.alpha * reg_loss # + self.beta * weight_norm 52 | loss.backward() 53 | 54 | if clip_grad is not None: 55 | # torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=clip_grad) 56 | clip_grad_ebm(self.parameters(), optimizer) 57 | 58 | optimizer.step() 59 | return {'loss': loss.item(), 60 | 'ebm_loss': ebm_loss.item(), 'pos_e': pos_e.mean().item(), 'neg_e': neg_e.mean().item(), 61 | 'reg_loss': reg_loss.item(), 'neg_sample': neg_x.detach().cpu(), 62 | 'weight_norm': weight_norm.item()} 63 | 64 | def sample(self, shape, device, replay=True, intermediate=False, step_size=None, 65 | sample_step=None): 66 | if step_size is None: 67 | step_size = self.step_size 68 | if sample_step is None: 69 | sample_step = self.sample_step 70 | # initialize 71 | x0 = self.buffer.sample(shape, device, replay=replay) 72 | # run langevin 73 | sample_x = sample_langevin(x0, self, step_size, sample_step, 74 | noise_scale=self.noise_std, 75 | intermediate_samples=intermediate, 76 | clip_x=self.clip_x, 77 | clip_grad=self.langevin_clip_grad, 78 | ) 79 | # push samples 80 | if replay: 81 | self.buffer.push(sample_x) 82 | return sample_x 83 | 84 | 85 | class SampleBuffer: 86 | def __init__(self, max_samples=10000, replay_ratio=0.95, bound=None): 87 | self.max_samples = max_samples 88 | self.buffer = [] 89 | self.replay_ratio = replay_ratio 90 | if bound is None: 91 | self.bound = (0, 1) 92 | else: 93 | self.bound = bound 94 | 95 | def __len__(self): 96 | return len(self.buffer) 97 | 98 | def push(self, samples): 99 | samples = samples.detach().to('cpu') 100 | 101 | for sample in samples: 102 | self.buffer.append(sample) 103 | 104 | if len(self.buffer) > self.max_samples: 105 | self.buffer.pop(0) 106 | 107 | def get(self, n_samples): 108 | samples = random.choices(self.buffer, k=n_samples) 109 | samples = torch.stack(samples, 0) 110 | return samples 111 | 112 | def sample(self, shape, device, replay=False): 113 | if len(self.buffer) < 1 or not replay: # empty buffer 114 | return self.random(shape, device) 115 | 116 | n_replay = (np.random.rand(shape[0]) < self.replay_ratio).sum() 117 | 118 | replay_sample = self.get(n_replay).to(device) 119 | n_random = shape[0] - n_replay 120 | if n_random > 0: 121 | random_sample = self.random((n_random,) + shape[1:], device) 122 | return torch.cat([replay_sample, random_sample]) 123 | else: 124 | return replay_sample 125 | 126 | def random(self, shape, device): 127 | if self.bound is None: 128 | r = torch.rand(*shape, dtype=torch.float).to(device) 129 | 130 | elif self.bound == 'spherical': 131 | r = torch.randn(*shape, dtype=torch.float).to(device) 132 | norm = r.view(len(r), -1).norm(dim=-1) 133 | if len(shape) == 4: 134 | r = r / norm[:, None, None, None] 135 | elif len(shape) == 2: 136 | r = r / norm[:, None] 137 | else: 138 | raise NotImplementedError 139 | 140 | elif len(self.bound) == 2: 141 | r = torch.rand(*shape, dtype=torch.float).to(device) 142 | r = r * (self.bound[1] - self.bound[0]) + self.bound[0] 143 | return r 144 | 145 | 146 | class SampleBufferV2: 147 | def __init__(self, max_samples=10000, replay_ratio=0.95): 148 | self.max_samples = max_samples 149 | self.buffer = [] 150 | self.replay_ratio = replay_ratio 151 | 152 | def __len__(self): 153 | return len(self.buffer) 154 | 155 | def push(self, samples): 156 | samples = samples.detach().to('cpu') 157 | 158 | for sample in samples: 159 | self.buffer.append(sample) 160 | 161 | if len(self.buffer) > self.max_samples: 162 | self.buffer.pop(0) 163 | 164 | def get(self, n_samples): 165 | samples = random.choices(self.buffer, k=n_samples) 166 | samples = torch.stack(samples, 0) 167 | return samples 168 | 169 | 170 | def clip_grad_ebm(parameters, optimizer): 171 | with torch.no_grad(): 172 | for group in optimizer.param_groups: 173 | for p in group['params']: 174 | state = optimizer.state[p] 175 | 176 | if 'step' not in state or state['step'] < 1: 177 | continue 178 | 179 | step = state['step'] 180 | exp_avg_sq = state['exp_avg_sq'] 181 | _, beta2 = group['betas'] 182 | 183 | bound = 3 * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1 184 | p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound)) 185 | -------------------------------------------------------------------------------- /models/igebm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import utils 6 | 7 | 8 | class SpectralNorm: 9 | def __init__(self, name, bound=False): 10 | self.name = name 11 | self.bound = bound 12 | 13 | def compute_weight(self, module): 14 | weight = getattr(module, self.name + '_orig') 15 | u = getattr(module, self.name + '_u') 16 | size = weight.size() 17 | weight_mat = weight.contiguous().view(size[0], -1) 18 | 19 | with torch.no_grad(): 20 | v = weight_mat.t() @ u 21 | v = v / v.norm() 22 | u = weight_mat @ v 23 | u = u / u.norm() 24 | 25 | sigma = u @ weight_mat @ v 26 | 27 | if self.bound: 28 | weight_sn = weight / (sigma + 1e-6) * torch.clamp(sigma, max=1) 29 | 30 | else: 31 | weight_sn = weight / sigma 32 | 33 | return weight_sn, u 34 | 35 | @staticmethod 36 | def apply(module, name, bound): 37 | fn = SpectralNorm(name, bound) 38 | 39 | weight = getattr(module, name) 40 | del module._parameters[name] 41 | module.register_parameter(name + '_orig', weight) 42 | input_size = weight.size(0) 43 | u = weight.new_empty(input_size).normal_() 44 | module.register_buffer(name, weight) 45 | module.register_buffer(name + '_u', u) 46 | 47 | module.register_forward_pre_hook(fn) 48 | 49 | return fn 50 | 51 | def __call__(self, module, input): 52 | weight_sn, u = self.compute_weight(module) 53 | setattr(module, self.name, weight_sn) 54 | setattr(module, self.name + '_u', u) 55 | 56 | 57 | def spectral_norm(module, init=True, std=1, bound=False): 58 | if init: 59 | nn.init.normal_(module.weight, 0, std) 60 | 61 | if hasattr(module, 'bias') and module.bias is not None: 62 | module.bias.data.zero_() 63 | 64 | SpectralNorm.apply(module, 'weight', bound=bound) 65 | 66 | return module 67 | 68 | 69 | class ResBlock(nn.Module): 70 | def __init__(self, in_channel, out_channel, n_class=None, downsample=False, use_spectral_norm=True): 71 | super().__init__() 72 | 73 | self.conv1 = nn.Conv2d(in_channel, 74 | out_channel, 75 | 3, 76 | padding=1, 77 | bias=False if n_class is not None else True) 78 | 79 | self.conv2 = nn.Conv2d(out_channel, 80 | out_channel, 81 | 3, 82 | padding=1, 83 | bias=False if n_class is not None else True) 84 | 85 | if use_spectral_norm: 86 | self.conv1 = spectral_norm(self.conv1) 87 | self.conv2 = spectral_norm(self.conv2, std=1e-10, bound=True) 88 | 89 | self.class_embed = None 90 | 91 | if n_class is not None: 92 | class_embed = nn.Embedding(n_class, out_channel * 2 * 2) 93 | class_embed.weight.data[:, : out_channel * 2] = 1 94 | class_embed.weight.data[:, out_channel * 2 :] = 0 95 | 96 | self.class_embed = class_embed 97 | 98 | self.skip = None 99 | 100 | if in_channel != out_channel or downsample: 101 | if use_spectral_norm: 102 | self.skip = nn.Sequential( 103 | spectral_norm(nn.Conv2d(in_channel, out_channel, 1, bias=False)) 104 | ) 105 | else: 106 | self.skip = nn.Sequential( 107 | nn.Conv2d(in_channel, out_channel, 1, bias=False) 108 | ) 109 | 110 | self.downsample = downsample 111 | 112 | def forward(self, input, class_id=None): 113 | out = input 114 | 115 | out = self.conv1(out) 116 | 117 | if self.class_embed is not None: 118 | embed = self.class_embed(class_id).view(input.shape[0], -1, 1, 1) 119 | weight1, weight2, bias1, bias2 = embed.chunk(4, 1) 120 | out = weight1 * out + bias1 121 | 122 | out = F.leaky_relu(out, negative_slope=0.2) 123 | 124 | out = self.conv2(out) 125 | 126 | if self.class_embed is not None: 127 | out = weight2 * out + bias2 128 | 129 | if self.skip is not None: 130 | skip = self.skip(input) 131 | 132 | else: 133 | skip = input 134 | 135 | out = out + skip 136 | 137 | if self.downsample: 138 | out = F.avg_pool2d(out, 2) 139 | 140 | out = F.leaky_relu(out, negative_slope=0.2) 141 | 142 | return out 143 | 144 | 145 | class IGEBM(nn.Module): 146 | def __init__(self, in_chan=3, n_class=None): 147 | super().__init__() 148 | 149 | self.conv1 = spectral_norm(nn.Conv2d(in_chan, 128, 3, padding=1), std=1) 150 | 151 | self.blocks = nn.ModuleList( 152 | [ 153 | ResBlock(128, 128, n_class, downsample=True), 154 | ResBlock(128, 128, n_class), 155 | ResBlock(128, 256, n_class, downsample=True), 156 | ResBlock(256, 256, n_class), 157 | ResBlock(256, 256, n_class, downsample=True), 158 | ResBlock(256, 256, n_class), 159 | ] 160 | ) 161 | 162 | self.linear = nn.Linear(256, 1) 163 | 164 | def forward(self, input, class_id=None): 165 | out = self.conv1(input) 166 | 167 | out = F.leaky_relu(out, negative_slope=0.2) 168 | 169 | for block in self.blocks: 170 | out = block(out, class_id) 171 | 172 | out = F.relu(out) 173 | out = out.view(out.shape[0], out.shape[1], -1).sum(2) 174 | out = self.linear(out) 175 | 176 | return out 177 | 178 | -------------------------------------------------------------------------------- /models/langevin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.autograd as autograd 4 | 5 | def clip_vector_norm(x, max_norm): 6 | norm = x.norm(dim=-1, keepdim=True) 7 | x = x * ((norm < max_norm).to(torch.float) + (norm > max_norm).to(torch.float) * max_norm/norm + 1e-6) 8 | return x 9 | 10 | 11 | def sample_langevin(x, model, stepsize, n_steps, noise_scale=None, intermediate_samples=False, 12 | clip_x=None, clip_grad=None, reject_boundary=False, noise_anneal=None, 13 | spherical=False, mh=False): 14 | """Draw samples using Langevin dynamics 15 | x: torch.Tensor, initial points 16 | model: An energy-based model. returns energy 17 | stepsize: float 18 | n_steps: integer 19 | noise_scale: Optional. float. If None, set to np.sqrt(stepsize * 2) 20 | clip_x : tuple (start, end) or None boundary of square domain 21 | reject_boundary: Reject out-of-domain samples if True. otherwise clip. 22 | """ 23 | assert not ((stepsize is None) and (noise_scale is None)), 'stepsize and noise_scale cannot be None at the same time' 24 | if noise_scale is None: 25 | noise_scale = np.sqrt(stepsize * 2) 26 | if stepsize is None: 27 | stepsize = (noise_scale ** 2) / 2 28 | noise_scale_ = noise_scale 29 | 30 | l_samples = [] 31 | l_dynamics = []; l_drift = []; l_diffusion = [] 32 | x.requires_grad = True 33 | for i_step in range(n_steps): 34 | l_samples.append(x.detach().to('cpu')) 35 | noise = torch.randn_like(x) * noise_scale_ 36 | out = model(x) 37 | grad = autograd.grad(out.sum(), x, only_inputs=True)[0] 38 | if clip_grad is not None: 39 | grad = clip_vector_norm(grad, max_norm=clip_grad) 40 | dynamics = - stepsize * grad + noise # negative! 41 | xnew = x + dynamics 42 | if clip_x is not None: 43 | if reject_boundary: 44 | accept = ((xnew >= clip_x[0]) & (xnew <= clip_x[1])).view(len(x), -1).all(dim=1) 45 | reject = ~ accept 46 | xnew[reject] = x[reject] 47 | x = xnew 48 | else: 49 | x = torch.clamp(xnew, clip_x[0], clip_x[1]) 50 | else: 51 | x = xnew 52 | 53 | if spherical: 54 | if len(x.shape) == 4: 55 | x = x / x.view(len(x), -1).norm(dim=1)[:, None, None ,None] 56 | else: 57 | x = x / x.norm(dim=1, keepdim=True) 58 | 59 | if noise_anneal is not None: 60 | noise_scale_ = noise_scale / (1 + i_step) 61 | 62 | l_dynamics.append(dynamics.detach().to('cpu')) 63 | l_drift.append((- stepsize * grad).detach().cpu()) 64 | l_diffusion.append(noise.detach().cpu()) 65 | l_samples.append(x.detach().to('cpu')) 66 | 67 | if intermediate_samples: 68 | return l_samples, l_dynamics, l_drift, l_diffusion 69 | else: 70 | return x.detach() 71 | 72 | 73 | def sample_langevin_v2(x, model, stepsize, n_steps, noise_scale=None, intermediate_samples=False, 74 | clip_x=None, clip_grad=None, reject_boundary=False, noise_anneal=None, 75 | spherical=False, mh=False, temperature=None): 76 | """Langevin Monte Carlo 77 | x: torch.Tensor, initial points 78 | model: An energy-based model. returns energy 79 | stepsize: float 80 | n_steps: integer 81 | noise_scale: Optional. float. If None, set to np.sqrt(stepsize * 2) 82 | clip_x : tuple (start, end) or None boundary of square domain 83 | reject_boundary: Reject out-of-domain samples if True. otherwise clip. 84 | """ 85 | assert not ((stepsize is None) and (noise_scale is None)), 'stepsize and noise_scale cannot be None at the same time' 86 | if noise_scale is None: 87 | noise_scale = np.sqrt(stepsize * 2) 88 | if stepsize is None: 89 | stepsize = (noise_scale ** 2) / 2 90 | noise_scale_ = noise_scale 91 | stepsize_ = stepsize 92 | if temperature is None: 93 | temperature = 1. 94 | 95 | # initial data 96 | x.requires_grad = True 97 | E_x = model(x) 98 | grad_E_x = autograd.grad(E_x.sum(), x, only_inputs=True)[0] 99 | if clip_grad is not None: 100 | grad_E_x = clip_vector_norm(grad_E_x, max_norm=clip_grad) 101 | E_y = E_x; grad_E_y = grad_E_x; 102 | 103 | l_samples = [x.detach().to('cpu')] 104 | l_dynamics = []; l_drift = []; l_diffusion = []; l_accept = [] 105 | for i_step in range(n_steps): 106 | noise = torch.randn_like(x) * noise_scale_ 107 | dynamics = - stepsize_ * grad_E_x / temperature + noise 108 | y = x + dynamics 109 | reject = torch.zeros(len(y), dtype=torch.bool) 110 | 111 | if clip_x is not None: 112 | if reject_boundary: 113 | accept = ((y >= clip_x[0]) & (y <= clip_x[1])).view(len(x), -1).all(dim=1) 114 | reject = ~ accept 115 | y[reject] = x[reject] 116 | else: 117 | y = torch.clamp(y, clip_x[0], clip_x[1]) 118 | 119 | if spherical: 120 | y = y / y.norm(dim=1, p=2, keepdim=True) 121 | 122 | # y_accept = y[~reject] 123 | # E_y[~reject] = model(y_accept) 124 | # grad_E_y[~reject] = autograd.grad(E_y.sum(), y_accept, only_inputs=True)[0] 125 | E_y = model(y) 126 | grad_E_y = autograd.grad(E_y.sum(), y, only_inputs=True)[0] 127 | 128 | if clip_grad is not None: 129 | grad_E_y = clip_vector_norm(grad_E_y, max_norm=clip_grad) 130 | 131 | if mh: 132 | y_to_x = ((grad_E_x + grad_E_y) * stepsize_ - noise).view(len(x), -1).norm(p=2, dim=1, keepdim=True) ** 2 133 | x_to_y = (noise).view(len(x), -1).norm(dim=1, keepdim=True, p=2) ** 2 134 | transition = - (y_to_x - x_to_y) / 4 / stepsize_ # B x 1 135 | prob = -E_y + E_x 136 | accept_prob = torch.exp((transition + prob) / temperature)[:,0] # B 137 | reject = (torch.rand_like(accept_prob) > accept_prob) # | reject 138 | y[reject] = x[reject] 139 | E_y[reject] = E_x[reject] 140 | grad_E_y[reject] = grad_E_x[reject] 141 | x = y; E_x = E_y; grad_E_x = grad_E_y 142 | l_accept.append(~reject) 143 | 144 | x = y; E_x = E_y; grad_E_x = grad_E_y 145 | 146 | if noise_anneal is not None: 147 | noise_scale_ = noise_scale / (1 + i_step) 148 | 149 | l_dynamics.append(dynamics.detach().cpu()) 150 | l_drift.append((- stepsize * grad_E_x).detach().cpu()) 151 | l_diffusion.append(noise.detach().cpu()) 152 | l_samples.append(x.detach().cpu()) 153 | 154 | return {'sample': x.detach(), 'l_samples': l_samples, 'l_dynamics': l_dynamics, 155 | 'l_drift': l_drift, 'l_diffusion': l_diffusion, 'l_accept': l_accept} 156 | 157 | -------------------------------------------------------------------------------- /models/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mmd(X1, X2, bandwidth='median'): 5 | """Compute Maximum Mean Discrepancy""" 6 | if len(X1.shape) == 4: 7 | X1 = X1.view(len(X1), -1) 8 | if len(X2.shape) == 4: 9 | X2 = X2.view(len(X2), -1) 10 | 11 | N1 = len(X1) 12 | X1_sq = X1.pow(2).sum(1).unsqueeze(0) 13 | X1_cr = torch.mm(X1, X1.t()) 14 | X1_dist = X1_sq + X1_sq.t() - 2 * X1_cr 15 | 16 | N2 = len(X2) 17 | X2_sq = X2.pow(2).sum(1).unsqueeze(0) 18 | X2_cr = torch.mm(X2, X2.t()) 19 | X2_dist = X2_sq + X2_sq.t() - 2 * X2_cr 20 | 21 | X12 = torch.mm(X1, X2.t()) 22 | X12_dist = X1_sq.t() + X2_sq - 2 * X12 23 | 24 | # median heuristic to select bandwidth 25 | if bandwidth == 'median': 26 | X1_triu = X1_dist[torch.triu(torch.ones_like(X1_dist), diagonal=1) == 1] 27 | bandwidth1 = torch.median(X1_triu) 28 | X2_triu = X2_dist[torch.triu(torch.ones_like(X2_dist), diagonal=1) == 1] 29 | bandwidth2 = torch.median(X2_triu) 30 | bandwidth_sq = ((bandwidth1 + bandwidth2) / 2).detach() 31 | else: 32 | bandwidth_sq = (bandwidth ** 2) 33 | 34 | C = - 0.5 / bandwidth_sq 35 | K11 = torch.exp(C * X1_dist) 36 | K22 = torch.exp(C * X2_dist) 37 | K12 = torch.exp(C * X12_dist) 38 | K11 = (1 - torch.eye(N1).to(X1.device)) * K11 39 | K22 = (1 - torch.eye(N2).to(X1.device)) * K22 40 | mmd = K11.sum() / N1 / (N1 - 1) + K22.sum() / N2 / (N2 - 1) - 2 * K12.mean() 41 | return mmd 42 | 43 | 44 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models 6 | from torch import optim 7 | from torch.distributions import Normal 8 | from torch.distributions.multivariate_normal import MultivariateNormal 9 | import numpy as np 10 | from models.spectral_norm import spectral_norm 11 | from models import igebm 12 | 13 | 14 | class DummyDistribution(nn.Module): 15 | """ Function-less class introduced for backward-compatibility of model checkpoint files. """ 16 | def __init__(self, net): 17 | super().__init__() 18 | self.net = net 19 | self.register_buffer('sigma', torch.tensor(0., dtype=torch.float)) 20 | 21 | def forward(self, x): 22 | return self.net(x) 23 | 24 | 25 | class IsotropicGaussian(nn.Module): 26 | """Isotripic Gaussian density function paramerized by a neural net. 27 | standard deviation is a free scalar parameter""" 28 | def __init__(self, net, sigma=1., sigma_trainable=False, error_normalize=True, deterministic=False): 29 | super().__init__() 30 | self.net = net 31 | self.sigma_trainable = sigma_trainable 32 | self.error_normalize = error_normalize 33 | self.deterministic = deterministic 34 | if sigma_trainable: 35 | # self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float)) 36 | self.register_parameter('sigma', nn.Parameter(torch.tensor(sigma, dtype=torch.float))) 37 | else: 38 | self.register_buffer('sigma', torch.tensor(sigma, dtype=torch.float)) 39 | 40 | def log_likelihood(self, x, z): 41 | decoder_out = self.net(z) 42 | if self.deterministic: 43 | return - ((x - decoder_out)**2).view((x.shape[0], -1)).sum(dim=1) 44 | else: 45 | D = torch.prod(torch.tensor(x.shape[1:])) 46 | # sig = torch.tensor(1, dtype=torch.float32) 47 | sig = self.sigma 48 | const = - D * 0.5 * torch.log(2 * torch.tensor(np.pi, dtype=torch.float32)) - D * torch.log(sig) 49 | loglik = const - 0.5 * ((x - decoder_out)**2).view((x.shape[0], -1)).sum(dim=1) / (sig ** 2) 50 | return loglik 51 | 52 | def error(self, x, x_hat): 53 | if not self.error_normalize: 54 | return (((x - x_hat) / self.sigma) ** 2).view(len(x), -1).sum(-1) 55 | else: 56 | return ((x - x_hat) ** 2).view(len(x), -1).mean(-1) 57 | 58 | def forward(self, z): 59 | """returns reconstruction""" 60 | return self.net(z) 61 | 62 | def sample(self, z): 63 | if self.deterministic: 64 | return self.mean(z) 65 | else: 66 | x_hat = self.net(z) 67 | return x_hat + torch.randn_like(x_hat) * self.sigma 68 | 69 | def mean(self, z): 70 | return self.net(z) 71 | 72 | def max_log_likelihood(self, x): 73 | if self.deterministic: 74 | return torch.tensor(0., dtype=torch.float, device=x.device) 75 | else: 76 | D = torch.prod(torch.tensor(x.shape[1:])) 77 | sig = self.sigma 78 | const = - D * 0.5 * torch.log(2 * torch.tensor(np.pi, dtype=torch.float32)) - D * torch.log(sig) 79 | return const 80 | 81 | class IsotropicLaplace(nn.Module): 82 | """Isotropic Laplace density function -- equivalent to using L1 error """ 83 | def __init__(self, net, sigma=0.1, sigma_trainable=False): 84 | super().__init__() 85 | self.net = net 86 | self.sigma_trainable = sigma_trainable 87 | if sigma_trainable: 88 | self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float)) 89 | else: 90 | self.register_buffer('sigma', torch.tensor(sigma, dtype=torch.float)) 91 | 92 | def log_likelihood(self, x, z): 93 | # decoder_out = self.net(z) 94 | # D = torch.prod(torch.tensor(x.shape[1:])) 95 | # sig = torch.tensor(1, dtype=torch.float32) 96 | # const = - D * 0.5 * torch.log(2 * torch.tensor(np.pi, dtype=torch.float32)) - D * torch.log(sig) 97 | # loglik = const - 0.5 * (torch.abs(x - decoder_out)).view((x.shape[0], -1)).sum(dim=1) / (sig ** 2) 98 | # return loglik 99 | raise NotImplementedError 100 | 101 | def error(self, x, x_hat): 102 | if self.sigma_trainable: 103 | return ((torch.abs(x - x_hat) / self.sigma)).view(len(x), -1).sum(-1) 104 | else: 105 | return (torch.abs(x - x_hat)).view(len(x), -1).mean(-1) 106 | 107 | def forward(self, z): 108 | """returns reconstruction""" 109 | return self.net(z) 110 | 111 | def sample(self, z): 112 | # x_hat = self.net(z) 113 | # return x_hat + torch.randn_like(x_hat) * self.sigma 114 | raise NotImplementedError 115 | 116 | 117 | class ConvNet2FC(nn.Module): 118 | """additional 1x1 conv layer at the top""" 119 | def __init__(self, in_chan=1, out_chan=64, nh=8, nh_mlp=512, out_activation='linear', use_spectral_norm=False): 120 | """nh: determines the numbers of conv filters""" 121 | super(ConvNet2FC, self).__init__() 122 | self.conv1 = nn.Conv2d(in_chan, nh * 4, kernel_size=3, bias=True) 123 | self.conv2 = nn.Conv2d(nh * 4, nh * 8, kernel_size=3, bias=True) 124 | self.max1 = nn.MaxPool2d(kernel_size=2, stride=2) 125 | self.conv3 = nn.Conv2d(nh * 8, nh * 8, kernel_size=3, bias=True) 126 | self.conv4 = nn.Conv2d(nh * 8, nh * 16, kernel_size=3, bias=True) 127 | self.max2 = nn.MaxPool2d(kernel_size=2, stride=2) 128 | self.conv5 = nn.Conv2d(nh * 16, nh_mlp, kernel_size=4, bias=True) 129 | self.conv6 = nn.Conv2d(nh_mlp, out_chan, kernel_size=1, bias=True) 130 | self.in_chan, self.out_chan = in_chan, out_chan 131 | self.out_activation = get_activation(out_activation) 132 | 133 | if use_spectral_norm: 134 | self.conv1 = spectral_norm(self.conv1) 135 | self.conv2 = spectral_norm(self.conv2) 136 | self.conv3 = spectral_norm(self.conv3) 137 | self.conv4 = spectral_norm(self.conv4) 138 | self.conv5 = spectral_norm(self.conv5) 139 | 140 | layers = [self.conv1, 141 | nn.ReLU(), 142 | self.conv2, 143 | nn.ReLU(), 144 | self.max1, 145 | self.conv3, 146 | nn.ReLU(), 147 | self.conv4, 148 | nn.ReLU(), 149 | self.max2, 150 | self.conv5, 151 | nn.ReLU(), 152 | self.conv6,] 153 | if self.out_activation is not None: 154 | layers.append(self.out_activation) 155 | 156 | 157 | self.net = nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | return self.net(x) 161 | 162 | 163 | class DeConvNet2(nn.Module): 164 | def __init__(self, in_chan=1, out_chan=1, nh=8, out_activation='linear', 165 | use_spectral_norm=False): 166 | """nh: determines the numbers of conv filters""" 167 | super(DeConvNet2, self).__init__() 168 | self.conv1 = nn.ConvTranspose2d(in_chan, nh * 16, kernel_size=4, bias=True) 169 | self.conv2 = nn.ConvTranspose2d(nh * 16, nh * 8, kernel_size=3, bias=True) 170 | self.conv3 = nn.ConvTranspose2d(nh * 8, nh * 8, kernel_size=3, bias=True) 171 | self.conv4 = nn.ConvTranspose2d(nh * 8, nh * 4, kernel_size=3, bias=True) 172 | self.conv5 = nn.ConvTranspose2d(nh * 4, out_chan, kernel_size=3, bias=True) 173 | self.in_chan, self.out_chan = in_chan, out_chan 174 | self.out_activation = get_activation(out_activation) 175 | 176 | if use_spectral_norm: 177 | self.conv1 = spectral_norm(self.conv1) 178 | self.conv2 = spectral_norm(self.conv2) 179 | self.conv3 = spectral_norm(self.conv3) 180 | self.conv4 = spectral_norm(self.conv4) 181 | self.conv5 = spectral_norm(self.conv5) 182 | 183 | def forward(self, x): 184 | x = self.conv1(x) 185 | x = F.relu(x) 186 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 187 | x = self.conv2(x) 188 | x = F.relu(x) 189 | x = self.conv3(x) 190 | x = F.relu(x) 191 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 192 | x = self.conv4(x) 193 | x = F.relu(x) 194 | x = self.conv5(x) 195 | if self.out_activation is not None: 196 | x = self.out_activation(x) 197 | return x 198 | 199 | 200 | ''' 201 | ConvNet for CIFAR10, following architecture in (Ghosh et al., 2019) 202 | but excluding batch normalization 203 | ''' 204 | 205 | class ConvNet64(nn.Module): 206 | """ConvNet architecture for CelebA64 following Ghosh et al., 2019""" 207 | def __init__(self, in_chan=3, out_chan=64, nh=32, out_activation='linear', activation='relu', 208 | num_groups=None, use_bn=False): 209 | super().__init__() 210 | self.conv1 = nn.Conv2d(in_chan, nh * 4, kernel_size=5, bias=True, stride=2) 211 | self.conv2 = nn.Conv2d(nh * 4, nh * 8, kernel_size=5, bias=True, stride=2) 212 | self.conv3 = nn.Conv2d(nh * 8, nh * 16, kernel_size=5, bias=True, stride=2) 213 | self.conv4 = nn.Conv2d(nh * 16, nh * 32, kernel_size=5, bias=True, stride=2) 214 | self.fc1 = nn.Conv2d(nh * 32, out_chan, kernel_size=1, bias=True) 215 | self.in_chan, self.out_chan = in_chan, out_chan 216 | self.num_groups = num_groups 217 | self.use_bn = use_bn 218 | 219 | layers = [] 220 | layers.append(self.conv1) 221 | if num_groups is not None: 222 | layers.append(self.get_norm_layer(num_channels=nh * 4)) 223 | layers.append(get_activation(activation)) 224 | layers.append(self.conv2) 225 | if num_groups is not None: 226 | layers.append(self.get_norm_layer(num_channels=nh * 8)) 227 | layers.append(get_activation(activation)) 228 | layers.append(self.conv3) 229 | if num_groups is not None: 230 | layers.append(self.get_norm_layer(num_channels=nh * 16)) 231 | layers.append(get_activation(activation)) 232 | layers.append(self.conv4) 233 | if num_groups is not None: 234 | layers.append(self.get_norm_layer(num_channels=nh * 32)) 235 | layers.append(get_activation(activation)) 236 | layers.append(self.fc1) 237 | out_activation = get_activation(out_activation) 238 | if out_activation is not None: 239 | layers.append(out_activation) 240 | 241 | self.net = nn.Sequential(*layers) 242 | 243 | def forward(self, x): 244 | return self.net(x) 245 | 246 | def get_norm_layer(self, num_channels): 247 | if self.num_groups is not None: 248 | return nn.GroupNorm(num_groups=self.num_groups, num_channels=num_channels) 249 | elif self.use_bn: 250 | return nn.BatchNorm2d(num_channels) 251 | 252 | 253 | class DeConvNet64(nn.Module): 254 | """ConvNet architecture for CelebA64 following Ghosh et al., 2019""" 255 | def __init__(self, in_chan=64, out_chan=3, nh=32, out_activation='linear', activation='relu', 256 | num_groups=None, use_bn=False): 257 | super().__init__() 258 | self.fc1 = nn.ConvTranspose2d(in_chan, nh * 32, kernel_size=8, bias=True) 259 | self.conv1 = nn.ConvTranspose2d(nh * 32, nh * 16, kernel_size=4, stride=2, padding=1, bias=True) 260 | self.conv2 = nn.ConvTranspose2d(nh * 16, nh * 8, kernel_size=4, stride=2, padding=1, bias=True) 261 | self.conv3 = nn.ConvTranspose2d(nh * 8, nh * 4, kernel_size=4, stride=2, padding=1, bias=True) 262 | self.conv4 = nn.ConvTranspose2d(nh * 4, out_chan, kernel_size=1, bias=True) 263 | self.in_chan, self.out_chan = in_chan, out_chan 264 | self.num_groups = num_groups 265 | self.use_bn = use_bn 266 | 267 | layers = [] 268 | layers.append(self.fc1) 269 | if num_groups is not None: 270 | layers.append(self.get_norm_layer(num_channels=nh * 32)) 271 | layers.append(get_activation(activation)) 272 | layers.append(self.conv1) 273 | if num_groups is not None: 274 | layers.append(self.get_norm_layer(num_channels=nh * 16)) 275 | layers.append(get_activation(activation)) 276 | layers.append(self.conv2) 277 | if num_groups is not None: 278 | layers.append(self.get_norm_layer(num_channels=nh * 8)) 279 | layers.append(get_activation(activation)) 280 | layers.append(self.conv3) 281 | if num_groups is not None: 282 | layers.append(self.get_norm_layer(num_channels=nh * 4)) 283 | layers.append(get_activation(activation)) 284 | layers.append(self.conv4) 285 | out_activation = get_activation(out_activation) 286 | if out_activation is not None: 287 | layers.append(out_activation) 288 | 289 | self.net = nn.Sequential(*layers) 290 | 291 | def forward(self, x): 292 | return self.net(x) 293 | 294 | def get_norm_layer(self, num_channels): 295 | if self.num_groups is not None: 296 | return nn.GroupNorm(num_groups=self.num_groups, num_channels=num_channels) 297 | elif self.use_bn: 298 | return nn.BatchNorm2d(num_channels) 299 | 300 | 301 | class ConvMLPBlock(nn.Module): 302 | def __init__(self, dim, hidden_dim=None, out_dim=None): 303 | super().__init__() 304 | if hidden_dim is None: 305 | hidden_dim = dim 306 | if out_dim is None: 307 | out_dim = dim 308 | 309 | self.block = nn.Sequential( 310 | nn.Conv2d(dim, hidden_dim, kernel_size=1, stride=1), 311 | nn.ReLU(), 312 | nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1)) 313 | 314 | def forward(self, x): 315 | return self.block(x) 316 | 317 | 318 | class DeConvNet3(nn.Module): 319 | def __init__(self, in_chan=1, out_chan=1, nh=32, out_activation='linear', 320 | activation='relu', num_groups=None): 321 | """nh: determines the numbers of conv filters""" 322 | super(DeConvNet3, self).__init__() 323 | self.num_groups = num_groups 324 | self.fc1 = nn.ConvTranspose2d(in_chan, nh * 32, kernel_size=8, bias=True) 325 | self.conv1 = nn.ConvTranspose2d(nh * 32, nh * 16, kernel_size=4, stride=2, padding=1, bias=True) 326 | self.conv2 = nn.ConvTranspose2d(nh * 16, nh * 8, kernel_size=4, stride=2, padding=1, bias=True) 327 | self.conv3 = nn.ConvTranspose2d(nh * 8, out_chan, kernel_size=1, bias=True) 328 | self.in_chan, self.out_chan = in_chan, out_chan 329 | 330 | layers = [self.fc1,] 331 | layers += [] if self.num_groups is None else [self.get_norm_layer(nh*32)] 332 | layers += [get_activation(activation), self.conv1,] 333 | layers += [] if self.num_groups is None else [self.get_norm_layer(nh*16)] 334 | layers += [get_activation(activation), self.conv2,] 335 | layers += [] if self.num_groups is None else [self.get_norm_layer(nh*8)] 336 | layers += [get_activation(activation), self.conv3] 337 | out_activation = get_activation(out_activation) 338 | if out_activation is not None: 339 | layers.append(out_activation) 340 | 341 | self.net = nn.Sequential(*layers) 342 | 343 | def forward(self, x): 344 | return self.net(x) 345 | 346 | def get_norm_layer(self, num_channels): 347 | if self.num_groups is not None: 348 | return nn.GroupNorm(num_groups=self.num_groups, num_channels=num_channels) 349 | # elif self.use_bn: 350 | # return nn.BatchNorm2d(num_channels) 351 | else: 352 | return None 353 | 354 | 355 | class IGEBMEncoder(nn.Module): 356 | """Neural Network used in IGEBM""" 357 | def __init__(self, in_chan=3, out_chan=1, n_class=None, use_spectral_norm=False, keepdim=True): 358 | super().__init__() 359 | self.keepdim = keepdim 360 | self.use_spectral_norm = use_spectral_norm 361 | 362 | if use_spectral_norm: 363 | self.conv1 = spectral_norm(nn.Conv2d(in_chan, 128, 3, padding=1), std=1) 364 | else: 365 | self.conv1 = nn.Conv2d(in_chan, 128, 3, padding=1) 366 | 367 | self.blocks = nn.ModuleList( 368 | [ 369 | igebm.ResBlock(128, 128, n_class, downsample=True, use_spectral_norm=use_spectral_norm), 370 | igebm.ResBlock(128, 128, n_class, use_spectral_norm=use_spectral_norm), 371 | igebm.ResBlock(128, 256, n_class, downsample=True, use_spectral_norm=use_spectral_norm), 372 | igebm.ResBlock(256, 256, n_class, use_spectral_norm=use_spectral_norm), 373 | igebm.ResBlock(256, 256, n_class, downsample=True, use_spectral_norm=use_spectral_norm), 374 | igebm.ResBlock(256, 256, n_class, use_spectral_norm=use_spectral_norm), 375 | ] 376 | ) 377 | 378 | if keepdim: 379 | self.linear = nn.Conv2d(256, out_chan, 1, 1, 0) 380 | else: 381 | self.linear = nn.Linear(256, out_chan) 382 | if use_spectral_norm: 383 | self.linear = spectral_norm(self.linear) 384 | 385 | def forward(self, input, class_id=None): 386 | out = self.conv1(input) 387 | 388 | out = F.leaky_relu(out, negative_slope=0.2) 389 | 390 | for block in self.blocks: 391 | out = block(out, class_id) 392 | 393 | out = F.relu(out) 394 | if self.keepdim: 395 | out = F.adaptive_avg_pool2d(out, (1,1)) 396 | else: 397 | out = out.view(out.shape[0], out.shape[1], -1).sum(2) 398 | 399 | out = self.linear(out) 400 | 401 | return out 402 | 403 | 404 | class SphericalActivation(nn.Module): 405 | def __init__(self): 406 | super().__init__() 407 | 408 | def forward(self, x): 409 | return x / x.norm(p=2, dim=1, keepdim=True) 410 | 411 | 412 | # Fully Connected Network 413 | def get_activation(s_act): 414 | if s_act == 'relu': 415 | return nn.ReLU(inplace=True) 416 | elif s_act == 'sigmoid': 417 | return nn.Sigmoid() 418 | elif s_act == 'softplus': 419 | return nn.Softplus() 420 | elif s_act == 'linear': 421 | return None 422 | elif s_act == 'tanh': 423 | return nn.Tanh() 424 | elif s_act == 'leakyrelu': 425 | return nn.LeakyReLU(0.2, inplace=True) 426 | elif s_act == 'softmax': 427 | return nn.Softmax(dim=1) 428 | elif s_act == 'spherical': 429 | return SphericalActivation() 430 | else: 431 | raise ValueError(f'Unexpected activation: {s_act}') 432 | 433 | 434 | class FCNet(nn.Module): 435 | """fully-connected network""" 436 | def __init__(self, in_dim, out_dim, l_hidden=(50,), activation='sigmoid', out_activation='linear', 437 | use_spectral_norm=False): 438 | super().__init__() 439 | l_neurons = tuple(l_hidden) + (out_dim,) 440 | if isinstance(activation, str): 441 | activation = (activation,) * len(l_hidden) 442 | activation = tuple(activation) + (out_activation,) 443 | 444 | l_layer = [] 445 | prev_dim = in_dim 446 | for i_layer, (n_hidden, act) in enumerate(zip(l_neurons, activation)): 447 | if use_spectral_norm and i_layer < len(l_neurons) - 1: # don't apply SN to the last layer 448 | l_layer.append(spectral_norm(nn.Linear(prev_dim, n_hidden))) 449 | else: 450 | l_layer.append(nn.Linear(prev_dim, n_hidden)) 451 | act_fn = get_activation(act) 452 | if act_fn is not None: 453 | l_layer.append(act_fn) 454 | prev_dim = n_hidden 455 | 456 | self.net = nn.Sequential(*l_layer) 457 | self.in_dim = in_dim 458 | self.out_shape = (out_dim,) 459 | 460 | def forward(self, x): 461 | return self.net(x) 462 | 463 | 464 | class ConvMLP(nn.Module): 465 | def __init__(self, in_dim, out_dim, l_hidden=(50,), activation='sigmoid', out_activation='linear', 466 | likelihood_type='isotropic_gaussian'): 467 | super(ConvMLP, self).__init__() 468 | self.likelihood_type = likelihood_type 469 | l_neurons = tuple(l_hidden) + (out_dim,) 470 | activation = (activation,) * len(l_hidden) 471 | activation = tuple(activation) + (out_activation,) 472 | 473 | l_layer = [] 474 | prev_dim = in_dim 475 | for i_layer, (n_hidden, act) in enumerate(zip(l_neurons, activation)): 476 | l_layer.append(nn.Conv2d(prev_dim, n_hidden, 1, bias=True)) 477 | act_fn = get_activation(act) 478 | if act_fn is not None: 479 | l_layer.append(act_fn) 480 | prev_dim = n_hidden 481 | 482 | self.net = nn.Sequential(*l_layer) 483 | self.in_dim = in_dim 484 | 485 | def forward(self, x): 486 | return self.net(x) 487 | 488 | 489 | class FCResNet(nn.Module): 490 | """FullyConnected Residual Network 491 | Input - Linear - (ResBlock * K) - Linear - Output""" 492 | def __init__(self, in_dim, out_dim, res_dim, n_res_hidden=100, n_resblock=2, out_activation='linear', use_spectral_norm=False): 493 | super().__init__() 494 | l_layer = [] 495 | block = nn.Linear(in_dim, res_dim) 496 | if use_spectral_norm: 497 | block = spectral_norm(block) 498 | l_layer.append(block) 499 | 500 | for i_resblock in range(n_resblock): 501 | block = FCResBlock(res_dim, n_res_hidden, use_spectral_norm=use_spectral_norm) 502 | l_layer.append(block) 503 | l_layer.append(nn.ReLU()) 504 | 505 | block = nn.Linear(res_dim, out_dim) 506 | if use_spectral_norm: 507 | block = spectral_norm(block) 508 | l_layer.append(block) 509 | out_activation = get_activation(out_activation) 510 | if out_activation is not None: 511 | l_layer.append(out_activation) 512 | self.net = nn.Sequential(*l_layer) 513 | 514 | def forward(self, x): 515 | return self.net(x) 516 | 517 | 518 | class FCResBlock(nn.Module): 519 | def __init__(self, res_dim, n_res_hidden, use_spectral_norm=False): 520 | super().__init__() 521 | if use_spectral_norm: 522 | self.net = nn.Sequential(nn.ReLU(), 523 | spectral_norm(nn.Linear(res_dim, n_res_hidden)), 524 | nn.ReLU(), 525 | spectral_norm(nn.Linear(n_res_hidden, res_dim))) 526 | else: 527 | self.net = nn.Sequential(nn.ReLU(), 528 | nn.Linear(res_dim, n_res_hidden), 529 | nn.ReLU(), 530 | nn.Linear(n_res_hidden, res_dim)) 531 | 532 | def forward(self, x): 533 | return x + self.net(x) 534 | -------------------------------------------------------------------------------- /models/modules_sngan.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet architectures from SNGAN(Miyato et al., 2018) 3 | Brought this code from: 4 | https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 5 | """ 6 | # ResNet generator and discriminator 7 | from torch import nn 8 | from torch import Tensor 9 | from torch.nn import Parameter 10 | import torch.nn.functional as F 11 | 12 | import numpy as np 13 | 14 | 15 | def l2normalize(v, eps=1e-12): 16 | return v / (v.norm() + eps) 17 | 18 | 19 | class SpectralNorm(nn.Module): 20 | def __init__(self, module, name='weight', power_iterations=1): 21 | super(SpectralNorm, self).__init__() 22 | self.module = module 23 | self.name = name 24 | self.power_iterations = power_iterations 25 | if not self._made_params(): 26 | self._make_params() 27 | 28 | def _update_u_v(self): 29 | u = getattr(self.module, self.name + "_u") 30 | v = getattr(self.module, self.name + "_v") 31 | w = getattr(self.module, self.name + "_bar") 32 | 33 | height = w.data.shape[0] 34 | for _ in range(self.power_iterations): 35 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 36 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 37 | 38 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 39 | sigma = u.dot(w.view(height, -1).mv(v)) 40 | setattr(self.module, self.name, w / sigma.expand_as(w)) 41 | 42 | def _made_params(self): 43 | try: 44 | u = getattr(self.module, self.name + "_u") 45 | v = getattr(self.module, self.name + "_v") 46 | w = getattr(self.module, self.name + "_bar") 47 | return True 48 | except AttributeError: 49 | return False 50 | 51 | 52 | def _make_params(self): 53 | w = getattr(self.module, self.name) 54 | 55 | height = w.data.shape[0] 56 | width = w.view(height, -1).data.shape[1] 57 | 58 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 59 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 60 | u.data = l2normalize(u.data) 61 | v.data = l2normalize(v.data) 62 | w_bar = Parameter(w.data) 63 | 64 | del self.module._parameters[self.name] 65 | 66 | self.module.register_parameter(self.name + "_u", u) 67 | self.module.register_parameter(self.name + "_v", v) 68 | self.module.register_parameter(self.name + "_bar", w_bar) 69 | 70 | 71 | def forward(self, *args): 72 | self._update_u_v() 73 | return self.module.forward(*args) 74 | 75 | 76 | 77 | 78 | channels = 3 79 | 80 | class ResBlockGenerator(nn.Module): 81 | 82 | def __init__(self, in_channels, out_channels, stride=1): 83 | super(ResBlockGenerator, self).__init__() 84 | 85 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 86 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 87 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.) 88 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.) 89 | 90 | self.model = nn.Sequential( 91 | nn.BatchNorm2d(in_channels), 92 | nn.ReLU(), 93 | nn.Upsample(scale_factor=2), 94 | self.conv1, 95 | nn.BatchNorm2d(out_channels), 96 | nn.ReLU(), 97 | self.conv2 98 | ) 99 | self.bypass = nn.Sequential() 100 | if stride != 1: 101 | self.bypass = nn.Upsample(scale_factor=2) 102 | 103 | def forward(self, x): 104 | return self.model(x) + self.bypass(x) 105 | 106 | 107 | class ResBlockGeneratorNoBN(nn.Module): 108 | 109 | def __init__(self, in_channels, out_channels, stride=1): 110 | super(ResBlockGeneratorNoBN, self).__init__() 111 | 112 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 113 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 114 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.) 115 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.) 116 | 117 | self.model = nn.Sequential( 118 | nn.ReLU(), 119 | nn.Upsample(scale_factor=2), 120 | self.conv1, 121 | nn.ReLU(), 122 | self.conv2 123 | ) 124 | self.bypass = nn.Sequential() 125 | if stride != 1: 126 | self.bypass = nn.Upsample(scale_factor=2) 127 | 128 | def forward(self, x): 129 | return self.model(x) + self.bypass(x) 130 | 131 | 132 | class ResBlockGeneratorGN(nn.Module): 133 | 134 | def __init__(self, in_channels, out_channels, stride=1, num_groups=1): 135 | super().__init__() 136 | 137 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 138 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 139 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.) 140 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.) 141 | 142 | self.model = nn.Sequential( 143 | nn.GroupNorm(num_groups=num_groups, num_channels=in_channels), 144 | nn.ReLU(), 145 | nn.Upsample(scale_factor=2), 146 | self.conv1, 147 | nn.GroupNorm(num_groups=num_groups, num_channels=out_channels), 148 | nn.ReLU(), 149 | self.conv2 150 | ) 151 | self.bypass = nn.Sequential() 152 | if stride != 1: 153 | self.bypass = nn.Upsample(scale_factor=2) 154 | 155 | def forward(self, x): 156 | return self.model(x) + self.bypass(x) 157 | 158 | 159 | 160 | class ResBlockDiscriminator(nn.Module): 161 | 162 | def __init__(self, in_channels, out_channels, stride=1): 163 | super(ResBlockDiscriminator, self).__init__() 164 | 165 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 166 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 167 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.) 168 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.) 169 | 170 | if stride == 1: 171 | self.model = nn.Sequential( 172 | nn.ReLU(), 173 | SpectralNorm(self.conv1), 174 | nn.ReLU(), 175 | SpectralNorm(self.conv2) 176 | ) 177 | else: 178 | self.model = nn.Sequential( 179 | nn.ReLU(), 180 | SpectralNorm(self.conv1), 181 | nn.ReLU(), 182 | SpectralNorm(self.conv2), 183 | nn.AvgPool2d(2, stride=stride, padding=0) 184 | ) 185 | self.bypass = nn.Sequential() 186 | if stride != 1: 187 | 188 | self.bypass_conv = nn.Conv2d(in_channels,out_channels, 1, 1, padding=0) 189 | nn.init.xavier_uniform_(self.bypass_conv.weight.data, np.sqrt(2)) 190 | 191 | self.bypass = nn.Sequential( 192 | SpectralNorm(self.bypass_conv), 193 | nn.AvgPool2d(2, stride=stride, padding=0) 194 | ) 195 | # if in_channels == out_channels: 196 | # self.bypass = nn.AvgPool2d(2, stride=stride, padding=0) 197 | # else: 198 | # self.bypass = nn.Sequential( 199 | # SpectralNorm(nn.Conv2d(in_channels,out_channels, 1, 1, padding=0)), 200 | # nn.AvgPool2d(2, stride=stride, padding=0) 201 | # ) 202 | 203 | 204 | def forward(self, x): 205 | return self.model(x) + self.bypass(x) 206 | 207 | # special ResBlock just for the first layer of the discriminator 208 | class FirstResBlockDiscriminator(nn.Module): 209 | 210 | def __init__(self, in_channels, out_channels, stride=1): 211 | super(FirstResBlockDiscriminator, self).__init__() 212 | 213 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 214 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 215 | self.bypass_conv = nn.Conv2d(in_channels, out_channels, 1, 1, padding=0) 216 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.) 217 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.) 218 | nn.init.xavier_uniform_(self.bypass_conv.weight.data, np.sqrt(2)) 219 | 220 | # we don't want to apply ReLU activation to raw image before convolution transformation. 221 | self.model = nn.Sequential( 222 | SpectralNorm(self.conv1), 223 | nn.ReLU(), 224 | SpectralNorm(self.conv2), 225 | nn.AvgPool2d(2) 226 | ) 227 | self.bypass = nn.Sequential( 228 | nn.AvgPool2d(2), 229 | SpectralNorm(self.bypass_conv), 230 | ) 231 | 232 | def forward(self, x): 233 | return self.model(x) + self.bypass(x) 234 | 235 | GEN_SIZE=128 236 | DISC_SIZE=128 237 | 238 | class Generator(nn.Module): 239 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None): 240 | super().__init__() 241 | self.z_dim = z_dim 242 | 243 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4) 244 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1) 245 | nn.init.xavier_uniform_(self.dense.weight.data, 1.) 246 | nn.init.xavier_uniform_(self.final.weight.data, 1.) 247 | 248 | l_layers = [ 249 | ResBlockGenerator(hidden_dim, hidden_dim, stride=2), 250 | ResBlockGenerator(hidden_dim, hidden_dim, stride=2), 251 | ResBlockGenerator(hidden_dim, hidden_dim, stride=2), 252 | # nn.BatchNorm2d(hidden_dim), # should be uncommented. temporarily commented for backward compatibility 253 | nn.ReLU(), 254 | self.final, 255 | ] 256 | 257 | if out_activation == 'sigmoid': 258 | l_layers.append(nn.Sigmoid()) 259 | elif out_activation == 'tanh': 260 | l_layers.append(nn.Tanh()) 261 | self.model = nn.Sequential(*l_layers) 262 | 263 | def forward(self, z): 264 | return self.model(self.dense(z)) 265 | 266 | 267 | class GeneratorNoBN(nn.Module): 268 | """remove batch normalization 269 | z vector is assumed to be 4-dimensional (fully-convolutional)""" 270 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None): 271 | super().__init__() 272 | self.z_dim = z_dim 273 | 274 | # self.dense = nn.Linear(self.z_dim, 4 * 4 * GEN_SIZE) 275 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4) 276 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1) 277 | nn.init.xavier_uniform_(self.dense.weight.data, 1.) 278 | nn.init.xavier_uniform_(self.final.weight.data, 1.) 279 | 280 | l_layers = [ 281 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2), 282 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2), 283 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2), 284 | nn.ReLU(), 285 | self.final] 286 | 287 | if out_activation == 'sigmoid': 288 | l_layers.append(nn.Sigmoid()) 289 | elif out_activation == 'tanh': 290 | l_layers.append(nn.Tanh()) 291 | self.model = nn.Sequential(*l_layers) 292 | 293 | def forward(self, z): 294 | # return self.model(self.dense(z).view(-1, GEN_SIZE, 4, 4)) 295 | return self.model(self.dense(z)) 296 | 297 | 298 | class GeneratorNoBN64(nn.Module): 299 | """remove batch normalization 300 | z vector is assumed to be 4-dimensional (fully-convolutional) 301 | for generating 64x64 output""" 302 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None): 303 | super().__init__() 304 | self.z_dim = z_dim 305 | 306 | # self.dense = nn.Linear(self.z_dim, 4 * 4 * GEN_SIZE) 307 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4) 308 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1) 309 | nn.init.xavier_uniform_(self.dense.weight.data, 1.) 310 | nn.init.xavier_uniform_(self.final.weight.data, 1.) 311 | 312 | l_layers = [ 313 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2), 314 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2), 315 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2), 316 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2), 317 | nn.ReLU(), 318 | self.final] 319 | 320 | if out_activation == 'sigmoid': 321 | l_layers.append(nn.Sigmoid()) 322 | elif out_activation == 'tanh': 323 | l_layers.append(nn.Tanh()) 324 | self.model = nn.Sequential(*l_layers) 325 | 326 | def forward(self, z): 327 | # return self.model(self.dense(z).view(-1, GEN_SIZE, 4, 4)) 328 | return self.model(self.dense(z)) 329 | 330 | 331 | class GeneratorGN(nn.Module): 332 | """replace Batch Normalization to Group Normalization 333 | z vector is assumed to be 4-dimensional (fully-convolutional)""" 334 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None, num_groups=1): 335 | super().__init__() 336 | self.z_dim = z_dim 337 | 338 | # self.dense = nn.Linear(self.z_dim, 4 * 4 * GEN_SIZE) 339 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4) 340 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1) 341 | nn.init.xavier_uniform_(self.dense.weight.data, 1.) 342 | nn.init.xavier_uniform_(self.final.weight.data, 1.) 343 | 344 | l_layers = [ 345 | ResBlockGeneratorGN(hidden_dim, hidden_dim, stride=2), 346 | ResBlockGeneratorGN(hidden_dim, hidden_dim, stride=2), 347 | ResBlockGeneratorGN(hidden_dim, hidden_dim, stride=2), 348 | nn.GroupNorm(num_groups=num_groups, num_channels=hidden_dim), 349 | nn.ReLU(), 350 | self.final] 351 | 352 | if out_activation == 'sigmoid': 353 | l_layers.append(nn.Sigmoid()) 354 | elif out_activation == 'tanh': 355 | l_layers.append(nn.Tanh()) 356 | self.model = nn.Sequential(*l_layers) 357 | 358 | def forward(self, z): 359 | # return self.model(self.dense(z).view(-1, GEN_SIZE, 4, 4)) 360 | return self.model(self.dense(z)) 361 | 362 | 363 | 364 | 365 | class Discriminator(nn.Module): 366 | def __init__(self): 367 | super(Discriminator, self).__init__() 368 | 369 | self.model = nn.Sequential( 370 | FirstResBlockDiscriminator(channels, DISC_SIZE, stride=2), 371 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE, stride=2), 372 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE), 373 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE), 374 | nn.ReLU(), 375 | nn.AvgPool2d(8), 376 | ) 377 | self.fc = nn.Linear(DISC_SIZE, 1) 378 | nn.init.xavier_uniform_(self.fc.weight.data, 1.) 379 | self.fc = SpectralNorm(self.fc) 380 | 381 | def forward(self, x): 382 | return self.fc(self.model(x).view(-1,DISC_SIZE)) 383 | -------------------------------------------------------------------------------- /models/spectral_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import utils 6 | 7 | 8 | class SpectralNorm: 9 | def __init__(self, name, bound=False): 10 | self.name = name 11 | self.bound = bound 12 | 13 | def compute_weight(self, module): 14 | weight = getattr(module, self.name + '_orig') 15 | u = getattr(module, self.name + '_u') 16 | size = weight.size() 17 | weight_mat = weight.contiguous().view(size[0], -1) 18 | 19 | with torch.no_grad(): 20 | v = weight_mat.t() @ u 21 | v = v / v.norm() 22 | u = weight_mat @ v 23 | u = u / u.norm() 24 | 25 | sigma = u @ weight_mat @ v 26 | 27 | if self.bound: 28 | weight_sn = weight / (sigma + 1e-6) * torch.clamp(sigma, max=1) 29 | 30 | else: 31 | weight_sn = weight / sigma 32 | 33 | return weight_sn, u 34 | 35 | @staticmethod 36 | def apply(module, name, bound): 37 | fn = SpectralNorm(name, bound) 38 | 39 | weight = getattr(module, name) 40 | del module._parameters[name] 41 | module.register_parameter(name + '_orig', weight) 42 | input_size = weight.size(0) 43 | u = weight.new_empty(input_size).normal_() 44 | module.register_buffer(name, weight) 45 | module.register_buffer(name + '_u', u) 46 | 47 | module.register_forward_pre_hook(fn) 48 | 49 | return fn 50 | 51 | def __call__(self, module, input): 52 | weight_sn, u = self.compute_weight(module) 53 | setattr(module, self.name, weight_sn) 54 | setattr(module, self.name + '_u', u) 55 | 56 | 57 | def spectral_norm(module, init=True, std=1, bound=False): 58 | if init: 59 | nn.init.normal_(module.weight, 0, std) 60 | 61 | if hasattr(module, 'bias') and module.bias is not None: 62 | module.bias.data.zero_() 63 | 64 | SpectralNorm.apply(module, 'weight', bound=bound) 65 | 66 | return module 67 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop 4 | 5 | logger = logging.getLogger("ptsemseg") 6 | 7 | key2opt = { 8 | "sgd": SGD, 9 | "adam": Adam, 10 | "asgd": ASGD, 11 | "adamax": Adamax, 12 | "adadelta": Adadelta, 13 | "adagrad": Adagrad, 14 | "rmsprop": RMSprop, 15 | } 16 | 17 | 18 | def get_optimizer(opt_dict, model_params): 19 | opt_dict = opt_dict.copy() 20 | optimizer = _get_optimizer_instance(opt_dict) 21 | 22 | # params = {k: v for k, v in opt_dict.items() if k != "name"} 23 | opt_dict.pop('name') 24 | 25 | optimizer = optimizer(model_params, **opt_dict) 26 | 27 | return optimizer, None 28 | 29 | 30 | def _get_optimizer_instance(opt_dict): 31 | if opt_dict is None: 32 | logger.info("Using SGD optimizer") 33 | return SGD 34 | 35 | else: 36 | opt_name = opt_dict["name"] 37 | if opt_name not in key2opt: 38 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name)) 39 | 40 | logger.info("Using {} optimizer".format(opt_name)) 41 | return key2opt[opt_name] 42 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/pretrained/.gitkeep -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate samples 3 | """ 4 | import os 5 | import numpy as np 6 | import torch 7 | from models import get_model 8 | from omegaconf import OmegaConf 9 | from tqdm import tqdm 10 | import argparse 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('result_dir', type=str, help='path to the directory where yaml file and checkpoint is stored.') 15 | parser.add_argument('config', type=str, help='the name of configs file.') 16 | parser.add_argument('ckpt', type=str, help='the name of the checkpoint file.') 17 | parser.add_argument('--device', default=0, type=int, help='the id of cuda device to use') 18 | parser.add_argument('--n_sample', default=1000, type=int, help='the number of samples to be generated') 19 | parser.add_argument('--zstep', default=None, type=int, help='the number of the latent Langevin MC steps') 20 | parser.add_argument('--xstep', default=None, type=int, help='the number of the visible Langevin MC steps') 21 | parser.add_argument('--x_shape', default=32, type=int, help='the size of a side of an image.') 22 | parser.add_argument('--x_channel', default=3, type=int, help='the number of channels in an image.') 23 | parser.add_argument('--batch_size', default=128, type=int) 24 | parser.add_argument('--name', default=None, type=str, help='additional identifier for the result file.') 25 | parser.add_argument('--replay', default=False, action='store_true', help='to use the sample replay buffer') 26 | args = parser.parse_args() 27 | 28 | 29 | config_path = os.path.join(args.result_dir, args.config) 30 | ckpt_path = os.path.join(args.result_dir, args.ckpt) 31 | device = f'cuda:{args.device}' 32 | print(f'using {device}') 33 | 34 | # load model 35 | print(f'building a model from {config_path}') 36 | cfg = OmegaConf.load(config_path) 37 | model = get_model(cfg) 38 | model_class = cfg['model']['arch'] 39 | 40 | print(f'loading a model from {ckpt_path}') 41 | state = torch.load(ckpt_path, map_location=device) 42 | if 'model_state' in state: 43 | state = state['model_state'] 44 | model.load_state_dict(state) 45 | model.to(device) 46 | model.eval() 47 | 48 | if model_class == 'nae': 49 | print(f'replay {args.replay}') 50 | model.replay = args.replay 51 | 52 | dummy_x = torch.rand(1, args.x_channel, args.x_shape, args.x_shape, dtype=torch.float).to(device) 53 | model._set_x_shape(dummy_x) 54 | model._set_z_shape(dummy_x) 55 | 56 | if args.zstep is not None: 57 | model.z_step = args.zstep 58 | print(f'z_step: {model.z_step}') 59 | if args.xstep is not None: 60 | model.x_step = args.xstep 61 | print(f'x_step: {model.x_step}') 62 | elif model_class == 'vae': 63 | dummy_x = torch.rand(1, 3, args.x_shape, args.x_shape, dtype=torch.float).to(device) 64 | model._set_z_shape(dummy_x) 65 | 66 | 67 | # run sampling 68 | batch_size = args.batch_size 69 | n_batch = int(np.ceil(args.n_sample / batch_size)) 70 | l_sample = [] 71 | for i_batch in tqdm(range(n_batch)): 72 | if i_batch == n_batch - 1: 73 | n_sample = args.n_sample % batch_size if args.n_sample % batch_size else batch_size 74 | else: 75 | n_sample = batch_size 76 | d_sample = model.sample(n_sample=n_sample, device=device) 77 | sample = d_sample['sample_x'].detach() 78 | 79 | # re-quantization 80 | # sample = (sample * 255 + 0.5).clamp(0, 255).cpu().to(torch.uint8) 81 | sample = (sample * 256).clamp(0, 255).cpu().to(torch.uint8) 82 | 83 | sample = sample.permute(0, 2, 3, 1).numpy() 84 | 85 | 86 | l_sample.append(sample) 87 | sample = np.concatenate(l_sample) 88 | print(f'sample shape: {sample.shape}') 89 | 90 | # save result 91 | if args.name is not None: 92 | out_name = os.path.join(args.result_dir, f'{args.ckpt.strip(".pkl")}_sample_{args.name}.npy') 93 | else: 94 | out_name = os.path.join(args.result_dir, f'{args.ckpt.strip(".pkl")}_sample.npy') 95 | np.save(out_name, sample) 96 | print(f'sample saved at {out_name}') 97 | 98 | 99 | -------------------------------------------------------------------------------- /tests/test_load_pretrained.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from models import load_pretrained 3 | 4 | # (identifier, config_file, ckpt_file, kwargs) 5 | cifar_ood_nae = ('cifar_ood_nae/z32gn', 'z32gn.yml', 'nae_8.pkl', {}) 6 | cifar_ood_nae_ae = ('cifar_ood_nae/z32gn', 'z32gn.yml', 'model_best.pkl', {}) # NAE before training 7 | 8 | mnist_ood_nae = ('mnist_ood_nae/z32', 'z32.yml', 'nae_20.pkl', {}) 9 | mnist_ood_nae_ae = ('mnist_ood_nae/z32', 'z32.yml', 'model_best.pkl', {}) # NAE before training 10 | 11 | celeba64_ood_nae = ('celeba64_ood_nae/z64gr_h32g8', 'z64gr_h32g8.yml', 'nae_3.pkl', {}) 12 | celeba64_ood_nae_ae = ('celeba64_ood_nae/z64gr_h32g8', 'z64gr_h32g8.yml', 'model_best.pkl', {}) # NAE before training 13 | l_setting = [cifar_ood_nae, cifar_ood_nae_ae, 14 | mnist_ood_nae, mnist_ood_nae_ae, 15 | celeba64_ood_nae, celeba64_ood_nae_ae] 16 | 17 | 18 | @pytest.mark.parametrize('model_setting', l_setting) 19 | def test_load_pretrained(model_setting): 20 | identifier, config_file, ckpt_file, kwargs = model_setting 21 | model, cfg = load_pretrained(identifier, config_file, ckpt_file, **kwargs) 22 | -------------------------------------------------------------------------------- /tests/test_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from loaders import get_dataset, get_dataloader 3 | from torch.utils import data 4 | import yaml 5 | import pytest 6 | from skimage import io 7 | import pickle 8 | import torch 9 | import os 10 | 11 | 12 | def test_get_dataloader(): 13 | cfg = {'dataset': 'FashionMNISTpad_OOD', 14 | 'path': 'datasets', 15 | 'shuffle': True, 16 | 'n_workers': 0, 17 | 'batch_size': 1, 18 | 'split': 'training'} 19 | dl = get_dataloader(cfg) 20 | 21 | 22 | def test_concat_dataset(): 23 | data_cfg = {'concat1': 24 | {'dataset': 'FashionMNISTpad_OOD', 25 | 'path': 'datasets', 26 | 'shuffle': True, 27 | 'split': 'training'}, 28 | 'concat2': 29 | {'dataset': 'MNISTpad_OOD', 30 | 'path': 'datasets', 31 | 'shuffle': True, 32 | 'n_workers': 0, 33 | 'batch_size': 1, 34 | 'split': 'training'}, 35 | 'n_workers': 0, 36 | 'batch_size': 1, 37 | } 38 | get_dataset(data_cfg) 39 | 40 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.optim import Adam 4 | from models.modules import FCNet, FCResNet, IsotropicGaussian 5 | from models.modules import ConvNet64, DeConvNet64 6 | from models.modules_sngan import GeneratorNoBN 7 | 8 | @pytest.mark.parametrize('out_activation', ['linear', 'spherical']) 9 | def test_fcnet(out_activation): 10 | net = FCNet(in_dim=1, out_dim=5, l_hidden=(50,), activation='sigmoid', out_activation=out_activation) 11 | X = torch.rand(10, 1) 12 | Z = net(X) 13 | 14 | 15 | 16 | def test_sngan_generator(): 17 | z_dim = 4 18 | channel = 3 19 | z = torch.rand(3, z_dim, 1, 1, dtype=torch.float) 20 | 21 | net = GeneratorNoBN(z_dim, channel, hidden_dim=32, out_activation='sigmoid') 22 | out = net(z) 23 | assert out.shape == (3, 3, 32, 32) 24 | 25 | 26 | def test_fcresnet(): 27 | net = FCResNet(in_dim=2, out_dim=2, res_dim=20, n_res_hidden=100, n_resblock=2) 28 | 29 | x = torch.rand((20, 2)) 30 | out = net(x) 31 | assert out.shape == (20, 2) 32 | 33 | 34 | @pytest.mark.parametrize('distribution', ['IsotropicGaussian']) 35 | def test_distribution_modules(distribution): 36 | if distribution == 'IsotropicGaussian': 37 | dist = IsotropicGaussian 38 | 39 | N = 10; z_dim = 3; x_dim = 5 40 | z = torch.rand(N, z_dim) 41 | x = torch.rand(N, x_dim) 42 | net = FCNet(in_dim=z_dim, out_dim=x_dim, l_hidden=(50,), activation='sigmoid', out_activation='sigmoid') 43 | net = dist(net) 44 | lik = net.log_likelihood(x, z) 45 | assert lik.shape == (N,) 46 | samples = net.sample(z) 47 | assert samples.shape == (N, x_dim) 48 | 49 | 50 | @pytest.mark.parametrize('num_groups', [None, 2]) 51 | def test_convnet64(num_groups): 52 | x = torch.rand(2, 3, 64, 64) 53 | encoder = ConvNet64(num_groups=num_groups) 54 | decoder = DeConvNet64(num_groups=num_groups) 55 | recon = decoder(encoder(x)) 56 | assert x.shape == recon.shape 57 | 58 | -------------------------------------------------------------------------------- /tests/test_nae.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.optim import Adam 4 | from models.modules import FCNet 5 | from models.nae import NAE, FFEBM, FFEBMV2, NAE_L2_CD, NAE_L2_NCE, NAE_L2_OMI 6 | from models import get_model 7 | from models.mcmc import LangevinSampler, MHSampler, NoiseSampler 8 | from omegaconf import OmegaConf 9 | 10 | 11 | def test_ffebm(): 12 | net = FCNet(2, 1, out_activation='linear') 13 | model = FFEBM(net, x_step=3, x_stepsize=1., sampling='x') 14 | X = torch.randn((10, 2), dtype=torch.float) 15 | opt = Adam(model.parameters(), lr=1e-4) 16 | 17 | # forward 18 | lik = model.predict(X) 19 | 20 | # sample 21 | model._set_x_shape(X) 22 | d_sample = model.sample_x(10, 'cpu:0') 23 | 24 | # training 25 | model.train_step(X, opt) 26 | 27 | 28 | @pytest.mark.parametrize('sampling', ['x', 'on_manifold', 'cd']) 29 | def test_nae(sampling): 30 | encoder = FCNet(2, 1) 31 | decoder = FCNet(1, 2) 32 | nae = NAE(encoder, decoder, initial_dist='gaussian', sampling=sampling) 33 | opt = Adam(nae.parameters(), lr=1e-4) 34 | 35 | X = torch.randn((10, 2), dtype=torch.float) 36 | 37 | # forward 38 | lik = nae.predict(X) 39 | 40 | # sample 41 | nae._set_x_shape(X) 42 | nae._set_z_shape(X) 43 | d_sample = nae.sample(X) 44 | 45 | # training 46 | nae.train_step(X, opt) 47 | 48 | 49 | def test_cifar(): 50 | cfg = OmegaConf.load('configs/cifar_ood_nae/z32gn.yml') 51 | model = get_model(cfg) 52 | xx = torch.rand(2, 3, 32, 32) 53 | recon = model.reconstruct(xx) 54 | error = model(xx) 55 | 56 | assert recon.shape == (2, 3, 32, 32) 57 | assert error.shape == (2,) 58 | 59 | 60 | @pytest.mark.parametrize('sampling', ['cd', 'x']) 61 | @pytest.mark.parametrize('sampler', ['mh', 'langevin']) 62 | def test_ffebm_v2(sampling, sampler): 63 | net = FCNet(2, 1, out_activation='linear') 64 | if sampler == 'mh': 65 | sampler = MHSampler(n_step=2, stepsize=0.1, bound=(-5, 5), buffer_size=10000, replay_ratio=0.95, reject_boundary=False, 66 | mh=True, initial_dist='uniform') 67 | elif sampler == 'langevin': 68 | sampler = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=1., 69 | bound=(-5, 5), buffer_size=10000, replay_ratio=0.95, reject_boundary=False, 70 | mh=True, initial_dist='uniform') 71 | model = FFEBMV2(net, sampler, gamma=1, sampling=sampling) 72 | X = torch.randn((10, 2), dtype=torch.float) 73 | opt = Adam(model.parameters(), lr=1e-4) 74 | 75 | # forward 76 | lik = model.predict(X) 77 | 78 | # sample 79 | model._set_x_shape(X) 80 | d_sample = model.sample(n_sample=10, device='cpu:0', sampling='x') 81 | 82 | # training 83 | model.train_step(X, opt) 84 | 85 | 86 | @pytest.mark.parametrize('sampling', ['cd', 'x']) 87 | def test_nae_l2_cd(sampling): 88 | encoder = FCNet(2, 1, out_activation='spherical') 89 | decoder = FCNet(1, 2) 90 | sampler = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=1., 91 | bound=(-5, 5), buffer_size=10000, replay_ratio=0.95, reject_boundary=False, 92 | mh=True, initial_dist='uniform') 93 | nae = NAE_L2_CD(encoder, decoder, sampler, sampling=sampling) 94 | opt = Adam(nae.parameters(), lr=1e-4) 95 | 96 | X = torch.randn((10, 2), dtype=torch.float) 97 | 98 | # forward 99 | lik = nae.predict(X) 100 | 101 | # sample 102 | nae._set_x_shape(X) 103 | # nae._set_z_shape(X) 104 | d_sample = nae.sample(X) 105 | 106 | # training 107 | nae.train_step(X, opt) 108 | 109 | 110 | def test_nae_l2_nce(): 111 | encoder = FCNet(2, 1, out_activation='spherical') 112 | decoder = FCNet(1, 2) 113 | sampler = NoiseSampler(dist='gaussian', shape=(2,), offset=1., scale=1.) 114 | nae = NAE_L2_NCE(encoder, decoder, sampler, T=0.5, T_trainable=True) 115 | opt = Adam(nae.parameters(), lr=1e-4) 116 | 117 | X = torch.randn((10, 2), dtype=torch.float) 118 | 119 | # forward 120 | lik = nae.predict(X) 121 | 122 | # training 123 | nae.train_step(X, opt) 124 | 125 | 126 | @pytest.mark.parametrize('spherical', [True, False]) 127 | def test_nae_l2_omi(spherical): 128 | encoder = FCNet(2, 3, out_activation='linear') 129 | decoder = FCNet(3, 2) 130 | sampler_x = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=1., 131 | bound=(-5, 5), buffer_size=10000, replay_ratio=0., 132 | reject_boundary=False, mh=True, initial_dist='uniform') 133 | 134 | sampler_z = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=None, 135 | bound=(-5, 5), buffer_size=10000, replay_ratio=0.95, 136 | reject_boundary=False, mh=True, initial_dist='uniform') 137 | 138 | nae = NAE_L2_OMI(encoder, decoder, sampler_z, sampler_x, T=0.5, T_trainable=True, spherical=spherical) 139 | opt = Adam(nae.parameters(), lr=1e-4) 140 | 141 | X = torch.randn((10, 2), dtype=torch.float) 142 | 143 | # forward 144 | lik = nae.predict(X) 145 | 146 | # training 147 | nae.train_step(X, opt) 148 | 149 | 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import argparse 5 | from omegaconf import OmegaConf 6 | import numpy as np 7 | from itertools import cycle 8 | import torch 9 | from models import get_model 10 | from trainers import get_trainer, get_logger 11 | from loaders import get_dataloader 12 | from optimizers import get_optimizer 13 | from datetime import datetime 14 | from tensorboardX import SummaryWriter 15 | from utils import save_yaml, search_params_intp, eprint, parse_unknown_args, parse_nested_args 16 | 17 | 18 | def run(cfg, writer): 19 | """main training function""" 20 | # Setup seeds 21 | seed = cfg.get('seed', 1) 22 | print(f'running with random seed : {seed}') 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | # for reproducibility 28 | # torch.backends.cudnn.deterministic = True 29 | # torch.backends.cudnn.benchmark = False 30 | 31 | # Setup device 32 | device = cfg.device 33 | 34 | # Setup Dataloader 35 | d_dataloaders = {} 36 | for key, dataloader_cfg in cfg['data'].items(): 37 | if 'holdout' in cfg: 38 | dataloader_cfg = process_holdout(dataloader_cfg, int(cfg['holdout'])) 39 | d_dataloaders[key] = get_dataloader(dataloader_cfg) 40 | 41 | # Setup Model 42 | model = get_model(cfg).to(device) 43 | trainer = get_trainer(cfg) 44 | logger = get_logger(cfg, writer) 45 | 46 | # Setup optimizer 47 | if hasattr(model, 'own_optimizer') and model.own_optimizer: 48 | optimizer, sch = model.get_optimizer(cfg['training']['optimizer']) 49 | elif 'optimizer' not in cfg['training']: 50 | optimizer = None 51 | sch = None 52 | else: 53 | optimizer, sch = get_optimizer(cfg["training"]["optimizer"], model.parameters()) 54 | 55 | model, train_result = trainer.train(model, optimizer, d_dataloaders, logger=logger, 56 | logdir=writer.file_writer.get_logdir(), scheduler=sch, 57 | clip_grad=cfg['training'].get('clip_grad', None)) 58 | 59 | 60 | 61 | def process_holdout(dataloader_cfg, holdout): 62 | """udpate config if holdout option is present in config""" 63 | if 'LeaveOut' in dataloader_cfg['dataset'] and 'out_class' in dataloader_cfg: 64 | if len(dataloader_cfg['out_class'] ) == 1: # indist 65 | dataloader_cfg['out_class'] = [holdout] 66 | else: # ood 67 | dataloader_cfg['out_class'] = [i for i in range(10) if i != holdout] 68 | print(dataloader_cfg) 69 | return dataloader_cfg 70 | 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--config', type=str) 76 | parser.add_argument('--device', default=0) 77 | parser.add_argument('--logdir', default='results/') 78 | parser.add_argument('--run', default=None, help='unique run id of the experiment') 79 | args, unknown = parser.parse_known_args() 80 | d_cmd_cfg = parse_unknown_args(unknown) 81 | d_cmd_cfg = parse_nested_args(d_cmd_cfg) 82 | print(d_cmd_cfg) 83 | cfg = OmegaConf.load(args.config) 84 | if args.device == 'cpu': 85 | cfg['device'] = f'cpu' 86 | else: 87 | cfg['device'] = f'cuda:{args.device}' 88 | 89 | if args.run is None: 90 | run_id = datetime.now().strftime('%Y%m%d-%H%M') 91 | else: 92 | run_id = args.run 93 | cfg = OmegaConf.merge(cfg, d_cmd_cfg) 94 | print(OmegaConf.to_yaml(cfg)) 95 | 96 | config_basename = os.path.basename(args.config).split('.')[0] 97 | logdir = os.path.join(args.logdir, config_basename, str(run_id)) 98 | writer = SummaryWriter(logdir=logdir) 99 | print("Result directory: {}".format(logdir)) 100 | 101 | # copy config file 102 | copied_yml = os.path.join(logdir, os.path.basename(args.config)) 103 | save_yaml(copied_yml, OmegaConf.to_yaml(cfg)) 104 | print(f'config saved as {copied_yml}') 105 | 106 | run(cfg, writer) 107 | 108 | 109 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from trainers.logger import BaseLogger 2 | from trainers.base import BaseTrainer 3 | from trainers.nae import NAETrainer, NAETrainerV2, NAELogger 4 | 5 | 6 | def get_trainer(cfg): 7 | # get trainer by specified `trainer` field 8 | # if not speficied, get trainer by model type 9 | trainer_type = cfg.get('trainer', None) 10 | arch = cfg['model']['arch'] 11 | device = cfg['device'] 12 | if trainer_type == 'nae': 13 | trainer = NAETrainer(cfg['training'], device=device) 14 | elif trainer_type == 'nae_v2': 15 | trainer = NAETrainerV2(cfg['training'], device=device) 16 | else: 17 | trainer = BaseTrainer(cfg['training'], device=device) 18 | return trainer 19 | 20 | 21 | def get_logger(cfg, writer): 22 | logger_type = cfg['logger'] 23 | if logger_type == 'nae': 24 | logger = NAELogger(writer) 25 | else: 26 | logger = BaseLogger(writer) 27 | return logger 28 | -------------------------------------------------------------------------------- /trainers/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import torch 4 | import numpy as np 5 | from metrics import averageMeter 6 | 7 | 8 | class BaseTrainer: 9 | """Trainer for a conventional iterative training of model""" 10 | def __init__(self, training_cfg, device): 11 | self.cfg = training_cfg 12 | self.device = device 13 | self.d_val_result = {} 14 | 15 | def train(self, model, opt, d_dataloaders, logger=None, logdir='', scheduler=None, clip_grad=None): 16 | cfg = self.cfg 17 | best_val_loss = np.inf 18 | time_meter = averageMeter() 19 | i = 0 20 | train_loader, val_loader = d_dataloaders['training'], d_dataloaders['validation'] 21 | 22 | for i_epoch in range(cfg.n_epoch): 23 | 24 | for x, y in train_loader: 25 | i += 1 26 | 27 | model.train() 28 | x = x.to(self.device) 29 | y = y.to(self.device) 30 | 31 | start_ts = time.time() 32 | d_train = model.train_step(x, y=y, optimizer=opt, clip_grad=clip_grad) 33 | time_meter.update(time.time() - start_ts) 34 | logger.process_iter_train(d_train) 35 | 36 | if i % cfg.print_interval == 0: 37 | d_train = logger.summary_train(i) 38 | print(f"Iter [{i:d}] Avg Loss: {d_train['loss/train_loss_']:.4f} Elapsed time: {time_meter.sum:.4f}") 39 | time_meter.reset() 40 | 41 | if i % cfg.val_interval == 0: 42 | model.eval() 43 | for val_x, val_y in val_loader: 44 | val_x = val_x.to(self.device) 45 | val_y = val_y.to(self.device) 46 | 47 | d_val = model.validation_step(val_x, y=val_y) 48 | logger.process_iter_val(d_val) 49 | d_val = logger.summary_val(i) 50 | val_loss = d_val['loss/val_loss_'] 51 | print(d_val['print_str']) 52 | best_model = val_loss < best_val_loss 53 | 54 | if i % cfg.save_interval == 0 or best_model: 55 | self.save_model(model, logdir, best=best_model, i_iter=i) 56 | if best_model: 57 | print(f'Iter [{i:d}] best model saved {val_loss} <= {best_val_loss}') 58 | best_val_loss = val_loss 59 | if i_epoch % cfg.save_interval_epoch == 0: 60 | self.save_model(model, logdir, best=False, i_epoch=i_epoch) 61 | 62 | return model, best_val_loss 63 | 64 | def save_model(self, model, logdir, best=False, i_iter=None, i_epoch=None): 65 | if best: 66 | pkl_name = "model_best.pkl" 67 | else: 68 | if i_iter is not None: 69 | pkl_name = "model_iter_{}.pkl".format(i_iter) 70 | else: 71 | pkl_name = "model_epoch_{}.pkl".format(i_epoch) 72 | state = {"epoch": i_epoch, "model_state": model.state_dict(), 'iter': i_iter} 73 | save_path = os.path.join(logdir, pkl_name) 74 | torch.save(state, save_path) 75 | print(f'Model saved: {pkl_name}') 76 | 77 | 78 | -------------------------------------------------------------------------------- /trainers/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from metrics import averageMeter 3 | 4 | 5 | class BaseLogger: 6 | """BaseLogger that can handle most of the logging 7 | logging convention 8 | ------------------ 9 | 'loss' has to be exist in all training settings 10 | endswith('_') : scalar 11 | endswith('@') : image 12 | """ 13 | def __init__(self, tb_writer): 14 | """tb_writer: tensorboard SummaryWriter""" 15 | self.writer = tb_writer 16 | self.train_loss_meter = averageMeter() 17 | self.val_loss_meter = averageMeter() 18 | self.d_train = {} 19 | self.d_val = {} 20 | self.has_val_loss = True # If True, we assume that validation loss is available 21 | 22 | def process_iter_train(self, d_result): 23 | self.train_loss_meter.update(d_result['loss']) 24 | self.d_train = d_result 25 | 26 | def summary_train(self, i): 27 | self.d_train['loss/train_loss_'] = self.train_loss_meter.avg 28 | for key, val in self.d_train.items(): 29 | if key.endswith('_'): 30 | self.writer.add_scalar(key, val, i) 31 | if key.endswith('@'): 32 | if val is not None: 33 | self.writer.add_image(key, val, i) 34 | 35 | result = self.d_train 36 | self.d_train = {} 37 | return result 38 | 39 | def process_iter_val(self, d_result): 40 | if self.has_val_loss: 41 | self.val_loss_meter.update(d_result['loss']) 42 | self.d_val = d_result 43 | 44 | def summary_val(self, i): 45 | if self.has_val_loss: 46 | self.d_val['loss/val_loss_'] = self.val_loss_meter.avg 47 | l_print_str = [f'Iter [{i:d}]'] 48 | for key, val in self.d_val.items(): 49 | if key.endswith('_'): 50 | self.writer.add_scalar(key, val, i) 51 | l_print_str.append(f'{key}: {val:.4f}') 52 | if key.endswith('@'): 53 | if val is not None: 54 | self.writer.add_image(key, val, i) 55 | 56 | print_str = ' '.join(l_print_str) 57 | 58 | result = self.d_val 59 | result['print_str'] = print_str 60 | self.d_val = {} 61 | return result 62 | -------------------------------------------------------------------------------- /trainers/nae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | from metrics import averageMeter 5 | from trainers.base import BaseTrainer 6 | from trainers.logger import BaseLogger 7 | from optimizers import get_optimizer 8 | import torch 9 | from torch.optim import Adam 10 | from tqdm import tqdm 11 | from torchvision.utils import make_grid, save_image 12 | from utils import roc_btw_arr 13 | 14 | 15 | class NAETrainer(BaseTrainer): 16 | def train(self, model, opt, d_dataloaders, logger=None, logdir='', scheduler=None, clip_grad=None): 17 | cfg = self.cfg 18 | best_val_loss = np.inf 19 | time_meter = averageMeter() 20 | i = 0 21 | indist_train_loader = d_dataloaders['indist_train'] 22 | indist_val_loader = d_dataloaders['indist_val'] 23 | oodval_val_loader = d_dataloaders['ood_val'] 24 | oodtarget_val_loader = d_dataloaders['ood_target'] 25 | no_best_model_tolerance = 3 26 | no_best_model_count = 0 27 | 28 | n_ae_epoch = cfg.ae_epoch 29 | n_nae_epoch = cfg.nae_epoch 30 | ae_opt = Adam(model.parameters(), lr=cfg.ae_lr) 31 | # nae_opt = Adam([{'params': list(model.encoder.parameters()) + list(model.decoder.parameters())}, 32 | # {'params': model.temperature_, 'lr': cfg.temperature_lr}], lr=cfg.nae_lr) 33 | 34 | if cfg.get('fix_D', False): 35 | if hasattr(model, 'temperature_'): 36 | nae_opt = Adam(list(model.encoder.parameters()) + [model.temperature_], lr=cfg.nae_lr) 37 | else: 38 | nae_opt = Adam(model.encoder.parameters(), lr=cfg.nae_lr) 39 | elif cfg.get('small_D_lr', False): 40 | print('small decoder learning rate') 41 | nae_opt = Adam([{'params': model.encoder.parameters(), 'lr': cfg.nae_lr}, 42 | {'params': model.decoder.parameters(), 'lr': cfg.nae_lr / 10}]) 43 | else: 44 | l_params = [{'params': list(model.encoder.parameters()) + list(model.decoder.parameters())}] 45 | if model.temperature_trainable: 46 | l_params.append({'params': model.temperature_, 'lr': cfg.temperature_lr}) 47 | nae_opt = Adam(l_params, lr=cfg.nae_lr) 48 | 49 | '''AE PASS''' 50 | if 'load_ae' in cfg: 51 | n_ae_epoch = 0 52 | model.load_state_dict(torch.load(cfg['load_ae'])['model_state']) 53 | print(f'model loaded from {cfg["load_ae"]}') 54 | 55 | for i_epoch in range(n_ae_epoch): 56 | 57 | for x, y in indist_train_loader: 58 | i += 1 59 | 60 | model.train() 61 | x = x.to(self.device) 62 | y = y.to(self.device) 63 | 64 | start_ts = time.time() 65 | d_train = model.train_step_ae(x, ae_opt, clip_grad=0.1) # todo: clip_grad 66 | time_meter.update(time.time() - start_ts) 67 | logger.process_iter_train(d_train) 68 | 69 | if i % cfg.print_interval == 0: 70 | d_train = logger.summary_train(i) 71 | print(f"Iter [{i:d}] Avg Loss: {d_train['loss/train_loss_']:.4f} Elapsed time: {time_meter.sum:.4f}") 72 | time_meter.reset() 73 | 74 | if i % cfg.val_interval == 0: 75 | model.eval() 76 | for val_x, val_y in indist_val_loader: 77 | val_x = val_x.to(self.device) 78 | val_y = val_y.to(self.device) 79 | 80 | d_val = model.validation_step(val_x, y=val_y) 81 | logger.process_iter_val(d_val) 82 | d_val = logger.summary_val(i) 83 | val_loss = d_val['loss/val_loss_'] 84 | print(d_val['print_str']) 85 | best_model = val_loss < best_val_loss 86 | 87 | if i % cfg.save_interval == 0 or best_model: 88 | self.save_model(model, logdir, best=best_model, i_iter=i) 89 | if best_model: 90 | print(f'Iter [{i:d}] best model saved {val_loss} <= {best_val_loss}') 91 | best_val_loss = val_loss 92 | else: 93 | no_best_model_count += 1 94 | if no_best_model_count > no_best_model_tolerance: 95 | break 96 | 97 | if no_best_model_count > no_best_model_tolerance: 98 | print('terminating autoencoder training since validation loss does not decrease anymore') 99 | break 100 | 101 | '''NAE PASS''' 102 | # load best autoencoder model 103 | # model.load_state_dict(torch.load(os.path.join(logdir, 'model_best.pkl'))['model_state']) 104 | # print('best model loaded') 105 | i = 0 106 | for i_epoch in tqdm(range(n_nae_epoch)): 107 | for x, _ in tqdm(indist_train_loader): 108 | i += 1 109 | 110 | x = x.cuda(self.device) 111 | if cfg.get('flatten', False): 112 | x = x.view(len(x), -1) 113 | d_result = model.train_step(x, nae_opt) 114 | logger.process_iter_train_nae(d_result) 115 | 116 | if i % cfg.print_interval == 1: 117 | logger.summary_train_nae(i) 118 | 119 | if i % cfg.val_interval == 1: 120 | '''AUC''' 121 | in_pred = self.predict(model, indist_val_loader, self.device) 122 | ood1_pred = self.predict(model, oodval_val_loader, self.device) 123 | auc_val = roc_btw_arr(ood1_pred, in_pred) 124 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device) 125 | auc_target = roc_btw_arr(ood2_pred, in_pred) 126 | d_result = {'nae/auc_val': auc_val, 'nae/auc_target': auc_target} 127 | print(logger.summary_val_nae(i, d_result)) 128 | torch.save({'model_state': model.state_dict()}, f'{logdir}/nae_iter_{i}.pkl') 129 | 130 | torch.save(model.state_dict(), f'{logdir}/nae_{i_epoch}.pkl') 131 | torch.save(model.state_dict(), f'{logdir}/nae.pkl') 132 | 133 | '''EBAE sample''' 134 | nae_sample = model.sample(n_sample=30, device=self.device, replay=True) 135 | img_grid = make_grid(nae_sample['sample_x'].detach().cpu(), nrow=10, range=(0, 1)) 136 | logger.writer.add_image('nae/sample', img_grid, i + 1) 137 | save_image(img_grid, f'{logdir}/nae_sample.png') 138 | 139 | '''AUC''' 140 | in_pred = self.predict(model, indist_val_loader, self.device) 141 | ood1_pred = self.predict(model, oodval_val_loader, self.device) 142 | auc_val = roc_btw_arr(ood1_pred, in_pred) 143 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device) 144 | auc_target = roc_btw_arr(ood2_pred, in_pred) 145 | d_result = {'nae/auc_val': auc_val, 'nae/auc_target': auc_target} 146 | print(d_result) 147 | 148 | return model, auc_val 149 | 150 | def predict(self, m, dl, device, flatten=False): 151 | """run prediction for the whole dataset""" 152 | l_result = [] 153 | for x, _ in dl: 154 | with torch.no_grad(): 155 | if flatten: 156 | x = x.view(len(x), -1) 157 | pred = m.predict(x.cuda(device)).detach().cpu() 158 | l_result.append(pred) 159 | return torch.cat(l_result) 160 | 161 | 162 | class NAELogger(BaseLogger): 163 | def __init__(self, tb_writer): 164 | super().__init__(tb_writer) 165 | self.train_loss_meter_nae = averageMeter() 166 | self.val_loss_meter_nae = averageMeter() 167 | self.d_train_nae = {} 168 | self.d_val_nae = {} 169 | 170 | def process_iter_train_nae(self, d_result): 171 | self.train_loss_meter.update(d_result['loss']) 172 | self.d_train_nae = d_result 173 | 174 | def summary_train_nae(self, i): 175 | d_result = self.d_train_nae 176 | writer = self.writer 177 | writer.add_scalar('nae/loss', d_result['loss'], i) 178 | writer.add_scalar('nae/energy_diff', d_result['pos_e'] - d_result['neg_e'], i) 179 | writer.add_scalar('nae/pos_e', d_result['pos_e'], i) 180 | writer.add_scalar('nae/neg_e', d_result['neg_e'], i) 181 | writer.add_scalar('nae/encoder_l2', d_result['encoder_norm'], i) 182 | writer.add_scalar('nae/decoder_l2', d_result['decoder_norm'], i) 183 | if 'neg_e_x0' in d_result: 184 | writer.add_scalar('nae/neg_e_x0', d_result['neg_e_x0'], i) 185 | if 'neg_e_z0' in d_result: 186 | writer.add_scalar('nae/neg_e_z0', d_result['neg_e_z0'], i) 187 | if 'temperature' in d_result: 188 | writer.add_scalar('nae/temperature', d_result['temperature'], i) 189 | if 'sigma' in d_result: 190 | writer.add_scalar('nae/sigma', d_result['sigma'], i) 191 | if 'delta_term' in d_result: 192 | writer.add_scalar('nae/delta_term', d_result['delta_term'], i) 193 | if 'gamma_term' in d_result: 194 | writer.add_scalar('nae/gamma_term', d_result['gamma_term'], i) 195 | 196 | 197 | '''images''' 198 | x_neg = d_result['x_neg'] 199 | recon_neg = d_result['recon_neg'] 200 | img_grid = make_grid(x_neg, nrow=10, range=(0, 1)) 201 | writer.add_image('nae/sample', img_grid, i) 202 | img_grid = make_grid(recon_neg, nrow=10, range=(0, 1), normalize=True) 203 | writer.add_image('nae/sample_recon', img_grid, i) 204 | 205 | # to uint8 and save as array 206 | x_neg = (x_neg.permute(0,2,3,1).numpy() * 256.).clip(0, 255).astype('uint8') 207 | recon_neg = (recon_neg.permute(0,2,3,1).numpy() * 256.).clip(0, 255).astype('uint8') 208 | # save_image(img_grid, f'{writer.file_writer.get_logdir()}/nae_sample_{i}.png') 209 | np.save(f'{writer.file_writer.get_logdir()}/nae_neg_{i}.npy', x_neg) 210 | np.save(f'{writer.file_writer.get_logdir()}/nae_neg_recon_{i}.npy', recon_neg) 211 | 212 | 213 | def summary_val_nae(self, i, d_result): 214 | l_print_str = [f'Iter [{i:d}]'] 215 | for key, val in d_result.items(): 216 | self.writer.add_scalar(key, val, i) 217 | l_print_str.append(f'{key}: {val:.4f}') 218 | print_str = ' '.join(l_print_str) 219 | return print_str 220 | 221 | 222 | class NAETrainerV2(BaseTrainer): 223 | def train(self, model, opt, d_dataloaders, logger=None, logdir='', scheduler=None, clip_grad=None): 224 | cfg = self.cfg 225 | best_val_loss = np.inf 226 | time_meter = averageMeter() 227 | i = 0 228 | indist_train_loader = d_dataloaders['indist_train'] 229 | indist_val_loader = d_dataloaders['indist_val'] 230 | oodval_val_loader = d_dataloaders['ood_val'] 231 | oodtarget_val_loader = d_dataloaders['ood_target'] 232 | no_best_model_tolerance = 3 233 | no_best_model_count = 0 234 | 235 | n_ae_epoch = cfg.ae_epoch 236 | n_nae_epoch = cfg.nae_epoch 237 | ae_opt = Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=cfg.ae_lr) 238 | nae_opt_mode = cfg.get('nae_opt', 'fix_ae') 239 | if nae_opt_mode == 'fix_ae': 240 | if hasattr(model, 'varnetx'): 241 | nae_opt = Adam(list(model.varnetx.parameters()) + list(model.varnetz.parameters()), lr=cfg.nae_lr) 242 | else: 243 | nae_opt = Adam(model.varnet.parameters(), lr=cfg.nae_lr) 244 | elif nae_opt_mode == 'all': 245 | nae_opt = Adam(model.parameters(), lr=cfg.nae_lr) 246 | elif nae_opt_mode == 'fix_D': 247 | if hasattr(model, 'varnet'): 248 | nae_opt = Adam(list(model.varnet.parameters()) + list(model.encoder.parameters()), lr=cfg.nae_lr) 249 | else: 250 | nae_opt = Adam(model.encoder.parameters(), lr=cfg.nae_lr) 251 | elif nae_opt_mode == 'slow_E': 252 | nae_opt = Adam([{'params': model.varnet.parameters(), 'lr': cfg.nae_lr * 10}, 253 | {'params': model.encoder.parameters(), 'lr': cfg.nae_lr}]) 254 | elif nae_opt_mode == 'sgd': 255 | nae_opt = SGD(model.parameters(), lr=cfg.nae_lr) 256 | else: 257 | raise ValueError(f'Invalid nae_opt_mode {nae_opt_mode}') 258 | 259 | '''AE PASS''' 260 | if cfg.get('load_ae', None) is not None: 261 | n_ae_epoch = 0 262 | state_dict = torch.load(cfg['load_ae'])['model_state'] 263 | encoder_state = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder')} 264 | decoder_state = {k.replace('decoder.', ''): v for k, v in state_dict.items() if k.startswith('decoder')} 265 | model.encoder.load_state_dict(encoder_state) 266 | model.decoder.load_state_dict(decoder_state) 267 | print(f'model loaded from {cfg["load_ae"]}') 268 | 269 | for i_epoch in range(n_ae_epoch): 270 | 271 | for x, y in indist_train_loader: 272 | i += 1 273 | 274 | model.train() 275 | x = x.to(self.device) 276 | y = y.to(self.device) 277 | 278 | start_ts = time.time() 279 | d_train = model.train_step_ae(x, ae_opt, clip_grad=clip_grad) 280 | time_meter.update(time.time() - start_ts) 281 | logger.process_iter_train(d_train) 282 | 283 | if i % cfg.print_interval == 0: 284 | d_train = logger.summary_train(i) 285 | print(f"Iter [{i:d}] Avg Loss: {d_train['loss/train_loss_']:.4f} Elapsed time: {time_meter.sum:.4f}") 286 | time_meter.reset() 287 | 288 | if i % cfg.val_interval == 0: 289 | model.eval() 290 | for val_x, val_y in indist_val_loader: 291 | val_x = val_x.to(self.device) 292 | val_y = val_y.to(self.device) 293 | 294 | d_val = model.validation_step_ae(val_x, y=val_y) 295 | logger.process_iter_val(d_val) 296 | d_val = logger.summary_val(i) 297 | val_loss = d_val['loss/val_loss_'] 298 | print(d_val['print_str']) 299 | best_model = val_loss < best_val_loss 300 | 301 | if best_model: 302 | self.save_model(model, logdir, best=best_model, i_iter=i) 303 | print(f'Iter [{i:d}] best model saved {val_loss} <= {best_val_loss}') 304 | best_val_loss = val_loss 305 | else: 306 | no_best_model_count += 1 307 | if no_best_model_count > no_best_model_tolerance: 308 | break 309 | 310 | if i_epoch % cfg.get('save_interval_epoch', 1) == 1: 311 | self.save_model(model, logdir, best=False, i_iter=None, i_epoch=i_epoch) 312 | print(f'Epoch [{i_epoch:d}] model saved {cfg.save_interval_epoch}') 313 | 314 | if no_best_model_count > no_best_model_tolerance: 315 | print('terminating autoencoder training since validation loss does not decrease anymore') 316 | break 317 | 318 | '''NAE PASS''' 319 | logger.has_val_loss = False 320 | if cfg.get('transfer_encoder', False): 321 | state_dict = model.encoder.state_dict() 322 | state_dict.pop('linear.weight'); state_dict.pop('linear.bias') 323 | transfer_result = model.varnet.load_state_dict(state_dict, strict=False) 324 | print('varnet initialized from encoder:', transfer_result) 325 | # load best autoencoder model 326 | # model.load_state_dict(torch.load(os.path.join(logdir, 'model_best.pkl'))['model_state']) 327 | # print('best model loaded') 328 | i = 0 329 | for i_epoch in tqdm(range(n_nae_epoch)): 330 | for x, _ in tqdm(indist_train_loader): 331 | i += 1 332 | 333 | x = x.cuda(self.device) 334 | if cfg.get('flatten', False): 335 | x = x.view(len(x), -1) 336 | # from pudb import set_trace; set_trace() 337 | d_result = model.train_step(x, nae_opt, clip_grad=clip_grad) 338 | logger.process_iter_train(d_result) 339 | 340 | if i % cfg.print_interval_nae == 1: 341 | input_img = make_grid(x.detach().cpu(), nrow=10, range=(0, 1)) 342 | recon_img = make_grid(model.reconstruct(x).detach().cpu(), nrow=10, range=(0, 1)) 343 | logger.d_train['input_img@'] = input_img 344 | logger.d_train['recon_img@'] = recon_img 345 | logger.summary_train(i) 346 | 347 | if i % cfg.val_interval == 1: 348 | '''AUC''' 349 | in_pred = self.predict(model, indist_val_loader, self.device) 350 | ood1_pred = self.predict(model, oodval_val_loader, self.device) 351 | auc_val = roc_btw_arr(ood1_pred, in_pred) 352 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device) 353 | auc_target = roc_btw_arr(ood2_pred, in_pred) 354 | d_result = {'nae/auc_val_': auc_val, 'nae/auc_target_': auc_target} 355 | logger.process_iter_val(d_result) 356 | print(logger.summary_val(i)['print_str']) 357 | torch.save({'model_state': model.state_dict()}, f'{logdir}/nae_iter_{i}.pkl') 358 | 359 | torch.save(model.state_dict(), f'{logdir}/nae_{i_epoch}.pkl') 360 | torch.save(model.state_dict(), f'{logdir}/nae.pkl') 361 | 362 | '''AUC''' 363 | in_pred = self.predict(model, indist_val_loader, self.device) 364 | ood1_pred = self.predict(model, oodval_val_loader, self.device) 365 | auc_val = roc_btw_arr(ood1_pred, in_pred) 366 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device) 367 | auc_target = roc_btw_arr(ood2_pred, in_pred) 368 | d_result = {'nae/auc_val': auc_val, 'nae/auc_target': auc_target} 369 | print(d_result) 370 | 371 | return model, auc_val 372 | 373 | def predict(self, m, dl, device, flatten=False): 374 | """run prediction for the whole dataset""" 375 | l_result = [] 376 | for x, _ in dl: 377 | with torch.no_grad(): 378 | if flatten: 379 | x = x.view(len(x), -1) 380 | pred = m.predict(x.cuda(device)).detach().cpu() 381 | l_result.append(pred) 382 | return torch.cat(l_result) 383 | 384 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc Utility functions 3 | """ 4 | import os 5 | import sys 6 | import logging 7 | import datetime 8 | import numpy as np 9 | import torch 10 | import yaml 11 | from torch.autograd import Variable 12 | from collections import OrderedDict 13 | 14 | 15 | def recursive_glob(rootdir=".", suffix=""): 16 | """Performs recursive glob with given suffix and rootdir 17 | :param rootdir is the root directory 18 | :param suffix is the suffix to be searched 19 | """ 20 | return [ 21 | os.path.join(looproot, filename) 22 | for looproot, _, filenames in os.walk(rootdir) 23 | for filename in filenames 24 | if filename.endswith(suffix) 25 | ] 26 | 27 | 28 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 29 | """Alpha Blending utility to overlay RGB masks on RBG images 30 | :param input_image is a np.ndarray with 3 channels 31 | :param segmentation_mask is a np.ndarray with 3 channels 32 | :param alpha is a float value 33 | """ 34 | blended = np.zeros(input_image.size, dtype=np.float32) 35 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 36 | return blended 37 | 38 | 39 | def convert_state_dict(state_dict): 40 | """Converts a state dict saved from a dataParallel module to normal 41 | module state_dict inplace 42 | :param state_dict is the loaded DataParallel model_state 43 | """ 44 | if not next(iter(state_dict)).startswith("module."): 45 | return state_dict # abort if dict is not a DataParallel model_state 46 | new_state_dict = OrderedDict() 47 | for k, v in state_dict.items(): 48 | name = k[7:] # remove `module.` 49 | new_state_dict[name] = v 50 | return new_state_dict 51 | 52 | def convert_state_dict_remove_main(state_dict): 53 | if not next(iter(state_dict)).startswith("module."): 54 | return state_dict # abort if dict is not a DataParallel model_state 55 | elif not next(iter(state_dict)).startswith("module.main"): 56 | return convert_state_dict(state_dict) # abort if dict is not a DataParallel model_state 57 | 58 | new_state_dict = OrderedDict() 59 | for k, v in state_dict.items(): 60 | name = k[12:] # remove `module.main` 61 | new_state_dict[name] = v 62 | return new_state_dict 63 | 64 | 65 | def get_logger(logdir): 66 | logger = logging.getLogger("ptsemseg") 67 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 68 | ts = ts.replace(":", "_").replace("-", "_") 69 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 70 | hdlr = logging.FileHandler(file_path, mode='w') 71 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 72 | hdlr.setFormatter(formatter) 73 | logger.addHandler(hdlr) 74 | logger.setLevel(logging.INFO) 75 | return logger 76 | 77 | 78 | def check_config_validity(cfg): 79 | """checks validity of config""" 80 | assert 'model' in cfg 81 | assert 'training' in cfg 82 | assert 'validation' in cfg 83 | assert 'evaluation' in cfg 84 | assert 'data' in cfg 85 | 86 | data = cfg['data'] 87 | assert 'dataset' in data 88 | assert 'path' in data 89 | assert 'n_classes' in data 90 | assert 'split' in data 91 | assert 'resize_factor' in data 92 | assert 'label' in data 93 | 94 | d = cfg['training'] 95 | s = 'training' 96 | assert 'train_iters' in d, s 97 | assert 'val_interval' in d, s 98 | assert 'print_interval' in d, s 99 | assert 'optimizer' in d, s 100 | assert 'loss' in d, s 101 | assert 'batch_size' in d, s 102 | assert 'n_workers' in d, s 103 | 104 | d = cfg['validation'] 105 | s = 'validation' 106 | assert 'batch_size' in d, s 107 | assert 'n_workers' in d, s 108 | 109 | d = cfg['evaluation'] 110 | s = 'evaluation' 111 | assert 'batch_size' in d, s 112 | assert 'n_workers' in d, s 113 | assert 'num_crop_width' in d, s 114 | assert 'num_crop_height' in d, s 115 | 116 | 117 | print('config file validation passed') 118 | 119 | 120 | def add_uniform_noise(x): 121 | """ 122 | x: torch.Tensor 123 | """ 124 | return x * 255. / 256. + torch.rand_like(x) / 256. 125 | 126 | 127 | import errno 128 | import os 129 | 130 | 131 | # Recursive mkdir 132 | def mkdir_p(path): 133 | try: 134 | os.makedirs(path) 135 | except OSError as exc: # Python >2.5 136 | if exc.errno == errno.EEXIST and os.path.isdir(path): 137 | pass 138 | else: 139 | raise 140 | 141 | 142 | from sklearn.metrics import roc_auc_score 143 | def roc_btw_arr(arr1, arr2): 144 | true_label = np.concatenate([np.ones_like(arr1), 145 | np.zeros_like(arr2)]) 146 | score = np.concatenate([arr1, arr2]) 147 | return roc_auc_score(true_label, score) 148 | 149 | 150 | def batch_run(m, dl, device, flatten=False, method='predict', input_type='first', no_grad=True, **kwargs): 151 | """ 152 | m: model 153 | dl: dataloader 154 | device: device 155 | method: the name of a function to be called 156 | no_grad: use torch.no_grad if True. 157 | kwargs: additional argument for the method being called 158 | """ 159 | method = getattr(m, method) 160 | l_result = [] 161 | for batch in dl: 162 | if input_type == 'first': 163 | x = batch[0] 164 | 165 | if no_grad: 166 | with torch.no_grad(): 167 | if flatten: 168 | x = x.view(len(x), -1) 169 | pred = method(x.cuda(device), **kwargs).detach().cpu() 170 | else: 171 | if flatten: 172 | x = x.view(len(x), -1) 173 | pred = method(x.cuda(device), **kwargs).detach().cpu() 174 | 175 | l_result.append(pred) 176 | return torch.cat(l_result) 177 | 178 | 179 | 180 | def save_yaml(filename, text): 181 | """parse string as yaml then dump as a file""" 182 | with open(filename, 'w') as f: 183 | yaml.dump(yaml.safe_load(text), f, default_flow_style=False) 184 | 185 | 186 | def get_shuffled_idx(N, seed): 187 | """get randomly permuted indices""" 188 | rng = np.random.default_rng(seed) 189 | return rng.permutation(N) 190 | 191 | 192 | def search_params_intp(params): 193 | ret = {} 194 | for param in params.keys(): 195 | # param : "train.batch" 196 | spl = param.split(".") 197 | if len(spl) == 2: 198 | if spl[0] in ret: 199 | ret[spl[0]][spl[1]] = params[param] 200 | else: 201 | ret[spl[0]] = {spl[1]: params[param]} 202 | # temp = {} 203 | # temp[spl[1]] = params[param] 204 | # ret[spl[0]] = temp 205 | elif len(spl) == 3: 206 | if spl[0] in ret: 207 | if spl[1] in ret[spl[0]]: 208 | ret[spl[0]][spl[1]][spl[2]] = params[param] 209 | else: 210 | ret[spl[0]][spl[1]] = {spl[2]: params[param]} 211 | else: 212 | ret[spl[0]] = {spl[1]: {spl[2]: params[param]}} 213 | elif len(spl) == 1: 214 | ret[spl[0]] = params[param] 215 | else: 216 | raise ValueError 217 | return ret 218 | 219 | 220 | def eprint(*args, **kwargs): 221 | print(*args, file=sys.stderr, **kwargs) 222 | 223 | 224 | def weight_norm(net): 225 | """computes L2 norm of weights of parameters""" 226 | norm = 0 227 | for param in net.parameters(): 228 | norm += (param ** 2).sum() 229 | return norm 230 | 231 | 232 | 233 | ## additional command-line argument parsing 234 | def parse_arg_type(val): 235 | if val.isnumeric(): 236 | return int(val) 237 | if val == 'True': 238 | return True 239 | try: 240 | return float(val) 241 | except ValueError: 242 | return val 243 | 244 | 245 | def parse_unknown_args(l_args): 246 | """convert the list of unknown args into dict 247 | this does similar stuff to OmegaConf.from_cli() 248 | I may have invented the wheel again...""" 249 | n_args = len(l_args) // 2 250 | kwargs = {} 251 | for i_args in range(n_args): 252 | key = l_args[i_args*2] 253 | val = l_args[i_args*2 + 1] 254 | assert '=' not in key, 'optional arguments should be separated by space' 255 | kwargs[key.strip('-')] = parse_arg_type(val) 256 | return kwargs 257 | 258 | 259 | def parse_nested_args(d_cmd_cfg): 260 | """produce a nested dictionary by parsing dot-separated keys 261 | e.g. {key1.key2 : 1} --> {key1: {key2: 1}}""" 262 | d_new_cfg = {} 263 | for key, val in d_cmd_cfg.items(): 264 | l_key = key.split('.') 265 | d = d_new_cfg 266 | for i_key, each_key in enumerate(l_key): 267 | if i_key == len(l_key) - 1: 268 | d[each_key] = val 269 | else: 270 | if each_key not in d: 271 | d[each_key] = {} 272 | d = d[each_key] 273 | return d_new_cfg 274 | --------------------------------------------------------------------------------