├── .gitignore
├── 2d_density_estimation
├── ae_1_8gaussians.pth
├── ae_3_8gaussians.pth
├── draw_2d_density_estimation_figure.ipynb
├── draw_nae_training_methods.ipynb
├── nae_3_8gaussians.pth
├── nae_8gaussians.pth
├── nae_cd_3_8gaussians.pth
├── nae_pcd_3_8gaussians.pth
├── nae_training_methods.ipynb
├── vae_1_8gaussians.pth
└── vae_3_8gaussians.pth
├── LICENSE
├── README.md
├── assets
├── celeba64samples.png
├── cifar10samples.png
├── fig_mnist_recon_with_box_v2.png
├── main.png
├── mnistsamples.png
└── vanilla_ae_mnistsamples.png
├── augmentations
├── __init__.py
└── augmentations.py
├── configs
├── celeba64_ood_nae
│ └── z64gr_h32g8.yml
├── cifar_ood_nae
│ └── z32gn.yml
├── fmnist_ood_nae
│ └── z32.yml
├── mnist_ho_nae
│ └── l2_z32.yml
└── mnist_ood_nae
│ └── z32.yml
├── conftest.py
├── datasets
└── .gitkeep
├── evaluate_ood.py
├── loaders
├── __init__.py
├── chimera_dataset.py
├── leaveout_dataset.py
├── modified_dataset.py
└── synthetic.py
├── metrics.py
├── models
├── __init__.py
├── ae.py
├── energybased.py
├── igebm.py
├── langevin.py
├── mcmc.py
├── mmd.py
├── modules.py
├── modules_sngan.py
├── nae.py
├── spectral_norm.py
└── utils.py
├── optimizers.py
├── pretrained
└── .gitkeep
├── sample.py
├── tests
├── test_load_pretrained.py
├── test_loader.py
├── test_modules.py
└── test_nae.py
├── train.py
├── trainers
├── __init__.py
├── base.py
├── logger.py
└── nae.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Add
132 | datasets/
133 | results/
134 | pretrained/
135 |
--------------------------------------------------------------------------------
/2d_density_estimation/ae_1_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/ae_1_8gaussians.pth
--------------------------------------------------------------------------------
/2d_density_estimation/ae_3_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/ae_3_8gaussians.pth
--------------------------------------------------------------------------------
/2d_density_estimation/nae_3_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_3_8gaussians.pth
--------------------------------------------------------------------------------
/2d_density_estimation/nae_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_8gaussians.pth
--------------------------------------------------------------------------------
/2d_density_estimation/nae_cd_3_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_cd_3_8gaussians.pth
--------------------------------------------------------------------------------
/2d_density_estimation/nae_pcd_3_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/nae_pcd_3_8gaussians.pth
--------------------------------------------------------------------------------
/2d_density_estimation/vae_1_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/vae_1_8gaussians.pth
--------------------------------------------------------------------------------
/2d_density_estimation/vae_3_8gaussians.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/2d_density_estimation/vae_3_8gaussians.pth
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Sangwoong Yoon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Autoencoding Under Normalization Constraints
2 |
3 | The official repository for <Autoencoding Under Normalization Constraints> (Yoon, Noh and Park, ICML 2021) and **normalized autoencoders**.
4 |
5 | > The paper proposes Normalized Autoencoder (NAE), which is a novel energy-based model where the energy function is the reconstruction error. NAE effectively remedies outlier reconstruction, a pathological phenomenon limiting the performance of an autoencoder as an outlier detector.
6 |
7 | Paper: https://arxiv.org/abs/2105.05735
8 | 5-min video: https://www.youtube.com/watch?v=ra6usGKnPGk
9 |
10 |
11 | 
12 |
13 | ## News
14 |
15 | * 2023-06-15 : Add config file for FashionMNIST
16 | * 2022-06-13 : Refactoring of NAE class.
17 |
18 |
19 | ## Set-up
20 |
21 | ### Environment
22 |
23 | I encourage you to use conda to set up a virtual environment. However, other methods should work without problems.
24 | ```
25 | conda create -n nae python=3.7
26 | ```
27 |
28 | The main dependencies of the repository is as follows:
29 |
30 | - python 3.7.2
31 | - numpy
32 | - pillow
33 | - pytorch 1.7.1
34 | - CUDA 10.1
35 | - scikit-learn 0.24.2
36 | - tensorboard 2.5.0
37 | - pytest 6.2.3
38 |
39 |
40 | ### Datasets
41 |
42 | All datasets are stored in `datasets/` directory.
43 |
44 | - MNIST, CIFAR-10, SVHN, Omniglot : Retrieved using `torchvision.dataset`.
45 | - Noise, Constant, ConstantGray : [Dropbox link](https://www.dropbox.com/sh/u41ewgwujuvqvpm/AABM6YbklJFAruczJPhBWNwZa?dl=0)
46 | - [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), [ImageNet 32x32](http://image-net.org/small/download.php): Retrieved from their official site. I am afraid that website for ImageNet 32x32 is not available as of June 24, 2021. I will temporarily upload the data to the above Dropbox link.
47 |
48 | When set up, the dataset directory should look like as follows.
49 |
50 | ```
51 | datasets
52 | ├── CelebA
53 | │ ├── Anno
54 | │ ├── Eval
55 | │ └── Img
56 | ├── cifar-10-batches-py
57 | ├── const_img_gray.npy
58 | ├── const_img.npy
59 | ├── FashionMNIST
60 | ├── ImageNet32
61 | │ ├── train_32x32
62 | │ └── valid_32x32
63 | ├── MNIST
64 | ├── noise_img.npy
65 | ├── omniglot-py
66 | │ ├── images_background
67 | │ └── images_evaluation
68 | ├── test_32x32.mat
69 | └── train_32x32.mat
70 |
71 | ```
72 |
73 | ### Pre-trained Models
74 |
75 | Pre-trained models are stored under `pretrained/`. The pre-trained models are provided through the [Dropbox link](https://www.dropbox.com/sh/u41ewgwujuvqvpm/AABM6YbklJFAruczJPhBWNwZa?dl=0).
76 |
77 | If the pretrained models are prepared successfully, the directory structure should look like the following.
78 |
79 | ```
80 | pretrained
81 | ├── celeba64_ood_nae
82 | │ └── z64gr_h32g8
83 | ├── cifar_ood_nae
84 | │ └── z32gn
85 | └── mnist_ood_nae
86 | └── z32
87 | ```
88 |
89 | ## Unittesting
90 |
91 | PyTest is used for unittesting.
92 |
93 | ```
94 | pytest tests
95 | ```
96 |
97 | The code should pass all tests after the preparation of pre-trained models and datasets.
98 |
99 | ## Execution
100 |
101 | ### OOD Detection Evaluation
102 |
103 | ```
104 | python evaluate_ood.py --ood ConstantGray_OOD,FashionMNIST_OOD,SVHN_OOD,CelebA_OOD,Noise_OOD --resultdir pretrained/cifar_ood_nae/z32gn/ --ckpt nae_9.pkl --config z32gn.yml --device 0 --dataset CIFAR10_OOD
105 | ```
106 |
107 |
108 | Expected Results
109 |
110 | ```
111 | OOD Detection Results in AUC
112 | ConstantGray_OOD:0.9632
113 | FashionMNIST_OOD:0.8193
114 | SVHN_OOD:0.9196
115 | CelebA_OOD:0.8873
116 | Noise_OOD:1.0000
117 | ```
118 |
119 |
120 |
121 | ### Training
122 |
123 | Use `train.py` to train NAE.
124 | * `--config` option specifies a path to a configuration yaml file.
125 | * `--logdir` specifies a directory where results files will be written.
126 | * `--run` specifies an id for each run, i.e., an experiment.
127 |
128 | **Training on MNIST**
129 | ```
130 | python train.py --config configs/mnist_ood_nae/z32.yml --logdir results/mnist_ood_nae/ --run run --device 0
131 | ```
132 |
133 | **Training on MNIST digits 0 to 8 for the hold-out digit detection task**
134 | ```
135 | python train.py --config configs/mnist_ho_nae/l2_z32.yml --logdir results/mnist_ho_nae --run run --device 0
136 | ```
137 |
138 | **Training on CIFAR-10**
139 | ```
140 | python train.py --config configs/cifar_ood_nae/z32gn.yml --logdir results/cifar_ood_nae/ --run run --device 0
141 | ```
142 |
143 | **Training on CelebA 64x64**
144 | ```
145 | python train.py --config configs/celeba64_ood_nae/z64gr_h32g8.yml --logdir results/celeba64_ood_nae/z64gr_h32g8.yml --run run --device 0
146 | ```
147 |
148 | **Training on FashionMNIST**
149 | ```
150 | python train.py --config configs/fmnist_ood_nae/z32.yml --device 0 --logdir results/fmnist_ood_nae --run run --device 0
151 | ```
152 |
153 |
154 | ### Sampling
155 |
156 | Use `sample.py` to generate sample images form NAE. Samples are saved as `.npy` file containing an `(n_sample, img_h, img_w, channels)` array.
157 | Note that the quality of generated images is not supposed to match that of state-of-the-art generative models. Improving the sample quality is one of the important future research direction.
158 |
159 | **Sampling for MNIST**
160 |
161 | ```
162 | python sample.py pretrained/mnist_ood_nae/z32/ z32.yml nae_20.pkl --zstep 200 --x_shape 28 --batch_size 64 --n_sample 64 --x_channel 1 --device 0
163 | ```
164 |
165 | 
166 |
167 | The white square is an artifact of NAE, possibly occurring due to the distortion of the encoder and the decoder.
168 |
169 | The result is comparable to the samples from a vanilla autoencoder generated with the same procedure.
170 |
171 | 
172 |
173 |
174 | **Sampling for CIFAR-10**
175 | ```
176 | python sample.py pretrained/cifar_ood_nae/z32gn/ z32gn.yml nae_8.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0
177 | ```
178 |
179 | 
180 |
181 | **Sampling for CelebA 64x64**
182 | ```
183 | python sample.py pretrained/celeba64_ood_nae/z64gr_h32g8/ z64gr_h32g8.yml nae_3.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0 --x_shape 64
184 | ```
185 |
186 |
187 | 
188 |
189 |
190 |
191 | ## Citation
192 |
193 |
194 | ```
195 | @InProceedings{pmlr-v139-yoon21c,
196 | title = {Autoencoding Under Normalization Constraints},
197 | author = {Yoon, Sangwoong and Noh, Yung-Kyun and Park, Frank},
198 | booktitle = {Proceedings of the 38th International Conference on Machine Learning},
199 | pages = {12087--12097},
200 | year = {2021},
201 | editor = {Meila, Marina and Zhang, Tong},
202 | volume = {139},
203 | series = {Proceedings of Machine Learning Research},
204 | month = {18--24 Jul},
205 | publisher = {PMLR},
206 | pdf = {http://proceedings.mlr.press/v139/yoon21c/yoon21c.pdf},
207 | url = {https://proceedings.mlr.press/v139/yoon21c.html}
208 | }
209 |
210 | ```
211 |
212 |
--------------------------------------------------------------------------------
/assets/celeba64samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/celeba64samples.png
--------------------------------------------------------------------------------
/assets/cifar10samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/cifar10samples.png
--------------------------------------------------------------------------------
/assets/fig_mnist_recon_with_box_v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/fig_mnist_recon_with_box_v2.png
--------------------------------------------------------------------------------
/assets/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/main.png
--------------------------------------------------------------------------------
/assets/mnistsamples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/mnistsamples.png
--------------------------------------------------------------------------------
/assets/vanilla_ae_mnistsamples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/assets/vanilla_ae_mnistsamples.png
--------------------------------------------------------------------------------
/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from torchvision.transforms import (
3 | RandomHorizontalFlip,
4 | RandomRotation,
5 | RandomResizedCrop,
6 | ColorJitter,
7 | RandomGrayscale,
8 | RandomChoice,
9 | Compose,
10 | Normalize,
11 | ToTensor
12 | )
13 | from augmentations.augmentations import (
14 | RandomRotate90,
15 | ColorJitterSimCLR,
16 | GaussianDequantize,
17 | UniformDequantize,
18 | ToGray,
19 | )
20 |
21 | logger = logging.getLogger("ptsemseg")
22 |
23 |
24 | key2aug = {
25 | 'hflip': RandomHorizontalFlip,
26 | 'rotate': RandomRotate90,
27 | 'rcrop': RandomResizedCrop,
28 | 'cjitter': ColorJitterSimCLR,
29 | 'rgray': RandomGrayscale,
30 | 'GaussianDequantize': GaussianDequantize,
31 | 'UniformDequantize': UniformDequantize,
32 | 'togray': ToGray,
33 | 'normalize': Normalize,
34 | 'totensor': ToTensor
35 | }
36 |
37 |
38 | def get_composed_augmentations(aug_dict):
39 | if aug_dict is None:
40 | # print("Using No Augmentations")
41 | return None
42 |
43 | augmentations = []
44 | for aug_key, aug_param in aug_dict.items():
45 | augmentations.append(key2aug[aug_key](**aug_param))
46 | # print("Using {} aug with params {}".format(aug_key, aug_param))
47 | return Compose(augmentations)
48 |
--------------------------------------------------------------------------------
/augmentations/augmentations.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import random
4 | import numpy as np
5 | import torch
6 | import torchvision.transforms.functional as tf
7 | from torchvision.transforms import (
8 | RandomApply,
9 | ColorJitter,
10 | )
11 |
12 | from PIL import Image, ImageOps
13 |
14 |
15 | class RandomRotate90:
16 | def __init__(self):
17 | self.t = RandomChoice([RandomRotation(0), RandomRotation(90),
18 | RandomRotation(180), RandomRotation(270)])
19 | def __call__(self, img):
20 | return self.t(img)
21 |
22 |
23 | class ColorJitterSimCLR:
24 | def __init__(self, jitter_d=0.5, jitter_p=0.8):
25 | self.t = RandomApply([ColorJitter(0.8 * jitter_d, 0.8 * jitter_d, 0.8 * jitter_d, 0.2 * jitter_d)],
26 | p=jitter_p)
27 |
28 | def __call__(self, img):
29 | return self.t(img)
30 |
31 |
32 | class Invert(object):
33 | """Transform that inverts pixel intensities"""
34 | def __init__(self):
35 | pass
36 |
37 | def __call__(self, img):
38 | return 1 - img
39 |
40 |
41 | class ToGray:
42 | """ Transform color image to gray scale"""
43 | def __init__(self):
44 | pass
45 |
46 | def __call__(self, img):
47 | img = torch.sqrt(torch.sum(img ** 2, dim=0, keepdim=True))
48 | return torch.cat([img, img, img], dim=0).detach()
49 |
50 |
51 | class Fragment:
52 | def __init__(self, mode):
53 | self.mode = mode
54 |
55 | def __call__(self, sample):
56 | '''sample: numpy array for image (after ToTensor)'''
57 | if self.mode == 'horizontal':
58 | half = (sample.shape[2]) // 2
59 | if np.random.rand() > 0.5: # uppder half
60 | sample[:,:,:half] = 0
61 | else:
62 | sample[:,:,half:] = 0
63 |
64 | elif self.mode == 'vertical':
65 | half = (sample.shape[1]) // 2
66 | if np.random.rand() > 0.5: # uppder half
67 | sample[:,:half,:] = 0
68 | else:
69 | sample[:,half:,:] = 0
70 | elif self.mode == '1/4':
71 | choice = (np.random.rand() * 4) % 4
72 | half1 = (sample.shape[1]) // 2
73 | half2 = (sample.shape[2]) // 2
74 | img = torch.zeros_like(sample)
75 | if choice == 0:
76 | img[:,:half1,:half2] = sample[:,:half1,:half2]
77 | elif choice == 1:
78 | img[:,half1:,:half2] = sample[:,half1:,:half2]
79 | elif choice == 2:
80 | img[:,half1:,half2:] = sample[:,half1:,half2:]
81 | else:
82 | img[:,:half1,half2:] = sample[:,:half1,half2:]
83 | sample = img
84 | else:
85 | raise ValueError
86 | return sample
87 |
88 |
89 | class UniformDequantize:
90 | def __init__(self):
91 | pass
92 |
93 | def __call__(self, img):
94 | img = img / 256. * 255. + torch.rand_like(img) / 256.
95 | return img
96 |
97 |
98 | class GaussianDequantize:
99 | def __init__(self):
100 | pass
101 |
102 | def __call__(self, img):
103 | img = img + torch.randn_like(img) * 0.01
104 | return img
105 |
106 |
--------------------------------------------------------------------------------
/configs/celeba64_ood_nae/z64gr_h32g8.yml:
--------------------------------------------------------------------------------
1 | trainer: nae
2 | logger: nae
3 | model:
4 | arch: nae
5 | encoder:
6 | arch: conv64
7 | nh: 32
8 | out_activation: linear
9 | activation: relu
10 | use_bn: false
11 | num_groups: 8
12 | decoder:
13 | arch: deconv64
14 | nh: 32
15 | use_bn: false
16 | num_groups: 8
17 | out_activation: sigmoid
18 | nae:
19 | spherical: True
20 | gamma: 1
21 | sampling: on_manifold
22 |
23 | z_step: 20
24 | z_stepsize: 1
25 | z_noise_std: 0.02
26 | z_noise_anneal: Null
27 | z_clip_langevin_grad: Null
28 | x_step: 40
29 | x_stepsize: 20
30 | x_noise_std: 0.02
31 | x_noise_anneal: 1
32 | x_clip_langevin_grad: 0.01
33 |
34 | buffer_size: 10000
35 | replay_ratio: 0.95
36 | replay: True
37 |
38 | x_bound:
39 | - 0
40 | - 1
41 | z_bound: Null
42 | l2_norm_reg: Null
43 | l2_norm_reg_en: Null
44 | temperature: 1.
45 | temperature_trainable: False
46 |
47 | x_dim: 3
48 | z_dim: 64
49 | data:
50 | indist_train:
51 | dataset: CelebA_OOD
52 | path: datasets
53 | batch_size: 128
54 | n_workers: 8
55 | split: training
56 | shuffle: True
57 | size: 64
58 | augmentations:
59 | hflip:
60 | p: 0.5
61 | dequant:
62 | UniformDequantize: {}
63 | indist_val:
64 | dataset: CelebA_OOD
65 | path: datasets
66 | batch_size: 128
67 | n_workers: 8
68 | split: validation
69 | size: 64
70 | ood_val:
71 | dataset: CelebA_OOD
72 | channel: 3
73 | path: datasets
74 | batch_size: 128
75 | split: validation
76 | n_workers: 8
77 | size: 64
78 | dequant:
79 | togray: {}
80 | ood_target:
81 | dataset: ConstantGray_OOD
82 | size: 64
83 | path: datasets
84 | batch_size: 128
85 | n_workers: 8
86 | split: validation
87 | training:
88 | # load_ae: /opt/home3/swyoon/energy-based-autoencoder/src/results/celeba_ood_nae/big/run/model_best.pkl
89 | ae_epoch: 30
90 | nae_epoch: 15
91 | lr_schedule: null
92 | # name: on_plateau
93 | # factor: 0.5
94 | # patience: 3
95 | # min_lr: 1.0e-4
96 | resume: null
97 | # file: /a/b/c.pkl
98 | # optimizer: False
99 | print_interval: 100
100 | val_interval: 200
101 | save_interval: 2000
102 | ae_lr: 1.0e-4
103 | nae_lr: 1.0e-5
104 | temperature_lr: 1.0e-3
105 | fix_D: false
106 |
--------------------------------------------------------------------------------
/configs/cifar_ood_nae/z32gn.yml:
--------------------------------------------------------------------------------
1 | data:
2 | indist_train:
3 | augmentations:
4 | hflip:
5 | p: 0.5
6 | batch_size: 128
7 | dataset: CIFAR10_OOD
8 | dequant:
9 | UniformDequantize: {}
10 | n_workers: 8
11 | path: datasets
12 | shuffle: true
13 | split: training
14 | indist_val:
15 | batch_size: 128
16 | dataset: CIFAR10_OOD
17 | n_workers: 8
18 | path: datasets
19 | split: validation
20 | ood_target:
21 | batch_size: 128
22 | dataset: Constant_OOD
23 | n_workers: 8
24 | path: datasets
25 | size: 32
26 | split: validation
27 | ood_val:
28 | batch_size: 128
29 | channel: 3
30 | dataset: SVHN_OOD
31 | n_workers: 8
32 | path: datasets
33 | split: validation
34 | device: cuda:3
35 | logger: nae
36 | model:
37 | arch: nae
38 | decoder:
39 | arch: sngan_generator_gn
40 | hidden_dim: 128
41 | num_groups: 2
42 | out_activation: sigmoid
43 | encoder:
44 | arch: IGEBMEncoder
45 | keepdim: true
46 | use_spectral_norm: false
47 | nae:
48 | buffer_size: 10000
49 | gamma: 1
50 | l2_norm_reg: null
51 | l2_norm_reg_en: null
52 | replay: true
53 | replay_ratio: 0.95
54 | sampling: on_manifold
55 | spherical: true
56 | temperature: 1.0
57 | temperature_trainable: false
58 | x_bound:
59 | - 0
60 | - 1
61 | x_clip_langevin_grad: 0.01
62 | x_noise_anneal: 1
63 | x_noise_std: 0.02
64 | x_step: 40
65 | x_stepsize: 10
66 | z_bound: null
67 | z_clip_langevin_grad: null
68 | z_noise_anneal: null
69 | z_noise_std: 0.02
70 | z_step: 20
71 | z_stepsize: 1
72 | x_dim: 3
73 | z_dim: 32
74 | trainer: nae
75 | training:
76 | ae_epoch: 180
77 | ae_lr: 0.0001
78 | fix_D: false
79 | lr_schedule: null
80 | nae_epoch: 10
81 | nae_lr: 1.0e-05
82 | print_interval: 352
83 | resume: null
84 | save_interval: 2000
85 | temperature_lr: 0.001
86 | val_interval: 352
87 |
--------------------------------------------------------------------------------
/configs/fmnist_ood_nae/z32.yml:
--------------------------------------------------------------------------------
1 | trainer: nae
2 | logger: nae
3 | model:
4 | arch: nae
5 | encoder:
6 | arch: conv2fc
7 | nh: 8
8 | nh_mlp: 1024
9 | out_activation: linear
10 | decoder:
11 | arch: deconv2
12 | nh: 8
13 | out_activation: sigmoid
14 | nae:
15 | spherical: True
16 | gamma: 1
17 | sampling: on_manifold
18 |
19 | z_step: 10
20 | z_stepsize: 0.2
21 | z_noise_std: 0.05
22 | z_noise_anneal: Null
23 | z_clip_langevin_grad: Null
24 | x_step: 50
25 | x_stepsize: 10
26 | x_noise_std: 0.05
27 | x_noise_anneal: 1
28 | x_clip_langevin_grad: 0.01
29 |
30 | buffer_size: 10000
31 | replay_ratio: 0.95
32 | replay: True
33 |
34 | x_bound:
35 | - 0
36 | - 1
37 | z_bound: Null
38 | l2_norm_reg: Null
39 | l2_norm_reg_en: 1e-4
40 | temperature: 1.
41 | temperature_trainable: False
42 |
43 | x_dim: 1
44 | z_dim: 32
45 | data:
46 | indist_train:
47 | dataset: FashionMNIST_OOD
48 | path: datasets
49 | shuffle: True
50 | batch_size: 128
51 | n_workers: 8
52 | split: training
53 | size: 28
54 | dequant:
55 | UniformDequantize: {}
56 | indist_val:
57 | dataset: FashionMNIST_OOD
58 | path: datasets
59 | batch_size: 128
60 | n_workers: 8
61 | split: validation
62 | size: 28
63 | ood_val:
64 | dataset: Constant_OOD
65 | size: 28
66 | channel: 1
67 | path: datasets
68 | split: validation
69 | n_workers: 4
70 | batch_size: 128
71 | ood_target:
72 | dataset: MNIST_OOD
73 | size: 28
74 | channel: 1
75 | path: datasets
76 | split: validation
77 | n_workers: 4
78 | batch_size: 128
79 | training:
80 | ae_epoch: 100
81 | nae_epoch: 50
82 | save_interval: 2000
83 | val_interval: 1000
84 | print_interval: 100
85 | ae_lr: 1e-4
86 | nae_lr: 1e-5
87 | fix_D: false
88 |
--------------------------------------------------------------------------------
/configs/mnist_ho_nae/l2_z32.yml:
--------------------------------------------------------------------------------
1 | trainer: nae_v2
2 | logger: base
3 | model:
4 | arch: nae_l2
5 | sampling: omi
6 | encoder:
7 | arch: conv2fc
8 | nh: 8
9 | nh_mlp: 1024
10 | out_activation: spherical
11 | decoder:
12 | arch: deconv2
13 | nh: 8
14 | out_activation: sigmoid
15 | nae:
16 | gamma: 1
17 | l2_norm_reg_de: Null
18 | l2_norm_reg_en: 0.0001
19 | T: 1.
20 | sampler_z:
21 | sampler: langevin
22 | n_step: 10
23 | stepsize: 0.2
24 | noise_std: 0.05
25 | noise_anneal: Null
26 | clip_langevin_grad: Null
27 | buffer_size: 10000
28 | replay_ratio: 0.95
29 | mh: False
30 | bound: spherical
31 | initial_dist: uniform_sphere
32 | sampler_x:
33 | sampler: langevin
34 | n_step: 50
35 | stepsize: 10
36 | noise_std: 0.05
37 | noise_anneal: 1
38 | clip_langevin_grad: 0.01
39 | mh: False
40 | buffer_size: 0
41 | bound: [0, 1]
42 | x_dim: 1
43 | z_dim: 32
44 | data:
45 | indist_train:
46 | dataset: MNISTLeaveOut
47 | path: datasets
48 | batch_size: 128
49 | n_workers: 8
50 | shuffle: True
51 | split: training
52 | out_class:
53 | - 9
54 | indist_val:
55 | dataset: MNISTLeaveOut
56 | path: datasets
57 | batch_size: 128
58 | n_workers: 8
59 | split: validation
60 | out_class:
61 | - 9
62 | ood_val:
63 | dataset: Constant_OOD
64 | size: 28
65 | channel: 1
66 | path: datasets
67 | batch_size: 128
68 | split: validation
69 | n_workers: 8
70 | ood_target:
71 | dataset: MNISTLeaveOut
72 | path: datasets
73 | batch_size: 128
74 | n_workers: 8
75 | split: validation
76 | out_class:
77 | - 0
78 | - 1
79 | - 2
80 | - 3
81 | - 4
82 | - 5
83 | - 6
84 | - 7
85 | - 8
86 | training:
87 | load_ae: Null
88 | ae_epoch: 100
89 | nae_epoch: 50
90 | print_interval: 500
91 | print_interval_nae: 100
92 | val_interval: 352
93 | save_interval: 2000
94 | ae_lr: 1.0e-4
95 | nae_lr: 1.0e-5
96 | nae_opt: all
97 |
--------------------------------------------------------------------------------
/configs/mnist_ood_nae/z32.yml:
--------------------------------------------------------------------------------
1 | trainer: nae
2 | logger: nae
3 | model:
4 | arch: nae
5 | encoder:
6 | arch: conv2fc
7 | nh: 8
8 | nh_mlp: 1024
9 | out_activation: linear
10 | decoder:
11 | arch: deconv2
12 | nh: 8
13 | out_activation: sigmoid
14 | nae:
15 | spherical: True
16 | gamma: 1
17 | sampling: on_manifold
18 |
19 | z_step: 10
20 | z_stepsize: 0.2
21 | z_noise_std: 0.05
22 | z_noise_anneal: Null
23 | z_clip_langevin_grad: Null
24 | x_step: 50
25 | x_stepsize: 10
26 | x_noise_std: 0.05
27 | x_noise_anneal: 1
28 | x_clip_langevin_grad: 0.01
29 |
30 | buffer_size: 10000
31 | replay_ratio: 0.95
32 | replay: True
33 |
34 | x_bound:
35 | - 0
36 | - 1
37 | z_bound: Null
38 | l2_norm_reg: Null
39 | l2_norm_reg_en: 1e-4
40 | temperature: 1.
41 | temperature_trainable: False
42 |
43 | x_dim: 1
44 | z_dim: 32
45 | data:
46 | indist_train:
47 | dataset: MNIST_OOD
48 | path: datasets
49 | shuffle: True
50 | batch_size: 128
51 | n_workers: 8
52 | split: training
53 | size: 28
54 | dequant:
55 | UniformDequantize: {}
56 | indist_val:
57 | dataset: MNIST_OOD
58 | path: datasets
59 | batch_size: 128
60 | n_workers: 8
61 | split: validation
62 | size: 28
63 | ood_val:
64 | dataset: Constant_OOD
65 | size: 28
66 | channel: 1
67 | path: datasets
68 | split: validation
69 | n_workers: 4
70 | batch_size: 128
71 | ood_target:
72 | dataset: Noise_OOD
73 | size: 28
74 | channel: 1
75 | path: datasets
76 | split: validation
77 | n_workers: 4
78 | batch_size: 128
79 | training:
80 | ae_epoch: 100
81 | nae_epoch: 50
82 | save_interval: 2000
83 | val_interval: 1000
84 | print_interval: 100
85 | ae_lr: 1e-4
86 | nae_lr: 1e-5
87 | fix_D: false
88 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/conftest.py
--------------------------------------------------------------------------------
/datasets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/datasets/.gitkeep
--------------------------------------------------------------------------------
/evaluate_ood.py:
--------------------------------------------------------------------------------
1 | """
2 | evaluate OOD detection performance through AUROC score
3 |
4 | Example:
5 | python evaluate_cifar_ood.py --dataset FashionMNIST_OOD \
6 | --ood MNIST_OOD,ConstantGray_OOD \
7 | --resultdir results/fmnist_ood_vqvae/Z7K512/e300 \
8 | --ckpt model_epoch_280.pkl \
9 | --config Z7K512.yml \
10 | --device 1
11 | """
12 | import os
13 | import yaml
14 | import argparse
15 | import copy
16 | import torch
17 | import numpy as np
18 | from torch.utils import data
19 | from models import get_model, load_pretrained
20 | from loaders import get_dataloader
21 |
22 | from utils import roc_btw_arr, batch_run, search_params_intp, parse_unknown_args, parse_nested_args
23 |
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--resultdir', type=str, help='result dir. results/... or pretrained/...')
27 | parser.add_argument('--config', type=str, help='config file name')
28 | parser.add_argument('--ckpt', type=str, help='checkpoint file name to load. default', default=None)
29 | parser.add_argument('--ood', type=str, help='list of OOD datasets, separated by comma')
30 | parser.add_argument('--device', type=str, help='device')
31 | parser.add_argument('--dataset', type=str, choices=['MNIST_OOD', 'CIFAR10_OOD', 'ImageNet32', 'FashionMNIST_OOD',
32 | 'FashionMNISTpad_OOD'],
33 | default='MNIST', help='inlier dataset dataset')
34 | parser.add_argument('--aug', type=str, help='pre-defiend data augmentation', choices=[None, 'CIFAR10', 'CIFAR10-OE'])
35 | parser.add_argument('--method', type=str, choices=[None, 'likelihood_regret', 'input_complexity', 'outlier_exposure'])
36 | args, unknown = parser.parse_known_args()
37 | d_cmd_cfg = parse_unknown_args(unknown)
38 | d_cmd_cfg = parse_nested_args(d_cmd_cfg)
39 | print(d_cmd_cfg)
40 |
41 |
42 | # load config file
43 | cfg = yaml.load(open(os.path.join(args.resultdir, args.config)), Loader=yaml.FullLoader)
44 | result_dir = args.resultdir
45 | if args.ckpt is not None:
46 | ckpt_file = os.path.join(result_dir, args.ckpt)
47 | else:
48 | raise ValueError(f'ckpt file not specified')
49 |
50 | print(f'loading from {ckpt_file}')
51 | l_ood = [s.strip() for s in args.ood.split(',')]
52 | device = f'cuda:{args.device}'
53 |
54 | print(f'loading from : {ckpt_file}')
55 |
56 |
57 | def evaluate(m, in_dl, out_dl, device):
58 | """computes OOD detection score"""
59 | in_pred = batch_run(m, in_dl, device, method='predict')
60 | out_pred = batch_run(m, out_dl, device, method='predict')
61 | auc = roc_btw_arr(out_pred, in_pred)
62 | return auc
63 |
64 |
65 | # load dataset
66 | print('ood datasets')
67 | print(l_ood)
68 | if args.dataset in {'MNIST_OOD', 'FashionMNIST_OOD'}:
69 | size = 28
70 | channel = 1
71 | else:
72 | size = 32
73 | channel = 3
74 | data_dict = {'path': 'datasets',
75 | 'size': size,
76 | 'channel': channel,
77 | 'batch_size': 64,
78 | 'n_workers': 4,
79 | 'split': 'evaluation',
80 | 'path': 'datasets'}
81 |
82 |
83 | data_dict_ = copy.copy(data_dict)
84 | data_dict_['dataset'] = args.dataset
85 | in_dl = get_dataloader(data_dict_)
86 |
87 | l_ood_dl = []
88 | for ood_name in l_ood:
89 | data_dict_ = copy.copy(data_dict)
90 | data_dict_['dataset'] = ood_name
91 | dl = get_dataloader(data_dict_)
92 | l_ood_dl.append(dl)
93 |
94 | model = get_model(cfg).to(device)
95 | ckpt_data = torch.load(ckpt_file)
96 | if 'model_state' in ckpt_data:
97 | model.load_state_dict(ckpt_data['model_state'])
98 | else:
99 | model.load_state_dict(torch.load(ckpt_file))
100 |
101 | model.eval()
102 | model.to(device)
103 |
104 | in_pred = batch_run(model, in_dl, device=device, no_grad=False)
105 |
106 | l_ood_pred = []
107 | for dl in l_ood_dl:
108 | out_pred = batch_run(model, dl, device=device, no_grad=False)
109 | l_ood_pred.append(out_pred)
110 |
111 | l_ood_auc = []
112 | for pred in l_ood_pred:
113 | l_ood_auc.append(roc_btw_arr(pred, in_pred))
114 |
115 | print('OOD Detection Results in AUC')
116 | for ds, auc in zip(l_ood, l_ood_auc):
117 | print(f'{ds}:{auc:.4f}')
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/loaders/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import json
4 | import torch
5 | from torch.utils import data
6 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop,\
7 | Pad, Normalize
8 |
9 | from loaders.leaveout_dataset import MNISTLeaveOut, CIFAR10LeaveOut
10 | from loaders.modified_dataset import Gray2RGB, MNIST_OOD, FashionMNIST_OOD, \
11 | CIFAR10_OOD, SVHN_OOD, Constant_OOD, \
12 | Noise_OOD, CIFAR100_OOD, CelebA_OOD, \
13 | NotMNIST, ConstantGray_OOD, ImageNet32
14 | from loaders.chimera_dataset import Chimera
15 | from torchvision.datasets import FashionMNIST, Omniglot
16 | from augmentations import get_composed_augmentations
17 | from augmentations.augmentations import ToGray, Invert, Fragment
18 |
19 |
20 | OOD_SIZE = 32 # common image size for OOD detection experiments
21 |
22 |
23 | def get_dataloader(data_dict, mode=None, mode_dict=None, data_aug=None):
24 | """constructs DataLoader
25 | data_dict: data part of cfg
26 |
27 | mode: deprecated argument
28 | mode_dict: deprecated argument
29 | data_aug: deprecated argument
30 |
31 | Example data_dict
32 | dataset: FashionMNISTpad_OOD
33 | path: datasets
34 | shuffle: True
35 | batch_size: 128
36 | n_workers: 8
37 | split: training
38 | dequant:
39 | UniformDequantize: {}
40 | """
41 |
42 | # dataset loading
43 | aug = get_composed_augmentations(data_dict.get('augmentations', None))
44 | dequant = get_composed_augmentations(data_dict.get('dequant', None))
45 | dataset = get_dataset(data_dict, split_type=None, data_aug=aug, dequant=dequant)
46 |
47 | # dataloader loading
48 | loader = data.DataLoader(
49 | dataset,
50 | batch_size=data_dict["batch_size"],
51 | num_workers=data_dict["n_workers"],
52 | shuffle=data_dict.get('shuffle', False),
53 | pin_memory=False,
54 | )
55 |
56 | return loader
57 |
58 |
59 | def get_dataset(data_dict, split_type=None, data_aug=None, dequant=None):
60 | """
61 | split_type: deprecated argument
62 | """
63 | do_concat = any([k.startswith('concat') for k in data_dict.keys()])
64 | if do_concat:
65 | if data_aug is not None:
66 | return data.ConcatDataset([get_dataset(d, data_aug=data_aug) for k, d in data_dict.items() if k.startswith('concat')])
67 | elif dequant is not None:
68 | return data.ConcatDataset([get_dataset(d, dequant=dequant) for k, d in data_dict.items() if k.startswith('concat')])
69 | else: return data.ConcatDataset([get_dataset(d) for k, d in data_dict.items() if k.startswith('concat')])
70 | name = data_dict["dataset"]
71 | split_type = data_dict['split']
72 | data_path = data_dict["path"][split_type] if split_type in data_dict["path"] else data_dict["path"]
73 |
74 | # default tranform behavior.
75 | original_data_aug = data_aug
76 | if data_aug is not None:
77 | #data_aug = Compose([data_aug, ToTensor()])
78 | data_aug = Compose([ToTensor(), data_aug])
79 | else:
80 | data_aug = ToTensor()
81 |
82 | if dequant is not None: # dequantization should be applied last
83 | data_aug = Compose([data_aug, dequant])
84 |
85 |
86 | # datasets
87 | if name == 'MNISTLeaveOut':
88 | l_out_class = data_dict['out_class']
89 | dataset = MNISTLeaveOut(data_path, l_out_class=l_out_class, split=split_type, download=True,
90 | transform=data_aug)
91 | elif name == 'MNISTLeaveOutFragment':
92 | l_out_class = data_dict['out_class']
93 | fragment = data_dict['fragment']
94 | dataset = MNISTLeaveOut(data_path, l_out_class=l_out_class, split=split_type, download=True,
95 | transform=Compose([ToTensor(),
96 | Fragment(fragment)]))
97 | elif name == 'MNIST_OOD':
98 | size = data_dict.get('size', 28)
99 | if size == 28:
100 | l_transform = [ToTensor()]
101 | else:
102 | l_transform = [Gray2RGB(), Resize(OOD_SIZE), ToTensor()]
103 | dataset = MNIST_OOD(data_path, split=split_type, download=True,
104 | transform=Compose(l_transform))
105 | dataset.img_size = (size, size)
106 |
107 | elif name == 'MNISTpad_OOD':
108 | dataset = MNIST_OOD(data_path, split=split_type, download=True,
109 | transform=Compose([Gray2RGB(),
110 | Pad(2),
111 | ToTensor()]))
112 | dataset.img_size = (OOD_SIZE, OOD_SIZE)
113 |
114 | elif name == 'FashionMNIST_OOD':
115 | size = data_dict.get('size', 28)
116 | if size == 28:
117 | l_transform = [ToTensor()]
118 | else:
119 | l_transform = [Gray2RGB(), Resize(OOD_SIZE), ToTensor()]
120 |
121 | dataset = FashionMNIST_OOD(data_path, split=split_type, download=True,
122 | transform=Compose(l_transform))
123 | dataset.img_size = (size, size)
124 |
125 | elif name == 'FashionMNISTpad_OOD':
126 | dataset = FashionMNIST_OOD(data_path, split=split_type, download=True,
127 | transform=Compose([Gray2RGB(),
128 | Pad(2),
129 | ToTensor()]))
130 | dataset.img_size = (OOD_SIZE, OOD_SIZE)
131 |
132 | elif name == 'HalfMNIST':
133 | mnist = MNIST_OOD(data_path, split=split_type, download=True,
134 | transform=ToTensor())
135 | dataset = Chimera(mnist, mode='horizontal_blank')
136 | elif name == 'ChimeraMNIST':
137 | mnist = MNIST_OOD(data_path, split=split_type, download=True,
138 | transform=ToTensor())
139 | dataset = Chimera(mnist, mode='horizontal')
140 | elif name == 'CIFAR10_OOD':
141 | dataset = CIFAR10_OOD(data_path, split=split_type, download=True,
142 | transform=data_aug)
143 | dataset.img_size = (OOD_SIZE, OOD_SIZE)
144 |
145 | elif name == 'CIFAR10LeaveOut':
146 | l_out_class = data_dict['out_class']
147 | seed = data_dict.get('seed', 1)
148 | dataset = CIFAR10LeaveOut(data_path, l_out_class=l_out_class, split=split_type, download=True,
149 | transform=data_aug, seed=seed)
150 |
151 | elif name == 'CIFAR10_GRAY':
152 | dataset = CIFAR10_OOD(data_path, split=split_type, download=True,
153 | transform=Compose([ToTensor(),
154 | ToGray()]))
155 | dataset.img_size = (OOD_SIZE, OOD_SIZE)
156 |
157 |
158 | elif name == 'CIFAR100_OOD':
159 | dataset = CIFAR100_OOD(data_path, split=split_type, download=True,
160 | transform=ToTensor())
161 | dataset.img_size = (OOD_SIZE, OOD_SIZE)
162 |
163 | elif name == 'SVHN_OOD':
164 | dataset = SVHN_OOD(data_path, split=split_type, download=True,
165 | transform=data_aug)
166 | dataset.img_size = (OOD_SIZE, OOD_SIZE)
167 |
168 | elif name == 'Constant_OOD':
169 | size = data_dict.get('size', OOD_SIZE)
170 | channel = data_dict.get('channel', 3)
171 | dataset = Constant_OOD(data_path, split=split_type, size=(size, size),
172 | channel=channel,
173 | transform=ToTensor())
174 |
175 | elif name == 'ConstantGray_OOD':
176 | size = data_dict.get('size', OOD_SIZE)
177 | channel = data_dict.get('channel', 3)
178 | dataset = ConstantGray_OOD(data_path, split=split_type, size=(size, size),
179 | channel=channel,
180 | transform=ToTensor())
181 |
182 | elif name == 'Noise_OOD':
183 | channel = data_dict.get('channel', 3)
184 | size = data_dict.get('size', OOD_SIZE)
185 | dataset = Noise_OOD(data_path, split=split_type,
186 | transform=ToTensor(), channel=channel, size=(size, size))
187 |
188 | elif name == 'CelebA_OOD':
189 | size = data_dict.get('size', OOD_SIZE)
190 | l_aug = []
191 | l_aug.append(CenterCrop(140))
192 | l_aug.append(Resize(size))
193 | if original_data_aug is not None:
194 | l_aug.append(original_data_aug)
195 | l_aug.append(ToTensor())
196 | if dequant is not None:
197 | l_aug.append(dequant)
198 | data_aug = Compose(l_aug)
199 | dataset = CelebA_OOD(data_path, split=split_type,
200 | transform=data_aug)
201 | dataset.img_size = (OOD_SIZE, OOD_SIZE)
202 |
203 | elif name == 'FashionMNIST': # normal FashionMNIS
204 | dataset = FashionMNIST_OOD(data_path, split=split_type, download=True,
205 | transform=ToTensor())
206 | dataset.img_size = (28, 28)
207 | elif name == 'MNIST': # normal MNIST
208 | dataset = MNIST_OOD(data_path, split=split_type, download=True,
209 | transform=ToTensor())
210 | dataset.img_size = (28, 28)
211 | elif name == 'NotMNIST':
212 | dataset = NotMNIST(data_path, split=split_type, transform=ToTensor())
213 | dataset.img_size = (28, 28)
214 | elif name == 'Omniglot':
215 | size = data_dict.get('size', OOD_SIZE)
216 | invert = data_dict.get('invert', True) # invert pixel intensity: x -> 1 - x
217 | if split_type == 'training':
218 | background = True
219 | else:
220 | background = False
221 |
222 | if invert:
223 | tr = Compose([Resize(size), ToTensor(), Invert()])
224 | else:
225 | tr = Compose([Resize(size), ToTensor()])
226 |
227 | dataset = Omniglot(data_path, background=background, download=False,
228 | transform=tr)
229 | elif name == 'ImageNet32':
230 | train_split_ratio = data_dict.get('train_split_ratio', 0.8)
231 | seed = data_dict.get('seed', 1)
232 | dataset = ImageNet32(data_path, split=split_type, transform=ToTensor(), seed=seed,
233 | train_split_ratio=train_split_ratio)
234 | else:
235 | n_classes = data_dict["n_classes"]
236 | split = data_dict['split'][split_type]
237 |
238 | param_dict = copy.deepcopy(data_dict)
239 | param_dict.pop("dataset")
240 | param_dict.pop("path")
241 | param_dict.pop("n_classes")
242 | param_dict.pop("split")
243 | param_dict.update({"split_type": split_type})
244 |
245 |
246 | dataset_instance = _get_dataset_instance(name)
247 | dataset = dataset_instance(
248 | data_path,
249 | n_classes,
250 | split=split,
251 | augmentations=data_aug,
252 | is_transform=True,
253 | **param_dict,
254 | )
255 |
256 | return dataset
257 |
258 |
259 | def _get_dataset_instance(name):
260 | """get_loader
261 |
262 | :param name:
263 | """
264 | return {
265 | "basic": basic_dataset,
266 | "inmemory": InMemoryDataset,
267 | }[name]
268 |
269 |
270 |
271 | def np_to_loader(l_tensors, batch_size, num_workers, load_all=False, shuffle=False):
272 | '''Convert a list of numpy arrays to a torch.DataLoader'''
273 | if load_all:
274 | dataset = data.TensorDataset(*[torch.Tensor(X).cuda() for X in l_tensors])
275 | num_workers = 0
276 | return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
277 | else:
278 | dataset = data.TensorDataset(*[torch.Tensor(X) for X in l_tensors])
279 | return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=False)
280 |
--------------------------------------------------------------------------------
/loaders/chimera_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 |
4 | class Chimera(Dataset):
5 | def __init__(self, ds, mode='horizontal'):
6 | self.ds = ds
7 | self.mode = mode
8 |
9 | def __len__(self):
10 | return len(self.ds)
11 |
12 | def __getitem__(self, i):
13 | img, lbl = self.ds[i]
14 | if self.mode == 'horizontal':
15 | half = (img.shape[2]) // 2
16 | img2, lbl2 = self.select_diff_digit(lbl, i)
17 |
18 | img[:, :half, :] = img2[:, :half, :]
19 |
20 | img[:, half-3:half+3, :] = 0
21 |
22 | return img, str(f'{lbl},{lbl2}')
23 |
24 | elif self.mode == 'horizontal_blank':
25 | half = (img.shape[2]) // 2
26 | if i % 2 == 0:
27 | img[:, :half, :] = 0.
28 | else:
29 | img[:, half:, :] = 0.
30 | return img, lbl
31 | else:
32 | raise ValueError
33 |
34 | def select_diff_digit(self, digit, i):
35 | new_digit = digit
36 | idx = i
37 | while new_digit == digit:
38 | idx += 1
39 | idx = idx % len(self.ds)
40 | # idx = np.random.randint(len(self.ds))
41 | img , lbl = self.ds[idx]
42 | new_digit = lbl
43 | return img, lbl
44 |
--------------------------------------------------------------------------------
/loaders/leaveout_dataset.py:
--------------------------------------------------------------------------------
1 | '''Dataset Class'''
2 | import os
3 | import sys
4 | import pickle
5 | import numpy as np
6 | import torch
7 | from torchvision.datasets.mnist import MNIST
8 | from torchvision.datasets.cifar import CIFAR10
9 | from utils import get_shuffled_idx
10 |
11 |
12 | class MNISTLeaveOut(MNIST):
13 | """
14 | MNIST Dataset with some digits excluded.
15 |
16 | targets will be 1 for excluded digits (outlier) and 0 for included digits.
17 |
18 | See also the original MNIST class:
19 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/mnist.html#MNIST
20 | """
21 | img_size = (28, 28)
22 |
23 | def __init__(self, root, l_out_class, split='training', transform=None, target_transform=None,
24 | download=False):
25 | """
26 | l_out_class : a list of ints. these clases are excluded in training
27 | """
28 | super(MNISTLeaveOut, self).__init__(root, transform=transform,
29 | target_transform=target_transform, download=download)
30 | if split == 'training' or split == 'validation':
31 | self.train = True # training set or test set
32 | else:
33 | self.train = False
34 | self.split = split
35 | self.l_out_class = list(l_out_class)
36 | for c in l_out_class:
37 | assert c in set(list(range(10)))
38 | set_out_class = set(l_out_class)
39 |
40 | if download:
41 | self.download()
42 |
43 | if not self._check_exists():
44 | raise RuntimeError('Dataset not found.' +
45 | ' You can use download=True to download it')
46 |
47 | if self.train:
48 | data_file = self.training_file
49 | else:
50 | data_file = self.test_file
51 |
52 | data, targets = torch.load(os.path.join(self.processed_folder, data_file))
53 |
54 | if split == 'training':
55 | data = data[:50000]
56 | targets = targets[:50000]
57 | elif split == 'validation':
58 | data = data[50000:]
59 | targets = targets[50000:]
60 |
61 | out_idx = torch.zeros(len(data), dtype=torch.bool) # pytorch 1.2
62 | for c in l_out_class:
63 | out_idx = out_idx | (targets == c)
64 |
65 | self.data = data[~out_idx]
66 | self.digits = targets[~out_idx]
67 | self.targets = self.digits
68 |
69 | @property
70 | def raw_folder(self):
71 | return os.path.join(self.root, 'MNIST', 'raw')
72 |
73 | @property
74 | def processed_folder(self):
75 | return os.path.join(self.root, 'MNIST', 'processed')
76 |
77 |
78 | class CIFAR10LeaveOut(CIFAR10):
79 | def __init__(self, root, l_out_class, split='training', transform=None, target_transform=None,
80 | download=True, seed=1):
81 |
82 | super(CIFAR10LeaveOut, self).__init__(root, transform=transform,
83 | target_transform=target_transform, download=True)
84 | assert split in ('training', 'validation', 'evaluation')
85 |
86 | if split == 'training' or split == 'validation':
87 | self.train = True
88 | shuffle_idx = get_shuffled_idx(50000, seed)
89 | else:
90 | self.train = False
91 | self.split = split
92 |
93 | if download:
94 | self.download()
95 |
96 | if not self._check_integrity():
97 | raise RuntimeError('Dataset not found or corrupted.' +
98 | ' You can use download=True to download it')
99 |
100 | if self.train:
101 | downloaded_list = self.train_list
102 | else:
103 | downloaded_list = self.test_list
104 |
105 | self.data = []
106 | self.targets = []
107 |
108 | # now load the picked numpy arrays
109 | for file_name, checksum in downloaded_list:
110 | file_path = os.path.join(self.root, self.base_folder, file_name)
111 | with open(file_path, 'rb') as f:
112 | if sys.version_info[0] == 2:
113 | entry = pickle.load(f)
114 | else:
115 | entry = pickle.load(f, encoding='latin1')
116 | self.data.append(entry['data'])
117 | if 'labels' in entry:
118 | self.targets.extend(entry['labels'])
119 | else:
120 | self.targets.extend(entry['fine_labels'])
121 |
122 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
123 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
124 | self.targets = np.array(self.targets)
125 |
126 | if split == 'training':
127 | self.data = self.data[shuffle_idx][:45000]
128 | self.targets = self.targets[shuffle_idx][:45000]
129 | elif split == 'validation':
130 | self.data = self.data[shuffle_idx][45000:]
131 | self.targets = self.targets[shuffle_idx][45000:]
132 |
133 | out_idx = torch.zeros(len(self.data), dtype=torch.bool)
134 |
135 | for c in l_out_class:
136 | out_idx = out_idx | (self.targets == c)
137 | out_idx = out_idx.bool()
138 |
139 | self.data = self.data[~out_idx]
140 | self.targets = self.targets[~out_idx]
141 | self._load_meta()
142 |
--------------------------------------------------------------------------------
/loaders/modified_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | modified_dataset.py
3 | ===================
4 | Inherited and modified pytorch datasets
5 | """
6 | import os
7 | import re
8 | import sys
9 | import numpy as np
10 | from scipy.io import loadmat
11 | import pickle
12 | from PIL import Image
13 | import torch
14 | from torch.utils.data import Dataset
15 | from torchvision.datasets.mnist import MNIST, FashionMNIST
16 | from torchvision.datasets.cifar import CIFAR10
17 | from torchvision.datasets.svhn import SVHN
18 | from torchvision.datasets.utils import verify_str_arg
19 | from utils import get_shuffled_idx
20 |
21 |
22 | class Gray2RGB:
23 | """change grayscale PIL image to RGB format.
24 | channel values are copied"""
25 | def __call__(self, x):
26 | return x.convert('RGB')
27 |
28 |
29 | class MNIST_OOD(MNIST):
30 | """
31 | See also the original MNIST class:
32 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/mnist.html#MNIST
33 | """
34 | def __init__(self, root, split='training', transform=None, target_transform=None,
35 | download=False, seed=1):
36 | super(MNIST_OOD, self).__init__(root, transform=transform,
37 | target_transform=target_transform, download=download)
38 | assert split in ('training', 'validation', 'evaluation')
39 | if split == 'training' or split == 'validation':
40 | self.train = True
41 | shuffle_idx = get_shuffled_idx(60000, seed)
42 | else:
43 | self.train = False
44 | self.split = split
45 |
46 | if download:
47 | self.download()
48 |
49 | if not self._check_exists():
50 | raise RuntimeError('Dataset not found.' +
51 | ' You can use download=True to download it')
52 |
53 | if self.train:
54 | data_file = self.training_file
55 | else:
56 | data_file = self.test_file
57 |
58 | data, targets = self.data, self.targets
59 |
60 | if split == 'training':
61 | self.data = data[shuffle_idx][:54000]
62 | self.targets = targets[shuffle_idx][:54000]
63 | elif split == 'validation':
64 | self.data = data[shuffle_idx][54000:]
65 | self.targets = targets[shuffle_idx][54000:]
66 | elif split == 'evaluation':
67 | self.data = data
68 | self.targets = targets
69 |
70 |
71 | class FashionMNIST_OOD(FashionMNIST):
72 | def __init__(self, root, split='training', transform=None, target_transform=None,
73 | download=False, seed=1):
74 | super(FashionMNIST_OOD, self).__init__(root, transform=transform,
75 | target_transform=target_transform, download=download)
76 | assert split in ('training', 'validation', 'evaluation')
77 | if split == 'training' or split == 'validation':
78 | self.train = True
79 | shuffle_idx = get_shuffled_idx(60000, seed)
80 | else:
81 | self.train = False
82 | self.split = split
83 |
84 | if download:
85 | self.download()
86 |
87 | if not self._check_exists():
88 | raise RuntimeError('Dataset not found.' +
89 | ' You can use download=True to download it')
90 |
91 | if self.train:
92 | data_file = self.training_file
93 | else:
94 | data_file = self.test_file
95 |
96 | data, targets = self.data, self.targets
97 |
98 | if split == 'training':
99 | self.data = data[shuffle_idx][:54000]
100 | self.targets = targets[shuffle_idx][:54000]
101 | elif split == 'validation':
102 | self.data = data[shuffle_idx][54000:]
103 | self.targets = targets[shuffle_idx][54000:]
104 | elif split == 'evaluation':
105 | self.data = data
106 | self.targets = targets
107 |
108 |
109 | class CIFAR10_OOD(CIFAR10):
110 | def __init__(self, root, split='training', transform=None, target_transform=None,
111 | download=True):
112 |
113 | super(CIFAR10_OOD, self).__init__(root, transform=transform,
114 | target_transform=target_transform, download=True)
115 | assert split in ('training', 'validation', 'evaluation', 'training_full')
116 |
117 | if split == 'training' or split == 'validation' or split == 'training_full':
118 | self.train = True
119 | shuffle_idx = np.load(os.path.join(root, 'cifar10_trainval_idx.npy'))
120 | else:
121 | self.train = False
122 | self.split = split
123 |
124 | if download:
125 | self.download()
126 |
127 | if not self._check_integrity():
128 | raise RuntimeError('Dataset not found or corrupted.' +
129 | ' You can use download=True to download it')
130 |
131 | if self.train:
132 | downloaded_list = self.train_list
133 | else:
134 | downloaded_list = self.test_list
135 |
136 | self.data = []
137 | self.targets = []
138 |
139 | # now load the picked numpy arrays
140 | for file_name, checksum in downloaded_list:
141 | file_path = os.path.join(self.root, self.base_folder, file_name)
142 | with open(file_path, 'rb') as f:
143 | if sys.version_info[0] == 2:
144 | entry = pickle.load(f)
145 | else:
146 | entry = pickle.load(f, encoding='latin1')
147 | self.data.append(entry['data'])
148 | if 'labels' in entry:
149 | self.targets.extend(entry['labels'])
150 | else:
151 | self.targets.extend(entry['fine_labels'])
152 |
153 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
154 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
155 | self.targets = np.array(self.targets)
156 |
157 | if split == 'training':
158 | self.data = self.data[shuffle_idx][:45000]
159 | self.targets = self.targets[shuffle_idx][:45000]
160 | elif split == 'validation':
161 | self.data = self.data[shuffle_idx][45000:]
162 | self.targets = self.targets[shuffle_idx][45000:]
163 | elif split == 'training_full':
164 | pass
165 |
166 | self._load_meta()
167 |
168 |
169 | class CIFAR100_OOD(CIFAR10_OOD):
170 | base_folder = 'cifar-100-python'
171 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
172 | filename = "cifar-100-python.tar.gz"
173 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
174 | train_list = [
175 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
176 | ]
177 |
178 | test_list = [
179 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
180 | ]
181 | meta = {
182 | 'filename': 'meta',
183 | 'key': 'fine_label_names',
184 | 'md5': '7973b15100ade9c7d40fb424638fde48',
185 | }
186 |
187 |
188 | class SVHN_OOD(SVHN):
189 |
190 | def __init__(self, root, split='training', transform=None, target_transform=None,
191 | download=False):
192 | super(SVHN, self).__init__(root, transform=transform,
193 | target_transform=target_transform)
194 | assert split in ('training', 'validation', 'evaluation')
195 | if split == 'training' or split == 'validation':
196 | svhn_split = 'train'
197 | shuffle_idx = np.load(os.path.join(root, 'svhn_trainval_idx.npy'))
198 | else:
199 | svhn_split = 'test'
200 |
201 | # self.split = verify_str_arg(svhn_split, "split", tuple(self.split_list.keys()))
202 | self.split = svhn_split # special treatment
203 | self.url = self.split_list[svhn_split][0]
204 | self.filename = self.split_list[svhn_split][1]
205 | self.file_md5 = self.split_list[svhn_split][2]
206 |
207 | if download:
208 | self.download()
209 |
210 | if not self._check_integrity():
211 | raise RuntimeError('Dataset not found or corrupted.' +
212 | ' You can use download=True to download it')
213 |
214 | # import here rather than at top of file because this is
215 | # an optional dependency for torchvision
216 | import scipy.io as sio
217 |
218 | # reading(loading) mat file as array
219 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
220 |
221 | self.data = loaded_mat['X']
222 | # loading from the .mat file gives an np array of type np.uint8
223 | # converting to np.int64, so that we have a LongTensor after
224 | # the conversion from the numpy array
225 | # the squeeze is needed to obtain a 1D tensor
226 | self.labels = loaded_mat['y'].astype(np.int64).squeeze()
227 |
228 | # the svhn dataset assigns the class label "10" to the digit 0
229 | # this makes it inconsistent with several loss functions
230 | # which expect the class labels to be in the range [0, C-1]
231 | np.place(self.labels, self.labels == 10, 0)
232 | self.data = np.transpose(self.data, (3, 2, 0, 1))
233 |
234 | if split == 'training':
235 | self.data = self.data[shuffle_idx][:65930]
236 | self.labels = self.labels[shuffle_idx][:65930]
237 | elif split == 'validation':
238 | self.data = self.data[shuffle_idx][65930:]
239 | self.labels = self.labels[shuffle_idx][65930:]
240 |
241 |
242 | class Constant_OOD(Dataset):
243 | def __init__(self, root, split='training', size=(32, 32), transform=None, channel=3):
244 | super(Constant_OOD, self).__init__()
245 | assert split in ('training', 'validation', 'evaluation')
246 | self.transform = transform
247 | self.root = root
248 | self.img_size = size
249 | self.channel = channel
250 | self.const_vals = np.load(os.path.join(root, 'const_img.npy')) # (40,000, 3) array
251 |
252 | if split == 'training':
253 | self.const_vals = self.const_vals[:32000]
254 | elif split == 'validation':
255 | self.const_vals = self.const_vals[32000:36000]
256 | elif split == 'evaluation':
257 | self.const_vals = self.const_vals[36000:]
258 |
259 | def __getitem__(self, index):
260 | img = np.ones(self.img_size + (self.channel,), dtype=np.float32) * self.const_vals[index] / 255 # (H, W, C)
261 | img = img.astype(np.float32)
262 | if self.channel == 1:
263 | img = img[:, :, 0:1]
264 |
265 | if self.transform is not None:
266 | img = self.transform(img)
267 | return img, 0
268 |
269 | def __len__(self):
270 | return len(self.const_vals)
271 |
272 |
273 | class ConstantGray_OOD(Dataset):
274 | def __init__(self, root, split='training', size=(32, 32), transform=None, channel=3):
275 | super(ConstantGray_OOD, self).__init__()
276 | assert split in ('training', 'validation', 'evaluation')
277 | self.transform = transform
278 | self.root = root
279 | self.img_size = size
280 | self.channel = channel
281 | self.const_vals = np.load(os.path.join(root, 'const_img_gray.npy')) # (40,000,) array
282 |
283 | if split == 'training':
284 | self.const_vals = self.const_vals[:32000]
285 | elif split == 'validation':
286 | self.const_vals = self.const_vals[32000:36000]
287 | elif split == 'evaluation':
288 | self.const_vals = self.const_vals[36000:]
289 |
290 | def __getitem__(self, index):
291 | img = np.ones(self.img_size + (self.channel,), dtype=np.float32) * self.const_vals[index] / 255 # (H, W, C)
292 | img = img.astype(np.float32)
293 |
294 | if self.transform is not None:
295 | img = self.transform(img)
296 | return img, 0
297 |
298 | def __len__(self):
299 | return len(self.const_vals)
300 |
301 |
302 | class Noise_OOD(Dataset):
303 | def __init__(self, root, split='training', transform=None, channel=3, size=(32,32)):
304 | super(Noise_OOD, self).__init__()
305 | assert split in ('training', 'validation', 'evaluation')
306 | self.transform = transform
307 | self.root = root
308 | self.vals = np.load(os.path.join(root, 'noise_img.npy')) # (40000, 32, 32, 3) array
309 | self.channel = channel
310 | self.size = size
311 |
312 | if split == 'training':
313 | self.vals = self.vals[:32000]
314 | elif split == 'validation':
315 | self.vals = self.vals[32000:36000]
316 | elif split == 'evaluation':
317 | self.vals = self.vals[36000:]
318 |
319 | def __getitem__(self, index):
320 | img = self.vals[index] / 255
321 |
322 | img = img.astype(np.float32)
323 | if self.channel == 1:
324 | img = img[:, :, 0:1]
325 | if self.size != (32, 32):
326 | img = img[:self.size[0], :self.size[1], :]
327 | if self.transform is not None:
328 | img = self.transform(img)
329 | return img, 0
330 |
331 | def __len__(self):
332 | return len(self.vals)
333 |
334 |
335 | #
336 | # CelebA
337 | #
338 | IMAGE_EXTENSTOINS = [".png", ".jpg", ".jpeg", ".bmp"]
339 | ATTR_ANNO = "list_attr_celeba.txt"
340 |
341 |
342 | def _is_image(fname):
343 | _, ext = os.path.splitext(fname)
344 | return ext.lower() in IMAGE_EXTENSTOINS
345 |
346 |
347 | def _find_images_and_annotation(root_dir):
348 | images = {}
349 | attr = None
350 | assert os.path.exists(root_dir), "{} not exists".format(root_dir)
351 | img_dir = os.path.join(root_dir, 'Img/img_align_celeba')
352 | for fname in os.listdir(img_dir):
353 | if _is_image(fname):
354 | path = os.path.join(img_dir, fname)
355 | images[os.path.splitext(fname)[0]] = path
356 |
357 | attr = os.path.join(root_dir, 'Anno', ATTR_ANNO)
358 | assert attr is not None, "Failed to find `list_attr_celeba.txt`"
359 |
360 | # begin to parse all image
361 | final = []
362 | with open(attr, "r") as fin:
363 | image_total = 0
364 | attrs = []
365 | for i_line, line in enumerate(fin):
366 | line = line.strip()
367 | if i_line == 0:
368 | image_total = int(line)
369 | elif i_line == 1:
370 | attrs = line.split(" ")
371 | else:
372 | line = re.sub("[ ]+", " ", line)
373 | line = line.split(" ")
374 | fname = os.path.splitext(line[0])[0]
375 | onehot = [int(int(d) > 0) for d in line[1:]]
376 | assert len(onehot) == len(attrs), "{} only has {} attrs < {}".format(
377 | fname, len(onehot), len(attrs))
378 | final.append({
379 | "path": images[fname],
380 | "attr": onehot
381 | })
382 | print("Find {} images, with {} attrs".format(len(final), len(attrs)))
383 | return final, attrs
384 |
385 |
386 | class CelebA_OOD(Dataset):
387 | def __init__(self, root_dir, split='training', transform=None, seed=1):
388 | """attributes are not implemented"""
389 | super().__init__()
390 | assert split in ('training', 'validation', 'evaluation')
391 | if split == 'training':
392 | setnum = 0
393 | elif split == 'validation':
394 | setnum = 1
395 | elif split == 'evaluation':
396 | setnum = 2
397 | else:
398 | raise ValueError(f'Unexpected split {split}')
399 |
400 | d_split = self.read_split_file(root_dir)
401 | self.data = d_split[setnum]
402 | self.transform = transform
403 | self.split = split
404 | self.root_dir = os.path.join(root_dir, 'CelebA', 'Img', 'img_align_celeba')
405 |
406 |
407 | def __getitem__(self, index):
408 | filename = self.data[index]
409 | path = os.path.join(self.root_dir, filename)
410 | image = Image.open(path).convert("RGB")
411 | if self.transform is not None:
412 | image = self.transform(image)
413 | return image, 0.
414 |
415 | def __len__(self):
416 | return len(self.data)
417 |
418 | def read_split_file(self, root_dir):
419 | split_path = os.path.join(root_dir, 'CelebA', 'Eval', 'list_eval_partition.txt')
420 | d_split = {0:[], 1:[], 2:[]}
421 | with open(split_path) as f:
422 | for line in f:
423 | fname, setnum = line.strip().split()
424 | d_split[int(setnum)].append(fname)
425 | return d_split
426 |
427 |
428 | class NotMNIST(Dataset):
429 | def __init__(self, root_dir, split='training', transform=None):
430 | super().__init__()
431 | self.transform = transform
432 | shuffle_idx = np.load(os.path.join(root_dir, 'notmnist_trainval_idx.npy'))
433 | datadict = loadmat(os.path.join(root_dir, 'NotMNIST/notMNIST_small.mat'))
434 | data = datadict['images'].transpose(2, 0, 1).astype('float32')
435 | data = data[shuffle_idx]
436 | targets = datadict['labels'].astype('float32')
437 | targets = targets[shuffle_idx]
438 | if split == 'training':
439 | self.data = data[:14979]
440 | self.targets = targets[:14979]
441 | elif split == 'validation':
442 | self.data = data[14979:16851]
443 | self.targets = targets[14979:16851]
444 | elif split == 'evaluation':
445 | self.data = data[16851:]
446 | self.targets = targets[16851:]
447 | else:
448 | raise ValueError
449 |
450 | def __getitem__(self, index):
451 | image = self.data[index]
452 | if self.transform is not None:
453 | image = self.transform(image)
454 | return image, self.targets[index]
455 |
456 | def __len__(self):
457 | return len(self.data)
458 |
459 |
460 | class ImageNet32(Dataset):
461 | def __init__(self, root_dir, split='training', transform=None, seed=1, train_split_ratio=0.8):
462 | """
463 | split: 'training' - the whole train split (1281149)
464 | 'evaluation' - the whole val split (49999)
465 | 'train_train' - (train_split_ratio) portion of train split
466 | 'train_val' - (1 - train_split_ratio) portion of train split
467 | """
468 | super().__init__()
469 | self.root_dir = os.path.join(root_dir, 'ImageNet32')
470 | self.split = split
471 | self.transform = transform
472 | self.shuffle_idx = get_shuffled_idx(1281149, seed)
473 | n_train = int(len(self.shuffle_idx) * train_split_ratio)
474 |
475 | if split == 'training': # whole train split
476 | self.imgdir = os.path.join(self.root_dir, 'train_32x32')
477 | self.l_img_file = sorted(os.listdir(self.imgdir))
478 | elif split == 'evaluation': # whole val split
479 | self.imgdir = os.path.join(self.root_dir, 'valid_32x32')
480 | self.l_img_file = sorted(os.listdir(self.imgdir))
481 | elif split == 'train_train': # 80 % of train split
482 | self.imgdir = os.path.join(self.root_dir, 'train_32x32')
483 | self.l_img_file = sorted(os.listdir(self.imgdir))
484 | self.l_img_file = [self.l_img_file[i] for i in self.shuffle_idx[:n_train]]
485 | elif split == 'train_val': # 20 % of train split
486 | self.imgdir = os.path.join(self.root_dir, 'train_32x32')
487 | self.l_img_file = sorted(os.listdir(self.imgdir))
488 | self.l_img_file = [self.l_img_file[i] for i in self.shuffle_idx[n_train:]]
489 | else:
490 | raise ValueError(f'{split}')
491 |
492 | def __getitem__(self, index):
493 | imgpath = os.path.join(self.imgdir, self.l_img_file[index])
494 | im = Image.open(imgpath)
495 | if self.transform is not None:
496 | im = self.transform(im)
497 | return im, 0
498 |
499 | def __len__(self):
500 | return len(self.l_img_file)
501 |
--------------------------------------------------------------------------------
/loaders/synthetic.py:
--------------------------------------------------------------------------------
1 | '''
2 | Synthetic distributions from https://github.com/nicola-decao/BNAF/
3 | '''
4 | import torch
5 | import numpy as np
6 |
7 | def sample2d(data, batch_size=200):
8 | rng = np.random.RandomState()
9 |
10 | if data == '8gaussians':
11 | scale = 4.
12 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),
13 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),
14 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]
15 | centers = [(scale * x, scale * y) for x, y in centers]
16 |
17 | dataset = []
18 | for i in range(batch_size):
19 | point = rng.randn(2) * 0.5
20 | idx = rng.randint(8)
21 | center = centers[idx]
22 | point[0] += center[0]
23 | point[1] += center[1]
24 | dataset.append(point)
25 | dataset = np.array(dataset, dtype='float32')
26 | dataset /= 1.414
27 | return dataset
28 |
29 | elif data == '2spirals':
30 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360
31 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
32 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
33 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3
34 | x += np.random.randn(*x.shape) * 0.1
35 | return x
36 |
37 | elif data == 'checkerboard':
38 | x1 = np.random.rand(batch_size) * 4 - 2
39 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2
40 | x2 = x2_ + (np.floor(x1) % 2)
41 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2
42 |
43 | else:
44 | raise RuntimeError
45 |
46 |
47 |
48 | def energy2d(data, z):
49 |
50 | if data == 't1':
51 | return U1(z)
52 | elif data == 't2':
53 | return U2(z)
54 | elif data == 't3':
55 | return U3(z)
56 | elif data == 't4':
57 | return U4(z)
58 | else:
59 | raise RuntimeError
60 |
61 | def w1(z):
62 | return torch.sin(2 * np.pi * z[:, 0] / 4)
63 |
64 | def w2(z):
65 | return 3 * torch.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2)
66 |
67 | def w3(z):
68 | return 3 * torch.sigmoid((z[:, 0] - 1) / 0.3)
69 |
70 | def U1(z):
71 | z_norm = torch.norm(z, 2, 1)
72 | add1 = 0.5 * ((z_norm - 2) / 0.4) ** 2
73 | add2 = - torch.log(torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2) +\
74 | torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2) + 1e-9)
75 |
76 | return add1 + add2
77 |
78 | def U2(z):
79 | return 0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2
80 |
81 | def U3(z):
82 | in1 = torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.35) ** 2)
83 | in2 = torch.exp(-0.5 * ((z[:, 1] - w1(z) + w2(z)) / 0.35) ** 2)
84 | return -torch.log(in1 + in2 + 1e-9)
85 |
86 | def U4(z):
87 | in1 = torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2)
88 | in2 = torch.exp(-0.5 * ((z[:, 1] - w1(z) + w3(z)) / 0.35) ** 2)
89 | return -torch.log(in1 + in2 + 1e-9)
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | # Adapted from score written by wkentaro
2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3 |
4 | import numpy as np
5 | from sklearn.metrics import roc_auc_score
6 |
7 |
8 | class runningScore(object):
9 | def __init__(self, n_classes):
10 | self.n_classes = n_classes
11 | self.confusion_matrix = np.zeros((n_classes, n_classes))
12 |
13 | def _fast_hist(self, label_true, label_pred, n_class):
14 | mask = (label_true >= 0) & (label_true < n_class)
15 | hist = np.bincount(
16 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
17 | ).reshape(n_class, n_class)
18 | return hist
19 |
20 | def update(self, label_trues, label_preds):
21 | for lt, lp in zip(label_trues, label_preds):
22 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
23 |
24 | def get_scores(self):
25 | """Returns accuracy score evaluation result.
26 | - overall accuracy
27 | - mean accuracy
28 | - mean IU
29 | - fwavacc
30 | """
31 | hist = self.confusion_matrix
32 | acc = np.diag(hist).sum() / hist.sum()
33 | acc_cls = np.diag(hist) / hist.sum(axis=1)
34 | acc_cls = np.nanmean(acc_cls)
35 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
36 | mean_iu = np.nanmean(iu)
37 | freq = hist.sum(axis=1) / hist.sum()
38 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
39 | cls_iu = dict(zip(range(self.n_classes), iu))
40 |
41 | return (
42 | {
43 | "Overall Acc: \t": acc,
44 | "Mean Acc : \t": acc_cls,
45 | "FreqW Acc : \t": fwavacc,
46 | "Mean IoU : \t": mean_iu,
47 | },
48 | cls_iu
49 | )
50 |
51 | def reset(self):
52 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
53 |
54 |
55 | class runningScore_cls(object):
56 | def __init__(self, n_classes):
57 | self.n_classes = n_classes
58 | self.confusion_matrix = np.zeros((n_classes, n_classes))
59 |
60 | def _fast_hist(self, label_true, label_pred, n_class):
61 | mask = (label_true >= 0) & (label_true < n_class)
62 | hist = np.bincount(
63 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
64 | ).reshape(n_class, n_class)
65 | return hist
66 |
67 | def update(self, label_trues, label_preds):
68 | for lt, lp in zip(label_trues, label_preds):
69 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
70 |
71 | def get_scores(self):
72 | """Returns accuracy score evaluation result.
73 | - overall accuracy
74 | - mean accuracy
75 | - mean IU
76 | - fwavacc
77 | """
78 | hist = self.confusion_matrix
79 | acc = np.diag(hist).sum() / hist.sum()
80 | acc_cls = np.diag(hist) / hist.sum(axis=1)
81 | acc_cls = np.nanmean(acc_cls)
82 |
83 |
84 | return {"Overall Acc : \t": acc,
85 | "Mean Acc : \t": acc_cls,}
86 |
87 |
88 | def reset(self):
89 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
90 |
91 |
92 | class AUC:
93 | def __init__(self):
94 | self.l_label = []
95 | self.l_pred = []
96 |
97 | def update(self, label_trues, label_preds):
98 | for lt, lp in zip(label_trues, label_preds):
99 | self.l_label.append(lt)
100 | self.l_pred.append(lp)
101 |
102 | def reset(self):
103 | self.l_label = []
104 | self.l_pred = []
105 |
106 | def get_scores(self):
107 | auc = roc_auc_score(self.l_label, self.l_pred)
108 | return {'AUROC: \t': auc}
109 |
110 |
111 | class averageMeter(object):
112 | """Computes and stores the average and current value"""
113 |
114 | def __init__(self):
115 | self.reset()
116 |
117 | def reset(self):
118 | self.val = 0
119 | self.avg = 0
120 | self.sum = 0
121 | self.count = 0
122 |
123 | def update(self, val, n=1):
124 | self.val = val
125 | self.sum += val * n
126 | self.count += n
127 | self.avg = self.sum / self.count
128 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import OmegaConf
3 | import copy
4 | import torch
5 | import torchvision.models as models
6 | import torchvision.transforms as transforms
7 | from augmentations import get_composed_augmentations
8 |
9 | from models.ae import (
10 | AE,
11 | VAE,
12 | DAE,
13 | WAE,
14 | )
15 | from models.nae import NAE, NAE_L2_OMI
16 | from models.mcmc import get_sampler
17 | from models.modules import (
18 | DeConvNet2,
19 | FCNet,
20 | ConvNet2FC,
21 | ConvMLP,
22 | IGEBMEncoder,
23 | ConvNet64,
24 | DeConvNet64,
25 | )
26 | from models.modules_sngan import Generator as SNGANGeneratorBN
27 | from models.modules_sngan import GeneratorNoBN as SNGANGeneratorNoBN
28 | from models.modules_sngan import GeneratorNoBN64 as SNGANGeneratorNoBN64
29 | from models.modules_sngan import GeneratorGN as SNGANGeneratorGN
30 | from models.energybased import EnergyBasedModel
31 |
32 |
33 |
34 | def get_net(in_dim, out_dim, **kwargs):
35 | nh = kwargs.get("nh", 8)
36 | out_activation = kwargs.get("out_activation", "linear")
37 |
38 | if kwargs["arch"] == "conv2fc":
39 | nh_mlp = kwargs["nh_mlp"]
40 | net = ConvNet2FC(
41 | in_chan=in_dim,
42 | out_chan=out_dim,
43 | nh=nh,
44 | nh_mlp=nh_mlp,
45 | out_activation=out_activation,
46 | )
47 |
48 | elif kwargs["arch"] == "deconv2":
49 | net = DeConvNet2(
50 | in_chan=in_dim, out_chan=out_dim, nh=nh, out_activation=out_activation
51 | )
52 | elif kwargs["arch"] == "conv64":
53 | num_groups = kwargs.get("num_groups", None)
54 | use_bn = kwargs.get("use_bn", False)
55 | net = ConvNet64(
56 | in_chan=in_dim,
57 | out_chan=out_dim,
58 | nh=nh,
59 | out_activation=out_activation,
60 | num_groups=num_groups,
61 | use_bn=use_bn,
62 | )
63 | elif kwargs["arch"] == "deconv64":
64 | num_groups = kwargs.get("num_groups", None)
65 | use_bn = kwargs.get("use_bn", False)
66 | net = DeConvNet64(
67 | in_chan=in_dim,
68 | out_chan=out_dim,
69 | nh=nh,
70 | out_activation=out_activation,
71 | num_groups=num_groups,
72 | use_bn=use_bn,
73 | )
74 | elif kwargs["arch"] == "fc":
75 | l_hidden = kwargs["l_hidden"]
76 | activation = kwargs["activation"]
77 | net = FCNet(
78 | in_dim=in_dim,
79 | out_dim=out_dim,
80 | l_hidden=l_hidden,
81 | activation=activation,
82 | out_activation=out_activation,
83 | )
84 | elif kwargs["arch"] == "convmlp":
85 | l_hidden = kwargs["l_hidden"]
86 | activation = kwargs["activation"]
87 | net = ConvMLP(
88 | in_dim=in_dim,
89 | out_dim=out_dim,
90 | l_hidden=l_hidden,
91 | activation=activation,
92 | out_activation=out_activation,
93 | )
94 | elif kwargs["arch"] == "IGEBMEncoder":
95 | use_spectral_norm = kwargs.get("use_spectral_norm", False)
96 | keepdim = kwargs.get("keepdim", True)
97 | net = IGEBMEncoder(
98 | in_chan=in_dim,
99 | out_chan=out_dim,
100 | n_class=None,
101 | use_spectral_norm=use_spectral_norm,
102 | keepdim=keepdim,
103 | )
104 | elif kwargs["arch"] == "sngan_generator_bn":
105 | hidden_dim = kwargs.get("hidden_dim", 128)
106 | out_activation = kwargs["out_activation"]
107 | net = SNGANGeneratorBN(
108 | z_dim=in_dim,
109 | channels=out_dim,
110 | hidden_dim=hidden_dim,
111 | out_activation=out_activation,
112 | )
113 | elif kwargs["arch"] == "sngan_generator_nobn":
114 | hidden_dim = kwargs.get("hidden_dim", 128)
115 | out_activation = kwargs["out_activation"]
116 | net = SNGANGeneratorNoBN(
117 | z_dim=in_dim,
118 | channels=out_dim,
119 | hidden_dim=hidden_dim,
120 | out_activation=out_activation,
121 | )
122 | elif kwargs["arch"] == "sngan_generator_nobn64":
123 | hidden_dim = kwargs.get("hidden_dim", 128)
124 | out_activation = kwargs["out_activation"]
125 | net = SNGANGeneratorNoBN64(
126 | z_dim=in_dim,
127 | channels=out_dim,
128 | hidden_dim=hidden_dim,
129 | out_activation=out_activation,
130 | )
131 | elif kwargs["arch"] == "sngan_generator_gn":
132 | hidden_dim = kwargs.get("hidden_dim", 128)
133 | out_activation = kwargs["out_activation"]
134 | num_groups = kwargs["num_groups"]
135 | net = SNGANGeneratorGN(
136 | z_dim=in_dim,
137 | channels=out_dim,
138 | hidden_dim=hidden_dim,
139 | out_activation=out_activation,
140 | num_groups=num_groups,
141 | )
142 |
143 | return net
144 |
145 |
146 | def get_ae(**model_cfg):
147 | arch = model_cfg.pop('arch')
148 | x_dim = model_cfg.pop("x_dim")
149 | z_dim = model_cfg.pop("z_dim")
150 | enc_cfg = model_cfg.pop('encoder')
151 | dec_cfg = model_cfg.pop('decoder')
152 |
153 | if arch == "ae":
154 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **enc_cfg)
155 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg)
156 | ae = AE(encoder, decoder)
157 | elif arch == "dae":
158 | sig = model_cfg["sig"]
159 | noise_type = model_cfg["noise_type"]
160 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **enc_cfg)
161 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg)
162 | ae = DAE(encoder, decoder, sig=sig, noise_type=noise_type)
163 | elif arch == "wae":
164 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **enc_cfg)
165 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg)
166 | ae = WAE(encoder, decoder, **model_cfg)
167 | elif arch == "vae":
168 | sigma_trainable = model_cfg.get("sigma_trainable", False)
169 | encoder = get_net(in_dim=x_dim, out_dim=z_dim * 2, **enc_cfg)
170 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **dec_cfg)
171 | ae = VAE(encoder, decoder, **model_cfg)
172 | return ae
173 |
174 |
175 |
176 | def get_vae(**model_cfg):
177 | x_dim = model_cfg["x_dim"]
178 | z_dim = model_cfg["z_dim"]
179 | encoder_out_dim = z_dim * 2
180 |
181 | encoder = get_net(in_dim=x_dim, out_dim=encoder_out_dim, **model_cfg["encoder"])
182 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg["decoder"])
183 | n_sample = model_cfg.get("n_sample", 1)
184 | pred_method = model_cfg.get("pred_method", "recon")
185 |
186 | if model_cfg["arch"] == "vae":
187 | ae = VAE(encoder, decoder, n_sample=n_sample, pred_method=pred_method)
188 | return ae
189 |
190 |
191 | def get_ebm(**model_cfg):
192 | model_cfg = copy.deepcopy(model_cfg)
193 | if "arch" in model_cfg:
194 | model_cfg.pop("arch")
195 | in_dim = model_cfg["x_dim"]
196 | model_cfg.pop("x_dim")
197 | net = get_net(in_dim=in_dim, out_dim=1, **model_cfg["net"])
198 | model_cfg.pop("net")
199 | return EnergyBasedModel(net, **model_cfg)
200 |
201 |
202 | def get_nae(**model_cfg):
203 | arch = model_cfg.pop("arch")
204 | x_dim = model_cfg["x_dim"]
205 | z_dim = model_cfg["z_dim"]
206 |
207 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **model_cfg["encoder"])
208 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg["decoder"])
209 |
210 | if arch == "nae":
211 | ae = NAE(encoder, decoder, **model_cfg["nae"])
212 | else:
213 | raise ValueError(f"{arch}")
214 | return ae
215 |
216 |
217 | def get_nae_v2(**model_cfg):
218 | arch = model_cfg.pop('arch')
219 | sampling = model_cfg.pop('sampling')
220 | x_dim = model_cfg['x_dim']
221 | z_dim = model_cfg['z_dim']
222 |
223 | encoder = get_net(in_dim=x_dim, out_dim=z_dim, **model_cfg["encoder"])
224 | decoder = get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg["decoder"])
225 | if arch == 'nae_l2' and sampling == 'omi':
226 | sampler_z = get_sampler(**model_cfg['sampler_z'])
227 | sampler_x = get_sampler(**model_cfg['sampler_x'])
228 | nae = NAE_L2_OMI(encoder, decoder, sampler_z, sampler_x, **model_cfg['nae'])
229 | else:
230 | raise ValueError(f'Invalid sampling: {sampling}')
231 | return nae
232 |
233 |
234 | def get_model(cfg, *args, version=None, **kwargs):
235 | # cfg can be a whole config dictionary or a value of a key 'model' in the config dictionary (cfg['model']).
236 | if "model" in cfg:
237 | model_dict = cfg["model"]
238 | elif "arch" in cfg:
239 | model_dict = cfg
240 | else:
241 | raise ValueError(f"Invalid model configuration dictionary: {cfg}")
242 | name = model_dict["arch"]
243 | model = _get_model_instance(name)
244 | model = model(**model_dict)
245 | return model
246 |
247 |
248 | def _get_model_instance(name):
249 | try:
250 | return {
251 | "ae": get_ae,
252 | "nae": get_nae,
253 | "nae_l2": get_nae_v2,
254 | }[name]
255 | except:
256 | raise ("Model {} not available".format(name))
257 |
258 |
259 | def load_pretrained(identifier, config_file, ckpt_file, root='pretrained', **kwargs):
260 | """
261 | load pre-trained model.
262 | identifier: '/'. e.g. 'ae_mnist/z16'
263 | config_file: name of a config file. e.g. 'ae.yml'
264 | ckpt_file: name of a model checkpoint file. e.g. 'model_best.pth'
265 | root: path to pretrained directory
266 | """
267 | config_path = os.path.join(root, identifier, config_file)
268 | ckpt_path = os.path.join(root, identifier, ckpt_file)
269 | cfg = OmegaConf.load(config_path)
270 | model_name = cfg['model']['arch']
271 |
272 | model = get_model(cfg)
273 | ckpt = torch.load(ckpt_path, map_location='cpu')
274 | if 'model_state' in ckpt:
275 | ckpt = ckpt['model_state']
276 | model.load_state_dict(ckpt)
277 | model.eval()
278 | return model, cfg
279 |
280 |
--------------------------------------------------------------------------------
/models/ae.py:
--------------------------------------------------------------------------------
1 | """
2 | ae.py
3 | =====
4 | Autoencoders
5 | """
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch import optim
11 | from torchvision.utils import make_grid
12 | from models.modules import IsotropicGaussian, IsotropicLaplace
13 |
14 |
15 | class AE(nn.Module):
16 | """autoencoder"""
17 | def __init__(self, encoder, decoder):
18 | """
19 | encoder, decoder : neural networks
20 | """
21 | super(AE, self).__init__()
22 | self.encoder = encoder
23 | self.decoder = decoder
24 | self.own_optimizer = False
25 |
26 | def forward(self, x):
27 | z = self.encode(x)
28 | recon = self.decoder(z)
29 | return recon
30 |
31 | def encode(self, x):
32 | z = self.encoder(x)
33 | return z
34 |
35 | def predict(self, x):
36 | """one-class anomaly prediction"""
37 | recon = self(x)
38 | if hasattr(self.decoder, 'error'):
39 | predict = self.decoder.error(x, recon)
40 | else:
41 | predict = ((recon - x) ** 2).view(len(x), -1).mean(dim=1)
42 | return predict
43 |
44 | def predict_and_reconstruct(self, x):
45 | recon = self(x)
46 | if hasattr(self.decoder, 'error'):
47 | recon_err = self.decoder.error(x, recon)
48 | else:
49 | recon_err = ((recon - x) ** 2).view(len(x), -1).mean(dim=1)
50 | return recon_err, recon
51 |
52 | def validation_step(self, x, **kwargs):
53 | recon = self(x)
54 | if hasattr(self.decoder, 'error'):
55 | predict = self.decoder.error(x, recon)
56 | else:
57 | predict = ((recon - x) ** 2).view(len(x), -1).mean(dim=1)
58 | loss = predict.mean()
59 |
60 | if kwargs.get('show_image', True):
61 | x_img = make_grid(x.detach().cpu(), nrow=10, range=(0, 1))
62 | recon_img = make_grid(recon.detach().cpu(), nrow=10, range=(0, 1))
63 | else:
64 | x_img, recon_img = None, None
65 | return {'loss': loss.item(), 'predict': predict, 'reconstruction': recon,
66 | 'input@': x_img, 'recon@': recon_img}
67 |
68 | def train_step(self, x, optimizer, clip_grad=None, **kwargs):
69 | optimizer.zero_grad()
70 | recon_error = self.predict(x)
71 | loss = recon_error.mean()
72 | loss.backward()
73 | if clip_grad is not None:
74 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=clip_grad)
75 | optimizer.step()
76 | return {'loss': loss.item()}
77 |
78 | def reconstruct(self, x):
79 | return self(x)
80 |
81 | def sample(self, N, z_shape=None, device='cpu'):
82 | if z_shape is None:
83 | z_shape = self.encoder.out_shape
84 |
85 | rand_z = torch.rand(N, *z_shape).to(device) * 2 - 1
86 | sample_x = self.decoder(rand_z)
87 | return sample_x
88 |
89 |
90 |
91 | def clip_vector_norm(x, max_norm):
92 | norm = x.norm(dim=-1, keepdim=True)
93 | x = x * ((norm < max_norm).to(torch.float) + (norm > max_norm).to(torch.float) * max_norm/norm + 1e-6)
94 | return x
95 |
96 |
97 | class DAE(AE):
98 | """denoising autoencoder"""
99 | def __init__(self, encoder, decoder, sig=0.0, noise_type='gaussian'):
100 | super(DAE, self).__init__(encoder, decoder)
101 | self.sig = sig
102 | self.noise_type = noise_type
103 |
104 | def train_step(self, x, optimizer, y=None):
105 | optimizer.zero_grad()
106 | if self.noise_type == 'gaussian':
107 | noise = torch.randn(*x.shape, dtype=torch.float32)
108 | noise = noise.to(x.device)
109 | recon = self(x + self.sig * noise)
110 | elif self.noise_type == 'saltnpepper':
111 | x = self.salt_and_pepper(x)
112 | recon = self(x)
113 | else:
114 | raise ValueError(f'Invalid noise_type: {self.noise_type}')
115 |
116 | loss = torch.mean((recon - x) ** 2)
117 | loss.backward()
118 | optimizer.step()
119 | return {'loss': loss.item()}
120 |
121 | def salt_and_pepper(self, img):
122 | """salt and pepper noise for mnist"""
123 | # for salt and pepper noise, sig is probability of occurance of noise pixels.
124 | img = img.copy()
125 | prob = self.sig
126 | rnd = torch.random.rand(*img.shape).to(img.device)
127 | img[rnd < prob / 2] = 0.
128 | img[rnd > 1 - prob / 2] = 1.
129 | return img
130 |
131 |
132 | class VAE(AE):
133 | def __init__(self, encoder, decoder, n_sample=1, use_mean=False, pred_method='recon', sigma_trainable=False):
134 | super(VAE, self).__init__(encoder, IsotropicGaussian(decoder, sigma=1, sigma_trainable=sigma_trainable))
135 | self.n_sample = n_sample # the number of samples to generate for anomaly detection
136 | self.use_mean = use_mean # if True, does not sample from posterior distribution
137 | self.pred_method = pred_method # which anomaly score to use
138 | self.z_shape = None
139 |
140 | def forward(self, x):
141 | z = self.encoder(x)
142 | z_sample = self.sample_latent(z)
143 | return self.decoder(z_sample)
144 |
145 | def sample_latent(self, z):
146 | half_chan = int(z.shape[1] / 2)
147 | mu, log_sig = z[:, :half_chan], z[:, half_chan:]
148 | if self.use_mean:
149 | return mu
150 | eps = torch.randn(*mu.shape, dtype=torch.float32)
151 | eps = eps.to(z.device)
152 | return mu + torch.exp(log_sig) * eps
153 |
154 | # def sample_marginal_latent(self, z_shape):
155 | # return torch.randn(z_shape)
156 |
157 | def kl_loss(self, z):
158 | """analytic (positive) KL divergence between gaussians
159 | KL(q(z|x) | p(z))"""
160 | half_chan = int(z.shape[1] / 2)
161 | mu, log_sig = z[:, :half_chan], z[:, half_chan:]
162 | mu_sq = mu ** 2
163 | sig_sq = torch.exp(log_sig) ** 2
164 | kl = mu_sq + sig_sq - torch.log(sig_sq) - 1
165 | # return 0.5 * torch.mean(kl.view(len(kl), -1), dim=1)
166 | return 0.5 * torch.sum(kl.view(len(kl), -1), dim=1)
167 |
168 | def train_step(self, x, optimizer, y=None, clip_grad=None):
169 | optimizer.zero_grad()
170 | z = self.encoder(x)
171 | z_sample = self.sample_latent(z)
172 | nll = - self.decoder.log_likelihood(x, z_sample)
173 |
174 | kl_loss = self.kl_loss(z)
175 | loss = nll + kl_loss
176 | loss = loss.mean()
177 | nll = nll.mean()
178 |
179 | loss.backward()
180 | optimizer.step()
181 | return {'loss': nll.item(), 'vae/kl_loss_': kl_loss.mean(), 'vae/sigma_': self.decoder.sigma.item()}
182 |
183 | def predict(self, x):
184 | """one-class anomaly prediction using the metric specified by self.anomaly_score"""
185 | if self.pred_method == 'recon':
186 | return self.reconstruction_probability(x)
187 | elif self.pred_method == 'lik':
188 | return - self.marginal_likelihood(x) # negative log likelihood
189 | else:
190 | raise ValueError(f'{self.pred_method} should be recon or lik')
191 |
192 | def validation_step(self, x, y=None, **kwargs):
193 | z = self.encoder(x)
194 | z_sample = self.sample_latent(z)
195 | recon = self.decoder(z_sample)
196 | loss = torch.mean((recon - x) ** 2)
197 | predict = - self.decoder.log_likelihood(x, z_sample)
198 |
199 | if kwargs.get('show_image', True):
200 | x_img = make_grid(x.detach().cpu(), nrow=10, range=(0, 1))
201 | recon_img = make_grid(recon.detach().cpu(), nrow=10, range=(0, 1))
202 | else:
203 | x_img, recon_img = None, None
204 |
205 | return {'loss': loss.item(), 'predict': predict, 'reconstruction': recon,
206 | 'input@': x_img, 'recon@': recon_img}
207 |
208 | def reconstruction_probability(self, x):
209 | l_score = []
210 | z = self.encoder(x)
211 | for i in range(self.n_sample):
212 | z_sample = self.sample_latent(z)
213 | recon_loss = - self.decoder.log_likelihood(x, z_sample)
214 | score = recon_loss
215 | l_score.append(score)
216 | return torch.stack(l_score).mean(dim=0)
217 |
218 | def marginal_likelihood(self, x, n_sample=None):
219 | """marginal likelihood from importance sampling
220 | log P(X) = log \int P(X|Z) * P(Z)/Q(Z|X) * Q(Z|X) dZ"""
221 | if n_sample is None:
222 | n_sample = self.n_sample
223 |
224 | # check z shape
225 | with torch.no_grad():
226 | z = self.encoder(x)
227 |
228 | l_score = []
229 | for i in range(n_sample):
230 | z_sample = self.sample_latent(z)
231 | log_recon = self.decoder.log_likelihood(x, z_sample)
232 | log_prior = self.log_prior(z_sample)
233 | log_posterior = self.log_posterior(z, z_sample)
234 | l_score.append(log_recon + log_prior - log_posterior)
235 | score = torch.stack(l_score)
236 | logN = torch.log(torch.tensor(n_sample, dtype=torch.float, device=x.device))
237 | return torch.logsumexp(score, dim=0) - logN
238 |
239 | def marginal_likelihood_naive(self, x, n_sample=None):
240 | if n_sample is None:
241 | n_sample = self.n_sample
242 |
243 | # check z shape
244 | z_dummy = self.encoder(x[[0]])
245 | z = torch.zeros(len(x), *list(z_dummy.shape[1:]), dtype=torch.float).to(x.device)
246 |
247 | l_score = []
248 | for i in range(n_sample):
249 | z_sample = self.sample_latent(z)
250 | recon_loss = - self.decoder.log_likelihood(x, z_sample)
251 | score = recon_loss
252 | l_score.append(score)
253 | score = torch.stack(l_score)
254 | return - torch.logsumexp(-score, dim=0)
255 |
256 | def elbo(self, x):
257 | l_score = []
258 | z = self.encoder(x)
259 | for i in range(self.n_sample):
260 | z_sample = self.sample_latent(z)
261 | recon_loss = - self.decoder.log_likelihood(x, z_sample)
262 | kl_loss = self.kl_loss(z)
263 | score = recon_loss + kl_loss
264 | l_score.append(score)
265 | return torch.stack(l_score).mean(dim=0)
266 |
267 | def log_posterior(self, z, z_sample):
268 | half_chan = int(z.shape[1] / 2)
269 | mu, log_sig = z[:, :half_chan], z[:, half_chan:]
270 |
271 | log_p = torch.distributions.Normal(mu, torch.exp(log_sig)).log_prob(z_sample)
272 | log_p = log_p.view(len(z), -1).sum(-1)
273 | return log_p
274 |
275 | def log_prior(self, z_sample):
276 | log_p = torch.distributions.Normal(torch.zeros_like(z_sample), torch.ones_like(z_sample)).log_prob(z_sample)
277 | log_p = log_p.view(len(z_sample), -1).sum(-1)
278 | return log_p
279 |
280 | def posterior_entropy(self, z):
281 | half_chan = int(z.shape[1] / 2)
282 | mu, log_sig = z[:, :half_chan], z[:, half_chan:]
283 | D = mu.shape[1]
284 | pi = torch.tensor(np.pi, dtype=torch.float32).to(z.device)
285 | term1 = D / 2
286 | term2 = D / 2 * torch.log(2 * pi)
287 | term3 = log_sig.view(len(log_sig), -1).sum(dim=-1)
288 | return term1 + term2 + term3
289 |
290 | def _set_z_shape(self, x):
291 | if self.z_shape is not None:
292 | return
293 | with torch.no_grad():
294 | dummy_z = self.encode(x[[0]])
295 | dummy_z = self.sample_latent(dummy_z)
296 | z_shape = dummy_z.shape
297 | self.z_shape = z_shape[1:]
298 |
299 | def sample_z(self, n_sample, device):
300 | z_shape = (n_sample,) + self.z_shape
301 | return torch.randn(z_shape, device=device, dtype=torch.float)
302 |
303 | def sample(self, n_sample, device):
304 | z = self.sample_z(n_sample, device)
305 | return {'sample_x': self.decoder.sample(z)}
306 |
307 |
308 | class WAE(AE):
309 | """Wassertstein Autoencoder with MMD loss"""
310 | def __init__(self, encoder, decoder, reg=1., bandwidth='median', prior='gaussian'):
311 | super().__init__(encoder, decoder)
312 | if not isinstance(bandwidth, str):
313 | bandwidth = float(bandwidth)
314 | self.bandwidth = bandwidth
315 | self.reg = reg # coefficient for MMD loss
316 | self.prior = prior
317 |
318 | def train_step(self, x, optimizer, y=None, **kwargs):
319 | optimizer.zero_grad()
320 | # forward step
321 | z = self.encoder(x)
322 | recon = self.decoder(z)
323 | # recon_loss = torch.mean(self.decoder.square_error(x, recon))
324 | recon_loss = torch.mean((x - recon) ** 2)
325 |
326 | # MMD step
327 | z_prior = self.sample_prior(z)
328 | mmd_loss = self.mmd(z_prior, z)
329 |
330 | loss = recon_loss + mmd_loss * self.reg
331 | loss.backward()
332 | optimizer.step()
333 | return {'loss': loss.item(), 'recon_loss': recon_loss, 'mmd_loss': mmd_loss.item()}
334 |
335 | def sample_prior(self, z):
336 | if self.prior == 'gaussian':
337 | return torch.randn_like(z)
338 | elif self.prior == 'uniform_tanh':
339 | return torch.rand_like(z) * 2 - 1
340 | else:
341 | raise ValueError(f'invalid prior {self.prior}')
342 |
343 | def mmd(self, X1, X2):
344 | if len(X1.shape) == 4:
345 | X1 = X1.view(len(X1), -1)
346 | if len(X2.shape) == 4:
347 | X2 = X2.view(len(X2), -1)
348 |
349 | N1 = len(X1)
350 | X1_sq = X1.pow(2).sum(1).unsqueeze(0)
351 | X1_cr = torch.mm(X1, X1.t())
352 | X1_dist = X1_sq + X1_sq.t() - 2 * X1_cr
353 |
354 | N2 = len(X2)
355 | X2_sq = X2.pow(2).sum(1).unsqueeze(0)
356 | X2_cr = torch.mm(X2, X2.t())
357 | X2_dist = X2_sq + X2_sq.t() - 2 * X2_cr
358 |
359 | X12 = torch.mm(X1, X2.t())
360 | X12_dist = X1_sq.t() + X2_sq - 2 * X12
361 |
362 | # median heuristic to select bandwidth
363 | if self.bandwidth == 'median':
364 | X1_triu = X1_dist[torch.triu(torch.ones_like(X1_dist), diagonal=1) == 1]
365 | bandwidth1 = torch.median(X1_triu)
366 | X2_triu = X2_dist[torch.triu(torch.ones_like(X2_dist), diagonal=1) == 1]
367 | bandwidth2 = torch.median(X2_triu)
368 | bandwidth_sq = ((bandwidth1 + bandwidth2) / 2).detach()
369 | else:
370 | bandwidth_sq = (self.bandwidth ** 2)
371 |
372 | C = - 0.5 / bandwidth_sq
373 | K11 = torch.exp(C * X1_dist)
374 | K22 = torch.exp(C * X2_dist)
375 | K12 = torch.exp(C * X12_dist)
376 | K11 = (1 - torch.eye(N1).to(X1.device)) * K11
377 | K22 = (1 - torch.eye(N2).to(X1.device)) * K22
378 | mmd = K11.sum() / N1 / (N1 - 1) + K22.sum() / N2 / (N2 - 1) - 2 * K12.mean()
379 | return mmd
380 |
381 |
382 |
383 |
--------------------------------------------------------------------------------
/models/energybased.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from models.langevin import sample_langevin
7 |
8 |
9 | class EnergyBasedModel(nn.Module):
10 | def __init__(self, net, alpha=1, step_size=10, sample_step=60,
11 | noise_std=0.005, buffer_size=10000, replay_ratio=0.95,
12 | langevin_clip_grad=0.01, clip_x=(0, 1)):
13 | super().__init__()
14 | self.net = net
15 | self.alpha = alpha
16 | self.step_size = step_size
17 | self.sample_step = sample_step
18 | self.noise_std = noise_std
19 | self.buffer_size = buffer_size
20 | self.buffer = SampleBuffer(max_samples=buffer_size, replay_ratio=replay_ratio)
21 | self.replay_ratio = replay_ratio
22 | self.replay = True if self.replay_ratio > 0 else False
23 | self.langevin_clip_grad = langevin_clip_grad
24 | self.clip_x = clip_x
25 |
26 | self.own_optimizer = False
27 |
28 | def forward(self, x):
29 | return self.net(x).view(-1)
30 |
31 | def predict(self, x):
32 | return self(x)
33 |
34 | def validation_step(self, x, y=None):
35 | with torch.no_grad():
36 | pos_e = self(x)
37 |
38 | return {'loss': pos_e.mean(),
39 | 'predict': pos_e,
40 | }
41 |
42 | def train_step(self, x, optimizer, clip_grad=None, y=None):
43 | neg_x = self.sample(shape=x.shape, device=x.device, replay=self.replay)
44 | optimizer.zero_grad()
45 | pos_e = self(x)
46 | neg_e = self(neg_x)
47 |
48 | ebm_loss = pos_e.mean() - neg_e.mean()
49 | reg_loss = (pos_e ** 2).mean() + (neg_e ** 2).mean()
50 | weight_norm = sum([(w ** 2).sum() for w in self.net.parameters()])
51 | loss = ebm_loss + self.alpha * reg_loss # + self.beta * weight_norm
52 | loss.backward()
53 |
54 | if clip_grad is not None:
55 | # torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=clip_grad)
56 | clip_grad_ebm(self.parameters(), optimizer)
57 |
58 | optimizer.step()
59 | return {'loss': loss.item(),
60 | 'ebm_loss': ebm_loss.item(), 'pos_e': pos_e.mean().item(), 'neg_e': neg_e.mean().item(),
61 | 'reg_loss': reg_loss.item(), 'neg_sample': neg_x.detach().cpu(),
62 | 'weight_norm': weight_norm.item()}
63 |
64 | def sample(self, shape, device, replay=True, intermediate=False, step_size=None,
65 | sample_step=None):
66 | if step_size is None:
67 | step_size = self.step_size
68 | if sample_step is None:
69 | sample_step = self.sample_step
70 | # initialize
71 | x0 = self.buffer.sample(shape, device, replay=replay)
72 | # run langevin
73 | sample_x = sample_langevin(x0, self, step_size, sample_step,
74 | noise_scale=self.noise_std,
75 | intermediate_samples=intermediate,
76 | clip_x=self.clip_x,
77 | clip_grad=self.langevin_clip_grad,
78 | )
79 | # push samples
80 | if replay:
81 | self.buffer.push(sample_x)
82 | return sample_x
83 |
84 |
85 | class SampleBuffer:
86 | def __init__(self, max_samples=10000, replay_ratio=0.95, bound=None):
87 | self.max_samples = max_samples
88 | self.buffer = []
89 | self.replay_ratio = replay_ratio
90 | if bound is None:
91 | self.bound = (0, 1)
92 | else:
93 | self.bound = bound
94 |
95 | def __len__(self):
96 | return len(self.buffer)
97 |
98 | def push(self, samples):
99 | samples = samples.detach().to('cpu')
100 |
101 | for sample in samples:
102 | self.buffer.append(sample)
103 |
104 | if len(self.buffer) > self.max_samples:
105 | self.buffer.pop(0)
106 |
107 | def get(self, n_samples):
108 | samples = random.choices(self.buffer, k=n_samples)
109 | samples = torch.stack(samples, 0)
110 | return samples
111 |
112 | def sample(self, shape, device, replay=False):
113 | if len(self.buffer) < 1 or not replay: # empty buffer
114 | return self.random(shape, device)
115 |
116 | n_replay = (np.random.rand(shape[0]) < self.replay_ratio).sum()
117 |
118 | replay_sample = self.get(n_replay).to(device)
119 | n_random = shape[0] - n_replay
120 | if n_random > 0:
121 | random_sample = self.random((n_random,) + shape[1:], device)
122 | return torch.cat([replay_sample, random_sample])
123 | else:
124 | return replay_sample
125 |
126 | def random(self, shape, device):
127 | if self.bound is None:
128 | r = torch.rand(*shape, dtype=torch.float).to(device)
129 |
130 | elif self.bound == 'spherical':
131 | r = torch.randn(*shape, dtype=torch.float).to(device)
132 | norm = r.view(len(r), -1).norm(dim=-1)
133 | if len(shape) == 4:
134 | r = r / norm[:, None, None, None]
135 | elif len(shape) == 2:
136 | r = r / norm[:, None]
137 | else:
138 | raise NotImplementedError
139 |
140 | elif len(self.bound) == 2:
141 | r = torch.rand(*shape, dtype=torch.float).to(device)
142 | r = r * (self.bound[1] - self.bound[0]) + self.bound[0]
143 | return r
144 |
145 |
146 | class SampleBufferV2:
147 | def __init__(self, max_samples=10000, replay_ratio=0.95):
148 | self.max_samples = max_samples
149 | self.buffer = []
150 | self.replay_ratio = replay_ratio
151 |
152 | def __len__(self):
153 | return len(self.buffer)
154 |
155 | def push(self, samples):
156 | samples = samples.detach().to('cpu')
157 |
158 | for sample in samples:
159 | self.buffer.append(sample)
160 |
161 | if len(self.buffer) > self.max_samples:
162 | self.buffer.pop(0)
163 |
164 | def get(self, n_samples):
165 | samples = random.choices(self.buffer, k=n_samples)
166 | samples = torch.stack(samples, 0)
167 | return samples
168 |
169 |
170 | def clip_grad_ebm(parameters, optimizer):
171 | with torch.no_grad():
172 | for group in optimizer.param_groups:
173 | for p in group['params']:
174 | state = optimizer.state[p]
175 |
176 | if 'step' not in state or state['step'] < 1:
177 | continue
178 |
179 | step = state['step']
180 | exp_avg_sq = state['exp_avg_sq']
181 | _, beta2 = group['betas']
182 |
183 | bound = 3 * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1
184 | p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound))
185 |
--------------------------------------------------------------------------------
/models/igebm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn import utils
6 |
7 |
8 | class SpectralNorm:
9 | def __init__(self, name, bound=False):
10 | self.name = name
11 | self.bound = bound
12 |
13 | def compute_weight(self, module):
14 | weight = getattr(module, self.name + '_orig')
15 | u = getattr(module, self.name + '_u')
16 | size = weight.size()
17 | weight_mat = weight.contiguous().view(size[0], -1)
18 |
19 | with torch.no_grad():
20 | v = weight_mat.t() @ u
21 | v = v / v.norm()
22 | u = weight_mat @ v
23 | u = u / u.norm()
24 |
25 | sigma = u @ weight_mat @ v
26 |
27 | if self.bound:
28 | weight_sn = weight / (sigma + 1e-6) * torch.clamp(sigma, max=1)
29 |
30 | else:
31 | weight_sn = weight / sigma
32 |
33 | return weight_sn, u
34 |
35 | @staticmethod
36 | def apply(module, name, bound):
37 | fn = SpectralNorm(name, bound)
38 |
39 | weight = getattr(module, name)
40 | del module._parameters[name]
41 | module.register_parameter(name + '_orig', weight)
42 | input_size = weight.size(0)
43 | u = weight.new_empty(input_size).normal_()
44 | module.register_buffer(name, weight)
45 | module.register_buffer(name + '_u', u)
46 |
47 | module.register_forward_pre_hook(fn)
48 |
49 | return fn
50 |
51 | def __call__(self, module, input):
52 | weight_sn, u = self.compute_weight(module)
53 | setattr(module, self.name, weight_sn)
54 | setattr(module, self.name + '_u', u)
55 |
56 |
57 | def spectral_norm(module, init=True, std=1, bound=False):
58 | if init:
59 | nn.init.normal_(module.weight, 0, std)
60 |
61 | if hasattr(module, 'bias') and module.bias is not None:
62 | module.bias.data.zero_()
63 |
64 | SpectralNorm.apply(module, 'weight', bound=bound)
65 |
66 | return module
67 |
68 |
69 | class ResBlock(nn.Module):
70 | def __init__(self, in_channel, out_channel, n_class=None, downsample=False, use_spectral_norm=True):
71 | super().__init__()
72 |
73 | self.conv1 = nn.Conv2d(in_channel,
74 | out_channel,
75 | 3,
76 | padding=1,
77 | bias=False if n_class is not None else True)
78 |
79 | self.conv2 = nn.Conv2d(out_channel,
80 | out_channel,
81 | 3,
82 | padding=1,
83 | bias=False if n_class is not None else True)
84 |
85 | if use_spectral_norm:
86 | self.conv1 = spectral_norm(self.conv1)
87 | self.conv2 = spectral_norm(self.conv2, std=1e-10, bound=True)
88 |
89 | self.class_embed = None
90 |
91 | if n_class is not None:
92 | class_embed = nn.Embedding(n_class, out_channel * 2 * 2)
93 | class_embed.weight.data[:, : out_channel * 2] = 1
94 | class_embed.weight.data[:, out_channel * 2 :] = 0
95 |
96 | self.class_embed = class_embed
97 |
98 | self.skip = None
99 |
100 | if in_channel != out_channel or downsample:
101 | if use_spectral_norm:
102 | self.skip = nn.Sequential(
103 | spectral_norm(nn.Conv2d(in_channel, out_channel, 1, bias=False))
104 | )
105 | else:
106 | self.skip = nn.Sequential(
107 | nn.Conv2d(in_channel, out_channel, 1, bias=False)
108 | )
109 |
110 | self.downsample = downsample
111 |
112 | def forward(self, input, class_id=None):
113 | out = input
114 |
115 | out = self.conv1(out)
116 |
117 | if self.class_embed is not None:
118 | embed = self.class_embed(class_id).view(input.shape[0], -1, 1, 1)
119 | weight1, weight2, bias1, bias2 = embed.chunk(4, 1)
120 | out = weight1 * out + bias1
121 |
122 | out = F.leaky_relu(out, negative_slope=0.2)
123 |
124 | out = self.conv2(out)
125 |
126 | if self.class_embed is not None:
127 | out = weight2 * out + bias2
128 |
129 | if self.skip is not None:
130 | skip = self.skip(input)
131 |
132 | else:
133 | skip = input
134 |
135 | out = out + skip
136 |
137 | if self.downsample:
138 | out = F.avg_pool2d(out, 2)
139 |
140 | out = F.leaky_relu(out, negative_slope=0.2)
141 |
142 | return out
143 |
144 |
145 | class IGEBM(nn.Module):
146 | def __init__(self, in_chan=3, n_class=None):
147 | super().__init__()
148 |
149 | self.conv1 = spectral_norm(nn.Conv2d(in_chan, 128, 3, padding=1), std=1)
150 |
151 | self.blocks = nn.ModuleList(
152 | [
153 | ResBlock(128, 128, n_class, downsample=True),
154 | ResBlock(128, 128, n_class),
155 | ResBlock(128, 256, n_class, downsample=True),
156 | ResBlock(256, 256, n_class),
157 | ResBlock(256, 256, n_class, downsample=True),
158 | ResBlock(256, 256, n_class),
159 | ]
160 | )
161 |
162 | self.linear = nn.Linear(256, 1)
163 |
164 | def forward(self, input, class_id=None):
165 | out = self.conv1(input)
166 |
167 | out = F.leaky_relu(out, negative_slope=0.2)
168 |
169 | for block in self.blocks:
170 | out = block(out, class_id)
171 |
172 | out = F.relu(out)
173 | out = out.view(out.shape[0], out.shape[1], -1).sum(2)
174 | out = self.linear(out)
175 |
176 | return out
177 |
178 |
--------------------------------------------------------------------------------
/models/langevin.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.autograd as autograd
4 |
5 | def clip_vector_norm(x, max_norm):
6 | norm = x.norm(dim=-1, keepdim=True)
7 | x = x * ((norm < max_norm).to(torch.float) + (norm > max_norm).to(torch.float) * max_norm/norm + 1e-6)
8 | return x
9 |
10 |
11 | def sample_langevin(x, model, stepsize, n_steps, noise_scale=None, intermediate_samples=False,
12 | clip_x=None, clip_grad=None, reject_boundary=False, noise_anneal=None,
13 | spherical=False, mh=False):
14 | """Draw samples using Langevin dynamics
15 | x: torch.Tensor, initial points
16 | model: An energy-based model. returns energy
17 | stepsize: float
18 | n_steps: integer
19 | noise_scale: Optional. float. If None, set to np.sqrt(stepsize * 2)
20 | clip_x : tuple (start, end) or None boundary of square domain
21 | reject_boundary: Reject out-of-domain samples if True. otherwise clip.
22 | """
23 | assert not ((stepsize is None) and (noise_scale is None)), 'stepsize and noise_scale cannot be None at the same time'
24 | if noise_scale is None:
25 | noise_scale = np.sqrt(stepsize * 2)
26 | if stepsize is None:
27 | stepsize = (noise_scale ** 2) / 2
28 | noise_scale_ = noise_scale
29 |
30 | l_samples = []
31 | l_dynamics = []; l_drift = []; l_diffusion = []
32 | x.requires_grad = True
33 | for i_step in range(n_steps):
34 | l_samples.append(x.detach().to('cpu'))
35 | noise = torch.randn_like(x) * noise_scale_
36 | out = model(x)
37 | grad = autograd.grad(out.sum(), x, only_inputs=True)[0]
38 | if clip_grad is not None:
39 | grad = clip_vector_norm(grad, max_norm=clip_grad)
40 | dynamics = - stepsize * grad + noise # negative!
41 | xnew = x + dynamics
42 | if clip_x is not None:
43 | if reject_boundary:
44 | accept = ((xnew >= clip_x[0]) & (xnew <= clip_x[1])).view(len(x), -1).all(dim=1)
45 | reject = ~ accept
46 | xnew[reject] = x[reject]
47 | x = xnew
48 | else:
49 | x = torch.clamp(xnew, clip_x[0], clip_x[1])
50 | else:
51 | x = xnew
52 |
53 | if spherical:
54 | if len(x.shape) == 4:
55 | x = x / x.view(len(x), -1).norm(dim=1)[:, None, None ,None]
56 | else:
57 | x = x / x.norm(dim=1, keepdim=True)
58 |
59 | if noise_anneal is not None:
60 | noise_scale_ = noise_scale / (1 + i_step)
61 |
62 | l_dynamics.append(dynamics.detach().to('cpu'))
63 | l_drift.append((- stepsize * grad).detach().cpu())
64 | l_diffusion.append(noise.detach().cpu())
65 | l_samples.append(x.detach().to('cpu'))
66 |
67 | if intermediate_samples:
68 | return l_samples, l_dynamics, l_drift, l_diffusion
69 | else:
70 | return x.detach()
71 |
72 |
73 | def sample_langevin_v2(x, model, stepsize, n_steps, noise_scale=None, intermediate_samples=False,
74 | clip_x=None, clip_grad=None, reject_boundary=False, noise_anneal=None,
75 | spherical=False, mh=False, temperature=None):
76 | """Langevin Monte Carlo
77 | x: torch.Tensor, initial points
78 | model: An energy-based model. returns energy
79 | stepsize: float
80 | n_steps: integer
81 | noise_scale: Optional. float. If None, set to np.sqrt(stepsize * 2)
82 | clip_x : tuple (start, end) or None boundary of square domain
83 | reject_boundary: Reject out-of-domain samples if True. otherwise clip.
84 | """
85 | assert not ((stepsize is None) and (noise_scale is None)), 'stepsize and noise_scale cannot be None at the same time'
86 | if noise_scale is None:
87 | noise_scale = np.sqrt(stepsize * 2)
88 | if stepsize is None:
89 | stepsize = (noise_scale ** 2) / 2
90 | noise_scale_ = noise_scale
91 | stepsize_ = stepsize
92 | if temperature is None:
93 | temperature = 1.
94 |
95 | # initial data
96 | x.requires_grad = True
97 | E_x = model(x)
98 | grad_E_x = autograd.grad(E_x.sum(), x, only_inputs=True)[0]
99 | if clip_grad is not None:
100 | grad_E_x = clip_vector_norm(grad_E_x, max_norm=clip_grad)
101 | E_y = E_x; grad_E_y = grad_E_x;
102 |
103 | l_samples = [x.detach().to('cpu')]
104 | l_dynamics = []; l_drift = []; l_diffusion = []; l_accept = []
105 | for i_step in range(n_steps):
106 | noise = torch.randn_like(x) * noise_scale_
107 | dynamics = - stepsize_ * grad_E_x / temperature + noise
108 | y = x + dynamics
109 | reject = torch.zeros(len(y), dtype=torch.bool)
110 |
111 | if clip_x is not None:
112 | if reject_boundary:
113 | accept = ((y >= clip_x[0]) & (y <= clip_x[1])).view(len(x), -1).all(dim=1)
114 | reject = ~ accept
115 | y[reject] = x[reject]
116 | else:
117 | y = torch.clamp(y, clip_x[0], clip_x[1])
118 |
119 | if spherical:
120 | y = y / y.norm(dim=1, p=2, keepdim=True)
121 |
122 | # y_accept = y[~reject]
123 | # E_y[~reject] = model(y_accept)
124 | # grad_E_y[~reject] = autograd.grad(E_y.sum(), y_accept, only_inputs=True)[0]
125 | E_y = model(y)
126 | grad_E_y = autograd.grad(E_y.sum(), y, only_inputs=True)[0]
127 |
128 | if clip_grad is not None:
129 | grad_E_y = clip_vector_norm(grad_E_y, max_norm=clip_grad)
130 |
131 | if mh:
132 | y_to_x = ((grad_E_x + grad_E_y) * stepsize_ - noise).view(len(x), -1).norm(p=2, dim=1, keepdim=True) ** 2
133 | x_to_y = (noise).view(len(x), -1).norm(dim=1, keepdim=True, p=2) ** 2
134 | transition = - (y_to_x - x_to_y) / 4 / stepsize_ # B x 1
135 | prob = -E_y + E_x
136 | accept_prob = torch.exp((transition + prob) / temperature)[:,0] # B
137 | reject = (torch.rand_like(accept_prob) > accept_prob) # | reject
138 | y[reject] = x[reject]
139 | E_y[reject] = E_x[reject]
140 | grad_E_y[reject] = grad_E_x[reject]
141 | x = y; E_x = E_y; grad_E_x = grad_E_y
142 | l_accept.append(~reject)
143 |
144 | x = y; E_x = E_y; grad_E_x = grad_E_y
145 |
146 | if noise_anneal is not None:
147 | noise_scale_ = noise_scale / (1 + i_step)
148 |
149 | l_dynamics.append(dynamics.detach().cpu())
150 | l_drift.append((- stepsize * grad_E_x).detach().cpu())
151 | l_diffusion.append(noise.detach().cpu())
152 | l_samples.append(x.detach().cpu())
153 |
154 | return {'sample': x.detach(), 'l_samples': l_samples, 'l_dynamics': l_dynamics,
155 | 'l_drift': l_drift, 'l_diffusion': l_diffusion, 'l_accept': l_accept}
156 |
157 |
--------------------------------------------------------------------------------
/models/mmd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mmd(X1, X2, bandwidth='median'):
5 | """Compute Maximum Mean Discrepancy"""
6 | if len(X1.shape) == 4:
7 | X1 = X1.view(len(X1), -1)
8 | if len(X2.shape) == 4:
9 | X2 = X2.view(len(X2), -1)
10 |
11 | N1 = len(X1)
12 | X1_sq = X1.pow(2).sum(1).unsqueeze(0)
13 | X1_cr = torch.mm(X1, X1.t())
14 | X1_dist = X1_sq + X1_sq.t() - 2 * X1_cr
15 |
16 | N2 = len(X2)
17 | X2_sq = X2.pow(2).sum(1).unsqueeze(0)
18 | X2_cr = torch.mm(X2, X2.t())
19 | X2_dist = X2_sq + X2_sq.t() - 2 * X2_cr
20 |
21 | X12 = torch.mm(X1, X2.t())
22 | X12_dist = X1_sq.t() + X2_sq - 2 * X12
23 |
24 | # median heuristic to select bandwidth
25 | if bandwidth == 'median':
26 | X1_triu = X1_dist[torch.triu(torch.ones_like(X1_dist), diagonal=1) == 1]
27 | bandwidth1 = torch.median(X1_triu)
28 | X2_triu = X2_dist[torch.triu(torch.ones_like(X2_dist), diagonal=1) == 1]
29 | bandwidth2 = torch.median(X2_triu)
30 | bandwidth_sq = ((bandwidth1 + bandwidth2) / 2).detach()
31 | else:
32 | bandwidth_sq = (bandwidth ** 2)
33 |
34 | C = - 0.5 / bandwidth_sq
35 | K11 = torch.exp(C * X1_dist)
36 | K22 = torch.exp(C * X2_dist)
37 | K12 = torch.exp(C * X12_dist)
38 | K11 = (1 - torch.eye(N1).to(X1.device)) * K11
39 | K22 = (1 - torch.eye(N2).to(X1.device)) * K22
40 | mmd = K11.sum() / N1 / (N1 - 1) + K22.sum() / N2 / (N2 - 1) - 2 * K12.mean()
41 | return mmd
42 |
43 |
44 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import models
6 | from torch import optim
7 | from torch.distributions import Normal
8 | from torch.distributions.multivariate_normal import MultivariateNormal
9 | import numpy as np
10 | from models.spectral_norm import spectral_norm
11 | from models import igebm
12 |
13 |
14 | class DummyDistribution(nn.Module):
15 | """ Function-less class introduced for backward-compatibility of model checkpoint files. """
16 | def __init__(self, net):
17 | super().__init__()
18 | self.net = net
19 | self.register_buffer('sigma', torch.tensor(0., dtype=torch.float))
20 |
21 | def forward(self, x):
22 | return self.net(x)
23 |
24 |
25 | class IsotropicGaussian(nn.Module):
26 | """Isotripic Gaussian density function paramerized by a neural net.
27 | standard deviation is a free scalar parameter"""
28 | def __init__(self, net, sigma=1., sigma_trainable=False, error_normalize=True, deterministic=False):
29 | super().__init__()
30 | self.net = net
31 | self.sigma_trainable = sigma_trainable
32 | self.error_normalize = error_normalize
33 | self.deterministic = deterministic
34 | if sigma_trainable:
35 | # self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float))
36 | self.register_parameter('sigma', nn.Parameter(torch.tensor(sigma, dtype=torch.float)))
37 | else:
38 | self.register_buffer('sigma', torch.tensor(sigma, dtype=torch.float))
39 |
40 | def log_likelihood(self, x, z):
41 | decoder_out = self.net(z)
42 | if self.deterministic:
43 | return - ((x - decoder_out)**2).view((x.shape[0], -1)).sum(dim=1)
44 | else:
45 | D = torch.prod(torch.tensor(x.shape[1:]))
46 | # sig = torch.tensor(1, dtype=torch.float32)
47 | sig = self.sigma
48 | const = - D * 0.5 * torch.log(2 * torch.tensor(np.pi, dtype=torch.float32)) - D * torch.log(sig)
49 | loglik = const - 0.5 * ((x - decoder_out)**2).view((x.shape[0], -1)).sum(dim=1) / (sig ** 2)
50 | return loglik
51 |
52 | def error(self, x, x_hat):
53 | if not self.error_normalize:
54 | return (((x - x_hat) / self.sigma) ** 2).view(len(x), -1).sum(-1)
55 | else:
56 | return ((x - x_hat) ** 2).view(len(x), -1).mean(-1)
57 |
58 | def forward(self, z):
59 | """returns reconstruction"""
60 | return self.net(z)
61 |
62 | def sample(self, z):
63 | if self.deterministic:
64 | return self.mean(z)
65 | else:
66 | x_hat = self.net(z)
67 | return x_hat + torch.randn_like(x_hat) * self.sigma
68 |
69 | def mean(self, z):
70 | return self.net(z)
71 |
72 | def max_log_likelihood(self, x):
73 | if self.deterministic:
74 | return torch.tensor(0., dtype=torch.float, device=x.device)
75 | else:
76 | D = torch.prod(torch.tensor(x.shape[1:]))
77 | sig = self.sigma
78 | const = - D * 0.5 * torch.log(2 * torch.tensor(np.pi, dtype=torch.float32)) - D * torch.log(sig)
79 | return const
80 |
81 | class IsotropicLaplace(nn.Module):
82 | """Isotropic Laplace density function -- equivalent to using L1 error """
83 | def __init__(self, net, sigma=0.1, sigma_trainable=False):
84 | super().__init__()
85 | self.net = net
86 | self.sigma_trainable = sigma_trainable
87 | if sigma_trainable:
88 | self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float))
89 | else:
90 | self.register_buffer('sigma', torch.tensor(sigma, dtype=torch.float))
91 |
92 | def log_likelihood(self, x, z):
93 | # decoder_out = self.net(z)
94 | # D = torch.prod(torch.tensor(x.shape[1:]))
95 | # sig = torch.tensor(1, dtype=torch.float32)
96 | # const = - D * 0.5 * torch.log(2 * torch.tensor(np.pi, dtype=torch.float32)) - D * torch.log(sig)
97 | # loglik = const - 0.5 * (torch.abs(x - decoder_out)).view((x.shape[0], -1)).sum(dim=1) / (sig ** 2)
98 | # return loglik
99 | raise NotImplementedError
100 |
101 | def error(self, x, x_hat):
102 | if self.sigma_trainable:
103 | return ((torch.abs(x - x_hat) / self.sigma)).view(len(x), -1).sum(-1)
104 | else:
105 | return (torch.abs(x - x_hat)).view(len(x), -1).mean(-1)
106 |
107 | def forward(self, z):
108 | """returns reconstruction"""
109 | return self.net(z)
110 |
111 | def sample(self, z):
112 | # x_hat = self.net(z)
113 | # return x_hat + torch.randn_like(x_hat) * self.sigma
114 | raise NotImplementedError
115 |
116 |
117 | class ConvNet2FC(nn.Module):
118 | """additional 1x1 conv layer at the top"""
119 | def __init__(self, in_chan=1, out_chan=64, nh=8, nh_mlp=512, out_activation='linear', use_spectral_norm=False):
120 | """nh: determines the numbers of conv filters"""
121 | super(ConvNet2FC, self).__init__()
122 | self.conv1 = nn.Conv2d(in_chan, nh * 4, kernel_size=3, bias=True)
123 | self.conv2 = nn.Conv2d(nh * 4, nh * 8, kernel_size=3, bias=True)
124 | self.max1 = nn.MaxPool2d(kernel_size=2, stride=2)
125 | self.conv3 = nn.Conv2d(nh * 8, nh * 8, kernel_size=3, bias=True)
126 | self.conv4 = nn.Conv2d(nh * 8, nh * 16, kernel_size=3, bias=True)
127 | self.max2 = nn.MaxPool2d(kernel_size=2, stride=2)
128 | self.conv5 = nn.Conv2d(nh * 16, nh_mlp, kernel_size=4, bias=True)
129 | self.conv6 = nn.Conv2d(nh_mlp, out_chan, kernel_size=1, bias=True)
130 | self.in_chan, self.out_chan = in_chan, out_chan
131 | self.out_activation = get_activation(out_activation)
132 |
133 | if use_spectral_norm:
134 | self.conv1 = spectral_norm(self.conv1)
135 | self.conv2 = spectral_norm(self.conv2)
136 | self.conv3 = spectral_norm(self.conv3)
137 | self.conv4 = spectral_norm(self.conv4)
138 | self.conv5 = spectral_norm(self.conv5)
139 |
140 | layers = [self.conv1,
141 | nn.ReLU(),
142 | self.conv2,
143 | nn.ReLU(),
144 | self.max1,
145 | self.conv3,
146 | nn.ReLU(),
147 | self.conv4,
148 | nn.ReLU(),
149 | self.max2,
150 | self.conv5,
151 | nn.ReLU(),
152 | self.conv6,]
153 | if self.out_activation is not None:
154 | layers.append(self.out_activation)
155 |
156 |
157 | self.net = nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | return self.net(x)
161 |
162 |
163 | class DeConvNet2(nn.Module):
164 | def __init__(self, in_chan=1, out_chan=1, nh=8, out_activation='linear',
165 | use_spectral_norm=False):
166 | """nh: determines the numbers of conv filters"""
167 | super(DeConvNet2, self).__init__()
168 | self.conv1 = nn.ConvTranspose2d(in_chan, nh * 16, kernel_size=4, bias=True)
169 | self.conv2 = nn.ConvTranspose2d(nh * 16, nh * 8, kernel_size=3, bias=True)
170 | self.conv3 = nn.ConvTranspose2d(nh * 8, nh * 8, kernel_size=3, bias=True)
171 | self.conv4 = nn.ConvTranspose2d(nh * 8, nh * 4, kernel_size=3, bias=True)
172 | self.conv5 = nn.ConvTranspose2d(nh * 4, out_chan, kernel_size=3, bias=True)
173 | self.in_chan, self.out_chan = in_chan, out_chan
174 | self.out_activation = get_activation(out_activation)
175 |
176 | if use_spectral_norm:
177 | self.conv1 = spectral_norm(self.conv1)
178 | self.conv2 = spectral_norm(self.conv2)
179 | self.conv3 = spectral_norm(self.conv3)
180 | self.conv4 = spectral_norm(self.conv4)
181 | self.conv5 = spectral_norm(self.conv5)
182 |
183 | def forward(self, x):
184 | x = self.conv1(x)
185 | x = F.relu(x)
186 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
187 | x = self.conv2(x)
188 | x = F.relu(x)
189 | x = self.conv3(x)
190 | x = F.relu(x)
191 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
192 | x = self.conv4(x)
193 | x = F.relu(x)
194 | x = self.conv5(x)
195 | if self.out_activation is not None:
196 | x = self.out_activation(x)
197 | return x
198 |
199 |
200 | '''
201 | ConvNet for CIFAR10, following architecture in (Ghosh et al., 2019)
202 | but excluding batch normalization
203 | '''
204 |
205 | class ConvNet64(nn.Module):
206 | """ConvNet architecture for CelebA64 following Ghosh et al., 2019"""
207 | def __init__(self, in_chan=3, out_chan=64, nh=32, out_activation='linear', activation='relu',
208 | num_groups=None, use_bn=False):
209 | super().__init__()
210 | self.conv1 = nn.Conv2d(in_chan, nh * 4, kernel_size=5, bias=True, stride=2)
211 | self.conv2 = nn.Conv2d(nh * 4, nh * 8, kernel_size=5, bias=True, stride=2)
212 | self.conv3 = nn.Conv2d(nh * 8, nh * 16, kernel_size=5, bias=True, stride=2)
213 | self.conv4 = nn.Conv2d(nh * 16, nh * 32, kernel_size=5, bias=True, stride=2)
214 | self.fc1 = nn.Conv2d(nh * 32, out_chan, kernel_size=1, bias=True)
215 | self.in_chan, self.out_chan = in_chan, out_chan
216 | self.num_groups = num_groups
217 | self.use_bn = use_bn
218 |
219 | layers = []
220 | layers.append(self.conv1)
221 | if num_groups is not None:
222 | layers.append(self.get_norm_layer(num_channels=nh * 4))
223 | layers.append(get_activation(activation))
224 | layers.append(self.conv2)
225 | if num_groups is not None:
226 | layers.append(self.get_norm_layer(num_channels=nh * 8))
227 | layers.append(get_activation(activation))
228 | layers.append(self.conv3)
229 | if num_groups is not None:
230 | layers.append(self.get_norm_layer(num_channels=nh * 16))
231 | layers.append(get_activation(activation))
232 | layers.append(self.conv4)
233 | if num_groups is not None:
234 | layers.append(self.get_norm_layer(num_channels=nh * 32))
235 | layers.append(get_activation(activation))
236 | layers.append(self.fc1)
237 | out_activation = get_activation(out_activation)
238 | if out_activation is not None:
239 | layers.append(out_activation)
240 |
241 | self.net = nn.Sequential(*layers)
242 |
243 | def forward(self, x):
244 | return self.net(x)
245 |
246 | def get_norm_layer(self, num_channels):
247 | if self.num_groups is not None:
248 | return nn.GroupNorm(num_groups=self.num_groups, num_channels=num_channels)
249 | elif self.use_bn:
250 | return nn.BatchNorm2d(num_channels)
251 |
252 |
253 | class DeConvNet64(nn.Module):
254 | """ConvNet architecture for CelebA64 following Ghosh et al., 2019"""
255 | def __init__(self, in_chan=64, out_chan=3, nh=32, out_activation='linear', activation='relu',
256 | num_groups=None, use_bn=False):
257 | super().__init__()
258 | self.fc1 = nn.ConvTranspose2d(in_chan, nh * 32, kernel_size=8, bias=True)
259 | self.conv1 = nn.ConvTranspose2d(nh * 32, nh * 16, kernel_size=4, stride=2, padding=1, bias=True)
260 | self.conv2 = nn.ConvTranspose2d(nh * 16, nh * 8, kernel_size=4, stride=2, padding=1, bias=True)
261 | self.conv3 = nn.ConvTranspose2d(nh * 8, nh * 4, kernel_size=4, stride=2, padding=1, bias=True)
262 | self.conv4 = nn.ConvTranspose2d(nh * 4, out_chan, kernel_size=1, bias=True)
263 | self.in_chan, self.out_chan = in_chan, out_chan
264 | self.num_groups = num_groups
265 | self.use_bn = use_bn
266 |
267 | layers = []
268 | layers.append(self.fc1)
269 | if num_groups is not None:
270 | layers.append(self.get_norm_layer(num_channels=nh * 32))
271 | layers.append(get_activation(activation))
272 | layers.append(self.conv1)
273 | if num_groups is not None:
274 | layers.append(self.get_norm_layer(num_channels=nh * 16))
275 | layers.append(get_activation(activation))
276 | layers.append(self.conv2)
277 | if num_groups is not None:
278 | layers.append(self.get_norm_layer(num_channels=nh * 8))
279 | layers.append(get_activation(activation))
280 | layers.append(self.conv3)
281 | if num_groups is not None:
282 | layers.append(self.get_norm_layer(num_channels=nh * 4))
283 | layers.append(get_activation(activation))
284 | layers.append(self.conv4)
285 | out_activation = get_activation(out_activation)
286 | if out_activation is not None:
287 | layers.append(out_activation)
288 |
289 | self.net = nn.Sequential(*layers)
290 |
291 | def forward(self, x):
292 | return self.net(x)
293 |
294 | def get_norm_layer(self, num_channels):
295 | if self.num_groups is not None:
296 | return nn.GroupNorm(num_groups=self.num_groups, num_channels=num_channels)
297 | elif self.use_bn:
298 | return nn.BatchNorm2d(num_channels)
299 |
300 |
301 | class ConvMLPBlock(nn.Module):
302 | def __init__(self, dim, hidden_dim=None, out_dim=None):
303 | super().__init__()
304 | if hidden_dim is None:
305 | hidden_dim = dim
306 | if out_dim is None:
307 | out_dim = dim
308 |
309 | self.block = nn.Sequential(
310 | nn.Conv2d(dim, hidden_dim, kernel_size=1, stride=1),
311 | nn.ReLU(),
312 | nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1))
313 |
314 | def forward(self, x):
315 | return self.block(x)
316 |
317 |
318 | class DeConvNet3(nn.Module):
319 | def __init__(self, in_chan=1, out_chan=1, nh=32, out_activation='linear',
320 | activation='relu', num_groups=None):
321 | """nh: determines the numbers of conv filters"""
322 | super(DeConvNet3, self).__init__()
323 | self.num_groups = num_groups
324 | self.fc1 = nn.ConvTranspose2d(in_chan, nh * 32, kernel_size=8, bias=True)
325 | self.conv1 = nn.ConvTranspose2d(nh * 32, nh * 16, kernel_size=4, stride=2, padding=1, bias=True)
326 | self.conv2 = nn.ConvTranspose2d(nh * 16, nh * 8, kernel_size=4, stride=2, padding=1, bias=True)
327 | self.conv3 = nn.ConvTranspose2d(nh * 8, out_chan, kernel_size=1, bias=True)
328 | self.in_chan, self.out_chan = in_chan, out_chan
329 |
330 | layers = [self.fc1,]
331 | layers += [] if self.num_groups is None else [self.get_norm_layer(nh*32)]
332 | layers += [get_activation(activation), self.conv1,]
333 | layers += [] if self.num_groups is None else [self.get_norm_layer(nh*16)]
334 | layers += [get_activation(activation), self.conv2,]
335 | layers += [] if self.num_groups is None else [self.get_norm_layer(nh*8)]
336 | layers += [get_activation(activation), self.conv3]
337 | out_activation = get_activation(out_activation)
338 | if out_activation is not None:
339 | layers.append(out_activation)
340 |
341 | self.net = nn.Sequential(*layers)
342 |
343 | def forward(self, x):
344 | return self.net(x)
345 |
346 | def get_norm_layer(self, num_channels):
347 | if self.num_groups is not None:
348 | return nn.GroupNorm(num_groups=self.num_groups, num_channels=num_channels)
349 | # elif self.use_bn:
350 | # return nn.BatchNorm2d(num_channels)
351 | else:
352 | return None
353 |
354 |
355 | class IGEBMEncoder(nn.Module):
356 | """Neural Network used in IGEBM"""
357 | def __init__(self, in_chan=3, out_chan=1, n_class=None, use_spectral_norm=False, keepdim=True):
358 | super().__init__()
359 | self.keepdim = keepdim
360 | self.use_spectral_norm = use_spectral_norm
361 |
362 | if use_spectral_norm:
363 | self.conv1 = spectral_norm(nn.Conv2d(in_chan, 128, 3, padding=1), std=1)
364 | else:
365 | self.conv1 = nn.Conv2d(in_chan, 128, 3, padding=1)
366 |
367 | self.blocks = nn.ModuleList(
368 | [
369 | igebm.ResBlock(128, 128, n_class, downsample=True, use_spectral_norm=use_spectral_norm),
370 | igebm.ResBlock(128, 128, n_class, use_spectral_norm=use_spectral_norm),
371 | igebm.ResBlock(128, 256, n_class, downsample=True, use_spectral_norm=use_spectral_norm),
372 | igebm.ResBlock(256, 256, n_class, use_spectral_norm=use_spectral_norm),
373 | igebm.ResBlock(256, 256, n_class, downsample=True, use_spectral_norm=use_spectral_norm),
374 | igebm.ResBlock(256, 256, n_class, use_spectral_norm=use_spectral_norm),
375 | ]
376 | )
377 |
378 | if keepdim:
379 | self.linear = nn.Conv2d(256, out_chan, 1, 1, 0)
380 | else:
381 | self.linear = nn.Linear(256, out_chan)
382 | if use_spectral_norm:
383 | self.linear = spectral_norm(self.linear)
384 |
385 | def forward(self, input, class_id=None):
386 | out = self.conv1(input)
387 |
388 | out = F.leaky_relu(out, negative_slope=0.2)
389 |
390 | for block in self.blocks:
391 | out = block(out, class_id)
392 |
393 | out = F.relu(out)
394 | if self.keepdim:
395 | out = F.adaptive_avg_pool2d(out, (1,1))
396 | else:
397 | out = out.view(out.shape[0], out.shape[1], -1).sum(2)
398 |
399 | out = self.linear(out)
400 |
401 | return out
402 |
403 |
404 | class SphericalActivation(nn.Module):
405 | def __init__(self):
406 | super().__init__()
407 |
408 | def forward(self, x):
409 | return x / x.norm(p=2, dim=1, keepdim=True)
410 |
411 |
412 | # Fully Connected Network
413 | def get_activation(s_act):
414 | if s_act == 'relu':
415 | return nn.ReLU(inplace=True)
416 | elif s_act == 'sigmoid':
417 | return nn.Sigmoid()
418 | elif s_act == 'softplus':
419 | return nn.Softplus()
420 | elif s_act == 'linear':
421 | return None
422 | elif s_act == 'tanh':
423 | return nn.Tanh()
424 | elif s_act == 'leakyrelu':
425 | return nn.LeakyReLU(0.2, inplace=True)
426 | elif s_act == 'softmax':
427 | return nn.Softmax(dim=1)
428 | elif s_act == 'spherical':
429 | return SphericalActivation()
430 | else:
431 | raise ValueError(f'Unexpected activation: {s_act}')
432 |
433 |
434 | class FCNet(nn.Module):
435 | """fully-connected network"""
436 | def __init__(self, in_dim, out_dim, l_hidden=(50,), activation='sigmoid', out_activation='linear',
437 | use_spectral_norm=False):
438 | super().__init__()
439 | l_neurons = tuple(l_hidden) + (out_dim,)
440 | if isinstance(activation, str):
441 | activation = (activation,) * len(l_hidden)
442 | activation = tuple(activation) + (out_activation,)
443 |
444 | l_layer = []
445 | prev_dim = in_dim
446 | for i_layer, (n_hidden, act) in enumerate(zip(l_neurons, activation)):
447 | if use_spectral_norm and i_layer < len(l_neurons) - 1: # don't apply SN to the last layer
448 | l_layer.append(spectral_norm(nn.Linear(prev_dim, n_hidden)))
449 | else:
450 | l_layer.append(nn.Linear(prev_dim, n_hidden))
451 | act_fn = get_activation(act)
452 | if act_fn is not None:
453 | l_layer.append(act_fn)
454 | prev_dim = n_hidden
455 |
456 | self.net = nn.Sequential(*l_layer)
457 | self.in_dim = in_dim
458 | self.out_shape = (out_dim,)
459 |
460 | def forward(self, x):
461 | return self.net(x)
462 |
463 |
464 | class ConvMLP(nn.Module):
465 | def __init__(self, in_dim, out_dim, l_hidden=(50,), activation='sigmoid', out_activation='linear',
466 | likelihood_type='isotropic_gaussian'):
467 | super(ConvMLP, self).__init__()
468 | self.likelihood_type = likelihood_type
469 | l_neurons = tuple(l_hidden) + (out_dim,)
470 | activation = (activation,) * len(l_hidden)
471 | activation = tuple(activation) + (out_activation,)
472 |
473 | l_layer = []
474 | prev_dim = in_dim
475 | for i_layer, (n_hidden, act) in enumerate(zip(l_neurons, activation)):
476 | l_layer.append(nn.Conv2d(prev_dim, n_hidden, 1, bias=True))
477 | act_fn = get_activation(act)
478 | if act_fn is not None:
479 | l_layer.append(act_fn)
480 | prev_dim = n_hidden
481 |
482 | self.net = nn.Sequential(*l_layer)
483 | self.in_dim = in_dim
484 |
485 | def forward(self, x):
486 | return self.net(x)
487 |
488 |
489 | class FCResNet(nn.Module):
490 | """FullyConnected Residual Network
491 | Input - Linear - (ResBlock * K) - Linear - Output"""
492 | def __init__(self, in_dim, out_dim, res_dim, n_res_hidden=100, n_resblock=2, out_activation='linear', use_spectral_norm=False):
493 | super().__init__()
494 | l_layer = []
495 | block = nn.Linear(in_dim, res_dim)
496 | if use_spectral_norm:
497 | block = spectral_norm(block)
498 | l_layer.append(block)
499 |
500 | for i_resblock in range(n_resblock):
501 | block = FCResBlock(res_dim, n_res_hidden, use_spectral_norm=use_spectral_norm)
502 | l_layer.append(block)
503 | l_layer.append(nn.ReLU())
504 |
505 | block = nn.Linear(res_dim, out_dim)
506 | if use_spectral_norm:
507 | block = spectral_norm(block)
508 | l_layer.append(block)
509 | out_activation = get_activation(out_activation)
510 | if out_activation is not None:
511 | l_layer.append(out_activation)
512 | self.net = nn.Sequential(*l_layer)
513 |
514 | def forward(self, x):
515 | return self.net(x)
516 |
517 |
518 | class FCResBlock(nn.Module):
519 | def __init__(self, res_dim, n_res_hidden, use_spectral_norm=False):
520 | super().__init__()
521 | if use_spectral_norm:
522 | self.net = nn.Sequential(nn.ReLU(),
523 | spectral_norm(nn.Linear(res_dim, n_res_hidden)),
524 | nn.ReLU(),
525 | spectral_norm(nn.Linear(n_res_hidden, res_dim)))
526 | else:
527 | self.net = nn.Sequential(nn.ReLU(),
528 | nn.Linear(res_dim, n_res_hidden),
529 | nn.ReLU(),
530 | nn.Linear(n_res_hidden, res_dim))
531 |
532 | def forward(self, x):
533 | return x + self.net(x)
534 |
--------------------------------------------------------------------------------
/models/modules_sngan.py:
--------------------------------------------------------------------------------
1 | """
2 | ResNet architectures from SNGAN(Miyato et al., 2018)
3 | Brought this code from:
4 | https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
5 | """
6 | # ResNet generator and discriminator
7 | from torch import nn
8 | from torch import Tensor
9 | from torch.nn import Parameter
10 | import torch.nn.functional as F
11 |
12 | import numpy as np
13 |
14 |
15 | def l2normalize(v, eps=1e-12):
16 | return v / (v.norm() + eps)
17 |
18 |
19 | class SpectralNorm(nn.Module):
20 | def __init__(self, module, name='weight', power_iterations=1):
21 | super(SpectralNorm, self).__init__()
22 | self.module = module
23 | self.name = name
24 | self.power_iterations = power_iterations
25 | if not self._made_params():
26 | self._make_params()
27 |
28 | def _update_u_v(self):
29 | u = getattr(self.module, self.name + "_u")
30 | v = getattr(self.module, self.name + "_v")
31 | w = getattr(self.module, self.name + "_bar")
32 |
33 | height = w.data.shape[0]
34 | for _ in range(self.power_iterations):
35 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
36 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
37 |
38 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
39 | sigma = u.dot(w.view(height, -1).mv(v))
40 | setattr(self.module, self.name, w / sigma.expand_as(w))
41 |
42 | def _made_params(self):
43 | try:
44 | u = getattr(self.module, self.name + "_u")
45 | v = getattr(self.module, self.name + "_v")
46 | w = getattr(self.module, self.name + "_bar")
47 | return True
48 | except AttributeError:
49 | return False
50 |
51 |
52 | def _make_params(self):
53 | w = getattr(self.module, self.name)
54 |
55 | height = w.data.shape[0]
56 | width = w.view(height, -1).data.shape[1]
57 |
58 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
59 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
60 | u.data = l2normalize(u.data)
61 | v.data = l2normalize(v.data)
62 | w_bar = Parameter(w.data)
63 |
64 | del self.module._parameters[self.name]
65 |
66 | self.module.register_parameter(self.name + "_u", u)
67 | self.module.register_parameter(self.name + "_v", v)
68 | self.module.register_parameter(self.name + "_bar", w_bar)
69 |
70 |
71 | def forward(self, *args):
72 | self._update_u_v()
73 | return self.module.forward(*args)
74 |
75 |
76 |
77 |
78 | channels = 3
79 |
80 | class ResBlockGenerator(nn.Module):
81 |
82 | def __init__(self, in_channels, out_channels, stride=1):
83 | super(ResBlockGenerator, self).__init__()
84 |
85 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1)
86 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1)
87 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.)
88 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.)
89 |
90 | self.model = nn.Sequential(
91 | nn.BatchNorm2d(in_channels),
92 | nn.ReLU(),
93 | nn.Upsample(scale_factor=2),
94 | self.conv1,
95 | nn.BatchNorm2d(out_channels),
96 | nn.ReLU(),
97 | self.conv2
98 | )
99 | self.bypass = nn.Sequential()
100 | if stride != 1:
101 | self.bypass = nn.Upsample(scale_factor=2)
102 |
103 | def forward(self, x):
104 | return self.model(x) + self.bypass(x)
105 |
106 |
107 | class ResBlockGeneratorNoBN(nn.Module):
108 |
109 | def __init__(self, in_channels, out_channels, stride=1):
110 | super(ResBlockGeneratorNoBN, self).__init__()
111 |
112 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1)
113 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1)
114 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.)
115 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.)
116 |
117 | self.model = nn.Sequential(
118 | nn.ReLU(),
119 | nn.Upsample(scale_factor=2),
120 | self.conv1,
121 | nn.ReLU(),
122 | self.conv2
123 | )
124 | self.bypass = nn.Sequential()
125 | if stride != 1:
126 | self.bypass = nn.Upsample(scale_factor=2)
127 |
128 | def forward(self, x):
129 | return self.model(x) + self.bypass(x)
130 |
131 |
132 | class ResBlockGeneratorGN(nn.Module):
133 |
134 | def __init__(self, in_channels, out_channels, stride=1, num_groups=1):
135 | super().__init__()
136 |
137 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1)
138 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1)
139 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.)
140 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.)
141 |
142 | self.model = nn.Sequential(
143 | nn.GroupNorm(num_groups=num_groups, num_channels=in_channels),
144 | nn.ReLU(),
145 | nn.Upsample(scale_factor=2),
146 | self.conv1,
147 | nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
148 | nn.ReLU(),
149 | self.conv2
150 | )
151 | self.bypass = nn.Sequential()
152 | if stride != 1:
153 | self.bypass = nn.Upsample(scale_factor=2)
154 |
155 | def forward(self, x):
156 | return self.model(x) + self.bypass(x)
157 |
158 |
159 |
160 | class ResBlockDiscriminator(nn.Module):
161 |
162 | def __init__(self, in_channels, out_channels, stride=1):
163 | super(ResBlockDiscriminator, self).__init__()
164 |
165 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1)
166 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1)
167 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.)
168 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.)
169 |
170 | if stride == 1:
171 | self.model = nn.Sequential(
172 | nn.ReLU(),
173 | SpectralNorm(self.conv1),
174 | nn.ReLU(),
175 | SpectralNorm(self.conv2)
176 | )
177 | else:
178 | self.model = nn.Sequential(
179 | nn.ReLU(),
180 | SpectralNorm(self.conv1),
181 | nn.ReLU(),
182 | SpectralNorm(self.conv2),
183 | nn.AvgPool2d(2, stride=stride, padding=0)
184 | )
185 | self.bypass = nn.Sequential()
186 | if stride != 1:
187 |
188 | self.bypass_conv = nn.Conv2d(in_channels,out_channels, 1, 1, padding=0)
189 | nn.init.xavier_uniform_(self.bypass_conv.weight.data, np.sqrt(2))
190 |
191 | self.bypass = nn.Sequential(
192 | SpectralNorm(self.bypass_conv),
193 | nn.AvgPool2d(2, stride=stride, padding=0)
194 | )
195 | # if in_channels == out_channels:
196 | # self.bypass = nn.AvgPool2d(2, stride=stride, padding=0)
197 | # else:
198 | # self.bypass = nn.Sequential(
199 | # SpectralNorm(nn.Conv2d(in_channels,out_channels, 1, 1, padding=0)),
200 | # nn.AvgPool2d(2, stride=stride, padding=0)
201 | # )
202 |
203 |
204 | def forward(self, x):
205 | return self.model(x) + self.bypass(x)
206 |
207 | # special ResBlock just for the first layer of the discriminator
208 | class FirstResBlockDiscriminator(nn.Module):
209 |
210 | def __init__(self, in_channels, out_channels, stride=1):
211 | super(FirstResBlockDiscriminator, self).__init__()
212 |
213 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1)
214 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1)
215 | self.bypass_conv = nn.Conv2d(in_channels, out_channels, 1, 1, padding=0)
216 | nn.init.xavier_uniform_(self.conv1.weight.data, 1.)
217 | nn.init.xavier_uniform_(self.conv2.weight.data, 1.)
218 | nn.init.xavier_uniform_(self.bypass_conv.weight.data, np.sqrt(2))
219 |
220 | # we don't want to apply ReLU activation to raw image before convolution transformation.
221 | self.model = nn.Sequential(
222 | SpectralNorm(self.conv1),
223 | nn.ReLU(),
224 | SpectralNorm(self.conv2),
225 | nn.AvgPool2d(2)
226 | )
227 | self.bypass = nn.Sequential(
228 | nn.AvgPool2d(2),
229 | SpectralNorm(self.bypass_conv),
230 | )
231 |
232 | def forward(self, x):
233 | return self.model(x) + self.bypass(x)
234 |
235 | GEN_SIZE=128
236 | DISC_SIZE=128
237 |
238 | class Generator(nn.Module):
239 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None):
240 | super().__init__()
241 | self.z_dim = z_dim
242 |
243 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4)
244 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1)
245 | nn.init.xavier_uniform_(self.dense.weight.data, 1.)
246 | nn.init.xavier_uniform_(self.final.weight.data, 1.)
247 |
248 | l_layers = [
249 | ResBlockGenerator(hidden_dim, hidden_dim, stride=2),
250 | ResBlockGenerator(hidden_dim, hidden_dim, stride=2),
251 | ResBlockGenerator(hidden_dim, hidden_dim, stride=2),
252 | # nn.BatchNorm2d(hidden_dim), # should be uncommented. temporarily commented for backward compatibility
253 | nn.ReLU(),
254 | self.final,
255 | ]
256 |
257 | if out_activation == 'sigmoid':
258 | l_layers.append(nn.Sigmoid())
259 | elif out_activation == 'tanh':
260 | l_layers.append(nn.Tanh())
261 | self.model = nn.Sequential(*l_layers)
262 |
263 | def forward(self, z):
264 | return self.model(self.dense(z))
265 |
266 |
267 | class GeneratorNoBN(nn.Module):
268 | """remove batch normalization
269 | z vector is assumed to be 4-dimensional (fully-convolutional)"""
270 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None):
271 | super().__init__()
272 | self.z_dim = z_dim
273 |
274 | # self.dense = nn.Linear(self.z_dim, 4 * 4 * GEN_SIZE)
275 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4)
276 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1)
277 | nn.init.xavier_uniform_(self.dense.weight.data, 1.)
278 | nn.init.xavier_uniform_(self.final.weight.data, 1.)
279 |
280 | l_layers = [
281 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2),
282 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2),
283 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2),
284 | nn.ReLU(),
285 | self.final]
286 |
287 | if out_activation == 'sigmoid':
288 | l_layers.append(nn.Sigmoid())
289 | elif out_activation == 'tanh':
290 | l_layers.append(nn.Tanh())
291 | self.model = nn.Sequential(*l_layers)
292 |
293 | def forward(self, z):
294 | # return self.model(self.dense(z).view(-1, GEN_SIZE, 4, 4))
295 | return self.model(self.dense(z))
296 |
297 |
298 | class GeneratorNoBN64(nn.Module):
299 | """remove batch normalization
300 | z vector is assumed to be 4-dimensional (fully-convolutional)
301 | for generating 64x64 output"""
302 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None):
303 | super().__init__()
304 | self.z_dim = z_dim
305 |
306 | # self.dense = nn.Linear(self.z_dim, 4 * 4 * GEN_SIZE)
307 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4)
308 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1)
309 | nn.init.xavier_uniform_(self.dense.weight.data, 1.)
310 | nn.init.xavier_uniform_(self.final.weight.data, 1.)
311 |
312 | l_layers = [
313 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2),
314 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2),
315 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2),
316 | ResBlockGeneratorNoBN(hidden_dim, hidden_dim, stride=2),
317 | nn.ReLU(),
318 | self.final]
319 |
320 | if out_activation == 'sigmoid':
321 | l_layers.append(nn.Sigmoid())
322 | elif out_activation == 'tanh':
323 | l_layers.append(nn.Tanh())
324 | self.model = nn.Sequential(*l_layers)
325 |
326 | def forward(self, z):
327 | # return self.model(self.dense(z).view(-1, GEN_SIZE, 4, 4))
328 | return self.model(self.dense(z))
329 |
330 |
331 | class GeneratorGN(nn.Module):
332 | """replace Batch Normalization to Group Normalization
333 | z vector is assumed to be 4-dimensional (fully-convolutional)"""
334 | def __init__(self, z_dim, channels, hidden_dim=128, out_activation=None, num_groups=1):
335 | super().__init__()
336 | self.z_dim = z_dim
337 |
338 | # self.dense = nn.Linear(self.z_dim, 4 * 4 * GEN_SIZE)
339 | self.dense = nn.ConvTranspose2d(z_dim, hidden_dim, 4)
340 | self.final = nn.Conv2d(hidden_dim, channels, 3, stride=1, padding=1)
341 | nn.init.xavier_uniform_(self.dense.weight.data, 1.)
342 | nn.init.xavier_uniform_(self.final.weight.data, 1.)
343 |
344 | l_layers = [
345 | ResBlockGeneratorGN(hidden_dim, hidden_dim, stride=2),
346 | ResBlockGeneratorGN(hidden_dim, hidden_dim, stride=2),
347 | ResBlockGeneratorGN(hidden_dim, hidden_dim, stride=2),
348 | nn.GroupNorm(num_groups=num_groups, num_channels=hidden_dim),
349 | nn.ReLU(),
350 | self.final]
351 |
352 | if out_activation == 'sigmoid':
353 | l_layers.append(nn.Sigmoid())
354 | elif out_activation == 'tanh':
355 | l_layers.append(nn.Tanh())
356 | self.model = nn.Sequential(*l_layers)
357 |
358 | def forward(self, z):
359 | # return self.model(self.dense(z).view(-1, GEN_SIZE, 4, 4))
360 | return self.model(self.dense(z))
361 |
362 |
363 |
364 |
365 | class Discriminator(nn.Module):
366 | def __init__(self):
367 | super(Discriminator, self).__init__()
368 |
369 | self.model = nn.Sequential(
370 | FirstResBlockDiscriminator(channels, DISC_SIZE, stride=2),
371 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE, stride=2),
372 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE),
373 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE),
374 | nn.ReLU(),
375 | nn.AvgPool2d(8),
376 | )
377 | self.fc = nn.Linear(DISC_SIZE, 1)
378 | nn.init.xavier_uniform_(self.fc.weight.data, 1.)
379 | self.fc = SpectralNorm(self.fc)
380 |
381 | def forward(self, x):
382 | return self.fc(self.model(x).view(-1,DISC_SIZE))
383 |
--------------------------------------------------------------------------------
/models/spectral_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn import utils
6 |
7 |
8 | class SpectralNorm:
9 | def __init__(self, name, bound=False):
10 | self.name = name
11 | self.bound = bound
12 |
13 | def compute_weight(self, module):
14 | weight = getattr(module, self.name + '_orig')
15 | u = getattr(module, self.name + '_u')
16 | size = weight.size()
17 | weight_mat = weight.contiguous().view(size[0], -1)
18 |
19 | with torch.no_grad():
20 | v = weight_mat.t() @ u
21 | v = v / v.norm()
22 | u = weight_mat @ v
23 | u = u / u.norm()
24 |
25 | sigma = u @ weight_mat @ v
26 |
27 | if self.bound:
28 | weight_sn = weight / (sigma + 1e-6) * torch.clamp(sigma, max=1)
29 |
30 | else:
31 | weight_sn = weight / sigma
32 |
33 | return weight_sn, u
34 |
35 | @staticmethod
36 | def apply(module, name, bound):
37 | fn = SpectralNorm(name, bound)
38 |
39 | weight = getattr(module, name)
40 | del module._parameters[name]
41 | module.register_parameter(name + '_orig', weight)
42 | input_size = weight.size(0)
43 | u = weight.new_empty(input_size).normal_()
44 | module.register_buffer(name, weight)
45 | module.register_buffer(name + '_u', u)
46 |
47 | module.register_forward_pre_hook(fn)
48 |
49 | return fn
50 |
51 | def __call__(self, module, input):
52 | weight_sn, u = self.compute_weight(module)
53 | setattr(module, self.name, weight_sn)
54 | setattr(module, self.name + '_u', u)
55 |
56 |
57 | def spectral_norm(module, init=True, std=1, bound=False):
58 | if init:
59 | nn.init.normal_(module.weight, 0, std)
60 |
61 | if hasattr(module, 'bias') and module.bias is not None:
62 | module.bias.data.zero_()
63 |
64 | SpectralNorm.apply(module, 'weight', bound=bound)
65 |
66 | return module
67 |
--------------------------------------------------------------------------------
/optimizers.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop
4 |
5 | logger = logging.getLogger("ptsemseg")
6 |
7 | key2opt = {
8 | "sgd": SGD,
9 | "adam": Adam,
10 | "asgd": ASGD,
11 | "adamax": Adamax,
12 | "adadelta": Adadelta,
13 | "adagrad": Adagrad,
14 | "rmsprop": RMSprop,
15 | }
16 |
17 |
18 | def get_optimizer(opt_dict, model_params):
19 | opt_dict = opt_dict.copy()
20 | optimizer = _get_optimizer_instance(opt_dict)
21 |
22 | # params = {k: v for k, v in opt_dict.items() if k != "name"}
23 | opt_dict.pop('name')
24 |
25 | optimizer = optimizer(model_params, **opt_dict)
26 |
27 | return optimizer, None
28 |
29 |
30 | def _get_optimizer_instance(opt_dict):
31 | if opt_dict is None:
32 | logger.info("Using SGD optimizer")
33 | return SGD
34 |
35 | else:
36 | opt_name = opt_dict["name"]
37 | if opt_name not in key2opt:
38 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name))
39 |
40 | logger.info("Using {} optimizer".format(opt_name))
41 | return key2opt[opt_name]
42 |
--------------------------------------------------------------------------------
/pretrained/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swyoon/normalized-autoencoders/77e198712bc88460ea2c2086166fe10f16735d0b/pretrained/.gitkeep
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | """
2 | generate samples
3 | """
4 | import os
5 | import numpy as np
6 | import torch
7 | from models import get_model
8 | from omegaconf import OmegaConf
9 | from tqdm import tqdm
10 | import argparse
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('result_dir', type=str, help='path to the directory where yaml file and checkpoint is stored.')
15 | parser.add_argument('config', type=str, help='the name of configs file.')
16 | parser.add_argument('ckpt', type=str, help='the name of the checkpoint file.')
17 | parser.add_argument('--device', default=0, type=int, help='the id of cuda device to use')
18 | parser.add_argument('--n_sample', default=1000, type=int, help='the number of samples to be generated')
19 | parser.add_argument('--zstep', default=None, type=int, help='the number of the latent Langevin MC steps')
20 | parser.add_argument('--xstep', default=None, type=int, help='the number of the visible Langevin MC steps')
21 | parser.add_argument('--x_shape', default=32, type=int, help='the size of a side of an image.')
22 | parser.add_argument('--x_channel', default=3, type=int, help='the number of channels in an image.')
23 | parser.add_argument('--batch_size', default=128, type=int)
24 | parser.add_argument('--name', default=None, type=str, help='additional identifier for the result file.')
25 | parser.add_argument('--replay', default=False, action='store_true', help='to use the sample replay buffer')
26 | args = parser.parse_args()
27 |
28 |
29 | config_path = os.path.join(args.result_dir, args.config)
30 | ckpt_path = os.path.join(args.result_dir, args.ckpt)
31 | device = f'cuda:{args.device}'
32 | print(f'using {device}')
33 |
34 | # load model
35 | print(f'building a model from {config_path}')
36 | cfg = OmegaConf.load(config_path)
37 | model = get_model(cfg)
38 | model_class = cfg['model']['arch']
39 |
40 | print(f'loading a model from {ckpt_path}')
41 | state = torch.load(ckpt_path, map_location=device)
42 | if 'model_state' in state:
43 | state = state['model_state']
44 | model.load_state_dict(state)
45 | model.to(device)
46 | model.eval()
47 |
48 | if model_class == 'nae':
49 | print(f'replay {args.replay}')
50 | model.replay = args.replay
51 |
52 | dummy_x = torch.rand(1, args.x_channel, args.x_shape, args.x_shape, dtype=torch.float).to(device)
53 | model._set_x_shape(dummy_x)
54 | model._set_z_shape(dummy_x)
55 |
56 | if args.zstep is not None:
57 | model.z_step = args.zstep
58 | print(f'z_step: {model.z_step}')
59 | if args.xstep is not None:
60 | model.x_step = args.xstep
61 | print(f'x_step: {model.x_step}')
62 | elif model_class == 'vae':
63 | dummy_x = torch.rand(1, 3, args.x_shape, args.x_shape, dtype=torch.float).to(device)
64 | model._set_z_shape(dummy_x)
65 |
66 |
67 | # run sampling
68 | batch_size = args.batch_size
69 | n_batch = int(np.ceil(args.n_sample / batch_size))
70 | l_sample = []
71 | for i_batch in tqdm(range(n_batch)):
72 | if i_batch == n_batch - 1:
73 | n_sample = args.n_sample % batch_size if args.n_sample % batch_size else batch_size
74 | else:
75 | n_sample = batch_size
76 | d_sample = model.sample(n_sample=n_sample, device=device)
77 | sample = d_sample['sample_x'].detach()
78 |
79 | # re-quantization
80 | # sample = (sample * 255 + 0.5).clamp(0, 255).cpu().to(torch.uint8)
81 | sample = (sample * 256).clamp(0, 255).cpu().to(torch.uint8)
82 |
83 | sample = sample.permute(0, 2, 3, 1).numpy()
84 |
85 |
86 | l_sample.append(sample)
87 | sample = np.concatenate(l_sample)
88 | print(f'sample shape: {sample.shape}')
89 |
90 | # save result
91 | if args.name is not None:
92 | out_name = os.path.join(args.result_dir, f'{args.ckpt.strip(".pkl")}_sample_{args.name}.npy')
93 | else:
94 | out_name = os.path.join(args.result_dir, f'{args.ckpt.strip(".pkl")}_sample.npy')
95 | np.save(out_name, sample)
96 | print(f'sample saved at {out_name}')
97 |
98 |
99 |
--------------------------------------------------------------------------------
/tests/test_load_pretrained.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from models import load_pretrained
3 |
4 | # (identifier, config_file, ckpt_file, kwargs)
5 | cifar_ood_nae = ('cifar_ood_nae/z32gn', 'z32gn.yml', 'nae_8.pkl', {})
6 | cifar_ood_nae_ae = ('cifar_ood_nae/z32gn', 'z32gn.yml', 'model_best.pkl', {}) # NAE before training
7 |
8 | mnist_ood_nae = ('mnist_ood_nae/z32', 'z32.yml', 'nae_20.pkl', {})
9 | mnist_ood_nae_ae = ('mnist_ood_nae/z32', 'z32.yml', 'model_best.pkl', {}) # NAE before training
10 |
11 | celeba64_ood_nae = ('celeba64_ood_nae/z64gr_h32g8', 'z64gr_h32g8.yml', 'nae_3.pkl', {})
12 | celeba64_ood_nae_ae = ('celeba64_ood_nae/z64gr_h32g8', 'z64gr_h32g8.yml', 'model_best.pkl', {}) # NAE before training
13 | l_setting = [cifar_ood_nae, cifar_ood_nae_ae,
14 | mnist_ood_nae, mnist_ood_nae_ae,
15 | celeba64_ood_nae, celeba64_ood_nae_ae]
16 |
17 |
18 | @pytest.mark.parametrize('model_setting', l_setting)
19 | def test_load_pretrained(model_setting):
20 | identifier, config_file, ckpt_file, kwargs = model_setting
21 | model, cfg = load_pretrained(identifier, config_file, ckpt_file, **kwargs)
22 |
--------------------------------------------------------------------------------
/tests/test_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from loaders import get_dataset, get_dataloader
3 | from torch.utils import data
4 | import yaml
5 | import pytest
6 | from skimage import io
7 | import pickle
8 | import torch
9 | import os
10 |
11 |
12 | def test_get_dataloader():
13 | cfg = {'dataset': 'FashionMNISTpad_OOD',
14 | 'path': 'datasets',
15 | 'shuffle': True,
16 | 'n_workers': 0,
17 | 'batch_size': 1,
18 | 'split': 'training'}
19 | dl = get_dataloader(cfg)
20 |
21 |
22 | def test_concat_dataset():
23 | data_cfg = {'concat1':
24 | {'dataset': 'FashionMNISTpad_OOD',
25 | 'path': 'datasets',
26 | 'shuffle': True,
27 | 'split': 'training'},
28 | 'concat2':
29 | {'dataset': 'MNISTpad_OOD',
30 | 'path': 'datasets',
31 | 'shuffle': True,
32 | 'n_workers': 0,
33 | 'batch_size': 1,
34 | 'split': 'training'},
35 | 'n_workers': 0,
36 | 'batch_size': 1,
37 | }
38 | get_dataset(data_cfg)
39 |
40 |
--------------------------------------------------------------------------------
/tests/test_modules.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.optim import Adam
4 | from models.modules import FCNet, FCResNet, IsotropicGaussian
5 | from models.modules import ConvNet64, DeConvNet64
6 | from models.modules_sngan import GeneratorNoBN
7 |
8 | @pytest.mark.parametrize('out_activation', ['linear', 'spherical'])
9 | def test_fcnet(out_activation):
10 | net = FCNet(in_dim=1, out_dim=5, l_hidden=(50,), activation='sigmoid', out_activation=out_activation)
11 | X = torch.rand(10, 1)
12 | Z = net(X)
13 |
14 |
15 |
16 | def test_sngan_generator():
17 | z_dim = 4
18 | channel = 3
19 | z = torch.rand(3, z_dim, 1, 1, dtype=torch.float)
20 |
21 | net = GeneratorNoBN(z_dim, channel, hidden_dim=32, out_activation='sigmoid')
22 | out = net(z)
23 | assert out.shape == (3, 3, 32, 32)
24 |
25 |
26 | def test_fcresnet():
27 | net = FCResNet(in_dim=2, out_dim=2, res_dim=20, n_res_hidden=100, n_resblock=2)
28 |
29 | x = torch.rand((20, 2))
30 | out = net(x)
31 | assert out.shape == (20, 2)
32 |
33 |
34 | @pytest.mark.parametrize('distribution', ['IsotropicGaussian'])
35 | def test_distribution_modules(distribution):
36 | if distribution == 'IsotropicGaussian':
37 | dist = IsotropicGaussian
38 |
39 | N = 10; z_dim = 3; x_dim = 5
40 | z = torch.rand(N, z_dim)
41 | x = torch.rand(N, x_dim)
42 | net = FCNet(in_dim=z_dim, out_dim=x_dim, l_hidden=(50,), activation='sigmoid', out_activation='sigmoid')
43 | net = dist(net)
44 | lik = net.log_likelihood(x, z)
45 | assert lik.shape == (N,)
46 | samples = net.sample(z)
47 | assert samples.shape == (N, x_dim)
48 |
49 |
50 | @pytest.mark.parametrize('num_groups', [None, 2])
51 | def test_convnet64(num_groups):
52 | x = torch.rand(2, 3, 64, 64)
53 | encoder = ConvNet64(num_groups=num_groups)
54 | decoder = DeConvNet64(num_groups=num_groups)
55 | recon = decoder(encoder(x))
56 | assert x.shape == recon.shape
57 |
58 |
--------------------------------------------------------------------------------
/tests/test_nae.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.optim import Adam
4 | from models.modules import FCNet
5 | from models.nae import NAE, FFEBM, FFEBMV2, NAE_L2_CD, NAE_L2_NCE, NAE_L2_OMI
6 | from models import get_model
7 | from models.mcmc import LangevinSampler, MHSampler, NoiseSampler
8 | from omegaconf import OmegaConf
9 |
10 |
11 | def test_ffebm():
12 | net = FCNet(2, 1, out_activation='linear')
13 | model = FFEBM(net, x_step=3, x_stepsize=1., sampling='x')
14 | X = torch.randn((10, 2), dtype=torch.float)
15 | opt = Adam(model.parameters(), lr=1e-4)
16 |
17 | # forward
18 | lik = model.predict(X)
19 |
20 | # sample
21 | model._set_x_shape(X)
22 | d_sample = model.sample_x(10, 'cpu:0')
23 |
24 | # training
25 | model.train_step(X, opt)
26 |
27 |
28 | @pytest.mark.parametrize('sampling', ['x', 'on_manifold', 'cd'])
29 | def test_nae(sampling):
30 | encoder = FCNet(2, 1)
31 | decoder = FCNet(1, 2)
32 | nae = NAE(encoder, decoder, initial_dist='gaussian', sampling=sampling)
33 | opt = Adam(nae.parameters(), lr=1e-4)
34 |
35 | X = torch.randn((10, 2), dtype=torch.float)
36 |
37 | # forward
38 | lik = nae.predict(X)
39 |
40 | # sample
41 | nae._set_x_shape(X)
42 | nae._set_z_shape(X)
43 | d_sample = nae.sample(X)
44 |
45 | # training
46 | nae.train_step(X, opt)
47 |
48 |
49 | def test_cifar():
50 | cfg = OmegaConf.load('configs/cifar_ood_nae/z32gn.yml')
51 | model = get_model(cfg)
52 | xx = torch.rand(2, 3, 32, 32)
53 | recon = model.reconstruct(xx)
54 | error = model(xx)
55 |
56 | assert recon.shape == (2, 3, 32, 32)
57 | assert error.shape == (2,)
58 |
59 |
60 | @pytest.mark.parametrize('sampling', ['cd', 'x'])
61 | @pytest.mark.parametrize('sampler', ['mh', 'langevin'])
62 | def test_ffebm_v2(sampling, sampler):
63 | net = FCNet(2, 1, out_activation='linear')
64 | if sampler == 'mh':
65 | sampler = MHSampler(n_step=2, stepsize=0.1, bound=(-5, 5), buffer_size=10000, replay_ratio=0.95, reject_boundary=False,
66 | mh=True, initial_dist='uniform')
67 | elif sampler == 'langevin':
68 | sampler = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=1.,
69 | bound=(-5, 5), buffer_size=10000, replay_ratio=0.95, reject_boundary=False,
70 | mh=True, initial_dist='uniform')
71 | model = FFEBMV2(net, sampler, gamma=1, sampling=sampling)
72 | X = torch.randn((10, 2), dtype=torch.float)
73 | opt = Adam(model.parameters(), lr=1e-4)
74 |
75 | # forward
76 | lik = model.predict(X)
77 |
78 | # sample
79 | model._set_x_shape(X)
80 | d_sample = model.sample(n_sample=10, device='cpu:0', sampling='x')
81 |
82 | # training
83 | model.train_step(X, opt)
84 |
85 |
86 | @pytest.mark.parametrize('sampling', ['cd', 'x'])
87 | def test_nae_l2_cd(sampling):
88 | encoder = FCNet(2, 1, out_activation='spherical')
89 | decoder = FCNet(1, 2)
90 | sampler = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=1.,
91 | bound=(-5, 5), buffer_size=10000, replay_ratio=0.95, reject_boundary=False,
92 | mh=True, initial_dist='uniform')
93 | nae = NAE_L2_CD(encoder, decoder, sampler, sampling=sampling)
94 | opt = Adam(nae.parameters(), lr=1e-4)
95 |
96 | X = torch.randn((10, 2), dtype=torch.float)
97 |
98 | # forward
99 | lik = nae.predict(X)
100 |
101 | # sample
102 | nae._set_x_shape(X)
103 | # nae._set_z_shape(X)
104 | d_sample = nae.sample(X)
105 |
106 | # training
107 | nae.train_step(X, opt)
108 |
109 |
110 | def test_nae_l2_nce():
111 | encoder = FCNet(2, 1, out_activation='spherical')
112 | decoder = FCNet(1, 2)
113 | sampler = NoiseSampler(dist='gaussian', shape=(2,), offset=1., scale=1.)
114 | nae = NAE_L2_NCE(encoder, decoder, sampler, T=0.5, T_trainable=True)
115 | opt = Adam(nae.parameters(), lr=1e-4)
116 |
117 | X = torch.randn((10, 2), dtype=torch.float)
118 |
119 | # forward
120 | lik = nae.predict(X)
121 |
122 | # training
123 | nae.train_step(X, opt)
124 |
125 |
126 | @pytest.mark.parametrize('spherical', [True, False])
127 | def test_nae_l2_omi(spherical):
128 | encoder = FCNet(2, 3, out_activation='linear')
129 | decoder = FCNet(3, 2)
130 | sampler_x = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=1.,
131 | bound=(-5, 5), buffer_size=10000, replay_ratio=0.,
132 | reject_boundary=False, mh=True, initial_dist='uniform')
133 |
134 | sampler_z = LangevinSampler(n_step=2, stepsize=0.1, noise_std=0.1, noise_anneal=None,
135 | bound=(-5, 5), buffer_size=10000, replay_ratio=0.95,
136 | reject_boundary=False, mh=True, initial_dist='uniform')
137 |
138 | nae = NAE_L2_OMI(encoder, decoder, sampler_z, sampler_x, T=0.5, T_trainable=True, spherical=spherical)
139 | opt = Adam(nae.parameters(), lr=1e-4)
140 |
141 | X = torch.randn((10, 2), dtype=torch.float)
142 |
143 | # forward
144 | lik = nae.predict(X)
145 |
146 | # training
147 | nae.train_step(X, opt)
148 |
149 |
150 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import argparse
5 | from omegaconf import OmegaConf
6 | import numpy as np
7 | from itertools import cycle
8 | import torch
9 | from models import get_model
10 | from trainers import get_trainer, get_logger
11 | from loaders import get_dataloader
12 | from optimizers import get_optimizer
13 | from datetime import datetime
14 | from tensorboardX import SummaryWriter
15 | from utils import save_yaml, search_params_intp, eprint, parse_unknown_args, parse_nested_args
16 |
17 |
18 | def run(cfg, writer):
19 | """main training function"""
20 | # Setup seeds
21 | seed = cfg.get('seed', 1)
22 | print(f'running with random seed : {seed}')
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed(seed)
25 | np.random.seed(seed)
26 | random.seed(seed)
27 | # for reproducibility
28 | # torch.backends.cudnn.deterministic = True
29 | # torch.backends.cudnn.benchmark = False
30 |
31 | # Setup device
32 | device = cfg.device
33 |
34 | # Setup Dataloader
35 | d_dataloaders = {}
36 | for key, dataloader_cfg in cfg['data'].items():
37 | if 'holdout' in cfg:
38 | dataloader_cfg = process_holdout(dataloader_cfg, int(cfg['holdout']))
39 | d_dataloaders[key] = get_dataloader(dataloader_cfg)
40 |
41 | # Setup Model
42 | model = get_model(cfg).to(device)
43 | trainer = get_trainer(cfg)
44 | logger = get_logger(cfg, writer)
45 |
46 | # Setup optimizer
47 | if hasattr(model, 'own_optimizer') and model.own_optimizer:
48 | optimizer, sch = model.get_optimizer(cfg['training']['optimizer'])
49 | elif 'optimizer' not in cfg['training']:
50 | optimizer = None
51 | sch = None
52 | else:
53 | optimizer, sch = get_optimizer(cfg["training"]["optimizer"], model.parameters())
54 |
55 | model, train_result = trainer.train(model, optimizer, d_dataloaders, logger=logger,
56 | logdir=writer.file_writer.get_logdir(), scheduler=sch,
57 | clip_grad=cfg['training'].get('clip_grad', None))
58 |
59 |
60 |
61 | def process_holdout(dataloader_cfg, holdout):
62 | """udpate config if holdout option is present in config"""
63 | if 'LeaveOut' in dataloader_cfg['dataset'] and 'out_class' in dataloader_cfg:
64 | if len(dataloader_cfg['out_class'] ) == 1: # indist
65 | dataloader_cfg['out_class'] = [holdout]
66 | else: # ood
67 | dataloader_cfg['out_class'] = [i for i in range(10) if i != holdout]
68 | print(dataloader_cfg)
69 | return dataloader_cfg
70 |
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--config', type=str)
76 | parser.add_argument('--device', default=0)
77 | parser.add_argument('--logdir', default='results/')
78 | parser.add_argument('--run', default=None, help='unique run id of the experiment')
79 | args, unknown = parser.parse_known_args()
80 | d_cmd_cfg = parse_unknown_args(unknown)
81 | d_cmd_cfg = parse_nested_args(d_cmd_cfg)
82 | print(d_cmd_cfg)
83 | cfg = OmegaConf.load(args.config)
84 | if args.device == 'cpu':
85 | cfg['device'] = f'cpu'
86 | else:
87 | cfg['device'] = f'cuda:{args.device}'
88 |
89 | if args.run is None:
90 | run_id = datetime.now().strftime('%Y%m%d-%H%M')
91 | else:
92 | run_id = args.run
93 | cfg = OmegaConf.merge(cfg, d_cmd_cfg)
94 | print(OmegaConf.to_yaml(cfg))
95 |
96 | config_basename = os.path.basename(args.config).split('.')[0]
97 | logdir = os.path.join(args.logdir, config_basename, str(run_id))
98 | writer = SummaryWriter(logdir=logdir)
99 | print("Result directory: {}".format(logdir))
100 |
101 | # copy config file
102 | copied_yml = os.path.join(logdir, os.path.basename(args.config))
103 | save_yaml(copied_yml, OmegaConf.to_yaml(cfg))
104 | print(f'config saved as {copied_yml}')
105 |
106 | run(cfg, writer)
107 |
108 |
109 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from trainers.logger import BaseLogger
2 | from trainers.base import BaseTrainer
3 | from trainers.nae import NAETrainer, NAETrainerV2, NAELogger
4 |
5 |
6 | def get_trainer(cfg):
7 | # get trainer by specified `trainer` field
8 | # if not speficied, get trainer by model type
9 | trainer_type = cfg.get('trainer', None)
10 | arch = cfg['model']['arch']
11 | device = cfg['device']
12 | if trainer_type == 'nae':
13 | trainer = NAETrainer(cfg['training'], device=device)
14 | elif trainer_type == 'nae_v2':
15 | trainer = NAETrainerV2(cfg['training'], device=device)
16 | else:
17 | trainer = BaseTrainer(cfg['training'], device=device)
18 | return trainer
19 |
20 |
21 | def get_logger(cfg, writer):
22 | logger_type = cfg['logger']
23 | if logger_type == 'nae':
24 | logger = NAELogger(writer)
25 | else:
26 | logger = BaseLogger(writer)
27 | return logger
28 |
--------------------------------------------------------------------------------
/trainers/base.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import torch
4 | import numpy as np
5 | from metrics import averageMeter
6 |
7 |
8 | class BaseTrainer:
9 | """Trainer for a conventional iterative training of model"""
10 | def __init__(self, training_cfg, device):
11 | self.cfg = training_cfg
12 | self.device = device
13 | self.d_val_result = {}
14 |
15 | def train(self, model, opt, d_dataloaders, logger=None, logdir='', scheduler=None, clip_grad=None):
16 | cfg = self.cfg
17 | best_val_loss = np.inf
18 | time_meter = averageMeter()
19 | i = 0
20 | train_loader, val_loader = d_dataloaders['training'], d_dataloaders['validation']
21 |
22 | for i_epoch in range(cfg.n_epoch):
23 |
24 | for x, y in train_loader:
25 | i += 1
26 |
27 | model.train()
28 | x = x.to(self.device)
29 | y = y.to(self.device)
30 |
31 | start_ts = time.time()
32 | d_train = model.train_step(x, y=y, optimizer=opt, clip_grad=clip_grad)
33 | time_meter.update(time.time() - start_ts)
34 | logger.process_iter_train(d_train)
35 |
36 | if i % cfg.print_interval == 0:
37 | d_train = logger.summary_train(i)
38 | print(f"Iter [{i:d}] Avg Loss: {d_train['loss/train_loss_']:.4f} Elapsed time: {time_meter.sum:.4f}")
39 | time_meter.reset()
40 |
41 | if i % cfg.val_interval == 0:
42 | model.eval()
43 | for val_x, val_y in val_loader:
44 | val_x = val_x.to(self.device)
45 | val_y = val_y.to(self.device)
46 |
47 | d_val = model.validation_step(val_x, y=val_y)
48 | logger.process_iter_val(d_val)
49 | d_val = logger.summary_val(i)
50 | val_loss = d_val['loss/val_loss_']
51 | print(d_val['print_str'])
52 | best_model = val_loss < best_val_loss
53 |
54 | if i % cfg.save_interval == 0 or best_model:
55 | self.save_model(model, logdir, best=best_model, i_iter=i)
56 | if best_model:
57 | print(f'Iter [{i:d}] best model saved {val_loss} <= {best_val_loss}')
58 | best_val_loss = val_loss
59 | if i_epoch % cfg.save_interval_epoch == 0:
60 | self.save_model(model, logdir, best=False, i_epoch=i_epoch)
61 |
62 | return model, best_val_loss
63 |
64 | def save_model(self, model, logdir, best=False, i_iter=None, i_epoch=None):
65 | if best:
66 | pkl_name = "model_best.pkl"
67 | else:
68 | if i_iter is not None:
69 | pkl_name = "model_iter_{}.pkl".format(i_iter)
70 | else:
71 | pkl_name = "model_epoch_{}.pkl".format(i_epoch)
72 | state = {"epoch": i_epoch, "model_state": model.state_dict(), 'iter': i_iter}
73 | save_path = os.path.join(logdir, pkl_name)
74 | torch.save(state, save_path)
75 | print(f'Model saved: {pkl_name}')
76 |
77 |
78 |
--------------------------------------------------------------------------------
/trainers/logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from metrics import averageMeter
3 |
4 |
5 | class BaseLogger:
6 | """BaseLogger that can handle most of the logging
7 | logging convention
8 | ------------------
9 | 'loss' has to be exist in all training settings
10 | endswith('_') : scalar
11 | endswith('@') : image
12 | """
13 | def __init__(self, tb_writer):
14 | """tb_writer: tensorboard SummaryWriter"""
15 | self.writer = tb_writer
16 | self.train_loss_meter = averageMeter()
17 | self.val_loss_meter = averageMeter()
18 | self.d_train = {}
19 | self.d_val = {}
20 | self.has_val_loss = True # If True, we assume that validation loss is available
21 |
22 | def process_iter_train(self, d_result):
23 | self.train_loss_meter.update(d_result['loss'])
24 | self.d_train = d_result
25 |
26 | def summary_train(self, i):
27 | self.d_train['loss/train_loss_'] = self.train_loss_meter.avg
28 | for key, val in self.d_train.items():
29 | if key.endswith('_'):
30 | self.writer.add_scalar(key, val, i)
31 | if key.endswith('@'):
32 | if val is not None:
33 | self.writer.add_image(key, val, i)
34 |
35 | result = self.d_train
36 | self.d_train = {}
37 | return result
38 |
39 | def process_iter_val(self, d_result):
40 | if self.has_val_loss:
41 | self.val_loss_meter.update(d_result['loss'])
42 | self.d_val = d_result
43 |
44 | def summary_val(self, i):
45 | if self.has_val_loss:
46 | self.d_val['loss/val_loss_'] = self.val_loss_meter.avg
47 | l_print_str = [f'Iter [{i:d}]']
48 | for key, val in self.d_val.items():
49 | if key.endswith('_'):
50 | self.writer.add_scalar(key, val, i)
51 | l_print_str.append(f'{key}: {val:.4f}')
52 | if key.endswith('@'):
53 | if val is not None:
54 | self.writer.add_image(key, val, i)
55 |
56 | print_str = ' '.join(l_print_str)
57 |
58 | result = self.d_val
59 | result['print_str'] = print_str
60 | self.d_val = {}
61 | return result
62 |
--------------------------------------------------------------------------------
/trainers/nae.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import time
4 | from metrics import averageMeter
5 | from trainers.base import BaseTrainer
6 | from trainers.logger import BaseLogger
7 | from optimizers import get_optimizer
8 | import torch
9 | from torch.optim import Adam
10 | from tqdm import tqdm
11 | from torchvision.utils import make_grid, save_image
12 | from utils import roc_btw_arr
13 |
14 |
15 | class NAETrainer(BaseTrainer):
16 | def train(self, model, opt, d_dataloaders, logger=None, logdir='', scheduler=None, clip_grad=None):
17 | cfg = self.cfg
18 | best_val_loss = np.inf
19 | time_meter = averageMeter()
20 | i = 0
21 | indist_train_loader = d_dataloaders['indist_train']
22 | indist_val_loader = d_dataloaders['indist_val']
23 | oodval_val_loader = d_dataloaders['ood_val']
24 | oodtarget_val_loader = d_dataloaders['ood_target']
25 | no_best_model_tolerance = 3
26 | no_best_model_count = 0
27 |
28 | n_ae_epoch = cfg.ae_epoch
29 | n_nae_epoch = cfg.nae_epoch
30 | ae_opt = Adam(model.parameters(), lr=cfg.ae_lr)
31 | # nae_opt = Adam([{'params': list(model.encoder.parameters()) + list(model.decoder.parameters())},
32 | # {'params': model.temperature_, 'lr': cfg.temperature_lr}], lr=cfg.nae_lr)
33 |
34 | if cfg.get('fix_D', False):
35 | if hasattr(model, 'temperature_'):
36 | nae_opt = Adam(list(model.encoder.parameters()) + [model.temperature_], lr=cfg.nae_lr)
37 | else:
38 | nae_opt = Adam(model.encoder.parameters(), lr=cfg.nae_lr)
39 | elif cfg.get('small_D_lr', False):
40 | print('small decoder learning rate')
41 | nae_opt = Adam([{'params': model.encoder.parameters(), 'lr': cfg.nae_lr},
42 | {'params': model.decoder.parameters(), 'lr': cfg.nae_lr / 10}])
43 | else:
44 | l_params = [{'params': list(model.encoder.parameters()) + list(model.decoder.parameters())}]
45 | if model.temperature_trainable:
46 | l_params.append({'params': model.temperature_, 'lr': cfg.temperature_lr})
47 | nae_opt = Adam(l_params, lr=cfg.nae_lr)
48 |
49 | '''AE PASS'''
50 | if 'load_ae' in cfg:
51 | n_ae_epoch = 0
52 | model.load_state_dict(torch.load(cfg['load_ae'])['model_state'])
53 | print(f'model loaded from {cfg["load_ae"]}')
54 |
55 | for i_epoch in range(n_ae_epoch):
56 |
57 | for x, y in indist_train_loader:
58 | i += 1
59 |
60 | model.train()
61 | x = x.to(self.device)
62 | y = y.to(self.device)
63 |
64 | start_ts = time.time()
65 | d_train = model.train_step_ae(x, ae_opt, clip_grad=0.1) # todo: clip_grad
66 | time_meter.update(time.time() - start_ts)
67 | logger.process_iter_train(d_train)
68 |
69 | if i % cfg.print_interval == 0:
70 | d_train = logger.summary_train(i)
71 | print(f"Iter [{i:d}] Avg Loss: {d_train['loss/train_loss_']:.4f} Elapsed time: {time_meter.sum:.4f}")
72 | time_meter.reset()
73 |
74 | if i % cfg.val_interval == 0:
75 | model.eval()
76 | for val_x, val_y in indist_val_loader:
77 | val_x = val_x.to(self.device)
78 | val_y = val_y.to(self.device)
79 |
80 | d_val = model.validation_step(val_x, y=val_y)
81 | logger.process_iter_val(d_val)
82 | d_val = logger.summary_val(i)
83 | val_loss = d_val['loss/val_loss_']
84 | print(d_val['print_str'])
85 | best_model = val_loss < best_val_loss
86 |
87 | if i % cfg.save_interval == 0 or best_model:
88 | self.save_model(model, logdir, best=best_model, i_iter=i)
89 | if best_model:
90 | print(f'Iter [{i:d}] best model saved {val_loss} <= {best_val_loss}')
91 | best_val_loss = val_loss
92 | else:
93 | no_best_model_count += 1
94 | if no_best_model_count > no_best_model_tolerance:
95 | break
96 |
97 | if no_best_model_count > no_best_model_tolerance:
98 | print('terminating autoencoder training since validation loss does not decrease anymore')
99 | break
100 |
101 | '''NAE PASS'''
102 | # load best autoencoder model
103 | # model.load_state_dict(torch.load(os.path.join(logdir, 'model_best.pkl'))['model_state'])
104 | # print('best model loaded')
105 | i = 0
106 | for i_epoch in tqdm(range(n_nae_epoch)):
107 | for x, _ in tqdm(indist_train_loader):
108 | i += 1
109 |
110 | x = x.cuda(self.device)
111 | if cfg.get('flatten', False):
112 | x = x.view(len(x), -1)
113 | d_result = model.train_step(x, nae_opt)
114 | logger.process_iter_train_nae(d_result)
115 |
116 | if i % cfg.print_interval == 1:
117 | logger.summary_train_nae(i)
118 |
119 | if i % cfg.val_interval == 1:
120 | '''AUC'''
121 | in_pred = self.predict(model, indist_val_loader, self.device)
122 | ood1_pred = self.predict(model, oodval_val_loader, self.device)
123 | auc_val = roc_btw_arr(ood1_pred, in_pred)
124 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device)
125 | auc_target = roc_btw_arr(ood2_pred, in_pred)
126 | d_result = {'nae/auc_val': auc_val, 'nae/auc_target': auc_target}
127 | print(logger.summary_val_nae(i, d_result))
128 | torch.save({'model_state': model.state_dict()}, f'{logdir}/nae_iter_{i}.pkl')
129 |
130 | torch.save(model.state_dict(), f'{logdir}/nae_{i_epoch}.pkl')
131 | torch.save(model.state_dict(), f'{logdir}/nae.pkl')
132 |
133 | '''EBAE sample'''
134 | nae_sample = model.sample(n_sample=30, device=self.device, replay=True)
135 | img_grid = make_grid(nae_sample['sample_x'].detach().cpu(), nrow=10, range=(0, 1))
136 | logger.writer.add_image('nae/sample', img_grid, i + 1)
137 | save_image(img_grid, f'{logdir}/nae_sample.png')
138 |
139 | '''AUC'''
140 | in_pred = self.predict(model, indist_val_loader, self.device)
141 | ood1_pred = self.predict(model, oodval_val_loader, self.device)
142 | auc_val = roc_btw_arr(ood1_pred, in_pred)
143 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device)
144 | auc_target = roc_btw_arr(ood2_pred, in_pred)
145 | d_result = {'nae/auc_val': auc_val, 'nae/auc_target': auc_target}
146 | print(d_result)
147 |
148 | return model, auc_val
149 |
150 | def predict(self, m, dl, device, flatten=False):
151 | """run prediction for the whole dataset"""
152 | l_result = []
153 | for x, _ in dl:
154 | with torch.no_grad():
155 | if flatten:
156 | x = x.view(len(x), -1)
157 | pred = m.predict(x.cuda(device)).detach().cpu()
158 | l_result.append(pred)
159 | return torch.cat(l_result)
160 |
161 |
162 | class NAELogger(BaseLogger):
163 | def __init__(self, tb_writer):
164 | super().__init__(tb_writer)
165 | self.train_loss_meter_nae = averageMeter()
166 | self.val_loss_meter_nae = averageMeter()
167 | self.d_train_nae = {}
168 | self.d_val_nae = {}
169 |
170 | def process_iter_train_nae(self, d_result):
171 | self.train_loss_meter.update(d_result['loss'])
172 | self.d_train_nae = d_result
173 |
174 | def summary_train_nae(self, i):
175 | d_result = self.d_train_nae
176 | writer = self.writer
177 | writer.add_scalar('nae/loss', d_result['loss'], i)
178 | writer.add_scalar('nae/energy_diff', d_result['pos_e'] - d_result['neg_e'], i)
179 | writer.add_scalar('nae/pos_e', d_result['pos_e'], i)
180 | writer.add_scalar('nae/neg_e', d_result['neg_e'], i)
181 | writer.add_scalar('nae/encoder_l2', d_result['encoder_norm'], i)
182 | writer.add_scalar('nae/decoder_l2', d_result['decoder_norm'], i)
183 | if 'neg_e_x0' in d_result:
184 | writer.add_scalar('nae/neg_e_x0', d_result['neg_e_x0'], i)
185 | if 'neg_e_z0' in d_result:
186 | writer.add_scalar('nae/neg_e_z0', d_result['neg_e_z0'], i)
187 | if 'temperature' in d_result:
188 | writer.add_scalar('nae/temperature', d_result['temperature'], i)
189 | if 'sigma' in d_result:
190 | writer.add_scalar('nae/sigma', d_result['sigma'], i)
191 | if 'delta_term' in d_result:
192 | writer.add_scalar('nae/delta_term', d_result['delta_term'], i)
193 | if 'gamma_term' in d_result:
194 | writer.add_scalar('nae/gamma_term', d_result['gamma_term'], i)
195 |
196 |
197 | '''images'''
198 | x_neg = d_result['x_neg']
199 | recon_neg = d_result['recon_neg']
200 | img_grid = make_grid(x_neg, nrow=10, range=(0, 1))
201 | writer.add_image('nae/sample', img_grid, i)
202 | img_grid = make_grid(recon_neg, nrow=10, range=(0, 1), normalize=True)
203 | writer.add_image('nae/sample_recon', img_grid, i)
204 |
205 | # to uint8 and save as array
206 | x_neg = (x_neg.permute(0,2,3,1).numpy() * 256.).clip(0, 255).astype('uint8')
207 | recon_neg = (recon_neg.permute(0,2,3,1).numpy() * 256.).clip(0, 255).astype('uint8')
208 | # save_image(img_grid, f'{writer.file_writer.get_logdir()}/nae_sample_{i}.png')
209 | np.save(f'{writer.file_writer.get_logdir()}/nae_neg_{i}.npy', x_neg)
210 | np.save(f'{writer.file_writer.get_logdir()}/nae_neg_recon_{i}.npy', recon_neg)
211 |
212 |
213 | def summary_val_nae(self, i, d_result):
214 | l_print_str = [f'Iter [{i:d}]']
215 | for key, val in d_result.items():
216 | self.writer.add_scalar(key, val, i)
217 | l_print_str.append(f'{key}: {val:.4f}')
218 | print_str = ' '.join(l_print_str)
219 | return print_str
220 |
221 |
222 | class NAETrainerV2(BaseTrainer):
223 | def train(self, model, opt, d_dataloaders, logger=None, logdir='', scheduler=None, clip_grad=None):
224 | cfg = self.cfg
225 | best_val_loss = np.inf
226 | time_meter = averageMeter()
227 | i = 0
228 | indist_train_loader = d_dataloaders['indist_train']
229 | indist_val_loader = d_dataloaders['indist_val']
230 | oodval_val_loader = d_dataloaders['ood_val']
231 | oodtarget_val_loader = d_dataloaders['ood_target']
232 | no_best_model_tolerance = 3
233 | no_best_model_count = 0
234 |
235 | n_ae_epoch = cfg.ae_epoch
236 | n_nae_epoch = cfg.nae_epoch
237 | ae_opt = Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=cfg.ae_lr)
238 | nae_opt_mode = cfg.get('nae_opt', 'fix_ae')
239 | if nae_opt_mode == 'fix_ae':
240 | if hasattr(model, 'varnetx'):
241 | nae_opt = Adam(list(model.varnetx.parameters()) + list(model.varnetz.parameters()), lr=cfg.nae_lr)
242 | else:
243 | nae_opt = Adam(model.varnet.parameters(), lr=cfg.nae_lr)
244 | elif nae_opt_mode == 'all':
245 | nae_opt = Adam(model.parameters(), lr=cfg.nae_lr)
246 | elif nae_opt_mode == 'fix_D':
247 | if hasattr(model, 'varnet'):
248 | nae_opt = Adam(list(model.varnet.parameters()) + list(model.encoder.parameters()), lr=cfg.nae_lr)
249 | else:
250 | nae_opt = Adam(model.encoder.parameters(), lr=cfg.nae_lr)
251 | elif nae_opt_mode == 'slow_E':
252 | nae_opt = Adam([{'params': model.varnet.parameters(), 'lr': cfg.nae_lr * 10},
253 | {'params': model.encoder.parameters(), 'lr': cfg.nae_lr}])
254 | elif nae_opt_mode == 'sgd':
255 | nae_opt = SGD(model.parameters(), lr=cfg.nae_lr)
256 | else:
257 | raise ValueError(f'Invalid nae_opt_mode {nae_opt_mode}')
258 |
259 | '''AE PASS'''
260 | if cfg.get('load_ae', None) is not None:
261 | n_ae_epoch = 0
262 | state_dict = torch.load(cfg['load_ae'])['model_state']
263 | encoder_state = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder')}
264 | decoder_state = {k.replace('decoder.', ''): v for k, v in state_dict.items() if k.startswith('decoder')}
265 | model.encoder.load_state_dict(encoder_state)
266 | model.decoder.load_state_dict(decoder_state)
267 | print(f'model loaded from {cfg["load_ae"]}')
268 |
269 | for i_epoch in range(n_ae_epoch):
270 |
271 | for x, y in indist_train_loader:
272 | i += 1
273 |
274 | model.train()
275 | x = x.to(self.device)
276 | y = y.to(self.device)
277 |
278 | start_ts = time.time()
279 | d_train = model.train_step_ae(x, ae_opt, clip_grad=clip_grad)
280 | time_meter.update(time.time() - start_ts)
281 | logger.process_iter_train(d_train)
282 |
283 | if i % cfg.print_interval == 0:
284 | d_train = logger.summary_train(i)
285 | print(f"Iter [{i:d}] Avg Loss: {d_train['loss/train_loss_']:.4f} Elapsed time: {time_meter.sum:.4f}")
286 | time_meter.reset()
287 |
288 | if i % cfg.val_interval == 0:
289 | model.eval()
290 | for val_x, val_y in indist_val_loader:
291 | val_x = val_x.to(self.device)
292 | val_y = val_y.to(self.device)
293 |
294 | d_val = model.validation_step_ae(val_x, y=val_y)
295 | logger.process_iter_val(d_val)
296 | d_val = logger.summary_val(i)
297 | val_loss = d_val['loss/val_loss_']
298 | print(d_val['print_str'])
299 | best_model = val_loss < best_val_loss
300 |
301 | if best_model:
302 | self.save_model(model, logdir, best=best_model, i_iter=i)
303 | print(f'Iter [{i:d}] best model saved {val_loss} <= {best_val_loss}')
304 | best_val_loss = val_loss
305 | else:
306 | no_best_model_count += 1
307 | if no_best_model_count > no_best_model_tolerance:
308 | break
309 |
310 | if i_epoch % cfg.get('save_interval_epoch', 1) == 1:
311 | self.save_model(model, logdir, best=False, i_iter=None, i_epoch=i_epoch)
312 | print(f'Epoch [{i_epoch:d}] model saved {cfg.save_interval_epoch}')
313 |
314 | if no_best_model_count > no_best_model_tolerance:
315 | print('terminating autoencoder training since validation loss does not decrease anymore')
316 | break
317 |
318 | '''NAE PASS'''
319 | logger.has_val_loss = False
320 | if cfg.get('transfer_encoder', False):
321 | state_dict = model.encoder.state_dict()
322 | state_dict.pop('linear.weight'); state_dict.pop('linear.bias')
323 | transfer_result = model.varnet.load_state_dict(state_dict, strict=False)
324 | print('varnet initialized from encoder:', transfer_result)
325 | # load best autoencoder model
326 | # model.load_state_dict(torch.load(os.path.join(logdir, 'model_best.pkl'))['model_state'])
327 | # print('best model loaded')
328 | i = 0
329 | for i_epoch in tqdm(range(n_nae_epoch)):
330 | for x, _ in tqdm(indist_train_loader):
331 | i += 1
332 |
333 | x = x.cuda(self.device)
334 | if cfg.get('flatten', False):
335 | x = x.view(len(x), -1)
336 | # from pudb import set_trace; set_trace()
337 | d_result = model.train_step(x, nae_opt, clip_grad=clip_grad)
338 | logger.process_iter_train(d_result)
339 |
340 | if i % cfg.print_interval_nae == 1:
341 | input_img = make_grid(x.detach().cpu(), nrow=10, range=(0, 1))
342 | recon_img = make_grid(model.reconstruct(x).detach().cpu(), nrow=10, range=(0, 1))
343 | logger.d_train['input_img@'] = input_img
344 | logger.d_train['recon_img@'] = recon_img
345 | logger.summary_train(i)
346 |
347 | if i % cfg.val_interval == 1:
348 | '''AUC'''
349 | in_pred = self.predict(model, indist_val_loader, self.device)
350 | ood1_pred = self.predict(model, oodval_val_loader, self.device)
351 | auc_val = roc_btw_arr(ood1_pred, in_pred)
352 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device)
353 | auc_target = roc_btw_arr(ood2_pred, in_pred)
354 | d_result = {'nae/auc_val_': auc_val, 'nae/auc_target_': auc_target}
355 | logger.process_iter_val(d_result)
356 | print(logger.summary_val(i)['print_str'])
357 | torch.save({'model_state': model.state_dict()}, f'{logdir}/nae_iter_{i}.pkl')
358 |
359 | torch.save(model.state_dict(), f'{logdir}/nae_{i_epoch}.pkl')
360 | torch.save(model.state_dict(), f'{logdir}/nae.pkl')
361 |
362 | '''AUC'''
363 | in_pred = self.predict(model, indist_val_loader, self.device)
364 | ood1_pred = self.predict(model, oodval_val_loader, self.device)
365 | auc_val = roc_btw_arr(ood1_pred, in_pred)
366 | ood2_pred = self.predict(model, oodtarget_val_loader, self.device)
367 | auc_target = roc_btw_arr(ood2_pred, in_pred)
368 | d_result = {'nae/auc_val': auc_val, 'nae/auc_target': auc_target}
369 | print(d_result)
370 |
371 | return model, auc_val
372 |
373 | def predict(self, m, dl, device, flatten=False):
374 | """run prediction for the whole dataset"""
375 | l_result = []
376 | for x, _ in dl:
377 | with torch.no_grad():
378 | if flatten:
379 | x = x.view(len(x), -1)
380 | pred = m.predict(x.cuda(device)).detach().cpu()
381 | l_result.append(pred)
382 | return torch.cat(l_result)
383 |
384 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc Utility functions
3 | """
4 | import os
5 | import sys
6 | import logging
7 | import datetime
8 | import numpy as np
9 | import torch
10 | import yaml
11 | from torch.autograd import Variable
12 | from collections import OrderedDict
13 |
14 |
15 | def recursive_glob(rootdir=".", suffix=""):
16 | """Performs recursive glob with given suffix and rootdir
17 | :param rootdir is the root directory
18 | :param suffix is the suffix to be searched
19 | """
20 | return [
21 | os.path.join(looproot, filename)
22 | for looproot, _, filenames in os.walk(rootdir)
23 | for filename in filenames
24 | if filename.endswith(suffix)
25 | ]
26 |
27 |
28 | def alpha_blend(input_image, segmentation_mask, alpha=0.5):
29 | """Alpha Blending utility to overlay RGB masks on RBG images
30 | :param input_image is a np.ndarray with 3 channels
31 | :param segmentation_mask is a np.ndarray with 3 channels
32 | :param alpha is a float value
33 | """
34 | blended = np.zeros(input_image.size, dtype=np.float32)
35 | blended = input_image * alpha + segmentation_mask * (1 - alpha)
36 | return blended
37 |
38 |
39 | def convert_state_dict(state_dict):
40 | """Converts a state dict saved from a dataParallel module to normal
41 | module state_dict inplace
42 | :param state_dict is the loaded DataParallel model_state
43 | """
44 | if not next(iter(state_dict)).startswith("module."):
45 | return state_dict # abort if dict is not a DataParallel model_state
46 | new_state_dict = OrderedDict()
47 | for k, v in state_dict.items():
48 | name = k[7:] # remove `module.`
49 | new_state_dict[name] = v
50 | return new_state_dict
51 |
52 | def convert_state_dict_remove_main(state_dict):
53 | if not next(iter(state_dict)).startswith("module."):
54 | return state_dict # abort if dict is not a DataParallel model_state
55 | elif not next(iter(state_dict)).startswith("module.main"):
56 | return convert_state_dict(state_dict) # abort if dict is not a DataParallel model_state
57 |
58 | new_state_dict = OrderedDict()
59 | for k, v in state_dict.items():
60 | name = k[12:] # remove `module.main`
61 | new_state_dict[name] = v
62 | return new_state_dict
63 |
64 |
65 | def get_logger(logdir):
66 | logger = logging.getLogger("ptsemseg")
67 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_")
68 | ts = ts.replace(":", "_").replace("-", "_")
69 | file_path = os.path.join(logdir, "run_{}.log".format(ts))
70 | hdlr = logging.FileHandler(file_path, mode='w')
71 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
72 | hdlr.setFormatter(formatter)
73 | logger.addHandler(hdlr)
74 | logger.setLevel(logging.INFO)
75 | return logger
76 |
77 |
78 | def check_config_validity(cfg):
79 | """checks validity of config"""
80 | assert 'model' in cfg
81 | assert 'training' in cfg
82 | assert 'validation' in cfg
83 | assert 'evaluation' in cfg
84 | assert 'data' in cfg
85 |
86 | data = cfg['data']
87 | assert 'dataset' in data
88 | assert 'path' in data
89 | assert 'n_classes' in data
90 | assert 'split' in data
91 | assert 'resize_factor' in data
92 | assert 'label' in data
93 |
94 | d = cfg['training']
95 | s = 'training'
96 | assert 'train_iters' in d, s
97 | assert 'val_interval' in d, s
98 | assert 'print_interval' in d, s
99 | assert 'optimizer' in d, s
100 | assert 'loss' in d, s
101 | assert 'batch_size' in d, s
102 | assert 'n_workers' in d, s
103 |
104 | d = cfg['validation']
105 | s = 'validation'
106 | assert 'batch_size' in d, s
107 | assert 'n_workers' in d, s
108 |
109 | d = cfg['evaluation']
110 | s = 'evaluation'
111 | assert 'batch_size' in d, s
112 | assert 'n_workers' in d, s
113 | assert 'num_crop_width' in d, s
114 | assert 'num_crop_height' in d, s
115 |
116 |
117 | print('config file validation passed')
118 |
119 |
120 | def add_uniform_noise(x):
121 | """
122 | x: torch.Tensor
123 | """
124 | return x * 255. / 256. + torch.rand_like(x) / 256.
125 |
126 |
127 | import errno
128 | import os
129 |
130 |
131 | # Recursive mkdir
132 | def mkdir_p(path):
133 | try:
134 | os.makedirs(path)
135 | except OSError as exc: # Python >2.5
136 | if exc.errno == errno.EEXIST and os.path.isdir(path):
137 | pass
138 | else:
139 | raise
140 |
141 |
142 | from sklearn.metrics import roc_auc_score
143 | def roc_btw_arr(arr1, arr2):
144 | true_label = np.concatenate([np.ones_like(arr1),
145 | np.zeros_like(arr2)])
146 | score = np.concatenate([arr1, arr2])
147 | return roc_auc_score(true_label, score)
148 |
149 |
150 | def batch_run(m, dl, device, flatten=False, method='predict', input_type='first', no_grad=True, **kwargs):
151 | """
152 | m: model
153 | dl: dataloader
154 | device: device
155 | method: the name of a function to be called
156 | no_grad: use torch.no_grad if True.
157 | kwargs: additional argument for the method being called
158 | """
159 | method = getattr(m, method)
160 | l_result = []
161 | for batch in dl:
162 | if input_type == 'first':
163 | x = batch[0]
164 |
165 | if no_grad:
166 | with torch.no_grad():
167 | if flatten:
168 | x = x.view(len(x), -1)
169 | pred = method(x.cuda(device), **kwargs).detach().cpu()
170 | else:
171 | if flatten:
172 | x = x.view(len(x), -1)
173 | pred = method(x.cuda(device), **kwargs).detach().cpu()
174 |
175 | l_result.append(pred)
176 | return torch.cat(l_result)
177 |
178 |
179 |
180 | def save_yaml(filename, text):
181 | """parse string as yaml then dump as a file"""
182 | with open(filename, 'w') as f:
183 | yaml.dump(yaml.safe_load(text), f, default_flow_style=False)
184 |
185 |
186 | def get_shuffled_idx(N, seed):
187 | """get randomly permuted indices"""
188 | rng = np.random.default_rng(seed)
189 | return rng.permutation(N)
190 |
191 |
192 | def search_params_intp(params):
193 | ret = {}
194 | for param in params.keys():
195 | # param : "train.batch"
196 | spl = param.split(".")
197 | if len(spl) == 2:
198 | if spl[0] in ret:
199 | ret[spl[0]][spl[1]] = params[param]
200 | else:
201 | ret[spl[0]] = {spl[1]: params[param]}
202 | # temp = {}
203 | # temp[spl[1]] = params[param]
204 | # ret[spl[0]] = temp
205 | elif len(spl) == 3:
206 | if spl[0] in ret:
207 | if spl[1] in ret[spl[0]]:
208 | ret[spl[0]][spl[1]][spl[2]] = params[param]
209 | else:
210 | ret[spl[0]][spl[1]] = {spl[2]: params[param]}
211 | else:
212 | ret[spl[0]] = {spl[1]: {spl[2]: params[param]}}
213 | elif len(spl) == 1:
214 | ret[spl[0]] = params[param]
215 | else:
216 | raise ValueError
217 | return ret
218 |
219 |
220 | def eprint(*args, **kwargs):
221 | print(*args, file=sys.stderr, **kwargs)
222 |
223 |
224 | def weight_norm(net):
225 | """computes L2 norm of weights of parameters"""
226 | norm = 0
227 | for param in net.parameters():
228 | norm += (param ** 2).sum()
229 | return norm
230 |
231 |
232 |
233 | ## additional command-line argument parsing
234 | def parse_arg_type(val):
235 | if val.isnumeric():
236 | return int(val)
237 | if val == 'True':
238 | return True
239 | try:
240 | return float(val)
241 | except ValueError:
242 | return val
243 |
244 |
245 | def parse_unknown_args(l_args):
246 | """convert the list of unknown args into dict
247 | this does similar stuff to OmegaConf.from_cli()
248 | I may have invented the wheel again..."""
249 | n_args = len(l_args) // 2
250 | kwargs = {}
251 | for i_args in range(n_args):
252 | key = l_args[i_args*2]
253 | val = l_args[i_args*2 + 1]
254 | assert '=' not in key, 'optional arguments should be separated by space'
255 | kwargs[key.strip('-')] = parse_arg_type(val)
256 | return kwargs
257 |
258 |
259 | def parse_nested_args(d_cmd_cfg):
260 | """produce a nested dictionary by parsing dot-separated keys
261 | e.g. {key1.key2 : 1} --> {key1: {key2: 1}}"""
262 | d_new_cfg = {}
263 | for key, val in d_cmd_cfg.items():
264 | l_key = key.split('.')
265 | d = d_new_cfg
266 | for i_key, each_key in enumerate(l_key):
267 | if i_key == len(l_key) - 1:
268 | d[each_key] = val
269 | else:
270 | if each_key not in d:
271 | d[each_key] = {}
272 | d = d[each_key]
273 | return d_new_cfg
274 |
--------------------------------------------------------------------------------