├── .gitignore ├── README.md ├── config └── celebhq.yaml ├── dataset ├── __init__.py └── celeb_dataset.py ├── model ├── __init__.py ├── attention.py ├── blocks.py ├── discriminator.py ├── lpips.py ├── patch_embed.py ├── transformer.py ├── transformer_layer.py ├── vae.py └── weights │ └── v0.1 │ └── .gitkeep ├── requirements.txt ├── scheduler ├── __init__.py └── linear_scheduler.py ├── tools ├── __init__.py ├── infer_vae.py ├── sample_vae_dit.py ├── train_vae.py └── train_vae_dit.py └── utils ├── __init__.py └── diffusion_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all image files 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | 6 | # Ignore pycharm and system files 7 | .DS_Store 8 | *.idea 9 | __pycache__ 10 | *.zip 11 | 12 | # Ignore dataset files 13 | *.csv 14 | *.json 15 | 16 | # Ignore checkpoints 17 | *.pth 18 | 19 | # Ignore pickle files 20 | *.pkl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Diffusion Transformers(DiT) Implementation in PyTorch 2 | ======== 3 | 4 | ## DiT Tutorial Video 5 | 6 | DiT Tutorial 8 | 9 | 10 | ## Sample Output for DiT on CelebHQ 11 | Trained for 200 epochs 12 | 13 | 14 | 15 | ___ 16 | 17 | This repository implements DiT in PyTorch for diffusion models. It provides code for the following: 18 | * Training and inference of VAE on CelebHQ (128x128 to 32x32x4 latents) 19 | * Training and Inference of DiT using trained VAE on CelebHQ 20 | * Configurable code for training all models from DIT-S to DIT-XL 21 | 22 | This is very similar to [official DiT implementation](https://github.com/facebookresearch/DiT) except the following changes. 23 | * Since training is on celebhq its unconditional generation as of now (but can be easily modified to class conditional or text conditional as well) 24 | * Variance is fixed during training and not learned (like original DDPM) 25 | * No EMA 26 | * Ability to train VAE 27 | * Ability to save latents of VAE for faster training 28 | 29 | 30 | ## Setup 31 | * Create a new conda environment with python 3.10 then run below commands 32 | * `conda activate ` 33 | * ```git clone https://github.com/explainingai-code/DiT-PyTorch.git``` 34 | * ```cd DiT-PyTorch``` 35 | * ```pip install -r requirements.txt``` 36 | * Download lpips weights by opening this link in browser(dont use cURL or wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```models/weights/v0.1/vgg.pth``` 37 | ___ 38 | 39 | ## Data Preparation 40 | 41 | ### CelebHQ 42 | For setting up on CelebHQ, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file). 43 | and add it to `data` directory. 44 | Ensure directory structure is the following 45 | ``` 46 | DiT-PyTorch 47 | -> data 48 | -> CelebAMask-HQ 49 | -> CelebA-HQ-img 50 | -> *.jpg 51 | 52 | ``` 53 | --- 54 | ## Configuration 55 | Allows you to play with different components of DiT and autoencoder 56 | * ```config/celebhq.yaml``` - Configuration used for celebhq dataset 57 | 58 | Important configuration parameters 59 | 60 | * `autoencoder_acc_steps` : For accumulating gradients if image size is too large and a large batch size cant be used. 61 | * `save_latents` : Enable this to save the latents , during inference of autoencoder. That way DiT training will be faster 62 | 63 | ___ 64 | ## Training 65 | The repo provides training and inference for CelebHQ (Unconditional DiT) 66 | 67 | For working on your own dataset: 68 | * Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance) 69 | * Create your own dataset class which will just collect all the filenames and return the image or latent in its getitem method. Look at `celeb_dataset.py` for guidance 70 | 71 | Once the config and dataset is setup: 72 | * First train the auto encoder on your dataset using [this section](#training-autoencoder-for-dit) 73 | * For training and inference of Unconditional DiT follow [this section](#training-unconditional-dit) 74 | 75 | ## Training AutoEncoder for DiT 76 | * For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml` 77 | * For training autoencoder on your own dataset 78 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance) 79 | * Create your own dataset class, similar to celeb_dataset.py 80 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/DiT-PyTorch/blob/main/tools/train_vae.py#L49) 81 | * For training autoencoder run ```python -m tools.train_vae --config config/celebhq.yaml``` for training autoencoder with the desire config file 82 | * For inference make sure `save_latent` is `True` in the config 83 | * For inference run ```python -m tools.infer_vae --config config/celebhq.yaml``` for generating reconstructions and saving latents with right config file. 84 | 85 | ## Training Unconditional DiT 86 | Train the autoencoder first and setup dataset accordingly. 87 | 88 | For training unconditional DiT ensure the right dataset is used in `train_vae_dit.py` 89 | * ```python -m tools.train_vae_dit --config config/celebhq.yaml``` for training unconditional DiT using right config 90 | * ```python -m tools.sample_vae_dit --config config/celebhq.yaml``` for generating images using trained DiT 91 | 92 | 93 | ## Output 94 | Outputs will be saved according to the configuration present in yaml files. 95 | 96 | For every run a folder of ```task_name``` key in config will be created 97 | 98 | During training of autoencoder the following output will be saved 99 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory 100 | * Sample reconstructions in ```task_name/vae_autoencoder_samples``` 101 | 102 | During inference of autoencoder the following output will be saved 103 | * Reconstructions for random images in ```task_name``` 104 | * Latents will be save in ```task_name/vae_latent_dir_name``` if mentioned in config 105 | 106 | During training and inference of unconditional DiT following output will be saved: 107 | * During training we will save the latest checkpoint in ```task_name``` directory 108 | * During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0 109 | 110 | 111 | ## Citations 112 | ``` 113 | @misc{peebles2023scalablediffusionmodelstransformers, 114 | title={Scalable Diffusion Models with Transformers}, 115 | author={William Peebles and Saining Xie}, 116 | year={2023}, 117 | eprint={2212.09748}, 118 | archivePrefix={arXiv}, 119 | primaryClass={cs.CV}, 120 | url={https://arxiv.org/abs/2212.09748}, 121 | } 122 | ``` 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /config/celebhq.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/CelebAMask-HQ' 3 | im_size : 128 4 | im_channels : 3 5 | 6 | diffusion_params: 7 | num_timesteps : 1000 8 | beta_start : 0.0001 9 | beta_end : 0.02 10 | 11 | dit_params: 12 | patch_size : 2 13 | num_layers : 12 14 | hidden_size : 768 15 | num_heads : 12 16 | head_dim : 64 17 | timestep_emb_dim : 768 18 | 19 | autoencoder_params: 20 | z_channels: 4 21 | codebook_size : 8192 22 | down_channels : [128, 256, 384] 23 | mid_channels : [384] 24 | down_sample : [True, True] 25 | attn_down : [False, False] 26 | norm_channels: 32 27 | num_heads: 4 28 | num_down_layers : 2 29 | num_mid_layers : 2 30 | num_up_layers : 2 31 | 32 | 33 | train_params: 34 | seed : 1111 35 | task_name: 'celebhq' 36 | autoencoder_batch_size: 4 37 | autoencoder_epochs: 3 38 | autoencoder_lr: 0.00001 39 | autoencoder_acc_steps: 1 40 | disc_start: 7500 41 | disc_weight: 0.5 42 | codebook_weight: 1 43 | commitment_beta: 0.2 44 | perceptual_weight: 1 45 | kl_weight: 0.000005 46 | autoencoder_img_save_steps: 64 47 | save_latents: False 48 | dit_batch_size: 32 49 | dit_epochs: 500 50 | num_samples: 1 51 | num_grid_rows: 2 52 | dit_lr: 0.00001 53 | dit_acc_steps: 1 54 | vae_latent_dir_name: 'vae_latents' 55 | dit_ckpt_name: 'dit_ckpt.pth' 56 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 57 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 58 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/celeb_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import torchvision 5 | import numpy as np 6 | from PIL import Image 7 | from utils.diffusion_utils import load_latents 8 | from tqdm import tqdm 9 | from torch.utils.data.dataset import Dataset 10 | 11 | 12 | class CelebDataset(Dataset): 13 | r""" 14 | Celeb dataset will by default centre crop and resize the images. 15 | This can be replaced by any other dataset. As long as all the images 16 | are under one directory. 17 | """ 18 | 19 | def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg', 20 | use_latents=False, latent_path=None): 21 | self.split = split 22 | self.im_size = im_size 23 | self.im_channels = im_channels 24 | self.im_ext = im_ext 25 | self.im_path = im_path 26 | self.latent_maps = None 27 | self.use_latents = False 28 | 29 | self.images = self.load_images(im_path) 30 | 31 | # Whether to load images or to load latents 32 | if use_latents and latent_path is not None: 33 | latent_maps = load_latents(latent_path) 34 | if len(latent_maps) == len(self.images): 35 | self.use_latents = True 36 | self.latent_maps = latent_maps 37 | print('Found {} latents'.format(len(self.latent_maps))) 38 | else: 39 | print('Latents not found') 40 | 41 | def load_images(self, im_path): 42 | r""" 43 | Gets all images from the path specified 44 | and stacks them all up 45 | """ 46 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path) 47 | ims = [] 48 | fnames = glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('png'))) 49 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpg'))) 50 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpeg'))) 51 | 52 | for fname in tqdm(fnames): 53 | ims.append(fname) 54 | 55 | print('Found {} images'.format(len(ims))) 56 | return ims 57 | 58 | def __len__(self): 59 | return len(self.images) 60 | 61 | def __getitem__(self, index): 62 | if self.use_latents: 63 | latent = self.latent_maps[self.images[index]] 64 | return latent 65 | 66 | else: 67 | im = Image.open(self.images[index]) 68 | im_tensor = torchvision.transforms.Compose([ 69 | torchvision.transforms.Resize(self.im_size), 70 | torchvision.transforms.CenterCrop(self.im_size), 71 | torchvision.transforms.ToTensor(), 72 | ])(im) 73 | im.close() 74 | 75 | # Convert input to -1 to 1 range. 76 | im_tensor = (2 * im_tensor) - 1 77 | 78 | return im_tensor 79 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/model/__init__.py -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | 6 | class Attention(nn.Module): 7 | r""" 8 | Attention Module for DiT. 9 | This is same as VIT code and does not have any changes 10 | from it. 11 | """ 12 | def __init__(self, config): 13 | super().__init__() 14 | self.n_heads = config['num_heads'] 15 | self.hidden_size = config['hidden_size'] 16 | self.head_dim = config['head_dim'] 17 | 18 | self.att_dim = self.n_heads * self.head_dim 19 | 20 | # QKV projection for the input 21 | self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.att_dim, bias=True) 22 | self.output_proj = nn.Sequential( 23 | nn.Linear(self.att_dim, self.hidden_size)) 24 | 25 | ############################ 26 | # DiT Layer Initialization # 27 | ############################ 28 | nn.init.xavier_uniform_(self.qkv_proj.weight) 29 | nn.init.constant_(self.qkv_proj.bias, 0) 30 | nn.init.xavier_uniform_(self.output_proj[0].weight) 31 | nn.init.constant_(self.output_proj[0].bias, 0) 32 | 33 | def forward(self, x): 34 | # Converting to Attention Dimension 35 | ###################################################### 36 | # Batch Size x Number of Patches x Dimension 37 | B, N = x.shape[:2] 38 | # Projecting to 3*att_dim and then splitting to get q, k v(each of att_dim) 39 | # qkv -> Batch Size x Number of Patches x (3* Attention Dimension) 40 | # q(as well as k and v) -> Batch Size x Number of Patches x Attention Dimension 41 | q, k, v = self.qkv_proj(x).split(self.att_dim, dim=-1) 42 | # Batch Size x Number of Patches x Attention Dimension 43 | # -> Batch Size x Number of Patches x (Heads * Head Dimension) 44 | # -> Batch Size x Number of Patches x (Heads * Head Dimension) 45 | # -> Batch Size x Heads x Number of Patches x Head Dimension 46 | # -> B x H x N x Head Dimension 47 | q = rearrange(q, 'b n (n_h h_dim) -> b n_h n h_dim', 48 | n_h=self.n_heads, h_dim=self.head_dim) 49 | k = rearrange(k, 'b n (n_h h_dim) -> b n_h n h_dim', 50 | n_h=self.n_heads, h_dim=self.head_dim) 51 | v = rearrange(v, 'b n (n_h h_dim) -> b n_h n h_dim', 52 | n_h=self.n_heads, h_dim=self.head_dim) 53 | ######################################################### 54 | 55 | # Compute Attention Weights 56 | ######################################################### 57 | # B x H x N x Head Dimension @ B x H x Head Dimension x N 58 | # -> B x H x N x N 59 | att = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** (-0.5)) 60 | att = torch.nn.functional.softmax(att, dim=-1) 61 | ######################################################### 62 | 63 | # Weighted Value Computation 64 | ######################################################### 65 | # B x H x N x N @ B x H x N x Head Dimension 66 | # -> B x H x N x Head Dimension 67 | out = torch.matmul(att, v) 68 | ######################################################### 69 | 70 | # Converting to Transformer Dimension 71 | ######################################################### 72 | # B x N x (Heads * Head Dimension) -> B x N x (Attention Dimension) 73 | out = rearrange(out, 'b n_h n h_dim -> b n (n_h h_dim)') 74 | # B x N x Dimension 75 | out = self.output_proj(out) 76 | ########################################################## 77 | 78 | return out 79 | -------------------------------------------------------------------------------- /model/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_time_embedding(time_steps, temb_dim): 6 | r""" 7 | Convert time steps tensor into an embedding using the 8 | sinusoidal time embedding formula 9 | :param time_steps: 1D tensor of length batch size 10 | :param temb_dim: Dimension of the embedding 11 | :return: BxD embedding representation of B time steps 12 | """ 13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" 14 | 15 | # factor = 10000^(2i/d_model) 16 | factor = 10000 ** ((torch.arange( 17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) 18 | ) 19 | 20 | # pos / factor 21 | # timesteps B -> B, 1 -> B, temb_dim 22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor 23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) 24 | return t_emb 25 | 26 | 27 | class DownBlock(nn.Module): 28 | r""" 29 | Down conv block with attention. 30 | Sequence of following block 31 | 1. Resnet block with time embedding 32 | 2. Attention block 33 | 3. Downsample 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, t_emb_dim, 37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): 38 | super().__init__() 39 | self.num_layers = num_layers 40 | self.down_sample = down_sample 41 | self.attn = attn 42 | self.context_dim = context_dim 43 | self.cross_attn = cross_attn 44 | self.t_emb_dim = t_emb_dim 45 | self.resnet_conv_first = nn.ModuleList( 46 | [ 47 | nn.Sequential( 48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 49 | nn.SiLU(), 50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, 51 | kernel_size=3, stride=1, padding=1), 52 | ) 53 | for i in range(num_layers) 54 | ] 55 | ) 56 | if self.t_emb_dim is not None: 57 | self.t_emb_layers = nn.ModuleList([ 58 | nn.Sequential( 59 | nn.SiLU(), 60 | nn.Linear(self.t_emb_dim, out_channels) 61 | ) 62 | for _ in range(num_layers) 63 | ]) 64 | self.resnet_conv_second = nn.ModuleList( 65 | [ 66 | nn.Sequential( 67 | nn.GroupNorm(norm_channels, out_channels), 68 | nn.SiLU(), 69 | nn.Conv2d(out_channels, out_channels, 70 | kernel_size=3, stride=1, padding=1), 71 | ) 72 | for _ in range(num_layers) 73 | ] 74 | ) 75 | 76 | if self.attn: 77 | self.attention_norms = nn.ModuleList( 78 | [nn.GroupNorm(norm_channels, out_channels) 79 | for _ in range(num_layers)] 80 | ) 81 | 82 | self.attentions = nn.ModuleList( 83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 84 | for _ in range(num_layers)] 85 | ) 86 | 87 | if self.cross_attn: 88 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 89 | self.cross_attention_norms = nn.ModuleList( 90 | [nn.GroupNorm(norm_channels, out_channels) 91 | for _ in range(num_layers)] 92 | ) 93 | self.cross_attentions = nn.ModuleList( 94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 95 | for _ in range(num_layers)] 96 | ) 97 | self.context_proj = nn.ModuleList( 98 | [nn.Linear(context_dim, out_channels) 99 | for _ in range(num_layers)] 100 | ) 101 | 102 | self.residual_input_conv = nn.ModuleList( 103 | [ 104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 105 | for i in range(num_layers) 106 | ] 107 | ) 108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 109 | 4, 2, 1) if self.down_sample else nn.Identity() 110 | 111 | def forward(self, x, t_emb=None, context=None): 112 | out = x 113 | for i in range(self.num_layers): 114 | # Resnet block of Unet 115 | resnet_input = out 116 | out = self.resnet_conv_first[i](out) 117 | if self.t_emb_dim is not None: 118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 119 | out = self.resnet_conv_second[i](out) 120 | out = out + self.residual_input_conv[i](resnet_input) 121 | 122 | if self.attn: 123 | # Attention block of Unet 124 | batch_size, channels, h, w = out.shape 125 | in_attn = out.reshape(batch_size, channels, h * w) 126 | in_attn = self.attention_norms[i](in_attn) 127 | in_attn = in_attn.transpose(1, 2) 128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 130 | out = out + out_attn 131 | 132 | if self.cross_attn: 133 | assert context is not None, "context cannot be None if cross attention layers are used" 134 | batch_size, channels, h, w = out.shape 135 | in_attn = out.reshape(batch_size, channels, h * w) 136 | in_attn = self.cross_attention_norms[i](in_attn) 137 | in_attn = in_attn.transpose(1, 2) 138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 139 | context_proj = self.context_proj[i](context) 140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 142 | out = out + out_attn 143 | 144 | # Downsample 145 | out = self.down_sample_conv(out) 146 | return out 147 | 148 | 149 | class MidBlock(nn.Module): 150 | r""" 151 | Mid conv block with attention. 152 | Sequence of following blocks 153 | 1. Resnet block with time embedding 154 | 2. Attention block 155 | 3. Resnet block with time embedding 156 | """ 157 | 158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, 159 | context_dim=None): 160 | super().__init__() 161 | self.num_layers = num_layers 162 | self.t_emb_dim = t_emb_dim 163 | self.context_dim = context_dim 164 | self.cross_attn = cross_attn 165 | self.resnet_conv_first = nn.ModuleList( 166 | [ 167 | nn.Sequential( 168 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 169 | nn.SiLU(), 170 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 171 | padding=1), 172 | ) 173 | for i in range(num_layers + 1) 174 | ] 175 | ) 176 | 177 | if self.t_emb_dim is not None: 178 | self.t_emb_layers = nn.ModuleList([ 179 | nn.Sequential( 180 | nn.SiLU(), 181 | nn.Linear(t_emb_dim, out_channels) 182 | ) 183 | for _ in range(num_layers + 1) 184 | ]) 185 | self.resnet_conv_second = nn.ModuleList( 186 | [ 187 | nn.Sequential( 188 | nn.GroupNorm(norm_channels, out_channels), 189 | nn.SiLU(), 190 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 191 | ) 192 | for _ in range(num_layers + 1) 193 | ] 194 | ) 195 | 196 | self.attention_norms = nn.ModuleList( 197 | [nn.GroupNorm(norm_channels, out_channels) 198 | for _ in range(num_layers)] 199 | ) 200 | 201 | self.attentions = nn.ModuleList( 202 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 203 | for _ in range(num_layers)] 204 | ) 205 | if self.cross_attn: 206 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 207 | self.cross_attention_norms = nn.ModuleList( 208 | [nn.GroupNorm(norm_channels, out_channels) 209 | for _ in range(num_layers)] 210 | ) 211 | self.cross_attentions = nn.ModuleList( 212 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 213 | for _ in range(num_layers)] 214 | ) 215 | self.context_proj = nn.ModuleList( 216 | [nn.Linear(context_dim, out_channels) 217 | for _ in range(num_layers)] 218 | ) 219 | self.residual_input_conv = nn.ModuleList( 220 | [ 221 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 222 | for i in range(num_layers + 1) 223 | ] 224 | ) 225 | 226 | def forward(self, x, t_emb=None, context=None): 227 | out = x 228 | 229 | # First resnet block 230 | resnet_input = out 231 | out = self.resnet_conv_first[0](out) 232 | if self.t_emb_dim is not None: 233 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] 234 | out = self.resnet_conv_second[0](out) 235 | out = out + self.residual_input_conv[0](resnet_input) 236 | 237 | for i in range(self.num_layers): 238 | # Attention Block 239 | batch_size, channels, h, w = out.shape 240 | in_attn = out.reshape(batch_size, channels, h * w) 241 | in_attn = self.attention_norms[i](in_attn) 242 | in_attn = in_attn.transpose(1, 2) 243 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 244 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 245 | out = out + out_attn 246 | 247 | if self.cross_attn: 248 | assert context is not None, "context cannot be None if cross attention layers are used" 249 | batch_size, channels, h, w = out.shape 250 | in_attn = out.reshape(batch_size, channels, h * w) 251 | in_attn = self.cross_attention_norms[i](in_attn) 252 | in_attn = in_attn.transpose(1, 2) 253 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 254 | context_proj = self.context_proj[i](context) 255 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 256 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 257 | out = out + out_attn 258 | 259 | # Resnet Block 260 | resnet_input = out 261 | out = self.resnet_conv_first[i + 1](out) 262 | if self.t_emb_dim is not None: 263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] 264 | out = self.resnet_conv_second[i + 1](out) 265 | out = out + self.residual_input_conv[i + 1](resnet_input) 266 | 267 | return out 268 | 269 | 270 | class UpBlock(nn.Module): 271 | r""" 272 | Up conv block with attention. 273 | Sequence of following blocks 274 | 1. Upsample 275 | 1. Concatenate Down block output 276 | 2. Resnet block with time embedding 277 | 3. Attention Block 278 | """ 279 | 280 | def __init__(self, in_channels, out_channels, t_emb_dim, 281 | up_sample, num_heads, num_layers, attn, norm_channels): 282 | super().__init__() 283 | self.num_layers = num_layers 284 | self.up_sample = up_sample 285 | self.t_emb_dim = t_emb_dim 286 | self.attn = attn 287 | self.resnet_conv_first = nn.ModuleList( 288 | [ 289 | nn.Sequential( 290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 291 | nn.SiLU(), 292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 293 | padding=1), 294 | ) 295 | for i in range(num_layers) 296 | ] 297 | ) 298 | 299 | if self.t_emb_dim is not None: 300 | self.t_emb_layers = nn.ModuleList([ 301 | nn.Sequential( 302 | nn.SiLU(), 303 | nn.Linear(t_emb_dim, out_channels) 304 | ) 305 | for _ in range(num_layers) 306 | ]) 307 | 308 | self.resnet_conv_second = nn.ModuleList( 309 | [ 310 | nn.Sequential( 311 | nn.GroupNorm(norm_channels, out_channels), 312 | nn.SiLU(), 313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 314 | ) 315 | for _ in range(num_layers) 316 | ] 317 | ) 318 | if self.attn: 319 | self.attention_norms = nn.ModuleList( 320 | [ 321 | nn.GroupNorm(norm_channels, out_channels) 322 | for _ in range(num_layers) 323 | ] 324 | ) 325 | 326 | self.attentions = nn.ModuleList( 327 | [ 328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 329 | for _ in range(num_layers) 330 | ] 331 | ) 332 | 333 | self.residual_input_conv = nn.ModuleList( 334 | [ 335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 336 | for i in range(num_layers) 337 | ] 338 | ) 339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, 340 | 4, 2, 1) \ 341 | if self.up_sample else nn.Identity() 342 | 343 | def forward(self, x, out_down=None, t_emb=None): 344 | # Upsample 345 | x = self.up_sample_conv(x) 346 | 347 | # Concat with Downblock output 348 | if out_down is not None: 349 | x = torch.cat([x, out_down], dim=1) 350 | 351 | out = x 352 | for i in range(self.num_layers): 353 | # Resnet Block 354 | resnet_input = out 355 | out = self.resnet_conv_first[i](out) 356 | if self.t_emb_dim is not None: 357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 358 | out = self.resnet_conv_second[i](out) 359 | out = out + self.residual_input_conv[i](resnet_input) 360 | 361 | # Self Attention 362 | if self.attn: 363 | batch_size, channels, h, w = out.shape 364 | in_attn = out.reshape(batch_size, channels, h * w) 365 | in_attn = self.attention_norms[i](in_attn) 366 | in_attn = in_attn.transpose(1, 2) 367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 369 | out = out + out_attn 370 | return out 371 | 372 | 373 | class UpBlockUnet(nn.Module): 374 | r""" 375 | Up conv block with attention. 376 | Sequence of following blocks 377 | 1. Upsample 378 | 1. Concatenate Down block output 379 | 2. Resnet block with time embedding 380 | 3. Attention Block 381 | """ 382 | 383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, 384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): 385 | super().__init__() 386 | self.num_layers = num_layers 387 | self.up_sample = up_sample 388 | self.t_emb_dim = t_emb_dim 389 | self.cross_attn = cross_attn 390 | self.context_dim = context_dim 391 | self.resnet_conv_first = nn.ModuleList( 392 | [ 393 | nn.Sequential( 394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 395 | nn.SiLU(), 396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 397 | padding=1), 398 | ) 399 | for i in range(num_layers) 400 | ] 401 | ) 402 | 403 | if self.t_emb_dim is not None: 404 | self.t_emb_layers = nn.ModuleList([ 405 | nn.Sequential( 406 | nn.SiLU(), 407 | nn.Linear(t_emb_dim, out_channels) 408 | ) 409 | for _ in range(num_layers) 410 | ]) 411 | 412 | self.resnet_conv_second = nn.ModuleList( 413 | [ 414 | nn.Sequential( 415 | nn.GroupNorm(norm_channels, out_channels), 416 | nn.SiLU(), 417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 418 | ) 419 | for _ in range(num_layers) 420 | ] 421 | ) 422 | 423 | self.attention_norms = nn.ModuleList( 424 | [ 425 | nn.GroupNorm(norm_channels, out_channels) 426 | for _ in range(num_layers) 427 | ] 428 | ) 429 | 430 | self.attentions = nn.ModuleList( 431 | [ 432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 433 | for _ in range(num_layers) 434 | ] 435 | ) 436 | 437 | if self.cross_attn: 438 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 439 | self.cross_attention_norms = nn.ModuleList( 440 | [nn.GroupNorm(norm_channels, out_channels) 441 | for _ in range(num_layers)] 442 | ) 443 | self.cross_attentions = nn.ModuleList( 444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 445 | for _ in range(num_layers)] 446 | ) 447 | self.context_proj = nn.ModuleList( 448 | [nn.Linear(context_dim, out_channels) 449 | for _ in range(num_layers)] 450 | ) 451 | self.residual_input_conv = nn.ModuleList( 452 | [ 453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 454 | for i in range(num_layers) 455 | ] 456 | ) 457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 458 | 4, 2, 1) \ 459 | if self.up_sample else nn.Identity() 460 | 461 | def forward(self, x, out_down=None, t_emb=None, context=None): 462 | x = self.up_sample_conv(x) 463 | if out_down is not None: 464 | x = torch.cat([x, out_down], dim=1) 465 | 466 | out = x 467 | for i in range(self.num_layers): 468 | # Resnet 469 | resnet_input = out 470 | out = self.resnet_conv_first[i](out) 471 | if self.t_emb_dim is not None: 472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 473 | out = self.resnet_conv_second[i](out) 474 | out = out + self.residual_input_conv[i](resnet_input) 475 | # Self Attention 476 | batch_size, channels, h, w = out.shape 477 | in_attn = out.reshape(batch_size, channels, h * w) 478 | in_attn = self.attention_norms[i](in_attn) 479 | in_attn = in_attn.transpose(1, 2) 480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 482 | out = out + out_attn 483 | # Cross Attention 484 | if self.cross_attn: 485 | assert context is not None, "context cannot be None if cross attention layers are used" 486 | batch_size, channels, h, w = out.shape 487 | in_attn = out.reshape(batch_size, channels, h * w) 488 | in_attn = self.cross_attention_norms[i](in_attn) 489 | in_attn = in_attn.transpose(1, 2) 490 | assert len(context.shape) == 3, \ 491 | "Context shape does not match B,_,CONTEXT_DIM" 492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \ 493 | "Context shape does not match B,_,CONTEXT_DIM" 494 | context_proj = self.context_proj[i](context) 495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 497 | out = out + out_attn 498 | 499 | return out 500 | 501 | 502 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Discriminator(nn.Module): 6 | r""" 7 | PatchGAN Discriminator. 8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to 9 | 1 scalar value , we instead predict grid of values. 10 | Where each grid is prediction of how likely 11 | the discriminator thinks that the image patch corresponding 12 | to the grid cell is real 13 | """ 14 | 15 | def __init__(self, im_channels=3, 16 | conv_channels=[64, 128, 256], 17 | kernels=[4,4,4,4], 18 | strides=[2,2,2,1], 19 | paddings=[1,1,1,1]): 20 | super().__init__() 21 | self.im_channels = im_channels 22 | activation = nn.LeakyReLU(0.2) 23 | layers_dim = [self.im_channels] + conv_channels + [1] 24 | self.layers = nn.ModuleList([ 25 | nn.Sequential( 26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1], 27 | kernel_size=kernels[i], 28 | stride=strides[i], 29 | padding=paddings[i], 30 | bias=False if i !=0 else True), 31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(), 32 | activation if i != len(layers_dim) - 2 else nn.Identity() 33 | ) 34 | for i in range(len(layers_dim) - 1) 35 | ]) 36 | 37 | def forward(self, x): 38 | out = x 39 | for layer in self.layers: 40 | out = layer(out) 41 | return out 42 | 43 | 44 | if __name__ == '__main__': 45 | x = torch.randn((2,3, 256, 256)) 46 | prob = Discriminator(im_channels=3)(x) 47 | print(prob.shape) 48 | -------------------------------------------------------------------------------- /model/lpips.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn 9 | import torchvision 10 | 11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | def spatial_average(in_tens, keepdim=True): 19 | return in_tens.mean([2, 3], keepdim=keepdim) 20 | 21 | 22 | class vgg16(torch.nn.Module): 23 | def __init__(self, requires_grad=False, pretrained=True): 24 | super(vgg16, self).__init__() 25 | # Load pretrained vgg model from torchvision 26 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | self.N_slices = 5 33 | for x in range(4): 34 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 35 | for x in range(4, 9): 36 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 37 | for x in range(9, 16): 38 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 39 | for x in range(16, 23): 40 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 41 | for x in range(23, 30): 42 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 43 | 44 | # Freeze vgg model 45 | if not requires_grad: 46 | for param in self.parameters(): 47 | param.requires_grad = False 48 | 49 | def forward(self, X): 50 | # Return output of vgg features 51 | h = self.slice1(X) 52 | h_relu1_2 = h 53 | h = self.slice2(h) 54 | h_relu2_2 = h 55 | h = self.slice3(h) 56 | h_relu3_3 = h 57 | h = self.slice4(h) 58 | h_relu4_3 = h 59 | h = self.slice5(h) 60 | h_relu5_3 = h 61 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 62 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 63 | return out 64 | 65 | 66 | # Learned perceptual metric 67 | class LPIPS(nn.Module): 68 | def __init__(self, net='vgg', version='0.1', use_dropout=True): 69 | super(LPIPS, self).__init__() 70 | self.version = version 71 | # Imagenet normalization 72 | self.scaling_layer = ScalingLayer() 73 | ######################## 74 | 75 | # Instantiate vgg model 76 | self.chns = [64, 128, 256, 512, 512] 77 | self.L = len(self.chns) 78 | self.net = vgg16(pretrained=True, requires_grad=False) 79 | 80 | # Add 1x1 convolutional Layers 81 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 82 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 83 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 84 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 85 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 86 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 87 | self.lins = nn.ModuleList(self.lins) 88 | ######################## 89 | 90 | # Load the weights of trained LPIPS model 91 | import inspect 92 | import os 93 | model_path = os.path.abspath( 94 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) 95 | print('Loading model from: %s' % model_path) 96 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False) 97 | ######################## 98 | 99 | # Freeze all parameters 100 | self.eval() 101 | for param in self.parameters(): 102 | param.requires_grad = False 103 | ######################## 104 | 105 | def forward(self, in0, in1, normalize=False): 106 | # Scale the inputs to -1 to +1 range if needed 107 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 108 | in0 = 2 * in0 - 1 109 | in1 = 2 * in1 - 1 110 | ######################## 111 | 112 | # Normalize the inputs according to imagenet normalization 113 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) 114 | ######################## 115 | 116 | # Get VGG outputs for image0 and image1 117 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 118 | feats0, feats1, diffs = {}, {}, {} 119 | ######################## 120 | 121 | # Compute Square of Difference for each layer output 122 | for kk in range(self.L): 123 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize( 124 | outs1[kk]) 125 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 126 | ######################## 127 | 128 | # 1x1 convolution followed by spatial average on the square differences 129 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 130 | val = 0 131 | 132 | # Aggregate the results of each layer 133 | for l in range(self.L): 134 | val += res[l] 135 | return val 136 | 137 | 138 | class ScalingLayer(nn.Module): 139 | def __init__(self): 140 | super(ScalingLayer, self).__init__() 141 | # Imagnet normalization for (0-1) 142 | # mean = [0.485, 0.456, 0.406] 143 | # std = [0.229, 0.224, 0.225] 144 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 145 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 146 | 147 | def forward(self, inp): 148 | return (inp - self.shift) / self.scale 149 | 150 | 151 | class NetLinLayer(nn.Module): 152 | ''' A single linear layer which does a 1x1 conv ''' 153 | 154 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 155 | super(NetLinLayer, self).__init__() 156 | 157 | layers = [nn.Dropout(), ] if (use_dropout) else [] 158 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 159 | self.model = nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | out = self.model(x) 163 | return out 164 | -------------------------------------------------------------------------------- /model/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | def get_patch_position_embedding(pos_emb_dim, grid_size, device): 9 | assert pos_emb_dim % 4 == 0, 'Position embedding dimension must be divisible by 4' 10 | grid_size_h, grid_size_w = grid_size 11 | grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device) 12 | grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device) 13 | grid = torch.meshgrid(grid_h, grid_w, indexing='ij') 14 | grid = torch.stack(grid, dim=0) 15 | 16 | # grid_h_positions -> (Number of patch tokens,) 17 | grid_h_positions = grid[0].reshape(-1) 18 | grid_w_positions = grid[1].reshape(-1) 19 | 20 | # factor = 10000^(2i/d_model) 21 | factor = 10000 ** ((torch.arange( 22 | start=0, 23 | end=pos_emb_dim // 4, 24 | dtype=torch.float32, 25 | device=device) / (pos_emb_dim // 4)) 26 | ) 27 | 28 | grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor 29 | grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1) 30 | # grid_h_emb -> (Number of patch tokens, pos_emb_dim // 2) 31 | 32 | grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor 33 | grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1) 34 | pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1) 35 | 36 | # pos_emb -> (Number of patch tokens, pos_emb_dim) 37 | return pos_emb 38 | 39 | 40 | class PatchEmbedding(nn.Module): 41 | r""" 42 | Layer to take in the input image and do the following: 43 | 1. Transform grid of image patches into a sequence of patches. 44 | Number of patches are decided based on image height,width and 45 | patch height, width. 46 | 2. Add positional embedding to the above sequence 47 | """ 48 | 49 | def __init__(self, 50 | image_height, 51 | image_width, 52 | im_channels, 53 | patch_height, 54 | patch_width, 55 | hidden_size): 56 | super().__init__() 57 | self.image_height = image_height 58 | self.image_width = image_width 59 | self.im_channels = im_channels 60 | 61 | self.hidden_size = hidden_size 62 | 63 | self.patch_height = patch_height 64 | self.patch_width = patch_width 65 | 66 | # Input dimension for Patch Embedding FC Layer 67 | patch_dim = self.im_channels * self.patch_height * self.patch_width 68 | self.patch_embed = nn.Sequential( 69 | nn.Linear(patch_dim, self.hidden_size) 70 | ) 71 | 72 | ############################ 73 | # DiT Layer Initialization # 74 | ############################ 75 | nn.init.xavier_uniform_(self.patch_embed[0].weight) 76 | nn.init.constant_(self.patch_embed[0].bias, 0) 77 | 78 | def forward(self, x): 79 | grid_size_h = self.image_height // self.patch_height 80 | grid_size_w = self.image_width // self.patch_width 81 | 82 | # B, C, H, W -> B, (Patches along height * Patches along width), Patch Dimension 83 | # Number of tokens = Patches along height * Patches along width 84 | out = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)', 85 | ph=self.patch_height, 86 | pw=self.patch_width) 87 | 88 | # BxNumber of tokens x Patch Dimension -> B x Number of tokens x Transformer Dimension 89 | out = self.patch_embed(out) 90 | 91 | # Add 2d sinusoidal position embeddings 92 | pos_embed = get_patch_position_embedding(pos_emb_dim=self.hidden_size, 93 | grid_size=(grid_size_h, grid_size_w), 94 | device=x.device) 95 | out += pos_embed 96 | return out 97 | 98 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.patch_embed import PatchEmbedding 4 | from model.transformer_layer import TransformerLayer 5 | from einops import rearrange 6 | 7 | 8 | def get_time_embedding(time_steps, temb_dim): 9 | r""" 10 | Convert time steps tensor into an embedding using the 11 | sinusoidal time embedding formula 12 | :param time_steps: 1D tensor of length batch size 13 | :param temb_dim: Dimension of the embedding 14 | :return: BxD embedding representation of B time steps 15 | """ 16 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" 17 | 18 | # factor = 10000^(2i/d_model) 19 | factor = 10000 ** ((torch.arange( 20 | start=0, 21 | end=temb_dim // 2, 22 | dtype=torch.float32, 23 | device=time_steps.device) / (temb_dim // 2)) 24 | ) 25 | 26 | # pos / factor 27 | # timesteps B -> B, 1 -> B, temb_dim 28 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor 29 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) 30 | return t_emb 31 | 32 | 33 | class DIT(nn.Module): 34 | def __init__(self, im_size, im_channels, config): 35 | super().__init__() 36 | 37 | num_layers = config['num_layers'] 38 | self.image_height = im_size 39 | self.image_width = im_size 40 | self.im_channels = im_channels 41 | self.hidden_size = config['hidden_size'] 42 | self.patch_height = config['patch_size'] 43 | self.patch_width = config['patch_size'] 44 | 45 | self.timestep_emb_dim = config['timestep_emb_dim'] 46 | 47 | # Number of patches along height and width 48 | self.nh = self.image_height // self.patch_height 49 | self.nw = self.image_width // self.patch_width 50 | 51 | # Patch Embedding Block 52 | self.patch_embed_layer = PatchEmbedding(image_height=self.image_height, 53 | image_width=self.image_width, 54 | im_channels=self.im_channels, 55 | patch_height=self.patch_height, 56 | patch_width=self.patch_width, 57 | hidden_size=self.hidden_size) 58 | 59 | # Initial projection from sinusoidal time embedding 60 | self.t_proj = nn.Sequential( 61 | nn.Linear(self.timestep_emb_dim, self.hidden_size), 62 | nn.SiLU(), 63 | nn.Linear(self.hidden_size, self.hidden_size) 64 | ) 65 | 66 | # All Transformer Layers 67 | self.layers = nn.ModuleList([ 68 | TransformerLayer(config) for _ in range(num_layers) 69 | ]) 70 | 71 | # Final normalization for unpatchify block 72 | self.norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6) 73 | 74 | # Scale and Shift parameters for the norm 75 | self.adaptive_norm_layer = nn.Sequential( 76 | nn.SiLU(), 77 | nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=True) 78 | ) 79 | 80 | # Final Linear Layer 81 | self.proj_out = nn.Linear(self.hidden_size, 82 | self.patch_height * self.patch_width * self.im_channels) 83 | 84 | ############################ 85 | # DiT Layer Initialization # 86 | ############################ 87 | nn.init.normal_(self.t_proj[0].weight, std=0.02) 88 | nn.init.normal_(self.t_proj[2].weight, std=0.02) 89 | 90 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0) 91 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0) 92 | 93 | nn.init.constant_(self.proj_out.weight, 0) 94 | nn.init.constant_(self.proj_out.bias, 0) 95 | 96 | def forward(self, x, t): 97 | # Patchify 98 | out = self.patch_embed_layer(x) 99 | 100 | # Compute Timestep representation 101 | # t_emb -> (Batch, timestep_emb_dim) 102 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.timestep_emb_dim) 103 | # (Batch, timestep_emb_dim) -> (Batch, hidden_size) 104 | t_emb = self.t_proj(t_emb) 105 | 106 | # Go through the transformer layers 107 | for layer in self.layers: 108 | out = layer(out, t_emb) 109 | 110 | # Shift and scale predictions for output normalization 111 | pre_mlp_shift, pre_mlp_scale = self.adaptive_norm_layer(t_emb).chunk(2, dim=1) 112 | out = (self.norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) + 113 | pre_mlp_shift.unsqueeze(1)) 114 | 115 | # Unpatchify 116 | # (B,patches,hidden_size) -> (B,patches,channels * patch_width * patch_height) 117 | out = self.proj_out(out) 118 | out = rearrange(out, 'b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)', 119 | ph=self.patch_height, 120 | pw=self.patch_width, 121 | nw=self.nw, 122 | nh=self.nh) 123 | return out 124 | -------------------------------------------------------------------------------- /model/transformer_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from model.attention import Attention 3 | 4 | 5 | class TransformerLayer(nn.Module): 6 | r""" 7 | Transformer block which is just doing the following based on VIT 8 | 1. LayerNorm followed by Attention 9 | 2. LayerNorm followed by Feed forward Block 10 | Both these also have residuals added to them 11 | 12 | For DiT we additionally have 13 | 1. Layernorm mlp to predict layernorm affine parameters from 14 | 2. Same Layernorm mlp to also predict scale parameters for outputs 15 | of both mlp/attention prior to residual connection. 16 | """ 17 | 18 | def __init__(self, config): 19 | super().__init__() 20 | self.hidden_size = config['hidden_size'] 21 | 22 | ff_hidden_dim = 4 * self.hidden_size 23 | 24 | # Layer norm for attention block 25 | self.att_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6) 26 | 27 | self.attn_block = Attention(config) 28 | 29 | # Layer norm for mlp block 30 | self.ff_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6) 31 | 32 | self.mlp_block = nn.Sequential( 33 | nn.Linear(self.hidden_size, ff_hidden_dim), 34 | nn.GELU(approximate='tanh'), 35 | nn.Linear(ff_hidden_dim, self.hidden_size), 36 | ) 37 | 38 | # Scale Shift Parameter predictions for this layer 39 | # 1. Scale and shift parameters for layernorm of attention (2 * hidden_size) 40 | # 2. Scale and shift parameters for layernorm of mlp (2 * hidden_size) 41 | # 3. Scale for output of attention prior to residual connection (hidden_size) 42 | # 4. Scale for output of mlp prior to residual connection (hidden_size) 43 | # Total 6 * hidden_size 44 | self.adaptive_norm_layer = nn.Sequential( 45 | nn.SiLU(), 46 | nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True) 47 | ) 48 | 49 | ############################ 50 | # DiT Layer Initialization # 51 | ############################ 52 | nn.init.xavier_uniform_(self.mlp_block[0].weight) 53 | nn.init.constant_(self.mlp_block[0].bias, 0) 54 | nn.init.xavier_uniform_(self.mlp_block[-1].weight) 55 | nn.init.constant_(self.mlp_block[-1].bias, 0) 56 | 57 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0) 58 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0) 59 | 60 | def forward(self, x, condition): 61 | scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1) 62 | (pre_attn_shift, pre_attn_scale, post_attn_scale, 63 | pre_mlp_shift, pre_mlp_scale, post_mlp_scale) = scale_shift_params 64 | out = x 65 | attn_norm_output = (self.att_norm(out) * (1 + pre_attn_scale.unsqueeze(1)) 66 | + pre_attn_shift.unsqueeze(1)) 67 | out = out + post_attn_scale.unsqueeze(1) * self.attn_block(attn_norm_output) 68 | mlp_norm_output = (self.ff_norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) + 69 | pre_mlp_shift.unsqueeze(1)) 70 | out = out + post_mlp_scale.unsqueeze(1) * self.mlp_block(mlp_norm_output) 71 | return out 72 | -------------------------------------------------------------------------------- /model/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.blocks import DownBlock, MidBlock, UpBlock 4 | 5 | 6 | class VAE(nn.Module): 7 | def __init__(self, im_channels, model_config): 8 | super().__init__() 9 | self.down_channels = model_config['down_channels'] 10 | self.mid_channels = model_config['mid_channels'] 11 | self.down_sample = model_config['down_sample'] 12 | self.num_down_layers = model_config['num_down_layers'] 13 | self.num_mid_layers = model_config['num_mid_layers'] 14 | self.num_up_layers = model_config['num_up_layers'] 15 | 16 | # To disable attention in Downblock of Encoder and Upblock of Decoder 17 | self.attns = model_config['attn_down'] 18 | 19 | # Latent Dimension 20 | self.z_channels = model_config['z_channels'] 21 | self.norm_channels = model_config['norm_channels'] 22 | self.num_heads = model_config['num_heads'] 23 | 24 | # Assertion to validate the channel information 25 | assert self.mid_channels[0] == self.down_channels[-1] 26 | assert self.mid_channels[-1] == self.down_channels[-1] 27 | assert len(self.down_sample) == len(self.down_channels) - 1 28 | assert len(self.attns) == len(self.down_channels) - 1 29 | 30 | # Wherever we use downsampling in encoder correspondingly use 31 | # upsampling in decoder 32 | self.up_sample = list(reversed(self.down_sample)) 33 | 34 | ##################### Encoder ###################### 35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) 36 | 37 | # Downblock + Midblock 38 | self.encoder_layers = nn.ModuleList([]) 39 | for i in range(len(self.down_channels) - 1): 40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], 41 | t_emb_dim=None, down_sample=self.down_sample[i], 42 | num_heads=self.num_heads, 43 | num_layers=self.num_down_layers, 44 | attn=self.attns[i], 45 | norm_channels=self.norm_channels)) 46 | 47 | self.encoder_mids = nn.ModuleList([]) 48 | for i in range(len(self.mid_channels) - 1): 49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], 50 | t_emb_dim=None, 51 | num_heads=self.num_heads, 52 | num_layers=self.num_mid_layers, 53 | norm_channels=self.norm_channels)) 54 | 55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) 56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2 * self.z_channels, kernel_size=3, padding=1) 57 | 58 | # Latent Dimension is 2*Latent because we are predicting mean & variance 59 | self.pre_quant_conv = nn.Conv2d(2 * self.z_channels, 2 * self.z_channels, kernel_size=1) 60 | #################################################### 61 | 62 | ##################### Decoder ###################### 63 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) 64 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) 65 | 66 | # Midblock + Upblock 67 | self.decoder_mids = nn.ModuleList([]) 68 | for i in reversed(range(1, len(self.mid_channels))): 69 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], 70 | t_emb_dim=None, 71 | num_heads=self.num_heads, 72 | num_layers=self.num_mid_layers, 73 | norm_channels=self.norm_channels)) 74 | 75 | self.decoder_layers = nn.ModuleList([]) 76 | for i in reversed(range(1, len(self.down_channels))): 77 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], 78 | t_emb_dim=None, up_sample=self.down_sample[i - 1], 79 | num_heads=self.num_heads, 80 | num_layers=self.num_up_layers, 81 | attn=self.attns[i - 1], 82 | norm_channels=self.norm_channels)) 83 | 84 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) 85 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) 86 | 87 | def encode(self, x): 88 | out = self.encoder_conv_in(x) 89 | for idx, down in enumerate(self.encoder_layers): 90 | out = down(out) 91 | for mid in self.encoder_mids: 92 | out = mid(out) 93 | out = self.encoder_norm_out(out) 94 | out = nn.SiLU()(out) 95 | out = self.encoder_conv_out(out) 96 | out = self.pre_quant_conv(out) 97 | mean, logvar = torch.chunk(out, 2, dim=1) 98 | std = torch.exp(0.5 * logvar) 99 | sample = mean + std * torch.randn(mean.shape).to(device=x.device) 100 | return sample, out 101 | 102 | def decode(self, z): 103 | out = z 104 | out = self.post_quant_conv(out) 105 | out = self.decoder_conv_in(out) 106 | for mid in self.decoder_mids: 107 | out = mid(out) 108 | for idx, up in enumerate(self.decoder_layers): 109 | out = up(out) 110 | 111 | out = self.decoder_norm_out(out) 112 | out = nn.SiLU()(out) 113 | out = self.decoder_conv_out(out) 114 | return out 115 | 116 | def forward(self, x): 117 | z, encoder_output = self.encode(x) 118 | out = self.decode(z) 119 | return out, encoder_output 120 | 121 | -------------------------------------------------------------------------------- /model/weights/v0.1/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/model/weights/v0.1/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | numpy==2.0.1 3 | opencv_python==4.10.0.84 4 | Pillow==10.4.0 5 | PyYAML==6.0.1 6 | torch==2.3.1 7 | torchvision==0.18.1 8 | tqdm==4.66.4 9 | -------------------------------------------------------------------------------- /scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/scheduler/__init__.py -------------------------------------------------------------------------------- /scheduler/linear_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LinearNoiseScheduler: 5 | r""" 6 | Class for the linear noise scheduler that is used in DDPM. 7 | """ 8 | 9 | def __init__(self, num_timesteps, beta_start, beta_end): 10 | self.num_timesteps = num_timesteps 11 | self.beta_start = beta_start 12 | self.beta_end = beta_end 13 | 14 | self.betas = torch.linspace(beta_start, beta_end, num_timesteps) 15 | self.alphas = 1. - self.betas 16 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) 17 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) 18 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) 19 | 20 | def add_noise(self, original, noise, t): 21 | r""" 22 | Forward method for diffusion 23 | :param original: Image on which noise is to be applied 24 | :param noise: Random Noise Tensor (from normal dist) 25 | :param t: timestep of the forward process of shape -> (B,) 26 | :return: 27 | """ 28 | original_shape = original.shape 29 | batch_size = original_shape[0] 30 | 31 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 32 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 33 | 34 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W) 35 | for _ in range(len(original_shape) - 1): 36 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) 37 | for _ in range(len(original_shape) - 1): 38 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) 39 | 40 | # Apply and Return Forward process equation 41 | return (sqrt_alpha_cum_prod.to(original.device) * original 42 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) 43 | 44 | def sample_prev_timestep(self, xt, pred, t): 45 | r""" 46 | Use the noise prediction by model to get 47 | xt-1 using xt and the noise predicted 48 | :param xt: current timestep sample 49 | :param pred: model noise prediction 50 | :param t: current timestep we are at 51 | :return: 52 | """ 53 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * pred)) / 54 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) 55 | x0 = torch.clamp(x0, -1., 1.) 56 | 57 | mean = xt - ((self.betas.to(xt.device)[t]) * pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) 58 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) 59 | 60 | if t == 0: 61 | return mean, x0 62 | else: 63 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) 64 | variance = variance * self.betas.to(xt.device)[t] 65 | sigma = variance ** 0.5 66 | z = torch.randn(xt.shape).to(xt.device) 67 | return mean + sigma * z, x0 68 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/tools/__init__.py -------------------------------------------------------------------------------- /tools/infer_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pickle 5 | 6 | import torch 7 | import torchvision 8 | import yaml 9 | from torch.utils.data.dataloader import DataLoader 10 | from torchvision.utils import make_grid 11 | from tqdm import tqdm 12 | 13 | from dataset.celeb_dataset import CelebDataset 14 | from model.vae import VAE 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | if torch.backends.mps.is_available(): 18 | device = torch.device('mps') 19 | print('Using mps') 20 | 21 | 22 | def infer(args): 23 | ######## Read the config file ####### 24 | with open(args.config_path, 'r') as file: 25 | try: 26 | config = yaml.safe_load(file) 27 | except yaml.YAMLError as exc: 28 | print(exc) 29 | print(config) 30 | 31 | dataset_config = config['dataset_params'] 32 | autoencoder_config = config['autoencoder_params'] 33 | train_config = config['train_params'] 34 | 35 | im_dataset = CelebDataset(split='train', 36 | im_path=dataset_config['im_path'], 37 | im_size=dataset_config['im_size'], 38 | im_channels=dataset_config['im_channels']) 39 | 40 | # This is only used for saving latents. Which as of now 41 | # is not done in batches hence batch size 1 42 | data_loader = DataLoader(im_dataset, 43 | batch_size=1, 44 | shuffle=False) 45 | 46 | num_images = train_config['num_samples'] 47 | ngrid = train_config['num_grid_rows'] 48 | 49 | idxs = torch.randint(0, len(im_dataset) - 1, (num_images,)) 50 | ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float() 51 | ims = ims.to(device) 52 | 53 | model = VAE(im_channels=dataset_config['im_channels'], 54 | model_config=autoencoder_config).to(device) 55 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 56 | train_config['vae_autoencoder_ckpt_name']), 57 | map_location=device)) 58 | model.eval() 59 | 60 | with torch.no_grad(): 61 | 62 | encoded_output, _ = model.encode(ims) 63 | decoded_output = model.decode(encoded_output) 64 | encoded_output = torch.clamp(encoded_output, -1., 1.) 65 | encoded_output = (encoded_output + 1) / 2 66 | decoded_output = torch.clamp(decoded_output, -1., 1.) 67 | decoded_output = (decoded_output + 1) / 2 68 | ims = (ims + 1) / 2 69 | 70 | encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid) 71 | decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid) 72 | input_grid = make_grid(ims.cpu(), nrow=ngrid) 73 | encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid) 74 | decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid) 75 | input_grid = torchvision.transforms.ToPILImage()(input_grid) 76 | 77 | input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png')) 78 | encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png')) 79 | decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png')) 80 | 81 | if train_config['save_latents']: 82 | # save Latents (but in a very unoptimized way) 83 | latent_path = os.path.join(train_config['task_name'], train_config['vae_latent_dir_name']) 84 | latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vae_latent_dir_name'], 85 | '*.pkl')) 86 | assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run' 87 | if not os.path.exists(latent_path): 88 | os.mkdir(latent_path) 89 | print('Saving Latents for {}'.format(dataset_config['name'])) 90 | 91 | fname_latent_map = {} 92 | part_count = 0 93 | count = 0 94 | for idx, im in enumerate(tqdm(data_loader)): 95 | _, encoded_output = model.encode(im.float().to(device)) 96 | fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu() 97 | # Save latents every 1000 images 98 | if (count + 1) % 1000 == 0: 99 | pickle.dump(fname_latent_map, open(os.path.join(latent_path, 100 | '{}.pkl'.format(part_count)), 'wb')) 101 | part_count += 1 102 | fname_latent_map = {} 103 | count += 1 104 | if len(fname_latent_map) > 0: 105 | pickle.dump(fname_latent_map, open(os.path.join(latent_path, 106 | '{}.pkl'.format(part_count)), 'wb')) 107 | print('Done saving latents') 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser(description='Arguments for vae inference') 112 | parser.add_argument('--config', dest='config_path', 113 | default='config/celebhq.yaml', type=str) 114 | args = parser.parse_args() 115 | infer(args) 116 | -------------------------------------------------------------------------------- /tools/sample_vae_dit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | from torchvision.utils import make_grid 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from model.vae import VAE 10 | from model.transformer import DIT 11 | from scheduler.linear_scheduler import LinearNoiseScheduler 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | 19 | def sample(model, scheduler, train_config, dit_model_config, 20 | autoencoder_model_config, diffusion_config, dataset_config, vae): 21 | r""" 22 | Sample stepwise by going backward one timestep at a time. 23 | We save the x0 predictions 24 | """ 25 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 26 | xt = torch.randn((train_config['num_samples'], 27 | autoencoder_model_config['z_channels'], 28 | im_size, 29 | im_size)).to(device) 30 | 31 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 32 | # Get prediction of noise 33 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) 34 | 35 | # Use scheduler to get x0 and xt-1 36 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 37 | 38 | # Save x0 39 | # ims = torch.clamp(xt, -1., 1.).detach().cpu() 40 | if i == 0: 41 | # Decode ONLY the final image to save time 42 | ims = vae.to(device).decode(xt) 43 | else: 44 | ims = xt 45 | ims = xt[:, :-1, :, :] 46 | 47 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 48 | ims = (ims + 1) / 2 49 | 50 | grid = make_grid(ims, nrow=train_config['num_grid_rows']) 51 | img = torchvision.transforms.ToPILImage()(grid) 52 | 53 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')): 54 | os.mkdir(os.path.join(train_config['task_name'], 'samples')) 55 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i))) 56 | img.close() 57 | 58 | 59 | def infer(args): 60 | # Read the config file # 61 | with open(args.config_path, 'r') as file: 62 | try: 63 | config = yaml.safe_load(file) 64 | except yaml.YAMLError as exc: 65 | print(exc) 66 | print(config) 67 | ######################## 68 | 69 | diffusion_config = config['diffusion_params'] 70 | dataset_config = config['dataset_params'] 71 | dit_model_config = config['dit_params'] 72 | autoencoder_model_config = config['autoencoder_params'] 73 | train_config = config['train_params'] 74 | 75 | # Create the noise scheduler 76 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 77 | beta_start=diffusion_config['beta_start'], 78 | beta_end=diffusion_config['beta_end']) 79 | 80 | # Get latent image size 81 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 82 | model = DIT(im_size=im_size, 83 | im_channels=autoencoder_model_config['z_channels'], 84 | config=dit_model_config).to(device) 85 | 86 | model.eval() 87 | 88 | assert os.path.exists(os.path.join(train_config['task_name'], 89 | train_config['dit_ckpt_name'])), "Train DiT first" 90 | 91 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 92 | train_config['dit_ckpt_name']), 93 | map_location=device)) 94 | print('Loaded dit checkpoint') 95 | 96 | # Create output directories 97 | if not os.path.exists(train_config['task_name']): 98 | os.mkdir(train_config['task_name']) 99 | 100 | vae = VAE(im_channels=dataset_config['im_channels'], 101 | model_config=autoencoder_model_config) 102 | vae.eval() 103 | 104 | # Load vae if found 105 | assert os.path.exists(os.path.join(train_config['task_name'], train_config['vae_autoencoder_ckpt_name'])), \ 106 | "VAE checkpoint not present. Train VAE first." 107 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 108 | train_config['vae_autoencoder_ckpt_name']), 109 | map_location=device), strict=True) 110 | print('Loaded vae checkpoint') 111 | 112 | with torch.no_grad(): 113 | sample(model, scheduler, train_config, dit_model_config, 114 | autoencoder_model_config, diffusion_config, dataset_config, vae) 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser(description='Arguments for dit image generation') 119 | parser.add_argument('--config', dest='config_path', 120 | default='config/celebhq.yaml', type=str) 121 | args = parser.parse_args() 122 | infer(args) 123 | -------------------------------------------------------------------------------- /tools/train_vae.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import random 5 | import torchvision 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from model.vae import VAE 10 | from model.lpips import LPIPS 11 | from model.discriminator import Discriminator 12 | from torch.utils.data.dataloader import DataLoader 13 | from dataset.celeb_dataset import CelebDataset 14 | from torch.optim import Adam 15 | from torchvision.utils import make_grid 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | if torch.backends.mps.is_available(): 19 | device = torch.device('mps') 20 | print('Using mps') 21 | 22 | 23 | def train(args): 24 | # Read the config file # 25 | with open(args.config_path, 'r') as file: 26 | try: 27 | config = yaml.safe_load(file) 28 | except yaml.YAMLError as exc: 29 | print(exc) 30 | print(config) 31 | 32 | dataset_config = config['dataset_params'] 33 | autoencoder_config = config['autoencoder_params'] 34 | train_config = config['train_params'] 35 | 36 | # Set the desired seed value # 37 | seed = train_config['seed'] 38 | torch.manual_seed(seed) 39 | np.random.seed(seed) 40 | random.seed(seed) 41 | if device == 'cuda': 42 | torch.cuda.manual_seed_all(seed) 43 | ############################# 44 | 45 | # Create the model and dataset # 46 | model = VAE(im_channels=dataset_config['im_channels'], 47 | model_config=autoencoder_config).to(device) 48 | # Create the dataset 49 | im_dataset = CelebDataset(split='train', 50 | im_path=dataset_config['im_path'], 51 | im_size=dataset_config['im_size'], 52 | im_channels=dataset_config['im_channels']) 53 | 54 | data_loader = DataLoader(im_dataset, 55 | batch_size=train_config['autoencoder_batch_size'], 56 | shuffle=True) 57 | 58 | # Create output directories 59 | if not os.path.exists(train_config['task_name']): 60 | os.mkdir(train_config['task_name']) 61 | 62 | num_epochs = train_config['autoencoder_epochs'] 63 | 64 | # L1/L2 loss for Reconstruction 65 | recon_criterion = torch.nn.MSELoss() 66 | # Disc Loss can even be BCEWithLogits 67 | disc_criterion = torch.nn.MSELoss() 68 | 69 | # No need to freeze lpips as lpips.py takes care of that 70 | lpips_model = LPIPS().eval().to(device) 71 | discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device) 72 | 73 | if os.path.exists(os.path.join(train_config['task_name'], 74 | train_config['vae_autoencoder_ckpt_name'])): 75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 76 | train_config['vae_autoencoder_ckpt_name']), 77 | map_location=device)) 78 | print('Loaded autoencoder from checkpoint') 79 | 80 | if os.path.exists(os.path.join(train_config['task_name'], 81 | train_config['vae_discriminator_ckpt_name'])): 82 | discriminator.load_state_dict(torch.load(os.path.join(train_config['task_name'], 83 | train_config['vae_discriminator_ckpt_name']), 84 | map_location=device)) 85 | print('Loaded discriminator from checkpoint') 86 | 87 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 88 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 89 | 90 | disc_step_start = train_config['disc_start'] 91 | step_count = 0 92 | 93 | # This is for accumulating gradients incase the images are huge 94 | # And one cant afford higher batch sizes 95 | acc_steps = train_config['autoencoder_acc_steps'] 96 | image_save_steps = train_config['autoencoder_img_save_steps'] 97 | img_save_count = 0 98 | 99 | for epoch_idx in range(num_epochs): 100 | recon_losses = [] 101 | perceptual_losses = [] 102 | disc_losses = [] 103 | gen_losses = [] 104 | losses = [] 105 | 106 | optimizer_g.zero_grad() 107 | optimizer_d.zero_grad() 108 | 109 | for im in tqdm(data_loader): 110 | step_count += 1 111 | im = im.float().to(device) 112 | 113 | # Fetch autoencoders output(reconstructions) 114 | model_output = model(im) 115 | output, encoder_output = model_output 116 | 117 | # Image Saving Logic 118 | if step_count % image_save_steps == 0 or step_count == 1: 119 | sample_size = min(8, im.shape[0]) 120 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu() 121 | save_output = ((save_output + 1) / 2) 122 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu() 123 | 124 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) 125 | img = torchvision.transforms.ToPILImage()(grid) 126 | if not os.path.exists(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')): 127 | os.mkdir(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')) 128 | img.save(os.path.join(train_config['task_name'], 'vae_autoencoder_samples', 129 | 'current_autoencoder_sample_{}.png'.format(img_save_count))) 130 | img_save_count += 1 131 | img.close() 132 | 133 | ######### Optimize Generator ########## 134 | # L2 Loss 135 | recon_loss = recon_criterion(output, im) 136 | recon_losses.append(recon_loss.item()) 137 | recon_loss = recon_loss / acc_steps 138 | 139 | mean, logvar = torch.chunk(encoder_output, 2, dim=1) 140 | kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mean ** 2 - 1 - logvar, dim=[1, 2, 3])) 141 | 142 | g_loss = recon_loss + (train_config['kl_weight'] * kl_loss / acc_steps) 143 | 144 | # Adversarial loss only if disc_step_start steps passed 145 | if step_count > disc_step_start: 146 | disc_fake_pred = discriminator(model_output[0]) 147 | disc_fake_loss = disc_criterion(disc_fake_pred, 148 | torch.ones(disc_fake_pred.shape, 149 | device=disc_fake_pred.device)) 150 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item()) 151 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps 152 | lpips_loss = torch.mean(lpips_model(output, im)) 153 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item()) 154 | g_loss += train_config['perceptual_weight'] * lpips_loss / acc_steps 155 | losses.append(g_loss.item()) 156 | g_loss.backward() 157 | ##################################### 158 | 159 | ######### Optimize Discriminator ####### 160 | if step_count > disc_step_start: 161 | fake = output 162 | disc_fake_pred = discriminator(fake.detach()) 163 | disc_real_pred = discriminator(im) 164 | disc_fake_loss = disc_criterion(disc_fake_pred, 165 | torch.zeros(disc_fake_pred.shape, 166 | device=disc_fake_pred.device)) 167 | disc_real_loss = disc_criterion(disc_real_pred, 168 | torch.ones(disc_real_pred.shape, 169 | device=disc_real_pred.device)) 170 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2 171 | disc_losses.append(disc_loss.item()) 172 | disc_loss = disc_loss / acc_steps 173 | disc_loss.backward() 174 | if step_count % acc_steps == 0: 175 | optimizer_d.step() 176 | optimizer_d.zero_grad() 177 | ##################################### 178 | 179 | if step_count % acc_steps == 0: 180 | optimizer_g.step() 181 | optimizer_g.zero_grad() 182 | optimizer_d.step() 183 | optimizer_d.zero_grad() 184 | optimizer_g.step() 185 | optimizer_g.zero_grad() 186 | if len(disc_losses) > 0: 187 | print( 188 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | ' 189 | 'G Loss : {:.4f} | D Loss {:.4f}'. 190 | format(epoch_idx + 1, 191 | np.mean(recon_losses), 192 | np.mean(perceptual_losses), 193 | np.mean(gen_losses), 194 | np.mean(disc_losses))) 195 | else: 196 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}'. 197 | format(epoch_idx + 1, 198 | np.mean(recon_losses), 199 | np.mean(perceptual_losses))) 200 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 201 | train_config['vae_autoencoder_ckpt_name'])) 202 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'], 203 | train_config['vae_discriminator_ckpt_name'])) 204 | print('Done Training...') 205 | 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser(description='Arguments for vae training') 209 | parser.add_argument('--config', dest='config_path', 210 | default='config/celebhq.yaml', type=str) 211 | args = parser.parse_args() 212 | train(args) 213 | -------------------------------------------------------------------------------- /tools/train_vae_dit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.optim import AdamW 8 | from dataset.celeb_dataset import CelebDataset 9 | from torch.utils.data import DataLoader 10 | from model.transformer import DIT 11 | from model.vae import VAE 12 | from scheduler.linear_scheduler import LinearNoiseScheduler 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | if torch.backends.mps.is_available(): 16 | device = torch.device('mps') 17 | print('Using mps') 18 | 19 | 20 | def train(args): 21 | # Read the config file # 22 | with open(args.config_path, 'r') as file: 23 | try: 24 | config = yaml.safe_load(file) 25 | except yaml.YAMLError as exc: 26 | print(exc) 27 | print(config) 28 | ######################## 29 | 30 | diffusion_config = config['diffusion_params'] 31 | dataset_config = config['dataset_params'] 32 | dit_model_config = config['dit_params'] 33 | autoencoder_model_config = config['autoencoder_params'] 34 | train_config = config['train_params'] 35 | 36 | # Create the noise scheduler 37 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 38 | beta_start=diffusion_config['beta_start'], 39 | beta_end=diffusion_config['beta_end']) 40 | 41 | im_dataset = CelebDataset(split='train', 42 | im_path=dataset_config['im_path'], 43 | im_size=dataset_config['im_size'], 44 | im_channels=dataset_config['im_channels'], 45 | use_latents=True, 46 | latent_path=os.path.join(train_config['task_name'], 47 | train_config['vae_latent_dir_name']) 48 | ) 49 | 50 | data_loader = DataLoader(im_dataset, 51 | batch_size=train_config['dit_batch_size'], 52 | shuffle=True) 53 | 54 | # Instantiate the model 55 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 56 | model = DIT(im_size=im_size, 57 | im_channels=autoencoder_model_config['z_channels'], 58 | config=dit_model_config).to(device) 59 | model.train() 60 | 61 | if os.path.exists(os.path.join(train_config['task_name'], 62 | train_config['dit_ckpt_name'])): 63 | print('Loaded DiT checkpoint') 64 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 65 | train_config['dit_ckpt_name']), 66 | map_location=device)) 67 | 68 | # Load VAE ONLY if latents are not to be used or are missing 69 | if not im_dataset.use_latents: 70 | print('Loading vae model as latents not present') 71 | vae = VAE(im_channels=dataset_config['im_channels'], 72 | model_config=autoencoder_model_config).to(device) 73 | vae.eval() 74 | # Load vae if found 75 | if os.path.exists(os.path.join(train_config['task_name'], 76 | train_config['vae_autoencoder_ckpt_name'])): 77 | print('Loaded vae checkpoint') 78 | vae.load_state_dict(torch.load(os.path.join( 79 | train_config['task_name'], 80 | train_config['vae_autoencoder_ckpt_name']), 81 | map_location=device)) 82 | # Specify training parameters 83 | num_epochs = train_config['dit_epochs'] 84 | optimizer = AdamW(model.parameters(), lr=1E-5, weight_decay=0) 85 | criterion = torch.nn.MSELoss() 86 | 87 | # Run training 88 | if not im_dataset.use_latents: 89 | for param in vae.parameters(): 90 | param.requires_grad = False 91 | 92 | acc_steps = train_config['dit_acc_steps'] 93 | for epoch_idx in range(num_epochs): 94 | losses = [] 95 | step_count = 0 96 | for im in tqdm(data_loader): 97 | step_count += 1 98 | im = im.float().to(device) 99 | if im_dataset.use_latents: 100 | mean, logvar = torch.chunk(im, 2, dim=1) 101 | std = torch.exp(0.5 * logvar) 102 | im = mean + std * torch.randn(mean.shape).to(device=im.device) 103 | else: 104 | with torch.no_grad(): 105 | im, _ = vae.encode(im) 106 | 107 | # Sample random noise 108 | noise = torch.randn_like(im).to(device) 109 | 110 | # Sample timestep 111 | t = torch.randint(0, diffusion_config['num_timesteps'], 112 | (im.shape[0],)).to(device) 113 | 114 | # Add noise to images according to timestep 115 | noisy_im = scheduler.add_noise(im, noise, t) 116 | pred = model(noisy_im, t) 117 | loss = criterion(pred, noise) 118 | losses.append(loss.item()) 119 | loss = loss / acc_steps 120 | loss.backward() 121 | if step_count % acc_steps == 0: 122 | optimizer.step() 123 | optimizer.zero_grad() 124 | optimizer.step() 125 | optimizer.zero_grad() 126 | print('Finished epoch:{} | Loss : {:.4f}'.format( 127 | epoch_idx + 1, 128 | np.mean(losses))) 129 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 130 | train_config['dit_ckpt_name'])) 131 | 132 | print('Done Training ...') 133 | 134 | 135 | if __name__ == '__main__': 136 | parser = argparse.ArgumentParser(description='Arguments for dit training') 137 | parser.add_argument('--config', dest='config_path', 138 | default='config/celebhq.yaml', type=str) 139 | args = parser.parse_args() 140 | train(args) 141 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/DiT-PyTorch/ed9c0bd29f2c2b2a64fad8c5b759b834f8c1c4c5/utils/__init__.py -------------------------------------------------------------------------------- /utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import glob 3 | import os 4 | import torch 5 | 6 | 7 | def load_latents(latent_path): 8 | r""" 9 | Simple utility to save latents to speed up ldm training 10 | :param latent_path: 11 | :return: 12 | """ 13 | latent_maps = {} 14 | for fname in glob.glob(os.path.join(latent_path, '*.pkl')): 15 | s = pickle.load(open(fname, 'rb')) 16 | for k, v in s.items(): 17 | latent_maps[k] = v[0] 18 | return latent_maps 19 | --------------------------------------------------------------------------------