├── .gitignore ├── LICENSE ├── README.md ├── arts └── vae.png ├── data.py ├── main.py ├── model.py ├── requirements.txt ├── train.py ├── utils.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | datasets/ 103 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ha Junsoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAE PyTorch Implementation 2 | 3 | PyTorch implementation of [Auto-Encoding Variational Bayes, arxiv:1312.6114](https://arxiv.org/abs/1312.6114) 4 | 5 | ![vae-graphical-model](./arts/vae.png) 6 | 7 | 8 | ## Installation 9 | 10 | ``` 11 | $ git clone https://github.com/kuc2477/pytorch-vae && cd pytorch-vae 12 | $ pip install -r requirements.txt 13 | ``` 14 | 15 | 16 | ## CLI 17 | 18 | Implementation CLI is provided by `main.py` 19 | 20 | #### Usage 21 | ``` 22 | $ ./main.py --help 23 | $ usage: VAE PyTorch implementation [-h] [--dataset {mnist,cifar10,cifar100}] 24 | [--kernel-num KERNEL_NUM] [--z-size Z_SIZE] 25 | [--epochs EPOCHS] [--batch-size BATCH_SIZE] 26 | [--sample-size SAMPLE_SIZE] [--lr LR] 27 | [--weight-decay WEIGHT_DECAY] 28 | [--loss-log-interval LOSS_LOG_INTERVAL] 29 | [--image-log-interval IMAGE_LOG_INTERVAL] 30 | [--resume] [--checkpoint-dir CHECKPOINT_DIR] 31 | [--sample-dir SAMPLE_DIR] [--no-gpus] 32 | (--test | --train) 33 | 34 | optional arguments: 35 | -h, --help show this help message and exit 36 | --dataset {mnist,cifar10,cifar100} 37 | --kernel-num KERNEL_NUM 38 | --z-size Z_SIZE 39 | --epochs EPOCHS 40 | --batch-size BATCH_SIZE 41 | --sample-size SAMPLE_SIZE 42 | --lr LR 43 | --weight-decay WEIGHT_DECAY 44 | --loss-log-interval LOSS_LOG_INTERVAL 45 | --image-log-interval IMAGE_LOG_INTERVAL 46 | --resume 47 | --checkpoint-dir CHECKPOINT_DIR 48 | --sample-dir SAMPLE_DIR 49 | --no-gpus 50 | --test 51 | --train 52 | ``` 53 | 54 | #### Train 55 | ``` 56 | ./main.py --train 57 | ``` 58 | 59 | #### Test 60 | ``` 61 | ./main.py --test 62 | ``` 63 | 64 | 65 | ## Reference 66 | - [Auto-Encoding Variational Bayes, arxiv:1312.6114](https://arxiv.org/abs/1312.6114) 67 | 68 | 69 | ## Author 70 | Ha Junsoo / [@kuc2477](https://github.com/kuc2477) / MIT License 71 | -------------------------------------------------------------------------------- /arts/vae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-vae/726a27eb668edcd7a6afdd23a2ee001a1f31d54f/arts/vae.png -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | 4 | _MNIST_TRAIN_TRANSFORMS = _MNIST_TEST_TRANSFORMS = [ 5 | transforms.ToTensor(), 6 | transforms.ToPILImage(), 7 | transforms.Pad(2), 8 | transforms.ToTensor(), 9 | ] 10 | 11 | _CIFAR_TRAIN_TRANSFORMS = [ 12 | transforms.RandomCrop(32, padding=4), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | ] 16 | 17 | _CIFAR_TEST_TRANSFORMS = [ 18 | transforms.ToTensor(), 19 | ] 20 | 21 | 22 | TRAIN_DATASETS = { 23 | 'mnist': datasets.MNIST( 24 | './datasets/mnist', train=True, download=True, 25 | transform=transforms.Compose(_MNIST_TRAIN_TRANSFORMS) 26 | ), 27 | 'cifar10': datasets.CIFAR10( 28 | './datasets/cifar10', train=True, download=True, 29 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS) 30 | ), 31 | 'cifar100': datasets.CIFAR100( 32 | './datasets/cifar100', train=True, download=True, 33 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS) 34 | ) 35 | } 36 | 37 | 38 | TEST_DATASETS = { 39 | 'mnist': datasets.MNIST( 40 | './datasets/mnist', train=False, 41 | transform=transforms.Compose(_MNIST_TEST_TRANSFORMS) 42 | ), 43 | 'cifar10': datasets.CIFAR10( 44 | './datasets/cifar10', train=False, 45 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS) 46 | ), 47 | 'cifar100': datasets.CIFAR100( 48 | './datasets/cifar100', train=False, 49 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS) 50 | ) 51 | } 52 | 53 | 54 | DATASET_CONFIGS = { 55 | 'mnist': {'size': 32, 'channels': 1, 'classes': 10}, 56 | 'cifar10': {'size': 32, 'channels': 3, 'classes': 10}, 57 | 'cifar100': {'size': 32, 'channels': 3, 'classes': 100}, 58 | } 59 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import torch 4 | import torchvision 5 | from model import VAE 6 | from data import TRAIN_DATASETS, DATASET_CONFIGS 7 | from train import train_model 8 | 9 | 10 | parser = argparse.ArgumentParser('VAE PyTorch implementation') 11 | parser.add_argument('--dataset', default='mnist', 12 | choices=list(TRAIN_DATASETS.keys())) 13 | 14 | parser.add_argument('--kernel-num', type=int, default=128) 15 | parser.add_argument('--z-size', type=int, default=128) 16 | 17 | parser.add_argument('--epochs', type=int, default=50) 18 | parser.add_argument('--batch-size', type=int, default=64) 19 | parser.add_argument('--sample-size', type=int, default=32) 20 | parser.add_argument('--lr', type=float, default=3e-05) 21 | parser.add_argument('--weight-decay', type=float, default=1e-06) 22 | 23 | parser.add_argument('--loss-log-interval', type=int, default=100) 24 | parser.add_argument('--image-log-interval', type=int, default=500) 25 | parser.add_argument('--resume', action='store_true') 26 | parser.add_argument('--checkpoint-dir', type=str, default='./checkpoints') 27 | parser.add_argument('--sample-dir', type=str, default='./samples') 28 | parser.add_argument('--no-gpus', action='store_false', dest='cuda') 29 | 30 | main_command = parser.add_mutually_exclusive_group(required=True) 31 | main_command.add_argument('--test', action='store_false', dest='train') 32 | main_command.add_argument('--train', action='store_true') 33 | 34 | 35 | if __name__ == '__main__': 36 | args = parser.parse_args() 37 | cuda = args.cuda and torch.cuda.is_available() 38 | dataset_config = DATASET_CONFIGS[args.dataset] 39 | dataset = TRAIN_DATASETS[args.dataset] 40 | 41 | vae = VAE( 42 | label=args.dataset, 43 | image_size=dataset_config['size'], 44 | channel_num=dataset_config['channels'], 45 | kernel_num=args.kernel_num, 46 | z_size=args.z_size, 47 | ) 48 | 49 | # move the model parameters to the gpu if needed. 50 | if cuda: 51 | vae.cuda() 52 | 53 | # run a test or a training process. 54 | if args.train: 55 | train_model( 56 | vae, dataset=dataset, 57 | epochs=args.epochs, 58 | batch_size=args.batch_size, 59 | sample_size=args.sample_size, 60 | lr=args.lr, 61 | weight_decay=args.weight_decay, 62 | checkpoint_dir=args.checkpoint_dir, 63 | loss_log_interval=args.loss_log_interval, 64 | image_log_interval=args.image_log_interval, 65 | resume=args.resume, 66 | cuda=cuda, 67 | ) 68 | else: 69 | images = vae.sample(args.sample_size) 70 | torchvision.utils.save_image(images, args.sample_dir) 71 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch import nn 4 | 5 | 6 | class VAE(nn.Module): 7 | def __init__(self, label, image_size, channel_num, kernel_num, z_size): 8 | # configurations 9 | super().__init__() 10 | self.label = label 11 | self.image_size = image_size 12 | self.channel_num = channel_num 13 | self.kernel_num = kernel_num 14 | self.z_size = z_size 15 | 16 | # encoder 17 | self.encoder = nn.Sequential( 18 | self._conv(channel_num, kernel_num // 4), 19 | self._conv(kernel_num // 4, kernel_num // 2), 20 | self._conv(kernel_num // 2, kernel_num, last=True), 21 | ) 22 | 23 | # encoded feature's size and volume 24 | self.feature_size = image_size // 8 25 | self.feature_volume = kernel_num * (self.feature_size ** 2) 26 | 27 | # q 28 | self.q_mean = self._linear(self.feature_volume, z_size, relu=False) 29 | self.q_logvar = self._linear(self.feature_volume, z_size, relu=False) 30 | 31 | # projection 32 | self.project = self._linear(z_size, self.feature_volume, relu=False) 33 | 34 | # decoder 35 | self.decoder = nn.Sequential( 36 | self._deconv(kernel_num, kernel_num // 2), 37 | self._deconv(kernel_num // 2, kernel_num // 4), 38 | self._deconv(kernel_num // 4, channel_num, last=True), 39 | nn.Sigmoid() 40 | ) 41 | 42 | def forward(self, x): 43 | # encode x 44 | encoded = self.encoder(x) 45 | 46 | # sample latent code z from q given x. 47 | mean, logvar = self.q(encoded) 48 | z = self.z(mean, logvar) 49 | z_projected = self.project(z).view( 50 | -1, self.kernel_num, 51 | self.feature_size, 52 | self.feature_size, 53 | ) 54 | 55 | # reconstruct x from z 56 | x_reconstructed = self.decoder(z_projected) 57 | 58 | # return the parameters of distribution of q given x and the 59 | # reconstructed image. 60 | return (mean, logvar), x_reconstructed 61 | 62 | # ============== 63 | # VAE components 64 | # ============== 65 | 66 | def q(self, encoded): 67 | unrolled = encoded.view(-1, self.feature_volume) 68 | return self.q_mean(unrolled), self.q_logvar(unrolled) 69 | 70 | def z(self, mean, logvar): 71 | std = logvar.mul(0.5).exp_() 72 | eps = ( 73 | Variable(torch.randn(std.size())).cuda() if self._is_on_cuda else 74 | Variable(torch.randn(std.size())) 75 | ) 76 | return eps.mul(std).add_(mean) 77 | 78 | def reconstruction_loss(self, x_reconstructed, x): 79 | return nn.BCELoss(size_average=False)(x_reconstructed, x) / x.size(0) 80 | 81 | def kl_divergence_loss(self, mean, logvar): 82 | return ((mean**2 + logvar.exp() - 1 - logvar) / 2).sum() / mean.size(0) 83 | 84 | # ===== 85 | # Utils 86 | # ===== 87 | 88 | @property 89 | def name(self): 90 | return ( 91 | 'VAE' 92 | '-{kernel_num}k' 93 | '-{label}' 94 | '-{channel_num}x{image_size}x{image_size}' 95 | ).format( 96 | label=self.label, 97 | kernel_num=self.kernel_num, 98 | image_size=self.image_size, 99 | channel_num=self.channel_num, 100 | ) 101 | 102 | def sample(self, size): 103 | z = Variable( 104 | torch.randn(size, self.z_size).cuda() if self._is_on_cuda() else 105 | torch.randn(size, self.z_size) 106 | ) 107 | z_projected = self.project(z).view( 108 | -1, self.kernel_num, 109 | self.feature_size, 110 | self.feature_size, 111 | ) 112 | return self.decoder(z_projected).data 113 | 114 | def _is_on_cuda(self): 115 | return next(self.parameters()).is_cuda 116 | 117 | # ====== 118 | # Layers 119 | # ====== 120 | 121 | def _conv(self, channel_size, kernel_num, last=False): 122 | conv = nn.Conv2d( 123 | channel_size, kernel_num, 124 | kernel_size=3, stride=2, padding=1, 125 | ) 126 | return conv if last else nn.Sequential( 127 | conv, 128 | nn.BatchNorm2d(kernel_num), 129 | nn.ReLU(), 130 | ) 131 | 132 | def _deconv(self, channel_num, kernel_num, last=False): 133 | deconv = nn.ConvTranspose2d( 134 | channel_num, kernel_num, 135 | kernel_size=4, stride=2, padding=1, 136 | ) 137 | return deconv if last else nn.Sequential( 138 | deconv, 139 | nn.BatchNorm2d(kernel_num), 140 | nn.ReLU(), 141 | ) 142 | 143 | def _linear(self, in_size, out_size, relu=True): 144 | return nn.Sequential( 145 | nn.Linear(in_size, out_size), 146 | nn.ReLU(), 147 | ) if relu else nn.Linear(in_size, out_size) 148 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl 3 | torchvision==0.1.8 4 | torchtext==0.1.1 5 | visdom 6 | 7 | # data 8 | scipy 9 | scikit-learn 10 | numpy 11 | pillow 12 | 13 | # utils 14 | tqdm 15 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch.autograd import Variable 3 | from tqdm import tqdm 4 | import utils 5 | import visual 6 | 7 | 8 | def train_model(model, dataset, epochs=10, 9 | batch_size=32, sample_size=32, 10 | lr=3e-04, weight_decay=1e-5, 11 | loss_log_interval=30, 12 | image_log_interval=300, 13 | checkpoint_dir='./checkpoints', 14 | resume=False, 15 | cuda=False): 16 | # prepare optimizer and model 17 | model.train() 18 | optimizer = optim.Adam( 19 | model.parameters(), lr=lr, 20 | weight_decay=weight_decay, 21 | ) 22 | 23 | if resume: 24 | epoch_start = utils.load_checkpoint(model, checkpoint_dir) 25 | else: 26 | epoch_start = 1 27 | 28 | for epoch in range(epoch_start, epochs+1): 29 | data_loader = utils.get_data_loader(dataset, batch_size, cuda=cuda) 30 | data_stream = tqdm(enumerate(data_loader, 1)) 31 | 32 | for batch_index, (x, _) in data_stream: 33 | # where are we? 34 | iteration = (epoch-1)*(len(dataset)//batch_size) + batch_index 35 | 36 | # prepare data on gpu if needed 37 | x = Variable(x).cuda() if cuda else Variable(x) 38 | 39 | # flush gradients and run the model forward 40 | optimizer.zero_grad() 41 | (mean, logvar), x_reconstructed = model(x) 42 | reconstruction_loss = model.reconstruction_loss(x_reconstructed, x) 43 | kl_divergence_loss = model.kl_divergence_loss(mean, logvar) 44 | total_loss = reconstruction_loss + kl_divergence_loss 45 | 46 | # backprop gradients from the loss 47 | total_loss.backward() 48 | optimizer.step() 49 | 50 | # update progress 51 | data_stream.set_description(( 52 | 'epoch: {epoch} | ' 53 | 'iteration: {iteration} | ' 54 | 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 55 | 'loss => ' 56 | 'total: {total_loss:.4f} / ' 57 | 're: {reconstruction_loss:.3f} / ' 58 | 'kl: {kl_divergence_loss:.3f}' 59 | ).format( 60 | epoch=epoch, 61 | iteration=iteration, 62 | trained=batch_index * len(x), 63 | total=len(data_loader.dataset), 64 | progress=(100. * batch_index / len(data_loader)), 65 | total_loss=total_loss.data.item(), 66 | reconstruction_loss=reconstruction_loss.data.item(), 67 | kl_divergence_loss=kl_divergence_loss.data.item(), 68 | )) 69 | 70 | if iteration % loss_log_interval == 0: 71 | losses = [ 72 | reconstruction_loss.data.item(), 73 | kl_divergence_loss.data.item(), 74 | total_loss.data.item(), 75 | ] 76 | names = ['reconstruction', 'kl divergence', 'total'] 77 | visual.visualize_scalars( 78 | losses, names, 'loss', 79 | iteration, env=model.name) 80 | 81 | if iteration % image_log_interval == 0: 82 | images = model.sample(sample_size) 83 | visual.visualize_images( 84 | images, 'generated samples', 85 | env=model.name 86 | ) 87 | 88 | # notify that we've reached to a new checkpoint. 89 | print() 90 | print() 91 | print('#############') 92 | print('# checkpoint!') 93 | print('#############') 94 | print() 95 | 96 | # save the checkpoint. 97 | utils.save_checkpoint(model, checkpoint_dir, epoch) 98 | print() 99 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | def get_data_loader(dataset, batch_size, cuda=False): 8 | return DataLoader( 9 | dataset, batch_size=batch_size, shuffle=True, 10 | **({'num_workers': 1, 'pin_memory': True} if cuda else {}) 11 | ) 12 | 13 | 14 | def save_checkpoint(model, model_dir, epoch): 15 | path = os.path.join(model_dir, model.name) 16 | 17 | # save the checkpoint. 18 | if not os.path.exists(model_dir): 19 | os.makedirs(model_dir) 20 | torch.save({'state': model.state_dict(), 'epoch': epoch}, path) 21 | 22 | # notify that we successfully saved the checkpoint. 23 | print('=> saved the model {name} to {path}'.format( 24 | name=model.name, path=path 25 | )) 26 | 27 | 28 | def load_checkpoint(model, model_dir): 29 | path = os.path.join(model_dir, model.name) 30 | 31 | # load the checkpoint. 32 | checkpoint = torch.load(path) 33 | print('=> loaded checkpoint of {name} from {path}'.format( 34 | name=model.name, path=(path) 35 | )) 36 | 37 | # load parameters and return the checkpoint's epoch and precision. 38 | model.load_state_dict(checkpoint['state']) 39 | epoch = checkpoint['epoch'] 40 | return epoch 41 | 42 | 43 | def xavier_initialize(model): 44 | modules = [ 45 | m for n, m in model.named_modules() if 46 | 'conv' in n or 'linear' in n 47 | ] 48 | 49 | parameters = [ 50 | p for 51 | m in modules for 52 | p in m.parameters() if 53 | p.dim() >= 2 54 | ] 55 | 56 | for p in parameters: 57 | init.xavier_normal(p) 58 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.cuda import FloatTensor as CUDATensor 3 | from visdom import Visdom 4 | 5 | _WINDOW_CASH = {} 6 | 7 | 8 | def _vis(env='main'): 9 | return Visdom(env=env) 10 | 11 | 12 | def visualize_image(tensor, name, label=None, env='main', w=250, h=250, 13 | update_window_without_label=False): 14 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 15 | title = name + ('-{}'.format(label) if label is not None else '') 16 | 17 | _WINDOW_CASH[title] = _vis(env).image( 18 | tensor.numpy(), win=_WINDOW_CASH.get(title), 19 | opts=dict(title=title, width=w, height=h) 20 | ) 21 | 22 | # This is useful when you want to maintain the most recent images. 23 | if update_window_without_label: 24 | _WINDOW_CASH[name] = _vis(env).image( 25 | tensor.numpy(), win=_WINDOW_CASH.get(name), 26 | opts=dict(title=name, width=w, height=h) 27 | ) 28 | 29 | 30 | def visualize_images(tensor, name, label=None, env='main', w=250, h=250, 31 | update_window_without_label=False): 32 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 33 | title = name + ('-{}'.format(label) if label is not None else '') 34 | 35 | _WINDOW_CASH[title] = _vis(env).images( 36 | tensor.numpy(), win=_WINDOW_CASH.get(title), 37 | opts=dict(title=title, width=w, height=h) 38 | ) 39 | 40 | # This is useful when you want to maintain the most recent images. 41 | if update_window_without_label: 42 | _WINDOW_CASH[name] = _vis(env).images( 43 | tensor.numpy(), win=_WINDOW_CASH.get(name), 44 | opts=dict(title=name, width=w, height=h) 45 | ) 46 | 47 | 48 | def visualize_kernel(kernel, name, label=None, env='main', w=250, h=250, 49 | update_window_without_label=False, compress_tensor=False): 50 | # Do not visualize kernels that does not exists. 51 | if kernel is None: 52 | return 53 | 54 | assert len(kernel.size()) in (2, 4) 55 | title = name + ('-{}'.format(label) if label is not None else '') 56 | kernel = kernel.cpu() if isinstance(kernel, CUDATensor) else kernel 57 | kernel_norm = kernel if len(kernel.size()) == 2 else ( 58 | (kernel**2).mean(-1).mean(-1) if compress_tensor else 59 | kernel.view( 60 | kernel.size()[0] * kernel.size()[2], 61 | kernel.size()[1] * kernel.size()[3], 62 | ) 63 | ) 64 | kernel_norm = kernel_norm.abs() 65 | 66 | visualized = ( 67 | (kernel_norm - kernel_norm.min()) / 68 | (kernel_norm.max() - kernel_norm.min()) 69 | ).numpy() 70 | 71 | _WINDOW_CASH[title] = _vis(env).image( 72 | visualized, win=_WINDOW_CASH.get(title), 73 | opts=dict(title=title, width=w, height=h) 74 | ) 75 | 76 | # This is useful when you want to maintain the most recent images. 77 | if update_window_without_label: 78 | _WINDOW_CASH[name] = _vis(env).image( 79 | visualized, win=_WINDOW_CASH.get(name), 80 | opts=dict(title=name, width=w, height=h) 81 | ) 82 | 83 | 84 | def visualize_scalar(scalar, name, iteration, env='main'): 85 | visualize_scalars( 86 | [scalar] if isinstance(scalar, float) or len(scalar) == 1 else scalar, 87 | [name], name, iteration, env=env 88 | ) 89 | 90 | 91 | def visualize_scalars(scalars, names, title, iteration, env='main'): 92 | assert len(scalars) == len(names) 93 | # Convert scalar tensors to numpy arrays. 94 | scalars, names = list(scalars), list(names) 95 | scalars = [s.cpu() if isinstance(s, CUDATensor) else s for s in scalars] 96 | scalars = [s.numpy() if hasattr(s, 'numpy') else np.array([s]) for s in 97 | scalars] 98 | multi = len(scalars) > 1 99 | num = len(scalars) 100 | 101 | options = dict( 102 | fillarea=True, 103 | legend=names, 104 | width=400, 105 | height=400, 106 | xlabel='Iterations', 107 | ylabel=title, 108 | title=title, 109 | marginleft=30, 110 | marginright=30, 111 | marginbottom=80, 112 | margintop=30, 113 | ) 114 | 115 | X = ( 116 | np.column_stack(np.array([iteration] * num)) if multi else 117 | np.array([iteration] * num) 118 | ) 119 | Y = np.column_stack(scalars) if multi else scalars[0] 120 | 121 | if title in _WINDOW_CASH: 122 | _vis(env).line(X=X, Y=Y, win=_WINDOW_CASH[title], opts=options) 123 | else: 124 | _WINDOW_CASH[title] = _vis(env).line(X=X, Y=Y, opts=options) 125 | --------------------------------------------------------------------------------