├── .gitignore ├── LICENSE ├── MNISTparameters.py ├── README.md ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Panayiotis Panayiotou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MNISTparameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import datasets, transforms 4 | 5 | #### Everything is as described in Appendix A of the paper #### 6 | 7 | # PARAMETERS 8 | IMAGESIZE = 784 9 | N_LABELS = 10 10 | EMBED_DIMS = 10 11 | Z_DIMS = 64 # latent dimensions 12 | BATCH_SIZE = 100 13 | EPOCHS = 500 14 | ANNEAL_EPOCHS = 200 # Epochs to anneal for KL 15 | L_RATE = 1E-3 16 | BATCH_LOGGING_INTERVAL = 2 17 | MULTIPLIER_IMAGE = 1 18 | MULTIPLIER_LABEL = 50 19 | # Load train and test data 20 | train_loader = torch.utils.data.DataLoader( 21 | datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor()), 22 | batch_size=BATCH_SIZE, 23 | shuffle=True) 24 | test_loader = torch.utils.data.DataLoader( 25 | datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor()), 26 | batch_size=BATCH_SIZE, 27 | shuffle=False) 28 | N_MINI_BATCHES = len(train_loader) 29 | 30 | # Encoders and Decoders 31 | class LabelEncoder(nn.Module): 32 | def __init__(self): 33 | super(LabelEncoder, self).__init__() 34 | self.fc1 = nn.Embedding(N_LABELS, N_LABELS) 35 | self.fc2 = nn.Linear(N_LABELS, 512) 36 | self.fc_means = nn.Linear(512, Z_DIMS) 37 | self.fc_logvar = nn.Linear(512, Z_DIMS) 38 | 39 | def forward(self, x): 40 | h = self.fc1(x) 41 | h = self.fc2(h) 42 | return self.fc_means(h), self.fc_logvar(h) 43 | 44 | class LabelDecoder(nn.Module): 45 | def __init__(self): 46 | super(LabelDecoder, self).__init__() 47 | self.fc1 = nn.Linear(Z_DIMS, 512) 48 | self.fc2 = nn.Linear(512, 512) 49 | self.fc_out = nn.Linear(512, N_LABELS) 50 | 51 | def forward(self, z): 52 | h = self.fc1(z) 53 | h = self.fc2(h) 54 | return self.fc_out(h) 55 | 56 | class ImageEncoder(nn.Module): 57 | def __init__(self): 58 | super(ImageEncoder, self).__init__() 59 | self.fc1 = nn.Linear(IMAGESIZE, 512) 60 | self.fc2 = nn.Linear(512, 512) 61 | self.fc_means = nn.Linear(512, Z_DIMS) 62 | self.fc_logvar = nn.Linear(512, Z_DIMS) 63 | 64 | def forward(self, x): 65 | h = self.fc1(x.view(-1, IMAGESIZE)) 66 | h = self.fc2(h) 67 | return self.fc_means(h), self.fc_logvar(h) 68 | 69 | class ImageDecoder(nn.Module): 70 | def __init__(self): 71 | super(ImageDecoder, self).__init__() 72 | self.fc1 = nn.Linear(Z_DIMS, 512) 73 | self.fc2 = nn.Linear(512, 512) 74 | self.fc_out = nn.Linear(512, IMAGESIZE) 75 | 76 | def forward(self, z): 77 | h = self.fc1(z) 78 | h = self.fc2(h) 79 | return self.fc_out(h) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Variational Autoencoder 2 | 3 | Implementation of ["Multimodal Generative Models for Scalable Weakly-Supervised Learning"](https://arxiv.org/abs/1802.05335) paper 4 | 5 | ### Requirements 6 | torch == 0.4.1 7 | 8 | torchvision == 0.2.1 9 | 10 | ### Usage 11 | 12 | ```sh 13 | python3 train.py 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from MNISTparameters import ImageEncoder, ImageDecoder, LabelEncoder, LabelDecoder, Z_DIMS 5 | 6 | class MVAE(nn.Module): 7 | def __init__(self): 8 | super(MVAE, self).__init__() 9 | self.image_encoder = ImageEncoder() 10 | self.image_decoder = ImageDecoder() 11 | self.label_encoder = LabelEncoder() 12 | self.label_decoder = LabelDecoder() 13 | 14 | def reparametrize(self, means, logvar): 15 | if self.training: 16 | eps = Variable(torch.Tensor(means.shape).normal_()) 17 | return means + eps * logvar.mul(0.5).exp_() 18 | else: 19 | return means 20 | 21 | def prior_expert(self, size): 22 | # N(0, 1) 23 | means = Variable(torch.zeros(size)) 24 | logvar = Variable(torch.zeros(size)) 25 | return means, logvar 26 | 27 | # Mix gaussians 28 | def product_of_experts(self, means, logvar): 29 | P = 1.0 / torch.exp(logvar) 30 | Psum = P.sum(dim=0) 31 | prod_means = torch.sum(means * P, dim=0) / Psum 32 | prod_logvar = torch.log(1.0 / Psum) 33 | return prod_means, prod_logvar 34 | 35 | def forward(self, image=None, label=None): 36 | means, logvar = self.encode_modalities(image, label) 37 | z = self.reparametrize(means, logvar) 38 | # Reconstruct 39 | decoded_img = self.image_decoder(z) 40 | decoded_lbl = self.label_decoder(z) 41 | return decoded_img, decoded_lbl, means, logvar 42 | 43 | def encode_modalities(self, image=None, label=None): 44 | if (image is not None): 45 | batch_size = image.size(0) 46 | else: 47 | batch_size = label.size(0) 48 | 49 | # Initialization 50 | means, logvar = self.prior_expert((1, batch_size, Z_DIMS)) 51 | 52 | # Support for weak supervision setting 53 | if image is not None: 54 | img_mean, img_logvar = self.image_encoder(image) 55 | means = torch.cat((means, img_mean.unsqueeze(0))) 56 | logvar = torch.cat((logvar, img_logvar.unsqueeze(0))) 57 | 58 | if label is not None: 59 | lbl_mean, lbl_logvar = self.label_encoder(label) 60 | means = torch.cat((means, lbl_mean.unsqueeze(0))) 61 | logvar = torch.cat((logvar, lbl_logvar.unsqueeze(0))) 62 | 63 | # Combine the gaussians 64 | means, logvar = self.product_of_experts(means, logvar) 65 | return means, logvar 66 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from model import MVAE 7 | from MNISTparameters import Z_DIMS, IMAGESIZE, BATCH_SIZE, EPOCHS, ANNEAL_EPOCHS, L_RATE, BATCH_LOGGING_INTERVAL, MULTIPLIER_IMAGE, MULTIPLIER_LABEL, N_MINI_BATCHES 8 | from MNISTparameters import train_loader, test_loader 9 | 10 | def elbo_loss(recon_image, image, recon_label, label, mean, logvar, 11 | lambda_image=1.0, lambda_label=1.0, anneal_factor=1): 12 | image_bce = 0 13 | if recon_image is not None and image is not None: 14 | image_bce = torch.sum(binary_cross_entropy_of_logits( 15 | recon_image.view(-1, IMAGESIZE), 16 | image.view(-1, IMAGESIZE)), dim=1) 17 | 18 | label_bce = 0 19 | if recon_label is not None and label is not None: 20 | label_bce = torch.sum(cross_entropy_of_logits(recon_label, label), dim=1) 21 | 22 | KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=1) 23 | ELBO = torch.mean(lambda_image * image_bce + lambda_label * label_bce + anneal_factor * KLD) 24 | return ELBO 25 | 26 | 27 | def binary_cross_entropy_of_logits(input, target): 28 | return torch.clamp(input, 0) - input * target + torch.log(1 + torch.exp(-torch.abs(input))) 29 | 30 | 31 | def cross_entropy_of_logits(input, target): 32 | log_input = F.log_softmax(input, dim=1) 33 | y_onehot = Variable(torch.zeros(input.shape)) 34 | y_onehot = y_onehot.scatter(1, target.unsqueeze(1), 1) 35 | return -y_onehot * log_input 36 | 37 | def train(epoch): 38 | model.train() 39 | total_loss = 0 40 | total_examples = 0 41 | 42 | for batch_idx, (image, label) in enumerate(train_loader): 43 | # Linearly increase from 0 to 1 on each epoch 44 | anneal_factor = min(1.0, float(epoch - 1) / ANNEAL_EPOCHS) 45 | 46 | image = Variable(image) 47 | label = Variable(label) 48 | batch_size = len(image) 49 | 50 | # Refresh 51 | optimizer.zero_grad() 52 | recon_image_1, recon_label_1, mean_1, logvar_1 = model(image, label) 53 | recon_image_2, recon_label_2, mean_2, logvar_2 = model(image) 54 | recon_image_3, recon_label_3, mean_3, logvar_3 = model(label=label) 55 | # Compute ELBO 56 | joint_loss = elbo_loss(recon_image_1, image, recon_label_1, label, mean_1, logvar_1, 57 | lambda_image=MULTIPLIER_IMAGE, lambda_label=MULTIPLIER_LABEL, anneal_factor=anneal_factor) 58 | image_loss = elbo_loss(recon_image_2, image, None, None, mean_2, logvar_2, 59 | lambda_image=MULTIPLIER_IMAGE, lambda_label=MULTIPLIER_LABEL, anneal_factor=anneal_factor) 60 | label_loss = elbo_loss(None, None, recon_label_3, label, mean_3, logvar_3, 61 | lambda_image=MULTIPLIER_IMAGE, lambda_label=MULTIPLIER_LABEL, anneal_factor=anneal_factor) 62 | train_loss = joint_loss + image_loss + label_loss 63 | total_loss += train_loss.item() * batch_size 64 | total_examples += batch_size 65 | train_loss.backward() 66 | optimizer.step() 67 | if batch_idx % BATCH_LOGGING_INTERVAL == 0: 68 | print('Epoch: {} [{:5}/{}] Loss: {:11.6f} Annealing-Factor: {:.5f}'.format( 69 | epoch, batch_idx * len(image), len(train_loader.dataset), total_loss / total_examples, anneal_factor)) 70 | 71 | print('######## Epoch: {}\tLoss: {:.6f} ########'.format(epoch, total_loss / total_examples)) 72 | 73 | 74 | def test(epoch): 75 | model.eval() 76 | total_loss = 0 77 | total_examples = 0 78 | 79 | for batch_idx, (image, label) in enumerate(test_loader): 80 | 81 | with torch.no_grad(): 82 | image = Variable(image) 83 | label = Variable(label) 84 | batch_size = len(image) 85 | 86 | recon_image_1, recon_label_1, mean_1, logvar_1 = model(image, label) 87 | recon_image_2, recon_label_2, mean_2, logvar_2 = model(image) 88 | recon_image_3, recon_label_3, mean_3, logvar_3 = model(label=label) 89 | 90 | joint_loss = elbo_loss(recon_image_1, image, recon_label_1, label, mean_1, logvar_1) 91 | image_loss = elbo_loss(recon_image_2, image, None, None, mean_2, logvar_2) 92 | label_loss = elbo_loss(None, None, recon_label_3, label, mean_3, logvar_3) 93 | test_loss = joint_loss + image_loss + label_loss 94 | total_loss += test_loss.item() * batch_size 95 | total_examples += batch_size 96 | 97 | 98 | print('######## Test Loss: {} ########'.format(total_loss / total_examples)) 99 | return total_loss / total_examples 100 | 101 | def save_model(): 102 | state = { 103 | 'state_dict': model.state_dict(), 104 | 'optimizer': optimizer.state_dict() 105 | } 106 | torch.save(state, os.path.join(folder, './bestmodel')) 107 | print('Model saved!') 108 | 109 | 110 | # Save the model every 5 epochs 111 | if __name__ == "__main__": 112 | model = MVAE() 113 | optimizer = torch.optim.Adam(model.parameters(), lr=L_RATE) 114 | 115 | # Directory to save results 116 | folder = './models' 117 | if not os.path.isdir(folder): 118 | os.mkdir(folder) 119 | 120 | # Train 121 | train(1) 122 | test_loss = test(1) 123 | best_loss = test_loss 124 | save_model() 125 | for epoch in range(2, EPOCHS + 1): 126 | train(epoch) 127 | test_loss = test(epoch) 128 | if test_loss < best_loss: 129 | best_loss = test_loss 130 | save_model() 131 | --------------------------------------------------------------------------------