├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── analysis.py ├── datasets ├── dynamic_mnist │ └── readme.me ├── fashion_mnist │ └── readme.me └── omniglot │ └── readme.me ├── density_estimation.py ├── images ├── augmentation.gif ├── celebA_exemplar_generation.png ├── cyclic_generation.png ├── data_augmentation.png ├── density_estimation.png ├── exemplar_generation.png └── full_generation.png ├── models ├── AbsHModel.py ├── AbsModel.py ├── BaseModel.py ├── HVAE_2level.py ├── PixelCNN.py ├── VAE.py ├── __init__.py ├── convHVAE_2level.py └── fully_conv.py ├── pretrained_model └── exemplar_prior_on_dynamic_mnist_model_name=vae │ └── 1 │ ├── checkpoint.pth │ ├── checkpoint_best.pth │ ├── generated │ ├── generated_0.png │ ├── generated_1.png │ ├── generated_10.png │ ├── generated_11.png │ ├── generated_12.png │ ├── generated_13.png │ ├── generated_14.png │ ├── generated_15.png │ ├── generated_16.png │ ├── generated_17.png │ ├── generated_18.png │ ├── generated_19.png │ ├── generated_2.png │ ├── generated_20.png │ ├── generated_21.png │ ├── generated_22.png │ ├── generated_23.png │ ├── generated_24.png │ ├── generated_25.png │ ├── generated_26.png │ ├── generated_27.png │ ├── generated_28.png │ ├── generated_29.png │ ├── generated_3.png │ ├── generated_30.png │ ├── generated_31.png │ ├── generated_32.png │ ├── generated_33.png │ ├── generated_34.png │ ├── generated_35.png │ ├── generated_36.png │ ├── generated_37.png │ ├── generated_38.png │ ├── generated_39.png │ ├── generated_4.png │ ├── generated_40.png │ ├── generated_41.png │ ├── generated_42.png │ ├── generated_43.png │ ├── generated_44.png │ ├── generated_45.png │ ├── generated_46.png │ ├── generated_47.png │ ├── generated_48.png │ ├── generated_49.png │ ├── generated_5.png │ ├── generated_6.png │ ├── generated_7.png │ ├── generated_8.png │ └── generated_9.png │ ├── generations_0.png │ ├── real.png │ ├── reconstructions.png │ ├── vae.config │ ├── vae.test_kl │ ├── vae.test_log_likelihood │ ├── vae.test_loss │ ├── vae.test_re │ ├── vae.train_kl │ ├── vae.train_loss │ ├── vae.train_re │ ├── vae.val_kl │ ├── vae.val_loss │ ├── vae.val_re │ ├── vae_config.txt │ ├── vae_experiment_log.txt │ └── whole_log.txt ├── requirements.txt └── utils ├── __init__.py ├── classify_data.py ├── distributions.py ├── evaluation.py ├── knn_on_latent.py ├── load_data ├── __init__.py ├── base_load_data.py └── data_loader_instances.py ├── nn.py ├── optimizer.py ├── plot_images.py ├── training.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | snapshots/* 3 | models/__pycache__ 4 | utils/load_data/__pycache__ 5 | utils/__pycache__ 6 | checkpoints/ 7 | __pycache__ 8 | .idea/* 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jakub Tomczak 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Exemplar-VAE 2 | Code for reproducing results in [Exemplar VAE](https://arxiv.org/abs/2004.04795) paper; Accepted to NeurIPS 2020 3 | 4 | ## Requirements 5 | ``` 6 | pip3 install -r requirements.txt 7 | ``` 8 | ## Exemplar VAE Samples 9 | 10 | 11 | 12 | ## Exemplar Based Generation 13 | ``` 14 | python3 analysis.py --dir pretrained_model --generate 15 | ``` 16 | 17 | 18 | 19 | 20 | ## Density Estimation 21 | ``` 22 | python3 density_estimation.py --prior exemplar_prior --dataset {dynamic_mnist|fashion_mnist|omniglot} --model_name {vae|hvae_2level|convhvae_2level} --number_components {25000|11500} --approximate_prior {True|False} 23 | ``` 24 | 25 | 26 | 27 | ## Data Augmentation 28 | ``` 29 | python3 analysis.py --dir pretrained_model --classify 30 | ``` 31 | 32 | 33 | 34 | 35 | ## Cyclic Generation 36 | ``` 37 | python3 analysis.py --dir pretrained_model --cyclic_generation 38 | ``` 39 | 40 | 41 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/__init__.py -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from utils.load_data.data_loader_instances import load_dataset 4 | import torchvision 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | from utils.plot_images import imshow 9 | from utils.utils import load_model 10 | from utils.classify_data import classify_data 11 | from utils.knn_on_latent import report_knn_on_latent, extract_full_data 12 | from utils.evaluation import compute_mean_variance_per_dimension 13 | from utils.plot_images import plot_images_in_line, generate_fancy_grid 14 | from utils.utils import importing_model 15 | from sklearn.manifold import TSNE 16 | import copy 17 | from pylab import rcParams 18 | 19 | 20 | parser = argparse.ArgumentParser(description='VAE+VampPrior') 21 | parser.add_argument('--KNN', action='store_true', default=False, help='run KNN classification on latent') 22 | parser.add_argument('--generate', action='store_true', default=False, help='generate images') 23 | parser.add_argument('--classify', action='store_true', default=False, 24 | help='train a classifier on data with augmentation') 25 | parser.add_argument('--dir', type=str, default='directory of pretrained model') 26 | parser.add_argument('--just_log_likelihood', action='store_true', default=False) 27 | parser.add_argument('--cyclic_generation', action='store_true', default=False, help='cyclic generation') 28 | parser.add_argument('--training_set_size', default=50000, type=int) 29 | parser.add_argument('--hyper_lambda', type=float, default=0.4, help='proportion of real data to augmented data') 30 | parser.add_argument('--lr', type=float, default=0.1) 31 | parser.add_argument('--batch_size', type=int, default=100) 32 | parser.add_argument('--input_size', type=list, default=[1, 28, 28]) 33 | parser.add_argument('--count_active_dimensions', action='store_true', default=False) 34 | parser.add_argument('--grid_interpolation', action='store_true', default=False) 35 | parser.add_argument('--tsne_visualization', action='store_true', default=False) 36 | parser.add_argument('--hidden_units', type=int, default=1024) 37 | parser.add_argument('--save_model_path', type=str, default='') 38 | parser.add_argument('--classification_dir', type=str, default='classification_report') 39 | parser.add_argument('--epochs', type=int, default=100) 40 | parser.add_argument('--seed', type=int, default=1) 41 | args = parser.parse_args() 42 | 43 | print(args) 44 | 45 | TRAIN_NUM = 50000 46 | 47 | 48 | def plot_data(data, labels): 49 | k = 10 50 | print(data.shape) 51 | subplot_num = data.shape[1] 52 | for i in range(subplot_num): 53 | plt.subplot2grid((subplot_num, 1), (i, 0), colspan=1, rowspan=1) 54 | imshow(torchvision.utils.make_grid(data[:k, i, :].view(-1, 1, 28, 28))) 55 | plt.axis('off') 56 | print(labels[:k, i, :].squeeze()) 57 | plt.show() 58 | 59 | directory = args.dir 60 | 61 | 62 | def grid_interpolation_in_latent(model, dir, index, reference_image): 63 | z, _ = model.q_z(reference_image.to(args.device), prior=True) 64 | whole_generation = [] 65 | for offset_0 in range(-2, 3, 1): 66 | row_generation = [] 67 | for offset_1 in range(-2, 3, 1): 68 | new_z = copy.deepcopy(z) 69 | new_z[0][0] += offset_0*3 70 | new_z[0][1] += offset_1*3 71 | image = model.generate_x_from_z(new_z, with_reparameterize=False) 72 | row_generation.append(image) 73 | whole_generation.append(torch.cat(row_generation, dim=0)) 74 | # print("LENNN", len(whole_generation)) 75 | whole_generation = torch.cat(whole_generation, dim=0) 76 | print('whole_generation shape', whole_generation.shape) 77 | imshow(torchvision.utils.make_grid(whole_generation.reshape(-1, *model.args.input_size), nrow=5)) 78 | save_dir = os.path.join(dir, 'grid_interpolation') 79 | os.makedirs(save_dir, exist_ok=True) 80 | plt.axis('off') 81 | plt.savefig(os.path.join(save_dir, 'interpolation{}'.format(i)), bbox='tight') 82 | 83 | 84 | def compute_test_metrics(test_log_likelihood, test_kl, test_re): 85 | test_log_likelihood.append(torch.load(dir + model_name + '.test_log_likelihood')) 86 | 87 | kl = torch.load(dir + model_name + '.test_kl') 88 | if type(kl) == torch.Tensor: 89 | kl = kl.cpu().numpy() 90 | test_kl.append(kl) 91 | 92 | reconst = torch.load(dir + model_name + '.test_re') 93 | if type(reconst) == torch.Tensor: 94 | reconst = reconst.cpu().numpy() 95 | test_re.append(reconst) 96 | 97 | 98 | def cyclic_generation(start_data, dir, index): 99 | cyclic_generation_dir = os.path.join(dir, 'cyclic_generation') 100 | os.makedirs(cyclic_generation_dir, exist_ok=True) 101 | single_data = start_data.unsqueeze(0) 102 | generated_cycle = [single_data.to(args.device)] 103 | for i in range(29): 104 | single_data = \ 105 | model.reference_based_generation_x(N=1, reference_image=single_data) 106 | generated_cycle.append(single_data) 107 | 108 | generated_cycle = torch.cat(generated_cycle, dim=0) 109 | plot_images_in_line(generated_cycle, args, cyclic_generation_dir, 'cycle_{}.png'.format(index)) 110 | 111 | 112 | temp = '' 113 | active_units_text = '' 114 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 115 | 116 | for folder in sorted(os.listdir(directory)): 117 | if os.path.isdir(directory+'/'+folder) is False: 118 | continue 119 | knn_results = [] 120 | test_log_likelihoods, test_kl, test_reconst, active_dimensions = [], [], [], [] 121 | knn_dictionary = {'3': [], '5': [], '7': [], '9': [], '11': [], '13': [], '15': []} 122 | 123 | 124 | torch.manual_seed(args.seed) 125 | if args.device=='cuda': 126 | torch.cuda.manual_seed(args.seed) 127 | np.random.seed(args.seed) 128 | 129 | for filename in os.listdir(directory+'/'+folder): 130 | print('filename**', filename) 131 | dir = directory + '/' + folder+'/'+filename + '/' 132 | model_name_start_index = folder.find('model_name=') 133 | model_name = folder[model_name_start_index + len('model_name='):] 134 | print("MODEL NAME", model_name) 135 | 136 | config = torch.load(dir + model_name + '.config') 137 | config.device = args.device 138 | VAE = importing_model(config) 139 | model = VAE(config) 140 | model.to(args.device) 141 | train_loader, val_loader, test_loader, config = load_dataset(config, 142 | training_num=args.training_set_size, 143 | no_binarization=True) 144 | 145 | if args.just_log_likelihood is False: 146 | load_model(dir + 'checkpoint_best.pth', model) 147 | model.eval() 148 | try: 149 | print('prior variance', model.prior_log_variance.item()) 150 | except: 151 | pass 152 | 153 | if args.cyclic_generation: 154 | with torch.no_grad(): 155 | for i in range(10): 156 | random_image = torch.rand([784]) 157 | cyclic_generation(random_image, dir, index=i) 158 | 159 | if args.KNN: 160 | with torch.no_grad(): 161 | report_knn_on_latent(train_loader, val_loader, test_loader, model, 162 | dir, knn_dictionary, args, val=False) 163 | if args.generate: 164 | with torch.no_grad(): 165 | exemplars_n = 50 166 | selected_indices = torch.randint(low=0, high=config.training_set_size, size=(exemplars_n,)) 167 | reference_images, indices, labels =train_loader.dataset[selected_indices] 168 | per_exemplar = 11 169 | generated = model.reference_based_generation_x(N=per_exemplar, reference_image=reference_images) 170 | generated = generated.reshape(-1, per_exemplar, *config.input_size) 171 | rcParams['figure.figsize'] = 4, 3 172 | generated_dir = dir + 'generated/' 173 | if config.use_logit: 174 | reference_images = model.logit_inverse(reference_images) 175 | generate_fancy_grid(config, dir, reference_images, generated) 176 | 177 | if args.count_active_dimensions: 178 | train_loader, val_loader, test_loader, config = load_dataset(config, 179 | training_num=args.training_set_size, 180 | no_binarization=False) 181 | with torch.no_grad(): 182 | num_active = compute_mean_variance_per_dimension(args, model, test_loader) 183 | active_dimensions.append(num_active) 184 | 185 | #TODO remove loop 186 | if args.grid_interpolation: 187 | with torch.no_grad(): 188 | for i in range(100): 189 | image = train_loader.dataset.tensors[0][torch.randint(low=0, high=args.training_set_size, 190 | size=(1,))] 191 | grid_interpolation_in_latent(model, dir, i, reference_image=image) 192 | 193 | if args.tsne_visualization: 194 | test_x, _, test_labels = extract_full_data(test_loader) 195 | test_z, _ = model.q_z(test_x.to(args.device)) 196 | tsne = TSNE(n_components=2) 197 | plt_colors = np.array( 198 | ['blue', 'orange', 'green', 'red', 'cyan', 'pink', 'purple', 'brown', 'gray', 'olive']) 199 | 200 | points_to_visualize = tsne.fit_transform(X=test_z.detach().cpu().numpy()) 201 | plt.scatter(points_to_visualize[:, 0], points_to_visualize[:, 1], 202 | c=plt_colors[test_labels.cpu().numpy()], s=2) 203 | plt.savefig(dir+'tsne.png') 204 | plt.show() 205 | 206 | if args.classify: 207 | test_acc = [] 208 | val_acc = [] 209 | test_acc_single_run, val_acc_single_run = classify_data(train_loader, val_loader, test_loader, 210 | args.classification_dir, args, model) 211 | test_acc.append(test_acc_single_run) 212 | val_acc.append(val_acc_single_run) 213 | test_acc = np.array(test_acc) 214 | val_acc = np.array(val_acc) 215 | 216 | print('averaged test accuracy: {0:.2f} \\pm {1:.2f}'.format(np.mean(test_acc), np.std(test_acc))) 217 | print('averaged val accuracy: {0:.2f} \\pm {1:.2f}'.format(np.mean(val_acc), np.std(val_acc))) 218 | exit() 219 | else: 220 | compute_test_metrics(test_log_likelihoods, test_kl, test_reconst) 221 | 222 | if args.just_log_likelihood: 223 | test_log_likelihoods = np.array(test_log_likelihoods) 224 | print("test log-likelihood", np.mean(test_log_likelihoods), np.std(test_log_likelihoods)) 225 | 226 | if args.count_active_dimensions: 227 | active_dimensions = np.array(active_dimensions).astype(float) 228 | print(np.mean(active_dimensions), np.std(active_dimensions)) 229 | -------------------------------------------------------------------------------- /datasets/dynamic_mnist/readme.me: -------------------------------------------------------------------------------- 1 | Will be downloaded by pytorch 2 | 3 | -------------------------------------------------------------------------------- /datasets/fashion_mnist/readme.me: -------------------------------------------------------------------------------- 1 | Will be downloaded by pytorch 2 | -------------------------------------------------------------------------------- /datasets/omniglot/readme.me: -------------------------------------------------------------------------------- 1 | https://github.com/yburda/iwae/tree/master/datasets/OMNIGLOT 2 | -------------------------------------------------------------------------------- /density_estimation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import datetime 4 | from utils.load_data.data_loader_instances import load_dataset 5 | from utils.utils import importing_model 6 | import torch 7 | import math 8 | import os 9 | from utils.utils import save_model, load_model 10 | from utils.optimizer import AdamNormGrad 11 | import time 12 | from utils.training import train_one_epoch 13 | from utils.evaluation import evaluate_loss, final_evaluation 14 | import random 15 | 16 | def str2bool(v): 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 20 | return True 21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError('Boolean value expected.') 25 | 26 | 27 | parser = argparse.ArgumentParser(description='VAE+VampPrior') 28 | parser.add_argument('--batch_size', type=int, default=100, metavar='BStrain', 29 | help='input batch size for training (default: 100)') 30 | parser.add_argument('--test_batch_size', type=int, default=100, metavar='BStest', 31 | help='input batch size for testing (default: 100)') 32 | parser.add_argument('--epochs', type=int, default=2000, metavar='E', 33 | help='number of epochs to train (default: 2000)') 34 | parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', 35 | help='learning rate (default: 0.0005)') 36 | parser.add_argument('--early_stopping_epochs', type=int, default=50, metavar='ES', 37 | help='number of epochs for early stopping') 38 | parser.add_argument('--z1_size', type=int, default=40, metavar='M1', 39 | help='latent size') 40 | parser.add_argument('--z2_size', type=int, default=40, metavar='M2', 41 | help='latent size') 42 | parser.add_argument('--input_size', type=int, default=[1, 28, 28], metavar='D', 43 | help='input size') 44 | parser.add_argument('--number_components', type=int, default=50000, metavar='NC', 45 | help='number of pseudo-inputs') 46 | parser.add_argument('--pseudoinputs_mean', type=float, default=-0.05, metavar='PM', 47 | help='mean for init pseudo-inputs') 48 | parser.add_argument('--pseudoinputs_std', type=float, default=0.01, metavar='PS', 49 | help='std for init pseudo-inputs') 50 | parser.add_argument('--use_training_data_init', action='store_true', default=False, 51 | help='initialize pseudo-inputs with randomly chosen training data') 52 | parser.add_argument('--model_name', type=str, default='vae', metavar='MN', 53 | help='model name: vae, hvae_2level, convhvae_2level') 54 | parser.add_argument('--prior', type=str, default='vampprior', metavar='P', 55 | help='prior: standard, vampprior, exemplar_prior') 56 | parser.add_argument('--input_type', type=str, default='binary', metavar='IT', 57 | help='type of the input: binary, gray, continuous, pca') 58 | parser.add_argument('--S', type=int, default=5000, metavar='SLL', 59 | help='number of samples used for approximating log-likelihood,' 60 | 'i.e. number of samples in IWAE') 61 | parser.add_argument('--MB', type=int, default=100, metavar='MBLL', 62 | help='size of a mini-batch used for approximating log-likelihood') 63 | parser.add_argument('--use_whole_train', type=str2bool, default=False, 64 | help='use whole training data points at the test time') 65 | parser.add_argument('--dataset_name', type=str, default='freyfaces', metavar='DN', 66 | help='name of the dataset: static_mnist, dynamic_mnist, omniglot, caltech101silhouettes,' 67 | ' histopathologyGray, freyfaces, cifar10') 68 | parser.add_argument('--dynamic_binarization', action='store_true', default=False, 69 | help='allow dynamic binarization') 70 | parser.add_argument('--seed', type=int, default=14, metavar='S', 71 | help='random seed (default: 14)') 72 | 73 | parser.add_argument('--no_mask', action='store_true', default=False, help='no leave one out') 74 | 75 | parser.add_argument('--parent_dir', type=str, default='') 76 | parser.add_argument('--same_variational_var', type=str2bool, default=False, 77 | help='use same variance for different dimentions') 78 | parser.add_argument('--model_signature', type=str, default='', help='load from this directory and continue training') 79 | parser.add_argument('--warmup', type=int, default=100, metavar='WU', 80 | help='number of epochs for warmu-up') 81 | parser.add_argument('--slurm_task_id', type=str, default='') 82 | parser.add_argument('--slurm_job_id', type=str, default='') 83 | parser.add_argument('--approximate_prior', type=str2bool, default=False) 84 | parser.add_argument('--just_evaluate', type=str2bool, default=False) 85 | parser.add_argument('--no_attention', type=str2bool, default=False) 86 | parser.add_argument('--approximate_k', type=int, default=10) 87 | parser.add_argument('--hidden_size', type=int, default=300) 88 | parser.add_argument('--base_dir', type=str, default='snapshots/') 89 | parser.add_argument('--continuous', type=str2bool, default=False) 90 | parser.add_argument('--use_logit', type=str2bool, default=False) 91 | parser.add_argument('--lambd', type=float, default=1e-4) 92 | parser.add_argument('--bottleneck', type=int, default=6) 93 | parser.add_argument('--training_set_size', type=int, default=50000) 94 | 95 | 96 | def initial_or_load(checkpoint_path_load, model, optimizer, dir): 97 | if os.path.exists(checkpoint_path_load): 98 | model_loaded_str = "******model is loaded*********" 99 | print(model_loaded_str) 100 | with open(dir + 'whole_log.txt', 'a') as f: 101 | print(model_loaded_str, file=f) 102 | checkpoint = load_model(checkpoint_path_load, model, optimizer) 103 | begin_epoch = checkpoint['epoch'] 104 | best_loss = checkpoint['best_loss'] 105 | e = checkpoint['e'] 106 | else: 107 | torch.manual_seed(args.seed) 108 | if args.device=='cuda': 109 | torch.cuda.manual_seed(args.seed) 110 | random.seed(args.seed) 111 | begin_epoch = 1 112 | best_loss = math.inf 113 | e = 0 114 | return begin_epoch, best_loss, e 115 | 116 | 117 | def save_loss_files(folder, train_loss_history, 118 | train_re_history, train_kl_history, val_loss_history, val_re_history, val_kl_history): 119 | torch.save(train_loss_history, folder + '.train_loss') 120 | torch.save(train_re_history, folder + '.train_re') 121 | torch.save(train_kl_history, folder + '.train_kl') 122 | torch.save(val_loss_history, folder + '.val_loss') 123 | torch.save(val_re_history, folder + '.val_re') 124 | torch.save(val_kl_history, folder + '.val_kl') 125 | 126 | 127 | def run_density_estimation(args, train_loader_input, val_loader_input, test_loader_input, model, optimizer, dir, model_name='vae'): 128 | torch.save(args, dir + args.model_name + '.config') 129 | train_loss_history, train_re_history, train_kl_history, val_loss_history, val_re_history, val_kl_history, \ 130 | time_history = [], [], [], [], [], [], [] 131 | checkpoint_path_save = os.path.join(dir, 'checkpoint_temp.pth') 132 | checkpoint_path_load = os.path.join(dir, 'checkpoint.pth') 133 | best_model_path_load = os.path.join(dir, 'checkpoint_best.pth') 134 | decayed = False 135 | time_history = [] 136 | # with torch.autograd.detect_anomaly(): 137 | begin_epoch, best_loss, e = initial_or_load(checkpoint_path_load, model, optimizer, dir) 138 | if args.just_evaluate is False: 139 | for epoch in range(begin_epoch, args.epochs + 1): 140 | time_start = time.time() 141 | train_loss_epoch, train_re_epoch, train_kl_epoch \ 142 | = train_one_epoch(epoch, args, train_loader_input, model, optimizer) 143 | with torch.no_grad(): 144 | val_loss_epoch, val_re_epoch, val_kl_epoch = evaluate_loss(args, model, val_loader_input, 145 | dataset=train_loader_input.dataset) 146 | time_end = time.time() 147 | time_elapsed = time_end - time_start 148 | content = {'epoch': epoch, 'state_dict': model.state_dict(), 149 | 'optimizer': optimizer.state_dict(), 'best_loss': best_loss, 'e': e} 150 | if epoch % 10 == 0: 151 | save_model(checkpoint_path_save, checkpoint_path_load, content) 152 | if val_loss_epoch < best_loss: 153 | e = 0 154 | best_loss = val_loss_epoch 155 | print('->model saved<-') 156 | save_model(checkpoint_path_save, best_model_path_load, content) 157 | else: 158 | e += 1 159 | if epoch < args.warmup: 160 | e = 0 161 | if e > args.early_stopping_epochs: 162 | break 163 | 164 | if math.isnan(val_loss_epoch): 165 | print("***** val loss is Nan *******") 166 | break 167 | 168 | for param_group in optimizer.param_groups: 169 | learning_rate = param_group['lr'] 170 | break 171 | 172 | time_history.append(time_elapsed) 173 | 174 | epoch_report = 'Epoch: {}/{}, Time elapsed: {:.2f}s\n' \ 175 | 'learning rate: {:.5f}\n' \ 176 | '* Train loss: {:.2f} (RE: {:.2f}, KL: {:.2f})\n' \ 177 | 'o Val. loss: {:.2f} (RE: {:.2f}, KL: {:.2f})\n' \ 178 | '--> Early stopping: {}/{} (BEST: {:.2f})\n'.format(epoch, args.epochs, time_elapsed, 179 | learning_rate, 180 | train_loss_epoch, train_re_epoch, 181 | train_kl_epoch, val_loss_epoch, 182 | val_re_epoch, val_kl_epoch, e, 183 | args.early_stopping_epochs, best_loss) 184 | 185 | if args.prior == 'exemplar_prior': 186 | print("Prior Variance", model.prior_log_variance.item()) 187 | if args.continuous is True: 188 | print("Decoder Variance", model.decoder_logstd.item()) 189 | print(epoch_report) 190 | with open(dir + 'whole_log.txt', 'a') as f: 191 | print(epoch_report, file=f) 192 | 193 | train_loss_history.append(train_loss_epoch), train_re_history.append( 194 | train_re_epoch), train_kl_history.append(train_kl_epoch) 195 | val_loss_history.append(val_loss_epoch), val_re_history.append(val_re_epoch), val_kl_history.append( 196 | val_kl_epoch) 197 | 198 | save_loss_files(dir + args.model_name, train_loss_history, 199 | train_re_history, train_kl_history, val_loss_history, val_re_history, val_kl_history) 200 | 201 | with torch.no_grad(): 202 | final_evaluation(train_loader_input, test_loader_input, val_loader_input, 203 | best_model_path_load, model, optimizer, args, dir) 204 | 205 | 206 | def run(args, kwargs): 207 | print('create model') 208 | # importing model 209 | VAE = importing_model(args) 210 | print('load data') 211 | train_loader, val_loader, test_loader, args = load_dataset(args, use_fixed_validation=True, **kwargs) 212 | if args.slurm_job_id != '': 213 | args.model_signature = str(args.seed) 214 | # base_dir = 'checkpoints/final_report/' 215 | elif args.model_signature == '': 216 | args.model_signature = str(datetime.datetime.now())[0:19] 217 | 218 | if args.parent_dir == '': 219 | args.parent_dir = args.prior + '_on_' + args.dataset_name+'_model_name='+args.model_name 220 | model_name = args.dataset_name + '_' + args.model_name + '_' + args.prior \ 221 | + '_(components_' + str(args.number_components) + ', lr=' + str(args.lr) + ')' 222 | snapshots_path = os.path.join(args.base_dir, args.parent_dir) + '/' 223 | dir = snapshots_path + args.model_signature + '_' + model_name + '_' + args.parent_dir + '/' 224 | 225 | if args.just_evaluate: 226 | config = torch.load(dir + args.model_name + '.config') 227 | config.translation = False 228 | config.hidden_size = 300 229 | model = VAE(config) 230 | else: 231 | model = VAE(args) 232 | if not os.path.exists(dir): 233 | os.makedirs(dir) 234 | model.to(args.device) 235 | optimizer = AdamNormGrad(model.parameters(), lr=args.lr) 236 | print(args) 237 | config_file = dir+'vae_config.txt' 238 | with open(config_file, 'a') as f: 239 | print(args, file=f) 240 | run_density_estimation(args, train_loader, val_loader, test_loader, model, optimizer, dir, model_name = args.model_name) 241 | 242 | 243 | if __name__ == "__main__": 244 | args = parser.parse_args() 245 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 246 | 247 | kwargs = {'num_workers': 2, 'pin_memory': True} if args.device=='cuda' else {} 248 | run(args, kwargs) 249 | 250 | -------------------------------------------------------------------------------- /images/augmentation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/augmentation.gif -------------------------------------------------------------------------------- /images/celebA_exemplar_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/celebA_exemplar_generation.png -------------------------------------------------------------------------------- /images/cyclic_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/cyclic_generation.png -------------------------------------------------------------------------------- /images/data_augmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/data_augmentation.png -------------------------------------------------------------------------------- /images/density_estimation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/density_estimation.png -------------------------------------------------------------------------------- /images/exemplar_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/exemplar_generation.png -------------------------------------------------------------------------------- /images/full_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/full_generation.png -------------------------------------------------------------------------------- /models/AbsHModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | from utils.distributions import log_normal_diag 6 | from .BaseModel import BaseModel 7 | 8 | 9 | class BaseHModel(BaseModel): 10 | def __init__(self, args): 11 | super(BaseHModel, self).__init__(args) 12 | 13 | def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices): 14 | z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = latent_stats 15 | if exemplars_embedding is None and self.args.prior == 'exemplar_prior': 16 | exemplars_embedding = self.get_exemplar_set(z2_q_mean, z2_q_logvar, 17 | dataset, cache, x_indices) 18 | log_p_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size), 19 | z1_p_mean.view(-1, self.args.z1_size), 20 | z1_p_logvar.view(-1, self.args.z1_size), dim=1) 21 | log_q_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size), 22 | z1_q_mean.view(-1, self.args.z1_size), 23 | z1_q_logvar.view(-1, self.args.z1_size), dim=1) 24 | log_p_z2 = self.log_p_z(z=(z2_q, x_indices), 25 | exemplars_embedding=exemplars_embedding) 26 | log_q_z2 = log_normal_diag(z2_q.view(-1, self.args.z2_size), 27 | z2_q_mean.view(-1, self.args.z2_size), 28 | z2_q_logvar.view(-1, self.args.z2_size), dim=1) 29 | return -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) 30 | 31 | def generate_x_from_z(self, z, with_reparameterize=True): 32 | z1_sample_mean, z1_sample_logvar = self.p_z1(z) 33 | if with_reparameterize: 34 | z1_sample_rand = self.reparameterize(z1_sample_mean, z1_sample_logvar) 35 | else: 36 | z1_sample_rand = z1_sample_mean 37 | 38 | if self.args.model_name=='pixelcnn': 39 | generated_xs = self.pixelcnn_generate(z1_sample_rand.view(-1, self.args.z1_size), z.reshape(-1, self.args.z2_size)) 40 | else: 41 | generated_xs, _ = self.p_x(z1_sample_rand.view(-1, self.args.z1_size), 42 | z.view(-1, self.args.z2_size)) 43 | return generated_xs 44 | 45 | def p_z1(self, z2): 46 | z2 = self.p_z1_layers_z2(z2) 47 | z1_p_mean = self.p_z1_mean(z2) 48 | z1_p_logvar = self.p_z1_logvar(z2) 49 | return z1_p_mean, z1_p_logvar 50 | 51 | def q_z1(self, x, z2): 52 | x = self.q_z1_layers_x(x) 53 | if self.args.model_name == 'convhvae_2level' or self.args.model_name == 'pixelcnn': 54 | x = x.view(x.size(0),-1) 55 | z2 = self.q_z1_layers_z2(z2) 56 | h = torch.cat((x,z2),1) 57 | h = self.q_z1_layers_joint(h) 58 | z1_q_mean = self.q_z1_mean(h) 59 | z1_q_logvar = self.q_z1_logvar(h) 60 | return z1_q_mean, z1_q_logvar 61 | 62 | def p_x(self, z1, z2, x=None): 63 | z1 = self.p_x_layers_z1(z1) 64 | 65 | z2 = self.p_x_layers_z2(z2) 66 | 67 | if self.args.model_name == 'pixelcnn': 68 | z2 = z2.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]) 69 | z1 = z1.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]) 70 | h = torch.cat((x, z1, z2), 1) 71 | # pixelcnn part of the decoder 72 | h_decoder = self.pixelcnn(h) 73 | 74 | else: 75 | 76 | h = torch.cat((z1, z2), 1) 77 | if 'convhvae_2level' in self.args.model_name: 78 | h = self.p_x_layers_joint_pre(h) 79 | h = h.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]) 80 | 81 | h_decoder = self.p_x_layers_joint(h) 82 | x_mean = self.p_x_mean(h_decoder) 83 | if 'convhvae_2level' in self.args.model_name or self.args.model_name=='pixelcnn': 84 | x_mean = x_mean.view(-1, np.prod(self.args.input_size)) 85 | 86 | if self.args.input_type == 'binary': 87 | x_logvar = 0. 88 | else: 89 | x_mean = torch.clamp(x_mean, min=0.+1./512., max=1.-1./512.) 90 | x_logvar = self.p_x_logvar(h_decoder) 91 | if 'convhvae_2level' in self.args.model_name or self.args.model_name=='pixelcnn': 92 | x_logvar = x_logvar.view(-1, np.prod(self.args.input_size)) 93 | 94 | return x_mean, x_logvar 95 | 96 | def forward(self, x): 97 | z2_q_mean, z2_q_logvar = self.q_z(x) 98 | z2_q = self.reparameterize(z2_q_mean, z2_q_logvar) 99 | z1_q_mean, z1_q_logvar = self.q_z1(x, z2_q) 100 | z1_q = self.reparameterize(z1_q_mean, z1_q_logvar) 101 | z1_p_mean, z1_p_logvar = self.p_z1(z2_q) 102 | if self.args.model_name == 'pixelcnn': 103 | x_mean, x_logvar = self.p_x(z1_q, z2_q, x=x) 104 | else: 105 | x_mean, x_logvar = self.p_x(z1_q, z2_q) 106 | return x_mean, x_logvar, (z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar) 107 | 108 | -------------------------------------------------------------------------------- /models/AbsModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | from models.BaseModel import BaseModel 6 | from utils.distributions import log_normal_diag 7 | 8 | 9 | class AbsModel(BaseModel): 10 | def __init__(self, args): 11 | super(AbsModel, self).__init__(args) 12 | 13 | def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices): 14 | z_q, z_q_mean, z_q_logvar = latent_stats 15 | if exemplars_embedding is None and self.args.prior == 'exemplar_prior': 16 | exemplars_embedding = self.get_exemplar_set(z_q_mean, z_q_logvar, dataset, cache, x_indices) 17 | log_p_z = self.log_p_z(z=(z_q, x_indices), exemplars_embedding=exemplars_embedding) 18 | log_q_z = log_normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) 19 | return -(log_p_z - log_q_z) 20 | 21 | def generate_x_from_z(self, z, with_reparameterize=True): 22 | generated_x, _ = self.p_x(z) 23 | try: 24 | if self.args.use_logit is True: 25 | return self.logit_inverse(generated_x) 26 | else: 27 | return generated_x 28 | except: 29 | return generated_x 30 | 31 | def p_x(self, z): 32 | if 'conv' in self.args.model_name: 33 | z = z.reshape(-1, self.bottleneck, self.args.input_size[1]//4, self.args.input_size[1]//4) 34 | z = self.p_x_layers(z) 35 | x_mean = self.p_x_mean(z) 36 | if self.args.input_type == 'binary': 37 | x_logvar = torch.zeros(1, np.prod(self.args.input_size)) 38 | else: 39 | if self.args.use_logit is False: 40 | x_mean = torch.clamp(x_mean, min=0.+1./512., max=1.-1./512.) 41 | x_logvar = self.decoder_logstd*x_mean.new_ones(size=x_mean.shape) 42 | return x_mean.reshape(-1, np.prod(self.args.input_size)), x_logvar.reshape(-1, np.prod(self.args.input_size)) 43 | 44 | def forward(self, x, label=0, num_categories=10): 45 | z_q_mean, z_q_logvar = self.q_z(x) 46 | 47 | z_q = self.reparameterize(z_q_mean, z_q_logvar) 48 | x_mean, x_logvar = self.p_x(z_q) 49 | return x_mean, x_logvar, (z_q, z_q_mean, z_q_logvar) 50 | -------------------------------------------------------------------------------- /models/BaseModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from utils.nn import normal_init, NonLinear 8 | from utils.distributions import log_normal_diag_vectorized 9 | import math 10 | from utils.nn import he_init 11 | from utils.distributions import pairwise_distance 12 | from utils.distributions import log_bernoulli, log_normal_diag, log_normal_standard, log_logistic_256 13 | from abc import ABC, abstractmethod 14 | 15 | 16 | class BaseModel(nn.Module, ABC): 17 | def __init__(self, args): 18 | super(BaseModel, self).__init__() 19 | print("constructor") 20 | self.args = args 21 | 22 | if self.args.prior == 'vampprior': 23 | self.add_pseudoinputs() 24 | 25 | if self.args.prior == 'exemplar_prior': 26 | self.prior_log_variance = torch.nn.Parameter(torch.randn((1))) 27 | 28 | if self.args.input_type == 'binary': 29 | self.p_x_mean = NonLinear(self.args.hidden_size, np.prod(self.args.input_size), activation=nn.Sigmoid()) 30 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': 31 | self.p_x_mean = NonLinear(self.args.hidden_size, np.prod(self.args.input_size)) 32 | self.p_x_logvar = NonLinear(self.args.hidden_size, np.prod(self.args.input_size), 33 | activation=nn.Hardtanh(min_val=-4.5, max_val=0)) 34 | self.decoder_logstd = torch.nn.Parameter(torch.tensor([0.], requires_grad=True)) 35 | 36 | self.create_model(args) 37 | self.he_initializer() 38 | 39 | def he_initializer(self): 40 | print("he initializer") 41 | 42 | for m in self.modules(): 43 | if isinstance(m, nn.Linear): 44 | he_init(m) 45 | 46 | @abstractmethod 47 | def create_model(self, args): 48 | pass 49 | 50 | @abstractmethod 51 | def kl_loss(self, latent_stats, exemplars_embeddin, dataset, cache, x_indices): 52 | pass 53 | 54 | def reconstruction_loss(self, x, x_mean, x_logvar): 55 | if self.args.input_type == 'binary': 56 | return log_bernoulli(x, x_mean, dim=1) 57 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': 58 | if self.args.use_logit is True: 59 | return log_normal_diag(x, x_mean, x_logvar, dim=1) 60 | else: 61 | return log_logistic_256(x, x_mean, x_logvar, dim=1) 62 | else: 63 | raise Exception('Wrong input type!') 64 | 65 | def calculate_loss(self, x, beta=1., average=False, 66 | exemplars_embedding=None, cache=None, dataset=None): 67 | x, x_indices = x 68 | x_mean, x_logvar, latent_stats = self.forward(x) 69 | RE = self.reconstruction_loss(x, x_mean, x_logvar) 70 | KL = self.kl_loss(latent_stats, exemplars_embedding, dataset, cache, x_indices) 71 | loss = -RE + beta*KL 72 | if average: 73 | loss = torch.mean(loss) 74 | RE = torch.mean(RE) 75 | KL = torch.mean(KL) 76 | 77 | return loss, RE, KL 78 | 79 | def reparameterize(self, mu, logvar): 80 | std = logvar.mul(0.5).exp_() 81 | eps = mu.new_empty(size=std.shape).normal_() 82 | return eps.mul(std).add_(mu) 83 | 84 | def log_p_z_vampprior(self, z, exemplars_embedding): 85 | if exemplars_embedding is None: 86 | C = self.args.number_components 87 | X = self.means(self.idle_input) 88 | z_p_mean, z_p_logvar = self.q_z(X, prior=True) # C x M 89 | else: 90 | C = torch.tensor(self.args.number_components).float() 91 | z_p_mean, z_p_logvar = exemplars_embedding 92 | 93 | z_expand = z.unsqueeze(1) 94 | means = z_p_mean.unsqueeze(0) 95 | logvars = z_p_logvar.unsqueeze(0) 96 | return log_normal_diag(z_expand, means, logvars, dim=2) - math.log(C) 97 | 98 | def log_p_z_exemplar(self, z, z_indices, exemplars_embedding, test): 99 | centers, center_log_variance, center_indices = exemplars_embedding 100 | denominator = torch.tensor(len(centers)).expand(len(z)).float().to(self.args.device) 101 | center_log_variance = center_log_variance[0, :].unsqueeze(0) 102 | prob, _ = log_normal_diag_vectorized(z, centers, center_log_variance) # MB x C 103 | if test is False and self.args.no_mask is False: 104 | mask = z_indices.expand(-1, len(center_indices)) \ 105 | == center_indices.squeeze().unsqueeze(0).expand(len(z_indices), -1) 106 | prob.masked_fill_(mask, value=float('-inf')) 107 | denominator = denominator - mask.sum(dim=1).float() 108 | prob -= torch.log(denominator).unsqueeze(1) 109 | return prob 110 | 111 | def log_p_z(self, z, exemplars_embedding, sum=True, test=None): 112 | z, z_indices = z 113 | if test is None: 114 | test = not self.training 115 | if self.args.prior == 'standard': 116 | return log_normal_standard(z, dim=1) 117 | elif self.args.prior == 'vampprior': 118 | prob = self.log_p_z_vampprior(z, exemplars_embedding) 119 | elif self.args.prior == 'exemplar_prior': 120 | prob = self.log_p_z_exemplar(z, z_indices, exemplars_embedding, test) 121 | else: 122 | raise Exception('Wrong name of the prior!') 123 | if sum: 124 | prob_max, _ = torch.max(prob, 1) # MB x 1 125 | log_prior = prob_max + torch.log(torch.sum(torch.exp(prob - prob_max.unsqueeze(1)), 1)) # MB x 1 126 | else: 127 | return prob 128 | return log_prior 129 | 130 | def add_pseudoinputs(self): 131 | nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0) 132 | self.means = NonLinear(self.args.number_components, np.prod(self.args.input_size), bias=False, activation=nonlinearity) 133 | # init pseudo-inputs 134 | if self.args.use_training_data_init: 135 | self.means.linear.weight.data = self.args.pseudoinputs_mean 136 | else: 137 | normal_init(self.means.linear, self.args.pseudoinputs_mean, self.args.pseudoinputs_std) 138 | self.idle_input = Variable(torch.eye(self.args.number_components, self.args.number_components), requires_grad=False) 139 | self.idle_input = self.idle_input.to(self.args.device) 140 | 141 | def generate_z_interpolate(self, exemplars_embedding=None, dim=0): 142 | new_zs = [] 143 | exemplars_embedding, _, _ = exemplars_embedding 144 | step_counts = 10 145 | step = (exemplars_embedding[1] - exemplars_embedding[0])/step_counts 146 | for i in range(step_counts): 147 | new_z = exemplars_embedding[0].clone() 148 | new_z += i*step 149 | new_zs.append(new_z.unsqueeze(0)) 150 | return torch.cat(new_zs, dim=0) 151 | 152 | def generate_z(self, N=25, dataset=None): 153 | if self.args.prior == 'standard': 154 | z_sample_rand = torch.FloatTensor(N, self.args.z1_size).normal_().to(self.args.device) 155 | elif self.args.prior == 'vampprior': 156 | means = self.means(self.idle_input)[0:N] 157 | z_sample_gen_mean, z_sample_gen_logvar = self.q_z(means) 158 | z_sample_rand = self.reparameterize(z_sample_gen_mean, z_sample_gen_logvar) 159 | z_sample_rand = z_sample_rand.to(self.args.device) 160 | elif self.args.prior == 'exemplar_prior': 161 | rand_indices = torch.randint(low=0, high=self.args.training_set_size, size=(N,)) 162 | exemplars = dataset.tensors[0][rand_indices] 163 | z_sample_gen_mean, z_sample_gen_logvar = self.q_z(exemplars.to(self.args.device), prior=True) 164 | z_sample_rand = self.reparameterize(z_sample_gen_mean, z_sample_gen_logvar) 165 | z_sample_rand = z_sample_rand.to(self.args.device) 166 | return z_sample_rand 167 | 168 | def reference_based_generation_z(self, N=25, reference_image=None): 169 | pseudo, log_var = self.q_z(reference_image.to(self.args.device), prior=True) 170 | pseudo = pseudo.unsqueeze(1).expand(-1, N, -1).reshape(-1, pseudo.shape[-1]) 171 | log_var = log_var[0].unsqueeze(0).expand(len(pseudo), -1) 172 | z_sample_rand = self.reparameterize(pseudo, log_var) 173 | z_sample_rand = z_sample_rand.reshape(-1, N, pseudo.shape[1]) 174 | return z_sample_rand 175 | 176 | def reconstruct_x(self, x): 177 | x_reconstructed, _, z = self.forward(x) 178 | if self.args.model_name == 'pixelcnn': 179 | x_reconstructed = self.pixelcnn_generate(z[0].reshape(-1, self.args.z1_size), z[3].reshape(-1, self.args.z2_size)) 180 | return x_reconstructed 181 | 182 | def logit_inverse(self, x): 183 | sigmoid = torch.nn.Sigmoid() 184 | lambd = self.args.lambd 185 | return ((sigmoid(x) - lambd)/(1-2*lambd)) 186 | 187 | def generate_x(self, N=25, dataset=None): 188 | z2_sample_rand = self.generate_z(N=N, dataset=dataset) 189 | return self.generate_x_from_z(z2_sample_rand) 190 | 191 | def reference_based_generation_x(self, N=25, reference_image=None): 192 | z2_sample_rand = \ 193 | self.reference_based_generation_z(N=N, reference_image=reference_image) 194 | generated_x = self.generate_x_from_z(z2_sample_rand) 195 | return generated_x 196 | 197 | def generate_x_interpolate(self, exemplars_embedding, dim=0): 198 | zs = self.generate_z_interpolate(exemplars_embedding, dim=dim) 199 | print(zs.shape) 200 | return self.generate_x_from_z(zs, with_reparameterize=False) 201 | 202 | def reshape_variance(self, variance, shape): 203 | return variance[0]*torch.ones(shape).to(self.args.device) 204 | 205 | def q_z(self, x, prior=False): 206 | if 'conv' in self.args.model_name or 'pixelcnn'==self.args.model_name: 207 | x = x.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]) 208 | h = self.q_z_layers(x) 209 | if self.args.model_name == 'convhvae_2level' or self.args.model_name=='pixelcnn': 210 | h = h.view(x.size(0), -1) 211 | z_q_mean = self.q_z_mean(h) 212 | if prior is True: 213 | if self.args.prior == 'exemplar_prior': 214 | z_q_logvar = self.prior_log_variance * torch.ones((x.shape[0], self.args.z1_size)).to(self.args.device) 215 | if self.args.model_name == 'newconvhvae_2level': 216 | z_q_logvar = z_q_logvar.reshape(-1, 4, 4, 4) 217 | else: 218 | z_q_logvar = self.q_z_logvar(h) 219 | else: 220 | z_q_logvar = self.q_z_logvar(h) 221 | return z_q_mean.reshape(-1, self.args.z1_size), z_q_logvar.reshape(-1, self.args.z1_size) 222 | 223 | def cache_z(self, dataset, prior=True, cuda=True): 224 | cached_z = [] 225 | cached_log_var = [] 226 | caching_batch_size = 10000 227 | num_batchs = math.ceil(len(dataset) / caching_batch_size) 228 | for i in range(num_batchs): 229 | if len(dataset[0]) == 3: 230 | batch_data, batch_indices, _ = dataset[i * caching_batch_size:(i + 1) * caching_batch_size] 231 | else: 232 | batch_data, _ = dataset[i * caching_batch_size:(i + 1) * caching_batch_size] 233 | 234 | exemplars_embedding, log_variance_z = self.q_z(batch_data.to(self.args.device), prior=prior) 235 | cached_z.append(exemplars_embedding) 236 | cached_log_var.append(log_variance_z) 237 | cached_z = torch.cat(cached_z, dim=0) 238 | cached_log_var = torch.cat(cached_log_var, dim=0) 239 | cached_z = cached_z.to(self.args.device) 240 | cached_log_var = cached_log_var.to(self.args.device) 241 | return cached_z, cached_log_var 242 | 243 | def get_exemplar_set(self, z_mean, z_log_var, dataset, cache, x_indices): 244 | if self.args.approximate_prior is False: 245 | exemplars_indices = torch.randint(low=0, high=self.args.training_set_size, 246 | size=(self.args.number_components, )) 247 | exemplars_z, log_variance = self.q_z(dataset.tensors[0][exemplars_indices].to(self.args.device), prior=True) 248 | exemplar_set = (exemplars_z, log_variance, exemplars_indices.to(self.args.device)) 249 | else: 250 | exemplar_set = self.get_approximate_nearest_exemplars( 251 | z=(z_mean, z_log_var, x_indices), 252 | dataset=dataset, 253 | cache=cache) 254 | return exemplar_set 255 | 256 | def get_approximate_nearest_exemplars(self, z, cache, dataset): 257 | exemplars_indices = torch.randint(low=0, high=self.args.training_set_size, 258 | size=(self.args.number_components, )).to(self.args.device) 259 | z, _, indices = z 260 | cached_z, cached_log_variance = cache 261 | cached_z[indices.reshape(-1)] = z 262 | sub_cache = cached_z[exemplars_indices, :] 263 | _, nearest_indices = pairwise_distance(z, sub_cache) \ 264 | .topk(k=self.args.approximate_k, largest=False, dim=1) 265 | nearest_indices = torch.unique(nearest_indices.view(-1)) 266 | exemplars_indices = exemplars_indices[nearest_indices].view(-1) 267 | exemplars = dataset.tensors[0][exemplars_indices].to(self.args.device) 268 | exemplars_z, log_variance = self.q_z(exemplars, prior=True) 269 | cached_z[exemplars_indices] = exemplars_z 270 | exemplar_set = (exemplars_z, log_variance, exemplars_indices) 271 | return exemplar_set 272 | 273 | 274 | -------------------------------------------------------------------------------- /models/HVAE_2level.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | import torch.nn as nn 6 | from torch.nn import Linear 7 | from utils.nn import GatedDense, NonLinear 8 | from models.AbsHModel import BaseHModel 9 | 10 | 11 | class VAE(BaseHModel): 12 | def __init__(self, args): 13 | super(VAE, self).__init__(args) 14 | 15 | def create_model(self, args): 16 | print("create_model") 17 | 18 | # becasue super is using h_size 19 | self.args = args 20 | 21 | # encoder: q(z2 | x) 22 | self.q_z_layers = nn.Sequential( 23 | GatedDense(np.prod(self.args.input_size), self.args.hidden_size), 24 | GatedDense(self.args.hidden_size, self.args.hidden_size) 25 | ) 26 | 27 | self.q_z_mean = Linear(self.args.hidden_size, self.args.z2_size) 28 | 29 | if args.same_variational_var: 30 | self.q_z_logvar = torch.nn.Parameter(torch.randn((1))) 31 | else: 32 | self.q_z_logvar = NonLinear(self.args.hidden_size, self.args.z2_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 33 | 34 | # encoder: q(z1 | x, z2) 35 | self.q_z1_layers_x = nn.Sequential( 36 | GatedDense(np.prod(self.args.input_size), self.args.hidden_size) 37 | ) 38 | self.q_z1_layers_z2 = nn.Sequential( 39 | GatedDense(self.args.z2_size, self.args.hidden_size) 40 | ) 41 | self.q_z1_layers_joint = nn.Sequential( 42 | GatedDense(2 * self.args.hidden_size, self.args.hidden_size) 43 | ) 44 | 45 | self.q_z1_mean = Linear(self.args.hidden_size, self.args.z1_size) 46 | self.q_z1_logvar = NonLinear(self.args.hidden_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 47 | 48 | # decoder: p(z1 | z2) 49 | self.p_z1_layers_z2 = nn.Sequential( 50 | GatedDense(self.args.z2_size, self.args.hidden_size), 51 | GatedDense(self.args.hidden_size, self.args.hidden_size) 52 | ) 53 | 54 | self.p_z1_mean = Linear(self.args.hidden_size, self.args.z1_size) 55 | self.p_z1_logvar = NonLinear(self.args.hidden_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6.,max_val=2.)) 56 | 57 | # decoder: p(x | z1, z2) 58 | self.p_x_layers_z1 = nn.Sequential( 59 | GatedDense(self.args.z1_size, self.args.hidden_size) 60 | ) 61 | self.p_x_layers_z2 = nn.Sequential( 62 | GatedDense(self.args.z2_size, self.args.hidden_size) 63 | ) 64 | self.p_x_layers_joint = nn.Sequential( 65 | GatedDense(2 * self.args.hidden_size, self.args.hidden_size) 66 | ) 67 | 68 | 69 | -------------------------------------------------------------------------------- /models/PixelCNN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch.nn as nn 4 | from utils.nn import GatedDense, NonLinear, \ 5 | Conv2d, GatedConv2d, MaskedConv2d, PixelSNAIL 6 | from models.AbsHModel import BaseHModel 7 | import torch 8 | 9 | class VAE(BaseHModel): 10 | def __init__(self, args): 11 | super(VAE, self).__init__(args) 12 | 13 | def create_model(self, args): 14 | if args.dataset_name == 'freyfaces': 15 | self.h_size = 210 16 | elif args.dataset_name == 'cifar10' or args.dataset_name == 'svhn': 17 | self.h_size = 384 18 | else: 19 | self.h_size = 294 20 | 21 | # encoder: q(z2 | x) 22 | self.q_z_layers = nn.Sequential( 23 | GatedConv2d(self.args.input_size[0], 32, 7, 1, 3), 24 | GatedConv2d(32, 32, 3, 2, 1), 25 | GatedConv2d(32, 64, 5, 1, 2), 26 | GatedConv2d(64, 64, 3, 2, 1), 27 | GatedConv2d(64, 6, 3, 1, 1) 28 | ) 29 | # linear layers 30 | self.q_z_mean = NonLinear(self.h_size, self.args.z2_size, activation=None) 31 | self.q_z_logvar = NonLinear(self.h_size, self.args.z2_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 32 | 33 | # encoder: q(z1|x,z2) 34 | # PROCESSING x 35 | self.q_z1_layers_x = nn.Sequential( 36 | GatedConv2d(self.args.input_size[0], 32, 3, 1, 1), 37 | GatedConv2d(32, 32, 3, 2, 1), 38 | GatedConv2d(32, 64, 3, 1, 1), 39 | GatedConv2d(64, 64, 3, 2, 1), 40 | GatedConv2d(64, 6, 3, 1, 1) 41 | ) 42 | # PROCESSING Z2 43 | self.q_z1_layers_z2 = nn.Sequential( 44 | GatedDense(self.args.z2_size, self.h_size) 45 | ) 46 | # PROCESSING JOINT 47 | self.q_z1_layers_joint = nn.Sequential( 48 | GatedDense( 2 * self.h_size, 300) 49 | ) 50 | # linear layers 51 | self.q_z1_mean = NonLinear(300, self.args.z1_size, activation=None) 52 | self.q_z1_logvar = NonLinear(300, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 53 | 54 | # decoder p(z1|z2) 55 | self.p_z1_layers_z2 = nn.Sequential( 56 | GatedDense(self.args.z2_size, 300), 57 | GatedDense(300, 300) 58 | ) 59 | self.p_z1_mean = NonLinear(300, self.args.z1_size, activation=None) 60 | self.p_z1_logvar = NonLinear(300, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 61 | 62 | # decoder: p(x | z) 63 | self.p_x_layers_z1 = nn.Sequential( 64 | GatedDense(self.args.z1_size, np.prod(self.args.input_size)) 65 | ) 66 | self.p_x_layers_z2 = nn.Sequential( 67 | GatedDense(self.args.z2_size, np.prod(self.args.input_size)) 68 | ) 69 | 70 | # decoder: p(x | z) 71 | act = nn.ReLU(True) 72 | #self.pixelcnn = nn.Sequential( 73 | # MaskedConv2d('A', self.args.input_size[0] + 2 * self.args.input_size[0], 64, 3, 1, 1, bias=False), 74 | # nn.BatchNorm2d(64), act, 75 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act, 76 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act, 77 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act, 78 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act, 79 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act, 80 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act, 81 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act 82 | #) 83 | self.pixelcnn = PixelSNAIL([28, 28], 64, 64, 3, 1, 4, 64) 84 | 85 | if self.args.input_type == 'binary': 86 | self.p_x_mean = Conv2d(64, 1, 1, 1, 0, activation=nn.Sigmoid()) 87 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': 88 | self.p_x_mean = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Sigmoid(), bias=False) 89 | self.p_x_logvar = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Hardtanh(min_val=-4.5, max_val=0.), bias=False) 90 | 91 | def pixelcnn_generate(self, z1, z2): 92 | # Sampling from PixelCNN 93 | x_zeros = torch.zeros( 94 | (z1.size(0), self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])) 95 | x_zeros = x_zeros.to(self.args.device) 96 | 97 | for i in range(self.args.input_size[1]): 98 | for j in range(self.args.input_size[2]): 99 | samples_mean, samples_logvar = self.p_x(z1, z2, x=x_zeros.detach()) 100 | samples_mean = samples_mean.view(samples_mean.size(0), self.args.input_size[0], self.args.input_size[1], 101 | self.args.input_size[2]) 102 | 103 | if self.args.input_type == 'binary': 104 | probs = samples_mean[:, :, i, j].data 105 | x_zeros[:, :, i, j] = torch.bernoulli(probs).float() 106 | samples_gen = samples_mean 107 | 108 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': 109 | binsize = 1. / 256. 110 | samples_logvar = samples_logvar.view(samples_mean.size(0), self.args.input_size[0], 111 | self.args.input_size[1], self.args.input_size[2]) 112 | means = samples_mean[:, :, i, j].data 113 | logvar = samples_logvar[:, :, i, j].data 114 | # sample from logistic distribution 115 | u = torch.rand(means.size()).cuda() 116 | y = torch.log(u) - torch.log(1. - u) 117 | sample = means + torch.exp(logvar) * y 118 | x_zeros[:, :, i, j] = torch.floor(sample / binsize) * binsize 119 | samples_gen = samples_mean 120 | return samples_gen 121 | 122 | def forward(self, x): 123 | x = x.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]) 124 | return super(VAE, self).forward(x) 125 | 126 | 127 | -------------------------------------------------------------------------------- /models/VAE.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | import torch.nn as nn 6 | from torch.nn import Linear 7 | from utils.nn import GatedDense, NonLinear 8 | from models.AbsModel import AbsModel 9 | 10 | 11 | class VAE(AbsModel): 12 | def __init__(self, args): 13 | super(VAE, self).__init__(args) 14 | 15 | def create_model(self, args, train_data_size=None): 16 | self.train_data_size = train_data_size 17 | self.q_z_layers = nn.Sequential( 18 | GatedDense(np.prod(self.args.input_size), self.args.hidden_size, no_attention=self.args.no_attention), 19 | GatedDense(self.args.hidden_size, self.args.hidden_size, no_attention=self.args.no_attention) 20 | ) 21 | self.q_z_mean = Linear(self.args.hidden_size, self.args.z1_size) 22 | if args.same_variational_var: 23 | self.q_z_logvar = torch.nn.Parameter(torch.randn((1))) 24 | else: 25 | self.q_z_logvar = NonLinear(self.args.hidden_size, 26 | self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 27 | 28 | self.p_x_layers = nn.Sequential( 29 | GatedDense(self.args.z1_size, self.args.hidden_size, no_attention=self.args.no_attention), 30 | GatedDense(self.args.hidden_size, self.args.hidden_size, no_attention=self.args.no_attention)) 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/models/__init__.py -------------------------------------------------------------------------------- /models/convHVAE_2level.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch.nn as nn 4 | from utils.nn import GatedDense, NonLinear, \ 5 | Conv2d, GatedConv2d 6 | from models.AbsHModel import BaseHModel 7 | 8 | 9 | class VAE(BaseHModel): 10 | def __init__(self, args): 11 | super(VAE, self).__init__(args) 12 | 13 | def create_model(self, args): 14 | if args.dataset_name == 'freyfaces': 15 | self.h_size = 210 16 | elif args.dataset_name == 'cifar10' or args.dataset_name == 'svhn': 17 | self.h_size = 384 18 | else: 19 | self.h_size = 294 20 | 21 | fc_size = 300 22 | 23 | # encoder: q(z2 | x) 24 | self.q_z_layers = nn.Sequential( 25 | GatedConv2d(self.args.input_size[0], 32, 7, 1, 3, no_attention=args.no_attention), 26 | GatedConv2d(32, 32, 3, 2, 1, no_attention=args.no_attention), 27 | GatedConv2d(32, 64, 5, 1, 2, no_attention=args.no_attention), 28 | GatedConv2d(64, 64, 3, 2, 1, no_attention=args.no_attention), 29 | GatedConv2d(64, 6, 3, 1, 1, no_attention=args.no_attention) 30 | ) 31 | 32 | # linear layers 33 | self.q_z_mean = NonLinear(self.h_size, self.args.z2_size, activation=None) 34 | 35 | # SAME VARAITIONAL VAR TO SEE IF IT HELPS 36 | self.q_z_logvar = NonLinear(self.h_size, self.args.z2_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 37 | 38 | # encoder: q(z1|x,z2) 39 | # PROCESSING x 40 | self.q_z1_layers_x = nn.Sequential( 41 | GatedConv2d(self.args.input_size[0], 32, 3, 1, 1, no_attention=args.no_attention), 42 | GatedConv2d(32, 32, 3, 2, 1, no_attention=args.no_attention), 43 | GatedConv2d(32, 64, 3, 1, 1, no_attention=args.no_attention), 44 | GatedConv2d(64, 64, 3, 2, 1, no_attention=args.no_attention), 45 | GatedConv2d(64, 6, 3, 1, 1, no_attention=args.no_attention) 46 | ) 47 | # PROCESSING Z2 48 | self.q_z1_layers_z2 = nn.Sequential(GatedDense(self.args.z2_size, self.h_size)) 49 | 50 | # PROCESSING JOINT 51 | self.q_z1_layers_joint = nn.Sequential(GatedDense(2* self.h_size, fc_size)) 52 | 53 | # linear layers 54 | self.q_z1_mean = NonLinear(fc_size, self.args.z1_size, activation=None) 55 | self.q_z1_logvar = NonLinear(fc_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 56 | 57 | # decoder p(z1|z2) 58 | self.p_z1_layers_z2 = nn.Sequential( 59 | GatedDense(self.args.z2_size, fc_size, no_attention=args.no_attention), 60 | GatedDense(fc_size, fc_size, no_attention=args.no_attention) 61 | ) 62 | self.p_z1_mean = NonLinear(fc_size, self.args.z1_size, activation=None) 63 | self.p_z1_logvar = NonLinear(fc_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.)) 64 | 65 | # decoder: p(x | z) 66 | self.p_x_layers_z1 = nn.Sequential( 67 | GatedDense(self.args.z1_size, fc_size, no_attention=args.no_attention) 68 | ) 69 | self.p_x_layers_z2 = nn.Sequential( 70 | GatedDense(self.args.z2_size, fc_size, no_attention=args.no_attention) 71 | ) 72 | 73 | self.p_x_layers_joint_pre = nn.Sequential( 74 | GatedDense(2 * fc_size, np.prod(self.args.input_size), no_attention=args.no_attention) 75 | ) 76 | 77 | # decoder: p(x | z) 78 | self.p_x_layers_joint = nn.Sequential( 79 | GatedConv2d(self.args.input_size[0], 64, 3, 1, 1, no_attention=args.no_attention), 80 | GatedConv2d(64, 64, 3, 1, 1, no_attention=args.no_attention), 81 | GatedConv2d(64, 64, 3, 1, 1, no_attention=args.no_attention), 82 | GatedConv2d(64, 64, 3, 1, 1, no_attention=args.no_attention), 83 | ) 84 | 85 | if self.args.input_type == 'binary': 86 | self.p_x_mean = Conv2d(64, 1, 1, 1, 0, activation=nn.Sigmoid()) 87 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': 88 | self.p_x_mean = Conv2d(64, self.args.input_size[0], 1, 1, 0) 89 | self.p_x_logvar = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Hardtanh(min_val=-4.5, max_val=0.)) 90 | elif self.args.input_type == 'pca': 91 | self.p_x_mean = Conv2d(64, 1, 1, 1, 0) 92 | self.p_x_logvar = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Hardtanh(min_val=-4.5, max_val=0.)) 93 | 94 | # THE MODEL: FORWARD PASS 95 | def forward(self, x): 96 | x = x.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]) 97 | return super(VAE, self).forward(x) 98 | 99 | -------------------------------------------------------------------------------- /models/fully_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data 4 | import torch.nn as nn 5 | from models.AbsModel import AbsModel 6 | from torch.nn.utils import weight_norm 7 | 8 | class VAE(AbsModel): 9 | def __init__(self, args): 10 | super(VAE, self).__init__(args) 11 | 12 | def create_model(self, args, train_data_size=None): 13 | class block(nn.Module): 14 | def __init__(self, input_size, output_size, stride=1, kernel=3, padding=1): 15 | super(block, self).__init__() 16 | self.normalization = nn.BatchNorm2d(input_size) 17 | self.conv1 = weight_norm(nn.Conv2d(input_size, output_size, kernel_size=kernel, stride=stride, padding=padding, 18 | bias=True)) 19 | self.activation = torch.nn.ELU() 20 | self.f = torch.nn.Sequential(self.activation, self.conv1) 21 | 22 | def forward(self, x): 23 | return x + self.f(x) 24 | 25 | self.train_data_size = train_data_size 26 | self.cs = 48 27 | self.bottleneck=self.args.bottleneck 28 | self.q_z_layers = nn.Sequential( 29 | weight_norm(nn.Conv2d(in_channels=self.args.input_size[0], out_channels=self.cs, kernel_size=3, stride=2, padding=1)), 30 | nn.ELU(), 31 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 32 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 33 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 34 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 35 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 36 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 37 | weight_norm(nn.Conv2d(in_channels=self.cs, out_channels=self.cs*2, kernel_size=3, stride=2, padding=1)), 38 | nn.ELU(), 39 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 40 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 41 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 42 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 43 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 44 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 45 | # nn.Conv2d(in_channels=self.cs, out_channels=self.cs, kernel_size=3, stride=2, padding=1), 46 | # nn.ELU(), 47 | # nn.Conv2d(in_channels=self.cs, out_channels=self.cs, kernel_size=3, stride=1, padding=1), 48 | # nn.ELU(), 49 | ) 50 | self.q_z_mean = weight_norm(nn.Conv2d(in_channels=self.cs*2, out_channels=self.bottleneck, kernel_size=3, stride=1, padding=1)) 51 | # self.q_z_mean = weight_norm(nn.Linear(self.args.hidden_size, self.args.z1_size)) 52 | self.q_z_logvar = weight_norm(nn.Conv2d(in_channels=self.cs*2, out_channels=self.bottleneck, kernel_size=3, stride=1, padding=1)) 53 | # self.q_z_logvar = weight_norm(nn.Linear(self.args.hidden_size, self.args.z1_size)) 54 | self.p_x_layers = nn.Sequential( 55 | # weight_norm(nn.Linear(self.args.z1_size, self.args.hidden_size)), 56 | nn.Upsample(scale_factor=2), 57 | weight_norm(nn.Conv2d(in_channels=self.bottleneck, out_channels=self.cs*2, kernel_size=3, stride=1, padding=1)), 58 | nn.ELU(), 59 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 60 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 61 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 62 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 63 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 64 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1), 65 | nn.Upsample(scale_factor=2), 66 | weight_norm(nn.Conv2d(in_channels=self.cs*2, out_channels=self.cs, kernel_size=3, stride=1, padding=1)), 67 | nn.ELU(), 68 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 69 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 70 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 71 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 72 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 73 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1), 74 | # nn.Upsample(size=(28, 28)), 75 | ) 76 | 77 | if self.args.input_type == 'binary': 78 | self.p_x_mean = nn.Sequential(nn.Conv2d(in_channels=self.cs, out_channels=self.args.input_size[0], kernel_size=3, stride=1, padding=1), nn.Sigmoid()) 79 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': 80 | self.p_x_mean = weight_norm(nn.Conv2d(in_channels=self.cs, out_channels=self.args.input_size[0], kernel_size=3, stride=1, padding=1)) 81 | self.p_x_logvar = nn.Conv2d(in_channels=self.cs, out_channels=self.args.input_size[0], kernel_size=3, stride=1, padding=1) 82 | -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint.pth -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint_best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint_best.pth -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_0.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_1.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_10.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_11.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_12.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_13.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_14.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_15.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_16.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_17.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_18.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_19.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_2.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_20.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_21.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_22.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_23.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_24.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_25.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_26.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_27.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_28.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_29.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_3.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_30.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_31.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_32.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_33.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_34.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_35.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_36.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_37.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_38.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_39.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_4.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_40.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_41.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_42.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_43.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_44.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_45.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_46.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_46.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_47.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_48.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_49.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_5.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_6.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_7.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_8.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_9.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generations_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generations_0.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/real.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/reconstructions.png -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.config: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.config -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_kl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_kl -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_log_likelihood: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_log_likelihood -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_loss: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_loss -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_re: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_re -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_kl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_kl -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_loss: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_loss -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_re: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_re -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_kl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_kl -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_loss: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_loss -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_re: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_re -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae_config.txt: -------------------------------------------------------------------------------- 1 | Namespace(MB=100, S=5000, approximate_k=10, approximate_prior=False, base_dir='/checkpoint/sajad/143803', batch_size=100, bottleneck=6, continuous=False, cuda=True, dataset_name='dynamic_mnist', dir_extra='exemplar_prior_on_dynamic_mnist_components=25000_lr=5e-4_model_name=vae_variance_type=shared_independent=True', dynamic_binarization=True, early_stopping_epochs=50, epochs=2000, hidden_size=300, input_size=[1, 28, 28], input_type='binary', just_evaluate=False, lambd=0.0001, lr=0.0005, model_name='vae', model_signature='2', no_attention=False, no_cuda=False, no_mask=False, number_components=25000, prior='exemplar_prior', pseudoinputs_mean=0.05, pseudoinputs_std=0.01, same_variational_var=False, seed=2, slurm_job_id='143803', slurm_task_id='', test_batch_size=100, training_set_size=50000, use_logit=False, use_training_data_init=False, use_whole_train=False, warmup=100, z1_size=40, z2_size=40) 2 | Namespace(MB=100, S=5000, approximate_k=10, approximate_prior=False, base_dir='/checkpoint/sajad/143803', batch_size=100, bottleneck=6, continuous=False, cuda=True, dataset_name='dynamic_mnist', dir_extra='exemplar_prior_on_dynamic_mnist_components=25000_lr=5e-4_model_name=vae_variance_type=shared_independent=True', dynamic_binarization=True, early_stopping_epochs=50, epochs=2000, hidden_size=300, input_size=[1, 28, 28], input_type='binary', just_evaluate=False, lambd=0.0001, lr=0.0005, model_name='vae', model_signature='2', no_attention=False, no_cuda=False, no_mask=False, number_components=25000, prior='exemplar_prior', pseudoinputs_mean=0.05, pseudoinputs_std=0.01, same_variational_var=False, seed=2, slurm_job_id='143803', slurm_task_id='', test_batch_size=100, training_set_size=50000, use_logit=False, use_training_data_init=False, use_whole_train=False, warmup=100, z1_size=40, z2_size=40) 3 | Namespace(MB=100, S=5000, approximate_k=10, approximate_prior=False, base_dir='/checkpoint/sajad/143803', batch_size=100, bottleneck=6, continuous=False, cuda=True, dataset_name='dynamic_mnist', dir_extra='exemplar_prior_on_dynamic_mnist_components=25000_lr=5e-4_model_name=vae_variance_type=shared_independent=True', dynamic_binarization=True, early_stopping_epochs=50, epochs=2000, hidden_size=300, input_size=[1, 28, 28], input_type='binary', just_evaluate=False, lambd=0.0001, lr=0.0005, model_name='vae', model_signature='2', no_attention=False, no_cuda=False, no_mask=False, number_components=25000, prior='exemplar_prior', pseudoinputs_mean=0.05, pseudoinputs_std=0.01, same_variational_var=False, seed=2, slurm_job_id='143803', slurm_task_id='', test_batch_size=100, training_set_size=50000, use_logit=False, use_training_data_init=False, use_whole_train=False, warmup=100, z1_size=40, z2_size=40) 4 | -------------------------------------------------------------------------------- /pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae_experiment_log.txt: -------------------------------------------------------------------------------- 1 | FINAL EVALUATION ON TEST SET 2 | LogL (TEST): 82.05 3 | LogL (TRAIN): 0.00 4 | ELBO (TEST): 85.50 5 | ELBO (TRAIN): 99.77 6 | RE: 61.04 7 | KL: 24.45 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | scipy 5 | sklearn 6 | opencv-python 7 | matplotlib 8 | wget 9 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/utils/__init__.py -------------------------------------------------------------------------------- /utils/classify_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import time 5 | from utils.plot_images import imshow 6 | import matplotlib.pylab as plt 7 | import torchvision 8 | from pylab import rcParams 9 | 10 | rcParams['figure.figsize'] = 15, 15 11 | 12 | 13 | def compute_accuracy(classifier, model, loader, mean, args, dir=None, plot_mistakes=False): 14 | acc = 0 15 | mistakes_list = [] 16 | for data, labels in loader: 17 | try: 18 | if model.args.use_logit is True and model.args.continuous is True: 19 | data = torch.round(model.logit_inverse(data) * 255) / 255 20 | except: 21 | pass 22 | labels = labels.to(args.device) 23 | pred = classifier(data.double().to(args.device) - mean) 24 | acc += torch.mean((labels == torch.argmax(pred, dim=1)).double()) 25 | mistakes = (labels != torch.argmax(pred, dim=1)) 26 | mistakes_list.append(data[mistakes]) 27 | mistakes_list = torch.cat(mistakes_list, dim=0) 28 | if plot_mistakes is True: 29 | imshow(torchvision.utils.make_grid(mistakes_list.reshape(-1, *args.input_size))) 30 | # plt.show() 31 | plt.axis('off') 32 | plt.savefig(os.path.join(dir, 'mistakes.png'), bbox_inches='tight') 33 | acc /= len(loader) 34 | return acc 35 | 36 | 37 | def save_model(save_path, load_path, content): 38 | torch.save(content, save_path) 39 | os.rename(save_path, load_path) 40 | 41 | 42 | def load_model(load_path, model, optimizer=None): 43 | checkpoint = torch.load(load_path) 44 | model.load_state_dict(checkpoint['state_dict']) 45 | if optimizer is not None: 46 | optimizer.load_state_dict(checkpoint['optimizer']) 47 | return checkpoint 48 | 49 | 50 | def compute_loss(pred, label, args): 51 | held_out_percent = 0.1 52 | 53 | denom = torch.logsumexp(pred, dim=1, keepdim=True) 54 | prediction = pred - denom 55 | 56 | one_hot_label = torch.ones_like(prediction) * (held_out_percent / 10) 57 | one_hot_label[torch.arange(args.batch_size), label] += (1 - held_out_percent) 58 | return -torch.sum(prediction * one_hot_label, dim=1).mean() 59 | 60 | 61 | def classify_data(train_loader, val_loader, test_loader, dir, args, model): 62 | classifier = nn.Sequential(nn.Linear(784, args.hidden_units), nn.ReLU(), 63 | nn.Linear(args.hidden_units, args.hidden_units), nn.ReLU(), 64 | nn.Linear(args.hidden_units, 10)).double().to(args.device) 65 | 66 | lr = args.lr 67 | 68 | optimizer = torch.optim.SGD(classifier.parameters(), lr=lr, momentum=0.9) 69 | epochs = args.epochs 70 | mean = 0 71 | lr_lambda = lambda epoch: 1-(0.99)*(epoch/epochs) 72 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 73 | os.makedirs(dir, exist_ok=True) 74 | 75 | if os.path.exists(os.path.join(dir, 'checkpoint.pth')): 76 | checkpoint = load_model(os.path.join(dir, 'checkpoint.pth'), 77 | model=classifier, 78 | optimizer=optimizer) 79 | begin_epoch = checkpoint['epoch'] 80 | else: 81 | begin_epoch = 1 82 | 83 | for epoch_number in range(begin_epoch, epochs + 1): 84 | start_time = time.time() 85 | if epoch_number % 10 == 0: 86 | content = {'epoch': epoch_number, 'state_dict': classifier.state_dict(), 87 | 'optimizer': optimizer.state_dict()} 88 | save_model(os.path.join(dir, 'checkpoint_temp.pth'), 89 | os.path.join(dir, 'checkpoint.pth'), content) 90 | 91 | print('epoch number:', epoch_number) 92 | for index, data in enumerate(train_loader): 93 | 94 | data, _, label = data 95 | data_augment = model.reference_based_generation_x(reference_image=data.detach(), N=1).squeeze().double() 96 | label_augment = label 97 | 98 | data_augment = data_augment.to(args.device) 99 | label_augment = label_augment.to(args.device) 100 | 101 | data = data.to(args.device).double() 102 | label = label.to(args.device).long() 103 | 104 | # imshow(torchvision.utils.make_grid(data.reshape(-1, *args.input_size)).detach()) 105 | # plt.show() 106 | try: 107 | if model.args.use_logit is True and model.args.continuous is True: 108 | data = torch.round(model.logit_inverse(data) * 255) / 255 109 | except: 110 | pass 111 | data_augment = torch.round(data_augment * 255) / 255 112 | 113 | loss1 = compute_loss(classifier(data), label, args) 114 | loss2 = compute_loss(classifier(data_augment), label_augment, args) 115 | 116 | loss = args.hyper_lambda*loss1 + (1-args.hyper_lambda)*loss2 117 | 118 | optimizer.zero_grad() 119 | loss.backward() 120 | optimizer.step() 121 | scheduler.step(epoch=epoch_number) 122 | 123 | for param_group in optimizer.param_groups: 124 | print('learning rate:', param_group['lr']) 125 | break 126 | 127 | if val_loader is not None: 128 | val_acc = compute_accuracy(classifier, model, val_loader, mean, args) 129 | print('val acc', val_acc.item()) 130 | test_acc = compute_accuracy(classifier, model, test_loader, mean, args) 131 | print('accuracy test:', test_acc.item()) 132 | print("time:", time.time() - start_time) 133 | 134 | content = {'epoch': args.epochs, 'state_dict': classifier.state_dict(), 135 | 'optimizer': optimizer.state_dict()} 136 | save_model(os.path.join(dir, 'checkpoint_temp.pth'), os.path.join(dir, 'checkpoint.pth'), content) 137 | classifier.eval() 138 | if val_loader is not None: 139 | val_acc = compute_accuracy(classifier, model, val_loader, mean, args) 140 | print('accuracy val:', val_acc.item()) 141 | else: 142 | val_acc = torch.zeros(1) 143 | test_acc = compute_accuracy(classifier, model, test_loader, mean, args, dir=dir, plot_mistakes=True) 144 | print('accuracy test:', test_acc.item()) 145 | # 146 | # 147 | return (test_acc*10000).item()/100, (val_acc*10000).item()/100 148 | -------------------------------------------------------------------------------- /utils/distributions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data 4 | import math 5 | 6 | min_epsilon = 1e-5 7 | max_epsilon = 1.-1e-5 8 | log_sigmoid = torch.nn.LogSigmoid() 9 | log_2_pi = math.log(2*math.pi) 10 | 11 | 12 | def pairwise_distance(z, means): 13 | z = z.double() 14 | means = means.double() 15 | dist1 = (z**2).sum(dim=1).unsqueeze(1).expand(-1, means.shape[0]) #MB x C 16 | dist2 = (means**2).sum(dim=1).unsqueeze(0).expand(z.shape[0], -1) #MB x C 17 | dist3 = torch.mm(z, torch.transpose(means, 0, 1)) #MB x C 18 | return (dist1 + dist2 + - 2*dist3).float() 19 | 20 | 21 | def log_normal_diag_vectorized(x, mean, log_var): 22 | log_var_sqrt = log_var.mul(0.5).exp_() 23 | pair_dist = pairwise_distance(x/log_var_sqrt, mean/log_var_sqrt) 24 | log_normal = -0.5 * torch.sum(log_var+log_2_pi, dim=1) - 0.5*pair_dist 25 | return log_normal, pair_dist 26 | 27 | 28 | def log_normal_diag(x, mean, log_var, average=False, dim=None): 29 | log_normal = -0.5 * (log_var + log_2_pi + torch.pow( x - mean, 2 ) / torch.exp( log_var ) ) 30 | if average: 31 | return torch.mean(log_normal, dim) 32 | else: 33 | return torch.sum(log_normal, dim) 34 | 35 | 36 | def log_normal_standard(x, average=False, dim=None): 37 | log_normal = -0.5 * torch.pow(x, 2) - 0.5 * log_2_pi*x.new_ones(size=x.shape) 38 | if average: 39 | return torch.mean(log_normal, dim) 40 | else: 41 | return torch.sum(log_normal, dim) 42 | 43 | 44 | def log_bernoulli(x, mean, average=False, dim=None): 45 | probs = torch.clamp( mean, min=min_epsilon, max=max_epsilon) 46 | log_bernoulli = x * torch.log(probs) + (1. - x) * torch.log(1. - probs) 47 | 48 | if average: 49 | return torch.mean(log_bernoulli, dim) 50 | else: 51 | return torch.sum(log_bernoulli, dim) 52 | 53 | 54 | def log_logistic_256(x, mean, logvar, average=False, reduce=True, dim=None): 55 | bin_size = 1. / 256. 56 | # implementation like https://github.com/openai/iaf/blob/master/tf_utils/distributions.py#L28 57 | scale = torch.exp(logvar) 58 | x = (torch.floor(x / bin_size) * bin_size - mean) / scale 59 | cdf_plus = torch.sigmoid(x + bin_size/scale) 60 | cdf_minus = torch.sigmoid(x) 61 | log_logist_256 = torch.log(cdf_plus - cdf_minus + 1e-7) 62 | 63 | if average: 64 | return torch.mean(log_logist_256, dim) 65 | else: 66 | return torch.sum(log_logist_256, dim) 67 | 68 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from utils.plot_images import plot_images 3 | import torch 4 | import time 5 | from scipy.special import logsumexp 6 | import numpy as np 7 | from utils.utils import load_model 8 | import torch.nn.functional as F 9 | 10 | 11 | def evaluate_loss(args, model, loader, dataset=None, exemplars_embedding=None): 12 | evaluateed_elbo, evaluate_re, evaluate_kl = 0, 0, 0 13 | model.eval() 14 | if exemplars_embedding is None: 15 | exemplars_embedding = load_all_pseudo_input(args, model, dataset) 16 | 17 | for data in loader: 18 | if len(data) == 3: 19 | data, _, _ = data 20 | else: 21 | data, _ = data 22 | data = data.to(args.device) 23 | x = data 24 | x_indices = None 25 | x = (x, x_indices) 26 | loss, RE, KL = model.calculate_loss(x, average=False, exemplars_embedding=exemplars_embedding) 27 | evaluateed_elbo += loss.sum().item() 28 | evaluate_re += -RE.sum().item() 29 | evaluate_kl += KL.sum().item() 30 | evaluateed_elbo /= len(loader.dataset) 31 | evaluate_re /= len(loader.dataset) 32 | evaluate_kl /= len(loader.dataset) 33 | return evaluateed_elbo, evaluate_re, evaluate_kl 34 | 35 | 36 | def visualize_reconstruction(test_samples, model, args, dir): 37 | samples_reconstruction = model.reconstruct_x(test_samples[0:25]) 38 | 39 | if args.use_logit: 40 | test_samples = model.logit_inverse(test_samples) 41 | samples_reconstruction = model.logit_inverse(samples_reconstruction) 42 | plot_images(args, test_samples.cpu().numpy()[0:25], dir, 'real', size_x=5, size_y=5) 43 | plot_images(args, samples_reconstruction.cpu().numpy(), dir, 'reconstructions', size_x=5, size_y=5) 44 | 45 | 46 | def visualize_generation(dataset, model, args, dir): 47 | generation_rounds = 1 48 | for i in range(generation_rounds): 49 | samples_rand = model.generate_x(25, dataset=dataset) 50 | plot_images(args, samples_rand.cpu().numpy(), dir, 'generations_{}'.format(i), size_x=5, size_y=5) 51 | if args.prior == 'vampprior': 52 | pseudo_means = model.means(model.idle_input) 53 | plot_images(args, pseudo_means[0:25].cpu().numpy(), dir, 'pseudoinputs', size_x=5, size_y=5) 54 | 55 | 56 | def load_all_pseudo_input(args, model, dataset): 57 | if args.prior == 'exemplar_prior': 58 | exemplars_z, exemplars_log_var = model.cache_z(dataset) 59 | embedding = (exemplars_z, exemplars_log_var, torch.arange(len(exemplars_z))) 60 | elif args.prior == 'vampprior': 61 | pseudo_means = model.means(model.idle_input) 62 | if 'conv' in args.model_name: 63 | pseudo_means = pseudo_means.view(-1, args.input_size[0], args.input_size[1], args.input_size[2]) 64 | embedding = model.q_z(pseudo_means, prior=True) # C x M 65 | elif args.prior == 'standard': 66 | embedding = None 67 | else: 68 | raise Exception("wrong name of prior") 69 | return embedding 70 | 71 | 72 | def calculate_likelihood(args, model, loader, S=5000, exemplars_embedding=None): 73 | likelihood_test = [] 74 | batch_size_evaluation = 1 75 | auxilary_loader = torch.utils.data.DataLoader(loader.dataset, batch_size=batch_size_evaluation) 76 | t0 = time.time() 77 | for index, (data, _) in enumerate(auxilary_loader): 78 | data = data.to(args.device) 79 | if index % 100 == 0: 80 | print(time.time() - t0) 81 | t0 = time.time() 82 | print('{:.2f}%'.format(index / (1. * len(auxilary_loader)) * 100)) 83 | x = data.expand(S, data.size(1)) 84 | if args.model_name == 'pixelcnn': 85 | BS = S//100 86 | prob = [] 87 | for i in range(BS): 88 | bx = x[i*100:(i+1)*100] 89 | x_indices = None 90 | bprob, _, _ = model.calculate_loss((bx, x_indices), exemplars_embedding=exemplars_embedding) 91 | prob.append(bprob) 92 | prob = torch.cat(prob, dim=0) 93 | else: 94 | x_indices = None 95 | prob, _, _ = model.calculate_loss((x, x_indices), exemplars_embedding=exemplars_embedding) 96 | likelihood_x = logsumexp(-prob.cpu().numpy()) 97 | if model.args.use_logit: 98 | lambd = torch.tensor(model.args.lambd).float() 99 | likelihood_x -= (-F.softplus(-x) - F.softplus(x)\ 100 | - torch.log((1 - 2 * lambd)/256)).sum(dim=1).cpu().numpy() 101 | likelihood_test.append(likelihood_x - np.log(len(prob))) 102 | likelihood_test = np.array(likelihood_test) 103 | return -np.mean(likelihood_test) 104 | 105 | 106 | def final_evaluation(train_loader, test_loader, valid_loader, best_model_path_load, 107 | model, optimizer, args, dir): 108 | _ = load_model(best_model_path_load, model, optimizer) 109 | model.eval() 110 | exemplars_embedding = load_all_pseudo_input(args, model, train_loader.dataset) 111 | test_samples = next(iter(test_loader))[0].to(args.device) 112 | visualize_reconstruction(test_samples, model, args, dir) 113 | visualize_generation(train_loader.dataset, model, args, dir) 114 | test_elbo, test_re, test_kl = evaluate_loss(args, model, test_loader, dataset=train_loader.dataset, exemplars_embedding=exemplars_embedding) 115 | valid_elbo, valid_re, valid_kl = evaluate_loss(args, model, valid_loader, dataset=valid_loader.dataset, exemplars_embedding=exemplars_embedding) 116 | train_elbo, _, _ = evaluate_loss(args, model, train_loader, dataset=train_loader.dataset, exemplars_embedding=exemplars_embedding) 117 | test_log_likelihood = calculate_likelihood(args, model, test_loader, exemplars_embedding=exemplars_embedding, S=args.S) 118 | final_evaluation_txt = 'FINAL EVALUATION ON TEST SET\n' \ 119 | 'LogL (TEST): {:.2f}\n' \ 120 | 'LogL (TRAIN): {:.2f}\n' \ 121 | 'ELBO (TEST): {:.2f}\n' \ 122 | 'ELBO (TRAIN): {:.2f}\n' \ 123 | 'ELBO (VALID): {:.2f}\n' \ 124 | 'RE: {:.2f}\n' \ 125 | 'KL: {:.2f}'.format( 126 | test_log_likelihood, 127 | 0, 128 | test_elbo, 129 | train_elbo, 130 | valid_elbo, 131 | test_re, 132 | test_kl) 133 | 134 | print(final_evaluation_txt) 135 | with open(dir + 'vae_experiment_log.txt', 'a') as f: 136 | print(final_evaluation_txt, file=f) 137 | torch.save(test_log_likelihood, dir + args.model_name + '.test_log_likelihood') 138 | torch.save(test_elbo, dir + args.model_name + '.test_loss') 139 | torch.save(test_re, dir + args.model_name + '.test_re') 140 | torch.save(test_kl, dir + args.model_name + '.test_kl') 141 | 142 | 143 | # TODO remove last loop from this function 144 | def compute_mean_variance_per_dimension(args, model, test_loader): 145 | means = [] 146 | for batch, _ in test_loader: 147 | mean, _ = model.q_z(batch.to(args.device)) 148 | means.append(mean) 149 | means = torch.cat(means, dim=0).cpu().detach().numpy() 150 | active = 0 151 | for i in range(means.shape[1]): 152 | if np.var(means[:, i].reshape(-1)) > 0.01: 153 | active += 1 154 | print('active dimensions', active) 155 | return active 156 | 157 | 158 | -------------------------------------------------------------------------------- /utils/knn_on_latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def find_nearest_neighbors(z_val, z_train, z_train_log_var): 5 | z_expand = z_val.unsqueeze(1) 6 | means = z_train.unsqueeze(0) 7 | distance = (z_expand - means)**2 8 | _, indices_batch = (torch.sum(distance, dim=2)**(0.5)).topk(k=20, dim=1, largest=False, sorted=True) 9 | return indices_batch 10 | 11 | 12 | def extract_full_data(data_loader): 13 | full_data = [] 14 | full_labels = [] 15 | full_indices = [] 16 | for data in data_loader: 17 | if len(data) == 3: 18 | data, indices, labels = data 19 | full_indices.append(indices) 20 | else: 21 | data, labels = data 22 | full_data.append(data) 23 | full_labels.append(labels) 24 | full_data = torch.cat(full_data, dim=0) 25 | full_labels = torch.cat(full_labels, dim=0) 26 | if len(full_indices) > 0: 27 | full_indices = torch.cat(full_indices, dim=0) 28 | return full_data, full_indices, full_labels 29 | 30 | 31 | # TODO refactor this fucntion 32 | def report_knn_on_latent(train_loader, val_loader, test_loader, model, dir, knn_dictionary, args, val=True): 33 | train_data, _, train_labels = extract_full_data(train_loader) 34 | val_data, _, val_labels = extract_full_data(val_loader) 35 | test_data, _, test_labels = extract_full_data(test_loader) 36 | 37 | train_data = train_data.to(args.device) 38 | val_data = val_data.to(args.device) 39 | 40 | if val is True: 41 | data_to_evaluate = val_data 42 | labels = val_labels 43 | else: 44 | train_data = torch.cat((train_data, val_data), dim=0) 45 | train_labels = torch.cat((train_labels, val_labels), dim=0) 46 | data_to_evaluate = test_data 47 | labels = test_labels 48 | 49 | with torch.no_grad(): 50 | z_train = [] 51 | for i in range(len(train_data)//args.batch_size): 52 | train_batch = train_data[i*args.batch_size: (i+1)*args.batch_size] 53 | z_train_batch, _ = model.q_z(train_batch.to(args.device), prior=True) 54 | z_train.append(z_train_batch) 55 | z_train = torch.cat(z_train, dim=0) 56 | 57 | print(z_train.shape) 58 | indices = [] 59 | for i in range(len(data_to_evaluate)//args.batch_size): 60 | z_val, _ = model.q_z(data_to_evaluate[i*args.batch_size: (i+1)*args.batch_size].to(args.device), prior=True) 61 | indices.append(find_nearest_neighbors(z_val, z_train, None)) 62 | indices = torch.cat(indices, dim=0) 63 | 64 | for k in knn_dictionary.keys(): 65 | k = int(k) 66 | k_labels = train_labels[indices[:, :k]].squeeze().long() 67 | num_classes = 10 68 | counts = torch.zeros(len(test_loader.dataset), num_classes) 69 | for i in range(num_classes): 70 | counts[:, i] = (k_labels == torch.tensor(i).long()).sum(dim=1) 71 | y_pred = torch.argmax(counts, dim=1) 72 | acc = (torch.mean((y_pred == labels.long()).float()) * 10000).round().item()/100 73 | print('K:', k, 'Accuracy:', acc) 74 | knn_dictionary[str(k)].append(acc) 75 | -------------------------------------------------------------------------------- /utils/load_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/utils/load_data/__init__.py -------------------------------------------------------------------------------- /utils/load_data/base_load_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data as data_utils 4 | import numpy as np 5 | from abc import ABC, abstractmethod 6 | 7 | 8 | class base_load_data(ABC): 9 | def __init__(self, args, use_fixed_validation=False, no_binarization=False): 10 | self.args = args 11 | self.train_num = args.training_set_size 12 | self.use_fixed_validation = use_fixed_validation 13 | self.no_binarization = no_binarization 14 | 15 | @abstractmethod 16 | def obtain_data(self): 17 | pass 18 | 19 | def logit(self, x): 20 | return np.log(x) - np.log1p(-x) 21 | 22 | def seperate_data_from_label(self, train_dataset, test_dataset): 23 | x_train = train_dataset.data.numpy() 24 | y_train = train_dataset.train_labels.numpy().astype(int) 25 | x_test = test_dataset.data.numpy() 26 | y_test = test_dataset.test_labels.numpy().astype(int) 27 | return x_train, y_train, x_test, y_test 28 | 29 | def preprocessing_(self, x_train, x_test): 30 | if self.args.input_type == 'gray' or self.args.input_type == 'continuous': 31 | if self.args.use_logit: 32 | lambd = self.args.lambd 33 | x_train = self.logit(lambd + (1 - 2 * lambd) * (x_train + np.random.rand(*x_train.shape)) / 256.) 34 | x_test = self.logit(lambd + (1 - 2 * lambd) * (x_test + np.random.rand(*x_test.shape)) / 256.) 35 | elif self.args.continuous: 36 | x_train = np.clip((x_train + 0.5) / 256., 0., 1.) 37 | x_test = np.clip((x_test + 0.5) / 256., 0., 1.) 38 | else: 39 | x_train = x_train / 255. 40 | x_test = x_test / 255. 41 | 42 | return x_train, x_test 43 | 44 | def vampprior_initialization(self, x_train, init_mean, init_std): 45 | if self.args.use_training_data_init == 1: 46 | self.args.pseudoinputs_std = 0.01 47 | init = x_train[0:self.args.number_components].T 48 | self.args.pseudoinputs_mean = torch.from_numpy( 49 | init + self.args.pseudoinputs_std * np.random.randn(np.prod(self.args.input_size), 50 | self.args.number_components)).float() 51 | else: 52 | self.args.pseudoinputs_mean = init_mean 53 | self.args.pseudoinputs_std = init_std 54 | 55 | def post_processing(self, x_train, x_val, x_test, y_train, y_val, y_test, init_mean=0.05, init_std=0.01, **kwargs): 56 | indices = np.arange(len(x_train)).reshape(-1, 1) 57 | train = data_utils.TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(indices), 58 | torch.from_numpy(y_train)) 59 | train_loader = data_utils.DataLoader(train, batch_size=self.args.batch_size, shuffle=True, **kwargs) 60 | 61 | if len(x_val) > 0: 62 | validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) 63 | val_loader = data_utils.DataLoader(validation, batch_size=self.args.test_batch_size, shuffle=True, **kwargs) 64 | else: 65 | val_loader = None 66 | test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) 67 | test_loader = data_utils.DataLoader(test, batch_size=self.args.test_batch_size, shuffle=False, **kwargs) 68 | 69 | self.vampprior_initialization(x_train, init_mean, init_std) 70 | return train_loader, val_loader, test_loader 71 | 72 | def binarize(self, x_val, x_test): 73 | self.args.input_type = 'binary' 74 | np.random.seed(777) 75 | x_val = np.random.binomial(1, x_val) 76 | x_test = np.random.binomial(1, x_test) 77 | return x_val, x_test 78 | 79 | def load_dataset(self, **kwargs): 80 | # start processing 81 | train, test = self.obtain_data() 82 | x_train, y_train, x_test, y_test = self.seperate_data_from_label(train, test) 83 | x_train, x_test = self.preprocessing_(x_train, x_test) 84 | 85 | if self.use_fixed_validation is False: 86 | permutation = np.arange(len(x_train)) 87 | np.random.shuffle(permutation) 88 | x_train = x_train[permutation] 89 | y_train = y_train[permutation] 90 | 91 | if self.args.dataset_name == 'static_mnist': 92 | x_train, x_val = x_train 93 | y_train, y_val = y_train 94 | else: 95 | x_val = x_train[self.train_num:] 96 | y_val = y_train[self.train_num:] 97 | x_train = x_train[:self.train_num] 98 | y_train = y_train[:self.train_num] 99 | 100 | # imshow(torchvision.utils.make_grid(torch.from_numpy(x_val[:50].reshape(-1, *self.args.input_size)))) 101 | # plt.axis('off') 102 | # plt.show() 103 | 104 | x_train = np.reshape(x_train, (-1, np.prod(self.args.input_size))) 105 | x_val = np.reshape(x_val, (-1, np.prod(self.args.input_size))) 106 | 107 | x_test = np.reshape(x_test, (-1, np.prod(self.args.input_size))) 108 | 109 | if self.args.dynamic_binarization and self.no_binarization is False: 110 | x_val, x_test = self.binarize(x_val, x_test) 111 | 112 | print("data stats:") 113 | print(len(x_train), len(y_train)) 114 | print(len(x_val), len(y_val)) 115 | print(len(x_test), len(y_test)) 116 | 117 | train_loader, val_loader, test_loader, = self.post_processing(x_train, x_val, x_test, 118 | y_train, y_val, y_test, **kwargs) 119 | 120 | return train_loader, val_loader, test_loader, self.args 121 | -------------------------------------------------------------------------------- /utils/load_data/data_loader_instances.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import datasets 3 | import numpy as np 4 | from scipy.io import loadmat 5 | from .base_load_data import base_load_data 6 | import wget 7 | 8 | class dynamic_mnist_loader(base_load_data): 9 | def __init__(self, args, use_fixed_validation=False, no_binarization=False): 10 | super(dynamic_mnist_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization) 11 | 12 | def obtain_data(self): 13 | train = datasets.MNIST(os.path.join('datasets', self.args.dataset_name), train=True, download=True) 14 | test = datasets.MNIST(os.path.join('datasets', self.args.dataset_name), train=False) 15 | return train, test 16 | 17 | 18 | class fashion_mnist_loader(base_load_data): 19 | def __init__(self, args, use_fixed_validation=False, no_binarization=False): 20 | super(fashion_mnist_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization) 21 | 22 | def obtain_data(self): 23 | train = datasets.FashionMNIST(os.path.join('datasets', self.args.dataset_name), train=True, download=True) 24 | test = datasets.FashionMNIST(os.path.join('datasets', self.args.dataset_name), train=False) 25 | return train, test 26 | 27 | 28 | class svhn_loader(base_load_data): 29 | def __init__(self, args, use_fixed_validation=False, no_binarization=False): 30 | super(svhn_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization) 31 | 32 | def obtain_data(self): 33 | train = datasets.SVHN(os.path.join('datasets', self.args.dataset_name), split='train', download=True) 34 | test = datasets.SVHN(os.path.join('datasets', self.args.dataset_name), split='test', download=True) 35 | return train, test 36 | 37 | def seperate_data_from_label(self, train_dataset, test_dataset): 38 | x_train = train_dataset.data 39 | y_train = train_dataset.labels.astype(dtype=int) 40 | x_test = test_dataset.data 41 | y_test = test_dataset.labels.astype(dtype=int) 42 | return x_train, y_train, x_test, y_test 43 | 44 | 45 | class static_mnist_loader(base_load_data): 46 | def __init__(self, args, use_fixed_validation=False, no_binarization=False): 47 | super(static_mnist_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization) 48 | 49 | def obtain_data(self): 50 | def lines_to_np_array(lines): 51 | return np.array([[int(i) for i in line.split()] for line in lines]) 52 | 53 | with open(os.path.join('datasets', self.args.dataset_name, 'binarized_mnist_train.amat')) as f: 54 | lines = f.readlines() 55 | x_train = lines_to_np_array(lines).astype('float32') 56 | with open(os.path.join('datasets', self.args.dataset_name, 'binarized_mnist_valid.amat')) as f: 57 | lines = f.readlines() 58 | x_val = lines_to_np_array(lines).astype('float32') 59 | with open(os.path.join('datasets', self.args.dataset_name, 'binarized_mnist_test.amat')) as f: 60 | lines = f.readlines() 61 | x_test = lines_to_np_array(lines).astype('float32') 62 | 63 | y_train = np.zeros((x_train.shape[0], 1)).astype(int) 64 | y_val = np.zeros((x_val.shape[0], 1)).astype(int) 65 | y_test = np.zeros((x_test.shape[0], 1)).astype(int) 66 | return (x_train, x_val, y_train, y_val), (x_test, y_test) 67 | 68 | def seperate_data_from_label(self, train_dataset, test_dataset): 69 | x_train, x_val, y_train, y_val = train_dataset 70 | x_test, y_test = test_dataset 71 | return (x_train, x_val), (y_train, y_val), x_test, y_test 72 | 73 | def preprocessing_(self, x_train, x_test): 74 | return x_train, x_test 75 | 76 | 77 | class omniglot_loader(base_load_data): 78 | def __init__(self, args, use_fixed_validation=False, no_binarization=False): 79 | super(omniglot_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization) 80 | 81 | def obtain_data(self): 82 | def reshape_data(data): 83 | return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='F') 84 | dataset_file = os.path.join('datasets', self.args.dataset_name, 'chardata.mat') 85 | if not os.path.exists(dataset_file): 86 | url = "https://raw.githubusercontent.com/yburda/iwae/master/datasets/OMNIGLOT/chardata.mat" 87 | wget.download(url, dataset_file) 88 | 89 | omni_raw = loadmat(os.path.join('datasets', self.args.dataset_name, 'chardata.mat')) 90 | 91 | x_train = reshape_data(omni_raw['data'].T.astype('float32')) 92 | x_test = reshape_data(omni_raw['testdata'].T.astype('float32')) 93 | 94 | y_train = omni_raw['targetchar'].reshape((-1, 1)) 95 | y_test = omni_raw['testtargetchar'].reshape((-1, 1)) 96 | return (x_train, y_train), (x_test, y_test) 97 | 98 | def seperate_data_from_label(self, train_dataset, test_dataset): 99 | x_train, y_train = train_dataset 100 | x_test, y_test = test_dataset 101 | return x_train, y_train, x_test, y_test 102 | 103 | def preprocessing_(self, x_train, x_test): 104 | return x_train, x_test 105 | 106 | 107 | class cifar10_loader(base_load_data): 108 | def __init__(self, args, use_fixed_validation=False, no_binarization=False): 109 | super(cifar10_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization) 110 | 111 | def obtain_data(self): 112 | training_dataset = datasets.CIFAR10(os.path.join('datasets', self.args.dataset_name), train=True, download=True) 113 | test_dataset = datasets.CIFAR10(os.path.join('datasets', self.args.dataset_name), train=False) 114 | return training_dataset, test_dataset 115 | 116 | def seperate_data_from_label(self, train_dataset, test_dataset): 117 | train_data = np.swapaxes(np.swapaxes(train_dataset.data, 1, 2), 1, 3) 118 | y_train = np.zeros((train_data.shape[0], 1)).astype(int) 119 | test_data = np.swapaxes(np.swapaxes(test_dataset.data, 1, 2), 1, 3) 120 | y_test = np.zeros((test_data.shape[0], 1)).astype(int) 121 | return train_data, y_train, test_data, y_test 122 | 123 | 124 | def load_dataset(args, training_num=None, use_fixed_validation=False, no_binarization=False, **kwargs): 125 | if training_num is not None: 126 | args.training_set_size = training_num 127 | if args.dataset_name == 'static_mnist': 128 | args.input_size = [1, 28, 28] 129 | args.input_type = 'binary' 130 | train_loader, val_loader, test_loader, args = static_mnist_loader(args).load_dataset(**kwargs) 131 | elif args.dataset_name == 'dynamic_mnist': 132 | if training_num is None: 133 | args.training_set_size = 50000 134 | args.input_size = [1, 28, 28] 135 | if args.continuous is True: 136 | args.input_type = 'gray' 137 | args.dynamic_binarization = False 138 | no_binarization = True 139 | else: 140 | args.input_type = 'binary' 141 | args.dynamic_binarization = True 142 | 143 | train_loader, val_loader, test_loader, args = \ 144 | dynamic_mnist_loader(args, use_fixed_validation, no_binarization=no_binarization).load_dataset(**kwargs) 145 | elif args.dataset_name == 'fashion_mnist': 146 | if training_num is None: 147 | args.training_set_size = 50000 148 | args.input_size = [1, 28, 28] 149 | 150 | if args.continuous is True: 151 | print("*****Continuous Data*****") 152 | args.input_type = 'gray' 153 | args.dynamic_binarization = False 154 | no_binarization = True 155 | else: 156 | args.input_type = 'binary' 157 | args.dynamic_binarization = True 158 | 159 | train_loader, val_loader, test_loader, args = \ 160 | fashion_mnist_loader(args, use_fixed_validation, no_binarization=no_binarization).load_dataset(**kwargs) 161 | elif args.dataset_name == 'omniglot': 162 | if training_num is None: 163 | args.training_set_size = 23000 164 | args.input_size = [1, 28, 28] 165 | args.input_type = 'binary' 166 | args.dynamic_binarization = True 167 | train_loader, val_loader, test_loader, args = omniglot_loader(args).load_dataset(**kwargs) 168 | elif args.dataset_name == 'svhn': 169 | args.training_set_size = 60000 170 | args.input_size = [3, 32, 32] 171 | args.input_type = 'continuous' 172 | train_loader, val_loader, test_loader, args = svhn_loader(args).load_dataset(**kwargs) 173 | elif args.dataset_name == 'cifar10': 174 | args.training_set_size = 40000 175 | args.input_size = [3, 32, 32] 176 | args.input_type = 'continuous' 177 | train_loader, val_loader, test_loader, args = cifar10_loader(args).load_dataset(**kwargs) 178 | else: 179 | raise Exception('Wrong name of the dataset!') 180 | print('train size', len(train_loader.dataset)) 181 | if val_loader is not None: 182 | print('val size', len(val_loader.dataset)) 183 | print('test size', len(test_loader.dataset)) 184 | return train_loader, val_loader, test_loader, args 185 | -------------------------------------------------------------------------------- /utils/nn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def xavier_init(m): 8 | s = np.sqrt( 2. / (m.in_features + m.out_features) ) 9 | m.weight.data.normal_(0, s) 10 | 11 | 12 | def he_init(m): 13 | s = np.sqrt( 2. / m.in_features ) 14 | m.weight.data.normal_(0, s) 15 | 16 | 17 | def normal_init(m, mean=0., std=0.01): 18 | m.weight.data.normal_(mean, std) 19 | 20 | 21 | class CReLU(nn.Module): 22 | def __init__(self): 23 | super(CReLU, self).__init__() 24 | 25 | def forward(self, x): 26 | return torch.cat( F.relu(x), F.relu(-x), 1 ) 27 | 28 | 29 | class NonLinear(nn.Module): 30 | def __init__(self, input_size, output_size, bias=True, activation=None): 31 | super(NonLinear, self).__init__() 32 | 33 | self.activation = activation 34 | self.linear = nn.Linear(int(input_size), int(output_size), bias=bias) 35 | 36 | def forward(self, x): 37 | h = self.linear(x) 38 | if self.activation is not None: 39 | h = self.activation( h ) 40 | 41 | return h 42 | 43 | 44 | class GatedDense(nn.Module): 45 | def __init__(self, input_size, output_size, activation=None, no_attention=False): 46 | super(GatedDense, self).__init__() 47 | 48 | self.activation = activation 49 | self.no_attention = no_attention 50 | self.sigmoid = nn.Sigmoid() 51 | self.h = nn.Linear(input_size, output_size) 52 | if no_attention is False: 53 | self.g = nn.Linear(input_size, output_size) 54 | else: 55 | self.activation = torch.nn.ReLU() 56 | 57 | def forward(self, x): 58 | h = self.h(x) 59 | if self.activation is not None: 60 | h = self.activation( self.h( x ) ) 61 | try: 62 | if self.no_attention is False: 63 | g = self.sigmoid(self.g(x)) 64 | return h * g 65 | else: 66 | return h 67 | except: 68 | g = self.sigmoid(self.g(x)) 69 | return h * g 70 | 71 | 72 | class GatedConv2d(nn.Module): 73 | def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None, 74 | no_attention=False): 75 | super(GatedConv2d, self).__init__() 76 | self.no_attention = no_attention 77 | 78 | self.activation = activation 79 | self.sigmoid = nn.Sigmoid() 80 | 81 | self.h = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) 82 | if no_attention is False: 83 | self.g = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) 84 | else: 85 | self.activation = torch.nn.ELU() 86 | 87 | def forward(self, x): 88 | if self.activation is None: 89 | h = self.h(x) 90 | else: 91 | h = self.activation( self.h( x ) ) 92 | 93 | # if self.no_attention is False: 94 | g = self.sigmoid( self.g( x ) ) 95 | return h * g 96 | # else: 97 | # return h 98 | 99 | 100 | class Conv2d(nn.Module): 101 | def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None, bias=True): 102 | super(Conv2d, self).__init__() 103 | 104 | self.activation = activation 105 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation, bias=bias) 106 | 107 | def forward(self, x): 108 | h = self.conv(x) 109 | if self.activation is None: 110 | out = h 111 | else: 112 | out = self.activation(h) 113 | 114 | return out 115 | 116 | 117 | class MaskedConv2d(nn.Conv2d): 118 | def __init__(self, mask_type, *args, **kwargs): 119 | super(MaskedConv2d, self).__init__(*args, **kwargs) 120 | assert mask_type in {'A', 'B'} 121 | self.register_buffer('mask', self.weight.data.clone()) 122 | _, _, kH, kW = self.weight.size() 123 | self.mask.fill_(1) 124 | self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0 125 | self.mask[:, :, kH // 2 + 1:] = 0 126 | 127 | def forward(self, x): 128 | self.weight.data *= self.mask 129 | return super(MaskedConv2d, self).forward(x) 130 | 131 | 132 | # Copyright (c) Xi Chen 133 | # 134 | # This source code is licensed under the MIT license found in the 135 | # LICENSE file in the root directory of this source tree. 136 | 137 | # Borrowed from https://github.com/neocxi/pixelsnail-public and ported it to PyTorch 138 | 139 | from math import sqrt 140 | from functools import partial, lru_cache 141 | 142 | import numpy as np 143 | import torch 144 | from torch import nn 145 | from torch.nn import functional as F 146 | 147 | 148 | def wn_linear(in_dim, out_dim): 149 | return nn.utils.weight_norm(nn.Linear(in_dim, out_dim)) 150 | 151 | 152 | class WNConv2d(nn.Module): 153 | def __init__( 154 | self, 155 | in_channel, 156 | out_channel, 157 | kernel_size, 158 | stride=1, 159 | padding=0, 160 | bias=True, 161 | activation=None, 162 | ): 163 | super().__init__() 164 | 165 | self.conv = nn.utils.weight_norm( 166 | nn.Conv2d( 167 | in_channel, 168 | out_channel, 169 | kernel_size, 170 | stride=stride, 171 | padding=padding, 172 | bias=bias, 173 | ) 174 | ) 175 | 176 | self.out_channel = out_channel 177 | 178 | if isinstance(kernel_size, int): 179 | kernel_size = [kernel_size, kernel_size] 180 | 181 | self.kernel_size = kernel_size 182 | 183 | self.activation = activation 184 | 185 | def forward(self, input): 186 | out = self.conv(input) 187 | 188 | if self.activation is not None: 189 | out = self.activation(out) 190 | 191 | return out 192 | 193 | 194 | def shift_down(input, size=1): 195 | return F.pad(input, [0, 0, size, 0])[:, :, : input.shape[2], :] 196 | 197 | 198 | def shift_right(input, size=1): 199 | return F.pad(input, [size, 0, 0, 0])[:, :, :, : input.shape[3]] 200 | 201 | 202 | class CausalConv2d(nn.Module): 203 | def __init__( 204 | self, 205 | in_channel, 206 | out_channel, 207 | kernel_size, 208 | stride=1, 209 | padding='downright', 210 | activation=None, 211 | ): 212 | super().__init__() 213 | 214 | if isinstance(kernel_size, int): 215 | kernel_size = [kernel_size] * 2 216 | 217 | self.kernel_size = kernel_size 218 | 219 | if padding == 'downright': 220 | pad = [kernel_size[1] - 1, 0, kernel_size[0] - 1, 0] 221 | 222 | elif padding == 'down' or padding == 'causal': 223 | pad = kernel_size[1] // 2 224 | 225 | pad = [pad, pad, kernel_size[0] - 1, 0] 226 | 227 | self.causal = 0 228 | if padding == 'causal': 229 | self.causal = kernel_size[1] // 2 230 | 231 | self.pad = nn.ZeroPad2d(pad) 232 | 233 | self.conv = WNConv2d( 234 | in_channel, 235 | out_channel, 236 | kernel_size, 237 | stride=stride, 238 | padding=0, 239 | activation=activation, 240 | ) 241 | 242 | def forward(self, input): 243 | out = self.pad(input) 244 | 245 | if self.causal > 0: 246 | self.conv.conv.weight_v.data[:, :, -1, self.causal :].zero_() 247 | 248 | out = self.conv(out) 249 | 250 | return out 251 | 252 | 253 | class GatedResBlock(nn.Module): 254 | def __init__( 255 | self, 256 | in_channel, 257 | channel, 258 | kernel_size, 259 | conv='wnconv2d', 260 | activation=nn.ELU, 261 | dropout=0.1, 262 | auxiliary_channel=0, 263 | condition_dim=0, 264 | ): 265 | super().__init__() 266 | 267 | if conv == 'wnconv2d': 268 | conv_module = partial(WNConv2d, padding=kernel_size // 2) 269 | 270 | elif conv == 'causal_downright': 271 | conv_module = partial(CausalConv2d, padding='downright') 272 | 273 | elif conv == 'causal': 274 | conv_module = partial(CausalConv2d, padding='causal') 275 | 276 | self.activation = activation() 277 | self.conv1 = conv_module(in_channel, channel, kernel_size) 278 | 279 | if auxiliary_channel > 0: 280 | self.aux_conv = WNConv2d(auxiliary_channel, channel, 1) 281 | 282 | self.dropout = nn.Dropout(dropout) 283 | 284 | self.conv2 = conv_module(channel, in_channel * 2, kernel_size) 285 | 286 | if condition_dim > 0: 287 | # self.condition = nn.Linear(condition_dim, in_channel * 2, bias=False) 288 | self.condition = WNConv2d(condition_dim, in_channel * 2, 1, bias=False) 289 | 290 | self.gate = nn.GLU(1) 291 | 292 | def forward(self, input, aux_input=None, condition=None): 293 | out = self.conv1(self.activation(input)) 294 | 295 | if aux_input is not None: 296 | out = out + self.aux_conv(self.activation(aux_input)) 297 | 298 | out = self.activation(out) 299 | out = self.dropout(out) 300 | out = self.conv2(out) 301 | 302 | if condition is not None: 303 | condition = self.condition(condition) 304 | out += condition 305 | # out = out + condition.view(condition.shape[0], 1, 1, condition.shape[1]) 306 | 307 | out = self.gate(out) 308 | out += input 309 | 310 | return out 311 | 312 | 313 | @lru_cache(maxsize=64) 314 | def causal_mask(size): 315 | shape = [size, size] 316 | mask = np.triu(np.ones(shape), k=1).astype(np.uint8).T 317 | start_mask = np.ones(size).astype(np.float32) 318 | start_mask[0] = 0 319 | 320 | return ( 321 | torch.from_numpy(mask).unsqueeze(0), 322 | torch.from_numpy(start_mask).unsqueeze(1), 323 | ) 324 | 325 | 326 | class CausalAttention(nn.Module): 327 | def __init__(self, query_channel, key_channel, channel, n_head=8, dropout=0.1): 328 | super().__init__() 329 | 330 | self.query = wn_linear(query_channel, channel) 331 | self.key = wn_linear(key_channel, channel) 332 | self.value = wn_linear(key_channel, channel) 333 | 334 | self.dim_head = channel // n_head 335 | self.n_head = n_head 336 | 337 | self.dropout = nn.Dropout(dropout) 338 | 339 | def forward(self, query, key): 340 | batch, _, height, width = key.shape 341 | 342 | def reshape(input): 343 | return input.view(batch, -1, self.n_head, self.dim_head).transpose(1, 2) 344 | 345 | query_flat = query.view(batch, query.shape[1], -1).transpose(1, 2) 346 | key_flat = key.view(batch, key.shape[1], -1).transpose(1, 2) 347 | query = reshape(self.query(query_flat)) 348 | key = reshape(self.key(key_flat)).transpose(2, 3) 349 | value = reshape(self.value(key_flat)) 350 | 351 | attn = torch.matmul(query, key) / sqrt(self.dim_head) 352 | mask, start_mask = causal_mask(height * width) 353 | mask = mask.type_as(query) 354 | start_mask = start_mask.type_as(query) 355 | attn = attn.masked_fill(mask == 0, -1e4) 356 | attn = torch.softmax(attn, 3) * start_mask 357 | attn = self.dropout(attn) 358 | 359 | out = attn @ value 360 | out = out.transpose(1, 2).reshape( 361 | batch, height, width, self.dim_head * self.n_head 362 | ) 363 | out = out.permute(0, 3, 1, 2) 364 | 365 | return out 366 | 367 | 368 | class PixelBlock(nn.Module): 369 | def __init__( 370 | self, 371 | in_channel, 372 | channel, 373 | kernel_size, 374 | n_res_block, 375 | attention=True, 376 | dropout=0.1, 377 | condition_dim=0, 378 | ): 379 | super().__init__() 380 | 381 | resblocks = [] 382 | for i in range(n_res_block): 383 | resblocks.append( 384 | GatedResBlock( 385 | in_channel, 386 | channel, 387 | kernel_size, 388 | conv='causal', 389 | dropout=dropout, 390 | condition_dim=condition_dim, 391 | ) 392 | ) 393 | 394 | self.resblocks = nn.ModuleList(resblocks) 395 | 396 | self.attention = attention 397 | 398 | if attention: 399 | self.key_resblock = GatedResBlock( 400 | in_channel * 2 + 2, in_channel, 1, dropout=dropout 401 | ) 402 | self.query_resblock = GatedResBlock( 403 | in_channel + 2, in_channel, 1, dropout=dropout 404 | ) 405 | 406 | self.causal_attention = CausalAttention( 407 | in_channel + 2, in_channel * 2 + 2, in_channel // 2, dropout=dropout 408 | ) 409 | 410 | self.out_resblock = GatedResBlock( 411 | in_channel, 412 | in_channel, 413 | 1, 414 | auxiliary_channel=in_channel // 2, 415 | dropout=dropout, 416 | ) 417 | 418 | else: 419 | self.out = WNConv2d(in_channel + 2, in_channel, 1) 420 | 421 | def forward(self, input, background, condition=None): 422 | out = input 423 | 424 | for resblock in self.resblocks: 425 | out = resblock(out, condition=condition) 426 | 427 | if self.attention: 428 | key_cat = torch.cat([input, out, background], 1) 429 | key = self.key_resblock(key_cat) 430 | query_cat = torch.cat([out, background], 1) 431 | query = self.query_resblock(query_cat) 432 | attn_out = self.causal_attention(query, key) 433 | out = self.out_resblock(out, attn_out) 434 | 435 | else: 436 | bg_cat = torch.cat([out, background], 1) 437 | out = self.out(bg_cat) 438 | 439 | return out 440 | 441 | 442 | class CondResNet(nn.Module): 443 | def __init__(self, in_channel, channel, kernel_size, n_res_block): 444 | super().__init__() 445 | 446 | blocks = [WNConv2d(in_channel, channel, kernel_size, padding=kernel_size // 2)] 447 | 448 | for i in range(n_res_block): 449 | blocks.append(GatedResBlock(channel, channel, kernel_size)) 450 | 451 | self.blocks = nn.Sequential(*blocks) 452 | 453 | def forward(self, input): 454 | return self.blocks(input) 455 | 456 | 457 | class PixelSNAIL(nn.Module): 458 | def __init__( 459 | self, 460 | shape, 461 | n_class, 462 | channel, 463 | kernel_size, 464 | n_block, 465 | n_res_block, 466 | res_channel, 467 | attention=True, 468 | dropout=0.1, 469 | n_cond_res_block=0, 470 | cond_res_channel=0, 471 | cond_res_kernel=3, 472 | n_out_res_block=0, 473 | ): 474 | super().__init__() 475 | 476 | height, width = shape 477 | 478 | self.n_class = n_class 479 | 480 | if kernel_size % 2 == 0: 481 | kernel = kernel_size + 1 482 | 483 | else: 484 | kernel = kernel_size 485 | 486 | self.horizontal = CausalConv2d( 487 | 3, channel, [kernel // 2, kernel], padding='down' 488 | ) 489 | self.vertical = CausalConv2d( 490 | 3, channel, [(kernel + 1) // 2, kernel // 2], padding='downright' 491 | ) 492 | 493 | coord_x = (torch.arange(height).float() - height / 2) / height 494 | coord_x = coord_x.view(1, 1, height, 1).expand(1, 1, height, width) 495 | coord_y = (torch.arange(width).float() - width / 2) / width 496 | coord_y = coord_y.view(1, 1, 1, width).expand(1, 1, height, width) 497 | self.register_buffer('background', torch.cat([coord_x, coord_y], 1)) 498 | 499 | self.blocks = nn.ModuleList() 500 | 501 | for i in range(n_block): 502 | self.blocks.append( 503 | PixelBlock( 504 | channel, 505 | res_channel, 506 | kernel_size, 507 | n_res_block, 508 | attention=attention, 509 | dropout=dropout, 510 | condition_dim=cond_res_channel, 511 | ) 512 | ) 513 | 514 | if n_cond_res_block > 0: 515 | self.cond_resnet = CondResNet( 516 | n_class, cond_res_channel, cond_res_kernel, n_cond_res_block 517 | ) 518 | 519 | out = [] 520 | 521 | for i in range(n_out_res_block): 522 | out.append(GatedResBlock(channel, res_channel, 1)) 523 | 524 | out.extend([nn.ELU(inplace=True), WNConv2d(channel, n_class, 1)]) 525 | 526 | self.out = nn.Sequential(*out) 527 | 528 | def forward(self, input, condition=None, cache=None): 529 | if cache is None: 530 | cache = {} 531 | batch, _, height, width = input.shape 532 | #input = ( 533 | # F.one_hot(input, self.n_class).permute(0, 3, 1, 2).type_as(self.background) 534 | #) 535 | horizontal = shift_down(self.horizontal(input)) 536 | vertical = shift_right(self.vertical(input)) 537 | out = horizontal + vertical 538 | 539 | background = self.background[:, :, :height, :].expand(batch, 2, height, width) 540 | 541 | if condition is not None: 542 | if 'condition' in cache: 543 | condition = cache['condition'] 544 | condition = condition[:, :, :height, :] 545 | 546 | else: 547 | condition = ( 548 | F.one_hot(condition, self.n_class) 549 | .permute(0, 3, 1, 2) 550 | .type_as(self.background) 551 | ) 552 | condition = self.cond_resnet(condition) 553 | condition = F.interpolate(condition, scale_factor=2) 554 | cache['condition'] = condition.detach().clone() 555 | condition = condition[:, :, :height, :] 556 | 557 | for block in self.blocks: 558 | out = block(out, background, condition=condition) 559 | 560 | out = self.out(out) 561 | 562 | return out 563 | 564 | 565 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.optim import Optimizer 4 | import math 5 | 6 | 7 | class AdamNormGrad(Optimizer): 8 | """Implements Adam algorithm. 9 | 10 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 11 | 12 | Arguments: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float, optional): learning rate (default: 1e-3) 16 | betas (Tuple[float, float], optional): coefficients used for computing 17 | running averages of gradient and its square (default: (0.9, 0.999)) 18 | eps (float, optional): term added to the denominator to improve 19 | numerical stability (default: 1e-8) 20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 21 | 22 | .. _Adam\: A Method for Stochastic Optimization: 23 | https://arxiv.org/abs/1412.6980 24 | """ 25 | 26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 27 | weight_decay=0): 28 | defaults = dict(lr=lr, betas=betas, eps=eps, 29 | weight_decay=weight_decay) 30 | super(AdamNormGrad, self).__init__(params, defaults) 31 | 32 | def step(self, closure=None): 33 | """Performs a single optimization step. 34 | 35 | Arguments: 36 | closure (callable, optional): A closure that reevaluates the model 37 | and returns the loss. 38 | """ 39 | loss = None 40 | if closure is not None: 41 | loss = closure() 42 | 43 | for group in self.param_groups: 44 | for p in group['params']: 45 | if p.grad is None: 46 | continue 47 | grad = p.grad.data 48 | # normalize grdients 49 | grad = grad / ( torch.norm(grad,2) + 1.e-7 ) 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | # Exponential moving average of gradient values 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | # Exponential moving average of squared gradient values 58 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 59 | 60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 61 | beta1, beta2 = group['betas'] 62 | 63 | state['step'] += 1 64 | 65 | if group['weight_decay'] != 0: 66 | grad = grad.add(group['weight_decay'], p.data) 67 | 68 | # Decay the first and second moment running average coefficient 69 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 70 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 71 | 72 | denom = exp_avg_sq.sqrt().add_(group['eps']) 73 | 74 | bias_correction1 = 1 - beta1 ** state['step'] 75 | bias_correction2 = 1 - beta2 ** state['step'] 76 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 77 | 78 | p.data.addcdiv_(-step_size, exp_avg, denom) 79 | 80 | return loss 81 | -------------------------------------------------------------------------------- /utils/plot_images.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.gridspec as gridspec 3 | import numpy as np 4 | import os 5 | 6 | 7 | def imshow(img, title=None, interpolation=None, show_plot=False): 8 | npimg = img.detach().cpu().numpy() 9 | plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation=interpolation) 10 | if title is not None: 11 | plt.title(title) 12 | if show_plot: 13 | plt.show() 14 | 15 | 16 | def generate_fancy_grid(config, dir, reference_data, generated, col_num=4, row_num=3): 17 | import cv2 18 | 19 | image_size = config.input_size[-1] 20 | width = col_num*image_size+2 21 | height = row_num*image_size+2 22 | 23 | print('references', reference_data.shape) 24 | print('generated', generated.shape) 25 | 26 | generated_dir = os.path.join(dir, 'generated/') 27 | os.makedirs(generated_dir, exist_ok=True) 28 | 29 | for k in range(len(reference_data)): 30 | grid = np.ones((config.input_size[0], height, width)) 31 | original_image = reference_data[k].reshape(1, *config.input_size).cpu().detach().numpy() 32 | grid[:, 0:image_size, 0:image_size] = original_image 33 | generated_images = generated[k].reshape(-1, *config.input_size).cpu().detach().numpy() 34 | offset = 2 35 | counts = 0 36 | for i in range(row_num): 37 | j_counts = col_num 38 | extra_offset = 0 39 | if i == 0: 40 | j_counts = col_num-1 41 | extra_offset = image_size 42 | 43 | row = i*image_size+offset 44 | for j in range(j_counts): 45 | generated_images[counts] 46 | grid[:, row:row+image_size, extra_offset+j*image_size+offset:extra_offset+(j+1)*image_size+offset] = generated_images[counts] 47 | counts += 1 48 | 49 | if config.input_size[0] > 1: 50 | grid = np.transpose(grid, (1, 2, 0)) 51 | grid = np.squeeze(grid) 52 | plt.imsave(arr=np.clip(grid, 0, 1), 53 | fname=generated_dir + "generated_{}.png".format(k), 54 | cmap='gray', format='png') 55 | 56 | img = cv2.imread(generated_dir + "generated_{}.png".format(k)) 57 | res = cv2.resize(img, dsize=(width*3, height*3), interpolation=cv2.INTER_NEAREST) 58 | cv2.imwrite(generated_dir + "generated_{}.png".format(k), res) 59 | # plt.show() 60 | 61 | 62 | def plot_images_in_line(images, args, dir, file_name): 63 | import cv2 64 | 65 | width = len(images) * 28 66 | height = 28 67 | grid = np.ones((height, width)) 68 | for index, image in enumerate(images): 69 | image = image.reshape(*args.input_size).cpu().detach().numpy() 70 | grid[0:28, 28*index:28*(index+1)] = image[0] 71 | file_name = os.path.join(dir, file_name) 72 | plt.imsave(arr=grid / 255, 73 | fname=file_name, 74 | cmap='gray', format='png') 75 | 76 | img = cv2.imread(file_name) 77 | res = cv2.resize(img, dsize=(width*3, height*3), interpolation=cv2.INTER_NEAREST) 78 | cv2.imwrite(file_name, res) 79 | 80 | 81 | def plot_images(config, x_sample, dir, file_name, size_x=3, size_y=3): 82 | if len(x_sample.shape) < 4: 83 | x_sample = x_sample.reshape(-1, *config.input_size) 84 | fig = plt.figure(figsize=(size_x, size_y)) 85 | # fig = plt.figure(1) 86 | gs = gridspec.GridSpec(size_x, size_y) 87 | gs.update(wspace=0.01, hspace=0.01) 88 | 89 | for i, sample in enumerate(x_sample): 90 | ax = plt.subplot(gs[i]) 91 | plt.axis('off') 92 | ax.set_xticklabels([]) 93 | ax.set_yticklabels([]) 94 | ax.set_aspect('equal') 95 | 96 | sample = sample.swapaxes(0, 2) 97 | sample = sample.swapaxes(0, 1) 98 | if config.input_type == 'binary' or config.input_type == 'gray': 99 | sample = sample[:, :, 0] 100 | plt.imshow(sample, cmap='gray') 101 | else: 102 | plt.imshow(sample) 103 | 104 | plt.savefig(dir + file_name + '.png', bbox_inches='tight') 105 | plt.close(fig) 106 | 107 | 108 | -------------------------------------------------------------------------------- /utils/training.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | 4 | 5 | def set_beta(args, epoch): 6 | if args.warmup == 0: 7 | beta = 1. 8 | else: 9 | beta = 1. * epoch / args.warmup 10 | if beta > 1.: 11 | beta = 1. 12 | return beta 13 | 14 | 15 | def train_one_epoch(epoch, args, train_loader, model, optimizer): 16 | train_loss, train_re, train_kl = 0, 0, 0 17 | model.train() 18 | beta = set_beta(args, epoch) 19 | print('beta: {}'.format(beta)) 20 | if args.approximate_prior is True: 21 | with torch.no_grad(): 22 | cached_z, cached_log_var = model.cache_z(train_loader.dataset) 23 | cache = (cached_z, cached_log_var) 24 | else: 25 | cache = None 26 | 27 | for batch_idx, (data, indices, target) in enumerate(train_loader): 28 | data, indices, target = data.to(args.device), indices.to(args.device), target.to(args.device) 29 | 30 | if args.dynamic_binarization: 31 | x = torch.bernoulli(data) 32 | else: 33 | x = data 34 | 35 | x = (x, indices) 36 | optimizer.zero_grad() 37 | loss, RE, KL = model.calculate_loss(x, beta, average=True, cache=cache, dataset=train_loader.dataset) 38 | loss.backward() 39 | optimizer.step() 40 | 41 | with torch.no_grad(): 42 | train_loss += loss.data.item() 43 | train_re += -RE.data.item() 44 | train_kl += KL.data.item() 45 | if cache is not None: 46 | cache = (cache[0].detach(), cache[1].detach()) 47 | 48 | train_loss /= len(train_loader) 49 | train_re /= len(train_loader) 50 | train_kl /= len(train_loader) 51 | return train_loss, train_re, train_kl 52 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def importing_model(args): 5 | if args.model_name == 'vae': 6 | from models.VAE import VAE 7 | elif args.model_name == 'hvae_2level': 8 | from models.HVAE_2level import VAE 9 | elif args.model_name == 'convhvae_2level': 10 | from models.convHVAE_2level import VAE 11 | elif args.model_name == 'new_vae': 12 | from models.new_vae import VAE 13 | elif args.model_name == 'single_conv': 14 | from models.fully_conv import VAE 15 | elif args.model_name == 'pixelcnn': 16 | from models.PixelCNN import VAE 17 | else: 18 | raise Exception('Wrong name of the model!') 19 | return VAE 20 | 21 | 22 | def save_model(save_path, load_path, content): 23 | torch.save(content, save_path) 24 | os.rename(save_path, load_path) 25 | 26 | 27 | def load_model(load_path, model, optimizer=None): 28 | checkpoint = torch.load(load_path) 29 | model.load_state_dict(checkpoint['state_dict']) 30 | if optimizer is not None: 31 | optimizer.load_state_dict(checkpoint['optimizer']) 32 | return checkpoint 33 | --------------------------------------------------------------------------------