├── .github └── workflows │ └── main.yml ├── LICENSE ├── README.md ├── configs └── default.yml ├── dataloader ├── __init__.py ├── cifar10.py ├── load_dataloader.py └── mnist.py ├── environment.yml ├── experiments ├── generated_0.jpg ├── generated_1.jpg ├── generated_2.jpg ├── generated_3.jpg ├── generated_4.jpg └── reconstruction.gif ├── generate.py ├── requirements.txt ├── test ├── __init__.py ├── test_dataloader.py └── test_vqgan.py ├── train.py ├── trainer ├── __init__.py ├── trainer.py ├── transformer.py └── vqgan.py ├── transformer ├── __init__.py ├── mingpt.py └── transformer.py ├── utils ├── __init__.py ├── assets │ ├── aim_images.png │ ├── aim_metrics.png │ ├── encoder_arch.png │ ├── nonlocalblocks_arch.png │ ├── patchgan_disc.png │ ├── perceptual_loss.png │ ├── reconstruction.gif │ ├── sliding_window.png │ ├── stage_1.png │ ├── stage_2.png │ ├── vqgan.png │ └── vqvae_arch.png └── utils.py └── vqgan ├── __init__.py ├── codebook.py ├── common.py ├── decoder.py ├── discriminator.py ├── encoder.py └── vqgan.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: Run Python Tests 4 | 5 | # Controls when the workflow will run 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the "main" branch 8 | push: 9 | branches: [ "main" ] 10 | pull_request: 11 | branches: [ "main" ] 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | # workflow_dispatch: 15 | 16 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 17 | jobs: 18 | # This workflow contains a single job called "build" 19 | build: 20 | # The type of runner that the job will run on 21 | runs-on: ubuntu-latest 22 | 23 | # Steps represent a sequence of tasks that will be executed as part of the job 24 | steps: 25 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 26 | - uses: actions/checkout@v3 27 | 28 | - name: Setup Python 29 | uses: actions/setup-python@v4.2.0 30 | with: 31 | # Version range or exact version of Python or PyPy to use, using SemVer's version range syntax. Reads from .python-version if unset. 32 | python-version: 3.7.13 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install -r requirements.txt 38 | 39 | - name: Run tests with pytest 40 | run: pytest --cov=. test --cov-report=xml 41 | 42 | - name: Upload coverage reports to Codecov with GitHub Action 43 | uses: codecov/codecov-action@v3 44 | with: 45 | token: ${{ secrets.CODECOV_TOKEN }} 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shubhamai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch VQGAN 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/Shubhamai/pytorch-vqgan/blob/main/LICENSE) 4 | [![Run Python Tests](https://github.com/Shubhamai/pytorch-vqgan/actions/workflows/main.yml/badge.svg)](https://github.com/Shubhamai/pytorch-vqgan/actions/workflows/main.yml) 5 | [![codecov](https://codecov.io/gh/Shubhamai/pytorch-vqgan/branch/main/graph/badge.svg?token=NANKT1FU4M)](https://codecov.io/gh/Shubhamai/pytorch-vqgan) 6 | 7 |

8 |
9 | Figure 1. VQGAN Architecture 10 |

