├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── analysis.py
├── datasets
├── dynamic_mnist
│ └── readme.me
├── fashion_mnist
│ └── readme.me
└── omniglot
│ └── readme.me
├── density_estimation.py
├── images
├── augmentation.gif
├── celebA_exemplar_generation.png
├── cyclic_generation.png
├── data_augmentation.png
├── density_estimation.png
├── exemplar_generation.png
└── full_generation.png
├── models
├── AbsHModel.py
├── AbsModel.py
├── BaseModel.py
├── HVAE_2level.py
├── PixelCNN.py
├── VAE.py
├── __init__.py
├── convHVAE_2level.py
└── fully_conv.py
├── pretrained_model
└── exemplar_prior_on_dynamic_mnist_model_name=vae
│ └── 1
│ ├── checkpoint.pth
│ ├── checkpoint_best.pth
│ ├── generated
│ ├── generated_0.png
│ ├── generated_1.png
│ ├── generated_10.png
│ ├── generated_11.png
│ ├── generated_12.png
│ ├── generated_13.png
│ ├── generated_14.png
│ ├── generated_15.png
│ ├── generated_16.png
│ ├── generated_17.png
│ ├── generated_18.png
│ ├── generated_19.png
│ ├── generated_2.png
│ ├── generated_20.png
│ ├── generated_21.png
│ ├── generated_22.png
│ ├── generated_23.png
│ ├── generated_24.png
│ ├── generated_25.png
│ ├── generated_26.png
│ ├── generated_27.png
│ ├── generated_28.png
│ ├── generated_29.png
│ ├── generated_3.png
│ ├── generated_30.png
│ ├── generated_31.png
│ ├── generated_32.png
│ ├── generated_33.png
│ ├── generated_34.png
│ ├── generated_35.png
│ ├── generated_36.png
│ ├── generated_37.png
│ ├── generated_38.png
│ ├── generated_39.png
│ ├── generated_4.png
│ ├── generated_40.png
│ ├── generated_41.png
│ ├── generated_42.png
│ ├── generated_43.png
│ ├── generated_44.png
│ ├── generated_45.png
│ ├── generated_46.png
│ ├── generated_47.png
│ ├── generated_48.png
│ ├── generated_49.png
│ ├── generated_5.png
│ ├── generated_6.png
│ ├── generated_7.png
│ ├── generated_8.png
│ └── generated_9.png
│ ├── generations_0.png
│ ├── real.png
│ ├── reconstructions.png
│ ├── vae.config
│ ├── vae.test_kl
│ ├── vae.test_log_likelihood
│ ├── vae.test_loss
│ ├── vae.test_re
│ ├── vae.train_kl
│ ├── vae.train_loss
│ ├── vae.train_re
│ ├── vae.val_kl
│ ├── vae.val_loss
│ ├── vae.val_re
│ ├── vae_config.txt
│ ├── vae_experiment_log.txt
│ └── whole_log.txt
├── requirements.txt
└── utils
├── __init__.py
├── classify_data.py
├── distributions.py
├── evaluation.py
├── knn_on_latent.py
├── load_data
├── __init__.py
├── base_load_data.py
└── data_loader_instances.py
├── nn.py
├── optimizer.py
├── plot_images.py
├── training.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/
2 | snapshots/*
3 | models/__pycache__
4 | utils/load_data/__pycache__
5 | utils/__pycache__
6 | checkpoints/
7 | __pycache__
8 | .idea/*
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Jakub Tomczak
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Exemplar-VAE
2 | Code for reproducing results in [Exemplar VAE](https://arxiv.org/abs/2004.04795) paper; Accepted to NeurIPS 2020
3 |
4 | ## Requirements
5 | ```
6 | pip3 install -r requirements.txt
7 | ```
8 | ## Exemplar VAE Samples
9 |
10 |
11 |
12 | ## Exemplar Based Generation
13 | ```
14 | python3 analysis.py --dir pretrained_model --generate
15 | ```
16 |
17 |
18 |
19 |
20 | ## Density Estimation
21 | ```
22 | python3 density_estimation.py --prior exemplar_prior --dataset {dynamic_mnist|fashion_mnist|omniglot} --model_name {vae|hvae_2level|convhvae_2level} --number_components {25000|11500} --approximate_prior {True|False}
23 | ```
24 |
25 |
26 |
27 | ## Data Augmentation
28 | ```
29 | python3 analysis.py --dir pretrained_model --classify
30 | ```
31 |
32 |
33 |
34 |
35 | ## Cyclic Generation
36 | ```
37 | python3 analysis.py --dir pretrained_model --cyclic_generation
38 | ```
39 |
40 |
41 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/__init__.py
--------------------------------------------------------------------------------
/analysis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from utils.load_data.data_loader_instances import load_dataset
4 | import torchvision
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import os
8 | from utils.plot_images import imshow
9 | from utils.utils import load_model
10 | from utils.classify_data import classify_data
11 | from utils.knn_on_latent import report_knn_on_latent, extract_full_data
12 | from utils.evaluation import compute_mean_variance_per_dimension
13 | from utils.plot_images import plot_images_in_line, generate_fancy_grid
14 | from utils.utils import importing_model
15 | from sklearn.manifold import TSNE
16 | import copy
17 | from pylab import rcParams
18 |
19 |
20 | parser = argparse.ArgumentParser(description='VAE+VampPrior')
21 | parser.add_argument('--KNN', action='store_true', default=False, help='run KNN classification on latent')
22 | parser.add_argument('--generate', action='store_true', default=False, help='generate images')
23 | parser.add_argument('--classify', action='store_true', default=False,
24 | help='train a classifier on data with augmentation')
25 | parser.add_argument('--dir', type=str, default='directory of pretrained model')
26 | parser.add_argument('--just_log_likelihood', action='store_true', default=False)
27 | parser.add_argument('--cyclic_generation', action='store_true', default=False, help='cyclic generation')
28 | parser.add_argument('--training_set_size', default=50000, type=int)
29 | parser.add_argument('--hyper_lambda', type=float, default=0.4, help='proportion of real data to augmented data')
30 | parser.add_argument('--lr', type=float, default=0.1)
31 | parser.add_argument('--batch_size', type=int, default=100)
32 | parser.add_argument('--input_size', type=list, default=[1, 28, 28])
33 | parser.add_argument('--count_active_dimensions', action='store_true', default=False)
34 | parser.add_argument('--grid_interpolation', action='store_true', default=False)
35 | parser.add_argument('--tsne_visualization', action='store_true', default=False)
36 | parser.add_argument('--hidden_units', type=int, default=1024)
37 | parser.add_argument('--save_model_path', type=str, default='')
38 | parser.add_argument('--classification_dir', type=str, default='classification_report')
39 | parser.add_argument('--epochs', type=int, default=100)
40 | parser.add_argument('--seed', type=int, default=1)
41 | args = parser.parse_args()
42 |
43 | print(args)
44 |
45 | TRAIN_NUM = 50000
46 |
47 |
48 | def plot_data(data, labels):
49 | k = 10
50 | print(data.shape)
51 | subplot_num = data.shape[1]
52 | for i in range(subplot_num):
53 | plt.subplot2grid((subplot_num, 1), (i, 0), colspan=1, rowspan=1)
54 | imshow(torchvision.utils.make_grid(data[:k, i, :].view(-1, 1, 28, 28)))
55 | plt.axis('off')
56 | print(labels[:k, i, :].squeeze())
57 | plt.show()
58 |
59 | directory = args.dir
60 |
61 |
62 | def grid_interpolation_in_latent(model, dir, index, reference_image):
63 | z, _ = model.q_z(reference_image.to(args.device), prior=True)
64 | whole_generation = []
65 | for offset_0 in range(-2, 3, 1):
66 | row_generation = []
67 | for offset_1 in range(-2, 3, 1):
68 | new_z = copy.deepcopy(z)
69 | new_z[0][0] += offset_0*3
70 | new_z[0][1] += offset_1*3
71 | image = model.generate_x_from_z(new_z, with_reparameterize=False)
72 | row_generation.append(image)
73 | whole_generation.append(torch.cat(row_generation, dim=0))
74 | # print("LENNN", len(whole_generation))
75 | whole_generation = torch.cat(whole_generation, dim=0)
76 | print('whole_generation shape', whole_generation.shape)
77 | imshow(torchvision.utils.make_grid(whole_generation.reshape(-1, *model.args.input_size), nrow=5))
78 | save_dir = os.path.join(dir, 'grid_interpolation')
79 | os.makedirs(save_dir, exist_ok=True)
80 | plt.axis('off')
81 | plt.savefig(os.path.join(save_dir, 'interpolation{}'.format(i)), bbox='tight')
82 |
83 |
84 | def compute_test_metrics(test_log_likelihood, test_kl, test_re):
85 | test_log_likelihood.append(torch.load(dir + model_name + '.test_log_likelihood'))
86 |
87 | kl = torch.load(dir + model_name + '.test_kl')
88 | if type(kl) == torch.Tensor:
89 | kl = kl.cpu().numpy()
90 | test_kl.append(kl)
91 |
92 | reconst = torch.load(dir + model_name + '.test_re')
93 | if type(reconst) == torch.Tensor:
94 | reconst = reconst.cpu().numpy()
95 | test_re.append(reconst)
96 |
97 |
98 | def cyclic_generation(start_data, dir, index):
99 | cyclic_generation_dir = os.path.join(dir, 'cyclic_generation')
100 | os.makedirs(cyclic_generation_dir, exist_ok=True)
101 | single_data = start_data.unsqueeze(0)
102 | generated_cycle = [single_data.to(args.device)]
103 | for i in range(29):
104 | single_data = \
105 | model.reference_based_generation_x(N=1, reference_image=single_data)
106 | generated_cycle.append(single_data)
107 |
108 | generated_cycle = torch.cat(generated_cycle, dim=0)
109 | plot_images_in_line(generated_cycle, args, cyclic_generation_dir, 'cycle_{}.png'.format(index))
110 |
111 |
112 | temp = ''
113 | active_units_text = ''
114 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
115 |
116 | for folder in sorted(os.listdir(directory)):
117 | if os.path.isdir(directory+'/'+folder) is False:
118 | continue
119 | knn_results = []
120 | test_log_likelihoods, test_kl, test_reconst, active_dimensions = [], [], [], []
121 | knn_dictionary = {'3': [], '5': [], '7': [], '9': [], '11': [], '13': [], '15': []}
122 |
123 |
124 | torch.manual_seed(args.seed)
125 | if args.device=='cuda':
126 | torch.cuda.manual_seed(args.seed)
127 | np.random.seed(args.seed)
128 |
129 | for filename in os.listdir(directory+'/'+folder):
130 | print('filename**', filename)
131 | dir = directory + '/' + folder+'/'+filename + '/'
132 | model_name_start_index = folder.find('model_name=')
133 | model_name = folder[model_name_start_index + len('model_name='):]
134 | print("MODEL NAME", model_name)
135 |
136 | config = torch.load(dir + model_name + '.config')
137 | config.device = args.device
138 | VAE = importing_model(config)
139 | model = VAE(config)
140 | model.to(args.device)
141 | train_loader, val_loader, test_loader, config = load_dataset(config,
142 | training_num=args.training_set_size,
143 | no_binarization=True)
144 |
145 | if args.just_log_likelihood is False:
146 | load_model(dir + 'checkpoint_best.pth', model)
147 | model.eval()
148 | try:
149 | print('prior variance', model.prior_log_variance.item())
150 | except:
151 | pass
152 |
153 | if args.cyclic_generation:
154 | with torch.no_grad():
155 | for i in range(10):
156 | random_image = torch.rand([784])
157 | cyclic_generation(random_image, dir, index=i)
158 |
159 | if args.KNN:
160 | with torch.no_grad():
161 | report_knn_on_latent(train_loader, val_loader, test_loader, model,
162 | dir, knn_dictionary, args, val=False)
163 | if args.generate:
164 | with torch.no_grad():
165 | exemplars_n = 50
166 | selected_indices = torch.randint(low=0, high=config.training_set_size, size=(exemplars_n,))
167 | reference_images, indices, labels =train_loader.dataset[selected_indices]
168 | per_exemplar = 11
169 | generated = model.reference_based_generation_x(N=per_exemplar, reference_image=reference_images)
170 | generated = generated.reshape(-1, per_exemplar, *config.input_size)
171 | rcParams['figure.figsize'] = 4, 3
172 | generated_dir = dir + 'generated/'
173 | if config.use_logit:
174 | reference_images = model.logit_inverse(reference_images)
175 | generate_fancy_grid(config, dir, reference_images, generated)
176 |
177 | if args.count_active_dimensions:
178 | train_loader, val_loader, test_loader, config = load_dataset(config,
179 | training_num=args.training_set_size,
180 | no_binarization=False)
181 | with torch.no_grad():
182 | num_active = compute_mean_variance_per_dimension(args, model, test_loader)
183 | active_dimensions.append(num_active)
184 |
185 | #TODO remove loop
186 | if args.grid_interpolation:
187 | with torch.no_grad():
188 | for i in range(100):
189 | image = train_loader.dataset.tensors[0][torch.randint(low=0, high=args.training_set_size,
190 | size=(1,))]
191 | grid_interpolation_in_latent(model, dir, i, reference_image=image)
192 |
193 | if args.tsne_visualization:
194 | test_x, _, test_labels = extract_full_data(test_loader)
195 | test_z, _ = model.q_z(test_x.to(args.device))
196 | tsne = TSNE(n_components=2)
197 | plt_colors = np.array(
198 | ['blue', 'orange', 'green', 'red', 'cyan', 'pink', 'purple', 'brown', 'gray', 'olive'])
199 |
200 | points_to_visualize = tsne.fit_transform(X=test_z.detach().cpu().numpy())
201 | plt.scatter(points_to_visualize[:, 0], points_to_visualize[:, 1],
202 | c=plt_colors[test_labels.cpu().numpy()], s=2)
203 | plt.savefig(dir+'tsne.png')
204 | plt.show()
205 |
206 | if args.classify:
207 | test_acc = []
208 | val_acc = []
209 | test_acc_single_run, val_acc_single_run = classify_data(train_loader, val_loader, test_loader,
210 | args.classification_dir, args, model)
211 | test_acc.append(test_acc_single_run)
212 | val_acc.append(val_acc_single_run)
213 | test_acc = np.array(test_acc)
214 | val_acc = np.array(val_acc)
215 |
216 | print('averaged test accuracy: {0:.2f} \\pm {1:.2f}'.format(np.mean(test_acc), np.std(test_acc)))
217 | print('averaged val accuracy: {0:.2f} \\pm {1:.2f}'.format(np.mean(val_acc), np.std(val_acc)))
218 | exit()
219 | else:
220 | compute_test_metrics(test_log_likelihoods, test_kl, test_reconst)
221 |
222 | if args.just_log_likelihood:
223 | test_log_likelihoods = np.array(test_log_likelihoods)
224 | print("test log-likelihood", np.mean(test_log_likelihoods), np.std(test_log_likelihoods))
225 |
226 | if args.count_active_dimensions:
227 | active_dimensions = np.array(active_dimensions).astype(float)
228 | print(np.mean(active_dimensions), np.std(active_dimensions))
229 |
--------------------------------------------------------------------------------
/datasets/dynamic_mnist/readme.me:
--------------------------------------------------------------------------------
1 | Will be downloaded by pytorch
2 |
3 |
--------------------------------------------------------------------------------
/datasets/fashion_mnist/readme.me:
--------------------------------------------------------------------------------
1 | Will be downloaded by pytorch
2 |
--------------------------------------------------------------------------------
/datasets/omniglot/readme.me:
--------------------------------------------------------------------------------
1 | https://github.com/yburda/iwae/tree/master/datasets/OMNIGLOT
2 |
--------------------------------------------------------------------------------
/density_estimation.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import datetime
4 | from utils.load_data.data_loader_instances import load_dataset
5 | from utils.utils import importing_model
6 | import torch
7 | import math
8 | import os
9 | from utils.utils import save_model, load_model
10 | from utils.optimizer import AdamNormGrad
11 | import time
12 | from utils.training import train_one_epoch
13 | from utils.evaluation import evaluate_loss, final_evaluation
14 | import random
15 |
16 | def str2bool(v):
17 | if isinstance(v, bool):
18 | return v
19 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
20 | return True
21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
22 | return False
23 | else:
24 | raise argparse.ArgumentTypeError('Boolean value expected.')
25 |
26 |
27 | parser = argparse.ArgumentParser(description='VAE+VampPrior')
28 | parser.add_argument('--batch_size', type=int, default=100, metavar='BStrain',
29 | help='input batch size for training (default: 100)')
30 | parser.add_argument('--test_batch_size', type=int, default=100, metavar='BStest',
31 | help='input batch size for testing (default: 100)')
32 | parser.add_argument('--epochs', type=int, default=2000, metavar='E',
33 | help='number of epochs to train (default: 2000)')
34 | parser.add_argument('--lr', type=float, default=0.0005, metavar='LR',
35 | help='learning rate (default: 0.0005)')
36 | parser.add_argument('--early_stopping_epochs', type=int, default=50, metavar='ES',
37 | help='number of epochs for early stopping')
38 | parser.add_argument('--z1_size', type=int, default=40, metavar='M1',
39 | help='latent size')
40 | parser.add_argument('--z2_size', type=int, default=40, metavar='M2',
41 | help='latent size')
42 | parser.add_argument('--input_size', type=int, default=[1, 28, 28], metavar='D',
43 | help='input size')
44 | parser.add_argument('--number_components', type=int, default=50000, metavar='NC',
45 | help='number of pseudo-inputs')
46 | parser.add_argument('--pseudoinputs_mean', type=float, default=-0.05, metavar='PM',
47 | help='mean for init pseudo-inputs')
48 | parser.add_argument('--pseudoinputs_std', type=float, default=0.01, metavar='PS',
49 | help='std for init pseudo-inputs')
50 | parser.add_argument('--use_training_data_init', action='store_true', default=False,
51 | help='initialize pseudo-inputs with randomly chosen training data')
52 | parser.add_argument('--model_name', type=str, default='vae', metavar='MN',
53 | help='model name: vae, hvae_2level, convhvae_2level')
54 | parser.add_argument('--prior', type=str, default='vampprior', metavar='P',
55 | help='prior: standard, vampprior, exemplar_prior')
56 | parser.add_argument('--input_type', type=str, default='binary', metavar='IT',
57 | help='type of the input: binary, gray, continuous, pca')
58 | parser.add_argument('--S', type=int, default=5000, metavar='SLL',
59 | help='number of samples used for approximating log-likelihood,'
60 | 'i.e. number of samples in IWAE')
61 | parser.add_argument('--MB', type=int, default=100, metavar='MBLL',
62 | help='size of a mini-batch used for approximating log-likelihood')
63 | parser.add_argument('--use_whole_train', type=str2bool, default=False,
64 | help='use whole training data points at the test time')
65 | parser.add_argument('--dataset_name', type=str, default='freyfaces', metavar='DN',
66 | help='name of the dataset: static_mnist, dynamic_mnist, omniglot, caltech101silhouettes,'
67 | ' histopathologyGray, freyfaces, cifar10')
68 | parser.add_argument('--dynamic_binarization', action='store_true', default=False,
69 | help='allow dynamic binarization')
70 | parser.add_argument('--seed', type=int, default=14, metavar='S',
71 | help='random seed (default: 14)')
72 |
73 | parser.add_argument('--no_mask', action='store_true', default=False, help='no leave one out')
74 |
75 | parser.add_argument('--parent_dir', type=str, default='')
76 | parser.add_argument('--same_variational_var', type=str2bool, default=False,
77 | help='use same variance for different dimentions')
78 | parser.add_argument('--model_signature', type=str, default='', help='load from this directory and continue training')
79 | parser.add_argument('--warmup', type=int, default=100, metavar='WU',
80 | help='number of epochs for warmu-up')
81 | parser.add_argument('--slurm_task_id', type=str, default='')
82 | parser.add_argument('--slurm_job_id', type=str, default='')
83 | parser.add_argument('--approximate_prior', type=str2bool, default=False)
84 | parser.add_argument('--just_evaluate', type=str2bool, default=False)
85 | parser.add_argument('--no_attention', type=str2bool, default=False)
86 | parser.add_argument('--approximate_k', type=int, default=10)
87 | parser.add_argument('--hidden_size', type=int, default=300)
88 | parser.add_argument('--base_dir', type=str, default='snapshots/')
89 | parser.add_argument('--continuous', type=str2bool, default=False)
90 | parser.add_argument('--use_logit', type=str2bool, default=False)
91 | parser.add_argument('--lambd', type=float, default=1e-4)
92 | parser.add_argument('--bottleneck', type=int, default=6)
93 | parser.add_argument('--training_set_size', type=int, default=50000)
94 |
95 |
96 | def initial_or_load(checkpoint_path_load, model, optimizer, dir):
97 | if os.path.exists(checkpoint_path_load):
98 | model_loaded_str = "******model is loaded*********"
99 | print(model_loaded_str)
100 | with open(dir + 'whole_log.txt', 'a') as f:
101 | print(model_loaded_str, file=f)
102 | checkpoint = load_model(checkpoint_path_load, model, optimizer)
103 | begin_epoch = checkpoint['epoch']
104 | best_loss = checkpoint['best_loss']
105 | e = checkpoint['e']
106 | else:
107 | torch.manual_seed(args.seed)
108 | if args.device=='cuda':
109 | torch.cuda.manual_seed(args.seed)
110 | random.seed(args.seed)
111 | begin_epoch = 1
112 | best_loss = math.inf
113 | e = 0
114 | return begin_epoch, best_loss, e
115 |
116 |
117 | def save_loss_files(folder, train_loss_history,
118 | train_re_history, train_kl_history, val_loss_history, val_re_history, val_kl_history):
119 | torch.save(train_loss_history, folder + '.train_loss')
120 | torch.save(train_re_history, folder + '.train_re')
121 | torch.save(train_kl_history, folder + '.train_kl')
122 | torch.save(val_loss_history, folder + '.val_loss')
123 | torch.save(val_re_history, folder + '.val_re')
124 | torch.save(val_kl_history, folder + '.val_kl')
125 |
126 |
127 | def run_density_estimation(args, train_loader_input, val_loader_input, test_loader_input, model, optimizer, dir, model_name='vae'):
128 | torch.save(args, dir + args.model_name + '.config')
129 | train_loss_history, train_re_history, train_kl_history, val_loss_history, val_re_history, val_kl_history, \
130 | time_history = [], [], [], [], [], [], []
131 | checkpoint_path_save = os.path.join(dir, 'checkpoint_temp.pth')
132 | checkpoint_path_load = os.path.join(dir, 'checkpoint.pth')
133 | best_model_path_load = os.path.join(dir, 'checkpoint_best.pth')
134 | decayed = False
135 | time_history = []
136 | # with torch.autograd.detect_anomaly():
137 | begin_epoch, best_loss, e = initial_or_load(checkpoint_path_load, model, optimizer, dir)
138 | if args.just_evaluate is False:
139 | for epoch in range(begin_epoch, args.epochs + 1):
140 | time_start = time.time()
141 | train_loss_epoch, train_re_epoch, train_kl_epoch \
142 | = train_one_epoch(epoch, args, train_loader_input, model, optimizer)
143 | with torch.no_grad():
144 | val_loss_epoch, val_re_epoch, val_kl_epoch = evaluate_loss(args, model, val_loader_input,
145 | dataset=train_loader_input.dataset)
146 | time_end = time.time()
147 | time_elapsed = time_end - time_start
148 | content = {'epoch': epoch, 'state_dict': model.state_dict(),
149 | 'optimizer': optimizer.state_dict(), 'best_loss': best_loss, 'e': e}
150 | if epoch % 10 == 0:
151 | save_model(checkpoint_path_save, checkpoint_path_load, content)
152 | if val_loss_epoch < best_loss:
153 | e = 0
154 | best_loss = val_loss_epoch
155 | print('->model saved<-')
156 | save_model(checkpoint_path_save, best_model_path_load, content)
157 | else:
158 | e += 1
159 | if epoch < args.warmup:
160 | e = 0
161 | if e > args.early_stopping_epochs:
162 | break
163 |
164 | if math.isnan(val_loss_epoch):
165 | print("***** val loss is Nan *******")
166 | break
167 |
168 | for param_group in optimizer.param_groups:
169 | learning_rate = param_group['lr']
170 | break
171 |
172 | time_history.append(time_elapsed)
173 |
174 | epoch_report = 'Epoch: {}/{}, Time elapsed: {:.2f}s\n' \
175 | 'learning rate: {:.5f}\n' \
176 | '* Train loss: {:.2f} (RE: {:.2f}, KL: {:.2f})\n' \
177 | 'o Val. loss: {:.2f} (RE: {:.2f}, KL: {:.2f})\n' \
178 | '--> Early stopping: {}/{} (BEST: {:.2f})\n'.format(epoch, args.epochs, time_elapsed,
179 | learning_rate,
180 | train_loss_epoch, train_re_epoch,
181 | train_kl_epoch, val_loss_epoch,
182 | val_re_epoch, val_kl_epoch, e,
183 | args.early_stopping_epochs, best_loss)
184 |
185 | if args.prior == 'exemplar_prior':
186 | print("Prior Variance", model.prior_log_variance.item())
187 | if args.continuous is True:
188 | print("Decoder Variance", model.decoder_logstd.item())
189 | print(epoch_report)
190 | with open(dir + 'whole_log.txt', 'a') as f:
191 | print(epoch_report, file=f)
192 |
193 | train_loss_history.append(train_loss_epoch), train_re_history.append(
194 | train_re_epoch), train_kl_history.append(train_kl_epoch)
195 | val_loss_history.append(val_loss_epoch), val_re_history.append(val_re_epoch), val_kl_history.append(
196 | val_kl_epoch)
197 |
198 | save_loss_files(dir + args.model_name, train_loss_history,
199 | train_re_history, train_kl_history, val_loss_history, val_re_history, val_kl_history)
200 |
201 | with torch.no_grad():
202 | final_evaluation(train_loader_input, test_loader_input, val_loader_input,
203 | best_model_path_load, model, optimizer, args, dir)
204 |
205 |
206 | def run(args, kwargs):
207 | print('create model')
208 | # importing model
209 | VAE = importing_model(args)
210 | print('load data')
211 | train_loader, val_loader, test_loader, args = load_dataset(args, use_fixed_validation=True, **kwargs)
212 | if args.slurm_job_id != '':
213 | args.model_signature = str(args.seed)
214 | # base_dir = 'checkpoints/final_report/'
215 | elif args.model_signature == '':
216 | args.model_signature = str(datetime.datetime.now())[0:19]
217 |
218 | if args.parent_dir == '':
219 | args.parent_dir = args.prior + '_on_' + args.dataset_name+'_model_name='+args.model_name
220 | model_name = args.dataset_name + '_' + args.model_name + '_' + args.prior \
221 | + '_(components_' + str(args.number_components) + ', lr=' + str(args.lr) + ')'
222 | snapshots_path = os.path.join(args.base_dir, args.parent_dir) + '/'
223 | dir = snapshots_path + args.model_signature + '_' + model_name + '_' + args.parent_dir + '/'
224 |
225 | if args.just_evaluate:
226 | config = torch.load(dir + args.model_name + '.config')
227 | config.translation = False
228 | config.hidden_size = 300
229 | model = VAE(config)
230 | else:
231 | model = VAE(args)
232 | if not os.path.exists(dir):
233 | os.makedirs(dir)
234 | model.to(args.device)
235 | optimizer = AdamNormGrad(model.parameters(), lr=args.lr)
236 | print(args)
237 | config_file = dir+'vae_config.txt'
238 | with open(config_file, 'a') as f:
239 | print(args, file=f)
240 | run_density_estimation(args, train_loader, val_loader, test_loader, model, optimizer, dir, model_name = args.model_name)
241 |
242 |
243 | if __name__ == "__main__":
244 | args = parser.parse_args()
245 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
246 |
247 | kwargs = {'num_workers': 2, 'pin_memory': True} if args.device=='cuda' else {}
248 | run(args, kwargs)
249 |
250 |
--------------------------------------------------------------------------------
/images/augmentation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/augmentation.gif
--------------------------------------------------------------------------------
/images/celebA_exemplar_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/celebA_exemplar_generation.png
--------------------------------------------------------------------------------
/images/cyclic_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/cyclic_generation.png
--------------------------------------------------------------------------------
/images/data_augmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/data_augmentation.png
--------------------------------------------------------------------------------
/images/density_estimation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/density_estimation.png
--------------------------------------------------------------------------------
/images/exemplar_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/exemplar_generation.png
--------------------------------------------------------------------------------
/images/full_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/images/full_generation.png
--------------------------------------------------------------------------------
/models/AbsHModel.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch
4 | import torch.utils.data
5 | from utils.distributions import log_normal_diag
6 | from .BaseModel import BaseModel
7 |
8 |
9 | class BaseHModel(BaseModel):
10 | def __init__(self, args):
11 | super(BaseHModel, self).__init__(args)
12 |
13 | def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices):
14 | z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = latent_stats
15 | if exemplars_embedding is None and self.args.prior == 'exemplar_prior':
16 | exemplars_embedding = self.get_exemplar_set(z2_q_mean, z2_q_logvar,
17 | dataset, cache, x_indices)
18 | log_p_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size),
19 | z1_p_mean.view(-1, self.args.z1_size),
20 | z1_p_logvar.view(-1, self.args.z1_size), dim=1)
21 | log_q_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size),
22 | z1_q_mean.view(-1, self.args.z1_size),
23 | z1_q_logvar.view(-1, self.args.z1_size), dim=1)
24 | log_p_z2 = self.log_p_z(z=(z2_q, x_indices),
25 | exemplars_embedding=exemplars_embedding)
26 | log_q_z2 = log_normal_diag(z2_q.view(-1, self.args.z2_size),
27 | z2_q_mean.view(-1, self.args.z2_size),
28 | z2_q_logvar.view(-1, self.args.z2_size), dim=1)
29 | return -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)
30 |
31 | def generate_x_from_z(self, z, with_reparameterize=True):
32 | z1_sample_mean, z1_sample_logvar = self.p_z1(z)
33 | if with_reparameterize:
34 | z1_sample_rand = self.reparameterize(z1_sample_mean, z1_sample_logvar)
35 | else:
36 | z1_sample_rand = z1_sample_mean
37 |
38 | if self.args.model_name=='pixelcnn':
39 | generated_xs = self.pixelcnn_generate(z1_sample_rand.view(-1, self.args.z1_size), z.reshape(-1, self.args.z2_size))
40 | else:
41 | generated_xs, _ = self.p_x(z1_sample_rand.view(-1, self.args.z1_size),
42 | z.view(-1, self.args.z2_size))
43 | return generated_xs
44 |
45 | def p_z1(self, z2):
46 | z2 = self.p_z1_layers_z2(z2)
47 | z1_p_mean = self.p_z1_mean(z2)
48 | z1_p_logvar = self.p_z1_logvar(z2)
49 | return z1_p_mean, z1_p_logvar
50 |
51 | def q_z1(self, x, z2):
52 | x = self.q_z1_layers_x(x)
53 | if self.args.model_name == 'convhvae_2level' or self.args.model_name == 'pixelcnn':
54 | x = x.view(x.size(0),-1)
55 | z2 = self.q_z1_layers_z2(z2)
56 | h = torch.cat((x,z2),1)
57 | h = self.q_z1_layers_joint(h)
58 | z1_q_mean = self.q_z1_mean(h)
59 | z1_q_logvar = self.q_z1_logvar(h)
60 | return z1_q_mean, z1_q_logvar
61 |
62 | def p_x(self, z1, z2, x=None):
63 | z1 = self.p_x_layers_z1(z1)
64 |
65 | z2 = self.p_x_layers_z2(z2)
66 |
67 | if self.args.model_name == 'pixelcnn':
68 | z2 = z2.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])
69 | z1 = z1.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])
70 | h = torch.cat((x, z1, z2), 1)
71 | # pixelcnn part of the decoder
72 | h_decoder = self.pixelcnn(h)
73 |
74 | else:
75 |
76 | h = torch.cat((z1, z2), 1)
77 | if 'convhvae_2level' in self.args.model_name:
78 | h = self.p_x_layers_joint_pre(h)
79 | h = h.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])
80 |
81 | h_decoder = self.p_x_layers_joint(h)
82 | x_mean = self.p_x_mean(h_decoder)
83 | if 'convhvae_2level' in self.args.model_name or self.args.model_name=='pixelcnn':
84 | x_mean = x_mean.view(-1, np.prod(self.args.input_size))
85 |
86 | if self.args.input_type == 'binary':
87 | x_logvar = 0.
88 | else:
89 | x_mean = torch.clamp(x_mean, min=0.+1./512., max=1.-1./512.)
90 | x_logvar = self.p_x_logvar(h_decoder)
91 | if 'convhvae_2level' in self.args.model_name or self.args.model_name=='pixelcnn':
92 | x_logvar = x_logvar.view(-1, np.prod(self.args.input_size))
93 |
94 | return x_mean, x_logvar
95 |
96 | def forward(self, x):
97 | z2_q_mean, z2_q_logvar = self.q_z(x)
98 | z2_q = self.reparameterize(z2_q_mean, z2_q_logvar)
99 | z1_q_mean, z1_q_logvar = self.q_z1(x, z2_q)
100 | z1_q = self.reparameterize(z1_q_mean, z1_q_logvar)
101 | z1_p_mean, z1_p_logvar = self.p_z1(z2_q)
102 | if self.args.model_name == 'pixelcnn':
103 | x_mean, x_logvar = self.p_x(z1_q, z2_q, x=x)
104 | else:
105 | x_mean, x_logvar = self.p_x(z1_q, z2_q)
106 | return x_mean, x_logvar, (z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar)
107 |
108 |
--------------------------------------------------------------------------------
/models/AbsModel.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch
4 | import torch.utils.data
5 | from models.BaseModel import BaseModel
6 | from utils.distributions import log_normal_diag
7 |
8 |
9 | class AbsModel(BaseModel):
10 | def __init__(self, args):
11 | super(AbsModel, self).__init__(args)
12 |
13 | def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices):
14 | z_q, z_q_mean, z_q_logvar = latent_stats
15 | if exemplars_embedding is None and self.args.prior == 'exemplar_prior':
16 | exemplars_embedding = self.get_exemplar_set(z_q_mean, z_q_logvar, dataset, cache, x_indices)
17 | log_p_z = self.log_p_z(z=(z_q, x_indices), exemplars_embedding=exemplars_embedding)
18 | log_q_z = log_normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
19 | return -(log_p_z - log_q_z)
20 |
21 | def generate_x_from_z(self, z, with_reparameterize=True):
22 | generated_x, _ = self.p_x(z)
23 | try:
24 | if self.args.use_logit is True:
25 | return self.logit_inverse(generated_x)
26 | else:
27 | return generated_x
28 | except:
29 | return generated_x
30 |
31 | def p_x(self, z):
32 | if 'conv' in self.args.model_name:
33 | z = z.reshape(-1, self.bottleneck, self.args.input_size[1]//4, self.args.input_size[1]//4)
34 | z = self.p_x_layers(z)
35 | x_mean = self.p_x_mean(z)
36 | if self.args.input_type == 'binary':
37 | x_logvar = torch.zeros(1, np.prod(self.args.input_size))
38 | else:
39 | if self.args.use_logit is False:
40 | x_mean = torch.clamp(x_mean, min=0.+1./512., max=1.-1./512.)
41 | x_logvar = self.decoder_logstd*x_mean.new_ones(size=x_mean.shape)
42 | return x_mean.reshape(-1, np.prod(self.args.input_size)), x_logvar.reshape(-1, np.prod(self.args.input_size))
43 |
44 | def forward(self, x, label=0, num_categories=10):
45 | z_q_mean, z_q_logvar = self.q_z(x)
46 |
47 | z_q = self.reparameterize(z_q_mean, z_q_logvar)
48 | x_mean, x_logvar = self.p_x(z_q)
49 | return x_mean, x_logvar, (z_q, z_q_mean, z_q_logvar)
50 |
--------------------------------------------------------------------------------
/models/BaseModel.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch
4 | import torch.utils.data
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | from utils.nn import normal_init, NonLinear
8 | from utils.distributions import log_normal_diag_vectorized
9 | import math
10 | from utils.nn import he_init
11 | from utils.distributions import pairwise_distance
12 | from utils.distributions import log_bernoulli, log_normal_diag, log_normal_standard, log_logistic_256
13 | from abc import ABC, abstractmethod
14 |
15 |
16 | class BaseModel(nn.Module, ABC):
17 | def __init__(self, args):
18 | super(BaseModel, self).__init__()
19 | print("constructor")
20 | self.args = args
21 |
22 | if self.args.prior == 'vampprior':
23 | self.add_pseudoinputs()
24 |
25 | if self.args.prior == 'exemplar_prior':
26 | self.prior_log_variance = torch.nn.Parameter(torch.randn((1)))
27 |
28 | if self.args.input_type == 'binary':
29 | self.p_x_mean = NonLinear(self.args.hidden_size, np.prod(self.args.input_size), activation=nn.Sigmoid())
30 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
31 | self.p_x_mean = NonLinear(self.args.hidden_size, np.prod(self.args.input_size))
32 | self.p_x_logvar = NonLinear(self.args.hidden_size, np.prod(self.args.input_size),
33 | activation=nn.Hardtanh(min_val=-4.5, max_val=0))
34 | self.decoder_logstd = torch.nn.Parameter(torch.tensor([0.], requires_grad=True))
35 |
36 | self.create_model(args)
37 | self.he_initializer()
38 |
39 | def he_initializer(self):
40 | print("he initializer")
41 |
42 | for m in self.modules():
43 | if isinstance(m, nn.Linear):
44 | he_init(m)
45 |
46 | @abstractmethod
47 | def create_model(self, args):
48 | pass
49 |
50 | @abstractmethod
51 | def kl_loss(self, latent_stats, exemplars_embeddin, dataset, cache, x_indices):
52 | pass
53 |
54 | def reconstruction_loss(self, x, x_mean, x_logvar):
55 | if self.args.input_type == 'binary':
56 | return log_bernoulli(x, x_mean, dim=1)
57 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
58 | if self.args.use_logit is True:
59 | return log_normal_diag(x, x_mean, x_logvar, dim=1)
60 | else:
61 | return log_logistic_256(x, x_mean, x_logvar, dim=1)
62 | else:
63 | raise Exception('Wrong input type!')
64 |
65 | def calculate_loss(self, x, beta=1., average=False,
66 | exemplars_embedding=None, cache=None, dataset=None):
67 | x, x_indices = x
68 | x_mean, x_logvar, latent_stats = self.forward(x)
69 | RE = self.reconstruction_loss(x, x_mean, x_logvar)
70 | KL = self.kl_loss(latent_stats, exemplars_embedding, dataset, cache, x_indices)
71 | loss = -RE + beta*KL
72 | if average:
73 | loss = torch.mean(loss)
74 | RE = torch.mean(RE)
75 | KL = torch.mean(KL)
76 |
77 | return loss, RE, KL
78 |
79 | def reparameterize(self, mu, logvar):
80 | std = logvar.mul(0.5).exp_()
81 | eps = mu.new_empty(size=std.shape).normal_()
82 | return eps.mul(std).add_(mu)
83 |
84 | def log_p_z_vampprior(self, z, exemplars_embedding):
85 | if exemplars_embedding is None:
86 | C = self.args.number_components
87 | X = self.means(self.idle_input)
88 | z_p_mean, z_p_logvar = self.q_z(X, prior=True) # C x M
89 | else:
90 | C = torch.tensor(self.args.number_components).float()
91 | z_p_mean, z_p_logvar = exemplars_embedding
92 |
93 | z_expand = z.unsqueeze(1)
94 | means = z_p_mean.unsqueeze(0)
95 | logvars = z_p_logvar.unsqueeze(0)
96 | return log_normal_diag(z_expand, means, logvars, dim=2) - math.log(C)
97 |
98 | def log_p_z_exemplar(self, z, z_indices, exemplars_embedding, test):
99 | centers, center_log_variance, center_indices = exemplars_embedding
100 | denominator = torch.tensor(len(centers)).expand(len(z)).float().to(self.args.device)
101 | center_log_variance = center_log_variance[0, :].unsqueeze(0)
102 | prob, _ = log_normal_diag_vectorized(z, centers, center_log_variance) # MB x C
103 | if test is False and self.args.no_mask is False:
104 | mask = z_indices.expand(-1, len(center_indices)) \
105 | == center_indices.squeeze().unsqueeze(0).expand(len(z_indices), -1)
106 | prob.masked_fill_(mask, value=float('-inf'))
107 | denominator = denominator - mask.sum(dim=1).float()
108 | prob -= torch.log(denominator).unsqueeze(1)
109 | return prob
110 |
111 | def log_p_z(self, z, exemplars_embedding, sum=True, test=None):
112 | z, z_indices = z
113 | if test is None:
114 | test = not self.training
115 | if self.args.prior == 'standard':
116 | return log_normal_standard(z, dim=1)
117 | elif self.args.prior == 'vampprior':
118 | prob = self.log_p_z_vampprior(z, exemplars_embedding)
119 | elif self.args.prior == 'exemplar_prior':
120 | prob = self.log_p_z_exemplar(z, z_indices, exemplars_embedding, test)
121 | else:
122 | raise Exception('Wrong name of the prior!')
123 | if sum:
124 | prob_max, _ = torch.max(prob, 1) # MB x 1
125 | log_prior = prob_max + torch.log(torch.sum(torch.exp(prob - prob_max.unsqueeze(1)), 1)) # MB x 1
126 | else:
127 | return prob
128 | return log_prior
129 |
130 | def add_pseudoinputs(self):
131 | nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
132 | self.means = NonLinear(self.args.number_components, np.prod(self.args.input_size), bias=False, activation=nonlinearity)
133 | # init pseudo-inputs
134 | if self.args.use_training_data_init:
135 | self.means.linear.weight.data = self.args.pseudoinputs_mean
136 | else:
137 | normal_init(self.means.linear, self.args.pseudoinputs_mean, self.args.pseudoinputs_std)
138 | self.idle_input = Variable(torch.eye(self.args.number_components, self.args.number_components), requires_grad=False)
139 | self.idle_input = self.idle_input.to(self.args.device)
140 |
141 | def generate_z_interpolate(self, exemplars_embedding=None, dim=0):
142 | new_zs = []
143 | exemplars_embedding, _, _ = exemplars_embedding
144 | step_counts = 10
145 | step = (exemplars_embedding[1] - exemplars_embedding[0])/step_counts
146 | for i in range(step_counts):
147 | new_z = exemplars_embedding[0].clone()
148 | new_z += i*step
149 | new_zs.append(new_z.unsqueeze(0))
150 | return torch.cat(new_zs, dim=0)
151 |
152 | def generate_z(self, N=25, dataset=None):
153 | if self.args.prior == 'standard':
154 | z_sample_rand = torch.FloatTensor(N, self.args.z1_size).normal_().to(self.args.device)
155 | elif self.args.prior == 'vampprior':
156 | means = self.means(self.idle_input)[0:N]
157 | z_sample_gen_mean, z_sample_gen_logvar = self.q_z(means)
158 | z_sample_rand = self.reparameterize(z_sample_gen_mean, z_sample_gen_logvar)
159 | z_sample_rand = z_sample_rand.to(self.args.device)
160 | elif self.args.prior == 'exemplar_prior':
161 | rand_indices = torch.randint(low=0, high=self.args.training_set_size, size=(N,))
162 | exemplars = dataset.tensors[0][rand_indices]
163 | z_sample_gen_mean, z_sample_gen_logvar = self.q_z(exemplars.to(self.args.device), prior=True)
164 | z_sample_rand = self.reparameterize(z_sample_gen_mean, z_sample_gen_logvar)
165 | z_sample_rand = z_sample_rand.to(self.args.device)
166 | return z_sample_rand
167 |
168 | def reference_based_generation_z(self, N=25, reference_image=None):
169 | pseudo, log_var = self.q_z(reference_image.to(self.args.device), prior=True)
170 | pseudo = pseudo.unsqueeze(1).expand(-1, N, -1).reshape(-1, pseudo.shape[-1])
171 | log_var = log_var[0].unsqueeze(0).expand(len(pseudo), -1)
172 | z_sample_rand = self.reparameterize(pseudo, log_var)
173 | z_sample_rand = z_sample_rand.reshape(-1, N, pseudo.shape[1])
174 | return z_sample_rand
175 |
176 | def reconstruct_x(self, x):
177 | x_reconstructed, _, z = self.forward(x)
178 | if self.args.model_name == 'pixelcnn':
179 | x_reconstructed = self.pixelcnn_generate(z[0].reshape(-1, self.args.z1_size), z[3].reshape(-1, self.args.z2_size))
180 | return x_reconstructed
181 |
182 | def logit_inverse(self, x):
183 | sigmoid = torch.nn.Sigmoid()
184 | lambd = self.args.lambd
185 | return ((sigmoid(x) - lambd)/(1-2*lambd))
186 |
187 | def generate_x(self, N=25, dataset=None):
188 | z2_sample_rand = self.generate_z(N=N, dataset=dataset)
189 | return self.generate_x_from_z(z2_sample_rand)
190 |
191 | def reference_based_generation_x(self, N=25, reference_image=None):
192 | z2_sample_rand = \
193 | self.reference_based_generation_z(N=N, reference_image=reference_image)
194 | generated_x = self.generate_x_from_z(z2_sample_rand)
195 | return generated_x
196 |
197 | def generate_x_interpolate(self, exemplars_embedding, dim=0):
198 | zs = self.generate_z_interpolate(exemplars_embedding, dim=dim)
199 | print(zs.shape)
200 | return self.generate_x_from_z(zs, with_reparameterize=False)
201 |
202 | def reshape_variance(self, variance, shape):
203 | return variance[0]*torch.ones(shape).to(self.args.device)
204 |
205 | def q_z(self, x, prior=False):
206 | if 'conv' in self.args.model_name or 'pixelcnn'==self.args.model_name:
207 | x = x.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])
208 | h = self.q_z_layers(x)
209 | if self.args.model_name == 'convhvae_2level' or self.args.model_name=='pixelcnn':
210 | h = h.view(x.size(0), -1)
211 | z_q_mean = self.q_z_mean(h)
212 | if prior is True:
213 | if self.args.prior == 'exemplar_prior':
214 | z_q_logvar = self.prior_log_variance * torch.ones((x.shape[0], self.args.z1_size)).to(self.args.device)
215 | if self.args.model_name == 'newconvhvae_2level':
216 | z_q_logvar = z_q_logvar.reshape(-1, 4, 4, 4)
217 | else:
218 | z_q_logvar = self.q_z_logvar(h)
219 | else:
220 | z_q_logvar = self.q_z_logvar(h)
221 | return z_q_mean.reshape(-1, self.args.z1_size), z_q_logvar.reshape(-1, self.args.z1_size)
222 |
223 | def cache_z(self, dataset, prior=True, cuda=True):
224 | cached_z = []
225 | cached_log_var = []
226 | caching_batch_size = 10000
227 | num_batchs = math.ceil(len(dataset) / caching_batch_size)
228 | for i in range(num_batchs):
229 | if len(dataset[0]) == 3:
230 | batch_data, batch_indices, _ = dataset[i * caching_batch_size:(i + 1) * caching_batch_size]
231 | else:
232 | batch_data, _ = dataset[i * caching_batch_size:(i + 1) * caching_batch_size]
233 |
234 | exemplars_embedding, log_variance_z = self.q_z(batch_data.to(self.args.device), prior=prior)
235 | cached_z.append(exemplars_embedding)
236 | cached_log_var.append(log_variance_z)
237 | cached_z = torch.cat(cached_z, dim=0)
238 | cached_log_var = torch.cat(cached_log_var, dim=0)
239 | cached_z = cached_z.to(self.args.device)
240 | cached_log_var = cached_log_var.to(self.args.device)
241 | return cached_z, cached_log_var
242 |
243 | def get_exemplar_set(self, z_mean, z_log_var, dataset, cache, x_indices):
244 | if self.args.approximate_prior is False:
245 | exemplars_indices = torch.randint(low=0, high=self.args.training_set_size,
246 | size=(self.args.number_components, ))
247 | exemplars_z, log_variance = self.q_z(dataset.tensors[0][exemplars_indices].to(self.args.device), prior=True)
248 | exemplar_set = (exemplars_z, log_variance, exemplars_indices.to(self.args.device))
249 | else:
250 | exemplar_set = self.get_approximate_nearest_exemplars(
251 | z=(z_mean, z_log_var, x_indices),
252 | dataset=dataset,
253 | cache=cache)
254 | return exemplar_set
255 |
256 | def get_approximate_nearest_exemplars(self, z, cache, dataset):
257 | exemplars_indices = torch.randint(low=0, high=self.args.training_set_size,
258 | size=(self.args.number_components, )).to(self.args.device)
259 | z, _, indices = z
260 | cached_z, cached_log_variance = cache
261 | cached_z[indices.reshape(-1)] = z
262 | sub_cache = cached_z[exemplars_indices, :]
263 | _, nearest_indices = pairwise_distance(z, sub_cache) \
264 | .topk(k=self.args.approximate_k, largest=False, dim=1)
265 | nearest_indices = torch.unique(nearest_indices.view(-1))
266 | exemplars_indices = exemplars_indices[nearest_indices].view(-1)
267 | exemplars = dataset.tensors[0][exemplars_indices].to(self.args.device)
268 | exemplars_z, log_variance = self.q_z(exemplars, prior=True)
269 | cached_z[exemplars_indices] = exemplars_z
270 | exemplar_set = (exemplars_z, log_variance, exemplars_indices)
271 | return exemplar_set
272 |
273 |
274 |
--------------------------------------------------------------------------------
/models/HVAE_2level.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch
4 | import torch.utils.data
5 | import torch.nn as nn
6 | from torch.nn import Linear
7 | from utils.nn import GatedDense, NonLinear
8 | from models.AbsHModel import BaseHModel
9 |
10 |
11 | class VAE(BaseHModel):
12 | def __init__(self, args):
13 | super(VAE, self).__init__(args)
14 |
15 | def create_model(self, args):
16 | print("create_model")
17 |
18 | # becasue super is using h_size
19 | self.args = args
20 |
21 | # encoder: q(z2 | x)
22 | self.q_z_layers = nn.Sequential(
23 | GatedDense(np.prod(self.args.input_size), self.args.hidden_size),
24 | GatedDense(self.args.hidden_size, self.args.hidden_size)
25 | )
26 |
27 | self.q_z_mean = Linear(self.args.hidden_size, self.args.z2_size)
28 |
29 | if args.same_variational_var:
30 | self.q_z_logvar = torch.nn.Parameter(torch.randn((1)))
31 | else:
32 | self.q_z_logvar = NonLinear(self.args.hidden_size, self.args.z2_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
33 |
34 | # encoder: q(z1 | x, z2)
35 | self.q_z1_layers_x = nn.Sequential(
36 | GatedDense(np.prod(self.args.input_size), self.args.hidden_size)
37 | )
38 | self.q_z1_layers_z2 = nn.Sequential(
39 | GatedDense(self.args.z2_size, self.args.hidden_size)
40 | )
41 | self.q_z1_layers_joint = nn.Sequential(
42 | GatedDense(2 * self.args.hidden_size, self.args.hidden_size)
43 | )
44 |
45 | self.q_z1_mean = Linear(self.args.hidden_size, self.args.z1_size)
46 | self.q_z1_logvar = NonLinear(self.args.hidden_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
47 |
48 | # decoder: p(z1 | z2)
49 | self.p_z1_layers_z2 = nn.Sequential(
50 | GatedDense(self.args.z2_size, self.args.hidden_size),
51 | GatedDense(self.args.hidden_size, self.args.hidden_size)
52 | )
53 |
54 | self.p_z1_mean = Linear(self.args.hidden_size, self.args.z1_size)
55 | self.p_z1_logvar = NonLinear(self.args.hidden_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6.,max_val=2.))
56 |
57 | # decoder: p(x | z1, z2)
58 | self.p_x_layers_z1 = nn.Sequential(
59 | GatedDense(self.args.z1_size, self.args.hidden_size)
60 | )
61 | self.p_x_layers_z2 = nn.Sequential(
62 | GatedDense(self.args.z2_size, self.args.hidden_size)
63 | )
64 | self.p_x_layers_joint = nn.Sequential(
65 | GatedDense(2 * self.args.hidden_size, self.args.hidden_size)
66 | )
67 |
68 |
69 |
--------------------------------------------------------------------------------
/models/PixelCNN.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch.nn as nn
4 | from utils.nn import GatedDense, NonLinear, \
5 | Conv2d, GatedConv2d, MaskedConv2d, PixelSNAIL
6 | from models.AbsHModel import BaseHModel
7 | import torch
8 |
9 | class VAE(BaseHModel):
10 | def __init__(self, args):
11 | super(VAE, self).__init__(args)
12 |
13 | def create_model(self, args):
14 | if args.dataset_name == 'freyfaces':
15 | self.h_size = 210
16 | elif args.dataset_name == 'cifar10' or args.dataset_name == 'svhn':
17 | self.h_size = 384
18 | else:
19 | self.h_size = 294
20 |
21 | # encoder: q(z2 | x)
22 | self.q_z_layers = nn.Sequential(
23 | GatedConv2d(self.args.input_size[0], 32, 7, 1, 3),
24 | GatedConv2d(32, 32, 3, 2, 1),
25 | GatedConv2d(32, 64, 5, 1, 2),
26 | GatedConv2d(64, 64, 3, 2, 1),
27 | GatedConv2d(64, 6, 3, 1, 1)
28 | )
29 | # linear layers
30 | self.q_z_mean = NonLinear(self.h_size, self.args.z2_size, activation=None)
31 | self.q_z_logvar = NonLinear(self.h_size, self.args.z2_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
32 |
33 | # encoder: q(z1|x,z2)
34 | # PROCESSING x
35 | self.q_z1_layers_x = nn.Sequential(
36 | GatedConv2d(self.args.input_size[0], 32, 3, 1, 1),
37 | GatedConv2d(32, 32, 3, 2, 1),
38 | GatedConv2d(32, 64, 3, 1, 1),
39 | GatedConv2d(64, 64, 3, 2, 1),
40 | GatedConv2d(64, 6, 3, 1, 1)
41 | )
42 | # PROCESSING Z2
43 | self.q_z1_layers_z2 = nn.Sequential(
44 | GatedDense(self.args.z2_size, self.h_size)
45 | )
46 | # PROCESSING JOINT
47 | self.q_z1_layers_joint = nn.Sequential(
48 | GatedDense( 2 * self.h_size, 300)
49 | )
50 | # linear layers
51 | self.q_z1_mean = NonLinear(300, self.args.z1_size, activation=None)
52 | self.q_z1_logvar = NonLinear(300, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
53 |
54 | # decoder p(z1|z2)
55 | self.p_z1_layers_z2 = nn.Sequential(
56 | GatedDense(self.args.z2_size, 300),
57 | GatedDense(300, 300)
58 | )
59 | self.p_z1_mean = NonLinear(300, self.args.z1_size, activation=None)
60 | self.p_z1_logvar = NonLinear(300, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
61 |
62 | # decoder: p(x | z)
63 | self.p_x_layers_z1 = nn.Sequential(
64 | GatedDense(self.args.z1_size, np.prod(self.args.input_size))
65 | )
66 | self.p_x_layers_z2 = nn.Sequential(
67 | GatedDense(self.args.z2_size, np.prod(self.args.input_size))
68 | )
69 |
70 | # decoder: p(x | z)
71 | act = nn.ReLU(True)
72 | #self.pixelcnn = nn.Sequential(
73 | # MaskedConv2d('A', self.args.input_size[0] + 2 * self.args.input_size[0], 64, 3, 1, 1, bias=False),
74 | # nn.BatchNorm2d(64), act,
75 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act,
76 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act,
77 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act,
78 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act,
79 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act,
80 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act,
81 | # MaskedConv2d('B', 64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), act
82 | #)
83 | self.pixelcnn = PixelSNAIL([28, 28], 64, 64, 3, 1, 4, 64)
84 |
85 | if self.args.input_type == 'binary':
86 | self.p_x_mean = Conv2d(64, 1, 1, 1, 0, activation=nn.Sigmoid())
87 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
88 | self.p_x_mean = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Sigmoid(), bias=False)
89 | self.p_x_logvar = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Hardtanh(min_val=-4.5, max_val=0.), bias=False)
90 |
91 | def pixelcnn_generate(self, z1, z2):
92 | # Sampling from PixelCNN
93 | x_zeros = torch.zeros(
94 | (z1.size(0), self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]))
95 | x_zeros = x_zeros.to(self.args.device)
96 |
97 | for i in range(self.args.input_size[1]):
98 | for j in range(self.args.input_size[2]):
99 | samples_mean, samples_logvar = self.p_x(z1, z2, x=x_zeros.detach())
100 | samples_mean = samples_mean.view(samples_mean.size(0), self.args.input_size[0], self.args.input_size[1],
101 | self.args.input_size[2])
102 |
103 | if self.args.input_type == 'binary':
104 | probs = samples_mean[:, :, i, j].data
105 | x_zeros[:, :, i, j] = torch.bernoulli(probs).float()
106 | samples_gen = samples_mean
107 |
108 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
109 | binsize = 1. / 256.
110 | samples_logvar = samples_logvar.view(samples_mean.size(0), self.args.input_size[0],
111 | self.args.input_size[1], self.args.input_size[2])
112 | means = samples_mean[:, :, i, j].data
113 | logvar = samples_logvar[:, :, i, j].data
114 | # sample from logistic distribution
115 | u = torch.rand(means.size()).cuda()
116 | y = torch.log(u) - torch.log(1. - u)
117 | sample = means + torch.exp(logvar) * y
118 | x_zeros[:, :, i, j] = torch.floor(sample / binsize) * binsize
119 | samples_gen = samples_mean
120 | return samples_gen
121 |
122 | def forward(self, x):
123 | x = x.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])
124 | return super(VAE, self).forward(x)
125 |
126 |
127 |
--------------------------------------------------------------------------------
/models/VAE.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch
4 | import torch.utils.data
5 | import torch.nn as nn
6 | from torch.nn import Linear
7 | from utils.nn import GatedDense, NonLinear
8 | from models.AbsModel import AbsModel
9 |
10 |
11 | class VAE(AbsModel):
12 | def __init__(self, args):
13 | super(VAE, self).__init__(args)
14 |
15 | def create_model(self, args, train_data_size=None):
16 | self.train_data_size = train_data_size
17 | self.q_z_layers = nn.Sequential(
18 | GatedDense(np.prod(self.args.input_size), self.args.hidden_size, no_attention=self.args.no_attention),
19 | GatedDense(self.args.hidden_size, self.args.hidden_size, no_attention=self.args.no_attention)
20 | )
21 | self.q_z_mean = Linear(self.args.hidden_size, self.args.z1_size)
22 | if args.same_variational_var:
23 | self.q_z_logvar = torch.nn.Parameter(torch.randn((1)))
24 | else:
25 | self.q_z_logvar = NonLinear(self.args.hidden_size,
26 | self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
27 |
28 | self.p_x_layers = nn.Sequential(
29 | GatedDense(self.args.z1_size, self.args.hidden_size, no_attention=self.args.no_attention),
30 | GatedDense(self.args.hidden_size, self.args.hidden_size, no_attention=self.args.no_attention))
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/models/__init__.py
--------------------------------------------------------------------------------
/models/convHVAE_2level.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch.nn as nn
4 | from utils.nn import GatedDense, NonLinear, \
5 | Conv2d, GatedConv2d
6 | from models.AbsHModel import BaseHModel
7 |
8 |
9 | class VAE(BaseHModel):
10 | def __init__(self, args):
11 | super(VAE, self).__init__(args)
12 |
13 | def create_model(self, args):
14 | if args.dataset_name == 'freyfaces':
15 | self.h_size = 210
16 | elif args.dataset_name == 'cifar10' or args.dataset_name == 'svhn':
17 | self.h_size = 384
18 | else:
19 | self.h_size = 294
20 |
21 | fc_size = 300
22 |
23 | # encoder: q(z2 | x)
24 | self.q_z_layers = nn.Sequential(
25 | GatedConv2d(self.args.input_size[0], 32, 7, 1, 3, no_attention=args.no_attention),
26 | GatedConv2d(32, 32, 3, 2, 1, no_attention=args.no_attention),
27 | GatedConv2d(32, 64, 5, 1, 2, no_attention=args.no_attention),
28 | GatedConv2d(64, 64, 3, 2, 1, no_attention=args.no_attention),
29 | GatedConv2d(64, 6, 3, 1, 1, no_attention=args.no_attention)
30 | )
31 |
32 | # linear layers
33 | self.q_z_mean = NonLinear(self.h_size, self.args.z2_size, activation=None)
34 |
35 | # SAME VARAITIONAL VAR TO SEE IF IT HELPS
36 | self.q_z_logvar = NonLinear(self.h_size, self.args.z2_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
37 |
38 | # encoder: q(z1|x,z2)
39 | # PROCESSING x
40 | self.q_z1_layers_x = nn.Sequential(
41 | GatedConv2d(self.args.input_size[0], 32, 3, 1, 1, no_attention=args.no_attention),
42 | GatedConv2d(32, 32, 3, 2, 1, no_attention=args.no_attention),
43 | GatedConv2d(32, 64, 3, 1, 1, no_attention=args.no_attention),
44 | GatedConv2d(64, 64, 3, 2, 1, no_attention=args.no_attention),
45 | GatedConv2d(64, 6, 3, 1, 1, no_attention=args.no_attention)
46 | )
47 | # PROCESSING Z2
48 | self.q_z1_layers_z2 = nn.Sequential(GatedDense(self.args.z2_size, self.h_size))
49 |
50 | # PROCESSING JOINT
51 | self.q_z1_layers_joint = nn.Sequential(GatedDense(2* self.h_size, fc_size))
52 |
53 | # linear layers
54 | self.q_z1_mean = NonLinear(fc_size, self.args.z1_size, activation=None)
55 | self.q_z1_logvar = NonLinear(fc_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
56 |
57 | # decoder p(z1|z2)
58 | self.p_z1_layers_z2 = nn.Sequential(
59 | GatedDense(self.args.z2_size, fc_size, no_attention=args.no_attention),
60 | GatedDense(fc_size, fc_size, no_attention=args.no_attention)
61 | )
62 | self.p_z1_mean = NonLinear(fc_size, self.args.z1_size, activation=None)
63 | self.p_z1_logvar = NonLinear(fc_size, self.args.z1_size, activation=nn.Hardtanh(min_val=-6., max_val=2.))
64 |
65 | # decoder: p(x | z)
66 | self.p_x_layers_z1 = nn.Sequential(
67 | GatedDense(self.args.z1_size, fc_size, no_attention=args.no_attention)
68 | )
69 | self.p_x_layers_z2 = nn.Sequential(
70 | GatedDense(self.args.z2_size, fc_size, no_attention=args.no_attention)
71 | )
72 |
73 | self.p_x_layers_joint_pre = nn.Sequential(
74 | GatedDense(2 * fc_size, np.prod(self.args.input_size), no_attention=args.no_attention)
75 | )
76 |
77 | # decoder: p(x | z)
78 | self.p_x_layers_joint = nn.Sequential(
79 | GatedConv2d(self.args.input_size[0], 64, 3, 1, 1, no_attention=args.no_attention),
80 | GatedConv2d(64, 64, 3, 1, 1, no_attention=args.no_attention),
81 | GatedConv2d(64, 64, 3, 1, 1, no_attention=args.no_attention),
82 | GatedConv2d(64, 64, 3, 1, 1, no_attention=args.no_attention),
83 | )
84 |
85 | if self.args.input_type == 'binary':
86 | self.p_x_mean = Conv2d(64, 1, 1, 1, 0, activation=nn.Sigmoid())
87 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
88 | self.p_x_mean = Conv2d(64, self.args.input_size[0], 1, 1, 0)
89 | self.p_x_logvar = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Hardtanh(min_val=-4.5, max_val=0.))
90 | elif self.args.input_type == 'pca':
91 | self.p_x_mean = Conv2d(64, 1, 1, 1, 0)
92 | self.p_x_logvar = Conv2d(64, self.args.input_size[0], 1, 1, 0, activation=nn.Hardtanh(min_val=-4.5, max_val=0.))
93 |
94 | # THE MODEL: FORWARD PASS
95 | def forward(self, x):
96 | x = x.view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])
97 | return super(VAE, self).forward(x)
98 |
99 |
--------------------------------------------------------------------------------
/models/fully_conv.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.utils.data
4 | import torch.nn as nn
5 | from models.AbsModel import AbsModel
6 | from torch.nn.utils import weight_norm
7 |
8 | class VAE(AbsModel):
9 | def __init__(self, args):
10 | super(VAE, self).__init__(args)
11 |
12 | def create_model(self, args, train_data_size=None):
13 | class block(nn.Module):
14 | def __init__(self, input_size, output_size, stride=1, kernel=3, padding=1):
15 | super(block, self).__init__()
16 | self.normalization = nn.BatchNorm2d(input_size)
17 | self.conv1 = weight_norm(nn.Conv2d(input_size, output_size, kernel_size=kernel, stride=stride, padding=padding,
18 | bias=True))
19 | self.activation = torch.nn.ELU()
20 | self.f = torch.nn.Sequential(self.activation, self.conv1)
21 |
22 | def forward(self, x):
23 | return x + self.f(x)
24 |
25 | self.train_data_size = train_data_size
26 | self.cs = 48
27 | self.bottleneck=self.args.bottleneck
28 | self.q_z_layers = nn.Sequential(
29 | weight_norm(nn.Conv2d(in_channels=self.args.input_size[0], out_channels=self.cs, kernel_size=3, stride=2, padding=1)),
30 | nn.ELU(),
31 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
32 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
33 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
34 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
35 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
36 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
37 | weight_norm(nn.Conv2d(in_channels=self.cs, out_channels=self.cs*2, kernel_size=3, stride=2, padding=1)),
38 | nn.ELU(),
39 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
40 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
41 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
42 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
43 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
44 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
45 | # nn.Conv2d(in_channels=self.cs, out_channels=self.cs, kernel_size=3, stride=2, padding=1),
46 | # nn.ELU(),
47 | # nn.Conv2d(in_channels=self.cs, out_channels=self.cs, kernel_size=3, stride=1, padding=1),
48 | # nn.ELU(),
49 | )
50 | self.q_z_mean = weight_norm(nn.Conv2d(in_channels=self.cs*2, out_channels=self.bottleneck, kernel_size=3, stride=1, padding=1))
51 | # self.q_z_mean = weight_norm(nn.Linear(self.args.hidden_size, self.args.z1_size))
52 | self.q_z_logvar = weight_norm(nn.Conv2d(in_channels=self.cs*2, out_channels=self.bottleneck, kernel_size=3, stride=1, padding=1))
53 | # self.q_z_logvar = weight_norm(nn.Linear(self.args.hidden_size, self.args.z1_size))
54 | self.p_x_layers = nn.Sequential(
55 | # weight_norm(nn.Linear(self.args.z1_size, self.args.hidden_size)),
56 | nn.Upsample(scale_factor=2),
57 | weight_norm(nn.Conv2d(in_channels=self.bottleneck, out_channels=self.cs*2, kernel_size=3, stride=1, padding=1)),
58 | nn.ELU(),
59 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
60 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
61 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
62 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
63 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
64 | block(input_size=self.cs*2, output_size=self.cs*2, stride=1, kernel=3, padding=1),
65 | nn.Upsample(scale_factor=2),
66 | weight_norm(nn.Conv2d(in_channels=self.cs*2, out_channels=self.cs, kernel_size=3, stride=1, padding=1)),
67 | nn.ELU(),
68 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
69 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
70 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
71 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
72 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
73 | block(input_size=self.cs, output_size=self.cs, stride=1, kernel=3, padding=1),
74 | # nn.Upsample(size=(28, 28)),
75 | )
76 |
77 | if self.args.input_type == 'binary':
78 | self.p_x_mean = nn.Sequential(nn.Conv2d(in_channels=self.cs, out_channels=self.args.input_size[0], kernel_size=3, stride=1, padding=1), nn.Sigmoid())
79 | elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
80 | self.p_x_mean = weight_norm(nn.Conv2d(in_channels=self.cs, out_channels=self.args.input_size[0], kernel_size=3, stride=1, padding=1))
81 | self.p_x_logvar = nn.Conv2d(in_channels=self.cs, out_channels=self.args.input_size[0], kernel_size=3, stride=1, padding=1)
82 |
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint.pth
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint_best.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/checkpoint_best.pth
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_0.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_1.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_10.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_11.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_12.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_13.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_14.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_15.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_16.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_17.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_18.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_19.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_2.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_20.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_21.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_22.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_23.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_24.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_25.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_26.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_27.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_28.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_29.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_3.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_30.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_31.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_31.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_32.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_33.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_33.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_34.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_34.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_35.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_35.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_36.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_36.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_37.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_37.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_38.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_38.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_39.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_39.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_4.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_40.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_41.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_41.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_42.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_42.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_43.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_43.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_44.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_44.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_45.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_45.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_46.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_46.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_47.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_47.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_48.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_48.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_49.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_5.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_6.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_7.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_8.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generated/generated_9.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generations_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/generations_0.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/real.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/real.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/reconstructions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/reconstructions.png
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.config:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.config
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_kl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_kl
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_log_likelihood:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_log_likelihood
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_loss:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_loss
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_re:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.test_re
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_kl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_kl
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_loss:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_loss
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_re:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.train_re
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_kl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_kl
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_loss:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_loss
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_re:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae.val_re
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae_config.txt:
--------------------------------------------------------------------------------
1 | Namespace(MB=100, S=5000, approximate_k=10, approximate_prior=False, base_dir='/checkpoint/sajad/143803', batch_size=100, bottleneck=6, continuous=False, cuda=True, dataset_name='dynamic_mnist', dir_extra='exemplar_prior_on_dynamic_mnist_components=25000_lr=5e-4_model_name=vae_variance_type=shared_independent=True', dynamic_binarization=True, early_stopping_epochs=50, epochs=2000, hidden_size=300, input_size=[1, 28, 28], input_type='binary', just_evaluate=False, lambd=0.0001, lr=0.0005, model_name='vae', model_signature='2', no_attention=False, no_cuda=False, no_mask=False, number_components=25000, prior='exemplar_prior', pseudoinputs_mean=0.05, pseudoinputs_std=0.01, same_variational_var=False, seed=2, slurm_job_id='143803', slurm_task_id='', test_batch_size=100, training_set_size=50000, use_logit=False, use_training_data_init=False, use_whole_train=False, warmup=100, z1_size=40, z2_size=40)
2 | Namespace(MB=100, S=5000, approximate_k=10, approximate_prior=False, base_dir='/checkpoint/sajad/143803', batch_size=100, bottleneck=6, continuous=False, cuda=True, dataset_name='dynamic_mnist', dir_extra='exemplar_prior_on_dynamic_mnist_components=25000_lr=5e-4_model_name=vae_variance_type=shared_independent=True', dynamic_binarization=True, early_stopping_epochs=50, epochs=2000, hidden_size=300, input_size=[1, 28, 28], input_type='binary', just_evaluate=False, lambd=0.0001, lr=0.0005, model_name='vae', model_signature='2', no_attention=False, no_cuda=False, no_mask=False, number_components=25000, prior='exemplar_prior', pseudoinputs_mean=0.05, pseudoinputs_std=0.01, same_variational_var=False, seed=2, slurm_job_id='143803', slurm_task_id='', test_batch_size=100, training_set_size=50000, use_logit=False, use_training_data_init=False, use_whole_train=False, warmup=100, z1_size=40, z2_size=40)
3 | Namespace(MB=100, S=5000, approximate_k=10, approximate_prior=False, base_dir='/checkpoint/sajad/143803', batch_size=100, bottleneck=6, continuous=False, cuda=True, dataset_name='dynamic_mnist', dir_extra='exemplar_prior_on_dynamic_mnist_components=25000_lr=5e-4_model_name=vae_variance_type=shared_independent=True', dynamic_binarization=True, early_stopping_epochs=50, epochs=2000, hidden_size=300, input_size=[1, 28, 28], input_type='binary', just_evaluate=False, lambd=0.0001, lr=0.0005, model_name='vae', model_signature='2', no_attention=False, no_cuda=False, no_mask=False, number_components=25000, prior='exemplar_prior', pseudoinputs_mean=0.05, pseudoinputs_std=0.01, same_variational_var=False, seed=2, slurm_job_id='143803', slurm_task_id='', test_batch_size=100, training_set_size=50000, use_logit=False, use_training_data_init=False, use_whole_train=False, warmup=100, z1_size=40, z2_size=40)
4 |
--------------------------------------------------------------------------------
/pretrained_model/exemplar_prior_on_dynamic_mnist_model_name=vae/1/vae_experiment_log.txt:
--------------------------------------------------------------------------------
1 | FINAL EVALUATION ON TEST SET
2 | LogL (TEST): 82.05
3 | LogL (TRAIN): 0.00
4 | ELBO (TEST): 85.50
5 | ELBO (TRAIN): 99.77
6 | RE: 61.04
7 | KL: 24.45
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | scipy
5 | sklearn
6 | opencv-python
7 | matplotlib
8 | wget
9 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/utils/__init__.py
--------------------------------------------------------------------------------
/utils/classify_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import time
5 | from utils.plot_images import imshow
6 | import matplotlib.pylab as plt
7 | import torchvision
8 | from pylab import rcParams
9 |
10 | rcParams['figure.figsize'] = 15, 15
11 |
12 |
13 | def compute_accuracy(classifier, model, loader, mean, args, dir=None, plot_mistakes=False):
14 | acc = 0
15 | mistakes_list = []
16 | for data, labels in loader:
17 | try:
18 | if model.args.use_logit is True and model.args.continuous is True:
19 | data = torch.round(model.logit_inverse(data) * 255) / 255
20 | except:
21 | pass
22 | labels = labels.to(args.device)
23 | pred = classifier(data.double().to(args.device) - mean)
24 | acc += torch.mean((labels == torch.argmax(pred, dim=1)).double())
25 | mistakes = (labels != torch.argmax(pred, dim=1))
26 | mistakes_list.append(data[mistakes])
27 | mistakes_list = torch.cat(mistakes_list, dim=0)
28 | if plot_mistakes is True:
29 | imshow(torchvision.utils.make_grid(mistakes_list.reshape(-1, *args.input_size)))
30 | # plt.show()
31 | plt.axis('off')
32 | plt.savefig(os.path.join(dir, 'mistakes.png'), bbox_inches='tight')
33 | acc /= len(loader)
34 | return acc
35 |
36 |
37 | def save_model(save_path, load_path, content):
38 | torch.save(content, save_path)
39 | os.rename(save_path, load_path)
40 |
41 |
42 | def load_model(load_path, model, optimizer=None):
43 | checkpoint = torch.load(load_path)
44 | model.load_state_dict(checkpoint['state_dict'])
45 | if optimizer is not None:
46 | optimizer.load_state_dict(checkpoint['optimizer'])
47 | return checkpoint
48 |
49 |
50 | def compute_loss(pred, label, args):
51 | held_out_percent = 0.1
52 |
53 | denom = torch.logsumexp(pred, dim=1, keepdim=True)
54 | prediction = pred - denom
55 |
56 | one_hot_label = torch.ones_like(prediction) * (held_out_percent / 10)
57 | one_hot_label[torch.arange(args.batch_size), label] += (1 - held_out_percent)
58 | return -torch.sum(prediction * one_hot_label, dim=1).mean()
59 |
60 |
61 | def classify_data(train_loader, val_loader, test_loader, dir, args, model):
62 | classifier = nn.Sequential(nn.Linear(784, args.hidden_units), nn.ReLU(),
63 | nn.Linear(args.hidden_units, args.hidden_units), nn.ReLU(),
64 | nn.Linear(args.hidden_units, 10)).double().to(args.device)
65 |
66 | lr = args.lr
67 |
68 | optimizer = torch.optim.SGD(classifier.parameters(), lr=lr, momentum=0.9)
69 | epochs = args.epochs
70 | mean = 0
71 | lr_lambda = lambda epoch: 1-(0.99)*(epoch/epochs)
72 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
73 | os.makedirs(dir, exist_ok=True)
74 |
75 | if os.path.exists(os.path.join(dir, 'checkpoint.pth')):
76 | checkpoint = load_model(os.path.join(dir, 'checkpoint.pth'),
77 | model=classifier,
78 | optimizer=optimizer)
79 | begin_epoch = checkpoint['epoch']
80 | else:
81 | begin_epoch = 1
82 |
83 | for epoch_number in range(begin_epoch, epochs + 1):
84 | start_time = time.time()
85 | if epoch_number % 10 == 0:
86 | content = {'epoch': epoch_number, 'state_dict': classifier.state_dict(),
87 | 'optimizer': optimizer.state_dict()}
88 | save_model(os.path.join(dir, 'checkpoint_temp.pth'),
89 | os.path.join(dir, 'checkpoint.pth'), content)
90 |
91 | print('epoch number:', epoch_number)
92 | for index, data in enumerate(train_loader):
93 |
94 | data, _, label = data
95 | data_augment = model.reference_based_generation_x(reference_image=data.detach(), N=1).squeeze().double()
96 | label_augment = label
97 |
98 | data_augment = data_augment.to(args.device)
99 | label_augment = label_augment.to(args.device)
100 |
101 | data = data.to(args.device).double()
102 | label = label.to(args.device).long()
103 |
104 | # imshow(torchvision.utils.make_grid(data.reshape(-1, *args.input_size)).detach())
105 | # plt.show()
106 | try:
107 | if model.args.use_logit is True and model.args.continuous is True:
108 | data = torch.round(model.logit_inverse(data) * 255) / 255
109 | except:
110 | pass
111 | data_augment = torch.round(data_augment * 255) / 255
112 |
113 | loss1 = compute_loss(classifier(data), label, args)
114 | loss2 = compute_loss(classifier(data_augment), label_augment, args)
115 |
116 | loss = args.hyper_lambda*loss1 + (1-args.hyper_lambda)*loss2
117 |
118 | optimizer.zero_grad()
119 | loss.backward()
120 | optimizer.step()
121 | scheduler.step(epoch=epoch_number)
122 |
123 | for param_group in optimizer.param_groups:
124 | print('learning rate:', param_group['lr'])
125 | break
126 |
127 | if val_loader is not None:
128 | val_acc = compute_accuracy(classifier, model, val_loader, mean, args)
129 | print('val acc', val_acc.item())
130 | test_acc = compute_accuracy(classifier, model, test_loader, mean, args)
131 | print('accuracy test:', test_acc.item())
132 | print("time:", time.time() - start_time)
133 |
134 | content = {'epoch': args.epochs, 'state_dict': classifier.state_dict(),
135 | 'optimizer': optimizer.state_dict()}
136 | save_model(os.path.join(dir, 'checkpoint_temp.pth'), os.path.join(dir, 'checkpoint.pth'), content)
137 | classifier.eval()
138 | if val_loader is not None:
139 | val_acc = compute_accuracy(classifier, model, val_loader, mean, args)
140 | print('accuracy val:', val_acc.item())
141 | else:
142 | val_acc = torch.zeros(1)
143 | test_acc = compute_accuracy(classifier, model, test_loader, mean, args, dir=dir, plot_mistakes=True)
144 | print('accuracy test:', test_acc.item())
145 | #
146 | #
147 | return (test_acc*10000).item()/100, (val_acc*10000).item()/100
148 |
--------------------------------------------------------------------------------
/utils/distributions.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.utils.data
4 | import math
5 |
6 | min_epsilon = 1e-5
7 | max_epsilon = 1.-1e-5
8 | log_sigmoid = torch.nn.LogSigmoid()
9 | log_2_pi = math.log(2*math.pi)
10 |
11 |
12 | def pairwise_distance(z, means):
13 | z = z.double()
14 | means = means.double()
15 | dist1 = (z**2).sum(dim=1).unsqueeze(1).expand(-1, means.shape[0]) #MB x C
16 | dist2 = (means**2).sum(dim=1).unsqueeze(0).expand(z.shape[0], -1) #MB x C
17 | dist3 = torch.mm(z, torch.transpose(means, 0, 1)) #MB x C
18 | return (dist1 + dist2 + - 2*dist3).float()
19 |
20 |
21 | def log_normal_diag_vectorized(x, mean, log_var):
22 | log_var_sqrt = log_var.mul(0.5).exp_()
23 | pair_dist = pairwise_distance(x/log_var_sqrt, mean/log_var_sqrt)
24 | log_normal = -0.5 * torch.sum(log_var+log_2_pi, dim=1) - 0.5*pair_dist
25 | return log_normal, pair_dist
26 |
27 |
28 | def log_normal_diag(x, mean, log_var, average=False, dim=None):
29 | log_normal = -0.5 * (log_var + log_2_pi + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
30 | if average:
31 | return torch.mean(log_normal, dim)
32 | else:
33 | return torch.sum(log_normal, dim)
34 |
35 |
36 | def log_normal_standard(x, average=False, dim=None):
37 | log_normal = -0.5 * torch.pow(x, 2) - 0.5 * log_2_pi*x.new_ones(size=x.shape)
38 | if average:
39 | return torch.mean(log_normal, dim)
40 | else:
41 | return torch.sum(log_normal, dim)
42 |
43 |
44 | def log_bernoulli(x, mean, average=False, dim=None):
45 | probs = torch.clamp( mean, min=min_epsilon, max=max_epsilon)
46 | log_bernoulli = x * torch.log(probs) + (1. - x) * torch.log(1. - probs)
47 |
48 | if average:
49 | return torch.mean(log_bernoulli, dim)
50 | else:
51 | return torch.sum(log_bernoulli, dim)
52 |
53 |
54 | def log_logistic_256(x, mean, logvar, average=False, reduce=True, dim=None):
55 | bin_size = 1. / 256.
56 | # implementation like https://github.com/openai/iaf/blob/master/tf_utils/distributions.py#L28
57 | scale = torch.exp(logvar)
58 | x = (torch.floor(x / bin_size) * bin_size - mean) / scale
59 | cdf_plus = torch.sigmoid(x + bin_size/scale)
60 | cdf_minus = torch.sigmoid(x)
61 | log_logist_256 = torch.log(cdf_plus - cdf_minus + 1e-7)
62 |
63 | if average:
64 | return torch.mean(log_logist_256, dim)
65 | else:
66 | return torch.sum(log_logist_256, dim)
67 |
68 |
--------------------------------------------------------------------------------
/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from utils.plot_images import plot_images
3 | import torch
4 | import time
5 | from scipy.special import logsumexp
6 | import numpy as np
7 | from utils.utils import load_model
8 | import torch.nn.functional as F
9 |
10 |
11 | def evaluate_loss(args, model, loader, dataset=None, exemplars_embedding=None):
12 | evaluateed_elbo, evaluate_re, evaluate_kl = 0, 0, 0
13 | model.eval()
14 | if exemplars_embedding is None:
15 | exemplars_embedding = load_all_pseudo_input(args, model, dataset)
16 |
17 | for data in loader:
18 | if len(data) == 3:
19 | data, _, _ = data
20 | else:
21 | data, _ = data
22 | data = data.to(args.device)
23 | x = data
24 | x_indices = None
25 | x = (x, x_indices)
26 | loss, RE, KL = model.calculate_loss(x, average=False, exemplars_embedding=exemplars_embedding)
27 | evaluateed_elbo += loss.sum().item()
28 | evaluate_re += -RE.sum().item()
29 | evaluate_kl += KL.sum().item()
30 | evaluateed_elbo /= len(loader.dataset)
31 | evaluate_re /= len(loader.dataset)
32 | evaluate_kl /= len(loader.dataset)
33 | return evaluateed_elbo, evaluate_re, evaluate_kl
34 |
35 |
36 | def visualize_reconstruction(test_samples, model, args, dir):
37 | samples_reconstruction = model.reconstruct_x(test_samples[0:25])
38 |
39 | if args.use_logit:
40 | test_samples = model.logit_inverse(test_samples)
41 | samples_reconstruction = model.logit_inverse(samples_reconstruction)
42 | plot_images(args, test_samples.cpu().numpy()[0:25], dir, 'real', size_x=5, size_y=5)
43 | plot_images(args, samples_reconstruction.cpu().numpy(), dir, 'reconstructions', size_x=5, size_y=5)
44 |
45 |
46 | def visualize_generation(dataset, model, args, dir):
47 | generation_rounds = 1
48 | for i in range(generation_rounds):
49 | samples_rand = model.generate_x(25, dataset=dataset)
50 | plot_images(args, samples_rand.cpu().numpy(), dir, 'generations_{}'.format(i), size_x=5, size_y=5)
51 | if args.prior == 'vampprior':
52 | pseudo_means = model.means(model.idle_input)
53 | plot_images(args, pseudo_means[0:25].cpu().numpy(), dir, 'pseudoinputs', size_x=5, size_y=5)
54 |
55 |
56 | def load_all_pseudo_input(args, model, dataset):
57 | if args.prior == 'exemplar_prior':
58 | exemplars_z, exemplars_log_var = model.cache_z(dataset)
59 | embedding = (exemplars_z, exemplars_log_var, torch.arange(len(exemplars_z)))
60 | elif args.prior == 'vampprior':
61 | pseudo_means = model.means(model.idle_input)
62 | if 'conv' in args.model_name:
63 | pseudo_means = pseudo_means.view(-1, args.input_size[0], args.input_size[1], args.input_size[2])
64 | embedding = model.q_z(pseudo_means, prior=True) # C x M
65 | elif args.prior == 'standard':
66 | embedding = None
67 | else:
68 | raise Exception("wrong name of prior")
69 | return embedding
70 |
71 |
72 | def calculate_likelihood(args, model, loader, S=5000, exemplars_embedding=None):
73 | likelihood_test = []
74 | batch_size_evaluation = 1
75 | auxilary_loader = torch.utils.data.DataLoader(loader.dataset, batch_size=batch_size_evaluation)
76 | t0 = time.time()
77 | for index, (data, _) in enumerate(auxilary_loader):
78 | data = data.to(args.device)
79 | if index % 100 == 0:
80 | print(time.time() - t0)
81 | t0 = time.time()
82 | print('{:.2f}%'.format(index / (1. * len(auxilary_loader)) * 100))
83 | x = data.expand(S, data.size(1))
84 | if args.model_name == 'pixelcnn':
85 | BS = S//100
86 | prob = []
87 | for i in range(BS):
88 | bx = x[i*100:(i+1)*100]
89 | x_indices = None
90 | bprob, _, _ = model.calculate_loss((bx, x_indices), exemplars_embedding=exemplars_embedding)
91 | prob.append(bprob)
92 | prob = torch.cat(prob, dim=0)
93 | else:
94 | x_indices = None
95 | prob, _, _ = model.calculate_loss((x, x_indices), exemplars_embedding=exemplars_embedding)
96 | likelihood_x = logsumexp(-prob.cpu().numpy())
97 | if model.args.use_logit:
98 | lambd = torch.tensor(model.args.lambd).float()
99 | likelihood_x -= (-F.softplus(-x) - F.softplus(x)\
100 | - torch.log((1 - 2 * lambd)/256)).sum(dim=1).cpu().numpy()
101 | likelihood_test.append(likelihood_x - np.log(len(prob)))
102 | likelihood_test = np.array(likelihood_test)
103 | return -np.mean(likelihood_test)
104 |
105 |
106 | def final_evaluation(train_loader, test_loader, valid_loader, best_model_path_load,
107 | model, optimizer, args, dir):
108 | _ = load_model(best_model_path_load, model, optimizer)
109 | model.eval()
110 | exemplars_embedding = load_all_pseudo_input(args, model, train_loader.dataset)
111 | test_samples = next(iter(test_loader))[0].to(args.device)
112 | visualize_reconstruction(test_samples, model, args, dir)
113 | visualize_generation(train_loader.dataset, model, args, dir)
114 | test_elbo, test_re, test_kl = evaluate_loss(args, model, test_loader, dataset=train_loader.dataset, exemplars_embedding=exemplars_embedding)
115 | valid_elbo, valid_re, valid_kl = evaluate_loss(args, model, valid_loader, dataset=valid_loader.dataset, exemplars_embedding=exemplars_embedding)
116 | train_elbo, _, _ = evaluate_loss(args, model, train_loader, dataset=train_loader.dataset, exemplars_embedding=exemplars_embedding)
117 | test_log_likelihood = calculate_likelihood(args, model, test_loader, exemplars_embedding=exemplars_embedding, S=args.S)
118 | final_evaluation_txt = 'FINAL EVALUATION ON TEST SET\n' \
119 | 'LogL (TEST): {:.2f}\n' \
120 | 'LogL (TRAIN): {:.2f}\n' \
121 | 'ELBO (TEST): {:.2f}\n' \
122 | 'ELBO (TRAIN): {:.2f}\n' \
123 | 'ELBO (VALID): {:.2f}\n' \
124 | 'RE: {:.2f}\n' \
125 | 'KL: {:.2f}'.format(
126 | test_log_likelihood,
127 | 0,
128 | test_elbo,
129 | train_elbo,
130 | valid_elbo,
131 | test_re,
132 | test_kl)
133 |
134 | print(final_evaluation_txt)
135 | with open(dir + 'vae_experiment_log.txt', 'a') as f:
136 | print(final_evaluation_txt, file=f)
137 | torch.save(test_log_likelihood, dir + args.model_name + '.test_log_likelihood')
138 | torch.save(test_elbo, dir + args.model_name + '.test_loss')
139 | torch.save(test_re, dir + args.model_name + '.test_re')
140 | torch.save(test_kl, dir + args.model_name + '.test_kl')
141 |
142 |
143 | # TODO remove last loop from this function
144 | def compute_mean_variance_per_dimension(args, model, test_loader):
145 | means = []
146 | for batch, _ in test_loader:
147 | mean, _ = model.q_z(batch.to(args.device))
148 | means.append(mean)
149 | means = torch.cat(means, dim=0).cpu().detach().numpy()
150 | active = 0
151 | for i in range(means.shape[1]):
152 | if np.var(means[:, i].reshape(-1)) > 0.01:
153 | active += 1
154 | print('active dimensions', active)
155 | return active
156 |
157 |
158 |
--------------------------------------------------------------------------------
/utils/knn_on_latent.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def find_nearest_neighbors(z_val, z_train, z_train_log_var):
5 | z_expand = z_val.unsqueeze(1)
6 | means = z_train.unsqueeze(0)
7 | distance = (z_expand - means)**2
8 | _, indices_batch = (torch.sum(distance, dim=2)**(0.5)).topk(k=20, dim=1, largest=False, sorted=True)
9 | return indices_batch
10 |
11 |
12 | def extract_full_data(data_loader):
13 | full_data = []
14 | full_labels = []
15 | full_indices = []
16 | for data in data_loader:
17 | if len(data) == 3:
18 | data, indices, labels = data
19 | full_indices.append(indices)
20 | else:
21 | data, labels = data
22 | full_data.append(data)
23 | full_labels.append(labels)
24 | full_data = torch.cat(full_data, dim=0)
25 | full_labels = torch.cat(full_labels, dim=0)
26 | if len(full_indices) > 0:
27 | full_indices = torch.cat(full_indices, dim=0)
28 | return full_data, full_indices, full_labels
29 |
30 |
31 | # TODO refactor this fucntion
32 | def report_knn_on_latent(train_loader, val_loader, test_loader, model, dir, knn_dictionary, args, val=True):
33 | train_data, _, train_labels = extract_full_data(train_loader)
34 | val_data, _, val_labels = extract_full_data(val_loader)
35 | test_data, _, test_labels = extract_full_data(test_loader)
36 |
37 | train_data = train_data.to(args.device)
38 | val_data = val_data.to(args.device)
39 |
40 | if val is True:
41 | data_to_evaluate = val_data
42 | labels = val_labels
43 | else:
44 | train_data = torch.cat((train_data, val_data), dim=0)
45 | train_labels = torch.cat((train_labels, val_labels), dim=0)
46 | data_to_evaluate = test_data
47 | labels = test_labels
48 |
49 | with torch.no_grad():
50 | z_train = []
51 | for i in range(len(train_data)//args.batch_size):
52 | train_batch = train_data[i*args.batch_size: (i+1)*args.batch_size]
53 | z_train_batch, _ = model.q_z(train_batch.to(args.device), prior=True)
54 | z_train.append(z_train_batch)
55 | z_train = torch.cat(z_train, dim=0)
56 |
57 | print(z_train.shape)
58 | indices = []
59 | for i in range(len(data_to_evaluate)//args.batch_size):
60 | z_val, _ = model.q_z(data_to_evaluate[i*args.batch_size: (i+1)*args.batch_size].to(args.device), prior=True)
61 | indices.append(find_nearest_neighbors(z_val, z_train, None))
62 | indices = torch.cat(indices, dim=0)
63 |
64 | for k in knn_dictionary.keys():
65 | k = int(k)
66 | k_labels = train_labels[indices[:, :k]].squeeze().long()
67 | num_classes = 10
68 | counts = torch.zeros(len(test_loader.dataset), num_classes)
69 | for i in range(num_classes):
70 | counts[:, i] = (k_labels == torch.tensor(i).long()).sum(dim=1)
71 | y_pred = torch.argmax(counts, dim=1)
72 | acc = (torch.mean((y_pred == labels.long()).float()) * 10000).round().item()/100
73 | print('K:', k, 'Accuracy:', acc)
74 | knn_dictionary[str(k)].append(acc)
75 |
--------------------------------------------------------------------------------
/utils/load_data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sajadn/Exemplar-VAE/4dbd0913d315db6ef9e3b5a689fae4140add2323/utils/load_data/__init__.py
--------------------------------------------------------------------------------
/utils/load_data/base_load_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.utils.data as data_utils
4 | import numpy as np
5 | from abc import ABC, abstractmethod
6 |
7 |
8 | class base_load_data(ABC):
9 | def __init__(self, args, use_fixed_validation=False, no_binarization=False):
10 | self.args = args
11 | self.train_num = args.training_set_size
12 | self.use_fixed_validation = use_fixed_validation
13 | self.no_binarization = no_binarization
14 |
15 | @abstractmethod
16 | def obtain_data(self):
17 | pass
18 |
19 | def logit(self, x):
20 | return np.log(x) - np.log1p(-x)
21 |
22 | def seperate_data_from_label(self, train_dataset, test_dataset):
23 | x_train = train_dataset.data.numpy()
24 | y_train = train_dataset.train_labels.numpy().astype(int)
25 | x_test = test_dataset.data.numpy()
26 | y_test = test_dataset.test_labels.numpy().astype(int)
27 | return x_train, y_train, x_test, y_test
28 |
29 | def preprocessing_(self, x_train, x_test):
30 | if self.args.input_type == 'gray' or self.args.input_type == 'continuous':
31 | if self.args.use_logit:
32 | lambd = self.args.lambd
33 | x_train = self.logit(lambd + (1 - 2 * lambd) * (x_train + np.random.rand(*x_train.shape)) / 256.)
34 | x_test = self.logit(lambd + (1 - 2 * lambd) * (x_test + np.random.rand(*x_test.shape)) / 256.)
35 | elif self.args.continuous:
36 | x_train = np.clip((x_train + 0.5) / 256., 0., 1.)
37 | x_test = np.clip((x_test + 0.5) / 256., 0., 1.)
38 | else:
39 | x_train = x_train / 255.
40 | x_test = x_test / 255.
41 |
42 | return x_train, x_test
43 |
44 | def vampprior_initialization(self, x_train, init_mean, init_std):
45 | if self.args.use_training_data_init == 1:
46 | self.args.pseudoinputs_std = 0.01
47 | init = x_train[0:self.args.number_components].T
48 | self.args.pseudoinputs_mean = torch.from_numpy(
49 | init + self.args.pseudoinputs_std * np.random.randn(np.prod(self.args.input_size),
50 | self.args.number_components)).float()
51 | else:
52 | self.args.pseudoinputs_mean = init_mean
53 | self.args.pseudoinputs_std = init_std
54 |
55 | def post_processing(self, x_train, x_val, x_test, y_train, y_val, y_test, init_mean=0.05, init_std=0.01, **kwargs):
56 | indices = np.arange(len(x_train)).reshape(-1, 1)
57 | train = data_utils.TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(indices),
58 | torch.from_numpy(y_train))
59 | train_loader = data_utils.DataLoader(train, batch_size=self.args.batch_size, shuffle=True, **kwargs)
60 |
61 | if len(x_val) > 0:
62 | validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val))
63 | val_loader = data_utils.DataLoader(validation, batch_size=self.args.test_batch_size, shuffle=True, **kwargs)
64 | else:
65 | val_loader = None
66 | test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test))
67 | test_loader = data_utils.DataLoader(test, batch_size=self.args.test_batch_size, shuffle=False, **kwargs)
68 |
69 | self.vampprior_initialization(x_train, init_mean, init_std)
70 | return train_loader, val_loader, test_loader
71 |
72 | def binarize(self, x_val, x_test):
73 | self.args.input_type = 'binary'
74 | np.random.seed(777)
75 | x_val = np.random.binomial(1, x_val)
76 | x_test = np.random.binomial(1, x_test)
77 | return x_val, x_test
78 |
79 | def load_dataset(self, **kwargs):
80 | # start processing
81 | train, test = self.obtain_data()
82 | x_train, y_train, x_test, y_test = self.seperate_data_from_label(train, test)
83 | x_train, x_test = self.preprocessing_(x_train, x_test)
84 |
85 | if self.use_fixed_validation is False:
86 | permutation = np.arange(len(x_train))
87 | np.random.shuffle(permutation)
88 | x_train = x_train[permutation]
89 | y_train = y_train[permutation]
90 |
91 | if self.args.dataset_name == 'static_mnist':
92 | x_train, x_val = x_train
93 | y_train, y_val = y_train
94 | else:
95 | x_val = x_train[self.train_num:]
96 | y_val = y_train[self.train_num:]
97 | x_train = x_train[:self.train_num]
98 | y_train = y_train[:self.train_num]
99 |
100 | # imshow(torchvision.utils.make_grid(torch.from_numpy(x_val[:50].reshape(-1, *self.args.input_size))))
101 | # plt.axis('off')
102 | # plt.show()
103 |
104 | x_train = np.reshape(x_train, (-1, np.prod(self.args.input_size)))
105 | x_val = np.reshape(x_val, (-1, np.prod(self.args.input_size)))
106 |
107 | x_test = np.reshape(x_test, (-1, np.prod(self.args.input_size)))
108 |
109 | if self.args.dynamic_binarization and self.no_binarization is False:
110 | x_val, x_test = self.binarize(x_val, x_test)
111 |
112 | print("data stats:")
113 | print(len(x_train), len(y_train))
114 | print(len(x_val), len(y_val))
115 | print(len(x_test), len(y_test))
116 |
117 | train_loader, val_loader, test_loader, = self.post_processing(x_train, x_val, x_test,
118 | y_train, y_val, y_test, **kwargs)
119 |
120 | return train_loader, val_loader, test_loader, self.args
121 |
--------------------------------------------------------------------------------
/utils/load_data/data_loader_instances.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import datasets
3 | import numpy as np
4 | from scipy.io import loadmat
5 | from .base_load_data import base_load_data
6 | import wget
7 |
8 | class dynamic_mnist_loader(base_load_data):
9 | def __init__(self, args, use_fixed_validation=False, no_binarization=False):
10 | super(dynamic_mnist_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization)
11 |
12 | def obtain_data(self):
13 | train = datasets.MNIST(os.path.join('datasets', self.args.dataset_name), train=True, download=True)
14 | test = datasets.MNIST(os.path.join('datasets', self.args.dataset_name), train=False)
15 | return train, test
16 |
17 |
18 | class fashion_mnist_loader(base_load_data):
19 | def __init__(self, args, use_fixed_validation=False, no_binarization=False):
20 | super(fashion_mnist_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization)
21 |
22 | def obtain_data(self):
23 | train = datasets.FashionMNIST(os.path.join('datasets', self.args.dataset_name), train=True, download=True)
24 | test = datasets.FashionMNIST(os.path.join('datasets', self.args.dataset_name), train=False)
25 | return train, test
26 |
27 |
28 | class svhn_loader(base_load_data):
29 | def __init__(self, args, use_fixed_validation=False, no_binarization=False):
30 | super(svhn_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization)
31 |
32 | def obtain_data(self):
33 | train = datasets.SVHN(os.path.join('datasets', self.args.dataset_name), split='train', download=True)
34 | test = datasets.SVHN(os.path.join('datasets', self.args.dataset_name), split='test', download=True)
35 | return train, test
36 |
37 | def seperate_data_from_label(self, train_dataset, test_dataset):
38 | x_train = train_dataset.data
39 | y_train = train_dataset.labels.astype(dtype=int)
40 | x_test = test_dataset.data
41 | y_test = test_dataset.labels.astype(dtype=int)
42 | return x_train, y_train, x_test, y_test
43 |
44 |
45 | class static_mnist_loader(base_load_data):
46 | def __init__(self, args, use_fixed_validation=False, no_binarization=False):
47 | super(static_mnist_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization)
48 |
49 | def obtain_data(self):
50 | def lines_to_np_array(lines):
51 | return np.array([[int(i) for i in line.split()] for line in lines])
52 |
53 | with open(os.path.join('datasets', self.args.dataset_name, 'binarized_mnist_train.amat')) as f:
54 | lines = f.readlines()
55 | x_train = lines_to_np_array(lines).astype('float32')
56 | with open(os.path.join('datasets', self.args.dataset_name, 'binarized_mnist_valid.amat')) as f:
57 | lines = f.readlines()
58 | x_val = lines_to_np_array(lines).astype('float32')
59 | with open(os.path.join('datasets', self.args.dataset_name, 'binarized_mnist_test.amat')) as f:
60 | lines = f.readlines()
61 | x_test = lines_to_np_array(lines).astype('float32')
62 |
63 | y_train = np.zeros((x_train.shape[0], 1)).astype(int)
64 | y_val = np.zeros((x_val.shape[0], 1)).astype(int)
65 | y_test = np.zeros((x_test.shape[0], 1)).astype(int)
66 | return (x_train, x_val, y_train, y_val), (x_test, y_test)
67 |
68 | def seperate_data_from_label(self, train_dataset, test_dataset):
69 | x_train, x_val, y_train, y_val = train_dataset
70 | x_test, y_test = test_dataset
71 | return (x_train, x_val), (y_train, y_val), x_test, y_test
72 |
73 | def preprocessing_(self, x_train, x_test):
74 | return x_train, x_test
75 |
76 |
77 | class omniglot_loader(base_load_data):
78 | def __init__(self, args, use_fixed_validation=False, no_binarization=False):
79 | super(omniglot_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization)
80 |
81 | def obtain_data(self):
82 | def reshape_data(data):
83 | return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='F')
84 | dataset_file = os.path.join('datasets', self.args.dataset_name, 'chardata.mat')
85 | if not os.path.exists(dataset_file):
86 | url = "https://raw.githubusercontent.com/yburda/iwae/master/datasets/OMNIGLOT/chardata.mat"
87 | wget.download(url, dataset_file)
88 |
89 | omni_raw = loadmat(os.path.join('datasets', self.args.dataset_name, 'chardata.mat'))
90 |
91 | x_train = reshape_data(omni_raw['data'].T.astype('float32'))
92 | x_test = reshape_data(omni_raw['testdata'].T.astype('float32'))
93 |
94 | y_train = omni_raw['targetchar'].reshape((-1, 1))
95 | y_test = omni_raw['testtargetchar'].reshape((-1, 1))
96 | return (x_train, y_train), (x_test, y_test)
97 |
98 | def seperate_data_from_label(self, train_dataset, test_dataset):
99 | x_train, y_train = train_dataset
100 | x_test, y_test = test_dataset
101 | return x_train, y_train, x_test, y_test
102 |
103 | def preprocessing_(self, x_train, x_test):
104 | return x_train, x_test
105 |
106 |
107 | class cifar10_loader(base_load_data):
108 | def __init__(self, args, use_fixed_validation=False, no_binarization=False):
109 | super(cifar10_loader, self).__init__(args, use_fixed_validation, no_binarization=no_binarization)
110 |
111 | def obtain_data(self):
112 | training_dataset = datasets.CIFAR10(os.path.join('datasets', self.args.dataset_name), train=True, download=True)
113 | test_dataset = datasets.CIFAR10(os.path.join('datasets', self.args.dataset_name), train=False)
114 | return training_dataset, test_dataset
115 |
116 | def seperate_data_from_label(self, train_dataset, test_dataset):
117 | train_data = np.swapaxes(np.swapaxes(train_dataset.data, 1, 2), 1, 3)
118 | y_train = np.zeros((train_data.shape[0], 1)).astype(int)
119 | test_data = np.swapaxes(np.swapaxes(test_dataset.data, 1, 2), 1, 3)
120 | y_test = np.zeros((test_data.shape[0], 1)).astype(int)
121 | return train_data, y_train, test_data, y_test
122 |
123 |
124 | def load_dataset(args, training_num=None, use_fixed_validation=False, no_binarization=False, **kwargs):
125 | if training_num is not None:
126 | args.training_set_size = training_num
127 | if args.dataset_name == 'static_mnist':
128 | args.input_size = [1, 28, 28]
129 | args.input_type = 'binary'
130 | train_loader, val_loader, test_loader, args = static_mnist_loader(args).load_dataset(**kwargs)
131 | elif args.dataset_name == 'dynamic_mnist':
132 | if training_num is None:
133 | args.training_set_size = 50000
134 | args.input_size = [1, 28, 28]
135 | if args.continuous is True:
136 | args.input_type = 'gray'
137 | args.dynamic_binarization = False
138 | no_binarization = True
139 | else:
140 | args.input_type = 'binary'
141 | args.dynamic_binarization = True
142 |
143 | train_loader, val_loader, test_loader, args = \
144 | dynamic_mnist_loader(args, use_fixed_validation, no_binarization=no_binarization).load_dataset(**kwargs)
145 | elif args.dataset_name == 'fashion_mnist':
146 | if training_num is None:
147 | args.training_set_size = 50000
148 | args.input_size = [1, 28, 28]
149 |
150 | if args.continuous is True:
151 | print("*****Continuous Data*****")
152 | args.input_type = 'gray'
153 | args.dynamic_binarization = False
154 | no_binarization = True
155 | else:
156 | args.input_type = 'binary'
157 | args.dynamic_binarization = True
158 |
159 | train_loader, val_loader, test_loader, args = \
160 | fashion_mnist_loader(args, use_fixed_validation, no_binarization=no_binarization).load_dataset(**kwargs)
161 | elif args.dataset_name == 'omniglot':
162 | if training_num is None:
163 | args.training_set_size = 23000
164 | args.input_size = [1, 28, 28]
165 | args.input_type = 'binary'
166 | args.dynamic_binarization = True
167 | train_loader, val_loader, test_loader, args = omniglot_loader(args).load_dataset(**kwargs)
168 | elif args.dataset_name == 'svhn':
169 | args.training_set_size = 60000
170 | args.input_size = [3, 32, 32]
171 | args.input_type = 'continuous'
172 | train_loader, val_loader, test_loader, args = svhn_loader(args).load_dataset(**kwargs)
173 | elif args.dataset_name == 'cifar10':
174 | args.training_set_size = 40000
175 | args.input_size = [3, 32, 32]
176 | args.input_type = 'continuous'
177 | train_loader, val_loader, test_loader, args = cifar10_loader(args).load_dataset(**kwargs)
178 | else:
179 | raise Exception('Wrong name of the dataset!')
180 | print('train size', len(train_loader.dataset))
181 | if val_loader is not None:
182 | print('val size', len(val_loader.dataset))
183 | print('test size', len(test_loader.dataset))
184 | return train_loader, val_loader, test_loader, args
185 |
--------------------------------------------------------------------------------
/utils/nn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def xavier_init(m):
8 | s = np.sqrt( 2. / (m.in_features + m.out_features) )
9 | m.weight.data.normal_(0, s)
10 |
11 |
12 | def he_init(m):
13 | s = np.sqrt( 2. / m.in_features )
14 | m.weight.data.normal_(0, s)
15 |
16 |
17 | def normal_init(m, mean=0., std=0.01):
18 | m.weight.data.normal_(mean, std)
19 |
20 |
21 | class CReLU(nn.Module):
22 | def __init__(self):
23 | super(CReLU, self).__init__()
24 |
25 | def forward(self, x):
26 | return torch.cat( F.relu(x), F.relu(-x), 1 )
27 |
28 |
29 | class NonLinear(nn.Module):
30 | def __init__(self, input_size, output_size, bias=True, activation=None):
31 | super(NonLinear, self).__init__()
32 |
33 | self.activation = activation
34 | self.linear = nn.Linear(int(input_size), int(output_size), bias=bias)
35 |
36 | def forward(self, x):
37 | h = self.linear(x)
38 | if self.activation is not None:
39 | h = self.activation( h )
40 |
41 | return h
42 |
43 |
44 | class GatedDense(nn.Module):
45 | def __init__(self, input_size, output_size, activation=None, no_attention=False):
46 | super(GatedDense, self).__init__()
47 |
48 | self.activation = activation
49 | self.no_attention = no_attention
50 | self.sigmoid = nn.Sigmoid()
51 | self.h = nn.Linear(input_size, output_size)
52 | if no_attention is False:
53 | self.g = nn.Linear(input_size, output_size)
54 | else:
55 | self.activation = torch.nn.ReLU()
56 |
57 | def forward(self, x):
58 | h = self.h(x)
59 | if self.activation is not None:
60 | h = self.activation( self.h( x ) )
61 | try:
62 | if self.no_attention is False:
63 | g = self.sigmoid(self.g(x))
64 | return h * g
65 | else:
66 | return h
67 | except:
68 | g = self.sigmoid(self.g(x))
69 | return h * g
70 |
71 |
72 | class GatedConv2d(nn.Module):
73 | def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None,
74 | no_attention=False):
75 | super(GatedConv2d, self).__init__()
76 | self.no_attention = no_attention
77 |
78 | self.activation = activation
79 | self.sigmoid = nn.Sigmoid()
80 |
81 | self.h = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation)
82 | if no_attention is False:
83 | self.g = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation)
84 | else:
85 | self.activation = torch.nn.ELU()
86 |
87 | def forward(self, x):
88 | if self.activation is None:
89 | h = self.h(x)
90 | else:
91 | h = self.activation( self.h( x ) )
92 |
93 | # if self.no_attention is False:
94 | g = self.sigmoid( self.g( x ) )
95 | return h * g
96 | # else:
97 | # return h
98 |
99 |
100 | class Conv2d(nn.Module):
101 | def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None, bias=True):
102 | super(Conv2d, self).__init__()
103 |
104 | self.activation = activation
105 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation, bias=bias)
106 |
107 | def forward(self, x):
108 | h = self.conv(x)
109 | if self.activation is None:
110 | out = h
111 | else:
112 | out = self.activation(h)
113 |
114 | return out
115 |
116 |
117 | class MaskedConv2d(nn.Conv2d):
118 | def __init__(self, mask_type, *args, **kwargs):
119 | super(MaskedConv2d, self).__init__(*args, **kwargs)
120 | assert mask_type in {'A', 'B'}
121 | self.register_buffer('mask', self.weight.data.clone())
122 | _, _, kH, kW = self.weight.size()
123 | self.mask.fill_(1)
124 | self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
125 | self.mask[:, :, kH // 2 + 1:] = 0
126 |
127 | def forward(self, x):
128 | self.weight.data *= self.mask
129 | return super(MaskedConv2d, self).forward(x)
130 |
131 |
132 | # Copyright (c) Xi Chen
133 | #
134 | # This source code is licensed under the MIT license found in the
135 | # LICENSE file in the root directory of this source tree.
136 |
137 | # Borrowed from https://github.com/neocxi/pixelsnail-public and ported it to PyTorch
138 |
139 | from math import sqrt
140 | from functools import partial, lru_cache
141 |
142 | import numpy as np
143 | import torch
144 | from torch import nn
145 | from torch.nn import functional as F
146 |
147 |
148 | def wn_linear(in_dim, out_dim):
149 | return nn.utils.weight_norm(nn.Linear(in_dim, out_dim))
150 |
151 |
152 | class WNConv2d(nn.Module):
153 | def __init__(
154 | self,
155 | in_channel,
156 | out_channel,
157 | kernel_size,
158 | stride=1,
159 | padding=0,
160 | bias=True,
161 | activation=None,
162 | ):
163 | super().__init__()
164 |
165 | self.conv = nn.utils.weight_norm(
166 | nn.Conv2d(
167 | in_channel,
168 | out_channel,
169 | kernel_size,
170 | stride=stride,
171 | padding=padding,
172 | bias=bias,
173 | )
174 | )
175 |
176 | self.out_channel = out_channel
177 |
178 | if isinstance(kernel_size, int):
179 | kernel_size = [kernel_size, kernel_size]
180 |
181 | self.kernel_size = kernel_size
182 |
183 | self.activation = activation
184 |
185 | def forward(self, input):
186 | out = self.conv(input)
187 |
188 | if self.activation is not None:
189 | out = self.activation(out)
190 |
191 | return out
192 |
193 |
194 | def shift_down(input, size=1):
195 | return F.pad(input, [0, 0, size, 0])[:, :, : input.shape[2], :]
196 |
197 |
198 | def shift_right(input, size=1):
199 | return F.pad(input, [size, 0, 0, 0])[:, :, :, : input.shape[3]]
200 |
201 |
202 | class CausalConv2d(nn.Module):
203 | def __init__(
204 | self,
205 | in_channel,
206 | out_channel,
207 | kernel_size,
208 | stride=1,
209 | padding='downright',
210 | activation=None,
211 | ):
212 | super().__init__()
213 |
214 | if isinstance(kernel_size, int):
215 | kernel_size = [kernel_size] * 2
216 |
217 | self.kernel_size = kernel_size
218 |
219 | if padding == 'downright':
220 | pad = [kernel_size[1] - 1, 0, kernel_size[0] - 1, 0]
221 |
222 | elif padding == 'down' or padding == 'causal':
223 | pad = kernel_size[1] // 2
224 |
225 | pad = [pad, pad, kernel_size[0] - 1, 0]
226 |
227 | self.causal = 0
228 | if padding == 'causal':
229 | self.causal = kernel_size[1] // 2
230 |
231 | self.pad = nn.ZeroPad2d(pad)
232 |
233 | self.conv = WNConv2d(
234 | in_channel,
235 | out_channel,
236 | kernel_size,
237 | stride=stride,
238 | padding=0,
239 | activation=activation,
240 | )
241 |
242 | def forward(self, input):
243 | out = self.pad(input)
244 |
245 | if self.causal > 0:
246 | self.conv.conv.weight_v.data[:, :, -1, self.causal :].zero_()
247 |
248 | out = self.conv(out)
249 |
250 | return out
251 |
252 |
253 | class GatedResBlock(nn.Module):
254 | def __init__(
255 | self,
256 | in_channel,
257 | channel,
258 | kernel_size,
259 | conv='wnconv2d',
260 | activation=nn.ELU,
261 | dropout=0.1,
262 | auxiliary_channel=0,
263 | condition_dim=0,
264 | ):
265 | super().__init__()
266 |
267 | if conv == 'wnconv2d':
268 | conv_module = partial(WNConv2d, padding=kernel_size // 2)
269 |
270 | elif conv == 'causal_downright':
271 | conv_module = partial(CausalConv2d, padding='downright')
272 |
273 | elif conv == 'causal':
274 | conv_module = partial(CausalConv2d, padding='causal')
275 |
276 | self.activation = activation()
277 | self.conv1 = conv_module(in_channel, channel, kernel_size)
278 |
279 | if auxiliary_channel > 0:
280 | self.aux_conv = WNConv2d(auxiliary_channel, channel, 1)
281 |
282 | self.dropout = nn.Dropout(dropout)
283 |
284 | self.conv2 = conv_module(channel, in_channel * 2, kernel_size)
285 |
286 | if condition_dim > 0:
287 | # self.condition = nn.Linear(condition_dim, in_channel * 2, bias=False)
288 | self.condition = WNConv2d(condition_dim, in_channel * 2, 1, bias=False)
289 |
290 | self.gate = nn.GLU(1)
291 |
292 | def forward(self, input, aux_input=None, condition=None):
293 | out = self.conv1(self.activation(input))
294 |
295 | if aux_input is not None:
296 | out = out + self.aux_conv(self.activation(aux_input))
297 |
298 | out = self.activation(out)
299 | out = self.dropout(out)
300 | out = self.conv2(out)
301 |
302 | if condition is not None:
303 | condition = self.condition(condition)
304 | out += condition
305 | # out = out + condition.view(condition.shape[0], 1, 1, condition.shape[1])
306 |
307 | out = self.gate(out)
308 | out += input
309 |
310 | return out
311 |
312 |
313 | @lru_cache(maxsize=64)
314 | def causal_mask(size):
315 | shape = [size, size]
316 | mask = np.triu(np.ones(shape), k=1).astype(np.uint8).T
317 | start_mask = np.ones(size).astype(np.float32)
318 | start_mask[0] = 0
319 |
320 | return (
321 | torch.from_numpy(mask).unsqueeze(0),
322 | torch.from_numpy(start_mask).unsqueeze(1),
323 | )
324 |
325 |
326 | class CausalAttention(nn.Module):
327 | def __init__(self, query_channel, key_channel, channel, n_head=8, dropout=0.1):
328 | super().__init__()
329 |
330 | self.query = wn_linear(query_channel, channel)
331 | self.key = wn_linear(key_channel, channel)
332 | self.value = wn_linear(key_channel, channel)
333 |
334 | self.dim_head = channel // n_head
335 | self.n_head = n_head
336 |
337 | self.dropout = nn.Dropout(dropout)
338 |
339 | def forward(self, query, key):
340 | batch, _, height, width = key.shape
341 |
342 | def reshape(input):
343 | return input.view(batch, -1, self.n_head, self.dim_head).transpose(1, 2)
344 |
345 | query_flat = query.view(batch, query.shape[1], -1).transpose(1, 2)
346 | key_flat = key.view(batch, key.shape[1], -1).transpose(1, 2)
347 | query = reshape(self.query(query_flat))
348 | key = reshape(self.key(key_flat)).transpose(2, 3)
349 | value = reshape(self.value(key_flat))
350 |
351 | attn = torch.matmul(query, key) / sqrt(self.dim_head)
352 | mask, start_mask = causal_mask(height * width)
353 | mask = mask.type_as(query)
354 | start_mask = start_mask.type_as(query)
355 | attn = attn.masked_fill(mask == 0, -1e4)
356 | attn = torch.softmax(attn, 3) * start_mask
357 | attn = self.dropout(attn)
358 |
359 | out = attn @ value
360 | out = out.transpose(1, 2).reshape(
361 | batch, height, width, self.dim_head * self.n_head
362 | )
363 | out = out.permute(0, 3, 1, 2)
364 |
365 | return out
366 |
367 |
368 | class PixelBlock(nn.Module):
369 | def __init__(
370 | self,
371 | in_channel,
372 | channel,
373 | kernel_size,
374 | n_res_block,
375 | attention=True,
376 | dropout=0.1,
377 | condition_dim=0,
378 | ):
379 | super().__init__()
380 |
381 | resblocks = []
382 | for i in range(n_res_block):
383 | resblocks.append(
384 | GatedResBlock(
385 | in_channel,
386 | channel,
387 | kernel_size,
388 | conv='causal',
389 | dropout=dropout,
390 | condition_dim=condition_dim,
391 | )
392 | )
393 |
394 | self.resblocks = nn.ModuleList(resblocks)
395 |
396 | self.attention = attention
397 |
398 | if attention:
399 | self.key_resblock = GatedResBlock(
400 | in_channel * 2 + 2, in_channel, 1, dropout=dropout
401 | )
402 | self.query_resblock = GatedResBlock(
403 | in_channel + 2, in_channel, 1, dropout=dropout
404 | )
405 |
406 | self.causal_attention = CausalAttention(
407 | in_channel + 2, in_channel * 2 + 2, in_channel // 2, dropout=dropout
408 | )
409 |
410 | self.out_resblock = GatedResBlock(
411 | in_channel,
412 | in_channel,
413 | 1,
414 | auxiliary_channel=in_channel // 2,
415 | dropout=dropout,
416 | )
417 |
418 | else:
419 | self.out = WNConv2d(in_channel + 2, in_channel, 1)
420 |
421 | def forward(self, input, background, condition=None):
422 | out = input
423 |
424 | for resblock in self.resblocks:
425 | out = resblock(out, condition=condition)
426 |
427 | if self.attention:
428 | key_cat = torch.cat([input, out, background], 1)
429 | key = self.key_resblock(key_cat)
430 | query_cat = torch.cat([out, background], 1)
431 | query = self.query_resblock(query_cat)
432 | attn_out = self.causal_attention(query, key)
433 | out = self.out_resblock(out, attn_out)
434 |
435 | else:
436 | bg_cat = torch.cat([out, background], 1)
437 | out = self.out(bg_cat)
438 |
439 | return out
440 |
441 |
442 | class CondResNet(nn.Module):
443 | def __init__(self, in_channel, channel, kernel_size, n_res_block):
444 | super().__init__()
445 |
446 | blocks = [WNConv2d(in_channel, channel, kernel_size, padding=kernel_size // 2)]
447 |
448 | for i in range(n_res_block):
449 | blocks.append(GatedResBlock(channel, channel, kernel_size))
450 |
451 | self.blocks = nn.Sequential(*blocks)
452 |
453 | def forward(self, input):
454 | return self.blocks(input)
455 |
456 |
457 | class PixelSNAIL(nn.Module):
458 | def __init__(
459 | self,
460 | shape,
461 | n_class,
462 | channel,
463 | kernel_size,
464 | n_block,
465 | n_res_block,
466 | res_channel,
467 | attention=True,
468 | dropout=0.1,
469 | n_cond_res_block=0,
470 | cond_res_channel=0,
471 | cond_res_kernel=3,
472 | n_out_res_block=0,
473 | ):
474 | super().__init__()
475 |
476 | height, width = shape
477 |
478 | self.n_class = n_class
479 |
480 | if kernel_size % 2 == 0:
481 | kernel = kernel_size + 1
482 |
483 | else:
484 | kernel = kernel_size
485 |
486 | self.horizontal = CausalConv2d(
487 | 3, channel, [kernel // 2, kernel], padding='down'
488 | )
489 | self.vertical = CausalConv2d(
490 | 3, channel, [(kernel + 1) // 2, kernel // 2], padding='downright'
491 | )
492 |
493 | coord_x = (torch.arange(height).float() - height / 2) / height
494 | coord_x = coord_x.view(1, 1, height, 1).expand(1, 1, height, width)
495 | coord_y = (torch.arange(width).float() - width / 2) / width
496 | coord_y = coord_y.view(1, 1, 1, width).expand(1, 1, height, width)
497 | self.register_buffer('background', torch.cat([coord_x, coord_y], 1))
498 |
499 | self.blocks = nn.ModuleList()
500 |
501 | for i in range(n_block):
502 | self.blocks.append(
503 | PixelBlock(
504 | channel,
505 | res_channel,
506 | kernel_size,
507 | n_res_block,
508 | attention=attention,
509 | dropout=dropout,
510 | condition_dim=cond_res_channel,
511 | )
512 | )
513 |
514 | if n_cond_res_block > 0:
515 | self.cond_resnet = CondResNet(
516 | n_class, cond_res_channel, cond_res_kernel, n_cond_res_block
517 | )
518 |
519 | out = []
520 |
521 | for i in range(n_out_res_block):
522 | out.append(GatedResBlock(channel, res_channel, 1))
523 |
524 | out.extend([nn.ELU(inplace=True), WNConv2d(channel, n_class, 1)])
525 |
526 | self.out = nn.Sequential(*out)
527 |
528 | def forward(self, input, condition=None, cache=None):
529 | if cache is None:
530 | cache = {}
531 | batch, _, height, width = input.shape
532 | #input = (
533 | # F.one_hot(input, self.n_class).permute(0, 3, 1, 2).type_as(self.background)
534 | #)
535 | horizontal = shift_down(self.horizontal(input))
536 | vertical = shift_right(self.vertical(input))
537 | out = horizontal + vertical
538 |
539 | background = self.background[:, :, :height, :].expand(batch, 2, height, width)
540 |
541 | if condition is not None:
542 | if 'condition' in cache:
543 | condition = cache['condition']
544 | condition = condition[:, :, :height, :]
545 |
546 | else:
547 | condition = (
548 | F.one_hot(condition, self.n_class)
549 | .permute(0, 3, 1, 2)
550 | .type_as(self.background)
551 | )
552 | condition = self.cond_resnet(condition)
553 | condition = F.interpolate(condition, scale_factor=2)
554 | cache['condition'] = condition.detach().clone()
555 | condition = condition[:, :, :height, :]
556 |
557 | for block in self.blocks:
558 | out = block(out, background, condition=condition)
559 |
560 | out = self.out(out)
561 |
562 | return out
563 |
564 |
565 |
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | from torch.optim import Optimizer
4 | import math
5 |
6 |
7 | class AdamNormGrad(Optimizer):
8 | """Implements Adam algorithm.
9 |
10 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
11 |
12 | Arguments:
13 | params (iterable): iterable of parameters to optimize or dicts defining
14 | parameter groups
15 | lr (float, optional): learning rate (default: 1e-3)
16 | betas (Tuple[float, float], optional): coefficients used for computing
17 | running averages of gradient and its square (default: (0.9, 0.999))
18 | eps (float, optional): term added to the denominator to improve
19 | numerical stability (default: 1e-8)
20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
21 |
22 | .. _Adam\: A Method for Stochastic Optimization:
23 | https://arxiv.org/abs/1412.6980
24 | """
25 |
26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
27 | weight_decay=0):
28 | defaults = dict(lr=lr, betas=betas, eps=eps,
29 | weight_decay=weight_decay)
30 | super(AdamNormGrad, self).__init__(params, defaults)
31 |
32 | def step(self, closure=None):
33 | """Performs a single optimization step.
34 |
35 | Arguments:
36 | closure (callable, optional): A closure that reevaluates the model
37 | and returns the loss.
38 | """
39 | loss = None
40 | if closure is not None:
41 | loss = closure()
42 |
43 | for group in self.param_groups:
44 | for p in group['params']:
45 | if p.grad is None:
46 | continue
47 | grad = p.grad.data
48 | # normalize grdients
49 | grad = grad / ( torch.norm(grad,2) + 1.e-7 )
50 | state = self.state[p]
51 |
52 | # State initialization
53 | if len(state) == 0:
54 | state['step'] = 0
55 | # Exponential moving average of gradient values
56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_()
57 | # Exponential moving average of squared gradient values
58 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
59 |
60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
61 | beta1, beta2 = group['betas']
62 |
63 | state['step'] += 1
64 |
65 | if group['weight_decay'] != 0:
66 | grad = grad.add(group['weight_decay'], p.data)
67 |
68 | # Decay the first and second moment running average coefficient
69 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
70 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
71 |
72 | denom = exp_avg_sq.sqrt().add_(group['eps'])
73 |
74 | bias_correction1 = 1 - beta1 ** state['step']
75 | bias_correction2 = 1 - beta2 ** state['step']
76 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
77 |
78 | p.data.addcdiv_(-step_size, exp_avg, denom)
79 |
80 | return loss
81 |
--------------------------------------------------------------------------------
/utils/plot_images.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import matplotlib.gridspec as gridspec
3 | import numpy as np
4 | import os
5 |
6 |
7 | def imshow(img, title=None, interpolation=None, show_plot=False):
8 | npimg = img.detach().cpu().numpy()
9 | plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation=interpolation)
10 | if title is not None:
11 | plt.title(title)
12 | if show_plot:
13 | plt.show()
14 |
15 |
16 | def generate_fancy_grid(config, dir, reference_data, generated, col_num=4, row_num=3):
17 | import cv2
18 |
19 | image_size = config.input_size[-1]
20 | width = col_num*image_size+2
21 | height = row_num*image_size+2
22 |
23 | print('references', reference_data.shape)
24 | print('generated', generated.shape)
25 |
26 | generated_dir = os.path.join(dir, 'generated/')
27 | os.makedirs(generated_dir, exist_ok=True)
28 |
29 | for k in range(len(reference_data)):
30 | grid = np.ones((config.input_size[0], height, width))
31 | original_image = reference_data[k].reshape(1, *config.input_size).cpu().detach().numpy()
32 | grid[:, 0:image_size, 0:image_size] = original_image
33 | generated_images = generated[k].reshape(-1, *config.input_size).cpu().detach().numpy()
34 | offset = 2
35 | counts = 0
36 | for i in range(row_num):
37 | j_counts = col_num
38 | extra_offset = 0
39 | if i == 0:
40 | j_counts = col_num-1
41 | extra_offset = image_size
42 |
43 | row = i*image_size+offset
44 | for j in range(j_counts):
45 | generated_images[counts]
46 | grid[:, row:row+image_size, extra_offset+j*image_size+offset:extra_offset+(j+1)*image_size+offset] = generated_images[counts]
47 | counts += 1
48 |
49 | if config.input_size[0] > 1:
50 | grid = np.transpose(grid, (1, 2, 0))
51 | grid = np.squeeze(grid)
52 | plt.imsave(arr=np.clip(grid, 0, 1),
53 | fname=generated_dir + "generated_{}.png".format(k),
54 | cmap='gray', format='png')
55 |
56 | img = cv2.imread(generated_dir + "generated_{}.png".format(k))
57 | res = cv2.resize(img, dsize=(width*3, height*3), interpolation=cv2.INTER_NEAREST)
58 | cv2.imwrite(generated_dir + "generated_{}.png".format(k), res)
59 | # plt.show()
60 |
61 |
62 | def plot_images_in_line(images, args, dir, file_name):
63 | import cv2
64 |
65 | width = len(images) * 28
66 | height = 28
67 | grid = np.ones((height, width))
68 | for index, image in enumerate(images):
69 | image = image.reshape(*args.input_size).cpu().detach().numpy()
70 | grid[0:28, 28*index:28*(index+1)] = image[0]
71 | file_name = os.path.join(dir, file_name)
72 | plt.imsave(arr=grid / 255,
73 | fname=file_name,
74 | cmap='gray', format='png')
75 |
76 | img = cv2.imread(file_name)
77 | res = cv2.resize(img, dsize=(width*3, height*3), interpolation=cv2.INTER_NEAREST)
78 | cv2.imwrite(file_name, res)
79 |
80 |
81 | def plot_images(config, x_sample, dir, file_name, size_x=3, size_y=3):
82 | if len(x_sample.shape) < 4:
83 | x_sample = x_sample.reshape(-1, *config.input_size)
84 | fig = plt.figure(figsize=(size_x, size_y))
85 | # fig = plt.figure(1)
86 | gs = gridspec.GridSpec(size_x, size_y)
87 | gs.update(wspace=0.01, hspace=0.01)
88 |
89 | for i, sample in enumerate(x_sample):
90 | ax = plt.subplot(gs[i])
91 | plt.axis('off')
92 | ax.set_xticklabels([])
93 | ax.set_yticklabels([])
94 | ax.set_aspect('equal')
95 |
96 | sample = sample.swapaxes(0, 2)
97 | sample = sample.swapaxes(0, 1)
98 | if config.input_type == 'binary' or config.input_type == 'gray':
99 | sample = sample[:, :, 0]
100 | plt.imshow(sample, cmap='gray')
101 | else:
102 | plt.imshow(sample)
103 |
104 | plt.savefig(dir + file_name + '.png', bbox_inches='tight')
105 | plt.close(fig)
106 |
107 |
108 |
--------------------------------------------------------------------------------
/utils/training.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 |
4 |
5 | def set_beta(args, epoch):
6 | if args.warmup == 0:
7 | beta = 1.
8 | else:
9 | beta = 1. * epoch / args.warmup
10 | if beta > 1.:
11 | beta = 1.
12 | return beta
13 |
14 |
15 | def train_one_epoch(epoch, args, train_loader, model, optimizer):
16 | train_loss, train_re, train_kl = 0, 0, 0
17 | model.train()
18 | beta = set_beta(args, epoch)
19 | print('beta: {}'.format(beta))
20 | if args.approximate_prior is True:
21 | with torch.no_grad():
22 | cached_z, cached_log_var = model.cache_z(train_loader.dataset)
23 | cache = (cached_z, cached_log_var)
24 | else:
25 | cache = None
26 |
27 | for batch_idx, (data, indices, target) in enumerate(train_loader):
28 | data, indices, target = data.to(args.device), indices.to(args.device), target.to(args.device)
29 |
30 | if args.dynamic_binarization:
31 | x = torch.bernoulli(data)
32 | else:
33 | x = data
34 |
35 | x = (x, indices)
36 | optimizer.zero_grad()
37 | loss, RE, KL = model.calculate_loss(x, beta, average=True, cache=cache, dataset=train_loader.dataset)
38 | loss.backward()
39 | optimizer.step()
40 |
41 | with torch.no_grad():
42 | train_loss += loss.data.item()
43 | train_re += -RE.data.item()
44 | train_kl += KL.data.item()
45 | if cache is not None:
46 | cache = (cache[0].detach(), cache[1].detach())
47 |
48 | train_loss /= len(train_loader)
49 | train_re /= len(train_loader)
50 | train_kl /= len(train_loader)
51 | return train_loss, train_re, train_kl
52 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | def importing_model(args):
5 | if args.model_name == 'vae':
6 | from models.VAE import VAE
7 | elif args.model_name == 'hvae_2level':
8 | from models.HVAE_2level import VAE
9 | elif args.model_name == 'convhvae_2level':
10 | from models.convHVAE_2level import VAE
11 | elif args.model_name == 'new_vae':
12 | from models.new_vae import VAE
13 | elif args.model_name == 'single_conv':
14 | from models.fully_conv import VAE
15 | elif args.model_name == 'pixelcnn':
16 | from models.PixelCNN import VAE
17 | else:
18 | raise Exception('Wrong name of the model!')
19 | return VAE
20 |
21 |
22 | def save_model(save_path, load_path, content):
23 | torch.save(content, save_path)
24 | os.rename(save_path, load_path)
25 |
26 |
27 | def load_model(load_path, model, optimizer=None):
28 | checkpoint = torch.load(load_path)
29 | model.load_state_dict(checkpoint['state_dict'])
30 | if optimizer is not None:
31 | optimizer.load_state_dict(checkpoint['optimizer'])
32 | return checkpoint
33 |
--------------------------------------------------------------------------------