├── .gitignore ├── LICENSE ├── README.md ├── data ├── load_hmnist.sh ├── load_physionet.sh └── load_sprites.sh ├── figures └── overview.png ├── lib ├── __init__.py ├── gp_kernel.py ├── healing_mnist.py ├── models.py ├── nn_utils.py └── utils.py ├── models └── .gitkeep ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | MANIFEST 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | .pytest_cache/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | db.sqlite3 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | env.bak/ 90 | venv.bak/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ratschlab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GP-VAE: Deep Probabilistic Time Series Imputation 2 | 3 | Code for [paper](http://arxiv.org/abs/1907.04155) 4 | 5 | ## Overview 6 | Our approach utilizes Variational Autoencoders with Gaussian Process prior for time series imputation. 7 | 8 | * The inference model takes time series with missingness and predicts variational parameters for multivariate Gaussian variational distribution. 9 | 10 | * The Gaussian Process prior encourages latent representations to capture the temporal correlations in data. 11 | 12 | * The generative model takes the sample from posterior approximation and reconstructs the original time series with imputed missing values. 13 | 14 | ![img](./figures/overview.png) 15 | 16 | ## Dependencies 17 | 18 | * Python >= 3.6 19 | * TensorFlow = 1.15 20 | * Some more packages: see `requirements.txt` 21 | 22 | ## Run 23 | 1. Clone or download this repo. `cd` yourself to it's root directory. 24 | 2. Grab or build a working python enviromnent. [Anaconda](https://www.anaconda.com/) works fine. 25 | 3. Install dependencies, using `pip install -r requirements.txt` 26 | 4. Download data: `bash data/load_{hmnist, sprites, physionet}.sh`. 27 | 5. Run command `CUDA_VISIBLE_DEVICES=* python train.py --model_type {vae, hi-vae, gp-vae} --data_type {hmnist, sprites, physionet} --exp_name ...` 28 | 29 | To see all available flags run: `python train.py --help` 30 | 31 | ## Reproducibility 32 | 33 | We provide a set of hyperparameters used in our final runs. Some flags have common values for all datasets by default. For reproducibility of reported results run: 34 | * HMNIST: `python train.py --model_type gp-vae --data_type hmnist --exp_name reproduce_hmnist --seed $RANDOM --testing --banded_covar 35 | --latent_dim 256 --encoder_sizes=256,256 --decoder_sizes=256,256,256 --window_size 3 --sigma 1 --length_scale 2 --beta 0.8 --num_epochs 20` 36 | * SPRITES: `python train.py --model_type gp-vae --data_type sprites --exp_name reproduce_sprites --seed $RANDOM --testing --banded_covar 37 | --latent_dim 256 --encoder_sizes=32,256,256 --decoder_sizes=256,256,256 --window_size 3 --sigma 1 --length_scale 2 --beta 0.1 --num_epochs 20` 38 | * Physionet: `python train.py --model_type gp-vae --data_type physionet --exp_name reproduce_physionet --seed $RANDOM --testing --banded_covar 39 | --latent_dim 35 --encoder_sizes=128,128 --decoder_sizes=256,256 --window_size 24 --sigma 1.005 --length_scale 7 --beta 0.2 --num_epochs 40` 40 | 41 | -------------------------------------------------------------------------------- /data/load_hmnist.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR="data/hmnist" 2 | random_mechanism="mnar" 3 | 4 | mkdir -p ${DATA_DIR} 5 | 6 | if [ "$random_mechanism" == "mnar" ] ; then 7 | wget https://www.dropbox.com/s/xzhelx89bzpkkvq/hmnist_mnar.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz 8 | elif [ "$random_mechanism" == "spatial"] ; then 9 | wget https://www.dropbox.com/s/jiix44usv7ibv1z/hmnist_spatial.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz 10 | elif [ "$random_mechanism" == "random" ] ; then 11 | wget https://www.dropbox.com/s/7s5y70f4idw9nei/hmnist_random.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz 12 | elif [ "$random_mechanism" == "temporal_neg" ] ; then 13 | wget https://www.dropbox.com/s/fnqi4rv9wtt2hqo/hmnist_temporal_neg.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz 14 | elif [ "$random_mechanism" == "temporal_pos" ] ; then 15 | wget https://www.dropbox.com/s/tae3rdm9ouaicfb/hmnist_temporal_pos.npz?dl=1 -O ${DATA_DIR}/hmnist_${random_mechanism}.npz 16 | fi 17 | -------------------------------------------------------------------------------- /data/load_physionet.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR="data/physionet" 2 | 3 | mkdir -p ${DATA_DIR} 4 | wget https://www.dropbox.com/s/651d86winb4cy9n/physionet.npz?dl=1 -O ${DATA_DIR}/physionet.npz 5 | -------------------------------------------------------------------------------- /data/load_sprites.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR="data/sprites" 2 | 3 | mkdir -p ${DATA_DIR} 4 | wget https://www.dropbox.com/s/1bdpmsmf7vu7pmb/sprites.npz?dl=1 -O ${DATA_DIR}/sprites.npz 5 | -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ratschlab/GP-VAE/6e06c93d37d07ede2eb1f578b577c80a59846e2c/figures/overview.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .healing_mnist import * 2 | from .utils import * 3 | from .nn_utils import * 4 | from .gp_kernel import * 5 | from .models import * 6 | -------------------------------------------------------------------------------- /lib/gp_kernel.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | ''' 4 | 5 | GP kernel functions 6 | 7 | ''' 8 | 9 | 10 | def rbf_kernel(T, length_scale): 11 | xs = tf.range(T, dtype=tf.float32) 12 | xs_in = tf.expand_dims(xs, 0) 13 | xs_out = tf.expand_dims(xs, 1) 14 | distance_matrix = tf.math.squared_difference(xs_in, xs_out) 15 | distance_matrix_scaled = distance_matrix / length_scale ** 2 16 | kernel_matrix = tf.math.exp(-distance_matrix_scaled) 17 | return kernel_matrix 18 | 19 | 20 | def diffusion_kernel(T, length_scale): 21 | assert length_scale < 0.5, "length_scale has to be smaller than 0.5 for the "\ 22 | "kernel matrix to be diagonally dominant" 23 | sigmas = tf.ones(shape=[T, T]) * length_scale 24 | sigmas_tridiag = tf.linalg.band_part(sigmas, 1, 1) 25 | kernel_matrix = sigmas_tridiag + tf.eye(T)*(1. - length_scale) 26 | return kernel_matrix 27 | 28 | 29 | def matern_kernel(T, length_scale): 30 | xs = tf.range(T, dtype=tf.float32) 31 | xs_in = tf.expand_dims(xs, 0) 32 | xs_out = tf.expand_dims(xs, 1) 33 | distance_matrix = tf.math.abs(xs_in - xs_out) 34 | distance_matrix_scaled = distance_matrix / tf.cast(tf.math.sqrt(length_scale), dtype=tf.float32) 35 | kernel_matrix = tf.math.exp(-distance_matrix_scaled) 36 | return kernel_matrix 37 | 38 | 39 | def cauchy_kernel(T, sigma, length_scale): 40 | xs = tf.range(T, dtype=tf.float32) 41 | xs_in = tf.expand_dims(xs, 0) 42 | xs_out = tf.expand_dims(xs, 1) 43 | distance_matrix = tf.math.squared_difference(xs_in, xs_out) 44 | distance_matrix_scaled = distance_matrix / length_scale ** 2 45 | kernel_matrix = tf.math.divide(sigma, (distance_matrix_scaled + 1.)) 46 | 47 | alpha = 0.001 48 | eye = tf.eye(num_rows=kernel_matrix.shape.as_list()[-1]) 49 | return kernel_matrix + alpha * eye 50 | -------------------------------------------------------------------------------- /lib/healing_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data loader for the Healing MNIST data set (c.f. https://arxiv.org/abs/1511.05121) 3 | 4 | Adapted from https://github.com/Nikita6000/deep_kalman_filter_for_BM/blob/master/healing_mnist.py 5 | """ 6 | 7 | 8 | import numpy as np 9 | import scipy.ndimage 10 | from tensorflow.keras.datasets import mnist 11 | 12 | 13 | def apply_square(img, square_size): 14 | img = np.array(img) 15 | img[:square_size, :square_size] = 255 16 | return img 17 | 18 | 19 | def apply_noise(img, bit_flip_ratio): 20 | img = np.array(img) 21 | mask = np.random.random(size=(28,28)) < bit_flip_ratio 22 | img[mask] = 255 - img[mask] 23 | return img 24 | 25 | 26 | def get_rotations(img, rotation_steps): 27 | for rot in rotation_steps: 28 | img = scipy.ndimage.rotate(img, rot, reshape=False) 29 | yield img 30 | 31 | 32 | def binarize(img): 33 | return (img > 127).astype(np.int) 34 | 35 | 36 | def heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle): 37 | squares_begin = np.random.randint(0, seq_len - square_count) 38 | squares_end = squares_begin + square_count 39 | 40 | rotations = [] 41 | rotation_steps = np.random.normal(size=seq_len, scale=max_angle) 42 | 43 | for idx, rotation in enumerate(get_rotations(img, rotation_steps)): 44 | # Don't add the squares right now 45 | # if idx >= squares_begin and idx < squares_end: 46 | # rotation = apply_square(rotation, square_size) 47 | 48 | # Don't add noise for now 49 | # noisy_img = apply_noise(rotation, noise_ratio) 50 | noisy_img = rotation 51 | binarized_img = binarize(noisy_img) 52 | rotations.append(binarized_img) 53 | 54 | return rotations, rotation_steps 55 | 56 | 57 | class HealingMNIST(): 58 | def __init__(self, seq_len=5, square_count=3, square_size=5, noise_ratio=0.15, digits=range(10), max_angle=180): 59 | (x_train, y_train),(x_test, y_test) = mnist.load_data() 60 | mnist_train = [(img,label) for img, label in zip(x_train, y_train) if label in digits] 61 | mnist_test = [(img, label) for img, label in zip(x_test, y_test) if label in digits] 62 | 63 | train_images = [] 64 | test_images = [] 65 | train_rotations = [] 66 | test_rotations = [] 67 | train_labels = [] 68 | test_labels = [] 69 | 70 | for img, label in mnist_train: 71 | train_img, train_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle) 72 | train_images.append(train_img) 73 | train_rotations.append(train_rot) 74 | train_labels.append(label) 75 | 76 | for img, label in mnist_test: 77 | test_img, test_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle) 78 | test_images.append(test_img) 79 | test_rotations.append(test_rot) 80 | test_labels.append(label) 81 | 82 | self.train_images = np.array(train_images) 83 | self.test_images = np.array(test_images) 84 | self.train_rotations = np.array(train_rotations) 85 | self.test_rotations = np.array(test_rotations) 86 | self.train_labels = np.array(train_labels) 87 | self.test_labels = np.array(test_labels) -------------------------------------------------------------------------------- /lib/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | TensorFlow models for use in this project. 4 | 5 | """ 6 | 7 | from .utils import * 8 | from .nn_utils import * 9 | from .gp_kernel import * 10 | from tensorflow_probability import distributions as tfd 11 | import tensorflow as tf 12 | 13 | 14 | # Encoders 15 | 16 | class DiagonalEncoder(tf.keras.Model): 17 | def __init__(self, z_size, hidden_sizes=(64, 64), **kwargs): 18 | """ Encoder with factorized Normal posterior over temporal dimension 19 | Used by disjoint VAE and HI-VAE with Standard Normal prior 20 | :param z_size: latent space dimensionality 21 | :param hidden_sizes: tuple of hidden layer sizes. 22 | The tuple length sets the number of hidden layers. 23 | """ 24 | super(DiagonalEncoder, self).__init__() 25 | self.z_size = int(z_size) 26 | self.net = make_nn(2*z_size, hidden_sizes) 27 | 28 | def __call__(self, x): 29 | mapped = self.net(x) 30 | return tfd.MultivariateNormalDiag( 31 | loc=mapped[..., :self.z_size], 32 | scale_diag=tf.nn.softplus(mapped[..., self.z_size:])) 33 | 34 | 35 | class JointEncoder(tf.keras.Model): 36 | def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3, transpose=False, **kwargs): 37 | """ Encoder with 1d-convolutional network and factorized Normal posterior 38 | Used by joint VAE and HI-VAE with Standard Normal prior or GP-VAE with factorized Normal posterior 39 | :param z_size: latent space dimensionality 40 | :param hidden_sizes: tuple of hidden layer sizes. 41 | The tuple length sets the number of hidden layers. 42 | :param window_size: kernel size for Conv1D layer 43 | :param transpose: True for GP prior | False for Standard Normal prior 44 | """ 45 | super(JointEncoder, self).__init__() 46 | self.z_size = int(z_size) 47 | self.net = make_cnn(2*z_size, hidden_sizes, window_size) 48 | self.transpose = transpose 49 | 50 | def __call__(self, x): 51 | mapped = self.net(x) 52 | if self.transpose: 53 | num_dim = len(x.shape.as_list()) 54 | perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2] 55 | mapped = tf.transpose(mapped, perm=perm) 56 | return tfd.MultivariateNormalDiag( 57 | loc=mapped[..., :self.z_size, :], 58 | scale_diag=tf.nn.softplus(mapped[..., self.z_size:, :])) 59 | return tfd.MultivariateNormalDiag( 60 | loc=mapped[..., :self.z_size], 61 | scale_diag=tf.nn.softplus(mapped[..., self.z_size:])) 62 | 63 | 64 | class BandedJointEncoder(tf.keras.Model): 65 | def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3, data_type=None, **kwargs): 66 | """ Encoder with 1d-convolutional network and multivariate Normal posterior 67 | Used by GP-VAE with proposed banded covariance matrix 68 | :param z_size: latent space dimensionality 69 | :param hidden_sizes: tuple of hidden layer sizes. 70 | The tuple length sets the number of hidden layers. 71 | :param window_size: kernel size for Conv1D layer 72 | :param data_type: needed for some data specific modifications, e.g: 73 | tf.nn.softplus is a more common and correct choice, however 74 | tf.nn.sigmoid provides more stable performance on Physionet dataset 75 | """ 76 | super(BandedJointEncoder, self).__init__() 77 | self.z_size = int(z_size) 78 | self.net = make_cnn(3*z_size, hidden_sizes, window_size) 79 | self.data_type = data_type 80 | 81 | def __call__(self, x): 82 | mapped = self.net(x) 83 | 84 | batch_size = mapped.shape.as_list()[0] 85 | time_length = mapped.shape.as_list()[1] 86 | 87 | # Obtain mean and precision matrix components 88 | num_dim = len(mapped.shape.as_list()) 89 | perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2] 90 | mapped_transposed = tf.transpose(mapped, perm=perm) 91 | mapped_mean = mapped_transposed[:, :self.z_size] 92 | mapped_covar = mapped_transposed[:, self.z_size:] 93 | 94 | # tf.nn.sigmoid provides more stable performance on Physionet dataset 95 | if self.data_type == 'physionet': 96 | mapped_covar = tf.nn.sigmoid(mapped_covar) 97 | else: 98 | mapped_covar = tf.nn.softplus(mapped_covar) 99 | 100 | mapped_reshaped = tf.reshape(mapped_covar, [batch_size, self.z_size, 2*time_length]) 101 | 102 | dense_shape = [batch_size, self.z_size, time_length, time_length] 103 | idxs_1 = np.repeat(np.arange(batch_size), self.z_size*(2*time_length-1)) 104 | idxs_2 = np.tile(np.repeat(np.arange(self.z_size), (2*time_length-1)), batch_size) 105 | idxs_3 = np.tile(np.concatenate([np.arange(time_length), np.arange(time_length-1)]), batch_size*self.z_size) 106 | idxs_4 = np.tile(np.concatenate([np.arange(time_length), np.arange(1,time_length)]), batch_size*self.z_size) 107 | idxs_all = np.stack([idxs_1, idxs_2, idxs_3, idxs_4], axis=1) 108 | 109 | # ~10x times faster on CPU then on GPU 110 | with tf.device('/cpu:0'): 111 | # Obtain covariance matrix from precision one 112 | mapped_values = tf.reshape(mapped_reshaped[:, :, :-1], [-1]) 113 | prec_sparse = tf.sparse.SparseTensor(indices=idxs_all, values=mapped_values, dense_shape=dense_shape) 114 | prec_sparse = tf.sparse.reorder(prec_sparse) 115 | prec_tril = tf.sparse_add(tf.zeros(prec_sparse.dense_shape, dtype=tf.float32), prec_sparse) 116 | eye = tf.eye(num_rows=prec_tril.shape.as_list()[-1], batch_shape=prec_tril.shape.as_list()[:-2]) 117 | prec_tril = prec_tril + eye 118 | cov_tril = tf.linalg.triangular_solve(matrix=prec_tril, rhs=eye, lower=False) 119 | cov_tril = tf.where(tf.math.is_finite(cov_tril), cov_tril, tf.zeros_like(cov_tril)) 120 | 121 | num_dim = len(cov_tril.shape) 122 | perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2] 123 | cov_tril_lower = tf.transpose(cov_tril, perm=perm) 124 | z_dist = tfd.MultivariateNormalTriL(loc=mapped_mean, scale_tril=cov_tril_lower) 125 | return z_dist 126 | 127 | 128 | # Decoders 129 | 130 | class Decoder(tf.keras.Model): 131 | def __init__(self, output_size, hidden_sizes=(64, 64)): 132 | """ Decoder parent class with no specified output distribution 133 | :param output_size: output dimensionality 134 | :param hidden_sizes: tuple of hidden layer sizes. 135 | The tuple length sets the number of hidden layers. 136 | """ 137 | super(Decoder, self).__init__() 138 | self.net = make_nn(output_size, hidden_sizes) 139 | 140 | def __call__(self, x): 141 | pass 142 | 143 | 144 | class BernoulliDecoder(Decoder): 145 | """ Decoder with Bernoulli output distribution (used for HMNIST) """ 146 | def __call__(self, x): 147 | mapped = self.net(x) 148 | return tfd.Bernoulli(logits=mapped) 149 | 150 | 151 | class GaussianDecoder(Decoder): 152 | """ Decoder with Gaussian output distribution (used for SPRITES and Physionet) """ 153 | def __call__(self, x): 154 | mean = self.net(x) 155 | var = tf.ones(tf.shape(mean), dtype=tf.float32) 156 | return tfd.Normal(loc=mean, scale=var) 157 | 158 | 159 | # Image preprocessor 160 | 161 | class ImagePreprocessor(tf.keras.Model): 162 | def __init__(self, image_shape, hidden_sizes=(256, ), kernel_size=3.): 163 | """ Decoder parent class without specified output distribution 164 | :param image_shape: input image size 165 | :param hidden_sizes: tuple of hidden layer sizes. 166 | The tuple length sets the number of hidden layers. 167 | :param kernel_size: kernel/filter width and height 168 | """ 169 | super(ImagePreprocessor, self).__init__() 170 | self.image_shape = image_shape 171 | self.net = make_2d_cnn(image_shape[-1], hidden_sizes, kernel_size) 172 | 173 | def __call__(self, x): 174 | return self.net(x) 175 | 176 | 177 | # VAE models 178 | 179 | class VAE(tf.keras.Model): 180 | def __init__(self, latent_dim, data_dim, time_length, 181 | encoder_sizes=(64, 64), encoder=DiagonalEncoder, 182 | decoder_sizes=(64, 64), decoder=BernoulliDecoder, 183 | image_preprocessor=None, beta=1.0, M=1, K=1, **kwargs): 184 | """ Basic Variational Autoencoder with Standard Normal prior 185 | :param latent_dim: latent space dimensionality 186 | :param data_dim: original data dimensionality 187 | :param time_length: time series duration 188 | 189 | :param encoder_sizes: layer sizes for the encoder network 190 | :param encoder: encoder model class {Diagonal, Joint, BandedJoint}Encoder 191 | :param decoder_sizes: layer sizes for the decoder network 192 | :param decoder: decoder model class {Bernoulli, Gaussian}Decoder 193 | 194 | :param image_preprocessor: 2d-convolutional network used for image data preprocessing 195 | :param beta: tradeoff coefficient between reconstruction and KL terms in ELBO 196 | :param M: number of Monte Carlo samples for ELBO estimation 197 | :param K: number of importance weights for IWAE model (see: https://arxiv.org/abs/1509.00519) 198 | """ 199 | super(VAE, self).__init__() 200 | self.latent_dim = latent_dim 201 | self.data_dim = data_dim 202 | self.time_length = time_length 203 | 204 | self.encoder = encoder(latent_dim, encoder_sizes, **kwargs) 205 | self.decoder = decoder(data_dim, decoder_sizes) 206 | self.preprocessor = image_preprocessor 207 | 208 | self.beta = beta 209 | self.K = K 210 | self.M = M 211 | 212 | def encode(self, x): 213 | x = tf.identity(x) # in case x is not a Tensor already... 214 | if self.preprocessor is not None: 215 | x_shape = x.shape.as_list() 216 | new_shape = [x_shape[0] * x_shape[1]] + list(self.preprocessor.image_shape) 217 | x_reshaped = tf.reshape(x, new_shape) 218 | x_preprocessed = self.preprocessor(x_reshaped) 219 | x = tf.reshape(x_preprocessed, x_shape) 220 | return self.encoder(x) 221 | 222 | def decode(self, z): 223 | z = tf.identity(z) # in case z is not a Tensor already... 224 | return self.decoder(z) 225 | 226 | def __call__(self, inputs): 227 | return self.decode(self.encode(inputs).sample()).sample() 228 | 229 | def generate(self, noise=None, num_samples=1): 230 | if noise is None: 231 | noise = tf.random_normal(shape=(num_samples, self.latent_dim)) 232 | return self.decode(noise) 233 | 234 | def _get_prior(self): 235 | if self.prior is None: 236 | self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.latent_dim, dtype=tf.float32), 237 | scale_diag=tf.ones(self.latent_dim, dtype=tf.float32)) 238 | return self.prior 239 | 240 | def compute_nll(self, x, y=None, m_mask=None): 241 | # Used only for evaluation 242 | assert len(x.shape) == 3, "Input should have shape: [batch_size, time_length, data_dim]" 243 | if y is None: y = x 244 | 245 | z_sample = self.encode(x).sample() 246 | x_hat_dist = self.decode(z_sample) 247 | nll = -x_hat_dist.log_prob(y) # shape=(BS, TL, D) 248 | nll = tf.where(tf.math.is_finite(nll), nll, tf.zeros_like(nll)) 249 | if m_mask is not None: 250 | m_mask = tf.cast(m_mask, tf.bool) 251 | nll = tf.where(m_mask, nll, tf.zeros_like(nll)) # !!! inverse mask, set zeros for observed 252 | return tf.reduce_sum(nll) 253 | 254 | def compute_mse(self, x, y=None, m_mask=None, binary=False): 255 | # Used only for evaluation 256 | assert len(x.shape) == 3, "Input should have shape: [batch_size, time_length, data_dim]" 257 | if y is None: y = x 258 | 259 | z_mean = self.encode(x).mean() 260 | x_hat_mean = self.decode(z_mean).mean() # shape=(BS, TL, D) 261 | if binary: 262 | x_hat_mean = tf.round(x_hat_mean) 263 | mse = tf.math.squared_difference(x_hat_mean, y) 264 | if m_mask is not None: 265 | m_mask = tf.cast(m_mask, tf.bool) 266 | mse = tf.where(m_mask, mse, tf.zeros_like(mse)) # !!! inverse mask, set zeros for observed 267 | return tf.reduce_sum(mse) 268 | 269 | def _compute_loss(self, x, m_mask=None, return_parts=False): 270 | assert len(x.shape) == 3, "Input should have shape: [batch_size, time_length, data_dim]" 271 | x = tf.identity(x) # in case x is not a Tensor already... 272 | x = tf.tile(x, [self.M * self.K, 1, 1]) # shape=(M*K*BS, TL, D) 273 | 274 | if m_mask is not None: 275 | m_mask = tf.identity(m_mask) # in case m_mask is not a Tensor already... 276 | m_mask = tf.tile(m_mask, [self.M * self.K, 1, 1]) # shape=(M*K*BS, TL, D) 277 | m_mask = tf.cast(m_mask, tf.bool) 278 | 279 | pz = self._get_prior() 280 | qz_x = self.encode(x) 281 | z = qz_x.sample() 282 | px_z = self.decode(z) 283 | 284 | nll = -px_z.log_prob(x) # shape=(M*K*BS, TL, D) 285 | nll = tf.where(tf.math.is_finite(nll), nll, tf.zeros_like(nll)) 286 | if m_mask is not None: 287 | nll = tf.where(m_mask, tf.zeros_like(nll), nll) # if not HI-VAE, m_mask is always zeros 288 | nll = tf.reduce_sum(nll, [1, 2]) # shape=(M*K*BS) 289 | 290 | if self.K > 1: 291 | kl = qz_x.log_prob(z) - pz.log_prob(z) # shape=(M*K*BS, TL or d) 292 | kl = tf.where(tf.is_finite(kl), kl, tf.zeros_like(kl)) 293 | kl = tf.reduce_sum(kl, 1) # shape=(M*K*BS) 294 | 295 | weights = -nll - kl # shape=(M*K*BS) 296 | weights = tf.reshape(weights, [self.M, self.K, -1]) # shape=(M, K, BS) 297 | 298 | elbo = reduce_logmeanexp(weights, axis=1) # shape=(M, 1, BS) 299 | elbo = tf.reduce_mean(elbo) # scalar 300 | else: 301 | # if K==1, compute KL analytically 302 | kl = self.kl_divergence(qz_x, pz) # shape=(M*K*BS, TL or d) 303 | kl = tf.where(tf.math.is_finite(kl), kl, tf.zeros_like(kl)) 304 | kl = tf.reduce_sum(kl, 1) # shape=(M*K*BS) 305 | 306 | elbo = -nll - self.beta * kl # shape=(M*K*BS) K=1 307 | elbo = tf.reduce_mean(elbo) # scalar 308 | 309 | if return_parts: 310 | nll = tf.reduce_mean(nll) # scalar 311 | kl = tf.reduce_mean(kl) # scalar 312 | return -elbo, nll, kl 313 | else: 314 | return -elbo 315 | 316 | def compute_loss(self, x, m_mask=None, return_parts=False): 317 | del m_mask 318 | return self._compute_loss(x, return_parts=return_parts) 319 | 320 | def kl_divergence(self, a, b): 321 | return tfd.kl_divergence(a, b) 322 | 323 | def get_trainable_vars(self): 324 | self.compute_loss(tf.random.normal(shape=(1, self.time_length, self.data_dim), dtype=tf.float32), 325 | tf.zeros(shape=(1, self.time_length, self.data_dim), dtype=tf.float32)) 326 | return self.trainable_variables 327 | 328 | 329 | class HI_VAE(VAE): 330 | """ HI-VAE model, where the reconstruction term in ELBO is summed only over observed components """ 331 | def compute_loss(self, x, m_mask=None, return_parts=False): 332 | return self._compute_loss(x, m_mask=m_mask, return_parts=return_parts) 333 | 334 | 335 | class GP_VAE(HI_VAE): 336 | def __init__(self, *args, kernel="cauchy", sigma=1., length_scale=1.0, kernel_scales=1, **kwargs): 337 | """ Proposed GP-VAE model with Gaussian Process prior 338 | :param kernel: Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] 339 | :param sigma: scale parameter for a kernel function 340 | :param length_scale: length scale parameter for a kernel function 341 | :param kernel_scales: number of different length scales over latent space dimensions 342 | """ 343 | super(GP_VAE, self).__init__(*args, **kwargs) 344 | self.kernel = kernel 345 | self.sigma = sigma 346 | self.length_scale = length_scale 347 | self.kernel_scales = kernel_scales 348 | 349 | if isinstance(self.encoder, JointEncoder): 350 | self.encoder.transpose = True 351 | 352 | # Precomputed KL components for efficiency 353 | self.pz_scale_inv = None 354 | self.pz_scale_log_abs_determinant = None 355 | self.prior = None 356 | 357 | def decode(self, z): 358 | num_dim = len(z.shape) 359 | assert num_dim > 2 360 | perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2] 361 | return self.decoder(tf.transpose(z, perm=perm)) 362 | 363 | def _get_prior(self): 364 | if self.prior is None: 365 | # Compute kernel matrices for each latent dimension 366 | kernel_matrices = [] 367 | for i in range(self.kernel_scales): 368 | if self.kernel == "rbf": 369 | kernel_matrices.append(rbf_kernel(self.time_length, self.length_scale / 2**i)) 370 | elif self.kernel == "diffusion": 371 | kernel_matrices.append(diffusion_kernel(self.time_length, self.length_scale / 2**i)) 372 | elif self.kernel == "matern": 373 | kernel_matrices.append(matern_kernel(self.time_length, self.length_scale / 2**i)) 374 | elif self.kernel == "cauchy": 375 | kernel_matrices.append(cauchy_kernel(self.time_length, self.sigma, self.length_scale / 2**i)) 376 | 377 | # Combine kernel matrices for each latent dimension 378 | tiled_matrices = [] 379 | total = 0 380 | for i in range(self.kernel_scales): 381 | if i == self.kernel_scales-1: 382 | multiplier = self.latent_dim - total 383 | else: 384 | multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) 385 | total += multiplier 386 | tiled_matrices.append(tf.tile(tf.expand_dims(kernel_matrices[i], 0), [multiplier, 1, 1])) 387 | kernel_matrix_tiled = np.concatenate(tiled_matrices) 388 | assert len(kernel_matrix_tiled) == self.latent_dim 389 | 390 | self.prior = tfd.MultivariateNormalFullCovariance( 391 | loc=tf.zeros([self.latent_dim, self.time_length], dtype=tf.float32), 392 | covariance_matrix=kernel_matrix_tiled) 393 | return self.prior 394 | 395 | def kl_divergence(self, a, b): 396 | """ Batched KL divergence `KL(a || b)` for multivariate Normals. 397 | See https://github.com/tensorflow/probability/blob/master/tensorflow_probability 398 | /python/distributions/mvn_linear_operator.py 399 | It's used instead of default KL class in order to exploit precomputed components for efficiency 400 | """ 401 | 402 | def squared_frobenius_norm(x): 403 | """Helper to make KL calculation slightly more readable.""" 404 | return tf.reduce_sum(tf.square(x), axis=[-2, -1]) 405 | 406 | def is_diagonal(x): 407 | """Helper to identify if `LinearOperator` has only a diagonal component.""" 408 | return (isinstance(x, tf.linalg.LinearOperatorIdentity) or 409 | isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or 410 | isinstance(x, tf.linalg.LinearOperatorDiag)) 411 | 412 | if is_diagonal(a.scale) and is_diagonal(b.scale): 413 | # Using `stddev` because it handles expansion of Identity cases. 414 | b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis] 415 | else: 416 | if self.pz_scale_inv is None: 417 | self.pz_scale_inv = tf.linalg.inv(b.scale.to_dense()) 418 | self.pz_scale_inv = tf.where(tf.math.is_finite(self.pz_scale_inv), 419 | self.pz_scale_inv, tf.zeros_like(self.pz_scale_inv)) 420 | 421 | if self.pz_scale_log_abs_determinant is None: 422 | self.pz_scale_log_abs_determinant = b.scale.log_abs_determinant() 423 | 424 | a_shape = a.scale.shape 425 | if len(b.scale.shape) == 3: 426 | _b_scale_inv = tf.tile(self.pz_scale_inv[tf.newaxis], [a_shape[0]] + [1] * (len(a_shape) - 1)) 427 | else: 428 | _b_scale_inv = tf.tile(self.pz_scale_inv, [a_shape[0]] + [1] * (len(a_shape) - 1)) 429 | 430 | b_inv_a = _b_scale_inv @ a.scale.to_dense() 431 | 432 | # ~10x times faster on CPU then on GPU 433 | with tf.device('/cpu:0'): 434 | kl_div = (self.pz_scale_log_abs_determinant - a.scale.log_abs_determinant() + 435 | 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + 436 | squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( 437 | b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) 438 | return kl_div 439 | -------------------------------------------------------------------------------- /lib/nn_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | ''' NN utils ''' 5 | 6 | 7 | def make_nn(output_size, hidden_sizes): 8 | """ Creates fully connected neural network 9 | :param output_size: output dimensionality 10 | :param hidden_sizes: tuple of hidden layer sizes. 11 | The tuple length sets the number of hidden layers. 12 | """ 13 | layers = [tf.keras.layers.Dense(h, activation=tf.nn.relu, dtype=tf.float32) 14 | for h in hidden_sizes] 15 | layers.append(tf.keras.layers.Dense(output_size, dtype=tf.float32)) 16 | return tf.keras.Sequential(layers) 17 | 18 | 19 | def make_cnn(output_size, hidden_sizes, kernel_size=3): 20 | """ Construct neural network consisting of 21 | one 1d-convolutional layer that utilizes temporal dependences, 22 | fully connected network 23 | 24 | :param output_size: output dimensionality 25 | :param hidden_sizes: tuple of hidden layer sizes. 26 | The tuple length sets the number of hidden layers. 27 | :param kernel_size: kernel size for convolutional layer 28 | """ 29 | cnn_layer = [tf.keras.layers.Conv1D(hidden_sizes[0], kernel_size=kernel_size, 30 | padding="same", dtype=tf.float32)] 31 | layers = [tf.keras.layers.Dense(h, activation=tf.nn.relu, dtype=tf.float32) 32 | for h in hidden_sizes[1:]] 33 | layers.append(tf.keras.layers.Dense(output_size, dtype=tf.float32)) 34 | return tf.keras.Sequential(cnn_layer + layers) 35 | 36 | 37 | def make_2d_cnn(output_size, hidden_sizes, kernel_size=3): 38 | """ Creates fully convolutional neural network. 39 | Used as CNN preprocessor for image data (HMNIST, SPRITES) 40 | 41 | :param output_size: output dimensionality 42 | :param hidden_sizes: tuple of hidden layer sizes. 43 | The tuple length sets the number of hidden layers. 44 | :param kernel_size: kernel size for convolutional layers 45 | """ 46 | layers = [tf.keras.layers.Conv2D(h, kernel_size=kernel_size, padding="same", 47 | activation=tf.nn.relu, dtype=tf.float32) 48 | for h in hidden_sizes + [output_size]] 49 | return tf.keras.Sequential(layers) -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | ''' TF utils ''' 6 | 7 | 8 | def reduce_logmeanexp(x, axis, eps=1e-5): 9 | """Numerically-stable (?) implementation of log-mean-exp. 10 | Args: 11 | x: The tensor to reduce. Should have numeric type. 12 | axis: The dimensions to reduce. If `None` (the default), 13 | reduces all dimensions. Must be in the range 14 | `[-rank(input_tensor), rank(input_tensor)]`. 15 | eps: Floating point scalar to avoid log-underflow. 16 | Returns: 17 | log_mean_exp: A `Tensor` representing `log(Avg{exp(x): x})`. 18 | """ 19 | x_max = tf.reduce_max(x, axis=axis, keepdims=True) 20 | return tf.log(tf.reduce_mean( 21 | tf.exp(x - x_max), axis=axis, keepdims=True) + eps) + x_max 22 | 23 | 24 | def multiply_tfd_gaussians(gaussians): 25 | """Multiplies two tfd.MultivariateNormal distributions.""" 26 | mus = [gauss.mean() for gauss in gaussians] 27 | Sigmas = [gauss.covariance() for gauss in gaussians] 28 | mu_3, Sigma_3, _ = multiply_gaussians(mus, Sigmas) 29 | return tfd.MultivariateNormalFullCovariance(loc=mu_3, covariance_matrix=Sigma_3) 30 | 31 | 32 | def multiply_inv_gaussians(mus, lambdas): 33 | """Multiplies a series of Gaussians that is given as a list of mean vectors and a list of precision matrices. 34 | mus: list of mean with shape [n, d] 35 | lambdas: list of precision matrices with shape [n, d, d] 36 | Returns the mean vector, covariance matrix, and precision matrix of the product 37 | """ 38 | assert len(mus) == len(lambdas) 39 | batch_size = int(mus[0].shape[0]) 40 | d_z = int(lambdas[0].shape[-1]) 41 | identity_matrix = tf.reshape(tf.tile(tf.eye(d_z), [batch_size,1]), [-1,d_z,d_z]) 42 | lambda_new = tf.reduce_sum(lambdas, axis=0) + identity_matrix 43 | mus_summed = tf.reduce_sum([tf.einsum("bij, bj -> bi", lamb, mu) 44 | for lamb, mu in zip(lambdas, mus)], axis=0) 45 | sigma_new = tf.linalg.inv(lambda_new) 46 | mu_new = tf.einsum("bij, bj -> bi", sigma_new, mus_summed) 47 | return mu_new, sigma_new, lambda_new 48 | 49 | 50 | def multiply_inv_gaussians_batch(mus, lambdas): 51 | """Multiplies a series of Gaussians that is given as a list of mean vectors and a list of precision matrices. 52 | mus: list of mean with shape [..., d] 53 | lambdas: list of precision matrices with shape [..., d, d] 54 | Returns the mean vector, covariance matrix, and precision matrix of the product 55 | """ 56 | assert len(mus) == len(lambdas) 57 | batch_size = mus[0].shape.as_list()[:-1] 58 | d_z = lambdas[0].shape.as_list()[-1] 59 | identity_matrix = tf.tile(tf.expand_dims(tf.expand_dims(tf.eye(d_z), axis=0), axis=0), batch_size+[1,1]) 60 | lambda_new = tf.reduce_sum(lambdas, axis=0) + identity_matrix 61 | mus_summed = tf.reduce_sum([tf.einsum("bcij, bcj -> bci", lamb, mu) 62 | for lamb, mu in zip(lambdas, mus)], axis=0) 63 | sigma_new = tf.linalg.inv(lambda_new) 64 | mu_new = tf.einsum("bcij, bcj -> bci", sigma_new, mus_summed) 65 | return mu_new, sigma_new, lambda_new 66 | 67 | 68 | def multiply_gaussians(mus, sigmas): 69 | """Multiplies a series of Gaussians that is given as a list of mean vectors and a list of covariance matrices. 70 | mus: list of mean with shape [n, d] 71 | sigmas: list of covariance matrices with shape [n, d, d] 72 | Returns the mean vector, covariance matrix, and precision matrix of the product 73 | """ 74 | assert len(mus) == len(sigmas) 75 | batch_size = [int(n) for n in mus[0].shape[0]] 76 | d_z = int(sigmas[0].shape[-1]) 77 | identity_matrix = tf.reshape(tf.tile(tf.eye(d_z), [batch_size,1]), batch_size+[d_z,d_z]) 78 | sigma_new = identity_matrix 79 | mu_new = tf.zeros((batch_size, d_z)) 80 | for mu, sigma in zip(mus, sigmas): 81 | sigma_inv = tf.linalg.inv(sigma_new + sigma) 82 | sigma_prod = tf.matmul(tf.matmul(sigma_new, sigma_inv), sigma) 83 | mu_prod = (tf.einsum("bij,bj->bi", tf.matmul(sigma, sigma_inv), mu_new) 84 | + tf.einsum("bij,bj->bi", tf.matmul(sigma_new, sigma_inv), mu)) 85 | sigma_new = sigma_prod 86 | mu_new = mu_prod 87 | lambda_new = tf.linalg.inv(sigma_new) 88 | return mu_new, sigma_new, lambda_new 89 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ratschlab/GP-VAE/6e06c93d37d07ede2eb1f578b577c80a59846e2c/models/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.0 2 | numpy==1.16.4 3 | scipy==1.2.0 4 | tensorflow==1.15.0 5 | tensorflow-gpu==1.15.0 6 | tensorflow_probability==0.7.0 7 | matplotlib 8 | scikit-learn 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Script to train the proposed GP-VAE model. 4 | 5 | """ 6 | 7 | import sys 8 | import os 9 | import time 10 | from datetime import datetime 11 | import numpy as np 12 | import matplotlib 13 | matplotlib.use("Agg") 14 | from matplotlib import pyplot as plt 15 | import tensorflow as tf 16 | 17 | tf.compat.v1.enable_eager_execution() 18 | 19 | from sklearn.metrics import average_precision_score, roc_auc_score 20 | from sklearn.linear_model import LogisticRegression 21 | 22 | import warnings 23 | warnings.simplefilter(action='ignore', category=FutureWarning) 24 | 25 | from absl import app 26 | from absl import flags 27 | 28 | sys.path.append("..") 29 | from lib.models import * 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | # HMNIST config 35 | # flags.DEFINE_integer('latent_dim', 256, 'Dimensionality of the latent space') 36 | # flags.DEFINE_list('encoder_sizes', [256, 256], 'Layer sizes of the encoder') 37 | # flags.DEFINE_list('decoder_sizes', [256, 256, 256], 'Layer sizes of the decoder') 38 | # flags.DEFINE_integer('window_size', 3, 'Window size for the inference CNN: Ignored if model_type is not gp-vae') 39 | # flags.DEFINE_float('sigma', 1.0, 'Sigma value for the GP prior: Ignored if model_type is not gp-vae') 40 | # flags.DEFINE_float('length_scale', 2.0, 'Length scale value for the GP prior: Ignored if model_type is not gp-vae') 41 | # flags.DEFINE_float('beta', 0.8, 'Factor to weigh the KL term (similar to beta-VAE)') 42 | # flags.DEFINE_integer('num_epochs', 20, 'Number of training epochs') 43 | 44 | # SPRITES config GP-VAE 45 | # flags.DEFINE_integer('latent_dim', 256, 'Dimensionality of the latent space') 46 | # flags.DEFINE_list('encoder_sizes', [32, 256, 256], 'Layer sizes of the encoder') 47 | # flags.DEFINE_list('decoder_sizes', [256, 256, 256], 'Layer sizes of the decoder') 48 | # flags.DEFINE_integer('window_size', 3, 'Window size for the inference CNN: Ignored if model_type is not gp-vae') 49 | # flags.DEFINE_float('sigma', 1.0, 'Sigma value for the GP prior: Ignored if model_type is not gp-vae') 50 | # flags.DEFINE_float('length_scale', 2.0, 'Length scale value for the GP prior: Ignored if model_type is not gp-vae') 51 | # flags.DEFINE_float('beta', 0.1, 'Factor to weigh the KL term (similar to beta-VAE)') 52 | # flags.DEFINE_integer('num_epochs', 20, 'Number of training epochs') 53 | 54 | # Physionet config 55 | flags.DEFINE_integer('latent_dim', 35, 'Dimensionality of the latent space') 56 | flags.DEFINE_list('encoder_sizes', [128, 128], 'Layer sizes of the encoder') 57 | flags.DEFINE_list('decoder_sizes', [256, 256], 'Layer sizes of the decoder') 58 | flags.DEFINE_integer('window_size', 24, 'Window size for the inference CNN: Ignored if model_type is not gp-vae') 59 | flags.DEFINE_float('sigma', 1.005, 'Sigma value for the GP prior: Ignored if model_type is not gp-vae') 60 | flags.DEFINE_float('length_scale', 7.0, 'Length scale value for the GP prior: Ignored if model_type is not gp-vae') 61 | flags.DEFINE_float('beta', 0.2, 'Factor to weigh the KL term (similar to beta-VAE)') 62 | flags.DEFINE_integer('num_epochs', 40, 'Number of training epochs') 63 | 64 | # Flags with common default values for all three datasets 65 | flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate for training') 66 | flags.DEFINE_float('gradient_clip', 1e4, 'Maximum global gradient norm for the gradient clipping during training') 67 | flags.DEFINE_integer('num_steps', 0, 'Number of training steps: If non-zero it overwrites num_epochs') 68 | flags.DEFINE_integer('print_interval', 0, 'Interval for printing the loss and saving the model during training') 69 | flags.DEFINE_string('exp_name', "debug", 'Name of the experiment') 70 | flags.DEFINE_string('basedir', "models", 'Directory where the models should be stored') 71 | flags.DEFINE_string('data_dir', "", 'Directory from where the data should be read in') 72 | flags.DEFINE_enum('data_type', 'hmnist', ['hmnist', 'physionet', 'sprites'], 'Type of data to be trained on') 73 | flags.DEFINE_integer('seed', 1337, 'Seed for the random number generator') 74 | flags.DEFINE_enum('model_type', 'gp-vae', ['vae', 'hi-vae', 'gp-vae'], 'Type of model to be trained') 75 | flags.DEFINE_integer('cnn_kernel_size', 3, 'Kernel size for the CNN preprocessor') 76 | flags.DEFINE_list('cnn_sizes', [256], 'Number of filters for the layers of the CNN preprocessor') 77 | flags.DEFINE_boolean('testing', False, 'Use the actual test set for testing') 78 | flags.DEFINE_boolean('banded_covar', False, 'Use a banded covariance matrix instead of a diagonal one for the output of the inference network: Ignored if model_type is not gp-vae') 79 | flags.DEFINE_integer('batch_size', 64, 'Batch size for training') 80 | 81 | flags.DEFINE_integer('M', 1, 'Number of samples for ELBO estimation') 82 | flags.DEFINE_integer('K', 1, 'Number of importance sampling weights') 83 | 84 | flags.DEFINE_enum('kernel', 'cauchy', ['rbf', 'diffusion', 'matern', 'cauchy'], 'Kernel to be used for the GP prior: Ignored if model_type is not (m)gp-vae') 85 | flags.DEFINE_integer('kernel_scales', 1, 'Number of different length scales sigma for the GP prior: Ignored if model_type is not gp-vae') 86 | 87 | 88 | def main(argv): 89 | del argv # unused 90 | np.random.seed(FLAGS.seed) 91 | tf.compat.v1.set_random_seed(FLAGS.seed) 92 | 93 | print("Testing: ", FLAGS.testing, f"\t Seed: {FLAGS.seed}") 94 | 95 | FLAGS.encoder_sizes = [int(size) for size in FLAGS.encoder_sizes] 96 | FLAGS.decoder_sizes = [int(size) for size in FLAGS.decoder_sizes] 97 | 98 | if 0 in FLAGS.encoder_sizes: 99 | FLAGS.encoder_sizes.remove(0) 100 | if 0 in FLAGS.decoder_sizes: 101 | FLAGS.decoder_sizes.remove(0) 102 | 103 | # Make up full exp name 104 | timestamp = datetime.now().strftime("%y%m%d") 105 | full_exp_name = "{}_{}".format(timestamp, FLAGS.exp_name) 106 | outdir = os.path.join(FLAGS.basedir, full_exp_name) 107 | if not os.path.exists(outdir): os.mkdir(outdir) 108 | checkpoint_prefix = os.path.join(outdir, "ckpt") 109 | print("Full exp name: ", full_exp_name) 110 | 111 | 112 | ################################### 113 | # Define data specific parameters # 114 | ################################### 115 | 116 | if FLAGS.data_type == "hmnist": 117 | FLAGS.data_dir = "data/hmnist/hmnist_mnar.npz" 118 | data_dim = 784 119 | time_length = 10 120 | num_classes = 10 121 | decoder = BernoulliDecoder 122 | img_shape = (28, 28, 1) 123 | val_split = 50000 124 | elif FLAGS.data_type == "physionet": 125 | if FLAGS.data_dir == "": 126 | FLAGS.data_dir = "data/physionet/physionet.npz" 127 | data_dim = 35 128 | time_length = 48 129 | num_classes = 2 130 | 131 | decoder = GaussianDecoder 132 | elif FLAGS.data_type == "sprites": 133 | if FLAGS.data_dir == "": 134 | FLAGS.data_dir = "data/sprites/sprites.npz" 135 | data_dim = 12288 136 | time_length = 8 137 | decoder = GaussianDecoder 138 | img_shape = (64, 64, 3) 139 | val_split = 8000 140 | else: 141 | raise ValueError("Data type must be one of ['hmnist', 'physionet', 'sprites']") 142 | 143 | 144 | ############# 145 | # Load data # 146 | ############# 147 | 148 | data = np.load(FLAGS.data_dir) 149 | x_train_full = data['x_train_full'] 150 | x_train_miss = data['x_train_miss'] 151 | m_train_miss = data['m_train_miss'] 152 | if FLAGS.data_type in ['hmnist', 'physionet']: 153 | y_train = data['y_train'] 154 | 155 | if FLAGS.testing: 156 | if FLAGS.data_type in ['hmnist', 'sprites']: 157 | x_val_full = data['x_test_full'] 158 | x_val_miss = data['x_test_miss'] 159 | m_val_miss = data['m_test_miss'] 160 | if FLAGS.data_type == 'hmnist': 161 | y_val = data['y_test'] 162 | elif FLAGS.data_type == 'physionet': 163 | x_val_full = data['x_train_full'] 164 | x_val_miss = data['x_train_miss'] 165 | m_val_miss = data['m_train_miss'] 166 | y_val = data['y_train'] 167 | m_val_artificial = data["m_train_artificial"] 168 | elif FLAGS.data_type in ['hmnist', 'sprites']: 169 | x_val_full = x_train_full[val_split:] 170 | x_val_miss = x_train_miss[val_split:] 171 | m_val_miss = m_train_miss[val_split:] 172 | if FLAGS.data_type == 'hmnist': 173 | y_val = y_train[val_split:] 174 | x_train_full = x_train_full[:val_split] 175 | x_train_miss = x_train_miss[:val_split] 176 | m_train_miss = m_train_miss[:val_split] 177 | y_train = y_train[:val_split] 178 | elif FLAGS.data_type == 'physionet': 179 | x_val_full = data["x_val_full"] # full for artificial missings 180 | x_val_miss = data["x_val_miss"] 181 | m_val_miss = data["m_val_miss"] 182 | m_val_artificial = data["m_val_artificial"] 183 | y_val = data["y_val"] 184 | else: 185 | raise ValueError("Data type must be one of ['hmnist', 'physionet', 'sprites']") 186 | 187 | tf_x_train_miss = tf.data.Dataset.from_tensor_slices((x_train_miss, m_train_miss))\ 188 | .shuffle(len(x_train_miss)).batch(FLAGS.batch_size).repeat() 189 | tf_x_val_miss = tf.data.Dataset.from_tensor_slices((x_val_miss, m_val_miss)).batch(FLAGS.batch_size).repeat() 190 | tf_x_val_miss = tf.compat.v1.data.make_one_shot_iterator(tf_x_val_miss) 191 | 192 | # Build Conv2D preprocessor for image data 193 | if FLAGS.data_type in ['hmnist', 'sprites']: 194 | print("Using CNN preprocessor") 195 | image_preprocessor = ImagePreprocessor(img_shape, FLAGS.cnn_sizes, FLAGS.cnn_kernel_size) 196 | elif FLAGS.data_type == 'physionet': 197 | image_preprocessor = None 198 | else: 199 | raise ValueError("Data type must be one of ['hmnist', 'physionet', 'sprites']") 200 | 201 | 202 | ############### 203 | # Build model # 204 | ############### 205 | 206 | if FLAGS.model_type == "vae": 207 | model = VAE(latent_dim=FLAGS.latent_dim, data_dim=data_dim, time_length=time_length, 208 | encoder_sizes=FLAGS.encoder_sizes, encoder=DiagonalEncoder, 209 | decoder_sizes=FLAGS.decoder_sizes, decoder=decoder, 210 | image_preprocessor=image_preprocessor, window_size=FLAGS.window_size, 211 | beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K) 212 | elif FLAGS.model_type == "hi-vae": 213 | model = HI_VAE(latent_dim=FLAGS.latent_dim, data_dim=data_dim, time_length=time_length, 214 | encoder_sizes=FLAGS.encoder_sizes, encoder=DiagonalEncoder, 215 | decoder_sizes=FLAGS.decoder_sizes, decoder=decoder, 216 | image_preprocessor=image_preprocessor, window_size=FLAGS.window_size, 217 | beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K) 218 | elif FLAGS.model_type == "gp-vae": 219 | encoder = BandedJointEncoder if FLAGS.banded_covar else JointEncoder 220 | model = GP_VAE(latent_dim=FLAGS.latent_dim, data_dim=data_dim, time_length=time_length, 221 | encoder_sizes=FLAGS.encoder_sizes, encoder=encoder, 222 | decoder_sizes=FLAGS.decoder_sizes, decoder=decoder, 223 | kernel=FLAGS.kernel, sigma=FLAGS.sigma, 224 | length_scale=FLAGS.length_scale, kernel_scales = FLAGS.kernel_scales, 225 | image_preprocessor=image_preprocessor, window_size=FLAGS.window_size, 226 | beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K, data_type=FLAGS.data_type) 227 | else: 228 | raise ValueError("Model type must be one of ['vae', 'hi-vae', 'gp-vae']") 229 | 230 | 231 | ######################## 232 | # Training preparation # 233 | ######################## 234 | 235 | print("GPU support: ", tf.test.is_gpu_available()) 236 | 237 | print("Training...") 238 | _ = tf.compat.v1.train.get_or_create_global_step() 239 | trainable_vars = model.get_trainable_vars() 240 | optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) 241 | 242 | print("Encoder: ", model.encoder.net.summary()) 243 | print("Decoder: ", model.decoder.net.summary()) 244 | 245 | if model.preprocessor is not None: 246 | print("Preprocessor: ", model.preprocessor.net.summary()) 247 | saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net, 248 | decoder=model.decoder.net, preprocessor=model.preprocessor.net, 249 | optimizer_step=tf.compat.v1.train.get_or_create_global_step()) 250 | else: 251 | saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net, decoder=model.decoder.net, 252 | optimizer_step=tf.compat.v1.train.get_or_create_global_step()) 253 | 254 | summary_writer = tf.contrib.summary.create_file_writer(outdir, flush_millis=10000) 255 | 256 | if FLAGS.num_steps == 0: 257 | num_steps = FLAGS.num_epochs * len(x_train_miss) // FLAGS.batch_size 258 | else: 259 | num_steps = FLAGS.num_steps 260 | 261 | if FLAGS.print_interval == 0: 262 | FLAGS.print_interval = num_steps // FLAGS.num_epochs 263 | 264 | 265 | ############ 266 | # Training # 267 | ############ 268 | 269 | losses_train = [] 270 | losses_val = [] 271 | 272 | t0 = time.time() 273 | with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): 274 | for i, (x_seq, m_seq) in enumerate(tf_x_train_miss.take(num_steps)): 275 | try: 276 | with tf.GradientTape() as tape: 277 | tape.watch(trainable_vars) 278 | loss = model.compute_loss(x_seq, m_mask=m_seq) 279 | losses_train.append(loss.numpy()) 280 | grads = tape.gradient(loss, trainable_vars) 281 | grads = [np.nan_to_num(grad) for grad in grads] 282 | grads, global_norm = tf.clip_by_global_norm(grads, FLAGS.gradient_clip) 283 | optimizer.apply_gradients(zip(grads, trainable_vars), 284 | global_step=tf.compat.v1.train.get_or_create_global_step()) 285 | 286 | # Print intermediate results 287 | if i % FLAGS.print_interval == 0: 288 | print("================================================") 289 | print("Learning rate: {} | Global gradient norm: {:.2f}".format(optimizer._lr, global_norm)) 290 | print("Step {}) Time = {:2f}".format(i, time.time() - t0)) 291 | loss, nll, kl = model.compute_loss(x_seq, m_mask=m_seq, return_parts=True) 292 | print("Train loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}".format(loss, nll, kl)) 293 | 294 | saver.save(checkpoint_prefix) 295 | tf.contrib.summary.scalar("loss_train", loss) 296 | tf.contrib.summary.scalar("kl_train", kl) 297 | tf.contrib.summary.scalar("nll_train", nll) 298 | 299 | # Validation loss 300 | x_val_batch, m_val_batch = tf_x_val_miss.get_next() 301 | val_loss, val_nll, val_kl = model.compute_loss(x_val_batch, m_mask=m_val_batch, return_parts=True) 302 | losses_val.append(val_loss.numpy()) 303 | print("Validation loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}".format(val_loss, val_nll, val_kl)) 304 | 305 | tf.contrib.summary.scalar("loss_val", val_loss) 306 | tf.contrib.summary.scalar("kl_val", val_kl) 307 | tf.contrib.summary.scalar("nll_val", val_nll) 308 | 309 | if FLAGS.data_type in ["hmnist", "sprites"]: 310 | # Draw reconstructed images 311 | x_hat = model.decode(model.encode(x_seq).sample()).mean() 312 | tf.contrib.summary.image("input_train", tf.reshape(x_seq, [-1]+list(img_shape))) 313 | tf.contrib.summary.image("reconstruction_train", tf.reshape(x_hat, [-1]+list(img_shape))) 314 | elif FLAGS.data_type == 'physionet': 315 | # Eval MSE and AUROC on entire val set 316 | x_val_miss_batches = np.array_split(x_val_miss, FLAGS.batch_size, axis=0) 317 | x_val_full_batches = np.array_split(x_val_full, FLAGS.batch_size, axis=0) 318 | m_val_artificial_batches = np.array_split(m_val_artificial, FLAGS.batch_size, axis=0) 319 | get_val_batches = lambda: zip(x_val_miss_batches, x_val_full_batches, m_val_artificial_batches) 320 | 321 | n_missings = m_val_artificial.sum() 322 | mse_miss = np.sum([model.compute_mse(x, y=y, m_mask=m).numpy() 323 | for x, y, m in get_val_batches()]) / n_missings 324 | 325 | x_val_imputed = np.vstack([model.decode(model.encode(x_batch).mean()).mean().numpy() 326 | for x_batch in x_val_miss_batches]) 327 | x_val_imputed[m_val_miss == 0] = x_val_miss[m_val_miss == 0] # impute gt observed values 328 | 329 | x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim]) 330 | val_split = len(x_val_imputed) // 2 331 | cls_model = LogisticRegression(solver='liblinear', tol=1e-10, max_iter=10000) 332 | cls_model.fit(x_val_imputed[:val_split], y_val[:val_split]) 333 | probs = cls_model.predict_proba(x_val_imputed[val_split:])[:, 1] 334 | auroc = roc_auc_score(y_val[val_split:], probs) 335 | print("MSE miss: {:.4f} | AUROC: {:.4f}".format(mse_miss, auroc)) 336 | 337 | # Update learning rate (used only for physionet with decay=0.5) 338 | if i > 0 and i % (10*FLAGS.print_interval) == 0: 339 | optimizer._lr = max(0.5 * optimizer._lr, 0.1 * FLAGS.learning_rate) 340 | t0 = time.time() 341 | except KeyboardInterrupt: 342 | saver.save(checkpoint_prefix) 343 | if FLAGS.debug: 344 | import ipdb 345 | ipdb.set_trace() 346 | break 347 | 348 | 349 | ############## 350 | # Evaluation # 351 | ############## 352 | 353 | print("Evaluation...") 354 | 355 | # Split data on batches 356 | x_val_miss_batches = np.array_split(x_val_miss, FLAGS.batch_size, axis=0) 357 | x_val_full_batches = np.array_split(x_val_full, FLAGS.batch_size, axis=0) 358 | if FLAGS.data_type == 'physionet': 359 | m_val_batches = np.array_split(m_val_artificial, FLAGS.batch_size, axis=0) 360 | else: 361 | m_val_batches = np.array_split(m_val_miss, FLAGS.batch_size, axis=0) 362 | get_val_batches = lambda: zip(x_val_miss_batches, x_val_full_batches, m_val_batches) 363 | 364 | # Compute NLL and MSE on missing values 365 | n_missings = m_val_artificial.sum() if FLAGS.data_type == 'physionet' else m_val_miss.sum() 366 | nll_miss = np.sum([model.compute_nll(x, y=y, m_mask=m).numpy() 367 | for x, y, m in get_val_batches()]) / n_missings 368 | mse_miss = np.sum([model.compute_mse(x, y=y, m_mask=m, binary=FLAGS.data_type=="hmnist").numpy() 369 | for x, y, m in get_val_batches()]) / n_missings 370 | print("NLL miss: {:.4f}".format(nll_miss)) 371 | print("MSE miss: {:.4f}".format(mse_miss)) 372 | 373 | # Save imputed values 374 | z_mean = [model.encode(x_batch).mean().numpy() for x_batch in x_val_miss_batches] 375 | np.save(os.path.join(outdir, "z_mean"), np.vstack(z_mean)) 376 | x_val_imputed = np.vstack([model.decode(z_batch).mean().numpy() for z_batch in z_mean]) 377 | np.save(os.path.join(outdir, "imputed_no_gt"), x_val_imputed) 378 | 379 | # impute gt observed values 380 | x_val_imputed[m_val_miss == 0] = x_val_miss[m_val_miss == 0] 381 | np.save(os.path.join(outdir, "imputed"), x_val_imputed) 382 | 383 | if FLAGS.data_type == "hmnist": 384 | # AUROC evaluation using Logistic Regression 385 | x_val_imputed = np.round(x_val_imputed) 386 | x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim]) 387 | 388 | cls_model = LogisticRegression(solver='lbfgs', multi_class='multinomial', tol=1e-10, max_iter=10000) 389 | val_split = len(x_val_imputed) // 2 390 | 391 | cls_model.fit(x_val_imputed[:val_split], y_val[:val_split]) 392 | probs = cls_model.predict_proba(x_val_imputed[val_split:]) 393 | 394 | auprc = average_precision_score(np.eye(num_classes)[y_val[val_split:]], probs) 395 | auroc = roc_auc_score(np.eye(num_classes)[y_val[val_split:]], probs) 396 | print("AUROC: {:.4f}".format(auroc)) 397 | print("AUPRC: {:.4f}".format(auprc)) 398 | 399 | elif FLAGS.data_type == "sprites": 400 | auroc, auprc = 0, 0 401 | 402 | elif FLAGS.data_type == "physionet": 403 | # Uncomment to preserve some z_samples and their reconstructions 404 | # for i in range(5): 405 | # z_sample = [model.encode(x_batch).sample().numpy() for x_batch in x_val_miss_batches] 406 | # np.save(os.path.join(outdir, "z_sample_{}".format(i)), np.vstack(z_sample)) 407 | # x_val_imputed_sample = np.vstack([model.decode(z_batch).mean().numpy() for z_batch in z_sample]) 408 | # np.save(os.path.join(outdir, "imputed_sample_{}_no_gt".format(i)), x_val_imputed_sample) 409 | # x_val_imputed_sample[m_val_miss == 0] = x_val_miss[m_val_miss == 0] 410 | # np.save(os.path.join(outdir, "imputed_sample_{}".format(i)), x_val_imputed_sample) 411 | 412 | # AUROC evaluation using Logistic Regression 413 | x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim]) 414 | val_split = len(x_val_imputed) // 2 415 | cls_model = LogisticRegression(solver='liblinear', tol=1e-10, max_iter=10000) 416 | cls_model.fit(x_val_imputed[:val_split], y_val[:val_split]) 417 | probs = cls_model.predict_proba(x_val_imputed[val_split:])[:, 1] 418 | auprc = average_precision_score(y_val[val_split:], probs) 419 | auroc = roc_auc_score(y_val[val_split:], probs) 420 | 421 | print("AUROC: {:.4f}".format(auroc)) 422 | print("AUPRC: {:.4f}".format(auprc)) 423 | 424 | # Visualize reconstructions 425 | if FLAGS.data_type in ["hmnist", "sprites"]: 426 | img_index = 0 427 | if FLAGS.data_type == "hmnist": 428 | img_shape = (28, 28) 429 | cmap = "gray" 430 | elif FLAGS.data_type == "sprites": 431 | img_shape = (64, 64, 3) 432 | cmap = None 433 | 434 | fig, axes = plt.subplots(nrows=3, ncols=x_val_miss.shape[1], figsize=(2*x_val_miss.shape[1], 6)) 435 | 436 | x_hat = model.decode(model.encode(x_val_miss[img_index: img_index+1]).mean()).mean().numpy() 437 | seqs = [x_val_miss[img_index:img_index+1], x_hat, x_val_full[img_index:img_index+1]] 438 | 439 | for axs, seq in zip(axes, seqs): 440 | for ax, img in zip(axs, seq[0]): 441 | ax.imshow(img.reshape(img_shape), cmap=cmap) 442 | ax.axis('off') 443 | 444 | suptitle = FLAGS.model_type + f" reconstruction, NLL missing = {mse_miss}" 445 | fig.suptitle(suptitle, size=18) 446 | fig.savefig(os.path.join(outdir, FLAGS.data_type + "_reconstruction.pdf")) 447 | 448 | results_all = [FLAGS.seed, FLAGS.model_type, FLAGS.data_type, FLAGS.kernel, FLAGS.beta, FLAGS.latent_dim, 449 | FLAGS.num_epochs, FLAGS.batch_size, FLAGS.learning_rate, FLAGS.window_size, 450 | FLAGS.kernel_scales, FLAGS.sigma, FLAGS.length_scale, 451 | len(FLAGS.encoder_sizes), FLAGS.encoder_sizes[0] if len(FLAGS.encoder_sizes) > 0 else 0, 452 | len(FLAGS.decoder_sizes), FLAGS.decoder_sizes[0] if len(FLAGS.decoder_sizes) > 0 else 0, 453 | FLAGS.cnn_kernel_size, FLAGS.cnn_sizes, 454 | nll_miss, mse_miss, losses_train[-1], losses_val[-1], auprc, auroc, FLAGS.testing, FLAGS.data_dir] 455 | 456 | with open(os.path.join(outdir, "results.tsv"), "w") as outfile: 457 | outfile.write("seed\tmodel\tdata\tkernel\tbeta\tz_size\tnum_epochs" 458 | "\tbatch_size\tlearning_rate\twindow_size\tkernel_scales\t" 459 | "sigma\tlength_scale\tencoder_depth\tencoder_width\t" 460 | "decoder_depth\tdecoder_width\tcnn_kernel_size\t" 461 | "cnn_sizes\tNLL\tMSE\tlast_train_loss\tlast_val_loss\tAUPRC\tAUROC\ttesting\tdata_dir\n") 462 | outfile.write("\t".join(map(str, results_all))) 463 | 464 | with open(os.path.join(outdir, "training_curve.tsv"), "w") as outfile: 465 | outfile.write("\t".join(map(str, losses_train))) 466 | outfile.write("\n") 467 | outfile.write("\t".join(map(str, losses_val))) 468 | 469 | print("Training finished.") 470 | 471 | 472 | if __name__ == '__main__': 473 | app.run(main) 474 | --------------------------------------------------------------------------------