11 | 12 | > **Note:** This is a work in progress. 13 | 14 | 15 | This repo purpose is to serve as a cleaner and feature-rich implementation of the VQGAN - *[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2010.11929)* from the initial work of [dome272's repo](https://github.com/dome272/VQGAN-pytorch) in PyTorch from scratch. There's also a great video on the [explanation of VQGAN](https://youtu.be/wcqLFDXaDO8) by dome272. 16 | 17 | I created this repo to better understand VQGAN myself, and to provide scripts for faster training and experimentation with a toy dataset like MNIST etc. I also tried to make it as clean as possible, with comments, logging, testing & coverage, custom datasets & visualizations, etc. 18 | 19 | - [PyTorch VQGAN](#pytorch-vqgan) 20 | - [What is VQGAN?](#what-is-vqgan) 21 | - [Stage 1](#stage-1) 22 | - [Training](#training) 23 | - [Generation](#generation) 24 | - [Stage 2](#stage-2) 25 | - [Setup](#setup) 26 | - [Usage](#usage) 27 | - [Training](#training-1) 28 | - [Generation](#generation-1) 29 | - [Tests](#tests) 30 | - [Hardware requirements](#hardware-requirements) 31 | - [Shoutouts](#shoutouts) 32 | - [BibTeX](#bibtex) 33 | 34 | ## What is VQGAN? 35 | 36 | VQGAN stands for **V**ector **Q**uantised **G**enerative **A**dversarial **N**etworks. The main idea behind this paper is to use CNN to learn the visual part of the image and generate a codebook of context-rich visual parts and then use Transformers to learn the long-range/global interactions between the visual parts of the image embedded in the codebook. Combining these two, we can generate very high-resolution images. 37 | 38 | Learning both of these short and long-term interactions to generate high-resolution images is done in two different stages. 39 | 40 | 1. The first stage uses VQGAN to learn the codebook of **context-rich** visual representation of the images. In terms of architecture, it is very similar to VQVAE in that it consists of an encoder, decoder and the codebook. We will learn more about this in the next section. 41 | 42 |

43 |
44 | Figure 2. VQVAE Architecture 45 |

46 | 47 | 2. Using a transformer to learn the global interactions between the vectors in the codebook by predicting the next sequence from the previous sequences, to generate high-resolution images. 48 | 49 | --- 50 | 51 | ### Stage 1 52 | 53 |

54 |
55 | Stage 1 : VQGAN Architecture 56 |

57 | 58 | 59 | The architecture of VQGAN consists of majorly three parts, the encoder, decoder and the Codebook, similar to the VQVAE paper. 60 | 61 | 62 | 1. The encoder [`encoder.py`](vqgan/encoder.py) part in the VQGAN learns to represent the images into a much lower dimension called embeddings or latent and consists of Convolution, Downsample, Residual blocks and special attention blocks ( Non-Local blocks ), around 30 million parameters in default settings. 63 | 2. The embeddings are then quantized using CodeBook and the quantized embeddings are used as input to the decoder [`decoder.py`](vqgan/decoder.py) part. 64 | 3. The decode takes the "quantized" embeddings and reconstructs the image. The architecture is similar to the encoder but reversed. Around 40 million parameters in default settings, slightly more compared to encoder due to more number of residual blocks. 65 | 66 | The main idea behind codebook and quantization is to convert the continuous latent representation into a discrete representation. The codebook is simply a list of `n` latent vectors ( which are learned while training ) which are then used to replace the latents generated from the encoder output with the closest vector ( in terms of distance ) from the codebook. The **VQ** part comes from here. 67 | 68 | #### Training 69 | 70 | The training involves, sending the batch of images through the encoder, quantizing the embeddings and then sending the quantized embeddings through the decoder to reconstruct the image. The loss function is computed as follows: 71 | 72 | 73 | $` 74 | \begin{aligned} 75 | \mathcal{L}_{\mathrm{VQ}}(E, G, \mathcal{Z})=\|x-\hat{x}\|^{2} &+\left\|\text{sg}[E(x)]-z_{\mathbf{q}}\right\|_{2}^{2}+\left\|\text{sg}\left[z_{\mathbf{q}}\right]-E(x)\right\|_{2}^{2} . 76 | \end{aligned} 77 | `$ 78 | 79 | The above equation represents the sum of reconstruction loss, alignment and commitment loss 80 | 81 | 1. Reconstruction loss 82 | 83 | > Appartely there is some confusion about is this reconstruction loss was replaced with perceptual loss or it was a combination of them, we will go with what was implemented in the official code https://github.com/CompVis/taming-transformers/issues/40, which is l1 + perceptual loss 84 | 85 | 86 | 87 | 88 | The reconstruction loss is a sum of the l1 loss and perceptual loss. 89 | $`\text { L1 Loss }=\sum_{i=1}^{n}\left|y_{\text {true }}-y_{\text {predicted }}\right|`$ 90 | 91 | The perceptual is calculated the l2 distance between the last layer output of the generated vs original image from pre-trained model like VGG, etc. 92 | 93 | 2. The alignment and commitment loss is from the quantization which compares the distance between the latent vectors from encoder output and the closest vector from the codebook. `sg` here means stop gradient function. 94 | 95 | --- 96 | 97 | 98 | 99 | $` 100 | \mathcal{L}_{\mathrm{GAN}}(\{E, G, \mathcal{Z}\}, D)=[\log D(x)+\log (1-D(\hat{x}))] 101 | `$ 102 | 103 | The above loss is for the discriminator which takes in real and generated images and learns to classify which one's real or face. the **GAN** in VQGAN comes from here :) 104 | 105 | The discrimination here is a bit different than conventional discriminators in that, instead of taking whole images as an input, they instead convert the images into patches using convolution and then predict which patch is real or fake. 106 | 107 | --- 108 | 109 | $` 110 | \lambda=\frac{\nabla_{G_{L}}\left[\mathcal{L}_{\mathrm{rec}}\right]}{\nabla_{G_{L}}\left[\mathcal{L}_{\mathrm{GAN}}\right]+\delta} 111 | `$ 112 | 113 | We calculate lambda as the ratio between the reconstruction loss and the GAN loss, both with respect to the gradient of the last layer of the decoder. `calculate_lambda` in [`vqgan.py`](vqgan/vqgan.py) 114 | 115 | The final loss then becomes - 116 | 117 | $` 118 | \begin{aligned} 119 | \mathcal{Q}^{*}=\underset{E, G, \mathcal{Z}}{\arg \min } \max _{D} \mathbb{E}_{x \sim p(x)}\left[\mathcal{L}_{\mathrm{VQ}}(E, G, \mathcal{Z})+\lambda \mathcal{L}_{\mathrm{GAN}}(\{E, G, \mathcal{Z}\}, D)\right] 120 | \end{aligned} 121 | `$ 122 | 123 | which is the combination of the reconstruction loss, alignment loss and commitment loss and discriminator loss multiplied with `lambda`. 124 | 125 | ### Generation 126 | 127 | 128 | To generate the images from VQGAN, we generate the quantized vectors from [Stage 2](#stage-2) and pass them through the decoder to reconstruct the image. 129 | 130 | --- 131 | 132 | ### Stage 2 133 | 134 |

135 |
136 | Stage 2: Transformers 137 |

138 | 139 | 140 | 141 | 142 | This stage contains Transformers 🤖 which are trained to predict the next latent vector from the sequence of previous latent vectors in the quantized encoder output. The paper uses [`mingpt.py`](transformer/mingpt.py) from Andrej Karpathy's [karpathy/minGPT](https://github.com/karpathy/minGPT) repo. 143 | 144 | Due to computation constraints of generating high-resolution images, they also use a sliding attention window to predict the next latent vector from its neighbor vectors in the quantized encoder output. 145 | 146 | ## Setup 147 | 148 | 1. Clone the repo - `https://github.com/Shubhamai/pytorch-vqgan` 149 | 2. Create a new conda environment using `conda env create --prefix env python=3.7.13 --file=environment.yml` 150 | 3. Activate the conda environment using `conda activate ./env` 151 | 152 | ## Usage 153 | 154 | 155 | 156 | 157 | ### Training 158 | 159 | - You can start the training by running `python train.py`. It reads the default config file from `configs/default.yml` . To change the config path, run - `python train.py --config_path configs/default.yaml`. 160 | 161 | Here's what mostly the script does - 162 | - Downloads the MNIST dataset automatically and saved in the [data](/data) directory ( specified in config ). 163 | - Training the VQGAN and transformer model on the MNIST train set with parameters passed from the config file. 164 | - The training metrics, visualizations and model are saved in the [experiments/](/experiments/) directory with the corresponding path specified in the config file. 165 | 166 | - Run `aim up` to open the experiment tracker to see the metrics and reconstructed & generated images. 167 | 168 |

169 | 170 |

171 | 172 | 173 | ### Generation 174 | 175 | To generate the images, simply run `python generate.py`, the models will be loaded from the `experiments/checkpoints` and the output will be saved in `experiments`. 176 | 177 | ### Tests 178 | 179 | I have also just started getting my feet wet with testing and automated testing with GitHub CI/CD, so the tests here might not be the best practices. 180 | 181 | To run tests, run `pytest --cov-config=.coveragerc --cov=. test` 182 | 183 | 184 | ## Hardware requirements 185 | 186 | The hardware which I tried the model on default settings is - 187 | - Ryzen 5 4600H 188 | - NVIDIA GeForce GTX 1660Ti - 6 GB VRAM 189 | - 12 GB ram 190 | 191 | It took around 2-3 min to get good reconstruction results. Since, google colab has similar hardware in terms compute power from what I understand, it should run just fine on colab :) 192 | 193 | 194 | ## Shoutouts 195 | 196 | The list here contains some helpful blogs or videos that helped me a bunch in understanding the VQGAN. 197 | 198 | 1. [The Illustrated VQGAN](https://ljvmiranda921.github.io/notebook/2021/08/08/clip-vqgan/) by Lj Miranda 199 | 2. [VQGAN: Taming Transformers for High-Resolution Image Synthesis [Paper Explained]](https://youtu.be/-wDSDtIAyWQ) by Gradient Dude 200 | 3. [VQ-GAN: Taming Transformers for High-Resolution Image Synthesis | Paper Explained](https://youtu.be/j2PXES-liuc) by The AI Epiphany 201 | 4. [VQ-GAN | Paper Explanation](https://youtu.be/wcqLFDXaDO8) and [VQ-GAN | PyTorch Implementation](https://youtu.be/_Br5WRwUz_U) by Outlier 202 | 5. [TL#006 Robin Rombach Taming Transformers for High Resolution Image Synthesis](https://youtu.be/fy153-yXSQk) by one of the paper's author - Robin Rombach. Thanks for the talk :) 203 | 204 | ## BibTeX 205 | 206 | ``` 207 | @misc{esser2020taming, 208 | title={Taming Transformers for High-Resolution Image Synthesis}, 209 | author={Patrick Esser and Robin Rombach and Björn Ommer}, 210 | year={2020}, 211 | eprint={2012.09841}, 212 | archivePrefix={arXiv}, 213 | primaryClass={cs.CV} 214 | } 215 | ``` 216 | -------------------------------------------------------------------------------- /configs/default.yml: -------------------------------------------------------------------------------- 1 | architecture: 2 | vqgan: 3 | img_channels: 1 4 | img_size: 256 5 | latent_channels: 256 6 | latent_size: 16 7 | intermediate_channels: [128, 128, 256, 256, 512] 8 | num_residual_blocks_encoder: 2 9 | num_residual_blocks_decoder: 3 10 | dropout: 0.0 11 | attention_resolution: [16] 12 | num_codebook_vectors: 1024 13 | 14 | transformer: 15 | sos_token: 0 16 | pkeep: 0.5 17 | block_size: 512 18 | n_layer: 12 19 | n_head: 16 20 | n_embd: 1024 21 | 22 | trainer: 23 | vqgan: 24 | learning_rate: 2.25e-05 25 | beta1: 0.5 26 | beta2: 0.9 27 | perceptual_loss_factor: 1.0 28 | rec_loss_factor: 1.0 29 | disc_factor: 1.0 30 | disc_start: 100 31 | perceptual_model: "vgg" 32 | save_every: 10 33 | 34 | transformer: 35 | learning_rate: 4.5e-06 36 | beta1: 0.9 37 | beta2: 0.95 38 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import load_mnist 2 | from .cifar10 import load_cifar10 3 | from .load_dataloader import load_dataloader -------------------------------------------------------------------------------- /dataloader/cifar10.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import torch 3 | import torchvision 4 | 5 | from utils import collate_fn 6 | 7 | 8 | def load_cifar10( 9 | batch_size: int = 16, 10 | image_size: int = 28, 11 | num_workers: int = 4, 12 | save_path: str = "data", 13 | ) -> torch.utils.data.DataLoader: 14 | """Load the Cifar 10 data and returns the dataloaders (train ). The data is downloaded if it does not exist. 15 | 16 | Args: 17 | batch_size (int): The batch size. 18 | image_size (int): The image size. 19 | num_workers (int): The number of workers to use for the dataloader. 20 | save_path (str): The path to save the data to. 21 | 22 | Returns: 23 | torch.utils.data.DataLoader: The data loader. 24 | """ 25 | 26 | # Load the data 27 | dataloader = torch.utils.data.DataLoader( 28 | torchvision.datasets.CIFAR10( 29 | root=save_path, 30 | train=True, 31 | download=True, 32 | transform=torchvision.transforms.Compose( 33 | [ 34 | torchvision.transforms.Resize((image_size, image_size)), 35 | torchvision.transforms.ToTensor(), 36 | torchvision.transforms.Normalize((0.1307,), (0.3081,)), 37 | ] 38 | ), 39 | ), 40 | batch_size=batch_size, 41 | shuffle=True, 42 | num_workers=num_workers, 43 | pin_memory=True, 44 | collate_fn=collate_fn, 45 | ) 46 | 47 | return dataloader 48 | -------------------------------------------------------------------------------- /dataloader/load_dataloader.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import torch 3 | 4 | from dataloader import load_mnist, load_cifar10 5 | 6 | 7 | def load_dataloader( 8 | name: str = "mnist", 9 | batch_size: int = 2, 10 | image_size: int = 256, 11 | num_workers: int = 4, 12 | save_path: str = "data", 13 | ) -> torch.utils.data.DataLoader: 14 | """Load the data loader for the given name. 15 | 16 | Args: 17 | name (str, optional): The name of the data loader. Defaults to "mnist". 18 | batch_size (int, optional): The batch size. Defaults to 2. 19 | image_size (int, optional): The image size. Defaults to 256. 20 | num_workers (int, optional): The number of workers to use for the dataloader. Defaults to 4. 21 | save_path (str, optional): The path to save the data to. Defaults to "data". 22 | 23 | Returns: 24 | torch.utils.data.DataLoader: The data loader. 25 | """ 26 | 27 | if name == "mnist": 28 | return load_mnist( 29 | batch_size=batch_size, 30 | image_size=image_size, 31 | num_workers=num_workers, 32 | save_path=save_path, 33 | ) 34 | 35 | elif name == "cifar10": 36 | return load_cifar10( 37 | batch_size=batch_size, 38 | image_size=image_size, 39 | num_workers=num_workers, 40 | save_path=save_path, 41 | ) -------------------------------------------------------------------------------- /dataloader/mnist.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import torch 3 | import torchvision 4 | 5 | from utils import collate_fn 6 | 7 | 8 | def load_mnist( 9 | batch_size: int = 2, 10 | image_size: int = 256, 11 | num_workers: int = 4, 12 | save_path: str = "data", 13 | ) -> torch.utils.data.DataLoader: 14 | """Load the MNIST data and returns the dataloaders (train ). The data is downloaded if it does not exist. 15 | 16 | Args: 17 | batch_size (int): The batch size. 18 | image_size (int): The image size. 19 | num_workers (int): The number of workers to use for the dataloader. 20 | save_path (str): The path to save the data to. 21 | 22 | Returns: 23 | torch.utils.data.DataLoader: The data loader. 24 | """ 25 | 26 | # Load the data 27 | mnist_data = torchvision.datasets.MNIST( 28 | root=save_path, 29 | train=True, 30 | download=True, 31 | transform=torchvision.transforms.Compose( 32 | [ 33 | torchvision.transforms.Resize((image_size, image_size)), 34 | torchvision.transforms.Grayscale(num_output_channels=1), 35 | torchvision.transforms.ToTensor(), 36 | torchvision.transforms.Normalize((0.1307,), (0.3081,)), 37 | ] 38 | ), 39 | ) 40 | 41 | # Reduced set for faster training 42 | mnist_data_reduced = torch.utils.data.Subset(mnist_data, list(range(0, 800))) 43 | 44 | dataloader = torch.utils.data.DataLoader( 45 | mnist_data_reduced, 46 | batch_size=batch_size, 47 | shuffle=True, 48 | num_workers=num_workers, 49 | pin_memory=True, 50 | collate_fn=collate_fn, 51 | ) 52 | 53 | return dataloader 54 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: /mnt/Stuffs/Projects/fromScratch/pytorch_vqgan/env 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotlipy=0.7.0=py37h27cfd23_1003 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2022.07.19=h06a4308_0 13 | - certifi=2022.6.15=py37h06a4308_0 14 | - cffi=1.15.1=py37h74dc2b5_0 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cryptography=37.0.1=py37h9ce1e76_0 17 | - cudatoolkit=11.6.0=hecad31d_10 18 | - ffmpeg=4.2.2=h20bf706_0 19 | - freetype=2.11.0=h70c0345_0 20 | - giflib=5.2.1=h7b6447c_0 21 | - gmp=6.2.1=h295c915_3 22 | - gnutls=3.6.15=he1e5248_0 23 | - idna=3.3=pyhd3eb1b0_0 24 | - intel-openmp=2021.4.0=h06a4308_3561 25 | - jpeg=9e=h7f8727e_0 26 | - lame=3.100=h7b6447c_0 27 | - lcms2=2.12=h3be6417_0 28 | - ld_impl_linux-64=2.38=h1181459_1 29 | - libffi=3.3=he6710b0_2 30 | - libgcc-ng=11.2.0=h1234567_1 31 | - libgomp=11.2.0=h1234567_1 32 | - libidn2=2.3.2=h7f8727e_0 33 | - libopus=1.3.1=h7b6447c_0 34 | - libpng=1.6.37=hbc83047_0 35 | - libstdcxx-ng=11.2.0=h1234567_1 36 | - libtasn1=4.16.0=h27cfd23_0 37 | - libtiff=4.2.0=h2818925_1 38 | - libunistring=0.9.10=h27cfd23_0 39 | - libvpx=1.7.0=h439df22_0 40 | - libwebp=1.2.2=h55f646e_0 41 | - libwebp-base=1.2.2=h7f8727e_0 42 | - lz4-c=1.9.3=h295c915_1 43 | - mkl=2021.4.0=h06a4308_640 44 | - mkl-service=2.4.0=py37h7f8727e_0 45 | - mkl_fft=1.3.1=py37hd3c417c_0 46 | - mkl_random=1.2.2=py37h51133e4_0 47 | - ncurses=6.3=h5eee18b_3 48 | - nettle=3.7.3=hbbd107a_1 49 | - numpy=1.21.5=py37h6c91a56_3 50 | - numpy-base=1.21.5=py37ha15fc14_3 51 | - openh264=2.1.1=h4ff587b_0 52 | - openssl=1.1.1q=h7f8727e_0 53 | - pillow=9.2.0=py37hace64e9_1 54 | - pip=22.1.2=py37h06a4308_0 55 | - pycparser=2.21=pyhd3eb1b0_0 56 | - pyopenssl=22.0.0=pyhd3eb1b0_0 57 | - pysocks=1.7.1=py37_1 58 | - python=3.7.13=h12debd9_0 59 | - python_abi=3.7=2_cp37m 60 | - pytorch=1.12.1=py3.7_cuda11.6_cudnn8.3.2_0 61 | - pytorch-mutex=1.0=cuda 62 | - readline=8.1.2=h7f8727e_1 63 | - requests=2.28.1=py37h06a4308_0 64 | - setuptools=61.2.0=py37h06a4308_0 65 | - six=1.16.0=pyhd3eb1b0_1 66 | - sqlite=3.39.2=h5082296_0 67 | - tk=8.6.12=h1ccaba5_0 68 | - torchaudio=0.12.1=py37_cu116 69 | - torchvision=0.13.1=py37_cu116 70 | - typing_extensions=4.3.0=py37h06a4308_0 71 | - urllib3=1.26.11=py37h06a4308_0 72 | - wheel=0.37.1=pyhd3eb1b0_0 73 | - x264=1!157.20191217=h7b6447c_0 74 | - xz=5.2.5=h7f8727e_1 75 | - zlib=1.2.12=h7f8727e_2 76 | - zstd=1.5.2=ha4553b6_0 77 | - pip: 78 | - aim==3.12.2 79 | - aim-ui==3.12.2 80 | - aimrecords==0.0.7 81 | - aimrocks==0.2.1 82 | - aiofiles==0.8.0 83 | - albumentations==1.2.1 84 | - alembic==1.8.1 85 | - attrs==22.1.0 86 | - base58==2.0.1 87 | - beautifulsoup4==4.11.1 88 | - cachetools==5.2.0 89 | - click==8.1.3 90 | - codecov==2.1.12 91 | - coverage==6.4.3 92 | - exceptiongroup==1.0.0rc8 93 | - fastapi==0.67.0 94 | - filelock==3.8.0 95 | - google-api-core==2.8.2 96 | - google-api-python-client==2.56.0 97 | - google-auth==2.10.0 98 | - google-auth-httplib2==0.1.0 99 | - googleapis-common-protos==1.56.4 100 | - greenlet==1.1.2 101 | - grpcio==1.47.0 102 | - h11==0.13.0 103 | - httplib2==0.20.4 104 | - hypothesis==6.54.3 105 | - imageio==2.21.1 106 | - importlib-metadata==4.12.0 107 | - importlib-resources==5.9.0 108 | - iniconfig==1.1.1 109 | - jinja2==3.1.2 110 | - joblib==1.1.0 111 | - lpips==0.1.4 112 | - mako==1.2.1 113 | - markupsafe==2.1.1 114 | - networkx==2.6.3 115 | - opencv-python-headless==4.6.0.66 116 | - packaging==21.3 117 | - pandas==1.3.5 118 | - pluggy==1.0.0 119 | - protobuf==3.20.0 120 | - psutil==5.9.1 121 | - py==1.11.0 122 | - py3nvml==0.2.7 123 | - pyasn1==0.4.8 124 | - pyasn1-modules==0.2.8 125 | - pydantic==1.9.2 126 | - pyparsing==3.0.9 127 | - pytest==7.1.2 128 | - pytest-cov==3.0.0 129 | - python-dateutil==2.8.2 130 | - pytz==2022.2.1 131 | - pywavelets==1.3.0 132 | - pyyaml==6.0 133 | - qudida==0.0.4 134 | - restrictedpython==5.2 135 | - rsa==4.9 136 | - scikit-image==0.19.3 137 | - scikit-learn==1.0.2 138 | - scipy==1.7.3 139 | - sortedcontainers==2.4.0 140 | - soupsieve==2.3.2.post1 141 | - sqlalchemy==1.4.40 142 | - starlette==0.14.2 143 | - threadpoolctl==3.1.0 144 | - tifffile==2021.11.2 145 | - tomli==2.0.1 146 | - torch-summary==1.4.5 147 | - tqdm==4.64.0 148 | - uritemplate==4.1.1 149 | - uvicorn==0.18.2 150 | - xmltodict==0.13.0 151 | - zipp==3.8.1 152 | prefix: /mnt/Stuffs/Projects/fromScratch/pytorch_vqgan/env 153 | -------------------------------------------------------------------------------- /experiments/generated_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_0.jpg -------------------------------------------------------------------------------- /experiments/generated_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_1.jpg -------------------------------------------------------------------------------- /experiments/generated_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_2.jpg -------------------------------------------------------------------------------- /experiments/generated_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_3.jpg -------------------------------------------------------------------------------- /experiments/generated_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_4.jpg -------------------------------------------------------------------------------- /experiments/reconstruction.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/reconstruction.gif -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import argparse 3 | 4 | import yaml 5 | 6 | from trainer import Trainer 7 | from transformer import VQGANTransformer 8 | from vqgan import VQGAN 9 | 10 | 11 | def main(args, config): 12 | 13 | vqgan = VQGAN(**config["architecture"]["vqgan"]) 14 | vqgan.load_checkpoint("./experiments/checkpoints/vqgan.pt") 15 | 16 | transformer = VQGANTransformer( 17 | vqgan, device=args.device, **config["architecture"]["transformer"] 18 | ) 19 | transformer.load_checkpoint("./experiments/checkpoints/transformer.pt") 20 | 21 | trainer = Trainer( 22 | vqgan, 23 | transformer, 24 | run=None, 25 | config=config["trainer"], 26 | seed=args.seed, 27 | device=args.device, 28 | ) 29 | 30 | trainer.generate_images() 31 | 32 | 33 | if __name__ == "__main__": 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--config_path", 38 | type=str, 39 | default="configs/default.yml", 40 | help="path to config file", 41 | ) 42 | parser.add_argument( 43 | "--dataset_name", 44 | type=str, 45 | choices=["mnist", "cifar", "custom"], 46 | default="mnist", 47 | help="Dataset for the model", 48 | ) 49 | parser.add_argument( 50 | "--device", 51 | type=str, 52 | default="cuda", 53 | choices=["cpu", "cuda"], 54 | help="Device to train the model on", 55 | ) 56 | parser.add_argument( 57 | "--seed", 58 | type=str, 59 | default=42, 60 | help="Seed for Reproducibility", 61 | ) 62 | 63 | args = parser.parse_args() 64 | 65 | args = parser.parse_args() 66 | with open(args.config_path) as f: 67 | config = yaml.load(f, Loader=yaml.FullLoader) 68 | 69 | main(args, config) 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aim==3.12.2 2 | aim-ui==3.12.2 3 | aimrecords==0.0.7 4 | aimrocks==0.2.1 5 | aiofiles==0.8.0 6 | albumentations==1.2.1 7 | alembic==1.8.1 8 | attrs==22.1.0 9 | base58==2.0.1 10 | beautifulsoup4==4.11.1 11 | brotlipy==0.7.0 12 | cachetools==5.2.0 13 | certifi==2022.6.15 14 | cffi==1.15.1 15 | charset-normalizer==2.0.4 16 | click==8.1.3 17 | codecov==2.1.12 18 | coverage==6.4.3 19 | cryptography==37.0.1 20 | exceptiongroup==1.0.0rc8 21 | fastapi==0.67.0 22 | filelock==3.8.0 23 | google-api-core==2.8.2 24 | google-api-python-client==2.56.0 25 | google-auth==2.10.0 26 | google-auth-httplib2==0.1.0 27 | googleapis-common-protos==1.56.4 28 | greenlet==1.1.2 29 | grpcio==1.47.0 30 | h11==0.13.0 31 | httplib2==0.20.4 32 | hypothesis==6.54.3 33 | idna==3.3 34 | imageio==2.21.1 35 | importlib-metadata==4.12.0 36 | importlib-resources==5.9.0 37 | iniconfig==1.1.1 38 | Jinja2==3.1.2 39 | joblib==1.1.1 40 | lpips==0.1.4 41 | Mako==1.2.1 42 | MarkupSafe==2.1.1 43 | mkl-fft==1.3.1 44 | mkl-random==1.2.2 45 | mkl-service==2.4.0 46 | networkx==2.6.3 47 | numpy==1.21.5 48 | opencv-python-headless==4.6.0.66 49 | packaging==21.3 50 | pandas==1.3.5 51 | Pillow==9.2.0 52 | pip==22.1.2 53 | pluggy==1.0.0 54 | protobuf==3.20.0 55 | psutil==5.9.1 56 | py==1.11.0 57 | py3nvml==0.2.7 58 | pyasn1==0.4.8 59 | pyasn1-modules==0.2.8 60 | pycparser==2.21 61 | pydantic==1.9.2 62 | pyOpenSSL==22.0.0 63 | pyparsing==3.0.9 64 | PySocks==1.7.1 65 | pytest==7.1.2 66 | pytest-cov==3.0.0 67 | python-dateutil==2.8.2 68 | pytz==2022.2.1 69 | PyWavelets==1.3.0 70 | PyYAML==6.0 71 | qudida==0.0.4 72 | requests==2.28.1 73 | RestrictedPython==5.2 74 | rsa==4.9 75 | scikit-image==0.19.3 76 | scikit-learn==1.0.2 77 | scipy==1.7.3 78 | setuptools==61.2.0 79 | six==1.16.0 80 | sortedcontainers==2.4.0 81 | soupsieve==2.3.2.post1 82 | SQLAlchemy==1.4.40 83 | starlette==0.25.0 84 | threadpoolctl==3.1.0 85 | tifffile==2021.11.2 86 | tomli==2.0.1 87 | torch==1.12.1 88 | torch-summary==1.4.5 89 | torchaudio==0.12.1 90 | torchvision==0.13.1 91 | tqdm==4.64.0 92 | typing_extensions==4.3.0 93 | uritemplate==4.1.1 94 | urllib3==1.26.11 95 | uvicorn==0.18.2 96 | wheel==0.37.1 97 | xmltodict==0.13.0 98 | zipp==3.8.1 99 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/test/__init__.py -------------------------------------------------------------------------------- /test/test_dataloader.py: -------------------------------------------------------------------------------- 1 | from dataloader import load_mnist 2 | 3 | 4 | def test_load_mnist(): 5 | 6 | dataloader = load_mnist(batch_size=16) 7 | 8 | for imgs in dataloader: 9 | break 10 | 11 | assert imgs.shape == (16, 1, 256, 256) -------------------------------------------------------------------------------- /test/test_vqgan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains test functions for the VQGAN model. 3 | """ 4 | 5 | # Importing Libraries 6 | import torch 7 | 8 | from vqgan import Encoder, Decoder, CodeBook, Discriminator 9 | 10 | 11 | def test_encoder(): 12 | 13 | image_channels = 3 14 | image_size = 256 15 | latent_channels = 256 16 | attn_resolution = 16 17 | 18 | image = torch.randn(1, image_channels, image_size, image_size) 19 | 20 | model = Encoder( 21 | img_channels=image_channels, 22 | image_size=image_size, 23 | latent_channels=latent_channels, 24 | attention_resolution=[attn_resolution], 25 | ) 26 | 27 | output = model(image) 28 | 29 | assert output.shape == ( 30 | 1, 31 | latent_channels, 32 | attn_resolution, 33 | attn_resolution, 34 | ), "Output of encoder does not match" 35 | 36 | 37 | def test_decoder(): 38 | 39 | img_channels = 3 40 | img_size = 256 41 | latent_channels = 256 42 | latent_size = 16 43 | attn_resolution = 16 44 | 45 | latent = torch.randn(1, latent_channels, latent_size, latent_size) 46 | model = Decoder( 47 | img_channels=img_channels, 48 | latent_size=latent_size, 49 | latent_channels=latent_channels, 50 | attention_resolution=[attn_resolution], 51 | ) 52 | 53 | output = model(latent) 54 | 55 | assert output.shape == ( 56 | 1, 57 | img_channels, 58 | img_size, 59 | img_size, 60 | ), "Output of decoder does not match" 61 | 62 | 63 | def test_codebook(): 64 | 65 | z = torch.randn(1, 256, 16, 16) 66 | 67 | codebok = CodeBook(num_codebook_vectors=100, latent_dim=16) 68 | 69 | z_q, min_distance_indices, loss = codebok(z) 70 | 71 | assert z_q.shape == (1, 256, 16, 16), "Output of codebook does not match" 72 | 73 | 74 | def test_discriminator(): 75 | 76 | image = torch.randn(1, 3, 256, 256) 77 | 78 | model = Discriminator() 79 | 80 | 81 | output = model(image) 82 | 83 | assert output.shape == (1, 1, 30, 30), "Output of discriminator does not match" 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import argparse 3 | 4 | import yaml 5 | from aim import Run 6 | 7 | from dataloader import load_dataloader 8 | from trainer import Trainer 9 | from transformer import VQGANTransformer 10 | from vqgan import VQGAN 11 | 12 | 13 | def main(args, config): 14 | 15 | vqgan = VQGAN(**config["architecture"]["vqgan"]) 16 | transformer = VQGANTransformer( 17 | vqgan, **config["architecture"]["transformer"], device=args.device 18 | ) 19 | dataloader = load_dataloader(name=args.dataset_name) 20 | 21 | run = Run(experiment=args.dataset_name) 22 | run["hparams"] = config 23 | 24 | trainer = Trainer( 25 | vqgan, 26 | transformer, 27 | run=run, 28 | config=config["trainer"], 29 | seed=args.seed, 30 | device=args.device, 31 | ) 32 | 33 | trainer.train_vqgan(dataloader) 34 | trainer.train_transformers(dataloader) 35 | trainer.generate_images() 36 | 37 | 38 | if __name__ == "__main__": 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | "--config_path", 43 | type=str, 44 | default="configs/default.yml", 45 | help="path to config file", 46 | ) 47 | parser.add_argument( 48 | "--dataset_name", 49 | type=str, 50 | choices=["mnist", "cifar", "custom"], 51 | default="mnist", 52 | help="Dataset for the model", 53 | ) 54 | parser.add_argument( 55 | "--device", 56 | type=str, 57 | default="cuda", 58 | choices=["cpu", "cuda"], 59 | help="Device to train the model on", 60 | ) 61 | parser.add_argument( 62 | "--seed", 63 | type=str, 64 | default=42, 65 | help="Seed for Reproducibility", 66 | ) 67 | 68 | args = parser.parse_args() 69 | 70 | args = parser.parse_args() 71 | with open(args.config_path) as f: 72 | config = yaml.load(f, Loader=yaml.FullLoader) 73 | 74 | main(args, config) 75 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .vqgan import VQGANTrainer 2 | from .transformer import TransformerTrainer 3 | from .trainer import Trainer -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import os 3 | 4 | import torch 5 | import torchvision 6 | from aim import Run 7 | from utils import reproducibility 8 | 9 | from trainer import TransformerTrainer, VQGANTrainer 10 | 11 | 12 | class Trainer: 13 | def __init__( 14 | self, 15 | vqgan: torch.nn.Module, 16 | transformer: torch.nn.Module, 17 | run: Run, 18 | config: dict, 19 | experiment_dir: str = "experiments", 20 | seed: int = 42, 21 | device: str = "cuda", 22 | ) -> None: 23 | 24 | self.vqgan = vqgan 25 | self.transformer = transformer 26 | 27 | self.run = run 28 | self.config = config 29 | self.experiment_dir = experiment_dir 30 | self.seed = seed 31 | self.device = device 32 | 33 | print(f"[INFO] Setting seed to {seed}") 34 | reproducibility(seed) 35 | 36 | print(f"[INFO] Results will be saved in {experiment_dir}") 37 | self.experiment_dir = experiment_dir 38 | 39 | def train_vqgan(self, dataloader: torch.utils.data.DataLoader, epochs: int = 1): 40 | 41 | print(f"[INFO] Training VQGAN on {self.device} for {epochs} epoch(s).") 42 | 43 | self.vqgan.to(self.device) 44 | 45 | self.vqgan_trainer = VQGANTrainer( 46 | model=self.vqgan, 47 | run=self.run, 48 | device=self.device, 49 | experiment_dir=self.experiment_dir, 50 | **self.config["vqgan"], 51 | ) 52 | 53 | self.vqgan_trainer.train( 54 | dataloader=dataloader, 55 | epochs=epochs, 56 | ) 57 | 58 | # Saving the model 59 | self.vqgan.save_checkpoint( 60 | os.path.join(self.experiment_dir, "checkpoints", "vqgan.pt") 61 | ) 62 | 63 | def train_transformers( 64 | self, dataloader: torch.utils.data.DataLoader, epochs: int = 1 65 | ): 66 | 67 | print(f"[INFO] Training Transformer on {self.device} for {epochs} epoch(s).") 68 | 69 | self.vqgan.eval() 70 | self.transformer = self.transformer.to(self.device) 71 | 72 | self.transformer_trainer = TransformerTrainer( 73 | model=self.transformer, 74 | run=self.run, 75 | device=self.device, 76 | experiment_dir=self.experiment_dir, 77 | **self.config["transformer"], 78 | ) 79 | 80 | self.transformer_trainer.train(dataloader=dataloader, epochs=epochs) 81 | 82 | self.transformer.save_checkpoint( 83 | os.path.join(self.experiment_dir, "checkpoints", "transformer.pt") 84 | ) 85 | 86 | def generate_images(self, n_images: int = 5): 87 | 88 | print(f"[INFO] Generating {n_images} images...") 89 | 90 | self.vqgan.to(self.device) 91 | self.transformer = self.transformer.to(self.device) 92 | 93 | 94 | for i in range(n_images): 95 | start_indices = torch.zeros((4, 0)).long().to(self.device) 96 | sos_tokens = torch.ones(start_indices.shape[0], 1) * 0 97 | 98 | sos_tokens = sos_tokens.long().to(self.device) 99 | sample_indices = self.transformer.sample( 100 | start_indices, sos_tokens, steps=256 101 | ) 102 | sampled_imgs = self.transformer.z_to_image(sample_indices) 103 | torchvision.utils.save_image( 104 | sampled_imgs, 105 | os.path.join(self.experiment_dir, f"generated_{i}.jpg"), 106 | nrow=4, 107 | ) 108 | 109 | -------------------------------------------------------------------------------- /trainer/transformer.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import torchvision 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from aim import Run, Image 8 | 9 | 10 | class TransformerTrainer: 11 | def __init__( 12 | self, 13 | model: nn.Module, 14 | run: Run, 15 | experiment_dir: str = "experiments", 16 | device: str = "cuda", 17 | learning_rate: float = 4.5e-06, 18 | beta1: float = 0.9, 19 | beta2: float = 0.95, 20 | ): 21 | self.run = run 22 | self.experiment_dir = experiment_dir 23 | 24 | self.model = model 25 | self.device = device 26 | self.optim = self.configure_optimizers( 27 | learning_rate=learning_rate, beta1=beta1, beta2=beta2 28 | ) 29 | 30 | def configure_optimizers( 31 | self, learning_rate: float = 4.5e-06, beta1: float = 0.9, beta2: float = 0.95 32 | ): 33 | decay, no_decay = set(), set() 34 | whitelist_weight_modules = (nn.Linear,) 35 | blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) 36 | 37 | # Enabling weight decay to only certain layers 38 | for mn, m in self.model.transformer.named_modules(): 39 | for pn, p in m.named_parameters(): 40 | fpn = f"{mn}.{pn}" if mn else pn 41 | 42 | if pn.endswith("bias"): 43 | no_decay.add(fpn) 44 | 45 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 46 | decay.add(fpn) 47 | 48 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 49 | no_decay.add(fpn) 50 | 51 | no_decay.add("pos_emb") 52 | 53 | param_dict = {pn: p for pn, p in self.model.transformer.named_parameters()} 54 | 55 | optim_groups = [ 56 | { 57 | "params": [param_dict[pn] for pn in sorted(list(decay))], 58 | "weight_decay": 0.01, 59 | }, 60 | { 61 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 62 | "weight_decay": 0.0, 63 | }, 64 | ] 65 | 66 | optimizer = torch.optim.AdamW( 67 | optim_groups, lr=learning_rate, betas=(beta1, beta2) 68 | ) 69 | return optimizer 70 | 71 | def train(self, dataloader: torch.utils.data.DataLoader, epochs: int): 72 | for epoch in range(epochs): 73 | 74 | for index, imgs in enumerate(dataloader): 75 | self.optim.zero_grad() 76 | imgs = imgs.to(device=self.device) 77 | logits, targets = self.model(imgs) 78 | loss = F.cross_entropy( 79 | logits.reshape(-1, logits.size(-1)), targets.reshape(-1) 80 | ) 81 | loss.backward() 82 | self.optim.step() 83 | 84 | self.run.track( 85 | loss, 86 | name="Cross Entropy Loss", 87 | step=index, 88 | context={"stage": "transformer"}, 89 | ) 90 | 91 | if index % 10 == 0: 92 | print( 93 | f"Epoch: {epoch+1}/{epochs} | Batch: {index}/{len(dataloader)} | Cross Entropy Loss : {loss:.4f}" 94 | ) 95 | 96 | _, sampled_imgs = self.model.log_images(imgs[0][None]) 97 | 98 | self.run.track( 99 | Image( 100 | torchvision.utils.make_grid(sampled_imgs) 101 | .mul(255) 102 | .add_(0.5) 103 | .clamp_(0, 255) 104 | .permute(1, 2, 0) 105 | .to("cpu", torch.uint8) 106 | .numpy() 107 | ), 108 | name="Transformer Images", 109 | step=index, 110 | context={"stage": "transformer"}, 111 | ) 112 | -------------------------------------------------------------------------------- /trainer/vqgan.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/dome272/VQGAN-pytorch/blob/main/training_vqgan.py 3 | """ 4 | 5 | # Importing Libraries 6 | import os 7 | 8 | import imageio 9 | import lpips 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torchvision 14 | from aim import Image, Run 15 | from utils import weights_init 16 | from vqgan import Discriminator 17 | 18 | 19 | class VQGANTrainer: 20 | """Trainer class for VQGAN, contains step, train methods""" 21 | 22 | def __init__( 23 | self, 24 | model: torch.nn.Module, 25 | run: Run, 26 | # Training parameters 27 | device: str or torch.device = "cuda", 28 | learning_rate: float = 2.25e-05, 29 | beta1: float = 0.5, 30 | beta2: float = 0.9, 31 | # Loss parameters 32 | perceptual_loss_factor: float = 1.0, 33 | rec_loss_factor: float = 1.0, 34 | # Discriminator parameters 35 | disc_factor: float = 1.0, 36 | disc_start: int = 100, 37 | # Miscellaneous parameters 38 | experiment_dir: str = "./experiments", 39 | perceptual_model: str = "vgg", 40 | save_every: int = 10, 41 | ): 42 | 43 | self.run = run 44 | self.device = device 45 | 46 | # VQGAN parameters 47 | self.vqgan = model 48 | 49 | # Discriminator parameters 50 | self.discriminator = Discriminator(image_channels=self.vqgan.img_channels).to( 51 | self.device 52 | ) 53 | self.discriminator.apply(weights_init) 54 | 55 | # Loss parameters 56 | self.perceptual_loss = lpips.LPIPS(net=perceptual_model).to(self.device) 57 | 58 | # Optimizers 59 | self.opt_vq, self.opt_disc = self.configure_optimizers( 60 | learning_rate=learning_rate, beta1=beta1, beta2=beta2 61 | ) 62 | 63 | # Hyperprameters 64 | self.disc_factor = disc_factor 65 | self.disc_start = disc_start 66 | self.perceptual_loss_factor = perceptual_loss_factor 67 | self.rec_loss_factor = rec_loss_factor 68 | 69 | # Save directory 70 | self.expriment_save_dir = experiment_dir 71 | 72 | # Miscellaneous 73 | self.global_step = 0 74 | self.sample_batch = None 75 | self.gif_images = [] 76 | self.save_every = save_every 77 | 78 | def configure_optimizers( 79 | self, learning_rate: float = 2.25e-05, beta1: float = 0.5, beta2: float = 0.9 80 | ): 81 | opt_vq = torch.optim.Adam( 82 | list(self.vqgan.encoder.parameters()) 83 | + list(self.vqgan.decoder.parameters()) 84 | + list(self.vqgan.codebook.parameters()) 85 | + list(self.vqgan.quant_conv.parameters()) 86 | + list(self.vqgan.post_quant_conv.parameters()), 87 | lr=learning_rate, 88 | eps=1e-08, 89 | betas=(beta1, beta2), 90 | ) 91 | opt_disc = torch.optim.Adam( 92 | self.discriminator.parameters(), 93 | lr=learning_rate, 94 | eps=1e-08, 95 | betas=(beta1, beta2), 96 | ) 97 | 98 | return opt_vq, opt_disc 99 | 100 | def step(self, imgs: torch.Tensor) -> torch.Tensor: 101 | """Performs a single training step from the dataloader images batch 102 | 103 | For the VQGAN, it calculates the perceptual loss, reconstruction loss, and the codebook loss and does the backward pass. 104 | 105 | For the discriminator, it calculates lambda for the discriminator loss and does the backward pass. 106 | 107 | Args: 108 | imgs: input tensor of shape (batch_size, channel, H, W) 109 | 110 | Returns: 111 | decoded_imgs: output tensor of shape (batch_size, channel, H, W) 112 | """ 113 | 114 | # Getting decoder output 115 | decoded_images, _, q_loss = self.vqgan(imgs) 116 | 117 | """ 118 | ======================================================================================================================= 119 | VQ Loss 120 | """ 121 | perceptual_loss = self.perceptual_loss(imgs, decoded_images) 122 | rec_loss = torch.abs(imgs - decoded_images) 123 | perceptual_rec_loss = ( 124 | self.perceptual_loss_factor * perceptual_loss 125 | + self.rec_loss_factor * rec_loss 126 | ) 127 | perceptual_rec_loss = perceptual_rec_loss.mean() 128 | 129 | """ 130 | ======================================================================================================================= 131 | Discriminator Loss 132 | """ 133 | disc_real = self.discriminator(imgs) 134 | disc_fake = self.discriminator(decoded_images) 135 | 136 | disc_factor = self.vqgan.adopt_weight( 137 | self.disc_factor, self.global_step, threshold=self.disc_start 138 | ) 139 | 140 | g_loss = -torch.mean(disc_fake) 141 | 142 | λ = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss) 143 | vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss 144 | 145 | d_loss_real = torch.mean(F.relu(1.0 - disc_real)) 146 | d_loss_fake = torch.mean(F.relu(1.0 + disc_fake)) 147 | gan_loss = disc_factor * 0.5 * (d_loss_real + d_loss_fake) 148 | 149 | # ====================================================================================================================== 150 | # Tracking metrics 151 | 152 | self.run.track( 153 | perceptual_rec_loss, 154 | name="Perceptual & Reconstruction loss", 155 | step=self.global_step, 156 | context={"stage": "vqgan"}, 157 | ) 158 | 159 | self.run.track( 160 | vq_loss, name="VQ Loss", step=self.global_step, context={"stage": "vqgan"} 161 | ) 162 | self.run.track( 163 | gan_loss, name="GAN Loss", step=self.global_step, context={"stage": "vqgan"} 164 | ) 165 | 166 | # ======================================================================================================================= 167 | # Backpropagation 168 | 169 | self.opt_vq.zero_grad() 170 | vq_loss.backward( 171 | retain_graph=True 172 | ) # retain_graph is used to retain the computation graph for the discriminator loss 173 | 174 | self.opt_disc.zero_grad() 175 | gan_loss.backward() 176 | 177 | self.opt_vq.step() 178 | self.opt_disc.step() 179 | 180 | return decoded_images, vq_loss, gan_loss 181 | 182 | def train( 183 | self, 184 | dataloader: torch.utils.data.DataLoader, 185 | epochs: int = 1, 186 | ): 187 | """Trains the VQGAN for the given number of epochs 188 | 189 | Args: 190 | dataloader (torch.utils.data.DataLoader): dataloader to use. 191 | epochs (int, optional): number of epochs to train for. Defaults to 100. 192 | """ 193 | 194 | for epoch in range(epochs): 195 | for index, imgs in enumerate(dataloader): 196 | 197 | # Training step 198 | imgs = imgs.to(self.device) 199 | 200 | decoded_images, vq_loss, gan_loss = self.step(imgs) 201 | 202 | # Updating global step 203 | self.global_step += 1 204 | 205 | if index % self.save_every == 0: 206 | 207 | print( 208 | f"Epoch: {epoch+1}/{epochs} | Batch: {index}/{len(dataloader)} | VQ Loss : {vq_loss:.4f} | Discriminator Loss: {gan_loss:.4f}" 209 | ) 210 | 211 | # Only saving the gif for the first 2000 save steps 212 | if self.global_step // self.save_every <= 2000: 213 | self.sample_batch = ( 214 | imgs[:] if self.sample_batch is None else self.sample_batch 215 | ) 216 | 217 | with torch.no_grad(): 218 | 219 | """ 220 | Note : Lots of efficiency & cleaning needed here 221 | """ 222 | 223 | gif_img = ( 224 | torchvision.utils.make_grid( 225 | torch.cat( 226 | ( 227 | self.sample_batch, 228 | self.vqgan(self.sample_batch)[0], 229 | ), 230 | ) 231 | ) 232 | .detach() 233 | .cpu() 234 | .permute(1, 2, 0) 235 | .numpy() 236 | ) 237 | 238 | gif_img = (gif_img - gif_img.min()) * ( 239 | 255 / (gif_img.max() - gif_img.min()) 240 | ) 241 | gif_img = gif_img.astype(np.uint8) 242 | 243 | self.run.track( 244 | Image( 245 | torchvision.utils.make_grid( 246 | torch.cat( 247 | ( 248 | imgs, 249 | decoded_images, 250 | ), 251 | ) 252 | ).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 253 | ), 254 | name="VQGAN Reconstruction", 255 | step=self.global_step, 256 | context={"stage": "vqgan"}, 257 | ) 258 | 259 | self.gif_images.append(gif_img) 260 | 261 | imageio.mimsave( 262 | os.path.join(self.expriment_save_dir, "reconstruction.gif"), 263 | self.gif_images, 264 | fps=5, 265 | ) 266 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import VQGANTransformer -------------------------------------------------------------------------------- /transformer/mingpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | from : https://github.com/dome272/VQGAN-pytorch/blob/main/mingpt.py 3 | which is taken from: https://github.com/karpathy/minGPT/ 4 | GPT model: 5 | - the initial stem consists of a combination of token encoding and a positional encoding 6 | - the meat of it is a uniform sequence of Transformer blocks 7 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 8 | - all blocks feed into a central residual pathway similar to resnets 9 | - the final decoder is a linear projection into a vanilla Softmax classifier 10 | """ 11 | 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | class GPTConfig: 19 | """ base GPT config, params common to all GPT versions """ 20 | embd_pdrop = 0.1 21 | resid_pdrop = 0.1 22 | attn_pdrop = 0.1 23 | 24 | def __init__(self, vocab_size, block_size, **kwargs): 25 | self.vocab_size = vocab_size 26 | self.block_size = block_size 27 | for k, v in kwargs.items(): 28 | setattr(self, k, v) 29 | 30 | 31 | class CausalSelfAttention(nn.Module): 32 | """ 33 | A vanilla multi-head masked self-attention layer with a projection at the end. 34 | It is possible to use torch.nn.MultiheadAttention here but I am including an 35 | explicit implementation here to show that there is nothing too scary here. 36 | """ 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | assert config.n_embd % config.n_head == 0 41 | # key, query, value projections for all heads 42 | self.key = nn.Linear(config.n_embd, config.n_embd) 43 | self.query = nn.Linear(config.n_embd, config.n_embd) 44 | self.value = nn.Linear(config.n_embd, config.n_embd) 45 | # regularization 46 | self.attn_drop = nn.Dropout(config.attn_pdrop) 47 | self.resid_drop = nn.Dropout(config.resid_pdrop) 48 | # output projection 49 | self.proj = nn.Linear(config.n_embd, config.n_embd) 50 | # causal mask to ensure that attention is only applied to the left in the input sequence 51 | mask = torch.tril(torch.ones(config.block_size, 52 | config.block_size)) 53 | if hasattr(config, "n_unmasked"): 54 | mask[:config.n_unmasked, :config.n_unmasked] = 1 55 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 56 | self.n_head = config.n_head 57 | 58 | def forward(self, x, layer_past=None): 59 | B, T, C = x.size() 60 | 61 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 62 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 63 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 64 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 65 | 66 | present = torch.stack((k, v)) 67 | if layer_past is not None: 68 | past_key, past_value = layer_past 69 | k = torch.cat((past_key, k), dim=-2) 70 | v = torch.cat((past_value, v), dim=-2) 71 | 72 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 73 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 74 | if layer_past is None: 75 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 76 | 77 | att = F.softmax(att, dim=-1) 78 | att = self.attn_drop(att) 79 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 80 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 81 | 82 | # output projection 83 | y = self.resid_drop(self.proj(y)) 84 | return y, present # TODO: check that this does not break anything 85 | 86 | 87 | class Block(nn.Module): 88 | """ an unassuming Transformer block """ 89 | 90 | def __init__(self, config): 91 | super().__init__() 92 | self.ln1 = nn.LayerNorm(config.n_embd) 93 | self.ln2 = nn.LayerNorm(config.n_embd) 94 | self.attn = CausalSelfAttention(config) 95 | self.mlp = nn.Sequential( 96 | nn.Linear(config.n_embd, 4 * config.n_embd), 97 | nn.GELU(), # nice 98 | nn.Linear(4 * config.n_embd, config.n_embd), 99 | nn.Dropout(config.resid_pdrop), 100 | ) 101 | 102 | def forward(self, x, layer_past=None, return_present=False): 103 | # TODO: check that training still works 104 | if return_present: 105 | assert not self.training 106 | # layer past: tuple of length two with B, nh, T, hs 107 | attn, present = self.attn(self.ln1(x), layer_past=layer_past) 108 | 109 | x = x + attn 110 | x = x + self.mlp(self.ln2(x)) 111 | if layer_past is not None or return_present: 112 | return x, present 113 | return x 114 | 115 | 116 | class GPT(nn.Module): 117 | """ the full GPT language model, with a context size of block_size """ 118 | 119 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, 120 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): 121 | super().__init__() 122 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, 123 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 124 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 125 | n_unmasked=n_unmasked) 126 | # input embedding stem 127 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 128 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) # 512 x 1024 129 | self.drop = nn.Dropout(config.embd_pdrop) 130 | # transformer 131 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 132 | # decoder head 133 | self.ln_f = nn.LayerNorm(config.n_embd) 134 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 135 | 136 | self.block_size = config.block_size 137 | self.apply(self._init_weights) 138 | self.config = config 139 | 140 | def get_block_size(self): 141 | return self.block_size 142 | 143 | def _init_weights(self, module): 144 | if isinstance(module, (nn.Linear, nn.Embedding)): 145 | module.weight.data.normal_(mean=0.0, std=0.02) 146 | if isinstance(module, nn.Linear) and module.bias is not None: 147 | module.bias.data.zero_() 148 | elif isinstance(module, nn.LayerNorm): 149 | module.bias.data.zero_() 150 | module.weight.data.fill_(1.0) 151 | 152 | def forward(self, idx, embeddings=None): 153 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 154 | 155 | if embeddings is not None: # prepend explicit embeddings 156 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 157 | 158 | t = token_embeddings.shape[1] 159 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 160 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 161 | x = self.drop(token_embeddings + position_embeddings) 162 | x = self.blocks(x) 163 | x = self.ln_f(x) 164 | logits = self.head(x) 165 | 166 | return logits, None 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /transformer/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/dome272/VQGAN-pytorch/blob/main/transformer.py 3 | """ 4 | 5 | # Importing Libraries 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from transformer.mingpt import GPT 10 | 11 | 12 | class VQGANTransformer(nn.Module): 13 | def __init__( 14 | self, 15 | vqgan: nn.Module, 16 | device: str = "cuda", 17 | sos_token: int = 0, 18 | pkeep: float = 0.5, 19 | block_size: int = 512, 20 | n_layer: int = 12, 21 | n_head: int = 16, 22 | n_embd: int = 1024, 23 | ): 24 | super().__init__() 25 | 26 | self.sos_token = sos_token 27 | self.device = device 28 | 29 | self.vqgan = vqgan 30 | 31 | self.transformer = GPT( 32 | vocab_size=self.vqgan.num_codebook_vectors, 33 | block_size=block_size, 34 | n_layer=n_layer, 35 | n_head=n_head, 36 | n_embd=n_embd, 37 | ) 38 | 39 | self.pkeep = pkeep 40 | 41 | @torch.no_grad() 42 | def encode_to_z(self, x: torch.tensor) -> torch.tensor: 43 | """Processes the input batch ( containing images ) to encoder and returning flattened quantized encodings 44 | 45 | Args: 46 | x (torch.tensor): the input batch b*c*h*w 47 | 48 | Returns: 49 | torch.tensor: the flattened quantized encodings 50 | """ 51 | quant_z, indices, _ = self.vqgan.encode(x) 52 | indices = indices.view(quant_z.shape[0], -1) 53 | return quant_z, indices 54 | 55 | @torch.no_grad() 56 | def z_to_image( 57 | self, indices: torch.tensor, p1: int = 16, p2: int = 16 58 | ) -> torch.Tensor: 59 | """Returns the decoded image from the indices for the codebook embeddings 60 | 61 | Args: 62 | indices (torch.tensor): the indices of the vectors in codebook to use for generating the decoder output 63 | p1 (int, optional): encoding size. Defaults to 16. 64 | p2 (int, optional): encoding size. Defaults to 16. 65 | 66 | Returns: 67 | torch.tensor: generated image from decoder 68 | """ 69 | 70 | ix_to_vectors = self.vqgan.codebook.codebook(indices).reshape( 71 | indices.shape[0], p1, p2, 256 72 | ) 73 | ix_to_vectors = ix_to_vectors.permute(0, 3, 1, 2) 74 | image = self.vqgan.decode(ix_to_vectors) 75 | return image 76 | 77 | def forward(self, x:torch.Tensor) -> torch.Tensor: 78 | """ 79 | transformer model forward pass 80 | 81 | Args: 82 | x (torch.tensor): Batch of images 83 | """ 84 | 85 | # Getting the codebook indices of the image 86 | _, indices = self.encode_to_z(x) 87 | 88 | # sos tokens, this will be needed when we will generate new and unseen images 89 | sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token 90 | sos_tokens = sos_tokens.long().to(self.device) 91 | 92 | # Generating a matrix of shape indices with 1s and 0s 93 | mask = torch.bernoulli( 94 | self.pkeep * torch.ones(indices.shape, device=indices.device) 95 | ) # torch.bernoulli([0.5 ... 0.5]) -> [1, 0, 1, 1, 0, 0] ; p(1) - 0.5 96 | mask = mask.round().to(dtype=torch.int64) 97 | 98 | # Generate a vector containing randomlly indices 99 | random_indices = torch.randint_like( 100 | indices, high=self.transformer.config.vocab_size 101 | ) # generating indices from the distribution 102 | 103 | """ 104 | indices - [3, 56, 72, 67, 45, 53, 78, 90] 105 | mask - [1, 1, 0, 0, 1, 1, 1, 0] 106 | random_indices - 15, 67, 27, 89, 92, 40, 91, 10] 107 | 108 | new_indices - [ 3, 56, 0, 0, 45, 53, 78, 0] + [ 0, 0, 27, 89, 0, 0, 0, 10] => [ 3, 56, 27, 89, 45, 53, 78, 10] 109 | """ 110 | new_indices = mask * indices + (1 - mask) * random_indices 111 | 112 | # Adding sos ( start of sentence ) token 113 | new_indices = torch.cat((sos_tokens, new_indices), dim=1) 114 | 115 | target = indices 116 | 117 | logits, _ = self.transformer(new_indices[:, :-1]) 118 | 119 | return logits, target 120 | 121 | def top_k_logits(self, logits: torch.Tensor, k: int) -> torch.Tensor: 122 | """ 123 | 124 | Args: 125 | logits (torch.Tensor): predictions from the transformer 126 | k (int): returning k highest values 127 | 128 | Returns: 129 | torch.Tensor: retuning tensor of same dimension as input keeping the top k entries 130 | """ 131 | v, ix = torch.topk(logits, k) 132 | out = logits.clone() 133 | out[out < v[..., [-1]]] = -float( 134 | "inf" 135 | ) # Setting all values except in topk to inf 136 | return out 137 | 138 | @torch.no_grad() 139 | def sample( 140 | self, 141 | x: torch.Tensor, 142 | c: torch.Tensor, 143 | steps: int = 256, 144 | temperature: float = 1.0, 145 | top_k: int = 100, 146 | ) -> torch.Tensor: 147 | """Generating sample indices from the transformer 148 | 149 | Args: 150 | x (torch.Tensor): the batch of images 151 | c (torch.Tensor): sos token 152 | steps (int, optional): the lenght of indices to generate. Defaults to 256. 153 | temperature (float, optional): hyperparameter for minGPT model. Defaults to 1.0. 154 | top_k (int, optional): keeping top k entries. Defaults to 100. 155 | 156 | Returns: 157 | torch.Tensor: _description_ 158 | """ 159 | 160 | self.transformer.eval() 161 | 162 | x = torch.cat((c, x), dim=1) # Appending sos token 163 | for k in range(steps): 164 | logits, _ = self.transformer(x) # Getting the predicted indices 165 | logits = ( 166 | logits[:, -1, :] / temperature 167 | ) # Getting the last prediction and scaling it by temperature 168 | 169 | if top_k is not None: 170 | logits = self.top_k_logits(logits, top_k) 171 | 172 | probs = F.softmax(logits, dim=-1) 173 | 174 | ix = torch.multinomial( 175 | probs, num_samples=1 176 | ) # Note : not sure what's happening here 177 | 178 | x = torch.cat((x, ix), dim=1) 179 | 180 | x = x[:, c.shape[1] :] # Removing the sos token 181 | self.transformer.train() 182 | return x 183 | 184 | @torch.no_grad() 185 | def log_images(self, x:torch.Tensor): 186 | """ Generating images using the transformer and decoder. Also uses encoder to complete partial images. 187 | 188 | Args: 189 | x (torch.Tensor): batch of images 190 | 191 | Returns: 192 | Retures the input and generated image in dictionary and in a simple concatenated image 193 | """ 194 | log = dict() 195 | 196 | _, indices = self.encode_to_z(x) # Getting the indices of the quantized encoding 197 | sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token 198 | sos_tokens = sos_tokens.long().to(self.device) 199 | 200 | start_indices = indices[:, : indices.shape[1] // 2] 201 | sample_indices = self.sample( 202 | start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1] 203 | ) 204 | half_sample = self.z_to_image(sample_indices) 205 | 206 | start_indices = indices[:, :0] 207 | sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1]) 208 | full_sample = self.z_to_image(sample_indices) 209 | 210 | x_rec = self.z_to_image(indices) 211 | 212 | log["input"] = x 213 | log["rec"] = x_rec 214 | log["half_sample"] = half_sample 215 | log["full_sample"] = full_sample 216 | 217 | return log, torch.concat((x, x_rec, half_sample, full_sample)) 218 | 219 | def load_checkpoint(self, path): 220 | """Loads the checkpoint from the given path.""" 221 | 222 | self.load_state_dict(torch.load(path)) 223 | 224 | def save_checkpoint(self, path): 225 | """Saves the checkpoint to the given path.""" 226 | 227 | torch.save(self.state_dict(), path) 228 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import weights_init, print_summary, generate_gif, clean_directory, reproducibility, collate_fn -------------------------------------------------------------------------------- /utils/assets/aim_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/aim_images.png -------------------------------------------------------------------------------- /utils/assets/aim_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/aim_metrics.png -------------------------------------------------------------------------------- /utils/assets/encoder_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/encoder_arch.png -------------------------------------------------------------------------------- /utils/assets/nonlocalblocks_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/nonlocalblocks_arch.png -------------------------------------------------------------------------------- /utils/assets/patchgan_disc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/patchgan_disc.png -------------------------------------------------------------------------------- /utils/assets/perceptual_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/perceptual_loss.png -------------------------------------------------------------------------------- /utils/assets/reconstruction.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/reconstruction.gif -------------------------------------------------------------------------------- /utils/assets/sliding_window.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/sliding_window.png -------------------------------------------------------------------------------- /utils/assets/stage_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/stage_1.png -------------------------------------------------------------------------------- /utils/assets/stage_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/stage_2.png -------------------------------------------------------------------------------- /utils/assets/vqgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/vqgan.png -------------------------------------------------------------------------------- /utils/assets/vqvae_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/vqvae_arch.png -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import glob 3 | import os 4 | import random 5 | import shutil 6 | 7 | import imageio 8 | import numpy as np 9 | import torch 10 | from torchsummary import summary 11 | 12 | 13 | def print_summary( 14 | model: torch.nn.Module, 15 | input_data: torch.Tensor, 16 | col_names: list = ["input_size", "output_size", "num_params"], 17 | device: str = "cpu", 18 | depth: int = 2, 19 | ): 20 | """ 21 | Prints a summary of the model. 22 | """ 23 | return summary( 24 | model, input_data=input_data, col_names=col_names, device=device, depth=depth 25 | ) 26 | 27 | 28 | def weights_init(m): 29 | """Setting up the weights for the discriminator model. 30 | This is mentioned in the original PatchGAN paper, in page 16, section 6.2 - Training Details 31 | 32 | ``` 33 | All networks were trained from scratch. Weights were initialized from a Gaussian distribution with mean 0 and 34 | standard deviation 0.02. 35 | ``` 36 | 37 | Image-to-Image Translation with Conditional Adversarial Network - https://arxiv.org/pdf/1611.07004v3.pdf 38 | 39 | Args: 40 | m 41 | """ 42 | 43 | classname = m.__class__.__name__ 44 | if classname.find("Conv") != -1: 45 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 46 | elif classname.find("BatchNorm") != -1: 47 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 48 | torch.nn.init.constant_(m.bias.data, 0) 49 | 50 | 51 | def generate_gif(imgs_path: str, save_path: str): 52 | """Generates a gif from a directory of images. 53 | 54 | Args: 55 | imgs_path: Path to the directory of images. 56 | save_path: Path to save the gif. 57 | """ 58 | 59 | with imageio.get_writer(save_path, mode="I") as writer: 60 | for filename in glob.glob(imgs_path + "/*.jpg"): 61 | image = imageio.imread(filename) 62 | writer.append_data(image) 63 | 64 | 65 | def clean_directory(directory: str): 66 | """Cleans a directory. 67 | Args: 68 | directory: Path to the directory. 69 | """ 70 | 71 | if os.path.exists(directory): 72 | shutil.rmtree(directory) 73 | os.mkdir(directory) 74 | 75 | 76 | def reproducibility(seed: int = 42): 77 | """Set the random seed. 78 | 79 | Args: 80 | seed (int): The seed to use. 81 | 82 | Returns: 83 | None 84 | """ 85 | 86 | torch.manual_seed(seed) 87 | torch.cuda.manual_seed(seed) 88 | 89 | np.random.seed(seed) 90 | random.seed(seed) 91 | 92 | 93 | def collate_fn(batch): 94 | """ 95 | Collate function for the dataloader like mnist or cifar10. 96 | """ 97 | 98 | imgs = torch.stack([img[0] for img in batch]) 99 | 100 | return imgs 101 | -------------------------------------------------------------------------------- /vqgan/__init__.py: -------------------------------------------------------------------------------- 1 | from vqgan.encoder import Encoder 2 | from vqgan.decoder import Decoder 3 | from vqgan.codebook import CodeBook 4 | from vqgan.discriminator import Discriminator 5 | from vqgan.vqgan import VQGAN 6 | -------------------------------------------------------------------------------- /vqgan/codebook.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/dome272/VQGAN-pytorch/blob/main/codebook.py 3 | 4 | Contains the implementation of the codebook for the VQGAN. 5 | With each forward pass, it returns the loss, indices of min distance latent vectors between codebook and encoder output and latent vector with minimim distance. 6 | """ 7 | 8 | # Importing Libraries 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class CodeBook(nn.Module): 14 | """ 15 | This is class, we are mostly implemented para 3.1 from the paper, 16 | 17 | We generate the codebook from nn.Embeddings of given size and randomly initialize the weights in uniform distribution. 18 | 19 | The `forward` method is mostly to calculates 20 | 1. the nearest vector in the codebook from the given latent vector by the encoder. 21 | 2. The index of the nearest vector in the codebook. 22 | 3. loss ( from eq. 4 ) ( except reconstruction loss ) 23 | 24 | Args: 25 | num_codebook_vectors (int): Number of codebook vectors. 26 | latent_dim (int): Latent dimension of individual vectors. 27 | beta (int): Beta value for the commitment loss. 28 | """ 29 | 30 | def __init__( 31 | self, num_codebook_vectors: int = 1024, latent_dim: int = 256, beta: int = 0.25 32 | ): 33 | super().__init__() 34 | 35 | self.num_codebook_vectors = num_codebook_vectors 36 | self.latent_dim = latent_dim 37 | self.beta = beta 38 | 39 | # creating the codebook, nn.Embedding here is simply a 2D array mainly for storing our embeddings, it's also learnable 40 | self.codebook = nn.Embedding(num_codebook_vectors, latent_dim) 41 | 42 | # Initializing the weights in codebook in uniform distribution 43 | self.codebook.weight.data.uniform_( 44 | -1 / num_codebook_vectors, 1 / num_codebook_vectors 45 | ) 46 | 47 | def forward(self, z: torch.Tensor) -> torch.Tensor: 48 | """ 49 | Calculates the loss and nearest vector in the codebook from the given latent vector. 50 | 51 | We are mostly implementing the eq 2 and 4 ( except reconstruction loss ) from the paper. 52 | 53 | Args: 54 | z (torch.Tensor): Latent vector. 55 | Returns: 56 | torch.Tensor: Nearest vector in the codebook. 57 | torch.Tensor: Index of the nearest vector in the codebook. 58 | torch.Tensor: Loss ( except reconstruction loss ). 59 | """ 60 | 61 | # Channel to last dimension and copying the tensor to store it in a contiguous ( in a sequence ) way 62 | z = z.permute(0, 2, 3, 1).contiguous() 63 | 64 | z_flattened = z.view( 65 | -1, self.latent_dim 66 | ) # b*h*w * latent_dim, will look similar to codebook in fig 2 of the paper 67 | 68 | # calculating the distance between the z to the vectors in flattened codebook, from eq. 2 69 | # (a - b)^2 = a^2 + b^2 - 2ab 70 | distance = ( 71 | torch.sum( 72 | z_flattened**2, dim=1, keepdim=True 73 | ) # keepdim = True to keep the same original shape after the sum 74 | + torch.sum(self.codebook.weight**2, dim=1) 75 | - 2 76 | * torch.matmul( 77 | z_flattened, self.codebook.weight.t() 78 | ) # 2*dot(z, codebook.T) 79 | ) 80 | 81 | # getting indices of vectors with minimum distance from the codebook 82 | min_distance_indices = torch.argmin(distance, dim=1) 83 | 84 | # getting the corresponding vector from the codebook 85 | z_q = self.codebook(min_distance_indices).view(z.shape) 86 | 87 | """ 88 | this represent the equation 4 from the paper ( except the reconstruction loss ) . Thia loss will then be added 89 | to GAN loss to create the final loss function for VQGAN, eq. 6 in the paper. 90 | 91 | 92 | Note : In the first para of A. Changlog section of the paper, 93 | they found a bug which resulted in beta equal to 1. here https://github.com/CompVis/taming-transformers/issues/57 94 | just a note :) 95 | """ 96 | loss = torch.mean( 97 | (z_q.detach() - z) ** 2 98 | # detach() to avoid calculating gradient while backpropagating 99 | + self.beta 100 | * torch.mean( 101 | (z_q - z.detach()) ** 2 102 | ) # commitment loss, detach() to avoid calculating gradient while backpropagating 103 | ) 104 | 105 | # Not sure why we need this, but it's in the original implementation and mentions for "preserving gradients" 106 | z_q = z + (z_q - z).detach() 107 | 108 | # reshapring to the original shape 109 | z_q = z_q.permute(0, 3, 1, 2) 110 | 111 | return z_q, min_distance_indices, loss 112 | -------------------------------------------------------------------------------- /vqgan/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/dome272/VQGAN-pytorch/blob/main/helper.py 3 | 4 | The file contains Swish, Group Norm, Residual & Non-Local Blocks, Upsample and Downsample layer for VQGAN encoder and decoder blocks. 5 | """ 6 | 7 | # Importing Libraries 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class Swish(nn.Module): 13 | """Swish activation function first introduced in the paper https://arxiv.org/abs/1710.05941v2 14 | It has shown to be working better in many datasets as compares to ReLU. 15 | """ 16 | 17 | def __init__(self) -> None: 18 | super().__init__() 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | 22 | return x * torch.sigmoid(x) 23 | 24 | 25 | class GroupNorm(nn.Module): 26 | """Group Normalization is a method which normalizes the activation of the layer for better results across any batch size. 27 | Note : Weight Standardization is also shown to given better results when added with group norm 28 | 29 | Args: 30 | in_channels (int): Number of channels in the input tensor. 31 | """ 32 | 33 | def __init__(self, in_channels: int) -> None: 34 | super().__init__() 35 | 36 | # num_groups is according to the official code provided by the authors, 37 | # eps is for numerical stability 38 | # i think affine here is enabling learnable param for affine trasnform on calculated mean & standard deviation 39 | self.group_norm = nn.GroupNorm( 40 | num_groups=32, num_channels=in_channels, eps=1e-06, affine=True 41 | ) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | return self.group_norm(x) 45 | 46 | 47 | class ResidualBlock(nn.Module): 48 | """Residual Block from the paper, 49 | group norm -> swish -> conv -> group norm -> swish -> conv -> dropout -> conv -> skip connection 50 | 51 | Args: 52 | in_channels (int): Number of channels in the input tensor. 53 | out_channels (int): Number of channels in the output tensor. 54 | dropout (float): Dropout probability. 55 | """ 56 | 57 | def __init__(self, in_channels:int, out_channels:int, dropout:float=0.0) -> None: 58 | super().__init__() 59 | 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | 63 | self.block = nn.Sequential( 64 | GroupNorm(in_channels), 65 | Swish(), 66 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 67 | GroupNorm(out_channels), 68 | Swish(), 69 | nn.Dropout(dropout), 70 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 71 | ) 72 | 73 | """ 74 | In some cases, the shortcut connection needs to be added 75 | to match the dimension of the input and the output for skip connection 76 | """ 77 | if in_channels != out_channels: 78 | self.conv_shortcut = nn.Conv2d( 79 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 80 | ) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | 84 | # shortcut connection 85 | if self.in_channels != self.out_channels: 86 | return self.conv_shortcut(x) + self.block(x) 87 | else: 88 | return x + self.block(x) 89 | 90 | 91 | class DownsampleBlock(nn.Module): 92 | """ 93 | Down sample block for the encoder. pad -> conv 94 | 95 | Args: 96 | in_channels (int): Number of channels in the input tensor. 97 | """ 98 | 99 | def __init__(self, in_channels:int) -> None: 100 | super().__init__() 101 | 102 | # (0, 1, 0, 1) - pad on left, right, top, bottom, with respective size 103 | self.pad = nn.ConstantPad2d((0, 1, 0, 1), value=0) # and fill value of 0 104 | 105 | self.conv = nn.Conv2d( 106 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 107 | ) 108 | 109 | def forward(self, x: torch.Tensor) -> torch.Tensor: 110 | 111 | x = self.pad(x) 112 | 113 | return self.conv(x) 114 | 115 | 116 | class UpsampleBlock(nn.Module): 117 | """ 118 | Upsample block for the decoder. interpolate -> conv 119 | 120 | Args: 121 | in_channels (int): Number of channels in the input tensor. 122 | """ 123 | 124 | def __init__(self, in_channels:int) -> None: 125 | super().__init__() 126 | 127 | self.conv = nn.Conv2d( 128 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 129 | ) 130 | 131 | def forward(self, x: torch.Tensor) -> torch.Tensor: 132 | 133 | x = torch.nn.functional.interpolate(x, scale_factor=2.0) 134 | 135 | return self.conv(x) 136 | 137 | 138 | class NonLocalBlock(nn.Module): 139 | """Attention mechanism similar to transformers but for CNNs, paper https://arxiv.org/abs/1805.08318 140 | 141 | Args: 142 | in_channels (int): Number of channels in the input tensor. 143 | """ 144 | 145 | def __init__(self, in_channels:int) -> None: 146 | super().__init__() 147 | 148 | self.in_channels = in_channels 149 | 150 | # normalization layer 151 | self.norm = GroupNorm(in_channels) 152 | 153 | # query, key and value layers 154 | self.q = nn.Conv2d(in_channels, in_channels, 1, 1, 0) 155 | self.k = nn.Conv2d(in_channels, in_channels, 1, 1, 0) 156 | self.v = nn.Conv2d(in_channels, in_channels, 1, 1, 0) 157 | 158 | self.project_out = nn.Conv2d(in_channels, in_channels, 1, 1, 0) 159 | 160 | self.softmax = nn.Softmax(dim=2) 161 | 162 | def forward(self, x): 163 | 164 | batch, _, height, width = x.size() 165 | 166 | x = self.norm(x) 167 | 168 | # query, key and value layers 169 | q = self.q(x) 170 | k = self.k(x) 171 | v = self.v(x) 172 | 173 | # resizing the output from 4D to 3D to generate attention map 174 | q = q.reshape(batch, self.in_channels, height * width) 175 | k = k.reshape(batch, self.in_channels, height * width) 176 | v = v.reshape(batch, self.in_channels, height * width) 177 | 178 | # transpose the query tensor for dot product 179 | q = q.permute(0, 2, 1) 180 | 181 | # main attention formula 182 | scores = torch.bmm(q, k) * (self.in_channels**-0.5) 183 | weights = self.softmax(scores) 184 | weights = weights.permute(0, 2, 1) 185 | 186 | attention = torch.bmm(v, weights) 187 | 188 | # resizing the output from 3D to 4D to match the input 189 | attention = attention.reshape(batch, self.in_channels, height, width) 190 | attention = self.project_out(attention) 191 | 192 | # adding the identity to the output 193 | return x + attention 194 | -------------------------------------------------------------------------------- /vqgan/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/dome272/VQGAN-pytorch/blob/main/decoder.py 3 | 4 | Contains the decoder implementation of VQGAN. 5 | 6 | The decoder architecture is also highly inspired by the - Denoising Diffusion Probabilistic Models - https://arxiv.org/abs/2006.11239 7 | According to the official implementation. 8 | """ 9 | 10 | # Importing Libraries 11 | import torch 12 | import torch.nn as nn 13 | 14 | from vqgan.common import GroupNorm, NonLocalBlock, ResidualBlock, Swish, UpsampleBlock 15 | 16 | 17 | class Decoder(nn.Module): 18 | """ 19 | The decoder part of the VQGAN. 20 | 21 | The implementation is similar to the encoder but inverse, to produce an image from a latent vector. 22 | 23 | Args: 24 | img_channels (int): Number of channels in the output image. 25 | latent_channels (int): Number of channels in the latent vector. 26 | latent_size (int): Size of the latent vector. 27 | intermediate_channels (list): List of channels in the intermediate layers. 28 | num_residual_blocks (int): Number of residual blocks b/w each downsample block. 29 | dropout (float): Dropout probability for residual blocks. 30 | attention_resolution (list): tensor size ( height or width ) at which to add attention blocks 31 | """ 32 | 33 | def __init__( 34 | self, 35 | img_channels: int = 3, 36 | latent_channels: int = 256, 37 | latent_size: int = 16, 38 | intermediate_channels: list = [128, 128, 256, 256, 512], 39 | num_residual_blocks: int = 3, 40 | dropout: float = 0.0, 41 | attention_resolution: list = [16], 42 | ): 43 | super().__init__() 44 | 45 | # Reverse the list to get the correct order of decoder layer channels 46 | intermediate_channels = intermediate_channels[::-1] 47 | 48 | # Appends all the layers to this list 49 | layers = [] 50 | 51 | # Adding the first conv layer to increase the input channels to the first intermediate channels 52 | in_channels = intermediate_channels[0] 53 | layers.extend( 54 | [ 55 | nn.Conv2d( 56 | latent_channels, 57 | intermediate_channels[0], 58 | kernel_size=3, 59 | stride=1, 60 | padding=1, 61 | ), 62 | ResidualBlock( 63 | in_channels=in_channels, out_channels=in_channels, dropout=dropout 64 | ), 65 | NonLocalBlock(in_channels=in_channels), 66 | ResidualBlock( 67 | in_channels=in_channels, out_channels=in_channels, dropout=dropout 68 | ), 69 | ] 70 | ) 71 | 72 | # Loop over the intermediate channels 73 | for n in range(len(intermediate_channels)): 74 | out_channels = intermediate_channels[n] 75 | 76 | # adding the residual blocks 77 | for _ in range(num_residual_blocks): 78 | layers.append(ResidualBlock(in_channels, out_channels, dropout=dropout)) 79 | in_channels = out_channels 80 | 81 | # adding the non local block 82 | if latent_size in attention_resolution: 83 | layers.append(NonLocalBlock(in_channels)) 84 | 85 | # Due to conv in first layer, do not upsample 86 | if n != 0: 87 | layers.append(UpsampleBlock(in_channels=in_channels)) 88 | latent_size = latent_size * 2 # Upsample by a factor of 2 89 | 90 | # Adding rest of the layers 91 | layers.extend( 92 | [ 93 | GroupNorm(in_channels=in_channels), 94 | Swish(), 95 | nn.Conv2d( 96 | in_channels, img_channels, kernel_size=3, stride=1, padding=1 97 | ), 98 | ] 99 | ) 100 | self.model = nn.Sequential(*layers) 101 | 102 | def forward(self, x: torch.Tensor) -> torch.Tensor: 103 | return self.model(x) 104 | -------------------------------------------------------------------------------- /vqgan/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538) 3 | 4 | 5 | This isn't a standward GAN discriminator, where the input is a batch of images and the output is a batch of real/fake labels. 6 | 7 | 8 | But instead, PatchGAN discriminator is a network that takes a batch of images 9 | split into multiple patches 10 | 11 | for ex. - 30-30 patches, 30 in x and 30 in y axis, similar to convolution kernels, 12 | and then runs them through a network to get a score of real/fake on those individual patches. 13 | 14 | ex. - input size (1, 3, 256, 256) -> output size (1, 1, 30, 30) 15 | 16 | """ 17 | 18 | import torch.nn as nn 19 | 20 | 21 | class Discriminator(nn.Module): 22 | """ PatchGAN Discriminator 23 | 24 | 25 | Args: 26 | image_channels (int): Number of channels in the input image. 27 | num_filters_last (int): Number of filters in the last layer of the discriminator. 28 | n_layers (int): Number of layers in the discriminator. 29 | 30 | 31 | """ 32 | 33 | def __init__(self, image_channels:int=3, num_filters_last=64, n_layers=3): 34 | super(Discriminator, self).__init__() 35 | 36 | layers = [ 37 | nn.Conv2d(image_channels, num_filters_last, 4, 2, 1), 38 | nn.LeakyReLU(0.2), 39 | ] 40 | num_filters_mult = 1 41 | 42 | for i in range(1, n_layers + 1): 43 | num_filters_mult_last = num_filters_mult 44 | num_filters_mult = min(2**i, 8) 45 | layers += [ 46 | nn.Conv2d( 47 | num_filters_last * num_filters_mult_last, 48 | num_filters_last * num_filters_mult, 49 | 4, 50 | 2 if i < n_layers else 1, 51 | 1, 52 | bias=False, 53 | ), 54 | nn.BatchNorm2d(num_filters_last * num_filters_mult), 55 | nn.LeakyReLU(0.2, True), 56 | ] 57 | 58 | layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1)) 59 | self.model = nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | return self.model(x) -------------------------------------------------------------------------------- /vqgan/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/dome272/VQGAN-pytorch/blob/main/encoder.py 3 | 4 | Contains the encoder implementation of VQGAN. 5 | 6 | The encoder is highly inspired by the - Denoising Diffusion Probabilistic Models - https://arxiv.org/abs/2006.11239 7 | According to the official implementation. 8 | """ 9 | 10 | # Importing Libraries 11 | import torch 12 | import torch.nn as nn 13 | 14 | from vqgan.common import DownsampleBlock, GroupNorm, NonLocalBlock, ResidualBlock, Swish 15 | 16 | 17 | class Encoder(nn.Module): 18 | """ 19 | The encoder part of the VQGAN. 20 | 21 | Args: 22 | img_channels (int): Number of channels in the input image. 23 | image_size (int): Size of the input image, only used in encoder (height or width ). 24 | latent_channels (int): Number of channels in the latent vector. 25 | intermediate_channels (list): List of channels in the intermediate layers. 26 | num_residual_blocks (int): Number of residual blocks b/w each downsample block. 27 | dropout (float): Dropout probability for residual blocks. 28 | attention_resolution (list): tensor size ( height or width ) at which to add attention blocks 29 | """ 30 | 31 | def __init__( 32 | self, 33 | img_channels: int = 3, 34 | image_size: int = 256, 35 | latent_channels: int = 256, 36 | intermediate_channels: list = [128, 128, 256, 256, 512], 37 | num_residual_blocks: int = 2, 38 | dropout: float = 0.0, 39 | attention_resolution: list = [16], 40 | ): 41 | super().__init__() 42 | 43 | # Inserting first intermediate channel to index 0 44 | intermediate_channels.insert(0, intermediate_channels[0]) 45 | 46 | # Appends all the layers to this list 47 | layers = [] 48 | 49 | # Addingt the first conv layer increase input channels to the first intermediate channels 50 | layers.append( 51 | nn.Conv2d( 52 | img_channels, 53 | intermediate_channels[0], 54 | kernel_size=3, 55 | stride=1, 56 | padding=1, 57 | ) 58 | ) 59 | 60 | # Loop over the intermediate channels except the last one 61 | for n in range(len(intermediate_channels) - 1): 62 | in_channels = intermediate_channels[n] 63 | out_channels = intermediate_channels[n + 1] 64 | 65 | # Adding the residual blocks for each channel 66 | for _ in range(num_residual_blocks): 67 | layers.append(ResidualBlock(in_channels, out_channels, dropout=dropout)) 68 | in_channels = out_channels 69 | 70 | # Once we have downsampled the image to the size in attention resolution, we add attention blocks 71 | if image_size in attention_resolution: 72 | layers.append(NonLocalBlock(in_channels)) 73 | 74 | # only downsample for the first n-2 layers, and decrease the input size by a factor of 2 75 | if n != len(intermediate_channels) - 2: 76 | layers.append(DownsampleBlock(in_channels=intermediate_channels[n + 1])) 77 | image_size = image_size // 2 # Downsample by a factor of 2 78 | 79 | in_channels = intermediate_channels[-1] 80 | layers.extend( 81 | [ 82 | ResidualBlock( 83 | in_channels=in_channels, out_channels=in_channels, dropout=dropout 84 | ), 85 | NonLocalBlock(in_channels=in_channels), 86 | ResidualBlock( 87 | in_channels=in_channels, out_channels=in_channels, dropout=dropout 88 | ), 89 | GroupNorm(in_channels=in_channels), 90 | Swish(), 91 | # increase the channels upto the latent vector channels 92 | nn.Conv2d( 93 | in_channels, latent_channels, kernel_size=3, stride=1, padding=1 94 | ), 95 | ] 96 | ) 97 | self.model = nn.Sequential(*layers) 98 | 99 | def forward(self, x: torch.Tensor) -> torch.Tensor: 100 | return self.model(x) 101 | -------------------------------------------------------------------------------- /vqgan/vqgan.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/dome272/VQGAN-pytorch/blob/main/vqgan.py 3 | 4 | Implementing the main VQGAN, containing forward pass, lambda calculation, and to "enable" discriminator loss after a certain number of global steps. 5 | """ 6 | 7 | # Importing Libraries 8 | import torch 9 | import torch.nn as nn 10 | 11 | from vqgan import Encoder 12 | from vqgan import Decoder 13 | from vqgan import CodeBook 14 | 15 | 16 | class VQGAN(nn.Module): 17 | """ 18 | VQGAN class 19 | 20 | Args: 21 | img_channels (int, optional): Number of channels in the input image. Defaults to 3. 22 | img_size (int, optional): Size of the input image. Defaults to 256. 23 | latent_channels (int, optional): Number of channels in the latent vector. Defaults to 256. 24 | latent_size (int, optional): Size of the latent vector. Defaults to 16. 25 | intermediate_channels (list, optional): List of channels in the intermediate layers of encoder and decoder. Defaults to [128, 128, 256, 256, 512]. 26 | num_residual_blocks_encoder (int, optional): Number of residual blocks in the encoder. Defaults to 2. 27 | num_residual_blocks_decoder (int, optional): Number of residual blocks in the decoder. Defaults to 3. 28 | dropout (float, optional): Dropout probability. Defaults to 0.0. 29 | attention_resolution (list, optional): Resolution of the attention mechanism. Defaults to [16]. 30 | num_codebook_vectors (int, optional): Number of codebook vectors. Defaults to 1024. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | img_channels: int = 3, 36 | img_size: int = 256, 37 | latent_channels: int = 256, 38 | latent_size: int = 16, 39 | intermediate_channels: list = [128, 128, 256, 256, 512], 40 | num_residual_blocks_encoder: int = 2, 41 | num_residual_blocks_decoder: int = 3, 42 | dropout: float = 0.0, 43 | attention_resolution: list = [16], 44 | num_codebook_vectors: int = 1024, 45 | ): 46 | 47 | super().__init__() 48 | 49 | self.img_channels = img_channels 50 | self.num_codebook_vectors = num_codebook_vectors 51 | 52 | self.encoder = Encoder( 53 | img_channels=img_channels, 54 | image_size=img_size, 55 | latent_channels=latent_channels, 56 | intermediate_channels=intermediate_channels[:], # shallow copy of the link 57 | num_residual_blocks=num_residual_blocks_encoder, 58 | dropout=dropout, 59 | attention_resolution=attention_resolution, 60 | ) 61 | 62 | self.decoder = Decoder( 63 | img_channels=img_channels, 64 | latent_channels=latent_channels, 65 | latent_size=latent_size, 66 | intermediate_channels=intermediate_channels[:], # shallow copy of the link 67 | num_residual_blocks=num_residual_blocks_decoder, 68 | dropout=dropout, 69 | attention_resolution=attention_resolution, 70 | ) 71 | self.codebook = CodeBook( 72 | num_codebook_vectors=num_codebook_vectors, latent_dim=latent_channels 73 | ) 74 | 75 | self.quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) 76 | self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) 77 | 78 | def forward(self, x: torch.Tensor) -> torch.Tensor: 79 | """Performs a single step of training on the input tensor x 80 | 81 | Args: 82 | x (torch.Tensor): Input tensor to the encoder. 83 | 84 | Returns: 85 | torch.Tensor: Output tensor from the decoder. 86 | """ 87 | 88 | encoded_images = self.encoder(x) 89 | quant_x = self.quant_conv(encoded_images) 90 | 91 | codebook_mapping, codebook_indices, codebook_loss = self.codebook(quant_x) 92 | 93 | post_quant_x = self.post_quant_conv(codebook_mapping) 94 | decoded_images = self.decoder(post_quant_x) 95 | 96 | return decoded_images, codebook_indices, codebook_loss 97 | 98 | def encode(self, x: torch.Tensor) -> torch.Tensor: 99 | 100 | x = self.encoder(x) 101 | quant_x = self.quant_conv(x) 102 | 103 | codebook_mapping, codebook_indices, q_loss = self.codebook(quant_x) 104 | 105 | return codebook_mapping, codebook_indices, q_loss 106 | 107 | def decode(self, x: torch.Tensor) -> torch.Tensor: 108 | 109 | x = self.post_quant_conv(x) 110 | x = self.decoder(x) 111 | 112 | return x 113 | 114 | def calculate_lambda(self, perceptual_loss, gan_loss): 115 | """Calculating lambda shown in the eq. 7 of the paper 116 | 117 | Args: 118 | perceptual_loss (torch.Tensor): Perceptual reconstruction loss. 119 | gan_loss (torch.Tensor): loss from the GAN discriminator. 120 | """ 121 | 122 | last_layer = self.decoder.model[-1] 123 | last_layer_weight = last_layer.weight 124 | 125 | # Because we have multiple loss functions in the networks, retain graph helps to keep the computational graph for backpropagation 126 | # https://stackoverflow.com/a/47174709 127 | perceptual_loss_grads = torch.autograd.grad( 128 | perceptual_loss, last_layer_weight, retain_graph=True 129 | )[0] 130 | gan_loss_grads = torch.autograd.grad( 131 | gan_loss, last_layer_weight, retain_graph=True 132 | )[0] 133 | 134 | lmda = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4) 135 | lmda = torch.clamp( 136 | lmda, 0, 1e4 137 | ).detach() # Here, we are constraining the value of lambda between 0 and 1e4, 138 | 139 | return 0.8 * lmda # Note: not sure why we are multiplying it by 0.8... ? 140 | 141 | @staticmethod 142 | def adopt_weight( 143 | disc_factor: float, i: int, threshold: int, value: float = 0.0 144 | ) -> float: 145 | """Starting the discrimator later in training, so that our model has enough time to generate "good-enough" images to try to "fool the discrimator". 146 | 147 | To do that, we before eaching a certain global step, set the discriminator factor by `value` ( default 0.0 ) . 148 | This discriminator factor is then used to multiply the discriminator's loss. 149 | 150 | Args: 151 | disc_factor (float): This value is multiple to the discriminator's loss. 152 | i (int): The current global step 153 | threshold (int): The global step after which the `disc_factor` value is retured. 154 | value (float, optional): The value of discriminator factor before the threshold is reached. Defaults to 0.0. 155 | 156 | Returns: 157 | float: The discriminator factor. 158 | """ 159 | 160 | if i < threshold: 161 | disc_factor = value 162 | 163 | return disc_factor 164 | 165 | def load_checkpoint(self, path): 166 | """Loads the checkpoint from the given path.""" 167 | 168 | self.load_state_dict(torch.load(path)) 169 | 170 | def save_checkpoint(self, path): 171 | """Saves the checkpoint to the given path.""" 172 | 173 | torch.save(self.state_dict(), path) 174 | --------------------------------------------------------------------------------