├── .gitignore ├── __pycache__ ├── config.cpython-36.pyc └── utils.cpython-36.pyc ├── models ├── __pycache__ │ ├── base.cpython-36.pyc │ ├── vae.cpython-36.pyc │ ├── training.cpython-36.pyc │ └── evaluation.cpython-36.pyc ├── base.py ├── evaluation.py ├── training.py └── vae.py ├── requirements.txt ├── config.py ├── utils.py ├── README.md ├── sampling.py ├── generate_data.py ├── LICENSE ├── train_vae.py └── models_to_train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | generated_data/* 2 | */__pycache__/** 3 | *.pyc 4 | trained_vae_models/* -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/geometric_perspective_on_vaes/HEAD/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/geometric_perspective_on_vaes/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/geometric_perspective_on_vaes/HEAD/models/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/geometric_perspective_on_vaes/HEAD/models/__pycache__/vae.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/training.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/geometric_perspective_on_vaes/HEAD/models/__pycache__/training.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/geometric_perspective_on_vaes/HEAD/models/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.61.2 2 | matplotlib==3.3.4 3 | torchvision==0.10.0 4 | h5py==2.10.0 5 | imageio==2.8.0 6 | scipy 7 | scikit_learn_extra==0.2.0 8 | torch==1.9.0 9 | numpy 10 | scikit_learn 11 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class BaseVAE(ABC): 5 | def __init__(self, args): 6 | self.device = args.device 7 | 8 | @abstractmethod 9 | def encode(self): 10 | pass 11 | 12 | @abstractmethod 13 | def decode(self): 14 | pass 15 | 16 | @abstractmethod 17 | def loss_function(self): 18 | pass 19 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class VAE_config: 5 | input_dim: int = None 6 | model_name: str = "VAE" 7 | architecture: str= 'convnet' 8 | n_channels: int = None 9 | latent_dim: int = None 10 | beta: float =1 11 | device: str = 'cuda' 12 | cuda: bool= True 13 | dynamic_binarization: bool= False 14 | dataset: str=None 15 | 16 | @dataclass 17 | class RHVAE_config: 18 | input_dim: int = None 19 | model_name: str = "RHVAE" 20 | architecture: str= 'convnet' 21 | n_channels: int = None 22 | latent_dim: int = None 23 | beta: float =1 24 | n_lf: int =1 25 | eps_lf: float=0.001 26 | temperature: float = 0.8 27 | regularization: float = 0.001 28 | device: str = 'cuda' 29 | cuda: bool= True 30 | beta_zero: float = 0.3 31 | metric_fc: int = 400 32 | dynamic_binarization: bool= False 33 | dataset: str=None 34 | 35 | 36 | @dataclass 37 | class VAMP_config: 38 | model_name="VAMP" 39 | architecture: str= 'convnet' 40 | input_size: int = None 41 | z1_size: int = None 42 | n_channels: int = None 43 | prior: str = 'vampprior' 44 | input_type: str = 'continuous' 45 | use_training_data_init: bool = False 46 | pseudoinputs_mean: float = 0.05 47 | pseudoinputs_std: float = 0.01 48 | number_components: int = 10 49 | dataset: str=None 50 | dynamic_binarization = False 51 | warmup: int = 0 52 | beta: int = 1 53 | cuda = True 54 | -------------------------------------------------------------------------------- /models/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def eval_vae(epoch, args, model, val_loader): 5 | val_loss = 0 6 | val_loss_rec = 0 7 | val_loss_kld = 0 8 | 9 | # Set model on eval mode 10 | model.eval() 11 | 12 | for batch_idx, (data, target) in enumerate(val_loader): 13 | if args.cuda: 14 | data, target = data.cuda(), target.cuda() 15 | 16 | if args.dynamic_binarization: 17 | x = torch.bernoulli(data) 18 | 19 | else: 20 | x = data 21 | 22 | if args.model_name == "RHVAE": 23 | # forward pass 24 | ( 25 | recon_batch, 26 | z, 27 | z0, 28 | rho, 29 | eps0, 30 | gamma, 31 | mu, 32 | log_var, 33 | G_inv, 34 | G_log_det, 35 | ) = model(data) 36 | # loss computation 37 | loss = model.loss_function( 38 | recon_batch, 39 | data, 40 | z0, 41 | z, 42 | rho, 43 | eps0, 44 | gamma, 45 | mu, 46 | log_var, 47 | G_inv, 48 | G_log_det, 49 | ) 50 | 51 | loss_rec = torch.zeros(1) 52 | loss_kld = torch.zeros(1) 53 | 54 | else: 55 | with torch.no_grad(): 56 | 57 | # forward pass 58 | recon_batch, z, _, mu, log_var = model(data) 59 | # loss computation 60 | loss, loss_rec, loss_kld = model.loss_function(recon_batch, data, mu, log_var, z) 61 | 62 | 63 | 64 | val_loss += loss.item() / len(val_loader.dataset) 65 | val_loss_rec += loss_rec.item() / len(val_loader.dataset) 66 | val_loss_kld += loss_kld.item() / len(val_loader.dataset) 67 | 68 | # calculate final loss 69 | 70 | return val_loss, val_loss_rec, val_loss_kld 71 | -------------------------------------------------------------------------------- /models/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | def train_vae(epoch, args, model, train_loader, optimizer): 5 | train_loss = 0 6 | train_loss_rec = 0 7 | train_loss_kld = 0 8 | 9 | # Set model on training mode 10 | model.train() 11 | 12 | for batch_idx, (data, target) in enumerate(tqdm(train_loader)): 13 | if args.cuda: 14 | data, target = data.cuda(), target.cuda() 15 | 16 | if args.dynamic_binarization: 17 | x = torch.bernoulli(data) 18 | 19 | else: 20 | x = data 21 | 22 | # reset gradients 23 | optimizer.zero_grad() 24 | 25 | if args.model_name == "RHVAE": 26 | # forward pass 27 | ( 28 | recon_batch, 29 | z, 30 | z0, 31 | rho, 32 | eps0, 33 | gamma, 34 | mu, 35 | log_var, 36 | G_inv, 37 | G_log_det, 38 | ) = model(data) 39 | # loss computation 40 | loss = model.loss_function( 41 | recon_batch, 42 | data, 43 | z0, 44 | z, 45 | rho, 46 | eps0, 47 | gamma, 48 | mu, 49 | log_var, 50 | G_inv, 51 | G_log_det, 52 | ) 53 | 54 | loss_rec = torch.zeros(1) 55 | loss_kld = torch.zeros(1) 56 | 57 | else: 58 | 59 | # forward pass 60 | recon_batch, z, _, mu, log_var = model(data) 61 | # loss computation 62 | loss, loss_rec, loss_kld = model.loss_function(recon_batch, data, mu, log_var, z) 63 | 64 | # backward pass 65 | loss.backward() 66 | # optimization 67 | optimizer.step() 68 | 69 | train_loss += loss.item() / len(train_loader.dataset) 70 | train_loss_rec += loss_rec.item() / len(train_loader.dataset) 71 | train_loss_kld += loss_kld.item() / len(train_loader.dataset) 72 | 73 | if args.model_name == "RHVAE": 74 | model.update_metric() 75 | # calculate final loss 76 | 77 | return model, train_loss, train_loss_rec, train_loss_kld 78 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | 4 | 5 | class Digits(data.Dataset): 6 | def __init__(self, digits, labels, mask=None, binarize=False): 7 | 8 | self.labels = labels 9 | 10 | if binarize: 11 | self.data = (torch.rand_like(digits) < digits).type(torch.float) 12 | 13 | else: 14 | self.data = digits.type(torch.float) 15 | 16 | def __len__(self): 17 | return len(self.data) 18 | 19 | def __getitem__(self, index): 20 | "Generates one sample of data" 21 | # Select sample 22 | X = self.data[index] 23 | y = self.labels[index] 24 | 25 | 26 | return X, y 27 | 28 | def create_metric(model, device='cpu'): 29 | """ 30 | Metric creation for RHVAE model 31 | """ 32 | def G(z): 33 | return torch.inverse( 34 | ( 35 | model.M_tens.unsqueeze(0) 36 | * torch.exp( 37 | -torch.norm( 38 | model.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1 39 | ) 40 | ** 2 41 | / (model.T ** 2) 42 | ) 43 | .unsqueeze(-1) 44 | .unsqueeze(-1) 45 | ).sum(dim=1) 46 | + model.lbd * torch.eye(model.latent_dim).to(device) 47 | ) 48 | 49 | return G 50 | 51 | def create_metric_inv(model, device='cpu'): 52 | """ 53 | Metric creation for RHVAE model 54 | """ 55 | def G_inv(z): 56 | return ( 57 | model.M_tens.unsqueeze(0) 58 | * torch.exp( 59 | -torch.norm(model.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1) 60 | ** 2 61 | / (model.T ** 2) 62 | ) 63 | .unsqueeze(-1) 64 | .unsqueeze(-1) 65 | ).sum(dim=1) + model.lbd * torch.eye(model.latent_dim).to(device) 66 | 67 | return G_inv 68 | 69 | def create_dH_dz(model): 70 | """ 71 | Computation of derivative of Hamiltonian for RHVAE model 72 | """ 73 | def dH_dz(z, q): 74 | 75 | a = ( 76 | torch.transpose(q.unsqueeze(-1).unsqueeze(1), 2, 3) 77 | @ model.M_tens.unsqueeze(0) 78 | @ q.unsqueeze(-1).unsqueeze(1) 79 | ) 80 | 81 | b = centroids_tens.unsqueeze(0) - z.unsqueeze(1) 82 | 83 | return ( 84 | -1 85 | / (model.T ** 2) 86 | * b.unsqueeze(-1) 87 | @ a 88 | * ( 89 | torch.exp( 90 | -torch.norm( 91 | model.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1 92 | ) 93 | ** 2 94 | / (model.T ** 2) 95 | ) 96 | ) 97 | .unsqueeze(-1) 98 | .unsqueeze(-1) 99 | ).sum(dim=1) 100 | return dH_dz 101 | 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This is the official implementation of ["A Geometric Perspective on Variational Autoencoders"](https://arxiv.org/abs/2209.07370) (NeurIPS 2022) 2 | 3 | This code uses a version of **python3.6**. 4 | 5 | **Note**: The method should be soon added to [`pythae`](https://github.com/clementchadebec/benchmark_VAE). 6 | 7 | To install requirement run 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Data folders 14 | 15 | The data must be located in `data_folders`: 16 | 17 | ### MNIST 18 | The provided code requires a file `mnist_32x32.npz` to be located in `data_folders/mnist/`. 19 | The data must be in the range [0, 255] and loadable as follows: 20 | ```python 21 | import numpy as np 22 | mnist_digits = np.load(args.path_to_train) 23 | train_data = mnist_digits['x_train'] # data of shape 60000x32x32x1 in [0-255] 24 | train_targets = mnist_digits['y_train'] # corresponding labels 25 | ``` 26 | 27 | In `data_folders/mnist/test_folder` must be located 10k test images in `.png` format used for metric 28 | computation 29 | 30 | ### CIFAR10 31 | The provided code requires a file `cifar_10.npz` to be located in `data_folders/cifar/`. 32 | The data must be in the range of [0, 255] and lodable as follows: 33 | ```python 34 | import numpy as np 35 | cifar_data = np.load(args.path_to_train) 36 | train_data = cifar_data['x_train'] # data of shape 50000x32x32x3 in [0-255] 37 | train_targets = cifar_data['y_train'] # corresponding labels 38 | ``` 39 | 40 | In `data_folders/cifar/test_folder` must be located 10k test images in `.png` format used for metric 41 | computation 42 | 43 | ### Celeba 44 | The provided code requires a file `train_data.pt` to be located in `data_folders/celeba/`. The data 45 | must be a big tensor of shape n_samplesx3x64x64 in the range [0, 1] and loadable as follows: 46 | 47 | ```python 48 | import torch 49 | train_data = torch.load(os.path.join(args.path_to_train, 'train_data.pt')) # data of shape 162770x64x64x3 in the range of [0-1] 50 | val_data = torch.load(os.path.join(args.path_to_train, 'val_data.pt')) # data of shape 19867x64x64x3 in the range of [0-1] 51 | ``` 52 | 53 | In `data_folders/celeba/test/test` must be located the test images in `.png` format used for metric 54 | computation 55 | 56 | ### SVHN 57 | 58 | The provided code requires a file `train_32x32.mat` to be located in `data_folders/svhn/`. 59 | The data must be in the rnage [0, 255] and loadable as follows: 60 | 61 | ```python 62 | from scipy.io import loadmat 63 | svnh_digits = loadmat(args.path_to_train)['X'] # data of shape 32x32x3x73257 in the range of [0-255] 64 | svnh_targets = loadmat(args.path_to_train)['y'] # corresponding labels 65 | ``` 66 | 67 | In `data_folders/svhn/test_folder` must be located the test images in `.png` format used for metric 68 | computation. 69 | 70 | 71 | ### OASIS 72 | The provided code requires a file `OASIS.npz` to be located in `data_folders/oasis/`. The data must be in the range of [0, 255] and you must ensure that each data image has a maximum voxel value of 255 and a minimum of 0. The data must be loadable as follows 73 | 74 | ```python 75 | import numpy as np 76 | oasis_data = np.load(args.path_to_train) 77 | train_data = oasis_data['x_train'] # data of shape 416x208x176x1 in the range of [0-255] 78 | train_targets = torch.tensor(oasis_data['y_train'] # corresponding targets 79 | ``` 80 | 81 | ## Performing experiments 82 | 83 | The commandines to train a model, generate new data and compute the metrics are available in 84 | `models_to_train.sh`. 85 | 86 | 87 | ## Reference 88 | 89 | ```bibtex 90 | @article{chadebec2022geometric, 91 | title={A geometric perspective on variational autoencoders}, 92 | author={Chadebec, Cl{\'e}ment and Allassonni{\`e}re, St{\'e}phanie}, 93 | journal={Advances in Neural Information Processing Systems}, 94 | volume={35}, 95 | pages={19618--19630}, 96 | year={2022} 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 4 | 5 | def build_metrics(model, mu, log_var, idx=None, T=0.3, lbd=0.0001): 6 | 7 | if idx is not None: 8 | mu = mu[idx] 9 | log_var = log_var[idx] 10 | 11 | with torch.no_grad(): 12 | model.M_i = torch.diag_embed((-log_var).exp()).detach() 13 | model.M_i_flat = (-log_var).exp().detach() 14 | model.M_i_inverse_flat = (log_var).exp().detach() 15 | model.centroids = mu.detach() 16 | model.T = T 17 | model.lbd = lbd 18 | 19 | 20 | def G_sampl(z): 21 | omega = ( 22 | -( 23 | torch.transpose( 24 | (model.centroids.unsqueeze(0) - z.unsqueeze(1)).unsqueeze(-1), 2, 3) @ torch.diag_embed(model.M_i_flat).unsqueeze(0) @ (model.centroids.unsqueeze(0) - z.unsqueeze(1)).unsqueeze(-1) 25 | ) / model.T**2 26 | ).exp() 27 | 28 | return (torch.diag_embed(model.M_i_flat).unsqueeze(0) * omega 29 | ).sum(dim=1) + model.lbd * torch.eye(model.latent_dim).to(device) 30 | 31 | model.G_sampl = G_sampl 32 | 33 | return model 34 | 35 | 36 | def d_log_sqrt_det_G(z, model): 37 | with torch.no_grad(): 38 | omega = ( 39 | -( 40 | torch.transpose( 41 | (model.centroids.unsqueeze(0) - z.unsqueeze(1)).unsqueeze(-1), 2, 3) @ model.M_i.unsqueeze(0) @ (model.centroids.unsqueeze(0) - z.unsqueeze(1)).unsqueeze(-1) 42 | ) / model.T**2 43 | ).exp() 44 | d_omega_dz = ((-2 * model.M_i_flat * (z.unsqueeze(1) - model.centroids.unsqueeze(0)) / (model.T ** 2)).unsqueeze(-2) * omega).squeeze(-2) 45 | num = (d_omega_dz.unsqueeze(-2) * (model.M_i_flat.unsqueeze(0).unsqueeze(-1))).sum(1) 46 | denom = (model.M_i_flat.unsqueeze(0) * omega.squeeze(-1) + model.lbd).sum(1) 47 | 48 | return torch.transpose(num / denom.unsqueeze(-1), 1, 2).sum(-1) 49 | 50 | def log_pi(model, z): 51 | return 0.5 * (torch.clamp(model.G_sampl(z).det(), 0, 1e10)).log() 52 | 53 | 54 | def hmc_sampling(model, mu, n_samples=1, mcmc_steps_nbr=1000, n_lf=10, eps_lf=0.01): 55 | 56 | acc_nbr = torch.zeros(n_samples, 1).to(device) 57 | path = torch.zeros(n_samples, mcmc_steps_nbr, model.latent_dim) 58 | with torch.no_grad(): 59 | 60 | idx = torch.randint(0, len(mu), (n_samples,)) 61 | z0 = mu[idx] 62 | z = z0 63 | for i in range(mcmc_steps_nbr): 64 | #print(i) 65 | gamma = 0.5*torch.randn_like(z, device=device) 66 | rho = gamma# / self.beta_zero_sqrt 67 | 68 | H0 = -log_pi(model, z) + 0.5 * torch.norm(rho, dim=1) ** 2 69 | #print(H0) 70 | # print(model.G_inv(z).det()) 71 | for k in range(n_lf): 72 | 73 | g = -d_log_sqrt_det_G(z, model).reshape( 74 | n_samples, model.latent_dim 75 | ) 76 | # step 1 77 | rho_ = rho - (eps_lf / 2) * g 78 | 79 | # step 2 80 | z = z + eps_lf * rho_ 81 | g = -d_log_sqrt_det_G(z, model).reshape( 82 | n_samples, model.latent_dim 83 | ) 84 | 85 | # step 3 86 | rho__ = rho_ - (eps_lf / 2) * g 87 | 88 | # tempering 89 | beta_sqrt = 1 90 | 91 | rho = rho__ 92 | #beta_sqrt_old = beta_sqrt 93 | 94 | H = -log_pi(model, z) + 0.5 * torch.norm(rho, dim=1) ** 2 95 | alpha = torch.exp(-H) / (torch.exp(-H0)) 96 | 97 | acc = torch.rand(n_samples).to(device) 98 | moves = (acc < alpha).type(torch.int).reshape(n_samples, 1) 99 | 100 | acc_nbr += moves 101 | 102 | z = z * moves + (1 - moves) * z0 103 | path[:, i] = z 104 | z0 = z 105 | 106 | return z.detach(), path.detach().cpu() -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn import mixture 3 | from imageio import imwrite 4 | import numpy as np 5 | from models.vae import VAE 6 | from utils import Digits 7 | import os 8 | from scipy.io import loadmat 9 | from config import * 10 | from sampling import * 11 | 12 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 13 | 14 | 15 | def main(args): 16 | 17 | 18 | 19 | checkpoint = torch.load(args.model_path) 20 | 21 | print(checkpoint['args']) 22 | if 'rhvae' in args.model_path: 23 | from models.vae import RHVAE 24 | best_model = RHVAE(checkpoint['args']) 25 | 26 | else: 27 | best_model = VAE(checkpoint['args']) 28 | 29 | best_model.to(device) 30 | best_model.load_state_dict(checkpoint['state_dict']) 31 | print(best_model) 32 | 33 | print("Nu params", sum(p.numel() for p in best_model.parameters() if p.requires_grad) - sum(p.numel() for p in best_model.fc21.parameters() if p.requires_grad) ) 34 | 35 | path_to_train = args.data_path 36 | 37 | dataset_name = args.data_path.split('/')[-2] 38 | 39 | 40 | if dataset_name == 'mnist': 41 | eps_lf = 0.01 42 | lbd = 0.01 43 | 44 | mnist_digits = np.load(path_to_train) 45 | 46 | train_data = torch.tensor(mnist_digits['x_train'][:-10000]).type(torch.float).permute(0, 3, 1, 2) / 255. 47 | train_targets = torch.tensor(mnist_digits['y_train'][:-10000]) 48 | 49 | 50 | elif dataset_name == 'cifar': 51 | 52 | cifar_digits = np.load(path_to_train) 53 | 54 | train_data = torch.tensor(cifar_digits['x_train'][:-10000]).type(torch.float).permute(0, 3, 1, 2) / 255. 55 | train_targets = torch.tensor(cifar_digits['y_train'][:-10000]) 56 | 57 | eps_lf = 0.01 58 | lbd = 0.1 59 | 60 | 61 | elif dataset_name == 'celeba': 62 | 63 | train_data = torch.load(os.path.join(path_to_train, 'train_data.pt')) 64 | train_targets = torch.ones(len(train_data)) 65 | 66 | eps_lf = 0.01 67 | lbd = 1 68 | 69 | elif dataset_name == 'oasis': 70 | 71 | oasis_data = np.load(path_to_train) 72 | 73 | train_targets = torch.tensor(oasis_data['y_train'][:]) 74 | train_data = torch.tensor(oasis_data['x_train']).type(torch.float).permute(0, 3, 1, 2) / 255. 75 | 76 | eps_lf = 0.01 77 | lbd = 1 78 | 79 | elif dataset_name == 'svhn': 80 | eps_lf = 0.01 81 | lbd= 0.01 82 | 83 | svnh_digits = loadmat(path_to_train)['X'] 84 | svnh_targets = loadmat(path_to_train)['y'] 85 | 86 | svnh_digits = np.transpose(svnh_digits, (3, 0, 1, 2)) 87 | 88 | train_data = torch.tensor(svnh_digits[:-10000]).type(torch.float).permute(0, 3, 1, 2) / 255. 89 | train_targets = torch.tensor(svnh_targets[:-10000]) 90 | 91 | train = Digits(train_data, train_targets) 92 | train_loader = torch.utils.data.DataLoader( 93 | dataset=train, batch_size=100, shuffle=False 94 | ) 95 | 96 | if args.generation == 'hmc': 97 | 98 | path_to_save = f"generated_data/vae/{dataset_name}/manifold_sampling/" 99 | if not os.path.exists(path_to_save): 100 | os.makedirs(path_to_save) 101 | print(f"Created folder {path_to_save}. Data will be saved here") 102 | 103 | mu = [] 104 | log_var = [] 105 | 106 | with torch.no_grad(): 107 | for _ , (data, _) in enumerate(train_loader): 108 | 109 | mu_data, log_var_data = best_model.encode(data.to(device)) 110 | 111 | mu.append(mu_data) 112 | log_var.append(log_var_data) 113 | 114 | mu = torch.cat(mu) 115 | log_var = torch.cat(log_var) 116 | 117 | if dataset_name == 'cifar' or dataset_name=='mnist' or dataset_name == 'svhn': 118 | print('Running Kmedoids') 119 | from sklearn_extra.cluster import KMedoids 120 | kmedoids = KMedoids(n_clusters=100).fit(mu.detach().cpu()) 121 | medoids = torch.tensor(kmedoids.cluster_centers_).to(device) 122 | centroids_idx = kmedoids.medoid_indices_ # 123 | 124 | elif dataset_name == 'oasis': 125 | centroids_idx = torch.arange(0, 50) 126 | medoids = mu[centroids_idx] 127 | 128 | else: 129 | centroids_idx = torch.arange(0, 100).to(device) 130 | medoids = mu[centroids_idx] 131 | 132 | print("Finding temperature") 133 | 134 | T = 0 135 | T_is = [] 136 | for i in range(len(medoids)-1): 137 | mask = torch.tensor([k for k in range(len(medoids)) if k != i]) 138 | dist = torch.norm(medoids[i].unsqueeze(0) - medoids[mask], dim=-1) 139 | T_i =torch.min(dist, dim=0)[0] 140 | T_is.append(T_i.item()) 141 | 142 | T = np.max(T_is) 143 | print('Best temperature found: ', T) 144 | 145 | print('Building metric') 146 | best_model = build_metrics(best_model, mu, log_var, centroids_idx, T=T, lbd=lbd) 147 | 148 | if args.n_samples % args.batch_size > 0: 149 | print('Cropping batch for now....') 150 | 151 | print('Launching generation HMC') 152 | for j in range(0, int(args.n_samples / args.batch_size)): 153 | z, p = hmc_sampling(best_model, mu, n_samples=args.batch_size, eps_lf=eps_lf, mcmc_steps_nbr=100) 154 | recon_x = best_model.decode(z) 155 | for i in range(args.batch_size): 156 | img = (255. * torch.movedim(recon_x[i], 0, 2).cpu().detach().numpy()) 157 | 158 | if img.shape[-1]==1: 159 | img = np.repeat(img, repeats=3, axis=-1) 160 | img = img.astype('uint8') 161 | imwrite(os.path.join(path_to_save, '%08d.png' % int(args.batch_size*j + i)), img) 162 | 163 | 164 | elif args.generation == 'gmm' or args.generation == 'GMM' : 165 | print('Launching generation GMM') 166 | 167 | mu = [] 168 | 169 | with torch.no_grad(): 170 | for _ , (data, _) in enumerate(train_loader): 171 | 172 | mu_data, _ = best_model.encode(data.to(device)) 173 | 174 | mu.append(mu_data) 175 | 176 | mu = torch.cat(mu) 177 | print(mu.shape) 178 | 179 | gmm = mixture.GaussianMixture(n_components=args.n_components, covariance_type='full', max_iter=2000, 180 | verbose=2, tol=1e-3) 181 | gmm.fit(mu.cpu().detach()) 182 | 183 | for j in range(0, int(args.n_samples / args.batch_size)): 184 | 185 | idx = np.array(range(args.batch_size)) 186 | np.random.shuffle(idx) 187 | 188 | z = torch.tensor(gmm.sample(args.batch_size)[0][idx, :]).to(device).type(torch.float) 189 | 190 | recon_x = best_model.decode(z) 191 | for i in range(args.batch_size): 192 | img = (255. * torch.movedim(recon_x[i], 0, 2).cpu().detach().numpy()) 193 | 194 | if img.shape[-1]==1: 195 | img = np.repeat(img, repeats=3, axis=-1) 196 | img = img.astype('uint8') 197 | 198 | if best_model.model_name == 'AE': 199 | 200 | path_to_save = f"generated_data/ae/{dataset_name}/gmm/" 201 | if not os.path.exists(path_to_save): 202 | os.makedirs(path_to_save) 203 | print(f"Created folder {path_to_save}. Data will be saved here") 204 | imwrite(os.path.join(path_to_save, '%08d.png' % int(args.batch_size*j + i)), img) 205 | 206 | elif best_model.model_name == 'RHVAE': 207 | path_to_save = f"generated_data/rhvae/{dataset_name}/gmm/" 208 | if not os.path.exists(path_to_save): 209 | os.makedirs(path_to_save) 210 | print(f"Created folder {path_to_save}. Data will be saved here") 211 | imwrite(os.path.join(path_to_save, '%08d.png' % int(args.batch_size*j + i)), img) 212 | else: 213 | path_to_save = f"generated_data/vae/{dataset_name}/gmm/" 214 | if not os.path.exists(path_to_save): 215 | os.makedirs(path_to_save) 216 | print(f"Created folder {path_to_save}. Data will be saved here") 217 | imwrite(os.path.join(path_to_save, '%08d.png' % int(args.batch_size*j + i)), img) 218 | 219 | 220 | else: 221 | print('Launching generation Gaussian') 222 | for j in range(0, int(args.n_samples / args.batch_size)): 223 | 224 | z = torch.randn(args.batch_size, best_model.latent_dim).to(device) 225 | recon_x = best_model.decode(z) 226 | 227 | for i in range(args.batch_size): 228 | 229 | img = (255. * torch.movedim(recon_x[i], 0, 2).cpu().detach().numpy()) 230 | 231 | if img.shape[-1]==1: 232 | img = np.repeat(img, repeats=3, axis=-1) 233 | img = img.astype('uint8') 234 | 235 | if best_model.model_name == 'AE': 236 | path_to_save = f"generated_data/ae/{dataset_name}/gaussian_prior/" 237 | if not os.path.exists(path_to_save): 238 | os.makedirs(path_to_save) 239 | print(f"Created folder {path_to_save}. Data will be saved here") 240 | imwrite(os.path.join(path_to_save, '%08d.png' % int(args.batch_size*j + i)), img) 241 | 242 | elif best_model.model_name == 'RHVAE': 243 | path_to_save = f"generated_data/rhvae/{dataset_name}/gaussian_prior/" 244 | if not os.path.exists(path_to_save): 245 | os.makedirs(path_to_save) 246 | print(f"Created folder {path_to_save}. Data will be saved here") 247 | imwrite(os.path.join(path_to_save, '%08d.png' % int(args.batch_size*j + i)), img) 248 | else: 249 | path_to_save = f"generated_data/vae/{dataset_name}/gaussian_prior/" 250 | if not os.path.exists(path_to_save): 251 | os.makedirs(path_to_save) 252 | print(f"Created folder {path_to_save}. Data will be saved here") 253 | imwrite(os.path.join(path_to_save, '%08d.png' % int(args.batch_size*j + i)), img) 254 | 255 | if __name__ == "__main__": 256 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 257 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 258 | 259 | 260 | parser.add_argument("--model_path", type=str, 261 | help='Path to the model') 262 | parser.add_argument("--data_path", type=str, 263 | help='Path to the training data .npz files') 264 | parser.add_argument("--generation", type=str, 265 | help='Generation type', default='hmc') 266 | parser.add_argument("--n_samples", type=int, 267 | help='Number of samples', default=10000) 268 | parser.add_argument("--batch_size", type=int, 269 | help='Batch size', default=500) 270 | parser.add_argument("--n_components", type=int, 271 | help='Number of comp for gmm', default=10) 272 | 273 | 274 | args = parser.parse_args() 275 | 276 | 277 | np.random.seed(8) 278 | torch.manual_seed(8) 279 | torch.cuda.manual_seed(8) 280 | 281 | main(args) 282 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import Digits 3 | from models.training import train_vae 4 | from models.evaluation import eval_vae 5 | from copy import deepcopy 6 | import torch.optim as optim 7 | import numpy as np 8 | import os 9 | from models.vae import VAE, RHVAE 10 | from imageio import imread 11 | from scipy.io import loadmat 12 | from config import * 13 | 14 | 15 | def main(args): 16 | 17 | if args.path_to_train.split('/')[-2] == 'mnist': 18 | im_size_x, im_size_y = 32, 32 19 | im_channels = 1 20 | latent_dim = 16 21 | beta = .01 22 | patience = 5 23 | n_epochs = 100 24 | architecture = 'convnet' 25 | lr = 1e-3 26 | 27 | mnist_digits = np.load(args.path_to_train) 28 | 29 | train_data = torch.tensor(mnist_digits['x_train'][:-10000]).type(torch.float) / 255. 30 | train_targets = torch.tensor(mnist_digits['y_train'][:-10000]) 31 | val_data = torch.tensor(mnist_digits['x_train'][-10000:]).type(torch.float) / 255. 32 | val_targets = torch.tensor(mnist_digits['y_train'][-10000:]) 33 | 34 | 35 | elif args.path_to_train.split('/')[-2] == 'cifar': 36 | im_size_x, im_size_y = 32, 32 37 | im_channels = 3 38 | latent_dim = 32 39 | beta = 0.001 40 | lr=5e-4 41 | patience = 5 42 | n_epochs = 200 43 | architecture = 'convnet' 44 | n_components = 100 45 | 46 | cifar_digits = np.load(args.path_to_train) 47 | 48 | train_data = torch.tensor(cifar_digits['x_train'][:-10000]).type(torch.float).permute(0, 3, 1, 2) / 255. 49 | train_targets = torch.tensor(cifar_digits['y_train'][:-10000]) 50 | val_data = torch.tensor(cifar_digits['x_train'][10000:]).type(torch.float).permute(0, 3, 1, 2) / 255. 51 | val_targets = torch.tensor(cifar_digits['y_train'][10000:]) 52 | 53 | 54 | elif args.path_to_train.split('/')[-1] == 'celeba' or args.path_to_train.split('/')[-2] == 'celeba': 55 | im_size_x, im_size_y = 64, 64 56 | im_channels = 3 57 | latent_dim = 64 58 | beta = 0.05 59 | architecture = 'convnet' 60 | lr = 1e-3 61 | patience = 5 62 | n_epochs = 100 63 | 64 | train_data = torch.load(os.path.join(args.path_to_train, 'train_data.pt')) 65 | train_targets = torch.ones(len(train_data)) 66 | val_data = torch.load(os.path.join(args.path_to_train, 'val_data.pt')) 67 | val_targets = torch.ones(len(val_data)) 68 | 69 | 70 | elif args.path_to_train.split('/')[-2] == 'oasis': 71 | im_size_x, im_size_y = 208, 176 72 | im_channels = 1 73 | latent_dim = 16 74 | beta = 1 75 | architecture = 'convnet' 76 | lr = 1e-4 77 | patience = 20 78 | n_epochs = 1000 79 | 80 | oasis_data = np.load(args.path_to_train) 81 | 82 | train_targets = torch.tensor(oasis_data['y_train'][:]) 83 | train_data = torch.tensor(oasis_data['x_train']).type(torch.float).permute(0, 3, 1, 2) / 255. 84 | val_targets = torch.tensor(oasis_data['y_train'][:]) 85 | val_data = torch.tensor(oasis_data['x_train']).type(torch.float).permute(0, 3, 1, 2) / 255. 86 | 87 | 88 | elif args.path_to_train.split('/')[-2] == 'svhn': 89 | im_size_x, im_size_y = 32, 32 90 | im_channels = 3 91 | latent_dim = 16 92 | beta = 0.01 93 | lr=1e-3 94 | patience = 5 95 | n_epochs = 100 96 | architecture = 'mlp' 97 | 98 | if args.model == 'rhvae': 99 | lr = 5e-4 100 | temperature = 2.5 101 | 102 | svnh_digits = loadmat(args.path_to_train)['X'] 103 | svnh_targets = loadmat(args.path_to_train)['y'] 104 | 105 | svnh_digits = np.transpose(svnh_digits, (3, 0, 1, 2)) 106 | 107 | train_data = torch.tensor(svnh_digits[:-10000]).type(torch.float).permute(0, 3, 1, 2) / 255. 108 | train_targets = torch.tensor(svnh_targets[:-10000]) 109 | val_data = torch.tensor(svnh_digits[-10000:]).permute(0, 3, 1, 2) / 255. 110 | val_targets = torch.tensor(svnh_targets[-10000:]) 111 | 112 | else: 113 | raise NotImplementedError() 114 | 115 | 116 | train = Digits(train_data.reshape(-1, im_channels, im_size_x, im_size_y), train_targets) 117 | val = Digits(val_data.reshape(-1, im_channels, im_size_x, im_size_y), val_targets) 118 | 119 | train_loader = torch.utils.data.DataLoader( 120 | dataset=train, batch_size=100, shuffle=True 121 | ) 122 | val_loader = torch.utils.data.DataLoader( 123 | dataset=val, batch_size=100, shuffle=True 124 | ) 125 | 126 | 127 | print('---------------------------------------------------------------') 128 | print(f'Train size: {train_loader.dataset.data.shape, train_loader.dataset.data.min(), train_loader.dataset.data.max()}') 129 | print(f'Val size: {val_loader.dataset.data.shape, val_loader.dataset.data.min(), val_loader.dataset.data.max()}') 130 | print('---------------------------------------------------------------') 131 | 132 | 133 | if args.model == 'vae': 134 | path_to_save = os.path.join('trained_vae_models', 'vae', args.path_to_train.split('/')[-2]) 135 | 136 | elif args.model == 'ae': 137 | path_to_save = os.path.join('trained_vae_models', 'ae', args.path_to_train.split('/')[-2]) 138 | 139 | elif args.model == 'rhvae': 140 | path_to_save = os.path.join('trained_vae_models', 'rhvae', args.path_to_train.split('/')[-2]) 141 | 142 | if not os.path.exists(path_to_save): 143 | os.makedirs(path_to_save) 144 | print(f"Created folder {path_to_save}. Best model is saved here") 145 | 146 | 147 | ##### Training ##### 148 | 149 | if args.model == 'vae': 150 | 151 | train_args = VAE_config( 152 | input_dim=im_size_x*im_size_y, 153 | latent_dim=latent_dim, 154 | architecture=architecture, 155 | n_channels=im_channels, 156 | dataset=args.path_to_train.split('/')[-2], 157 | beta=beta 158 | ) 159 | 160 | vae = VAE(train_args) 161 | 162 | elif args.model == 'ae': 163 | 164 | train_args = VAE_config( 165 | model_name="AE", 166 | input_dim=im_size_x*im_size_y, 167 | latent_dim=latent_dim, 168 | architecture=architecture, 169 | n_channels=im_channels, 170 | dataset=args.path_to_train.split('/')[-2], 171 | beta=beta, 172 | ) 173 | 174 | vae = VAE(train_args) 175 | 176 | elif args.model == 'rhvae': 177 | 178 | train_args = RHVAE_config( 179 | model_name="RHVAE", 180 | input_dim=im_size_x*im_size_y, 181 | latent_dim=latent_dim, 182 | architecture=architecture, 183 | n_channels=im_channels, 184 | dataset=args.path_to_train.split('/')[-2], 185 | beta=beta 186 | ) 187 | 188 | vae = RHVAE(train_args) 189 | 190 | print(train_args) 191 | 192 | if torch.cuda.is_available(): 193 | print('Using cuda') 194 | vae.cuda() 195 | print("Model") 196 | print(vae) 197 | if train_args.architecture == 'convnet' and args.model == 'vae': 198 | print(f"Encoder num params: {sum(p.numel() for p in vae.fc21.parameters() if p.requires_grad) + sum(p.numel() for p in vae.conv.parameters() if p.requires_grad)} + log_var: {sum(p.numel() for p in vae.fc22.parameters() if p.requires_grad)}") 199 | print(f"Decoder num params: {sum(p.numel() for p in vae.fc3.parameters() if p.requires_grad) + sum(p.numel() for p in vae.deconv.parameters() if p.requires_grad)}") 200 | 201 | print("Nu params", sum(p.numel() for p in vae.parameters() if p.requires_grad) - sum(p.numel() for p in vae.fc21.parameters() if p.requires_grad) ) 202 | 203 | elif train_args.architecture == 'convnet' and args.model=='rhvae': 204 | num_metric_param = sum(p.numel() for p in vae.metric_fc21.parameters() if p.requires_grad) + sum(p.numel() for p in vae.metric_fc22.parameters() if p.requires_grad) + sum(p.numel() for p in vae.metric_fc1.parameters() if p.requires_grad) 205 | num_cov_param = sum(p.numel() for p in vae.fc22.parameters() if p.requires_grad) 206 | print(f"Encoder num params: {sum(p.numel() for p in vae.fc21.parameters() if p.requires_grad) + sum(p.numel() for p in vae.conv.parameters() if p.requires_grad)} + log_var: {num_cov_param} + metric: {num_metric_param}") 207 | print(f"Decoder num params: {sum(p.numel() for p in vae.fc3.parameters() if p.requires_grad) + sum(p.numel() for p in vae.deconv.parameters() if p.requires_grad)}") 208 | 209 | print("Nu params", sum(p.numel() for p in vae.parameters() if p.requires_grad) - num_cov_param - num_metric_param) 210 | 211 | elif train_args.architecture == 'convnet' and args.model == 'vamp': 212 | print(f"Encoder num params: {sum(p.numel() for p in vae.fc21.parameters() if p.requires_grad) + sum(p.numel() for p in vae.conv.parameters() if p.requires_grad)} + log_var: {sum(p.numel() for p in vae.fc22.parameters() if p.requires_grad)}") 213 | print(f"Decoder num params: {sum(p.numel() for p in vae.fc3.parameters() if p.requires_grad) + sum(p.numel() for p in vae.deconv.parameters() if p.requires_grad)}") 214 | 215 | print("Nu params", sum(p.numel() for p in vae.parameters() if p.requires_grad) - sum(p.numel() for p in vae.fc21.parameters() if p.requires_grad) - sum(p.numel() for p in vae.means.parameters() if p.requires_grad) ) 216 | 217 | elif train_args.architecture == 'mlp' and (not args.model=='vamp' and not args.model=='ae'): 218 | print("Nu params", sum(p.numel() for p in vae.parameters() if p.requires_grad) - sum(p.numel() for p in vae.fc21.parameters() if p.requires_grad)) 219 | 220 | elif train_args.architecture == 'mlp' and args.model=='vamp': 221 | print("Nu params", sum(p.numel() for p in vae.parameters() if p.requires_grad) - sum(p.numel() for p in vae.fc21.parameters() if p.requires_grad) - sum(p.numel() for p in vae.means.parameters() if p.requires_grad)) 222 | 223 | else: 224 | print("Nu params", sum(p.numel() for p in vae.parameters() if p.requires_grad)) 225 | 226 | optimizer = optim.Adam(vae.parameters(), lr=lr) 227 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=patience, verbose=True) 228 | 229 | 230 | best_loss = 1e10 231 | torch.manual_seed(8) 232 | torch.cuda.manual_seed(8) 233 | e = 0 234 | for epoch in range(1, n_epochs+1): 235 | print(f"Epoch {epoch}") 236 | 237 | if args.model == 'vae' or args.model=='ae' or args.model =='rhvae' or args.model == 'geoae': 238 | 239 | vae, train_loss, train_loss_rec, train_loss_kld = train_vae(epoch, train_args, vae, train_loader, optimizer) 240 | val_loss, val_loss_rec, val_loss_kld = eval_vae(epoch, train_args, vae, val_loader) 241 | 242 | scheduler.step(val_loss) 243 | if val_loss < best_loss: 244 | e = 0 245 | best_model_dict = { 246 | 'state_dict': deepcopy(vae.state_dict()), 247 | 'args': train_args 248 | } 249 | best_loss = val_loss 250 | 251 | if epoch % 1== 0: 252 | print('----------------------------------------------------------------------------------------------------------------') 253 | print(f'Epoch {epoch}: Train loss: {np.round(train_loss, 10)}\t Rec Loss: {np.round(train_loss_rec, 10)}\t KLD Loss: {np.round(train_loss_kld, 10)}') 254 | print(f'Epoch {epoch}: Eval loss: {np.round(val_loss, 10)}\t Rec Loss: {np.round(val_loss_rec, 10)}\t KLD Loss: {np.round(val_loss_kld, 10)}') 255 | print('----------------------------------------------------------------------------------------------------------------') 256 | 257 | torch.save(best_model_dict, os.path.join(path_to_save, 'best_model.pt')) 258 | print('<<<<<<<<<<<<<<<<<<<<< Saved best model >>>>>>>>>>>>>>>>>>>>>>>>') 259 | 260 | 261 | if __name__ == '__main__': 262 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 263 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 264 | 265 | 266 | parser.add_argument("path_to_train", type=str, 267 | help='Path to the training data .npz files') 268 | parser.add_argument("--model", type=str, choices=['ae', 'vae', 'vamp', 'rhvae'], 269 | help='Model to train', default='vae') 270 | 271 | args = parser.parse_args() 272 | 273 | main(args) 274 | -------------------------------------------------------------------------------- /models_to_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################## MNIST ################## 4 | 5 | ######## training ######## 6 | 7 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 0 8 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 3 9 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 4 10 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 5 11 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 6 12 | #python train_vae.py data_folders/mnist/mnist_32x32.npz --model 'vamp' 13 | #python train_vae.py data_folders/mnist/mnist_32x32.npz --model 'vae' 14 | #python train_vae.py data_folders/mnist/mnist_32x32.npz --model 'ae' 15 | #python train_vae.py data_folders/mnist/mnist_32x32.npz --model 'rhvae' 16 | 17 | 18 | ######## generation ######## 19 | 20 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 0 21 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 3 22 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 4 23 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 5 24 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 6 25 | #python generate_data.py --model_path trained_vae_models/vamp/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 26 | #python generate_data.py --model_path trained_vae_models/vae/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 27 | #python generate_data.py --model_path trained_vae_models/vae/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 --generation 'gauss' 28 | #python generate_data.py --model_path trained_vae_models/vae/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 --generation 'gmm' 29 | #python generate_data.py --model_path trained_vae_models/ae/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 30 | #python generate_data.py --model_path trained_vae_models/ae/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 --generation 'gauss' 31 | #python generate_data.py --model_path trained_vae_models/ae/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 --generation 'gmm' 32 | #python generate_data.py --model_path trained_vae_models/rhvae/mnist/best_model.pt --data_path data_folders/mnist/mnist_32x32.npz --n_samples 10000 --generation 'gauss' 33 | 34 | ######## metric computation ######## 35 | 36 | #python TTUR/fid.py peers/logs/0/WAE_1/one_gaussian_sampled/ data_folders/mnist/test_folder --gpu '0' 37 | #python TTUR/fid.py peers/logs/3/RAE-GP_1/GMM_10_sampled/ data_folders/mnist/test_folder --gpu '0' 38 | #python TTUR/fid.py peers/logs/4/RAE-L2_1/GMM_10_sampled/ data_folders/mnist/test_folder --gpu '0' 39 | #python TTUR/fid.py peers/logs/5/RAE-SN_1/GMM_10_sampled/ data_folders/mnist/test_folder --gpu '0' 40 | #python TTUR/fid.py peers/logs/6/RAE_1/GMM_10_sampled/ data_folders/mnist/test_folder --gpu '0' 41 | #python TTUR/fid.py generated_data/vamp/mnist/ data_folders/mnist/test_folder/ --gpu '0' 42 | #python TTUR/fid.py generated_data/vae/mnist/manifold_sampling/ data_folders/mnist/test_folder/ --gpu '0' 43 | #python TTUR/fid.py generated_data/vae/mnist/gmm/ data_folders/mnist/test_folder/ --gpu '0' 44 | #python TTUR/fid.py generated_data/vae/mnist/gaussian_prior/ data_folders/mnist/test_folder/ --gpu '0' 45 | #python TTUR/fid.py generated_data/ae/mnist/gmm/ data_folders/mnist/test_folder/ --gpu '0' 46 | #python TTUR/fid.py generated_data/ae/mnist/gaussian_prior/ data_folders/mnist/test_folder/ --gpu '0' 47 | #python TTUR/fid.py generated_data/rhvae/mnist/gaussian_prior/ data_folders/mnist/test_folder/ --gpu '0' 48 | 49 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs peers/logs/0/WAE_1/one_gaussian_sampled/ 50 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs peers/logs/3/RAE-GP_1/GMM_10_sampled/ 51 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs peers/logs/4/RAE-L2_1/GMM_10_sampled/ 52 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs peers/logs/5/RAE-SN_1/GMM_10_sampled/ 53 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs peers/logs/6/RAE_1/GMM_10_sampled/ 54 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs generated_data/vamp/mnist/ 55 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs generated_data/vae/mnist/manifold_sampling/ 56 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs generated_data/vae/mnist/gmm/ 57 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs generated_data/vae/mnist/gaussian_prior/ 58 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs generated_data/ae/mnist/gmm/ 59 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs generated_data/ae/mnist/gaussian_prior/ 60 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/mnist/test_folder/ --eval_dirs generated_data/rhvae/mnist/gaussian_prior/ 61 | # 62 | 63 | 64 | ################## CIFAR ################## 65 | 66 | ######## training ######## 67 | 68 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 8 69 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 11 70 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 12 71 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 13 72 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 14 73 | #python train_vae.py data_folders/cifar/cifar_10.npz --model 'vamp' 74 | #python train_vae.py data_folders/cifar/cifar_10.npz --model 'vae' 75 | #python train_vae.py data_folders/cifar/cifar_10.npz --model 'ae' 76 | #python train_vae.py data_folders/cifar/cifar_10.npz --model 'rhvae' 77 | 78 | 79 | ######## generation ######## 80 | 81 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 8 82 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 11 83 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 12 84 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 13 85 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 14 86 | #python generate_data.py --model_path trained_vae_models/vamp/cifar/best_model.pt --data_path data_folders/cifar/cifar_10.npz --n_samples 10000 87 | #python generate_data.py --model_path trained_vae_models/vae/cifar/best_model.pt --data_path data_folders/cifar/cifar_10.npz --n_samples 10000 88 | #python generate_data.py --model_path trained_vae_models/vae/cifar/best_model.pt --data_path data_folders/cifar/cifar_10.npz --n_samples 10000 --generation 'gauss' 89 | #python generate_data.py --model_path trained_vae_models/vae/cifar/best_model.pt --data_path data_folders/cifar/cifar_10.npz --n_samples 10000 --generation 'gmm' 90 | #python generate_data.py --model_path trained_vae_models/ae/cifar/best_model.pt --data_path data_folders/cifar/cifar_10.npz --n_samples 10000 --generation 'gauss' 91 | #python generate_data.py --model_path trained_vae_models/ae/cifar/best_model.pt --data_path data_folders/cifar/cifar_10.npz --n_samples 10000 --generation 'gmm' 92 | #python generate_data.py --model_path trained_vae_models/rhvae/cifar/best_model.pt --data_path data_folders/cifar/cifar_10.npz --n_samples 10000 --generation 'gauss' 93 | 94 | 95 | ######## metric computation ######## 96 | 97 | #python TTUR/fid.py peers/logs/8/WAE_1/one_gaussian_sampled/ data_folders/cifar/test_folder --gpu '0' 98 | #python TTUR/fid.py peers/logs/11/RAE-GP_1/GMM_10_sampled/ data_folders/cifar/test_folder --gpu '0' 99 | #python TTUR/fid.py peers/logs/12/RAE-L2_1/GMM_10_sampled/ data_folders/cifar/test_folder --gpu '0' 100 | #python TTUR/fid.py peers/logs/13/RAE-SN_1/GMM_10_sampled/ data_folders/cifar/test_folder --gpu '0' 101 | #python TTUR/fid.py peers/logs/14/RAE_1/GMM_10_sampled/ data_folders/cifar/test_folder --gpu '0' 102 | #python TTUR/fid.py generated_data/vamp/cifar/ data_folders/cifar/test_folder/ --gpu '0' 103 | #python TTUR/fid.py generated_data/vae/cifar/manifold_sampling/ data_folders/cifar/test_folder/ --gpu '0' 104 | #python TTUR/fid.py generated_data/vae/cifar/gmm/ data_folders/cifar/test_folder/ --gpu '0' 105 | #python TTUR/fid.py generated_data/vae/cifar/gaussian_prior/ data_folders/cifar/test_folder/ --gpu '0' 106 | #python TTUR/fid.py generated_data/ae/cifar/gmm/ data_folders/cifar/test_folder/ --gpu '0' 107 | #python TTUR/fid.py generated_data/ae/cifar/gaussian_prior/ data_folders/cifar/test_folder/ --gpu '0' 108 | #python TTUR/fid.py generated_data/rhvae/cifar/gaussian_prior/ data_folders/cifar/test_folder/ --gpu '0' 109 | 110 | 111 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs peers/logs/8/WAE_1/one_gaussian_sampled/ 112 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs peers/logs/11/RAE-GP_1/GMM_10_sampled/ 113 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs peers/logs/12/RAE-L2_1/GMM_10_sampled/ 114 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs peers/logs/13/RAE-SN_1/GMM_10_sampled/ 115 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs peers/logs/14/RAE_1/GMM_10_sampled/ 116 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs generated_data/vamp/cifar/ 117 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs generated_data/vae/cifar/manifold_sampling/ 118 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs generated_data/vae/cifar/gmm/ 119 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs generated_data/vae/cifar/gaussian_prior/ 120 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs generated_data/ae/cifar/gmm/ 121 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs generated_data/ae/cifar/gaussian_prior/ 122 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/cifar/test_folder/ --eval_dirs generated_data/rhvae/cifar/gaussian_prior/ 123 | 124 | 125 | 126 | ################## CELEBA ################## 127 | 128 | ######## training ######## 129 | 130 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 16 131 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 19 132 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 20 133 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 21 134 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 22 135 | #python train_vae.py data_folders/celeba/ --model 'vamp' 136 | #python train_vae.py data_folders/celeba/ --model 'vae' 137 | #python train_vae.py data_folders/celeba/ --model 'ae' 138 | #python train_vae.py data_folders/celeba/ --model 'rhvae' 139 | 140 | 141 | ######## generation ######## 142 | 143 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 16 144 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 19 145 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 20 146 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 21 147 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 22 148 | #python generate_data.py --model_path trained_vae_models/vamp/celeba/best_model.pt --data_path data_folders/celeba/. --n_samples 10000 149 | #python generate_data.py --model_path trained_vae_models/vae/celeba/best_model.pt --data_path data_folders/celeba/. --n_samples 10000 150 | #python generate_data.py --model_path trained_vae_models/vae/celeba/best_model.pt --data_path data_folders/celeba/. --n_samples 10000 --generation 'gauss' 151 | #python generate_data.py --model_path trained_vae_models/vae/celeba/best_model.pt --data_path data_folders/celeba/. --n_samples 10000 --generation 'gmm' 152 | #python generate_data.py --model_path trained_vae_models/ae/celeba/best_model.pt --data_path data_folders/celeba/. --n_samples 10000 --generation 'gauss' 153 | #python generate_data.py --model_path trained_vae_models/ae/celeba/best_model.pt --data_path data_folders/celeba/. --n_samples 10000 --generation 'gmm' 154 | #python generate_data.py --model_path trained_vae_models/rhvae/celeba/best_model.pt --data_path data_folders/celeba/ --n_samples 10000 --generation 'gauss' 155 | 156 | 157 | ######## metric computation ######## 158 | 159 | #python TTUR/fid.py peers/logs/16/WAE_1/one_gaussian_sampled/ data_folders/celeba/test/test/ --gpu '0' 160 | #python TTUR/fid.py peers/logs/19/RAE-GP_1/GMM_10_sampled/ data_folders/celeba/test/test/ --gpu '0' 161 | #python TTUR/fid.py peers/logs/20/RAE-L2_1/GMM_10_sampled/ data_folders/celeba/test/test/ --gpu '0' 162 | #python TTUR/fid.py peers/logs/21/RAE-SN_1/GMM_10_sampled/ data_folders/celeba/test/test/ --gpu '0' 163 | #python TTUR/fid.py peers/logs/22/RAE_1/GMM_10_sampled/ data_folders/celeba/test/test/ --gpu '0' 164 | #python TTUR/fid.py generated_data/vamp/celeba/ data_folders/celeba/test/test/ --gpu '0' 165 | #python TTUR/fid.py generated_data/vae/celeba/manifold_sampling/ data_folders/celeba/test/test/ --gpu '0' 166 | #python TTUR/fid.py generated_data/vae/celeba/gmm/ data_folders/celeba/test/test/ --gpu '0' 167 | #python TTUR/fid.py generated_data/vae/celeba/gaussian_prior/ data_folders/celeba/test/test/ --gpu '0' 168 | #python TTUR/fid.py generated_data/ae/celeba/gmm/ data_folders/celeba/test/test/ --gpu '0' 169 | #python TTUR/fid.py generated_data/ae/celeba/gaussian_prior/ data_folders/celeba/test/test/ --gpu '0' 170 | #python TTUR/fid.py generated_data/rhvae/celeba/gaussian_prior/ data_folders/celeba/test/test/ --gpu '0' 171 | 172 | 173 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs peers/logs/16/WAE_1/one_gaussian_sampled 174 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs peers/logs/19/RAE-GP_1/GMM_10_sampled/ 175 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs peers/logs/20/RAE-L2_1/GMM_10_sampled/ 176 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs peers/logs/21/RAE-SN_1/GMM_10_sampled/ 177 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs peers/logs/22/RAE_1/GMM_10_sampled/ 178 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs generated_data/vamp/celeba/ 179 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs generated_data/vae/celeba/manifold_sampling/ 180 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs generated_data/vae/celeba/gmm/ 181 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs generated_data/vae/celeba/gaussian_prior/ 182 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs generated_data/ae/celeba/gmm/ 183 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs generated_data/ae/celeba/gaussian_prior/ 184 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/celeba/test/test/ --eval_dirs generated_data/rhvae/celeba/gaussian_prior/ 185 | 186 | 187 | 188 | ################## SVHN ################## 189 | 190 | ######## training ######## 191 | 192 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 24 193 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 27 194 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 28 195 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 29 196 | #python peers/Regularized_autoencoders-RAE-/train_raes_vaes.py 30 197 | #python train_vae.py data_folders/svhn/train_32x32.mat --model 'vamp' 198 | #python train_vae.py data_folders/svhn/train_32x32.mat --model 'vae' 199 | #python train_vae.py data_folders/svhn/train_32x32.mat --model 'ae' 200 | #python train_vae.py data_folders/svhn/train_32x32.mat --model 'rhvae' 201 | 202 | 203 | ######## generation ######## 204 | 205 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 24 206 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 27 207 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 28 208 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 29 209 | #python peers/Regularized_autoencoders-RAE-/interpolation_fid_and_viz.py 30 210 | #python generate_data.py --model_path trained_vae_models/vamp/svhn/best_model.pt --data_path data_folders/svhn/train_32x32.mat --n_samples 10000 211 | #python generate_data.py --model_path trained_vae_models/vae/svhn/best_model.pt --data_path data_folders/svhn/train_32x32.mat --n_samples 10000 212 | #python generate_data.py --model_path trained_vae_models/vae/svhn/best_model.pt --data_path data_folders/svhn/train_32x32.mat --n_samples 10000 --generation 'gauss' 213 | #python generate_data.py --model_path trained_vae_models/vae/svhn/best_model.pt --data_path data_folders/svhn/train_32x32.mat --n_samples 10000 --generation 'gmm' 214 | #python generate_data.py --model_path trained_vae_models/ae/svhn/best_model.pt --data_path data_folders/svhn/train_32x32.mat --n_samples 10000 --generation 'gauss' 215 | #python generate_data.py --model_path trained_vae_models/ae/svhn/best_model.pt --data_path data_folders/svhn/train_32x32.mat --n_samples 10000 --generation 'gmm' 216 | #python generate_data.py --model_path trained_vae_models/rhvae/svhn/best_model.pt --data_path data_folders/svhn/train_32x32.mat --n_samples 10000 --generation 'gauss' 217 | 218 | 219 | ######## metric computation ######## 220 | 221 | #python TTUR/fid.py peers/logs/24/WAE_1/one_gaussian_sampled/ data_folders/svhn/test_folder/ --gpu '0' 222 | #python TTUR/fid.py peers/logs/27/RAE-GP_1/GMM_10_sampled/ data_folders/svhn/test_folder/ --gpu '0' 223 | #python TTUR/fid.py peers/logs/28/RAE-L2_1/GMM_10_sampled/ data_folders/svhn/test_folder/ --gpu '0' 224 | #python TTUR/fid.py peers/logs/29/RAE-SN_1/GMM_10_sampled/ data_folders/svhn/test_folder/ --gpu '0' 225 | #python TTUR/fid.py peers/logs/30/RAE_1/GMM_10_sampled/ data_folders/svhn/test_folder/ --gpu '0' 226 | #python TTUR/fid.py generated_data/vamp/svhn/ data_folders/svhn/test_folder/ --gpu '0' 227 | #python TTUR/fid.py generated_data/vae/svhn/manifold_sampling/ data_folders/svhn/test_folder/ --gpu '0' 228 | #python TTUR/fid.py generated_data/vae/svhn/gmm/ data_folders/svhn/test_folder/ --gpu '0' 229 | #python TTUR/fid.py generated_data/vae/svhn/gaussian_prior/ data_folders/svhn/test_folder/ --gpu '0' 230 | #python TTUR/fid.py generated_data/ae/svhn/gmm/ data_folders/svhn/test_folder/ --gpu '0' 231 | #python TTUR/fid.py generated_data/ae/svhn/gaussian_prior/ data_folders/svhn/test_folder/ --gpu '0' 232 | #python TTUR/fid.py generated_data/rhvae/svhn/gaussian_prior/ data_folders/svhn/test_folder/ --gpu '0' 233 | 234 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs peers/logs/24/WAE_1/one_gaussian_sampled 235 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs peers/logs/27/RAE-GP_1/GMM_10_sampled/ 236 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs peers/logs/28/RAE-L2_1/GMM_10_sampled/ 237 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs peers/logs/29/RAE-SN_1/GMM_10_sampled/ 238 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs peers/logs/30/RAE_1/GMM_10_sampled/ 239 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs generated_data/vamp/svhn/ 240 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs generated_data/vae/svhn/manifold_sampling/ 241 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs generated_data/vae/svhn/gmm/ 242 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs generated_data/vae/svhn/gaussian_prior/ 243 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs generated_data/ae/svhn/gmm/ 244 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs generated_data/ae/svhn/gaussian_prior/ 245 | #python precision-recall-distributions/prd_from_image_folders.py --inception_path /tmp/classify_image_graph_def.pb --reference_dir data_folders/svhn/test_folder --eval_dirs generated_data/rhvae/svhn/gaussian_prior/ 246 | -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import grad 6 | 7 | from .base import BaseVAE 8 | 9 | 10 | class VAE(BaseVAE, nn.Module): 11 | def __init__(self, args): 12 | 13 | BaseVAE.__init__(self, args) 14 | nn.Module.__init__(self) 15 | 16 | self.model_name = args.model_name 17 | self.dataset = args.dataset 18 | self.input_dim = args.input_dim 19 | self.latent_dim = args.latent_dim 20 | self.n_channels = args.n_channels 21 | self.beta = args.beta 22 | 23 | if args.architecture == 'convnet': 24 | 25 | if args.dataset == 'mnist' or args.dataset == 'cifar': # SAME as From Var. to Deter. AE (check num of params) 26 | self.conv = torch.nn.Sequential( 27 | nn.Conv2d(self.n_channels, 128, 4, 2, padding=1), 28 | nn.BatchNorm2d(128), 29 | nn.ReLU(), 30 | nn.Conv2d(128, 256, 4, 2, padding=1), 31 | nn.BatchNorm2d(256), 32 | nn.ReLU(), 33 | nn.Conv2d(256, 512, 4, 2, padding=1), 34 | nn.BatchNorm2d(512), 35 | nn.ReLU(), 36 | nn.Conv2d(512, 1024, 4, 2, padding=1), 37 | nn.BatchNorm2d(1024), 38 | nn.ReLU(), 39 | ) 40 | 41 | #self.fc1 = nn.Linear(1024*2*2, 16) 42 | self.fc21 = nn.Linear(1024*2*2, args.latent_dim) 43 | if not args.model_name == 'AE': 44 | self.fc22 = nn.Linear(1024*2*2, args.latent_dim) 45 | #self.fc22 = nn.Linear(400, args.latent_dim) 46 | 47 | self.fc3 = nn.Linear(args.latent_dim, 1024*8*8) 48 | #self.fc4 = nn.Linear(16, 1024*8*8) 49 | self.deconv = nn.Sequential( 50 | nn.ConvTranspose2d(1024, 512, 4, 2, padding=1), 51 | nn.BatchNorm2d(512), 52 | nn.ReLU(), 53 | nn.ConvTranspose2d(512, 256, 4, 2, padding=1, output_padding=1), 54 | nn.BatchNorm2d(256), 55 | nn.ReLU(), 56 | nn.ConvTranspose2d(256, self.n_channels, 4, 1, padding=2), 57 | #nn.BatchNorm2d(self.n_channels), 58 | #nn.ReLU(), 59 | #nn.ConvTranspose2d(self.n_channels, self.n_channels, 3, 1, padding=1), 60 | #nn.BatchNorm2d(self.n_channels), 61 | nn.Sigmoid() 62 | ) 63 | 64 | elif args.dataset == 'celeba': 65 | self.conv = torch.nn.Sequential( 66 | nn.Conv2d(self.n_channels, 128, 5, 2, padding=1), 67 | nn.BatchNorm2d(128), 68 | nn.ReLU(), 69 | nn.Conv2d(128, 256, 5, 2, padding=1), 70 | nn.BatchNorm2d(256), 71 | nn.ReLU(), 72 | nn.Conv2d(256, 512, 5, 2, padding=2), 73 | nn.BatchNorm2d(512), 74 | nn.ReLU(), 75 | nn.Conv2d(512, 1024, 5, 2, padding=2), 76 | nn.BatchNorm2d(1024), 77 | nn.ReLU(), 78 | ) 79 | 80 | #self.fc1 = nn.Linear(1024*2*2, 128) 81 | self.fc21 = nn.Linear(1024*4*4, args.latent_dim) 82 | if not args.model_name == 'AE': 83 | self.fc22 = nn.Linear(1024*4*4, args.latent_dim) 84 | #self.fc22 = nn.Linear(400, args.latent_dim) 85 | 86 | self.fc3 = nn.Linear(args.latent_dim, 1024*8*8) 87 | #self.fc4 = nn.Linear(128, 1024*8*8) 88 | self.deconv = nn.Sequential( 89 | nn.ConvTranspose2d(1024, 512, 5, 2, padding=2), 90 | nn.BatchNorm2d(512), 91 | nn.ReLU(), 92 | nn.ConvTranspose2d(512, 256, 5, 2, padding=1, output_padding=0), 93 | nn.BatchNorm2d(256), 94 | nn.ReLU(), 95 | nn.ConvTranspose2d(256, 128, 5, 2, padding=2, output_padding=1), 96 | nn.BatchNorm2d(128), 97 | nn.ReLU(), 98 | nn.ConvTranspose2d(128, 3, 5, 1, padding=1), 99 | #nn.BatchNorm2d(self.n_channels), 100 | #nn.ReLU(), 101 | #nn.ConvTranspose2d(self.n_channels, self.n_channels, 3, 1, padding=1), 102 | #nn.BatchNorm2d(self.n_channels), 103 | nn.Sigmoid()) 104 | 105 | 106 | elif args.dataset == 'oasis': 107 | 108 | self.conv = torch.nn.Sequential( 109 | nn.Conv2d(self.n_channels, 64, 5, 2, padding=1), 110 | #nn.BatchNorm2d(64), 111 | nn.ReLU(), 112 | nn.Conv2d(64, 128, 5, 2, padding=1), 113 | #nn.BatchNorm2d(128), 114 | nn.ReLU(), 115 | nn.Conv2d(128, 256, 5, 2, padding=1), 116 | #nn.BatchNorm2d(256), 117 | nn.ReLU(), 118 | nn.Conv2d(256, 512, 5, 2, padding=(1, 2)), 119 | #nn.BatchNorm2d(512), 120 | nn.ReLU(), 121 | nn.Conv2d(512, 1024, 5, 2, padding=0), 122 | #nn.BatchNorm2d(1024), 123 | nn.ReLU(), 124 | #nn.Conv2d(1024, 20, 5, 2, padding=1), 125 | #nn.BatchNorm2d(1024), 126 | #nn.ReLU(), 127 | ) 128 | self.fc21 = nn.Linear(1024*4*4, args.latent_dim) 129 | if not args.model_name == 'AE': 130 | self.fc22 = nn.Linear(1024*4*4, args.latent_dim) 131 | 132 | self.fc3 = nn.Linear(args.latent_dim, 1024*8*8) 133 | self.deconv = nn.Sequential( 134 | nn.ConvTranspose2d(1024, 512, 5, (3, 2), padding=(1, 0), output_padding=(0, 0)), 135 | #nn.BatchNorm2d(512), 136 | nn.ReLU(), 137 | nn.ConvTranspose2d(512, 256, 5, 2, padding=(1, 0), output_padding=(0, 0)), 138 | #nn.BatchNorm2d(256), 139 | nn.ReLU(), 140 | nn.ConvTranspose2d(256, 128, 5, 2, padding=0, output_padding=(0, 0)), 141 | #nn.BatchNorm2d(128), 142 | nn.ReLU(), 143 | nn.ConvTranspose2d(128, 64, 5, 2, padding=0, output_padding=(1, 1)), 144 | #nn.BatchNorm2d(64), 145 | nn.ReLU(), 146 | nn.ConvTranspose2d(64, self.n_channels, 5, 1, padding=1), 147 | #nn.BatchNorm2d(self.n_channels), 148 | nn.Sigmoid()) 149 | 150 | else: 151 | # encoder network 152 | self.fc1 = nn.Linear(args.input_dim*args.n_channels, 1000) 153 | self.fc11 = nn.Linear(1000, 500) 154 | #self.fc111 = nn.Linear(500, 500) 155 | #self.fc1111 = nn.Linear(500, 500) 156 | self.fc21 = nn.Linear(500, args.latent_dim) 157 | self.fc22 = nn.Linear(500, args.latent_dim) 158 | 159 | # decoder network 160 | self.fc3 = nn.Linear(args.latent_dim, 500) 161 | self.fc33 = nn.Linear(500, 1000) 162 | #self.fc333 = nn.Linear(500, 500) 163 | #self.fc3333 = nn.Linear(500, 500) 164 | self.fc4 = nn.Linear(1000, args.input_dim*args.n_channels) 165 | 166 | if args.architecture == 'convnet': 167 | self._encoder = self._encode_convnet 168 | self._decoder = self._decode_convnet 169 | 170 | else: 171 | self._encoder = self._encode_mlp 172 | self._decoder = self._decode_mlp 173 | 174 | 175 | 176 | # define a N(0, I) distribution 177 | self.normal = torch.distributions.MultivariateNormal( 178 | loc=torch.zeros(args.latent_dim).to(self.device), 179 | covariance_matrix=torch.eye(args.latent_dim).to(self.device), 180 | ) 181 | 182 | def forward(self, x): 183 | """ 184 | The VAE model 185 | """ 186 | mu, log_var = self.encode(x) 187 | 188 | if self.model_name == 'AE': 189 | std = 0 190 | eps = 0 191 | z = mu 192 | 193 | else: 194 | std = torch.exp(0.5 * log_var) 195 | z, eps = self._sample_gauss(mu, std) 196 | 197 | recon_x = self.decode(z) 198 | 199 | return recon_x, z, eps, mu, log_var 200 | 201 | def loss_function(self, recon_x, x, mu, log_var, z): 202 | BCE = F.mse_loss( 203 | recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction='none' 204 | ).sum() 205 | if self.model_name == 'VAE': 206 | KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) 207 | 208 | elif self.model_name =='AE': 209 | KLD = torch.zeros(1).to(self.device) 210 | 211 | elif self.model_name =='GeoAE': 212 | KLD = torch.linalg.norm(log_var.exp() - torch.ones_like(log_var), dim=1).sum() 213 | 214 | return BCE + self.beta * KLD, BCE, KLD 215 | 216 | def encode(self, x): 217 | print(x.max()) 218 | return self._encoder(x) 219 | 220 | def decode(self, z): 221 | x_prob = self._decoder(z) 222 | return x_prob 223 | 224 | def sample_img( 225 | self, 226 | z=None, 227 | n_samples=1 228 | ): 229 | """ 230 | Generate an image 231 | """ 232 | if z is None: 233 | z = self.normal.sample(sample_shape=(n_samples,)).to(self.device) 234 | 235 | else: 236 | n_samples = z.shape[0] 237 | 238 | recon_x = self.decode(z) 239 | 240 | return recon_x 241 | 242 | def encode(self, x): 243 | return self._encoder(x) 244 | 245 | def decode(self, z): 246 | x_prob = self._decoder(z) 247 | return x_prob 248 | 249 | def _encode_mlp(self, x): 250 | h1 = F.relu(self.fc1(x.reshape(-1, self.input_dim*self.n_channels))) 251 | h1 = F.relu(self.fc11(h1)) 252 | #h1 = F.relu(self.fc111(h1)) 253 | #h1 = F.relu(self.fc1111(h1)) 254 | if self.model_name == 'AE': 255 | return self.fc21(h1), 0 256 | else: 257 | return self.fc21(h1), self.fc22(h1) 258 | 259 | def _decode_mlp(self, z): 260 | h3 = F.relu(self.fc3(z)) 261 | h3 = F.relu(self.fc33(h3)) 262 | #h3 = F.relu(self.fc333(h3)) 263 | #h3 = F.relu(self.fc3333(h3)) 264 | if self.dataset == "oasis": 265 | return torch.sigmoid(self.fc4(h3)).reshape(z.shape[0], self.n_channels, 208,176) 266 | else: 267 | return torch.sigmoid(self.fc4(h3)).reshape(z.shape[0], self.n_channels, int(self.input_dim**0.5), int(self.input_dim**0.5)) 268 | 269 | def _encode_convnet(self, x): 270 | h1 = self.conv(x).reshape(x.shape[0], -1) 271 | #h1 = self.fc1(h1.reshape(h1.shape[0], -1))#F.relu(self.fc1(h1.reshape(h1.shape[0], -1))) 272 | 273 | if self.model_name == 'AE': 274 | return self.fc21(h1), 0 275 | if self.dataset == 'oasis': 276 | return self.fc21(h1), torch.tanh(self.fc22(h1)) 277 | else: 278 | return self.fc21(h1), self.fc22(h1) 279 | 280 | def _decode_convnet(self, z): 281 | h3 = self.fc3(z) 282 | #h3 = F.relu(self.fc4(h3)) 283 | #h3 = F.relu(self.fc4(h3)) 284 | 285 | h3 = self.deconv(h3.reshape(-1, 1024, 8, 8)) 286 | 287 | return h3 288 | 289 | def _sample_gauss(self, mu, std): 290 | # Reparametrization trick 291 | # Sample N(0, I) 292 | eps = torch.randn_like(std) 293 | return mu + eps * std, eps 294 | 295 | def _tempering(self, k, K): 296 | """Perform tempering step""" 297 | 298 | beta_k = ( 299 | (1 - 1 / self.beta_zero_sqrt) * (k / K) ** 2 300 | ) + 1 / self.beta_zero_sqrt 301 | 302 | return 1 / beta_k 303 | 304 | ########## Estimate densities ########## 305 | 306 | def log_p_x_given_z(self, recon_x, x, reduction="none"): 307 | """ 308 | Estimate the decoder's log-density modelled as follows: 309 | p(x|z) = \prod_i Bernouilli(x_i|pi_{theta}(z_i)) 310 | p(x = s|z) = \prod_i (pi(z_i))^x_i * (1 - pi(z_i)^(1 - x_i))""" 311 | return -F.binary_cross_entropy( 312 | recon_x.reshape(x.shape[0], -1), x.view(x.shape[0], -1), reduction=reduction 313 | ).sum(dim=1) 314 | 315 | def log_z(self, z): 316 | """ 317 | Return Normal density function as prior on z 318 | """ 319 | return self.normal.log_prob(z) 320 | 321 | def log_p_z_given_x(self, z, recon_x, x, sample_size=10): 322 | """ 323 | Estimate log(p(z|x)) using Bayes rule and Importance Sampling for log(p(x)) 324 | """ 325 | logpx = self.log_p_x(x, sample_size) 326 | lopgxz = self.log_p_x_given_z(recon_x, x) 327 | logpz = self.log_z(z) 328 | return lopgxz + logpz - logpx 329 | 330 | def log_p_xz(self, recon_x, x, z): 331 | """ 332 | Estimate log(p(x, z)) using Bayes rule 333 | """ 334 | logpxz = self.log_p_x_given_z(recon_x, x) 335 | logpz = self.log_z(z) 336 | return logpxz + logpz 337 | 338 | ########## Kullback-Leiber divergences estimates ########## 339 | 340 | def kl_prior(self, mu, log_var): 341 | """KL[q(z|y) || p(z)] : exact formula""" 342 | return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) 343 | 344 | def kl_cond(self, recon_x, x, z, mu, log_var, sample_size=10): 345 | """ 346 | KL[p(z|x) || q(z|x)] 347 | 348 | Note: 349 | ----- 350 | p(z|x) is approximated using IS on log(p(x)) 351 | """ 352 | logpzx = self.log_p_z_given_x(z, recon_x, x, sample_size=sample_size) 353 | logqzx = torch.distributions.MultivariateNormal( 354 | loc=mu, covariance_matrix=torch.diag_embed(torch.exp(log_var)) 355 | ).log_prob(z) 356 | 357 | return (logqzx - logpzx).sum() 358 | 359 | class HVAE(VAE): 360 | def __init__(self, args): 361 | """ 362 | Inputs: 363 | ------- 364 | 365 | n_lf (int): Number of leapfrog steps to perform 366 | eps_lf (float): Leapfrog step size 367 | beta_zero (float): Initial tempering 368 | tempering (str): Tempering type (free, fixed) 369 | model_type (str): Model type for VAR (mlp, convnet) 370 | latent_dim (int): Latentn dimension 371 | """ 372 | VAE.__init__(self, args) 373 | 374 | self.vae_forward = super().forward 375 | self.n_lf = args.n_lf 376 | 377 | self.eps_lf = nn.Parameter(torch.Tensor([args.eps_lf]), requires_grad=False) 378 | 379 | assert 0 < args.beta_zero <= 1, "Tempering factor should belong to [0, 1]" 380 | 381 | self.beta_zero_sqrt = nn.Parameter( 382 | torch.Tensor([args.beta_zero]), requires_grad=False 383 | ) 384 | 385 | def forward(self, x): 386 | """ 387 | The HVAE model 388 | """ 389 | 390 | recon_x, z0, eps0, mu, log_var = self.vae_forward(x) 391 | gamma = torch.randn_like(z0, device=self.device) 392 | rho = gamma / self.beta_zero_sqrt 393 | z = z0 394 | beta_sqrt_old = self.beta_zero_sqrt 395 | 396 | recon_x = self.decode(z) 397 | 398 | for k in range(self.n_lf): 399 | 400 | # perform leapfrog steps 401 | 402 | # computes potential energy 403 | U = -self.log_p_xz(recon_x, x, z).sum() 404 | 405 | # Compute its gradient 406 | g = grad(U, z, create_graph=True)[0] 407 | 408 | # 1st leapfrog step 409 | rho_ = rho - (self.eps_lf / 2) * g 410 | 411 | # 2nd leapfrog step 412 | z = z + self.eps_lf * rho_ 413 | 414 | recon_x = self.decode(z) 415 | 416 | U = -self.log_p_xz(recon_x, x, z).sum() 417 | g = grad(U, z, create_graph=True)[0] 418 | 419 | # 3rd leapfrog step 420 | rho__ = rho_ - (self.eps_lf / 2) * g 421 | 422 | # tempering steps 423 | beta_sqrt = self._tempering(k + 1, self.n_lf) 424 | rho = (beta_sqrt_old / beta_sqrt) * rho__ 425 | beta_sqrt_old = beta_sqrt 426 | 427 | return recon_x, z, z0, rho, eps0, gamma, mu, log_var 428 | 429 | def loss_function(self, recon_x, x, z0, zK, rhoK, eps0, gamma, mu, log_var): 430 | 431 | logpxz = self.log_p_xz(recon_x.reshape(x.shape[0], -1), x, zK) # log p(x, z_K) 432 | logrhoK = self.normal.log_prob(rhoK) # log p(\rho_K) 433 | logp = logpxz + logrhoK 434 | 435 | logq = self.normal.log_prob(eps0) - 0.5 * log_var.sum(dim=1) # q(z_0|x) 436 | 437 | return -(logp - logq).sum() 438 | 439 | 440 | def hamiltonian(self, recon_x, x, z, rho, G_inv=None, G_log_det=None): 441 | """ 442 | Computes the Hamiltonian function. 443 | used for HVAE and RHVAE 444 | """ 445 | if self.model_name == "HVAE": 446 | return -self.log_p_xz(recon_x.reshape(x.shape[0], -1), x, z).sum() 447 | 448 | # norm = (torch.solve(rho[:, :, None], G).solution[:, :, 0] * rho).sum() 449 | norm = ( 450 | torch.transpose(rho.unsqueeze(-1), 1, 2) @ G_inv @ rho.unsqueeze(-1) 451 | ).sum() 452 | 453 | return -self.log_p_xz(recon_x.reshape(x.shape[0], -1), x, z).sum() + 0.5 * norm + 0.5 * G_log_det.sum() 454 | 455 | def _tempering(self, k, K): 456 | """Perform tempering step""" 457 | 458 | beta_k = ( 459 | (1 - 1 / self.beta_zero_sqrt) * (k / K) ** 2 460 | ) + 1 / self.beta_zero_sqrt 461 | 462 | return 1 / beta_k 463 | 464 | 465 | class RHVAE(HVAE): 466 | def __init__(self, args): 467 | 468 | HVAE.__init__(self, args) 469 | # defines the Neural net to compute the metric 470 | 471 | # first layer 472 | self.metric_fc1 = nn.Linear(self.input_dim*self.n_channels, args.metric_fc) 473 | 474 | # diagonal 475 | self.metric_fc21 = nn.Linear(args.metric_fc, self.latent_dim) 476 | # remaining coefficients 477 | k = int(self.latent_dim * (self.latent_dim - 1) / 2) 478 | self.metric_fc22 = nn.Linear(args.metric_fc, k) 479 | 480 | self.T = nn.Parameter(torch.Tensor([args.temperature]), requires_grad=False) 481 | self.lbd = nn.Parameter( 482 | torch.Tensor([args.regularization]), requires_grad=False 483 | ) 484 | 485 | # this is used to store the matrices and centroids throughout trainning for 486 | # further use in metric update (L is the cholesky decomposition of M) 487 | self.M = [] 488 | self.centroids = [] 489 | 490 | # define a starting metric (gamma_i = 0 & L = I_d) 491 | def G(z): 492 | return ( 493 | torch.eye(self.latent_dim, device=self.device).unsqueeze(0) 494 | * torch.exp(-torch.norm(z.unsqueeze(1), dim=-1) ** 2) 495 | .unsqueeze(-1) 496 | .unsqueeze(-1) 497 | ).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(self.device) 498 | 499 | self.G = G 500 | 501 | def metric_forward(self, x): 502 | """ 503 | This function returns the outputs of the metric neural network 504 | 505 | Outputs: 506 | -------- 507 | 508 | L (Tensor): The L matrix as used in the metric definition 509 | M (Tensor): L L^T 510 | """ 511 | 512 | h1 = torch.relu(self.metric_fc1(x.reshape(-1, self.input_dim*self.n_channels))) 513 | h21, h22 = self.metric_fc21(h1), self.metric_fc22(h1) 514 | 515 | L = torch.zeros((x.shape[0], self.latent_dim, self.latent_dim)).to(self.device) 516 | indices = torch.tril_indices( 517 | row=self.latent_dim, col=self.latent_dim, offset=-1 518 | ) 519 | 520 | # get non-diagonal coefficients 521 | L[:, indices[0], indices[1]] = h22 522 | 523 | # add diagonal coefficients 524 | L = L + torch.diag_embed(h21.exp()) 525 | 526 | return L, L @ torch.transpose(L, 1, 2) 527 | 528 | def update_metric(self): 529 | """ 530 | As soon as the model has seen all the data points (i.e. at the end of 1 loop) 531 | we update the final metric function using \mu(x_i) as centroids 532 | """ 533 | # convert to 1 big tensor 534 | self.M_tens = torch.cat(self.M) 535 | self.centroids_tens = torch.cat(self.centroids) 536 | 537 | # define new metric 538 | def G(z): 539 | return torch.inverse( 540 | ( 541 | self.M_tens.unsqueeze(0) 542 | * torch.exp( 543 | -torch.norm( 544 | self.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1 545 | ) 546 | ** 2 547 | / (self.T ** 2) 548 | ) 549 | .unsqueeze(-1) 550 | .unsqueeze(-1) 551 | ).sum(dim=1) 552 | + self.lbd * torch.eye(self.latent_dim).to(self.device) 553 | ) 554 | 555 | def G_inv(z): 556 | return ( 557 | self.M_tens.unsqueeze(0) 558 | * torch.exp( 559 | -torch.norm( 560 | self.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1 561 | ) 562 | ** 2 563 | / (self.T ** 2) 564 | ) 565 | .unsqueeze(-1) 566 | .unsqueeze(-1) 567 | ).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(self.device) 568 | 569 | self.G = G 570 | self.G_inv = G_inv 571 | self.M = [] 572 | self.centroids = [] 573 | 574 | def forward(self, x): 575 | """ 576 | The RHVAE model 577 | """ 578 | 579 | recon_x, z0, eps0, mu, log_var = self.vae_forward(x) 580 | 581 | z = z0 582 | 583 | if self.training: 584 | 585 | # update the metric using batch data points 586 | L, M = self.metric_forward(x) 587 | 588 | # store LL^T and mu(x_i) to update final metric 589 | self.M.append(M.clone().detach()) 590 | self.centroids.append(mu.clone().detach()) 591 | 592 | G_inv = ( 593 | M.unsqueeze(0) 594 | * torch.exp( 595 | -torch.norm(mu.unsqueeze(0) - z.unsqueeze(1), dim=-1) ** 2 596 | / (self.T ** 2) 597 | ) 598 | .unsqueeze(-1) 599 | .unsqueeze(-1) 600 | ).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(self.device) 601 | 602 | else: 603 | G = self.G(z) 604 | G_inv = self.G_inv(z) 605 | L = torch.cholesky(G) 606 | 607 | G_log_det = -torch.logdet(G_inv) 608 | 609 | gamma = torch.randn_like(z0, device=self.device) 610 | rho = gamma / self.beta_zero_sqrt 611 | beta_sqrt_old = self.beta_zero_sqrt 612 | 613 | # sample \rho from N(0, G) 614 | rho = (L @ rho.unsqueeze(-1)).squeeze(-1) 615 | 616 | recon_x = self.decode(z) 617 | 618 | for k in range(self.n_lf): 619 | 620 | # perform leapfrog steps 621 | 622 | # step 1 623 | rho_ = self.leap_step_1(recon_x, x, z, rho, G_inv, G_log_det) 624 | 625 | # step 2 626 | z = self.leap_step_2(recon_x, x, z, rho_, G_inv, G_log_det) 627 | 628 | recon_x = self.decode(z) 629 | 630 | if self.training: 631 | G_inv = ( 632 | M.unsqueeze(0) 633 | * torch.exp( 634 | -torch.norm(mu.unsqueeze(0) - z.unsqueeze(1), dim=-1) ** 2 635 | / (self.T ** 2) 636 | ) 637 | .unsqueeze(-1) 638 | .unsqueeze(-1) 639 | ).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(self.device) 640 | 641 | else: 642 | # compute metric value on new z using final metric 643 | G = self.G(z) 644 | G_inv = self.G_inv(z) 645 | 646 | G_log_det = -torch.logdet(G_inv) 647 | 648 | # step 3 649 | rho__ = self.leap_step_3(recon_x, x, z, rho_, G_inv, G_log_det) 650 | 651 | # tempering 652 | beta_sqrt = self._tempering(k + 1, self.n_lf) 653 | rho = (beta_sqrt_old / beta_sqrt) * rho__ 654 | beta_sqrt_old = beta_sqrt 655 | 656 | return recon_x, z, z0, rho, eps0, gamma, mu, log_var, G_inv, G_log_det 657 | 658 | def leap_step_1(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3): 659 | """ 660 | Resolves first equation of generalized leapfrog integrator 661 | using fixed point iterations 662 | """ 663 | 664 | def f_(rho_): 665 | H = self.hamiltonian(recon_x, x, z, rho_, G_inv, G_log_det) 666 | gz = grad(H, z, retain_graph=True)[0] 667 | return rho - 0.5 * self.eps_lf * gz 668 | 669 | rho_ = rho.clone() 670 | for _ in range(steps): 671 | rho_ = f_(rho_) 672 | return rho_ 673 | 674 | def leap_step_2(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3): 675 | """ 676 | Resolves second equation of generalized leapfrog integrator 677 | using fixed point iterations 678 | """ 679 | H0 = self.hamiltonian(recon_x, x, z, rho, G_inv, G_log_det) 680 | grho_0 = grad(H0, rho)[0] 681 | 682 | def f_(z_): 683 | H = self.hamiltonian(recon_x, x, z_, rho, G_inv, G_log_det) 684 | grho = grad(H, rho, retain_graph=True)[0] 685 | return z + 0.5 * self.eps_lf * (grho_0 + grho) 686 | 687 | z_ = z.clone() 688 | for _ in range(steps): 689 | z_ = f_(z_) 690 | return z_ 691 | 692 | def leap_step_3(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3): 693 | """ 694 | Resolves third equation of generalized leapfrog integrator 695 | using fixed point iterations 696 | """ 697 | H = self.hamiltonian(recon_x, x, z, rho, G_inv, G_log_det) 698 | gz = grad(H, z, create_graph=True)[0] 699 | return rho - 0.5 * self.eps_lf * gz 700 | 701 | def loss_function( 702 | self, recon_x, x, z0, zK, rhoK, eps0, gamma, mu, log_var, G_inv, G_log_det 703 | ): 704 | 705 | logpxz = self.log_p_xz(recon_x.reshape(x.shape[0], -1), x, zK) # log p(x, z_K) 706 | logrhoK = ( 707 | -0.5 708 | * (torch.transpose(rhoK.unsqueeze(-1), 1, 2) @ G_inv @ rhoK.unsqueeze(-1)) 709 | .squeeze() 710 | .squeeze() 711 | - 0.5 * G_log_det 712 | ) - torch.log( 713 | torch.tensor([2 * np.pi]).to(self.device) 714 | ) * self.latent_dim / 2 # log p(\rho_K) 715 | 716 | logp = logpxz + logrhoK 717 | 718 | logq = self.normal.log_prob(eps0) - 0.5 * log_var.sum(dim=1) # log(q(z_0|x)) 719 | 720 | return -(logp - logq).sum() 721 | --------------------------------------------------------------------------------