├── .gitignore ├── LICENSE.md ├── README.md ├── data ├── __init__.py ├── base.py ├── bsds300.py ├── caltech101.py ├── celeba.py ├── cifar10.py ├── download.py ├── frey.py ├── gas.py ├── hepmass.py ├── imagenet.py ├── miniboone.py ├── omniglot.py ├── plane.py ├── plane_test.py └── power.py ├── environment.yml ├── experiments ├── autils.py ├── cutils │ ├── __init__.py │ ├── io.py │ └── misc.py ├── face.py ├── faceeval.py ├── image_configs │ ├── cifar-10-5bit.json │ ├── cifar-10-8bit.json │ ├── imagenet-64-5bit.json │ └── imagenet-64-8bit.json ├── images.py ├── images_data.py ├── plane.py ├── planeeval.py ├── process_celeba_hq.py ├── process_imagenet.py ├── uci.py └── vae_.py ├── nde ├── __init__.py ├── distributions │ ├── __init__.py │ ├── base.py │ ├── discrete.py │ ├── discrete_test.py │ ├── mixture.py │ ├── normal.py │ ├── normal_test.py │ └── uniform.py ├── flows │ ├── __init__.py │ ├── autoregressive.py │ ├── autoregressive_test.py │ ├── base.py │ ├── base_test.py │ ├── realnvp.py │ └── realnvp_test.py └── transforms │ ├── __init__.py │ ├── autoregressive.py │ ├── autoregressive_test.py │ ├── base.py │ ├── base_test.py │ ├── conv.py │ ├── conv_test.py │ ├── coupling.py │ ├── coupling_test.py │ ├── linear.py │ ├── linear_test.py │ ├── lu.py │ ├── lu_test.py │ ├── made.py │ ├── made_test.py │ ├── nonlinearities.py │ ├── nonlinearities_test.py │ ├── normalization.py │ ├── normalization_test.py │ ├── orthogonal.py │ ├── orthogonal_test.py │ ├── permutations.py │ ├── permutations_test.py │ ├── qr.py │ ├── qr_test.py │ ├── reshape.py │ ├── reshape_test.py │ ├── splines │ ├── __init__.py │ ├── cubic.py │ ├── cubic_test.py │ ├── linear.py │ ├── linear_test.py │ ├── quadratic.py │ ├── quadratic_test.py │ ├── rational_quadratic.py │ └── rational_quadratic_test.py │ ├── standard.py │ ├── standard_test.py │ ├── svd.py │ ├── svd_test.py │ └── transform_test.py ├── nn ├── __init__.py ├── attention.py ├── conv.py ├── mlp.py ├── mlp_test.py ├── resnet.py └── unet.py ├── optim ├── __init__.py └── custom_lr_schedulers.py ├── utils ├── __init__.py ├── io.py ├── torchutils.py ├── torchutils_test.py └── typechecks.py └── vae ├── __init__.py ├── base.py └── base_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # PyCharm 107 | .idea/ 108 | 109 | # latex files 110 | *.pdf 111 | *.dvi 112 | *.toc 113 | *.aux 114 | *.bbl 115 | *.blg 116 | *.log 117 | *.out 118 | *.bak 119 | *.idx 120 | *.ilg 121 | *.ind 122 | *.nav 123 | *.snm 124 | *.synctex.gz 125 | 126 | # images 127 | *.png 128 | 129 | # conor 130 | log/* 131 | out/* 132 | checkpoints/* 133 | datasets/* 134 | 135 | # tests 136 | *_trial_temp* 137 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Spline Flows 2 | 3 | A record of the code and experiments for the paper: 4 | 5 | > C. Durkan, A. Bekasov, I. Murray, G. Papamakarios, _Neural Spline Flows_, NeurIPS 2019. 6 | > [[arXiv]](https://arxiv.org/abs/1906.04032) [[bibtex]](https://gpapamak.github.io/bibtex/neural_spline_flows.bib) 7 | 8 | Work in this repository has now stopped. Please go to [nflows](https://github.com/bayesiains/nflows) for an updated and pip-installable normalizing flows framework for PyTorch. 9 | 10 | ## Dependencies 11 | 12 | See `environment.yml` for required Conda/pip packages, or use this to create a Conda environment with 13 | all dependencies: 14 | ```bash 15 | conda env create -f environment.yml 16 | ``` 17 | 18 | Tested with Python 3.5 and PyTorch 1.1. 19 | 20 | ## Data 21 | 22 | Data for density-estimation experiments is available at https://zenodo.org/record/1161203#.Wmtf_XVl8eN. 23 | 24 | Data for VAE and image-modeling experiments is downloaded automatically using either `torchvision` or custom 25 | data providers. 26 | 27 | ## Usage 28 | 29 | `DATAROOT` environment variable needs to be set before running experiments. 30 | 31 | ### 2D toy density experiments 32 | 33 | Use `experiments/face.py` or `experiments/plane.py`. 34 | 35 | ### Density-estimation experiments 36 | 37 | Use `experiments/uci.py`. 38 | 39 | ### VAE experiments 40 | 41 | Use `experiments/vae_.py`. 42 | 43 | ### Image-modeling experiments 44 | 45 | Use `experiments/images.py`. 46 | 47 | [Sacred](https://github.com/IDSIA/sacred) is used to organize image experiments. See the 48 | [documentation](http://sacred.readthedocs.org) for more information. 49 | 50 | `experiments/image_configs` contains .json configurations used for RQ-NSF (C) experiments. For baseline experiments use `coupling_layer_type='affine'`. 51 | 52 | For example, to run RQ-NSF (C) on CIFAR-10 8-bit: 53 | ```bash 54 | python experiments/images.py with experiments/image_configs/cifar-10-8bit.json 55 | ``` 56 | 57 | Corresponding affine baseline run: 58 | ```bash 59 | python experiments/images.py with experiments/image_configs/cifar-10-8bit.json coupling_layer_type='affine' 60 | ``` 61 | 62 | To evaluate on the test set: 63 | ```bash 64 | python experiments/images.py eval_on_test with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='' 65 | ``` 66 | 67 | To sample: 68 | ```bash 69 | python experiments/images.py sample with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='' 70 | ``` 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | get_uci_dataset_range, 3 | get_uci_dataset_max_abs_value, 4 | load_dataset, 5 | load_plane_dataset, 6 | load_face_dataset, 7 | batch_generator, 8 | InfiniteLoader, 9 | load_num_batches, 10 | UnlabelledImageFolder 11 | ) 12 | from data.download import download_file, download_file_from_google_drive 13 | 14 | from .plane import TestGridDataset 15 | 16 | from .celeba import CelebA, CelebAHQ, CelebAHQ64Fast 17 | 18 | from .imagenet import ImageNet32, ImageNet64, ImageNet64Fast 19 | 20 | from .cifar10 import CIFAR10Fast 21 | 22 | from .omniglot import OmniglotDataset 23 | -------------------------------------------------------------------------------- /data/bsds300.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | import utils 5 | 6 | from matplotlib import pyplot as plt 7 | from torch.utils import data 8 | 9 | 10 | def load_bsds300(): 11 | path = os.path.join(utils.get_data_root(), 'bsds300', 'bsds300.hdf5') 12 | file = h5py.File(path, 'r') 13 | return file['train'], file['validation'], file['test'] 14 | 15 | 16 | class BSDS300Dataset(data.Dataset): 17 | def __init__(self, split='train', frac=None): 18 | splits = dict(zip( 19 | ('train', 'val', 'test'), 20 | load_bsds300() 21 | )) 22 | self.data = np.array(splits[split]).astype(np.float32) 23 | self.n, self.dim = self.data.shape 24 | if frac is not None: 25 | self.n = int(frac * self.n) 26 | 27 | def __getitem__(self, item): 28 | return self.data[item] 29 | 30 | def __len__(self): 31 | return self.n 32 | 33 | 34 | def main(): 35 | dataset = BSDS300Dataset(split='train') 36 | print(type(dataset.data)) 37 | print(dataset.data.shape) 38 | print(dataset.data.min(), dataset.data.max()) 39 | fig, axs = plt.subplots(8, 8, figsize=(10, 10), sharex=True, sharey=True) 40 | axs = axs.reshape(-1) 41 | for i, dimension in enumerate(dataset.data.T): 42 | axs[i].hist(dimension, bins=100) 43 | # plt.hist(dataset.data.reshape(-1), bins=250) 44 | plt.tight_layout() 45 | plt.show() 46 | print(len(dataset)) 47 | loader = data.DataLoader(dataset, batch_size=128, drop_last=True) 48 | print(len(loader)) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /data/caltech101.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesiains/nsf/eaa9377f75df1193025f6b2487524cf266874472/data/caltech101.py -------------------------------------------------------------------------------- /data/celeba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import zipfile 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from data import UnlabelledImageFolder 10 | from data.download import download_file_from_google_drive 11 | 12 | 13 | class CelebA(UnlabelledImageFolder): 14 | """Unlabelled standard CelebA dataset, the aligned version.""" 15 | GOOGLE_DRIVE_FILE_ID = '0B7EVK8r0v71pZjFTYXZWM3FlRnM' 16 | ZIP_FILE_NAME = 'img_align_celeba.zip' 17 | 18 | def __init__(self, root, transform=None, download=False): 19 | if download: 20 | self.download(root) 21 | super(CelebA, self).__init__(os.path.join(root, self.img_dir), 22 | transform=transform) 23 | 24 | @property 25 | def img_dir(self): 26 | return 'img_align_celeba' 27 | 28 | def download(self, root): 29 | if os.path.isdir(os.path.join(root, self.img_dir)): 30 | return # Downloaded already 31 | 32 | os.makedirs(root, exist_ok=True) 33 | 34 | zip_file = os.path.join(root, self.ZIP_FILE_NAME) 35 | 36 | print('Downloading {}...'.format(os.path.basename(zip_file))) 37 | download_file_from_google_drive(self.GOOGLE_DRIVE_FILE_ID, zip_file) 38 | 39 | print('Extracting {}...'.format(os.path.basename(zip_file))) 40 | with zipfile.ZipFile(zip_file, 'r') as fp: 41 | fp.extractall(root) 42 | 43 | os.remove(zip_file) 44 | 45 | 46 | class CelebAHQ(CelebA): 47 | """Unlabelled high quality CelebA dataset with 256x256 images.""" 48 | GOOGLE_DRIVE_FILE_ID = '1psLniAvAvyDgJV8DBk7cvTZ9EasB_2tZ' 49 | ZIP_FILE_NAME = 'celeba-hq-256.zip' 50 | 51 | def __init__(self, root, transform=None, train=True, download=False): 52 | self.train = train 53 | super().__init__(root, transform=transform, download=download) 54 | 55 | @property 56 | def img_dir(self): 57 | if self.train: 58 | return 'celeba-hq-256/train-png' 59 | else: 60 | return 'celeba-hq-256/validation-png' 61 | 62 | class CelebAHQ64Fast(Dataset): 63 | GOOGLE_DRIVE_FILE_ID = { 64 | 'train': '1bcaqMKWzJ-2ca7HCQrUPwN61lfk115TO', 65 | 'valid': '1WfE64z9FNgOnLliGshUDuCrGBfJSwf-t' 66 | } 67 | 68 | NPY_NAME = { 69 | 'train': 'train.npy', 70 | 'valid': 'valid.npy' 71 | } 72 | 73 | def __init__(self, root, train=True, download=False, transform=None): 74 | self.transform = transform 75 | self.root = root 76 | 77 | if download: 78 | self._download() 79 | 80 | tag = 'train' if train else 'valid' 81 | npy_data = np.load(os.path.join(root, self.NPY_NAME[tag])) 82 | self.data = torch.from_numpy(npy_data) # Shouldn't make a copy. 83 | 84 | def __getitem__(self, index): 85 | img = self.data[index, ...] 86 | 87 | if self.transform is not None: 88 | img = self.transform(img) 89 | 90 | # Add a bogus label to be compatible with standard image datasets. 91 | return img, torch.tensor([0.]) 92 | 93 | def __len__(self): 94 | return self.data.shape[0] 95 | 96 | def _download(self): 97 | os.makedirs(self.root, exist_ok=True) 98 | 99 | for tag in ['train','valid']: 100 | npy = os.path.join(self.root, self.NPY_NAME[tag]) 101 | if not os.path.isfile(npy): 102 | print('Downloading {}...'.format(self.NPY_NAME[tag])) 103 | download_file_from_google_drive(self.GOOGLE_DRIVE_FILE_ID[tag], npy) 104 | -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import CIFAR10 3 | 4 | class CIFAR10Fast(CIFAR10): 5 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 6 | super().__init__(root, train, transform, target_transform, download) 7 | 8 | self.data = self.data.transpose((0, 3, 1, 2)) # HWC -> CHW. 9 | self.data = torch.from_numpy(self.data) # Shouldn't make a copy. 10 | assert self.data.dtype == torch.uint8 11 | 12 | def __getitem__(self, index): 13 | img, target = self.data[index], self.targets[index] 14 | 15 | # Don't convert to PIL Image, just to convert back later: slow. 16 | 17 | if self.transform is not None: 18 | img = self.transform(img) 19 | 20 | if self.target_transform is not None: 21 | target = self.target_transform(target) 22 | 23 | return img, target 24 | -------------------------------------------------------------------------------- /data/download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def download_file(url, dest): 5 | CHUNK_SIZE = 8192 6 | 7 | with requests.get(url, stream=True) as r: 8 | r.raise_for_status() 9 | with open(dest, 'wb') as f: 10 | for chunk in r.iter_content(chunk_size=CHUNK_SIZE): 11 | if chunk: 12 | f.write(chunk) 13 | 14 | 15 | def download_file_from_google_drive(id, dest): 16 | URL = "https://docs.google.com/uc?export=download" 17 | 18 | session = requests.Session() 19 | 20 | response = session.get(URL, params={'id': id}, stream=True) 21 | token = get_confirm_token(response) 22 | 23 | if token: 24 | params = {'id': id, 'confirm': token} 25 | response = session.get(URL, params=params, stream=True) 26 | 27 | save_response_content(response, dest) 28 | 29 | 30 | def get_confirm_token(response): 31 | for key, value in response.cookies.items(): 32 | if key.startswith('download_warning'): 33 | return value 34 | 35 | return None 36 | 37 | 38 | def save_response_content(response, destination): 39 | CHUNK_SIZE = 32768 40 | 41 | with open(destination, "wb") as f: 42 | for chunk in response.iter_content(CHUNK_SIZE): 43 | if chunk: # Filter out keep-alive new chunks. 44 | f.write(chunk) 45 | -------------------------------------------------------------------------------- /data/frey.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesiains/nsf/eaa9377f75df1193025f6b2487524cf266874472/data/frey.py -------------------------------------------------------------------------------- /data/gas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import utils 5 | 6 | from matplotlib import pyplot as plt 7 | from torch.utils.data import Dataset 8 | 9 | 10 | def load_gas(): 11 | def load_data(file): 12 | data = pd.read_pickle(file) 13 | data.drop("Meth", axis=1, inplace=True) 14 | data.drop("Eth", axis=1, inplace=True) 15 | data.drop("Time", axis=1, inplace=True) 16 | return data 17 | 18 | def get_correlation_numbers(data): 19 | C = data.corr() 20 | A = C > 0.98 21 | B = A.sum(axis=1) 22 | return B 23 | 24 | def load_data_and_clean(file): 25 | data = load_data(file) 26 | B = get_correlation_numbers(data) 27 | 28 | while np.any(B > 1): 29 | col_to_remove = np.where(B > 1)[0][0] 30 | col_name = data.columns[col_to_remove] 31 | data.drop(col_name, axis=1, inplace=True) 32 | B = get_correlation_numbers(data) 33 | data = (data - data.mean()) / data.std() 34 | 35 | return data.values 36 | 37 | def load_data_and_clean_and_split(file): 38 | data = load_data_and_clean(file) 39 | N_test = int(0.1 * data.shape[0]) 40 | data_test = data[-N_test:] 41 | data_train = data[0:-N_test] 42 | N_validate = int(0.1 * data_train.shape[0]) 43 | data_validate = data_train[-N_validate:] 44 | data_train = data_train[0:-N_validate] 45 | 46 | return data_train, data_validate, data_test 47 | 48 | return load_data_and_clean_and_split( 49 | file=os.path.join(utils.get_data_root(), 'gas', 'ethylene_CO.pickle') 50 | ) 51 | 52 | 53 | def save_splits(): 54 | train, val, test = load_gas() 55 | splits = ( 56 | ('train', train), 57 | ('val', val), 58 | ('test', test) 59 | ) 60 | for split in splits: 61 | name, data = split 62 | file = os.path.join(utils.get_data_root(), 'gas', '{}.npy'.format(name)) 63 | np.save(file, data) 64 | 65 | 66 | class GasDataset(Dataset): 67 | def __init__(self, split='train', frac=None): 68 | path = os.path.join(utils.get_data_root(), 'gas', '{}.npy'.format(split)) 69 | self.data = np.load(path).astype(np.float32) 70 | self.n, self.dim = self.data.shape 71 | if frac is not None: 72 | self.n = int(frac * self.n) 73 | 74 | def __getitem__(self, item): 75 | return self.data[item] 76 | 77 | def __len__(self): 78 | return self.n 79 | 80 | 81 | def main(): 82 | dataset = GasDataset(split='train') 83 | print(type(dataset.data)) 84 | print(dataset.data.shape) 85 | print(dataset.data.min(), dataset.data.max()) 86 | print(np.where(dataset.data == dataset.data.max())) 87 | fig, axs = plt.subplots(3, 3, figsize=(10, 10), sharex=True, sharey=True) 88 | axs = axs.reshape(-1) 89 | for i, dimension in enumerate(dataset.data.T): 90 | print(i) 91 | axs[i].hist(dimension, bins=100) 92 | plt.tight_layout() 93 | plt.show() 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /data/hepmass.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import utils 5 | 6 | from collections import Counter 7 | from matplotlib import pyplot as plt 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def load_hepmass(): 12 | def load_data(path): 13 | 14 | data_train = pd.read_csv(filepath_or_buffer=os.path.join(path, '1000_train.csv'), 15 | index_col=False) 16 | data_test = pd.read_csv(filepath_or_buffer=os.path.join(path, '1000_test.csv'), 17 | index_col=False) 18 | 19 | return data_train, data_test 20 | 21 | def load_data_no_discrete(path): 22 | """Loads the positive class examples from the first 10% of the dataset.""" 23 | data_train, data_test = load_data(path) 24 | 25 | # Gets rid of any background noise examples i.e. class label 0. 26 | data_train = data_train[data_train[data_train.columns[0]] == 1] 27 | data_train = data_train.drop(data_train.columns[0], axis=1) 28 | data_test = data_test[data_test[data_test.columns[0]] == 1] 29 | data_test = data_test.drop(data_test.columns[0], axis=1) 30 | # Because the data_ set is messed up! 31 | data_test = data_test.drop(data_test.columns[-1], axis=1) 32 | 33 | return data_train, data_test 34 | 35 | def load_data_no_discrete_normalised(path): 36 | 37 | data_train, data_test = load_data_no_discrete(path) 38 | mu = data_train.mean() 39 | s = data_train.std() 40 | data_train = (data_train - mu) / s 41 | data_test = (data_test - mu) / s 42 | 43 | return data_train, data_test 44 | 45 | def load_data_no_discrete_normalised_as_array(path): 46 | 47 | data_train, data_test = load_data_no_discrete_normalised(path) 48 | data_train, data_test = data_train.values, data_test.values 49 | 50 | i = 0 51 | # Remove any features that have too many re-occurring real values. 52 | features_to_remove = [] 53 | for feature in data_train.T: 54 | c = Counter(feature) 55 | max_count = np.array([v for k, v in sorted(c.items())])[0] 56 | if max_count > 5: 57 | features_to_remove.append(i) 58 | i += 1 59 | data_train = data_train[:, np.array( 60 | [i for i in range(data_train.shape[1]) if i not in features_to_remove])] 61 | data_test = data_test[:, np.array( 62 | [i for i in range(data_test.shape[1]) if i not in features_to_remove])] 63 | 64 | N = data_train.shape[0] 65 | N_validate = int(N * 0.1) 66 | data_validate = data_train[-N_validate:] 67 | data_train = data_train[0:-N_validate] 68 | 69 | return data_train, data_validate, data_test 70 | 71 | return load_data_no_discrete_normalised_as_array( 72 | path=os.path.join(utils.get_data_root(), 'hepmass') 73 | ) 74 | 75 | 76 | def save_splits(): 77 | train, val, test = load_hepmass() 78 | splits = ( 79 | ('train', train), 80 | ('val', val), 81 | ('test', test) 82 | ) 83 | for split in splits: 84 | name, data = split 85 | file = os.path.join(utils.get_data_root(), 'hepmass', '{}.npy'.format(name)) 86 | np.save(file, data) 87 | 88 | 89 | class HEPMASSDataset(Dataset): 90 | def __init__(self, split='train', frac=None): 91 | path = os.path.join(utils.get_data_root(), 'hepmass', '{}.npy'.format(split)) 92 | self.data = np.load(path).astype(np.float32) 93 | self.n, self.dim = self.data.shape 94 | if frac is not None: 95 | self.n = int(frac * self.n) 96 | 97 | def __getitem__(self, item): 98 | return self.data[item] 99 | 100 | def __len__(self): 101 | return self.n 102 | 103 | 104 | def main(): 105 | dataset = HEPMASSDataset(split='train') 106 | print(type(dataset.data)) 107 | print(dataset.data.shape) 108 | print(dataset.data.min(), dataset.data.max()) 109 | plt.hist(dataset.data.reshape(-1), bins=250) 110 | plt.show() 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from data import UnlabelledImageFolder 7 | from data.download import download_file_from_google_drive 8 | import numpy as np 9 | 10 | class ImageNet32(UnlabelledImageFolder): 11 | GOOGLE_DRIVE_FILE_ID = '1TXsg8TP5SfsSL6Gk39McCkZu9rhSQnNX' 12 | UNZIPPED_DIR_NAME = 'imagenet32' 13 | UNZIPPED_TRAIN_SUBDIR = 'train_32x32' 14 | UNZIPPED_VAL_SUBDIR = 'valid_32x32' 15 | 16 | def __init__(self, root, train=True, download=False, transform=None): 17 | if download: 18 | self._download(root) 19 | 20 | img_dir = 'train' if train else 'val' 21 | super(ImageNet32, self).__init__(os.path.join(root, img_dir), 22 | transform=transform) 23 | 24 | def _download(self, root): 25 | if os.path.isdir(os.path.join(root, 'train')): 26 | return # Downloaded already 27 | 28 | os.makedirs(root, exist_ok=True) 29 | 30 | zip_file = os.path.join(root, self.UNZIPPED_DIR_NAME + '.zip') 31 | 32 | print('Downloading {}...'.format(os.path.basename(zip_file))) 33 | download_file_from_google_drive(self.GOOGLE_DRIVE_FILE_ID, zip_file) 34 | 35 | print('Extracting {}...'.format(os.path.basename(zip_file))) 36 | with zipfile.ZipFile(zip_file, 'r') as fp: 37 | fp.extractall(root) 38 | os.remove(zip_file) 39 | 40 | os.rename(os.path.join(root, self.UNZIPPED_DIR_NAME, self.UNZIPPED_TRAIN_SUBDIR), 41 | os.path.join(root, 'train')) 42 | os.rename(os.path.join(root, self.UNZIPPED_DIR_NAME, self.UNZIPPED_VAL_SUBDIR), 43 | os.path.join(root, 'val')) 44 | os.rmdir(os.path.join(root, self.UNZIPPED_DIR_NAME)) 45 | 46 | 47 | class ImageNet64(ImageNet32): 48 | GOOGLE_DRIVE_FILE_ID = '1NqpYnfluJz9A2INgsn16238FUfZh9QwR' 49 | UNZIPPED_DIR_NAME = 'imagenet64' 50 | UNZIPPED_TRAIN_SUBDIR = 'train_64x64' 51 | UNZIPPED_VAL_SUBDIR = 'valid_64x64' 52 | 53 | class ImageNet64Fast(Dataset): 54 | GOOGLE_DRIVE_FILE_ID = { 55 | 'train': '15AMmVSX-LDbP7LqC3R9Ns0RPbDI9301D', 56 | 'valid': '1Me8EhsSwWbQjQ91vRG1emkIOCgDKK4yC' 57 | } 58 | 59 | NPY_NAME = { 60 | 'train': 'train_64x64.npy', 61 | 'valid': 'valid_64x64.npy' 62 | } 63 | 64 | def __init__(self, root, train=True, download=False, transform=None): 65 | self.transform = transform 66 | self.root = root 67 | 68 | if download: 69 | self._download() 70 | 71 | tag = 'train' if train else 'valid' 72 | npy_data = np.load(os.path.join(root, self.NPY_NAME[tag])) 73 | self.data = torch.from_numpy(npy_data) # Shouldn't make a copy. 74 | 75 | def __getitem__(self, index): 76 | img = self.data[index, ...] 77 | 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | 81 | # Add a bogus label to be compatible with standard image datasets. 82 | return img, torch.tensor([0.]) 83 | 84 | def __len__(self): 85 | return self.data.shape[0] 86 | 87 | def _download(self): 88 | os.makedirs(self.root, exist_ok=True) 89 | 90 | for tag in ['train','valid']: 91 | npy = os.path.join(self.root, self.NPY_NAME[tag]) 92 | if not os.path.isfile(npy): 93 | print('Downloading {}...'.format(self.NPY_NAME[tag])) 94 | download_file_from_google_drive(self.GOOGLE_DRIVE_FILE_ID[tag], npy) 95 | -------------------------------------------------------------------------------- /data/miniboone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import utils 4 | 5 | from matplotlib import pyplot as plt 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def load_miniboone(): 10 | def load_data(path): 11 | # NOTE: To remember how the pre-processing was done. 12 | # data_ = pd.read_csv(root_path, names=[str(x) for x in range(50)], delim_whitespace=True) 13 | # print data_.head() 14 | # data_ = data_.as_matrix() 15 | # # Remove some random outliers 16 | # indices = (data_[:, 0] < -100) 17 | # data_ = data_[~indices] 18 | # 19 | # i = 0 20 | # # Remove any features that have too many re-occuring real values. 21 | # features_to_remove = [] 22 | # for feature in data_.T: 23 | # c = Counter(feature) 24 | # max_count = np.array([v for k, v in sorted(c.iteritems())])[0] 25 | # if max_count > 5: 26 | # features_to_remove.append(i) 27 | # i += 1 28 | # data_ = data_[:, np.array([i for i in range(data_.shape[1]) if i not in features_to_remove])] 29 | # np.save("~/data_/miniboone/data_.npy", data_) 30 | 31 | data = np.load(path) 32 | N_test = int(0.1 * data.shape[0]) 33 | data_test = data[-N_test:] 34 | data = data[0:-N_test] 35 | N_validate = int(0.1 * data.shape[0]) 36 | data_validate = data[-N_validate:] 37 | data_train = data[0:-N_validate] 38 | 39 | return data_train, data_validate, data_test 40 | 41 | def load_data_normalised(path): 42 | data_train, data_validate, data_test = load_data(path) 43 | data = np.vstack((data_train, data_validate)) 44 | mu = data.mean(axis=0) 45 | s = data.std(axis=0) 46 | data_train = (data_train - mu) / s 47 | data_validate = (data_validate - mu) / s 48 | data_test = (data_test - mu) / s 49 | 50 | return data_train, data_validate, data_test 51 | 52 | return load_data_normalised( 53 | path=os.path.join(utils.get_data_root(), 'miniboone', 'data.npy') 54 | ) 55 | 56 | 57 | def save_splits(): 58 | train, val, test = load_miniboone() 59 | splits = ( 60 | ('train', train), 61 | ('val', val), 62 | ('test', test) 63 | ) 64 | for split in splits: 65 | name, data = split 66 | file = os.path.join(utils.get_data_root(), 'miniboone', '{}.npy'.format(name)) 67 | np.save(file, data) 68 | 69 | 70 | class MiniBooNEDataset(Dataset): 71 | def __init__(self, split='train', frac=None): 72 | path = os.path.join(utils.get_data_root(), 'miniboone', '{}.npy'.format(split)) 73 | self.data = np.load(path).astype(np.float32) 74 | self.n, self.dim = self.data.shape 75 | if frac is not None: 76 | self.n = int(frac * self.n) 77 | 78 | def __getitem__(self, item): 79 | return self.data[item] 80 | 81 | def __len__(self): 82 | return self.n 83 | 84 | 85 | def main(): 86 | dataset = MiniBooNEDataset(split='train') 87 | print(type(dataset.data)) 88 | print(dataset.data.shape) 89 | print(dataset.data.min(), dataset.data.max()) 90 | plt.hist(dataset.data.reshape(-1), bins=250) 91 | plt.show() 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /data/omniglot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | 5 | import utils 6 | 7 | from PIL import Image 8 | from scipy.io import loadmat 9 | from torch.utils import data 10 | from torchvision import transforms as tvtransforms 11 | 12 | 13 | class OmniglotDataset(data.Dataset): 14 | def __init__(self, split='train', transform=None): 15 | self.transform = transform 16 | path = os.path.join(utils.get_data_root(), 'omniglot', 'omniglot.mat') 17 | rawdata = loadmat(path) 18 | 19 | if split == 'train': 20 | self.data = rawdata['data'].T.reshape(-1, 28, 28) 21 | self.targets = rawdata['target'].T 22 | elif split == 'test': 23 | self.data = rawdata['testdata'].T.reshape(-1, 28, 28) 24 | self.targets = rawdata['testtarget'].T 25 | else: 26 | raise ValueError 27 | 28 | def __getitem__(self, item): 29 | image, target = self.data[item], self.targets[item] 30 | image = Image.fromarray(image) 31 | if self.transform is not None: 32 | image = self.transform(image) 33 | return image, target 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | 39 | def main(): 40 | transform = tvtransforms.Compose([ 41 | tvtransforms.ToTensor(), 42 | tvtransforms.Lambda(torch.bernoulli) 43 | ]) 44 | dataset = OmniglotDataset(split='test', transform=transform) 45 | loader = data.DataLoader(dataset, batch_size=16) 46 | batch = next(iter(loader))[0] 47 | from matplotlib import pyplot as plt 48 | from experiments import cutils 49 | from torchvision.utils import make_grid 50 | fig, ax = plt.subplots(1, 1, figsize=(5, 5)) 51 | cutils.gridimshow(make_grid(batch, nrow=4), ax) 52 | plt.show() 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /data/plane_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the 2-dim plane datasets.""" 2 | 3 | import torchtestcase 4 | import unittest 5 | 6 | from data import plane 7 | 8 | 9 | class PlaneDatasetTest(torchtestcase.TorchTestCase): 10 | 11 | def test_all(self): 12 | num_points = 40 13 | constructors = [ 14 | plane.GaussianDataset, 15 | plane.CrescentDataset, 16 | plane.CrescentCubedDataset, 17 | plane.SineWaveDataset, 18 | plane.AbsDataset, 19 | plane.SignDataset, 20 | plane.FourCircles, 21 | plane.DiamondDataset, 22 | plane.TwoSpiralsDataset, 23 | plane.CheckerboardDataset, 24 | ] 25 | for constructor in constructors: 26 | for flip_axes in [True, False]: 27 | with self.subTest(constructor=constructor, flip_axes=flip_axes): 28 | dataset = constructor(num_points=num_points, flip_axes=flip_axes) 29 | dataset.reset() 30 | self.assertEqual(len(dataset), num_points) 31 | self.assertEqual(dataset[0], dataset.data[0]) 32 | 33 | 34 | class FaceDatasetTest(torchtestcase.TorchTestCase): 35 | 36 | def test_all(self): 37 | num_points = 40 38 | for name in ['einstein', 'boole', 'bayes']: 39 | for flip_axes in [True, False]: 40 | with self.subTest(name=name, flip_axes=flip_axes): 41 | dataset = plane.FaceDataset( 42 | num_points=num_points, flip_axes=flip_axes, name=name) 43 | dataset.reset() 44 | self.assertEqual(len(dataset), num_points) 45 | self.assertEqual(dataset[0], dataset.data[0]) 46 | 47 | 48 | class TestGridDatasetTest(torchtestcase.TorchTestCase): 49 | 50 | def test_all(self): 51 | num_points = 40 52 | bounds = [[-1, 1]] * 2 53 | for flip_axes in [True, False]: 54 | with self.subTest(flip_axes=flip_axes): 55 | dataset = plane.TestGridDataset( 56 | num_points_per_axis=num_points, bounds=bounds) 57 | dataset.reset() 58 | self.assertEqual(len(dataset), num_points ** 2) 59 | self.assertEqual(dataset[0], dataset.data[0]) 60 | 61 | 62 | if __name__ == '__main__': 63 | unittest.main() 64 | -------------------------------------------------------------------------------- /data/power.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import utils 4 | 5 | from matplotlib import pyplot as plt 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def load_power(): 10 | def load_data(): 11 | file = os.path.join(utils.get_data_root(), 'power', 'data.npy') 12 | return np.load(file) 13 | 14 | def load_data_split_with_noise(): 15 | rng = np.random.RandomState(42) 16 | 17 | data = load_data() 18 | rng.shuffle(data) 19 | N = data.shape[0] 20 | 21 | data = np.delete(data, 3, axis=1) 22 | data = np.delete(data, 1, axis=1) 23 | ############################ 24 | # Add noise 25 | ############################ 26 | # global_intensity_noise = 0.1*rng.rand(N, 1) 27 | voltage_noise = 0.01 * rng.rand(N, 1) 28 | # grp_noise = 0.001*rng.rand(N, 1) 29 | gap_noise = 0.001 * rng.rand(N, 1) 30 | sm_noise = rng.rand(N, 3) 31 | time_noise = np.zeros((N, 1)) 32 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, global_intensity_noise, sm_noise, time_noise)) 33 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, sm_noise, time_noise)) 34 | noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) 35 | data += noise 36 | 37 | N_test = int(0.1 * data.shape[0]) 38 | data_test = data[-N_test:] 39 | data = data[0:-N_test] 40 | N_validate = int(0.1 * data.shape[0]) 41 | data_validate = data[-N_validate:] 42 | data_train = data[0:-N_validate] 43 | 44 | return data_train, data_validate, data_test 45 | 46 | def load_data_normalised(): 47 | data_train, data_validate, data_test = load_data_split_with_noise() 48 | data = np.vstack((data_train, data_validate)) 49 | mu = data.mean(axis=0) 50 | s = data.std(axis=0) 51 | data_train = (data_train - mu) / s 52 | data_validate = (data_validate - mu) / s 53 | data_test = (data_test - mu) / s 54 | 55 | return data_train, data_validate, data_test 56 | 57 | return load_data_normalised() 58 | 59 | 60 | def save_splits(): 61 | train, val, test = load_power() 62 | splits = ( 63 | ('train', train), 64 | ('val', val), 65 | ('test', test) 66 | ) 67 | for split in splits: 68 | name, data = split 69 | file = os.path.join(utils.get_data_root(), 'power', '{}.npy'.format(name)) 70 | np.save(file, data) 71 | 72 | 73 | def print_shape_info(): 74 | train, val, test = load_power() 75 | print(train.shape, val.shape, test.shape) 76 | 77 | 78 | class PowerDataset(Dataset): 79 | def __init__(self, split='train', frac=None): 80 | path = os.path.join(utils.get_data_root(), 'power', '{}.npy'.format(split)) 81 | self.data = np.load(path).astype(np.float32) 82 | self.n, self.dim = self.data.shape 83 | if frac is not None: 84 | self.n = int(frac * self.n) 85 | 86 | def __getitem__(self, item): 87 | return self.data[item] 88 | 89 | def __len__(self): 90 | return self.n 91 | 92 | 93 | def main(): 94 | dataset = PowerDataset(split='train') 95 | print(type(dataset.data)) 96 | print(dataset.data.shape) 97 | print(dataset.data.min(), dataset.data.max()) 98 | plt.hist(dataset.data.reshape(-1), bins=250) 99 | plt.show() 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nsf 2 | channels: 3 | - defaults 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - h5py 8 | - matplotlib 9 | - pip 10 | - python 11 | - scikit-image 12 | - tqdm 13 | - numpy 14 | - scipy 15 | - pytorch 16 | - cudatoolkit=9.0 17 | - torchvision 18 | - pandas 19 | - pip: 20 | - requests 21 | - sacred 22 | - tensorflow 23 | - tensorboard 24 | - tensorboardx 25 | - pymongo 26 | -------------------------------------------------------------------------------- /experiments/cutils/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import * 2 | 3 | from .misc import gridimshow -------------------------------------------------------------------------------- /experiments/cutils/io.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import time 3 | 4 | 5 | def on_cluster(): 6 | hostname = socket.gethostname() 7 | return False if hostname == 'coldingham' else True 8 | 9 | 10 | def get_timestamp(): 11 | formatted_time = time.strftime('%d-%b-%y||%H:%M:%S') 12 | return formatted_time 13 | 14 | 15 | def get_project_root(): 16 | if on_cluster(): 17 | path = '/home/s1638128/deployment/decomposition-flows' 18 | else: 19 | path = '/home/conor/Dropbox/phd/projects/decomposition-flows' 20 | return path 21 | 22 | 23 | def get_log_root(): 24 | if on_cluster(): 25 | path = '/home/s1638128/tmp/decomposition-flows/log' 26 | else: 27 | path = '/home/conor/Dropbox/phd/projects/decomposition-flows/log' 28 | return path 29 | 30 | 31 | def get_data_root(): 32 | if on_cluster(): 33 | path = '/home/s1638128/deployment/decomposition-flows/datasets' 34 | else: 35 | path = '/home/conor/Dropbox/phd/projects/decomposition-flows/datasets' 36 | return path 37 | 38 | 39 | def get_checkpoint_root(from_cluster=False): 40 | if on_cluster(): 41 | path = '/home/s1638128/tmp/decomposition-flows/checkpoints' 42 | else: 43 | if from_cluster: 44 | path = '/home/conor/Dropbox/phd/projects/decomposition-flows/checkpoints/cluster' 45 | else: 46 | path = '/home/conor/Dropbox/phd/projects/decomposition-flows/checkpoints' 47 | return path 48 | 49 | 50 | def get_output_root(): 51 | if on_cluster(): 52 | path = '/home/s1638128/tmp/decomposition-flows/out' 53 | else: 54 | path = '/home/conor/Dropbox/phd/projects/decomposition-flows/out' 55 | return path 56 | 57 | 58 | def get_final_root(): 59 | if on_cluster(): 60 | path = '/home/s1638128/deployment/decomposition-flows/final' 61 | else: 62 | path = '/home/conor/Dropbox/phd/projects/decomposition-flows/final' 63 | return path 64 | 65 | 66 | def main(): 67 | print(get_timestamp()) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /experiments/cutils/misc.py: -------------------------------------------------------------------------------- 1 | import utils 2 | 3 | 4 | def gridimshow(image, ax): 5 | if image.shape[0] == 1: 6 | image = utils.tensor2numpy(image[0, ...]) 7 | ax.imshow(1 - image, cmap='Greys') 8 | else: 9 | image = utils.tensor2numpy(image.permute(1, 2, 0)) 10 | ax.imshow(image) 11 | ax.spines['top'].set_visible(False) 12 | ax.spines['right'].set_visible(False) 13 | ax.spines['left'].set_visible(False) 14 | ax.spines['bottom'].set_visible(False) 15 | ax.tick_params(axis='both', length=0) 16 | ax.set_xticklabels('') 17 | ax.set_yticklabels('') 18 | -------------------------------------------------------------------------------- /experiments/faceeval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import torch 5 | import os 6 | 7 | from matplotlib import cm, pyplot as plt 8 | from tensorboardX import SummaryWriter 9 | from torch import optim 10 | from torch.utils import data 11 | from tqdm import tqdm 12 | 13 | import data as data_ 14 | import nn as nn_ 15 | import utils 16 | 17 | from experiments import cutils 18 | from nde import distributions, flows, transforms 19 | 20 | dataset_name = 'shannon' 21 | path = os.path.join(cutils.get_final_root(), '{}-final.json'.format(dataset_name)) 22 | with open(path) as file: 23 | dictionary = json.load(file) 24 | args = argparse.Namespace(**dictionary) 25 | 26 | torch.manual_seed(args.seed) 27 | np.random.seed(args.seed) 28 | 29 | if args.use_gpu: 30 | device = torch.device('cuda') 31 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 32 | else: 33 | device = torch.device('cpu') 34 | 35 | # create data 36 | train_dataset = data_.load_face_dataset( 37 | name=args.dataset_name, 38 | num_points=args.n_data_points 39 | ) 40 | train_loader = data_.InfiniteLoader( 41 | dataset=train_dataset, 42 | batch_size=args.batch_size, 43 | shuffle=True, 44 | drop_last=True, 45 | num_epochs=None 46 | ) 47 | dim = 2 48 | 49 | # Generate test grid data 50 | num_points_per_axis = 512 51 | bounds = np.array([ 52 | [1e-3, 1 - 1e-3], 53 | [1e-3, 1 - 1e-3] 54 | ]) 55 | grid_dataset = data_.TestGridDataset( 56 | num_points_per_axis=num_points_per_axis, 57 | bounds=bounds 58 | ) 59 | grid_loader = data.DataLoader( 60 | dataset=grid_dataset, 61 | batch_size=1000, 62 | drop_last=False 63 | ) 64 | 65 | # create model 66 | distribution = distributions.TweakedUniform( 67 | low=torch.zeros(dim), 68 | high=torch.ones(dim) 69 | ) 70 | transform = transforms.CompositeTransform([ 71 | transforms.CompositeTransform([ 72 | transforms.PiecewiseRationalQuadraticCouplingTransform( 73 | mask=utils.create_alternating_binary_mask( 74 | features=dim, 75 | even=(i % 2 == 0) 76 | ), 77 | transform_net_create_fn=lambda in_features, out_features: 78 | nn_.ResidualNet( 79 | in_features=in_features, 80 | out_features=out_features, 81 | hidden_features=args.hidden_features, 82 | num_blocks=args.num_transform_blocks, 83 | dropout_probability=args.dropout_probability, 84 | use_batch_norm=args.use_batch_norm 85 | ), 86 | num_bins=args.num_bins, 87 | tails=None, 88 | tail_bound=1, 89 | # apply_unconditional_transform=args.apply_unconditional_transform, 90 | min_bin_width=args.min_bin_width 91 | ), 92 | ]) for i in range(args.num_flow_steps) 93 | ]) 94 | 95 | flow = flows.Flow(transform, distribution).to(device) 96 | path = os.path.join(cutils.get_final_root(), '{}-final.t'.format(dataset_name)) 97 | state_dict = torch.load(path) 98 | flow.load_state_dict(state_dict) 99 | flow.eval() 100 | 101 | n_params = utils.get_num_parameters(flow) 102 | print('There are {} trainable parameters in this model.'.format(n_params)) 103 | 104 | log_density_np = [] 105 | for batch in grid_loader: 106 | batch = batch.to(device) 107 | log_density = flow.log_prob(batch) 108 | log_density_np = np.concatenate( 109 | (log_density_np, utils.tensor2numpy(log_density)) 110 | ) 111 | 112 | vmax = np.exp(log_density_np).max() 113 | cmap = cm.magma 114 | # plot data 115 | figure, axes = plt.subplots(1, 1, figsize=(2.5, 2.5)) 116 | axes.hist2d(utils.tensor2numpy(train_dataset.data[:, 0]), 117 | utils.tensor2numpy(train_dataset.data[:, 1]), 118 | range=bounds, bins=512, cmap=cmap, rasterized=False, normed=True, vmax=1.1*vmax) 119 | axes.set_xlim(bounds[0]) 120 | axes.set_ylim(bounds[1]) 121 | axes.set_xticks([]) 122 | axes.set_yticks([]) 123 | plt.tight_layout() 124 | path = os.path.join(cutils.get_output_root(), '{}-data.png'.format(dataset_name)) 125 | plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=300) 126 | plt.close() 127 | 128 | # plot density 129 | figure, axes = plt.subplots(1, 1, figsize=(2.5, 2.5)) 130 | axes.pcolormesh(grid_dataset.X, grid_dataset.Y, 131 | np.exp(log_density_np).reshape(grid_dataset.X.shape), 132 | cmap=cmap, vmin=0, vmax=vmax) 133 | axes.set_xlim(bounds[0]) 134 | axes.set_ylim(bounds[1]) 135 | axes.set_xticks([]) 136 | axes.set_yticks([]) 137 | plt.tight_layout() 138 | path = os.path.join(cutils.get_output_root(), '{}-density.png'.format(dataset_name)) 139 | plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=300) 140 | plt.close() 141 | 142 | # plot samples 143 | figure, axes = plt.subplots(1, 1, figsize=(2.5, 2.5)) 144 | with torch.no_grad(): 145 | samples = utils.tensor2numpy( 146 | flow.sample(num_samples=int(1e6), batch_size=int(1e5))) 147 | axes.hist2d(samples[:, 0], samples[:, 1], 148 | range=bounds, bins=512, cmap=cmap, rasterized=False, normed=True, vmax=1.1*vmax) 149 | axes.set_xlim(bounds[0]) 150 | axes.set_ylim(bounds[1]) 151 | axes.set_xticks([]) 152 | axes.set_yticks([]) 153 | plt.tight_layout() 154 | path = os.path.join(cutils.get_output_root(), '{}-samples.png'.format(dataset_name)) 155 | plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=300) 156 | plt.close() 157 | 158 | -------------------------------------------------------------------------------- /experiments/image_configs/cifar-10-5bit.json: -------------------------------------------------------------------------------- 1 | { 2 | "actnorm": true, 3 | "alpha": 0.05, 4 | "batch_size": 512, 5 | "cosine_annealing": true, 6 | "coupling_layer_type": "rational_quadratic_spline", 7 | "dataset": "cifar-10-fast", 8 | "dropout_prob": 0.2, 9 | "eta_min": 0.0, 10 | "flow_checkpoint": null, 11 | "hidden_channels": 64, 12 | "intervals": { 13 | "eval": 500, 14 | "log": 10, 15 | "reconstruct": 500, 16 | "sample": 250, 17 | "save": 500 18 | }, 19 | "learning_rate": 0.0005, 20 | "levels": 3, 21 | "multi_gpu": false, 22 | "multi_scale": false, 23 | "num_bits": 5, 24 | "num_res_blocks": 3, 25 | "num_steps": 100000, 26 | "num_workers": 1, 27 | "optimizer_checkpoint": null, 28 | "pad": 2, 29 | "preprocessing": "glow", 30 | "resnet_batchnorm": true, 31 | "run_descr": "5-bit", 32 | "seed": 656693568, 33 | "spline_params": { 34 | "apply_unconditional_transform": false, 35 | "min_bin_height": 0.001, 36 | "min_bin_width": 0.001, 37 | "min_derivative": 0.001, 38 | "num_bins": 2, 39 | "tail_bound": 3.0 40 | }, 41 | "start_step": 0, 42 | "steps_per_level": 7, 43 | "temperatures": [ 44 | 0.5, 45 | 0.75, 46 | 1.0 47 | ], 48 | "use_gpu": true, 49 | "use_resnet": true, 50 | "valid_frac": 0.01, 51 | "warmup_fraction": 0.0 52 | } 53 | -------------------------------------------------------------------------------- /experiments/image_configs/cifar-10-8bit.json: -------------------------------------------------------------------------------- 1 | { 2 | "actnorm": true, 3 | "alpha": 0.05, 4 | "batch_size": 512, 5 | "cosine_annealing": true, 6 | "coupling_layer_type": "rational_quadratic_spline", 7 | "dataset": "cifar-10-fast", 8 | "dropout_prob": 0.2, 9 | "eta_min": 0.0, 10 | "flow_checkpoint": null, 11 | "hidden_channels": 96, 12 | "intervals": { 13 | "eval": 500, 14 | "log": 10, 15 | "reconstruct": 500, 16 | "sample": 250, 17 | "save": 500 18 | }, 19 | "learning_rate": 0.0005, 20 | "levels": 3, 21 | "multi_gpu": false, 22 | "multi_scale": false, 23 | "num_bits": 8, 24 | "num_res_blocks": 3, 25 | "num_steps": 200000, 26 | "num_workers": 1, 27 | "optimizer_checkpoint": null, 28 | "pad": 2, 29 | "preprocessing": "glow", 30 | "resnet_batchnorm": true, 31 | "run_descr": "8-bit", 32 | "seed": 656693568, 33 | "spline_params": { 34 | "apply_unconditional_transform": false, 35 | "min_bin_height": 0.001, 36 | "min_bin_width": 0.001, 37 | "min_derivative": 0.001, 38 | "num_bins": 4, 39 | "tail_bound": 3.0 40 | }, 41 | "start_step": 0, 42 | "steps_per_level": 7, 43 | "temperatures": [ 44 | 0.5, 45 | 0.75, 46 | 1.0 47 | ], 48 | "use_gpu": true, 49 | "use_resnet": true, 50 | "valid_frac": 0.01, 51 | "warmup_fraction": 0.0 52 | } 53 | -------------------------------------------------------------------------------- /experiments/image_configs/imagenet-64-5bit.json: -------------------------------------------------------------------------------- 1 | { 2 | "actnorm": true, 3 | "alpha": 0.05, 4 | "batch_size": 256, 5 | "cosine_annealing": true, 6 | "coupling_layer_type": "rational_quadratic_spline", 7 | "dataset": "imagenet-64-fast", 8 | "dropout_prob": 0.1, 9 | "eta_min": 0.0, 10 | "flow_checkpoint": null, 11 | "hidden_channels": 96, 12 | "intervals": { 13 | "eval": 500, 14 | "log": 10, 15 | "reconstruct": 500, 16 | "sample": 250, 17 | "save": 500 18 | }, 19 | "learning_rate": 0.0005, 20 | "levels": 4, 21 | "multi_gpu": true, 22 | "multi_scale": true, 23 | "num_bits": 5, 24 | "num_res_blocks": 3, 25 | "num_steps": 100000, 26 | "num_workers": 1, 27 | "optimizer_checkpoint": null, 28 | "pad": 2, 29 | "preprocessing": "glow", 30 | "resnet_batchnorm": true, 31 | "run_descr": "5-bit", 32 | "seed": 656693568, 33 | "spline_params": { 34 | "apply_unconditional_transform": false, 35 | "min_bin_height": 0.001, 36 | "min_bin_width": 0.001, 37 | "min_derivative": 0.001, 38 | "num_bins": 8, 39 | "tail_bound": 3.0 40 | }, 41 | "start_step": 0, 42 | "steps_per_level": 7, 43 | "temperatures": [ 44 | 0.5, 45 | 0.75, 46 | 1.0 47 | ], 48 | "use_gpu": true, 49 | "use_resnet": true, 50 | "valid_frac": 0.01, 51 | "warmup_fraction": 0.0 52 | } 53 | -------------------------------------------------------------------------------- /experiments/image_configs/imagenet-64-8bit.json: -------------------------------------------------------------------------------- 1 | { 2 | "actnorm": true, 3 | "alpha": 0.05, 4 | "batch_size": 256, 5 | "cosine_annealing": true, 6 | "coupling_layer_type": "rational_quadratic_spline", 7 | "dataset": "imagenet-64-fast", 8 | "dropout_prob": 0.0, 9 | "eta_min": 0.0, 10 | "flow_checkpoint": null, 11 | "hidden_channels": 96, 12 | "intervals": { 13 | "eval": 500, 14 | "log": 10, 15 | "reconstruct": 500, 16 | "sample": 250, 17 | "save": 500 18 | }, 19 | "learning_rate": 0.0005, 20 | "levels": 4, 21 | "multi_gpu": true, 22 | "multi_scale": true, 23 | "num_bits": 8, 24 | "num_res_blocks": 3, 25 | "num_steps": 200000, 26 | "num_workers": 1, 27 | "optimizer_checkpoint": null, 28 | "pad": 2, 29 | "preprocessing": "glow", 30 | "resnet_batchnorm": true, 31 | "run_descr": "8-bit", 32 | "seed": 656693568, 33 | "spline_params": { 34 | "apply_unconditional_transform": false, 35 | "min_bin_height": 0.001, 36 | "min_bin_width": 0.001, 37 | "min_derivative": 0.001, 38 | "num_bins": 8, 39 | "tail_bound": 3.0 40 | }, 41 | "start_step": 0, 42 | "steps_per_level": 7, 43 | "temperatures": [ 44 | 0.5, 45 | 0.75, 46 | 1.0 47 | ], 48 | "use_gpu": true, 49 | "use_resnet": true, 50 | "valid_frac": 0.01, 51 | "warmup_fraction": 0.0 52 | } 53 | -------------------------------------------------------------------------------- /experiments/process_celeba_hq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | from scipy.misc import imresize 4 | import os 5 | from os import listdir 6 | from os.path import isfile, join 7 | from tqdm import tqdm 8 | 9 | def process_images(*, path, outfile): 10 | assert os.path.exists(path), "Input path doesn't exist" 11 | files = [f for f in listdir(path) if isfile(join(path, f))] 12 | print('Number of valid images is:', len(files)) 13 | imgs = [] 14 | for i in tqdm(range(len(files))): 15 | img = scipy.ndimage.imread(join(path, files[i])) 16 | assert img.shape == (256, 256, 3) 17 | assert isinstance(img, np.ndarray) 18 | 19 | img = img.astype('uint8') 20 | img = imresize(img, (64, 64)) 21 | 22 | assert img.dtype == np.uint8 23 | assert img.shape == (64, 64, 3) 24 | 25 | # HWC -> CHW, for use in PyTorch 26 | img = img.transpose(2, 0, 1) 27 | assert img.shape == (3, 64, 64) 28 | 29 | imgs.append(img) 30 | 31 | imgs = np.asarray(imgs).astype('uint8') 32 | assert imgs.shape[1:] == (3, 64, 64) 33 | 34 | np.save(outfile, imgs) 35 | 36 | if __name__ == '__main__': 37 | process_images(path='./train-png', outfile='./train.npy') 38 | process_images(path='./validation-png', outfile='./valid.npy') 39 | -------------------------------------------------------------------------------- /experiments/process_imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run the following commands before running this file 3 | wget http://image-net.org/small/train_64x64.tar 4 | wget http://image-net.org/small/valid_64x64.tar 5 | tar -xvf train_64x64.tar 6 | tar -xvf valid_64x64.tar 7 | """ 8 | 9 | import numpy as np 10 | import scipy.ndimage 11 | import os 12 | from os import listdir 13 | from os.path import isfile, join 14 | from tqdm import tqdm 15 | 16 | def process_images(*, path, outfile): 17 | assert os.path.exists(path), "Input path doesn't exist" 18 | files = [f for f in listdir(path) if isfile(join(path, f))] 19 | print('Number of valid images is:', len(files)) 20 | imgs = [] 21 | for i in tqdm(range(len(files))): 22 | img = scipy.ndimage.imread(join(path, files[i])) 23 | 24 | assert isinstance(img, np.ndarray) 25 | 26 | img = img.astype('uint8') 27 | 28 | # HWC -> CHW, for use in PyTorch 29 | img = img.transpose(2, 0, 1) 30 | assert img.shape == (3, 64, 64) 31 | 32 | imgs.append(img) 33 | 34 | imgs = np.asarray(imgs).astype('uint8') 35 | assert imgs.shape[1:] == (3, 64, 64) 36 | 37 | np.save(outfile, imgs) 38 | 39 | if __name__ == '__main__': 40 | process_images(path='./train_64x64', outfile='./train_64x64.npy') 41 | process_images(path='./valid_64x64', outfile='./valid_64x64.npy') 42 | -------------------------------------------------------------------------------- /nde/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesiains/nsf/eaa9377f75df1193025f6b2487524cf266874472/nde/__init__.py -------------------------------------------------------------------------------- /nde/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Distribution 2 | from .base import NoMeanException 3 | 4 | from .normal import StandardNormal 5 | from .normal import ConditionalDiagonalNormal 6 | 7 | from .discrete import ConditionalIndependentBernoulli 8 | -------------------------------------------------------------------------------- /nde/distributions/base.py: -------------------------------------------------------------------------------- 1 | """Basic definitions for the distributions module.""" 2 | 3 | import torch 4 | 5 | from torch import nn 6 | 7 | import utils 8 | 9 | 10 | class NoMeanException(Exception): 11 | """Exception to be thrown when a mean function doesn't exist.""" 12 | pass 13 | 14 | 15 | class Distribution(nn.Module): 16 | """Base class for all distribution objects.""" 17 | 18 | def forward(self, *args): 19 | raise RuntimeError('Forward method cannot be called for a Distribution object.') 20 | 21 | def log_prob(self, inputs, context=None): 22 | """Calculate log probability under the distribution. 23 | 24 | Args: 25 | inputs: Tensor, input variables. 26 | context: Tensor or None, conditioning variables. If a Tensor, it must have the same 27 | number or rows as the inputs. If None, the context is ignored. 28 | 29 | Returns: 30 | A Tensor of shape [input_size], the log probability of the inputs given the context. 31 | """ 32 | inputs = torch.as_tensor(inputs) 33 | if context is not None: 34 | context = torch.as_tensor(context) 35 | if inputs.shape[0] != context.shape[0]: 36 | raise ValueError('Number of input items must be equal to number of context items.') 37 | return self._log_prob(inputs, context) 38 | 39 | def _log_prob(self, inputs, context): 40 | raise NotImplementedError() 41 | 42 | def sample(self, num_samples, context=None, batch_size=None): 43 | """Generates samples from the distribution. Samples can be generated in batches. 44 | 45 | Args: 46 | num_samples: int, number of samples to generate. 47 | context: Tensor or None, conditioning variables. If None, the context is ignored. 48 | batch_size: int or None, number of samples per batch. If None, all samples are generated 49 | in one batch. 50 | 51 | Returns: 52 | A Tensor containing the samples, with shape [num_samples, ...] if context is None, or 53 | [context_size, num_samples, ...] if context is given. 54 | """ 55 | if not utils.is_positive_int(num_samples): 56 | raise TypeError('Number of samples must be a positive integer.') 57 | 58 | if context is not None: 59 | context = torch.as_tensor(context) 60 | 61 | if batch_size is None: 62 | return self._sample(num_samples, context) 63 | 64 | else: 65 | if not utils.is_positive_int(batch_size): 66 | raise TypeError('Batch size must be a positive integer.') 67 | 68 | num_batches = num_samples // batch_size 69 | num_leftover = num_samples % batch_size 70 | samples = [self._sample(batch_size, context) for _ in range(num_batches)] 71 | if num_leftover > 0: 72 | samples.append(self._sample(num_leftover, context)) 73 | return torch.cat(samples, dim=0) 74 | 75 | def _sample(self, num_samples, context): 76 | raise NotImplementedError() 77 | 78 | def sample_and_log_prob(self, num_samples, context=None): 79 | """Generates samples from the distribution together with their log probability. 80 | 81 | Args: 82 | num_samples: int, number of samples to generate. 83 | context: Tensor or None, conditioning variables. If None, the context is ignored. 84 | 85 | Returns: 86 | A tuple of: 87 | * A Tensor containing the samples, with shape [num_samples, ...] if context is None, 88 | or [context_size, num_samples, ...] if context is given. 89 | * A Tensor containing the log probabilities of the samples, with shape 90 | [num_samples, ...] if context is None, or [context_size, num_samples, ...] if 91 | context is given. 92 | """ 93 | samples = self.sample(num_samples, context=context) 94 | 95 | if context is not None: 96 | # Merge the context dimension with sample dimension in order to call log_prob. 97 | samples = utils.merge_leading_dims(samples, num_dims=2) 98 | context = utils.repeat_rows(context, num_reps=num_samples) 99 | assert samples.shape[0] == context.shape[0] 100 | 101 | log_prob = self.log_prob(samples, context=context) 102 | 103 | if context is not None: 104 | # Split the context dimension from sample dimension. 105 | samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) 106 | log_prob = utils.split_leading_dim(log_prob, shape=[-1, num_samples]) 107 | 108 | return samples, log_prob 109 | 110 | def mean(self, context=None): 111 | if context is not None: 112 | context = torch.as_tensor(context) 113 | return self._mean(context) 114 | 115 | def _mean(self, context): 116 | raise NoMeanException() 117 | -------------------------------------------------------------------------------- /nde/distributions/discrete.py: -------------------------------------------------------------------------------- 1 | """Implementations of discrete distributions.""" 2 | 3 | import torch 4 | 5 | from torch.nn import functional as F 6 | 7 | import utils 8 | 9 | from nde import distributions 10 | 11 | 12 | class ConditionalIndependentBernoulli(distributions.Distribution): 13 | """An independent Bernoulli whose parameters are functions of a context.""" 14 | 15 | def __init__(self, shape, context_encoder=None): 16 | """Constructor. 17 | 18 | Args: 19 | shape: list, tuple or torch.Size, the shape of the input variables. 20 | context_encoder: callable or None, encodes the context to the distribution parameters. 21 | If None, defaults to the identity function. 22 | """ 23 | super().__init__() 24 | self._shape = torch.Size(shape) 25 | if context_encoder is None: 26 | self._context_encoder = lambda x: x 27 | else: 28 | self._context_encoder = context_encoder 29 | 30 | def _compute_params(self, context): 31 | """Compute the logits from context.""" 32 | if context is None: 33 | raise ValueError('Context can\'t be None.') 34 | 35 | logits = self._context_encoder(context) 36 | if logits.shape[0] != context.shape[0]: 37 | raise RuntimeError( 38 | 'The batch dimension of the parameters is inconsistent with the input.') 39 | 40 | return logits.reshape(logits.shape[0], *self._shape) 41 | 42 | def _log_prob(self, inputs, context): 43 | if inputs.shape[1:] != self._shape: 44 | raise ValueError('Expected input of shape {}, got {}'.format( 45 | self._shape, inputs.shape[1:])) 46 | 47 | # Compute parameters. 48 | logits = self._compute_params(context) 49 | assert logits.shape == inputs.shape 50 | 51 | # Compute log prob. 52 | log_prob = -inputs * F.softplus(-logits) - (1.0 - inputs) * F.softplus(logits) 53 | log_prob = utils.sum_except_batch(log_prob, num_batch_dims=1) 54 | return log_prob 55 | 56 | def _sample(self, num_samples, context): 57 | # Compute parameters. 58 | logits = self._compute_params(context) 59 | probs = torch.sigmoid(logits) 60 | probs = utils.repeat_rows(probs, num_samples) 61 | 62 | # Generate samples. 63 | context_size = context.shape[0] 64 | noise = torch.rand(context_size * num_samples, *self._shape) 65 | samples = (noise < probs).float() 66 | return utils.split_leading_dim(samples, [context_size, num_samples]) 67 | 68 | def _mean(self, context): 69 | logits = self._compute_params(context) 70 | return torch.sigmoid(logits) 71 | -------------------------------------------------------------------------------- /nde/distributions/discrete_test.py: -------------------------------------------------------------------------------- 1 | """Tests for discrete distributions.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | from nde.distributions import discrete 7 | 8 | 9 | class ConditionalIndependentBernoulliTest(torchtestcase.TorchTestCase): 10 | 11 | def test_log_prob(self): 12 | batch_size = 10 13 | input_shape = [2, 3, 4] 14 | context_shape = [2, 3, 4] 15 | dist = discrete.ConditionalIndependentBernoulli(input_shape) 16 | inputs = torch.randn(batch_size, *input_shape) 17 | context = torch.randn(batch_size, *context_shape) 18 | log_prob = dist.log_prob(inputs, context=context) 19 | self.assertIsInstance(log_prob, torch.Tensor) 20 | self.assertEqual(log_prob.shape, torch.Size([batch_size])) 21 | self.assertFalse(torch.isnan(log_prob).any()) 22 | self.assertFalse(torch.isinf(log_prob).any()) 23 | self.assert_tensor_less_equal(log_prob, 0.0) 24 | 25 | def test_sample(self): 26 | num_samples = 10 27 | context_size = 20 28 | input_shape = [2, 3, 4] 29 | context_shape = [2, 3, 4] 30 | dist = discrete.ConditionalIndependentBernoulli(input_shape) 31 | context = torch.randn(context_size, *context_shape) 32 | samples = dist.sample(num_samples, context=context) 33 | self.assertIsInstance(samples, torch.Tensor) 34 | self.assertEqual(samples.shape, torch.Size([context_size, num_samples] + input_shape)) 35 | self.assertFalse(torch.isnan(samples).any()) 36 | self.assertFalse(torch.isinf(samples).any()) 37 | binary = (samples == 1.0) | (samples == 0.0) 38 | self.assertEqual(binary, torch.ones_like(binary)) 39 | 40 | def test_sample_and_log_prob_with_context(self): 41 | num_samples = 10 42 | context_size = 20 43 | input_shape = [2, 3, 4] 44 | context_shape = [2, 3, 4] 45 | 46 | dist = discrete.ConditionalIndependentBernoulli(input_shape) 47 | context = torch.randn(context_size, *context_shape) 48 | samples, log_prob = dist.sample_and_log_prob(num_samples, context=context) 49 | 50 | self.assertIsInstance(samples, torch.Tensor) 51 | self.assertIsInstance(log_prob, torch.Tensor) 52 | 53 | self.assertEqual(samples.shape, torch.Size([context_size, num_samples] + input_shape)) 54 | self.assertEqual(log_prob.shape, torch.Size([context_size, num_samples])) 55 | 56 | self.assertFalse(torch.isnan(log_prob).any()) 57 | self.assertFalse(torch.isinf(log_prob).any()) 58 | self.assert_tensor_less_equal(log_prob, 0.0) 59 | 60 | self.assertFalse(torch.isnan(samples).any()) 61 | self.assertFalse(torch.isinf(samples).any()) 62 | binary = (samples == 1.0) | (samples == 0.0) 63 | self.assertEqual(binary, torch.ones_like(binary)) 64 | 65 | def test_mean(self): 66 | context_size = 20 67 | input_shape = [2, 3, 4] 68 | context_shape = [2, 3, 4] 69 | dist = discrete.ConditionalIndependentBernoulli(input_shape) 70 | context = torch.randn(context_size, *context_shape) 71 | means = dist.mean(context=context) 72 | self.assertIsInstance(means, torch.Tensor) 73 | self.assertEqual(means.shape, torch.Size([context_size] + input_shape)) 74 | self.assertFalse(torch.isnan(means).any()) 75 | self.assertFalse(torch.isinf(means).any()) 76 | self.assert_tensor_greater_equal(means, 0.0) 77 | self.assert_tensor_less_equal(means, 1.0) 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /nde/distributions/mixture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import distributions 4 | 5 | 6 | class MixtureSameFamily(distributions.Distribution): 7 | def __init__(self, mixture_distribution, components_distribution): 8 | self.mixture_distribution = mixture_distribution 9 | self.components_distribution = components_distribution 10 | 11 | has_rsample=False 12 | 13 | super().__init__( 14 | batch_shape=self.components_distribution.batch_shape, 15 | event_shape=self.components_distribution.event_shape 16 | ) 17 | 18 | def expand(self, batch_shape, _instance=None): 19 | raise NotImplementedError 20 | 21 | @property 22 | def arg_constraints(self): 23 | return dict( 24 | self.mixture_distribution.arg_constraints, 25 | **self.components_distribution.arg_constraints 26 | ) 27 | 28 | @property 29 | def support(self): 30 | return self.components_distribution.support 31 | 32 | @property 33 | def mean(self): 34 | raise NotImplementedError 35 | 36 | @property 37 | def variance(self): 38 | raise NotImplementedError 39 | 40 | def sample(self, sample_shape=torch.Size()): 41 | mixture_mask = self.mixture_distribution.sample(sample_shape) # [S, B, D, M] 42 | if len(mixture_mask.shape) == 3: 43 | mixture_mask = mixture_mask[:, None, ...] 44 | components_samples = self.components_distribution.rsample(sample_shape) # [S, B, D, M] 45 | samples = torch.sum(mixture_mask * components_samples, dim=-1) # [S, B, D] 46 | return samples 47 | 48 | def rsample(self, sample_shape=torch.Size()): 49 | raise NotImplementedError 50 | 51 | def log_prob(self, value): 52 | # pad value for evaluation under component density 53 | value = value.permute(2, 0, 1) # [S, B, D] 54 | value = value[..., None].repeat(1, 1, 1, self.batch_shape[-1]) # [S, B, D, M] 55 | log_prob_components = self.components_distribution.log_prob(value).permute(1, 2, 3, 0) 56 | 57 | # calculate numerically stable log coefficients, and pad 58 | log_prob_mixture = self.mixture_distribution.logits 59 | log_prob_mixture = log_prob_mixture[..., None] 60 | return torch.logsumexp(log_prob_mixture + log_prob_components, dim=-2) 61 | 62 | def cdf(self, value): 63 | raise NotImplementedError 64 | 65 | def icdf(self, value): 66 | raise NotImplementedError 67 | 68 | def enumerate_support(self, expand=True): 69 | raise NotImplementedError 70 | 71 | def entropy(self): 72 | raise NotImplementedError 73 | 74 | 75 | def main(): 76 | pass 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /nde/distributions/normal.py: -------------------------------------------------------------------------------- 1 | """Implementations of Normal distributions.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import utils 7 | 8 | from nde import distributions 9 | 10 | 11 | class StandardNormal(distributions.Distribution): 12 | """A multivariate Normal with zero mean and unit covariance.""" 13 | 14 | def __init__(self, shape): 15 | super().__init__() 16 | self._shape = torch.Size(shape) 17 | self._log_z = 0.5 * np.prod(shape) * np.log(2 * np.pi) 18 | 19 | def _log_prob(self, inputs, context): 20 | # Note: the context is ignored. 21 | if inputs.shape[1:] != self._shape: 22 | raise ValueError('Expected input of shape {}, got {}'.format( 23 | self._shape, inputs.shape[1:])) 24 | neg_energy = -0.5 * utils.sum_except_batch(inputs ** 2, num_batch_dims=1) 25 | return neg_energy - self._log_z 26 | 27 | def _sample(self, num_samples, context): 28 | if context is None: 29 | return torch.randn(num_samples, *self._shape) 30 | else: 31 | # The value of the context is ignored, only its size is taken into account. 32 | context_size = context.shape[0] 33 | samples = torch.randn(context_size * num_samples, *self._shape) 34 | return utils.split_leading_dim(samples, [context_size, num_samples]) 35 | 36 | def _mean(self, context): 37 | if context is None: 38 | return torch.zeros(self._shape) 39 | else: 40 | # The value of the context is ignored, only its size is taken into account. 41 | return torch.zeros(context.shape[0], *self._shape) 42 | 43 | 44 | class ConditionalDiagonalNormal(distributions.Distribution): 45 | """A diagonal multivariate Normal whose parameters are functions of a context.""" 46 | 47 | def __init__(self, shape, context_encoder=None): 48 | """Constructor. 49 | 50 | Args: 51 | shape: list, tuple or torch.Size, the shape of the input variables. 52 | context_encoder: callable or None, encodes the context to the distribution parameters. 53 | If None, defaults to the identity function. 54 | """ 55 | super().__init__() 56 | self._shape = torch.Size(shape) 57 | if context_encoder is None: 58 | self._context_encoder = lambda x: x 59 | else: 60 | self._context_encoder = context_encoder 61 | self._log_z = 0.5 * np.prod(shape) * np.log(2 * np.pi) 62 | 63 | def _compute_params(self, context): 64 | """Compute the means and log stds form the context.""" 65 | if context is None: 66 | raise ValueError('Context can\'t be None.') 67 | 68 | params = self._context_encoder(context) 69 | if params.shape[-1] % 2 != 0: 70 | raise RuntimeError( 71 | 'The context encoder must return a tensor whose last dimension is even.') 72 | if params.shape[0] != context.shape[0]: 73 | raise RuntimeError( 74 | 'The batch dimension of the parameters is inconsistent with the input.') 75 | 76 | split = params.shape[-1] // 2 77 | means = params[..., :split].reshape(params.shape[0], *self._shape) 78 | log_stds = params[..., split:].reshape(params.shape[0], *self._shape) 79 | return means, log_stds 80 | 81 | def _log_prob(self, inputs, context): 82 | if inputs.shape[1:] != self._shape: 83 | raise ValueError('Expected input of shape {}, got {}'.format( 84 | self._shape, inputs.shape[1:])) 85 | 86 | # Compute parameters. 87 | means, log_stds = self._compute_params(context) 88 | assert means.shape == inputs.shape and log_stds.shape == inputs.shape 89 | 90 | # Compute log prob. 91 | norm_inputs = (inputs - means) * torch.exp(-log_stds) 92 | log_prob = -0.5 * utils.sum_except_batch(norm_inputs ** 2, num_batch_dims=1) 93 | log_prob -= utils.sum_except_batch(log_stds, num_batch_dims=1) 94 | log_prob -= self._log_z 95 | return log_prob 96 | 97 | def _sample(self, num_samples, context): 98 | # Compute parameters. 99 | means, log_stds = self._compute_params(context) 100 | stds = torch.exp(log_stds) 101 | means = utils.repeat_rows(means, num_samples) 102 | stds = utils.repeat_rows(stds, num_samples) 103 | 104 | # Generate samples. 105 | context_size = context.shape[0] 106 | noise = torch.randn(context_size * num_samples, *self._shape) 107 | samples = means + stds * noise 108 | return utils.split_leading_dim(samples, [context_size, num_samples]) 109 | 110 | def _mean(self, context): 111 | means, _ = self._compute_params(context) 112 | return means 113 | -------------------------------------------------------------------------------- /nde/distributions/uniform.py: -------------------------------------------------------------------------------- 1 | from torch import distributions 2 | 3 | import utils 4 | 5 | 6 | class TweakedUniform(distributions.Uniform): 7 | def log_prob(self, value, context): 8 | return utils.sum_except_batch(super().log_prob(value)) 9 | # result = super().log_prob(value) 10 | # if len(result.shape) == 2 and result.shape[1] == 1: 11 | # return result.reshape(-1) 12 | # else: 13 | # return result 14 | 15 | def sample(self, num_samples, context): 16 | return super().sample((num_samples, )) 17 | -------------------------------------------------------------------------------- /nde/flows/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Flow 2 | 3 | from .autoregressive import MaskedAutoregressiveFlow 4 | 5 | from .realnvp import SimpleRealNVP 6 | -------------------------------------------------------------------------------- /nde/flows/autoregressive.py: -------------------------------------------------------------------------------- 1 | """Implementations of autoregressive flows.""" 2 | 3 | from torch.nn import functional as F 4 | 5 | from nde import distributions 6 | from nde import flows 7 | from nde import transforms 8 | 9 | 10 | class MaskedAutoregressiveFlow(flows.Flow): 11 | """An autoregressive flow that uses affine transforms with masking. 12 | 13 | Reference: 14 | > G. Papamakarios et al., Masked Autoregressive Flow for Density Estimation, 15 | > Advances in Neural Information Processing Systems, 2017. 16 | """ 17 | 18 | def __init__(self, 19 | features, 20 | hidden_features, 21 | num_layers, 22 | num_blocks_per_layer, 23 | use_residual_blocks=True, 24 | use_random_masks=False, 25 | use_random_permutations=False, 26 | activation=F.relu, 27 | dropout_probability=0., 28 | batch_norm_within_layers=False, 29 | batch_norm_between_layers=False): 30 | 31 | if use_random_permutations: 32 | permutation_constructor = transforms.RandomPermutation 33 | else: 34 | permutation_constructor = transforms.ReversePermutation 35 | 36 | layers = [] 37 | for _ in range(num_layers): 38 | layers.append(permutation_constructor(features)) 39 | layers.append( 40 | transforms.MaskedAffineAutoregressiveTransform( 41 | features=features, 42 | hidden_features=hidden_features, 43 | num_blocks=num_blocks_per_layer, 44 | use_residual_blocks=use_residual_blocks, 45 | random_mask=use_random_masks, 46 | activation=activation, 47 | dropout_probability=dropout_probability, 48 | use_batch_norm=batch_norm_within_layers, 49 | ) 50 | ) 51 | if batch_norm_between_layers: 52 | layers.append(transforms.BatchNorm(features)) 53 | 54 | super().__init__( 55 | transform=transforms.CompositeTransform(layers), 56 | distribution=distributions.StandardNormal([features]), 57 | ) 58 | -------------------------------------------------------------------------------- /nde/flows/autoregressive_test.py: -------------------------------------------------------------------------------- 1 | """Tests for autoregressive flows.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | 7 | from nde.flows import autoregressive as ar 8 | 9 | 10 | class MaskedAutoregressiveFlowTest(torchtestcase.TorchTestCase): 11 | 12 | def test_log_prob(self): 13 | batch_size = 10 14 | features = 20 15 | flow = ar.MaskedAutoregressiveFlow( 16 | features=features, 17 | hidden_features=30, 18 | num_layers=5, 19 | num_blocks_per_layer=2, 20 | ) 21 | inputs = torch.randn(batch_size, features) 22 | log_prob = flow.log_prob(inputs) 23 | self.assertIsInstance(log_prob, torch.Tensor) 24 | self.assertEqual(log_prob.shape, torch.Size([batch_size])) 25 | 26 | def test_sample(self): 27 | num_samples = 10 28 | features = 20 29 | flow = ar.MaskedAutoregressiveFlow( 30 | features=features, 31 | hidden_features=30, 32 | num_layers=5, 33 | num_blocks_per_layer=2, 34 | ) 35 | samples = flow.sample(num_samples) 36 | self.assertIsInstance(samples, torch.Tensor) 37 | self.assertEqual(samples.shape, torch.Size([num_samples, features])) 38 | 39 | 40 | if __name__ == '__main__': 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /nde/flows/base.py: -------------------------------------------------------------------------------- 1 | """Basic definitions for the flows module.""" 2 | 3 | import utils 4 | 5 | from nde import distributions 6 | 7 | 8 | class Flow(distributions.Distribution): 9 | """Base class for all flow objects.""" 10 | 11 | def __init__(self, transform, distribution): 12 | """Constructor. 13 | 14 | Args: 15 | transform: A `Transform` object, it transforms data into noise. 16 | distribution: A `Distribution` object, the base distribution of the flow that 17 | generates the noise. 18 | """ 19 | super().__init__() 20 | self._transform = transform 21 | self._distribution = distribution 22 | 23 | def _log_prob(self, inputs, context): 24 | noise, logabsdet = self._transform(inputs, context=context) 25 | log_prob = self._distribution.log_prob(noise, context=context) 26 | return log_prob + logabsdet 27 | 28 | def _sample(self, num_samples, context): 29 | noise = self._distribution.sample(num_samples, context=context) 30 | 31 | if context is not None: 32 | # Merge the context dimension with sample dimension in order to apply the transform. 33 | noise = utils.merge_leading_dims(noise, num_dims=2) 34 | context = utils.repeat_rows(context, num_reps=num_samples) 35 | 36 | samples, _ = self._transform.inverse(noise, context=context) 37 | 38 | if context is not None: 39 | # Split the context dimension from sample dimension. 40 | samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) 41 | 42 | return samples 43 | 44 | def sample_and_log_prob(self, num_samples, context=None): 45 | """Generates samples from the flow, together with their log probabilities. 46 | 47 | For flows, this is more efficient that calling `sample` and `log_prob` separately. 48 | """ 49 | noise, log_prob = self._distribution.sample_and_log_prob(num_samples, context=context) 50 | 51 | if context is not None: 52 | # Merge the context dimension with sample dimension in order to apply the transform. 53 | noise = utils.merge_leading_dims(noise, num_dims=2) 54 | context = utils.repeat_rows(context, num_reps=num_samples) 55 | 56 | samples, logabsdet = self._transform.inverse(noise, context=context) 57 | 58 | if context is not None: 59 | # Split the context dimension from sample dimension. 60 | samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) 61 | logabsdet = utils.split_leading_dim(logabsdet, shape=[-1, num_samples]) 62 | 63 | return samples, log_prob - logabsdet 64 | 65 | def transform_to_noise(self, inputs, context=None): 66 | """Transforms given data into noise. Useful for goodness-of-fit checking. 67 | 68 | Args: 69 | inputs: A `Tensor` of shape [batch_size, ...], the data to be transformed. 70 | context: A `Tensor` of shape [batch_size, ...] or None, optional context associated 71 | with the data. 72 | 73 | Returns: 74 | A `Tensor` of shape [batch_size, ...], the noise. 75 | """ 76 | noise, _ = self._transform(inputs, context=context) 77 | return noise 78 | -------------------------------------------------------------------------------- /nde/flows/base_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the basic flow definitions.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | from nde import transforms 7 | from nde import distributions 8 | from nde.flows import base 9 | 10 | 11 | class FlowTest(torchtestcase.TorchTestCase): 12 | 13 | def test_log_prob(self): 14 | batch_size = 10 15 | input_shape = [2, 3, 4] 16 | context_shape = [5, 6] 17 | flow = base.Flow( 18 | transform=transforms.AffineScalarTransform(scale=2.0), 19 | distribution=distributions.StandardNormal(input_shape), 20 | ) 21 | inputs = torch.randn(batch_size, *input_shape) 22 | maybe_context = torch.randn(batch_size, *context_shape) 23 | for context in [None, maybe_context]: 24 | with self.subTest(context=context): 25 | log_prob = flow.log_prob(inputs, context=context) 26 | self.assertIsInstance(log_prob, torch.Tensor) 27 | self.assertEqual(log_prob.shape, torch.Size([batch_size])) 28 | 29 | def test_sample(self): 30 | num_samples = 10 31 | context_size = 20 32 | input_shape = [2, 3, 4] 33 | context_shape = [5, 6] 34 | flow = base.Flow( 35 | transform=transforms.AffineScalarTransform(scale=2.0), 36 | distribution=distributions.StandardNormal(input_shape), 37 | ) 38 | maybe_context = torch.randn(context_size, *context_shape) 39 | for context in [None, maybe_context]: 40 | with self.subTest(context=context): 41 | samples = flow.sample(num_samples, context=context) 42 | self.assertIsInstance(samples, torch.Tensor) 43 | if context is None: 44 | self.assertEqual(samples.shape, torch.Size([num_samples] + input_shape)) 45 | else: 46 | self.assertEqual( 47 | samples.shape, torch.Size([context_size, num_samples] + input_shape)) 48 | 49 | def test_sample_and_log_prob(self): 50 | num_samples = 10 51 | input_shape = [2, 3, 4] 52 | flow = base.Flow( 53 | transform=transforms.AffineScalarTransform(scale=2.0), 54 | distribution=distributions.StandardNormal(input_shape), 55 | ) 56 | samples, log_prob_1 = flow.sample_and_log_prob(num_samples) 57 | log_prob_2 = flow.log_prob(samples) 58 | self.assertIsInstance(samples, torch.Tensor) 59 | self.assertIsInstance(log_prob_1, torch.Tensor) 60 | self.assertIsInstance(log_prob_2, torch.Tensor) 61 | self.assertEqual(samples.shape, torch.Size([num_samples] + input_shape)) 62 | self.assertEqual(log_prob_1.shape, torch.Size([num_samples])) 63 | self.assertEqual(log_prob_2.shape, torch.Size([num_samples])) 64 | self.assertEqual(log_prob_1, log_prob_2) 65 | 66 | def test_sample_and_log_prob_with_context(self): 67 | num_samples = 10 68 | context_size = 20 69 | input_shape = [2, 3, 4] 70 | context_shape = [5, 6] 71 | flow = base.Flow( 72 | transform=transforms.AffineScalarTransform(scale=2.0), 73 | distribution=distributions.StandardNormal(input_shape), 74 | ) 75 | context = torch.randn(context_size, *context_shape) 76 | samples, log_prob = flow.sample_and_log_prob(num_samples, context=context) 77 | self.assertIsInstance(samples, torch.Tensor) 78 | self.assertIsInstance(log_prob, torch.Tensor) 79 | self.assertEqual(samples.shape, torch.Size([context_size, num_samples] + input_shape)) 80 | self.assertEqual(log_prob.shape, torch.Size([context_size, num_samples])) 81 | 82 | def test_transform_to_noise(self): 83 | batch_size = 10 84 | context_size = 20 85 | shape = [2, 3, 4] 86 | context_shape = [5, 6] 87 | flow = base.Flow( 88 | transform=transforms.AffineScalarTransform(scale=2.0), 89 | distribution=distributions.StandardNormal(shape), 90 | ) 91 | inputs = torch.randn(batch_size, *shape) 92 | maybe_context = torch.randn(context_size, *context_shape) 93 | for context in [None, maybe_context]: 94 | with self.subTest(context=context): 95 | noise = flow.transform_to_noise(inputs, context=context) 96 | self.assertIsInstance(noise, torch.Tensor) 97 | self.assertEqual(noise.shape, torch.Size([batch_size] + shape)) 98 | 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /nde/flows/realnvp.py: -------------------------------------------------------------------------------- 1 | """Implementations of Real NVP.""" 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from nde import distributions 7 | from nde import flows 8 | from nde import transforms 9 | import nn as nn_ 10 | 11 | 12 | class SimpleRealNVP(flows.Flow): 13 | """An simplified version of Real NVP for 1-dim inputs. 14 | 15 | This implementation uses 1-dim checkerboard masking but doesn't use multi-scaling. 16 | 17 | Reference: 18 | > L. Dinh et al., Density estimation using Real NVP, ICLR 2017. 19 | """ 20 | 21 | def __init__(self, 22 | features, 23 | hidden_features, 24 | num_layers, 25 | num_blocks_per_layer, 26 | use_volume_preserving=False, 27 | activation=F.relu, 28 | dropout_probability=0., 29 | batch_norm_within_layers=False, 30 | batch_norm_between_layers=False): 31 | 32 | if use_volume_preserving: 33 | coupling_constructor = transforms.AdditiveCouplingTransform 34 | else: 35 | coupling_constructor = transforms.AffineCouplingTransform 36 | 37 | mask = torch.ones(features) 38 | mask[::2] = -1 39 | 40 | def create_resnet(in_features, out_features): 41 | return nn_.ResidualNet( 42 | in_features, 43 | out_features, 44 | hidden_features=hidden_features, 45 | num_blocks=num_blocks_per_layer, 46 | activation=activation, 47 | dropout_probability=dropout_probability, 48 | use_batch_norm=batch_norm_within_layers 49 | ) 50 | 51 | layers = [] 52 | for _ in range(num_layers): 53 | transform = coupling_constructor( 54 | mask=mask, 55 | transform_net_create_fn=create_resnet 56 | ) 57 | layers.append(transform) 58 | mask *= -1 59 | if batch_norm_between_layers: 60 | layers.append(transforms.BatchNorm(features=features)) 61 | 62 | super().__init__( 63 | transform=transforms.CompositeTransform(layers), 64 | distribution=distributions.StandardNormal([features]), 65 | ) 66 | -------------------------------------------------------------------------------- /nde/flows/realnvp_test.py: -------------------------------------------------------------------------------- 1 | """Tests for Real NVP.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | 7 | from nde.flows import realnvp 8 | 9 | 10 | class SimpleRealNVPTest(torchtestcase.TorchTestCase): 11 | 12 | def test_log_prob(self): 13 | batch_size = 10 14 | features = 20 15 | flow = realnvp.SimpleRealNVP( 16 | features=features, 17 | hidden_features=30, 18 | num_layers=5, 19 | num_blocks_per_layer=2, 20 | ) 21 | inputs = torch.randn(batch_size, features) 22 | log_prob = flow.log_prob(inputs) 23 | self.assertIsInstance(log_prob, torch.Tensor) 24 | self.assertEqual(log_prob.shape, torch.Size([batch_size])) 25 | 26 | def test_sample(self): 27 | num_samples = 10 28 | features = 20 29 | flow = realnvp.SimpleRealNVP( 30 | features=features, 31 | hidden_features=30, 32 | num_layers=5, 33 | num_blocks_per_layer=2, 34 | ) 35 | samples = flow.sample(num_samples) 36 | self.assertIsInstance(samples, torch.Tensor) 37 | self.assertEqual(samples.shape, torch.Size([num_samples, features])) 38 | 39 | 40 | if __name__ == '__main__': 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /nde/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | InverseNotAvailable, 3 | InputOutsideDomain, 4 | Transform, 5 | CompositeTransform, 6 | MultiscaleCompositeTransform, 7 | InverseTransform 8 | ) 9 | 10 | from .autoregressive import ( 11 | MaskedAffineAutoregressiveTransform, 12 | MaskedPiecewiseLinearAutoregressiveTransform, 13 | MaskedPiecewiseQuadraticAutoregressiveTransform, 14 | MaskedPiecewiseCubicAutoregressiveTransform, 15 | MaskedPiecewiseRationalQuadraticAutoregressiveTransform 16 | ) 17 | 18 | from .linear import NaiveLinear 19 | from .lu import LULinear 20 | from .qr import QRLinear 21 | from .svd import SVDLinear 22 | 23 | from .nonlinearities import ( 24 | CompositeCDFTransform, 25 | LeakyReLU, 26 | Logit, 27 | LogTanh, 28 | PiecewiseLinearCDF, 29 | PiecewiseQuadraticCDF, 30 | PiecewiseCubicCDF, 31 | PiecewiseRationalQuadraticCDF, 32 | Sigmoid, 33 | Tanh 34 | ) 35 | 36 | from .normalization import ( 37 | BatchNorm, 38 | ActNorm 39 | ) 40 | 41 | from .orthogonal import HouseholderSequence 42 | 43 | from .permutations import Permutation 44 | from .permutations import RandomPermutation 45 | from .permutations import ReversePermutation 46 | 47 | from .coupling import ( 48 | AffineCouplingTransform, 49 | AdditiveCouplingTransform, 50 | PiecewiseLinearCouplingTransform, 51 | PiecewiseQuadraticCouplingTransform, 52 | PiecewiseCubicCouplingTransform, 53 | PiecewiseRationalQuadraticCouplingTransform 54 | ) 55 | 56 | from .standard import ( 57 | IdentityTransform, 58 | AffineScalarTransform, 59 | ) 60 | 61 | from .reshape import SqueezeTransform, ReshapeTransform 62 | from .conv import OneByOneConvolution -------------------------------------------------------------------------------- /nde/transforms/autoregressive_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the autoregressive transforms.""" 2 | 3 | import torch 4 | import unittest 5 | 6 | from nde.transforms import autoregressive 7 | from nde.transforms.transform_test import TransformTest 8 | 9 | 10 | class MaskedAffineAutoregressiveTransformTest(TransformTest): 11 | 12 | def test_forward(self): 13 | batch_size = 10 14 | features = 20 15 | inputs = torch.randn(batch_size, features) 16 | for use_residual_blocks, random_mask in [(False, False), 17 | (False, True), 18 | (True, False)]: 19 | with self.subTest(use_residual_blocks=use_residual_blocks, 20 | random_mask=random_mask): 21 | transform = autoregressive.MaskedAffineAutoregressiveTransform( 22 | features=features, 23 | hidden_features=30, 24 | num_blocks=5, 25 | use_residual_blocks=use_residual_blocks, 26 | random_mask=random_mask, 27 | ) 28 | outputs, logabsdet = transform(inputs) 29 | self.assert_tensor_is_good(outputs, [batch_size, features]) 30 | self.assert_tensor_is_good(logabsdet, [batch_size]) 31 | 32 | def test_inverse(self): 33 | batch_size = 10 34 | features = 20 35 | inputs = torch.randn(batch_size, features) 36 | for use_residual_blocks, random_mask in [(False, False), 37 | (False, True), 38 | (True, False)]: 39 | with self.subTest(use_residual_blocks=use_residual_blocks, 40 | random_mask=random_mask): 41 | transform = autoregressive.MaskedAffineAutoregressiveTransform( 42 | features=features, 43 | hidden_features=30, 44 | num_blocks=5, 45 | use_residual_blocks=use_residual_blocks, 46 | random_mask=random_mask, 47 | ) 48 | outputs, logabsdet = transform.inverse(inputs) 49 | self.assert_tensor_is_good(outputs, [batch_size, features]) 50 | self.assert_tensor_is_good(logabsdet, [batch_size]) 51 | 52 | def test_forward_inverse_are_consistent(self): 53 | batch_size = 10 54 | features = 20 55 | inputs = torch.randn(batch_size, features) 56 | self.eps = 1e-6 57 | for use_residual_blocks, random_mask in [(False, False), 58 | (False, True), 59 | (True, False)]: 60 | with self.subTest(use_residual_blocks=use_residual_blocks, 61 | random_mask=random_mask): 62 | transform = autoregressive.MaskedAffineAutoregressiveTransform( 63 | features=features, 64 | hidden_features=30, 65 | num_blocks=5, 66 | use_residual_blocks=use_residual_blocks, 67 | random_mask=random_mask, 68 | ) 69 | self.assert_forward_inverse_are_consistent(transform, inputs) 70 | 71 | 72 | class MaskedPiecewiseLinearAutoregressiveTranformTest(TransformTest): 73 | def test_forward_inverse_are_consistent(self): 74 | batch_size = 10 75 | features = 20 76 | inputs = torch.rand(batch_size, features) 77 | self.eps = 1e-3 78 | 79 | transform = autoregressive.MaskedPiecewiseLinearAutoregressiveTransform( 80 | num_bins=10, 81 | features=features, 82 | hidden_features=30, 83 | num_blocks=5, 84 | use_residual_blocks=True 85 | ) 86 | 87 | self.assert_forward_inverse_are_consistent(transform, inputs) 88 | 89 | 90 | class MaskedPiecewiseQuadraticAutoregressiveTranformTest(TransformTest): 91 | def test_forward_inverse_are_consistent(self): 92 | batch_size = 10 93 | features = 20 94 | inputs = torch.rand(batch_size, features) 95 | self.eps = 1e-4 96 | 97 | transform = autoregressive.MaskedPiecewiseQuadraticAutoregressiveTransform( 98 | num_bins=10, 99 | features=features, 100 | hidden_features=30, 101 | num_blocks=5, 102 | use_residual_blocks=True 103 | ) 104 | 105 | self.assert_forward_inverse_are_consistent(transform, inputs) 106 | 107 | 108 | class MaskedPiecewiseCubicAutoregressiveTranformTest(TransformTest): 109 | def test_forward_inverse_are_consistent(self): 110 | batch_size = 10 111 | features = 20 112 | inputs = torch.rand(batch_size, features) 113 | self.eps = 1e-3 114 | 115 | transform = autoregressive.MaskedPiecewiseCubicAutoregressiveTransform( 116 | num_bins=10, 117 | features=features, 118 | hidden_features=30, 119 | num_blocks=5, 120 | use_residual_blocks=True 121 | ) 122 | 123 | self.assert_forward_inverse_are_consistent(transform, inputs) 124 | 125 | 126 | if __name__ == '__main__': 127 | unittest.main() 128 | -------------------------------------------------------------------------------- /nde/transforms/base_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the basic transform definitions.""" 2 | import unittest 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from nde.transforms import base 8 | from nde.transforms import standard 9 | from nde.transforms.transform_test import TransformTest 10 | 11 | class CompositeTransformTest(TransformTest): 12 | 13 | def test_forward(self): 14 | batch_size = 10 15 | shape = [2, 3, 4] 16 | inputs = torch.randn(batch_size, *shape) 17 | transforms = [ 18 | standard.AffineScalarTransform(scale=2.0), 19 | standard.IdentityTransform(), 20 | standard.AffineScalarTransform(scale=-0.25), 21 | ] 22 | composite = base.CompositeTransform(transforms) 23 | reference = standard.AffineScalarTransform(scale=-0.5) 24 | outputs, logabsdet = composite(inputs) 25 | outputs_ref, logabsdet_ref = reference(inputs) 26 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 27 | self.assert_tensor_is_good(logabsdet, [batch_size]) 28 | self.assertEqual(outputs, outputs_ref) 29 | self.assertEqual(logabsdet, logabsdet_ref) 30 | 31 | def test_inverse(self): 32 | batch_size = 10 33 | shape = [2, 3, 4] 34 | inputs = torch.randn(batch_size, *shape) 35 | transforms = [ 36 | standard.AffineScalarTransform(scale=2.0), 37 | standard.IdentityTransform(), 38 | standard.AffineScalarTransform(scale=-0.25), 39 | ] 40 | composite = base.CompositeTransform(transforms) 41 | reference = standard.AffineScalarTransform(scale=-0.5) 42 | outputs, logabsdet = composite.inverse(inputs) 43 | outputs_ref, logabsdet_ref = reference.inverse(inputs) 44 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 45 | self.assert_tensor_is_good(logabsdet, [batch_size]) 46 | self.assertEqual(outputs, outputs_ref) 47 | self.assertEqual(logabsdet, logabsdet_ref) 48 | 49 | class MultiscaleCompositeTransformTest(TransformTest): 50 | def create_transform(self, shape, split_dim=1): 51 | mct = base.MultiscaleCompositeTransform(num_transforms=4, 52 | split_dim=split_dim) 53 | for transform in [ standard.AffineScalarTransform(scale=2.), 54 | standard.AffineScalarTransform(scale=4.), 55 | standard.AffineScalarTransform(scale=0.5), 56 | standard.AffineScalarTransform(scale=0.25)]: 57 | shape = mct.add_transform(transform, shape) 58 | 59 | return mct 60 | 61 | 62 | def test_forward(self): 63 | batch_size = 5 64 | for shape in [(32,4,4), 65 | (64,), 66 | (65,)]: 67 | with self.subTest(shape=shape): 68 | inputs = torch.ones(batch_size, *shape) 69 | transform = self.create_transform(shape) 70 | outputs, logabsdet = transform(inputs) 71 | self.assert_tensor_is_good(outputs, [batch_size] + [np.prod(shape)]) 72 | self.assert_tensor_is_good(logabsdet, [batch_size]) 73 | 74 | def test_forward_bad_shape(self): 75 | shape = (8,) 76 | with self.assertRaises(ValueError): 77 | transform = self.create_transform(shape) 78 | 79 | def test_forward_bad_split_dim(self): 80 | batch_size = 5 81 | shape = [32] 82 | inputs = torch.randn(batch_size, *shape) 83 | with self.assertRaises(ValueError): 84 | transform = self.create_transform(shape, split_dim=2) 85 | 86 | def test_inverse_not_flat(self): 87 | batch_size = 5 88 | shape = [32, 4, 4] 89 | inputs = torch.randn(batch_size, *shape) 90 | transform = self.create_transform(shape) 91 | with self.assertRaises(ValueError): 92 | transform.inverse(inputs) 93 | 94 | def test_forward_inverse_are_consistent(self): 95 | batch_size = 5 96 | for shape in [(32,4,4), 97 | (64,), 98 | (65,), 99 | (21,)]: 100 | with self.subTest(shape=shape): 101 | transform = self.create_transform(shape) 102 | inputs = torch.randn(batch_size, *shape).view(batch_size, -1) 103 | self.assert_forward_inverse_are_consistent(transform, inputs) 104 | 105 | class InverseTransformTest(TransformTest): 106 | 107 | def test_forward(self): 108 | batch_size = 10 109 | shape = [2, 3, 4] 110 | inputs = torch.randn(batch_size, *shape) 111 | transform = base.InverseTransform(standard.AffineScalarTransform(scale=-2.0)) 112 | reference = standard.AffineScalarTransform(scale=-0.5) 113 | outputs, logabsdet = transform(inputs) 114 | outputs_ref, logabsdet_ref = reference(inputs) 115 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 116 | self.assert_tensor_is_good(logabsdet, [batch_size]) 117 | self.assertEqual(outputs, outputs_ref) 118 | self.assertEqual(logabsdet, logabsdet_ref) 119 | 120 | def test_inverse(self): 121 | batch_size = 10 122 | shape = [2, 3, 4] 123 | inputs = torch.randn(batch_size, *shape) 124 | transform = base.InverseTransform(standard.AffineScalarTransform(scale=-2.0)) 125 | reference = standard.AffineScalarTransform(scale=-0.5) 126 | outputs, logabsdet = transform.inverse(inputs) 127 | outputs_ref, logabsdet_ref = reference.inverse(inputs) 128 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 129 | self.assert_tensor_is_good(logabsdet, [batch_size]) 130 | self.assertEqual(outputs, outputs_ref) 131 | self.assertEqual(logabsdet, logabsdet_ref) 132 | 133 | 134 | if __name__ == '__main__': 135 | unittest.main() 136 | -------------------------------------------------------------------------------- /nde/transforms/conv.py: -------------------------------------------------------------------------------- 1 | import utils 2 | from nde import transforms 3 | 4 | 5 | class OneByOneConvolution(transforms.LULinear): 6 | """An invertible 1x1 convolution with a fixed permutation, as introduced in the Glow paper. 7 | 8 | Reference: 9 | > D. Kingma et. al., Glow: Generative flow with invertible 1x1 convolutions, NeurIPS 2018. 10 | """ 11 | def __init__(self, 12 | num_channels, 13 | using_cache=False, 14 | identity_init=True): 15 | super().__init__(num_channels, using_cache, identity_init) 16 | self.permutation = transforms.RandomPermutation(num_channels, dim=1) 17 | 18 | def _lu_forward_inverse(self, inputs, inverse=False): 19 | b, c, h, w = inputs.shape 20 | inputs = inputs.permute(0, 2, 3, 1).reshape(b*h*w, c) 21 | 22 | if inverse: 23 | outputs, logabsdet = super().inverse(inputs) 24 | else: 25 | outputs, logabsdet = super().forward(inputs) 26 | 27 | outputs = outputs.reshape(b, h, w, c).permute(0, 3, 1, 2) 28 | logabsdet = logabsdet.reshape(b, h, w) 29 | 30 | return outputs, utils.sum_except_batch(logabsdet) 31 | 32 | def forward(self, inputs, context=None): 33 | if inputs.dim() != 4: 34 | raise ValueError('Inputs must be a 4D tensor.') 35 | 36 | inputs, _ = self.permutation(inputs) 37 | 38 | return self._lu_forward_inverse(inputs, inverse=False) 39 | 40 | def inverse(self, inputs, context=None): 41 | if inputs.dim() != 4: 42 | raise ValueError('Inputs must be a 4D tensor.') 43 | 44 | outputs, logabsdet = self._lu_forward_inverse(inputs, inverse=True) 45 | 46 | outputs, _ = self.permutation.inverse(outputs) 47 | 48 | return outputs, logabsdet 49 | -------------------------------------------------------------------------------- /nde/transforms/conv_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from nde import transforms 5 | from nde.transforms.transform_test import TransformTest 6 | 7 | 8 | class OneByOneConvolutionTest(TransformTest): 9 | def test_forward_and_inverse_are_consistent(self): 10 | batch_size = 10 11 | c, h, w = 3, 28, 28 12 | inputs = torch.randn(batch_size, c, h, w) 13 | transform = transforms.OneByOneConvolution(c) 14 | self.eps = 1e-6 15 | self.assert_forward_inverse_are_consistent(transform, inputs) 16 | 17 | 18 | if __name__ == '__main__': 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /nde/transforms/lu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch import nn 5 | from torch.nn import functional as F, init 6 | 7 | from nde.transforms.linear import Linear 8 | 9 | 10 | class LULinear(Linear): 11 | """A linear transform where we parameterize the LU decomposition of the weights.""" 12 | 13 | def __init__(self, features, using_cache=False, identity_init=True, eps=1e-3): 14 | super().__init__(features, using_cache) 15 | 16 | self.eps = eps 17 | 18 | self.lower_indices = np.tril_indices(features, k=-1) 19 | self.upper_indices = np.triu_indices(features, k=1) 20 | self.diag_indices = np.diag_indices(features) 21 | 22 | n_triangular_entries = ((features - 1) * features) // 2 23 | 24 | self.lower_entries = nn.Parameter(torch.zeros(n_triangular_entries)) 25 | self.upper_entries = nn.Parameter(torch.zeros(n_triangular_entries)) 26 | self.unconstrained_upper_diag = nn.Parameter(torch.zeros(features)) 27 | 28 | self._initialize(identity_init) 29 | 30 | def _initialize(self, identity_init): 31 | init.zeros_(self.bias) 32 | 33 | if identity_init: 34 | init.zeros_(self.lower_entries) 35 | init.zeros_(self.upper_entries) 36 | constant = np.log(np.exp(1 - self.eps) - 1) 37 | init.constant_(self.unconstrained_upper_diag, constant) 38 | else: 39 | stdv = 1.0 / np.sqrt(self.features) 40 | init.uniform_(self.lower_entries, -stdv, stdv) 41 | init.uniform_(self.upper_entries, -stdv, stdv) 42 | init.uniform_(self.unconstrained_upper_diag, -stdv, stdv) 43 | 44 | def _create_lower_upper(self): 45 | lower = self.lower_entries.new_zeros(self.features, self.features) 46 | lower[self.lower_indices[0], self.lower_indices[1]] = self.lower_entries 47 | # The diagonal of L is taken to be all-ones without loss of generality. 48 | lower[self.diag_indices[0], self.diag_indices[1]] = 1. 49 | 50 | upper = self.upper_entries.new_zeros(self.features, self.features) 51 | upper[self.upper_indices[0], self.upper_indices[1]] = self.upper_entries 52 | upper[self.diag_indices[0], self.diag_indices[1]] = self.upper_diag 53 | 54 | return lower, upper 55 | 56 | def forward_no_cache(self, inputs): 57 | """Cost: 58 | output = O(D^2N) 59 | logabsdet = O(D) 60 | where: 61 | D = num of features 62 | N = num of inputs 63 | """ 64 | lower, upper = self._create_lower_upper() 65 | outputs = F.linear(inputs, upper) 66 | outputs = F.linear(outputs, lower, self.bias) 67 | logabsdet = self.logabsdet() * inputs.new_ones(outputs.shape[0]) 68 | return outputs, logabsdet 69 | 70 | def inverse_no_cache(self, inputs): 71 | """Cost: 72 | output = O(D^2N) 73 | logabsdet = O(D) 74 | where: 75 | D = num of features 76 | N = num of inputs 77 | """ 78 | lower, upper = self._create_lower_upper() 79 | outputs = inputs - self.bias 80 | outputs, _ = torch.triangular_solve(outputs.t(), lower, upper=False, unitriangular=True) 81 | outputs, _ = torch.triangular_solve(outputs, upper, upper=True, unitriangular=False) 82 | outputs = outputs.t() 83 | 84 | logabsdet = -self.logabsdet() 85 | logabsdet = logabsdet * inputs.new_ones(outputs.shape[0]) 86 | 87 | return outputs, logabsdet 88 | 89 | def weight(self): 90 | """Cost: 91 | weight = O(D^3) 92 | where: 93 | D = num of features 94 | """ 95 | lower, upper = self._create_lower_upper() 96 | return lower @ upper 97 | 98 | def weight_inverse(self): 99 | """Cost: 100 | inverse = O(D^3) 101 | where: 102 | D = num of features 103 | """ 104 | lower, upper = self._create_lower_upper() 105 | identity = torch.eye(self.features, self.features) 106 | lower_inverse, _ = torch.trtrs(identity, lower, upper=False, unitriangular=True) 107 | weight_inverse, _ = torch.trtrs(lower_inverse, upper, upper=True, unitriangular=False) 108 | return weight_inverse 109 | 110 | @property 111 | def upper_diag(self): 112 | return F.softplus(self.unconstrained_upper_diag) + self.eps 113 | 114 | def logabsdet(self): 115 | """Cost: 116 | logabsdet = O(D) 117 | where: 118 | D = num of features 119 | """ 120 | return torch.sum(torch.log(self.upper_diag)) -------------------------------------------------------------------------------- /nde/transforms/lu_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the LU linear transforms.""" 2 | 3 | import torch 4 | import unittest 5 | 6 | import utils 7 | 8 | from nde.transforms import lu 9 | from nde.transforms.transform_test import TransformTest 10 | 11 | 12 | class LULinearTest(TransformTest): 13 | 14 | def setUp(self): 15 | self.features = 3 16 | self.transform = lu.LULinear(features=self.features) 17 | 18 | lower, upper = self.transform._create_lower_upper() 19 | self.weight = lower @ upper 20 | self.weight_inverse = torch.inverse(self.weight) 21 | self.logabsdet = utils.logabsdet(self.weight) 22 | 23 | self.eps = 1e-5 24 | 25 | def test_forward_no_cache(self): 26 | batch_size = 10 27 | inputs = torch.randn(batch_size, self.features) 28 | outputs, logabsdet = self.transform.forward_no_cache(inputs) 29 | 30 | outputs_ref = inputs @ self.weight.t() + self.transform.bias 31 | logabsdet_ref = torch.full([batch_size], self.logabsdet.item()) 32 | 33 | self.assert_tensor_is_good(outputs, [batch_size, self.features]) 34 | self.assert_tensor_is_good(logabsdet, [batch_size]) 35 | 36 | self.assertEqual(outputs, outputs_ref) 37 | self.assertEqual(logabsdet, logabsdet_ref) 38 | 39 | def test_inverse_no_cache(self): 40 | batch_size = 10 41 | inputs = torch.randn(batch_size, self.features) 42 | outputs, logabsdet = self.transform.inverse_no_cache(inputs) 43 | 44 | outputs_ref = (inputs - self.transform.bias) @ self.weight_inverse.t() 45 | logabsdet_ref = torch.full([batch_size], -self.logabsdet.item()) 46 | 47 | self.assert_tensor_is_good(outputs, [batch_size, self.features]) 48 | self.assert_tensor_is_good(logabsdet, [batch_size]) 49 | 50 | self.assertEqual(outputs, outputs_ref) 51 | self.assertEqual(logabsdet, logabsdet_ref) 52 | 53 | def test_weight(self): 54 | weight = self.transform.weight() 55 | self.assert_tensor_is_good(weight, [self.features, self.features]) 56 | self.assertEqual(weight, self.weight) 57 | 58 | def test_weight_inverse(self): 59 | weight_inverse = self.transform.weight_inverse() 60 | self.assert_tensor_is_good(weight_inverse, [self.features, self.features]) 61 | self.assertEqual(weight_inverse, self.weight_inverse) 62 | 63 | def test_logabsdet(self): 64 | logabsdet = self.transform.logabsdet() 65 | self.assert_tensor_is_good(logabsdet, []) 66 | self.assertEqual(logabsdet, self.logabsdet) 67 | 68 | def test_forward_inverse_are_consistent(self): 69 | batch_size = 10 70 | inputs = torch.randn(batch_size, self.features) 71 | self.assert_forward_inverse_are_consistent(self.transform, inputs) 72 | 73 | 74 | if __name__ == '__main__': 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /nde/transforms/made_test.py: -------------------------------------------------------------------------------- 1 | """Tests for MADE.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | 7 | from nde.transforms import made 8 | 9 | 10 | class ShapeTest(torchtestcase.TorchTestCase): 11 | 12 | def test_conditional(self): 13 | features = 100 14 | hidden_features = 200 15 | num_blocks = 5 16 | output_multiplier = 3 17 | conditional_features = 50 18 | batch_size = 16 19 | 20 | inputs = torch.randn(batch_size, features) 21 | conditional_inputs = torch.randn(batch_size, conditional_features) 22 | 23 | for use_residual_blocks, random_mask in [(False, False), 24 | (False, True), 25 | (True, False)]: 26 | with self.subTest(use_residual_blocks=use_residual_blocks, 27 | random_mask=random_mask): 28 | model = made.MADE( 29 | features=features, 30 | hidden_features=hidden_features, 31 | num_blocks=num_blocks, 32 | output_multiplier=output_multiplier, 33 | context_features=conditional_features, 34 | use_residual_blocks=use_residual_blocks, 35 | random_mask=random_mask, 36 | ) 37 | outputs = model(inputs, conditional_inputs) 38 | self.assertEqual(outputs.dim(), 2) 39 | self.assertEqual(outputs.shape[0], batch_size) 40 | self.assertEqual(outputs.shape[1], output_multiplier * features) 41 | 42 | def test_unconditional(self): 43 | features = 100 44 | hidden_features = 200 45 | num_blocks = 5 46 | output_multiplier = 3 47 | batch_size = 16 48 | 49 | inputs = torch.randn(batch_size, features) 50 | 51 | for use_residual_blocks, random_mask in [(False, False), 52 | (False, True), 53 | (True, False)]: 54 | with self.subTest(use_residual_blocks=use_residual_blocks, 55 | random_mask=random_mask): 56 | model = made.MADE( 57 | features=features, 58 | hidden_features=hidden_features, 59 | num_blocks=num_blocks, 60 | output_multiplier=output_multiplier, 61 | use_residual_blocks=use_residual_blocks, 62 | random_mask=random_mask, 63 | ) 64 | outputs = model(inputs) 65 | self.assertEqual(outputs.dim(), 2) 66 | self.assertEqual(outputs.shape[0], batch_size) 67 | self.assertEqual(outputs.shape[1], output_multiplier * features) 68 | 69 | 70 | class ConnectivityTest(torchtestcase.TorchTestCase): 71 | 72 | def test_gradients(self): 73 | features = 10 74 | hidden_features = 256 75 | num_blocks = 20 76 | output_multiplier = 3 77 | 78 | for use_residual_blocks, random_mask in [(False, False), 79 | (False, True), 80 | (True, False)]: 81 | with self.subTest(use_residual_blocks=use_residual_blocks, 82 | random_mask=random_mask): 83 | model = made.MADE( 84 | features=features, 85 | hidden_features=hidden_features, 86 | num_blocks=num_blocks, 87 | output_multiplier=output_multiplier, 88 | use_residual_blocks=use_residual_blocks, 89 | random_mask=random_mask, 90 | ) 91 | inputs = torch.randn(1, features) 92 | inputs.requires_grad = True 93 | for k in range(features * output_multiplier): 94 | outputs = model(inputs) 95 | outputs[0, k].backward() 96 | depends = inputs.grad.data[0] != 0.0 97 | dim = k // output_multiplier 98 | self.assertEqual(torch.all(depends[dim:] == 0), 1) 99 | 100 | def test_total_mask_sequential(self): 101 | features = 10 102 | hidden_features = 50 103 | num_blocks = 5 104 | output_multiplier = 1 105 | 106 | for use_residual_blocks in [True, False]: 107 | with self.subTest(use_residual_blocks=use_residual_blocks): 108 | model = made.MADE( 109 | features=features, 110 | hidden_features=hidden_features, 111 | num_blocks=num_blocks, 112 | output_multiplier=output_multiplier, 113 | use_residual_blocks=use_residual_blocks, 114 | random_mask=False, 115 | ) 116 | total_mask = model.initial_layer.mask 117 | for block in model.blocks: 118 | if use_residual_blocks: 119 | self.assertIsInstance(block, made.MaskedResidualBlock) 120 | total_mask = block.linear_layers[0].mask @ total_mask 121 | total_mask = block.linear_layers[1].mask @ total_mask 122 | else: 123 | self.assertIsInstance(block, made.MaskedFeedforwardBlock) 124 | total_mask = block.linear.mask @ total_mask 125 | total_mask = model.final_layer.mask @ total_mask 126 | total_mask = (total_mask > 0).float() 127 | reference = torch.tril(torch.ones([features, features]), -1) 128 | self.assertEqual(total_mask, reference) 129 | 130 | def test_total_mask_random(self): 131 | features = 10 132 | hidden_features = 50 133 | num_blocks = 5 134 | output_multiplier = 1 135 | 136 | model = made.MADE( 137 | features=features, 138 | hidden_features=hidden_features, 139 | num_blocks=num_blocks, 140 | output_multiplier=output_multiplier, 141 | use_residual_blocks=False, 142 | random_mask=True, 143 | ) 144 | total_mask = model.initial_layer.mask 145 | for block in model.blocks: 146 | self.assertIsInstance(block, made.MaskedFeedforwardBlock) 147 | total_mask = block.linear.mask @ total_mask 148 | total_mask = model.final_layer.mask @ total_mask 149 | total_mask = (total_mask > 0).float() 150 | self.assertEqual(torch.triu(total_mask), torch.zeros([features, features])) 151 | 152 | 153 | if __name__ == '__main__': 154 | unittest.main() 155 | -------------------------------------------------------------------------------- /nde/transforms/nonlinearities_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the invertible non-linearities.""" 2 | 3 | import torch 4 | import unittest 5 | 6 | import nde.transforms.splines.quadratic 7 | from nde.transforms import base, nonlinearities as nl, standard 8 | from nde.transforms.transform_test import TransformTest 9 | 10 | 11 | class TanhTest(TransformTest): 12 | def test_raises_domain_exception(self): 13 | shape = [2, 3, 4] 14 | transform = nl.Tanh() 15 | for value in [-2.0, -1.0, 1.0, 2.0]: 16 | with self.assertRaises(base.InputOutsideDomain): 17 | inputs = torch.full(shape, value) 18 | transform.inverse(inputs) 19 | 20 | 21 | class TestPiecewiseCDF(TransformTest): 22 | def setUp(self): 23 | self.shape = [2, 3, 4] 24 | self.batch_size = 10 25 | self.transforms = [nl.PiecewiseLinearCDF(self.shape), 26 | nl.PiecewiseQuadraticCDF(self.shape), 27 | nl.PiecewiseCubicCDF(self.shape), 28 | nl.PiecewiseRationalQuadraticCDF(self.shape)] 29 | 30 | def test_raises_domain_exception(self): 31 | for transform in self.transforms: 32 | with self.subTest(transform=transform): 33 | for value in [-1.0, -0.1, 1.1, 2.0]: 34 | with self.assertRaises(base.InputOutsideDomain): 35 | inputs = torch.full([self.batch_size] + self.shape, value) 36 | transform.forward(inputs) 37 | 38 | def test_zeros_to_zeros(self): 39 | for transform in self.transforms: 40 | with self.subTest(transform=transform): 41 | inputs = torch.zeros(self.batch_size, *self.shape) 42 | outputs, _ = transform(inputs) 43 | self.eps = 1e-5 44 | self.assertEqual(outputs, inputs) 45 | 46 | def test_ones_to_ones(self): 47 | for transform in self.transforms: 48 | with self.subTest(transform=transform): 49 | inputs = torch.ones(self.batch_size, *self.shape) 50 | outputs, _ = transform(inputs) 51 | self.eps = 1e-5 52 | self.assertEqual(outputs, inputs) 53 | 54 | def test_forward_inverse_are_consistent(self): 55 | for transform in self.transforms: 56 | with self.subTest(transform=transform): 57 | inputs = torch.rand(self.batch_size, *self.shape) 58 | self.eps = 1e-4 59 | self.assert_forward_inverse_are_consistent(transform, inputs) 60 | 61 | class TestUnconstrainedPiecewiseCDF(TransformTest): 62 | def test_forward_inverse_are_consistent(self): 63 | shape = [2,3,4] 64 | batch_size = 10 65 | transforms = [nl.PiecewiseLinearCDF(shape, tails='linear'), 66 | nl.PiecewiseQuadraticCDF(shape, tails='linear'), 67 | nl.PiecewiseCubicCDF(shape, tails='linear'), 68 | nl.PiecewiseRationalQuadraticCDF(shape, tails='linear')] 69 | 70 | for transform in transforms: 71 | with self.subTest(transform=transform): 72 | inputs = 3 * torch.randn(batch_size, *shape) 73 | self.eps = 1e-4 74 | self.assert_forward_inverse_are_consistent(transform, inputs) 75 | 76 | 77 | class LogitTest(TransformTest): 78 | def test_forward_zero_and_one(self): 79 | batch_size = 10 80 | shape = [5, 10, 15] 81 | inputs = torch.cat([torch.zeros(batch_size // 2, *shape), 82 | torch.ones(batch_size // 2, *shape)]) 83 | 84 | transform = nl.Logit() 85 | outputs, logabsdet = transform(inputs) 86 | 87 | self.assert_tensor_is_good(outputs) 88 | self.assert_tensor_is_good(logabsdet) 89 | 90 | 91 | class NonlinearitiesTest(TransformTest): 92 | 93 | def test_forward(self): 94 | batch_size = 10 95 | shape = [5, 10, 15] 96 | inputs = torch.rand(batch_size, *shape) 97 | transforms = [ 98 | nl.Tanh(), 99 | nl.LogTanh(), 100 | nl.LeakyReLU(), 101 | nl.Sigmoid(), 102 | nl.Logit(), 103 | nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()) 104 | ] 105 | for transform in transforms: 106 | with self.subTest(transform=transform): 107 | outputs, logabsdet = transform(inputs) 108 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 109 | self.assert_tensor_is_good(logabsdet, [batch_size]) 110 | 111 | def test_inverse(self): 112 | batch_size = 10 113 | shape = [5, 10, 15] 114 | inputs = torch.rand(batch_size, *shape) 115 | transforms = [ 116 | nl.Tanh(), 117 | nl.LogTanh(), 118 | nl.LeakyReLU(), 119 | nl.Sigmoid(), 120 | nl.Logit(), 121 | nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()) 122 | ] 123 | for transform in transforms: 124 | with self.subTest(transform=transform): 125 | outputs, logabsdet = transform.inverse(inputs) 126 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 127 | self.assert_tensor_is_good(logabsdet, [batch_size]) 128 | 129 | def test_forward_inverse_are_consistent(self): 130 | batch_size = 10 131 | shape = [5, 10, 15] 132 | inputs = torch.rand(batch_size, *shape) 133 | transforms = [ 134 | nl.Tanh(), 135 | nl.LogTanh(), 136 | nl.LeakyReLU(), 137 | nl.Sigmoid(), 138 | nl.Logit(), 139 | nl.CompositeCDFTransform(nl.Sigmoid(), standard.IdentityTransform()) 140 | ] 141 | self.eps = 1e-3 142 | for transform in transforms: 143 | with self.subTest(transform=transform): 144 | self.assert_forward_inverse_are_consistent(transform, inputs) 145 | 146 | 147 | if __name__ == '__main__': 148 | unittest.main() 149 | -------------------------------------------------------------------------------- /nde/transforms/normalization_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the normalization-based transforms.""" 2 | 3 | import torch 4 | import unittest 5 | 6 | from nde.transforms import base 7 | from nde.transforms import normalization as norm 8 | from nde.transforms.transform_test import TransformTest 9 | 10 | 11 | class BatchNormTest(TransformTest): 12 | 13 | def test_forward(self): 14 | features = 100 15 | batch_size = 50 16 | bn_eps = 1e-5 17 | self.eps = 1e-4 18 | 19 | for affine in [True, True]: 20 | with self.subTest(affine=affine): 21 | inputs = torch.randn(batch_size, features) 22 | transform = norm.BatchNorm(features=features, affine=affine, eps=bn_eps) 23 | 24 | outputs, logabsdet = transform(inputs) 25 | self.assert_tensor_is_good(outputs, [batch_size, features]) 26 | self.assert_tensor_is_good(logabsdet, [batch_size]) 27 | 28 | mean, var = inputs.mean(0), inputs.var(0) 29 | outputs_ref = (inputs - mean) / torch.sqrt(var + bn_eps) 30 | logabsdet_ref = torch.sum(torch.log(1.0 / torch.sqrt(var + bn_eps))) 31 | logabsdet_ref = torch.full([batch_size], logabsdet_ref.item()) 32 | if affine: 33 | outputs_ref *= transform.weight 34 | outputs_ref += transform.bias 35 | logabsdet_ref += torch.sum(torch.log(transform.weight)) 36 | self.assert_tensor_is_good(outputs_ref, [batch_size, features]) 37 | self.assert_tensor_is_good(logabsdet_ref, [batch_size]) 38 | print(outputs, outputs_ref) 39 | self.assertEqual(outputs, outputs_ref) 40 | self.assertEqual(logabsdet, logabsdet_ref) 41 | 42 | transform.eval() 43 | outputs, logabsdet = transform(inputs) 44 | self.assert_tensor_is_good(outputs, [batch_size, features]) 45 | self.assert_tensor_is_good(logabsdet, [batch_size]) 46 | 47 | mean = transform.running_mean 48 | var = transform.running_var 49 | outputs_ref = (inputs - mean) / torch.sqrt(var + bn_eps) 50 | logabsdet_ref = torch.sum(torch.log(1.0 / torch.sqrt(var + bn_eps))) 51 | logabsdet_ref = torch.full([batch_size], logabsdet_ref.item()) 52 | if affine: 53 | outputs_ref *= transform.weight 54 | outputs_ref += transform.bias 55 | logabsdet_ref += torch.sum(torch.log(transform.weight)) 56 | self.assert_tensor_is_good(outputs_ref, [batch_size, features]) 57 | self.assert_tensor_is_good(logabsdet_ref, [batch_size]) 58 | self.assertEqual(outputs, outputs_ref) 59 | self.assertEqual(logabsdet, logabsdet_ref) 60 | 61 | def test_inverse(self): 62 | features = 100 63 | batch_size = 50 64 | inputs = torch.randn(batch_size, features) 65 | 66 | for affine in [True, False]: 67 | with self.subTest(affine=affine): 68 | transform = norm.BatchNorm(features=features, affine=affine) 69 | with self.assertRaises(base.InverseNotAvailable): 70 | transform.inverse(inputs) 71 | transform.eval() 72 | outputs, logabsdet = transform.inverse(inputs) 73 | self.assert_tensor_is_good(outputs, [batch_size, features]) 74 | self.assert_tensor_is_good(logabsdet, [batch_size]) 75 | 76 | def test_forward_inverse_are_consistent(self): 77 | features = 100 78 | batch_size = 50 79 | inputs = torch.randn(batch_size, features) 80 | transforms = [ 81 | norm.BatchNorm(features=features, affine=affine) 82 | for affine in [True, False] 83 | ] 84 | self.eps = 1e-6 85 | for transform in transforms: 86 | with self.subTest(transform=transform): 87 | transform.eval() 88 | self.assert_forward_inverse_are_consistent(transform, inputs) 89 | 90 | 91 | class ActNormTest(TransformTest): 92 | 93 | def test_forward(self): 94 | batch_size = 50 95 | for shape in [(100,), 96 | (32,8,8)]: 97 | with self.subTest(shape=shape): 98 | inputs = torch.randn(batch_size, *shape) 99 | transform = norm.ActNorm(shape[0]) 100 | 101 | outputs, logabsdet = transform.forward(inputs) 102 | self.assert_tensor_is_good(outputs, [batch_size] + list(shape)) 103 | self.assert_tensor_is_good(logabsdet, [batch_size]) 104 | 105 | def test_inverse(self): 106 | batch_size = 50 107 | for shape in [(100,), 108 | (32,8,8)]: 109 | with self.subTest(shape=shape): 110 | inputs = torch.randn(batch_size, *shape) 111 | transform = norm.ActNorm(shape[0]) 112 | 113 | outputs, logabsdet = transform.inverse(inputs) 114 | self.assert_tensor_is_good(outputs, [batch_size] + list(shape)) 115 | self.assert_tensor_is_good(logabsdet, [batch_size]) 116 | 117 | def test_forward_inverse_are_consistent(self): 118 | batch_size = 50 119 | for shape in [(100,), 120 | (32,8,8)]: 121 | with self.subTest(shape=shape): 122 | inputs = torch.randn(batch_size, *shape) 123 | transform = norm.ActNorm(shape[0]) 124 | transform.forward(inputs) # One forward pass to initialize 125 | self.eps = 1e-6 126 | self.assert_forward_inverse_are_consistent(transform, inputs) 127 | 128 | 129 | if __name__ == '__main__': 130 | unittest.main() 131 | -------------------------------------------------------------------------------- /nde/transforms/orthogonal.py: -------------------------------------------------------------------------------- 1 | """Implementations of orthogonal transforms.""" 2 | 3 | import torch 4 | 5 | from torch import nn 6 | 7 | import utils 8 | 9 | from nde import transforms 10 | 11 | 12 | class HouseholderSequence(transforms.Transform): 13 | """A sequence of Householder transforms. 14 | 15 | This class can be used as a way of parameterizing an orthogonal matrix. 16 | """ 17 | 18 | def __init__(self, features, num_transforms): 19 | """Constructor. 20 | 21 | Args: 22 | features: int, dimensionality of the input. 23 | num_transforms: int, number of Householder transforms to use. 24 | 25 | Raises: 26 | TypeError: if arguments are not the right type. 27 | """ 28 | if not utils.is_positive_int(features): 29 | raise TypeError('Number of features must be a positive integer.') 30 | if not utils.is_positive_int(num_transforms): 31 | raise TypeError('Number of transforms must be a positive integer.') 32 | 33 | super().__init__() 34 | self.features = features 35 | self.num_transforms = num_transforms 36 | # TODO: are randn good initial values? 37 | self.q_vectors = nn.Parameter(torch.randn(num_transforms, features)) 38 | 39 | @staticmethod 40 | def _apply_transforms(inputs, q_vectors): 41 | """Apply the sequence of transforms parameterized by given q_vectors to inputs. 42 | 43 | Costs O(KDN), where: 44 | - K is number of transforms 45 | - D is dimensionality of inputs 46 | - N is number of inputs 47 | 48 | Args: 49 | inputs: Tensor of shape [N, D] 50 | q_vectors: Tensor of shape [K, D] 51 | 52 | Returns: 53 | A tuple of: 54 | - A Tensor of shape [N, D], the outputs. 55 | - A Tensor of shape [N], the log absolute determinants of the total transform. 56 | """ 57 | squared_norms = torch.sum(q_vectors ** 2, dim=-1) 58 | outputs = inputs 59 | for q_vector, squared_norm in zip(q_vectors, squared_norms): 60 | temp = outputs @ q_vector # Inner product. 61 | temp = torch.ger(temp, (2.0 / squared_norm) * q_vector) # Outer product. 62 | outputs = outputs - temp 63 | batch_size = inputs.shape[0] 64 | logabsdet = torch.zeros(batch_size) 65 | return outputs, logabsdet 66 | 67 | def forward(self, inputs, context=None): 68 | return self._apply_transforms(inputs, self.q_vectors) 69 | 70 | def inverse(self, inputs, context=None): 71 | # Each householder transform is its own inverse, so the total inverse is given by 72 | # simply performing each transform in the reverse order. 73 | reverse_idx = torch.arange(self.num_transforms - 1, -1, -1) 74 | return self._apply_transforms(inputs, self.q_vectors[reverse_idx]) 75 | 76 | def matrix(self): 77 | """Returns the orthogonal matrix that is equivalent to the total transform. 78 | 79 | Costs O(KD^2), where: 80 | - K is number of transforms 81 | - D is dimensionality of inputs 82 | 83 | Returns: 84 | A Tensor of shape [D, D]. 85 | """ 86 | identity = torch.eye(self.features, self.features) 87 | outputs, _ = self.inverse(identity) 88 | return outputs 89 | -------------------------------------------------------------------------------- /nde/transforms/orthogonal_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the orthogonal transforms.""" 2 | 3 | import torch 4 | import unittest 5 | 6 | import utils 7 | 8 | from nde.transforms import orthogonal 9 | from nde.transforms.transform_test import TransformTest 10 | 11 | 12 | class HouseholderSequenceTest(TransformTest): 13 | 14 | def test_forward(self): 15 | features = 100 16 | batch_size = 50 17 | 18 | for num_transforms in [1, 2, 11, 12]: 19 | with self.subTest(num_transforms=num_transforms): 20 | transform = orthogonal.HouseholderSequence( 21 | features=features, num_transforms=num_transforms) 22 | matrix = transform.matrix() 23 | inputs = torch.randn(batch_size, features) 24 | outputs, logabsdet = transform(inputs) 25 | self.assert_tensor_is_good(outputs, [batch_size, features]) 26 | self.assert_tensor_is_good(logabsdet, [batch_size]) 27 | self.eps = 1e-5 28 | self.assertEqual(outputs, inputs @ matrix.t()) 29 | self.assertEqual(logabsdet, utils.logabsdet(matrix) * torch.ones(batch_size)) 30 | 31 | def test_inverse(self): 32 | features = 100 33 | batch_size = 50 34 | 35 | for num_transforms in [1, 2, 11, 12]: 36 | with self.subTest(num_transforms=num_transforms): 37 | transform = orthogonal.HouseholderSequence( 38 | features=features, num_transforms=num_transforms) 39 | matrix = transform.matrix() 40 | inputs = torch.randn(batch_size, features) 41 | outputs, logabsdet = transform.inverse(inputs) 42 | self.assert_tensor_is_good(outputs, [batch_size, features]) 43 | self.assert_tensor_is_good(logabsdet, [batch_size]) 44 | self.eps = 1e-5 45 | self.assertEqual(outputs, inputs @ matrix) 46 | self.assertEqual(logabsdet, utils.logabsdet(matrix) * torch.ones(batch_size)) 47 | 48 | def test_matrix(self): 49 | features = 100 50 | 51 | for num_transforms in [1, 2, 11, 12]: 52 | with self.subTest(num_transforms=num_transforms): 53 | transform = orthogonal.HouseholderSequence( 54 | features=features, num_transforms=num_transforms) 55 | matrix = transform.matrix() 56 | self.assert_tensor_is_good(matrix, [features, features]) 57 | self.eps = 1e-5 58 | self.assertEqual(matrix @ matrix.t(), torch.eye(features, features)) 59 | self.assertEqual(matrix.t() @ matrix, torch.eye(features, features)) 60 | self.assertEqual(matrix.t(), torch.inverse(matrix)) 61 | det_ref = torch.tensor(1.0 if num_transforms % 2 == 0 else -1.0) 62 | self.assertEqual(matrix.det(), det_ref) 63 | 64 | def test_forward_inverse_are_consistent(self): 65 | features = 100 66 | batch_size = 50 67 | inputs = torch.randn(batch_size, features) 68 | transforms = [ 69 | orthogonal.HouseholderSequence( 70 | features=features, num_transforms=num_transforms) 71 | for num_transforms in [1, 2, 11, 12] 72 | ] 73 | self.eps = 1e-5 74 | for transform in transforms: 75 | with self.subTest(transform=transform): 76 | self.assert_forward_inverse_are_consistent(transform, inputs) 77 | 78 | 79 | if __name__ == '__main__': 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /nde/transforms/permutations.py: -------------------------------------------------------------------------------- 1 | """Implementations of permutation-like transforms.""" 2 | 3 | import torch 4 | import utils 5 | 6 | from nde import transforms 7 | 8 | 9 | class Permutation(transforms.Transform): 10 | """Permutes inputs on a given dimension using a given permutation.""" 11 | 12 | def __init__(self, permutation, dim=1): 13 | if permutation.ndimension() != 1: 14 | raise ValueError('Permutation must be a 1D tensor.') 15 | if not utils.is_positive_int(dim): 16 | raise ValueError('dim must be a positive integer.') 17 | 18 | super().__init__() 19 | self._dim = dim 20 | self.register_buffer('_permutation', permutation) 21 | 22 | @property 23 | def _inverse_permutation(self): 24 | return torch.argsort(self._permutation) 25 | 26 | @staticmethod 27 | def _permute(inputs, permutation, dim): 28 | if dim >= inputs.ndimension(): 29 | raise ValueError("No dimension {} in inputs.".format(dim)) 30 | if inputs.shape[dim] != len(permutation): 31 | raise ValueError("Dimension {} in inputs must be of size {}." 32 | .format(dim, len(permutation))) 33 | batch_size = inputs.shape[0] 34 | outputs = torch.index_select(inputs, dim, permutation) 35 | logabsdet = torch.zeros(batch_size) 36 | return outputs, logabsdet 37 | 38 | def forward(self, inputs, context=None): 39 | return self._permute(inputs, self._permutation, self._dim) 40 | 41 | def inverse(self, inputs, context=None): 42 | return self._permute(inputs, self._inverse_permutation, self._dim) 43 | 44 | 45 | class RandomPermutation(Permutation): 46 | """Permutes using a random, but fixed, permutation. Only works with 1D inputs.""" 47 | 48 | def __init__(self, features, dim=1): 49 | if not utils.is_positive_int(features): 50 | raise ValueError('Number of features must be a positive integer.') 51 | super().__init__(torch.randperm(features), dim) 52 | 53 | 54 | class ReversePermutation(Permutation): 55 | """Reverses the elements of the input. Only works with 1D inputs.""" 56 | 57 | def __init__(self, features, dim=1): 58 | if not utils.is_positive_int(features): 59 | raise ValueError('Number of features must be a positive integer.') 60 | super().__init__(torch.arange(features - 1, -1, -1), dim) 61 | -------------------------------------------------------------------------------- /nde/transforms/permutations_test.py: -------------------------------------------------------------------------------- 1 | """Tests for permutations.""" 2 | 3 | import torch 4 | import unittest 5 | 6 | from nde.transforms import permutations 7 | from nde.transforms.transform_test import TransformTest 8 | 9 | 10 | class PermutationTest(TransformTest): 11 | 12 | def test_forward(self): 13 | batch_size = 10 14 | features = 100 15 | inputs = torch.randn(batch_size, features) 16 | permutation = torch.randperm(features) 17 | transform = permutations.Permutation(permutation) 18 | outputs, logabsdet = transform(inputs) 19 | self.assert_tensor_is_good(outputs, [batch_size, features]) 20 | self.assert_tensor_is_good(logabsdet, [batch_size]) 21 | self.assertEqual(outputs, inputs[:, permutation]) 22 | self.assertEqual(logabsdet, torch.zeros([batch_size])) 23 | 24 | def test_inverse(self): 25 | batch_size = 10 26 | features = 100 27 | inputs = torch.randn(batch_size, features) 28 | permutation = torch.randperm(features) 29 | transform = permutations.Permutation(permutation) 30 | temp, _ = transform(inputs) 31 | outputs, logabsdet = transform.inverse(temp) 32 | self.assert_tensor_is_good(outputs, [batch_size, features]) 33 | self.assert_tensor_is_good(logabsdet, [batch_size]) 34 | self.assertEqual(outputs, inputs) 35 | self.assertEqual(logabsdet, torch.zeros([batch_size])) 36 | 37 | def test_forward_inverse_are_consistent(self): 38 | batch_size = 10 39 | features = 100 40 | inputs = torch.randn(batch_size, features) 41 | transforms = [ 42 | permutations.Permutation(torch.randperm(features)), 43 | permutations.RandomPermutation(features), 44 | permutations.ReversePermutation(features), 45 | ] 46 | for transform in transforms: 47 | with self.subTest(transform=transform): 48 | self.assert_forward_inverse_are_consistent(transform, inputs) 49 | 50 | 51 | if __name__ == '__main__': 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /nde/transforms/qr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch import nn 5 | from torch.nn import functional as F, init 6 | 7 | from nde import transforms 8 | from nde.transforms.linear import Linear 9 | 10 | 11 | class QRLinear(Linear): 12 | """A linear module using the QR decomposition for the weight matrix.""" 13 | 14 | def __init__(self, features, num_householder, using_cache=False): 15 | super().__init__(features, using_cache) 16 | 17 | # Parameterization for R 18 | self.upper_indices = np.triu_indices(features, k=1) 19 | self.diag_indices = np.diag_indices(features) 20 | n_triangular_entries = ((features - 1) * features) // 2 21 | self.upper_entries = nn.Parameter(torch.zeros(n_triangular_entries)) 22 | self.log_upper_diag = nn.Parameter(torch.zeros(features)) 23 | 24 | # Parameterization for Q 25 | self.orthogonal = transforms.HouseholderSequence( 26 | features=features, num_transforms=num_householder) 27 | 28 | self._initialize() 29 | 30 | def _initialize(self): 31 | stdv = 1.0 / np.sqrt(self.features) 32 | init.uniform_(self.upper_entries, -stdv, stdv) 33 | init.uniform_(self.log_upper_diag, -stdv, stdv) 34 | init.constant_(self.bias, 0.0) 35 | 36 | def _create_upper(self): 37 | upper = torch.zeros(self.features, self.features) 38 | upper[self.upper_indices[0], self.upper_indices[1]] = self.upper_entries 39 | upper[self.diag_indices[0], self.diag_indices[1]] = torch.exp(self.log_upper_diag) 40 | return upper 41 | 42 | def forward_no_cache(self, inputs): 43 | """Cost: 44 | output = O(D^2N + KDN) 45 | logabsdet = O(D) 46 | where: 47 | K = num of householder transforms 48 | D = num of features 49 | N = num of inputs 50 | """ 51 | upper = self._create_upper() 52 | 53 | outputs = F.linear(inputs, upper) 54 | outputs, _ = self.orthogonal(outputs) # Ignore logabsdet as we know it's zero. 55 | outputs += self.bias 56 | 57 | logabsdet = self.logabsdet() * torch.ones(outputs.shape[0]) 58 | 59 | return outputs, logabsdet 60 | 61 | def inverse_no_cache(self, inputs): 62 | """Cost: 63 | output = O(D^2N + KDN) 64 | logabsdet = O(D) 65 | where: 66 | K = num of householder transforms 67 | D = num of features 68 | N = num of inputs 69 | """ 70 | upper = self._create_upper() 71 | outputs = inputs - self.bias 72 | outputs, _ = self.orthogonal.inverse(outputs) # Ignore logabsdet since we know it's zero. 73 | outputs, _ = torch.trtrs(outputs.t(), upper, upper=True) 74 | outputs = outputs.t() 75 | logabsdet = -self.logabsdet() 76 | logabsdet = logabsdet * torch.ones(outputs.shape[0]) 77 | return outputs, logabsdet 78 | 79 | def weight(self): 80 | """Cost: 81 | weight = O(KD^2) 82 | where: 83 | K = num of householder transforms 84 | D = num of features 85 | """ 86 | upper = self._create_upper() 87 | weight, _ = self.orthogonal(upper.t()) 88 | return weight.t() 89 | 90 | def weight_inverse(self): 91 | """Cost: 92 | inverse = O(D^3 + KD^2) 93 | where: 94 | K = num of householder transforms 95 | D = num of features 96 | """ 97 | upper = self._create_upper() 98 | identity = torch.eye(self.features, self.features) 99 | upper_inv, _ = torch.trtrs(identity, upper, upper=True) 100 | weight_inv, _ = self.orthogonal(upper_inv) 101 | return weight_inv 102 | 103 | def logabsdet(self): 104 | """Cost: 105 | logabsdet = O(D) 106 | where: 107 | D = num of features 108 | """ 109 | return torch.sum(self.log_upper_diag) 110 | -------------------------------------------------------------------------------- /nde/transforms/qr_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | import utils 5 | 6 | from nde.transforms import qr 7 | from nde.transforms.transform_test import TransformTest 8 | 9 | 10 | class QRLinearTest(TransformTest): 11 | 12 | def setUp(self): 13 | self.features = 3 14 | self.transform = qr.QRLinear(features=self.features, num_householder=4) 15 | 16 | upper = self.transform._create_upper() 17 | orthogonal = self.transform.orthogonal.matrix() 18 | self.weight = orthogonal @ upper 19 | self.weight_inverse = torch.inverse(self.weight) 20 | self.logabsdet = utils.logabsdet(self.weight) 21 | 22 | self.eps = 1e-5 23 | 24 | def test_forward_no_cache(self): 25 | batch_size = 10 26 | inputs = torch.randn(batch_size, self.features) 27 | outputs, logabsdet = self.transform.forward_no_cache(inputs) 28 | 29 | outputs_ref = torch.matmul(inputs, self.weight.t()) + self.transform.bias 30 | logabsdet_ref = torch.full([batch_size], self.logabsdet.item()) 31 | 32 | self.assert_tensor_is_good(outputs, [batch_size, self.features]) 33 | self.assert_tensor_is_good(logabsdet, [batch_size]) 34 | 35 | self.assertEqual(outputs, outputs_ref) 36 | self.assertEqual(logabsdet, logabsdet_ref) 37 | 38 | def test_inverse_no_cache(self): 39 | batch_size = 10 40 | inputs = torch.randn(batch_size, self.features) 41 | outputs, logabsdet = self.transform.inverse_no_cache(inputs) 42 | 43 | outputs_ref = (inputs - self.transform.bias) @ self.weight_inverse.t() 44 | logabsdet_ref = torch.full([batch_size], -self.logabsdet.item()) 45 | 46 | self.assert_tensor_is_good(outputs, [batch_size, self.features]) 47 | self.assert_tensor_is_good(logabsdet, [batch_size]) 48 | 49 | self.assertEqual(outputs, outputs_ref) 50 | self.assertEqual(logabsdet, logabsdet_ref) 51 | 52 | def test_weight(self): 53 | weight = self.transform.weight() 54 | self.assert_tensor_is_good(weight, [self.features, self.features]) 55 | self.assertEqual(weight, self.weight) 56 | 57 | def test_weight_inverse(self): 58 | weight_inverse = self.transform.weight_inverse() 59 | self.assert_tensor_is_good(weight_inverse, [self.features, self.features]) 60 | self.assertEqual(weight_inverse, self.weight_inverse) 61 | 62 | def test_logabsdet(self): 63 | logabsdet = self.transform.logabsdet() 64 | self.assert_tensor_is_good(logabsdet, []) 65 | self.assertEqual(logabsdet, self.logabsdet) 66 | 67 | def test_forward_inverse_are_consistent(self): 68 | batch_size = 10 69 | inputs = torch.randn(batch_size, self.features) 70 | self.assert_forward_inverse_are_consistent(self.transform, inputs) 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /nde/transforms/reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utils 4 | from nde import transforms 5 | 6 | 7 | class SqueezeTransform(transforms.Transform): 8 | """A transformation defined for image data that trades spatial dimensions for channel 9 | dimensions, i.e. "squeezes" the inputs along the channel dimensions. 10 | 11 | Implementation adapted from https://github.com/pclucas14/pytorch-glow and 12 | https://github.com/chaiyujin/glow-pytorch. 13 | 14 | Reference: 15 | > L. Dinh et al., Density estimation using Real NVP, ICLR 2017. 16 | """ 17 | def __init__(self, factor=2): 18 | super(SqueezeTransform, self).__init__() 19 | 20 | if not utils.is_int(factor) or factor <= 1: 21 | raise ValueError('Factor must be an integer > 1.') 22 | 23 | self.factor = factor 24 | 25 | def get_output_shape(self, c, h, w): 26 | return (c * self.factor * self.factor, 27 | h // self.factor, 28 | w // self.factor) 29 | 30 | def forward(self, inputs, context=None): 31 | if inputs.dim() != 4: 32 | raise ValueError('Expecting inputs with 4 dimensions') 33 | 34 | batch_size, c, h, w = inputs.size() 35 | 36 | if h % self.factor != 0 or w % self.factor != 0: 37 | raise ValueError('Input image size not compatible with the factor.') 38 | 39 | inputs = inputs.view(batch_size, c, h // self.factor, self.factor, w // self.factor, 40 | self.factor) 41 | inputs = inputs.permute(0, 1, 3, 5, 2, 4).contiguous() 42 | inputs = inputs.view(batch_size, c * self.factor * self.factor, h // self.factor, 43 | w // self.factor) 44 | 45 | return inputs, torch.zeros(batch_size) 46 | 47 | def inverse(self, inputs, context=None): 48 | if inputs.dim() != 4: 49 | raise ValueError('Expecting inputs with 4 dimensions') 50 | 51 | batch_size, c, h, w = inputs.size() 52 | 53 | if c < 4 or c % 4 != 0: 54 | raise ValueError('Invalid number of channel dimensions.') 55 | 56 | inputs = inputs.view(batch_size, c // self.factor ** 2, self.factor, self.factor, h, w) 57 | inputs = inputs.permute(0, 1, 4, 2, 5, 3).contiguous() 58 | inputs = inputs.view(batch_size, c // self.factor ** 2, h * self.factor, w * self.factor) 59 | 60 | return inputs, torch.zeros(batch_size) 61 | 62 | class ReshapeTransform(transforms.Transform): 63 | def __init__(self, input_shape, output_shape): 64 | super().__init__() 65 | self.input_shape = input_shape 66 | self.output_shape = output_shape 67 | 68 | def forward(self, inputs, context=None): 69 | if tuple(inputs.shape[1:]) != self.input_shape: 70 | raise RuntimeError('Unexpected inputs shape ({}, but expecting {})' 71 | .format(tuple(inputs.shape[1:]), self.input_shape)) 72 | return inputs.reshape(-1, *self.output_shape), torch.zeros(inputs.shape[0]) 73 | 74 | def inverse(self, inputs, context=None): 75 | if tuple(inputs.shape[1:]) != self.output_shape: 76 | raise RuntimeError('Unexpected inputs shape ({}, but expecting {})' 77 | .format(tuple(inputs.shape[1:]), self.output_shape)) 78 | return inputs.reshape(-1, *self.input_shape), torch.zeros(inputs.shape[0]) 79 | -------------------------------------------------------------------------------- /nde/transforms/reshape_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from nde.transforms.transform_test import TransformTest 5 | from nde import transforms 6 | 7 | 8 | class SqueezeTransformTest(TransformTest): 9 | def setUp(self): 10 | self.transform = transforms.SqueezeTransform() 11 | 12 | def test_forward(self): 13 | batch_size = 10 14 | for shape in [[32, 4, 4], 15 | [16, 8, 8]]: 16 | with self.subTest(shape=shape): 17 | c, h, w = shape 18 | inputs = torch.randn(batch_size, c, h, w) 19 | outputs, logabsdet = self.transform(inputs) 20 | self.assert_tensor_is_good(outputs, [batch_size, c * 4, h // 2, w // 2]) 21 | self.assert_tensor_is_good(logabsdet, [batch_size]) 22 | self.assertEqual(logabsdet, torch.zeros(batch_size)) 23 | 24 | def test_forward_values(self): 25 | inputs = torch.arange(1, 17, 1).long().view(1, 1, 4, 4) 26 | outputs, _ = self.transform(inputs) 27 | 28 | def assert_channel_equal(channel, values): 29 | self.assertEqual(outputs[0, channel, ...], torch.LongTensor(values)) 30 | 31 | assert_channel_equal(0, [[1,3], 32 | [9,11]]) 33 | assert_channel_equal(1, [[2,4], 34 | [10,12]]) 35 | assert_channel_equal(2, [[5,7], 36 | [13,15]]) 37 | assert_channel_equal(3, [[6,8], 38 | [14,16]]) 39 | 40 | def test_forward_wrong_shape(self): 41 | batch_size = 10 42 | for shape in [[32, 3, 3], 43 | [32, 5, 5], 44 | [32, 4]]: 45 | with self.subTest(shape=shape): 46 | inputs = torch.randn(batch_size, *shape) 47 | with self.assertRaises(ValueError): 48 | self.transform(inputs) 49 | 50 | def test_forward_inverse_are_consistent(self): 51 | batch_size = 10 52 | for shape in [[32, 4, 4], 53 | [16, 8, 8]]: 54 | with self.subTest(shape=shape): 55 | c, h, w = shape 56 | inputs = torch.randn(batch_size, c, h, w) 57 | self.assert_forward_inverse_are_consistent(self.transform, inputs) 58 | 59 | def test_inverse_wrong_shape(self): 60 | batch_size = 10 61 | for shape in [[3, 4, 4], 62 | [33, 4, 4], 63 | [32, 4]]: 64 | with self.subTest(shape=shape): 65 | inputs = torch.randn(batch_size, *shape) 66 | with self.assertRaises(ValueError): 67 | self.transform.inverse(inputs) 68 | 69 | 70 | if __name__ == '__main__': 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /nde/transforms/splines/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import ( 2 | linear_spline, 3 | unconstrained_linear_spline 4 | ) 5 | 6 | from .quadratic import ( 7 | quadratic_spline, 8 | unconstrained_quadratic_spline 9 | ) 10 | 11 | from .cubic import ( 12 | cubic_spline, 13 | unconstrained_cubic_spline 14 | ) 15 | 16 | from .rational_quadratic import ( 17 | rational_quadratic_spline, 18 | unconstrained_rational_quadratic_spline 19 | ) 20 | 21 | -------------------------------------------------------------------------------- /nde/transforms/splines/cubic_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchtestcase 3 | 4 | from nde.transforms import splines 5 | 6 | 7 | class CubicSplineTest(torchtestcase.TorchTestCase): 8 | def test_forward_inverse_are_consistent(self): 9 | num_bins = 10 10 | shape = [2,3,4] 11 | 12 | unnormalized_widths = torch.randn(*shape, num_bins) 13 | unnormalized_heights = torch.randn(*shape, num_bins) 14 | unnorm_derivatives_left = torch.randn(*shape, 1) 15 | unnorm_derivatives_right = torch.randn(*shape, 1) 16 | 17 | def call_spline_fn(inputs, inverse=False): 18 | return splines.cubic_spline( 19 | inputs=inputs, 20 | unnormalized_widths=unnormalized_widths, 21 | unnormalized_heights=unnormalized_heights, 22 | unnorm_derivatives_left=unnorm_derivatives_left, 23 | unnorm_derivatives_right=unnorm_derivatives_right, 24 | inverse=inverse 25 | ) 26 | 27 | inputs = torch.rand(*shape) 28 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 29 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 30 | 31 | self.eps = 1e-4 32 | self.assertEqual(inputs, inputs_inv) 33 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 34 | 35 | class UnconstrainedCubicSplineTest(torchtestcase.TorchTestCase): 36 | def test_forward_inverse_are_consistent(self): 37 | num_bins = 10 38 | shape = [2,3,4] 39 | 40 | unnormalized_widths = torch.randn(*shape, num_bins) 41 | unnormalized_heights = torch.randn(*shape, num_bins) 42 | unnorm_derivatives_left = torch.randn(*shape, 1) 43 | unnorm_derivatives_right = torch.randn(*shape, 1) 44 | 45 | def call_spline_fn(inputs, inverse=False): 46 | return splines.unconstrained_cubic_spline( 47 | inputs=inputs, 48 | unnormalized_widths=unnormalized_widths, 49 | unnormalized_heights=unnormalized_heights, 50 | unnorm_derivatives_left=unnorm_derivatives_left, 51 | unnorm_derivatives_right=unnorm_derivatives_right, 52 | inverse=inverse 53 | ) 54 | 55 | inputs = 3 * torch.randn(*shape) # Note inputs are outside [0,1]. 56 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 57 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 58 | 59 | self.eps = 1e-4 60 | self.assertEqual(inputs, inputs_inv) 61 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 62 | -------------------------------------------------------------------------------- /nde/transforms/splines/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | 7 | import utils 8 | from nde import transforms 9 | 10 | def unconstrained_linear_spline(inputs, unnormalized_pdf, 11 | inverse=False, 12 | tail_bound=1., 13 | tails='linear'): 14 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 15 | outside_interval_mask = ~inside_interval_mask 16 | 17 | outputs = torch.zeros_like(inputs) 18 | logabsdet = torch.zeros_like(inputs) 19 | 20 | if tails == 'linear': 21 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 22 | logabsdet[outside_interval_mask] = 0 23 | else: 24 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 25 | 26 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = linear_spline( 27 | inputs=inputs[inside_interval_mask], 28 | unnormalized_pdf=unnormalized_pdf[inside_interval_mask, :], 29 | inverse=inverse, 30 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound 31 | ) 32 | 33 | return outputs, logabsdet 34 | 35 | def linear_spline(inputs, unnormalized_pdf, 36 | inverse=False, 37 | left=0., right=1., bottom=0., top=1.): 38 | """ 39 | Reference: 40 | > Müller et al., Neural Importance Sampling, arXiv:1808.03856, 2018. 41 | """ 42 | if not inverse and (torch.min(inputs) < left or torch.max(inputs) > right): 43 | raise transforms.InputOutsideDomain() 44 | elif inverse and (torch.min(inputs) < bottom or torch.max(inputs) > top): 45 | raise transforms.InputOutsideDomain() 46 | 47 | if inverse: 48 | inputs = (inputs - bottom) / (top - bottom) 49 | else: 50 | inputs = (inputs - left) / (right - left) 51 | 52 | num_bins = unnormalized_pdf.size(-1) 53 | 54 | pdf = F.softmax(unnormalized_pdf, dim=-1) 55 | 56 | cdf = torch.cumsum(pdf, dim=-1) 57 | cdf[..., -1] = 1. 58 | cdf = F.pad(cdf, pad=(1, 0), mode='constant', value=0.0) 59 | 60 | if inverse: 61 | inv_bin_idx = utils.searchsorted(cdf, inputs) 62 | 63 | bin_boundaries = (torch.linspace(0, 1, num_bins+1) 64 | .view([1] * inputs.dim() + [-1]) 65 | .expand(*inputs.shape, -1)) 66 | 67 | slopes = ((cdf[..., 1:] - cdf[..., :-1]) 68 | / (bin_boundaries[..., 1:] - bin_boundaries[..., :-1])) 69 | offsets = cdf[..., 1:] - slopes * bin_boundaries[..., 1:] 70 | 71 | inv_bin_idx = inv_bin_idx.unsqueeze(-1) 72 | input_slopes = slopes.gather(-1, inv_bin_idx)[..., 0] 73 | input_offsets = offsets.gather(-1, inv_bin_idx)[..., 0] 74 | 75 | outputs = (inputs - input_offsets) / input_slopes 76 | outputs = torch.clamp(outputs, 0, 1) 77 | 78 | logabsdet = -torch.log(input_slopes) 79 | else: 80 | bin_pos = inputs * num_bins 81 | 82 | bin_idx = torch.floor(bin_pos).long() 83 | bin_idx[bin_idx >= num_bins] = num_bins - 1 84 | 85 | alpha = bin_pos - bin_idx.float() 86 | 87 | input_pdfs = pdf.gather(-1, bin_idx[..., None])[..., 0] 88 | 89 | outputs = cdf.gather(-1, bin_idx[..., None])[..., 0] 90 | outputs += alpha * input_pdfs 91 | outputs = torch.clamp(outputs, 0, 1) 92 | 93 | bin_width = 1.0 / num_bins 94 | logabsdet = torch.log(input_pdfs) - np.log(bin_width) 95 | 96 | if inverse: 97 | outputs = outputs * (right - left) + left 98 | logabsdet = logabsdet - math.log(top - bottom) + math.log(right - left) 99 | else: 100 | outputs = outputs * (top - bottom) + bottom 101 | logabsdet = logabsdet + math.log(top - bottom) - math.log(right - left) 102 | 103 | return outputs, logabsdet 104 | -------------------------------------------------------------------------------- /nde/transforms/splines/linear_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchtestcase 3 | 4 | from nde.transforms import splines 5 | 6 | class LinearSplineTest(torchtestcase.TorchTestCase): 7 | def test_forward_inverse_are_consistent(self): 8 | num_bins = 10 9 | shape = [2,3,4] 10 | 11 | unnormalized_pdf = torch.randn(*shape, num_bins) 12 | 13 | def call_spline_fn(inputs, inverse=False): 14 | return splines.linear_spline( 15 | inputs=inputs, 16 | unnormalized_pdf=unnormalized_pdf, 17 | inverse=inverse 18 | ) 19 | 20 | inputs = torch.rand(*shape) 21 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 22 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 23 | 24 | self.eps = 1e-4 25 | self.assertEqual(inputs, inputs_inv) 26 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 27 | 28 | class UnconstrainedLinearSplineTest(torchtestcase.TorchTestCase): 29 | def test_forward_inverse_are_consistent(self): 30 | num_bins = 10 31 | shape = [2,3,4] 32 | 33 | unnormalized_pdf = torch.randn(*shape, num_bins) 34 | 35 | def call_spline_fn(inputs, inverse=False): 36 | return splines.unconstrained_linear_spline( 37 | inputs=inputs, 38 | unnormalized_pdf=unnormalized_pdf, 39 | inverse=inverse 40 | ) 41 | 42 | inputs = 3 * torch.randn(*shape) # Note inputs are outside [0,1]. 43 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 44 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 45 | 46 | self.eps = 1e-4 47 | self.assertEqual(inputs, inputs_inv) 48 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 49 | -------------------------------------------------------------------------------- /nde/transforms/splines/quadratic.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | import utils 7 | from nde import transforms 8 | 9 | DEFAULT_MIN_BIN_WIDTH = 1e-3 10 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 11 | 12 | def unconstrained_quadratic_spline(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | inverse=False, 16 | tail_bound=1., 17 | tails='linear', 18 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 19 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT): 20 | 21 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 22 | outside_interval_mask = ~inside_interval_mask 23 | 24 | outputs = torch.zeros_like(inputs) 25 | logabsdet = torch.zeros_like(inputs) 26 | 27 | num_bins = unnormalized_widths.shape[-1] 28 | 29 | if tails == 'linear': 30 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 31 | logabsdet[outside_interval_mask] = 0 32 | assert unnormalized_heights.shape[-1] == num_bins - 1 33 | else: 34 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 35 | 36 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = quadratic_spline( 37 | inputs=inputs[inside_interval_mask], 38 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 39 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 40 | inverse=inverse, 41 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 42 | min_bin_width=min_bin_width, 43 | min_bin_height=min_bin_height 44 | ) 45 | 46 | return outputs, logabsdet 47 | 48 | def quadratic_spline(inputs, 49 | unnormalized_widths, 50 | unnormalized_heights, 51 | inverse=False, 52 | left=0., right=1., bottom=0., top=1., 53 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 54 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT): 55 | if not inverse and (torch.min(inputs) < left or torch.max(inputs) > right): 56 | raise transforms.InputOutsideDomain() 57 | elif inverse and (torch.min(inputs) < bottom or torch.max(inputs) > top): 58 | raise transforms.InputOutsideDomain() 59 | 60 | if inverse: 61 | inputs = (inputs - bottom) / (top - bottom) 62 | else: 63 | inputs = (inputs - left) / (right - left) 64 | 65 | num_bins = unnormalized_widths.shape[-1] 66 | 67 | if min_bin_width * num_bins > 1.0: 68 | raise ValueError('Minimal bin width too large for the number of bins') 69 | if min_bin_height * num_bins > 1.0: 70 | raise ValueError('Minimal bin height too large for the number of bins') 71 | 72 | widths = F.softmax(unnormalized_widths, dim=-1) 73 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 74 | 75 | unnorm_heights_exp = torch.exp(unnormalized_heights) 76 | 77 | if unnorm_heights_exp.shape[-1] == num_bins - 1: 78 | # Set boundary heights s.t. after normalization they are exactly 1. 79 | first_widths = 0.5 * widths[..., 0] 80 | last_widths = 0.5 * widths[..., -1] 81 | numerator = (0.5 * first_widths * unnorm_heights_exp[..., 0] 82 | + 0.5 * last_widths * unnorm_heights_exp[..., -1] 83 | + torch.sum(((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) / 2) 84 | * widths[..., 1:-1], dim=-1)) 85 | constant = numerator / (1 - 0.5 * first_widths - 0.5 * last_widths) 86 | constant = constant[..., None] 87 | unnorm_heights_exp = torch.cat([constant, unnorm_heights_exp, constant], dim=-1) 88 | 89 | unnormalized_area = torch.sum(((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) / 2) 90 | * widths, dim=-1)[..., None] 91 | heights = unnorm_heights_exp / unnormalized_area 92 | heights = min_bin_height + (1 - min_bin_height) * heights 93 | 94 | bin_left_cdf = torch.cumsum(((heights[..., :-1] + heights[..., 1:]) / 2) * widths, dim=-1) 95 | bin_left_cdf[..., -1] = 1. 96 | bin_left_cdf = F.pad(bin_left_cdf, pad=(1, 0), mode='constant', value=0.0) 97 | 98 | bin_locations = torch.cumsum(widths, dim=-1) 99 | bin_locations[..., -1] = 1. 100 | bin_locations = F.pad(bin_locations, pad=(1, 0), mode='constant', value=0.0) 101 | 102 | if inverse: 103 | bin_idx = utils.searchsorted(bin_left_cdf, inputs)[..., None] 104 | else: 105 | bin_idx = utils.searchsorted(bin_locations, inputs)[..., None] 106 | 107 | input_bin_locations = bin_locations.gather(-1, bin_idx)[..., 0] 108 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 109 | 110 | input_left_cdf = bin_left_cdf.gather(-1, bin_idx)[..., 0] 111 | 112 | input_left_heights = heights.gather(-1, bin_idx)[..., 0] 113 | input_right_heights = heights.gather(-1, bin_idx+1)[..., 0] 114 | 115 | a = 0.5 * (input_right_heights - input_left_heights) * input_bin_widths 116 | b = input_left_heights * input_bin_widths 117 | c = input_left_cdf 118 | 119 | if inverse: 120 | c_ = c - inputs 121 | alpha = (-b + torch.sqrt(b.pow(2) - 4*a*c_)) / (2*a) 122 | outputs = alpha * input_bin_widths + input_bin_locations 123 | outputs = torch.clamp(outputs, 0, 1) 124 | logabsdet = -torch.log((alpha * (input_right_heights - input_left_heights) 125 | + input_left_heights)) 126 | else: 127 | alpha = (inputs - input_bin_locations) / input_bin_widths 128 | outputs = a * alpha.pow(2) + b * alpha + c 129 | outputs = torch.clamp(outputs, 0, 1) 130 | logabsdet = torch.log((alpha * (input_right_heights - input_left_heights) 131 | + input_left_heights)) 132 | 133 | if inverse: 134 | outputs = outputs * (right - left) + left 135 | logabsdet = logabsdet - math.log(top - bottom) + math.log(right - left) 136 | else: 137 | outputs = outputs * (top - bottom) + bottom 138 | logabsdet = logabsdet + math.log(top - bottom) - math.log(right - left) 139 | 140 | return outputs, logabsdet 141 | -------------------------------------------------------------------------------- /nde/transforms/splines/quadratic_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchtestcase 3 | 4 | from nde.transforms import splines 5 | 6 | class QuadraticSplineTest(torchtestcase.TorchTestCase): 7 | def test_forward_inverse_are_consistent(self): 8 | num_bins = 10 9 | shape = [2,3,4] 10 | 11 | unnormalized_widths = torch.randn(*shape, num_bins) 12 | unnormalized_heights = torch.randn(*shape, num_bins + 1) 13 | 14 | def call_spline_fn(inputs, inverse=False): 15 | return splines.quadratic_spline( 16 | inputs=inputs, 17 | unnormalized_widths=unnormalized_widths, 18 | unnormalized_heights=unnormalized_heights, 19 | inverse=inverse 20 | ) 21 | 22 | inputs = torch.rand(*shape) 23 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 24 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 25 | 26 | self.eps = 1e-4 27 | self.assertEqual(inputs, inputs_inv) 28 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 29 | 30 | class UnconstrainedQuadraticSplineTest(torchtestcase.TorchTestCase): 31 | def test_forward_inverse_are_consistent(self): 32 | num_bins = 10 33 | shape = [2,3,4] 34 | 35 | unnormalized_widths = torch.randn(*shape, num_bins) 36 | unnormalized_heights = torch.randn(*shape, num_bins - 1) 37 | 38 | def call_spline_fn(inputs, inverse=False): 39 | return splines.unconstrained_quadratic_spline( 40 | inputs=inputs, 41 | unnormalized_widths=unnormalized_widths, 42 | unnormalized_heights=unnormalized_heights, 43 | inverse=inverse 44 | ) 45 | 46 | inputs = 3 * torch.randn(*shape) # Note inputs are outside [0,1]. 47 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 48 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 49 | 50 | self.eps = 1e-4 51 | self.assertEqual(inputs, inputs_inv) 52 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 53 | -------------------------------------------------------------------------------- /nde/transforms/splines/rational_quadratic_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchtestcase 3 | 4 | from nde.transforms import splines 5 | 6 | class RationalQuadraticSplineTest(torchtestcase.TorchTestCase): 7 | def test_forward_inverse_are_consistent(self): 8 | num_bins = 10 9 | shape = [2,3,4] 10 | 11 | unnormalized_widths = torch.randn(*shape, num_bins) 12 | unnormalized_heights = torch.randn(*shape, num_bins) 13 | unnormalized_derivatives = torch.randn(*shape, num_bins + 1) 14 | 15 | def call_spline_fn(inputs, inverse=False): 16 | return splines.rational_quadratic_spline( 17 | inputs=inputs, 18 | unnormalized_widths=unnormalized_widths, 19 | unnormalized_heights=unnormalized_heights, 20 | unnormalized_derivatives=unnormalized_derivatives, 21 | inverse=inverse 22 | ) 23 | 24 | inputs = torch.rand(*shape) 25 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 26 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 27 | 28 | self.eps = 1e-4 29 | self.assertEqual(inputs, inputs_inv) 30 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 31 | 32 | class UnconstrainedRationalQuadraticSplineTest(torchtestcase.TorchTestCase): 33 | def test_forward_inverse_are_consistent(self): 34 | num_bins = 10 35 | shape = [2,3,4] 36 | 37 | unnormalized_widths = torch.randn(*shape, num_bins) 38 | unnormalized_heights = torch.randn(*shape, num_bins) 39 | unnormalized_derivatives = torch.randn(*shape, num_bins + 1) 40 | 41 | def call_spline_fn(inputs, inverse=False): 42 | return splines.unconstrained_rational_quadratic_spline( 43 | inputs=inputs, 44 | unnormalized_widths=unnormalized_widths, 45 | unnormalized_heights=unnormalized_heights, 46 | unnormalized_derivatives=unnormalized_derivatives, 47 | inverse=inverse 48 | ) 49 | 50 | inputs = 3 * torch.randn(*shape) # Note inputs are outside [0,1]. 51 | outputs, logabsdet = call_spline_fn(inputs, inverse=False) 52 | inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) 53 | 54 | self.eps = 1e-4 55 | self.assertEqual(inputs, inputs_inv) 56 | self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) 57 | -------------------------------------------------------------------------------- /nde/transforms/standard.py: -------------------------------------------------------------------------------- 1 | """Implementations of some standard transforms.""" 2 | 3 | import torch 4 | from nde import transforms 5 | 6 | 7 | class IdentityTransform(transforms.Transform): 8 | """Transform that leaves input unchanged.""" 9 | 10 | def forward(self, inputs, context=None): 11 | batch_size = inputs.shape[0] 12 | logabsdet = torch.zeros(batch_size) 13 | return inputs, logabsdet 14 | 15 | def inverse(self, inputs, context=None): 16 | return self(inputs, context) 17 | 18 | 19 | class AffineScalarTransform(transforms.Transform): 20 | """Computes X = X * scale + shift, where scale and shift are scalars, and scale is non-zero.""" 21 | 22 | def __init__(self, shift=None, scale=None): 23 | super().__init__() 24 | 25 | if shift is None and scale is None: 26 | raise ValueError('At least one of scale and shift must be provided.') 27 | if scale == 0.: 28 | raise ValueError('Scale cannot be zero.') 29 | 30 | self.register_buffer('_shift', torch.tensor(shift if (shift is not None) else 0.)) 31 | self.register_buffer('_scale', torch.tensor(scale if (scale is not None) else 1.)) 32 | 33 | @property 34 | def _log_scale(self): 35 | return torch.log(torch.abs(self._scale)) 36 | 37 | def forward(self, inputs, context=None): 38 | batch_size = inputs.shape[0] 39 | num_dims = torch.prod(torch.tensor(inputs.shape[1:]), dtype=torch.float) 40 | outputs = inputs * self._scale + self._shift 41 | logabsdet = torch.full([batch_size], self._log_scale * num_dims) 42 | return outputs, logabsdet 43 | 44 | def inverse(self, inputs, context=None): 45 | batch_size = inputs.shape[0] 46 | num_dims = torch.prod(torch.tensor(inputs.shape[1:]), dtype=torch.float) 47 | outputs = (inputs - self._shift) / self._scale 48 | logabsdet = torch.full([batch_size], -self._log_scale * num_dims) 49 | return outputs, logabsdet 50 | -------------------------------------------------------------------------------- /nde/transforms/standard_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the standard transforms.""" 2 | 3 | import numpy as np 4 | import torch 5 | import unittest 6 | 7 | from nde.transforms import standard 8 | from nde.transforms.transform_test import TransformTest 9 | 10 | 11 | class IdentityTransformTest(TransformTest): 12 | 13 | def test_forward(self): 14 | batch_size = 10 15 | shape = [2, 3, 4] 16 | inputs = torch.randn(batch_size, *shape) 17 | transform = standard.IdentityTransform() 18 | outputs, logabsdet = transform(inputs) 19 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 20 | self.assert_tensor_is_good(logabsdet, [batch_size]) 21 | self.assertEqual(outputs, inputs) 22 | self.assertEqual(logabsdet, torch.zeros(batch_size)) 23 | 24 | def test_inverse(self): 25 | batch_size = 10 26 | shape = [2, 3, 4] 27 | inputs = torch.randn(batch_size, *shape) 28 | transform = standard.IdentityTransform() 29 | outputs, logabsdet = transform.inverse(inputs) 30 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 31 | self.assert_tensor_is_good(logabsdet, [batch_size]) 32 | self.assertEqual(outputs, inputs) 33 | self.assertEqual(logabsdet, torch.zeros(batch_size)) 34 | 35 | def test_forward_inverse_are_consistent(self): 36 | batch_size = 10 37 | shape = [2, 3, 4] 38 | inputs = torch.randn(batch_size, *shape) 39 | transform = standard.IdentityTransform() 40 | self.assert_forward_inverse_are_consistent(transform, inputs) 41 | 42 | 43 | class AffineScalarTransformTest(TransformTest): 44 | 45 | def test_forward(self): 46 | batch_size = 10 47 | shape = [2, 3, 4] 48 | inputs = torch.randn(batch_size, *shape) 49 | 50 | def test_case(scale, shift, true_outputs, true_logabsdet): 51 | with self.subTest(scale=scale, shift=shift): 52 | transform = standard.AffineScalarTransform(scale=scale, shift=shift) 53 | outputs, logabsdet = transform(inputs) 54 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 55 | self.assert_tensor_is_good(logabsdet, [batch_size]) 56 | self.assertEqual(outputs, true_outputs) 57 | self.assertEqual(logabsdet, 58 | torch.full([batch_size], true_logabsdet * np.prod(shape))) 59 | 60 | self.eps = 1e-6 61 | test_case(None, 2., inputs + 2., 0) 62 | test_case(2., None, inputs * 2., np.log(2.)) 63 | test_case(2., 2., inputs * 2. + 2., np.log(2.)) 64 | 65 | def test_inverse(self): 66 | batch_size = 10 67 | shape = [2, 3, 4] 68 | inputs = torch.randn(batch_size, *shape) 69 | 70 | def test_case(scale, shift, true_outputs, true_logabsdet): 71 | with self.subTest(scale=scale, shift=shift): 72 | transform = standard.AffineScalarTransform(scale=scale, shift=shift) 73 | outputs, logabsdet = transform.inverse(inputs) 74 | self.assert_tensor_is_good(outputs, [batch_size] + shape) 75 | self.assert_tensor_is_good(logabsdet, [batch_size]) 76 | self.assertEqual(outputs, true_outputs) 77 | self.assertEqual(logabsdet, 78 | torch.full([batch_size], true_logabsdet * np.prod(shape))) 79 | 80 | self.eps = 1e-6 81 | test_case(None, 2., inputs - 2., 0) 82 | test_case(2., None, inputs / 2., -np.log(2.)) 83 | test_case(2., 2., (inputs - 2.) / 2., -np.log(2.)) 84 | 85 | def test_forward_inverse_are_consistent(self): 86 | batch_size = 10 87 | shape = [2, 3, 4] 88 | inputs = torch.randn(batch_size, *shape) 89 | 90 | def test_case(scale, shift): 91 | transform = standard.AffineScalarTransform(scale=scale, shift=shift) 92 | self.assert_forward_inverse_are_consistent(transform, inputs) 93 | 94 | self.eps = 1e-6 95 | test_case(None, 2.) 96 | test_case(2., None) 97 | test_case(2., 2.) 98 | 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /nde/transforms/svd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch import nn 5 | from torch.nn import init 6 | 7 | from nde import transforms 8 | from nde.transforms.linear import Linear 9 | 10 | 11 | class SVDLinear(Linear): 12 | """A linear module using the SVD decomposition for the weight matrix.""" 13 | 14 | def __init__(self, features, num_householder, using_cache=False): 15 | super().__init__(features, using_cache) 16 | 17 | # First orthogonal matrix (U). 18 | self.orthogonal_1 = transforms.HouseholderSequence( 19 | features=features, num_transforms=num_householder) 20 | 21 | # Logs of diagonal entries of the diagonal matrix (S). 22 | self.log_diagonal = nn.Parameter(torch.zeros(features)) 23 | 24 | # Second orthogonal matrix (V^T). 25 | self.orthogonal_2 = transforms.HouseholderSequence( 26 | features=features, num_transforms=num_householder) 27 | 28 | self._initialize() 29 | 30 | def _initialize(self): 31 | stdv = 1.0 / np.sqrt(self.features) 32 | init.uniform_(self.log_diagonal, -stdv, stdv) 33 | init.constant_(self.bias, 0.0) 34 | 35 | 36 | def forward_no_cache(self, inputs): 37 | """Cost: 38 | output = O(KDN) 39 | logabsdet = O(D) 40 | where: 41 | K = num of householder transforms 42 | D = num of features 43 | N = num of inputs 44 | """ 45 | outputs, _ = self.orthogonal_2(inputs) # Ignore logabsdet as we know it's zero. 46 | outputs *= torch.exp(self.log_diagonal) 47 | outputs, _ = self.orthogonal_1(outputs) # Ignore logabsdet as we know it's zero. 48 | outputs += self.bias 49 | 50 | logabsdet = self.logabsdet() * torch.ones(outputs.shape[0]) 51 | 52 | return outputs, logabsdet 53 | 54 | def inverse_no_cache(self, inputs): 55 | """Cost: 56 | output = O(KDN) 57 | logabsdet = O(D) 58 | where: 59 | K = num of householder transforms 60 | D = num of features 61 | N = num of inputs 62 | """ 63 | outputs = inputs - self.bias 64 | outputs, _ = self.orthogonal_1.inverse(outputs) # Ignore logabsdet since we know it's zero. 65 | outputs *= torch.exp(-self.log_diagonal) 66 | outputs, _ = self.orthogonal_2.inverse(outputs) # Ignore logabsdet since we know it's zero. 67 | logabsdet = -self.logabsdet() 68 | logabsdet = logabsdet * torch.ones(outputs.shape[0]) 69 | return outputs, logabsdet 70 | 71 | def weight(self): 72 | """Cost: 73 | weight = O(KD^2) 74 | where: 75 | K = num of householder transforms 76 | D = num of features 77 | """ 78 | diagonal = torch.diag(torch.exp(self.log_diagonal)) 79 | weight, _ = self.orthogonal_2.inverse(diagonal) 80 | weight, _ = self.orthogonal_1(weight.t()) 81 | return weight.t() 82 | 83 | def weight_inverse(self): 84 | """Cost: 85 | inverse = O(KD^2) 86 | where: 87 | K = num of householder transforms 88 | D = num of features 89 | """ 90 | diagonal_inv = torch.diag(torch.exp(-self.log_diagonal)) 91 | weight_inv, _ = self.orthogonal_1(diagonal_inv) 92 | weight_inv, _ = self.orthogonal_2.inverse(weight_inv.t()) 93 | return weight_inv.t() 94 | 95 | def logabsdet(self): 96 | """Cost: 97 | logabsdet = O(D) 98 | where: 99 | D = num of features 100 | """ 101 | return torch.sum(self.log_diagonal) 102 | -------------------------------------------------------------------------------- /nde/transforms/svd_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | import utils 5 | 6 | from nde.transforms import svd 7 | from nde.transforms.transform_test import TransformTest 8 | 9 | 10 | class SVDLinearTest(TransformTest): 11 | 12 | def setUp(self): 13 | self.features = 3 14 | self.transform = svd.SVDLinear(features=self.features, num_householder=4) 15 | self.transform.bias.data = torch.randn(self.features) # Just so bias isn't zero. 16 | 17 | diagonal = torch.diag(torch.exp(self.transform.log_diagonal)) 18 | orthogonal_1 = self.transform.orthogonal_1.matrix() 19 | orthogonal_2 = self.transform.orthogonal_2.matrix() 20 | self.weight = orthogonal_1 @ diagonal @ orthogonal_2 21 | self.weight_inverse = torch.inverse(self.weight) 22 | self.logabsdet = utils.logabsdet(self.weight) 23 | 24 | self.eps = 1e-5 25 | 26 | def test_forward_no_cache(self): 27 | batch_size = 10 28 | inputs = torch.randn(batch_size, self.features) 29 | outputs, logabsdet = self.transform.forward_no_cache(inputs) 30 | 31 | outputs_ref = inputs @ self.weight.t() + self.transform.bias 32 | logabsdet_ref = torch.full([batch_size], self.logabsdet.item()) 33 | 34 | self.assert_tensor_is_good(outputs, [batch_size, self.features]) 35 | self.assert_tensor_is_good(logabsdet, [batch_size]) 36 | 37 | self.assertEqual(outputs, outputs_ref) 38 | self.assertEqual(logabsdet, logabsdet_ref) 39 | 40 | def test_inverse_no_cache(self): 41 | batch_size = 10 42 | inputs = torch.randn(batch_size, self.features) 43 | outputs, logabsdet = self.transform.inverse_no_cache(inputs) 44 | 45 | outputs_ref = (inputs - self.transform.bias) @ self.weight_inverse.t() 46 | logabsdet_ref = torch.full([batch_size], -self.logabsdet.item()) 47 | 48 | self.assert_tensor_is_good(outputs, [batch_size, self.features]) 49 | self.assert_tensor_is_good(logabsdet, [batch_size]) 50 | 51 | self.assertEqual(outputs, outputs_ref) 52 | self.assertEqual(logabsdet, logabsdet_ref) 53 | 54 | def test_weight(self): 55 | weight = self.transform.weight() 56 | self.assert_tensor_is_good(weight, [self.features, self.features]) 57 | self.assertEqual(weight, self.weight) 58 | 59 | def test_weight_inverse(self): 60 | weight_inverse = self.transform.weight_inverse() 61 | self.assert_tensor_is_good(weight_inverse, [self.features, self.features]) 62 | self.assertEqual(weight_inverse, self.weight_inverse) 63 | 64 | def test_logabsdet(self): 65 | logabsdet = self.transform.logabsdet() 66 | self.assert_tensor_is_good(logabsdet, []) 67 | self.assertEqual(logabsdet, self.logabsdet) 68 | 69 | def test_forward_inverse_are_consistent(self): 70 | batch_size = 10 71 | inputs = torch.randn(batch_size, self.features) 72 | self.assert_forward_inverse_are_consistent(self.transform, inputs) 73 | 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /nde/transforms/transform_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchtestcase 3 | 4 | from nde.transforms import base 5 | 6 | 7 | class TransformTest(torchtestcase.TorchTestCase): 8 | """Base test for all transforms.""" 9 | 10 | def assert_tensor_is_good(self, tensor, shape=None): 11 | self.assertIsInstance(tensor, torch.Tensor) 12 | self.assertFalse(torch.isnan(tensor).any()) 13 | self.assertFalse(torch.isinf(tensor).any()) 14 | if shape is not None: 15 | self.assertEqual(tensor.shape, torch.Size(shape)) 16 | 17 | def assert_forward_inverse_are_consistent(self, transform, inputs): 18 | inverse = base.InverseTransform(transform) 19 | identity = base.CompositeTransform([inverse, transform]) 20 | outputs, logabsdet = identity(inputs) 21 | 22 | self.assert_tensor_is_good(outputs, shape=inputs.shape) 23 | self.assert_tensor_is_good(logabsdet, shape=inputs.shape[:1]) 24 | self.assertEqual(outputs, inputs) 25 | self.assertEqual(logabsdet, torch.zeros(inputs.shape[:1])) 26 | 27 | def assertNotEqual(self, first, second, msg=None): 28 | if ((self._eps and (first - second).abs().max().item() < self._eps) or 29 | (not self._eps and torch.equal(first, second))): 30 | self._fail_with_message(msg, "The tensors are _not_ different!") 31 | 32 | 33 | -------------------------------------------------------------------------------- /nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ( 2 | ResidualNet, 3 | ConvResidualNet, 4 | ) 5 | 6 | from .unet import UNet 7 | 8 | from .attention import ConvAttentionNet 9 | 10 | from .mlp import MLP 11 | 12 | from .conv import ( 13 | SylvesterFlowConvEncoderNet, 14 | SylvesterFlowConvDecoderNet, 15 | ConvEncoder, 16 | ConvDecoder 17 | ) -------------------------------------------------------------------------------- /nn/mlp.py: -------------------------------------------------------------------------------- 1 | """Implementations multi-layer perceptrons.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from torch.nn import functional as F 7 | from torch import nn 8 | 9 | 10 | class MLP(nn.Module): 11 | """A standard multi-layer perceptron.""" 12 | 13 | def __init__(self, 14 | in_shape, 15 | out_shape, 16 | hidden_sizes, 17 | activation=F.relu, 18 | activate_output=False): 19 | """ 20 | Args: 21 | in_shape: tuple, list or torch.Size, the shape of the input. 22 | out_shape: tuple, list or torch.Size, the shape of the output. 23 | hidden_sizes: iterable of ints, the hidden-layer sizes. 24 | activation: callable, the activation function. 25 | activate_output: bool, whether to apply the activation to the output. 26 | """ 27 | super().__init__() 28 | self._in_shape = torch.Size(in_shape) 29 | self._out_shape = torch.Size(out_shape) 30 | self._hidden_sizes = hidden_sizes 31 | self._activation = activation 32 | self._activate_output = activate_output 33 | 34 | if len(hidden_sizes) == 0: 35 | raise ValueError('List of hidden sizes can\'t be empty.') 36 | 37 | self._input_layer = nn.Linear(np.prod(in_shape), hidden_sizes[0]) 38 | self._hidden_layers = nn.ModuleList([ 39 | nn.Linear(in_size, out_size) 40 | for in_size, out_size in zip(hidden_sizes[:-1], hidden_sizes[1:]) 41 | ]) 42 | self._output_layer = nn.Linear(hidden_sizes[-1], np.prod(out_shape)) 43 | 44 | def forward(self, inputs): 45 | if inputs.shape[1:] != self._in_shape: 46 | raise ValueError('Expected inputs of shape {}, got {}.'.format( 47 | self._in_shape, inputs.shape[1:])) 48 | 49 | inputs = inputs.reshape(-1, np.prod(self._in_shape)) 50 | outputs = self._input_layer(inputs) 51 | outputs = self._activation(outputs) 52 | 53 | for hidden_layer in self._hidden_layers: 54 | outputs = hidden_layer(outputs) 55 | outputs = self._activation(outputs) 56 | 57 | outputs = self._output_layer(outputs) 58 | if self._activate_output: 59 | outputs = self._activation(outputs) 60 | outputs = outputs.reshape(-1, *self._out_shape) 61 | 62 | return outputs 63 | -------------------------------------------------------------------------------- /nn/mlp_test.py: -------------------------------------------------------------------------------- 1 | """Tests for multi-layer perceptrons.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | 7 | from nn import mlp 8 | 9 | 10 | class MLPTest(torchtestcase.TorchTestCase): 11 | 12 | def test_forward(self): 13 | batch_size = 10 14 | in_shape = [2, 3, 4] 15 | out_shape = [5, 6] 16 | inputs = torch.randn(batch_size, *in_shape) 17 | 18 | for hidden_sizes in [[20], [20, 30], [20, 30, 40]]: 19 | with self.subTest(hidden_sizes=hidden_sizes): 20 | model = mlp.MLP( 21 | in_shape=in_shape, 22 | out_shape=out_shape, 23 | hidden_sizes=hidden_sizes, 24 | ) 25 | outputs = model(inputs) 26 | self.assertIsInstance(outputs, torch.Tensor) 27 | self.assertEqual(outputs.shape, torch.Size([batch_size] + out_shape)) 28 | self.assertFalse(torch.isnan(outputs).any()) 29 | self.assertFalse(torch.isinf(outputs).any()) 30 | 31 | with self.assertRaises(Exception): 32 | mlp.MLP( 33 | in_shape=in_shape, 34 | out_shape=out_shape, 35 | hidden_sizes=[], 36 | ) 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /nn/unet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | import utils 7 | 8 | 9 | class UNet(nn.Module): 10 | def __init__(self, 11 | in_features, 12 | max_hidden_features, 13 | num_layers, 14 | out_features, 15 | nonlinearity=F.relu): 16 | super().__init__() 17 | 18 | assert utils.is_power_of_two(max_hidden_features), \ 19 | '\'max_hidden_features\' must be a power of two.' 20 | assert max_hidden_features // 2 ** num_layers > 1, \ 21 | '\'num_layers\' must be {} or fewer'.format(int(np.log2(max_hidden_features) - 1)) 22 | 23 | self.nonlinearity = nonlinearity 24 | self.num_layers = num_layers 25 | 26 | self.initial_layer = nn.Linear(in_features, max_hidden_features) 27 | 28 | self.down_layers = nn.ModuleList([ 29 | nn.Linear( 30 | in_features=max_hidden_features // 2 ** i, 31 | out_features=max_hidden_features // 2 ** (i + 1) 32 | ) 33 | for i in range(num_layers) 34 | ]) 35 | 36 | self.middle_layer = nn.Linear( 37 | in_features=max_hidden_features // 2 ** num_layers, 38 | out_features=max_hidden_features // 2 ** num_layers) 39 | 40 | self.up_layers = nn.ModuleList([ 41 | nn.Linear( 42 | in_features=max_hidden_features // 2 ** (i + 1), 43 | out_features=max_hidden_features // 2 ** i 44 | ) 45 | for i in range(num_layers - 1, -1, -1) 46 | ]) 47 | 48 | self.final_layer = nn.Linear(max_hidden_features, out_features) 49 | 50 | def forward(self, inputs): 51 | temps = self.initial_layer(inputs) 52 | temps = self.nonlinearity(temps) 53 | 54 | down_temps = [] 55 | for layer in self.down_layers: 56 | temps = layer(temps) 57 | temps = self.nonlinearity(temps) 58 | down_temps.append(temps) 59 | 60 | temps = self.middle_layer(temps) 61 | temps = self.nonlinearity(temps) 62 | 63 | for i, layer in enumerate(self.up_layers): 64 | temps += down_temps[self.num_layers - i - 1] 65 | temps = self.nonlinearity(temps) 66 | temps = layer(temps) 67 | 68 | return self.final_layer(temps) 69 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom_lr_schedulers import CosineAnnealingWarmUpLR 2 | -------------------------------------------------------------------------------- /optim/custom_lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch import nn, optim 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | class CosineAnnealingWarmUpLR(_LRScheduler): 8 | def __init__(self, optimizer, warm_up_epochs, total_epochs, eta_min=0, last_epoch=-1): 9 | self.warm_up_epochs = warm_up_epochs 10 | self.total_epochs = total_epochs 11 | self.eta_min = eta_min 12 | super().__init__(optimizer, last_epoch) 13 | 14 | def get_lr(self): 15 | if self.last_epoch < self.warm_up_epochs: 16 | return [base_lr * (self.last_epoch / self.warm_up_epochs) 17 | for base_lr in self.base_lrs] 18 | else: 19 | frac_epochs = ( 20 | (self.last_epoch - self.warm_up_epochs) 21 | / (self.total_epochs - self.warm_up_epochs) 22 | ) 23 | return [self.eta_min + (base_lr - self.eta_min) * 24 | (1 + math.cos(math.pi * frac_epochs)) / 2 25 | for base_lr in self.base_lrs] 26 | 27 | 28 | def main(): 29 | model = nn.Linear(5, 3) 30 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 31 | warm_up_steps = 0 32 | total_steps = 100 33 | scheduler = CosineAnnealingWarmUpLR(optimizer, warm_up_steps, total_steps) 34 | lrs = [] 35 | for _ in range(total_steps): 36 | optimizer.zero_grad() 37 | optimizer.step() 38 | scheduler.step() 39 | lrs.append(scheduler.get_lr()[0]) 40 | print(lrs) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .torchutils import ( 2 | create_alternating_binary_mask, 3 | create_mid_split_binary_mask, 4 | create_random_binary_mask, 5 | get_num_parameters, 6 | logabsdet, 7 | random_orthogonal, 8 | sum_except_batch, 9 | split_leading_dim, 10 | merge_leading_dims, 11 | repeat_rows, 12 | tensor2numpy, 13 | tile, 14 | searchsorted, 15 | cbrt, 16 | get_temperature 17 | ) 18 | 19 | from .typechecks import is_bool 20 | from .typechecks import is_int 21 | from .typechecks import is_positive_int 22 | from .typechecks import is_nonnegative_int 23 | from .typechecks import is_power_of_two 24 | 25 | from .io import get_data_root 26 | from .io import NoDataRootError 27 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | """Utility functions for Input/Output.""" 2 | 3 | import os 4 | 5 | 6 | class NoDataRootError(Exception): 7 | """Exception to be thrown when data root doesn't exist.""" 8 | pass 9 | 10 | 11 | def get_data_root(): 12 | """Returns the data root, which we assume is contained in an environment variable. 13 | 14 | Returns: 15 | string, the data root. 16 | 17 | Raises: 18 | NoDataRootError: If environment variable doesn't exist. 19 | """ 20 | data_root_var = 'DATAROOT' 21 | try: 22 | return os.environ[data_root_var] 23 | except KeyError: 24 | raise NoDataRootError('Data root must be in environment variable {}, which' 25 | ' doesn\'t exist.'.format(data_root_var)) 26 | -------------------------------------------------------------------------------- /utils/torchutils.py: -------------------------------------------------------------------------------- 1 | """Various PyTorch utility functions.""" 2 | 3 | import torch 4 | 5 | import utils 6 | 7 | 8 | def tile(x, n): 9 | if not utils.is_positive_int(n): 10 | raise TypeError('Argument \'n\' must be a positive integer.') 11 | x_ = x.reshape(-1) 12 | x_ = x_.repeat(n) 13 | x_ = x_.reshape(n, -1) 14 | x_ = x_.transpose(1, 0) 15 | x_ = x_.reshape(-1) 16 | return x_ 17 | 18 | 19 | def sum_except_batch(x, num_batch_dims=1): 20 | """Sums all elements of `x` except for the first `num_batch_dims` dimensions.""" 21 | if not utils.is_nonnegative_int(num_batch_dims): 22 | raise TypeError('Number of batch dimensions must be a non-negative integer.') 23 | reduce_dims = list(range(num_batch_dims, x.ndimension())) 24 | return torch.sum(x, dim=reduce_dims) 25 | 26 | 27 | def split_leading_dim(x, shape): 28 | """Reshapes the leading dim of `x` to have the given shape.""" 29 | new_shape = torch.Size(shape) + x.shape[1:] 30 | return torch.reshape(x, new_shape) 31 | 32 | 33 | def merge_leading_dims(x, num_dims): 34 | """Reshapes the tensor `x` such that the first `num_dims` dimensions are merged to one.""" 35 | if not utils.is_positive_int(num_dims): 36 | raise TypeError('Number of leading dims must be a positive integer.') 37 | if num_dims > x.dim(): 38 | raise ValueError('Number of leading dims can\'t be greater than total number of dims.') 39 | new_shape = torch.Size([-1]) + x.shape[num_dims:] 40 | return torch.reshape(x, new_shape) 41 | 42 | 43 | def repeat_rows(x, num_reps): 44 | """Each row of tensor `x` is repeated `num_reps` times along leading dimension.""" 45 | if not utils.is_positive_int(num_reps): 46 | raise TypeError('Number of repetitions must be a positive integer.') 47 | shape = x.shape 48 | x = x.unsqueeze(1) 49 | x = x.expand(shape[0], num_reps, *shape[1:]) 50 | return merge_leading_dims(x, num_dims=2) 51 | 52 | 53 | def tensor2numpy(x): 54 | return x.detach().cpu().numpy() 55 | 56 | 57 | def logabsdet(x): 58 | """Returns the log absolute determinant of square matrix x.""" 59 | # Note: torch.logdet() only works for positive determinant. 60 | _, res = torch.slogdet(x) 61 | return res 62 | 63 | 64 | def random_orthogonal(size): 65 | """ 66 | Returns a random orthogonal matrix as a 2-dim tensor of shape [size, size]. 67 | """ 68 | 69 | # Use the QR decomposition of a random Gaussian matrix. 70 | x = torch.randn(size, size) 71 | q, _ = torch.qr(x) 72 | return q 73 | 74 | 75 | def get_num_parameters(model): 76 | """ 77 | Returns the number of trainable parameters in a model of type nn.Module 78 | :param model: nn.Module containing trainable parameters 79 | :return: number of trainable parameters in model 80 | """ 81 | num_parameters = 0 82 | for parameter in model.parameters(): 83 | num_parameters += torch.numel(parameter) 84 | return num_parameters 85 | 86 | 87 | def create_alternating_binary_mask(features, even=True): 88 | """ 89 | Creates a binary mask of a given dimension which alternates its masking. 90 | 91 | :param features: Dimension of mask. 92 | :param even: If True, even values are assigned 1s, odd 0s. If False, vice versa. 93 | :return: Alternating binary mask of type torch.Tensor. 94 | """ 95 | mask = torch.zeros(features).byte() 96 | start = 0 if even else 1 97 | mask[start::2] += 1 98 | return mask 99 | 100 | 101 | def create_mid_split_binary_mask(features): 102 | """ 103 | Creates a binary mask of a given dimension which splits its masking at the midpoint. 104 | 105 | :param features: Dimension of mask. 106 | :return: Binary mask split at midpoint of type torch.Tensor 107 | """ 108 | mask = torch.zeros(features).byte() 109 | midpoint = features // 2 if features % 2 == 0 else features // 2 + 1 110 | mask[:midpoint] += 1 111 | return mask 112 | 113 | 114 | def create_random_binary_mask(features): 115 | """ 116 | Creates a random binary mask of a given dimension with half of its entries 117 | randomly set to 1s. 118 | 119 | :param features: Dimension of mask. 120 | :return: Binary mask with half of its entries set to 1s, of type torch.Tensor. 121 | """ 122 | mask = torch.zeros(features).byte() 123 | weights = torch.ones(features).float() 124 | num_samples = features // 2 if features % 2 == 0 else features // 2 + 1 125 | indices = torch.multinomial( 126 | input=weights, 127 | num_samples=num_samples, 128 | replacement=False 129 | ) 130 | mask[indices] += 1 131 | return mask 132 | 133 | def searchsorted(bin_locations, inputs, eps=1e-6): 134 | bin_locations[..., -1] += eps 135 | return torch.sum( 136 | inputs[..., None] >= bin_locations, 137 | dim=-1 138 | ) - 1 139 | 140 | def cbrt(x): 141 | """Cube root. Equivalent to torch.pow(x, 1/3), but numerically stable.""" 142 | return torch.sign(x) * torch.exp(torch.log(torch.abs(x)) / 3.0) 143 | 144 | 145 | def get_temperature(max_value, bound=1-1e-3): 146 | """ 147 | For a dataset with max value 'max_value', returns the temperature such that 148 | 149 | sigmoid(temperature * max_value) = bound. 150 | 151 | If temperature is greater than 1, returns 1. 152 | 153 | :param max_value: 154 | :param bound: 155 | :return: 156 | """ 157 | max_value = torch.Tensor([max_value]) 158 | bound = torch.Tensor([bound]) 159 | temperature = min(- (1 / max_value) * (torch.log1p(-bound) - torch.log(bound)), 1) 160 | return temperature 161 | -------------------------------------------------------------------------------- /utils/torchutils_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the PyTorch utility functions.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | 7 | from utils import torchutils 8 | 9 | 10 | class TorchUtilsTest(torchtestcase.TorchTestCase): 11 | 12 | def test_split_leading_dim(self): 13 | x = torch.randn(24, 5) 14 | self.assertEqual(torchutils.split_leading_dim(x, [-1]), x) 15 | self.assertEqual(torchutils.split_leading_dim(x, [2, -1]), x.view(2, 12, 5)) 16 | self.assertEqual(torchutils.split_leading_dim(x, [2, 3, -1]), x.view(2, 3, 4, 5)) 17 | with self.assertRaises(Exception): 18 | self.assertEqual(torchutils.split_leading_dim(x, []), x) 19 | with self.assertRaises(Exception): 20 | self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x) 21 | 22 | def test_merge_leading_dims(self): 23 | x = torch.randn(2, 3, 4, 5) 24 | self.assertEqual(torchutils.merge_leading_dims(x, 1), x) 25 | self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5)) 26 | self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5)) 27 | self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120)) 28 | with self.assertRaises(Exception): 29 | torchutils.merge_leading_dims(x, 0) 30 | with self.assertRaises(Exception): 31 | torchutils.merge_leading_dims(x, 5) 32 | 33 | def test_split_merge_leading_dims_are_consistent(self): 34 | x = torch.randn(2, 3, 4, 5) 35 | y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2]) 36 | self.assertEqual(y, x) 37 | y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3]) 38 | self.assertEqual(y, x) 39 | y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4]) 40 | self.assertEqual(y, x) 41 | y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5]) 42 | self.assertEqual(y, x) 43 | 44 | def test_repeat_rows(self): 45 | x = torch.randn(2, 3, 4, 5) 46 | self.assertEqual(torchutils.repeat_rows(x, 1), x) 47 | y = torchutils.repeat_rows(x, 2) 48 | self.assertEqual(y.shape, torch.Size([4, 3, 4, 5])) 49 | self.assertEqual(x[0], y[0]) 50 | self.assertEqual(x[0], y[1]) 51 | self.assertEqual(x[1], y[2]) 52 | self.assertEqual(x[1], y[3]) 53 | with self.assertRaises(Exception): 54 | torchutils.repeat_rows(x, 0) 55 | 56 | def test_logabsdet(self): 57 | size = 10 58 | matrix = torch.randn(size, size) 59 | logabsdet = torchutils.logabsdet(matrix) 60 | logabsdet_ref = torch.log(torch.abs(matrix.det())) 61 | self.eps = 1e-6 62 | self.assertEqual(logabsdet, logabsdet_ref) 63 | 64 | def test_random_orthogonal(self): 65 | size = 100 66 | matrix = torchutils.random_orthogonal(size) 67 | self.assertIsInstance(matrix, torch.Tensor) 68 | self.assertEqual(matrix.shape, torch.Size([size, size])) 69 | self.eps = 1e-5 70 | unit = torch.eye(size, size) 71 | self.assertEqual(matrix @ matrix.t(), unit) 72 | self.assertEqual(matrix.t() @ matrix, unit) 73 | self.assertEqual(matrix.t(), matrix.inverse()) 74 | self.assertEqual(torch.abs(matrix.det()), torch.tensor(1.0)) 75 | 76 | def test_searchsorted(self): 77 | bin_locations = torch.linspace(0,1,10) # 9 bins == 10 locations 78 | 79 | left_boundaries = bin_locations[:-1] 80 | right_boundaries = bin_locations[:-1] + 0.1 81 | mid_points = bin_locations[:-1] + 0.05 82 | 83 | for inputs in [left_boundaries, 84 | right_boundaries, 85 | mid_points]: 86 | with self.subTest(inputs=inputs): 87 | idx = torchutils.searchsorted(bin_locations[None, :], inputs) 88 | self.assertEqual(idx, torch.arange(0, 9)) 89 | 90 | def test_searchsorted_arbitrary_shape(self): 91 | shape = [2,3,4] 92 | bin_locations = torch.linspace(0,1,10).repeat(*shape, 1) 93 | inputs = torch.rand(*shape) 94 | idx = torchutils.searchsorted(bin_locations, inputs) 95 | self.assertEqual(idx.shape, inputs.shape) 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /utils/typechecks.py: -------------------------------------------------------------------------------- 1 | """Functions that check types.""" 2 | 3 | 4 | def is_bool(x): 5 | return isinstance(x, bool) 6 | 7 | 8 | def is_int(x): 9 | return isinstance(x, int) 10 | 11 | 12 | def is_positive_int(x): 13 | return is_int(x) and x > 0 14 | 15 | 16 | def is_nonnegative_int(x): 17 | return is_int(x) and x >= 0 18 | 19 | 20 | def is_power_of_two(n): 21 | if is_positive_int(n): 22 | return not n & (n - 1) 23 | else: 24 | return False 25 | -------------------------------------------------------------------------------- /vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import VariationalAutoencoder 2 | -------------------------------------------------------------------------------- /vae/base.py: -------------------------------------------------------------------------------- 1 | """Basic definitions for VAEs.""" 2 | 3 | import torch 4 | 5 | from torch import nn 6 | 7 | import utils 8 | 9 | 10 | class VariationalAutoencoder(nn.Module): 11 | """Implementation of a standard VAE.""" 12 | 13 | def __init__(self, prior, approximate_posterior, likelihood, inputs_encoder=None): 14 | """ 15 | Args: 16 | prior: a distribution object, the prior. 17 | approximate_posterior: a distribution object, the encoder. 18 | likelihood: a distribution object, the decoder. 19 | """ 20 | super().__init__() 21 | self._prior = prior 22 | self._approximate_posterior = approximate_posterior 23 | self._likelihood = likelihood 24 | self._inputs_encoder = inputs_encoder 25 | 26 | def forward(self, *args): 27 | raise RuntimeError('Forward method cannot be called for a VAE object.') 28 | 29 | def stochastic_elbo(self, inputs, num_samples=1, kl_multiplier=1, keepdim=False): 30 | """Calculates an unbiased Monte-Carlo estimate of the evidence lower bound. 31 | 32 | Note: the KL term is also estimated via Monte Carlo. 33 | 34 | Args: 35 | inputs: Tensor of shape [batch_size, ...], the inputs. 36 | num_samples: int, number of samples to use for the Monte-Carlo estimate. 37 | 38 | Returns: 39 | A Tensor of shape [batch_size], an ELBO estimate for each input. 40 | """ 41 | # Sample latents and calculate their log prob under the encoder. 42 | if self._inputs_encoder is None: 43 | posterior_context = inputs 44 | else: 45 | posterior_context = self._inputs_encoder(inputs) 46 | latents, log_q_z = self._approximate_posterior.sample_and_log_prob( 47 | num_samples, 48 | context=posterior_context 49 | ) 50 | latents = utils.merge_leading_dims(latents, num_dims=2) 51 | log_q_z = utils.merge_leading_dims(log_q_z, num_dims=2) 52 | 53 | # Compute log prob of latents under the prior. 54 | log_p_z = self._prior.log_prob(latents) 55 | 56 | # Compute log prob of inputs under the decoder, 57 | inputs = utils.repeat_rows(inputs, num_reps=num_samples) 58 | log_p_x = self._likelihood.log_prob(inputs, context=latents) 59 | 60 | # Compute ELBO. 61 | # TODO: maybe compute KL analytically when possible? 62 | elbo = log_p_x + kl_multiplier * (log_p_z - log_q_z) 63 | elbo = utils.split_leading_dim(elbo, [-1, num_samples]) 64 | if keepdim: 65 | return elbo 66 | else: 67 | return torch.sum(elbo, dim=1) / num_samples # Average ELBO across samples. 68 | 69 | def log_prob_lower_bound(self, inputs, num_samples=100): 70 | elbo = self.stochastic_elbo(inputs, num_samples=num_samples, keepdim=True) 71 | log_prob_lower_bound = torch.logsumexp(elbo, dim=1) - torch.log(torch.Tensor([num_samples])) 72 | return log_prob_lower_bound 73 | 74 | def _decode(self, latents, mean): 75 | if mean: 76 | return self._likelihood.mean(context=latents) 77 | else: 78 | samples = self._likelihood.sample(num_samples=1, context=latents) 79 | return utils.merge_leading_dims(samples, num_dims=2) 80 | 81 | def sample(self, num_samples, mean=False): 82 | """Generates samples from the VAE. 83 | 84 | Args: 85 | num_samples: int, number of samples to generate. 86 | mean: bool, if True it uses the mean of the decoder instead of sampling from it. 87 | 88 | Returns: 89 | A tensor of shape [num_samples, ...], the samples. 90 | """ 91 | latents = self._prior.sample(num_samples) 92 | return self._decode(latents, mean) 93 | 94 | def encode(self, inputs, num_samples=None): 95 | """Encodes inputs into the latent space. 96 | 97 | Args: 98 | inputs: Tensor of shape [batch_size, ...], the inputs to encode. 99 | num_samples: int or None, the number of latent samples to generate per input. If None, 100 | only one latent sample is generated per input. 101 | 102 | Returns: 103 | A Tensor of shape [batch_size, num_samples, ...] or [batch_size, ...] if num_samples 104 | is None, the latent samples for each input. 105 | """ 106 | if num_samples is None: 107 | latents = self._approximate_posterior.sample(num_samples=1, context=inputs) 108 | latents = utils.merge_leading_dims(latents, num_dims=2) 109 | else: 110 | latents = self._approximate_posterior.sample(num_samples=num_samples, context=inputs) 111 | return latents 112 | 113 | def reconstruct(self, inputs, num_samples=None, mean=False): 114 | """Reconstruct given inputs. 115 | 116 | Args: 117 | inputs: Tensor of shape [batch_size, ...], the inputs to reconstruct. 118 | num_samples: int or None, the number of reconstructions to generate per input. If None, 119 | only one reconstruction is generated per input. 120 | mean: bool, if True it uses the mean of the decoder instead of sampling from it. 121 | 122 | Returns: 123 | A Tensor of shape [batch_size, num_samples, ...] or [batch_size, ...] if num_samples 124 | is None, the reconstructions for each input. 125 | """ 126 | latents = self.encode(inputs, num_samples) 127 | if num_samples is not None: 128 | latents = utils.merge_leading_dims(latents, num_dims=2) 129 | recons = self._decode(latents, mean) 130 | if num_samples is not None: 131 | recons = utils.split_leading_dim(recons, [-1, num_samples]) 132 | return recons 133 | -------------------------------------------------------------------------------- /vae/base_test.py: -------------------------------------------------------------------------------- 1 | """Tests for VAEs.""" 2 | 3 | import torch 4 | import torchtestcase 5 | import unittest 6 | 7 | from nde import distributions 8 | from vae import base 9 | 10 | 11 | class VariationalAutoencoderTest(torchtestcase.TorchTestCase): 12 | 13 | def test_stochastic_elbo(self): 14 | batch_size = 10 15 | input_shape = [2, 3, 4] 16 | latent_shape = [5, 6] 17 | 18 | prior = distributions.StandardNormal(latent_shape) 19 | approximate_posterior = distributions.StandardNormal(latent_shape) 20 | likelihood = distributions.StandardNormal(input_shape) 21 | vae = base.VariationalAutoencoder(prior, approximate_posterior, likelihood) 22 | 23 | inputs = torch.randn(batch_size, *input_shape) 24 | for num_samples in [1, 10, 100]: 25 | with self.subTest(num_samples=num_samples): 26 | elbo = vae.stochastic_elbo(inputs, num_samples) 27 | self.assertIsInstance(elbo, torch.Tensor) 28 | self.assertFalse(torch.isnan(elbo).any()) 29 | self.assertFalse(torch.isinf(elbo).any()) 30 | self.assertEqual(elbo.shape, torch.Size([batch_size])) 31 | 32 | def test_sample(self): 33 | num_samples = 10 34 | input_shape = [2, 3, 4] 35 | latent_shape = [5, 6] 36 | 37 | prior = distributions.StandardNormal(latent_shape) 38 | approximate_posterior = distributions.StandardNormal(latent_shape) 39 | likelihood = distributions.StandardNormal(input_shape) 40 | vae = base.VariationalAutoencoder(prior, approximate_posterior, likelihood) 41 | 42 | for mean in [True, False]: 43 | with self.subTest(mean=mean): 44 | samples = vae.sample(num_samples, mean=mean) 45 | self.assertIsInstance(samples, torch.Tensor) 46 | self.assertFalse(torch.isnan(samples).any()) 47 | self.assertFalse(torch.isinf(samples).any()) 48 | self.assertEqual(samples.shape, torch.Size([num_samples] + input_shape)) 49 | 50 | def test_encode(self): 51 | batch_size = 20 52 | input_shape = [2, 3, 4] 53 | latent_shape = [5, 6] 54 | inputs = torch.randn(batch_size, *input_shape) 55 | 56 | prior = distributions.StandardNormal(latent_shape) 57 | approximate_posterior = distributions.StandardNormal(latent_shape) 58 | likelihood = distributions.StandardNormal(input_shape) 59 | vae = base.VariationalAutoencoder(prior, approximate_posterior, likelihood) 60 | 61 | for num_samples in [None, 1, 10]: 62 | with self.subTest(num_samples=num_samples): 63 | encodings = vae.encode(inputs, num_samples) 64 | self.assertIsInstance(encodings, torch.Tensor) 65 | self.assertFalse(torch.isnan(encodings).any()) 66 | self.assertFalse(torch.isinf(encodings).any()) 67 | if num_samples is None: 68 | self.assertEqual(encodings.shape, torch.Size([batch_size] + latent_shape)) 69 | else: 70 | self.assertEqual( 71 | encodings.shape, torch.Size([batch_size, num_samples] + latent_shape)) 72 | 73 | def test_reconstruct(self): 74 | batch_size = 20 75 | input_shape = [2, 3, 4] 76 | latent_shape = [5, 6] 77 | inputs = torch.randn(batch_size, *input_shape) 78 | 79 | prior = distributions.StandardNormal(latent_shape) 80 | approximate_posterior = distributions.StandardNormal(latent_shape) 81 | likelihood = distributions.StandardNormal(input_shape) 82 | vae = base.VariationalAutoencoder(prior, approximate_posterior, likelihood) 83 | 84 | for mean in [True, False]: 85 | for num_samples in [None, 1, 10]: 86 | with self.subTest(mean=mean, num_samples=num_samples): 87 | recons = vae.reconstruct(inputs, num_samples=num_samples, mean=mean) 88 | self.assertIsInstance(recons, torch.Tensor) 89 | self.assertFalse(torch.isnan(recons).any()) 90 | self.assertFalse(torch.isinf(recons).any()) 91 | if num_samples is None: 92 | self.assertEqual(recons.shape, torch.Size([batch_size] + input_shape)) 93 | else: 94 | self.assertEqual( 95 | recons.shape, torch.Size([batch_size, num_samples] + input_shape)) 96 | 97 | 98 | if __name__ == '__main__': 99 | unittest.main() 100 | --------------------------------------------------------------------------------