├── .gitignore ├── LICENSE ├── README.md ├── data └── .gitkeep ├── model.py ├── requirements.txt ├── results └── .gitkeep └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .DS_Store 107 | .idea 108 | data 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Koichiro Mori 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Variational Autoencoder in PyTorch 2 | 3 | ## Install 4 | 5 | ``` 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | ## Run 10 | 11 | ``` 12 | python train.py 13 | ``` -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidiary/conv-vae/c82baad1b77cf5025ce058798e03fc7490e7f7c9/data/.gitkeep -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Flatten(nn.Module): 6 | def forward(self, input): 7 | return input.view(input.size(0), -1) 8 | 9 | 10 | class Unflatten(nn.Module): 11 | def __init__(self, channel, height, width): 12 | super(Unflatten, self).__init__() 13 | self.channel = channel 14 | self.height = height 15 | self.width = width 16 | 17 | def forward(self, input): 18 | return input.view(input.size(0), self.channel, self.height, self.width) 19 | 20 | 21 | class ConvVAE(nn.Module): 22 | 23 | def __init__(self, latent_size): 24 | super(ConvVAE, self).__init__() 25 | 26 | self.latent_size = latent_size 27 | 28 | self.encoder = nn.Sequential( 29 | nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), 30 | nn.ReLU(), 31 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 32 | nn.ReLU(), 33 | Flatten(), 34 | nn.Linear(6272, 1024), 35 | nn.ReLU() 36 | ) 37 | 38 | # hidden => mu 39 | self.fc1 = nn.Linear(1024, self.latent_size) 40 | 41 | # hidden => logvar 42 | self.fc2 = nn.Linear(1024, self.latent_size) 43 | 44 | self.decoder = nn.Sequential( 45 | nn.Linear(self.latent_size, 1024), 46 | nn.ReLU(), 47 | nn.Linear(1024, 6272), 48 | nn.ReLU(), 49 | Unflatten(128, 7, 7), 50 | nn.ReLU(), 51 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), 52 | nn.ReLU(), 53 | nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), 54 | nn.Sigmoid() 55 | ) 56 | 57 | def encode(self, x): 58 | h = self.encoder(x) 59 | mu, logvar = self.fc1(h), self.fc2(h) 60 | return mu, logvar 61 | 62 | def decode(self, z): 63 | z = self.decoder(z) 64 | return z 65 | 66 | def reparameterize(self, mu, logvar): 67 | if self.training: 68 | std = torch.exp(0.5 * logvar) 69 | eps = torch.randn_like(std) 70 | return eps.mul(std).add_(mu) 71 | else: 72 | return mu 73 | 74 | def forward(self, x): 75 | mu, logvar = self.encode(x) 76 | z = self.reparameterize(mu, logvar) 77 | return self.decode(z), mu, logvar 78 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 0.4.1 2 | torchvision >= 0.2.1 3 | tensorboardX >= 1.4 4 | tqdm >= 4.25.0 5 | -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidiary/conv-vae/c82baad1b77cf5025ce058798e03fc7490e7f7c9/results/.gitkeep -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from torch.nn import functional as F 5 | from torchvision import datasets, transforms 6 | from torchvision.utils import save_image, make_grid 7 | 8 | import os 9 | import shutil 10 | import numpy as np 11 | from tensorboardX import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | from model import ConvVAE 15 | 16 | 17 | cuda = torch.cuda.is_available() 18 | if cuda: 19 | print('cuda available') 20 | 21 | device = torch.device("cuda" if cuda else "cpu") 22 | 23 | 24 | def loss_function(recon_x, x, mu, logvar): 25 | # reconstruction loss 26 | BCE = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum') 27 | 28 | # KL divergence loss 29 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 30 | 31 | return BCE + KLD 32 | 33 | 34 | def train(epoch, model, train_loader, optimizer, args): 35 | model.train() 36 | train_loss = 0 37 | 38 | for batch_idx, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader), desc='train'): 39 | data = data.to(device) 40 | 41 | optimizer.zero_grad() 42 | recon_batch, mu, logvar = model(data) 43 | 44 | loss = loss_function(recon_batch, data, mu, logvar) 45 | train_loss += loss.item() 46 | 47 | loss.backward() 48 | optimizer.step() 49 | 50 | train_loss /= len(train_loader.dataset) 51 | 52 | return train_loss 53 | 54 | 55 | def test(epoch, model, test_loader, writer, args): 56 | model.eval() 57 | test_loss = 0 58 | 59 | with torch.no_grad(): 60 | for batch_idx, (data, _) in tqdm(enumerate(test_loader), total=len(test_loader), desc='test'): 61 | data = data.to(device) 62 | 63 | recon_batch, mu, logvar = model(data) 64 | 65 | test_loss += loss_function(recon_batch, data, mu, logvar).item() 66 | 67 | if batch_idx == 0: 68 | n = min(data.size(0), 8) 69 | comparison = torch.cat([data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n]]).cpu() 70 | img = make_grid(comparison) 71 | writer.add_image('reconstruction', img, epoch) 72 | # save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow=n) 73 | 74 | test_loss /= len(test_loader.dataset) 75 | 76 | return test_loss 77 | 78 | 79 | def save_checkpoint(state, is_best, outdir='results'): 80 | checkpoint_file = os.path.join(outdir, 'checkpoint.pth') 81 | best_file = os.path.join(outdir, 'model_best.pth') 82 | torch.save(state, checkpoint_file) 83 | if is_best: 84 | shutil.copyfile(checkpoint_file, best_file) 85 | 86 | 87 | def main(): 88 | parser = argparse.ArgumentParser(description='Convolutional VAE MNIST Example') 89 | parser.add_argument('--result_dir', type=str, default='results', metavar='DIR', 90 | help='output directory') 91 | parser.add_argument('--batch_size', type=int, default=100, metavar='N', 92 | help='input batch size for training (default: 128)') 93 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 94 | help='number of epochs to train (default: 10)') 95 | parser.add_argument('--seed', type=int, default=1, metavar='S', 96 | help='random seed (default: 1)') 97 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 98 | help='path to latest checkpoint (default: None') 99 | 100 | # model options 101 | parser.add_argument('--latent_size', type=int, default=32, metavar='N', 102 | help='latent vector size of encoder') 103 | 104 | args = parser.parse_args() 105 | 106 | torch.manual_seed(args.seed) 107 | 108 | kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {} 109 | 110 | train_loader = torch.utils.data.DataLoader( 111 | datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor()), 112 | batch_size=args.batch_size, shuffle=True, **kwargs) 113 | 114 | test_loader = torch.utils.data.DataLoader( 115 | datasets.MNIST('./data', train=False, transform=transforms.ToTensor()), 116 | batch_size=args.batch_size, shuffle=True, **kwargs) 117 | 118 | model = ConvVAE(args.latent_size).to(device) 119 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 120 | 121 | start_epoch = 0 122 | best_test_loss = np.finfo('f').max 123 | 124 | # optionally resume from a checkpoint 125 | if args.resume: 126 | if os.path.isfile(args.resume): 127 | print('=> loading checkpoint %s' % args.resume) 128 | checkpoint = torch.load(args.resume) 129 | start_epoch = checkpoint['epoch'] + 1 130 | best_test_loss = checkpoint['best_test_loss'] 131 | model.load_state_dict(checkpoint['state_dict']) 132 | optimizer.load_state_dict(checkpoint['optimizer']) 133 | print('=> loaded checkpoint %s' % args.resume) 134 | else: 135 | print('=> no checkpoint found at %s' % args.resume) 136 | 137 | writer = SummaryWriter() 138 | 139 | for epoch in range(start_epoch, args.epochs): 140 | train_loss = train(epoch, model, train_loader, optimizer, args) 141 | test_loss = test(epoch, model, test_loader, writer, args) 142 | 143 | # logging 144 | writer.add_scalar('train/loss', train_loss, epoch) 145 | writer.add_scalar('test/loss', test_loss, epoch) 146 | 147 | print('Epoch [%d/%d] loss: %.3f val_loss: %.3f' % (epoch + 1, args.epochs, train_loss, test_loss)) 148 | 149 | is_best = test_loss < best_test_loss 150 | best_test_loss = min(test_loss, best_test_loss) 151 | save_checkpoint({ 152 | 'epoch': epoch, 153 | 'best_test_loss': best_test_loss, 154 | 'state_dict': model.state_dict(), 155 | 'optimizer': optimizer.state_dict(), 156 | }, is_best) 157 | 158 | with torch.no_grad(): 159 | sample = torch.randn(64, 32).to(device) 160 | sample = model.decode(sample).cpu() 161 | img = make_grid(sample) 162 | writer.add_image('sampling', img, epoch) 163 | # save_image(sample.view(64, 1, 28, 28), 'results/sample_' + str(epoch) + '.png') 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | --------------------------------------------------------------------------------