├── .gitignore ├── README.md ├── configs ├── cifar10_ce_loss.json └── cifar10_mse_loss.json ├── data_loaders └── cifar10_data_loader.py ├── graph ├── ce_loss.py ├── ce_model.py ├── mse_loss.py └── mse_model.py ├── main.py ├── train ├── ce_trainer.py └── mse_trainer.py └── utils ├── utils.py └── weight_initializer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | data/ 109 | .idea/ 110 | experiments/ 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Autoencoder 2 | This repository contains a convolutional-VAE model implementation in pytorch and trained on CIFAR10 dataset. 3 | 4 | ## How to train 5 | ``` 6 | python main.py --config=./configs/cifar10.json 7 | ``` 8 | -------------------------------------------------------------------------------- /configs/cifar10_ce_loss.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_dir": "VAE_CIFAR10_ce_loss_ld_512", 3 | 4 | "num_epochs": 800, 5 | "seed": 1, 6 | 7 | "dataset": "CIFAR10", 8 | "loss": "ce", 9 | "input_shape":{ 10 | "width": 32, 11 | "hight": 32, 12 | "channels": 3 13 | }, 14 | "num_classes": 10, 15 | "dataloader_workers": 4, 16 | "shuffle": true, 17 | 18 | "batch_size": 32, 19 | "weight_decay": 5e-4, 20 | "learning_rate": 1e-3, 21 | "learning_rate_decay": 0.99, 22 | 23 | "resume": true, 24 | "resume_from": "checkpoint.pth.tar", 25 | "test_every": 20, 26 | "to_train": true, 27 | "to_test": true, 28 | 29 | "cuda": true, 30 | "pin_memory": true, 31 | "async_loading": true 32 | } 33 | -------------------------------------------------------------------------------- /configs/cifar10_mse_loss.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_dir": "VAE_CIFAR10_mse_loss_ld_512", 3 | 4 | "num_epochs": 800, 5 | "seed": 1, 6 | 7 | "dataset": "CIFAR10", 8 | "loss": "mse", 9 | "input_shape":{ 10 | "width": 32, 11 | "hight": 32, 12 | "channels": 3 13 | }, 14 | "num_classes": 10, 15 | "dataloader_workers": 4, 16 | "shuffle": true, 17 | 18 | "batch_size": 32, 19 | "weight_decay": 5e-4, 20 | "learning_rate": 1e-3, 21 | "learning_rate_decay": 0.99, 22 | 23 | "resume": true, 24 | "resume_from": "checkpoint.pth.tar", 25 | "test_every": 20, 26 | "to_train": true, 27 | "to_test": true, 28 | 29 | "cuda": true, 30 | "pin_memory": true, 31 | "async_loading": true 32 | } 33 | -------------------------------------------------------------------------------- /data_loaders/cifar10_data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class CIFAR10DataLoader: 6 | def __init__(self, args): 7 | if args.dataset == 'CIFAR10': 8 | # Data Loading 9 | kwargs = {'num_workers': args.dataloader_workers, 'pin_memory': args.pin_memory} if args.cuda else {} 10 | 11 | transform_train = transforms.Compose([ 12 | transforms.ToTensor() 13 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 14 | ]) 15 | 16 | transform_test = transforms.Compose([ 17 | transforms.ToTensor() 18 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 19 | ]) 20 | 21 | train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 22 | self.train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=args.shuffle, 23 | **kwargs) 24 | 25 | test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 26 | self.test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, 27 | **kwargs) 28 | 29 | self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 30 | 31 | else: 32 | raise ValueError('The dataset should be CIFAR10') 33 | -------------------------------------------------------------------------------- /graph/ce_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class Loss(nn.Module): 6 | def __init__(self): 7 | super(Loss, self).__init__() 8 | self.ce_loss = nn.CrossEntropyLoss(size_average=False) 9 | 10 | def forward(self, recon_x, x, mu, logvar): 11 | # BCE = F.mse_loss(recon_x, x, size_average=False) 12 | x = x * 255 13 | x.data = x.data.int().long().view(-1) 14 | # print(recon_x.shape) 15 | recon_x = recon_x.permute(0, 2, 3, 4, 1) # N * C * W * H 16 | # print(recon_x.shape) 17 | recon_x = recon_x.contiguous().view(-1, 256) 18 | 19 | CE = self.ce_loss(recon_x, x) 20 | 21 | # see Appendix B from VAE paper: https://arxiv.org/abs/1312.6114 22 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 23 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 24 | 25 | return CE + KLD 26 | -------------------------------------------------------------------------------- /graph/ce_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Variable 3 | 4 | 5 | class VAE(nn.Module): 6 | def __init__(self): 7 | super(VAE, self).__init__() 8 | 9 | # Encoder 10 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 11 | self.bn1 = nn.BatchNorm2d(16) 12 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False) 13 | self.bn2 = nn.BatchNorm2d(32) 14 | self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn3 = nn.BatchNorm2d(32) 16 | self.conv4 = nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1, bias=False) 17 | self.bn4 = nn.BatchNorm2d(16) 18 | 19 | self.fc1 = nn.Linear(8 * 8 * 16, 512) 20 | self.fc_bn1 = nn.BatchNorm1d(512) 21 | self.fc21 = nn.Linear(512, 512) 22 | self.fc22 = nn.Linear(512, 512) 23 | 24 | # Decoder 25 | self.fc3 = nn.Linear(512, 512) 26 | self.fc_bn3 = nn.BatchNorm1d(512) 27 | self.fc4 = nn.Linear(512, 8 * 8 * 16) 28 | self.fc_bn4 = nn.BatchNorm1d(8 * 8 * 16) 29 | 30 | self.conv5 = nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False) 31 | self.bn5 = nn.BatchNorm2d(32) 32 | self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn6 = nn.BatchNorm2d(32) 34 | self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False) 35 | self.bn7 = nn.BatchNorm2d(16) 36 | self.conv8 = nn.ConvTranspose2d(16, 3 * 256, kernel_size=3, stride=1, padding=1, bias=False) 37 | 38 | self.relu = nn.ReLU() 39 | 40 | def encode(self, x): 41 | conv1 = self.relu(self.bn1(self.conv1(x))) 42 | conv2 = self.relu(self.bn2(self.conv2(conv1))) 43 | conv3 = self.relu(self.bn3(self.conv3(conv2))) 44 | conv4 = self.relu(self.bn4(self.conv4(conv3))).view(-1, 8 * 8 * 16) 45 | 46 | fc1 = self.relu(self.fc_bn1(self.fc1(conv4))) 47 | return self.fc21(fc1), self.fc22(fc1) 48 | 49 | def reparameterize(self, mu, logvar): 50 | if self.training: 51 | std = logvar.mul(0.5).exp_() 52 | eps = Variable(std.data.new(std.size()).normal_()) 53 | return eps.mul(std).add_(mu) 54 | else: 55 | return mu 56 | 57 | def decode(self, z): 58 | fc3 = self.relu(self.fc_bn3(self.fc3(z))) 59 | fc4 = self.relu(self.fc_bn4(self.fc4(fc3))).view(-1, 16, 8, 8) 60 | 61 | conv5 = self.relu(self.bn5(self.conv5(fc4))) 62 | conv6 = self.relu(self.bn6(self.conv6(conv5))) 63 | conv7 = self.relu(self.bn7(self.conv7(conv6))) 64 | return self.conv8(conv7).view(-1, 256, 3, 32, 32) 65 | 66 | def forward(self, x): 67 | mu, logvar = self.encode(x) 68 | z = self.reparameterize(mu, logvar) 69 | return self.decode(z), mu, logvar 70 | -------------------------------------------------------------------------------- /graph/mse_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class Loss(nn.Module): 6 | def __init__(self): 7 | super(Loss, self).__init__() 8 | self.mse_loss = nn.MSELoss(size_average=False) 9 | 10 | def forward(self, recon_x, x, mu, logvar): 11 | MSE = self.mse_loss(recon_x, x) 12 | 13 | # see Appendix B from VAE paper: https://arxiv.org/abs/1312.6114 14 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 15 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 16 | 17 | return MSE + KLD 18 | -------------------------------------------------------------------------------- /graph/mse_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Variable 3 | 4 | 5 | class VAE(nn.Module): 6 | def __init__(self): 7 | super(VAE, self).__init__() 8 | 9 | # Encoder 10 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 11 | self.bn1 = nn.BatchNorm2d(16) 12 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False) 13 | self.bn2 = nn.BatchNorm2d(32) 14 | self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn3 = nn.BatchNorm2d(32) 16 | self.conv4 = nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1, bias=False) 17 | self.bn4 = nn.BatchNorm2d(16) 18 | 19 | self.fc1 = nn.Linear(8 * 8 * 16, 512) 20 | self.fc_bn1 = nn.BatchNorm1d(512) 21 | self.fc21 = nn.Linear(512, 512) 22 | self.fc22 = nn.Linear(512, 512) 23 | 24 | # Decoder 25 | self.fc3 = nn.Linear(512, 512) 26 | self.fc_bn3 = nn.BatchNorm1d(512) 27 | self.fc4 = nn.Linear(512, 8 * 8 * 16) 28 | self.fc_bn4 = nn.BatchNorm1d(8 * 8 * 16) 29 | 30 | self.conv5 = nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False) 31 | self.bn5 = nn.BatchNorm2d(32) 32 | self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn6 = nn.BatchNorm2d(32) 34 | self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False) 35 | self.bn7 = nn.BatchNorm2d(16) 36 | self.conv8 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1, bias=False) 37 | 38 | self.relu = nn.ReLU() 39 | 40 | def encode(self, x): 41 | conv1 = self.relu(self.bn1(self.conv1(x))) 42 | conv2 = self.relu(self.bn2(self.conv2(conv1))) 43 | conv3 = self.relu(self.bn3(self.conv3(conv2))) 44 | conv4 = self.relu(self.bn4(self.conv4(conv3))).view(-1, 8 * 8 * 16) 45 | 46 | fc1 = self.relu(self.fc_bn1(self.fc1(conv4))) 47 | return self.fc21(fc1), self.fc22(fc1) 48 | 49 | def reparameterize(self, mu, logvar): 50 | if self.training: 51 | std = logvar.mul(0.5).exp_() 52 | eps = Variable(std.data.new(std.size()).normal_()) 53 | return eps.mul(std).add_(mu) 54 | else: 55 | return mu 56 | 57 | def decode(self, z): 58 | fc3 = self.relu(self.fc_bn3(self.fc3(z))) 59 | fc4 = self.relu(self.fc_bn4(self.fc4(fc3))).view(-1, 16, 8, 8) 60 | 61 | conv5 = self.relu(self.bn5(self.conv5(fc4))) 62 | conv6 = self.relu(self.bn6(self.conv6(conv5))) 63 | conv7 = self.relu(self.bn7(self.conv7(conv6))) 64 | return self.conv8(conv7).view(-1, 3, 32, 32) 65 | 66 | def forward(self, x): 67 | mu, logvar = self.encode(x) 68 | z = self.reparameterize(mu, logvar) 69 | return self.decode(z), mu, logvar 70 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.backends.cudnn as cudnn 4 | import torch.nn.init as init 5 | import torch.utils.data 6 | 7 | from data_loaders.cifar10_data_loader import CIFAR10DataLoader 8 | from graph.ce_loss import Loss as Loss_ce 9 | from graph.mse_loss import Loss as Loss_mse 10 | from graph.ce_model import VAE as VAE_ce 11 | from graph.mse_model import VAE as VAE_mse 12 | from train.ce_trainer import Trainer as Trainer_ce 13 | from train.mse_trainer import Trainer as Trainer_mse 14 | from utils.utils import * 15 | from utils.weight_initializer import Initializer 16 | 17 | 18 | def main(): 19 | # Parse the JSON arguments 20 | args = parse_args() 21 | 22 | # Create the experiment directories 23 | args.summary_dir, args.checkpoint_dir = create_experiment_dirs( 24 | args.experiment_dir) 25 | 26 | if args.loss == 'ce': 27 | model = VAE_ce() 28 | else: 29 | model = VAE_mse() 30 | 31 | # to apply xavier_uniform: 32 | Initializer.initialize(model=model, initialization=init.xavier_uniform, gain=init.calculate_gain('relu')) 33 | 34 | if args.loss == 'ce': 35 | loss = Loss_ce() 36 | else: 37 | loss = Loss_mse() 38 | 39 | args.cuda = args.cuda and torch.cuda.is_available() 40 | if args.cuda: 41 | model.cuda() 42 | loss.cuda() 43 | cudnn.enabled = True 44 | cudnn.benchmark = True 45 | 46 | print("Loading Data...") 47 | data = CIFAR10DataLoader(args) 48 | print("Data loaded successfully\n") 49 | 50 | if args.loss == 'ce': 51 | trainer = Trainer_ce(model, loss, data.train_loader, data.test_loader, args) 52 | else: 53 | trainer = Trainer_mse(model, loss, data.train_loader, data.test_loader, args) 54 | 55 | if args.to_train: 56 | try: 57 | print("Training...") 58 | trainer.train() 59 | print("Training Finished\n") 60 | except KeyboardInterrupt: 61 | print("Training had been Interrupted\n") 62 | 63 | if args.to_test: 64 | print("Testing on training data...") 65 | trainer.test_on_trainings_set() 66 | print("Testing Finished\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /train/ce_trainer.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from torch import optim 3 | import torch 4 | 5 | from tensorboardX import SummaryWriter 6 | import shutil 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | class Trainer: 11 | def __init__(self, model, loss, train_loader, test_loader, args): 12 | self.model = model 13 | self.args = args 14 | self.args.start_epoch = 0 15 | 16 | self.train_loader = train_loader 17 | self.test_loader = test_loader 18 | 19 | # Loss function and Optimizer 20 | self.loss = loss 21 | self.optimizer = self.get_optimizer() 22 | 23 | # Tensorboard Writer 24 | self.summary_writer = SummaryWriter(log_dir=args.summary_dir) 25 | # Model Loading 26 | if args.resume: 27 | self.load_checkpoint(self.args.resume_from) 28 | 29 | def train(self): 30 | self.model.train() 31 | for epoch in range(self.args.start_epoch, self.args.num_epochs): 32 | loss_list = [] 33 | print("epoch {}...".format(epoch)) 34 | for batch_idx, (data, _) in enumerate(tqdm(self.train_loader)): 35 | if self.args.cuda: 36 | data = data.cuda() 37 | data = Variable(data) 38 | self.optimizer.zero_grad() 39 | recon_batch, mu, logvar = self.model(data) 40 | loss = self.loss(recon_batch, data, mu, logvar) 41 | loss.backward() 42 | self.optimizer.step() 43 | loss_list.append(loss.data[0]) 44 | 45 | print("epoch {}: - loss: {}".format(epoch, np.mean(loss_list))) 46 | new_lr = self.adjust_learning_rate(epoch) 47 | print('learning rate:', new_lr) 48 | 49 | self.summary_writer.add_scalar('training/loss', np.mean(loss_list), epoch) 50 | self.summary_writer.add_scalar('training/learning_rate', new_lr, epoch) 51 | self.save_checkpoint({ 52 | 'epoch': epoch + 1, 53 | 'state_dict': self.model.state_dict(), 54 | 'optimizer': self.optimizer.state_dict(), 55 | }) 56 | if epoch % self.args.test_every == 0: 57 | self.test(epoch) 58 | 59 | def test(self, cur_epoch): 60 | print('testing...') 61 | self.model.eval() 62 | test_loss = 0 63 | for i, (data, _) in enumerate(self.test_loader): 64 | if self.args.cuda: 65 | data = data.cuda() 66 | data = Variable(data, volatile=True) 67 | recon_batch, mu, logvar = self.model(data) 68 | test_loss += self.loss(recon_batch, data, mu, logvar).data[0] 69 | _, indices = recon_batch.max(1) 70 | indices.data = indices.data.float() / 255 71 | if i == 0: 72 | n = min(data.size(0), 8) 73 | comparison = torch.cat([data[:n], 74 | indices.view(-1, 3, 32, 32)[:n]]) 75 | self.summary_writer.add_image('testing_set/image', comparison, cur_epoch) 76 | 77 | test_loss /= len(self.test_loader.dataset) 78 | print('====> Test set loss: {:.4f}'.format(test_loss)) 79 | self.summary_writer.add_scalar('testing/loss', test_loss, cur_epoch) 80 | self.model.train() 81 | 82 | def test_on_trainings_set(self): 83 | print('testing...') 84 | self.model.eval() 85 | test_loss = 0 86 | for i, (data, _) in enumerate(self.train_loader): 87 | if self.args.cuda: 88 | data = data.cuda() 89 | data = Variable(data, volatile=True) 90 | recon_batch, mu, logvar = self.model(data) 91 | test_loss += self.loss(recon_batch, data, mu, logvar).data[0] 92 | _, indices = recon_batch.max(1) 93 | indices.data = indices.data.float() / 255 94 | if i % 50 == 0: 95 | n = min(data.size(0), 8) 96 | comparison = torch.cat([data[:n], 97 | indices.view(-1, 3, 32, 32)[:n]]) 98 | self.summary_writer.add_image('training_set/image', comparison, i) 99 | 100 | test_loss /= len(self.test_loader.dataset) 101 | print('====> Test on training set loss: {:.4f}'.format(test_loss)) 102 | self.model.train() 103 | 104 | def get_optimizer(self): 105 | return optim.Adam(self.model.parameters(), lr=self.args.learning_rate, 106 | weight_decay=self.args.weight_decay) 107 | 108 | def adjust_learning_rate(self, epoch): 109 | """Sets the learning rate to the initial LR multiplied by 0.98 every epoch""" 110 | learning_rate = self.args.learning_rate * (self.args.learning_rate_decay ** epoch) 111 | for param_group in self.optimizer.param_groups: 112 | param_group['lr'] = learning_rate 113 | return learning_rate 114 | 115 | def save_checkpoint(self, state, is_best=False, filename='checkpoint.pth.tar'): 116 | ''' 117 | a function to save checkpoint of the training 118 | :param state: {'epoch': cur_epoch + 1, 'state_dict': self.model.state_dict(), 119 | 'optimizer': self.optimizer.state_dict()} 120 | :param is_best: boolean to save the checkpoint aside if it has the best score so far 121 | :param filename: the name of the saved file 122 | ''' 123 | torch.save(state, self.args.checkpoint_dir + filename) 124 | if is_best: 125 | shutil.copyfile(self.args.checkpoint_dir + filename, 126 | self.args.checkpoint_dir + 'model_best.pth.tar') 127 | 128 | def load_checkpoint(self, filename): 129 | filename = self.args.checkpoint_dir + filename 130 | try: 131 | print("Loading checkpoint '{}'".format(filename)) 132 | checkpoint = torch.load(filename) 133 | self.args.start_epoch = checkpoint['epoch'] 134 | self.model.load_state_dict(checkpoint['state_dict']) 135 | self.optimizer.load_state_dict(checkpoint['optimizer']) 136 | print("Checkpoint loaded successfully from '{}' at (epoch {})\n" 137 | .format(self.args.checkpoint_dir, checkpoint['epoch'])) 138 | except: 139 | print("No checkpoint exists from '{}'. Skipping...\n".format(self.args.checkpoint_dir)) 140 | -------------------------------------------------------------------------------- /train/mse_trainer.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from torch import optim 3 | import torch 4 | 5 | from tensorboardX import SummaryWriter 6 | import shutil 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | class Trainer: 12 | def __init__(self, model, loss, train_loader, test_loader, args): 13 | self.model = model 14 | self.args = args 15 | self.args.start_epoch = 0 16 | 17 | self.train_loader = train_loader 18 | self.test_loader = test_loader 19 | 20 | # Loss function and Optimizer 21 | self.loss = loss 22 | self.optimizer = self.get_optimizer() 23 | 24 | # Tensorboard Writer 25 | self.summary_writer = SummaryWriter(log_dir=args.summary_dir) 26 | # Model Loading 27 | if args.resume: 28 | self.load_checkpoint(self.args.resume_from) 29 | 30 | def train(self): 31 | self.model.train() 32 | for epoch in range(self.args.start_epoch, self.args.num_epochs): 33 | loss_list = [] 34 | print("epoch {}...".format(epoch)) 35 | for batch_idx, (data, _) in enumerate(tqdm(self.train_loader)): 36 | if self.args.cuda: 37 | data = data.cuda() 38 | data = Variable(data) 39 | self.optimizer.zero_grad() 40 | recon_batch, mu, logvar = self.model(data) 41 | loss = self.loss(recon_batch, data, mu, logvar) 42 | loss.backward() 43 | self.optimizer.step() 44 | loss_list.append(loss.data[0]) 45 | 46 | print("epoch {}: - loss: {}".format(epoch, np.mean(loss_list))) 47 | new_lr = self.adjust_learning_rate(epoch) 48 | print('learning rate:', new_lr) 49 | 50 | self.summary_writer.add_scalar('training/loss', np.mean(loss_list), epoch) 51 | self.summary_writer.add_scalar('training/learning_rate', new_lr, epoch) 52 | self.save_checkpoint({ 53 | 'epoch': epoch + 1, 54 | 'state_dict': self.model.state_dict(), 55 | 'optimizer': self.optimizer.state_dict(), 56 | }) 57 | if epoch % self.args.test_every == 0: 58 | self.test(epoch) 59 | 60 | def test(self, cur_epoch): 61 | print('testing...') 62 | self.model.eval() 63 | test_loss = 0 64 | for i, (data, _) in enumerate(self.test_loader): 65 | if self.args.cuda: 66 | data = data.cuda() 67 | data = Variable(data, volatile=True) 68 | recon_batch, mu, logvar = self.model(data) 69 | test_loss += self.loss(recon_batch, data, mu, logvar).data[0] 70 | if i == 0: 71 | n = min(data.size(0), 8) 72 | comparison = torch.cat([data[:n], 73 | recon_batch.view(-1, 3, 32, 32)[:n]]) 74 | self.summary_writer.add_image('testing_set/image', comparison, cur_epoch) 75 | 76 | test_loss /= len(self.test_loader.dataset) 77 | print('====> Test set loss: {:.4f}'.format(test_loss)) 78 | self.summary_writer.add_scalar('testing/loss', test_loss, cur_epoch) 79 | self.model.train() 80 | 81 | def test_on_trainings_set(self): 82 | print('testing...') 83 | self.model.eval() 84 | test_loss = 0 85 | for i, (data, _) in enumerate(self.train_loader): 86 | if self.args.cuda: 87 | data = data.cuda() 88 | data = Variable(data, volatile=True) 89 | recon_batch, mu, logvar = self.model(data) 90 | test_loss += self.loss(recon_batch, data, mu, logvar).data[0] 91 | if i % 50 == 0: 92 | n = min(data.size(0), 8) 93 | comparison = torch.cat([data[:n], 94 | recon_batch.view(-1, 3, 32, 32)[:n]]) 95 | self.summary_writer.add_image('training_set/image', comparison, i) 96 | 97 | test_loss /= len(self.test_loader.dataset) 98 | print('====> Test on training set loss: {:.4f}'.format(test_loss)) 99 | self.model.train() 100 | 101 | def get_optimizer(self): 102 | return optim.Adam(self.model.parameters(), lr=self.args.learning_rate, 103 | weight_decay=self.args.weight_decay) 104 | 105 | def adjust_learning_rate(self, epoch): 106 | """Sets the learning rate to the initial LR multiplied by 0.98 every epoch""" 107 | learning_rate = self.args.learning_rate * (self.args.learning_rate_decay ** epoch) 108 | for param_group in self.optimizer.param_groups: 109 | param_group['lr'] = learning_rate 110 | return learning_rate 111 | 112 | def save_checkpoint(self, state, is_best=False, filename='checkpoint.pth.tar'): 113 | ''' 114 | a function to save checkpoint of the training 115 | :param state: {'epoch': cur_epoch + 1, 'state_dict': self.model.state_dict(), 116 | 'optimizer': self.optimizer.state_dict()} 117 | :param is_best: boolean to save the checkpoint aside if it has the best score so far 118 | :param filename: the name of the saved file 119 | ''' 120 | torch.save(state, self.args.checkpoint_dir + filename) 121 | if is_best: 122 | shutil.copyfile(self.args.checkpoint_dir + filename, 123 | self.args.checkpoint_dir + 'model_best.pth.tar') 124 | 125 | def load_checkpoint(self, filename): 126 | filename = self.args.checkpoint_dir + filename 127 | try: 128 | print("Loading checkpoint '{}'".format(filename)) 129 | checkpoint = torch.load(filename) 130 | self.args.start_epoch = checkpoint['epoch'] 131 | self.model.load_state_dict(checkpoint['state_dict']) 132 | self.optimizer.load_state_dict(checkpoint['optimizer']) 133 | print("Checkpoint loaded successfully from '{}' at (epoch {})\n" 134 | .format(self.args.checkpoint_dir, checkpoint['epoch'])) 135 | except: 136 | print("No checkpoint exists from '{}'. Skipping...\n".format(self.args.checkpoint_dir)) 137 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | from pprint import pprint 6 | 7 | import numpy as np 8 | from easydict import EasyDict as edict 9 | 10 | 11 | def parse_args(): 12 | """ 13 | Parse the arguments of the program 14 | :return: (config_args) 15 | :rtype: tuple 16 | """ 17 | # Create a parser 18 | parser = argparse.ArgumentParser(description="VAE PyTorch Implementation") 19 | parser.add_argument('--version', action='version', version='%(prog)s 0.0.1') 20 | parser.add_argument('--config', default=None, type=str, help='Configuration file') 21 | 22 | # Parse the arguments 23 | args = parser.parse_args() 24 | 25 | # Parse the configurations from the config json file provided 26 | try: 27 | if args.config is not None: 28 | with open(args.config, 'r') as config_file: 29 | config_args_dict = json.load(config_file) 30 | else: 31 | print("Add a config file using \'--config file_name.json\'", file=sys.stderr) 32 | exit(1) 33 | 34 | except FileNotFoundError: 35 | print("ERROR: Config file not found: {}".format(args.config), file=sys.stderr) 36 | exit(1) 37 | except json.decoder.JSONDecodeError: 38 | print("ERROR: Config file is not a proper JSON file!", file=sys.stderr) 39 | exit(1) 40 | 41 | config_args = edict(config_args_dict) 42 | 43 | pprint(config_args) 44 | print("\n") 45 | 46 | return config_args 47 | 48 | 49 | def create_experiment_dirs(exp_dir): 50 | """ 51 | Create Directories of a regular tensorflow experiment directory 52 | :param exp_dir: 53 | :return summary_dir, checkpoint_dir: 54 | """ 55 | experiment_dir = os.path.realpath( 56 | os.path.join(os.path.dirname(__file__))) + "/../experiments/" + exp_dir + "/" 57 | summary_dir = experiment_dir + 'summaries/' 58 | checkpoint_dir = experiment_dir + 'checkpoints/' 59 | 60 | dirs = [summary_dir, checkpoint_dir] 61 | try: 62 | for dir_ in dirs: 63 | if not os.path.exists(dir_): 64 | os.makedirs(dir_) 65 | print("Experiment directories created!") 66 | # return experiment_dir, summary_dir, checkpoint_dir 67 | return summary_dir, checkpoint_dir 68 | except Exception as err: 69 | print("Creating directories error: {0}".format(err)) 70 | exit(-1) 71 | -------------------------------------------------------------------------------- /utils/weight_initializer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Initializer: 5 | def __init__(self): 6 | pass 7 | 8 | @staticmethod 9 | def initialize(model, initialization, **kwargs): 10 | 11 | def weights_init(m): 12 | if isinstance(m, nn.Conv2d): 13 | initialization(m.weight.data, **kwargs) 14 | try: 15 | initialization(m.bias.data) 16 | except: 17 | pass 18 | 19 | elif isinstance(m, nn.Linear): 20 | initialization(m.weight.data, **kwargs) 21 | try: 22 | initialization(m.bias.data) 23 | except: 24 | pass 25 | 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1.0) 28 | m.bias.data.fill_(0) 29 | 30 | elif isinstance(m, nn.BatchNorm1d): 31 | m.weight.data.fill_(1.0) 32 | m.bias.data.fill_(0) 33 | 34 | model.apply(weights_init) 35 | --------------------------------------------------------------------------------