├── .gitignore ├── README.md ├── config ├── __init__.py ├── vqvae_colored_mnist.yaml └── vqvae_mnist.yaml ├── data ├── dataset ├── __init__.py └── mnist_dataset.py ├── model ├── __init__.py ├── decoder.py ├── encoder.py ├── quantizer.py └── vqvae.py ├── requirements.txt ├── run_simple_vqvae.py └── tools ├── __init__.py ├── generate_images.py ├── infer_vqvae.py ├── train_lstm.py └── train_vqvae.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all image files 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | 6 | # Ignore pycharm and system files 7 | .DS_Store 8 | *.idea 9 | __pycache__ 10 | *.zip 11 | 12 | # Ignore dataset files 13 | *.csv 14 | 15 | # Ignore checkpoints 16 | *.pth 17 | 18 | # Ignore pickle files 19 | *.pkl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | VQVAE Implementation in pytorch with generation using LSTM 2 | ======== 3 | 4 | This repository implements [VQVAE](https://arxiv.org/abs/1711.00937) for mnist and colored version of mnist and follows up with a simple LSTM for generating numbers. 5 | 6 | ## VQVAE Explanation and Implementation Video 7 | 8 | VQVAE Video 10 | 11 | 12 | 13 | # Quickstart 14 | * Create a new conda environment with python 3.8 then run below commands 15 | * ```git clone https://github.com/explainingai-code/VQVAE-Pytorch.git``` 16 | * ```cd VQVAE-Pytorch``` 17 | * ```pip install -r requirements.txt``` 18 | * For running a simple VQVAE with minimal code to understand the basics ```python run_simple_vqvae.py``` 19 | * For playing around with VQVAE and training/inferencing the LSTM use the below commands passing the desired configuration file as the config argument 20 | * ```python -m tools.train_vqvae``` for training vqvae 21 | * ```python -m tools.infer_vqvae``` for generating reconstructions and encoder outputs for LSTM training 22 | * ```python -m tools.train_lstm``` for training minimal LSTM 23 | * ```python -m tools.generate_images``` for using the trained LSTM to generate some numbers 24 | 25 | ## Configurations 26 | * ```config/vqvae_mnist.yaml``` - VQVAE for training on black and white mnist images 27 | * ```config/vqvae_colored_mnist.yaml``` - VQVAE with more embedding vectors for training colored mnist images 28 | 29 | ## Data preparation 30 | For setting up the dataset: 31 | Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation 32 | 33 | Verify the data directory has the following structure: 34 | ``` 35 | VQVAE-Pytorch/data/train/images/{0/1/.../9} 36 | *.png 37 | VQVAE-Pytorch/data/test/images/{0/1/.../9} 38 | *.png 39 | ``` 40 | 41 | ## Output 42 | Outputs will be saved according to the configuration present in yaml files. 43 | 44 | For every run a folder of ```task_name``` key in config will be created and ```output_train_dir``` will be created inside it. 45 | 46 | During training of VQVAE the following output will be saved 47 | * Best Model checkpoints(VQVAE and LSTM) in ```task_name``` directory 48 | 49 | During inference the following output will be saved 50 | * Reconstructions for sample of test set in ```task_name/output_train_dir/reconstruction.png``` 51 | * Encoder outputs on train set for LSTM training in ```task_name/output_train_dir/mnist_encodings.pkl``` 52 | * LSTM generation output in ```task_name/output_train_dir/generation_results.png``` 53 | 54 | 55 | ## Sample Output for VQVAE 56 | 57 | Running `run_simple_vqvae` should be very quick (as its very simple model) and give you below reconstructions (input in black black background and reconstruction in white background) 58 | 59 | 60 | 61 | Running default config VQVAE for mnist should give you below reconstructions for both versions 62 | 63 | 64 | 65 | 66 | Sample Generation Output after just 10 epochs 67 | Training the vqvae and lstm longer and more parameters(codebook size, codebook dimension, channels , lstm hidden dimension e.t.c) will give better results 68 | 69 | 70 | 71 | 72 | ## Citations 73 | ``` 74 | @misc{oord2018neural, 75 | title={Neural Discrete Representation Learning}, 76 | author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu}, 77 | year={2018}, 78 | eprint={1711.00937}, 79 | archivePrefix={arXiv}, 80 | primaryClass={cs.LG} 81 | } 82 | ``` 83 | 84 | 85 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/config/__init__.py -------------------------------------------------------------------------------- /config/vqvae_colored_mnist.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | in_channels: 3 3 | convbn_blocks: 4 4 | conv_kernel_size: [3,3,3,2] 5 | conv_kernel_strides: [2, 2, 1, 1] 6 | convbn_channels: [3, 16, 32, 8, 8] 7 | conv_activation_fn: 'leaky' 8 | transpose_bn_blocks: 4 9 | transposebn_channels: [8, 8, 32, 16, 3] 10 | transpose_kernel_size: [3,4,4,4] 11 | transpose_kernel_strides: [1,2,1,1] 12 | transpose_activation_fn: 'leaky' 13 | latent_dim: 8 14 | codebook_size : 20 15 | 16 | train_params: 17 | task_name: 'vqvae_latent_8_colored_codebook_20' 18 | batch_size: 64 19 | epochs: 20 20 | lr: 0.005 21 | crit: 'l2' 22 | reconstruction_loss_weight : 5 23 | codebook_loss_weight : 1 24 | commitment_loss_weight : 0.2 25 | ckpt_name: 'best_vqvae_latent_8_colored_codebook_20.pth' 26 | seed: 111 27 | save_training_image: True 28 | train_path: 'data/train/images' 29 | test_path: 'data/test/images' 30 | output_train_dir: 'output' 31 | -------------------------------------------------------------------------------- /config/vqvae_mnist.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | in_channels: 1 3 | convbn_blocks: 4 4 | conv_kernel_size: [3,3,3,2] 5 | conv_kernel_strides: [2, 2, 1, 1] 6 | convbn_channels: [1, 16, 32, 8, 4] 7 | conv_activation_fn: 'leaky' 8 | transpose_bn_blocks: 4 9 | transposebn_channels: [4, 8, 32, 16, 1] 10 | transpose_kernel_size: [3,4,4,4] 11 | transpose_kernel_strides: [1,2,1,1] 12 | transpose_activation_fn: 'leaky' 13 | latent_dim: 2 14 | codebook_size : 5 15 | 16 | train_params: 17 | task_name: 'vqvae_latent_2_codebook_5' 18 | batch_size: 64 19 | epochs: 10 20 | lr: 0.005 21 | crit: 'l2' 22 | reconstruction_loss_weight : 1 23 | codebook_loss_weight : 1 24 | commitment_loss_weight : 0.2 25 | ckpt_name: 'best_vqvae_latent_2_codebook_5.pth' 26 | seed: 111 27 | save_training_image: True 28 | train_path: 'data/train/images' 29 | test_path: 'data/test/images' 30 | output_train_dir: 'output' 31 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | /Users/tusharkumar/PycharmProjects/explainingai-repos/Pytorch-VAE/data -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import torch 5 | import random 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torch.utils.data.dataset import Dataset 9 | from torch.utils.data.dataloader import DataLoader 10 | 11 | r""" 12 | Simple Dataloader for mnist. 13 | """ 14 | 15 | class MnistDataset(Dataset): 16 | def __init__(self, split, im_path, im_ext='png', im_channels=1): 17 | self.split = split 18 | self.im_ext = im_ext 19 | self.im_channels = im_channels 20 | self.images, self.labels = self.load_images(im_path) 21 | 22 | def load_images(self, im_path): 23 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path) 24 | ims = [] 25 | labels = [] 26 | for d_name in tqdm(os.listdir(im_path)): 27 | for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))): 28 | ims.append(fname) 29 | labels.append(int(d_name)) 30 | print('Found {} images for split {}'.format(len(ims), self.split)) 31 | return ims, labels 32 | 33 | def __len__(self): 34 | return len(self.images) 35 | 36 | def __getitem__(self, index): 37 | assert self.im_channels == 1 or self.im_channels == 3, "Input iamge channels can only be 1 or 3" 38 | if self.im_channels == 1: 39 | im = cv2.imread(self.images[index], 0) 40 | else: 41 | # Generate a random color digit 42 | im_1 = cv2.imread(self.images[index], 0)[None, :]*np.clip(random.random(), 0.2, 1.0) 43 | im_2 = cv2.imread(self.images[index], 0)[None, :]*np.clip(random.random(), 0.2, 1.0) 44 | im_3 = cv2.imread(self.images[index], 0)[None, :]*np.clip(random.random(), 0.2, 1.0) 45 | im = np.concatenate([im_1, im_2, im_3], axis=0) 46 | 47 | label = self.labels[index] 48 | # Convert to 0 to 255 into -1 to 1 49 | im = 2 * (im / 255) - 1 50 | im_tensor = torch.from_numpy(im)[None, :] if self.im_channels == 1 else torch.from_numpy(im) 51 | return im_tensor, torch.as_tensor(label) 52 | 53 | 54 | if __name__ == '__main__': 55 | mnist = MnistDataset('test', 'data/test/images', im_channels=3) 56 | mnist_loader = DataLoader(mnist, batch_size=16, shuffle=True, num_workers=0) 57 | for im, label in mnist_loader: 58 | print('Image dimension', im.shape) 59 | print('Label dimension: {}'.format(label.shape)) 60 | break 61 | 62 | 63 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/model/__init__.py -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Decoder(nn.Module): 6 | def __init__(self, 7 | config 8 | ): 9 | super(Decoder, self).__init__() 10 | activation_map = { 11 | 'relu': nn.ReLU(), 12 | 'leaky': nn.LeakyReLU(), 13 | 'tanh': nn.Tanh(), 14 | 'gelu': nn.GELU(), 15 | 'silu': nn.SiLU() 16 | } 17 | 18 | self.config = config 19 | ##### Validate the configuration for the model is correctly setup ####### 20 | assert config['transpose_activation_fn'] is None or config['transpose_activation_fn'] in activation_map 21 | self.latent_dim = config['latent_dim'] 22 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | 25 | 26 | self.decoder_layers = nn.ModuleList([ 27 | nn.Sequential( 28 | nn.ConvTranspose2d(config['transposebn_channels'][i], config['transposebn_channels'][i + 1], 29 | kernel_size=config['transpose_kernel_size'][i], 30 | stride=config['transpose_kernel_strides'][i], 31 | padding=0), 32 | nn.BatchNorm2d(config['transposebn_channels'][i + 1]), 33 | activation_map[config['transpose_activation_fn']] 34 | ) 35 | for i in range(config['transpose_bn_blocks']-1) 36 | ]) 37 | 38 | dec_last_idx = config['transpose_bn_blocks'] 39 | self.decoder_layers.append( 40 | nn.Sequential( 41 | nn.ConvTranspose2d(config['transposebn_channels'][dec_last_idx - 1], config['transposebn_channels'][dec_last_idx], 42 | kernel_size=config['transpose_kernel_size'][dec_last_idx - 1], 43 | stride=config['transpose_kernel_strides'][dec_last_idx - 1], 44 | padding=0), 45 | nn.Tanh() 46 | ) 47 | ) 48 | 49 | def forward(self, x): 50 | out = x 51 | for idx, layer in enumerate(self.decoder_layers): 52 | out = layer(out) 53 | return out 54 | 55 | 56 | def get_decoder(config): 57 | decoder = Decoder( 58 | config=config['model_params'] 59 | ) 60 | return decoder 61 | 62 | 63 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, 7 | config 8 | ): 9 | super(Encoder, self).__init__() 10 | activation_map = { 11 | 'relu': nn.ReLU(), 12 | 'leaky': nn.LeakyReLU(), 13 | 'tanh': nn.Tanh(), 14 | 'gelu': nn.GELU(), 15 | 'silu': nn.SiLU() 16 | } 17 | 18 | self.config = config 19 | 20 | ##### Validate the configuration for the model is correctly setup ####### 21 | assert config['conv_activation_fn'] is None or config['conv_activation_fn'] in activation_map 22 | self.latent_dim = config['latent_dim'] 23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | # Encoder is just Conv bn activation blocks 26 | self.encoder_layers = nn.ModuleList([ 27 | nn.Sequential( 28 | nn.Conv2d(config['convbn_channels'][i], config['convbn_channels'][i + 1], 29 | kernel_size=config['conv_kernel_size'][i], stride=config['conv_kernel_strides'][i],padding=1), 30 | nn.BatchNorm2d(config['convbn_channels'][i + 1]), 31 | activation_map[config['conv_activation_fn']], 32 | ) 33 | for i in range(config['convbn_blocks']-1) 34 | ]) 35 | 36 | enc_last_idx = config['convbn_blocks'] 37 | self.encoder_layers.append( 38 | nn.Sequential( 39 | nn.Conv2d(config['convbn_channels'][enc_last_idx - 1], config['convbn_channels'][enc_last_idx], 40 | kernel_size=config['conv_kernel_size'][enc_last_idx-1], 41 | stride=config['conv_kernel_strides'][enc_last_idx-1], padding=1), 42 | ) 43 | ) 44 | 45 | def forward(self, x): 46 | out = x 47 | for layer in self.encoder_layers: 48 | out = layer(out) 49 | return out 50 | 51 | 52 | def get_encoder(config): 53 | encoder = Encoder( 54 | config=config['model_params'] 55 | ) 56 | return encoder 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /model/quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import einsum 4 | 5 | 6 | class Quantizer(nn.Module): 7 | def __init__(self, 8 | config 9 | ): 10 | super(Quantizer, self).__init__() 11 | self.config = config 12 | self.embedding = nn.Embedding(config['codebook_size'], config['latent_dim']) 13 | 14 | def forward(self, x): 15 | B, C, H, W = x.shape 16 | x = x.permute(0, 2, 3, 1) 17 | x = x.reshape(x.size(0), -1, x.size(-1)) 18 | 19 | dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) 20 | min_encoding_indices = torch.argmin(dist, dim=-1) 21 | 22 | # 23 | quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) 24 | x = x.reshape((-1, x.size(-1))) 25 | commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) 26 | codebook_loss = torch.mean((quant_out - x.detach()) ** 2) 27 | quantize_losses = { 28 | 'codebook_loss' : codebook_loss, 29 | 'commitment_loss' : commmitment_loss 30 | } 31 | quant_out = x + (quant_out - x).detach() 32 | quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) 33 | min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) 34 | return quant_out, quantize_losses, min_encoding_indices 35 | 36 | def quantize_indices(self, indices): 37 | return einsum(indices, self.embedding.weight, 'b n h w, n d -> b d h w') 38 | 39 | 40 | def get_quantizer(config): 41 | quantizer = Quantizer( 42 | config=config['model_params'] 43 | ) 44 | return quantizer 45 | 46 | -------------------------------------------------------------------------------- /model/vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.encoder import get_encoder 4 | from model.decoder import get_decoder 5 | from model.quantizer import get_quantizer 6 | 7 | 8 | class VQVAE(nn.Module): 9 | def __init__(self, 10 | config 11 | ): 12 | super(VQVAE, self).__init__() 13 | self.encoder = get_encoder(config) 14 | self.pre_quant_conv = nn.Conv2d(config['model_params']['convbn_channels'][-1], 15 | config['model_params']['latent_dim'], 16 | kernel_size=1) 17 | self.quantizer = get_quantizer(config) 18 | self.post_quant_conv = nn.Conv2d(config['model_params']['latent_dim'], 19 | config['model_params']['transposebn_channels'][0], 20 | kernel_size=1) 21 | self.decoder = get_decoder(config) 22 | 23 | def forward(self, x): 24 | enc = self.encoder(x) 25 | quant_input = self.pre_quant_conv(enc) 26 | quant_output, quant_loss, quant_idxs = self.quantizer(quant_input) 27 | dec_input = self.post_quant_conv(quant_output) 28 | out = self.decoder(dec_input) 29 | return { 30 | 'generated_image' : out, 31 | 'quantized_output' : quant_output, 32 | 'quantized_losses' : quant_loss, 33 | 'quantized_indices' : quant_idxs 34 | } 35 | 36 | def decode_from_codebook_indices(self, indices): 37 | quantized_output = self.quantizer.quantize_indices(indices) 38 | dec_input = self.post_quant_conv(quantized_output) 39 | return self.decoder(dec_input) 40 | 41 | 42 | 43 | def get_model(config): 44 | print(config) 45 | model = VQVAE( 46 | config=config 47 | ) 48 | return model 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | matplotlib==3.7.2 3 | numpy==1.23.5 4 | opencv_python==4.8.0.74 5 | PyYAML==6.0 6 | torch==1.11.0 7 | torchvision==0.12.0 8 | tqdm==4.65.0 9 | -------------------------------------------------------------------------------- /run_simple_vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import torchvision 4 | import torch.nn as nn 5 | import numpy as np 6 | from tqdm import tqdm 7 | from einops import rearrange 8 | from torch.optim import Adam 9 | from dataset.mnist_dataset import MnistDataset 10 | from torch.utils.data import DataLoader 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | class VQVAE(nn.Module): 16 | def __init__(self): 17 | super(VQVAE, self).__init__() 18 | self.encoder = nn.Sequential( 19 | nn.Conv2d(1, 16, 4, stride=2, padding=1), 20 | nn.BatchNorm2d(16), 21 | nn.ReLU(), 22 | nn.Conv2d(16, 4, 4, stride=2, padding=1), 23 | nn.BatchNorm2d(4), 24 | nn.ReLU(), 25 | ) 26 | 27 | self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1) 28 | self.embedding = nn.Embedding(num_embeddings=3, embedding_dim=2) 29 | self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1) 30 | 31 | # Commitment Loss Beta 32 | self.beta = 0.2 33 | 34 | self.decoder = nn.Sequential( 35 | nn.ConvTranspose2d(4, 16, 4, stride=2, padding=1), 36 | nn.BatchNorm2d(16), 37 | nn.ReLU(), 38 | nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1), 39 | nn.Tanh(), 40 | ) 41 | 42 | 43 | def forward(self, x): 44 | # B, C, H, W 45 | encoded_output = self.encoder(x) 46 | quant_input = self.pre_quant_conv(encoded_output) 47 | 48 | ## Quantization 49 | B, C, H, W = quant_input.shape 50 | quant_input = quant_input.permute(0, 2, 3, 1) 51 | quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) 52 | 53 | # Compute pairwise distances 54 | dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) 55 | 56 | # Find index of nearest embedding 57 | min_encoding_indices = torch.argmin(dist, dim=-1) 58 | 59 | # Select the embedding weights 60 | quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) 61 | quant_input = quant_input.reshape((-1, quant_input.size(-1))) 62 | 63 | # Compute losses 64 | commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) 65 | codebook_loss = torch.mean((quant_out - quant_input.detach())**2) 66 | quantize_losses = codebook_loss + self.beta*commitment_loss 67 | 68 | # Ensure straight through gradient 69 | quant_out = quant_input + (quant_out - quant_input).detach() 70 | 71 | # Reshaping back to original input shape 72 | quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) 73 | min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) 74 | 75 | 76 | ## Decoder part 77 | decoder_input = self.post_quant_conv(quant_out) 78 | output = self.decoder(decoder_input) 79 | return output, quantize_losses 80 | 81 | def train_vqvae(): 82 | mnist = MnistDataset('train', im_path='data/train/images') 83 | mnist_test = MnistDataset('test', im_path='data/test/images') 84 | mnist_loader = DataLoader(mnist, batch_size=64, shuffle=True, num_workers=4) 85 | 86 | model = VQVAE().to(device) 87 | 88 | num_epochs = 20 89 | optimizer = Adam(model.parameters(), lr=1E-3) 90 | criterion = torch.nn.MSELoss() 91 | 92 | for epoch_idx in range(num_epochs): 93 | for im, label in tqdm(mnist_loader): 94 | im = im.float().to(device) 95 | optimizer.zero_grad() 96 | out, quantize_loss = model(im) 97 | 98 | recon_loss = criterion(out, im) 99 | loss = recon_loss + quantize_loss 100 | loss.backward() 101 | optimizer.step() 102 | print('Finished epoch {}'.format(epoch_idx+1)) 103 | print('Done Training...') 104 | 105 | # Reconstruction part 106 | 107 | idxs = torch.randint(0, len(mnist_test), (100, )) 108 | ims = torch.cat([mnist_test[idx][0][None, :] for idx in idxs]).float() 109 | ims = ims.to(device) 110 | model.eval() 111 | 112 | 113 | generated_im, _ = model(ims) 114 | ims = (ims+1)/2 115 | generated_im = 1 - (generated_im+1)/2 116 | out = torch.hstack([ims, generated_im]) 117 | output = rearrange(out, 'b c h w -> b () h (c w)') 118 | grid = torchvision.utils.make_grid(output.detach().cpu(), nrow=10) 119 | img = torchvision.transforms.ToPILImage()(grid) 120 | img.save('reconstruction.png') 121 | 122 | print('Done Reconstruction ...') 123 | 124 | if __name__ == '__main__': 125 | train_vqvae() 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/tools/__init__.py -------------------------------------------------------------------------------- /tools/generate_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | import torch 5 | import pickle 6 | from tqdm import tqdm 7 | import torchvision 8 | from model.vqvae import get_model 9 | from torchvision.utils import make_grid 10 | from tools.train_lstm import MnistLSTM 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | 14 | 15 | def generate(args): 16 | r""" 17 | Method for generating images after training vqvae and lstm 18 | 1. Create config 19 | 2. Create and load vqvae model 20 | 3. Create and load LSTM model 21 | 4. Generate 100 encoder outputs from trained LSTM 22 | 5. Pass them to the trained vqvae decoder 23 | 6. Save the generated image 24 | :param args: 25 | :return: 26 | """ 27 | 28 | ########## Read the Config ############## 29 | with open(args.config_path, 'r') as file: 30 | try: 31 | config = yaml.safe_load(file) 32 | except yaml.YAMLError as exc: 33 | print(exc) 34 | print(config) 35 | ######################################### 36 | 37 | ########## Load VQVAE Model ############## 38 | vqvae_model = get_model(config).to(device) 39 | vqvae_model.to(device) 40 | assert os.path.exists(os.path.join(config['train_params']['task_name'], 41 | config['train_params']['ckpt_name'])), "Train the vqvae model first" 42 | vqvae_model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'], 43 | config['train_params']['ckpt_name']), map_location=device)) 44 | 45 | vqvae_model.eval() 46 | ######################################### 47 | 48 | ########## Load LSTM ############## 49 | default_lstm_config = { 50 | 'input_size': 2, 51 | 'hidden_size': 128, 52 | 'codebook_size': config['model_params']['codebook_size'] 53 | } 54 | 55 | model = MnistLSTM(input_size=default_lstm_config['input_size'], 56 | hidden_size=default_lstm_config['hidden_size'], 57 | codebook_size=default_lstm_config['codebook_size']).to(device) 58 | model.to(device) 59 | assert os.path.exists(os.path.join(config['train_params']['task_name'], 60 | 'best_mnist_lstm.pth')), "Train the lstm first" 61 | model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'], 62 | 'best_mnist_lstm.pth'), map_location=device)) 63 | model.eval() 64 | ######################################### 65 | 66 | ################ Generate Samples ############# 67 | generated_quantized_indices = [] 68 | mnist_encodings = pickle.load(open(os.path.join(config['train_params']['task_name'], 69 | config['train_params']['output_train_dir'], 70 | 'mnist_encodings.pkl'), 'rb')) 71 | mnist_encodings_length = mnist_encodings.reshape(mnist_encodings.size(0), -1).shape[-1] 72 | # Assume fixed contex size 73 | context_size = 32 74 | num_samples = 100 75 | print('Generating Samples') 76 | for _ in tqdm(range(num_samples)): 77 | # Initialize with start token 78 | ctx = torch.ones((1)).to(device) * (config['model_params']['codebook_size']) 79 | 80 | for i in range(mnist_encodings_length): 81 | padded_ctx = ctx 82 | if len(ctx) < context_size: 83 | # Pad context with pad token 84 | padded_ctx = torch.nn.functional.pad(padded_ctx, (0, context_size - len(ctx)), "constant", 85 | config['model_params']['codebook_size']+1) 86 | 87 | out = model(padded_ctx[None, :].long().to(device)) 88 | probs = torch.nn.functional.softmax(out, dim=-1) 89 | pred = torch.multinomial(probs[0], num_samples=1) 90 | # Update the context with the new prediction 91 | ctx = torch.cat([ctx, pred]) 92 | generated_quantized_indices.append(ctx[1:][None, :]) 93 | 94 | ######## Decode the Generated Indices ########## 95 | generated_quantized_indices = torch.cat(generated_quantized_indices, dim=0) 96 | h = int(generated_quantized_indices[0].size(-1)**0.5) 97 | quantized_indices = generated_quantized_indices.reshape((generated_quantized_indices.size(0), h, h)).long() 98 | quantized_indices = torch.nn.functional.one_hot(quantized_indices, config['model_params']['codebook_size']) 99 | quantized_indices = quantized_indices.permute((0, 3, 1, 2)) 100 | output = vqvae_model.decode_from_codebook_indices(quantized_indices.float()) 101 | 102 | # Transform from -1, 1 range to 0,1 103 | output = (output + 1) / 2 104 | 105 | if config['model_params']['in_channels'] == 3: 106 | # Just because we took input as cv2.imread which is BGR so make it RGB 107 | output = output[:, [2, 1, 0], :, :] 108 | grid = make_grid(output.detach().cpu(), nrow=10) 109 | 110 | img = torchvision.transforms.ToPILImage()(grid) 111 | img.save(os.path.join(config['train_params']['task_name'], 112 | config['train_params']['output_train_dir'], 113 | 'generation_results.png')) 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser(description='Arguments for LSTM generation') 117 | parser.add_argument('--config', dest='config_path', 118 | default='config/vqvae_colored_mnist.yaml', type=str) 119 | args = parser.parse_args() 120 | generate(args) 121 | -------------------------------------------------------------------------------- /tools/infer_vqvae.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import os 5 | from tqdm import tqdm 6 | import torchvision 7 | from model.vqvae import get_model 8 | from torch.utils.data.dataloader import DataLoader 9 | from dataset.mnist_dataset import MnistDataset 10 | from torchvision.utils import make_grid 11 | from einops import rearrange 12 | import pickle 13 | 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | def reconstruct(config, model, dataset, num_images=100): 18 | r""" 19 | Randomly sample points from the dataset and visualize image and its reconstruction 20 | :param config: Config file used to create the model 21 | :param model: Trained model 22 | :param dataset: Mnist dataset(not the data loader) 23 | :param num_images: NUmber of images to visualize 24 | :return: 25 | """ 26 | print('Generating reconstructions') 27 | if not os.path.exists(config['train_params']['task_name']): 28 | os.mkdir(config['train_params']['task_name']) 29 | if not os.path.exists( 30 | os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir'])): 31 | os.mkdir(os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir'])) 32 | 33 | idxs = torch.randint(0, len(dataset) - 1, (num_images,)) 34 | ims = torch.cat([dataset[idx][0][None, :] for idx in idxs]).float() 35 | ims = ims.to(device) 36 | model_output = model(ims) 37 | output = model_output['generated_image'] 38 | 39 | # Dataset generates -1 to 1 we convert it to 0-1 40 | ims = (ims + 1) / 2 41 | 42 | # For reconstruction, we specifically flip it(white digit on black background -> black digit on white background) 43 | # for easier visualization but only if its not colored: 44 | generated_im = (output + 1) / 2 45 | if config['model_params']['in_channels'] == 1: 46 | generated_im = 1 - generated_im 47 | out = torch.hstack([ims, generated_im]) 48 | output = rearrange(out, 'b (c d) h w -> b (d) h (c w)', c=2, d=config['model_params']['in_channels']) 49 | # flip r and b channels as everything was trained on bgr(cv2) 50 | # although doesnt matter since both input and output would be flipped 51 | if config['model_params']['in_channels'] == 3: 52 | output = output[:, [2, 1, 0], :, :] 53 | grid = make_grid(output.detach().cpu(), nrow=10) 54 | 55 | img = torchvision.transforms.ToPILImage()(grid) 56 | img.save(os.path.join(config['train_params']['task_name'], 57 | config['train_params']['output_train_dir'], 58 | 'reconstruction.png')) 59 | 60 | 61 | def save_encodings(config, model, mnist_loader): 62 | r""" 63 | Method to save the encoder outputs for training LSTM 64 | :param config: 65 | :param model: 66 | :param mnist_loader: 67 | :return: 68 | """ 69 | save_encodings = None 70 | print('Saving Encodings for lstm') 71 | for im, _ in tqdm(mnist_loader): 72 | im = im.float().to(device) 73 | model_output = model(im) 74 | quant_indices = model_output['quantized_indices'] 75 | save_encodings = quant_indices if save_encodings is None else torch.cat([save_encodings, quant_indices], dim=0) 76 | pickle.dump(save_encodings, open(os.path.join(config['train_params']['task_name'], 77 | config['train_params']['output_train_dir'], 78 | 'mnist_encodings.pkl'), 'wb')) 79 | print('Done saving encoder outputs for lstm for training') 80 | 81 | 82 | def inference(args): 83 | with open(args.config_path, 'r') as file: 84 | try: 85 | config = yaml.safe_load(file) 86 | except yaml.YAMLError as exc: 87 | print(exc) 88 | print(config) 89 | 90 | model = get_model(config).to(device) 91 | model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'], 92 | config['train_params']['ckpt_name']), map_location='cpu')) 93 | model.to(device) 94 | model.eval() 95 | 96 | ######### For generating encoder output for training lstm ############# 97 | mnist = MnistDataset('train', config['train_params']['train_path'], 98 | im_channels=config['model_params']['in_channels']) 99 | 100 | ######### For visualizing reconstructions ############# 101 | mnist_test = MnistDataset('test', config['train_params']['train_path'], 102 | im_channels=config['model_params']['in_channels']) 103 | mnist_train_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=False, num_workers=0) 104 | with torch.no_grad(): 105 | # Generate Reconstructions 106 | reconstruct(config, model, mnist_test) 107 | # Save Encoder Outputs for training lstm 108 | save_encodings(config, model, mnist_train_loader) 109 | 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser(description='Arguments for vqvae inference') 113 | parser.add_argument('--config', dest='config_path', 114 | default='config/vqvae_colored_mnist.yaml', type=str) 115 | args = parser.parse_args() 116 | inference(args) -------------------------------------------------------------------------------- /tools/train_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import cv2 5 | import glob 6 | import torch 7 | import yaml 8 | import argparse 9 | import random 10 | import numpy as np 11 | from torch.optim import Adam 12 | import pickle 13 | from tqdm import tqdm 14 | from torch.utils.data.dataset import Dataset 15 | from torch.utils.data.dataloader import DataLoader 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | class MnistLSTM(nn.Module): 20 | r""" 21 | Very Simple 2 layer LSTM with an fc layer on last steps hidden dimension 22 | """ 23 | def __init__(self, input_size, hidden_size, codebook_size): 24 | super(MnistLSTM, self).__init__() 25 | self.rnn = nn.LSTM(input_size=2, hidden_size=128, num_layers=2, batch_first=True) 26 | self.fc = nn.Sequential(nn.Linear(hidden_size, hidden_size // 4), 27 | nn.ReLU(), 28 | nn.Linear(hidden_size // 4, codebook_size)) 29 | # Add pad and start token to embedding size 30 | self.word_embedding = nn.Embedding(codebook_size+2, input_size) 31 | 32 | def forward(self, x): 33 | x = self.word_embedding(x) 34 | output, _ = self.rnn(x) 35 | output = output[:, -1, :] 36 | return self.fc(output) 37 | 38 | 39 | class MnistSeqDataset(Dataset): 40 | r""" 41 | Dataset for training of LSTM. Assumes the encodings are already generated 42 | by running vqvae inference 43 | """ 44 | def __init__(self, config): 45 | self.codebook_size = config['model_params']['codebook_size'] 46 | 47 | # Codebook tokens will be 0 to codebook_size-1 48 | self.start_token = self.codebook_size 49 | self.pad_token = self.codebook_size+1 50 | # Fix context size 51 | self.context_size = 32 52 | self.sents = self.load_sents(config) 53 | 54 | def load_sents(self, config): 55 | assert os.path.exists(os.path.join(config['train_params']['task_name'], 56 | config['train_params']['output_train_dir'], 57 | 'mnist_encodings.pkl')), ("No encodings generated for lstm." 58 | "Run save_encodings method in inference script") 59 | mnist_encodings = pickle.load(open(os.path.join(config['train_params']['task_name'], 60 | config['train_params']['output_train_dir'], 61 | 'mnist_encodings.pkl'), 'rb')) 62 | mnist_encodings = mnist_encodings.reshape(mnist_encodings.size(0), -1) 63 | num_encodings = mnist_encodings.size(0) 64 | padded_sents = [] 65 | 66 | for encoding_idx in tqdm(range(num_encodings)): 67 | # Use only 10% encodings. 68 | # Uncomment this for getting some kind of output quickly validate working 69 | if random.random() > 0.1: 70 | continue 71 | enc = mnist_encodings[encoding_idx] 72 | encoding_length = enc.shape[-1] 73 | 74 | # Make sure all encodings start with start token 75 | enc = torch.cat([torch.ones((1)).to(device) * self.start_token, enc.to(device)]) 76 | 77 | # Create batches of context sized inputs(if possible) and target 78 | sents = [(enc[:i], enc[i]) if i < self.context_size else (enc[i - self.context_size:i], enc[i]) 79 | for i in range(1, encoding_length+1)] 80 | 81 | for context, target in sents: 82 | # Pad token if context not enough 83 | if len(context) < self.context_size: 84 | context = torch.nn.functional.pad(context, (0, self.context_size-len(context)), "constant", self.pad_token) 85 | padded_sents.append((context, target)) 86 | return padded_sents 87 | 88 | def __len__(self): 89 | return len(self.sents) 90 | 91 | def __getitem__(self, index): 92 | context, target = self.sents[index] 93 | return context, target 94 | 95 | def train_lstm(args): 96 | ############ Read the config ############# 97 | with open(args.config_path, 'r') as file: 98 | try: 99 | config = yaml.safe_load(file) 100 | except yaml.YAMLError as exc: 101 | print(exc) 102 | print(config) 103 | ######################################### 104 | 105 | ############## Create dataset ########### 106 | mnist = MnistSeqDataset(config) 107 | mnist_seq_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=0) 108 | ######################################### 109 | 110 | ############## Create LSTM ########### 111 | default_lstm_config = { 112 | 'input_size' : 2, 113 | 'hidden_size' : 128, 114 | 'codebook_size' : config['model_params']['codebook_size'] 115 | } 116 | model = MnistLSTM(input_size=default_lstm_config['input_size'], 117 | hidden_size=default_lstm_config['hidden_size'], 118 | codebook_size=default_lstm_config['codebook_size']).to(device) 119 | model.to(device) 120 | model.train() 121 | 122 | ############## Training Params ########### 123 | num_epochs = 10 124 | optimizer = Adam(model.parameters(), lr=1E-3) 125 | criterion = torch.nn.CrossEntropyLoss() 126 | 127 | for epoch in range(num_epochs): 128 | losses = [] 129 | for sent, target in tqdm(mnist_seq_loader): 130 | sent = sent.to(device).long() 131 | target = target.to(device).long() 132 | optimizer.zero_grad() 133 | pred = model(sent) 134 | loss = torch.mean(criterion(pred, target)) 135 | loss.backward() 136 | optimizer.step() 137 | losses.append(loss.item()) 138 | print('Epoch {} : {}'.format(epoch, np.mean(losses))) 139 | print('=' * 50) 140 | torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'], 141 | 'best_mnist_lstm.pth')) 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser(description='Arguments for lstm training') 146 | parser.add_argument('--config', dest='config_path', 147 | default='../config/vqvae_colored_mnist.yaml', type=str) 148 | args = parser.parse_args() 149 | train_lstm(args) 150 | 151 | -------------------------------------------------------------------------------- /tools/train_vqvae.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import cv2 5 | import random 6 | import os 7 | import shutil 8 | import torchvision 9 | import numpy as np 10 | from tqdm import tqdm 11 | from model.vqvae import get_model 12 | from torch.utils.data.dataloader import DataLoader 13 | from dataset.mnist_dataset import MnistDataset 14 | from torch.optim import Adam 15 | from torchvision.utils import make_grid 16 | from torch.optim.lr_scheduler import ReduceLROnPlateau 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | def train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, crtierion, config): 20 | r""" 21 | Method to run the training for one epoch. 22 | :param epoch_idx: iteration number of current epoch 23 | :param model: VQVAE model 24 | :param mnist_loader: Data loder for mnist 25 | :param optimizer: optimzier to be used taken from config 26 | :param crtierion: For computing the loss 27 | :param config: configuration for the current run 28 | :return: 29 | """ 30 | recon_losses = [] 31 | codebook_losses = [] 32 | commitment_losses = [] 33 | losses = [] 34 | # We ignore the label for VQVAE 35 | count = 0 36 | for im, _ in tqdm(mnist_loader): 37 | im = im.float().to(device) 38 | optimizer.zero_grad() 39 | model_output = model(im) 40 | output = model_output['generated_image'] 41 | quantize_losses = model_output['quantized_losses'] 42 | 43 | if config['train_params']['save_training_image']: 44 | cv2.imwrite('input.jpeg', (255 * (im.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0]) 45 | cv2.imwrite('output.jpeg', (255 * (output.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0]) 46 | 47 | recon_loss = crtierion(output, im) 48 | loss = (config['train_params']['reconstruction_loss_weight']*recon_loss + 49 | config['train_params']['codebook_loss_weight']*quantize_losses['codebook_loss'] + 50 | config['train_params']['commitment_loss_weight']*quantize_losses['commitment_loss']) 51 | recon_losses.append(recon_loss.item()) 52 | codebook_losses.append(config['train_params']['codebook_loss_weight']*quantize_losses['codebook_loss'].item()) 53 | commitment_losses.append(quantize_losses['commitment_loss'].item()) 54 | losses.append(loss.item()) 55 | loss.backward() 56 | optimizer.step() 57 | print('Finished epoch: {} | Recon Loss : {:.4f} | Codebook Loss : {:.4f} | Commitment Loss : {:.4f}'. 58 | format(epoch_idx + 1, 59 | np.mean(recon_losses), 60 | np.mean(codebook_losses), 61 | np.mean(commitment_losses))) 62 | return np.mean(losses) 63 | 64 | 65 | def train(args): 66 | ######## Read the config file ####### 67 | with open(args.config_path, 'r') as file: 68 | try: 69 | config = yaml.safe_load(file) 70 | except yaml.YAMLError as exc: 71 | print(exc) 72 | print(config) 73 | ####################################### 74 | 75 | ######## Set the desired seed value ####### 76 | seed = config['train_params']['seed'] 77 | torch.manual_seed(seed) 78 | np.random.seed(seed) 79 | random.seed(seed) 80 | if device == 'cuda': 81 | torch.cuda.manual_seed_all(args.seed) 82 | ####################################### 83 | 84 | # Create the model and dataset 85 | model = get_model(config).to(device) 86 | mnist = MnistDataset('train', config['train_params']['train_path'], im_channels=config['model_params']['in_channels']) 87 | mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=0) 88 | num_epochs = config['train_params']['epochs'] 89 | optimizer = Adam(model.parameters(), lr=config['train_params']['lr']) 90 | scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True) 91 | criterion = { 92 | 'l1': torch.nn.L1Loss(), 93 | 'l2': torch.nn.MSELoss() 94 | }.get(config['train_params']['crit']) 95 | 96 | # Create output directories 97 | if not os.path.exists(config['train_params']['task_name']): 98 | os.mkdir(config['train_params']['task_name']) 99 | if not os.path.exists(os.path.join(config['train_params']['task_name'], 100 | config['train_params']['output_train_dir'])): 101 | os.mkdir(os.path.join(config['train_params']['task_name'], 102 | config['train_params']['output_train_dir'])) 103 | 104 | # Load checkpoint if found 105 | if os.path.exists(os.path.join(config['train_params']['task_name'], 106 | config['train_params']['ckpt_name'])): 107 | print('Loading checkpoint') 108 | model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'], 109 | config['train_params']['ckpt_name']), map_location=device)) 110 | best_loss = np.inf 111 | 112 | for epoch_idx in range(num_epochs): 113 | mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, criterion, config) 114 | scheduler.step(mean_loss) 115 | # Simply update checkpoint if found better version 116 | if mean_loss < best_loss: 117 | print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss)) 118 | torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'], 119 | config['train_params']['ckpt_name'])) 120 | best_loss = mean_loss 121 | else: 122 | print('No Loss Improvement') 123 | 124 | 125 | if __name__ == '__main__': 126 | parser = argparse.ArgumentParser(description='Arguments for vq vae training') 127 | parser.add_argument('--config', dest='config_path', 128 | default='config/vqvae_colored_mnist.yaml', type=str) 129 | args = parser.parse_args() 130 | train(args) --------------------------------------------------------------------------------