├── .github
└── workflows
│ └── main.yml
├── LICENSE
├── README.md
├── configs
└── default.yml
├── dataloader
├── __init__.py
├── cifar10.py
├── load_dataloader.py
└── mnist.py
├── environment.yml
├── experiments
├── generated_0.jpg
├── generated_1.jpg
├── generated_2.jpg
├── generated_3.jpg
├── generated_4.jpg
└── reconstruction.gif
├── generate.py
├── requirements.txt
├── test
├── __init__.py
├── test_dataloader.py
└── test_vqgan.py
├── train.py
├── trainer
├── __init__.py
├── trainer.py
├── transformer.py
└── vqgan.py
├── transformer
├── __init__.py
├── mingpt.py
└── transformer.py
├── utils
├── __init__.py
├── assets
│ ├── aim_images.png
│ ├── aim_metrics.png
│ ├── encoder_arch.png
│ ├── nonlocalblocks_arch.png
│ ├── patchgan_disc.png
│ ├── perceptual_loss.png
│ ├── reconstruction.gif
│ ├── sliding_window.png
│ ├── stage_1.png
│ ├── stage_2.png
│ ├── vqgan.png
│ └── vqvae_arch.png
└── utils.py
└── vqgan
├── __init__.py
├── codebook.py
├── common.py
├── decoder.py
├── discriminator.py
├── encoder.py
└── vqgan.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | # This is a basic workflow to help you get started with Actions
2 |
3 | name: Run Python Tests
4 |
5 | # Controls when the workflow will run
6 | on:
7 | # Triggers the workflow on push or pull request events but only for the "main" branch
8 | push:
9 | branches: [ "main" ]
10 | pull_request:
11 | branches: [ "main" ]
12 |
13 | # Allows you to run this workflow manually from the Actions tab
14 | # workflow_dispatch:
15 |
16 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
17 | jobs:
18 | # This workflow contains a single job called "build"
19 | build:
20 | # The type of runner that the job will run on
21 | runs-on: ubuntu-latest
22 |
23 | # Steps represent a sequence of tasks that will be executed as part of the job
24 | steps:
25 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
26 | - uses: actions/checkout@v3
27 |
28 | - name: Setup Python
29 | uses: actions/setup-python@v4.2.0
30 | with:
31 | # Version range or exact version of Python or PyPy to use, using SemVer's version range syntax. Reads from .python-version if unset.
32 | python-version: 3.7.13
33 |
34 | - name: Install dependencies
35 | run: |
36 | python -m pip install --upgrade pip
37 | pip install -r requirements.txt
38 |
39 | - name: Run tests with pytest
40 | run: pytest --cov=. test --cov-report=xml
41 |
42 | - name: Upload coverage reports to Codecov with GitHub Action
43 | uses: codecov/codecov-action@v3
44 | with:
45 | token: ${{ secrets.CODECOV_TOKEN }}
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Shubhamai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch VQGAN
2 |
3 | [](https://github.com/Shubhamai/pytorch-vqgan/blob/main/LICENSE)
4 | [](https://github.com/Shubhamai/pytorch-vqgan/actions/workflows/main.yml)
5 | [](https://codecov.io/gh/Shubhamai/pytorch-vqgan)
6 |
7 |
8 | 
9 | Figure 1. VQGAN Architecture
10 |
11 |
12 | > **Note:** This is a work in progress.
13 |
14 |
15 | This repo purpose is to serve as a cleaner and feature-rich implementation of the VQGAN - *[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2010.11929)* from the initial work of [dome272's repo](https://github.com/dome272/VQGAN-pytorch) in PyTorch from scratch. There's also a great video on the [explanation of VQGAN](https://youtu.be/wcqLFDXaDO8) by dome272.
16 |
17 | I created this repo to better understand VQGAN myself, and to provide scripts for faster training and experimentation with a toy dataset like MNIST etc. I also tried to make it as clean as possible, with comments, logging, testing & coverage, custom datasets & visualizations, etc.
18 |
19 | - [PyTorch VQGAN](#pytorch-vqgan)
20 | - [What is VQGAN?](#what-is-vqgan)
21 | - [Stage 1](#stage-1)
22 | - [Training](#training)
23 | - [Generation](#generation)
24 | - [Stage 2](#stage-2)
25 | - [Setup](#setup)
26 | - [Usage](#usage)
27 | - [Training](#training-1)
28 | - [Generation](#generation-1)
29 | - [Tests](#tests)
30 | - [Hardware requirements](#hardware-requirements)
31 | - [Shoutouts](#shoutouts)
32 | - [BibTeX](#bibtex)
33 |
34 | ## What is VQGAN?
35 |
36 | VQGAN stands for **V**ector **Q**uantised **G**enerative **A**dversarial **N**etworks. The main idea behind this paper is to use CNN to learn the visual part of the image and generate a codebook of context-rich visual parts and then use Transformers to learn the long-range/global interactions between the visual parts of the image embedded in the codebook. Combining these two, we can generate very high-resolution images.
37 |
38 | Learning both of these short and long-term interactions to generate high-resolution images is done in two different stages.
39 |
40 | 1. The first stage uses VQGAN to learn the codebook of **context-rich** visual representation of the images. In terms of architecture, it is very similar to VQVAE in that it consists of an encoder, decoder and the codebook. We will learn more about this in the next section.
41 |
42 |
43 | 
44 | Figure 2. VQVAE Architecture
45 |
46 |
47 | 2. Using a transformer to learn the global interactions between the vectors in the codebook by predicting the next sequence from the previous sequences, to generate high-resolution images.
48 |
49 | ---
50 |
51 | ### Stage 1
52 |
53 |
54 | 
55 | Stage 1 : VQGAN Architecture
56 |
57 |
58 |
59 | The architecture of VQGAN consists of majorly three parts, the encoder, decoder and the Codebook, similar to the VQVAE paper.
60 |
61 |
62 | 1. The encoder [`encoder.py`](vqgan/encoder.py) part in the VQGAN learns to represent the images into a much lower dimension called embeddings or latent and consists of Convolution, Downsample, Residual blocks and special attention blocks ( Non-Local blocks ), around 30 million parameters in default settings.
63 | 2. The embeddings are then quantized using CodeBook and the quantized embeddings are used as input to the decoder [`decoder.py`](vqgan/decoder.py) part.
64 | 3. The decode takes the "quantized" embeddings and reconstructs the image. The architecture is similar to the encoder but reversed. Around 40 million parameters in default settings, slightly more compared to encoder due to more number of residual blocks.
65 |
66 | The main idea behind codebook and quantization is to convert the continuous latent representation into a discrete representation. The codebook is simply a list of `n` latent vectors ( which are learned while training ) which are then used to replace the latents generated from the encoder output with the closest vector ( in terms of distance ) from the codebook. The **VQ** part comes from here.
67 |
68 | #### Training
69 |
70 | The training involves, sending the batch of images through the encoder, quantizing the embeddings and then sending the quantized embeddings through the decoder to reconstruct the image. The loss function is computed as follows:
71 |
72 |
73 | $`
74 | \begin{aligned}
75 | \mathcal{L}_{\mathrm{VQ}}(E, G, \mathcal{Z})=\|x-\hat{x}\|^{2} &+\left\|\text{sg}[E(x)]-z_{\mathbf{q}}\right\|_{2}^{2}+\left\|\text{sg}\left[z_{\mathbf{q}}\right]-E(x)\right\|_{2}^{2} .
76 | \end{aligned}
77 | `$
78 |
79 | The above equation represents the sum of reconstruction loss, alignment and commitment loss
80 |
81 | 1. Reconstruction loss
82 |
83 | > Appartely there is some confusion about is this reconstruction loss was replaced with perceptual loss or it was a combination of them, we will go with what was implemented in the official code https://github.com/CompVis/taming-transformers/issues/40, which is l1 + perceptual loss
84 |
85 |
86 |
87 |
88 | The reconstruction loss is a sum of the l1 loss and perceptual loss.
89 | $`\text { L1 Loss }=\sum_{i=1}^{n}\left|y_{\text {true }}-y_{\text {predicted }}\right|`$
90 |
91 | The perceptual is calculated the l2 distance between the last layer output of the generated vs original image from pre-trained model like VGG, etc.
92 |
93 | 2. The alignment and commitment loss is from the quantization which compares the distance between the latent vectors from encoder output and the closest vector from the codebook. `sg` here means stop gradient function.
94 |
95 | ---
96 |
97 |
98 |
99 | $`
100 | \mathcal{L}_{\mathrm{GAN}}(\{E, G, \mathcal{Z}\}, D)=[\log D(x)+\log (1-D(\hat{x}))]
101 | `$
102 |
103 | The above loss is for the discriminator which takes in real and generated images and learns to classify which one's real or face. the **GAN** in VQGAN comes from here :)
104 |
105 | The discrimination here is a bit different than conventional discriminators in that, instead of taking whole images as an input, they instead convert the images into patches using convolution and then predict which patch is real or fake.
106 |
107 | ---
108 |
109 | $`
110 | \lambda=\frac{\nabla_{G_{L}}\left[\mathcal{L}_{\mathrm{rec}}\right]}{\nabla_{G_{L}}\left[\mathcal{L}_{\mathrm{GAN}}\right]+\delta}
111 | `$
112 |
113 | We calculate lambda as the ratio between the reconstruction loss and the GAN loss, both with respect to the gradient of the last layer of the decoder. `calculate_lambda` in [`vqgan.py`](vqgan/vqgan.py)
114 |
115 | The final loss then becomes -
116 |
117 | $`
118 | \begin{aligned}
119 | \mathcal{Q}^{*}=\underset{E, G, \mathcal{Z}}{\arg \min } \max _{D} \mathbb{E}_{x \sim p(x)}\left[\mathcal{L}_{\mathrm{VQ}}(E, G, \mathcal{Z})+\lambda \mathcal{L}_{\mathrm{GAN}}(\{E, G, \mathcal{Z}\}, D)\right]
120 | \end{aligned}
121 | `$
122 |
123 | which is the combination of the reconstruction loss, alignment loss and commitment loss and discriminator loss multiplied with `lambda`.
124 |
125 | ### Generation
126 |
127 |
128 | To generate the images from VQGAN, we generate the quantized vectors from [Stage 2](#stage-2) and pass them through the decoder to reconstruct the image.
129 |
130 | ---
131 |
132 | ### Stage 2
133 |
134 |
135 | 
136 | Stage 2: Transformers
137 |
138 |
139 |
140 |
141 |
142 | This stage contains Transformers 🤖 which are trained to predict the next latent vector from the sequence of previous latent vectors in the quantized encoder output. The paper uses [`mingpt.py`](transformer/mingpt.py) from Andrej Karpathy's [karpathy/minGPT](https://github.com/karpathy/minGPT) repo.
143 |
144 | Due to computation constraints of generating high-resolution images, they also use a sliding attention window to predict the next latent vector from its neighbor vectors in the quantized encoder output.
145 |
146 | ## Setup
147 |
148 | 1. Clone the repo - `https://github.com/Shubhamai/pytorch-vqgan`
149 | 2. Create a new conda environment using `conda env create --prefix env python=3.7.13 --file=environment.yml`
150 | 3. Activate the conda environment using `conda activate ./env`
151 |
152 | ## Usage
153 |
154 |
155 |
156 |
157 | ### Training
158 |
159 | - You can start the training by running `python train.py`. It reads the default config file from `configs/default.yml` . To change the config path, run - `python train.py --config_path configs/default.yaml`.
160 |
161 | Here's what mostly the script does -
162 | - Downloads the MNIST dataset automatically and saved in the [data](/data) directory ( specified in config ).
163 | - Training the VQGAN and transformer model on the MNIST train set with parameters passed from the config file.
164 | - The training metrics, visualizations and model are saved in the [experiments/](/experiments/) directory with the corresponding path specified in the config file.
165 |
166 | - Run `aim up` to open the experiment tracker to see the metrics and reconstructed & generated images.
167 |
168 |
169 | 
170 |
171 |
172 |
173 | ### Generation
174 |
175 | To generate the images, simply run `python generate.py`, the models will be loaded from the `experiments/checkpoints` and the output will be saved in `experiments`.
176 |
177 | ### Tests
178 |
179 | I have also just started getting my feet wet with testing and automated testing with GitHub CI/CD, so the tests here might not be the best practices.
180 |
181 | To run tests, run `pytest --cov-config=.coveragerc --cov=. test`
182 |
183 |
184 | ## Hardware requirements
185 |
186 | The hardware which I tried the model on default settings is -
187 | - Ryzen 5 4600H
188 | - NVIDIA GeForce GTX 1660Ti - 6 GB VRAM
189 | - 12 GB ram
190 |
191 | It took around 2-3 min to get good reconstruction results. Since, google colab has similar hardware in terms compute power from what I understand, it should run just fine on colab :)
192 |
193 |
194 | ## Shoutouts
195 |
196 | The list here contains some helpful blogs or videos that helped me a bunch in understanding the VQGAN.
197 |
198 | 1. [The Illustrated VQGAN](https://ljvmiranda921.github.io/notebook/2021/08/08/clip-vqgan/) by Lj Miranda
199 | 2. [VQGAN: Taming Transformers for High-Resolution Image Synthesis [Paper Explained]](https://youtu.be/-wDSDtIAyWQ) by Gradient Dude
200 | 3. [VQ-GAN: Taming Transformers for High-Resolution Image Synthesis | Paper Explained](https://youtu.be/j2PXES-liuc) by The AI Epiphany
201 | 4. [VQ-GAN | Paper Explanation](https://youtu.be/wcqLFDXaDO8) and [VQ-GAN | PyTorch Implementation](https://youtu.be/_Br5WRwUz_U) by Outlier
202 | 5. [TL#006 Robin Rombach Taming Transformers for High Resolution Image Synthesis](https://youtu.be/fy153-yXSQk) by one of the paper's author - Robin Rombach. Thanks for the talk :)
203 |
204 | ## BibTeX
205 |
206 | ```
207 | @misc{esser2020taming,
208 | title={Taming Transformers for High-Resolution Image Synthesis},
209 | author={Patrick Esser and Robin Rombach and Björn Ommer},
210 | year={2020},
211 | eprint={2012.09841},
212 | archivePrefix={arXiv},
213 | primaryClass={cs.CV}
214 | }
215 | ```
216 |
--------------------------------------------------------------------------------
/configs/default.yml:
--------------------------------------------------------------------------------
1 | architecture:
2 | vqgan:
3 | img_channels: 1
4 | img_size: 256
5 | latent_channels: 256
6 | latent_size: 16
7 | intermediate_channels: [128, 128, 256, 256, 512]
8 | num_residual_blocks_encoder: 2
9 | num_residual_blocks_decoder: 3
10 | dropout: 0.0
11 | attention_resolution: [16]
12 | num_codebook_vectors: 1024
13 |
14 | transformer:
15 | sos_token: 0
16 | pkeep: 0.5
17 | block_size: 512
18 | n_layer: 12
19 | n_head: 16
20 | n_embd: 1024
21 |
22 | trainer:
23 | vqgan:
24 | learning_rate: 2.25e-05
25 | beta1: 0.5
26 | beta2: 0.9
27 | perceptual_loss_factor: 1.0
28 | rec_loss_factor: 1.0
29 | disc_factor: 1.0
30 | disc_start: 100
31 | perceptual_model: "vgg"
32 | save_every: 10
33 |
34 | transformer:
35 | learning_rate: 4.5e-06
36 | beta1: 0.9
37 | beta2: 0.95
38 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | from .mnist import load_mnist
2 | from .cifar10 import load_cifar10
3 | from .load_dataloader import load_dataloader
--------------------------------------------------------------------------------
/dataloader/cifar10.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import torch
3 | import torchvision
4 |
5 | from utils import collate_fn
6 |
7 |
8 | def load_cifar10(
9 | batch_size: int = 16,
10 | image_size: int = 28,
11 | num_workers: int = 4,
12 | save_path: str = "data",
13 | ) -> torch.utils.data.DataLoader:
14 | """Load the Cifar 10 data and returns the dataloaders (train ). The data is downloaded if it does not exist.
15 |
16 | Args:
17 | batch_size (int): The batch size.
18 | image_size (int): The image size.
19 | num_workers (int): The number of workers to use for the dataloader.
20 | save_path (str): The path to save the data to.
21 |
22 | Returns:
23 | torch.utils.data.DataLoader: The data loader.
24 | """
25 |
26 | # Load the data
27 | dataloader = torch.utils.data.DataLoader(
28 | torchvision.datasets.CIFAR10(
29 | root=save_path,
30 | train=True,
31 | download=True,
32 | transform=torchvision.transforms.Compose(
33 | [
34 | torchvision.transforms.Resize((image_size, image_size)),
35 | torchvision.transforms.ToTensor(),
36 | torchvision.transforms.Normalize((0.1307,), (0.3081,)),
37 | ]
38 | ),
39 | ),
40 | batch_size=batch_size,
41 | shuffle=True,
42 | num_workers=num_workers,
43 | pin_memory=True,
44 | collate_fn=collate_fn,
45 | )
46 |
47 | return dataloader
48 |
--------------------------------------------------------------------------------
/dataloader/load_dataloader.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import torch
3 |
4 | from dataloader import load_mnist, load_cifar10
5 |
6 |
7 | def load_dataloader(
8 | name: str = "mnist",
9 | batch_size: int = 2,
10 | image_size: int = 256,
11 | num_workers: int = 4,
12 | save_path: str = "data",
13 | ) -> torch.utils.data.DataLoader:
14 | """Load the data loader for the given name.
15 |
16 | Args:
17 | name (str, optional): The name of the data loader. Defaults to "mnist".
18 | batch_size (int, optional): The batch size. Defaults to 2.
19 | image_size (int, optional): The image size. Defaults to 256.
20 | num_workers (int, optional): The number of workers to use for the dataloader. Defaults to 4.
21 | save_path (str, optional): The path to save the data to. Defaults to "data".
22 |
23 | Returns:
24 | torch.utils.data.DataLoader: The data loader.
25 | """
26 |
27 | if name == "mnist":
28 | return load_mnist(
29 | batch_size=batch_size,
30 | image_size=image_size,
31 | num_workers=num_workers,
32 | save_path=save_path,
33 | )
34 |
35 | elif name == "cifar10":
36 | return load_cifar10(
37 | batch_size=batch_size,
38 | image_size=image_size,
39 | num_workers=num_workers,
40 | save_path=save_path,
41 | )
--------------------------------------------------------------------------------
/dataloader/mnist.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import torch
3 | import torchvision
4 |
5 | from utils import collate_fn
6 |
7 |
8 | def load_mnist(
9 | batch_size: int = 2,
10 | image_size: int = 256,
11 | num_workers: int = 4,
12 | save_path: str = "data",
13 | ) -> torch.utils.data.DataLoader:
14 | """Load the MNIST data and returns the dataloaders (train ). The data is downloaded if it does not exist.
15 |
16 | Args:
17 | batch_size (int): The batch size.
18 | image_size (int): The image size.
19 | num_workers (int): The number of workers to use for the dataloader.
20 | save_path (str): The path to save the data to.
21 |
22 | Returns:
23 | torch.utils.data.DataLoader: The data loader.
24 | """
25 |
26 | # Load the data
27 | mnist_data = torchvision.datasets.MNIST(
28 | root=save_path,
29 | train=True,
30 | download=True,
31 | transform=torchvision.transforms.Compose(
32 | [
33 | torchvision.transforms.Resize((image_size, image_size)),
34 | torchvision.transforms.Grayscale(num_output_channels=1),
35 | torchvision.transforms.ToTensor(),
36 | torchvision.transforms.Normalize((0.1307,), (0.3081,)),
37 | ]
38 | ),
39 | )
40 |
41 | # Reduced set for faster training
42 | mnist_data_reduced = torch.utils.data.Subset(mnist_data, list(range(0, 800)))
43 |
44 | dataloader = torch.utils.data.DataLoader(
45 | mnist_data_reduced,
46 | batch_size=batch_size,
47 | shuffle=True,
48 | num_workers=num_workers,
49 | pin_memory=True,
50 | collate_fn=collate_fn,
51 | )
52 |
53 | return dataloader
54 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: /mnt/Stuffs/Projects/fromScratch/pytorch_vqgan/env
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - blas=1.0=mkl
10 | - brotlipy=0.7.0=py37h27cfd23_1003
11 | - bzip2=1.0.8=h7b6447c_0
12 | - ca-certificates=2022.07.19=h06a4308_0
13 | - certifi=2022.6.15=py37h06a4308_0
14 | - cffi=1.15.1=py37h74dc2b5_0
15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
16 | - cryptography=37.0.1=py37h9ce1e76_0
17 | - cudatoolkit=11.6.0=hecad31d_10
18 | - ffmpeg=4.2.2=h20bf706_0
19 | - freetype=2.11.0=h70c0345_0
20 | - giflib=5.2.1=h7b6447c_0
21 | - gmp=6.2.1=h295c915_3
22 | - gnutls=3.6.15=he1e5248_0
23 | - idna=3.3=pyhd3eb1b0_0
24 | - intel-openmp=2021.4.0=h06a4308_3561
25 | - jpeg=9e=h7f8727e_0
26 | - lame=3.100=h7b6447c_0
27 | - lcms2=2.12=h3be6417_0
28 | - ld_impl_linux-64=2.38=h1181459_1
29 | - libffi=3.3=he6710b0_2
30 | - libgcc-ng=11.2.0=h1234567_1
31 | - libgomp=11.2.0=h1234567_1
32 | - libidn2=2.3.2=h7f8727e_0
33 | - libopus=1.3.1=h7b6447c_0
34 | - libpng=1.6.37=hbc83047_0
35 | - libstdcxx-ng=11.2.0=h1234567_1
36 | - libtasn1=4.16.0=h27cfd23_0
37 | - libtiff=4.2.0=h2818925_1
38 | - libunistring=0.9.10=h27cfd23_0
39 | - libvpx=1.7.0=h439df22_0
40 | - libwebp=1.2.2=h55f646e_0
41 | - libwebp-base=1.2.2=h7f8727e_0
42 | - lz4-c=1.9.3=h295c915_1
43 | - mkl=2021.4.0=h06a4308_640
44 | - mkl-service=2.4.0=py37h7f8727e_0
45 | - mkl_fft=1.3.1=py37hd3c417c_0
46 | - mkl_random=1.2.2=py37h51133e4_0
47 | - ncurses=6.3=h5eee18b_3
48 | - nettle=3.7.3=hbbd107a_1
49 | - numpy=1.21.5=py37h6c91a56_3
50 | - numpy-base=1.21.5=py37ha15fc14_3
51 | - openh264=2.1.1=h4ff587b_0
52 | - openssl=1.1.1q=h7f8727e_0
53 | - pillow=9.2.0=py37hace64e9_1
54 | - pip=22.1.2=py37h06a4308_0
55 | - pycparser=2.21=pyhd3eb1b0_0
56 | - pyopenssl=22.0.0=pyhd3eb1b0_0
57 | - pysocks=1.7.1=py37_1
58 | - python=3.7.13=h12debd9_0
59 | - python_abi=3.7=2_cp37m
60 | - pytorch=1.12.1=py3.7_cuda11.6_cudnn8.3.2_0
61 | - pytorch-mutex=1.0=cuda
62 | - readline=8.1.2=h7f8727e_1
63 | - requests=2.28.1=py37h06a4308_0
64 | - setuptools=61.2.0=py37h06a4308_0
65 | - six=1.16.0=pyhd3eb1b0_1
66 | - sqlite=3.39.2=h5082296_0
67 | - tk=8.6.12=h1ccaba5_0
68 | - torchaudio=0.12.1=py37_cu116
69 | - torchvision=0.13.1=py37_cu116
70 | - typing_extensions=4.3.0=py37h06a4308_0
71 | - urllib3=1.26.11=py37h06a4308_0
72 | - wheel=0.37.1=pyhd3eb1b0_0
73 | - x264=1!157.20191217=h7b6447c_0
74 | - xz=5.2.5=h7f8727e_1
75 | - zlib=1.2.12=h7f8727e_2
76 | - zstd=1.5.2=ha4553b6_0
77 | - pip:
78 | - aim==3.12.2
79 | - aim-ui==3.12.2
80 | - aimrecords==0.0.7
81 | - aimrocks==0.2.1
82 | - aiofiles==0.8.0
83 | - albumentations==1.2.1
84 | - alembic==1.8.1
85 | - attrs==22.1.0
86 | - base58==2.0.1
87 | - beautifulsoup4==4.11.1
88 | - cachetools==5.2.0
89 | - click==8.1.3
90 | - codecov==2.1.12
91 | - coverage==6.4.3
92 | - exceptiongroup==1.0.0rc8
93 | - fastapi==0.67.0
94 | - filelock==3.8.0
95 | - google-api-core==2.8.2
96 | - google-api-python-client==2.56.0
97 | - google-auth==2.10.0
98 | - google-auth-httplib2==0.1.0
99 | - googleapis-common-protos==1.56.4
100 | - greenlet==1.1.2
101 | - grpcio==1.47.0
102 | - h11==0.13.0
103 | - httplib2==0.20.4
104 | - hypothesis==6.54.3
105 | - imageio==2.21.1
106 | - importlib-metadata==4.12.0
107 | - importlib-resources==5.9.0
108 | - iniconfig==1.1.1
109 | - jinja2==3.1.2
110 | - joblib==1.1.0
111 | - lpips==0.1.4
112 | - mako==1.2.1
113 | - markupsafe==2.1.1
114 | - networkx==2.6.3
115 | - opencv-python-headless==4.6.0.66
116 | - packaging==21.3
117 | - pandas==1.3.5
118 | - pluggy==1.0.0
119 | - protobuf==3.20.0
120 | - psutil==5.9.1
121 | - py==1.11.0
122 | - py3nvml==0.2.7
123 | - pyasn1==0.4.8
124 | - pyasn1-modules==0.2.8
125 | - pydantic==1.9.2
126 | - pyparsing==3.0.9
127 | - pytest==7.1.2
128 | - pytest-cov==3.0.0
129 | - python-dateutil==2.8.2
130 | - pytz==2022.2.1
131 | - pywavelets==1.3.0
132 | - pyyaml==6.0
133 | - qudida==0.0.4
134 | - restrictedpython==5.2
135 | - rsa==4.9
136 | - scikit-image==0.19.3
137 | - scikit-learn==1.0.2
138 | - scipy==1.7.3
139 | - sortedcontainers==2.4.0
140 | - soupsieve==2.3.2.post1
141 | - sqlalchemy==1.4.40
142 | - starlette==0.14.2
143 | - threadpoolctl==3.1.0
144 | - tifffile==2021.11.2
145 | - tomli==2.0.1
146 | - torch-summary==1.4.5
147 | - tqdm==4.64.0
148 | - uritemplate==4.1.1
149 | - uvicorn==0.18.2
150 | - xmltodict==0.13.0
151 | - zipp==3.8.1
152 | prefix: /mnt/Stuffs/Projects/fromScratch/pytorch_vqgan/env
153 |
--------------------------------------------------------------------------------
/experiments/generated_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_0.jpg
--------------------------------------------------------------------------------
/experiments/generated_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_1.jpg
--------------------------------------------------------------------------------
/experiments/generated_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_2.jpg
--------------------------------------------------------------------------------
/experiments/generated_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_3.jpg
--------------------------------------------------------------------------------
/experiments/generated_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/generated_4.jpg
--------------------------------------------------------------------------------
/experiments/reconstruction.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/experiments/reconstruction.gif
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import argparse
3 |
4 | import yaml
5 |
6 | from trainer import Trainer
7 | from transformer import VQGANTransformer
8 | from vqgan import VQGAN
9 |
10 |
11 | def main(args, config):
12 |
13 | vqgan = VQGAN(**config["architecture"]["vqgan"])
14 | vqgan.load_checkpoint("./experiments/checkpoints/vqgan.pt")
15 |
16 | transformer = VQGANTransformer(
17 | vqgan, device=args.device, **config["architecture"]["transformer"]
18 | )
19 | transformer.load_checkpoint("./experiments/checkpoints/transformer.pt")
20 |
21 | trainer = Trainer(
22 | vqgan,
23 | transformer,
24 | run=None,
25 | config=config["trainer"],
26 | seed=args.seed,
27 | device=args.device,
28 | )
29 |
30 | trainer.generate_images()
31 |
32 |
33 | if __name__ == "__main__":
34 |
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument(
37 | "--config_path",
38 | type=str,
39 | default="configs/default.yml",
40 | help="path to config file",
41 | )
42 | parser.add_argument(
43 | "--dataset_name",
44 | type=str,
45 | choices=["mnist", "cifar", "custom"],
46 | default="mnist",
47 | help="Dataset for the model",
48 | )
49 | parser.add_argument(
50 | "--device",
51 | type=str,
52 | default="cuda",
53 | choices=["cpu", "cuda"],
54 | help="Device to train the model on",
55 | )
56 | parser.add_argument(
57 | "--seed",
58 | type=str,
59 | default=42,
60 | help="Seed for Reproducibility",
61 | )
62 |
63 | args = parser.parse_args()
64 |
65 | args = parser.parse_args()
66 | with open(args.config_path) as f:
67 | config = yaml.load(f, Loader=yaml.FullLoader)
68 |
69 | main(args, config)
70 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aim==3.12.2
2 | aim-ui==3.12.2
3 | aimrecords==0.0.7
4 | aimrocks==0.2.1
5 | aiofiles==0.8.0
6 | albumentations==1.2.1
7 | alembic==1.8.1
8 | attrs==22.1.0
9 | base58==2.0.1
10 | beautifulsoup4==4.11.1
11 | brotlipy==0.7.0
12 | cachetools==5.2.0
13 | certifi==2022.6.15
14 | cffi==1.15.1
15 | charset-normalizer==2.0.4
16 | click==8.1.3
17 | codecov==2.1.12
18 | coverage==6.4.3
19 | cryptography==37.0.1
20 | exceptiongroup==1.0.0rc8
21 | fastapi==0.67.0
22 | filelock==3.8.0
23 | google-api-core==2.8.2
24 | google-api-python-client==2.56.0
25 | google-auth==2.10.0
26 | google-auth-httplib2==0.1.0
27 | googleapis-common-protos==1.56.4
28 | greenlet==1.1.2
29 | grpcio==1.47.0
30 | h11==0.13.0
31 | httplib2==0.20.4
32 | hypothesis==6.54.3
33 | idna==3.3
34 | imageio==2.21.1
35 | importlib-metadata==4.12.0
36 | importlib-resources==5.9.0
37 | iniconfig==1.1.1
38 | Jinja2==3.1.2
39 | joblib==1.1.1
40 | lpips==0.1.4
41 | Mako==1.2.1
42 | MarkupSafe==2.1.1
43 | mkl-fft==1.3.1
44 | mkl-random==1.2.2
45 | mkl-service==2.4.0
46 | networkx==2.6.3
47 | numpy==1.21.5
48 | opencv-python-headless==4.6.0.66
49 | packaging==21.3
50 | pandas==1.3.5
51 | Pillow==9.2.0
52 | pip==22.1.2
53 | pluggy==1.0.0
54 | protobuf==3.20.0
55 | psutil==5.9.1
56 | py==1.11.0
57 | py3nvml==0.2.7
58 | pyasn1==0.4.8
59 | pyasn1-modules==0.2.8
60 | pycparser==2.21
61 | pydantic==1.9.2
62 | pyOpenSSL==22.0.0
63 | pyparsing==3.0.9
64 | PySocks==1.7.1
65 | pytest==7.1.2
66 | pytest-cov==3.0.0
67 | python-dateutil==2.8.2
68 | pytz==2022.2.1
69 | PyWavelets==1.3.0
70 | PyYAML==6.0
71 | qudida==0.0.4
72 | requests==2.28.1
73 | RestrictedPython==5.2
74 | rsa==4.9
75 | scikit-image==0.19.3
76 | scikit-learn==1.0.2
77 | scipy==1.7.3
78 | setuptools==61.2.0
79 | six==1.16.0
80 | sortedcontainers==2.4.0
81 | soupsieve==2.3.2.post1
82 | SQLAlchemy==1.4.40
83 | starlette==0.25.0
84 | threadpoolctl==3.1.0
85 | tifffile==2021.11.2
86 | tomli==2.0.1
87 | torch==1.12.1
88 | torch-summary==1.4.5
89 | torchaudio==0.12.1
90 | torchvision==0.13.1
91 | tqdm==4.64.0
92 | typing_extensions==4.3.0
93 | uritemplate==4.1.1
94 | urllib3==1.26.11
95 | uvicorn==0.18.2
96 | wheel==0.37.1
97 | xmltodict==0.13.0
98 | zipp==3.8.1
99 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/test/__init__.py
--------------------------------------------------------------------------------
/test/test_dataloader.py:
--------------------------------------------------------------------------------
1 | from dataloader import load_mnist
2 |
3 |
4 | def test_load_mnist():
5 |
6 | dataloader = load_mnist(batch_size=16)
7 |
8 | for imgs in dataloader:
9 | break
10 |
11 | assert imgs.shape == (16, 1, 256, 256)
--------------------------------------------------------------------------------
/test/test_vqgan.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains test functions for the VQGAN model.
3 | """
4 |
5 | # Importing Libraries
6 | import torch
7 |
8 | from vqgan import Encoder, Decoder, CodeBook, Discriminator
9 |
10 |
11 | def test_encoder():
12 |
13 | image_channels = 3
14 | image_size = 256
15 | latent_channels = 256
16 | attn_resolution = 16
17 |
18 | image = torch.randn(1, image_channels, image_size, image_size)
19 |
20 | model = Encoder(
21 | img_channels=image_channels,
22 | image_size=image_size,
23 | latent_channels=latent_channels,
24 | attention_resolution=[attn_resolution],
25 | )
26 |
27 | output = model(image)
28 |
29 | assert output.shape == (
30 | 1,
31 | latent_channels,
32 | attn_resolution,
33 | attn_resolution,
34 | ), "Output of encoder does not match"
35 |
36 |
37 | def test_decoder():
38 |
39 | img_channels = 3
40 | img_size = 256
41 | latent_channels = 256
42 | latent_size = 16
43 | attn_resolution = 16
44 |
45 | latent = torch.randn(1, latent_channels, latent_size, latent_size)
46 | model = Decoder(
47 | img_channels=img_channels,
48 | latent_size=latent_size,
49 | latent_channels=latent_channels,
50 | attention_resolution=[attn_resolution],
51 | )
52 |
53 | output = model(latent)
54 |
55 | assert output.shape == (
56 | 1,
57 | img_channels,
58 | img_size,
59 | img_size,
60 | ), "Output of decoder does not match"
61 |
62 |
63 | def test_codebook():
64 |
65 | z = torch.randn(1, 256, 16, 16)
66 |
67 | codebok = CodeBook(num_codebook_vectors=100, latent_dim=16)
68 |
69 | z_q, min_distance_indices, loss = codebok(z)
70 |
71 | assert z_q.shape == (1, 256, 16, 16), "Output of codebook does not match"
72 |
73 |
74 | def test_discriminator():
75 |
76 | image = torch.randn(1, 3, 256, 256)
77 |
78 | model = Discriminator()
79 |
80 |
81 | output = model(image)
82 |
83 | assert output.shape == (1, 1, 30, 30), "Output of discriminator does not match"
84 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import argparse
3 |
4 | import yaml
5 | from aim import Run
6 |
7 | from dataloader import load_dataloader
8 | from trainer import Trainer
9 | from transformer import VQGANTransformer
10 | from vqgan import VQGAN
11 |
12 |
13 | def main(args, config):
14 |
15 | vqgan = VQGAN(**config["architecture"]["vqgan"])
16 | transformer = VQGANTransformer(
17 | vqgan, **config["architecture"]["transformer"], device=args.device
18 | )
19 | dataloader = load_dataloader(name=args.dataset_name)
20 |
21 | run = Run(experiment=args.dataset_name)
22 | run["hparams"] = config
23 |
24 | trainer = Trainer(
25 | vqgan,
26 | transformer,
27 | run=run,
28 | config=config["trainer"],
29 | seed=args.seed,
30 | device=args.device,
31 | )
32 |
33 | trainer.train_vqgan(dataloader)
34 | trainer.train_transformers(dataloader)
35 | trainer.generate_images()
36 |
37 |
38 | if __name__ == "__main__":
39 |
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument(
42 | "--config_path",
43 | type=str,
44 | default="configs/default.yml",
45 | help="path to config file",
46 | )
47 | parser.add_argument(
48 | "--dataset_name",
49 | type=str,
50 | choices=["mnist", "cifar", "custom"],
51 | default="mnist",
52 | help="Dataset for the model",
53 | )
54 | parser.add_argument(
55 | "--device",
56 | type=str,
57 | default="cuda",
58 | choices=["cpu", "cuda"],
59 | help="Device to train the model on",
60 | )
61 | parser.add_argument(
62 | "--seed",
63 | type=str,
64 | default=42,
65 | help="Seed for Reproducibility",
66 | )
67 |
68 | args = parser.parse_args()
69 |
70 | args = parser.parse_args()
71 | with open(args.config_path) as f:
72 | config = yaml.load(f, Loader=yaml.FullLoader)
73 |
74 | main(args, config)
75 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .vqgan import VQGANTrainer
2 | from .transformer import TransformerTrainer
3 | from .trainer import Trainer
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import os
3 |
4 | import torch
5 | import torchvision
6 | from aim import Run
7 | from utils import reproducibility
8 |
9 | from trainer import TransformerTrainer, VQGANTrainer
10 |
11 |
12 | class Trainer:
13 | def __init__(
14 | self,
15 | vqgan: torch.nn.Module,
16 | transformer: torch.nn.Module,
17 | run: Run,
18 | config: dict,
19 | experiment_dir: str = "experiments",
20 | seed: int = 42,
21 | device: str = "cuda",
22 | ) -> None:
23 |
24 | self.vqgan = vqgan
25 | self.transformer = transformer
26 |
27 | self.run = run
28 | self.config = config
29 | self.experiment_dir = experiment_dir
30 | self.seed = seed
31 | self.device = device
32 |
33 | print(f"[INFO] Setting seed to {seed}")
34 | reproducibility(seed)
35 |
36 | print(f"[INFO] Results will be saved in {experiment_dir}")
37 | self.experiment_dir = experiment_dir
38 |
39 | def train_vqgan(self, dataloader: torch.utils.data.DataLoader, epochs: int = 1):
40 |
41 | print(f"[INFO] Training VQGAN on {self.device} for {epochs} epoch(s).")
42 |
43 | self.vqgan.to(self.device)
44 |
45 | self.vqgan_trainer = VQGANTrainer(
46 | model=self.vqgan,
47 | run=self.run,
48 | device=self.device,
49 | experiment_dir=self.experiment_dir,
50 | **self.config["vqgan"],
51 | )
52 |
53 | self.vqgan_trainer.train(
54 | dataloader=dataloader,
55 | epochs=epochs,
56 | )
57 |
58 | # Saving the model
59 | self.vqgan.save_checkpoint(
60 | os.path.join(self.experiment_dir, "checkpoints", "vqgan.pt")
61 | )
62 |
63 | def train_transformers(
64 | self, dataloader: torch.utils.data.DataLoader, epochs: int = 1
65 | ):
66 |
67 | print(f"[INFO] Training Transformer on {self.device} for {epochs} epoch(s).")
68 |
69 | self.vqgan.eval()
70 | self.transformer = self.transformer.to(self.device)
71 |
72 | self.transformer_trainer = TransformerTrainer(
73 | model=self.transformer,
74 | run=self.run,
75 | device=self.device,
76 | experiment_dir=self.experiment_dir,
77 | **self.config["transformer"],
78 | )
79 |
80 | self.transformer_trainer.train(dataloader=dataloader, epochs=epochs)
81 |
82 | self.transformer.save_checkpoint(
83 | os.path.join(self.experiment_dir, "checkpoints", "transformer.pt")
84 | )
85 |
86 | def generate_images(self, n_images: int = 5):
87 |
88 | print(f"[INFO] Generating {n_images} images...")
89 |
90 | self.vqgan.to(self.device)
91 | self.transformer = self.transformer.to(self.device)
92 |
93 |
94 | for i in range(n_images):
95 | start_indices = torch.zeros((4, 0)).long().to(self.device)
96 | sos_tokens = torch.ones(start_indices.shape[0], 1) * 0
97 |
98 | sos_tokens = sos_tokens.long().to(self.device)
99 | sample_indices = self.transformer.sample(
100 | start_indices, sos_tokens, steps=256
101 | )
102 | sampled_imgs = self.transformer.z_to_image(sample_indices)
103 | torchvision.utils.save_image(
104 | sampled_imgs,
105 | os.path.join(self.experiment_dir, f"generated_{i}.jpg"),
106 | nrow=4,
107 | )
108 |
109 |
--------------------------------------------------------------------------------
/trainer/transformer.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import torchvision
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from aim import Run, Image
8 |
9 |
10 | class TransformerTrainer:
11 | def __init__(
12 | self,
13 | model: nn.Module,
14 | run: Run,
15 | experiment_dir: str = "experiments",
16 | device: str = "cuda",
17 | learning_rate: float = 4.5e-06,
18 | beta1: float = 0.9,
19 | beta2: float = 0.95,
20 | ):
21 | self.run = run
22 | self.experiment_dir = experiment_dir
23 |
24 | self.model = model
25 | self.device = device
26 | self.optim = self.configure_optimizers(
27 | learning_rate=learning_rate, beta1=beta1, beta2=beta2
28 | )
29 |
30 | def configure_optimizers(
31 | self, learning_rate: float = 4.5e-06, beta1: float = 0.9, beta2: float = 0.95
32 | ):
33 | decay, no_decay = set(), set()
34 | whitelist_weight_modules = (nn.Linear,)
35 | blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)
36 |
37 | # Enabling weight decay to only certain layers
38 | for mn, m in self.model.transformer.named_modules():
39 | for pn, p in m.named_parameters():
40 | fpn = f"{mn}.{pn}" if mn else pn
41 |
42 | if pn.endswith("bias"):
43 | no_decay.add(fpn)
44 |
45 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
46 | decay.add(fpn)
47 |
48 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
49 | no_decay.add(fpn)
50 |
51 | no_decay.add("pos_emb")
52 |
53 | param_dict = {pn: p for pn, p in self.model.transformer.named_parameters()}
54 |
55 | optim_groups = [
56 | {
57 | "params": [param_dict[pn] for pn in sorted(list(decay))],
58 | "weight_decay": 0.01,
59 | },
60 | {
61 | "params": [param_dict[pn] for pn in sorted(list(no_decay))],
62 | "weight_decay": 0.0,
63 | },
64 | ]
65 |
66 | optimizer = torch.optim.AdamW(
67 | optim_groups, lr=learning_rate, betas=(beta1, beta2)
68 | )
69 | return optimizer
70 |
71 | def train(self, dataloader: torch.utils.data.DataLoader, epochs: int):
72 | for epoch in range(epochs):
73 |
74 | for index, imgs in enumerate(dataloader):
75 | self.optim.zero_grad()
76 | imgs = imgs.to(device=self.device)
77 | logits, targets = self.model(imgs)
78 | loss = F.cross_entropy(
79 | logits.reshape(-1, logits.size(-1)), targets.reshape(-1)
80 | )
81 | loss.backward()
82 | self.optim.step()
83 |
84 | self.run.track(
85 | loss,
86 | name="Cross Entropy Loss",
87 | step=index,
88 | context={"stage": "transformer"},
89 | )
90 |
91 | if index % 10 == 0:
92 | print(
93 | f"Epoch: {epoch+1}/{epochs} | Batch: {index}/{len(dataloader)} | Cross Entropy Loss : {loss:.4f}"
94 | )
95 |
96 | _, sampled_imgs = self.model.log_images(imgs[0][None])
97 |
98 | self.run.track(
99 | Image(
100 | torchvision.utils.make_grid(sampled_imgs)
101 | .mul(255)
102 | .add_(0.5)
103 | .clamp_(0, 255)
104 | .permute(1, 2, 0)
105 | .to("cpu", torch.uint8)
106 | .numpy()
107 | ),
108 | name="Transformer Images",
109 | step=index,
110 | context={"stage": "transformer"},
111 | )
112 |
--------------------------------------------------------------------------------
/trainer/vqgan.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/dome272/VQGAN-pytorch/blob/main/training_vqgan.py
3 | """
4 |
5 | # Importing Libraries
6 | import os
7 |
8 | import imageio
9 | import lpips
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | import torchvision
14 | from aim import Image, Run
15 | from utils import weights_init
16 | from vqgan import Discriminator
17 |
18 |
19 | class VQGANTrainer:
20 | """Trainer class for VQGAN, contains step, train methods"""
21 |
22 | def __init__(
23 | self,
24 | model: torch.nn.Module,
25 | run: Run,
26 | # Training parameters
27 | device: str or torch.device = "cuda",
28 | learning_rate: float = 2.25e-05,
29 | beta1: float = 0.5,
30 | beta2: float = 0.9,
31 | # Loss parameters
32 | perceptual_loss_factor: float = 1.0,
33 | rec_loss_factor: float = 1.0,
34 | # Discriminator parameters
35 | disc_factor: float = 1.0,
36 | disc_start: int = 100,
37 | # Miscellaneous parameters
38 | experiment_dir: str = "./experiments",
39 | perceptual_model: str = "vgg",
40 | save_every: int = 10,
41 | ):
42 |
43 | self.run = run
44 | self.device = device
45 |
46 | # VQGAN parameters
47 | self.vqgan = model
48 |
49 | # Discriminator parameters
50 | self.discriminator = Discriminator(image_channels=self.vqgan.img_channels).to(
51 | self.device
52 | )
53 | self.discriminator.apply(weights_init)
54 |
55 | # Loss parameters
56 | self.perceptual_loss = lpips.LPIPS(net=perceptual_model).to(self.device)
57 |
58 | # Optimizers
59 | self.opt_vq, self.opt_disc = self.configure_optimizers(
60 | learning_rate=learning_rate, beta1=beta1, beta2=beta2
61 | )
62 |
63 | # Hyperprameters
64 | self.disc_factor = disc_factor
65 | self.disc_start = disc_start
66 | self.perceptual_loss_factor = perceptual_loss_factor
67 | self.rec_loss_factor = rec_loss_factor
68 |
69 | # Save directory
70 | self.expriment_save_dir = experiment_dir
71 |
72 | # Miscellaneous
73 | self.global_step = 0
74 | self.sample_batch = None
75 | self.gif_images = []
76 | self.save_every = save_every
77 |
78 | def configure_optimizers(
79 | self, learning_rate: float = 2.25e-05, beta1: float = 0.5, beta2: float = 0.9
80 | ):
81 | opt_vq = torch.optim.Adam(
82 | list(self.vqgan.encoder.parameters())
83 | + list(self.vqgan.decoder.parameters())
84 | + list(self.vqgan.codebook.parameters())
85 | + list(self.vqgan.quant_conv.parameters())
86 | + list(self.vqgan.post_quant_conv.parameters()),
87 | lr=learning_rate,
88 | eps=1e-08,
89 | betas=(beta1, beta2),
90 | )
91 | opt_disc = torch.optim.Adam(
92 | self.discriminator.parameters(),
93 | lr=learning_rate,
94 | eps=1e-08,
95 | betas=(beta1, beta2),
96 | )
97 |
98 | return opt_vq, opt_disc
99 |
100 | def step(self, imgs: torch.Tensor) -> torch.Tensor:
101 | """Performs a single training step from the dataloader images batch
102 |
103 | For the VQGAN, it calculates the perceptual loss, reconstruction loss, and the codebook loss and does the backward pass.
104 |
105 | For the discriminator, it calculates lambda for the discriminator loss and does the backward pass.
106 |
107 | Args:
108 | imgs: input tensor of shape (batch_size, channel, H, W)
109 |
110 | Returns:
111 | decoded_imgs: output tensor of shape (batch_size, channel, H, W)
112 | """
113 |
114 | # Getting decoder output
115 | decoded_images, _, q_loss = self.vqgan(imgs)
116 |
117 | """
118 | =======================================================================================================================
119 | VQ Loss
120 | """
121 | perceptual_loss = self.perceptual_loss(imgs, decoded_images)
122 | rec_loss = torch.abs(imgs - decoded_images)
123 | perceptual_rec_loss = (
124 | self.perceptual_loss_factor * perceptual_loss
125 | + self.rec_loss_factor * rec_loss
126 | )
127 | perceptual_rec_loss = perceptual_rec_loss.mean()
128 |
129 | """
130 | =======================================================================================================================
131 | Discriminator Loss
132 | """
133 | disc_real = self.discriminator(imgs)
134 | disc_fake = self.discriminator(decoded_images)
135 |
136 | disc_factor = self.vqgan.adopt_weight(
137 | self.disc_factor, self.global_step, threshold=self.disc_start
138 | )
139 |
140 | g_loss = -torch.mean(disc_fake)
141 |
142 | λ = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss)
143 | vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss
144 |
145 | d_loss_real = torch.mean(F.relu(1.0 - disc_real))
146 | d_loss_fake = torch.mean(F.relu(1.0 + disc_fake))
147 | gan_loss = disc_factor * 0.5 * (d_loss_real + d_loss_fake)
148 |
149 | # ======================================================================================================================
150 | # Tracking metrics
151 |
152 | self.run.track(
153 | perceptual_rec_loss,
154 | name="Perceptual & Reconstruction loss",
155 | step=self.global_step,
156 | context={"stage": "vqgan"},
157 | )
158 |
159 | self.run.track(
160 | vq_loss, name="VQ Loss", step=self.global_step, context={"stage": "vqgan"}
161 | )
162 | self.run.track(
163 | gan_loss, name="GAN Loss", step=self.global_step, context={"stage": "vqgan"}
164 | )
165 |
166 | # =======================================================================================================================
167 | # Backpropagation
168 |
169 | self.opt_vq.zero_grad()
170 | vq_loss.backward(
171 | retain_graph=True
172 | ) # retain_graph is used to retain the computation graph for the discriminator loss
173 |
174 | self.opt_disc.zero_grad()
175 | gan_loss.backward()
176 |
177 | self.opt_vq.step()
178 | self.opt_disc.step()
179 |
180 | return decoded_images, vq_loss, gan_loss
181 |
182 | def train(
183 | self,
184 | dataloader: torch.utils.data.DataLoader,
185 | epochs: int = 1,
186 | ):
187 | """Trains the VQGAN for the given number of epochs
188 |
189 | Args:
190 | dataloader (torch.utils.data.DataLoader): dataloader to use.
191 | epochs (int, optional): number of epochs to train for. Defaults to 100.
192 | """
193 |
194 | for epoch in range(epochs):
195 | for index, imgs in enumerate(dataloader):
196 |
197 | # Training step
198 | imgs = imgs.to(self.device)
199 |
200 | decoded_images, vq_loss, gan_loss = self.step(imgs)
201 |
202 | # Updating global step
203 | self.global_step += 1
204 |
205 | if index % self.save_every == 0:
206 |
207 | print(
208 | f"Epoch: {epoch+1}/{epochs} | Batch: {index}/{len(dataloader)} | VQ Loss : {vq_loss:.4f} | Discriminator Loss: {gan_loss:.4f}"
209 | )
210 |
211 | # Only saving the gif for the first 2000 save steps
212 | if self.global_step // self.save_every <= 2000:
213 | self.sample_batch = (
214 | imgs[:] if self.sample_batch is None else self.sample_batch
215 | )
216 |
217 | with torch.no_grad():
218 |
219 | """
220 | Note : Lots of efficiency & cleaning needed here
221 | """
222 |
223 | gif_img = (
224 | torchvision.utils.make_grid(
225 | torch.cat(
226 | (
227 | self.sample_batch,
228 | self.vqgan(self.sample_batch)[0],
229 | ),
230 | )
231 | )
232 | .detach()
233 | .cpu()
234 | .permute(1, 2, 0)
235 | .numpy()
236 | )
237 |
238 | gif_img = (gif_img - gif_img.min()) * (
239 | 255 / (gif_img.max() - gif_img.min())
240 | )
241 | gif_img = gif_img.astype(np.uint8)
242 |
243 | self.run.track(
244 | Image(
245 | torchvision.utils.make_grid(
246 | torch.cat(
247 | (
248 | imgs,
249 | decoded_images,
250 | ),
251 | )
252 | ).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
253 | ),
254 | name="VQGAN Reconstruction",
255 | step=self.global_step,
256 | context={"stage": "vqgan"},
257 | )
258 |
259 | self.gif_images.append(gif_img)
260 |
261 | imageio.mimsave(
262 | os.path.join(self.expriment_save_dir, "reconstruction.gif"),
263 | self.gif_images,
264 | fps=5,
265 | )
266 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import VQGANTransformer
--------------------------------------------------------------------------------
/transformer/mingpt.py:
--------------------------------------------------------------------------------
1 | """
2 | from : https://github.com/dome272/VQGAN-pytorch/blob/main/mingpt.py
3 | which is taken from: https://github.com/karpathy/minGPT/
4 | GPT model:
5 | - the initial stem consists of a combination of token encoding and a positional encoding
6 | - the meat of it is a uniform sequence of Transformer blocks
7 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
8 | - all blocks feed into a central residual pathway similar to resnets
9 | - the final decoder is a linear projection into a vanilla Softmax classifier
10 | """
11 |
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | from torch.nn import functional as F
16 |
17 |
18 | class GPTConfig:
19 | """ base GPT config, params common to all GPT versions """
20 | embd_pdrop = 0.1
21 | resid_pdrop = 0.1
22 | attn_pdrop = 0.1
23 |
24 | def __init__(self, vocab_size, block_size, **kwargs):
25 | self.vocab_size = vocab_size
26 | self.block_size = block_size
27 | for k, v in kwargs.items():
28 | setattr(self, k, v)
29 |
30 |
31 | class CausalSelfAttention(nn.Module):
32 | """
33 | A vanilla multi-head masked self-attention layer with a projection at the end.
34 | It is possible to use torch.nn.MultiheadAttention here but I am including an
35 | explicit implementation here to show that there is nothing too scary here.
36 | """
37 |
38 | def __init__(self, config):
39 | super().__init__()
40 | assert config.n_embd % config.n_head == 0
41 | # key, query, value projections for all heads
42 | self.key = nn.Linear(config.n_embd, config.n_embd)
43 | self.query = nn.Linear(config.n_embd, config.n_embd)
44 | self.value = nn.Linear(config.n_embd, config.n_embd)
45 | # regularization
46 | self.attn_drop = nn.Dropout(config.attn_pdrop)
47 | self.resid_drop = nn.Dropout(config.resid_pdrop)
48 | # output projection
49 | self.proj = nn.Linear(config.n_embd, config.n_embd)
50 | # causal mask to ensure that attention is only applied to the left in the input sequence
51 | mask = torch.tril(torch.ones(config.block_size,
52 | config.block_size))
53 | if hasattr(config, "n_unmasked"):
54 | mask[:config.n_unmasked, :config.n_unmasked] = 1
55 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
56 | self.n_head = config.n_head
57 |
58 | def forward(self, x, layer_past=None):
59 | B, T, C = x.size()
60 |
61 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
62 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
63 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
64 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
65 |
66 | present = torch.stack((k, v))
67 | if layer_past is not None:
68 | past_key, past_value = layer_past
69 | k = torch.cat((past_key, k), dim=-2)
70 | v = torch.cat((past_value, v), dim=-2)
71 |
72 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
73 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
74 | if layer_past is None:
75 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
76 |
77 | att = F.softmax(att, dim=-1)
78 | att = self.attn_drop(att)
79 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
80 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
81 |
82 | # output projection
83 | y = self.resid_drop(self.proj(y))
84 | return y, present # TODO: check that this does not break anything
85 |
86 |
87 | class Block(nn.Module):
88 | """ an unassuming Transformer block """
89 |
90 | def __init__(self, config):
91 | super().__init__()
92 | self.ln1 = nn.LayerNorm(config.n_embd)
93 | self.ln2 = nn.LayerNorm(config.n_embd)
94 | self.attn = CausalSelfAttention(config)
95 | self.mlp = nn.Sequential(
96 | nn.Linear(config.n_embd, 4 * config.n_embd),
97 | nn.GELU(), # nice
98 | nn.Linear(4 * config.n_embd, config.n_embd),
99 | nn.Dropout(config.resid_pdrop),
100 | )
101 |
102 | def forward(self, x, layer_past=None, return_present=False):
103 | # TODO: check that training still works
104 | if return_present:
105 | assert not self.training
106 | # layer past: tuple of length two with B, nh, T, hs
107 | attn, present = self.attn(self.ln1(x), layer_past=layer_past)
108 |
109 | x = x + attn
110 | x = x + self.mlp(self.ln2(x))
111 | if layer_past is not None or return_present:
112 | return x, present
113 | return x
114 |
115 |
116 | class GPT(nn.Module):
117 | """ the full GPT language model, with a context size of block_size """
118 |
119 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
120 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
121 | super().__init__()
122 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
123 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
124 | n_layer=n_layer, n_head=n_head, n_embd=n_embd,
125 | n_unmasked=n_unmasked)
126 | # input embedding stem
127 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
128 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) # 512 x 1024
129 | self.drop = nn.Dropout(config.embd_pdrop)
130 | # transformer
131 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
132 | # decoder head
133 | self.ln_f = nn.LayerNorm(config.n_embd)
134 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
135 |
136 | self.block_size = config.block_size
137 | self.apply(self._init_weights)
138 | self.config = config
139 |
140 | def get_block_size(self):
141 | return self.block_size
142 |
143 | def _init_weights(self, module):
144 | if isinstance(module, (nn.Linear, nn.Embedding)):
145 | module.weight.data.normal_(mean=0.0, std=0.02)
146 | if isinstance(module, nn.Linear) and module.bias is not None:
147 | module.bias.data.zero_()
148 | elif isinstance(module, nn.LayerNorm):
149 | module.bias.data.zero_()
150 | module.weight.data.fill_(1.0)
151 |
152 | def forward(self, idx, embeddings=None):
153 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
154 |
155 | if embeddings is not None: # prepend explicit embeddings
156 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
157 |
158 | t = token_embeddings.shape[1]
159 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
160 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
161 | x = self.drop(token_embeddings + position_embeddings)
162 | x = self.blocks(x)
163 | x = self.ln_f(x)
164 | logits = self.head(x)
165 |
166 | return logits, None
167 |
168 |
169 |
170 |
171 |
172 |
--------------------------------------------------------------------------------
/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/dome272/VQGAN-pytorch/blob/main/transformer.py
3 | """
4 |
5 | # Importing Libraries
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from transformer.mingpt import GPT
10 |
11 |
12 | class VQGANTransformer(nn.Module):
13 | def __init__(
14 | self,
15 | vqgan: nn.Module,
16 | device: str = "cuda",
17 | sos_token: int = 0,
18 | pkeep: float = 0.5,
19 | block_size: int = 512,
20 | n_layer: int = 12,
21 | n_head: int = 16,
22 | n_embd: int = 1024,
23 | ):
24 | super().__init__()
25 |
26 | self.sos_token = sos_token
27 | self.device = device
28 |
29 | self.vqgan = vqgan
30 |
31 | self.transformer = GPT(
32 | vocab_size=self.vqgan.num_codebook_vectors,
33 | block_size=block_size,
34 | n_layer=n_layer,
35 | n_head=n_head,
36 | n_embd=n_embd,
37 | )
38 |
39 | self.pkeep = pkeep
40 |
41 | @torch.no_grad()
42 | def encode_to_z(self, x: torch.tensor) -> torch.tensor:
43 | """Processes the input batch ( containing images ) to encoder and returning flattened quantized encodings
44 |
45 | Args:
46 | x (torch.tensor): the input batch b*c*h*w
47 |
48 | Returns:
49 | torch.tensor: the flattened quantized encodings
50 | """
51 | quant_z, indices, _ = self.vqgan.encode(x)
52 | indices = indices.view(quant_z.shape[0], -1)
53 | return quant_z, indices
54 |
55 | @torch.no_grad()
56 | def z_to_image(
57 | self, indices: torch.tensor, p1: int = 16, p2: int = 16
58 | ) -> torch.Tensor:
59 | """Returns the decoded image from the indices for the codebook embeddings
60 |
61 | Args:
62 | indices (torch.tensor): the indices of the vectors in codebook to use for generating the decoder output
63 | p1 (int, optional): encoding size. Defaults to 16.
64 | p2 (int, optional): encoding size. Defaults to 16.
65 |
66 | Returns:
67 | torch.tensor: generated image from decoder
68 | """
69 |
70 | ix_to_vectors = self.vqgan.codebook.codebook(indices).reshape(
71 | indices.shape[0], p1, p2, 256
72 | )
73 | ix_to_vectors = ix_to_vectors.permute(0, 3, 1, 2)
74 | image = self.vqgan.decode(ix_to_vectors)
75 | return image
76 |
77 | def forward(self, x:torch.Tensor) -> torch.Tensor:
78 | """
79 | transformer model forward pass
80 |
81 | Args:
82 | x (torch.tensor): Batch of images
83 | """
84 |
85 | # Getting the codebook indices of the image
86 | _, indices = self.encode_to_z(x)
87 |
88 | # sos tokens, this will be needed when we will generate new and unseen images
89 | sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token
90 | sos_tokens = sos_tokens.long().to(self.device)
91 |
92 | # Generating a matrix of shape indices with 1s and 0s
93 | mask = torch.bernoulli(
94 | self.pkeep * torch.ones(indices.shape, device=indices.device)
95 | ) # torch.bernoulli([0.5 ... 0.5]) -> [1, 0, 1, 1, 0, 0] ; p(1) - 0.5
96 | mask = mask.round().to(dtype=torch.int64)
97 |
98 | # Generate a vector containing randomlly indices
99 | random_indices = torch.randint_like(
100 | indices, high=self.transformer.config.vocab_size
101 | ) # generating indices from the distribution
102 |
103 | """
104 | indices - [3, 56, 72, 67, 45, 53, 78, 90]
105 | mask - [1, 1, 0, 0, 1, 1, 1, 0]
106 | random_indices - 15, 67, 27, 89, 92, 40, 91, 10]
107 |
108 | new_indices - [ 3, 56, 0, 0, 45, 53, 78, 0] + [ 0, 0, 27, 89, 0, 0, 0, 10] => [ 3, 56, 27, 89, 45, 53, 78, 10]
109 | """
110 | new_indices = mask * indices + (1 - mask) * random_indices
111 |
112 | # Adding sos ( start of sentence ) token
113 | new_indices = torch.cat((sos_tokens, new_indices), dim=1)
114 |
115 | target = indices
116 |
117 | logits, _ = self.transformer(new_indices[:, :-1])
118 |
119 | return logits, target
120 |
121 | def top_k_logits(self, logits: torch.Tensor, k: int) -> torch.Tensor:
122 | """
123 |
124 | Args:
125 | logits (torch.Tensor): predictions from the transformer
126 | k (int): returning k highest values
127 |
128 | Returns:
129 | torch.Tensor: retuning tensor of same dimension as input keeping the top k entries
130 | """
131 | v, ix = torch.topk(logits, k)
132 | out = logits.clone()
133 | out[out < v[..., [-1]]] = -float(
134 | "inf"
135 | ) # Setting all values except in topk to inf
136 | return out
137 |
138 | @torch.no_grad()
139 | def sample(
140 | self,
141 | x: torch.Tensor,
142 | c: torch.Tensor,
143 | steps: int = 256,
144 | temperature: float = 1.0,
145 | top_k: int = 100,
146 | ) -> torch.Tensor:
147 | """Generating sample indices from the transformer
148 |
149 | Args:
150 | x (torch.Tensor): the batch of images
151 | c (torch.Tensor): sos token
152 | steps (int, optional): the lenght of indices to generate. Defaults to 256.
153 | temperature (float, optional): hyperparameter for minGPT model. Defaults to 1.0.
154 | top_k (int, optional): keeping top k entries. Defaults to 100.
155 |
156 | Returns:
157 | torch.Tensor: _description_
158 | """
159 |
160 | self.transformer.eval()
161 |
162 | x = torch.cat((c, x), dim=1) # Appending sos token
163 | for k in range(steps):
164 | logits, _ = self.transformer(x) # Getting the predicted indices
165 | logits = (
166 | logits[:, -1, :] / temperature
167 | ) # Getting the last prediction and scaling it by temperature
168 |
169 | if top_k is not None:
170 | logits = self.top_k_logits(logits, top_k)
171 |
172 | probs = F.softmax(logits, dim=-1)
173 |
174 | ix = torch.multinomial(
175 | probs, num_samples=1
176 | ) # Note : not sure what's happening here
177 |
178 | x = torch.cat((x, ix), dim=1)
179 |
180 | x = x[:, c.shape[1] :] # Removing the sos token
181 | self.transformer.train()
182 | return x
183 |
184 | @torch.no_grad()
185 | def log_images(self, x:torch.Tensor):
186 | """ Generating images using the transformer and decoder. Also uses encoder to complete partial images.
187 |
188 | Args:
189 | x (torch.Tensor): batch of images
190 |
191 | Returns:
192 | Retures the input and generated image in dictionary and in a simple concatenated image
193 | """
194 | log = dict()
195 |
196 | _, indices = self.encode_to_z(x) # Getting the indices of the quantized encoding
197 | sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token
198 | sos_tokens = sos_tokens.long().to(self.device)
199 |
200 | start_indices = indices[:, : indices.shape[1] // 2]
201 | sample_indices = self.sample(
202 | start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1]
203 | )
204 | half_sample = self.z_to_image(sample_indices)
205 |
206 | start_indices = indices[:, :0]
207 | sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1])
208 | full_sample = self.z_to_image(sample_indices)
209 |
210 | x_rec = self.z_to_image(indices)
211 |
212 | log["input"] = x
213 | log["rec"] = x_rec
214 | log["half_sample"] = half_sample
215 | log["full_sample"] = full_sample
216 |
217 | return log, torch.concat((x, x_rec, half_sample, full_sample))
218 |
219 | def load_checkpoint(self, path):
220 | """Loads the checkpoint from the given path."""
221 |
222 | self.load_state_dict(torch.load(path))
223 |
224 | def save_checkpoint(self, path):
225 | """Saves the checkpoint to the given path."""
226 |
227 | torch.save(self.state_dict(), path)
228 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import weights_init, print_summary, generate_gif, clean_directory, reproducibility, collate_fn
--------------------------------------------------------------------------------
/utils/assets/aim_images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/aim_images.png
--------------------------------------------------------------------------------
/utils/assets/aim_metrics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/aim_metrics.png
--------------------------------------------------------------------------------
/utils/assets/encoder_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/encoder_arch.png
--------------------------------------------------------------------------------
/utils/assets/nonlocalblocks_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/nonlocalblocks_arch.png
--------------------------------------------------------------------------------
/utils/assets/patchgan_disc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/patchgan_disc.png
--------------------------------------------------------------------------------
/utils/assets/perceptual_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/perceptual_loss.png
--------------------------------------------------------------------------------
/utils/assets/reconstruction.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/reconstruction.gif
--------------------------------------------------------------------------------
/utils/assets/sliding_window.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/sliding_window.png
--------------------------------------------------------------------------------
/utils/assets/stage_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/stage_1.png
--------------------------------------------------------------------------------
/utils/assets/stage_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/stage_2.png
--------------------------------------------------------------------------------
/utils/assets/vqgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/vqgan.png
--------------------------------------------------------------------------------
/utils/assets/vqvae_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shubhamai/pytorch-vqgan/120ae164f770cbf35a1ff3a51d95d707b29ef841/utils/assets/vqvae_arch.png
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import glob
3 | import os
4 | import random
5 | import shutil
6 |
7 | import imageio
8 | import numpy as np
9 | import torch
10 | from torchsummary import summary
11 |
12 |
13 | def print_summary(
14 | model: torch.nn.Module,
15 | input_data: torch.Tensor,
16 | col_names: list = ["input_size", "output_size", "num_params"],
17 | device: str = "cpu",
18 | depth: int = 2,
19 | ):
20 | """
21 | Prints a summary of the model.
22 | """
23 | return summary(
24 | model, input_data=input_data, col_names=col_names, device=device, depth=depth
25 | )
26 |
27 |
28 | def weights_init(m):
29 | """Setting up the weights for the discriminator model.
30 | This is mentioned in the original PatchGAN paper, in page 16, section 6.2 - Training Details
31 |
32 | ```
33 | All networks were trained from scratch. Weights were initialized from a Gaussian distribution with mean 0 and
34 | standard deviation 0.02.
35 | ```
36 |
37 | Image-to-Image Translation with Conditional Adversarial Network - https://arxiv.org/pdf/1611.07004v3.pdf
38 |
39 | Args:
40 | m
41 | """
42 |
43 | classname = m.__class__.__name__
44 | if classname.find("Conv") != -1:
45 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
46 | elif classname.find("BatchNorm") != -1:
47 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
48 | torch.nn.init.constant_(m.bias.data, 0)
49 |
50 |
51 | def generate_gif(imgs_path: str, save_path: str):
52 | """Generates a gif from a directory of images.
53 |
54 | Args:
55 | imgs_path: Path to the directory of images.
56 | save_path: Path to save the gif.
57 | """
58 |
59 | with imageio.get_writer(save_path, mode="I") as writer:
60 | for filename in glob.glob(imgs_path + "/*.jpg"):
61 | image = imageio.imread(filename)
62 | writer.append_data(image)
63 |
64 |
65 | def clean_directory(directory: str):
66 | """Cleans a directory.
67 | Args:
68 | directory: Path to the directory.
69 | """
70 |
71 | if os.path.exists(directory):
72 | shutil.rmtree(directory)
73 | os.mkdir(directory)
74 |
75 |
76 | def reproducibility(seed: int = 42):
77 | """Set the random seed.
78 |
79 | Args:
80 | seed (int): The seed to use.
81 |
82 | Returns:
83 | None
84 | """
85 |
86 | torch.manual_seed(seed)
87 | torch.cuda.manual_seed(seed)
88 |
89 | np.random.seed(seed)
90 | random.seed(seed)
91 |
92 |
93 | def collate_fn(batch):
94 | """
95 | Collate function for the dataloader like mnist or cifar10.
96 | """
97 |
98 | imgs = torch.stack([img[0] for img in batch])
99 |
100 | return imgs
101 |
--------------------------------------------------------------------------------
/vqgan/__init__.py:
--------------------------------------------------------------------------------
1 | from vqgan.encoder import Encoder
2 | from vqgan.decoder import Decoder
3 | from vqgan.codebook import CodeBook
4 | from vqgan.discriminator import Discriminator
5 | from vqgan.vqgan import VQGAN
6 |
--------------------------------------------------------------------------------
/vqgan/codebook.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/dome272/VQGAN-pytorch/blob/main/codebook.py
3 |
4 | Contains the implementation of the codebook for the VQGAN.
5 | With each forward pass, it returns the loss, indices of min distance latent vectors between codebook and encoder output and latent vector with minimim distance.
6 | """
7 |
8 | # Importing Libraries
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class CodeBook(nn.Module):
14 | """
15 | This is class, we are mostly implemented para 3.1 from the paper,
16 |
17 | We generate the codebook from nn.Embeddings of given size and randomly initialize the weights in uniform distribution.
18 |
19 | The `forward` method is mostly to calculates
20 | 1. the nearest vector in the codebook from the given latent vector by the encoder.
21 | 2. The index of the nearest vector in the codebook.
22 | 3. loss ( from eq. 4 ) ( except reconstruction loss )
23 |
24 | Args:
25 | num_codebook_vectors (int): Number of codebook vectors.
26 | latent_dim (int): Latent dimension of individual vectors.
27 | beta (int): Beta value for the commitment loss.
28 | """
29 |
30 | def __init__(
31 | self, num_codebook_vectors: int = 1024, latent_dim: int = 256, beta: int = 0.25
32 | ):
33 | super().__init__()
34 |
35 | self.num_codebook_vectors = num_codebook_vectors
36 | self.latent_dim = latent_dim
37 | self.beta = beta
38 |
39 | # creating the codebook, nn.Embedding here is simply a 2D array mainly for storing our embeddings, it's also learnable
40 | self.codebook = nn.Embedding(num_codebook_vectors, latent_dim)
41 |
42 | # Initializing the weights in codebook in uniform distribution
43 | self.codebook.weight.data.uniform_(
44 | -1 / num_codebook_vectors, 1 / num_codebook_vectors
45 | )
46 |
47 | def forward(self, z: torch.Tensor) -> torch.Tensor:
48 | """
49 | Calculates the loss and nearest vector in the codebook from the given latent vector.
50 |
51 | We are mostly implementing the eq 2 and 4 ( except reconstruction loss ) from the paper.
52 |
53 | Args:
54 | z (torch.Tensor): Latent vector.
55 | Returns:
56 | torch.Tensor: Nearest vector in the codebook.
57 | torch.Tensor: Index of the nearest vector in the codebook.
58 | torch.Tensor: Loss ( except reconstruction loss ).
59 | """
60 |
61 | # Channel to last dimension and copying the tensor to store it in a contiguous ( in a sequence ) way
62 | z = z.permute(0, 2, 3, 1).contiguous()
63 |
64 | z_flattened = z.view(
65 | -1, self.latent_dim
66 | ) # b*h*w * latent_dim, will look similar to codebook in fig 2 of the paper
67 |
68 | # calculating the distance between the z to the vectors in flattened codebook, from eq. 2
69 | # (a - b)^2 = a^2 + b^2 - 2ab
70 | distance = (
71 | torch.sum(
72 | z_flattened**2, dim=1, keepdim=True
73 | ) # keepdim = True to keep the same original shape after the sum
74 | + torch.sum(self.codebook.weight**2, dim=1)
75 | - 2
76 | * torch.matmul(
77 | z_flattened, self.codebook.weight.t()
78 | ) # 2*dot(z, codebook.T)
79 | )
80 |
81 | # getting indices of vectors with minimum distance from the codebook
82 | min_distance_indices = torch.argmin(distance, dim=1)
83 |
84 | # getting the corresponding vector from the codebook
85 | z_q = self.codebook(min_distance_indices).view(z.shape)
86 |
87 | """
88 | this represent the equation 4 from the paper ( except the reconstruction loss ) . Thia loss will then be added
89 | to GAN loss to create the final loss function for VQGAN, eq. 6 in the paper.
90 |
91 |
92 | Note : In the first para of A. Changlog section of the paper,
93 | they found a bug which resulted in beta equal to 1. here https://github.com/CompVis/taming-transformers/issues/57
94 | just a note :)
95 | """
96 | loss = torch.mean(
97 | (z_q.detach() - z) ** 2
98 | # detach() to avoid calculating gradient while backpropagating
99 | + self.beta
100 | * torch.mean(
101 | (z_q - z.detach()) ** 2
102 | ) # commitment loss, detach() to avoid calculating gradient while backpropagating
103 | )
104 |
105 | # Not sure why we need this, but it's in the original implementation and mentions for "preserving gradients"
106 | z_q = z + (z_q - z).detach()
107 |
108 | # reshapring to the original shape
109 | z_q = z_q.permute(0, 3, 1, 2)
110 |
111 | return z_q, min_distance_indices, loss
112 |
--------------------------------------------------------------------------------
/vqgan/common.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/dome272/VQGAN-pytorch/blob/main/helper.py
3 |
4 | The file contains Swish, Group Norm, Residual & Non-Local Blocks, Upsample and Downsample layer for VQGAN encoder and decoder blocks.
5 | """
6 |
7 | # Importing Libraries
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class Swish(nn.Module):
13 | """Swish activation function first introduced in the paper https://arxiv.org/abs/1710.05941v2
14 | It has shown to be working better in many datasets as compares to ReLU.
15 | """
16 |
17 | def __init__(self) -> None:
18 | super().__init__()
19 |
20 | def forward(self, x: torch.Tensor) -> torch.Tensor:
21 |
22 | return x * torch.sigmoid(x)
23 |
24 |
25 | class GroupNorm(nn.Module):
26 | """Group Normalization is a method which normalizes the activation of the layer for better results across any batch size.
27 | Note : Weight Standardization is also shown to given better results when added with group norm
28 |
29 | Args:
30 | in_channels (int): Number of channels in the input tensor.
31 | """
32 |
33 | def __init__(self, in_channels: int) -> None:
34 | super().__init__()
35 |
36 | # num_groups is according to the official code provided by the authors,
37 | # eps is for numerical stability
38 | # i think affine here is enabling learnable param for affine trasnform on calculated mean & standard deviation
39 | self.group_norm = nn.GroupNorm(
40 | num_groups=32, num_channels=in_channels, eps=1e-06, affine=True
41 | )
42 |
43 | def forward(self, x: torch.Tensor) -> torch.Tensor:
44 | return self.group_norm(x)
45 |
46 |
47 | class ResidualBlock(nn.Module):
48 | """Residual Block from the paper,
49 | group norm -> swish -> conv -> group norm -> swish -> conv -> dropout -> conv -> skip connection
50 |
51 | Args:
52 | in_channels (int): Number of channels in the input tensor.
53 | out_channels (int): Number of channels in the output tensor.
54 | dropout (float): Dropout probability.
55 | """
56 |
57 | def __init__(self, in_channels:int, out_channels:int, dropout:float=0.0) -> None:
58 | super().__init__()
59 |
60 | self.in_channels = in_channels
61 | self.out_channels = out_channels
62 |
63 | self.block = nn.Sequential(
64 | GroupNorm(in_channels),
65 | Swish(),
66 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
67 | GroupNorm(out_channels),
68 | Swish(),
69 | nn.Dropout(dropout),
70 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
71 | )
72 |
73 | """
74 | In some cases, the shortcut connection needs to be added
75 | to match the dimension of the input and the output for skip connection
76 | """
77 | if in_channels != out_channels:
78 | self.conv_shortcut = nn.Conv2d(
79 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
80 | )
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 |
84 | # shortcut connection
85 | if self.in_channels != self.out_channels:
86 | return self.conv_shortcut(x) + self.block(x)
87 | else:
88 | return x + self.block(x)
89 |
90 |
91 | class DownsampleBlock(nn.Module):
92 | """
93 | Down sample block for the encoder. pad -> conv
94 |
95 | Args:
96 | in_channels (int): Number of channels in the input tensor.
97 | """
98 |
99 | def __init__(self, in_channels:int) -> None:
100 | super().__init__()
101 |
102 | # (0, 1, 0, 1) - pad on left, right, top, bottom, with respective size
103 | self.pad = nn.ConstantPad2d((0, 1, 0, 1), value=0) # and fill value of 0
104 |
105 | self.conv = nn.Conv2d(
106 | in_channels, in_channels, kernel_size=3, stride=2, padding=0
107 | )
108 |
109 | def forward(self, x: torch.Tensor) -> torch.Tensor:
110 |
111 | x = self.pad(x)
112 |
113 | return self.conv(x)
114 |
115 |
116 | class UpsampleBlock(nn.Module):
117 | """
118 | Upsample block for the decoder. interpolate -> conv
119 |
120 | Args:
121 | in_channels (int): Number of channels in the input tensor.
122 | """
123 |
124 | def __init__(self, in_channels:int) -> None:
125 | super().__init__()
126 |
127 | self.conv = nn.Conv2d(
128 | in_channels, in_channels, kernel_size=3, stride=1, padding=1
129 | )
130 |
131 | def forward(self, x: torch.Tensor) -> torch.Tensor:
132 |
133 | x = torch.nn.functional.interpolate(x, scale_factor=2.0)
134 |
135 | return self.conv(x)
136 |
137 |
138 | class NonLocalBlock(nn.Module):
139 | """Attention mechanism similar to transformers but for CNNs, paper https://arxiv.org/abs/1805.08318
140 |
141 | Args:
142 | in_channels (int): Number of channels in the input tensor.
143 | """
144 |
145 | def __init__(self, in_channels:int) -> None:
146 | super().__init__()
147 |
148 | self.in_channels = in_channels
149 |
150 | # normalization layer
151 | self.norm = GroupNorm(in_channels)
152 |
153 | # query, key and value layers
154 | self.q = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
155 | self.k = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
156 | self.v = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
157 |
158 | self.project_out = nn.Conv2d(in_channels, in_channels, 1, 1, 0)
159 |
160 | self.softmax = nn.Softmax(dim=2)
161 |
162 | def forward(self, x):
163 |
164 | batch, _, height, width = x.size()
165 |
166 | x = self.norm(x)
167 |
168 | # query, key and value layers
169 | q = self.q(x)
170 | k = self.k(x)
171 | v = self.v(x)
172 |
173 | # resizing the output from 4D to 3D to generate attention map
174 | q = q.reshape(batch, self.in_channels, height * width)
175 | k = k.reshape(batch, self.in_channels, height * width)
176 | v = v.reshape(batch, self.in_channels, height * width)
177 |
178 | # transpose the query tensor for dot product
179 | q = q.permute(0, 2, 1)
180 |
181 | # main attention formula
182 | scores = torch.bmm(q, k) * (self.in_channels**-0.5)
183 | weights = self.softmax(scores)
184 | weights = weights.permute(0, 2, 1)
185 |
186 | attention = torch.bmm(v, weights)
187 |
188 | # resizing the output from 3D to 4D to match the input
189 | attention = attention.reshape(batch, self.in_channels, height, width)
190 | attention = self.project_out(attention)
191 |
192 | # adding the identity to the output
193 | return x + attention
194 |
--------------------------------------------------------------------------------
/vqgan/decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/dome272/VQGAN-pytorch/blob/main/decoder.py
3 |
4 | Contains the decoder implementation of VQGAN.
5 |
6 | The decoder architecture is also highly inspired by the - Denoising Diffusion Probabilistic Models - https://arxiv.org/abs/2006.11239
7 | According to the official implementation.
8 | """
9 |
10 | # Importing Libraries
11 | import torch
12 | import torch.nn as nn
13 |
14 | from vqgan.common import GroupNorm, NonLocalBlock, ResidualBlock, Swish, UpsampleBlock
15 |
16 |
17 | class Decoder(nn.Module):
18 | """
19 | The decoder part of the VQGAN.
20 |
21 | The implementation is similar to the encoder but inverse, to produce an image from a latent vector.
22 |
23 | Args:
24 | img_channels (int): Number of channels in the output image.
25 | latent_channels (int): Number of channels in the latent vector.
26 | latent_size (int): Size of the latent vector.
27 | intermediate_channels (list): List of channels in the intermediate layers.
28 | num_residual_blocks (int): Number of residual blocks b/w each downsample block.
29 | dropout (float): Dropout probability for residual blocks.
30 | attention_resolution (list): tensor size ( height or width ) at which to add attention blocks
31 | """
32 |
33 | def __init__(
34 | self,
35 | img_channels: int = 3,
36 | latent_channels: int = 256,
37 | latent_size: int = 16,
38 | intermediate_channels: list = [128, 128, 256, 256, 512],
39 | num_residual_blocks: int = 3,
40 | dropout: float = 0.0,
41 | attention_resolution: list = [16],
42 | ):
43 | super().__init__()
44 |
45 | # Reverse the list to get the correct order of decoder layer channels
46 | intermediate_channels = intermediate_channels[::-1]
47 |
48 | # Appends all the layers to this list
49 | layers = []
50 |
51 | # Adding the first conv layer to increase the input channels to the first intermediate channels
52 | in_channels = intermediate_channels[0]
53 | layers.extend(
54 | [
55 | nn.Conv2d(
56 | latent_channels,
57 | intermediate_channels[0],
58 | kernel_size=3,
59 | stride=1,
60 | padding=1,
61 | ),
62 | ResidualBlock(
63 | in_channels=in_channels, out_channels=in_channels, dropout=dropout
64 | ),
65 | NonLocalBlock(in_channels=in_channels),
66 | ResidualBlock(
67 | in_channels=in_channels, out_channels=in_channels, dropout=dropout
68 | ),
69 | ]
70 | )
71 |
72 | # Loop over the intermediate channels
73 | for n in range(len(intermediate_channels)):
74 | out_channels = intermediate_channels[n]
75 |
76 | # adding the residual blocks
77 | for _ in range(num_residual_blocks):
78 | layers.append(ResidualBlock(in_channels, out_channels, dropout=dropout))
79 | in_channels = out_channels
80 |
81 | # adding the non local block
82 | if latent_size in attention_resolution:
83 | layers.append(NonLocalBlock(in_channels))
84 |
85 | # Due to conv in first layer, do not upsample
86 | if n != 0:
87 | layers.append(UpsampleBlock(in_channels=in_channels))
88 | latent_size = latent_size * 2 # Upsample by a factor of 2
89 |
90 | # Adding rest of the layers
91 | layers.extend(
92 | [
93 | GroupNorm(in_channels=in_channels),
94 | Swish(),
95 | nn.Conv2d(
96 | in_channels, img_channels, kernel_size=3, stride=1, padding=1
97 | ),
98 | ]
99 | )
100 | self.model = nn.Sequential(*layers)
101 |
102 | def forward(self, x: torch.Tensor) -> torch.Tensor:
103 | return self.model(x)
104 |
--------------------------------------------------------------------------------
/vqgan/discriminator.py:
--------------------------------------------------------------------------------
1 | """
2 | PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538)
3 |
4 |
5 | This isn't a standward GAN discriminator, where the input is a batch of images and the output is a batch of real/fake labels.
6 |
7 |
8 | But instead, PatchGAN discriminator is a network that takes a batch of images
9 | split into multiple patches
10 |
11 | for ex. - 30-30 patches, 30 in x and 30 in y axis, similar to convolution kernels,
12 | and then runs them through a network to get a score of real/fake on those individual patches.
13 |
14 | ex. - input size (1, 3, 256, 256) -> output size (1, 1, 30, 30)
15 |
16 | """
17 |
18 | import torch.nn as nn
19 |
20 |
21 | class Discriminator(nn.Module):
22 | """ PatchGAN Discriminator
23 |
24 |
25 | Args:
26 | image_channels (int): Number of channels in the input image.
27 | num_filters_last (int): Number of filters in the last layer of the discriminator.
28 | n_layers (int): Number of layers in the discriminator.
29 |
30 |
31 | """
32 |
33 | def __init__(self, image_channels:int=3, num_filters_last=64, n_layers=3):
34 | super(Discriminator, self).__init__()
35 |
36 | layers = [
37 | nn.Conv2d(image_channels, num_filters_last, 4, 2, 1),
38 | nn.LeakyReLU(0.2),
39 | ]
40 | num_filters_mult = 1
41 |
42 | for i in range(1, n_layers + 1):
43 | num_filters_mult_last = num_filters_mult
44 | num_filters_mult = min(2**i, 8)
45 | layers += [
46 | nn.Conv2d(
47 | num_filters_last * num_filters_mult_last,
48 | num_filters_last * num_filters_mult,
49 | 4,
50 | 2 if i < n_layers else 1,
51 | 1,
52 | bias=False,
53 | ),
54 | nn.BatchNorm2d(num_filters_last * num_filters_mult),
55 | nn.LeakyReLU(0.2, True),
56 | ]
57 |
58 | layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
59 | self.model = nn.Sequential(*layers)
60 |
61 | def forward(self, x):
62 | return self.model(x)
--------------------------------------------------------------------------------
/vqgan/encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/dome272/VQGAN-pytorch/blob/main/encoder.py
3 |
4 | Contains the encoder implementation of VQGAN.
5 |
6 | The encoder is highly inspired by the - Denoising Diffusion Probabilistic Models - https://arxiv.org/abs/2006.11239
7 | According to the official implementation.
8 | """
9 |
10 | # Importing Libraries
11 | import torch
12 | import torch.nn as nn
13 |
14 | from vqgan.common import DownsampleBlock, GroupNorm, NonLocalBlock, ResidualBlock, Swish
15 |
16 |
17 | class Encoder(nn.Module):
18 | """
19 | The encoder part of the VQGAN.
20 |
21 | Args:
22 | img_channels (int): Number of channels in the input image.
23 | image_size (int): Size of the input image, only used in encoder (height or width ).
24 | latent_channels (int): Number of channels in the latent vector.
25 | intermediate_channels (list): List of channels in the intermediate layers.
26 | num_residual_blocks (int): Number of residual blocks b/w each downsample block.
27 | dropout (float): Dropout probability for residual blocks.
28 | attention_resolution (list): tensor size ( height or width ) at which to add attention blocks
29 | """
30 |
31 | def __init__(
32 | self,
33 | img_channels: int = 3,
34 | image_size: int = 256,
35 | latent_channels: int = 256,
36 | intermediate_channels: list = [128, 128, 256, 256, 512],
37 | num_residual_blocks: int = 2,
38 | dropout: float = 0.0,
39 | attention_resolution: list = [16],
40 | ):
41 | super().__init__()
42 |
43 | # Inserting first intermediate channel to index 0
44 | intermediate_channels.insert(0, intermediate_channels[0])
45 |
46 | # Appends all the layers to this list
47 | layers = []
48 |
49 | # Addingt the first conv layer increase input channels to the first intermediate channels
50 | layers.append(
51 | nn.Conv2d(
52 | img_channels,
53 | intermediate_channels[0],
54 | kernel_size=3,
55 | stride=1,
56 | padding=1,
57 | )
58 | )
59 |
60 | # Loop over the intermediate channels except the last one
61 | for n in range(len(intermediate_channels) - 1):
62 | in_channels = intermediate_channels[n]
63 | out_channels = intermediate_channels[n + 1]
64 |
65 | # Adding the residual blocks for each channel
66 | for _ in range(num_residual_blocks):
67 | layers.append(ResidualBlock(in_channels, out_channels, dropout=dropout))
68 | in_channels = out_channels
69 |
70 | # Once we have downsampled the image to the size in attention resolution, we add attention blocks
71 | if image_size in attention_resolution:
72 | layers.append(NonLocalBlock(in_channels))
73 |
74 | # only downsample for the first n-2 layers, and decrease the input size by a factor of 2
75 | if n != len(intermediate_channels) - 2:
76 | layers.append(DownsampleBlock(in_channels=intermediate_channels[n + 1]))
77 | image_size = image_size // 2 # Downsample by a factor of 2
78 |
79 | in_channels = intermediate_channels[-1]
80 | layers.extend(
81 | [
82 | ResidualBlock(
83 | in_channels=in_channels, out_channels=in_channels, dropout=dropout
84 | ),
85 | NonLocalBlock(in_channels=in_channels),
86 | ResidualBlock(
87 | in_channels=in_channels, out_channels=in_channels, dropout=dropout
88 | ),
89 | GroupNorm(in_channels=in_channels),
90 | Swish(),
91 | # increase the channels upto the latent vector channels
92 | nn.Conv2d(
93 | in_channels, latent_channels, kernel_size=3, stride=1, padding=1
94 | ),
95 | ]
96 | )
97 | self.model = nn.Sequential(*layers)
98 |
99 | def forward(self, x: torch.Tensor) -> torch.Tensor:
100 | return self.model(x)
101 |
--------------------------------------------------------------------------------
/vqgan/vqgan.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/dome272/VQGAN-pytorch/blob/main/vqgan.py
3 |
4 | Implementing the main VQGAN, containing forward pass, lambda calculation, and to "enable" discriminator loss after a certain number of global steps.
5 | """
6 |
7 | # Importing Libraries
8 | import torch
9 | import torch.nn as nn
10 |
11 | from vqgan import Encoder
12 | from vqgan import Decoder
13 | from vqgan import CodeBook
14 |
15 |
16 | class VQGAN(nn.Module):
17 | """
18 | VQGAN class
19 |
20 | Args:
21 | img_channels (int, optional): Number of channels in the input image. Defaults to 3.
22 | img_size (int, optional): Size of the input image. Defaults to 256.
23 | latent_channels (int, optional): Number of channels in the latent vector. Defaults to 256.
24 | latent_size (int, optional): Size of the latent vector. Defaults to 16.
25 | intermediate_channels (list, optional): List of channels in the intermediate layers of encoder and decoder. Defaults to [128, 128, 256, 256, 512].
26 | num_residual_blocks_encoder (int, optional): Number of residual blocks in the encoder. Defaults to 2.
27 | num_residual_blocks_decoder (int, optional): Number of residual blocks in the decoder. Defaults to 3.
28 | dropout (float, optional): Dropout probability. Defaults to 0.0.
29 | attention_resolution (list, optional): Resolution of the attention mechanism. Defaults to [16].
30 | num_codebook_vectors (int, optional): Number of codebook vectors. Defaults to 1024.
31 | """
32 |
33 | def __init__(
34 | self,
35 | img_channels: int = 3,
36 | img_size: int = 256,
37 | latent_channels: int = 256,
38 | latent_size: int = 16,
39 | intermediate_channels: list = [128, 128, 256, 256, 512],
40 | num_residual_blocks_encoder: int = 2,
41 | num_residual_blocks_decoder: int = 3,
42 | dropout: float = 0.0,
43 | attention_resolution: list = [16],
44 | num_codebook_vectors: int = 1024,
45 | ):
46 |
47 | super().__init__()
48 |
49 | self.img_channels = img_channels
50 | self.num_codebook_vectors = num_codebook_vectors
51 |
52 | self.encoder = Encoder(
53 | img_channels=img_channels,
54 | image_size=img_size,
55 | latent_channels=latent_channels,
56 | intermediate_channels=intermediate_channels[:], # shallow copy of the link
57 | num_residual_blocks=num_residual_blocks_encoder,
58 | dropout=dropout,
59 | attention_resolution=attention_resolution,
60 | )
61 |
62 | self.decoder = Decoder(
63 | img_channels=img_channels,
64 | latent_channels=latent_channels,
65 | latent_size=latent_size,
66 | intermediate_channels=intermediate_channels[:], # shallow copy of the link
67 | num_residual_blocks=num_residual_blocks_decoder,
68 | dropout=dropout,
69 | attention_resolution=attention_resolution,
70 | )
71 | self.codebook = CodeBook(
72 | num_codebook_vectors=num_codebook_vectors, latent_dim=latent_channels
73 | )
74 |
75 | self.quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
76 | self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
77 |
78 | def forward(self, x: torch.Tensor) -> torch.Tensor:
79 | """Performs a single step of training on the input tensor x
80 |
81 | Args:
82 | x (torch.Tensor): Input tensor to the encoder.
83 |
84 | Returns:
85 | torch.Tensor: Output tensor from the decoder.
86 | """
87 |
88 | encoded_images = self.encoder(x)
89 | quant_x = self.quant_conv(encoded_images)
90 |
91 | codebook_mapping, codebook_indices, codebook_loss = self.codebook(quant_x)
92 |
93 | post_quant_x = self.post_quant_conv(codebook_mapping)
94 | decoded_images = self.decoder(post_quant_x)
95 |
96 | return decoded_images, codebook_indices, codebook_loss
97 |
98 | def encode(self, x: torch.Tensor) -> torch.Tensor:
99 |
100 | x = self.encoder(x)
101 | quant_x = self.quant_conv(x)
102 |
103 | codebook_mapping, codebook_indices, q_loss = self.codebook(quant_x)
104 |
105 | return codebook_mapping, codebook_indices, q_loss
106 |
107 | def decode(self, x: torch.Tensor) -> torch.Tensor:
108 |
109 | x = self.post_quant_conv(x)
110 | x = self.decoder(x)
111 |
112 | return x
113 |
114 | def calculate_lambda(self, perceptual_loss, gan_loss):
115 | """Calculating lambda shown in the eq. 7 of the paper
116 |
117 | Args:
118 | perceptual_loss (torch.Tensor): Perceptual reconstruction loss.
119 | gan_loss (torch.Tensor): loss from the GAN discriminator.
120 | """
121 |
122 | last_layer = self.decoder.model[-1]
123 | last_layer_weight = last_layer.weight
124 |
125 | # Because we have multiple loss functions in the networks, retain graph helps to keep the computational graph for backpropagation
126 | # https://stackoverflow.com/a/47174709
127 | perceptual_loss_grads = torch.autograd.grad(
128 | perceptual_loss, last_layer_weight, retain_graph=True
129 | )[0]
130 | gan_loss_grads = torch.autograd.grad(
131 | gan_loss, last_layer_weight, retain_graph=True
132 | )[0]
133 |
134 | lmda = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
135 | lmda = torch.clamp(
136 | lmda, 0, 1e4
137 | ).detach() # Here, we are constraining the value of lambda between 0 and 1e4,
138 |
139 | return 0.8 * lmda # Note: not sure why we are multiplying it by 0.8... ?
140 |
141 | @staticmethod
142 | def adopt_weight(
143 | disc_factor: float, i: int, threshold: int, value: float = 0.0
144 | ) -> float:
145 | """Starting the discrimator later in training, so that our model has enough time to generate "good-enough" images to try to "fool the discrimator".
146 |
147 | To do that, we before eaching a certain global step, set the discriminator factor by `value` ( default 0.0 ) .
148 | This discriminator factor is then used to multiply the discriminator's loss.
149 |
150 | Args:
151 | disc_factor (float): This value is multiple to the discriminator's loss.
152 | i (int): The current global step
153 | threshold (int): The global step after which the `disc_factor` value is retured.
154 | value (float, optional): The value of discriminator factor before the threshold is reached. Defaults to 0.0.
155 |
156 | Returns:
157 | float: The discriminator factor.
158 | """
159 |
160 | if i < threshold:
161 | disc_factor = value
162 |
163 | return disc_factor
164 |
165 | def load_checkpoint(self, path):
166 | """Loads the checkpoint from the given path."""
167 |
168 | self.load_state_dict(torch.load(path))
169 |
170 | def save_checkpoint(self, path):
171 | """Saves the checkpoint to the given path."""
172 |
173 | torch.save(self.state_dict(), path)
174 |
--------------------------------------------------------------------------------