├── Load_model_visual.py ├── Nmetrics.py ├── README.txt ├── main.py ├── models ├── Multi-COIL-10.pt ├── Multi-FMNIST.pt ├── Multi-MNIST.pt └── README.txt ├── multi_vae ├── MvModels.py └── MvTraining.py ├── requirements.txt ├── utils ├── DATA │ ├── Multi-COIL-10.mat │ └── README.txt ├── MvDataloaders.py └── MvLoad_models.py └── viz ├── MvVisualize.py └── latent_traversals.py /Load_model_visual.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | # Look at architecture and latent spec 4 | from utils.MvDataloaders import Get_dataloaders 5 | from utils.MvLoad_models import load 6 | 7 | # Visualize various aspects of the model 8 | show = 1 9 | datasets = ['Multi-MNIST', 'Multi-FMNIST', 'Multi-COIL-10'] 10 | settings = [[1, 2], [1, 3], [1, 3]] 11 | d_idex = 2 12 | show_view = 1 13 | DATA = datasets[d_idex] 14 | share = settings[d_idex][0] 15 | view_num = settings[d_idex][1] 16 | path_to_model_folder = './models/' + DATA + '.pt' 17 | latent_spec = {"disc": [10], "cont": 10} 18 | model = load(path=path_to_model_folder, view_num=view_num, shareAE=share, latent_spec=latent_spec) 19 | 20 | # Print the latent distribution info 21 | print(model.MvLatent_spec) 22 | # Print model architecture 23 | print(model) 24 | 25 | # Visualize various aspects of the model 26 | from viz.MvVisualize import Visualizer as Viz 27 | 28 | # Create a Visualizer for the model 29 | viz = Viz(model, view_num=view_num, show_view=show_view) 30 | if show == 0: 31 | viz.save_images = True # Return tensors instead of saving images 32 | else: 33 | viz.save_images = False # Return tensors instead of saving images 34 | 35 | # Plot generated samples from the model 36 | import matplotlib.pyplot as plt 37 | import pylab 38 | 39 | plt.figure('Plot generated samples from the model') 40 | print("Plot generated samples from the model") 41 | samples = viz.samples() 42 | if show == 1: 43 | plt.imshow(samples.numpy()[0, :, :], cmap='gray') 44 | pylab.show() 45 | 46 | # Plot traversal of single dimension 47 | plt.figure('Plot traversal of single dimension') 48 | print("Plot traversal of single dimension") 49 | traversal = viz.latent_traversal_line(disc_idx=0, size=10) 50 | if show == 1: 51 | plt.imshow(traversal.numpy()[0, :, :], cmap='gray') 52 | pylab.show() 53 | # All latent traversals 54 | plt.figure('All latent traversals') 55 | print("All latent traversals") 56 | traversals = viz.all_latent_traversals() 57 | if show == 1: 58 | plt.imshow(traversals.numpy()[0, :, :], cmap='gray') 59 | pylab.show() 60 | 61 | # Plot a grid of two interesting traversals 62 | # discrete latent dimension across rows 63 | plt.figure('Plot a grid of two interesting traversals') 64 | print("Plot a grid of two interesting traversals") 65 | traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10)) 66 | if show == 1: 67 | plt.imshow(traversals.numpy()[0, :, :], cmap='gray') 68 | pylab.show() 69 | 70 | # Reorder discrete latent to match order of digits 71 | from viz.MvVisualize import reorder_img 72 | 73 | if show == 1: 74 | ordering = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # The 9th dimension corresponds to 0, the 3rd to 1 etc... 75 | traversals = reorder_img(traversals, ordering, by_row=True) 76 | plt.figure('Reorder discrete latent to match order of digits') 77 | print("Reorder discrete latent to match order of digits") 78 | plt.imshow(traversals.numpy()[0, :, :], cmap='gray') 79 | pylab.show() 80 | 81 | for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: 82 | plt.figure(i) 83 | print('cont_idx:' + str(i)) 84 | traversal = viz.latent_traversal_grid(filename='cont_idx_' + str(i) + '.png', cont_idx=i, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10)) 85 | if show == 1: 86 | plt.imshow(traversal.numpy()[0, :, :], cmap='gray') 87 | pylab.show() 88 | 89 | # Plot reconstructions 90 | dataloader, _, _, _ = Get_dataloaders(batch_size=50, DATANAME=DATA + '.mat') 91 | # Extract a batch of data 92 | datarecon = [] 93 | if view_num == 2: 94 | for data1, data2, labels in dataloader: 95 | break 96 | datarecon = [data1, data2] 97 | else: 98 | for data1, data2, data3, labels in dataloader: 99 | break 100 | datarecon = [data1, data2, data3] 101 | recon = viz.reconstructions(datarecon, size=(10, 10)) 102 | 103 | plt.figure('Plot reconstructions') 104 | print("Plot reconstructions") 105 | if show == 1: 106 | plt.imshow(recon.numpy()[0, :, :], cmap='gray') 107 | pylab.show() 108 | 109 | print("------END-----") 110 | -------------------------------------------------------------------------------- /Nmetrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, v_measure_score, accuracy_score 3 | 4 | nmi = normalized_mutual_info_score 5 | vmeasure = v_measure_score 6 | ari = adjusted_rand_score 7 | 8 | 9 | def acc(y_true, y_pred): 10 | """ 11 | Calculate clustering accuracy. Require scikit-learn installed 12 | 13 | # Arguments 14 | y: true labels, numpy.array with shape `(n_samples,)` 15 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 16 | 17 | # Return 18 | accuracy, in [0,1] 19 | """ 20 | y_true = y_true.astype(np.int64) 21 | assert y_pred.size == y_true.size 22 | D = max(y_pred.max(), y_true.max()) + 1 23 | w = np.zeros((D, D), dtype=np.int64) 24 | for i in range(y_pred.size): 25 | w[y_pred[i], y_true[i]] += 1 26 | # from scipy.optimize import linear_sum_assignment 27 | from sklearn.utils.linear_assignment_ import linear_assignment 28 | ind = linear_assignment(w.max() - w) 29 | # ind = linear_sum_assignment(w.max() - w) 30 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 31 | 32 | def purity(y_true, y_pred): 33 | """Purity score 34 | Args: 35 | y_true(np.ndarray): n*1 matrix Ground truth labels 36 | y_pred(np.ndarray): n*1 matrix Predicted clusters 37 | 38 | Returns: 39 | float: Purity score 40 | """ 41 | # matrix which will hold the majority-voted labels 42 | y_voted_labels = np.zeros(y_true.shape) 43 | # Ordering labels 44 | ## Labels might be missing e.g with set like 0,2 where 1 is missing 45 | ## First find the unique labels, then map the labels to an ordered set 46 | ## 0,2 should become 0,1 47 | labels = np.unique(y_true) 48 | ordered_labels = np.arange(labels.shape[0]) 49 | for k in range(labels.shape[0]): 50 | y_true[y_true==labels[k]] = ordered_labels[k] 51 | # Update unique labels 52 | labels = np.unique(y_true) 53 | # We set the number of bins to be n_classes+2 so that 54 | # we count the actual occurence of classes between two consecutive bins 55 | # the bigger being excluded [bin_i, bin_i+1] 56 | bins = np.concatenate((labels, [np.max(labels)+1]), axis=0) 57 | 58 | for cluster in np.unique(y_pred): 59 | hist, _ = np.histogram(y_true[y_pred==cluster], bins=bins) 60 | # Find the most present label in the cluster 61 | winner = np.argmax(hist) 62 | y_voted_labels[y_pred==cluster] = winner 63 | 64 | return accuracy_score(y_true, y_voted_labels) 65 | 66 | def test(y_true, y_pred): 67 | print("ACC:%.4f, NMI:%.4f, VME:%.4f, ARI:%.4f, PUR:%.4f" % (acc(y_true, y_pred), 68 | nmi(y_true, y_pred), 69 | vmeasure(y_true, y_pred), 70 | ari(y_true, y_pred), 71 | purity(y_true, y_pred))) 72 | return nmi(y_true, y_pred) 73 | -------------------------------------------------------------------------------- /README.txt: -------------------------------------------------------------------------------- 1 | # Accepted by ICCV 2021. 2 | 3 | # Settings in main.py 4 | 5 | # TEST = Ture 6 | when TEST = Ture, the code just test the trained Multi-VAE model 7 | when TEST = False, the code will train Multi-VAE model 8 | 9 | # Run the code: 10 | python main.py 11 | 12 | # Visualize the generative model: 13 | python Load_model_visual.py 14 | 15 | # BibTex 16 | @InProceedings{Xu_2021_ICCV, 17 | author = {Xu, Jie and Ren, Yazhou and Tang, Huayi and Pu, Xiaorong and Zhu, Xiaofeng and Zeng, Ming and He, Lifang}, 18 | title= {Multi-{VAE}: Learning Disentangled View-Common and View-Peculiar Visual Representations for Multi-View Clustering}, 19 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 20 | year = {2021}, 21 | pages = {9234-9243} 22 | } 23 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | from utils.MvDataloaders import Get_dataloaders 4 | from utils.MvLoad_models import load 5 | from sklearn.cluster import KMeans 6 | from sklearn import preprocessing 7 | min_max_scaler = preprocessing.MinMaxScaler() 8 | import numpy as np 9 | import torch 10 | import scipy.io as scio 11 | import os 12 | import random 13 | 14 | 15 | def setup_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | setup_seed(5) 24 | 25 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 26 | 27 | NMI_c = [] 28 | NMI_cz = [] 29 | ACC_c = [] 30 | ACC_cz = [] 31 | 32 | datasets = ['Multi-COIL-10', 33 | 'Multi-COIL-20', 34 | 'Object-Digit-Product', 35 | 'Multi-MNIST', 36 | 'Multi-FMNIST', 37 | 'Digit-Product'] 38 | settings = [[1, 32], [1, 32], [0, 32], [1, 64], [1, 64], [0, 64]] # share autoencoder, Batch_size 39 | iters_to_add_capacity = [25000, 25000, 25000, 25000, 25000, 25000] 40 | for d in [0]: # datasets index 41 | DATA = datasets[d] 42 | share = settings[d][0] 43 | Batch_size = settings[d][1] 44 | iters_add_capacity = iters_to_add_capacity[d] 45 | Epochs = 500 46 | lr = 5e-4 47 | Net = 'C' # CNN 48 | hidden_dim = 256 49 | z_variables = 10 50 | # for beta in [10, 20, 30, 40, 50]: 51 | # for capacity in [3, 4, 5, 6, 7]: 52 | runs = 1 53 | TEST = True 54 | for beta in [30]: 55 | for capacity in [5]: 56 | ACCc = 0 57 | NMIc = 0 58 | ARIc = 0 59 | PURc = 0 60 | ACCcz = 0 61 | NMIcz = 0 62 | ARIcz = 0 63 | PURcz = 0 64 | for i in range(runs): 65 | model_name = DATA + '.pt' 66 | print('Run:' + str(i)) 67 | if Net == 'C': 68 | train_loader, view_num, n_clusters, size = Get_dataloaders(batch_size=Batch_size, 69 | DATANAME=DATA + '.mat') # 64 70 | print('Iters:' + str(size / Batch_size * Epochs)) 71 | # Define the capacities 72 | # Continuous channels 73 | # Discrete channels 74 | # iters_add_capacity = size/Batch_size*Epochs 75 | cont_capacity = [capacity, beta, iters_add_capacity] 76 | disc_capacity = [np.log(n_clusters), beta, iters_add_capacity] 77 | 78 | latent_spec = {'cont': z_variables, # view-peculiar variable (representation) 79 | 'disc': [n_clusters]} # view-common variable (representation) 80 | use_cuda = torch.cuda.is_available() 81 | print("cuda is available?") 82 | print(use_cuda) 83 | 84 | # img_size=(3, 64, 64) 85 | # img_size=(3, 32, 32) 86 | img_size = (1, 32, 32) 87 | if TEST == False: 88 | # Build a model 89 | from multi_vae.MvModels import VAE 90 | 91 | model = VAE(latent_spec=latent_spec, img_size=img_size, 92 | view_num=view_num, use_cuda=use_cuda, 93 | Network=Net, hidden_dim=hidden_dim, shareAE=share) 94 | if use_cuda: 95 | model.cuda() 96 | 97 | print(model) 98 | 99 | # Train the model 100 | from torch import optim 101 | 102 | # Build optimizer 103 | optimizer = optim.Adam(model.parameters(), lr=lr) 104 | 105 | from multi_vae.MvTraining import Trainer 106 | 107 | # Build a trainer 108 | trainer = Trainer(model, optimizer, 109 | cont_capacity=cont_capacity, 110 | disc_capacity=disc_capacity, view_num=view_num, use_cuda=use_cuda, DATA=DATA) 111 | 112 | # Train model for Epochs 113 | trainer.train(train_loader, epochs=Epochs) 114 | torch.save(trainer.model.state_dict(), './models/' + model_name) 115 | # print('save model?') 116 | TEST = True 117 | 118 | if TEST == True: 119 | path_to_model_folder = './models/' + model_name 120 | batch_size_test = 140000 # max number to cover all dataset. 121 | if Net == 'C': 122 | train_loader, view_num, n_clusters, _ = Get_dataloaders(batch_size=batch_size_test, 123 | DATANAME=DATA + '.mat') 124 | model = load(latent_spec=latent_spec, 125 | path=path_to_model_folder, 126 | view_num=view_num, 127 | img_size=img_size, 128 | Network=Net, 129 | hid=hidden_dim, shareAE=share) 130 | 131 | # Print the latent distribution info 132 | print(model.MvLatent_spec) 133 | 134 | # Print model architecture 135 | print(model) 136 | from torch.autograd import Variable 137 | 138 | for batch_idx, Data in enumerate(train_loader): 139 | break 140 | data = Data[0:-1] 141 | labels = Data[-1] 142 | inputs = [] 143 | for i in range(view_num): 144 | inputs.append(Variable(data[i])) 145 | encodings = model.encode(inputs) 146 | 147 | import Nmetrics 148 | 149 | kmeans = KMeans(n_clusters=n_clusters, n_init=100) 150 | # Discrete encodings, view-common variable 151 | x = encodings['disc'][0].cpu().detach().data.numpy() 152 | multiview_z = [] 153 | multiview_cz = [] 154 | for i in range(view_num): 155 | name = 'cont' + str(i + 1) 156 | # Continuous encodings, view-peculiar variables 157 | x_c = encodings[name][0].cpu().detach().data.numpy() # z 158 | xi = min_max_scaler.fit_transform(x_c) # scale to [0,1] 159 | multiview_z.append(np.concatenate([xi, x], axis=1)) # z + c 160 | multiview_cz.append(xi) 161 | print(multiview_z[-1].shape) 162 | print(multiview_z[-1][0]) 163 | y = labels.cpu().detach().data.numpy() 164 | 165 | p = kmeans.fit_predict(x) 166 | print('k-means on C') 167 | print(x.shape) 168 | from Nmetrics import test 169 | 170 | test(y, p) 171 | p = x.argmax(1) 172 | print('Multi-VAE-C: y = C.argmax(1)') 173 | test(y, p) 174 | ACCc += Nmetrics.acc(y, p) 175 | NMIc += Nmetrics.nmi(y, p) 176 | ARIc += Nmetrics.ari(y, p) 177 | PURc += Nmetrics.purity(y, p) 178 | 179 | X_all = np.concatenate(multiview_cz, axis=1) 180 | p = kmeans.fit_predict(X_all) 181 | print('k-means on [z1, z2, ..., zV]') 182 | print(X_all.shape) 183 | test(y, p) 184 | print('k-means on [zv]\nk-means on [C, zv]') 185 | print(multiview_cz[0].shape, multiview_z[0].shape) 186 | for i in range(view_num): 187 | name = 'cont' + str(i + 1) 188 | x_cz = encodings[name][0].cpu().detach().data.numpy() 189 | x_Conz = multiview_z[i] 190 | p = kmeans.fit_predict(x_cz) 191 | test(y, p) 192 | p = kmeans.fit_predict(x_Conz) 193 | test(y, p) 194 | print('\n') 195 | 196 | multiview_cz.append(x) 197 | X_all = np.concatenate(multiview_cz, axis=1) 198 | p = kmeans.fit_predict(X_all) 199 | # scio.savemat('./viz/' + str(Epochs) + '.mat', {'Z': X_all, 'Y': y, 'P': p}) 200 | print('Multi-VAE-CZ: k-means on [C, z1, z2, ..., zV]') 201 | print(X_all.shape) 202 | test(y, p) 203 | ACCcz += Nmetrics.acc(y, p) 204 | NMIcz += Nmetrics.nmi(y, p) 205 | ARIcz += Nmetrics.ari(y, p) 206 | PURcz += Nmetrics.purity(y, p) 207 | 208 | print('Multi-VAE-C:', ACCc / runs, NMIc / runs, ARIc / runs, PURc / runs) 209 | print('Multi-VAE-CZ:', ACCcz / runs, NMIcz / runs, ARIcz / runs, PURcz / runs) 210 | # np.save('Cmetics.npy', [ACCc/runs, NMIc/runs, ARIc/runs, PURc/runs]) 211 | # np.save('CZmetics.npy', [ACCcz/runs, NMIcz/runs, ARIcz/runs, PURcz/runs]) 212 | NMI_c.append(NMIc / runs) 213 | NMI_cz.append(NMIcz / runs) 214 | ACC_c.append(ACCc / runs) 215 | ACC_cz.append(ACCcz / runs) 216 | # print(NMI_c) 217 | # np.save('NMI_c.npy', NMI_c) 218 | # print(NMI_cz) 219 | # np.save('NMI_cz.npy', NMI_cz) 220 | # print(ACC_c) 221 | # np.save('ACC_c.npy', ACC_c) 222 | # print(ACC_cz) 223 | # np.save('ACC_cz.npy', ACC_cz) 224 | -------------------------------------------------------------------------------- /models/Multi-COIL-10.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubmissionsIn/Multi-VAE/f1b9d7aac8d896267d349b577418eb6a31e31145/models/Multi-COIL-10.pt -------------------------------------------------------------------------------- /models/Multi-FMNIST.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubmissionsIn/Multi-VAE/f1b9d7aac8d896267d349b577418eb6a31e31145/models/Multi-FMNIST.pt -------------------------------------------------------------------------------- /models/Multi-MNIST.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubmissionsIn/Multi-VAE/f1b9d7aac8d896267d349b577418eb6a31e31145/models/Multi-MNIST.pt -------------------------------------------------------------------------------- /models/README.txt: -------------------------------------------------------------------------------- 1 | Link:https://pan.baidu.com/s/1CHQ5EDgNPkW37uanDztGjQ 2 | Pass code:todz 3 | -------------------------------------------------------------------------------- /multi_vae/MvModels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | EPS = 1e-12 5 | 6 | 7 | class VAE(nn.Module): 8 | def __init__(self, img_size, latent_spec, temperature=.67, 9 | use_cuda=False, view_num=2, Network='C', 10 | hidden_dim=256, shareAE=1): 11 | """ 12 | Class which defines model and forward pass. 13 | 14 | Parameters 15 | ---------- 16 | img_size : tuple of ints 17 | Size of images. E.g. (1, 32, 32) 18 | 19 | latent_spec : dict 20 | Specifies latent distribution. 21 | 22 | temperature : float 23 | Temperature for gumbel softmax distribution. 24 | 25 | use_cuda : bool 26 | If True moves model to GPU 27 | """ 28 | super(VAE, self).__init__() 29 | self.use_cuda = use_cuda 30 | 31 | # Parameters 32 | self.img_size = img_size 33 | self.is_continuous = 'cont' in latent_spec 34 | self.is_discrete = 'disc' in latent_spec 35 | self.latent_spec = latent_spec 36 | 37 | self.view_num = view_num 38 | self.MvLatent_spec = {'disc': latent_spec['disc']} 39 | for i in range(self.view_num): 40 | continue_name = 'cont' + str(i + 1) 41 | self.MvLatent_spec[continue_name] = latent_spec['cont'] 42 | 43 | self.Net = Network 44 | self.share = shareAE 45 | if self.Net == 'C': 46 | self.num_pixels = [] 47 | for i in range(self.view_num): 48 | self.num_pixels.append(img_size[1] * img_size[2]) 49 | 50 | self.temperature = temperature 51 | self.hidden_dim = hidden_dim # Hidden dimension of linear layer 52 | self.reshape = (64, 4, 4) # Shape required to start transpose convs 53 | 54 | # Calculate dimensions of latent distribution 55 | self.latent_cont_dim = 0 56 | self.latent_disc_dim = 0 57 | self.num_disc_latents = 0 58 | if self.is_continuous: 59 | self.latent_cont_dim = self.latent_spec['cont'] 60 | if self.is_discrete: 61 | self.latent_disc_dim += sum([dim for dim in self.latent_spec['disc']]) 62 | self.num_disc_latents = len(self.latent_spec['disc']) 63 | self.latent_dim = self.latent_cont_dim + self.latent_disc_dim 64 | 65 | if self.Net == 'C': 66 | Mv_img_to_features = [] 67 | Mv_features_to_hidden = [] 68 | Mv_latent_to_features = [] 69 | Mv_features_to_img = [] 70 | # Define encoder layers 71 | for i in range(self.view_num): 72 | # Intial layer 73 | encoder_layers = [ 74 | nn.Conv2d(self.img_size[0], 32, (4, 4), stride=2, padding=1), 75 | nn.ReLU() 76 | ] 77 | # Add additional layer if (64, 64) images 78 | if self.img_size[1:] == (64, 64): 79 | encoder_layers += [ 80 | nn.Conv2d(32, 32, (4, 4), stride=2, padding=1), 81 | nn.ReLU() 82 | ] 83 | elif self.img_size[1:] == (32, 32): 84 | # (32, 32) images are supported but do not require an extra layer 85 | pass 86 | else: 87 | raise RuntimeError( 88 | "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format( 89 | img_size)) 90 | # Add final layers 91 | encoder_layers += [ 92 | nn.Conv2d(32, 64, (4, 4), stride=2, padding=1), 93 | nn.ReLU(), 94 | nn.Conv2d(64, 64, (4, 4), stride=2, padding=1), 95 | nn.ReLU() 96 | ] 97 | # Define encoder 98 | Mv_img_to_features.append(nn.Sequential(*encoder_layers)) 99 | # Map encoded features into a hidden vector which will be used to 100 | # encode parameters of the latent distribution 101 | features_to_hidden = nn.Sequential( 102 | nn.Linear(64 * 4 * 4, self.hidden_dim), 103 | nn.ReLU() 104 | # nn.Sigmoid() 105 | ) 106 | Mv_features_to_hidden.append(features_to_hidden) 107 | 108 | if self.share: 109 | self.Mv_img_to_features = nn.ModuleList([Mv_img_to_features[0]]) 110 | self.Mv_features_to_hidden = nn.ModuleList([Mv_features_to_hidden[0]]) 111 | else: 112 | self.Mv_img_to_features = nn.ModuleList(Mv_img_to_features) 113 | self.Mv_features_to_hidden = nn.ModuleList(Mv_features_to_hidden) 114 | # Encode parameters of latent distribution 115 | means = [] 116 | log_vars = [] 117 | if self.is_continuous: 118 | for i in range(self.view_num): 119 | means.append(nn.Linear(self.hidden_dim, self.latent_cont_dim)) 120 | log_vars.append(nn.Linear(self.hidden_dim, self.latent_cont_dim)) 121 | self.means = nn.ModuleList(means) 122 | self.log_vars = nn.ModuleList(log_vars) 123 | if self.is_discrete: 124 | # Linear layer for each of the categorical distributions 125 | fc_alphas = [] 126 | for disc_dim in self.latent_spec['disc']: 127 | fc_alphas.append(nn.Linear(self.hidden_dim * self.view_num, disc_dim)) 128 | self.fc_alphas = nn.ModuleList(fc_alphas) 129 | 130 | for i in range(self.view_num): 131 | # Map latent samples to features to be used by generative model 132 | latent_to_features = nn.Sequential( 133 | nn.Linear(self.latent_dim, self.hidden_dim), 134 | nn.ReLU(), 135 | nn.Linear(self.hidden_dim, 64 * 4 * 4), 136 | nn.ReLU() 137 | ) 138 | Mv_latent_to_features.append(latent_to_features) 139 | # Define decoder 140 | decoder_layers = [] 141 | # Additional decoding layer for (64, 64) images 142 | if self.img_size[1:] == (64, 64): 143 | decoder_layers += [ 144 | nn.ConvTranspose2d(64, 64, (4, 4), stride=2, padding=1), 145 | nn.ReLU() 146 | ] 147 | 148 | decoder_layers += [ 149 | nn.ConvTranspose2d(64, 32, (4, 4), stride=2, padding=1), 150 | nn.ReLU(), 151 | nn.ConvTranspose2d(32, 32, (4, 4), stride=2, padding=1), 152 | nn.ReLU(), 153 | nn.ConvTranspose2d(32, self.img_size[0], (4, 4), stride=2, padding=1), 154 | nn.Sigmoid() 155 | # nn.Tanh() 156 | ] 157 | # Define decoder 158 | Mv_features_to_img.append(nn.Sequential(*decoder_layers)) 159 | 160 | if self.share: 161 | self.Mv_latent_to_features = nn.ModuleList([Mv_latent_to_features[0]]) 162 | self.Mv_features_to_img = nn.ModuleList([Mv_features_to_img[0]]) 163 | else: 164 | self.Mv_latent_to_features = nn.ModuleList(Mv_latent_to_features) 165 | self.Mv_features_to_img = nn.ModuleList(Mv_features_to_img) 166 | 167 | def encode(self, X): 168 | """ 169 | Encodes an image into parameters of a latent distribution defined in 170 | self.latent_spec. 171 | 172 | Parameters 173 | ---------- 174 | x : torch.Tensor 175 | Batch of data, shape (N, C, H, W) 176 | """ 177 | batch_size = X[0].size()[0] 178 | features = [] 179 | hiddens = [] 180 | for i in range(self.view_num): 181 | # Encode image to hidden features 182 | if self.share: 183 | net_num = 0 184 | else: 185 | net_num = i 186 | # print(X[i].shape) 187 | features.append(self.Mv_img_to_features[net_num](X[i])) 188 | hiddens.append(self.Mv_features_to_hidden[net_num](features[i].view(batch_size, -1))) 189 | 190 | fusion = torch.cat(hiddens, dim=1) 191 | # print(hidden.shape) 192 | # print(fusion.shape) 193 | # Output parameters of latent distribution from hidden representation 194 | latent_dist = {} 195 | if self.is_continuous: 196 | for i in range(self.view_num): 197 | continue_name = 'cont' + str(i+1) 198 | latent_dist[continue_name] = [self.means[i](hiddens[i]), self.log_vars[i](hiddens[i])] 199 | if self.is_discrete: 200 | latent_dist['disc'] = [] 201 | for fc_alpha in self.fc_alphas: 202 | latent_dist['disc'].append(F.softmax(fc_alpha(fusion), dim=1)) 203 | 204 | return latent_dist 205 | 206 | def reparameterize(self, latent_dist): 207 | """ 208 | Samples from latent distribution using the reparameterization trick. 209 | 210 | Parameters 211 | ---------- 212 | latent_dist : dict 213 | Dict with keys 'cont' or 'disc' or both, containing the parameters 214 | of the latent distributions as torch.Tensor instances. 215 | """ 216 | latent_sample = [] 217 | if self.is_continuous: 218 | for i in range(self.view_num): 219 | countinus_name = 'cont' + str(i+1) 220 | mean, logvar = latent_dist[countinus_name] 221 | cont_sample = self.sample_normal(mean, logvar) 222 | latent_sample.append(cont_sample) 223 | 224 | if self.is_discrete: 225 | for alpha in latent_dist['disc']: 226 | disc_sample = self.sample_gumbel_softmax(alpha) 227 | latent_sample.append(disc_sample) 228 | 229 | # Concatenate continuous and discrete samples into one large sample 230 | return torch.cat(latent_sample, dim=1) 231 | 232 | def sample_normal(self, mean, logvar): 233 | """ 234 | Samples from a normal distribution using the reparameterization trick. 235 | 236 | Parameters 237 | ---------- 238 | mean : torch.Tensor 239 | Mean of the normal distribution. Shape (N, D) where D is dimension 240 | of distribution. 241 | 242 | logvar : torch.Tensor 243 | Diagonal log variance of the normal distribution. Shape (N, D) 244 | """ 245 | if self.training: 246 | std = torch.exp(0.5 * logvar) 247 | eps = torch.zeros(std.size()).normal_() 248 | if self.use_cuda: 249 | eps = eps.cuda() 250 | return mean + std * eps 251 | else: 252 | # Reconstruction mode 253 | return mean 254 | 255 | def sample_gumbel_softmax(self, alpha): 256 | """ 257 | Samples from a gumbel-softmax distribution using the reparameterization 258 | trick. 259 | 260 | Parameters 261 | ---------- 262 | alpha : torch.Tensor 263 | Parameters of the gumbel-softmax distribution. Shape (N, D) 264 | """ 265 | if self.training: 266 | # Sample from gumbel distribution 267 | unif = torch.rand(alpha.size()) 268 | if self.use_cuda: 269 | unif = unif.cuda() 270 | gumbel = -torch.log(-torch.log(unif + EPS) + EPS) 271 | # Reparameterize to create gumbel softmax sample 272 | log_alpha = torch.log(alpha + EPS) 273 | logit = (log_alpha + gumbel) / self.temperature 274 | return F.softmax(logit, dim=1) 275 | else: 276 | # In reconstruction mode, pick most likely sample 277 | _, max_alpha = torch.max(alpha, dim=1) 278 | one_hot_samples = torch.zeros(alpha.size()) 279 | # On axis 1 of one_hot_samples, scatter the value 1 at indices 280 | # max_alpha. Note the view is because scatter_ only accepts 2D 281 | # tensors. 282 | one_hot_samples.scatter_(1, max_alpha.view(-1, 1).data.cpu(), 1) 283 | if self.use_cuda: 284 | one_hot_samples = one_hot_samples.cuda() 285 | return one_hot_samples 286 | 287 | def decode(self, latent_samples): 288 | """ 289 | Decodes sample from latent distribution into an image. 290 | 291 | Parameters 292 | ---------- 293 | latent_sample : torch.Tensor 294 | Sample from latent distribution. 295 | """ 296 | features_to_img = [] 297 | for i in range(self.view_num): 298 | if self.share: 299 | net_num = 0 300 | else: 301 | net_num = i 302 | feature = self.Mv_latent_to_features[net_num](latent_samples[i]) 303 | if self.Net == 'C': 304 | features_to_img.append(self.Mv_features_to_img[net_num](feature.view(-1, *self.reshape))) 305 | return features_to_img[:] 306 | 307 | def forward(self, X): 308 | """ 309 | Forward pass of model. 310 | 311 | Parameters 312 | ---------- 313 | x : torch.Tensor 314 | Batch of data. Shape (N, C, H, W) 315 | """ 316 | latent_dist = self.encode(X) 317 | # print(latent_dist) 318 | latent_sample = self.reparameterize(latent_dist) 319 | # print(latent_sample.shape) 320 | split_list = [] 321 | for i in range(self.view_num): 322 | split_list.append(self.latent_spec['cont']) 323 | split_list.append(self.latent_spec['disc'][0]) 324 | cont_des_list = latent_sample.split(split_list, dim=1) 325 | decode_list = [] 326 | for i in range(self.view_num): 327 | decode_list.append(torch.cat([cont_des_list[i], cont_des_list[-1]], dim=1)) 328 | out_list = self.decode(decode_list) 329 | return out_list, latent_dist 330 | -------------------------------------------------------------------------------- /multi_vae/MvTraining.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | EPS = 1e-12 7 | 8 | class Trainer(): 9 | def __init__(self, model, optimizer, cont_capacity=None, 10 | disc_capacity=None, print_loss_every=50, record_loss_every=100, 11 | use_cuda=False, view_num=2, DATA='DATA'): 12 | """ 13 | **Acknowledgments** 14 | This code is inspired by https://github.com/Schlumberger/joint-vae 15 | 16 | Class to handle training of model. 17 | 18 | Parameters 19 | ---------- 20 | model : multi-vae.models.VAE instance 21 | 22 | optimizer : torch.optim.Optimizer instance 23 | 24 | cont_capacity : tuple (float, float, int, float) or None 25 | Tuple containing (min_capacity, max_capacity, num_iters, gamma_z). 26 | Parameters to control the capacity of the continuous latent 27 | channels. Cannot be None if model.is_continuous is True. 28 | 29 | disc_capacity : tuple (float, float, int, float) or None 30 | Tuple containing (min_capacity, max_capacity, num_iters, gamma_c). 31 | Parameters to control the capacity of the discrete latent channels. 32 | Cannot be None if model.is_discrete is True. 33 | 34 | print_loss_every : int 35 | Frequency with which loss is printed during training. 36 | 37 | record_loss_every : int 38 | Frequency with which loss is recorded during training. 39 | 40 | use_cuda : bool 41 | If True moves model and training to GPU. 42 | """ 43 | self.model = model 44 | self.optimizer = optimizer 45 | self.cont_capacity = cont_capacity 46 | self.disc_capacity = disc_capacity 47 | self.print_loss_every = print_loss_every 48 | self.record_loss_every = record_loss_every 49 | self.use_cuda = use_cuda 50 | self.view_num = view_num 51 | self.DATA = DATA 52 | if self.model.is_continuous and self.cont_capacity is None: 53 | raise RuntimeError("Model is continuous but cont_capacity not provided.") 54 | 55 | if self.model.is_discrete and self.disc_capacity is None: 56 | raise RuntimeError("Model is discrete but disc_capacity not provided.") 57 | 58 | if self.use_cuda: 59 | self.model.cuda() 60 | 61 | # Initialize attributes 62 | self.num_steps = 0 63 | self.beta = [] 64 | for i in range(self.view_num): 65 | self.beta.append(1) 66 | self.batch_size = None 67 | self.losses = {'loss': [], 68 | 'recon_loss': [], 69 | 'kl_loss': []} 70 | 71 | self.loss_r = [[], [], []] # recon_loss 72 | self.loss_z = [[], [], []] # kl_loss z 73 | self.loss_c = [[], [], []] # kl_loss c 74 | 75 | self.mean_loss = [[], [], []] 76 | 77 | # Keep track of divergence values for each latent variable 78 | if self.model.is_continuous: 79 | self.losses['kl_loss_cont'] = [] 80 | # For every dimension of continuous latent variables 81 | for i in range(self.model.latent_spec['cont']): 82 | self.losses['kl_loss_cont_' + str(i)] = [] 83 | 84 | if self.model.is_discrete: 85 | self.losses['kl_loss_disc'] = [] 86 | # For every discrete latent variable 87 | for i in range(len(self.model.latent_spec['disc'])): 88 | self.losses['kl_loss_disc_' + str(i)] = [] 89 | 90 | def train(self, data_loader, epochs=10, save_training_gif=None): 91 | """ 92 | Trains the model. 93 | 94 | Parameters 95 | ---------- 96 | data_loader : torch.utils.data.DataLoader 97 | 98 | epochs : int 99 | Number of epochs to train the model for. 100 | 101 | save_training_gif : None or tuple (string, Visualizer instance) 102 | If not None, will use visualizer object to create image of samples 103 | after every epoch and will save gif of these at location specified 104 | by string. Note that string should end with '.gif'. 105 | """ 106 | if save_training_gif is not None: 107 | training_progress_images = [] 108 | 109 | self.batch_size = data_loader.batch_size 110 | self.model.train() 111 | c = [] 112 | z = [[], [], []] 113 | for epoch in range(epochs): 114 | # print("Epoch:" + str(epoch)) 115 | mean_epoch_loss = self._train_epoch(data_loader) 116 | for i in range(self.view_num): 117 | print('Epoch: {}'.format(epoch + 1) + '. Average loss view-' + str(i+1) + ': {:.2f}'.format( self.batch_size * self.model.num_pixels[i] * mean_epoch_loss[i])) 118 | self.mean_loss[i].append(self.batch_size * self.model.num_pixels[i] * mean_epoch_loss[i]) 119 | show = 0 120 | if (epoch % 1 == 0) and show == 1: 121 | batch_size_test = 100000 # max size 122 | from utils.MvDataloaders import Get_dataloaders 123 | if self.model.Net == 'C': 124 | Train_loader, View_num, N_clusters, _ = Get_dataloaders(batch_size=batch_size_test, DATANAME=self.DATA + '.mat') 125 | for Batch_idx, Data in enumerate(Train_loader): 126 | break 127 | data = Data[0:-1] 128 | labels = Data[-1] 129 | inputs = [] 130 | from torch.autograd import Variable 131 | for i in range(View_num): 132 | inputs.append(Variable(data[i])) 133 | inputs[i] = inputs[i].cuda() 134 | encodings = self.model.encode(inputs) 135 | from Nmetrics import test 136 | from sklearn.cluster import KMeans 137 | kmeans = KMeans(n_clusters=N_clusters, n_init=100) 138 | x = encodings['disc'][0].cpu().detach().data.numpy() 139 | y = labels.cpu().detach().data.numpy() 140 | for i in range(View_num): 141 | name = 'cont' + str(i+1) 142 | x_c = encodings[name][0].cpu().detach().data.numpy() # cz 143 | cz = kmeans.fit_predict(x_c) 144 | z[i].append(test(y, cz)) 145 | p = kmeans.fit_predict(x) 146 | c.append(test(y, p)) 147 | # yp = x.argmax(1) 148 | # acc_cmax = test(yp, y) 149 | # acc_abs = test(p, yp) 150 | # if acc_abs == 1: 151 | # break 152 | 153 | if save_training_gif is not None: 154 | # Generate batch of images and convert to grid 155 | viz = save_training_gif[1] 156 | viz.save_images = False 157 | img_grid = viz.all_latent_traversals(size=10) 158 | # Convert to numpy and transpose axes to fit imageio convention 159 | # i.e. (width, height, channels) 160 | img_grid = np.transpose(img_grid.numpy(), (1, 2, 0)) 161 | # Add image grid to training progress 162 | training_progress_images.append(img_grid) 163 | # np.save('./MultiC.npy', c) 164 | # np.save('./MultiZ.npy', z) 165 | # np.save('./LossR.npy', self.loss_r) 166 | # np.save('./LossZ.npy', self.loss_z) 167 | # np.save('./LossC.npy', self.loss_c) 168 | np.save('./mean_loss.npy', self.mean_loss) 169 | 170 | if save_training_gif is not None: 171 | imageio.mimsave(save_training_gif[0], training_progress_images, 172 | fps=24) 173 | 174 | def _train_epoch(self, data_loader): 175 | """ 176 | Trains the model for one epoch. 177 | 178 | Parameters 179 | ---------- 180 | data_loader : torch.utils.data.DataLoader 181 | """ 182 | epoch_loss = [] 183 | print_every_loss = [] 184 | for i in range(self.view_num): 185 | epoch_loss.append(0.) 186 | print_every_loss.append(0.) 187 | 188 | for batch_idx, Data in enumerate(data_loader): 189 | iter_loss = self._train_iteration(Data[0:-1]) 190 | # print(iter_loss) 191 | for i in range(self.view_num): 192 | epoch_loss[i] += iter_loss[i] 193 | print_every_loss[i] += iter_loss[i] 194 | # Print loss info every self.print_loss_every iteration 195 | if batch_idx % self.print_loss_every == 0: 196 | if batch_idx == 0: 197 | mean_loss = print_every_loss 198 | else: 199 | for i in range(self.view_num): 200 | mean_loss[i] = print_every_loss[i] / self.print_loss_every 201 | for i in range(self.view_num): 202 | # print('Loss' + str(i+1) + '\t{}/{}\t: {:.3f}'.format(batch_idx * len(Data[i]), len(data_loader.dataset), 203 | # self.model.num_pixels * mean_loss[i])) 204 | print_every_loss[i] = 0. 205 | # Return mean epoch loss 206 | Epoch_loss = [] 207 | for i in range(self.view_num): 208 | Epoch_loss.append(epoch_loss[i] / len(data_loader.dataset)) 209 | return Epoch_loss 210 | 211 | def _train_iteration(self, data): 212 | """ 213 | Trains the model for one iteration on a batch of data. 214 | 215 | Parameters 216 | ---------- 217 | data : torch.Tensor 218 | A batch of data. Shape (N, C, H, W) 219 | """ 220 | self.num_steps += 1 221 | for i in range(self.view_num): 222 | if self.use_cuda: 223 | data[i] = data[i].cuda() 224 | 225 | self.optimizer.zero_grad() 226 | listout, latent_dist = self.model(data) 227 | recon_batchs = [] 228 | for i in range(self.view_num): 229 | recon_batchs.append(listout[i]) 230 | 231 | Loss = [] 232 | 233 | max_recon_loss = F.binary_cross_entropy(recon_batchs[0].view(-1, self.model.num_pixels[0]), 234 | data[0].view(-1, self.model.num_pixels[0])) 235 | max_recon_loss *= self.model.num_pixels[0] 236 | # print(float(max_recon_loss)) 237 | recloss = [max_recon_loss] 238 | for i in range(self.view_num - 1): 239 | loss_view = F.binary_cross_entropy(recon_batchs[i+1].view(-1, self.model.num_pixels[i+1]), 240 | data[i+1].view(-1, self.model.num_pixels[i+1])) 241 | loss_view *= self.model.num_pixels[i+1] 242 | # print(float(loss_view)) 243 | recloss.append(loss_view) 244 | if max_recon_loss < loss_view: 245 | max_recon_loss = loss_view 246 | # print(self.num_steps) 247 | if self.num_steps == 1: 248 | for i in range(self.view_num): 249 | self.beta[i] = float(recloss[i])/float(max_recon_loss) 250 | # print(self.beta) 251 | 252 | # print(max_recon_loss) 253 | # print(latent_dist) 254 | for i in range(self.view_num): 255 | countinue_name = 'cont' + str(i+1) 256 | Loss.append(self._loss_function(data[i], recon_batchs[i], {'cont': latent_dist[countinue_name], 'disc': latent_dist['disc']}, 257 | view=i, max_recon_loss=max_recon_loss, beta=self.beta[i])) 258 | for i in range(self.view_num - 1): 259 | Loss[i].backward(retain_graph=True) 260 | Loss[-1].backward() 261 | self.optimizer.step() 262 | 263 | train_loss = [] 264 | for i in range(self.view_num): 265 | train_loss.append(Loss[i].item()) 266 | # print(train_loss) 267 | return train_loss 268 | 269 | def _loss_function(self, data, recon_data, latent_dist, view=0, max_recon_loss=1000, beta=1): 270 | """ 271 | Calculates loss for a batch of data. 272 | 273 | Parameters 274 | ---------- 275 | data : torch.Tensor 276 | Input data (e.g. batch of images). Should have shape (N, C, H, W) 277 | 278 | recon_data : torch.Tensor 279 | Reconstructed data. Should have shape (N, C, H, W) 280 | 281 | latent_dist : dict 282 | Dict with keys 'cont' or 'disc' or both containing the parameters 283 | of the latent distributions as values. 284 | """ 285 | # Reconstruction loss is pixel wise cross-entropy 286 | if self.model.Net == 'C': 287 | recon_loss = F.binary_cross_entropy(recon_data.view(-1, self.model.num_pixels[view]), 288 | data.view(-1, self.model.num_pixels[view])) 289 | 290 | # F.binary_cross_entropy takes mean over pixels, so unnormalise this 291 | recon_loss *= self.model.num_pixels[view] 292 | # print(recon_loss, self.model.num_pixels[view]) 293 | # Calculate KL divergences 294 | kl_cont_loss = 0 # Used to compute capacity loss (but not a loss in itself) 295 | kl_disc_loss = 0 # Used to compute capacity loss (but not a loss in itself) 296 | cont_capacity_loss = 0 297 | disc_capacity_loss = 0 298 | 299 | cont_max, cont_gamma, iters_add_cont_max = \ 300 | self.cont_capacity 301 | disc_max, disc_gamma, iters_add_disc_max = \ 302 | self.disc_capacity 303 | 304 | step = cont_max / iters_add_cont_max 305 | 306 | 307 | if self.model.is_continuous: 308 | # Calculate KL divergence 309 | mean, logvar = latent_dist['cont'] 310 | kl_cont_loss = self._kl_normal_loss(mean, logvar) 311 | # Linearly increase capacity of continuous channels 312 | # Increase continuous capacity without exceeding cont_max 313 | cont_cap_current = step * self.num_steps 314 | cont_cap_current = min(cont_cap_current, cont_max) 315 | # Calculate continuous capacity loss 316 | if self.DATA in ['Object-Digit-Product']: 317 | # cont_gamma = cont_gamma * recon_loss/max_recon_loss 318 | cont_gamma = cont_gamma * beta 319 | cont_capacity_loss = cont_gamma * torch.abs(cont_cap_current - kl_cont_loss) 320 | if self.model.is_discrete: 321 | # Calculate KL divergence 322 | kl_disc_loss = self._kl_multiple_discrete_loss(latent_dist['disc']) 323 | # Linearly increase capacity of discrete channels 324 | # Increase discrete capacity without exceeding disc_max or theoretical 325 | disc_cap_current = step * self.num_steps 326 | disc_cap_current = min(disc_cap_current, disc_max) 327 | # Calculate discrete capacity loss 328 | if self.DATA in ['Object-Digit-Product']: 329 | # disc_gamma = disc_gamma * recon_loss/max_recon_loss 330 | disc_gamma = disc_gamma * beta 331 | disc_capacity_loss = disc_gamma * torch.abs(disc_cap_current - kl_disc_loss) 332 | # Calculate total kl value to record it 333 | kl_loss = kl_cont_loss + kl_disc_loss 334 | # Calculate total loss 335 | total_loss = recon_loss + cont_capacity_loss + disc_capacity_loss 336 | # print(self.num_steps, recon_loss, cont_capacity_loss+disc_capacity_loss, kl_loss) 337 | 338 | # Record losses 339 | # if self.model.training and self.num_steps % self.record_loss_every == 0: 340 | # self.losses['recon_loss'].append(recon_loss.item()) 341 | # self.losses['kl_loss'].append(kl_loss.item()) 342 | # self.losses['loss'].append(total_loss.item()) 343 | # print(recon_loss.data, kl_cont_loss.data, kl_disc_loss.data) 344 | 345 | # self.loss_r[view].append(recon_loss.item()) 346 | # self.loss_z[view].append(kl_cont_loss.item()) 347 | # self.loss_c[view].append(kl_disc_loss.item()) 348 | 349 | # To avoid large losses normalise by number of pixels 350 | return total_loss / self.model.num_pixels[view] 351 | 352 | def _kl_normal_loss(self, mean, logvar): 353 | """ 354 | Calculates the KL divergence between a normal distribution with 355 | diagonal covariance and a unit normal distribution. 356 | 357 | Parameters 358 | ---------- 359 | mean : torch.Tensor 360 | Mean of the normal distribution. Shape (N, D) where D is dimension 361 | of distribution. 362 | 363 | logvar : torch.Tensor 364 | Diagonal log variance of the normal distribution. Shape (N, D) 365 | """ 366 | # Calculate KL divergence 367 | kl_values = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()) 368 | # Mean KL divergence across batch for each latent variable 369 | kl_means = torch.mean(kl_values, dim=0) 370 | # KL loss is sum of mean KL of each latent variable 371 | kl_loss = torch.sum(kl_means) 372 | 373 | # Record losses 374 | if self.model.training and self.num_steps % self.record_loss_every == 1: 375 | self.losses['kl_loss_cont'].append(kl_loss.item()) 376 | for i in range(self.model.latent_spec['cont']): 377 | self.losses['kl_loss_cont_' + str(i)].append(kl_means[i].item()) 378 | 379 | return kl_loss 380 | 381 | def _kl_multiple_discrete_loss(self, alphas): 382 | """ 383 | Calculates the KL divergence between a set of categorical distributions 384 | and a set of uniform categorical distributions. 385 | 386 | Parameters 387 | ---------- 388 | alphas : list 389 | List of the alpha parameters of a categorical (or gumbel-softmax) 390 | distribution. For example, if the categorical atent distribution of 391 | the model has dimensions [2, 5, 10] then alphas will contain 3 392 | torch.Tensor instances with the parameters for each of 393 | the distributions. Each of these will have shape (N, D). 394 | """ 395 | # Calculate kl losses for each discrete latent 396 | kl_losses = [self._kl_discrete_loss(alpha) for alpha in alphas] 397 | 398 | # Total loss is sum of kl loss for each discrete latent 399 | kl_loss = torch.sum(torch.cat(kl_losses)) 400 | 401 | # Record losses 402 | if self.model.training and self.num_steps % self.record_loss_every == 1: 403 | self.losses['kl_loss_disc'].append(kl_loss.item()) 404 | for i in range(len(alphas)): 405 | self.losses['kl_loss_disc_' + str(i)].append(kl_losses[i].item()) 406 | 407 | return kl_loss 408 | 409 | def _kl_discrete_loss(self, alpha): 410 | """ 411 | Calculates the KL divergence between a categorical distribution and a 412 | uniform categorical distribution. 413 | 414 | Parameters 415 | ---------- 416 | alpha : torch.Tensor 417 | Parameters of the categorical or gumbel-softmax distribution. 418 | Shape (N, D) 419 | """ 420 | disc_dim = int(alpha.size()[-1]) 421 | log_dim = torch.Tensor([np.log(disc_dim)]) 422 | if self.use_cuda: 423 | log_dim = log_dim.cuda() 424 | # Calculate negative entropy of each row 425 | neg_entropy = torch.sum(alpha * torch.log(alpha + EPS), dim=1) 426 | # Take mean of negative entropy across batch 427 | mean_neg_entropy = torch.mean(neg_entropy, dim=0) 428 | # KL loss of alpha with uniform categorical variable 429 | kl_loss = log_dim + mean_neg_entropy 430 | return kl_loss 431 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==0.3.1 2 | torchvision==0.2.0 3 | scipy==0.19.1 4 | imageio==2.1.2 5 | scikit-image==0.13.1 -------------------------------------------------------------------------------- /utils/DATA/Multi-COIL-10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubmissionsIn/Multi-VAE/f1b9d7aac8d896267d349b577418eb6a31e31145/utils/DATA/Multi-COIL-10.mat -------------------------------------------------------------------------------- /utils/DATA/README.txt: -------------------------------------------------------------------------------- 1 | Data & Model Link:https://pan.baidu.com/s/1CHQ5EDgNPkW37uanDztGjQ 2 | Pass code:todz 3 | 4 | 5 | Disentangling multi-view information&representation has a strict condition that it is clear what view-common information and view-peculiar/private information is going to be disentangled. In order to verify the disentanglement effectiveness of visual representations in Multi-VAE, we constructed two kinds of datasets as follows: 6 | 7 | 1. Multiple views are with view-common visual information and with different view-peculiar visual information. Such samples are constructed by the same datasets, where the view-common information among multiple views is the same sample object; and the view-peculiar information among multiple views is their different visual patterns. For convenience, these datasets are entitled Multi-MNIST, Multi-Fashion, Multi-COIL-10, and Multi-COIL-20. 8 | 9 | 2. Multiple views are with view-common semantic information but with different view-peculiar visual information. Such samples are constructed by different datasets, where the view-common information among multiple views is the same semantic category/label; and the view-peculiar information among multiple views is that they have different visual patterns coming from unrelated datasets. For convenience, these datasets are entitled Digit-Product and Object-Digit-Product. 10 | 11 | It should be noted that: 12 | 13 | 1. Traditional multi-view datasets may not meet the requirement that multiple views contain view-common information and view-peculiar information that is obvious and can be disentangled. They may not be suitable for the model Multi-VAE as it focuses on disentanglement. 14 | 15 | 2. This link provides the datasets and models. If one needs to use the datasets we constructed, please cite the original source of these datasets, as described in many previous publications. 16 | -------------------------------------------------------------------------------- /utils/MvDataloaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader, TensorDataset 3 | import torch 4 | import scipy.io as scio 5 | 6 | 7 | def Get_dataloaders(batch_size=128, path_to_data='./utils/DATA/', DATANAME='MNIST-Sobel.mat'): 8 | """dataloader with (32, 32) images.""" 9 | DATA = scio.loadmat(path_to_data + DATANAME) 10 | view = len(DATA)-3-1 11 | X1 = DATA['X1'] 12 | X2 = DATA['X2'] 13 | print(X1.shape) 14 | print(X2.shape) 15 | if view == 3: 16 | X3 = DATA['X3'] 17 | print(X3.shape) 18 | y = DATA['Y'] 19 | size = y.shape[1] 20 | print(y.shape[1]) 21 | cluster = np.unique(y) 22 | # print(cluster) 23 | print('Cluster K:' + str(len(cluster))) 24 | x1 = torch.from_numpy(X1).float() 25 | X1 = [] 26 | x2 = torch.from_numpy(X2).float() 27 | X2 = [] 28 | if view == 3: 29 | x3 = torch.from_numpy(X3).float() 30 | X3 = [] 31 | y = torch.from_numpy(y[0]) 32 | if view == 2: 33 | X = TensorDataset(x1, x2, y) 34 | if view == 3: 35 | X = TensorDataset(x1, x2, x3, y) 36 | train_loader = DataLoader(X, batch_size=batch_size, shuffle=True) 37 | return train_loader, view, len(cluster), size 38 | -------------------------------------------------------------------------------- /utils/MvLoad_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from multi_vae.MvModels import VAE 3 | 4 | 5 | def load(view_num=2, 6 | Network='C', 7 | hid=256, 8 | img_size= (1, 32, 32), 9 | path='./example-model.pt', 10 | latent_spec = {"disc": [10], "cont": 10}, 11 | shareAE=1): 12 | """ 13 | Loads a trained model. 14 | 15 | Parameters 16 | ---------- 17 | path : string 18 | Path to folder where model is saved. For example 19 | './models/dataset/'. Note the path MUST end with a '/' 20 | """ 21 | path_to_model = path 22 | # Get model 23 | model = VAE(latent_spec=latent_spec, img_size=img_size, view_num=view_num, 24 | Network=Network, hidden_dim=hid, shareAE=shareAE) 25 | model.load_state_dict(torch.load(path_to_model, 26 | map_location=lambda storage, loc: storage)) 27 | 28 | return model 29 | -------------------------------------------------------------------------------- /viz/MvVisualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from viz.latent_traversals import LatentTraverser 4 | from torch.autograd import Variable 5 | from torchvision.utils import make_grid, save_image 6 | 7 | 8 | class Visualizer(): 9 | def __init__(self, model, view_num=2, show_view=1): 10 | """ 11 | Visualizer is used to generate images of samples, reconstructions, 12 | latent traversals and so on of the trained model. 13 | 14 | Parameters 15 | ---------- 16 | model : multi-vae.models.VAE instance 17 | """ 18 | self.view_num = view_num 19 | self.show_view = show_view - 1 20 | self.model = model 21 | self.latent_traverser = LatentTraverser(self.model.latent_spec) 22 | self.save_images = False # If false, each method returns a tensor 23 | # instead of saving image. 24 | 25 | def reconstructions(self, data, size=(8, 8), filename='recon.png'): 26 | """ 27 | Generates reconstructions of data through the model. 28 | 29 | Parameters 30 | ---------- 31 | data : torch.Tensor 32 | Data to be reconstructed. Shape (N, C, H, W) 33 | 34 | size : tuple of ints 35 | Size of grid on which reconstructions will be plotted. The number 36 | of rows should be even, so that upper half contains true data and 37 | bottom half contains reconstructions 38 | """ 39 | # Plot reconstructions in test mode, i.e. without sampling from latent 40 | self.model.eval() 41 | # Pass data through VAE to obtain reconstruction 42 | 43 | # input_data = Variable(data, volatile=True) 44 | input_datas = [] 45 | with torch.no_grad(): 46 | for i in range(self.view_num): 47 | input_data = Variable(data[i]) 48 | if self.model.use_cuda: 49 | input_datas.append(input_data.cuda()) 50 | else: 51 | input_datas.append(input_data) 52 | view = self.show_view 53 | recon_datas, prior_samples = self.model(input_datas) 54 | print(prior_samples) 55 | 56 | # prior_sample = self.latent_traverser.traverse_grid(size=(5, 10), 57 | # cont_idx=None, 58 | # cont_axis=None, 59 | # disc_idx=0, 60 | # disc_axis=0) 61 | # print(prior_sample.shape) 62 | # print(prior_samples['disc'][0].shape) 63 | # for i in range(50): 64 | # for j in range(10): 65 | # prior_sample[i, j+10] = prior_samples['disc'][0][i, j] 66 | # generateds = self._decode_latents([prior_sample, prior_sample, prior_sample]) 67 | # if self.save_images: 68 | # save_image(generateds[view].data, 'cluster.png', nrow=size[1]) 69 | # else: 70 | # return make_grid(generateds[view].data, nrow=size[1]) 71 | 72 | self.model.train() 73 | # Upper half of plot will contain data, bottom half will contain 74 | # reconstructions 75 | num_images = size[0] * size[1] / 2 76 | num_images = np.int(num_images) 77 | 78 | originals = input_datas[view][:num_images].cpu() 79 | reconstructions = recon_datas[view].view(-1, *self.model.img_size)[:num_images].cpu() 80 | # If there are fewer examples given than spaces available in grid, 81 | # augment with blank images 82 | num_examples = originals.size()[0] 83 | if num_images > num_examples: 84 | blank_images = torch.zeros((num_images - num_examples,) + originals.size()[1:]) 85 | originals = torch.cat([originals, blank_images]) 86 | reconstructions = torch.cat([reconstructions, blank_images]) 87 | 88 | # Concatenate images and reconstructions 89 | comparison = torch.cat([originals, reconstructions]) 90 | 91 | if self.save_images: 92 | save_image(comparison.data, filename, nrow=size[0]) 93 | else: 94 | return make_grid(comparison.data, nrow=size[0]) 95 | 96 | def samples(self, size=(10, 10), filename='samples.png'): 97 | """ 98 | Generates samples from learned distribution by sampling prior and 99 | decoding. 100 | 101 | size : tuple of ints 102 | """ 103 | # Get prior samples from latent distribution 104 | cached_sample_prior = self.latent_traverser.sample_prior 105 | self.latent_traverser.sample_prior = True 106 | prior_samples_1 = self.latent_traverser.traverse_grid(size=size, 107 | cont_idx=None, 108 | cont_axis=None, 109 | disc_idx=0, 110 | disc_axis=0) 111 | self.latent_traverser.sample_prior = cached_sample_prior 112 | # Map samples through decoder 113 | print("------------------------------") 114 | print(prior_samples_1.shape) 115 | prior_zeros = torch.zeros(prior_samples_1.shape) 116 | # prior_zeros = torch.ones(prior_samples_1.shape) 117 | # print(prior_zeros.shape) 118 | # print(prior_samples_1[:, 0:10].shape) 119 | # prior_samples_1 = torch.cat([prior_samples_1[:, 0:10], prior_zeros[:, 10:20]], dim=1) 120 | print(prior_samples_1) 121 | # print(prior_samples_1[0:10]) 122 | print("------------------------------") 123 | prior_samples = [] 124 | for i in range(self.view_num): 125 | prior_samples.append(prior_samples_1) 126 | generateds = self._decode_latents(prior_samples) 127 | 128 | if self.save_images: 129 | save_image(generateds[self.show_view].data, filename, nrow=size[1]) 130 | else: 131 | return make_grid(generateds[self.show_view].data, nrow=size[1]) 132 | 133 | def latent_traversal_line(self, cont_idx=None, disc_idx=None, size=8, 134 | filename='traversal_line.png'): 135 | """ 136 | Generates an image traversal through a latent dimension. 137 | 138 | Parameters 139 | ---------- 140 | See viz.latent_traversals.LatentTraverser.traverse_line for parameter 141 | documentation. 142 | """ 143 | # Generate latent traversal 144 | latent_samples = self.latent_traverser.traverse_line(cont_idx=cont_idx, 145 | disc_idx=disc_idx, 146 | size=size) 147 | 148 | # Map samples through decoder 149 | x = torch.cat([latent_samples, latent_samples, latent_samples], dim=0) 150 | for i in range(10): 151 | x[3 * i] = latent_samples[i] 152 | x[3 * i + 1] = x[3 * i] 153 | x[3 * i + 2] = x[3 * i] 154 | print(x) 155 | prior_samples = [] 156 | for i in range(self.view_num): 157 | prior_samples.append(x) 158 | # print(prior_samples) 159 | generateds = self._decode_latents(prior_samples) 160 | 161 | if self.save_images: 162 | save_image(generateds[self.show_view].data, filename, nrow=3) 163 | else: 164 | return make_grid(generateds[self.show_view].data, nrow=3) 165 | 166 | def latent_traversal_grid(self, cont_idx=None, cont_axis=None, 167 | disc_idx=None, disc_axis=None, size=(5, 5), 168 | filename='traversal_grid.png'): 169 | """ 170 | Generates a grid of image traversals through two latent dimensions. 171 | 172 | Parameters 173 | ---------- 174 | See viz.latent_traversals.LatentTraverser.traverse_grid for parameter 175 | documentation. 176 | """ 177 | # Generate latent traversal 178 | latent_samples = self.latent_traverser.traverse_grid(cont_idx=cont_idx, 179 | cont_axis=cont_axis, 180 | disc_idx=disc_idx, 181 | disc_axis=disc_axis, 182 | size=size) 183 | 184 | # Map samples through decoder 185 | prior_samples = [] 186 | for i in range(self.view_num): 187 | prior_samples.append(latent_samples) 188 | generateds = self._decode_latents(prior_samples) 189 | 190 | if self.save_images: 191 | save_image(generateds[self.show_view].data, filename, nrow=size[1]) 192 | else: 193 | return make_grid(generateds[self.show_view].data, nrow=size[1]) 194 | 195 | def all_latent_traversals(self, size=10, filename='all_traversals.png'): 196 | """ 197 | Traverses all latent dimensions one by one and plots a grid of images 198 | where each row corresponds to a latent traversal of one latent 199 | dimension. 200 | 201 | Parameters 202 | ---------- 203 | size : int 204 | Number of samples for each latent traversal. 205 | """ 206 | latent_samples = [] 207 | 208 | # Perform line traversal of every continuous and discrete latent 209 | for cont_idx in range(self.model.latent_cont_dim): 210 | latent_samples.append(self.latent_traverser.traverse_line(cont_idx=cont_idx, 211 | disc_idx=None, 212 | size=size)) 213 | 214 | for disc_idx in range(self.model.num_disc_latents): 215 | latent_samples.append(self.latent_traverser.traverse_line(cont_idx=None, 216 | disc_idx=disc_idx, 217 | size=size)) 218 | 219 | # Decode samples 220 | samples = [] 221 | for i in range(self.view_num): 222 | samples.append(torch.cat(latent_samples, dim=0)) 223 | generateds = self._decode_latents(samples) 224 | 225 | if self.save_images: 226 | save_image(generateds[self.show_view].data, filename, nrow=size) 227 | else: 228 | return make_grid(generateds[self.show_view].data, nrow=size) 229 | 230 | def _decode_latents(self, latent_sample): 231 | """ 232 | Decodes latent samples into images. 233 | 234 | Parameters 235 | ---------- 236 | latent_samples : torch.autograd.Variable 237 | Samples from latent distribution. Shape (N, L) where L is dimension 238 | of latent distribution. 239 | """ 240 | latent_samples = [] 241 | for i in range(self.view_num): 242 | latent_samples.append(Variable(latent_sample[i])) 243 | if self.model.use_cuda: 244 | latent_samples[i] = latent_samples[i].cuda() 245 | 246 | return self.model.decode(latent_samples) 247 | 248 | 249 | def reorder_img(orig_img, reorder, by_row=True, img_size=(3, 32, 32), padding=2): 250 | """ 251 | Reorders rows or columns of an image grid. 252 | 253 | Parameters 254 | ---------- 255 | orig_img : torch.Tensor 256 | Original image. Shape (channels, width, height) 257 | 258 | reorder : list of ints 259 | List corresponding to desired permutation of rows or columns 260 | 261 | by_row : bool 262 | If True reorders rows, otherwise reorders columns 263 | 264 | img_size : tuple of ints 265 | Image size following pytorch convention 266 | 267 | padding : int 268 | Number of pixels used to pad in torchvision.utils.make_grid 269 | """ 270 | reordered_img = torch.zeros(orig_img.size()) 271 | _, height, width = img_size 272 | 273 | for new_idx, old_idx in enumerate(reorder): 274 | if by_row: 275 | start_pix_new = new_idx * (padding + height) + padding 276 | start_pix_old = old_idx * (padding + height) + padding 277 | reordered_img[:, start_pix_new:start_pix_new + height, :] = orig_img[:, start_pix_old:start_pix_old + height, :] 278 | else: 279 | start_pix_new = new_idx * (padding + width) + padding 280 | start_pix_old = old_idx * (padding + width) + padding 281 | reordered_img[:, :, start_pix_new:start_pix_new + width] = orig_img[:, :, start_pix_old:start_pix_old + width] 282 | 283 | return reordered_img -------------------------------------------------------------------------------- /viz/latent_traversals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import stats 4 | 5 | 6 | class LatentTraverser(): 7 | def __init__(self, latent_spec): 8 | """ 9 | LatentTraverser is used to generate traversals of the latent space. 10 | 11 | Parameters 12 | ---------- 13 | latent_spec : dict 14 | See multi-vae.models.VAE for parameter definition. 15 | """ 16 | self.latent_spec = latent_spec 17 | self.sample_prior = False # If False fixes samples in untraversed 18 | # latent dimensions. If True samples 19 | # untraversed latent dimensions from prior. 20 | self.is_continuous = 'cont' in latent_spec 21 | self.is_discrete = 'disc' in latent_spec 22 | self.cont_dim = latent_spec['cont'] if self.is_continuous else None 23 | self.disc_dims = latent_spec['disc'] if self.is_discrete else None 24 | self.begain = 0.001 25 | self.end = 0.999 26 | 27 | def traverse_line(self, cont_idx=None, disc_idx=None, size=5): 28 | """ 29 | Returns a (size, D) latent sample, corresponding to a traversal of the 30 | latent variable indicated by cont_idx or disc_idx. 31 | 32 | Parameters 33 | ---------- 34 | cont_idx : int or None 35 | Index of continuous dimension to traverse. If the continuous latent 36 | vector is 10 dimensional and cont_idx = 7, then the 7th dimension 37 | will be traversed while all others will either be fixed or randomly 38 | sampled. If None, no latent is traversed and all latent 39 | dimensions are randomly sampled or kept fixed. 40 | 41 | disc_idx : int or None 42 | Index of discrete latent dimension to traverse. If there are 5 43 | discrete latent variables and disc_idx = 3, then only the 3rd 44 | discrete latent will be traversed while others will be fixed or 45 | randomly sampled. If None, no latent is traversed and all latent 46 | dimensions are randomly sampled or kept fixed. 47 | 48 | size : int 49 | Number of samples to generate. 50 | """ 51 | samples = [] 52 | 53 | if self.is_continuous: 54 | samples.append(self._traverse_continuous_line(idx=cont_idx, 55 | size=size)) 56 | if self.is_discrete: 57 | for i, disc_dim in enumerate(self.disc_dims): 58 | if i == disc_idx: 59 | samples.append(self._traverse_discrete_line(dim=disc_dim, 60 | traverse=True, 61 | size=size)) 62 | else: 63 | samples.append(self._traverse_discrete_line(dim=disc_dim, 64 | traverse=False, 65 | size=size)) 66 | 67 | return torch.cat(samples, dim=1) 68 | 69 | def _traverse_continuous_line(self, idx, size): 70 | """ 71 | Returns a (size, cont_dim) latent sample, corresponding to a traversal 72 | of a continuous latent variable indicated by idx. 73 | 74 | Parameters 75 | ---------- 76 | idx : int or None 77 | Index of continuous latent dimension to traverse. If None, no 78 | latent is traversed and all latent dimensions are randomly sampled 79 | or kept fixed. 80 | 81 | size : int 82 | Number of samples to generate. 83 | """ 84 | if self.sample_prior: 85 | samples = np.random.normal(size=(size, self.cont_dim)) 86 | else: 87 | samples = np.zeros(shape=(size, self.cont_dim)) 88 | 89 | if idx is not None: 90 | # Sweep over linearly spaced coordinates transformed through the 91 | # inverse CDF (ppf) of a gaussian since the prior of the latent 92 | # space is gaussian 93 | # cdf_traversal = np.linspace(0.05, 0.95, size) 94 | # cdf_traversal = np.linspace(0.01, 0.99, size) 95 | cdf_traversal = np.linspace(self.begain, self.end, size) 96 | cont_traversal = stats.norm.ppf(cdf_traversal) 97 | 98 | for i in range(size): 99 | samples[i, idx] = cont_traversal[i] 100 | 101 | return torch.Tensor(samples) 102 | 103 | def _traverse_discrete_line(self, dim, traverse, size): 104 | """ 105 | Returns a (size, dim) latent sample, corresponding to a traversal of a 106 | discrete latent variable. 107 | 108 | Parameters 109 | ---------- 110 | dim : int 111 | Number of categories of discrete latent variable. 112 | 113 | traverse : bool 114 | If True, traverse the categorical variable otherwise keep it fixed 115 | or randomly sample. 116 | 117 | size : int 118 | Number of samples to generate. 119 | """ 120 | samples = np.zeros((size, dim)) 121 | 122 | if traverse: 123 | for i in range(size): 124 | samples[i, i % dim] = 1. 125 | else: 126 | # Randomly select discrete variable (i.e. sample from uniform prior) 127 | if self.sample_prior: 128 | samples[np.arange(size), np.random.randint(0, dim, size)] = 1. 129 | else: 130 | samples[:, 0] = 1. 131 | 132 | return torch.Tensor(samples) 133 | 134 | def traverse_grid(self, cont_idx=None, cont_axis=None, disc_idx=None, 135 | disc_axis=None, size=(5, 5)): 136 | """ 137 | Returns a (size[0] * size[1], D) latent sample, corresponding to a 138 | two dimensional traversal of the latent space. 139 | 140 | Parameters 141 | ---------- 142 | cont_idx : int or None 143 | Index of continuous dimension to traverse. If the continuous latent 144 | vector is 10 dimensional and cont_idx = 7, then the 7th dimension 145 | will be traversed while all others will either be fixed or randomly 146 | sampled. If None, no latent is traversed and all latent 147 | dimensions are randomly sampled or kept fixed. 148 | 149 | cont_axis : int or None 150 | Either 0 for traversal across the rows or 1 for traversal across 151 | the columns. If None and disc_axis not None will default to axis 152 | which disc_axis is not. Otherwise will default to 0. 153 | 154 | disc_idx : int or None 155 | Index of discrete latent dimension to traverse. If there are 5 156 | discrete latent variables and disc_idx = 3, then only the 3rd 157 | discrete latent will be traversed while others will be fixed or 158 | randomly sampled. If None, no latent is traversed and all latent 159 | dimensions are randomly sampled or kept fixed. 160 | 161 | disc_axis : int or None 162 | Either 0 for traversal across the rows or 1 for traversal across 163 | the columns. If None and cont_axis not None will default to axis 164 | which cont_axis is not. Otherwise will default to 1. 165 | 166 | size : tuple of ints 167 | Shape of grid to generate. E.g. (6, 4). 168 | """ 169 | if cont_axis is None and disc_axis is None: 170 | cont_axis = 0 171 | disc_axis = 0 172 | elif cont_axis is None: 173 | cont_axis = int(not disc_axis) 174 | elif disc_axis is None: 175 | disc_axis = int(not cont_axis) 176 | 177 | samples = [] 178 | 179 | if self.is_continuous: 180 | samples.append(self._traverse_continuous_grid(idx=cont_idx, 181 | axis=cont_axis, 182 | size=size)) 183 | if self.is_discrete: 184 | for i, disc_dim in enumerate(self.disc_dims): 185 | if i == disc_idx: 186 | samples.append(self._traverse_discrete_grid(dim=disc_dim, 187 | axis=disc_axis, 188 | traverse=True, 189 | size=size)) 190 | else: 191 | samples.append(self._traverse_discrete_grid(dim=disc_dim, 192 | axis=disc_axis, 193 | traverse=False, 194 | size=size)) 195 | 196 | return torch.cat(samples, dim=1) 197 | 198 | def _traverse_continuous_grid(self, idx, axis, size): 199 | """ 200 | Returns a (size[0] * size[1], cont_dim) latent sample, corresponding to 201 | a two dimensional traversal of the continuous latent space. 202 | 203 | Parameters 204 | ---------- 205 | idx : int or None 206 | Index of continuous latent dimension to traverse. If None, no 207 | latent is traversed and all latent dimensions are randomly sampled 208 | or kept fixed. 209 | 210 | axis : int 211 | Either 0 for traversal across the rows or 1 for traversal across 212 | the columns. 213 | 214 | size : tuple of ints 215 | Shape of grid to generate. E.g. (6, 4). 216 | """ 217 | num_samples = size[0] * size[1] 218 | 219 | if self.sample_prior: 220 | samples = np.random.normal(size=(num_samples, self.cont_dim)) 221 | else: 222 | samples = np.zeros(shape=(num_samples, self.cont_dim)) 223 | 224 | if idx is not None: 225 | # Sweep over linearly spaced coordinates transformed through the 226 | # inverse CDF (ppf) of a gaussian since the prior of the latent 227 | # space is gaussian 228 | # cdf_traversal = np.linspace(0.05, 0.95, size[axis]) 229 | # cdf_traversal = np.linspace(0.001, 0.999, size[axis]) 230 | cdf_traversal = np.linspace(self.begain, self.end, size[axis]) 231 | cont_traversal = stats.norm.ppf(cdf_traversal) 232 | 233 | for i in range(size[0]): 234 | for j in range(size[1]): 235 | if axis == 0: 236 | samples[i * size[1] + j, idx] = cont_traversal[i] 237 | else: 238 | samples[i * size[1] + j, idx] = cont_traversal[j] 239 | 240 | return torch.Tensor(samples) 241 | 242 | def _traverse_discrete_grid(self, dim, axis, traverse, size): 243 | """ 244 | Returns a (size[0] * size[1], dim) latent sample, corresponding to a 245 | two dimensional traversal of a discrete latent variable, where the 246 | dimension of the traversal is determined by axis. 247 | 248 | Parameters 249 | ---------- 250 | idx : int or None 251 | Index of continuous latent dimension to traverse. If None, no 252 | latent is traversed and all latent dimensions are randomly sampled 253 | or kept fixed. 254 | 255 | axis : int 256 | Either 0 for traversal across the rows or 1 for traversal across 257 | the columns. 258 | 259 | traverse : bool 260 | If True, traverse the categorical variable otherwise keep it fixed 261 | or randomly sample. 262 | 263 | size : tuple of ints 264 | Shape of grid to generate. E.g. (6, 4). 265 | """ 266 | num_samples = size[0] * size[1] 267 | samples = np.zeros((num_samples, dim)) 268 | 269 | if traverse: 270 | disc_traversal = [i % dim for i in range(size[axis])] 271 | for i in range(size[0]): 272 | for j in range(size[1]): 273 | if axis == 0: 274 | samples[i * size[1] + j, disc_traversal[i]] = 1. 275 | else: 276 | samples[i * size[1] + j, disc_traversal[j]] = 1. 277 | else: 278 | # Randomly select discrete variable (i.e. sample from uniform prior) 279 | if self.sample_prior: 280 | samples[np.arange(num_samples), np.random.randint(0, dim, num_samples)] = 1. 281 | else: 282 | samples[:, 0] = 1. 283 | 284 | return torch.Tensor(samples) 285 | --------------------------------------------------------------------------------