├── data └── .gitkeep ├── config ├── __init__.py ├── mnist.yaml ├── celebhq.yaml ├── mnist_class_cond.yaml ├── celebhq_text_cond.yaml ├── celebhq_text_cond_bert.yaml └── celebhq_text_image_cond.yaml ├── models ├── __init__.py ├── weights │ ├── .gitkeep │ └── v0.1 │ │ └── .gitkeep ├── discriminator.py ├── unet_base.py ├── vae.py ├── lpips.py ├── vqvae.py ├── unet_cond_base.py └── blocks.py ├── tools ├── __init__.py ├── sample_ddpm_vqvae.py ├── train_ddpm_vqvae.py ├── infer_vqvae.py ├── sample_ddpm_class_cond.py ├── sample_ddpm_text_cond.py ├── train_ddpm_cond.py ├── train_vqvae.py └── sample_ddpm_text_image_cond.py ├── utils ├── __init__.py ├── create_celeb_mask.py ├── text_utils.py ├── diffusion_utils.py └── config_utils.py ├── dataset ├── __init__.py ├── mnist_dataset.py └── celeb_dataset.py ├── scheduler ├── __init__.py └── linear_noise_scheduler.py ├── .github └── FUNDING.yml ├── .gitignore ├── requirements.txt ├── LICENSE └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/weights/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/weights/v0.1/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: explainingai-code 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all image files 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | 6 | # Ignore pycharm and system files 7 | .DS_Store 8 | *.idea 9 | __pycache__ 10 | *.zip 11 | 12 | # Ignore dataset files 13 | *.csv 14 | *.json 15 | 16 | # Ignore checkpoints 17 | *.pth 18 | 19 | # Ignore pickle files 20 | *.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3==1.34.36 2 | botocore==1.34.36 3 | certifi==2023.7.22 4 | charset-normalizer==3.2.0 5 | click==8.1.7 6 | contourpy==1.1.0 7 | cycler==0.11.0 8 | einops==0.6.1 9 | filelock==3.13.1 10 | fonttools==4.42.0 11 | fsspec==2024.2.0 12 | ftfy==6.1.3 13 | huggingface-hub==0.20.3 14 | idna==3.4 15 | importlib-metadata==7.0.1 16 | importlib-resources==6.0.1 17 | jmespath==1.0.1 18 | joblib==1.3.2 19 | kiwisolver==1.4.4 20 | matplotlib==3.7.2 21 | numpy==1.23.5 22 | opencv-python==4.8.0.74 23 | packaging==23.1 24 | Pillow==10.0.0 25 | pyparsing==3.0.9 26 | python-dateutil==2.8.2 27 | PyYAML==6.0 28 | regex==2023.12.25 29 | requests==2.31.0 30 | s3transfer==0.10.0 31 | sacremoses==0.1.1 32 | safetensors==0.4.2 33 | scipy==1.10.1 34 | sentencepiece==0.1.99 35 | six==1.16.0 36 | tokenizers==0.15.1 37 | torch==1.11.0 38 | torchvision==0.12.0 39 | tqdm==4.65.0 40 | transformers==4.37.2 41 | typing_extensions==4.7.1 42 | urllib3==1.26.18 43 | wcwidth==0.2.13 44 | zipp==3.16.2 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ExplainingAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/create_celeb_mask.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is exact copy of https://github.com/switchablenorms/CelebAMask-HQ/blob/master/face_parsing/Data_preprocessing/g_mask.py 3 | Only the folder locations are modified. 4 | """ 5 | 6 | import os 7 | import cv2 8 | import glob 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] 13 | 14 | folder_base = 'data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' 15 | folder_save = 'data/CelebAMask-HQ/CelebAMask-HQ-mask' 16 | img_num = 30000 17 | 18 | if not os.path.exists(folder_save): 19 | os.mkdir(folder_save) 20 | 21 | for k in tqdm(range(img_num)): 22 | folder_num = k // 2000 23 | im_base = np.zeros((512, 512)) 24 | for idx, label in enumerate(label_list): 25 | filename = os.path.join(folder_base, str(folder_num), str(k).rjust(5, '0') + '_' + label + '.png') 26 | if os.path.exists(filename): 27 | im = cv2.imread(filename) 28 | im = im[:, :, 0] 29 | im_base[im != 0] = (idx + 1) 30 | 31 | filename_save = os.path.join(folder_save, str(k) + '.png') 32 | cv2.imwrite(filename_save, im_base) 33 | -------------------------------------------------------------------------------- /utils/text_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import DistilBertModel, DistilBertTokenizer, CLIPTokenizer, CLIPTextModel 3 | 4 | 5 | def get_tokenizer_and_model(model_type, device, eval_mode=True): 6 | assert model_type in ('bert', 'clip'), "Text model can only be one of clip or bert" 7 | if model_type == 'bert': 8 | text_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') 9 | text_model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device) 10 | else: 11 | text_tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch16') 12 | text_model = CLIPTextModel.from_pretrained('openai/clip-vit-base-patch16').to(device) 13 | if eval_mode: 14 | text_model.eval() 15 | return text_tokenizer, text_model 16 | 17 | 18 | def get_text_representation(text, text_tokenizer, text_model, device, 19 | truncation=True, 20 | padding='max_length', 21 | max_length=77): 22 | token_output = text_tokenizer(text, 23 | truncation=truncation, 24 | padding=padding, 25 | return_attention_mask=True, 26 | max_length=max_length) 27 | indexed_tokens = token_output['input_ids'] 28 | att_masks = token_output['attention_mask'] 29 | tokens_tensor = torch.tensor(indexed_tokens).to(device) 30 | mask_tensor = torch.tensor(att_masks).to(device) 31 | text_embed = text_model(tokens_tensor, attention_mask=mask_tensor).last_hidden_state 32 | return text_embed 33 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Discriminator(nn.Module): 6 | r""" 7 | PatchGAN Discriminator. 8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to 9 | 1 scalar value , we instead predict grid of values. 10 | Where each grid is prediction of how likely 11 | the discriminator thinks that the image patch corresponding 12 | to the grid cell is real 13 | """ 14 | 15 | def __init__(self, im_channels=3, 16 | conv_channels=[64, 128, 256], 17 | kernels=[4,4,4,4], 18 | strides=[2,2,2,1], 19 | paddings=[1,1,1,1]): 20 | super().__init__() 21 | self.im_channels = im_channels 22 | activation = nn.LeakyReLU(0.2) 23 | layers_dim = [self.im_channels] + conv_channels + [1] 24 | self.layers = nn.ModuleList([ 25 | nn.Sequential( 26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1], 27 | kernel_size=kernels[i], 28 | stride=strides[i], 29 | padding=paddings[i], 30 | bias=False if i !=0 else True), 31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(), 32 | activation if i != len(layers_dim) - 2 else nn.Identity() 33 | ) 34 | for i in range(len(layers_dim) - 1) 35 | ]) 36 | 37 | def forward(self, x): 38 | out = x 39 | for layer in self.layers: 40 | out = layer(out) 41 | return out 42 | 43 | 44 | if __name__ == '__main__': 45 | x = torch.randn((2,3, 256, 256)) 46 | prob = Discriminator(im_channels=3)(x) 47 | print(prob.shape) 48 | -------------------------------------------------------------------------------- /config/mnist.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/mnist/train/images' 3 | im_channels : 1 4 | im_size : 28 5 | name: 'mnist' 6 | 7 | diffusion_params: 8 | num_timesteps : 1000 9 | beta_start : 0.0015 10 | beta_end : 0.0195 11 | 12 | ldm_params: 13 | down_channels: [ 128, 256, 256, 256] 14 | mid_channels: [ 256, 256] 15 | down_sample: [ False, False, False ] 16 | attn_down : [True, True, True] 17 | time_emb_dim: 256 18 | norm_channels : 32 19 | num_heads : 16 20 | conv_out_channels : 128 21 | num_down_layers: 2 22 | num_mid_layers: 2 23 | num_up_layers: 2 24 | 25 | autoencoder_params: 26 | z_channels: 3 27 | codebook_size : 20 28 | down_channels : [32, 64, 128] 29 | mid_channels : [128, 128] 30 | down_sample : [True, True] 31 | attn_down : [False, False] 32 | norm_channels: 32 33 | num_heads: 16 34 | num_down_layers : 1 35 | num_mid_layers : 1 36 | num_up_layers : 1 37 | 38 | train_params: 39 | seed : 1111 40 | task_name: 'mnist' 41 | ldm_batch_size: 64 42 | autoencoder_batch_size: 64 43 | disc_start: 1000 44 | disc_weight: 0.5 45 | codebook_weight: 1 46 | commitment_beta: 0.2 47 | perceptual_weight: 1 48 | kl_weight: 0.000005 49 | ldm_epochs : 100 50 | autoencoder_epochs : 10 51 | num_samples : 25 52 | num_grid_rows : 5 53 | ldm_lr: 0.00001 54 | autoencoder_lr: 0.0001 55 | autoencoder_acc_steps : 1 56 | autoencoder_img_save_steps : 8 57 | save_latents : False 58 | vae_latent_dir_name : 'vae_latents' 59 | vqvae_latent_dir_name : 'vqvae_latents' 60 | ldm_ckpt_name: 'ddpm_ckpt.pth' 61 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth' 62 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 63 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth' 64 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 65 | -------------------------------------------------------------------------------- /config/celebhq.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/CelebAMask-HQ' 3 | im_channels : 3 4 | im_size : 256 5 | name: 'celebhq' 6 | 7 | diffusion_params: 8 | num_timesteps : 1000 9 | beta_start : 0.0015 10 | beta_end : 0.0195 11 | 12 | ldm_params: 13 | down_channels: [ 256, 384, 512, 768 ] 14 | mid_channels: [ 768, 512 ] 15 | down_sample: [ True, True, True ] 16 | attn_down : [True, True, True] 17 | time_emb_dim: 512 18 | norm_channels: 32 19 | num_heads: 16 20 | conv_out_channels : 128 21 | num_down_layers : 2 22 | num_mid_layers : 2 23 | num_up_layers : 2 24 | 25 | autoencoder_params: 26 | z_channels: 4 27 | codebook_size : 8192 28 | down_channels : [64, 128, 256, 256] 29 | mid_channels : [256, 256] 30 | down_sample : [True, True, True] 31 | attn_down : [False, False, False] 32 | norm_channels: 32 33 | num_heads: 4 34 | num_down_layers : 2 35 | num_mid_layers : 2 36 | num_up_layers : 2 37 | 38 | 39 | train_params: 40 | seed : 1111 41 | task_name: 'celebhq' 42 | ldm_batch_size: 16 43 | autoencoder_batch_size: 4 44 | disc_start: 15000 45 | disc_weight: 0.5 46 | codebook_weight: 1 47 | commitment_beta: 0.2 48 | perceptual_weight: 1 49 | kl_weight: 0.000005 50 | ldm_epochs: 100 51 | autoencoder_epochs: 20 52 | num_samples: 1 53 | num_grid_rows: 1 54 | ldm_lr: 0.000005 55 | autoencoder_lr: 0.00001 56 | autoencoder_acc_steps: 4 57 | autoencoder_img_save_steps: 64 58 | save_latents : False 59 | vae_latent_dir_name: 'vae_latents' 60 | vqvae_latent_dir_name: 'vqvae_latents' 61 | ldm_ckpt_name: 'ddpm_ckpt.pth' 62 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth' 63 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 64 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth' 65 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 66 | -------------------------------------------------------------------------------- /utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import glob 3 | import os 4 | import torch 5 | 6 | 7 | def load_latents(latent_path): 8 | r""" 9 | Simple utility to save latents to speed up ldm training 10 | :param latent_path: 11 | :return: 12 | """ 13 | latent_maps = {} 14 | for fname in glob.glob(os.path.join(latent_path, '*.pkl')): 15 | s = pickle.load(open(fname, 'rb')) 16 | for k, v in s.items(): 17 | latent_maps[k] = v[0] 18 | return latent_maps 19 | 20 | 21 | def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob): 22 | if text_drop_prob > 0: 23 | text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0, 24 | 1) < text_drop_prob 25 | assert empty_text_embed is not None, ("Text Conditioning required as well as" 26 | " text dropping but empty text representation not created") 27 | text_embed[text_drop_mask, :, :] = empty_text_embed[0] 28 | return text_embed 29 | 30 | 31 | def drop_image_condition(image_condition, im, im_drop_prob): 32 | if im_drop_prob > 0: 33 | im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, 34 | 1) > im_drop_prob 35 | return image_condition * im_drop_mask 36 | else: 37 | return image_condition 38 | 39 | 40 | def drop_class_condition(class_condition, class_drop_prob, im): 41 | if class_drop_prob > 0: 42 | class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0, 43 | 1) > class_drop_prob 44 | return class_condition * class_drop_mask 45 | else: 46 | return class_condition -------------------------------------------------------------------------------- /config/mnist_class_cond.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/mnist/train/images' 3 | im_channels : 1 4 | im_size : 28 5 | name: 'mnist' 6 | 7 | diffusion_params: 8 | num_timesteps : 1000 9 | beta_start : 0.0015 10 | beta_end : 0.0195 11 | 12 | ldm_params: 13 | down_channels: [ 128, 256, 256, 256] 14 | mid_channels: [ 256, 256] 15 | down_sample: [ False, False, False ] 16 | attn_down : [True, True, True] 17 | time_emb_dim: 256 18 | norm_channels : 32 19 | num_heads : 16 20 | conv_out_channels : 128 21 | num_down_layers: 2 22 | num_mid_layers: 2 23 | num_up_layers: 2 24 | condition_config: 25 | condition_types: ['class'] 26 | class_condition_config : 27 | num_classes : 10 28 | cond_drop_prob : 0.1 29 | 30 | autoencoder_params: 31 | z_channels: 3 32 | codebook_size : 20 33 | down_channels : [32, 64, 128] 34 | mid_channels : [128, 128] 35 | down_sample : [True, True] 36 | attn_down : [False, False] 37 | norm_channels: 32 38 | num_heads: 16 39 | num_down_layers : 1 40 | num_mid_layers : 1 41 | num_up_layers : 1 42 | 43 | train_params: 44 | seed : 1111 45 | task_name: 'mnist' 46 | ldm_batch_size: 64 47 | autoencoder_batch_size: 64 48 | disc_start: 1000 49 | disc_weight: 0.5 50 | codebook_weight: 1 51 | commitment_beta: 0.2 52 | perceptual_weight: 1 53 | kl_weight: 0.000005 54 | ldm_epochs : 100 55 | autoencoder_epochs : 10 56 | num_samples : 25 57 | num_grid_rows : 5 58 | ldm_lr: 0.00001 59 | autoencoder_lr: 0.0001 60 | autoencoder_acc_steps : 1 61 | autoencoder_img_save_steps : 8 62 | save_latents : False 63 | cf_guidance_scale : 1.0 64 | vae_latent_dir_name : 'vae_latents' 65 | vqvae_latent_dir_name : 'vqvae_latents' 66 | ldm_ckpt_name: 'ddpm_ckpt_class_cond.pth' 67 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth' 68 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 69 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth' 70 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 71 | -------------------------------------------------------------------------------- /config/celebhq_text_cond.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/CelebAMask-HQ' 3 | im_channels : 3 4 | im_size : 256 5 | name: 'celebhq' 6 | 7 | diffusion_params: 8 | num_timesteps : 1000 9 | beta_start : 0.00085 10 | beta_end : 0.012 11 | 12 | ldm_params: 13 | down_channels: [ 256, 384, 512, 768 ] 14 | mid_channels: [ 768, 512 ] 15 | down_sample: [ True, True, True ] 16 | attn_down : [True, True, True] 17 | time_emb_dim: 512 18 | norm_channels: 32 19 | num_heads: 16 20 | conv_out_channels : 128 21 | num_down_layers : 2 22 | num_mid_layers : 2 23 | num_up_layers : 2 24 | condition_config: 25 | condition_types: [ 'text' ] 26 | text_condition_config: 27 | text_embed_model: 'clip' 28 | train_text_embed_model: False 29 | text_embed_dim: 512 30 | cond_drop_prob: 0.1 31 | 32 | autoencoder_params: 33 | z_channels: 4 34 | codebook_size : 8192 35 | down_channels : [64, 128, 256, 256] 36 | mid_channels : [256, 256] 37 | down_sample : [True, True, True] 38 | attn_down : [False, False, False] 39 | norm_channels: 32 40 | num_heads: 4 41 | num_down_layers : 2 42 | num_mid_layers : 2 43 | num_up_layers : 2 44 | 45 | 46 | train_params: 47 | seed : 1111 48 | task_name: 'celebhq' 49 | ldm_batch_size: 16 50 | autoencoder_batch_size: 4 51 | disc_start: 15000 52 | disc_weight: 0.5 53 | codebook_weight: 1 54 | commitment_beta: 0.2 55 | perceptual_weight: 1 56 | kl_weight: 0.000005 57 | ldm_epochs: 100 58 | autoencoder_epochs: 20 59 | num_samples: 1 60 | num_grid_rows: 1 61 | ldm_lr: 0.000005 62 | autoencoder_lr: 0.00001 63 | autoencoder_acc_steps: 4 64 | autoencoder_img_save_steps: 64 65 | save_latents : False 66 | cf_guidance_scale : 1.0 67 | vae_latent_dir_name: 'vae_latents' 68 | vqvae_latent_dir_name: 'vqvae_latents' 69 | ldm_ckpt_name: 'ddpm_ckpt_text_cond_clip.pth' 70 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth' 71 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 72 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth' 73 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 74 | -------------------------------------------------------------------------------- /config/celebhq_text_cond_bert.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/CelebAMask-HQ' 3 | im_channels : 3 4 | im_size : 256 5 | name: 'celebhq' 6 | 7 | diffusion_params: 8 | num_timesteps : 1000 9 | beta_start : 0.00085 10 | beta_end : 0.012 11 | 12 | ldm_params: 13 | down_channels: [ 256, 384, 512, 768 ] 14 | mid_channels: [ 768, 512 ] 15 | down_sample: [ True, True, True ] 16 | attn_down : [True, True, True] 17 | time_emb_dim: 512 18 | norm_channels: 32 19 | num_heads: 16 20 | conv_out_channels : 128 21 | num_down_layers : 2 22 | num_mid_layers : 2 23 | num_up_layers : 2 24 | condition_config: 25 | condition_types: [ 'text' ] 26 | text_condition_config: 27 | text_embed_model: 'bert' 28 | train_text_embed_model: False 29 | text_embed_dim: 768 30 | cond_drop_prob: 0.1 31 | 32 | autoencoder_params: 33 | z_channels: 4 34 | codebook_size : 8192 35 | down_channels : [64, 128, 256, 256] 36 | mid_channels : [256, 256] 37 | down_sample : [True, True, True] 38 | attn_down : [False, False, False] 39 | norm_channels: 32 40 | num_heads: 4 41 | num_down_layers : 2 42 | num_mid_layers : 2 43 | num_up_layers : 2 44 | 45 | 46 | train_params: 47 | seed : 1111 48 | task_name: 'celebhq' 49 | ldm_batch_size: 16 50 | autoencoder_batch_size: 4 51 | disc_start: 15000 52 | disc_weight: 0.5 53 | codebook_weight: 1 54 | commitment_beta: 0.2 55 | perceptual_weight: 1 56 | kl_weight: 0.000005 57 | ldm_epochs: 100 58 | autoencoder_epochs: 20 59 | num_samples: 1 60 | num_grid_rows: 1 61 | ldm_lr: 0.000005 62 | autoencoder_lr: 0.00001 63 | autoencoder_acc_steps: 4 64 | autoencoder_img_save_steps: 64 65 | save_latents : False 66 | cf_guidance_scale : 1.0 67 | vae_latent_dir_name: 'vae_latents' 68 | vqvae_latent_dir_name: 'vqvae_latents' 69 | ldm_ckpt_name: 'ddpm_ckpt_text_cond_bert.pth' 70 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth' 71 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 72 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth' 73 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 74 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def validate_class_config(condition_config): 3 | assert 'class_condition_config' in condition_config, \ 4 | "Class conditioning desired but class condition config missing" 5 | assert 'num_classes' in condition_config['class_condition_config'], \ 6 | "num_class missing in class condition config" 7 | 8 | 9 | def validate_text_config(condition_config): 10 | assert 'text_condition_config' in condition_config, \ 11 | "Text conditioning desired but text condition config missing" 12 | assert 'text_embed_dim' in condition_config['text_condition_config'], \ 13 | "text_embed_dim missing in text condition config" 14 | 15 | 16 | def validate_image_config(condition_config): 17 | assert 'image_condition_config' in condition_config, \ 18 | "Image conditioning desired but image condition config missing" 19 | assert 'image_condition_input_channels' in condition_config['image_condition_config'], \ 20 | "image_condition_input_channels missing in image condition config" 21 | assert 'image_condition_output_channels' in condition_config['image_condition_config'], \ 22 | "image_condition_output_channels missing in image condition config" 23 | 24 | 25 | def validate_image_conditional_input(cond_input, x): 26 | assert 'image' in cond_input, \ 27 | "Model initialized with image conditioning but cond_input has no image information" 28 | assert cond_input['image'].shape[0] == x.shape[0], \ 29 | "Batch size mismatch of image condition and input" 30 | assert cond_input['image'].shape[2] % x.shape[2] == 0, \ 31 | "Height/Width of image condition must be divisible by latent input" 32 | 33 | 34 | def validate_class_conditional_input(cond_input, x, num_classes): 35 | assert 'class' in cond_input, \ 36 | "Model initialized with class conditioning but cond_input has no class information" 37 | assert cond_input['class'].shape == (x.shape[0], num_classes), \ 38 | "Shape of class condition input must match (Batch Size, )" 39 | def get_config_value(config, key, default_value): 40 | return config[key] if key in config else default_value -------------------------------------------------------------------------------- /config/celebhq_text_image_cond.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/CelebAMask-HQ' 3 | im_channels : 3 4 | im_size : 256 5 | name: 'celebhq' 6 | 7 | diffusion_params: 8 | num_timesteps : 1000 9 | beta_start: 0.00085 10 | beta_end: 0.012 11 | 12 | ldm_params: 13 | down_channels: [ 256, 384, 512, 768 ] 14 | mid_channels: [ 768, 512 ] 15 | down_sample: [ True, True, True ] 16 | attn_down : [True, True, True] 17 | time_emb_dim: 512 18 | norm_channels: 32 19 | num_heads: 16 20 | conv_out_channels : 128 21 | num_down_layers : 2 22 | num_mid_layers : 2 23 | num_up_layers : 2 24 | condition_config: 25 | condition_types: [ 'text', 'image' ] 26 | text_condition_config: 27 | text_embed_model: 'clip' 28 | train_text_embed_model: False 29 | text_embed_dim: 512 30 | cond_drop_prob: 0.1 31 | image_condition_config: 32 | image_condition_input_channels: 18 33 | image_condition_output_channels: 3 34 | image_condition_h : 512 35 | image_condition_w : 512 36 | cond_drop_prob: 0.1 37 | 38 | 39 | autoencoder_params: 40 | z_channels: 4 41 | codebook_size : 8192 42 | down_channels : [64, 128, 256, 256] 43 | mid_channels : [256, 256] 44 | down_sample : [True, True, True] 45 | attn_down : [False, False, False] 46 | norm_channels: 32 47 | num_heads: 4 48 | num_down_layers : 2 49 | num_mid_layers : 2 50 | num_up_layers : 2 51 | 52 | 53 | train_params: 54 | seed : 1111 55 | task_name: 'celebhq' 56 | ldm_batch_size: 16 57 | autoencoder_batch_size: 4 58 | disc_start: 15000 59 | disc_weight: 0.5 60 | codebook_weight: 1 61 | commitment_beta: 0.2 62 | perceptual_weight: 1 63 | kl_weight: 0.000005 64 | ldm_epochs: 100 65 | autoencoder_epochs: 20 66 | num_samples: 1 67 | num_grid_rows: 1 68 | ldm_lr: 0.000005 69 | autoencoder_lr: 0.00001 70 | autoencoder_acc_steps: 4 71 | autoencoder_img_save_steps: 64 72 | save_latents : False 73 | cf_guidance_scale : 1.0 74 | vae_latent_dir_name: 'vae_latents' 75 | vqvae_latent_dir_name: 'vqvae_latents' 76 | ldm_ckpt_name: 'ddpm_ckpt_text_image_cond_clip.pth' 77 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth' 78 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 79 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth' 80 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 81 | -------------------------------------------------------------------------------- /scheduler/linear_noise_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class LinearNoiseScheduler: 6 | r""" 7 | Class for the linear noise scheduler that is used in DDPM. 8 | """ 9 | 10 | def __init__(self, num_timesteps, beta_start, beta_end): 11 | self.num_timesteps = num_timesteps 12 | self.beta_start = beta_start 13 | self.beta_end = beta_end 14 | # Mimicking how compvis repo creates schedule 15 | self.betas = ( 16 | torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2 17 | ) 18 | self.alphas = 1. - self.betas 19 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) 20 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) 21 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) 22 | 23 | def add_noise(self, original, noise, t): 24 | r""" 25 | Forward method for diffusion 26 | :param original: Image on which noise is to be applied 27 | :param noise: Random Noise Tensor (from normal dist) 28 | :param t: timestep of the forward process of shape -> (B,) 29 | :return: 30 | """ 31 | original_shape = original.shape 32 | batch_size = original_shape[0] 33 | 34 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 35 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 36 | 37 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W) 38 | for _ in range(len(original_shape) - 1): 39 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) 40 | for _ in range(len(original_shape) - 1): 41 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) 42 | 43 | # Apply and Return Forward process equation 44 | return (sqrt_alpha_cum_prod.to(original.device) * original 45 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) 46 | 47 | def sample_prev_timestep(self, xt, noise_pred, t): 48 | r""" 49 | Use the noise prediction by model to get 50 | xt-1 using xt and the nosie predicted 51 | :param xt: current timestep sample 52 | :param noise_pred: model noise prediction 53 | :param t: current timestep we are at 54 | :return: 55 | """ 56 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / 57 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) 58 | x0 = torch.clamp(x0, -1., 1.) 59 | 60 | mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) 61 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) 62 | 63 | if t == 0: 64 | return mean, x0 65 | else: 66 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) 67 | variance = variance * self.betas.to(xt.device)[t] 68 | sigma = variance ** 0.5 69 | z = torch.randn(xt.shape).to(xt.device) 70 | 71 | # OR 72 | # variance = self.betas[t] 73 | # sigma = variance ** 0.5 74 | # z = torch.randn(xt.shape).to(xt.device) 75 | return mean + sigma * z, x0 76 | -------------------------------------------------------------------------------- /dataset/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import torchvision 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from utils.diffusion_utils import load_latents 7 | from torch.utils.data.dataset import Dataset 8 | 9 | 10 | class MnistDataset(Dataset): 11 | r""" 12 | Nothing special here. Just a simple dataset class for mnist images. 13 | Created a dataset class rather using torchvision to allow 14 | replacement with any other image dataset 15 | """ 16 | 17 | def __init__(self, split, im_path, im_size, im_channels, 18 | use_latents=False, latent_path=None, condition_config=None): 19 | r""" 20 | Init method for initializing the dataset properties 21 | :param split: train/test to locate the image files 22 | :param im_path: root folder of images 23 | :param im_ext: image extension. assumes all 24 | images would be this type. 25 | """ 26 | self.split = split 27 | self.im_size = im_size 28 | self.im_channels = im_channels 29 | 30 | # Should we use latents or not 31 | self.latent_maps = None 32 | self.use_latents = False 33 | 34 | # Conditioning for the dataset 35 | self.condition_types = [] if condition_config is None else condition_config['condition_types'] 36 | 37 | self.images, self.labels = self.load_images(im_path) 38 | 39 | # Whether to load images and call vae or to load latents 40 | if use_latents and latent_path is not None: 41 | latent_maps = load_latents(latent_path) 42 | if len(latent_maps) == len(self.images): 43 | self.use_latents = True 44 | self.latent_maps = latent_maps 45 | print('Found {} latents'.format(len(self.latent_maps))) 46 | else: 47 | print('Latents not found') 48 | 49 | def load_images(self, im_path): 50 | r""" 51 | Gets all images from the path specified 52 | and stacks them all up 53 | :param im_path: 54 | :return: 55 | """ 56 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path) 57 | ims = [] 58 | labels = [] 59 | for d_name in tqdm(os.listdir(im_path)): 60 | fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png'))) 61 | fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg'))) 62 | fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg'))) 63 | for fname in fnames: 64 | ims.append(fname) 65 | if 'class' in self.condition_types: 66 | labels.append(int(d_name)) 67 | print('Found {} images for split {}'.format(len(ims), self.split)) 68 | return ims, labels 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | def __getitem__(self, index): 74 | ######## Set Conditioning Info ######## 75 | cond_inputs = {} 76 | if 'class' in self.condition_types: 77 | cond_inputs['class'] = self.labels[index] 78 | ####################################### 79 | 80 | if self.use_latents: 81 | latent = self.latent_maps[self.images[index]] 82 | if len(self.condition_types) == 0: 83 | return latent 84 | else: 85 | return latent, cond_inputs 86 | else: 87 | im = Image.open(self.images[index]) 88 | im_tensor = torchvision.transforms.ToTensor()(im) 89 | 90 | # Convert input to -1 to 1 range. 91 | im_tensor = (2 * im_tensor) - 1 92 | if len(self.condition_types) == 0: 93 | return im_tensor 94 | else: 95 | return im_tensor, cond_inputs 96 | 97 | -------------------------------------------------------------------------------- /models/unet_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.blocks import get_time_embedding 4 | from models.blocks import DownBlock, MidBlock, UpBlockUnet 5 | 6 | 7 | class Unet(nn.Module): 8 | r""" 9 | Unet model comprising 10 | Down blocks, Midblocks and Uplocks 11 | """ 12 | 13 | def __init__(self, im_channels, model_config): 14 | super().__init__() 15 | self.down_channels = model_config['down_channels'] 16 | self.mid_channels = model_config['mid_channels'] 17 | self.t_emb_dim = model_config['time_emb_dim'] 18 | self.down_sample = model_config['down_sample'] 19 | self.num_down_layers = model_config['num_down_layers'] 20 | self.num_mid_layers = model_config['num_mid_layers'] 21 | self.num_up_layers = model_config['num_up_layers'] 22 | self.attns = model_config['attn_down'] 23 | self.norm_channels = model_config['norm_channels'] 24 | self.num_heads = model_config['num_heads'] 25 | self.conv_out_channels = model_config['conv_out_channels'] 26 | 27 | assert self.mid_channels[0] == self.down_channels[-1] 28 | assert self.mid_channels[-1] == self.down_channels[-2] 29 | assert len(self.down_sample) == len(self.down_channels) - 1 30 | assert len(self.attns) == len(self.down_channels) - 1 31 | 32 | # Initial projection from sinusoidal time embedding 33 | self.t_proj = nn.Sequential( 34 | nn.Linear(self.t_emb_dim, self.t_emb_dim), 35 | nn.SiLU(), 36 | nn.Linear(self.t_emb_dim, self.t_emb_dim) 37 | ) 38 | 39 | self.up_sample = list(reversed(self.down_sample)) 40 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) 41 | 42 | self.downs = nn.ModuleList([]) 43 | for i in range(len(self.down_channels) - 1): 44 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, 45 | down_sample=self.down_sample[i], 46 | num_heads=self.num_heads, 47 | num_layers=self.num_down_layers, 48 | attn=self.attns[i], norm_channels=self.norm_channels)) 49 | 50 | self.mids = nn.ModuleList([]) 51 | for i in range(len(self.mid_channels) - 1): 52 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, 53 | num_heads=self.num_heads, 54 | num_layers=self.num_mid_layers, 55 | norm_channels=self.norm_channels)) 56 | 57 | self.ups = nn.ModuleList([]) 58 | for i in reversed(range(len(self.down_channels) - 1)): 59 | self.ups.append(UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, 60 | self.t_emb_dim, up_sample=self.down_sample[i], 61 | num_heads=self.num_heads, 62 | num_layers=self.num_up_layers, 63 | norm_channels=self.norm_channels)) 64 | 65 | self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) 66 | self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) 67 | 68 | def forward(self, x, t): 69 | # Shapes assuming downblocks are [C1, C2, C3, C4] 70 | # Shapes assuming midblocks are [C4, C4, C3] 71 | # Shapes assuming downsamples are [True, True, False] 72 | # B x C x H x W 73 | out = self.conv_in(x) 74 | # B x C1 x H x W 75 | 76 | # t_emb -> B x t_emb_dim 77 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) 78 | t_emb = self.t_proj(t_emb) 79 | 80 | down_outs = [] 81 | 82 | for idx, down in enumerate(self.downs): 83 | down_outs.append(out) 84 | out = down(out, t_emb) 85 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4] 86 | # out B x C4 x H/4 x W/4 87 | 88 | for mid in self.mids: 89 | out = mid(out, t_emb) 90 | # out B x C3 x H/4 x W/4 91 | 92 | for up in self.ups: 93 | down_out = down_outs.pop() 94 | out = up(out, down_out, t_emb) 95 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W] 96 | out = self.norm_out(out) 97 | out = nn.SiLU()(out) 98 | out = self.conv_out(out) 99 | # out B x C x H x W 100 | return out 101 | -------------------------------------------------------------------------------- /tools/sample_ddpm_vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | from torchvision.utils import make_grid 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from models.unet_base import Unet 10 | from models.vqvae import VQVAE 11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | def sample(model, scheduler, train_config, diffusion_model_config, 17 | autoencoder_model_config, diffusion_config, dataset_config, vae): 18 | r""" 19 | Sample stepwise by going backward one timestep at a time. 20 | We save the x0 predictions 21 | """ 22 | im_size = dataset_config['im_size'] // 2**sum(autoencoder_model_config['down_sample']) 23 | xt = torch.randn((train_config['num_samples'], 24 | autoencoder_model_config['z_channels'], 25 | im_size, 26 | im_size)).to(device) 27 | 28 | save_count = 0 29 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 30 | # Get prediction of noise 31 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) 32 | 33 | # Use scheduler to get x0 and xt-1 34 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 35 | 36 | # Save x0 37 | #ims = torch.clamp(xt, -1., 1.).detach().cpu() 38 | if i == 0: 39 | # Decode ONLY the final iamge to save time 40 | ims = vae.decode(xt) 41 | else: 42 | ims = xt 43 | 44 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 45 | ims = (ims + 1) / 2 46 | grid = make_grid(ims, nrow=train_config['num_grid_rows']) 47 | img = torchvision.transforms.ToPILImage()(grid) 48 | 49 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')): 50 | os.mkdir(os.path.join(train_config['task_name'], 'samples')) 51 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i))) 52 | img.close() 53 | 54 | 55 | def infer(args): 56 | # Read the config file # 57 | with open(args.config_path, 'r') as file: 58 | try: 59 | config = yaml.safe_load(file) 60 | except yaml.YAMLError as exc: 61 | print(exc) 62 | print(config) 63 | ######################## 64 | 65 | diffusion_config = config['diffusion_params'] 66 | dataset_config = config['dataset_params'] 67 | diffusion_model_config = config['ldm_params'] 68 | autoencoder_model_config = config['autoencoder_params'] 69 | train_config = config['train_params'] 70 | 71 | # Create the noise scheduler 72 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 73 | beta_start=diffusion_config['beta_start'], 74 | beta_end=diffusion_config['beta_end']) 75 | 76 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 77 | model_config=diffusion_model_config).to(device) 78 | model.eval() 79 | if os.path.exists(os.path.join(train_config['task_name'], 80 | train_config['ldm_ckpt_name'])): 81 | print('Loaded unet checkpoint') 82 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 83 | train_config['ldm_ckpt_name']), 84 | map_location=device)) 85 | # Create output directories 86 | if not os.path.exists(train_config['task_name']): 87 | os.mkdir(train_config['task_name']) 88 | 89 | vae = VQVAE(im_channels=dataset_config['im_channels'], 90 | model_config=autoencoder_model_config).to(device) 91 | vae.eval() 92 | 93 | # Load vae if found 94 | if os.path.exists(os.path.join(train_config['task_name'], 95 | train_config['vqvae_autoencoder_ckpt_name'])): 96 | print('Loaded vae checkpoint') 97 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 98 | train_config['vqvae_autoencoder_ckpt_name']), 99 | map_location=device), strict=True) 100 | with torch.no_grad(): 101 | sample(model, scheduler, train_config, diffusion_model_config, 102 | autoencoder_model_config, diffusion_config, dataset_config, vae) 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation') 107 | parser.add_argument('--config', dest='config_path', 108 | default='config/mnist.yaml', type=str) 109 | args = parser.parse_args() 110 | infer(args) 111 | -------------------------------------------------------------------------------- /tools/train_ddpm_vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.optim import Adam 8 | from dataset.mnist_dataset import MnistDataset 9 | from dataset.celeb_dataset import CelebDataset 10 | from torch.utils.data import DataLoader 11 | from models.unet_base import Unet 12 | from models.vqvae import VQVAE 13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def train(args): 19 | # Read the config file # 20 | with open(args.config_path, 'r') as file: 21 | try: 22 | config = yaml.safe_load(file) 23 | except yaml.YAMLError as exc: 24 | print(exc) 25 | print(config) 26 | ######################## 27 | 28 | diffusion_config = config['diffusion_params'] 29 | dataset_config = config['dataset_params'] 30 | diffusion_model_config = config['ldm_params'] 31 | autoencoder_model_config = config['autoencoder_params'] 32 | train_config = config['train_params'] 33 | 34 | # Create the noise scheduler 35 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 36 | beta_start=diffusion_config['beta_start'], 37 | beta_end=diffusion_config['beta_end']) 38 | 39 | im_dataset_cls = { 40 | 'mnist': MnistDataset, 41 | 'celebhq': CelebDataset, 42 | }.get(dataset_config['name']) 43 | 44 | im_dataset = im_dataset_cls(split='train', 45 | im_path=dataset_config['im_path'], 46 | im_size=dataset_config['im_size'], 47 | im_channels=dataset_config['im_channels'], 48 | use_latents=True, 49 | latent_path=os.path.join(train_config['task_name'], 50 | train_config['vqvae_latent_dir_name']) 51 | ) 52 | 53 | data_loader = DataLoader(im_dataset, 54 | batch_size=train_config['ldm_batch_size'], 55 | shuffle=True) 56 | 57 | # Instantiate the model 58 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 59 | model_config=diffusion_model_config).to(device) 60 | model.train() 61 | 62 | # Load VAE ONLY if latents are not to be used or are missing 63 | if not im_dataset.use_latents: 64 | print('Loading vqvae model as latents not present') 65 | vae = VQVAE(im_channels=dataset_config['im_channels'], 66 | model_config=autoencoder_model_config).to(device) 67 | vae.eval() 68 | # Load vae if found 69 | if os.path.exists(os.path.join(train_config['task_name'], 70 | train_config['vqvae_autoencoder_ckpt_name'])): 71 | print('Loaded vae checkpoint') 72 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 73 | train_config['vqvae_autoencoder_ckpt_name']), 74 | map_location=device)) 75 | # Specify training parameters 76 | num_epochs = train_config['ldm_epochs'] 77 | optimizer = Adam(model.parameters(), lr=train_config['ldm_lr']) 78 | criterion = torch.nn.MSELoss() 79 | 80 | # Run training 81 | if not im_dataset.use_latents: 82 | for param in vae.parameters(): 83 | param.requires_grad = False 84 | 85 | for epoch_idx in range(num_epochs): 86 | losses = [] 87 | for im in tqdm(data_loader): 88 | optimizer.zero_grad() 89 | im = im.float().to(device) 90 | if not im_dataset.use_latents: 91 | with torch.no_grad(): 92 | im, _ = vae.encode(im) 93 | 94 | # Sample random noise 95 | noise = torch.randn_like(im).to(device) 96 | 97 | # Sample timestep 98 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device) 99 | 100 | # Add noise to images according to timestep 101 | noisy_im = scheduler.add_noise(im, noise, t) 102 | noise_pred = model(noisy_im, t) 103 | 104 | loss = criterion(noise_pred, noise) 105 | losses.append(loss.item()) 106 | loss.backward() 107 | optimizer.step() 108 | print('Finished epoch:{} | Loss : {:.4f}'.format( 109 | epoch_idx + 1, 110 | np.mean(losses))) 111 | 112 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 113 | train_config['ldm_ckpt_name'])) 114 | 115 | print('Done Training ...') 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser(description='Arguments for ddpm training') 120 | parser.add_argument('--config', dest='config_path', 121 | default='config/mnist.yaml', type=str) 122 | args = parser.parse_args() 123 | train(args) 124 | -------------------------------------------------------------------------------- /tools/infer_vqvae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pickle 5 | 6 | import torch 7 | import torchvision 8 | import yaml 9 | from torch.utils.data.dataloader import DataLoader 10 | from torchvision.utils import make_grid 11 | from tqdm import tqdm 12 | 13 | from dataset.celeb_dataset import CelebDataset 14 | from dataset.mnist_dataset import MnistDataset 15 | from models.vqvae import VQVAE 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | def infer(args): 21 | ######## Read the config file ####### 22 | with open(args.config_path, 'r') as file: 23 | try: 24 | config = yaml.safe_load(file) 25 | except yaml.YAMLError as exc: 26 | print(exc) 27 | print(config) 28 | 29 | dataset_config = config['dataset_params'] 30 | autoencoder_config = config['autoencoder_params'] 31 | train_config = config['train_params'] 32 | 33 | # Create the dataset 34 | im_dataset_cls = { 35 | 'mnist': MnistDataset, 36 | 'celebhq': CelebDataset, 37 | }.get(dataset_config['name']) 38 | 39 | im_dataset = im_dataset_cls(split='train', 40 | im_path=dataset_config['im_path'], 41 | im_size=dataset_config['im_size'], 42 | im_channels=dataset_config['im_channels']) 43 | 44 | # This is only used for saving latents. Which as of now 45 | # is not done in batches hence batch size 1 46 | data_loader = DataLoader(im_dataset, 47 | batch_size=1, 48 | shuffle=False) 49 | 50 | num_images = train_config['num_samples'] 51 | ngrid = train_config['num_grid_rows'] 52 | 53 | idxs = torch.randint(0, len(im_dataset) - 1, (num_images,)) 54 | ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float() 55 | ims = ims.to(device) 56 | 57 | model = VQVAE(im_channels=dataset_config['im_channels'], 58 | model_config=autoencoder_config).to(device) 59 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 60 | train_config['vqvae_autoencoder_ckpt_name']), 61 | map_location=device)) 62 | model.eval() 63 | 64 | with torch.no_grad(): 65 | 66 | encoded_output, _ = model.encode(ims) 67 | decoded_output = model.decode(encoded_output) 68 | encoded_output = torch.clamp(encoded_output, -1., 1.) 69 | encoded_output = (encoded_output + 1) / 2 70 | decoded_output = torch.clamp(decoded_output, -1., 1.) 71 | decoded_output = (decoded_output + 1) / 2 72 | ims = (ims + 1) / 2 73 | 74 | encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid) 75 | decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid) 76 | input_grid = make_grid(ims.cpu(), nrow=ngrid) 77 | encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid) 78 | decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid) 79 | input_grid = torchvision.transforms.ToPILImage()(input_grid) 80 | 81 | input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png')) 82 | encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png')) 83 | decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png')) 84 | 85 | if train_config['save_latents']: 86 | # save Latents (but in a very unoptimized way) 87 | latent_path = os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name']) 88 | latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name'], 89 | '*.pkl')) 90 | assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run' 91 | if not os.path.exists(latent_path): 92 | os.mkdir(latent_path) 93 | print('Saving Latents for {}'.format(dataset_config['name'])) 94 | 95 | fname_latent_map = {} 96 | part_count = 0 97 | count = 0 98 | for idx, im in enumerate(tqdm(data_loader)): 99 | encoded_output, _ = model.encode(im.float().to(device)) 100 | fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu() 101 | # Save latents every 1000 images 102 | if (count+1) % 1000 == 0: 103 | pickle.dump(fname_latent_map, open(os.path.join(latent_path, 104 | '{}.pkl'.format(part_count)), 'wb')) 105 | part_count += 1 106 | fname_latent_map = {} 107 | count += 1 108 | if len(fname_latent_map) > 0: 109 | pickle.dump(fname_latent_map, open(os.path.join(latent_path, 110 | '{}.pkl'.format(part_count)), 'wb')) 111 | print('Done saving latents') 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser(description='Arguments for vq vae inference') 116 | parser.add_argument('--config', dest='config_path', 117 | default='config/mnist.yaml', type=str) 118 | args = parser.parse_args() 119 | infer(args) 120 | -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.blocks import DownBlock, MidBlock, UpBlock 4 | 5 | 6 | class VAE(nn.Module): 7 | def __init__(self, im_channels, model_config): 8 | super().__init__() 9 | self.down_channels = model_config['down_channels'] 10 | self.mid_channels = model_config['mid_channels'] 11 | self.down_sample = model_config['down_sample'] 12 | self.num_down_layers = model_config['num_down_layers'] 13 | self.num_mid_layers = model_config['num_mid_layers'] 14 | self.num_up_layers = model_config['num_up_layers'] 15 | 16 | # To disable attention in Downblock of Encoder and Upblock of Decoder 17 | self.attns = model_config['attn_down'] 18 | 19 | # Latent Dimension 20 | self.z_channels = model_config['z_channels'] 21 | self.norm_channels = model_config['norm_channels'] 22 | self.num_heads = model_config['num_heads'] 23 | 24 | # Assertion to validate the channel information 25 | assert self.mid_channels[0] == self.down_channels[-1] 26 | assert self.mid_channels[-1] == self.down_channels[-1] 27 | assert len(self.down_sample) == len(self.down_channels) - 1 28 | assert len(self.attns) == len(self.down_channels) - 1 29 | 30 | # Wherever we use downsampling in encoder correspondingly use 31 | # upsampling in decoder 32 | self.up_sample = list(reversed(self.down_sample)) 33 | 34 | ##################### Encoder ###################### 35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) 36 | 37 | # Downblock + Midblock 38 | self.encoder_layers = nn.ModuleList([]) 39 | for i in range(len(self.down_channels) - 1): 40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], 41 | t_emb_dim=None, down_sample=self.down_sample[i], 42 | num_heads=self.num_heads, 43 | num_layers=self.num_down_layers, 44 | attn=self.attns[i], 45 | norm_channels=self.norm_channels)) 46 | 47 | self.encoder_mids = nn.ModuleList([]) 48 | for i in range(len(self.mid_channels) - 1): 49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], 50 | t_emb_dim=None, 51 | num_heads=self.num_heads, 52 | num_layers=self.num_mid_layers, 53 | norm_channels=self.norm_channels)) 54 | 55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) 56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2*self.z_channels, kernel_size=3, padding=1) 57 | 58 | # Latent Dimension is 2*Latent because we are predicting mean & variance 59 | self.pre_quant_conv = nn.Conv2d(2*self.z_channels, 2*self.z_channels, kernel_size=1) 60 | #################################################### 61 | 62 | 63 | ##################### Decoder ###################### 64 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) 65 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) 66 | 67 | # Midblock + Upblock 68 | self.decoder_mids = nn.ModuleList([]) 69 | for i in reversed(range(1, len(self.mid_channels))): 70 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], 71 | t_emb_dim=None, 72 | num_heads=self.num_heads, 73 | num_layers=self.num_mid_layers, 74 | norm_channels=self.norm_channels)) 75 | 76 | self.decoder_layers = nn.ModuleList([]) 77 | for i in reversed(range(1, len(self.down_channels))): 78 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], 79 | t_emb_dim=None, up_sample=self.down_sample[i - 1], 80 | num_heads=self.num_heads, 81 | num_layers=self.num_up_layers, 82 | attn=self.attns[i - 1], 83 | norm_channels=self.norm_channels)) 84 | 85 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) 86 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) 87 | 88 | def encode(self, x): 89 | out = self.encoder_conv_in(x) 90 | for idx, down in enumerate(self.encoder_layers): 91 | out = down(out) 92 | for mid in self.encoder_mids: 93 | out = mid(out) 94 | out = self.encoder_norm_out(out) 95 | out = nn.SiLU()(out) 96 | out = self.encoder_conv_out(out) 97 | out = self.pre_quant_conv(out) 98 | mean, logvar = torch.chunk(out, 2, dim=1) 99 | std = torch.exp(0.5 * logvar) 100 | sample = mean + std * torch.randn(mean.shape).to(device=x.device) 101 | return sample, out 102 | 103 | def decode(self, z): 104 | out = z 105 | out = self.post_quant_conv(out) 106 | out = self.decoder_conv_in(out) 107 | for mid in self.decoder_mids: 108 | out = mid(out) 109 | for idx, up in enumerate(self.decoder_layers): 110 | out = up(out) 111 | 112 | out = self.decoder_norm_out(out) 113 | out = nn.SiLU()(out) 114 | out = self.decoder_conv_out(out) 115 | return out 116 | 117 | def forward(self, x): 118 | z, encoder_output = self.encode(x) 119 | out = self.decode(z) 120 | return out, encoder_output 121 | 122 | -------------------------------------------------------------------------------- /dataset/celeb_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | from PIL import Image 8 | from utils.diffusion_utils import load_latents 9 | from tqdm import tqdm 10 | from torch.utils.data.dataset import Dataset 11 | 12 | 13 | class CelebDataset(Dataset): 14 | r""" 15 | Celeb dataset will by default centre crop and resize the images. 16 | This can be replaced by any other dataset. As long as all the images 17 | are under one directory. 18 | """ 19 | 20 | def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg', 21 | use_latents=False, latent_path=None, condition_config=None): 22 | self.split = split 23 | self.im_size = im_size 24 | self.im_channels = im_channels 25 | self.im_ext = im_ext 26 | self.im_path = im_path 27 | self.latent_maps = None 28 | self.use_latents = False 29 | 30 | self.condition_types = [] if condition_config is None else condition_config['condition_types'] 31 | 32 | self.idx_to_cls_map = {} 33 | self.cls_to_idx_map ={} 34 | 35 | if 'image' in self.condition_types: 36 | self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels'] 37 | self.mask_h = condition_config['image_condition_config']['image_condition_h'] 38 | self.mask_w = condition_config['image_condition_config']['image_condition_w'] 39 | 40 | self.images, self.texts, self.masks = self.load_images(im_path) 41 | 42 | # Whether to load images or to load latents 43 | if use_latents and latent_path is not None: 44 | latent_maps = load_latents(latent_path) 45 | if len(latent_maps) == len(self.images): 46 | self.use_latents = True 47 | self.latent_maps = latent_maps 48 | print('Found {} latents'.format(len(self.latent_maps))) 49 | else: 50 | print('Latents not found') 51 | 52 | def load_images(self, im_path): 53 | r""" 54 | Gets all images from the path specified 55 | and stacks them all up 56 | """ 57 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path) 58 | ims = [] 59 | fnames = glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('png'))) 60 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpg'))) 61 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpeg'))) 62 | texts = [] 63 | masks = [] 64 | 65 | if 'image' in self.condition_types: 66 | label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 67 | 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] 68 | self.idx_to_cls_map = {idx: label_list[idx] for idx in range(len(label_list))} 69 | self.cls_to_idx_map = {label_list[idx]: idx for idx in range(len(label_list))} 70 | 71 | for fname in tqdm(fnames): 72 | ims.append(fname) 73 | 74 | if 'text' in self.condition_types: 75 | im_name = os.path.split(fname)[1].split('.')[0] 76 | captions_im = [] 77 | with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f: 78 | for line in f.readlines(): 79 | captions_im.append(line.strip()) 80 | texts.append(captions_im) 81 | 82 | if 'image' in self.condition_types: 83 | im_name = int(os.path.split(fname)[1].split('.')[0]) 84 | masks.append(os.path.join(im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name))) 85 | if 'text' in self.condition_types: 86 | assert len(texts) == len(ims), "Condition Type Text but could not find captions for all images" 87 | if 'image' in self.condition_types: 88 | assert len(masks) == len(ims), "Condition Type Image but could not find masks for all images" 89 | print('Found {} images'.format(len(ims))) 90 | print('Found {} masks'.format(len(masks))) 91 | print('Found {} captions'.format(len(texts))) 92 | return ims, texts, masks 93 | 94 | def get_mask(self, index): 95 | r""" 96 | Method to get the mask of WxH 97 | for given index and convert it into 98 | Classes x W x H mask image 99 | :param index: 100 | :return: 101 | """ 102 | mask_im = Image.open(self.masks[index]) 103 | mask_im = np.array(mask_im) 104 | im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels)) 105 | for orig_idx in range(len(self.idx_to_cls_map)): 106 | im_base[mask_im == (orig_idx+1), orig_idx] = 1 107 | mask = torch.from_numpy(im_base).permute(2, 0, 1).float() 108 | return mask 109 | 110 | def __len__(self): 111 | return len(self.images) 112 | 113 | def __getitem__(self, index): 114 | ######## Set Conditioning Info ######## 115 | cond_inputs = {} 116 | if 'text' in self.condition_types: 117 | cond_inputs['text'] = random.sample(self.texts[index], k=1)[0] 118 | if 'image' in self.condition_types: 119 | mask = self.get_mask(index) 120 | cond_inputs['image'] = mask 121 | ####################################### 122 | 123 | if self.use_latents: 124 | latent = self.latent_maps[self.images[index]] 125 | if len(self.condition_types) == 0: 126 | return latent 127 | else: 128 | return latent, cond_inputs 129 | else: 130 | im = Image.open(self.images[index]) 131 | im_tensor = torchvision.transforms.Compose([ 132 | torchvision.transforms.Resize(self.im_size), 133 | torchvision.transforms.CenterCrop(self.im_size), 134 | torchvision.transforms.ToTensor(), 135 | ])(im) 136 | im.close() 137 | 138 | # Convert input to -1 to 1 range. 139 | im_tensor = (2 * im_tensor) - 1 140 | if len(self.condition_types) == 0: 141 | return im_tensor 142 | else: 143 | return im_tensor, cond_inputs 144 | -------------------------------------------------------------------------------- /models/lpips.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn 9 | import torchvision 10 | 11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | def spatial_average(in_tens, keepdim=True): 17 | return in_tens.mean([2, 3], keepdim=keepdim) 18 | 19 | 20 | class vgg16(torch.nn.Module): 21 | def __init__(self, requires_grad=False, pretrained=True): 22 | super(vgg16, self).__init__() 23 | # Load pretrained vgg model from torchvision 24 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features 25 | self.slice1 = torch.nn.Sequential() 26 | self.slice2 = torch.nn.Sequential() 27 | self.slice3 = torch.nn.Sequential() 28 | self.slice4 = torch.nn.Sequential() 29 | self.slice5 = torch.nn.Sequential() 30 | self.N_slices = 5 31 | for x in range(4): 32 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 33 | for x in range(4, 9): 34 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 35 | for x in range(9, 16): 36 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 37 | for x in range(16, 23): 38 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 39 | for x in range(23, 30): 40 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 41 | 42 | # Freeze vgg model 43 | if not requires_grad: 44 | for param in self.parameters(): 45 | param.requires_grad = False 46 | 47 | def forward(self, X): 48 | # Return output of vgg features 49 | h = self.slice1(X) 50 | h_relu1_2 = h 51 | h = self.slice2(h) 52 | h_relu2_2 = h 53 | h = self.slice3(h) 54 | h_relu3_3 = h 55 | h = self.slice4(h) 56 | h_relu4_3 = h 57 | h = self.slice5(h) 58 | h_relu5_3 = h 59 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 60 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 61 | return out 62 | 63 | 64 | # Learned perceptual metric 65 | class LPIPS(nn.Module): 66 | def __init__(self, net='vgg', version='0.1', use_dropout=True): 67 | super(LPIPS, self).__init__() 68 | self.version = version 69 | # Imagenet normalization 70 | self.scaling_layer = ScalingLayer() 71 | ######################## 72 | 73 | # Instantiate vgg model 74 | self.chns = [64, 128, 256, 512, 512] 75 | self.L = len(self.chns) 76 | self.net = vgg16(pretrained=True, requires_grad=False) 77 | 78 | # Add 1x1 convolutional Layers 79 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 80 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 81 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 82 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 83 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 84 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 85 | self.lins = nn.ModuleList(self.lins) 86 | ######################## 87 | 88 | # Load the weights of trained LPIPS model 89 | import inspect 90 | import os 91 | model_path = os.path.abspath( 92 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) 93 | print('Loading model from: %s' % model_path) 94 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False) 95 | ######################## 96 | 97 | # Freeze all parameters 98 | self.eval() 99 | for param in self.parameters(): 100 | param.requires_grad = False 101 | ######################## 102 | 103 | def forward(self, in0, in1, normalize=False): 104 | # Scale the inputs to -1 to +1 range if needed 105 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 106 | in0 = 2 * in0 - 1 107 | in1 = 2 * in1 - 1 108 | ######################## 109 | 110 | # Normalize the inputs according to imagenet normalization 111 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) 112 | ######################## 113 | 114 | # Get VGG outputs for image0 and image1 115 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 116 | feats0, feats1, diffs = {}, {}, {} 117 | ######################## 118 | 119 | # Compute Square of Difference for each layer output 120 | for kk in range(self.L): 121 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize( 122 | outs1[kk]) 123 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 124 | ######################## 125 | 126 | # 1x1 convolution followed by spatial average on the square differences 127 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 128 | val = 0 129 | 130 | # Aggregate the results of each layer 131 | for l in range(self.L): 132 | val += res[l] 133 | return val 134 | 135 | 136 | class ScalingLayer(nn.Module): 137 | def __init__(self): 138 | super(ScalingLayer, self).__init__() 139 | # Imagnet normalization for (0-1) 140 | # mean = [0.485, 0.456, 0.406] 141 | # std = [0.229, 0.224, 0.225] 142 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 143 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 144 | 145 | def forward(self, inp): 146 | return (inp - self.shift) / self.scale 147 | 148 | 149 | class NetLinLayer(nn.Module): 150 | ''' A single linear layer which does a 1x1 conv ''' 151 | 152 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 153 | super(NetLinLayer, self).__init__() 154 | 155 | layers = [nn.Dropout(), ] if (use_dropout) else [] 156 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 157 | self.model = nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | out = self.model(x) 161 | return out 162 | -------------------------------------------------------------------------------- /tools/sample_ddpm_class_cond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | from torchvision.utils import make_grid 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from models.unet_cond_base import Unet 10 | from models.vqvae import VQVAE 11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 12 | from utils.config_utils import * 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | def sample(model, scheduler, train_config, diffusion_model_config, 18 | autoencoder_model_config, diffusion_config, dataset_config, vae): 19 | r""" 20 | Sample stepwise by going backward one timestep at a time. 21 | We save the x0 predictions 22 | """ 23 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 24 | 25 | ########### Sample random noise latent ########## 26 | xt = torch.randn((train_config['num_samples'], 27 | autoencoder_model_config['z_channels'], 28 | im_size, 29 | im_size)).to(device) 30 | ############################################### 31 | 32 | ############# Validate the config ################# 33 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None) 34 | assert condition_config is not None, ("This sampling script is for class conditional " 35 | "but no conditioning config found") 36 | condition_types = get_config_value(condition_config, 'condition_types', []) 37 | assert 'class' in condition_types, ("This sampling script is for class conditional " 38 | "but no class condition found in config") 39 | validate_class_config(condition_config) 40 | ############################################### 41 | 42 | ############ Create Conditional input ############### 43 | num_classes = condition_config['class_condition_config']['num_classes'] 44 | sample_classes = torch.randint(0, num_classes, (train_config['num_samples'], )) 45 | print('Generating images for {}'.format(list(sample_classes.numpy()))) 46 | cond_input = { 47 | 'class': torch.nn.functional.one_hot(sample_classes, num_classes).to(device) 48 | } 49 | # Unconditional input for classifier free guidance 50 | uncond_input = { 51 | 'class': cond_input['class'] * 0 52 | } 53 | ############################################### 54 | 55 | # By default classifier free guidance is disabled 56 | # Change value in config or change default value here to enable it 57 | cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0) 58 | 59 | ################# Sampling Loop ######################## 60 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 61 | # Get prediction of noise 62 | t = (torch.ones((xt.shape[0],))*i).long().to(device) 63 | noise_pred_cond = model(xt, t, cond_input) 64 | 65 | if cf_guidance_scale > 1: 66 | noise_pred_uncond = model(xt, t, uncond_input) 67 | noise_pred = noise_pred_uncond + cf_guidance_scale*(noise_pred_cond - noise_pred_uncond) 68 | else: 69 | noise_pred = noise_pred_cond 70 | 71 | # Use scheduler to get x0 and xt-1 72 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 73 | 74 | if i == 0: 75 | # Decode ONLY the final image to save time 76 | ims = vae.decode(xt) 77 | else: 78 | ims = x0_pred 79 | 80 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 81 | ims = (ims + 1) / 2 82 | grid = make_grid(ims, nrow=1) 83 | img = torchvision.transforms.ToPILImage()(grid) 84 | 85 | if not os.path.exists(os.path.join(train_config['task_name'], 'cond_class_samples')): 86 | os.mkdir(os.path.join(train_config['task_name'], 'cond_class_samples')) 87 | img.save(os.path.join(train_config['task_name'], 'cond_class_samples', 'x0_{}.png'.format(i))) 88 | img.close() 89 | ############################################################## 90 | 91 | def infer(args): 92 | # Read the config file # 93 | with open(args.config_path, 'r') as file: 94 | try: 95 | config = yaml.safe_load(file) 96 | except yaml.YAMLError as exc: 97 | print(exc) 98 | print(config) 99 | ######################## 100 | 101 | diffusion_config = config['diffusion_params'] 102 | dataset_config = config['dataset_params'] 103 | diffusion_model_config = config['ldm_params'] 104 | autoencoder_model_config = config['autoencoder_params'] 105 | train_config = config['train_params'] 106 | 107 | ########## Create the noise scheduler ############# 108 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 109 | beta_start=diffusion_config['beta_start'], 110 | beta_end=diffusion_config['beta_end']) 111 | ############################################### 112 | 113 | ########## Load Unet ############# 114 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 115 | model_config=diffusion_model_config).to(device) 116 | model.eval() 117 | if os.path.exists(os.path.join(train_config['task_name'], 118 | train_config['ldm_ckpt_name'])): 119 | print('Loaded unet checkpoint') 120 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 121 | train_config['ldm_ckpt_name']), 122 | map_location=device)) 123 | else: 124 | raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'], 125 | train_config['ldm_ckpt_name']))) 126 | ##################################### 127 | 128 | # Create output directories 129 | if not os.path.exists(train_config['task_name']): 130 | os.mkdir(train_config['task_name']) 131 | 132 | ########## Load VQVAE ############# 133 | vae = VQVAE(im_channels=dataset_config['im_channels'], 134 | model_config=autoencoder_model_config).to(device) 135 | vae.eval() 136 | 137 | # Load vae if found 138 | if os.path.exists(os.path.join(train_config['task_name'], 139 | train_config['vqvae_autoencoder_ckpt_name'])): 140 | print('Loaded vae checkpoint') 141 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 142 | train_config['vqvae_autoencoder_ckpt_name']), 143 | map_location=device), strict=True) 144 | else: 145 | raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['task_name'], 146 | train_config['vqvae_autoencoder_ckpt_name']))) 147 | ##################################### 148 | 149 | with torch.no_grad(): 150 | sample(model, scheduler, train_config, diffusion_model_config, 151 | autoencoder_model_config, diffusion_config, dataset_config, vae) 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation for class conditional ' 156 | 'Mnist generation') 157 | parser.add_argument('--config', dest='config_path', 158 | default='config/mnist_class_cond.yaml', type=str) 159 | args = parser.parse_args() 160 | infer(args) 161 | -------------------------------------------------------------------------------- /models/vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.blocks import DownBlock, MidBlock, UpBlock 4 | 5 | 6 | class VQVAE(nn.Module): 7 | def __init__(self, im_channels, model_config): 8 | super().__init__() 9 | self.down_channels = model_config['down_channels'] 10 | self.mid_channels = model_config['mid_channels'] 11 | self.down_sample = model_config['down_sample'] 12 | self.num_down_layers = model_config['num_down_layers'] 13 | self.num_mid_layers = model_config['num_mid_layers'] 14 | self.num_up_layers = model_config['num_up_layers'] 15 | 16 | # To disable attention in Downblock of Encoder and Upblock of Decoder 17 | self.attns = model_config['attn_down'] 18 | 19 | # Latent Dimension 20 | self.z_channels = model_config['z_channels'] 21 | self.codebook_size = model_config['codebook_size'] 22 | self.norm_channels = model_config['norm_channels'] 23 | self.num_heads = model_config['num_heads'] 24 | 25 | # Assertion to validate the channel information 26 | assert self.mid_channels[0] == self.down_channels[-1] 27 | assert self.mid_channels[-1] == self.down_channels[-1] 28 | assert len(self.down_sample) == len(self.down_channels) - 1 29 | assert len(self.attns) == len(self.down_channels) - 1 30 | 31 | # Wherever we use downsampling in encoder correspondingly use 32 | # upsampling in decoder 33 | self.up_sample = list(reversed(self.down_sample)) 34 | 35 | ##################### Encoder ###################### 36 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) 37 | 38 | # Downblock + Midblock 39 | self.encoder_layers = nn.ModuleList([]) 40 | for i in range(len(self.down_channels) - 1): 41 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], 42 | t_emb_dim=None, down_sample=self.down_sample[i], 43 | num_heads=self.num_heads, 44 | num_layers=self.num_down_layers, 45 | attn=self.attns[i], 46 | norm_channels=self.norm_channels)) 47 | 48 | self.encoder_mids = nn.ModuleList([]) 49 | for i in range(len(self.mid_channels) - 1): 50 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], 51 | t_emb_dim=None, 52 | num_heads=self.num_heads, 53 | num_layers=self.num_mid_layers, 54 | norm_channels=self.norm_channels)) 55 | 56 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) 57 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1) 58 | 59 | # Pre Quantization Convolution 60 | self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) 61 | 62 | # Codebook 63 | self.embedding = nn.Embedding(self.codebook_size, self.z_channels) 64 | #################################################### 65 | 66 | ##################### Decoder ###################### 67 | 68 | # Post Quantization Convolution 69 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) 70 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) 71 | 72 | # Midblock + Upblock 73 | self.decoder_mids = nn.ModuleList([]) 74 | for i in reversed(range(1, len(self.mid_channels))): 75 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], 76 | t_emb_dim=None, 77 | num_heads=self.num_heads, 78 | num_layers=self.num_mid_layers, 79 | norm_channels=self.norm_channels)) 80 | 81 | self.decoder_layers = nn.ModuleList([]) 82 | for i in reversed(range(1, len(self.down_channels))): 83 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], 84 | t_emb_dim=None, up_sample=self.down_sample[i - 1], 85 | num_heads=self.num_heads, 86 | num_layers=self.num_up_layers, 87 | attn=self.attns[i-1], 88 | norm_channels=self.norm_channels)) 89 | 90 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) 91 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) 92 | 93 | def quantize(self, x): 94 | B, C, H, W = x.shape 95 | 96 | # B, C, H, W -> B, H, W, C 97 | x = x.permute(0, 2, 3, 1) 98 | 99 | # B, H, W, C -> B, H*W, C 100 | x = x.reshape(x.size(0), -1, x.size(-1)) 101 | 102 | # Find nearest embedding/codebook vector 103 | # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K) 104 | dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) 105 | # (B, H*W) 106 | min_encoding_indices = torch.argmin(dist, dim=-1) 107 | 108 | # Replace encoder output with nearest codebook 109 | # quant_out -> B*H*W, C 110 | quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) 111 | 112 | # x -> B*H*W, C 113 | x = x.reshape((-1, x.size(-1))) 114 | commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) 115 | codebook_loss = torch.mean((quant_out - x.detach()) ** 2) 116 | quantize_losses = { 117 | 'codebook_loss': codebook_loss, 118 | 'commitment_loss': commmitment_loss 119 | } 120 | # Straight through estimation 121 | quant_out = x + (quant_out - x).detach() 122 | 123 | # quant_out -> B, C, H, W 124 | quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) 125 | min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) 126 | return quant_out, quantize_losses, min_encoding_indices 127 | 128 | def encode(self, x): 129 | out = self.encoder_conv_in(x) 130 | for idx, down in enumerate(self.encoder_layers): 131 | out = down(out) 132 | for mid in self.encoder_mids: 133 | out = mid(out) 134 | out = self.encoder_norm_out(out) 135 | out = nn.SiLU()(out) 136 | out = self.encoder_conv_out(out) 137 | out = self.pre_quant_conv(out) 138 | out, quant_losses, _ = self.quantize(out) 139 | return out, quant_losses 140 | 141 | def decode(self, z): 142 | out = z 143 | out = self.post_quant_conv(out) 144 | out = self.decoder_conv_in(out) 145 | for mid in self.decoder_mids: 146 | out = mid(out) 147 | for idx, up in enumerate(self.decoder_layers): 148 | out = up(out) 149 | 150 | out = self.decoder_norm_out(out) 151 | out = nn.SiLU()(out) 152 | out = self.decoder_conv_out(out) 153 | return out 154 | 155 | def forward(self, x): 156 | z, quant_losses = self.encode(x) 157 | out = self.decode(z) 158 | return out, z, quant_losses 159 | 160 | -------------------------------------------------------------------------------- /tools/sample_ddpm_text_cond.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import argparse 3 | import yaml 4 | import os 5 | from torchvision.utils import make_grid 6 | from tqdm import tqdm 7 | from models.unet_cond_base import Unet 8 | from models.vqvae import VQVAE 9 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 10 | from utils.config_utils import * 11 | from utils.text_utils import * 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | def sample(model, scheduler, train_config, diffusion_model_config, 17 | autoencoder_model_config, diffusion_config, dataset_config, vae, text_tokenizer, text_model): 18 | r""" 19 | Sample stepwise by going backward one timestep at a time. 20 | We save the x0 predictions 21 | """ 22 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 23 | 24 | ########### Sample random noise latent ########## 25 | # For not fixing generation with one sample 26 | xt = torch.randn((1, 27 | autoencoder_model_config['z_channels'], 28 | im_size, 29 | im_size)).to(device) 30 | ############################################### 31 | 32 | ############ Create Conditional input ############### 33 | text_prompt = ['She is a woman with blond hair. She is wearing lipstick.'] 34 | neg_prompt = ['He is a man.'] 35 | empty_prompt = [''] 36 | text_prompt_embed = get_text_representation(text_prompt, 37 | text_tokenizer, 38 | text_model, 39 | device) 40 | # Can replace empty prompt with negative prompt 41 | empty_text_embed = get_text_representation(empty_prompt, text_tokenizer, text_model, device) 42 | assert empty_text_embed.shape == text_prompt_embed.shape 43 | 44 | uncond_input = { 45 | 'text': empty_text_embed 46 | } 47 | cond_input = { 48 | 'text': text_prompt_embed 49 | } 50 | ############################################### 51 | 52 | # By default classifier free guidance is disabled 53 | # Change value in config or change default value here to enable it 54 | cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0) 55 | 56 | ################# Sampling Loop ######################## 57 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 58 | # Get prediction of noise 59 | t = (torch.ones((xt.shape[0],)) * i).long().to(device) 60 | noise_pred_cond = model(xt, t, cond_input) 61 | 62 | if cf_guidance_scale > 1: 63 | noise_pred_uncond = model(xt, t, uncond_input) 64 | noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond) 65 | else: 66 | noise_pred = noise_pred_cond 67 | 68 | # Use scheduler to get x0 and xt-1 69 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 70 | 71 | # Save x0 72 | # ims = torch.clamp(xt, -1., 1.).detach().cpu() 73 | if i == 0: 74 | # Decode ONLY the final iamge to save time 75 | ims = vae.decode(xt) 76 | else: 77 | ims = x0_pred 78 | 79 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 80 | ims = (ims + 1) / 2 81 | grid = make_grid(ims, nrow=1) 82 | img = torchvision.transforms.ToPILImage()(grid) 83 | 84 | if not os.path.exists(os.path.join(train_config['task_name'], 'cond_text_samples')): 85 | os.mkdir(os.path.join(train_config['task_name'], 'cond_text_samples')) 86 | img.save(os.path.join(train_config['task_name'], 'cond_text_samples', 'x0_{}.png'.format(i))) 87 | img.close() 88 | ############################################################## 89 | 90 | 91 | def infer(args): 92 | # Read the config file # 93 | with open(args.config_path, 'r') as file: 94 | try: 95 | config = yaml.safe_load(file) 96 | except yaml.YAMLError as exc: 97 | print(exc) 98 | print(config) 99 | ######################## 100 | 101 | diffusion_config = config['diffusion_params'] 102 | dataset_config = config['dataset_params'] 103 | diffusion_model_config = config['ldm_params'] 104 | autoencoder_model_config = config['autoencoder_params'] 105 | train_config = config['train_params'] 106 | 107 | ########## Create the noise scheduler ############# 108 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 109 | beta_start=diffusion_config['beta_start'], 110 | beta_end=diffusion_config['beta_end']) 111 | ############################################### 112 | 113 | text_tokenizer = None 114 | text_model = None 115 | 116 | ############# Validate the config ################# 117 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None) 118 | assert condition_config is not None, ("This sampling script is for text conditional " 119 | "but no conditioning config found") 120 | condition_types = get_config_value(condition_config, 'condition_types', []) 121 | assert 'text' in condition_types, ("This sampling script is for text conditional " 122 | "but no text condition found in config") 123 | validate_text_config(condition_config) 124 | ############################################### 125 | 126 | ############# Load tokenizer and text model ################# 127 | with torch.no_grad(): 128 | # Load tokenizer and text model based on config 129 | # Also get empty text representation 130 | text_tokenizer, text_model = get_tokenizer_and_model(condition_config['text_condition_config'] 131 | ['text_embed_model'], device=device) 132 | ############################################### 133 | 134 | ########## Load Unet ############# 135 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 136 | model_config=diffusion_model_config).to(device) 137 | model.eval() 138 | if os.path.exists(os.path.join(train_config['task_name'], 139 | train_config['ldm_ckpt_name'])): 140 | print('Loaded unet checkpoint') 141 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 142 | train_config['ldm_ckpt_name']), 143 | map_location=device)) 144 | else: 145 | raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'], 146 | train_config['ldm_ckpt_name']))) 147 | ##################################### 148 | 149 | # Create output directories 150 | if not os.path.exists(train_config['task_name']): 151 | os.mkdir(train_config['task_name']) 152 | 153 | ########## Load VQVAE ############# 154 | vae = VQVAE(im_channels=dataset_config['im_channels'], 155 | model_config=autoencoder_model_config).to(device) 156 | vae.eval() 157 | 158 | # Load vae if found 159 | if os.path.exists(os.path.join(train_config['task_name'], 160 | train_config['vqvae_autoencoder_ckpt_name'])): 161 | print('Loaded vae checkpoint') 162 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 163 | train_config['vqvae_autoencoder_ckpt_name']), 164 | map_location=device), strict=True) 165 | else: 166 | raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['task_name'], 167 | train_config['vqvae_autoencoder_ckpt_name']))) 168 | ##################################### 169 | 170 | with torch.no_grad(): 171 | sample(model, scheduler, train_config, diffusion_model_config, 172 | autoencoder_model_config, diffusion_config, dataset_config, vae,text_tokenizer, text_model) 173 | 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation with only ' 177 | 'text conditioning') 178 | parser.add_argument('--config', dest='config_path', 179 | default='config/celebhq_text_cond.yaml', type=str) 180 | args = parser.parse_args() 181 | infer(args) 182 | -------------------------------------------------------------------------------- /tools/train_ddpm_cond.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | from torch.optim import Adam 6 | from dataset.mnist_dataset import MnistDataset 7 | from dataset.celeb_dataset import CelebDataset 8 | from torch.utils.data import DataLoader 9 | from models.unet_cond_base import Unet 10 | from models.vqvae import VQVAE 11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 12 | from utils.text_utils import * 13 | from utils.config_utils import * 14 | from utils.diffusion_utils import * 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def train(args): 20 | # Read the config file # 21 | with open(args.config_path, 'r') as file: 22 | try: 23 | config = yaml.safe_load(file) 24 | except yaml.YAMLError as exc: 25 | print(exc) 26 | print(config) 27 | ######################## 28 | 29 | diffusion_config = config['diffusion_params'] 30 | dataset_config = config['dataset_params'] 31 | diffusion_model_config = config['ldm_params'] 32 | autoencoder_model_config = config['autoencoder_params'] 33 | train_config = config['train_params'] 34 | 35 | ########## Create the noise scheduler ############# 36 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 37 | beta_start=diffusion_config['beta_start'], 38 | beta_end=diffusion_config['beta_end']) 39 | ############################################### 40 | 41 | # Instantiate Condition related components 42 | text_tokenizer = None 43 | text_model = None 44 | empty_text_embed = None 45 | condition_types = [] 46 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None) 47 | if condition_config is not None: 48 | assert 'condition_types' in condition_config, \ 49 | "condition type missing in conditioning config" 50 | condition_types = condition_config['condition_types'] 51 | if 'text' in condition_types: 52 | validate_text_config(condition_config) 53 | with torch.no_grad(): 54 | # Load tokenizer and text model based on config 55 | # Also get empty text representation 56 | text_tokenizer, text_model = get_tokenizer_and_model(condition_config['text_condition_config'] 57 | ['text_embed_model'], device=device) 58 | empty_text_embed = get_text_representation([''], text_tokenizer, text_model, device) 59 | 60 | im_dataset_cls = { 61 | 'mnist': MnistDataset, 62 | 'celebhq': CelebDataset, 63 | }.get(dataset_config['name']) 64 | 65 | im_dataset = im_dataset_cls(split='train', 66 | im_path=dataset_config['im_path'], 67 | im_size=dataset_config['im_size'], 68 | im_channels=dataset_config['im_channels'], 69 | use_latents=True, 70 | latent_path=os.path.join(train_config['task_name'], 71 | train_config['vqvae_latent_dir_name']), 72 | condition_config=condition_config) 73 | 74 | data_loader = DataLoader(im_dataset, 75 | batch_size=train_config['ldm_batch_size'], 76 | shuffle=True) 77 | 78 | # Instantiate the unet model 79 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 80 | model_config=diffusion_model_config).to(device) 81 | model.train() 82 | 83 | vae = None 84 | # Load VAE ONLY if latents are not to be saved or some are missing 85 | if not im_dataset.use_latents: 86 | print('Loading vqvae model as latents not present') 87 | vae = VQVAE(im_channels=dataset_config['im_channels'], 88 | model_config=autoencoder_model_config).to(device) 89 | vae.eval() 90 | # Load vae if found 91 | if os.path.exists(os.path.join(train_config['task_name'], 92 | train_config['vqvae_autoencoder_ckpt_name'])): 93 | print('Loaded vae checkpoint') 94 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 95 | train_config['vqvae_autoencoder_ckpt_name']), 96 | map_location=device)) 97 | else: 98 | raise Exception('VAE checkpoint not found and use_latents was disabled') 99 | 100 | # Specify training parameters 101 | num_epochs = train_config['ldm_epochs'] 102 | optimizer = Adam(model.parameters(), lr=train_config['ldm_lr']) 103 | criterion = torch.nn.MSELoss() 104 | 105 | # Load vae and freeze parameters ONLY if latents already not saved 106 | if not im_dataset.use_latents: 107 | assert vae is not None 108 | for param in vae.parameters(): 109 | param.requires_grad = False 110 | 111 | # Run training 112 | for epoch_idx in range(num_epochs): 113 | losses = [] 114 | for data in tqdm(data_loader): 115 | cond_input = None 116 | if condition_config is not None: 117 | im, cond_input = data 118 | else: 119 | im = data 120 | optimizer.zero_grad() 121 | im = im.float().to(device) 122 | if not im_dataset.use_latents: 123 | with torch.no_grad(): 124 | im, _ = vae.encode(im) 125 | 126 | ########### Handling Conditional Input ########### 127 | if 'text' in condition_types: 128 | with torch.no_grad(): 129 | assert 'text' in cond_input, 'Conditioning Type Text but no text conditioning input present' 130 | validate_text_config(condition_config) 131 | text_condition = get_text_representation(cond_input['text'], 132 | text_tokenizer, 133 | text_model, 134 | device) 135 | text_drop_prob = get_config_value(condition_config['text_condition_config'], 136 | 'cond_drop_prob', 0.) 137 | text_condition = drop_text_condition(text_condition, im, empty_text_embed, text_drop_prob) 138 | cond_input['text'] = text_condition 139 | if 'image' in condition_types: 140 | assert 'image' in cond_input, 'Conditioning Type Image but no image conditioning input present' 141 | validate_image_config(condition_config) 142 | cond_input_image = cond_input['image'].to(device) 143 | # Drop condition 144 | im_drop_prob = get_config_value(condition_config['image_condition_config'], 145 | 'cond_drop_prob', 0.) 146 | cond_input['image'] = drop_image_condition(cond_input_image, im, im_drop_prob) 147 | if 'class' in condition_types: 148 | assert 'class' in cond_input, 'Conditioning Type Class but no class conditioning input present' 149 | validate_class_config(condition_config) 150 | class_condition = torch.nn.functional.one_hot( 151 | cond_input['class'], 152 | condition_config['class_condition_config']['num_classes']).to(device) 153 | class_drop_prob = get_config_value(condition_config['class_condition_config'], 154 | 'cond_drop_prob', 0.) 155 | # Drop condition 156 | cond_input['class'] = drop_class_condition(class_condition, class_drop_prob, im) 157 | ################################################ 158 | 159 | # Sample random noise 160 | noise = torch.randn_like(im).to(device) 161 | 162 | # Sample timestep 163 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device) 164 | 165 | # Add noise to images according to timestep 166 | noisy_im = scheduler.add_noise(im, noise, t) 167 | noise_pred = model(noisy_im, t, cond_input=cond_input) 168 | loss = criterion(noise_pred, noise) 169 | losses.append(loss.item()) 170 | loss.backward() 171 | optimizer.step() 172 | print('Finished epoch:{} | Loss : {:.4f}'.format( 173 | epoch_idx + 1, 174 | np.mean(losses))) 175 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 176 | train_config['ldm_ckpt_name'])) 177 | 178 | print('Done Training ...') 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser(description='Arguments for ddpm training') 183 | parser.add_argument('--config', dest='config_path', 184 | default='config/celebhq_text_cond_clip.yaml', type=str) 185 | args = parser.parse_args() 186 | train(args) 187 | -------------------------------------------------------------------------------- /models/unet_cond_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum 3 | import torch.nn as nn 4 | from models.blocks import get_time_embedding 5 | from models.blocks import DownBlock, MidBlock, UpBlockUnet 6 | from utils.config_utils import * 7 | 8 | 9 | class Unet(nn.Module): 10 | r""" 11 | Unet model comprising 12 | Down blocks, Midblocks and Uplocks 13 | """ 14 | 15 | def __init__(self, im_channels, model_config): 16 | super().__init__() 17 | self.down_channels = model_config['down_channels'] 18 | self.mid_channels = model_config['mid_channels'] 19 | self.t_emb_dim = model_config['time_emb_dim'] 20 | self.down_sample = model_config['down_sample'] 21 | self.num_down_layers = model_config['num_down_layers'] 22 | self.num_mid_layers = model_config['num_mid_layers'] 23 | self.num_up_layers = model_config['num_up_layers'] 24 | self.attns = model_config['attn_down'] 25 | self.norm_channels = model_config['norm_channels'] 26 | self.num_heads = model_config['num_heads'] 27 | self.conv_out_channels = model_config['conv_out_channels'] 28 | 29 | # Validating Unet Model configurations 30 | assert self.mid_channels[0] == self.down_channels[-1] 31 | assert self.mid_channels[-1] == self.down_channels[-2] 32 | assert len(self.down_sample) == len(self.down_channels) - 1 33 | assert len(self.attns) == len(self.down_channels) - 1 34 | 35 | ######## Class, Mask and Text Conditioning Config ##### 36 | self.class_cond = False 37 | self.text_cond = False 38 | self.image_cond = False 39 | self.text_embed_dim = None 40 | self.condition_config = get_config_value(model_config, 'condition_config', None) 41 | if self.condition_config is not None: 42 | assert 'condition_types' in self.condition_config, 'Condition Type not provided in model config' 43 | condition_types = self.condition_config['condition_types'] 44 | if 'class' in condition_types: 45 | validate_class_config(self.condition_config) 46 | self.class_cond = True 47 | self.num_classes = self.condition_config['class_condition_config']['num_classes'] 48 | if 'text' in condition_types: 49 | validate_text_config(self.condition_config) 50 | self.text_cond = True 51 | self.text_embed_dim = self.condition_config['text_condition_config']['text_embed_dim'] 52 | if 'image' in condition_types: 53 | self.image_cond = True 54 | self.im_cond_input_ch = self.condition_config['image_condition_config'][ 55 | 'image_condition_input_channels'] 56 | self.im_cond_output_ch = self.condition_config['image_condition_config'][ 57 | 'image_condition_output_channels'] 58 | if self.class_cond: 59 | # Rather than using a special null class we dont add the 60 | # class embedding information for unconditional generation 61 | self.class_emb = nn.Embedding(self.num_classes, 62 | self.t_emb_dim) 63 | 64 | if self.image_cond: 65 | # Map the mask image to a N channel image and 66 | # concat that with input across channel dimension 67 | self.cond_conv_in = nn.Conv2d(in_channels=self.im_cond_input_ch, 68 | out_channels=self.im_cond_output_ch, 69 | kernel_size=1, 70 | bias=False) 71 | self.conv_in_concat = nn.Conv2d(im_channels + self.im_cond_output_ch, 72 | self.down_channels[0], kernel_size=3, padding=1) 73 | else: 74 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) 75 | self.cond = self.text_cond or self.image_cond or self.class_cond 76 | ################################### 77 | 78 | # Initial projection from sinusoidal time embedding 79 | self.t_proj = nn.Sequential( 80 | nn.Linear(self.t_emb_dim, self.t_emb_dim), 81 | nn.SiLU(), 82 | nn.Linear(self.t_emb_dim, self.t_emb_dim) 83 | ) 84 | 85 | self.up_sample = list(reversed(self.down_sample)) 86 | self.downs = nn.ModuleList([]) 87 | 88 | # Build the Downblocks 89 | for i in range(len(self.down_channels) - 1): 90 | # Cross Attention and Context Dim only needed if text condition is present 91 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, 92 | down_sample=self.down_sample[i], 93 | num_heads=self.num_heads, 94 | num_layers=self.num_down_layers, 95 | attn=self.attns[i], norm_channels=self.norm_channels, 96 | cross_attn=self.text_cond, 97 | context_dim=self.text_embed_dim)) 98 | 99 | self.mids = nn.ModuleList([]) 100 | # Build the Midblocks 101 | for i in range(len(self.mid_channels) - 1): 102 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, 103 | num_heads=self.num_heads, 104 | num_layers=self.num_mid_layers, 105 | norm_channels=self.norm_channels, 106 | cross_attn=self.text_cond, 107 | context_dim=self.text_embed_dim)) 108 | 109 | self.ups = nn.ModuleList([]) 110 | # Build the Upblocks 111 | for i in reversed(range(len(self.down_channels) - 1)): 112 | self.ups.append( 113 | UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, 114 | self.t_emb_dim, up_sample=self.down_sample[i], 115 | num_heads=self.num_heads, 116 | num_layers=self.num_up_layers, 117 | norm_channels=self.norm_channels, 118 | cross_attn=self.text_cond, 119 | context_dim=self.text_embed_dim)) 120 | 121 | self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) 122 | self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) 123 | 124 | def forward(self, x, t, cond_input=None): 125 | # Shapes assuming downblocks are [C1, C2, C3, C4] 126 | # Shapes assuming midblocks are [C4, C4, C3] 127 | # Shapes assuming downsamples are [True, True, False] 128 | if self.cond: 129 | assert cond_input is not None, \ 130 | "Model initialized with conditioning so cond_input cannot be None" 131 | if self.image_cond: 132 | ######## Mask Conditioning ######## 133 | validate_image_conditional_input(cond_input, x) 134 | im_cond = cond_input['image'] 135 | im_cond = torch.nn.functional.interpolate(im_cond, size=x.shape[-2:]) 136 | im_cond = self.cond_conv_in(im_cond) 137 | assert im_cond.shape[-2:] == x.shape[-2:] 138 | x = torch.cat([x, im_cond], dim=1) 139 | # B x (C+N) x H x W 140 | out = self.conv_in_concat(x) 141 | ##################################### 142 | else: 143 | # B x C x H x W 144 | out = self.conv_in(x) 145 | # B x C1 x H x W 146 | 147 | # t_emb -> B x t_emb_dim 148 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) 149 | t_emb = self.t_proj(t_emb) 150 | 151 | ######## Class Conditioning ######## 152 | if self.class_cond: 153 | validate_class_conditional_input(cond_input, x, self.num_classes) 154 | class_embed = einsum(cond_input['class'].float(), self.class_emb.weight, 'b n, n d -> b d') 155 | t_emb += class_embed 156 | #################################### 157 | 158 | context_hidden_states = None 159 | if self.text_cond: 160 | assert 'text' in cond_input, \ 161 | "Model initialized with text conditioning but cond_input has no text information" 162 | context_hidden_states = cond_input['text'] 163 | down_outs = [] 164 | 165 | for idx, down in enumerate(self.downs): 166 | down_outs.append(out) 167 | out = down(out, t_emb, context_hidden_states) 168 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4] 169 | # out B x C4 x H/4 x W/4 170 | 171 | for mid in self.mids: 172 | out = mid(out, t_emb, context_hidden_states) 173 | # out B x C3 x H/4 x W/4 174 | 175 | for up in self.ups: 176 | down_out = down_outs.pop() 177 | out = up(out, down_out, t_emb, context_hidden_states) 178 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W] 179 | out = self.norm_out(out) 180 | out = nn.SiLU()(out) 181 | out = self.conv_out(out) 182 | # out B x C x H x W 183 | return out 184 | -------------------------------------------------------------------------------- /tools/train_vqvae.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import random 5 | import torchvision 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from models.vqvae import VQVAE 10 | from models.lpips import LPIPS 11 | from models.discriminator import Discriminator 12 | from torch.utils.data.dataloader import DataLoader 13 | from dataset.mnist_dataset import MnistDataset 14 | from dataset.celeb_dataset import CelebDataset 15 | from torch.optim import Adam 16 | from torchvision.utils import make_grid 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def train(args): 22 | # Read the config file # 23 | with open(args.config_path, 'r') as file: 24 | try: 25 | config = yaml.safe_load(file) 26 | except yaml.YAMLError as exc: 27 | print(exc) 28 | print(config) 29 | 30 | dataset_config = config['dataset_params'] 31 | autoencoder_config = config['autoencoder_params'] 32 | train_config = config['train_params'] 33 | 34 | # Set the desired seed value # 35 | seed = train_config['seed'] 36 | torch.manual_seed(seed) 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | if device == 'cuda': 40 | torch.cuda.manual_seed_all(seed) 41 | ############################# 42 | 43 | # Create the model and dataset # 44 | model = VQVAE(im_channels=dataset_config['im_channels'], 45 | model_config=autoencoder_config).to(device) 46 | # Create the dataset 47 | im_dataset_cls = { 48 | 'mnist': MnistDataset, 49 | 'celebhq': CelebDataset, 50 | }.get(dataset_config['name']) 51 | 52 | im_dataset = im_dataset_cls(split='train', 53 | im_path=dataset_config['im_path'], 54 | im_size=dataset_config['im_size'], 55 | im_channels=dataset_config['im_channels']) 56 | 57 | data_loader = DataLoader(im_dataset, 58 | batch_size=train_config['autoencoder_batch_size'], 59 | shuffle=True) 60 | 61 | # Create output directories 62 | if not os.path.exists(train_config['task_name']): 63 | os.mkdir(train_config['task_name']) 64 | 65 | num_epochs = train_config['autoencoder_epochs'] 66 | 67 | # L1/L2 loss for Reconstruction 68 | recon_criterion = torch.nn.MSELoss() 69 | # Disc Loss can even be BCEWithLogits 70 | disc_criterion = torch.nn.MSELoss() 71 | 72 | # No need to freeze lpips as lpips.py takes care of that 73 | lpips_model = LPIPS().eval().to(device) 74 | discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device) 75 | 76 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 77 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 78 | 79 | disc_step_start = train_config['disc_start'] 80 | step_count = 0 81 | 82 | # This is for accumulating gradients incase the images are huge 83 | # And one cant afford higher batch sizes 84 | acc_steps = train_config['autoencoder_acc_steps'] 85 | image_save_steps = train_config['autoencoder_img_save_steps'] 86 | img_save_count = 0 87 | 88 | for epoch_idx in range(num_epochs): 89 | recon_losses = [] 90 | codebook_losses = [] 91 | #commitment_losses = [] 92 | perceptual_losses = [] 93 | disc_losses = [] 94 | gen_losses = [] 95 | losses = [] 96 | 97 | optimizer_g.zero_grad() 98 | optimizer_d.zero_grad() 99 | 100 | for im in tqdm(data_loader): 101 | step_count += 1 102 | im = im.float().to(device) 103 | 104 | # Fetch autoencoders output(reconstructions) 105 | model_output = model(im) 106 | output, z, quantize_losses = model_output 107 | 108 | # Image Saving Logic 109 | if step_count % image_save_steps == 0 or step_count == 1: 110 | sample_size = min(8, im.shape[0]) 111 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu() 112 | save_output = ((save_output + 1) / 2) 113 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu() 114 | 115 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) 116 | img = torchvision.transforms.ToPILImage()(grid) 117 | if not os.path.exists(os.path.join(train_config['task_name'],'vqvae_autoencoder_samples')): 118 | os.mkdir(os.path.join(train_config['task_name'], 'vqvae_autoencoder_samples')) 119 | img.save(os.path.join(train_config['task_name'],'vqvae_autoencoder_samples', 120 | 'current_autoencoder_sample_{}.png'.format(img_save_count))) 121 | img_save_count += 1 122 | img.close() 123 | 124 | ######### Optimize Generator ########## 125 | # L2 Loss 126 | recon_loss = recon_criterion(output, im) 127 | recon_losses.append(recon_loss.item()) 128 | recon_loss = recon_loss / acc_steps 129 | g_loss = (recon_loss + 130 | (train_config['codebook_weight'] * quantize_losses['codebook_loss'] / acc_steps) + 131 | (train_config['commitment_beta'] * quantize_losses['commitment_loss'] / acc_steps)) 132 | codebook_losses.append(train_config['codebook_weight'] * quantize_losses['codebook_loss'].item()) 133 | # Adversarial loss only if disc_step_start steps passed 134 | if step_count > disc_step_start: 135 | disc_fake_pred = discriminator(model_output[0]) 136 | disc_fake_loss = disc_criterion(disc_fake_pred, 137 | torch.ones(disc_fake_pred.shape, 138 | device=disc_fake_pred.device)) 139 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item()) 140 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps 141 | lpips_loss = torch.mean(lpips_model(output, im)) 142 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item()) 143 | g_loss += train_config['perceptual_weight']*lpips_loss / acc_steps 144 | losses.append(g_loss.item()) 145 | g_loss.backward() 146 | ##################################### 147 | 148 | ######### Optimize Discriminator ####### 149 | if step_count > disc_step_start: 150 | fake = output 151 | disc_fake_pred = discriminator(fake.detach()) 152 | disc_real_pred = discriminator(im) 153 | disc_fake_loss = disc_criterion(disc_fake_pred, 154 | torch.zeros(disc_fake_pred.shape, 155 | device=disc_fake_pred.device)) 156 | disc_real_loss = disc_criterion(disc_real_pred, 157 | torch.ones(disc_real_pred.shape, 158 | device=disc_real_pred.device)) 159 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2 160 | disc_losses.append(disc_loss.item()) 161 | disc_loss = disc_loss / acc_steps 162 | disc_loss.backward() 163 | if step_count % acc_steps == 0: 164 | optimizer_d.step() 165 | optimizer_d.zero_grad() 166 | ##################################### 167 | 168 | if step_count % acc_steps == 0: 169 | optimizer_g.step() 170 | optimizer_g.zero_grad() 171 | optimizer_d.step() 172 | optimizer_d.zero_grad() 173 | optimizer_g.step() 174 | optimizer_g.zero_grad() 175 | if len(disc_losses) > 0: 176 | print( 177 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | ' 178 | 'Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}'. 179 | format(epoch_idx + 1, 180 | np.mean(recon_losses), 181 | np.mean(perceptual_losses), 182 | np.mean(codebook_losses), 183 | np.mean(gen_losses), 184 | np.mean(disc_losses))) 185 | else: 186 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}'. 187 | format(epoch_idx + 1, 188 | np.mean(recon_losses), 189 | np.mean(perceptual_losses), 190 | np.mean(codebook_losses))) 191 | 192 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 193 | train_config['vqvae_autoencoder_ckpt_name'])) 194 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'], 195 | train_config['vqvae_discriminator_ckpt_name'])) 196 | print('Done Training...') 197 | 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser(description='Arguments for vq vae training') 201 | parser.add_argument('--config', dest='config_path', 202 | default='config/mnist.yaml', type=str) 203 | args = parser.parse_args() 204 | train(args) 205 | -------------------------------------------------------------------------------- /tools/sample_ddpm_text_image_cond.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import torchvision 5 | import argparse 6 | import yaml 7 | import os 8 | from torchvision.utils import make_grid 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from models.unet_cond_base import Unet 12 | from models.vqvae import VQVAE 13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 14 | from transformers import DistilBertModel, DistilBertTokenizer, CLIPTokenizer, CLIPTextModel 15 | from utils.config_utils import * 16 | from utils.text_utils import * 17 | from dataset.celeb_dataset import CelebDataset 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | 22 | def sample(model, scheduler, train_config, diffusion_model_config, 23 | autoencoder_model_config, diffusion_config, dataset_config, vae, text_tokenizer, text_model): 24 | r""" 25 | Sample stepwise by going backward one timestep at a time. 26 | We save the x0 predictions 27 | """ 28 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 29 | 30 | ########### Sample random noise latent ########## 31 | # For not fixing generation with one sample 32 | xt = torch.randn((1, 33 | autoencoder_model_config['z_channels'], 34 | im_size, 35 | im_size)).to(device) 36 | ############################################### 37 | 38 | ############ Create Conditional input ############### 39 | text_prompt = ['She is a woman with blond hair. She is wearing lipstick.'] 40 | neg_prompts = ['He is a man.'] 41 | empty_prompt = [''] 42 | text_prompt_embed = get_text_representation(text_prompt, 43 | text_tokenizer, 44 | text_model, 45 | device) 46 | # Can replace empty prompt with negative prompt 47 | empty_text_embed = get_text_representation(empty_prompt, text_tokenizer, text_model, device) 48 | assert empty_text_embed.shape == text_prompt_embed.shape 49 | 50 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None) 51 | validate_image_config(condition_config) 52 | 53 | # This is required to get a random but valid mask 54 | dataset = CelebDataset(split='train', 55 | im_path=dataset_config['im_path'], 56 | im_size=dataset_config['im_size'], 57 | im_channels=dataset_config['im_channels'], 58 | use_latents=True, 59 | latent_path=os.path.join(train_config['task_name'], 60 | train_config['vqvae_latent_dir_name']), 61 | condition_config=condition_config) 62 | mask_idx = random.randint(0, len(dataset.masks)) 63 | mask = dataset.get_mask(mask_idx).unsqueeze(0).to(device) 64 | uncond_input = { 65 | 'text': empty_text_embed, 66 | 'image': torch.zeros_like(mask) 67 | } 68 | cond_input = { 69 | 'text': text_prompt_embed, 70 | 'image': mask 71 | } 72 | ############################################### 73 | 74 | # By default classifier free guidance is disabled 75 | # Change value in config or change default value here to enable it 76 | cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0) 77 | 78 | ################# Sampling Loop ######################## 79 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 80 | # Get prediction of noise 81 | t = (torch.ones((xt.shape[0],)) * i).long().to(device) 82 | noise_pred_cond = model(xt, t, cond_input) 83 | 84 | if cf_guidance_scale > 1: 85 | noise_pred_uncond = model(xt, t, uncond_input) 86 | noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond) 87 | else: 88 | noise_pred = noise_pred_cond 89 | 90 | # Use scheduler to get x0 and xt-1 91 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 92 | 93 | # Save x0 94 | if i == 0: 95 | # Decode ONLY the final image to save time 96 | ims = vae.decode(xt) 97 | else: 98 | ims = x0_pred 99 | 100 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 101 | ims = (ims + 1) / 2 102 | grid = make_grid(ims, nrow=10) 103 | img = torchvision.transforms.ToPILImage()(grid) 104 | 105 | if not os.path.exists(os.path.join(train_config['task_name'], 'cond_text_image_samples')): 106 | os.mkdir(os.path.join(train_config['task_name'], 'cond_text_image_samples')) 107 | img.save(os.path.join(train_config['task_name'], 'cond_text_image_samples', 'x0_{}.png'.format(i))) 108 | img.close() 109 | ############################################################## 110 | 111 | def infer(args): 112 | # Read the config file # 113 | with open(args.config_path, 'r') as file: 114 | try: 115 | config = yaml.safe_load(file) 116 | except yaml.YAMLError as exc: 117 | print(exc) 118 | print(config) 119 | ######################## 120 | 121 | diffusion_config = config['diffusion_params'] 122 | dataset_config = config['dataset_params'] 123 | diffusion_model_config = config['ldm_params'] 124 | autoencoder_model_config = config['autoencoder_params'] 125 | train_config = config['train_params'] 126 | 127 | ########## Create the noise scheduler ############# 128 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 129 | beta_start=diffusion_config['beta_start'], 130 | beta_end=diffusion_config['beta_end']) 131 | ############################################### 132 | 133 | ############# Validate the config ################# 134 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None) 135 | assert condition_config is not None, ("This sampling script is for image and text conditional " 136 | "but no conditioning config found") 137 | condition_types = get_config_value(condition_config, 'condition_types', []) 138 | assert 'text' in condition_types, ("This sampling script is for image and text conditional " 139 | "but no text condition found in config") 140 | assert 'image' in condition_types, ("This sampling script is for image and text conditional " 141 | "but no image condition found in config") 142 | validate_text_config(condition_config) 143 | validate_image_config(condition_config) 144 | ############################################### 145 | 146 | ############# Load tokenizer and text model ################# 147 | with torch.no_grad(): 148 | # Load tokenizer and text model based on config 149 | # Also get empty text representation 150 | text_tokenizer, text_model = get_tokenizer_and_model(condition_config['text_condition_config'] 151 | ['text_embed_model'], device=device) 152 | ############################################### 153 | 154 | ########## Load Unet ############# 155 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 156 | model_config=diffusion_model_config).to(device) 157 | model.eval() 158 | if os.path.exists(os.path.join(train_config['task_name'], 159 | train_config['ldm_ckpt_name'])): 160 | print('Loaded unet checkpoint') 161 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 162 | train_config['ldm_ckpt_name']), 163 | map_location=device)) 164 | else: 165 | raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'], 166 | train_config['ldm_ckpt_name']))) 167 | ##################################### 168 | 169 | # Create output directories 170 | if not os.path.exists(train_config['task_name']): 171 | os.mkdir(train_config['task_name']) 172 | 173 | ########## Load VQVAE ############# 174 | vae = VQVAE(im_channels=dataset_config['im_channels'], 175 | model_config=autoencoder_model_config).to(device) 176 | vae.eval() 177 | 178 | # Load vae if found 179 | if os.path.exists(os.path.join(train_config['task_name'], 180 | train_config['vqvae_autoencoder_ckpt_name'])): 181 | print('Loaded vae checkpoint') 182 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 183 | train_config['vqvae_autoencoder_ckpt_name']), 184 | map_location=device)) 185 | else: 186 | raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['task_name'], 187 | train_config['vqvae_autoencoder_ckpt_name']))) 188 | ##################################### 189 | 190 | with torch.no_grad(): 191 | sample(model, scheduler, train_config, diffusion_model_config, 192 | autoencoder_model_config, diffusion_config, dataset_config, vae, text_tokenizer, text_model) 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation ' 197 | 'with text and mask conditioning') 198 | parser.add_argument('--config', dest='config_path', 199 | default='config/celebhq_text_image_cond.yaml', type=str) 200 | args = parser.parse_args() 201 | infer(args) 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Stable Diffusion Implementation in PyTorch 2 | ======== 3 | 4 | This repository implements Stable Diffusion. 5 | As of today the repo provides code to do the following: 6 | * Training and Inference on Unconditional Latent Diffusion Models 7 | * Training a Class Conditional Latent Diffusion Model 8 | * Training a Text Conditioned Latent Diffusion Model 9 | * Training a Semantic Mask Conditioned Latent Diffusion Model 10 | * Any Combination of the above three conditioning 11 | 12 | For autoencoder I provide code for vae as well as vqvae. 13 | But both the stages of training use VQVAE only. One can easily change that to vae if needed 14 | 15 | For diffusion part, as of now it only implements DDPM with linear schedule. 16 | 17 | 18 | ## Stable Diffusion Tutorial Videos 19 | 20 | Stable Diffusion Tutorial 22 | 23 | 24 | Stable Diffusion Conditioning Tutorial 26 | 27 | ___ 28 | 29 | 30 | ## Sample Output for Autoencoder on CelebHQ 31 | Image - Top, Reconstructions - Below 32 | 33 | 34 | 35 | ## Sample Output for Unconditional LDM on CelebHQ (not fully converged) 36 | 37 | 38 | 39 | ## Sample Output for Conditional LDM 40 | ### Sample Output for Class Conditioned on MNIST 41 | ![50](https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/46a38d36-3770-4f40-895a-95a16dc6462a) 42 | ![50](https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/1562c41d-e6ff-41cf-8d1e-6909a4240a04) 43 | ![50](https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/0cde44a6-746b-4f05-9422-9604f9436d91) 44 | ![50](https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/7d6b8db2-dab4-4a17-9fe6-570d938669f6) 45 | ![50](https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/6ecc3c61-3668-4305-aa4a-0f0e3cf815a0) 46 | 47 | ### Sample Output for Text(using CLIP) and Mask Conditioned on CelebHQ (not converged) 48 | 49 |
50 | 51 | Text - She is a woman with blond hair 52 |
53 | 54 | Text - She is a woman with black hair 55 | 56 | ___ 57 | 58 | ## Setup 59 | * Create a new conda environment with python 3.8 then run below commands 60 | * `conda activate ` 61 | * ```git clone https://github.com/explainingai-code/StableDiffusion-PyTorch.git``` 62 | * ```cd StableDiffusion-PyTorch``` 63 | * ```pip install -r requirements.txt``` 64 | * Download lpips weights by opening this link in browser(dont use cURL or wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```models/weights/v0.1/vgg.pth``` 65 | 66 | ___ 67 | 68 | ## Data Preparation 69 | ### Mnist 70 | 71 | For setting up the mnist dataset follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation 72 | 73 | Ensure directory structure is following 74 | ``` 75 | StableDiffusion-PyTorch 76 | -> data 77 | -> mnist 78 | -> train 79 | -> images 80 | -> *.png 81 | -> test 82 | -> images 83 | -> *.png 84 | ``` 85 | 86 | ### CelebHQ 87 | #### Unconditional 88 | For setting up on CelebHQ for unconditional, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file). 89 | 90 | Ensure directory structure is the following 91 | ``` 92 | StableDiffusion-PyTorch 93 | -> data 94 | -> CelebAMask-HQ 95 | -> CelebA-HQ-img 96 | -> *.jpg 97 | 98 | ``` 99 | #### Mask Conditional 100 | For CelebHQ for mask conditional LDM additionally do the following: 101 | 102 | Ensure directory structure is the following 103 | ``` 104 | StableDiffusion-PyTorch 105 | -> data 106 | -> CelebAMask-HQ 107 | -> CelebA-HQ-img 108 | -> *.jpg 109 | -> CelebAMask-HQ-mask-anno 110 | -> 0/1/2/3.../14 111 | -> *.png 112 | 113 | ``` 114 | 115 | * Run `python -m utils.create_celeb_mask` from repo root to create the mask images from mask annotations 116 | 117 | Ensure directory structure is the following 118 | ``` 119 | StableDiffusion-PyTorch 120 | -> data 121 | -> CelebAMask-HQ 122 | -> CelebA-HQ-img 123 | -> *.jpg 124 | -> CelebAMask-HQ-mask-anno 125 | -> 0/1/2/3.../14 126 | -> *.png 127 | -> CelebAMask-HQ-mask 128 | -> *.png 129 | ``` 130 | 131 | #### Text Conditional 132 | For CelebHQ for text conditional LDM additionally do the following: 133 | * The repo uses captions collected as part of this repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file 134 | * Download the captions from the `text` link provided in the repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file#overview 135 | * This will download a `celeba-captions` folder, simply move this inside the `data/CelebAMask-HQ` folder as that is where the dataset class expects it to be. 136 | 137 | Ensure directory structure is the following 138 | ``` 139 | StableDiffusion-PyTorch 140 | -> data 141 | -> CelebAMask-HQ 142 | -> CelebA-HQ-img 143 | -> *.jpg 144 | -> CelebAMask-HQ-mask-anno 145 | -> 0/1/2/3.../14 146 | -> *.png 147 | -> CelebAMask-HQ-mask 148 | -> *.png 149 | -> celeba-caption 150 | -> *.txt 151 | ``` 152 | --- 153 | ## Configuration 154 | Allows you to play with different components of ddpm and autoencoder training 155 | * ```config/mnist.yaml``` - Small autoencoder and ldm can even be trained on CPU 156 | * ```config/celebhq.yaml``` - Configuration used for celebhq dataset 157 | 158 | Relevant configuration parameters 159 | 160 | Most parameters are self explanatory but below I mention couple which are specific to this repo. 161 | * ```autoencoder_acc_steps``` : For accumulating gradients if image size is too large for larger batch sizes 162 | * ```save_latents``` : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster 163 | 164 | ___ 165 | ## Training 166 | The repo provides training and inference for Mnist(Unconditional and Class Conditional) and CelebHQ (Unconditional, Text and/or Mask Conditional). 167 | 168 | For working on your own dataset: 169 | * Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance) 170 | * Create your own dataset class which will just collect all the filenames and return the image in its getitem method. Look at `mnist_dataset.py` or `celeb_dataset.py` for guidance 171 | 172 | Once the config and dataset is setup: 173 | * Train the auto encoder on your dataset using [this section](#training-autoencoder-for-ldm) 174 | * For training Unconditional LDM follow [this section](#training-unconditional-ldm) 175 | * For class conditional ldm go through [this section](#training-class-conditional-ldm) 176 | * For text conditional ldm go through [this section](#training-text-conditional-ldm) 177 | * For text and mask conditional ldm go through [this section](#training-text-and-mask-conditional-ldm) 178 | 179 | 180 | ## Training AutoEncoder for LDM 181 | * For training autoencoder on mnist,ensure the right path is mentioned in `mnist.yaml` 182 | * For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml` 183 | * For training autoencoder on your own dataset 184 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance) 185 | * Create your own dataset class, similar to celeb_dataset.py without conditining parts 186 | * Map the dataset name to the right class in the training code [here](https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/tools/train_ddpm_vqvae.py#L40) 187 | * For training autoencoder run ```python -m tools.train_vqvae --config config/mnist.yaml``` for training vqvae with the desire config file 188 | * For inference using trained autoencoder run```python -m tools.infer_vqvae --config config/mnist.yaml``` for generating reconstructions with right config file. Use save_latent in config to save the latent files 189 | 190 | 191 | ## Training Unconditional LDM 192 | Train the autoencoder first and setup dataset accordingly. 193 | 194 | For training unconditional LDM map the dataset to the right class in `train_ddpm_vqvae.py` 195 | * ```python -m tools.train_ddpm_vqvae --config config/mnist.yaml``` for training unconditional ddpm using right config 196 | * ```python -m tools.sample_ddpm_vqvae --config config/mnist.yaml``` for generating images using trained ddpm 197 | 198 | ## Training Conditional LDM 199 | For training conditional models we need two changes: 200 | * Dataset classes must provide the additional conditional inputs(see below) 201 | * Config must be changed with additional conditioning config added 202 | 203 | Specifically the dataset `getitem` will return the following: 204 | * `image_tensor` for unconditional training 205 | * tuple of `(image_tensor, cond_input )` for conditional training where cond_input is a dictionary consisting of keys ```{class/text/image}``` 206 | 207 | ### Training Class Conditional LDM 208 | The repo provides class conditional latent diffusion model training code for mnist dataset, so one 209 | can use that to follow the same for their own dataset 210 | 211 | * Use `mnist_class_cond.yaml` config file as a guide to create your class conditional config file. 212 | Specifically following new keys need to be modified according to your dataset within `ldm_params`. 213 | * ``` 214 | condition_config: 215 | condition_types: ['class'] 216 | class_condition_config : 217 | num_classes : 218 | cond_drop_prob : 219 | ``` 220 | * Create a dataset class similar to mnist where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs. 221 | * For class, conditional input will ONLY be the integer class 222 | * ``` 223 | (image_tensor, { 224 | 'class' : {0/1/.../num_classes} 225 | }) 226 | 227 | For training class conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config 228 | * ```python -m tools.train_ddpm_cond --config config/mnist_class_cond.yaml``` for training class conditional on mnist 229 | * ```python -m tools.sample_ddpm_class_cond --config config/mnist.yaml``` for generating images using class conditional trained ddpm 230 | 231 | ### Training Text Conditional LDM 232 | The repo provides text conditional latent diffusion model training code for celebhq dataset, so one 233 | can use that to follow the same for their own dataset 234 | 235 | * Use `celebhq_text_cond.yaml` config file as a guide to create your config file. 236 | Specifically following new keys need to be modified according to your dataset within `ldm_params`. 237 | * ``` 238 | condition_config: 239 | condition_types: [ 'text' ] 240 | text_condition_config: 241 | text_embed_model: 'clip' or 'bert' 242 | text_embed_dim: 512 or 768 243 | cond_drop_prob: 0.1 244 | ``` 245 | * Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs. 246 | * For text, conditional input will ONLY be the caption 247 | * ``` 248 | (image_tensor, { 249 | 'text' : 'a sample caption for image_tensor' 250 | }) 251 | 252 | For training text conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config 253 | * ```python -m tools.train_ddpm_cond --config config/celebhq_text_cond.yaml``` for training text conditioned ldm on celebhq 254 | * ```python -m tools.sample_ddpm_text_cond --config config/celebhq_text_cond.yaml``` for generating images using text conditional trained ddpm 255 | 256 | ### Training Text and Mask Conditional LDM 257 | The repo provides text and mask conditional latent diffusion model training code for celebhq dataset, so one 258 | can use that to follow the same for their own dataset and can even use that train a mask only conditional ldm 259 | 260 | * Use `celebhq_text_image_cond.yaml` config file as a guide to create your config file. 261 | Specifically following new keys need to be modified according to your dataset within `ldm_params`. 262 | * ``` 263 | condition_config: 264 | condition_types: [ 'text', 'image' ] 265 | text_condition_config: 266 | text_embed_model: 'clip' or 'bert 267 | text_embed_dim: 512 or 768 268 | cond_drop_prob: 0.1 269 | image_condition_config: 270 | image_condition_input_channels: 18 271 | image_condition_output_channels: 3 272 | image_condition_h : 512 273 | image_condition_w : 512 274 | cond_drop_prob: 0.1 275 | ``` 276 | * Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs. 277 | * For text and mask, conditional input will be caption and mask image 278 | * ``` 279 | (image_tensor, { 280 | 'text' : 'a sample caption for image_tensor', 281 | 'image' : NUM_CLASSES x MASK_H x MASK_W 282 | }) 283 | 284 | For training text unconditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config 285 | * ```python -m tools.train_ddpm_cond --config config/celebhq_text_image_cond.yaml``` for training text and mask conditioned ldm on celebhq 286 | * ```python -m tools.sample_ddpm_text_image_cond --config config/celebhq_text_image_cond.yaml``` for generating images using text and mask conditional trained ddpm 287 | 288 | 289 | ## Output 290 | Outputs will be saved according to the configuration present in yaml files. 291 | 292 | For every run a folder of ```task_name``` key in config will be created 293 | 294 | During training of autoencoder the following output will be saved 295 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory 296 | * Sample reconstructions in ```task_name/vqvae_autoencoder_samples``` 297 | 298 | During inference of autoencoder the following output will be saved 299 | * Reconstructions for random images in ```task_name``` 300 | * Latents will be save in ```task_name/vqvae_latent_dir_name``` if mentioned in config 301 | 302 | During training and inference of ddpm following output will be saved 303 | * During training of unconditional or conditional DDPM we will save the latest checkpoint in ```task_name``` directory 304 | * During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0 305 | * During sampling, class conditionally sampled image grid for all timesteps in ```task_name/cond_class_samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0 306 | * During sampling, text only conditionally sampled image grid for all timesteps in ```task_name/cond_text_samples/*.png``` . The final decoded generated image will be `x0_0.png` . Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0 307 | * During sampling, image only conditionally sampled image grid for all timesteps in ```task_name/cond_text_image_samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0 308 | 309 | 310 | 311 | 312 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_time_embedding(time_steps, temb_dim): 6 | r""" 7 | Convert time steps tensor into an embedding using the 8 | sinusoidal time embedding formula 9 | :param time_steps: 1D tensor of length batch size 10 | :param temb_dim: Dimension of the embedding 11 | :return: BxD embedding representation of B time steps 12 | """ 13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" 14 | 15 | # factor = 10000^(2i/d_model) 16 | factor = 10000 ** ((torch.arange( 17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) 18 | ) 19 | 20 | # pos / factor 21 | # timesteps B -> B, 1 -> B, temb_dim 22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor 23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) 24 | return t_emb 25 | 26 | 27 | class DownBlock(nn.Module): 28 | r""" 29 | Down conv block with attention. 30 | Sequence of following block 31 | 1. Resnet block with time embedding 32 | 2. Attention block 33 | 3. Downsample 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, t_emb_dim, 37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): 38 | super().__init__() 39 | self.num_layers = num_layers 40 | self.down_sample = down_sample 41 | self.attn = attn 42 | self.context_dim = context_dim 43 | self.cross_attn = cross_attn 44 | self.t_emb_dim = t_emb_dim 45 | self.resnet_conv_first = nn.ModuleList( 46 | [ 47 | nn.Sequential( 48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 49 | nn.SiLU(), 50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, 51 | kernel_size=3, stride=1, padding=1), 52 | ) 53 | for i in range(num_layers) 54 | ] 55 | ) 56 | if self.t_emb_dim is not None: 57 | self.t_emb_layers = nn.ModuleList([ 58 | nn.Sequential( 59 | nn.SiLU(), 60 | nn.Linear(self.t_emb_dim, out_channels) 61 | ) 62 | for _ in range(num_layers) 63 | ]) 64 | self.resnet_conv_second = nn.ModuleList( 65 | [ 66 | nn.Sequential( 67 | nn.GroupNorm(norm_channels, out_channels), 68 | nn.SiLU(), 69 | nn.Conv2d(out_channels, out_channels, 70 | kernel_size=3, stride=1, padding=1), 71 | ) 72 | for _ in range(num_layers) 73 | ] 74 | ) 75 | 76 | if self.attn: 77 | self.attention_norms = nn.ModuleList( 78 | [nn.GroupNorm(norm_channels, out_channels) 79 | for _ in range(num_layers)] 80 | ) 81 | 82 | self.attentions = nn.ModuleList( 83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 84 | for _ in range(num_layers)] 85 | ) 86 | 87 | if self.cross_attn: 88 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 89 | self.cross_attention_norms = nn.ModuleList( 90 | [nn.GroupNorm(norm_channels, out_channels) 91 | for _ in range(num_layers)] 92 | ) 93 | self.cross_attentions = nn.ModuleList( 94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 95 | for _ in range(num_layers)] 96 | ) 97 | self.context_proj = nn.ModuleList( 98 | [nn.Linear(context_dim, out_channels) 99 | for _ in range(num_layers)] 100 | ) 101 | 102 | self.residual_input_conv = nn.ModuleList( 103 | [ 104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 105 | for i in range(num_layers) 106 | ] 107 | ) 108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 109 | 4, 2, 1) if self.down_sample else nn.Identity() 110 | 111 | def forward(self, x, t_emb=None, context=None): 112 | out = x 113 | for i in range(self.num_layers): 114 | # Resnet block of Unet 115 | resnet_input = out 116 | out = self.resnet_conv_first[i](out) 117 | if self.t_emb_dim is not None: 118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 119 | out = self.resnet_conv_second[i](out) 120 | out = out + self.residual_input_conv[i](resnet_input) 121 | 122 | if self.attn: 123 | # Attention block of Unet 124 | batch_size, channels, h, w = out.shape 125 | in_attn = out.reshape(batch_size, channels, h * w) 126 | in_attn = self.attention_norms[i](in_attn) 127 | in_attn = in_attn.transpose(1, 2) 128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 130 | out = out + out_attn 131 | 132 | if self.cross_attn: 133 | assert context is not None, "context cannot be None if cross attention layers are used" 134 | batch_size, channels, h, w = out.shape 135 | in_attn = out.reshape(batch_size, channels, h * w) 136 | in_attn = self.cross_attention_norms[i](in_attn) 137 | in_attn = in_attn.transpose(1, 2) 138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 139 | context_proj = self.context_proj[i](context) 140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 142 | out = out + out_attn 143 | 144 | # Downsample 145 | out = self.down_sample_conv(out) 146 | return out 147 | 148 | 149 | class MidBlock(nn.Module): 150 | r""" 151 | Mid conv block with attention. 152 | Sequence of following blocks 153 | 1. Resnet block with time embedding 154 | 2. Attention block 155 | 3. Resnet block with time embedding 156 | """ 157 | 158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None): 159 | super().__init__() 160 | self.num_layers = num_layers 161 | self.t_emb_dim = t_emb_dim 162 | self.context_dim = context_dim 163 | self.cross_attn = cross_attn 164 | self.resnet_conv_first = nn.ModuleList( 165 | [ 166 | nn.Sequential( 167 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 168 | nn.SiLU(), 169 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 170 | padding=1), 171 | ) 172 | for i in range(num_layers + 1) 173 | ] 174 | ) 175 | 176 | if self.t_emb_dim is not None: 177 | self.t_emb_layers = nn.ModuleList([ 178 | nn.Sequential( 179 | nn.SiLU(), 180 | nn.Linear(t_emb_dim, out_channels) 181 | ) 182 | for _ in range(num_layers + 1) 183 | ]) 184 | self.resnet_conv_second = nn.ModuleList( 185 | [ 186 | nn.Sequential( 187 | nn.GroupNorm(norm_channels, out_channels), 188 | nn.SiLU(), 189 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 190 | ) 191 | for _ in range(num_layers + 1) 192 | ] 193 | ) 194 | 195 | self.attention_norms = nn.ModuleList( 196 | [nn.GroupNorm(norm_channels, out_channels) 197 | for _ in range(num_layers)] 198 | ) 199 | 200 | self.attentions = nn.ModuleList( 201 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 202 | for _ in range(num_layers)] 203 | ) 204 | if self.cross_attn: 205 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 206 | self.cross_attention_norms = nn.ModuleList( 207 | [nn.GroupNorm(norm_channels, out_channels) 208 | for _ in range(num_layers)] 209 | ) 210 | self.cross_attentions = nn.ModuleList( 211 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 212 | for _ in range(num_layers)] 213 | ) 214 | self.context_proj = nn.ModuleList( 215 | [nn.Linear(context_dim, out_channels) 216 | for _ in range(num_layers)] 217 | ) 218 | self.residual_input_conv = nn.ModuleList( 219 | [ 220 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 221 | for i in range(num_layers + 1) 222 | ] 223 | ) 224 | 225 | def forward(self, x, t_emb=None, context=None): 226 | out = x 227 | 228 | # First resnet block 229 | resnet_input = out 230 | out = self.resnet_conv_first[0](out) 231 | if self.t_emb_dim is not None: 232 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] 233 | out = self.resnet_conv_second[0](out) 234 | out = out + self.residual_input_conv[0](resnet_input) 235 | 236 | for i in range(self.num_layers): 237 | # Attention Block 238 | batch_size, channels, h, w = out.shape 239 | in_attn = out.reshape(batch_size, channels, h * w) 240 | in_attn = self.attention_norms[i](in_attn) 241 | in_attn = in_attn.transpose(1, 2) 242 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 243 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 244 | out = out + out_attn 245 | 246 | if self.cross_attn: 247 | assert context is not None, "context cannot be None if cross attention layers are used" 248 | batch_size, channels, h, w = out.shape 249 | in_attn = out.reshape(batch_size, channels, h * w) 250 | in_attn = self.cross_attention_norms[i](in_attn) 251 | in_attn = in_attn.transpose(1, 2) 252 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 253 | context_proj = self.context_proj[i](context) 254 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 255 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 256 | out = out + out_attn 257 | 258 | 259 | # Resnet Block 260 | resnet_input = out 261 | out = self.resnet_conv_first[i + 1](out) 262 | if self.t_emb_dim is not None: 263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] 264 | out = self.resnet_conv_second[i + 1](out) 265 | out = out + self.residual_input_conv[i + 1](resnet_input) 266 | 267 | return out 268 | 269 | 270 | class UpBlock(nn.Module): 271 | r""" 272 | Up conv block with attention. 273 | Sequence of following blocks 274 | 1. Upsample 275 | 1. Concatenate Down block output 276 | 2. Resnet block with time embedding 277 | 3. Attention Block 278 | """ 279 | 280 | def __init__(self, in_channels, out_channels, t_emb_dim, 281 | up_sample, num_heads, num_layers, attn, norm_channels): 282 | super().__init__() 283 | self.num_layers = num_layers 284 | self.up_sample = up_sample 285 | self.t_emb_dim = t_emb_dim 286 | self.attn = attn 287 | self.resnet_conv_first = nn.ModuleList( 288 | [ 289 | nn.Sequential( 290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 291 | nn.SiLU(), 292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 293 | padding=1), 294 | ) 295 | for i in range(num_layers) 296 | ] 297 | ) 298 | 299 | if self.t_emb_dim is not None: 300 | self.t_emb_layers = nn.ModuleList([ 301 | nn.Sequential( 302 | nn.SiLU(), 303 | nn.Linear(t_emb_dim, out_channels) 304 | ) 305 | for _ in range(num_layers) 306 | ]) 307 | 308 | self.resnet_conv_second = nn.ModuleList( 309 | [ 310 | nn.Sequential( 311 | nn.GroupNorm(norm_channels, out_channels), 312 | nn.SiLU(), 313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 314 | ) 315 | for _ in range(num_layers) 316 | ] 317 | ) 318 | if self.attn: 319 | self.attention_norms = nn.ModuleList( 320 | [ 321 | nn.GroupNorm(norm_channels, out_channels) 322 | for _ in range(num_layers) 323 | ] 324 | ) 325 | 326 | self.attentions = nn.ModuleList( 327 | [ 328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 329 | for _ in range(num_layers) 330 | ] 331 | ) 332 | 333 | self.residual_input_conv = nn.ModuleList( 334 | [ 335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 336 | for i in range(num_layers) 337 | ] 338 | ) 339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, 340 | 4, 2, 1) \ 341 | if self.up_sample else nn.Identity() 342 | 343 | def forward(self, x, out_down=None, t_emb=None): 344 | # Upsample 345 | x = self.up_sample_conv(x) 346 | 347 | # Concat with Downblock output 348 | if out_down is not None: 349 | x = torch.cat([x, out_down], dim=1) 350 | 351 | out = x 352 | for i in range(self.num_layers): 353 | # Resnet Block 354 | resnet_input = out 355 | out = self.resnet_conv_first[i](out) 356 | if self.t_emb_dim is not None: 357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 358 | out = self.resnet_conv_second[i](out) 359 | out = out + self.residual_input_conv[i](resnet_input) 360 | 361 | # Self Attention 362 | if self.attn: 363 | batch_size, channels, h, w = out.shape 364 | in_attn = out.reshape(batch_size, channels, h * w) 365 | in_attn = self.attention_norms[i](in_attn) 366 | in_attn = in_attn.transpose(1, 2) 367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 369 | out = out + out_attn 370 | return out 371 | 372 | 373 | class UpBlockUnet(nn.Module): 374 | r""" 375 | Up conv block with attention. 376 | Sequence of following blocks 377 | 1. Upsample 378 | 1. Concatenate Down block output 379 | 2. Resnet block with time embedding 380 | 3. Attention Block 381 | """ 382 | 383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, 384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): 385 | super().__init__() 386 | self.num_layers = num_layers 387 | self.up_sample = up_sample 388 | self.t_emb_dim = t_emb_dim 389 | self.cross_attn = cross_attn 390 | self.context_dim = context_dim 391 | self.resnet_conv_first = nn.ModuleList( 392 | [ 393 | nn.Sequential( 394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 395 | nn.SiLU(), 396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 397 | padding=1), 398 | ) 399 | for i in range(num_layers) 400 | ] 401 | ) 402 | 403 | if self.t_emb_dim is not None: 404 | self.t_emb_layers = nn.ModuleList([ 405 | nn.Sequential( 406 | nn.SiLU(), 407 | nn.Linear(t_emb_dim, out_channels) 408 | ) 409 | for _ in range(num_layers) 410 | ]) 411 | 412 | self.resnet_conv_second = nn.ModuleList( 413 | [ 414 | nn.Sequential( 415 | nn.GroupNorm(norm_channels, out_channels), 416 | nn.SiLU(), 417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 418 | ) 419 | for _ in range(num_layers) 420 | ] 421 | ) 422 | 423 | self.attention_norms = nn.ModuleList( 424 | [ 425 | nn.GroupNorm(norm_channels, out_channels) 426 | for _ in range(num_layers) 427 | ] 428 | ) 429 | 430 | self.attentions = nn.ModuleList( 431 | [ 432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 433 | for _ in range(num_layers) 434 | ] 435 | ) 436 | 437 | if self.cross_attn: 438 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 439 | self.cross_attention_norms = nn.ModuleList( 440 | [nn.GroupNorm(norm_channels, out_channels) 441 | for _ in range(num_layers)] 442 | ) 443 | self.cross_attentions = nn.ModuleList( 444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 445 | for _ in range(num_layers)] 446 | ) 447 | self.context_proj = nn.ModuleList( 448 | [nn.Linear(context_dim, out_channels) 449 | for _ in range(num_layers)] 450 | ) 451 | self.residual_input_conv = nn.ModuleList( 452 | [ 453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 454 | for i in range(num_layers) 455 | ] 456 | ) 457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 458 | 4, 2, 1) \ 459 | if self.up_sample else nn.Identity() 460 | 461 | def forward(self, x, out_down=None, t_emb=None, context=None): 462 | x = self.up_sample_conv(x) 463 | if out_down is not None: 464 | x = torch.cat([x, out_down], dim=1) 465 | 466 | out = x 467 | for i in range(self.num_layers): 468 | # Resnet 469 | resnet_input = out 470 | out = self.resnet_conv_first[i](out) 471 | if self.t_emb_dim is not None: 472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 473 | out = self.resnet_conv_second[i](out) 474 | out = out + self.residual_input_conv[i](resnet_input) 475 | # Self Attention 476 | batch_size, channels, h, w = out.shape 477 | in_attn = out.reshape(batch_size, channels, h * w) 478 | in_attn = self.attention_norms[i](in_attn) 479 | in_attn = in_attn.transpose(1, 2) 480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 482 | out = out + out_attn 483 | # Cross Attention 484 | if self.cross_attn: 485 | assert context is not None, "context cannot be None if cross attention layers are used" 486 | batch_size, channels, h, w = out.shape 487 | in_attn = out.reshape(batch_size, channels, h * w) 488 | in_attn = self.cross_attention_norms[i](in_attn) 489 | in_attn = in_attn.transpose(1, 2) 490 | assert len(context.shape) == 3, \ 491 | "Context shape does not match B,_,CONTEXT_DIM" 492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ 493 | "Context shape does not match B,_,CONTEXT_DIM" 494 | context_proj = self.context_proj[i](context) 495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 497 | out = out + out_attn 498 | 499 | return out 500 | 501 | 502 | --------------------------------------------------------------------------------