├── .gitignore
├── README.md
├── config
└── celebhq.yaml
├── dataset
├── __init__.py
└── celeb_dataset.py
├── model
├── __init__.py
├── attention.py
├── blocks.py
├── discriminator.py
├── lpips.py
├── patch_embed.py
├── transformer.py
├── transformer_layer.py
├── vae.py
└── weights
│ └── v0.1
│ └── .gitkeep
├── requirements.txt
├── scheduler
├── __init__.py
└── linear_scheduler.py
├── tools
├── __init__.py
├── infer_vae.py
├── sample_vae_dit.py
├── train_vae.py
└── train_vae_dit.py
└── utils
├── __init__.py
└── diffusion_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all image files
2 | *.jpg
3 | *.png
4 | *.jpeg
5 |
6 | # Ignore pycharm and system files
7 | .DS_Store
8 | *.idea
9 | __pycache__
10 | *.zip
11 |
12 | # Ignore dataset files
13 | *.csv
14 | *.json
15 |
16 | # Ignore checkpoints
17 | *.pth
18 |
19 | # Ignore pickle files
20 | *.pkl
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Diffusion Transformers(DiT) Implementation in PyTorch
2 | ========
3 |
4 | ## DiT Tutorial Video
5 |
6 |
8 |
9 |
10 | ## Sample Output for DiT on CelebHQ
11 | Trained for 200 epochs
12 |
13 |
14 |
15 | ___
16 |
17 | This repository implements DiT in PyTorch for diffusion models. It provides code for the following:
18 | * Training and inference of VAE on CelebHQ (128x128 to 32x32x4 latents)
19 | * Training and Inference of DiT using trained VAE on CelebHQ
20 | * Configurable code for training all models from DIT-S to DIT-XL
21 |
22 | This is very similar to [official DiT implementation](https://github.com/facebookresearch/DiT) except the following changes.
23 | * Since training is on celebhq its unconditional generation as of now (but can be easily modified to class conditional or text conditional as well)
24 | * Variance is fixed during training and not learned (like original DDPM)
25 | * No EMA
26 | * Ability to train VAE
27 | * Ability to save latents of VAE for faster training
28 |
29 |
30 | ## Setup
31 | * Create a new conda environment with python 3.10 then run below commands
32 | * `conda activate `
33 | * ```git clone https://github.com/explainingai-code/DiT-PyTorch.git```
34 | * ```cd DiT-PyTorch```
35 | * ```pip install -r requirements.txt```
36 | * Download lpips weights by opening this link in browser(dont use cURL or wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```models/weights/v0.1/vgg.pth```
37 | ___
38 |
39 | ## Data Preparation
40 |
41 | ### CelebHQ
42 | For setting up on CelebHQ, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file).
43 | and add it to `data` directory.
44 | Ensure directory structure is the following
45 | ```
46 | DiT-PyTorch
47 | -> data
48 | -> CelebAMask-HQ
49 | -> CelebA-HQ-img
50 | -> *.jpg
51 |
52 | ```
53 | ---
54 | ## Configuration
55 | Allows you to play with different components of DiT and autoencoder
56 | * ```config/celebhq.yaml``` - Configuration used for celebhq dataset
57 |
58 | Important configuration parameters
59 |
60 | * `autoencoder_acc_steps` : For accumulating gradients if image size is too large and a large batch size cant be used.
61 | * `save_latents` : Enable this to save the latents , during inference of autoencoder. That way DiT training will be faster
62 |
63 | ___
64 | ## Training
65 | The repo provides training and inference for CelebHQ (Unconditional DiT)
66 |
67 | For working on your own dataset:
68 | * Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance)
69 | * Create your own dataset class which will just collect all the filenames and return the image or latent in its getitem method. Look at `celeb_dataset.py` for guidance
70 |
71 | Once the config and dataset is setup:
72 | * First train the auto encoder on your dataset using [this section](#training-autoencoder-for-dit)
73 | * For training and inference of Unconditional DiT follow [this section](#training-unconditional-dit)
74 |
75 | ## Training AutoEncoder for DiT
76 | * For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml`
77 | * For training autoencoder on your own dataset
78 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance)
79 | * Create your own dataset class, similar to celeb_dataset.py
80 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/DiT-PyTorch/blob/main/tools/train_vae.py#L49)
81 | * For training autoencoder run ```python -m tools.train_vae --config config/celebhq.yaml``` for training autoencoder with the desire config file
82 | * For inference make sure `save_latent` is `True` in the config
83 | * For inference run ```python -m tools.infer_vae --config config/celebhq.yaml``` for generating reconstructions and saving latents with right config file.
84 |
85 | ## Training Unconditional DiT
86 | Train the autoencoder first and setup dataset accordingly.
87 |
88 | For training unconditional DiT ensure the right dataset is used in `train_vae_dit.py`
89 | * ```python -m tools.train_vae_dit --config config/celebhq.yaml``` for training unconditional DiT using right config
90 | * ```python -m tools.sample_vae_dit --config config/celebhq.yaml``` for generating images using trained DiT
91 |
92 |
93 | ## Output
94 | Outputs will be saved according to the configuration present in yaml files.
95 |
96 | For every run a folder of ```task_name``` key in config will be created
97 |
98 | During training of autoencoder the following output will be saved
99 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory
100 | * Sample reconstructions in ```task_name/vae_autoencoder_samples```
101 |
102 | During inference of autoencoder the following output will be saved
103 | * Reconstructions for random images in ```task_name```
104 | * Latents will be save in ```task_name/vae_latent_dir_name``` if mentioned in config
105 |
106 | During training and inference of unconditional DiT following output will be saved:
107 | * During training we will save the latest checkpoint in ```task_name``` directory
108 | * During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
109 |
110 |
111 | ## Citations
112 | ```
113 | @misc{peebles2023scalablediffusionmodelstransformers,
114 | title={Scalable Diffusion Models with Transformers},
115 | author={William Peebles and Saining Xie},
116 | year={2023},
117 | eprint={2212.09748},
118 | archivePrefix={arXiv},
119 | primaryClass={cs.CV},
120 | url={https://arxiv.org/abs/2212.09748},
121 | }
122 | ```
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/config/celebhq.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/CelebAMask-HQ'
3 | im_size : 128
4 | im_channels : 3
5 |
6 | diffusion_params:
7 | num_timesteps : 1000
8 | beta_start : 0.0001
9 | beta_end : 0.02
10 |
11 | dit_params:
12 | patch_size : 2
13 | num_layers : 12
14 | hidden_size : 768
15 | num_heads : 12
16 | head_dim : 64
17 | timestep_emb_dim : 768
18 |
19 | autoencoder_params:
20 | z_channels: 4
21 | codebook_size : 8192
22 | down_channels : [128, 256, 384]
23 | mid_channels : [384]
24 | down_sample : [True, True]
25 | attn_down : [False, False]
26 | norm_channels: 32
27 | num_heads: 4
28 | num_down_layers : 2
29 | num_mid_layers : 2
30 | num_up_layers : 2
31 |
32 |
33 | train_params:
34 | seed : 1111
35 | task_name: 'celebhq'
36 | autoencoder_batch_size: 4
37 | autoencoder_epochs: 3
38 | autoencoder_lr: 0.00001
39 | autoencoder_acc_steps: 1
40 | disc_start: 7500
41 | disc_weight: 0.5
42 | codebook_weight: 1
43 | commitment_beta: 0.2
44 | perceptual_weight: 1
45 | kl_weight: 0.000005
46 | autoencoder_img_save_steps: 64
47 | save_latents: False
48 | dit_batch_size: 32
49 | dit_epochs: 500
50 | num_samples: 1
51 | num_grid_rows: 2
52 | dit_lr: 0.00001
53 | dit_acc_steps: 1
54 | vae_latent_dir_name: 'vae_latents'
55 | dit_ckpt_name: 'dit_ckpt.pth'
56 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
57 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
58 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/celeb_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import cv2
4 | import torchvision
5 | import numpy as np
6 | from PIL import Image
7 | from utils.diffusion_utils import load_latents
8 | from tqdm import tqdm
9 | from torch.utils.data.dataset import Dataset
10 |
11 |
12 | class CelebDataset(Dataset):
13 | r"""
14 | Celeb dataset will by default centre crop and resize the images.
15 | This can be replaced by any other dataset. As long as all the images
16 | are under one directory.
17 | """
18 |
19 | def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
20 | use_latents=False, latent_path=None):
21 | self.split = split
22 | self.im_size = im_size
23 | self.im_channels = im_channels
24 | self.im_ext = im_ext
25 | self.im_path = im_path
26 | self.latent_maps = None
27 | self.use_latents = False
28 |
29 | self.images = self.load_images(im_path)
30 |
31 | # Whether to load images or to load latents
32 | if use_latents and latent_path is not None:
33 | latent_maps = load_latents(latent_path)
34 | if len(latent_maps) == len(self.images):
35 | self.use_latents = True
36 | self.latent_maps = latent_maps
37 | print('Found {} latents'.format(len(self.latent_maps)))
38 | else:
39 | print('Latents not found')
40 |
41 | def load_images(self, im_path):
42 | r"""
43 | Gets all images from the path specified
44 | and stacks them all up
45 | """
46 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
47 | ims = []
48 | fnames = glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('png')))
49 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpg')))
50 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpeg')))
51 |
52 | for fname in tqdm(fnames):
53 | ims.append(fname)
54 |
55 | print('Found {} images'.format(len(ims)))
56 | return ims
57 |
58 | def __len__(self):
59 | return len(self.images)
60 |
61 | def __getitem__(self, index):
62 | if self.use_latents:
63 | latent = self.latent_maps[self.images[index]]
64 | return latent
65 |
66 | else:
67 | im = Image.open(self.images[index])
68 | im_tensor = torchvision.transforms.Compose([
69 | torchvision.transforms.Resize(self.im_size),
70 | torchvision.transforms.CenterCrop(self.im_size),
71 | torchvision.transforms.ToTensor(),
72 | ])(im)
73 | im.close()
74 |
75 | # Convert input to -1 to 1 range.
76 | im_tensor = (2 * im_tensor) - 1
77 |
78 | return im_tensor
79 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/model/__init__.py
--------------------------------------------------------------------------------
/model/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 |
6 | class Attention(nn.Module):
7 | r"""
8 | Attention Module for DiT.
9 | This is same as VIT code and does not have any changes
10 | from it.
11 | """
12 | def __init__(self, config):
13 | super().__init__()
14 | self.n_heads = config['num_heads']
15 | self.hidden_size = config['hidden_size']
16 | self.head_dim = config['head_dim']
17 |
18 | self.att_dim = self.n_heads * self.head_dim
19 |
20 | # QKV projection for the input
21 | self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.att_dim, bias=True)
22 | self.output_proj = nn.Sequential(
23 | nn.Linear(self.att_dim, self.hidden_size))
24 |
25 | ############################
26 | # DiT Layer Initialization #
27 | ############################
28 | nn.init.xavier_uniform_(self.qkv_proj.weight)
29 | nn.init.constant_(self.qkv_proj.bias, 0)
30 | nn.init.xavier_uniform_(self.output_proj[0].weight)
31 | nn.init.constant_(self.output_proj[0].bias, 0)
32 |
33 | def forward(self, x):
34 | # Converting to Attention Dimension
35 | ######################################################
36 | # Batch Size x Number of Patches x Dimension
37 | B, N = x.shape[:2]
38 | # Projecting to 3*att_dim and then splitting to get q, k v(each of att_dim)
39 | # qkv -> Batch Size x Number of Patches x (3* Attention Dimension)
40 | # q(as well as k and v) -> Batch Size x Number of Patches x Attention Dimension
41 | q, k, v = self.qkv_proj(x).split(self.att_dim, dim=-1)
42 | # Batch Size x Number of Patches x Attention Dimension
43 | # -> Batch Size x Number of Patches x (Heads * Head Dimension)
44 | # -> Batch Size x Number of Patches x (Heads * Head Dimension)
45 | # -> Batch Size x Heads x Number of Patches x Head Dimension
46 | # -> B x H x N x Head Dimension
47 | q = rearrange(q, 'b n (n_h h_dim) -> b n_h n h_dim',
48 | n_h=self.n_heads, h_dim=self.head_dim)
49 | k = rearrange(k, 'b n (n_h h_dim) -> b n_h n h_dim',
50 | n_h=self.n_heads, h_dim=self.head_dim)
51 | v = rearrange(v, 'b n (n_h h_dim) -> b n_h n h_dim',
52 | n_h=self.n_heads, h_dim=self.head_dim)
53 | #########################################################
54 |
55 | # Compute Attention Weights
56 | #########################################################
57 | # B x H x N x Head Dimension @ B x H x Head Dimension x N
58 | # -> B x H x N x N
59 | att = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** (-0.5))
60 | att = torch.nn.functional.softmax(att, dim=-1)
61 | #########################################################
62 |
63 | # Weighted Value Computation
64 | #########################################################
65 | # B x H x N x N @ B x H x N x Head Dimension
66 | # -> B x H x N x Head Dimension
67 | out = torch.matmul(att, v)
68 | #########################################################
69 |
70 | # Converting to Transformer Dimension
71 | #########################################################
72 | # B x N x (Heads * Head Dimension) -> B x N x (Attention Dimension)
73 | out = rearrange(out, 'b n_h n h_dim -> b n (n_h h_dim)')
74 | # B x N x Dimension
75 | out = self.output_proj(out)
76 | ##########################################################
77 |
78 | return out
79 |
--------------------------------------------------------------------------------
/model/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def get_time_embedding(time_steps, temb_dim):
6 | r"""
7 | Convert time steps tensor into an embedding using the
8 | sinusoidal time embedding formula
9 | :param time_steps: 1D tensor of length batch size
10 | :param temb_dim: Dimension of the embedding
11 | :return: BxD embedding representation of B time steps
12 | """
13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
14 |
15 | # factor = 10000^(2i/d_model)
16 | factor = 10000 ** ((torch.arange(
17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
18 | )
19 |
20 | # pos / factor
21 | # timesteps B -> B, 1 -> B, temb_dim
22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
24 | return t_emb
25 |
26 |
27 | class DownBlock(nn.Module):
28 | r"""
29 | Down conv block with attention.
30 | Sequence of following block
31 | 1. Resnet block with time embedding
32 | 2. Attention block
33 | 3. Downsample
34 | """
35 |
36 | def __init__(self, in_channels, out_channels, t_emb_dim,
37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
38 | super().__init__()
39 | self.num_layers = num_layers
40 | self.down_sample = down_sample
41 | self.attn = attn
42 | self.context_dim = context_dim
43 | self.cross_attn = cross_attn
44 | self.t_emb_dim = t_emb_dim
45 | self.resnet_conv_first = nn.ModuleList(
46 | [
47 | nn.Sequential(
48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
49 | nn.SiLU(),
50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
51 | kernel_size=3, stride=1, padding=1),
52 | )
53 | for i in range(num_layers)
54 | ]
55 | )
56 | if self.t_emb_dim is not None:
57 | self.t_emb_layers = nn.ModuleList([
58 | nn.Sequential(
59 | nn.SiLU(),
60 | nn.Linear(self.t_emb_dim, out_channels)
61 | )
62 | for _ in range(num_layers)
63 | ])
64 | self.resnet_conv_second = nn.ModuleList(
65 | [
66 | nn.Sequential(
67 | nn.GroupNorm(norm_channels, out_channels),
68 | nn.SiLU(),
69 | nn.Conv2d(out_channels, out_channels,
70 | kernel_size=3, stride=1, padding=1),
71 | )
72 | for _ in range(num_layers)
73 | ]
74 | )
75 |
76 | if self.attn:
77 | self.attention_norms = nn.ModuleList(
78 | [nn.GroupNorm(norm_channels, out_channels)
79 | for _ in range(num_layers)]
80 | )
81 |
82 | self.attentions = nn.ModuleList(
83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
84 | for _ in range(num_layers)]
85 | )
86 |
87 | if self.cross_attn:
88 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
89 | self.cross_attention_norms = nn.ModuleList(
90 | [nn.GroupNorm(norm_channels, out_channels)
91 | for _ in range(num_layers)]
92 | )
93 | self.cross_attentions = nn.ModuleList(
94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
95 | for _ in range(num_layers)]
96 | )
97 | self.context_proj = nn.ModuleList(
98 | [nn.Linear(context_dim, out_channels)
99 | for _ in range(num_layers)]
100 | )
101 |
102 | self.residual_input_conv = nn.ModuleList(
103 | [
104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
105 | for i in range(num_layers)
106 | ]
107 | )
108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
109 | 4, 2, 1) if self.down_sample else nn.Identity()
110 |
111 | def forward(self, x, t_emb=None, context=None):
112 | out = x
113 | for i in range(self.num_layers):
114 | # Resnet block of Unet
115 | resnet_input = out
116 | out = self.resnet_conv_first[i](out)
117 | if self.t_emb_dim is not None:
118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
119 | out = self.resnet_conv_second[i](out)
120 | out = out + self.residual_input_conv[i](resnet_input)
121 |
122 | if self.attn:
123 | # Attention block of Unet
124 | batch_size, channels, h, w = out.shape
125 | in_attn = out.reshape(batch_size, channels, h * w)
126 | in_attn = self.attention_norms[i](in_attn)
127 | in_attn = in_attn.transpose(1, 2)
128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
130 | out = out + out_attn
131 |
132 | if self.cross_attn:
133 | assert context is not None, "context cannot be None if cross attention layers are used"
134 | batch_size, channels, h, w = out.shape
135 | in_attn = out.reshape(batch_size, channels, h * w)
136 | in_attn = self.cross_attention_norms[i](in_attn)
137 | in_attn = in_attn.transpose(1, 2)
138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
139 | context_proj = self.context_proj[i](context)
140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
142 | out = out + out_attn
143 |
144 | # Downsample
145 | out = self.down_sample_conv(out)
146 | return out
147 |
148 |
149 | class MidBlock(nn.Module):
150 | r"""
151 | Mid conv block with attention.
152 | Sequence of following blocks
153 | 1. Resnet block with time embedding
154 | 2. Attention block
155 | 3. Resnet block with time embedding
156 | """
157 |
158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None,
159 | context_dim=None):
160 | super().__init__()
161 | self.num_layers = num_layers
162 | self.t_emb_dim = t_emb_dim
163 | self.context_dim = context_dim
164 | self.cross_attn = cross_attn
165 | self.resnet_conv_first = nn.ModuleList(
166 | [
167 | nn.Sequential(
168 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
169 | nn.SiLU(),
170 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
171 | padding=1),
172 | )
173 | for i in range(num_layers + 1)
174 | ]
175 | )
176 |
177 | if self.t_emb_dim is not None:
178 | self.t_emb_layers = nn.ModuleList([
179 | nn.Sequential(
180 | nn.SiLU(),
181 | nn.Linear(t_emb_dim, out_channels)
182 | )
183 | for _ in range(num_layers + 1)
184 | ])
185 | self.resnet_conv_second = nn.ModuleList(
186 | [
187 | nn.Sequential(
188 | nn.GroupNorm(norm_channels, out_channels),
189 | nn.SiLU(),
190 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
191 | )
192 | for _ in range(num_layers + 1)
193 | ]
194 | )
195 |
196 | self.attention_norms = nn.ModuleList(
197 | [nn.GroupNorm(norm_channels, out_channels)
198 | for _ in range(num_layers)]
199 | )
200 |
201 | self.attentions = nn.ModuleList(
202 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
203 | for _ in range(num_layers)]
204 | )
205 | if self.cross_attn:
206 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
207 | self.cross_attention_norms = nn.ModuleList(
208 | [nn.GroupNorm(norm_channels, out_channels)
209 | for _ in range(num_layers)]
210 | )
211 | self.cross_attentions = nn.ModuleList(
212 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
213 | for _ in range(num_layers)]
214 | )
215 | self.context_proj = nn.ModuleList(
216 | [nn.Linear(context_dim, out_channels)
217 | for _ in range(num_layers)]
218 | )
219 | self.residual_input_conv = nn.ModuleList(
220 | [
221 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
222 | for i in range(num_layers + 1)
223 | ]
224 | )
225 |
226 | def forward(self, x, t_emb=None, context=None):
227 | out = x
228 |
229 | # First resnet block
230 | resnet_input = out
231 | out = self.resnet_conv_first[0](out)
232 | if self.t_emb_dim is not None:
233 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
234 | out = self.resnet_conv_second[0](out)
235 | out = out + self.residual_input_conv[0](resnet_input)
236 |
237 | for i in range(self.num_layers):
238 | # Attention Block
239 | batch_size, channels, h, w = out.shape
240 | in_attn = out.reshape(batch_size, channels, h * w)
241 | in_attn = self.attention_norms[i](in_attn)
242 | in_attn = in_attn.transpose(1, 2)
243 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
244 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
245 | out = out + out_attn
246 |
247 | if self.cross_attn:
248 | assert context is not None, "context cannot be None if cross attention layers are used"
249 | batch_size, channels, h, w = out.shape
250 | in_attn = out.reshape(batch_size, channels, h * w)
251 | in_attn = self.cross_attention_norms[i](in_attn)
252 | in_attn = in_attn.transpose(1, 2)
253 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
254 | context_proj = self.context_proj[i](context)
255 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
256 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
257 | out = out + out_attn
258 |
259 | # Resnet Block
260 | resnet_input = out
261 | out = self.resnet_conv_first[i + 1](out)
262 | if self.t_emb_dim is not None:
263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
264 | out = self.resnet_conv_second[i + 1](out)
265 | out = out + self.residual_input_conv[i + 1](resnet_input)
266 |
267 | return out
268 |
269 |
270 | class UpBlock(nn.Module):
271 | r"""
272 | Up conv block with attention.
273 | Sequence of following blocks
274 | 1. Upsample
275 | 1. Concatenate Down block output
276 | 2. Resnet block with time embedding
277 | 3. Attention Block
278 | """
279 |
280 | def __init__(self, in_channels, out_channels, t_emb_dim,
281 | up_sample, num_heads, num_layers, attn, norm_channels):
282 | super().__init__()
283 | self.num_layers = num_layers
284 | self.up_sample = up_sample
285 | self.t_emb_dim = t_emb_dim
286 | self.attn = attn
287 | self.resnet_conv_first = nn.ModuleList(
288 | [
289 | nn.Sequential(
290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
291 | nn.SiLU(),
292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
293 | padding=1),
294 | )
295 | for i in range(num_layers)
296 | ]
297 | )
298 |
299 | if self.t_emb_dim is not None:
300 | self.t_emb_layers = nn.ModuleList([
301 | nn.Sequential(
302 | nn.SiLU(),
303 | nn.Linear(t_emb_dim, out_channels)
304 | )
305 | for _ in range(num_layers)
306 | ])
307 |
308 | self.resnet_conv_second = nn.ModuleList(
309 | [
310 | nn.Sequential(
311 | nn.GroupNorm(norm_channels, out_channels),
312 | nn.SiLU(),
313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
314 | )
315 | for _ in range(num_layers)
316 | ]
317 | )
318 | if self.attn:
319 | self.attention_norms = nn.ModuleList(
320 | [
321 | nn.GroupNorm(norm_channels, out_channels)
322 | for _ in range(num_layers)
323 | ]
324 | )
325 |
326 | self.attentions = nn.ModuleList(
327 | [
328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
329 | for _ in range(num_layers)
330 | ]
331 | )
332 |
333 | self.residual_input_conv = nn.ModuleList(
334 | [
335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
336 | for i in range(num_layers)
337 | ]
338 | )
339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
340 | 4, 2, 1) \
341 | if self.up_sample else nn.Identity()
342 |
343 | def forward(self, x, out_down=None, t_emb=None):
344 | # Upsample
345 | x = self.up_sample_conv(x)
346 |
347 | # Concat with Downblock output
348 | if out_down is not None:
349 | x = torch.cat([x, out_down], dim=1)
350 |
351 | out = x
352 | for i in range(self.num_layers):
353 | # Resnet Block
354 | resnet_input = out
355 | out = self.resnet_conv_first[i](out)
356 | if self.t_emb_dim is not None:
357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
358 | out = self.resnet_conv_second[i](out)
359 | out = out + self.residual_input_conv[i](resnet_input)
360 |
361 | # Self Attention
362 | if self.attn:
363 | batch_size, channels, h, w = out.shape
364 | in_attn = out.reshape(batch_size, channels, h * w)
365 | in_attn = self.attention_norms[i](in_attn)
366 | in_attn = in_attn.transpose(1, 2)
367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
369 | out = out + out_attn
370 | return out
371 |
372 |
373 | class UpBlockUnet(nn.Module):
374 | r"""
375 | Up conv block with attention.
376 | Sequence of following blocks
377 | 1. Upsample
378 | 1. Concatenate Down block output
379 | 2. Resnet block with time embedding
380 | 3. Attention Block
381 | """
382 |
383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
385 | super().__init__()
386 | self.num_layers = num_layers
387 | self.up_sample = up_sample
388 | self.t_emb_dim = t_emb_dim
389 | self.cross_attn = cross_attn
390 | self.context_dim = context_dim
391 | self.resnet_conv_first = nn.ModuleList(
392 | [
393 | nn.Sequential(
394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
395 | nn.SiLU(),
396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
397 | padding=1),
398 | )
399 | for i in range(num_layers)
400 | ]
401 | )
402 |
403 | if self.t_emb_dim is not None:
404 | self.t_emb_layers = nn.ModuleList([
405 | nn.Sequential(
406 | nn.SiLU(),
407 | nn.Linear(t_emb_dim, out_channels)
408 | )
409 | for _ in range(num_layers)
410 | ])
411 |
412 | self.resnet_conv_second = nn.ModuleList(
413 | [
414 | nn.Sequential(
415 | nn.GroupNorm(norm_channels, out_channels),
416 | nn.SiLU(),
417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
418 | )
419 | for _ in range(num_layers)
420 | ]
421 | )
422 |
423 | self.attention_norms = nn.ModuleList(
424 | [
425 | nn.GroupNorm(norm_channels, out_channels)
426 | for _ in range(num_layers)
427 | ]
428 | )
429 |
430 | self.attentions = nn.ModuleList(
431 | [
432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
433 | for _ in range(num_layers)
434 | ]
435 | )
436 |
437 | if self.cross_attn:
438 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
439 | self.cross_attention_norms = nn.ModuleList(
440 | [nn.GroupNorm(norm_channels, out_channels)
441 | for _ in range(num_layers)]
442 | )
443 | self.cross_attentions = nn.ModuleList(
444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
445 | for _ in range(num_layers)]
446 | )
447 | self.context_proj = nn.ModuleList(
448 | [nn.Linear(context_dim, out_channels)
449 | for _ in range(num_layers)]
450 | )
451 | self.residual_input_conv = nn.ModuleList(
452 | [
453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
454 | for i in range(num_layers)
455 | ]
456 | )
457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
458 | 4, 2, 1) \
459 | if self.up_sample else nn.Identity()
460 |
461 | def forward(self, x, out_down=None, t_emb=None, context=None):
462 | x = self.up_sample_conv(x)
463 | if out_down is not None:
464 | x = torch.cat([x, out_down], dim=1)
465 |
466 | out = x
467 | for i in range(self.num_layers):
468 | # Resnet
469 | resnet_input = out
470 | out = self.resnet_conv_first[i](out)
471 | if self.t_emb_dim is not None:
472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
473 | out = self.resnet_conv_second[i](out)
474 | out = out + self.residual_input_conv[i](resnet_input)
475 | # Self Attention
476 | batch_size, channels, h, w = out.shape
477 | in_attn = out.reshape(batch_size, channels, h * w)
478 | in_attn = self.attention_norms[i](in_attn)
479 | in_attn = in_attn.transpose(1, 2)
480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
482 | out = out + out_attn
483 | # Cross Attention
484 | if self.cross_attn:
485 | assert context is not None, "context cannot be None if cross attention layers are used"
486 | batch_size, channels, h, w = out.shape
487 | in_attn = out.reshape(batch_size, channels, h * w)
488 | in_attn = self.cross_attention_norms[i](in_attn)
489 | in_attn = in_attn.transpose(1, 2)
490 | assert len(context.shape) == 3, \
491 | "Context shape does not match B,_,CONTEXT_DIM"
492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \
493 | "Context shape does not match B,_,CONTEXT_DIM"
494 | context_proj = self.context_proj[i](context)
495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
497 | out = out + out_attn
498 |
499 | return out
500 |
501 |
502 |
--------------------------------------------------------------------------------
/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Discriminator(nn.Module):
6 | r"""
7 | PatchGAN Discriminator.
8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
9 | 1 scalar value , we instead predict grid of values.
10 | Where each grid is prediction of how likely
11 | the discriminator thinks that the image patch corresponding
12 | to the grid cell is real
13 | """
14 |
15 | def __init__(self, im_channels=3,
16 | conv_channels=[64, 128, 256],
17 | kernels=[4,4,4,4],
18 | strides=[2,2,2,1],
19 | paddings=[1,1,1,1]):
20 | super().__init__()
21 | self.im_channels = im_channels
22 | activation = nn.LeakyReLU(0.2)
23 | layers_dim = [self.im_channels] + conv_channels + [1]
24 | self.layers = nn.ModuleList([
25 | nn.Sequential(
26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1],
27 | kernel_size=kernels[i],
28 | stride=strides[i],
29 | padding=paddings[i],
30 | bias=False if i !=0 else True),
31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
32 | activation if i != len(layers_dim) - 2 else nn.Identity()
33 | )
34 | for i in range(len(layers_dim) - 1)
35 | ])
36 |
37 | def forward(self, x):
38 | out = x
39 | for layer in self.layers:
40 | out = layer(out)
41 | return out
42 |
43 |
44 | if __name__ == '__main__':
45 | x = torch.randn((2,3, 256, 256))
46 | prob = Discriminator(im_channels=3)(x)
47 | print(prob.shape)
48 |
--------------------------------------------------------------------------------
/model/lpips.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import namedtuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn
9 | import torchvision
10 |
11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 | def spatial_average(in_tens, keepdim=True):
19 | return in_tens.mean([2, 3], keepdim=keepdim)
20 |
21 |
22 | class vgg16(torch.nn.Module):
23 | def __init__(self, requires_grad=False, pretrained=True):
24 | super(vgg16, self).__init__()
25 | # Load pretrained vgg model from torchvision
26 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features
27 | self.slice1 = torch.nn.Sequential()
28 | self.slice2 = torch.nn.Sequential()
29 | self.slice3 = torch.nn.Sequential()
30 | self.slice4 = torch.nn.Sequential()
31 | self.slice5 = torch.nn.Sequential()
32 | self.N_slices = 5
33 | for x in range(4):
34 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
35 | for x in range(4, 9):
36 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
37 | for x in range(9, 16):
38 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
39 | for x in range(16, 23):
40 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
41 | for x in range(23, 30):
42 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
43 |
44 | # Freeze vgg model
45 | if not requires_grad:
46 | for param in self.parameters():
47 | param.requires_grad = False
48 |
49 | def forward(self, X):
50 | # Return output of vgg features
51 | h = self.slice1(X)
52 | h_relu1_2 = h
53 | h = self.slice2(h)
54 | h_relu2_2 = h
55 | h = self.slice3(h)
56 | h_relu3_3 = h
57 | h = self.slice4(h)
58 | h_relu4_3 = h
59 | h = self.slice5(h)
60 | h_relu5_3 = h
61 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
62 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
63 | return out
64 |
65 |
66 | # Learned perceptual metric
67 | class LPIPS(nn.Module):
68 | def __init__(self, net='vgg', version='0.1', use_dropout=True):
69 | super(LPIPS, self).__init__()
70 | self.version = version
71 | # Imagenet normalization
72 | self.scaling_layer = ScalingLayer()
73 | ########################
74 |
75 | # Instantiate vgg model
76 | self.chns = [64, 128, 256, 512, 512]
77 | self.L = len(self.chns)
78 | self.net = vgg16(pretrained=True, requires_grad=False)
79 |
80 | # Add 1x1 convolutional Layers
81 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
82 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
83 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
84 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
85 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
86 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
87 | self.lins = nn.ModuleList(self.lins)
88 | ########################
89 |
90 | # Load the weights of trained LPIPS model
91 | import inspect
92 | import os
93 | model_path = os.path.abspath(
94 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
95 | print('Loading model from: %s' % model_path)
96 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
97 | ########################
98 |
99 | # Freeze all parameters
100 | self.eval()
101 | for param in self.parameters():
102 | param.requires_grad = False
103 | ########################
104 |
105 | def forward(self, in0, in1, normalize=False):
106 | # Scale the inputs to -1 to +1 range if needed
107 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
108 | in0 = 2 * in0 - 1
109 | in1 = 2 * in1 - 1
110 | ########################
111 |
112 | # Normalize the inputs according to imagenet normalization
113 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
114 | ########################
115 |
116 | # Get VGG outputs for image0 and image1
117 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
118 | feats0, feats1, diffs = {}, {}, {}
119 | ########################
120 |
121 | # Compute Square of Difference for each layer output
122 | for kk in range(self.L):
123 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
124 | outs1[kk])
125 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
126 | ########################
127 |
128 | # 1x1 convolution followed by spatial average on the square differences
129 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
130 | val = 0
131 |
132 | # Aggregate the results of each layer
133 | for l in range(self.L):
134 | val += res[l]
135 | return val
136 |
137 |
138 | class ScalingLayer(nn.Module):
139 | def __init__(self):
140 | super(ScalingLayer, self).__init__()
141 | # Imagnet normalization for (0-1)
142 | # mean = [0.485, 0.456, 0.406]
143 | # std = [0.229, 0.224, 0.225]
144 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
145 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
146 |
147 | def forward(self, inp):
148 | return (inp - self.shift) / self.scale
149 |
150 |
151 | class NetLinLayer(nn.Module):
152 | ''' A single linear layer which does a 1x1 conv '''
153 |
154 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
155 | super(NetLinLayer, self).__init__()
156 |
157 | layers = [nn.Dropout(), ] if (use_dropout) else []
158 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
159 | self.model = nn.Sequential(*layers)
160 |
161 | def forward(self, x):
162 | out = self.model(x)
163 | return out
164 |
--------------------------------------------------------------------------------
/model/patch_embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 |
7 |
8 | def get_patch_position_embedding(pos_emb_dim, grid_size, device):
9 | assert pos_emb_dim % 4 == 0, 'Position embedding dimension must be divisible by 4'
10 | grid_size_h, grid_size_w = grid_size
11 | grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
12 | grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
13 | grid = torch.meshgrid(grid_h, grid_w, indexing='ij')
14 | grid = torch.stack(grid, dim=0)
15 |
16 | # grid_h_positions -> (Number of patch tokens,)
17 | grid_h_positions = grid[0].reshape(-1)
18 | grid_w_positions = grid[1].reshape(-1)
19 |
20 | # factor = 10000^(2i/d_model)
21 | factor = 10000 ** ((torch.arange(
22 | start=0,
23 | end=pos_emb_dim // 4,
24 | dtype=torch.float32,
25 | device=device) / (pos_emb_dim // 4))
26 | )
27 |
28 | grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
29 | grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1)
30 | # grid_h_emb -> (Number of patch tokens, pos_emb_dim // 2)
31 |
32 | grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
33 | grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1)
34 | pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)
35 |
36 | # pos_emb -> (Number of patch tokens, pos_emb_dim)
37 | return pos_emb
38 |
39 |
40 | class PatchEmbedding(nn.Module):
41 | r"""
42 | Layer to take in the input image and do the following:
43 | 1. Transform grid of image patches into a sequence of patches.
44 | Number of patches are decided based on image height,width and
45 | patch height, width.
46 | 2. Add positional embedding to the above sequence
47 | """
48 |
49 | def __init__(self,
50 | image_height,
51 | image_width,
52 | im_channels,
53 | patch_height,
54 | patch_width,
55 | hidden_size):
56 | super().__init__()
57 | self.image_height = image_height
58 | self.image_width = image_width
59 | self.im_channels = im_channels
60 |
61 | self.hidden_size = hidden_size
62 |
63 | self.patch_height = patch_height
64 | self.patch_width = patch_width
65 |
66 | # Input dimension for Patch Embedding FC Layer
67 | patch_dim = self.im_channels * self.patch_height * self.patch_width
68 | self.patch_embed = nn.Sequential(
69 | nn.Linear(patch_dim, self.hidden_size)
70 | )
71 |
72 | ############################
73 | # DiT Layer Initialization #
74 | ############################
75 | nn.init.xavier_uniform_(self.patch_embed[0].weight)
76 | nn.init.constant_(self.patch_embed[0].bias, 0)
77 |
78 | def forward(self, x):
79 | grid_size_h = self.image_height // self.patch_height
80 | grid_size_w = self.image_width // self.patch_width
81 |
82 | # B, C, H, W -> B, (Patches along height * Patches along width), Patch Dimension
83 | # Number of tokens = Patches along height * Patches along width
84 | out = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)',
85 | ph=self.patch_height,
86 | pw=self.patch_width)
87 |
88 | # BxNumber of tokens x Patch Dimension -> B x Number of tokens x Transformer Dimension
89 | out = self.patch_embed(out)
90 |
91 | # Add 2d sinusoidal position embeddings
92 | pos_embed = get_patch_position_embedding(pos_emb_dim=self.hidden_size,
93 | grid_size=(grid_size_h, grid_size_w),
94 | device=x.device)
95 | out += pos_embed
96 | return out
97 |
98 |
--------------------------------------------------------------------------------
/model/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.patch_embed import PatchEmbedding
4 | from model.transformer_layer import TransformerLayer
5 | from einops import rearrange
6 |
7 |
8 | def get_time_embedding(time_steps, temb_dim):
9 | r"""
10 | Convert time steps tensor into an embedding using the
11 | sinusoidal time embedding formula
12 | :param time_steps: 1D tensor of length batch size
13 | :param temb_dim: Dimension of the embedding
14 | :return: BxD embedding representation of B time steps
15 | """
16 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
17 |
18 | # factor = 10000^(2i/d_model)
19 | factor = 10000 ** ((torch.arange(
20 | start=0,
21 | end=temb_dim // 2,
22 | dtype=torch.float32,
23 | device=time_steps.device) / (temb_dim // 2))
24 | )
25 |
26 | # pos / factor
27 | # timesteps B -> B, 1 -> B, temb_dim
28 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
29 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
30 | return t_emb
31 |
32 |
33 | class DIT(nn.Module):
34 | def __init__(self, im_size, im_channels, config):
35 | super().__init__()
36 |
37 | num_layers = config['num_layers']
38 | self.image_height = im_size
39 | self.image_width = im_size
40 | self.im_channels = im_channels
41 | self.hidden_size = config['hidden_size']
42 | self.patch_height = config['patch_size']
43 | self.patch_width = config['patch_size']
44 |
45 | self.timestep_emb_dim = config['timestep_emb_dim']
46 |
47 | # Number of patches along height and width
48 | self.nh = self.image_height // self.patch_height
49 | self.nw = self.image_width // self.patch_width
50 |
51 | # Patch Embedding Block
52 | self.patch_embed_layer = PatchEmbedding(image_height=self.image_height,
53 | image_width=self.image_width,
54 | im_channels=self.im_channels,
55 | patch_height=self.patch_height,
56 | patch_width=self.patch_width,
57 | hidden_size=self.hidden_size)
58 |
59 | # Initial projection from sinusoidal time embedding
60 | self.t_proj = nn.Sequential(
61 | nn.Linear(self.timestep_emb_dim, self.hidden_size),
62 | nn.SiLU(),
63 | nn.Linear(self.hidden_size, self.hidden_size)
64 | )
65 |
66 | # All Transformer Layers
67 | self.layers = nn.ModuleList([
68 | TransformerLayer(config) for _ in range(num_layers)
69 | ])
70 |
71 | # Final normalization for unpatchify block
72 | self.norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6)
73 |
74 | # Scale and Shift parameters for the norm
75 | self.adaptive_norm_layer = nn.Sequential(
76 | nn.SiLU(),
77 | nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=True)
78 | )
79 |
80 | # Final Linear Layer
81 | self.proj_out = nn.Linear(self.hidden_size,
82 | self.patch_height * self.patch_width * self.im_channels)
83 |
84 | ############################
85 | # DiT Layer Initialization #
86 | ############################
87 | nn.init.normal_(self.t_proj[0].weight, std=0.02)
88 | nn.init.normal_(self.t_proj[2].weight, std=0.02)
89 |
90 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0)
91 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0)
92 |
93 | nn.init.constant_(self.proj_out.weight, 0)
94 | nn.init.constant_(self.proj_out.bias, 0)
95 |
96 | def forward(self, x, t):
97 | # Patchify
98 | out = self.patch_embed_layer(x)
99 |
100 | # Compute Timestep representation
101 | # t_emb -> (Batch, timestep_emb_dim)
102 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.timestep_emb_dim)
103 | # (Batch, timestep_emb_dim) -> (Batch, hidden_size)
104 | t_emb = self.t_proj(t_emb)
105 |
106 | # Go through the transformer layers
107 | for layer in self.layers:
108 | out = layer(out, t_emb)
109 |
110 | # Shift and scale predictions for output normalization
111 | pre_mlp_shift, pre_mlp_scale = self.adaptive_norm_layer(t_emb).chunk(2, dim=1)
112 | out = (self.norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) +
113 | pre_mlp_shift.unsqueeze(1))
114 |
115 | # Unpatchify
116 | # (B,patches,hidden_size) -> (B,patches,channels * patch_width * patch_height)
117 | out = self.proj_out(out)
118 | out = rearrange(out, 'b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)',
119 | ph=self.patch_height,
120 | pw=self.patch_width,
121 | nw=self.nw,
122 | nh=self.nh)
123 | return out
124 |
--------------------------------------------------------------------------------
/model/transformer_layer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from model.attention import Attention
3 |
4 |
5 | class TransformerLayer(nn.Module):
6 | r"""
7 | Transformer block which is just doing the following based on VIT
8 | 1. LayerNorm followed by Attention
9 | 2. LayerNorm followed by Feed forward Block
10 | Both these also have residuals added to them
11 |
12 | For DiT we additionally have
13 | 1. Layernorm mlp to predict layernorm affine parameters from
14 | 2. Same Layernorm mlp to also predict scale parameters for outputs
15 | of both mlp/attention prior to residual connection.
16 | """
17 |
18 | def __init__(self, config):
19 | super().__init__()
20 | self.hidden_size = config['hidden_size']
21 |
22 | ff_hidden_dim = 4 * self.hidden_size
23 |
24 | # Layer norm for attention block
25 | self.att_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6)
26 |
27 | self.attn_block = Attention(config)
28 |
29 | # Layer norm for mlp block
30 | self.ff_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6)
31 |
32 | self.mlp_block = nn.Sequential(
33 | nn.Linear(self.hidden_size, ff_hidden_dim),
34 | nn.GELU(approximate='tanh'),
35 | nn.Linear(ff_hidden_dim, self.hidden_size),
36 | )
37 |
38 | # Scale Shift Parameter predictions for this layer
39 | # 1. Scale and shift parameters for layernorm of attention (2 * hidden_size)
40 | # 2. Scale and shift parameters for layernorm of mlp (2 * hidden_size)
41 | # 3. Scale for output of attention prior to residual connection (hidden_size)
42 | # 4. Scale for output of mlp prior to residual connection (hidden_size)
43 | # Total 6 * hidden_size
44 | self.adaptive_norm_layer = nn.Sequential(
45 | nn.SiLU(),
46 | nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)
47 | )
48 |
49 | ############################
50 | # DiT Layer Initialization #
51 | ############################
52 | nn.init.xavier_uniform_(self.mlp_block[0].weight)
53 | nn.init.constant_(self.mlp_block[0].bias, 0)
54 | nn.init.xavier_uniform_(self.mlp_block[-1].weight)
55 | nn.init.constant_(self.mlp_block[-1].bias, 0)
56 |
57 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0)
58 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0)
59 |
60 | def forward(self, x, condition):
61 | scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1)
62 | (pre_attn_shift, pre_attn_scale, post_attn_scale,
63 | pre_mlp_shift, pre_mlp_scale, post_mlp_scale) = scale_shift_params
64 | out = x
65 | attn_norm_output = (self.att_norm(out) * (1 + pre_attn_scale.unsqueeze(1))
66 | + pre_attn_shift.unsqueeze(1))
67 | out = out + post_attn_scale.unsqueeze(1) * self.attn_block(attn_norm_output)
68 | mlp_norm_output = (self.ff_norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) +
69 | pre_mlp_shift.unsqueeze(1))
70 | out = out + post_mlp_scale.unsqueeze(1) * self.mlp_block(mlp_norm_output)
71 | return out
72 |
--------------------------------------------------------------------------------
/model/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.blocks import DownBlock, MidBlock, UpBlock
4 |
5 |
6 | class VAE(nn.Module):
7 | def __init__(self, im_channels, model_config):
8 | super().__init__()
9 | self.down_channels = model_config['down_channels']
10 | self.mid_channels = model_config['mid_channels']
11 | self.down_sample = model_config['down_sample']
12 | self.num_down_layers = model_config['num_down_layers']
13 | self.num_mid_layers = model_config['num_mid_layers']
14 | self.num_up_layers = model_config['num_up_layers']
15 |
16 | # To disable attention in Downblock of Encoder and Upblock of Decoder
17 | self.attns = model_config['attn_down']
18 |
19 | # Latent Dimension
20 | self.z_channels = model_config['z_channels']
21 | self.norm_channels = model_config['norm_channels']
22 | self.num_heads = model_config['num_heads']
23 |
24 | # Assertion to validate the channel information
25 | assert self.mid_channels[0] == self.down_channels[-1]
26 | assert self.mid_channels[-1] == self.down_channels[-1]
27 | assert len(self.down_sample) == len(self.down_channels) - 1
28 | assert len(self.attns) == len(self.down_channels) - 1
29 |
30 | # Wherever we use downsampling in encoder correspondingly use
31 | # upsampling in decoder
32 | self.up_sample = list(reversed(self.down_sample))
33 |
34 | ##################### Encoder ######################
35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
36 |
37 | # Downblock + Midblock
38 | self.encoder_layers = nn.ModuleList([])
39 | for i in range(len(self.down_channels) - 1):
40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
41 | t_emb_dim=None, down_sample=self.down_sample[i],
42 | num_heads=self.num_heads,
43 | num_layers=self.num_down_layers,
44 | attn=self.attns[i],
45 | norm_channels=self.norm_channels))
46 |
47 | self.encoder_mids = nn.ModuleList([])
48 | for i in range(len(self.mid_channels) - 1):
49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
50 | t_emb_dim=None,
51 | num_heads=self.num_heads,
52 | num_layers=self.num_mid_layers,
53 | norm_channels=self.norm_channels))
54 |
55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2 * self.z_channels, kernel_size=3, padding=1)
57 |
58 | # Latent Dimension is 2*Latent because we are predicting mean & variance
59 | self.pre_quant_conv = nn.Conv2d(2 * self.z_channels, 2 * self.z_channels, kernel_size=1)
60 | ####################################################
61 |
62 | ##################### Decoder ######################
63 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
64 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
65 |
66 | # Midblock + Upblock
67 | self.decoder_mids = nn.ModuleList([])
68 | for i in reversed(range(1, len(self.mid_channels))):
69 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
70 | t_emb_dim=None,
71 | num_heads=self.num_heads,
72 | num_layers=self.num_mid_layers,
73 | norm_channels=self.norm_channels))
74 |
75 | self.decoder_layers = nn.ModuleList([])
76 | for i in reversed(range(1, len(self.down_channels))):
77 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
78 | t_emb_dim=None, up_sample=self.down_sample[i - 1],
79 | num_heads=self.num_heads,
80 | num_layers=self.num_up_layers,
81 | attn=self.attns[i - 1],
82 | norm_channels=self.norm_channels))
83 |
84 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
85 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
86 |
87 | def encode(self, x):
88 | out = self.encoder_conv_in(x)
89 | for idx, down in enumerate(self.encoder_layers):
90 | out = down(out)
91 | for mid in self.encoder_mids:
92 | out = mid(out)
93 | out = self.encoder_norm_out(out)
94 | out = nn.SiLU()(out)
95 | out = self.encoder_conv_out(out)
96 | out = self.pre_quant_conv(out)
97 | mean, logvar = torch.chunk(out, 2, dim=1)
98 | std = torch.exp(0.5 * logvar)
99 | sample = mean + std * torch.randn(mean.shape).to(device=x.device)
100 | return sample, out
101 |
102 | def decode(self, z):
103 | out = z
104 | out = self.post_quant_conv(out)
105 | out = self.decoder_conv_in(out)
106 | for mid in self.decoder_mids:
107 | out = mid(out)
108 | for idx, up in enumerate(self.decoder_layers):
109 | out = up(out)
110 |
111 | out = self.decoder_norm_out(out)
112 | out = nn.SiLU()(out)
113 | out = self.decoder_conv_out(out)
114 | return out
115 |
116 | def forward(self, x):
117 | z, encoder_output = self.encode(x)
118 | out = self.decode(z)
119 | return out, encoder_output
120 |
121 |
--------------------------------------------------------------------------------
/model/weights/v0.1/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/model/weights/v0.1/.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.8.0
2 | numpy==2.0.1
3 | opencv_python==4.10.0.84
4 | Pillow==10.4.0
5 | PyYAML==6.0.1
6 | torch==2.3.1
7 | torchvision==0.18.1
8 | tqdm==4.66.4
9 |
--------------------------------------------------------------------------------
/scheduler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/scheduler/__init__.py
--------------------------------------------------------------------------------
/scheduler/linear_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LinearNoiseScheduler:
5 | r"""
6 | Class for the linear noise scheduler that is used in DDPM.
7 | """
8 |
9 | def __init__(self, num_timesteps, beta_start, beta_end):
10 | self.num_timesteps = num_timesteps
11 | self.beta_start = beta_start
12 | self.beta_end = beta_end
13 |
14 | self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
15 | self.alphas = 1. - self.betas
16 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
17 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
18 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
19 |
20 | def add_noise(self, original, noise, t):
21 | r"""
22 | Forward method for diffusion
23 | :param original: Image on which noise is to be applied
24 | :param noise: Random Noise Tensor (from normal dist)
25 | :param t: timestep of the forward process of shape -> (B,)
26 | :return:
27 | """
28 | original_shape = original.shape
29 | batch_size = original_shape[0]
30 |
31 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
32 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
33 |
34 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
35 | for _ in range(len(original_shape) - 1):
36 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
37 | for _ in range(len(original_shape) - 1):
38 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
39 |
40 | # Apply and Return Forward process equation
41 | return (sqrt_alpha_cum_prod.to(original.device) * original
42 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
43 |
44 | def sample_prev_timestep(self, xt, pred, t):
45 | r"""
46 | Use the noise prediction by model to get
47 | xt-1 using xt and the noise predicted
48 | :param xt: current timestep sample
49 | :param pred: model noise prediction
50 | :param t: current timestep we are at
51 | :return:
52 | """
53 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * pred)) /
54 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
55 | x0 = torch.clamp(x0, -1., 1.)
56 |
57 | mean = xt - ((self.betas.to(xt.device)[t]) * pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
58 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
59 |
60 | if t == 0:
61 | return mean, x0
62 | else:
63 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
64 | variance = variance * self.betas.to(xt.device)[t]
65 | sigma = variance ** 0.5
66 | z = torch.randn(xt.shape).to(xt.device)
67 | return mean + sigma * z, x0
68 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/tools/__init__.py
--------------------------------------------------------------------------------
/tools/infer_vae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import pickle
5 |
6 | import torch
7 | import torchvision
8 | import yaml
9 | from torch.utils.data.dataloader import DataLoader
10 | from torchvision.utils import make_grid
11 | from tqdm import tqdm
12 |
13 | from dataset.celeb_dataset import CelebDataset
14 | from model.vae import VAE
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | if torch.backends.mps.is_available():
18 | device = torch.device('mps')
19 | print('Using mps')
20 |
21 |
22 | def infer(args):
23 | ######## Read the config file #######
24 | with open(args.config_path, 'r') as file:
25 | try:
26 | config = yaml.safe_load(file)
27 | except yaml.YAMLError as exc:
28 | print(exc)
29 | print(config)
30 |
31 | dataset_config = config['dataset_params']
32 | autoencoder_config = config['autoencoder_params']
33 | train_config = config['train_params']
34 |
35 | im_dataset = CelebDataset(split='train',
36 | im_path=dataset_config['im_path'],
37 | im_size=dataset_config['im_size'],
38 | im_channels=dataset_config['im_channels'])
39 |
40 | # This is only used for saving latents. Which as of now
41 | # is not done in batches hence batch size 1
42 | data_loader = DataLoader(im_dataset,
43 | batch_size=1,
44 | shuffle=False)
45 |
46 | num_images = train_config['num_samples']
47 | ngrid = train_config['num_grid_rows']
48 |
49 | idxs = torch.randint(0, len(im_dataset) - 1, (num_images,))
50 | ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float()
51 | ims = ims.to(device)
52 |
53 | model = VAE(im_channels=dataset_config['im_channels'],
54 | model_config=autoencoder_config).to(device)
55 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
56 | train_config['vae_autoencoder_ckpt_name']),
57 | map_location=device))
58 | model.eval()
59 |
60 | with torch.no_grad():
61 |
62 | encoded_output, _ = model.encode(ims)
63 | decoded_output = model.decode(encoded_output)
64 | encoded_output = torch.clamp(encoded_output, -1., 1.)
65 | encoded_output = (encoded_output + 1) / 2
66 | decoded_output = torch.clamp(decoded_output, -1., 1.)
67 | decoded_output = (decoded_output + 1) / 2
68 | ims = (ims + 1) / 2
69 |
70 | encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid)
71 | decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid)
72 | input_grid = make_grid(ims.cpu(), nrow=ngrid)
73 | encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid)
74 | decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid)
75 | input_grid = torchvision.transforms.ToPILImage()(input_grid)
76 |
77 | input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png'))
78 | encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png'))
79 | decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png'))
80 |
81 | if train_config['save_latents']:
82 | # save Latents (but in a very unoptimized way)
83 | latent_path = os.path.join(train_config['task_name'], train_config['vae_latent_dir_name'])
84 | latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vae_latent_dir_name'],
85 | '*.pkl'))
86 | assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run'
87 | if not os.path.exists(latent_path):
88 | os.mkdir(latent_path)
89 | print('Saving Latents for {}'.format(dataset_config['name']))
90 |
91 | fname_latent_map = {}
92 | part_count = 0
93 | count = 0
94 | for idx, im in enumerate(tqdm(data_loader)):
95 | _, encoded_output = model.encode(im.float().to(device))
96 | fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu()
97 | # Save latents every 1000 images
98 | if (count + 1) % 1000 == 0:
99 | pickle.dump(fname_latent_map, open(os.path.join(latent_path,
100 | '{}.pkl'.format(part_count)), 'wb'))
101 | part_count += 1
102 | fname_latent_map = {}
103 | count += 1
104 | if len(fname_latent_map) > 0:
105 | pickle.dump(fname_latent_map, open(os.path.join(latent_path,
106 | '{}.pkl'.format(part_count)), 'wb'))
107 | print('Done saving latents')
108 |
109 |
110 | if __name__ == '__main__':
111 | parser = argparse.ArgumentParser(description='Arguments for vae inference')
112 | parser.add_argument('--config', dest='config_path',
113 | default='config/celebhq.yaml', type=str)
114 | args = parser.parse_args()
115 | infer(args)
116 |
--------------------------------------------------------------------------------
/tools/sample_vae_dit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | from torchvision.utils import make_grid
7 | from PIL import Image
8 | from tqdm import tqdm
9 | from model.vae import VAE
10 | from model.transformer import DIT
11 | from scheduler.linear_scheduler import LinearNoiseScheduler
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 |
19 | def sample(model, scheduler, train_config, dit_model_config,
20 | autoencoder_model_config, diffusion_config, dataset_config, vae):
21 | r"""
22 | Sample stepwise by going backward one timestep at a time.
23 | We save the x0 predictions
24 | """
25 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
26 | xt = torch.randn((train_config['num_samples'],
27 | autoencoder_model_config['z_channels'],
28 | im_size,
29 | im_size)).to(device)
30 |
31 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
32 | # Get prediction of noise
33 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
34 |
35 | # Use scheduler to get x0 and xt-1
36 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
37 |
38 | # Save x0
39 | # ims = torch.clamp(xt, -1., 1.).detach().cpu()
40 | if i == 0:
41 | # Decode ONLY the final image to save time
42 | ims = vae.to(device).decode(xt)
43 | else:
44 | ims = xt
45 | ims = xt[:, :-1, :, :]
46 |
47 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
48 | ims = (ims + 1) / 2
49 |
50 | grid = make_grid(ims, nrow=train_config['num_grid_rows'])
51 | img = torchvision.transforms.ToPILImage()(grid)
52 |
53 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
54 | os.mkdir(os.path.join(train_config['task_name'], 'samples'))
55 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
56 | img.close()
57 |
58 |
59 | def infer(args):
60 | # Read the config file #
61 | with open(args.config_path, 'r') as file:
62 | try:
63 | config = yaml.safe_load(file)
64 | except yaml.YAMLError as exc:
65 | print(exc)
66 | print(config)
67 | ########################
68 |
69 | diffusion_config = config['diffusion_params']
70 | dataset_config = config['dataset_params']
71 | dit_model_config = config['dit_params']
72 | autoencoder_model_config = config['autoencoder_params']
73 | train_config = config['train_params']
74 |
75 | # Create the noise scheduler
76 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
77 | beta_start=diffusion_config['beta_start'],
78 | beta_end=diffusion_config['beta_end'])
79 |
80 | # Get latent image size
81 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
82 | model = DIT(im_size=im_size,
83 | im_channels=autoencoder_model_config['z_channels'],
84 | config=dit_model_config).to(device)
85 |
86 | model.eval()
87 |
88 | assert os.path.exists(os.path.join(train_config['task_name'],
89 | train_config['dit_ckpt_name'])), "Train DiT first"
90 |
91 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
92 | train_config['dit_ckpt_name']),
93 | map_location=device))
94 | print('Loaded dit checkpoint')
95 |
96 | # Create output directories
97 | if not os.path.exists(train_config['task_name']):
98 | os.mkdir(train_config['task_name'])
99 |
100 | vae = VAE(im_channels=dataset_config['im_channels'],
101 | model_config=autoencoder_model_config)
102 | vae.eval()
103 |
104 | # Load vae if found
105 | assert os.path.exists(os.path.join(train_config['task_name'], train_config['vae_autoencoder_ckpt_name'])), \
106 | "VAE checkpoint not present. Train VAE first."
107 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
108 | train_config['vae_autoencoder_ckpt_name']),
109 | map_location=device), strict=True)
110 | print('Loaded vae checkpoint')
111 |
112 | with torch.no_grad():
113 | sample(model, scheduler, train_config, dit_model_config,
114 | autoencoder_model_config, diffusion_config, dataset_config, vae)
115 |
116 |
117 | if __name__ == '__main__':
118 | parser = argparse.ArgumentParser(description='Arguments for dit image generation')
119 | parser.add_argument('--config', dest='config_path',
120 | default='config/celebhq.yaml', type=str)
121 | args = parser.parse_args()
122 | infer(args)
123 |
--------------------------------------------------------------------------------
/tools/train_vae.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import random
5 | import torchvision
6 | import os
7 | import numpy as np
8 | from tqdm import tqdm
9 | from model.vae import VAE
10 | from model.lpips import LPIPS
11 | from model.discriminator import Discriminator
12 | from torch.utils.data.dataloader import DataLoader
13 | from dataset.celeb_dataset import CelebDataset
14 | from torch.optim import Adam
15 | from torchvision.utils import make_grid
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 | if torch.backends.mps.is_available():
19 | device = torch.device('mps')
20 | print('Using mps')
21 |
22 |
23 | def train(args):
24 | # Read the config file #
25 | with open(args.config_path, 'r') as file:
26 | try:
27 | config = yaml.safe_load(file)
28 | except yaml.YAMLError as exc:
29 | print(exc)
30 | print(config)
31 |
32 | dataset_config = config['dataset_params']
33 | autoencoder_config = config['autoencoder_params']
34 | train_config = config['train_params']
35 |
36 | # Set the desired seed value #
37 | seed = train_config['seed']
38 | torch.manual_seed(seed)
39 | np.random.seed(seed)
40 | random.seed(seed)
41 | if device == 'cuda':
42 | torch.cuda.manual_seed_all(seed)
43 | #############################
44 |
45 | # Create the model and dataset #
46 | model = VAE(im_channels=dataset_config['im_channels'],
47 | model_config=autoencoder_config).to(device)
48 | # Create the dataset
49 | im_dataset = CelebDataset(split='train',
50 | im_path=dataset_config['im_path'],
51 | im_size=dataset_config['im_size'],
52 | im_channels=dataset_config['im_channels'])
53 |
54 | data_loader = DataLoader(im_dataset,
55 | batch_size=train_config['autoencoder_batch_size'],
56 | shuffle=True)
57 |
58 | # Create output directories
59 | if not os.path.exists(train_config['task_name']):
60 | os.mkdir(train_config['task_name'])
61 |
62 | num_epochs = train_config['autoencoder_epochs']
63 |
64 | # L1/L2 loss for Reconstruction
65 | recon_criterion = torch.nn.MSELoss()
66 | # Disc Loss can even be BCEWithLogits
67 | disc_criterion = torch.nn.MSELoss()
68 |
69 | # No need to freeze lpips as lpips.py takes care of that
70 | lpips_model = LPIPS().eval().to(device)
71 | discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device)
72 |
73 | if os.path.exists(os.path.join(train_config['task_name'],
74 | train_config['vae_autoencoder_ckpt_name'])):
75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
76 | train_config['vae_autoencoder_ckpt_name']),
77 | map_location=device))
78 | print('Loaded autoencoder from checkpoint')
79 |
80 | if os.path.exists(os.path.join(train_config['task_name'],
81 | train_config['vae_discriminator_ckpt_name'])):
82 | discriminator.load_state_dict(torch.load(os.path.join(train_config['task_name'],
83 | train_config['vae_discriminator_ckpt_name']),
84 | map_location=device))
85 | print('Loaded discriminator from checkpoint')
86 |
87 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
88 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
89 |
90 | disc_step_start = train_config['disc_start']
91 | step_count = 0
92 |
93 | # This is for accumulating gradients incase the images are huge
94 | # And one cant afford higher batch sizes
95 | acc_steps = train_config['autoencoder_acc_steps']
96 | image_save_steps = train_config['autoencoder_img_save_steps']
97 | img_save_count = 0
98 |
99 | for epoch_idx in range(num_epochs):
100 | recon_losses = []
101 | perceptual_losses = []
102 | disc_losses = []
103 | gen_losses = []
104 | losses = []
105 |
106 | optimizer_g.zero_grad()
107 | optimizer_d.zero_grad()
108 |
109 | for im in tqdm(data_loader):
110 | step_count += 1
111 | im = im.float().to(device)
112 |
113 | # Fetch autoencoders output(reconstructions)
114 | model_output = model(im)
115 | output, encoder_output = model_output
116 |
117 | # Image Saving Logic
118 | if step_count % image_save_steps == 0 or step_count == 1:
119 | sample_size = min(8, im.shape[0])
120 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu()
121 | save_output = ((save_output + 1) / 2)
122 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
123 |
124 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
125 | img = torchvision.transforms.ToPILImage()(grid)
126 | if not os.path.exists(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')):
127 | os.mkdir(os.path.join(train_config['task_name'], 'vae_autoencoder_samples'))
128 | img.save(os.path.join(train_config['task_name'], 'vae_autoencoder_samples',
129 | 'current_autoencoder_sample_{}.png'.format(img_save_count)))
130 | img_save_count += 1
131 | img.close()
132 |
133 | ######### Optimize Generator ##########
134 | # L2 Loss
135 | recon_loss = recon_criterion(output, im)
136 | recon_losses.append(recon_loss.item())
137 | recon_loss = recon_loss / acc_steps
138 |
139 | mean, logvar = torch.chunk(encoder_output, 2, dim=1)
140 | kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mean ** 2 - 1 - logvar, dim=[1, 2, 3]))
141 |
142 | g_loss = recon_loss + (train_config['kl_weight'] * kl_loss / acc_steps)
143 |
144 | # Adversarial loss only if disc_step_start steps passed
145 | if step_count > disc_step_start:
146 | disc_fake_pred = discriminator(model_output[0])
147 | disc_fake_loss = disc_criterion(disc_fake_pred,
148 | torch.ones(disc_fake_pred.shape,
149 | device=disc_fake_pred.device))
150 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item())
151 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps
152 | lpips_loss = torch.mean(lpips_model(output, im))
153 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item())
154 | g_loss += train_config['perceptual_weight'] * lpips_loss / acc_steps
155 | losses.append(g_loss.item())
156 | g_loss.backward()
157 | #####################################
158 |
159 | ######### Optimize Discriminator #######
160 | if step_count > disc_step_start:
161 | fake = output
162 | disc_fake_pred = discriminator(fake.detach())
163 | disc_real_pred = discriminator(im)
164 | disc_fake_loss = disc_criterion(disc_fake_pred,
165 | torch.zeros(disc_fake_pred.shape,
166 | device=disc_fake_pred.device))
167 | disc_real_loss = disc_criterion(disc_real_pred,
168 | torch.ones(disc_real_pred.shape,
169 | device=disc_real_pred.device))
170 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2
171 | disc_losses.append(disc_loss.item())
172 | disc_loss = disc_loss / acc_steps
173 | disc_loss.backward()
174 | if step_count % acc_steps == 0:
175 | optimizer_d.step()
176 | optimizer_d.zero_grad()
177 | #####################################
178 |
179 | if step_count % acc_steps == 0:
180 | optimizer_g.step()
181 | optimizer_g.zero_grad()
182 | optimizer_d.step()
183 | optimizer_d.zero_grad()
184 | optimizer_g.step()
185 | optimizer_g.zero_grad()
186 | if len(disc_losses) > 0:
187 | print(
188 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | '
189 | 'G Loss : {:.4f} | D Loss {:.4f}'.
190 | format(epoch_idx + 1,
191 | np.mean(recon_losses),
192 | np.mean(perceptual_losses),
193 | np.mean(gen_losses),
194 | np.mean(disc_losses)))
195 | else:
196 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}'.
197 | format(epoch_idx + 1,
198 | np.mean(recon_losses),
199 | np.mean(perceptual_losses)))
200 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
201 | train_config['vae_autoencoder_ckpt_name']))
202 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'],
203 | train_config['vae_discriminator_ckpt_name']))
204 | print('Done Training...')
205 |
206 |
207 | if __name__ == '__main__':
208 | parser = argparse.ArgumentParser(description='Arguments for vae training')
209 | parser.add_argument('--config', dest='config_path',
210 | default='config/celebhq.yaml', type=str)
211 | args = parser.parse_args()
212 | train(args)
213 |
--------------------------------------------------------------------------------
/tools/train_vae_dit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import argparse
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.optim import AdamW
8 | from dataset.celeb_dataset import CelebDataset
9 | from torch.utils.data import DataLoader
10 | from model.transformer import DIT
11 | from model.vae import VAE
12 | from scheduler.linear_scheduler import LinearNoiseScheduler
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | if torch.backends.mps.is_available():
16 | device = torch.device('mps')
17 | print('Using mps')
18 |
19 |
20 | def train(args):
21 | # Read the config file #
22 | with open(args.config_path, 'r') as file:
23 | try:
24 | config = yaml.safe_load(file)
25 | except yaml.YAMLError as exc:
26 | print(exc)
27 | print(config)
28 | ########################
29 |
30 | diffusion_config = config['diffusion_params']
31 | dataset_config = config['dataset_params']
32 | dit_model_config = config['dit_params']
33 | autoencoder_model_config = config['autoencoder_params']
34 | train_config = config['train_params']
35 |
36 | # Create the noise scheduler
37 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
38 | beta_start=diffusion_config['beta_start'],
39 | beta_end=diffusion_config['beta_end'])
40 |
41 | im_dataset = CelebDataset(split='train',
42 | im_path=dataset_config['im_path'],
43 | im_size=dataset_config['im_size'],
44 | im_channels=dataset_config['im_channels'],
45 | use_latents=True,
46 | latent_path=os.path.join(train_config['task_name'],
47 | train_config['vae_latent_dir_name'])
48 | )
49 |
50 | data_loader = DataLoader(im_dataset,
51 | batch_size=train_config['dit_batch_size'],
52 | shuffle=True)
53 |
54 | # Instantiate the model
55 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
56 | model = DIT(im_size=im_size,
57 | im_channels=autoencoder_model_config['z_channels'],
58 | config=dit_model_config).to(device)
59 | model.train()
60 |
61 | if os.path.exists(os.path.join(train_config['task_name'],
62 | train_config['dit_ckpt_name'])):
63 | print('Loaded DiT checkpoint')
64 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
65 | train_config['dit_ckpt_name']),
66 | map_location=device))
67 |
68 | # Load VAE ONLY if latents are not to be used or are missing
69 | if not im_dataset.use_latents:
70 | print('Loading vae model as latents not present')
71 | vae = VAE(im_channels=dataset_config['im_channels'],
72 | model_config=autoencoder_model_config).to(device)
73 | vae.eval()
74 | # Load vae if found
75 | if os.path.exists(os.path.join(train_config['task_name'],
76 | train_config['vae_autoencoder_ckpt_name'])):
77 | print('Loaded vae checkpoint')
78 | vae.load_state_dict(torch.load(os.path.join(
79 | train_config['task_name'],
80 | train_config['vae_autoencoder_ckpt_name']),
81 | map_location=device))
82 | # Specify training parameters
83 | num_epochs = train_config['dit_epochs']
84 | optimizer = AdamW(model.parameters(), lr=1E-5, weight_decay=0)
85 | criterion = torch.nn.MSELoss()
86 |
87 | # Run training
88 | if not im_dataset.use_latents:
89 | for param in vae.parameters():
90 | param.requires_grad = False
91 |
92 | acc_steps = train_config['dit_acc_steps']
93 | for epoch_idx in range(num_epochs):
94 | losses = []
95 | step_count = 0
96 | for im in tqdm(data_loader):
97 | step_count += 1
98 | im = im.float().to(device)
99 | if im_dataset.use_latents:
100 | mean, logvar = torch.chunk(im, 2, dim=1)
101 | std = torch.exp(0.5 * logvar)
102 | im = mean + std * torch.randn(mean.shape).to(device=im.device)
103 | else:
104 | with torch.no_grad():
105 | im, _ = vae.encode(im)
106 |
107 | # Sample random noise
108 | noise = torch.randn_like(im).to(device)
109 |
110 | # Sample timestep
111 | t = torch.randint(0, diffusion_config['num_timesteps'],
112 | (im.shape[0],)).to(device)
113 |
114 | # Add noise to images according to timestep
115 | noisy_im = scheduler.add_noise(im, noise, t)
116 | pred = model(noisy_im, t)
117 | loss = criterion(pred, noise)
118 | losses.append(loss.item())
119 | loss = loss / acc_steps
120 | loss.backward()
121 | if step_count % acc_steps == 0:
122 | optimizer.step()
123 | optimizer.zero_grad()
124 | optimizer.step()
125 | optimizer.zero_grad()
126 | print('Finished epoch:{} | Loss : {:.4f}'.format(
127 | epoch_idx + 1,
128 | np.mean(losses)))
129 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
130 | train_config['dit_ckpt_name']))
131 |
132 | print('Done Training ...')
133 |
134 |
135 | if __name__ == '__main__':
136 | parser = argparse.ArgumentParser(description='Arguments for dit training')
137 | parser.add_argument('--config', dest='config_path',
138 | default='config/celebhq.yaml', type=str)
139 | args = parser.parse_args()
140 | train(args)
141 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/utils/__init__.py
--------------------------------------------------------------------------------
/utils/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import glob
3 | import os
4 | import torch
5 |
6 |
7 | def load_latents(latent_path):
8 | r"""
9 | Simple utility to save latents to speed up ldm training
10 | :param latent_path:
11 | :return:
12 | """
13 | latent_maps = {}
14 | for fname in glob.glob(os.path.join(latent_path, '*.pkl')):
15 | s = pickle.load(open(fname, 'rb'))
16 | for k, v in s.items():
17 | latent_maps[k] = v[0]
18 | return latent_maps
19 |
--------------------------------------------------------------------------------