├── .gitignore
├── README.md
├── config
├── __init__.py
├── vqvae_colored_mnist.yaml
└── vqvae_mnist.yaml
├── data
├── dataset
├── __init__.py
└── mnist_dataset.py
├── model
├── __init__.py
├── decoder.py
├── encoder.py
├── quantizer.py
└── vqvae.py
├── requirements.txt
├── run_simple_vqvae.py
└── tools
├── __init__.py
├── generate_images.py
├── infer_vqvae.py
├── train_lstm.py
└── train_vqvae.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all image files
2 | *.jpg
3 | *.png
4 | *.jpeg
5 |
6 | # Ignore pycharm and system files
7 | .DS_Store
8 | *.idea
9 | __pycache__
10 | *.zip
11 |
12 | # Ignore dataset files
13 | *.csv
14 |
15 | # Ignore checkpoints
16 | *.pth
17 |
18 | # Ignore pickle files
19 | *.pkl
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | VQVAE Implementation in pytorch with generation using LSTM
2 | ========
3 |
4 | This repository implements [VQVAE](https://arxiv.org/abs/1711.00937) for mnist and colored version of mnist and follows up with a simple LSTM for generating numbers.
5 |
6 | ## VQVAE Explanation and Implementation Video
7 |
8 |
10 |
11 |
12 |
13 | # Quickstart
14 | * Create a new conda environment with python 3.8 then run below commands
15 | * ```git clone https://github.com/explainingai-code/VQVAE-Pytorch.git```
16 | * ```cd VQVAE-Pytorch```
17 | * ```pip install -r requirements.txt```
18 | * For running a simple VQVAE with minimal code to understand the basics ```python run_simple_vqvae.py```
19 | * For playing around with VQVAE and training/inferencing the LSTM use the below commands passing the desired configuration file as the config argument
20 | * ```python -m tools.train_vqvae``` for training vqvae
21 | * ```python -m tools.infer_vqvae``` for generating reconstructions and encoder outputs for LSTM training
22 | * ```python -m tools.train_lstm``` for training minimal LSTM
23 | * ```python -m tools.generate_images``` for using the trained LSTM to generate some numbers
24 |
25 | ## Configurations
26 | * ```config/vqvae_mnist.yaml``` - VQVAE for training on black and white mnist images
27 | * ```config/vqvae_colored_mnist.yaml``` - VQVAE with more embedding vectors for training colored mnist images
28 |
29 | ## Data preparation
30 | For setting up the dataset:
31 | Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
32 |
33 | Verify the data directory has the following structure:
34 | ```
35 | VQVAE-Pytorch/data/train/images/{0/1/.../9}
36 | *.png
37 | VQVAE-Pytorch/data/test/images/{0/1/.../9}
38 | *.png
39 | ```
40 |
41 | ## Output
42 | Outputs will be saved according to the configuration present in yaml files.
43 |
44 | For every run a folder of ```task_name``` key in config will be created and ```output_train_dir``` will be created inside it.
45 |
46 | During training of VQVAE the following output will be saved
47 | * Best Model checkpoints(VQVAE and LSTM) in ```task_name``` directory
48 |
49 | During inference the following output will be saved
50 | * Reconstructions for sample of test set in ```task_name/output_train_dir/reconstruction.png```
51 | * Encoder outputs on train set for LSTM training in ```task_name/output_train_dir/mnist_encodings.pkl```
52 | * LSTM generation output in ```task_name/output_train_dir/generation_results.png```
53 |
54 |
55 | ## Sample Output for VQVAE
56 |
57 | Running `run_simple_vqvae` should be very quick (as its very simple model) and give you below reconstructions (input in black black background and reconstruction in white background)
58 |
59 |
60 |
61 | Running default config VQVAE for mnist should give you below reconstructions for both versions
62 |
63 |
64 |
65 |
66 | Sample Generation Output after just 10 epochs
67 | Training the vqvae and lstm longer and more parameters(codebook size, codebook dimension, channels , lstm hidden dimension e.t.c) will give better results
68 |
69 |
70 |
71 |
72 | ## Citations
73 | ```
74 | @misc{oord2018neural,
75 | title={Neural Discrete Representation Learning},
76 | author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
77 | year={2018},
78 | eprint={1711.00937},
79 | archivePrefix={arXiv},
80 | primaryClass={cs.LG}
81 | }
82 | ```
83 |
84 |
85 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/config/__init__.py
--------------------------------------------------------------------------------
/config/vqvae_colored_mnist.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | in_channels: 3
3 | convbn_blocks: 4
4 | conv_kernel_size: [3,3,3,2]
5 | conv_kernel_strides: [2, 2, 1, 1]
6 | convbn_channels: [3, 16, 32, 8, 8]
7 | conv_activation_fn: 'leaky'
8 | transpose_bn_blocks: 4
9 | transposebn_channels: [8, 8, 32, 16, 3]
10 | transpose_kernel_size: [3,4,4,4]
11 | transpose_kernel_strides: [1,2,1,1]
12 | transpose_activation_fn: 'leaky'
13 | latent_dim: 8
14 | codebook_size : 20
15 |
16 | train_params:
17 | task_name: 'vqvae_latent_8_colored_codebook_20'
18 | batch_size: 64
19 | epochs: 20
20 | lr: 0.005
21 | crit: 'l2'
22 | reconstruction_loss_weight : 5
23 | codebook_loss_weight : 1
24 | commitment_loss_weight : 0.2
25 | ckpt_name: 'best_vqvae_latent_8_colored_codebook_20.pth'
26 | seed: 111
27 | save_training_image: True
28 | train_path: 'data/train/images'
29 | test_path: 'data/test/images'
30 | output_train_dir: 'output'
31 |
--------------------------------------------------------------------------------
/config/vqvae_mnist.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | in_channels: 1
3 | convbn_blocks: 4
4 | conv_kernel_size: [3,3,3,2]
5 | conv_kernel_strides: [2, 2, 1, 1]
6 | convbn_channels: [1, 16, 32, 8, 4]
7 | conv_activation_fn: 'leaky'
8 | transpose_bn_blocks: 4
9 | transposebn_channels: [4, 8, 32, 16, 1]
10 | transpose_kernel_size: [3,4,4,4]
11 | transpose_kernel_strides: [1,2,1,1]
12 | transpose_activation_fn: 'leaky'
13 | latent_dim: 2
14 | codebook_size : 5
15 |
16 | train_params:
17 | task_name: 'vqvae_latent_2_codebook_5'
18 | batch_size: 64
19 | epochs: 10
20 | lr: 0.005
21 | crit: 'l2'
22 | reconstruction_loss_weight : 1
23 | codebook_loss_weight : 1
24 | commitment_loss_weight : 0.2
25 | ckpt_name: 'best_vqvae_latent_2_codebook_5.pth'
26 | seed: 111
27 | save_training_image: True
28 | train_path: 'data/train/images'
29 | test_path: 'data/test/images'
30 | output_train_dir: 'output'
31 |
--------------------------------------------------------------------------------
/data:
--------------------------------------------------------------------------------
1 | /Users/tusharkumar/PycharmProjects/explainingai-repos/Pytorch-VAE/data
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/mnist_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import torch
5 | import random
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torch.utils.data.dataset import Dataset
9 | from torch.utils.data.dataloader import DataLoader
10 |
11 | r"""
12 | Simple Dataloader for mnist.
13 | """
14 |
15 | class MnistDataset(Dataset):
16 | def __init__(self, split, im_path, im_ext='png', im_channels=1):
17 | self.split = split
18 | self.im_ext = im_ext
19 | self.im_channels = im_channels
20 | self.images, self.labels = self.load_images(im_path)
21 |
22 | def load_images(self, im_path):
23 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
24 | ims = []
25 | labels = []
26 | for d_name in tqdm(os.listdir(im_path)):
27 | for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
28 | ims.append(fname)
29 | labels.append(int(d_name))
30 | print('Found {} images for split {}'.format(len(ims), self.split))
31 | return ims, labels
32 |
33 | def __len__(self):
34 | return len(self.images)
35 |
36 | def __getitem__(self, index):
37 | assert self.im_channels == 1 or self.im_channels == 3, "Input iamge channels can only be 1 or 3"
38 | if self.im_channels == 1:
39 | im = cv2.imread(self.images[index], 0)
40 | else:
41 | # Generate a random color digit
42 | im_1 = cv2.imread(self.images[index], 0)[None, :]*np.clip(random.random(), 0.2, 1.0)
43 | im_2 = cv2.imread(self.images[index], 0)[None, :]*np.clip(random.random(), 0.2, 1.0)
44 | im_3 = cv2.imread(self.images[index], 0)[None, :]*np.clip(random.random(), 0.2, 1.0)
45 | im = np.concatenate([im_1, im_2, im_3], axis=0)
46 |
47 | label = self.labels[index]
48 | # Convert to 0 to 255 into -1 to 1
49 | im = 2 * (im / 255) - 1
50 | im_tensor = torch.from_numpy(im)[None, :] if self.im_channels == 1 else torch.from_numpy(im)
51 | return im_tensor, torch.as_tensor(label)
52 |
53 |
54 | if __name__ == '__main__':
55 | mnist = MnistDataset('test', 'data/test/images', im_channels=3)
56 | mnist_loader = DataLoader(mnist, batch_size=16, shuffle=True, num_workers=0)
57 | for im, label in mnist_loader:
58 | print('Image dimension', im.shape)
59 | print('Label dimension: {}'.format(label.shape))
60 | break
61 |
62 |
63 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/model/__init__.py
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Decoder(nn.Module):
6 | def __init__(self,
7 | config
8 | ):
9 | super(Decoder, self).__init__()
10 | activation_map = {
11 | 'relu': nn.ReLU(),
12 | 'leaky': nn.LeakyReLU(),
13 | 'tanh': nn.Tanh(),
14 | 'gelu': nn.GELU(),
15 | 'silu': nn.SiLU()
16 | }
17 |
18 | self.config = config
19 | ##### Validate the configuration for the model is correctly setup #######
20 | assert config['transpose_activation_fn'] is None or config['transpose_activation_fn'] in activation_map
21 | self.latent_dim = config['latent_dim']
22 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 |
24 |
25 |
26 | self.decoder_layers = nn.ModuleList([
27 | nn.Sequential(
28 | nn.ConvTranspose2d(config['transposebn_channels'][i], config['transposebn_channels'][i + 1],
29 | kernel_size=config['transpose_kernel_size'][i],
30 | stride=config['transpose_kernel_strides'][i],
31 | padding=0),
32 | nn.BatchNorm2d(config['transposebn_channels'][i + 1]),
33 | activation_map[config['transpose_activation_fn']]
34 | )
35 | for i in range(config['transpose_bn_blocks']-1)
36 | ])
37 |
38 | dec_last_idx = config['transpose_bn_blocks']
39 | self.decoder_layers.append(
40 | nn.Sequential(
41 | nn.ConvTranspose2d(config['transposebn_channels'][dec_last_idx - 1], config['transposebn_channels'][dec_last_idx],
42 | kernel_size=config['transpose_kernel_size'][dec_last_idx - 1],
43 | stride=config['transpose_kernel_strides'][dec_last_idx - 1],
44 | padding=0),
45 | nn.Tanh()
46 | )
47 | )
48 |
49 | def forward(self, x):
50 | out = x
51 | for idx, layer in enumerate(self.decoder_layers):
52 | out = layer(out)
53 | return out
54 |
55 |
56 | def get_decoder(config):
57 | decoder = Decoder(
58 | config=config['model_params']
59 | )
60 | return decoder
61 |
62 |
63 |
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Encoder(nn.Module):
6 | def __init__(self,
7 | config
8 | ):
9 | super(Encoder, self).__init__()
10 | activation_map = {
11 | 'relu': nn.ReLU(),
12 | 'leaky': nn.LeakyReLU(),
13 | 'tanh': nn.Tanh(),
14 | 'gelu': nn.GELU(),
15 | 'silu': nn.SiLU()
16 | }
17 |
18 | self.config = config
19 |
20 | ##### Validate the configuration for the model is correctly setup #######
21 | assert config['conv_activation_fn'] is None or config['conv_activation_fn'] in activation_map
22 | self.latent_dim = config['latent_dim']
23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24 |
25 | # Encoder is just Conv bn activation blocks
26 | self.encoder_layers = nn.ModuleList([
27 | nn.Sequential(
28 | nn.Conv2d(config['convbn_channels'][i], config['convbn_channels'][i + 1],
29 | kernel_size=config['conv_kernel_size'][i], stride=config['conv_kernel_strides'][i],padding=1),
30 | nn.BatchNorm2d(config['convbn_channels'][i + 1]),
31 | activation_map[config['conv_activation_fn']],
32 | )
33 | for i in range(config['convbn_blocks']-1)
34 | ])
35 |
36 | enc_last_idx = config['convbn_blocks']
37 | self.encoder_layers.append(
38 | nn.Sequential(
39 | nn.Conv2d(config['convbn_channels'][enc_last_idx - 1], config['convbn_channels'][enc_last_idx],
40 | kernel_size=config['conv_kernel_size'][enc_last_idx-1],
41 | stride=config['conv_kernel_strides'][enc_last_idx-1], padding=1),
42 | )
43 | )
44 |
45 | def forward(self, x):
46 | out = x
47 | for layer in self.encoder_layers:
48 | out = layer(out)
49 | return out
50 |
51 |
52 | def get_encoder(config):
53 | encoder = Encoder(
54 | config=config['model_params']
55 | )
56 | return encoder
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/model/quantizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import einsum
4 |
5 |
6 | class Quantizer(nn.Module):
7 | def __init__(self,
8 | config
9 | ):
10 | super(Quantizer, self).__init__()
11 | self.config = config
12 | self.embedding = nn.Embedding(config['codebook_size'], config['latent_dim'])
13 |
14 | def forward(self, x):
15 | B, C, H, W = x.shape
16 | x = x.permute(0, 2, 3, 1)
17 | x = x.reshape(x.size(0), -1, x.size(-1))
18 |
19 | dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
20 | min_encoding_indices = torch.argmin(dist, dim=-1)
21 |
22 | #
23 | quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
24 | x = x.reshape((-1, x.size(-1)))
25 | commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
26 | codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
27 | quantize_losses = {
28 | 'codebook_loss' : codebook_loss,
29 | 'commitment_loss' : commmitment_loss
30 | }
31 | quant_out = x + (quant_out - x).detach()
32 | quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
33 | min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
34 | return quant_out, quantize_losses, min_encoding_indices
35 |
36 | def quantize_indices(self, indices):
37 | return einsum(indices, self.embedding.weight, 'b n h w, n d -> b d h w')
38 |
39 |
40 | def get_quantizer(config):
41 | quantizer = Quantizer(
42 | config=config['model_params']
43 | )
44 | return quantizer
45 |
46 |
--------------------------------------------------------------------------------
/model/vqvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.encoder import get_encoder
4 | from model.decoder import get_decoder
5 | from model.quantizer import get_quantizer
6 |
7 |
8 | class VQVAE(nn.Module):
9 | def __init__(self,
10 | config
11 | ):
12 | super(VQVAE, self).__init__()
13 | self.encoder = get_encoder(config)
14 | self.pre_quant_conv = nn.Conv2d(config['model_params']['convbn_channels'][-1],
15 | config['model_params']['latent_dim'],
16 | kernel_size=1)
17 | self.quantizer = get_quantizer(config)
18 | self.post_quant_conv = nn.Conv2d(config['model_params']['latent_dim'],
19 | config['model_params']['transposebn_channels'][0],
20 | kernel_size=1)
21 | self.decoder = get_decoder(config)
22 |
23 | def forward(self, x):
24 | enc = self.encoder(x)
25 | quant_input = self.pre_quant_conv(enc)
26 | quant_output, quant_loss, quant_idxs = self.quantizer(quant_input)
27 | dec_input = self.post_quant_conv(quant_output)
28 | out = self.decoder(dec_input)
29 | return {
30 | 'generated_image' : out,
31 | 'quantized_output' : quant_output,
32 | 'quantized_losses' : quant_loss,
33 | 'quantized_indices' : quant_idxs
34 | }
35 |
36 | def decode_from_codebook_indices(self, indices):
37 | quantized_output = self.quantizer.quantize_indices(indices)
38 | dec_input = self.post_quant_conv(quantized_output)
39 | return self.decoder(dec_input)
40 |
41 |
42 |
43 | def get_model(config):
44 | print(config)
45 | model = VQVAE(
46 | config=config
47 | )
48 | return model
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.6.1
2 | matplotlib==3.7.2
3 | numpy==1.23.5
4 | opencv_python==4.8.0.74
5 | PyYAML==6.0
6 | torch==1.11.0
7 | torchvision==0.12.0
8 | tqdm==4.65.0
9 |
--------------------------------------------------------------------------------
/run_simple_vqvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import torchvision
4 | import torch.nn as nn
5 | import numpy as np
6 | from tqdm import tqdm
7 | from einops import rearrange
8 | from torch.optim import Adam
9 | from dataset.mnist_dataset import MnistDataset
10 | from torch.utils.data import DataLoader
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 |
15 | class VQVAE(nn.Module):
16 | def __init__(self):
17 | super(VQVAE, self).__init__()
18 | self.encoder = nn.Sequential(
19 | nn.Conv2d(1, 16, 4, stride=2, padding=1),
20 | nn.BatchNorm2d(16),
21 | nn.ReLU(),
22 | nn.Conv2d(16, 4, 4, stride=2, padding=1),
23 | nn.BatchNorm2d(4),
24 | nn.ReLU(),
25 | )
26 |
27 | self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1)
28 | self.embedding = nn.Embedding(num_embeddings=3, embedding_dim=2)
29 | self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1)
30 |
31 | # Commitment Loss Beta
32 | self.beta = 0.2
33 |
34 | self.decoder = nn.Sequential(
35 | nn.ConvTranspose2d(4, 16, 4, stride=2, padding=1),
36 | nn.BatchNorm2d(16),
37 | nn.ReLU(),
38 | nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1),
39 | nn.Tanh(),
40 | )
41 |
42 |
43 | def forward(self, x):
44 | # B, C, H, W
45 | encoded_output = self.encoder(x)
46 | quant_input = self.pre_quant_conv(encoded_output)
47 |
48 | ## Quantization
49 | B, C, H, W = quant_input.shape
50 | quant_input = quant_input.permute(0, 2, 3, 1)
51 | quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1)))
52 |
53 | # Compute pairwise distances
54 | dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1)))
55 |
56 | # Find index of nearest embedding
57 | min_encoding_indices = torch.argmin(dist, dim=-1)
58 |
59 | # Select the embedding weights
60 | quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
61 | quant_input = quant_input.reshape((-1, quant_input.size(-1)))
62 |
63 | # Compute losses
64 | commitment_loss = torch.mean((quant_out.detach() - quant_input)**2)
65 | codebook_loss = torch.mean((quant_out - quant_input.detach())**2)
66 | quantize_losses = codebook_loss + self.beta*commitment_loss
67 |
68 | # Ensure straight through gradient
69 | quant_out = quant_input + (quant_out - quant_input).detach()
70 |
71 | # Reshaping back to original input shape
72 | quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
73 | min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
74 |
75 |
76 | ## Decoder part
77 | decoder_input = self.post_quant_conv(quant_out)
78 | output = self.decoder(decoder_input)
79 | return output, quantize_losses
80 |
81 | def train_vqvae():
82 | mnist = MnistDataset('train', im_path='data/train/images')
83 | mnist_test = MnistDataset('test', im_path='data/test/images')
84 | mnist_loader = DataLoader(mnist, batch_size=64, shuffle=True, num_workers=4)
85 |
86 | model = VQVAE().to(device)
87 |
88 | num_epochs = 20
89 | optimizer = Adam(model.parameters(), lr=1E-3)
90 | criterion = torch.nn.MSELoss()
91 |
92 | for epoch_idx in range(num_epochs):
93 | for im, label in tqdm(mnist_loader):
94 | im = im.float().to(device)
95 | optimizer.zero_grad()
96 | out, quantize_loss = model(im)
97 |
98 | recon_loss = criterion(out, im)
99 | loss = recon_loss + quantize_loss
100 | loss.backward()
101 | optimizer.step()
102 | print('Finished epoch {}'.format(epoch_idx+1))
103 | print('Done Training...')
104 |
105 | # Reconstruction part
106 |
107 | idxs = torch.randint(0, len(mnist_test), (100, ))
108 | ims = torch.cat([mnist_test[idx][0][None, :] for idx in idxs]).float()
109 | ims = ims.to(device)
110 | model.eval()
111 |
112 |
113 | generated_im, _ = model(ims)
114 | ims = (ims+1)/2
115 | generated_im = 1 - (generated_im+1)/2
116 | out = torch.hstack([ims, generated_im])
117 | output = rearrange(out, 'b c h w -> b () h (c w)')
118 | grid = torchvision.utils.make_grid(output.detach().cpu(), nrow=10)
119 | img = torchvision.transforms.ToPILImage()(grid)
120 | img.save('reconstruction.png')
121 |
122 | print('Done Reconstruction ...')
123 |
124 | if __name__ == '__main__':
125 | train_vqvae()
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VQVAE-Pytorch/378815c6609c5dd7271d33e32a756c5f74335cc0/tools/__init__.py
--------------------------------------------------------------------------------
/tools/generate_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 | import os
4 | import torch
5 | import pickle
6 | from tqdm import tqdm
7 | import torchvision
8 | from model.vqvae import get_model
9 | from torchvision.utils import make_grid
10 | from tools.train_lstm import MnistLSTM
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 |
13 |
14 |
15 | def generate(args):
16 | r"""
17 | Method for generating images after training vqvae and lstm
18 | 1. Create config
19 | 2. Create and load vqvae model
20 | 3. Create and load LSTM model
21 | 4. Generate 100 encoder outputs from trained LSTM
22 | 5. Pass them to the trained vqvae decoder
23 | 6. Save the generated image
24 | :param args:
25 | :return:
26 | """
27 |
28 | ########## Read the Config ##############
29 | with open(args.config_path, 'r') as file:
30 | try:
31 | config = yaml.safe_load(file)
32 | except yaml.YAMLError as exc:
33 | print(exc)
34 | print(config)
35 | #########################################
36 |
37 | ########## Load VQVAE Model ##############
38 | vqvae_model = get_model(config).to(device)
39 | vqvae_model.to(device)
40 | assert os.path.exists(os.path.join(config['train_params']['task_name'],
41 | config['train_params']['ckpt_name'])), "Train the vqvae model first"
42 | vqvae_model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
43 | config['train_params']['ckpt_name']), map_location=device))
44 |
45 | vqvae_model.eval()
46 | #########################################
47 |
48 | ########## Load LSTM ##############
49 | default_lstm_config = {
50 | 'input_size': 2,
51 | 'hidden_size': 128,
52 | 'codebook_size': config['model_params']['codebook_size']
53 | }
54 |
55 | model = MnistLSTM(input_size=default_lstm_config['input_size'],
56 | hidden_size=default_lstm_config['hidden_size'],
57 | codebook_size=default_lstm_config['codebook_size']).to(device)
58 | model.to(device)
59 | assert os.path.exists(os.path.join(config['train_params']['task_name'],
60 | 'best_mnist_lstm.pth')), "Train the lstm first"
61 | model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
62 | 'best_mnist_lstm.pth'), map_location=device))
63 | model.eval()
64 | #########################################
65 |
66 | ################ Generate Samples #############
67 | generated_quantized_indices = []
68 | mnist_encodings = pickle.load(open(os.path.join(config['train_params']['task_name'],
69 | config['train_params']['output_train_dir'],
70 | 'mnist_encodings.pkl'), 'rb'))
71 | mnist_encodings_length = mnist_encodings.reshape(mnist_encodings.size(0), -1).shape[-1]
72 | # Assume fixed contex size
73 | context_size = 32
74 | num_samples = 100
75 | print('Generating Samples')
76 | for _ in tqdm(range(num_samples)):
77 | # Initialize with start token
78 | ctx = torch.ones((1)).to(device) * (config['model_params']['codebook_size'])
79 |
80 | for i in range(mnist_encodings_length):
81 | padded_ctx = ctx
82 | if len(ctx) < context_size:
83 | # Pad context with pad token
84 | padded_ctx = torch.nn.functional.pad(padded_ctx, (0, context_size - len(ctx)), "constant",
85 | config['model_params']['codebook_size']+1)
86 |
87 | out = model(padded_ctx[None, :].long().to(device))
88 | probs = torch.nn.functional.softmax(out, dim=-1)
89 | pred = torch.multinomial(probs[0], num_samples=1)
90 | # Update the context with the new prediction
91 | ctx = torch.cat([ctx, pred])
92 | generated_quantized_indices.append(ctx[1:][None, :])
93 |
94 | ######## Decode the Generated Indices ##########
95 | generated_quantized_indices = torch.cat(generated_quantized_indices, dim=0)
96 | h = int(generated_quantized_indices[0].size(-1)**0.5)
97 | quantized_indices = generated_quantized_indices.reshape((generated_quantized_indices.size(0), h, h)).long()
98 | quantized_indices = torch.nn.functional.one_hot(quantized_indices, config['model_params']['codebook_size'])
99 | quantized_indices = quantized_indices.permute((0, 3, 1, 2))
100 | output = vqvae_model.decode_from_codebook_indices(quantized_indices.float())
101 |
102 | # Transform from -1, 1 range to 0,1
103 | output = (output + 1) / 2
104 |
105 | if config['model_params']['in_channels'] == 3:
106 | # Just because we took input as cv2.imread which is BGR so make it RGB
107 | output = output[:, [2, 1, 0], :, :]
108 | grid = make_grid(output.detach().cpu(), nrow=10)
109 |
110 | img = torchvision.transforms.ToPILImage()(grid)
111 | img.save(os.path.join(config['train_params']['task_name'],
112 | config['train_params']['output_train_dir'],
113 | 'generation_results.png'))
114 |
115 | if __name__ == '__main__':
116 | parser = argparse.ArgumentParser(description='Arguments for LSTM generation')
117 | parser.add_argument('--config', dest='config_path',
118 | default='config/vqvae_colored_mnist.yaml', type=str)
119 | args = parser.parse_args()
120 | generate(args)
121 |
--------------------------------------------------------------------------------
/tools/infer_vqvae.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import os
5 | from tqdm import tqdm
6 | import torchvision
7 | from model.vqvae import get_model
8 | from torch.utils.data.dataloader import DataLoader
9 | from dataset.mnist_dataset import MnistDataset
10 | from torchvision.utils import make_grid
11 | from einops import rearrange
12 | import pickle
13 |
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 | def reconstruct(config, model, dataset, num_images=100):
18 | r"""
19 | Randomly sample points from the dataset and visualize image and its reconstruction
20 | :param config: Config file used to create the model
21 | :param model: Trained model
22 | :param dataset: Mnist dataset(not the data loader)
23 | :param num_images: NUmber of images to visualize
24 | :return:
25 | """
26 | print('Generating reconstructions')
27 | if not os.path.exists(config['train_params']['task_name']):
28 | os.mkdir(config['train_params']['task_name'])
29 | if not os.path.exists(
30 | os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir'])):
31 | os.mkdir(os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir']))
32 |
33 | idxs = torch.randint(0, len(dataset) - 1, (num_images,))
34 | ims = torch.cat([dataset[idx][0][None, :] for idx in idxs]).float()
35 | ims = ims.to(device)
36 | model_output = model(ims)
37 | output = model_output['generated_image']
38 |
39 | # Dataset generates -1 to 1 we convert it to 0-1
40 | ims = (ims + 1) / 2
41 |
42 | # For reconstruction, we specifically flip it(white digit on black background -> black digit on white background)
43 | # for easier visualization but only if its not colored:
44 | generated_im = (output + 1) / 2
45 | if config['model_params']['in_channels'] == 1:
46 | generated_im = 1 - generated_im
47 | out = torch.hstack([ims, generated_im])
48 | output = rearrange(out, 'b (c d) h w -> b (d) h (c w)', c=2, d=config['model_params']['in_channels'])
49 | # flip r and b channels as everything was trained on bgr(cv2)
50 | # although doesnt matter since both input and output would be flipped
51 | if config['model_params']['in_channels'] == 3:
52 | output = output[:, [2, 1, 0], :, :]
53 | grid = make_grid(output.detach().cpu(), nrow=10)
54 |
55 | img = torchvision.transforms.ToPILImage()(grid)
56 | img.save(os.path.join(config['train_params']['task_name'],
57 | config['train_params']['output_train_dir'],
58 | 'reconstruction.png'))
59 |
60 |
61 | def save_encodings(config, model, mnist_loader):
62 | r"""
63 | Method to save the encoder outputs for training LSTM
64 | :param config:
65 | :param model:
66 | :param mnist_loader:
67 | :return:
68 | """
69 | save_encodings = None
70 | print('Saving Encodings for lstm')
71 | for im, _ in tqdm(mnist_loader):
72 | im = im.float().to(device)
73 | model_output = model(im)
74 | quant_indices = model_output['quantized_indices']
75 | save_encodings = quant_indices if save_encodings is None else torch.cat([save_encodings, quant_indices], dim=0)
76 | pickle.dump(save_encodings, open(os.path.join(config['train_params']['task_name'],
77 | config['train_params']['output_train_dir'],
78 | 'mnist_encodings.pkl'), 'wb'))
79 | print('Done saving encoder outputs for lstm for training')
80 |
81 |
82 | def inference(args):
83 | with open(args.config_path, 'r') as file:
84 | try:
85 | config = yaml.safe_load(file)
86 | except yaml.YAMLError as exc:
87 | print(exc)
88 | print(config)
89 |
90 | model = get_model(config).to(device)
91 | model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
92 | config['train_params']['ckpt_name']), map_location='cpu'))
93 | model.to(device)
94 | model.eval()
95 |
96 | ######### For generating encoder output for training lstm #############
97 | mnist = MnistDataset('train', config['train_params']['train_path'],
98 | im_channels=config['model_params']['in_channels'])
99 |
100 | ######### For visualizing reconstructions #############
101 | mnist_test = MnistDataset('test', config['train_params']['train_path'],
102 | im_channels=config['model_params']['in_channels'])
103 | mnist_train_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=False, num_workers=0)
104 | with torch.no_grad():
105 | # Generate Reconstructions
106 | reconstruct(config, model, mnist_test)
107 | # Save Encoder Outputs for training lstm
108 | save_encodings(config, model, mnist_train_loader)
109 |
110 |
111 | if __name__ == '__main__':
112 | parser = argparse.ArgumentParser(description='Arguments for vqvae inference')
113 | parser.add_argument('--config', dest='config_path',
114 | default='config/vqvae_colored_mnist.yaml', type=str)
115 | args = parser.parse_args()
116 | inference(args)
--------------------------------------------------------------------------------
/tools/train_lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import cv2
5 | import glob
6 | import torch
7 | import yaml
8 | import argparse
9 | import random
10 | import numpy as np
11 | from torch.optim import Adam
12 | import pickle
13 | from tqdm import tqdm
14 | from torch.utils.data.dataset import Dataset
15 | from torch.utils.data.dataloader import DataLoader
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 |
19 | class MnistLSTM(nn.Module):
20 | r"""
21 | Very Simple 2 layer LSTM with an fc layer on last steps hidden dimension
22 | """
23 | def __init__(self, input_size, hidden_size, codebook_size):
24 | super(MnistLSTM, self).__init__()
25 | self.rnn = nn.LSTM(input_size=2, hidden_size=128, num_layers=2, batch_first=True)
26 | self.fc = nn.Sequential(nn.Linear(hidden_size, hidden_size // 4),
27 | nn.ReLU(),
28 | nn.Linear(hidden_size // 4, codebook_size))
29 | # Add pad and start token to embedding size
30 | self.word_embedding = nn.Embedding(codebook_size+2, input_size)
31 |
32 | def forward(self, x):
33 | x = self.word_embedding(x)
34 | output, _ = self.rnn(x)
35 | output = output[:, -1, :]
36 | return self.fc(output)
37 |
38 |
39 | class MnistSeqDataset(Dataset):
40 | r"""
41 | Dataset for training of LSTM. Assumes the encodings are already generated
42 | by running vqvae inference
43 | """
44 | def __init__(self, config):
45 | self.codebook_size = config['model_params']['codebook_size']
46 |
47 | # Codebook tokens will be 0 to codebook_size-1
48 | self.start_token = self.codebook_size
49 | self.pad_token = self.codebook_size+1
50 | # Fix context size
51 | self.context_size = 32
52 | self.sents = self.load_sents(config)
53 |
54 | def load_sents(self, config):
55 | assert os.path.exists(os.path.join(config['train_params']['task_name'],
56 | config['train_params']['output_train_dir'],
57 | 'mnist_encodings.pkl')), ("No encodings generated for lstm."
58 | "Run save_encodings method in inference script")
59 | mnist_encodings = pickle.load(open(os.path.join(config['train_params']['task_name'],
60 | config['train_params']['output_train_dir'],
61 | 'mnist_encodings.pkl'), 'rb'))
62 | mnist_encodings = mnist_encodings.reshape(mnist_encodings.size(0), -1)
63 | num_encodings = mnist_encodings.size(0)
64 | padded_sents = []
65 |
66 | for encoding_idx in tqdm(range(num_encodings)):
67 | # Use only 10% encodings.
68 | # Uncomment this for getting some kind of output quickly validate working
69 | if random.random() > 0.1:
70 | continue
71 | enc = mnist_encodings[encoding_idx]
72 | encoding_length = enc.shape[-1]
73 |
74 | # Make sure all encodings start with start token
75 | enc = torch.cat([torch.ones((1)).to(device) * self.start_token, enc.to(device)])
76 |
77 | # Create batches of context sized inputs(if possible) and target
78 | sents = [(enc[:i], enc[i]) if i < self.context_size else (enc[i - self.context_size:i], enc[i])
79 | for i in range(1, encoding_length+1)]
80 |
81 | for context, target in sents:
82 | # Pad token if context not enough
83 | if len(context) < self.context_size:
84 | context = torch.nn.functional.pad(context, (0, self.context_size-len(context)), "constant", self.pad_token)
85 | padded_sents.append((context, target))
86 | return padded_sents
87 |
88 | def __len__(self):
89 | return len(self.sents)
90 |
91 | def __getitem__(self, index):
92 | context, target = self.sents[index]
93 | return context, target
94 |
95 | def train_lstm(args):
96 | ############ Read the config #############
97 | with open(args.config_path, 'r') as file:
98 | try:
99 | config = yaml.safe_load(file)
100 | except yaml.YAMLError as exc:
101 | print(exc)
102 | print(config)
103 | #########################################
104 |
105 | ############## Create dataset ###########
106 | mnist = MnistSeqDataset(config)
107 | mnist_seq_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=0)
108 | #########################################
109 |
110 | ############## Create LSTM ###########
111 | default_lstm_config = {
112 | 'input_size' : 2,
113 | 'hidden_size' : 128,
114 | 'codebook_size' : config['model_params']['codebook_size']
115 | }
116 | model = MnistLSTM(input_size=default_lstm_config['input_size'],
117 | hidden_size=default_lstm_config['hidden_size'],
118 | codebook_size=default_lstm_config['codebook_size']).to(device)
119 | model.to(device)
120 | model.train()
121 |
122 | ############## Training Params ###########
123 | num_epochs = 10
124 | optimizer = Adam(model.parameters(), lr=1E-3)
125 | criterion = torch.nn.CrossEntropyLoss()
126 |
127 | for epoch in range(num_epochs):
128 | losses = []
129 | for sent, target in tqdm(mnist_seq_loader):
130 | sent = sent.to(device).long()
131 | target = target.to(device).long()
132 | optimizer.zero_grad()
133 | pred = model(sent)
134 | loss = torch.mean(criterion(pred, target))
135 | loss.backward()
136 | optimizer.step()
137 | losses.append(loss.item())
138 | print('Epoch {} : {}'.format(epoch, np.mean(losses)))
139 | print('=' * 50)
140 | torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'],
141 | 'best_mnist_lstm.pth'))
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser(description='Arguments for lstm training')
146 | parser.add_argument('--config', dest='config_path',
147 | default='../config/vqvae_colored_mnist.yaml', type=str)
148 | args = parser.parse_args()
149 | train_lstm(args)
150 |
151 |
--------------------------------------------------------------------------------
/tools/train_vqvae.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import cv2
5 | import random
6 | import os
7 | import shutil
8 | import torchvision
9 | import numpy as np
10 | from tqdm import tqdm
11 | from model.vqvae import get_model
12 | from torch.utils.data.dataloader import DataLoader
13 | from dataset.mnist_dataset import MnistDataset
14 | from torch.optim import Adam
15 | from torchvision.utils import make_grid
16 | from torch.optim.lr_scheduler import ReduceLROnPlateau
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 |
19 | def train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, crtierion, config):
20 | r"""
21 | Method to run the training for one epoch.
22 | :param epoch_idx: iteration number of current epoch
23 | :param model: VQVAE model
24 | :param mnist_loader: Data loder for mnist
25 | :param optimizer: optimzier to be used taken from config
26 | :param crtierion: For computing the loss
27 | :param config: configuration for the current run
28 | :return:
29 | """
30 | recon_losses = []
31 | codebook_losses = []
32 | commitment_losses = []
33 | losses = []
34 | # We ignore the label for VQVAE
35 | count = 0
36 | for im, _ in tqdm(mnist_loader):
37 | im = im.float().to(device)
38 | optimizer.zero_grad()
39 | model_output = model(im)
40 | output = model_output['generated_image']
41 | quantize_losses = model_output['quantized_losses']
42 |
43 | if config['train_params']['save_training_image']:
44 | cv2.imwrite('input.jpeg', (255 * (im.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0])
45 | cv2.imwrite('output.jpeg', (255 * (output.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0])
46 |
47 | recon_loss = crtierion(output, im)
48 | loss = (config['train_params']['reconstruction_loss_weight']*recon_loss +
49 | config['train_params']['codebook_loss_weight']*quantize_losses['codebook_loss'] +
50 | config['train_params']['commitment_loss_weight']*quantize_losses['commitment_loss'])
51 | recon_losses.append(recon_loss.item())
52 | codebook_losses.append(config['train_params']['codebook_loss_weight']*quantize_losses['codebook_loss'].item())
53 | commitment_losses.append(quantize_losses['commitment_loss'].item())
54 | losses.append(loss.item())
55 | loss.backward()
56 | optimizer.step()
57 | print('Finished epoch: {} | Recon Loss : {:.4f} | Codebook Loss : {:.4f} | Commitment Loss : {:.4f}'.
58 | format(epoch_idx + 1,
59 | np.mean(recon_losses),
60 | np.mean(codebook_losses),
61 | np.mean(commitment_losses)))
62 | return np.mean(losses)
63 |
64 |
65 | def train(args):
66 | ######## Read the config file #######
67 | with open(args.config_path, 'r') as file:
68 | try:
69 | config = yaml.safe_load(file)
70 | except yaml.YAMLError as exc:
71 | print(exc)
72 | print(config)
73 | #######################################
74 |
75 | ######## Set the desired seed value #######
76 | seed = config['train_params']['seed']
77 | torch.manual_seed(seed)
78 | np.random.seed(seed)
79 | random.seed(seed)
80 | if device == 'cuda':
81 | torch.cuda.manual_seed_all(args.seed)
82 | #######################################
83 |
84 | # Create the model and dataset
85 | model = get_model(config).to(device)
86 | mnist = MnistDataset('train', config['train_params']['train_path'], im_channels=config['model_params']['in_channels'])
87 | mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=0)
88 | num_epochs = config['train_params']['epochs']
89 | optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
90 | scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True)
91 | criterion = {
92 | 'l1': torch.nn.L1Loss(),
93 | 'l2': torch.nn.MSELoss()
94 | }.get(config['train_params']['crit'])
95 |
96 | # Create output directories
97 | if not os.path.exists(config['train_params']['task_name']):
98 | os.mkdir(config['train_params']['task_name'])
99 | if not os.path.exists(os.path.join(config['train_params']['task_name'],
100 | config['train_params']['output_train_dir'])):
101 | os.mkdir(os.path.join(config['train_params']['task_name'],
102 | config['train_params']['output_train_dir']))
103 |
104 | # Load checkpoint if found
105 | if os.path.exists(os.path.join(config['train_params']['task_name'],
106 | config['train_params']['ckpt_name'])):
107 | print('Loading checkpoint')
108 | model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
109 | config['train_params']['ckpt_name']), map_location=device))
110 | best_loss = np.inf
111 |
112 | for epoch_idx in range(num_epochs):
113 | mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, criterion, config)
114 | scheduler.step(mean_loss)
115 | # Simply update checkpoint if found better version
116 | if mean_loss < best_loss:
117 | print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss))
118 | torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'],
119 | config['train_params']['ckpt_name']))
120 | best_loss = mean_loss
121 | else:
122 | print('No Loss Improvement')
123 |
124 |
125 | if __name__ == '__main__':
126 | parser = argparse.ArgumentParser(description='Arguments for vq vae training')
127 | parser.add_argument('--config', dest='config_path',
128 | default='config/vqvae_colored_mnist.yaml', type=str)
129 | args = parser.parse_args()
130 | train(args)
--------------------------------------------------------------------------------