├── data
└── .gitkeep
├── config
├── __init__.py
├── mnist.yaml
├── celebhq.yaml
├── mnist_class_cond.yaml
├── celebhq_text_cond.yaml
├── celebhq_text_cond_bert.yaml
└── celebhq_text_image_cond.yaml
├── models
├── __init__.py
├── weights
│ ├── .gitkeep
│ └── v0.1
│ │ └── .gitkeep
├── discriminator.py
├── unet_base.py
├── vae.py
├── lpips.py
├── vqvae.py
├── unet_cond_base.py
└── blocks.py
├── tools
├── __init__.py
├── sample_ddpm_vqvae.py
├── train_ddpm_vqvae.py
├── infer_vqvae.py
├── sample_ddpm_class_cond.py
├── sample_ddpm_text_cond.py
├── train_ddpm_cond.py
├── train_vqvae.py
└── sample_ddpm_text_image_cond.py
├── utils
├── __init__.py
├── create_celeb_mask.py
├── text_utils.py
├── diffusion_utils.py
└── config_utils.py
├── dataset
├── __init__.py
├── mnist_dataset.py
└── celeb_dataset.py
├── scheduler
├── __init__.py
└── linear_noise_scheduler.py
├── .github
└── FUNDING.yml
├── .gitignore
├── requirements.txt
├── LICENSE
└── README.md
/data/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/weights/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/weights/v0.1/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: explainingai-code
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all image files
2 | *.jpg
3 | *.png
4 | *.jpeg
5 |
6 | # Ignore pycharm and system files
7 | .DS_Store
8 | *.idea
9 | __pycache__
10 | *.zip
11 |
12 | # Ignore dataset files
13 | *.csv
14 | *.json
15 |
16 | # Ignore checkpoints
17 | *.pth
18 |
19 | # Ignore pickle files
20 | *.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | boto3==1.34.36
2 | botocore==1.34.36
3 | certifi==2023.7.22
4 | charset-normalizer==3.2.0
5 | click==8.1.7
6 | contourpy==1.1.0
7 | cycler==0.11.0
8 | einops==0.6.1
9 | filelock==3.13.1
10 | fonttools==4.42.0
11 | fsspec==2024.2.0
12 | ftfy==6.1.3
13 | huggingface-hub==0.20.3
14 | idna==3.4
15 | importlib-metadata==7.0.1
16 | importlib-resources==6.0.1
17 | jmespath==1.0.1
18 | joblib==1.3.2
19 | kiwisolver==1.4.4
20 | matplotlib==3.7.2
21 | numpy==1.23.5
22 | opencv-python==4.8.0.74
23 | packaging==23.1
24 | Pillow==10.0.0
25 | pyparsing==3.0.9
26 | python-dateutil==2.8.2
27 | PyYAML==6.0
28 | regex==2023.12.25
29 | requests==2.31.0
30 | s3transfer==0.10.0
31 | sacremoses==0.1.1
32 | safetensors==0.4.2
33 | scipy==1.10.1
34 | sentencepiece==0.1.99
35 | six==1.16.0
36 | tokenizers==0.15.1
37 | torch==1.11.0
38 | torchvision==0.12.0
39 | tqdm==4.65.0
40 | transformers==4.37.2
41 | typing_extensions==4.7.1
42 | urllib3==1.26.18
43 | wcwidth==0.2.13
44 | zipp==3.16.2
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ExplainingAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/create_celeb_mask.py:
--------------------------------------------------------------------------------
1 | """
2 | This is exact copy of https://github.com/switchablenorms/CelebAMask-HQ/blob/master/face_parsing/Data_preprocessing/g_mask.py
3 | Only the folder locations are modified.
4 | """
5 |
6 | import os
7 | import cv2
8 | import glob
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
13 |
14 | folder_base = 'data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
15 | folder_save = 'data/CelebAMask-HQ/CelebAMask-HQ-mask'
16 | img_num = 30000
17 |
18 | if not os.path.exists(folder_save):
19 | os.mkdir(folder_save)
20 |
21 | for k in tqdm(range(img_num)):
22 | folder_num = k // 2000
23 | im_base = np.zeros((512, 512))
24 | for idx, label in enumerate(label_list):
25 | filename = os.path.join(folder_base, str(folder_num), str(k).rjust(5, '0') + '_' + label + '.png')
26 | if os.path.exists(filename):
27 | im = cv2.imread(filename)
28 | im = im[:, :, 0]
29 | im_base[im != 0] = (idx + 1)
30 |
31 | filename_save = os.path.join(folder_save, str(k) + '.png')
32 | cv2.imwrite(filename_save, im_base)
33 |
--------------------------------------------------------------------------------
/utils/text_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import DistilBertModel, DistilBertTokenizer, CLIPTokenizer, CLIPTextModel
3 |
4 |
5 | def get_tokenizer_and_model(model_type, device, eval_mode=True):
6 | assert model_type in ('bert', 'clip'), "Text model can only be one of clip or bert"
7 | if model_type == 'bert':
8 | text_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
9 | text_model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
10 | else:
11 | text_tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch16')
12 | text_model = CLIPTextModel.from_pretrained('openai/clip-vit-base-patch16').to(device)
13 | if eval_mode:
14 | text_model.eval()
15 | return text_tokenizer, text_model
16 |
17 |
18 | def get_text_representation(text, text_tokenizer, text_model, device,
19 | truncation=True,
20 | padding='max_length',
21 | max_length=77):
22 | token_output = text_tokenizer(text,
23 | truncation=truncation,
24 | padding=padding,
25 | return_attention_mask=True,
26 | max_length=max_length)
27 | indexed_tokens = token_output['input_ids']
28 | att_masks = token_output['attention_mask']
29 | tokens_tensor = torch.tensor(indexed_tokens).to(device)
30 | mask_tensor = torch.tensor(att_masks).to(device)
31 | text_embed = text_model(tokens_tensor, attention_mask=mask_tensor).last_hidden_state
32 | return text_embed
33 |
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Discriminator(nn.Module):
6 | r"""
7 | PatchGAN Discriminator.
8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
9 | 1 scalar value , we instead predict grid of values.
10 | Where each grid is prediction of how likely
11 | the discriminator thinks that the image patch corresponding
12 | to the grid cell is real
13 | """
14 |
15 | def __init__(self, im_channels=3,
16 | conv_channels=[64, 128, 256],
17 | kernels=[4,4,4,4],
18 | strides=[2,2,2,1],
19 | paddings=[1,1,1,1]):
20 | super().__init__()
21 | self.im_channels = im_channels
22 | activation = nn.LeakyReLU(0.2)
23 | layers_dim = [self.im_channels] + conv_channels + [1]
24 | self.layers = nn.ModuleList([
25 | nn.Sequential(
26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1],
27 | kernel_size=kernels[i],
28 | stride=strides[i],
29 | padding=paddings[i],
30 | bias=False if i !=0 else True),
31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
32 | activation if i != len(layers_dim) - 2 else nn.Identity()
33 | )
34 | for i in range(len(layers_dim) - 1)
35 | ])
36 |
37 | def forward(self, x):
38 | out = x
39 | for layer in self.layers:
40 | out = layer(out)
41 | return out
42 |
43 |
44 | if __name__ == '__main__':
45 | x = torch.randn((2,3, 256, 256))
46 | prob = Discriminator(im_channels=3)(x)
47 | print(prob.shape)
48 |
--------------------------------------------------------------------------------
/config/mnist.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/mnist/train/images'
3 | im_channels : 1
4 | im_size : 28
5 | name: 'mnist'
6 |
7 | diffusion_params:
8 | num_timesteps : 1000
9 | beta_start : 0.0015
10 | beta_end : 0.0195
11 |
12 | ldm_params:
13 | down_channels: [ 128, 256, 256, 256]
14 | mid_channels: [ 256, 256]
15 | down_sample: [ False, False, False ]
16 | attn_down : [True, True, True]
17 | time_emb_dim: 256
18 | norm_channels : 32
19 | num_heads : 16
20 | conv_out_channels : 128
21 | num_down_layers: 2
22 | num_mid_layers: 2
23 | num_up_layers: 2
24 |
25 | autoencoder_params:
26 | z_channels: 3
27 | codebook_size : 20
28 | down_channels : [32, 64, 128]
29 | mid_channels : [128, 128]
30 | down_sample : [True, True]
31 | attn_down : [False, False]
32 | norm_channels: 32
33 | num_heads: 16
34 | num_down_layers : 1
35 | num_mid_layers : 1
36 | num_up_layers : 1
37 |
38 | train_params:
39 | seed : 1111
40 | task_name: 'mnist'
41 | ldm_batch_size: 64
42 | autoencoder_batch_size: 64
43 | disc_start: 1000
44 | disc_weight: 0.5
45 | codebook_weight: 1
46 | commitment_beta: 0.2
47 | perceptual_weight: 1
48 | kl_weight: 0.000005
49 | ldm_epochs : 100
50 | autoencoder_epochs : 10
51 | num_samples : 25
52 | num_grid_rows : 5
53 | ldm_lr: 0.00001
54 | autoencoder_lr: 0.0001
55 | autoencoder_acc_steps : 1
56 | autoencoder_img_save_steps : 8
57 | save_latents : False
58 | vae_latent_dir_name : 'vae_latents'
59 | vqvae_latent_dir_name : 'vqvae_latents'
60 | ldm_ckpt_name: 'ddpm_ckpt.pth'
61 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
62 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
63 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
64 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
65 |
--------------------------------------------------------------------------------
/config/celebhq.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/CelebAMask-HQ'
3 | im_channels : 3
4 | im_size : 256
5 | name: 'celebhq'
6 |
7 | diffusion_params:
8 | num_timesteps : 1000
9 | beta_start : 0.0015
10 | beta_end : 0.0195
11 |
12 | ldm_params:
13 | down_channels: [ 256, 384, 512, 768 ]
14 | mid_channels: [ 768, 512 ]
15 | down_sample: [ True, True, True ]
16 | attn_down : [True, True, True]
17 | time_emb_dim: 512
18 | norm_channels: 32
19 | num_heads: 16
20 | conv_out_channels : 128
21 | num_down_layers : 2
22 | num_mid_layers : 2
23 | num_up_layers : 2
24 |
25 | autoencoder_params:
26 | z_channels: 4
27 | codebook_size : 8192
28 | down_channels : [64, 128, 256, 256]
29 | mid_channels : [256, 256]
30 | down_sample : [True, True, True]
31 | attn_down : [False, False, False]
32 | norm_channels: 32
33 | num_heads: 4
34 | num_down_layers : 2
35 | num_mid_layers : 2
36 | num_up_layers : 2
37 |
38 |
39 | train_params:
40 | seed : 1111
41 | task_name: 'celebhq'
42 | ldm_batch_size: 16
43 | autoencoder_batch_size: 4
44 | disc_start: 15000
45 | disc_weight: 0.5
46 | codebook_weight: 1
47 | commitment_beta: 0.2
48 | perceptual_weight: 1
49 | kl_weight: 0.000005
50 | ldm_epochs: 100
51 | autoencoder_epochs: 20
52 | num_samples: 1
53 | num_grid_rows: 1
54 | ldm_lr: 0.000005
55 | autoencoder_lr: 0.00001
56 | autoencoder_acc_steps: 4
57 | autoencoder_img_save_steps: 64
58 | save_latents : False
59 | vae_latent_dir_name: 'vae_latents'
60 | vqvae_latent_dir_name: 'vqvae_latents'
61 | ldm_ckpt_name: 'ddpm_ckpt.pth'
62 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
63 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
64 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
65 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
66 |
--------------------------------------------------------------------------------
/utils/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import glob
3 | import os
4 | import torch
5 |
6 |
7 | def load_latents(latent_path):
8 | r"""
9 | Simple utility to save latents to speed up ldm training
10 | :param latent_path:
11 | :return:
12 | """
13 | latent_maps = {}
14 | for fname in glob.glob(os.path.join(latent_path, '*.pkl')):
15 | s = pickle.load(open(fname, 'rb'))
16 | for k, v in s.items():
17 | latent_maps[k] = v[0]
18 | return latent_maps
19 |
20 |
21 | def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob):
22 | if text_drop_prob > 0:
23 | text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0,
24 | 1) < text_drop_prob
25 | assert empty_text_embed is not None, ("Text Conditioning required as well as"
26 | " text dropping but empty text representation not created")
27 | text_embed[text_drop_mask, :, :] = empty_text_embed[0]
28 | return text_embed
29 |
30 |
31 | def drop_image_condition(image_condition, im, im_drop_prob):
32 | if im_drop_prob > 0:
33 | im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0,
34 | 1) > im_drop_prob
35 | return image_condition * im_drop_mask
36 | else:
37 | return image_condition
38 |
39 |
40 | def drop_class_condition(class_condition, class_drop_prob, im):
41 | if class_drop_prob > 0:
42 | class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0,
43 | 1) > class_drop_prob
44 | return class_condition * class_drop_mask
45 | else:
46 | return class_condition
--------------------------------------------------------------------------------
/config/mnist_class_cond.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/mnist/train/images'
3 | im_channels : 1
4 | im_size : 28
5 | name: 'mnist'
6 |
7 | diffusion_params:
8 | num_timesteps : 1000
9 | beta_start : 0.0015
10 | beta_end : 0.0195
11 |
12 | ldm_params:
13 | down_channels: [ 128, 256, 256, 256]
14 | mid_channels: [ 256, 256]
15 | down_sample: [ False, False, False ]
16 | attn_down : [True, True, True]
17 | time_emb_dim: 256
18 | norm_channels : 32
19 | num_heads : 16
20 | conv_out_channels : 128
21 | num_down_layers: 2
22 | num_mid_layers: 2
23 | num_up_layers: 2
24 | condition_config:
25 | condition_types: ['class']
26 | class_condition_config :
27 | num_classes : 10
28 | cond_drop_prob : 0.1
29 |
30 | autoencoder_params:
31 | z_channels: 3
32 | codebook_size : 20
33 | down_channels : [32, 64, 128]
34 | mid_channels : [128, 128]
35 | down_sample : [True, True]
36 | attn_down : [False, False]
37 | norm_channels: 32
38 | num_heads: 16
39 | num_down_layers : 1
40 | num_mid_layers : 1
41 | num_up_layers : 1
42 |
43 | train_params:
44 | seed : 1111
45 | task_name: 'mnist'
46 | ldm_batch_size: 64
47 | autoencoder_batch_size: 64
48 | disc_start: 1000
49 | disc_weight: 0.5
50 | codebook_weight: 1
51 | commitment_beta: 0.2
52 | perceptual_weight: 1
53 | kl_weight: 0.000005
54 | ldm_epochs : 100
55 | autoencoder_epochs : 10
56 | num_samples : 25
57 | num_grid_rows : 5
58 | ldm_lr: 0.00001
59 | autoencoder_lr: 0.0001
60 | autoencoder_acc_steps : 1
61 | autoencoder_img_save_steps : 8
62 | save_latents : False
63 | cf_guidance_scale : 1.0
64 | vae_latent_dir_name : 'vae_latents'
65 | vqvae_latent_dir_name : 'vqvae_latents'
66 | ldm_ckpt_name: 'ddpm_ckpt_class_cond.pth'
67 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
68 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
69 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
70 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
71 |
--------------------------------------------------------------------------------
/config/celebhq_text_cond.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/CelebAMask-HQ'
3 | im_channels : 3
4 | im_size : 256
5 | name: 'celebhq'
6 |
7 | diffusion_params:
8 | num_timesteps : 1000
9 | beta_start : 0.00085
10 | beta_end : 0.012
11 |
12 | ldm_params:
13 | down_channels: [ 256, 384, 512, 768 ]
14 | mid_channels: [ 768, 512 ]
15 | down_sample: [ True, True, True ]
16 | attn_down : [True, True, True]
17 | time_emb_dim: 512
18 | norm_channels: 32
19 | num_heads: 16
20 | conv_out_channels : 128
21 | num_down_layers : 2
22 | num_mid_layers : 2
23 | num_up_layers : 2
24 | condition_config:
25 | condition_types: [ 'text' ]
26 | text_condition_config:
27 | text_embed_model: 'clip'
28 | train_text_embed_model: False
29 | text_embed_dim: 512
30 | cond_drop_prob: 0.1
31 |
32 | autoencoder_params:
33 | z_channels: 4
34 | codebook_size : 8192
35 | down_channels : [64, 128, 256, 256]
36 | mid_channels : [256, 256]
37 | down_sample : [True, True, True]
38 | attn_down : [False, False, False]
39 | norm_channels: 32
40 | num_heads: 4
41 | num_down_layers : 2
42 | num_mid_layers : 2
43 | num_up_layers : 2
44 |
45 |
46 | train_params:
47 | seed : 1111
48 | task_name: 'celebhq'
49 | ldm_batch_size: 16
50 | autoencoder_batch_size: 4
51 | disc_start: 15000
52 | disc_weight: 0.5
53 | codebook_weight: 1
54 | commitment_beta: 0.2
55 | perceptual_weight: 1
56 | kl_weight: 0.000005
57 | ldm_epochs: 100
58 | autoencoder_epochs: 20
59 | num_samples: 1
60 | num_grid_rows: 1
61 | ldm_lr: 0.000005
62 | autoencoder_lr: 0.00001
63 | autoencoder_acc_steps: 4
64 | autoencoder_img_save_steps: 64
65 | save_latents : False
66 | cf_guidance_scale : 1.0
67 | vae_latent_dir_name: 'vae_latents'
68 | vqvae_latent_dir_name: 'vqvae_latents'
69 | ldm_ckpt_name: 'ddpm_ckpt_text_cond_clip.pth'
70 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
71 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
72 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
73 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
74 |
--------------------------------------------------------------------------------
/config/celebhq_text_cond_bert.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/CelebAMask-HQ'
3 | im_channels : 3
4 | im_size : 256
5 | name: 'celebhq'
6 |
7 | diffusion_params:
8 | num_timesteps : 1000
9 | beta_start : 0.00085
10 | beta_end : 0.012
11 |
12 | ldm_params:
13 | down_channels: [ 256, 384, 512, 768 ]
14 | mid_channels: [ 768, 512 ]
15 | down_sample: [ True, True, True ]
16 | attn_down : [True, True, True]
17 | time_emb_dim: 512
18 | norm_channels: 32
19 | num_heads: 16
20 | conv_out_channels : 128
21 | num_down_layers : 2
22 | num_mid_layers : 2
23 | num_up_layers : 2
24 | condition_config:
25 | condition_types: [ 'text' ]
26 | text_condition_config:
27 | text_embed_model: 'bert'
28 | train_text_embed_model: False
29 | text_embed_dim: 768
30 | cond_drop_prob: 0.1
31 |
32 | autoencoder_params:
33 | z_channels: 4
34 | codebook_size : 8192
35 | down_channels : [64, 128, 256, 256]
36 | mid_channels : [256, 256]
37 | down_sample : [True, True, True]
38 | attn_down : [False, False, False]
39 | norm_channels: 32
40 | num_heads: 4
41 | num_down_layers : 2
42 | num_mid_layers : 2
43 | num_up_layers : 2
44 |
45 |
46 | train_params:
47 | seed : 1111
48 | task_name: 'celebhq'
49 | ldm_batch_size: 16
50 | autoencoder_batch_size: 4
51 | disc_start: 15000
52 | disc_weight: 0.5
53 | codebook_weight: 1
54 | commitment_beta: 0.2
55 | perceptual_weight: 1
56 | kl_weight: 0.000005
57 | ldm_epochs: 100
58 | autoencoder_epochs: 20
59 | num_samples: 1
60 | num_grid_rows: 1
61 | ldm_lr: 0.000005
62 | autoencoder_lr: 0.00001
63 | autoencoder_acc_steps: 4
64 | autoencoder_img_save_steps: 64
65 | save_latents : False
66 | cf_guidance_scale : 1.0
67 | vae_latent_dir_name: 'vae_latents'
68 | vqvae_latent_dir_name: 'vqvae_latents'
69 | ldm_ckpt_name: 'ddpm_ckpt_text_cond_bert.pth'
70 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
71 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
72 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
73 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
74 |
--------------------------------------------------------------------------------
/utils/config_utils.py:
--------------------------------------------------------------------------------
1 |
2 | def validate_class_config(condition_config):
3 | assert 'class_condition_config' in condition_config, \
4 | "Class conditioning desired but class condition config missing"
5 | assert 'num_classes' in condition_config['class_condition_config'], \
6 | "num_class missing in class condition config"
7 |
8 |
9 | def validate_text_config(condition_config):
10 | assert 'text_condition_config' in condition_config, \
11 | "Text conditioning desired but text condition config missing"
12 | assert 'text_embed_dim' in condition_config['text_condition_config'], \
13 | "text_embed_dim missing in text condition config"
14 |
15 |
16 | def validate_image_config(condition_config):
17 | assert 'image_condition_config' in condition_config, \
18 | "Image conditioning desired but image condition config missing"
19 | assert 'image_condition_input_channels' in condition_config['image_condition_config'], \
20 | "image_condition_input_channels missing in image condition config"
21 | assert 'image_condition_output_channels' in condition_config['image_condition_config'], \
22 | "image_condition_output_channels missing in image condition config"
23 |
24 |
25 | def validate_image_conditional_input(cond_input, x):
26 | assert 'image' in cond_input, \
27 | "Model initialized with image conditioning but cond_input has no image information"
28 | assert cond_input['image'].shape[0] == x.shape[0], \
29 | "Batch size mismatch of image condition and input"
30 | assert cond_input['image'].shape[2] % x.shape[2] == 0, \
31 | "Height/Width of image condition must be divisible by latent input"
32 |
33 |
34 | def validate_class_conditional_input(cond_input, x, num_classes):
35 | assert 'class' in cond_input, \
36 | "Model initialized with class conditioning but cond_input has no class information"
37 | assert cond_input['class'].shape == (x.shape[0], num_classes), \
38 | "Shape of class condition input must match (Batch Size, )"
39 | def get_config_value(config, key, default_value):
40 | return config[key] if key in config else default_value
--------------------------------------------------------------------------------
/config/celebhq_text_image_cond.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/CelebAMask-HQ'
3 | im_channels : 3
4 | im_size : 256
5 | name: 'celebhq'
6 |
7 | diffusion_params:
8 | num_timesteps : 1000
9 | beta_start: 0.00085
10 | beta_end: 0.012
11 |
12 | ldm_params:
13 | down_channels: [ 256, 384, 512, 768 ]
14 | mid_channels: [ 768, 512 ]
15 | down_sample: [ True, True, True ]
16 | attn_down : [True, True, True]
17 | time_emb_dim: 512
18 | norm_channels: 32
19 | num_heads: 16
20 | conv_out_channels : 128
21 | num_down_layers : 2
22 | num_mid_layers : 2
23 | num_up_layers : 2
24 | condition_config:
25 | condition_types: [ 'text', 'image' ]
26 | text_condition_config:
27 | text_embed_model: 'clip'
28 | train_text_embed_model: False
29 | text_embed_dim: 512
30 | cond_drop_prob: 0.1
31 | image_condition_config:
32 | image_condition_input_channels: 18
33 | image_condition_output_channels: 3
34 | image_condition_h : 512
35 | image_condition_w : 512
36 | cond_drop_prob: 0.1
37 |
38 |
39 | autoencoder_params:
40 | z_channels: 4
41 | codebook_size : 8192
42 | down_channels : [64, 128, 256, 256]
43 | mid_channels : [256, 256]
44 | down_sample : [True, True, True]
45 | attn_down : [False, False, False]
46 | norm_channels: 32
47 | num_heads: 4
48 | num_down_layers : 2
49 | num_mid_layers : 2
50 | num_up_layers : 2
51 |
52 |
53 | train_params:
54 | seed : 1111
55 | task_name: 'celebhq'
56 | ldm_batch_size: 16
57 | autoencoder_batch_size: 4
58 | disc_start: 15000
59 | disc_weight: 0.5
60 | codebook_weight: 1
61 | commitment_beta: 0.2
62 | perceptual_weight: 1
63 | kl_weight: 0.000005
64 | ldm_epochs: 100
65 | autoencoder_epochs: 20
66 | num_samples: 1
67 | num_grid_rows: 1
68 | ldm_lr: 0.000005
69 | autoencoder_lr: 0.00001
70 | autoencoder_acc_steps: 4
71 | autoencoder_img_save_steps: 64
72 | save_latents : False
73 | cf_guidance_scale : 1.0
74 | vae_latent_dir_name: 'vae_latents'
75 | vqvae_latent_dir_name: 'vqvae_latents'
76 | ldm_ckpt_name: 'ddpm_ckpt_text_image_cond_clip.pth'
77 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
78 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
79 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
80 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
81 |
--------------------------------------------------------------------------------
/scheduler/linear_noise_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class LinearNoiseScheduler:
6 | r"""
7 | Class for the linear noise scheduler that is used in DDPM.
8 | """
9 |
10 | def __init__(self, num_timesteps, beta_start, beta_end):
11 | self.num_timesteps = num_timesteps
12 | self.beta_start = beta_start
13 | self.beta_end = beta_end
14 | # Mimicking how compvis repo creates schedule
15 | self.betas = (
16 | torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2
17 | )
18 | self.alphas = 1. - self.betas
19 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
20 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
21 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
22 |
23 | def add_noise(self, original, noise, t):
24 | r"""
25 | Forward method for diffusion
26 | :param original: Image on which noise is to be applied
27 | :param noise: Random Noise Tensor (from normal dist)
28 | :param t: timestep of the forward process of shape -> (B,)
29 | :return:
30 | """
31 | original_shape = original.shape
32 | batch_size = original_shape[0]
33 |
34 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
35 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
36 |
37 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
38 | for _ in range(len(original_shape) - 1):
39 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
40 | for _ in range(len(original_shape) - 1):
41 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
42 |
43 | # Apply and Return Forward process equation
44 | return (sqrt_alpha_cum_prod.to(original.device) * original
45 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
46 |
47 | def sample_prev_timestep(self, xt, noise_pred, t):
48 | r"""
49 | Use the noise prediction by model to get
50 | xt-1 using xt and the nosie predicted
51 | :param xt: current timestep sample
52 | :param noise_pred: model noise prediction
53 | :param t: current timestep we are at
54 | :return:
55 | """
56 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
57 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
58 | x0 = torch.clamp(x0, -1., 1.)
59 |
60 | mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
61 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
62 |
63 | if t == 0:
64 | return mean, x0
65 | else:
66 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
67 | variance = variance * self.betas.to(xt.device)[t]
68 | sigma = variance ** 0.5
69 | z = torch.randn(xt.shape).to(xt.device)
70 |
71 | # OR
72 | # variance = self.betas[t]
73 | # sigma = variance ** 0.5
74 | # z = torch.randn(xt.shape).to(xt.device)
75 | return mean + sigma * z, x0
76 |
--------------------------------------------------------------------------------
/dataset/mnist_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import torchvision
4 | from PIL import Image
5 | from tqdm import tqdm
6 | from utils.diffusion_utils import load_latents
7 | from torch.utils.data.dataset import Dataset
8 |
9 |
10 | class MnistDataset(Dataset):
11 | r"""
12 | Nothing special here. Just a simple dataset class for mnist images.
13 | Created a dataset class rather using torchvision to allow
14 | replacement with any other image dataset
15 | """
16 |
17 | def __init__(self, split, im_path, im_size, im_channels,
18 | use_latents=False, latent_path=None, condition_config=None):
19 | r"""
20 | Init method for initializing the dataset properties
21 | :param split: train/test to locate the image files
22 | :param im_path: root folder of images
23 | :param im_ext: image extension. assumes all
24 | images would be this type.
25 | """
26 | self.split = split
27 | self.im_size = im_size
28 | self.im_channels = im_channels
29 |
30 | # Should we use latents or not
31 | self.latent_maps = None
32 | self.use_latents = False
33 |
34 | # Conditioning for the dataset
35 | self.condition_types = [] if condition_config is None else condition_config['condition_types']
36 |
37 | self.images, self.labels = self.load_images(im_path)
38 |
39 | # Whether to load images and call vae or to load latents
40 | if use_latents and latent_path is not None:
41 | latent_maps = load_latents(latent_path)
42 | if len(latent_maps) == len(self.images):
43 | self.use_latents = True
44 | self.latent_maps = latent_maps
45 | print('Found {} latents'.format(len(self.latent_maps)))
46 | else:
47 | print('Latents not found')
48 |
49 | def load_images(self, im_path):
50 | r"""
51 | Gets all images from the path specified
52 | and stacks them all up
53 | :param im_path:
54 | :return:
55 | """
56 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
57 | ims = []
58 | labels = []
59 | for d_name in tqdm(os.listdir(im_path)):
60 | fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png')))
61 | fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg')))
62 | fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg')))
63 | for fname in fnames:
64 | ims.append(fname)
65 | if 'class' in self.condition_types:
66 | labels.append(int(d_name))
67 | print('Found {} images for split {}'.format(len(ims), self.split))
68 | return ims, labels
69 |
70 | def __len__(self):
71 | return len(self.images)
72 |
73 | def __getitem__(self, index):
74 | ######## Set Conditioning Info ########
75 | cond_inputs = {}
76 | if 'class' in self.condition_types:
77 | cond_inputs['class'] = self.labels[index]
78 | #######################################
79 |
80 | if self.use_latents:
81 | latent = self.latent_maps[self.images[index]]
82 | if len(self.condition_types) == 0:
83 | return latent
84 | else:
85 | return latent, cond_inputs
86 | else:
87 | im = Image.open(self.images[index])
88 | im_tensor = torchvision.transforms.ToTensor()(im)
89 |
90 | # Convert input to -1 to 1 range.
91 | im_tensor = (2 * im_tensor) - 1
92 | if len(self.condition_types) == 0:
93 | return im_tensor
94 | else:
95 | return im_tensor, cond_inputs
96 |
97 |
--------------------------------------------------------------------------------
/models/unet_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.blocks import get_time_embedding
4 | from models.blocks import DownBlock, MidBlock, UpBlockUnet
5 |
6 |
7 | class Unet(nn.Module):
8 | r"""
9 | Unet model comprising
10 | Down blocks, Midblocks and Uplocks
11 | """
12 |
13 | def __init__(self, im_channels, model_config):
14 | super().__init__()
15 | self.down_channels = model_config['down_channels']
16 | self.mid_channels = model_config['mid_channels']
17 | self.t_emb_dim = model_config['time_emb_dim']
18 | self.down_sample = model_config['down_sample']
19 | self.num_down_layers = model_config['num_down_layers']
20 | self.num_mid_layers = model_config['num_mid_layers']
21 | self.num_up_layers = model_config['num_up_layers']
22 | self.attns = model_config['attn_down']
23 | self.norm_channels = model_config['norm_channels']
24 | self.num_heads = model_config['num_heads']
25 | self.conv_out_channels = model_config['conv_out_channels']
26 |
27 | assert self.mid_channels[0] == self.down_channels[-1]
28 | assert self.mid_channels[-1] == self.down_channels[-2]
29 | assert len(self.down_sample) == len(self.down_channels) - 1
30 | assert len(self.attns) == len(self.down_channels) - 1
31 |
32 | # Initial projection from sinusoidal time embedding
33 | self.t_proj = nn.Sequential(
34 | nn.Linear(self.t_emb_dim, self.t_emb_dim),
35 | nn.SiLU(),
36 | nn.Linear(self.t_emb_dim, self.t_emb_dim)
37 | )
38 |
39 | self.up_sample = list(reversed(self.down_sample))
40 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
41 |
42 | self.downs = nn.ModuleList([])
43 | for i in range(len(self.down_channels) - 1):
44 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim,
45 | down_sample=self.down_sample[i],
46 | num_heads=self.num_heads,
47 | num_layers=self.num_down_layers,
48 | attn=self.attns[i], norm_channels=self.norm_channels))
49 |
50 | self.mids = nn.ModuleList([])
51 | for i in range(len(self.mid_channels) - 1):
52 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
53 | num_heads=self.num_heads,
54 | num_layers=self.num_mid_layers,
55 | norm_channels=self.norm_channels))
56 |
57 | self.ups = nn.ModuleList([])
58 | for i in reversed(range(len(self.down_channels) - 1)):
59 | self.ups.append(UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
60 | self.t_emb_dim, up_sample=self.down_sample[i],
61 | num_heads=self.num_heads,
62 | num_layers=self.num_up_layers,
63 | norm_channels=self.norm_channels))
64 |
65 | self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
66 | self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
67 |
68 | def forward(self, x, t):
69 | # Shapes assuming downblocks are [C1, C2, C3, C4]
70 | # Shapes assuming midblocks are [C4, C4, C3]
71 | # Shapes assuming downsamples are [True, True, False]
72 | # B x C x H x W
73 | out = self.conv_in(x)
74 | # B x C1 x H x W
75 |
76 | # t_emb -> B x t_emb_dim
77 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
78 | t_emb = self.t_proj(t_emb)
79 |
80 | down_outs = []
81 |
82 | for idx, down in enumerate(self.downs):
83 | down_outs.append(out)
84 | out = down(out, t_emb)
85 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
86 | # out B x C4 x H/4 x W/4
87 |
88 | for mid in self.mids:
89 | out = mid(out, t_emb)
90 | # out B x C3 x H/4 x W/4
91 |
92 | for up in self.ups:
93 | down_out = down_outs.pop()
94 | out = up(out, down_out, t_emb)
95 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
96 | out = self.norm_out(out)
97 | out = nn.SiLU()(out)
98 | out = self.conv_out(out)
99 | # out B x C x H x W
100 | return out
101 |
--------------------------------------------------------------------------------
/tools/sample_ddpm_vqvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | from torchvision.utils import make_grid
7 | from PIL import Image
8 | from tqdm import tqdm
9 | from models.unet_base import Unet
10 | from models.vqvae import VQVAE
11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 |
16 | def sample(model, scheduler, train_config, diffusion_model_config,
17 | autoencoder_model_config, diffusion_config, dataset_config, vae):
18 | r"""
19 | Sample stepwise by going backward one timestep at a time.
20 | We save the x0 predictions
21 | """
22 | im_size = dataset_config['im_size'] // 2**sum(autoencoder_model_config['down_sample'])
23 | xt = torch.randn((train_config['num_samples'],
24 | autoencoder_model_config['z_channels'],
25 | im_size,
26 | im_size)).to(device)
27 |
28 | save_count = 0
29 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
30 | # Get prediction of noise
31 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
32 |
33 | # Use scheduler to get x0 and xt-1
34 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
35 |
36 | # Save x0
37 | #ims = torch.clamp(xt, -1., 1.).detach().cpu()
38 | if i == 0:
39 | # Decode ONLY the final iamge to save time
40 | ims = vae.decode(xt)
41 | else:
42 | ims = xt
43 |
44 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
45 | ims = (ims + 1) / 2
46 | grid = make_grid(ims, nrow=train_config['num_grid_rows'])
47 | img = torchvision.transforms.ToPILImage()(grid)
48 |
49 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
50 | os.mkdir(os.path.join(train_config['task_name'], 'samples'))
51 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
52 | img.close()
53 |
54 |
55 | def infer(args):
56 | # Read the config file #
57 | with open(args.config_path, 'r') as file:
58 | try:
59 | config = yaml.safe_load(file)
60 | except yaml.YAMLError as exc:
61 | print(exc)
62 | print(config)
63 | ########################
64 |
65 | diffusion_config = config['diffusion_params']
66 | dataset_config = config['dataset_params']
67 | diffusion_model_config = config['ldm_params']
68 | autoencoder_model_config = config['autoencoder_params']
69 | train_config = config['train_params']
70 |
71 | # Create the noise scheduler
72 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
73 | beta_start=diffusion_config['beta_start'],
74 | beta_end=diffusion_config['beta_end'])
75 |
76 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
77 | model_config=diffusion_model_config).to(device)
78 | model.eval()
79 | if os.path.exists(os.path.join(train_config['task_name'],
80 | train_config['ldm_ckpt_name'])):
81 | print('Loaded unet checkpoint')
82 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
83 | train_config['ldm_ckpt_name']),
84 | map_location=device))
85 | # Create output directories
86 | if not os.path.exists(train_config['task_name']):
87 | os.mkdir(train_config['task_name'])
88 |
89 | vae = VQVAE(im_channels=dataset_config['im_channels'],
90 | model_config=autoencoder_model_config).to(device)
91 | vae.eval()
92 |
93 | # Load vae if found
94 | if os.path.exists(os.path.join(train_config['task_name'],
95 | train_config['vqvae_autoencoder_ckpt_name'])):
96 | print('Loaded vae checkpoint')
97 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
98 | train_config['vqvae_autoencoder_ckpt_name']),
99 | map_location=device), strict=True)
100 | with torch.no_grad():
101 | sample(model, scheduler, train_config, diffusion_model_config,
102 | autoencoder_model_config, diffusion_config, dataset_config, vae)
103 |
104 |
105 | if __name__ == '__main__':
106 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')
107 | parser.add_argument('--config', dest='config_path',
108 | default='config/mnist.yaml', type=str)
109 | args = parser.parse_args()
110 | infer(args)
111 |
--------------------------------------------------------------------------------
/tools/train_ddpm_vqvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import argparse
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.optim import Adam
8 | from dataset.mnist_dataset import MnistDataset
9 | from dataset.celeb_dataset import CelebDataset
10 | from torch.utils.data import DataLoader
11 | from models.unet_base import Unet
12 | from models.vqvae import VQVAE
13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 |
18 | def train(args):
19 | # Read the config file #
20 | with open(args.config_path, 'r') as file:
21 | try:
22 | config = yaml.safe_load(file)
23 | except yaml.YAMLError as exc:
24 | print(exc)
25 | print(config)
26 | ########################
27 |
28 | diffusion_config = config['diffusion_params']
29 | dataset_config = config['dataset_params']
30 | diffusion_model_config = config['ldm_params']
31 | autoencoder_model_config = config['autoencoder_params']
32 | train_config = config['train_params']
33 |
34 | # Create the noise scheduler
35 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
36 | beta_start=diffusion_config['beta_start'],
37 | beta_end=diffusion_config['beta_end'])
38 |
39 | im_dataset_cls = {
40 | 'mnist': MnistDataset,
41 | 'celebhq': CelebDataset,
42 | }.get(dataset_config['name'])
43 |
44 | im_dataset = im_dataset_cls(split='train',
45 | im_path=dataset_config['im_path'],
46 | im_size=dataset_config['im_size'],
47 | im_channels=dataset_config['im_channels'],
48 | use_latents=True,
49 | latent_path=os.path.join(train_config['task_name'],
50 | train_config['vqvae_latent_dir_name'])
51 | )
52 |
53 | data_loader = DataLoader(im_dataset,
54 | batch_size=train_config['ldm_batch_size'],
55 | shuffle=True)
56 |
57 | # Instantiate the model
58 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
59 | model_config=diffusion_model_config).to(device)
60 | model.train()
61 |
62 | # Load VAE ONLY if latents are not to be used or are missing
63 | if not im_dataset.use_latents:
64 | print('Loading vqvae model as latents not present')
65 | vae = VQVAE(im_channels=dataset_config['im_channels'],
66 | model_config=autoencoder_model_config).to(device)
67 | vae.eval()
68 | # Load vae if found
69 | if os.path.exists(os.path.join(train_config['task_name'],
70 | train_config['vqvae_autoencoder_ckpt_name'])):
71 | print('Loaded vae checkpoint')
72 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
73 | train_config['vqvae_autoencoder_ckpt_name']),
74 | map_location=device))
75 | # Specify training parameters
76 | num_epochs = train_config['ldm_epochs']
77 | optimizer = Adam(model.parameters(), lr=train_config['ldm_lr'])
78 | criterion = torch.nn.MSELoss()
79 |
80 | # Run training
81 | if not im_dataset.use_latents:
82 | for param in vae.parameters():
83 | param.requires_grad = False
84 |
85 | for epoch_idx in range(num_epochs):
86 | losses = []
87 | for im in tqdm(data_loader):
88 | optimizer.zero_grad()
89 | im = im.float().to(device)
90 | if not im_dataset.use_latents:
91 | with torch.no_grad():
92 | im, _ = vae.encode(im)
93 |
94 | # Sample random noise
95 | noise = torch.randn_like(im).to(device)
96 |
97 | # Sample timestep
98 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
99 |
100 | # Add noise to images according to timestep
101 | noisy_im = scheduler.add_noise(im, noise, t)
102 | noise_pred = model(noisy_im, t)
103 |
104 | loss = criterion(noise_pred, noise)
105 | losses.append(loss.item())
106 | loss.backward()
107 | optimizer.step()
108 | print('Finished epoch:{} | Loss : {:.4f}'.format(
109 | epoch_idx + 1,
110 | np.mean(losses)))
111 |
112 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
113 | train_config['ldm_ckpt_name']))
114 |
115 | print('Done Training ...')
116 |
117 |
118 | if __name__ == '__main__':
119 | parser = argparse.ArgumentParser(description='Arguments for ddpm training')
120 | parser.add_argument('--config', dest='config_path',
121 | default='config/mnist.yaml', type=str)
122 | args = parser.parse_args()
123 | train(args)
124 |
--------------------------------------------------------------------------------
/tools/infer_vqvae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import pickle
5 |
6 | import torch
7 | import torchvision
8 | import yaml
9 | from torch.utils.data.dataloader import DataLoader
10 | from torchvision.utils import make_grid
11 | from tqdm import tqdm
12 |
13 | from dataset.celeb_dataset import CelebDataset
14 | from dataset.mnist_dataset import MnistDataset
15 | from models.vqvae import VQVAE
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 |
19 |
20 | def infer(args):
21 | ######## Read the config file #######
22 | with open(args.config_path, 'r') as file:
23 | try:
24 | config = yaml.safe_load(file)
25 | except yaml.YAMLError as exc:
26 | print(exc)
27 | print(config)
28 |
29 | dataset_config = config['dataset_params']
30 | autoencoder_config = config['autoencoder_params']
31 | train_config = config['train_params']
32 |
33 | # Create the dataset
34 | im_dataset_cls = {
35 | 'mnist': MnistDataset,
36 | 'celebhq': CelebDataset,
37 | }.get(dataset_config['name'])
38 |
39 | im_dataset = im_dataset_cls(split='train',
40 | im_path=dataset_config['im_path'],
41 | im_size=dataset_config['im_size'],
42 | im_channels=dataset_config['im_channels'])
43 |
44 | # This is only used for saving latents. Which as of now
45 | # is not done in batches hence batch size 1
46 | data_loader = DataLoader(im_dataset,
47 | batch_size=1,
48 | shuffle=False)
49 |
50 | num_images = train_config['num_samples']
51 | ngrid = train_config['num_grid_rows']
52 |
53 | idxs = torch.randint(0, len(im_dataset) - 1, (num_images,))
54 | ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float()
55 | ims = ims.to(device)
56 |
57 | model = VQVAE(im_channels=dataset_config['im_channels'],
58 | model_config=autoencoder_config).to(device)
59 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
60 | train_config['vqvae_autoencoder_ckpt_name']),
61 | map_location=device))
62 | model.eval()
63 |
64 | with torch.no_grad():
65 |
66 | encoded_output, _ = model.encode(ims)
67 | decoded_output = model.decode(encoded_output)
68 | encoded_output = torch.clamp(encoded_output, -1., 1.)
69 | encoded_output = (encoded_output + 1) / 2
70 | decoded_output = torch.clamp(decoded_output, -1., 1.)
71 | decoded_output = (decoded_output + 1) / 2
72 | ims = (ims + 1) / 2
73 |
74 | encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid)
75 | decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid)
76 | input_grid = make_grid(ims.cpu(), nrow=ngrid)
77 | encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid)
78 | decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid)
79 | input_grid = torchvision.transforms.ToPILImage()(input_grid)
80 |
81 | input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png'))
82 | encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png'))
83 | decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png'))
84 |
85 | if train_config['save_latents']:
86 | # save Latents (but in a very unoptimized way)
87 | latent_path = os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name'])
88 | latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name'],
89 | '*.pkl'))
90 | assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run'
91 | if not os.path.exists(latent_path):
92 | os.mkdir(latent_path)
93 | print('Saving Latents for {}'.format(dataset_config['name']))
94 |
95 | fname_latent_map = {}
96 | part_count = 0
97 | count = 0
98 | for idx, im in enumerate(tqdm(data_loader)):
99 | encoded_output, _ = model.encode(im.float().to(device))
100 | fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu()
101 | # Save latents every 1000 images
102 | if (count+1) % 1000 == 0:
103 | pickle.dump(fname_latent_map, open(os.path.join(latent_path,
104 | '{}.pkl'.format(part_count)), 'wb'))
105 | part_count += 1
106 | fname_latent_map = {}
107 | count += 1
108 | if len(fname_latent_map) > 0:
109 | pickle.dump(fname_latent_map, open(os.path.join(latent_path,
110 | '{}.pkl'.format(part_count)), 'wb'))
111 | print('Done saving latents')
112 |
113 |
114 | if __name__ == '__main__':
115 | parser = argparse.ArgumentParser(description='Arguments for vq vae inference')
116 | parser.add_argument('--config', dest='config_path',
117 | default='config/mnist.yaml', type=str)
118 | args = parser.parse_args()
119 | infer(args)
120 |
--------------------------------------------------------------------------------
/models/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.blocks import DownBlock, MidBlock, UpBlock
4 |
5 |
6 | class VAE(nn.Module):
7 | def __init__(self, im_channels, model_config):
8 | super().__init__()
9 | self.down_channels = model_config['down_channels']
10 | self.mid_channels = model_config['mid_channels']
11 | self.down_sample = model_config['down_sample']
12 | self.num_down_layers = model_config['num_down_layers']
13 | self.num_mid_layers = model_config['num_mid_layers']
14 | self.num_up_layers = model_config['num_up_layers']
15 |
16 | # To disable attention in Downblock of Encoder and Upblock of Decoder
17 | self.attns = model_config['attn_down']
18 |
19 | # Latent Dimension
20 | self.z_channels = model_config['z_channels']
21 | self.norm_channels = model_config['norm_channels']
22 | self.num_heads = model_config['num_heads']
23 |
24 | # Assertion to validate the channel information
25 | assert self.mid_channels[0] == self.down_channels[-1]
26 | assert self.mid_channels[-1] == self.down_channels[-1]
27 | assert len(self.down_sample) == len(self.down_channels) - 1
28 | assert len(self.attns) == len(self.down_channels) - 1
29 |
30 | # Wherever we use downsampling in encoder correspondingly use
31 | # upsampling in decoder
32 | self.up_sample = list(reversed(self.down_sample))
33 |
34 | ##################### Encoder ######################
35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
36 |
37 | # Downblock + Midblock
38 | self.encoder_layers = nn.ModuleList([])
39 | for i in range(len(self.down_channels) - 1):
40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
41 | t_emb_dim=None, down_sample=self.down_sample[i],
42 | num_heads=self.num_heads,
43 | num_layers=self.num_down_layers,
44 | attn=self.attns[i],
45 | norm_channels=self.norm_channels))
46 |
47 | self.encoder_mids = nn.ModuleList([])
48 | for i in range(len(self.mid_channels) - 1):
49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
50 | t_emb_dim=None,
51 | num_heads=self.num_heads,
52 | num_layers=self.num_mid_layers,
53 | norm_channels=self.norm_channels))
54 |
55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2*self.z_channels, kernel_size=3, padding=1)
57 |
58 | # Latent Dimension is 2*Latent because we are predicting mean & variance
59 | self.pre_quant_conv = nn.Conv2d(2*self.z_channels, 2*self.z_channels, kernel_size=1)
60 | ####################################################
61 |
62 |
63 | ##################### Decoder ######################
64 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
65 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
66 |
67 | # Midblock + Upblock
68 | self.decoder_mids = nn.ModuleList([])
69 | for i in reversed(range(1, len(self.mid_channels))):
70 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
71 | t_emb_dim=None,
72 | num_heads=self.num_heads,
73 | num_layers=self.num_mid_layers,
74 | norm_channels=self.norm_channels))
75 |
76 | self.decoder_layers = nn.ModuleList([])
77 | for i in reversed(range(1, len(self.down_channels))):
78 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
79 | t_emb_dim=None, up_sample=self.down_sample[i - 1],
80 | num_heads=self.num_heads,
81 | num_layers=self.num_up_layers,
82 | attn=self.attns[i - 1],
83 | norm_channels=self.norm_channels))
84 |
85 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
86 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
87 |
88 | def encode(self, x):
89 | out = self.encoder_conv_in(x)
90 | for idx, down in enumerate(self.encoder_layers):
91 | out = down(out)
92 | for mid in self.encoder_mids:
93 | out = mid(out)
94 | out = self.encoder_norm_out(out)
95 | out = nn.SiLU()(out)
96 | out = self.encoder_conv_out(out)
97 | out = self.pre_quant_conv(out)
98 | mean, logvar = torch.chunk(out, 2, dim=1)
99 | std = torch.exp(0.5 * logvar)
100 | sample = mean + std * torch.randn(mean.shape).to(device=x.device)
101 | return sample, out
102 |
103 | def decode(self, z):
104 | out = z
105 | out = self.post_quant_conv(out)
106 | out = self.decoder_conv_in(out)
107 | for mid in self.decoder_mids:
108 | out = mid(out)
109 | for idx, up in enumerate(self.decoder_layers):
110 | out = up(out)
111 |
112 | out = self.decoder_norm_out(out)
113 | out = nn.SiLU()(out)
114 | out = self.decoder_conv_out(out)
115 | return out
116 |
117 | def forward(self, x):
118 | z, encoder_output = self.encode(x)
119 | out = self.decode(z)
120 | return out, encoder_output
121 |
122 |
--------------------------------------------------------------------------------
/dataset/celeb_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | import torch
5 | import torchvision
6 | import numpy as np
7 | from PIL import Image
8 | from utils.diffusion_utils import load_latents
9 | from tqdm import tqdm
10 | from torch.utils.data.dataset import Dataset
11 |
12 |
13 | class CelebDataset(Dataset):
14 | r"""
15 | Celeb dataset will by default centre crop and resize the images.
16 | This can be replaced by any other dataset. As long as all the images
17 | are under one directory.
18 | """
19 |
20 | def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
21 | use_latents=False, latent_path=None, condition_config=None):
22 | self.split = split
23 | self.im_size = im_size
24 | self.im_channels = im_channels
25 | self.im_ext = im_ext
26 | self.im_path = im_path
27 | self.latent_maps = None
28 | self.use_latents = False
29 |
30 | self.condition_types = [] if condition_config is None else condition_config['condition_types']
31 |
32 | self.idx_to_cls_map = {}
33 | self.cls_to_idx_map ={}
34 |
35 | if 'image' in self.condition_types:
36 | self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels']
37 | self.mask_h = condition_config['image_condition_config']['image_condition_h']
38 | self.mask_w = condition_config['image_condition_config']['image_condition_w']
39 |
40 | self.images, self.texts, self.masks = self.load_images(im_path)
41 |
42 | # Whether to load images or to load latents
43 | if use_latents and latent_path is not None:
44 | latent_maps = load_latents(latent_path)
45 | if len(latent_maps) == len(self.images):
46 | self.use_latents = True
47 | self.latent_maps = latent_maps
48 | print('Found {} latents'.format(len(self.latent_maps)))
49 | else:
50 | print('Latents not found')
51 |
52 | def load_images(self, im_path):
53 | r"""
54 | Gets all images from the path specified
55 | and stacks them all up
56 | """
57 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
58 | ims = []
59 | fnames = glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('png')))
60 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpg')))
61 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpeg')))
62 | texts = []
63 | masks = []
64 |
65 | if 'image' in self.condition_types:
66 | label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth',
67 | 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
68 | self.idx_to_cls_map = {idx: label_list[idx] for idx in range(len(label_list))}
69 | self.cls_to_idx_map = {label_list[idx]: idx for idx in range(len(label_list))}
70 |
71 | for fname in tqdm(fnames):
72 | ims.append(fname)
73 |
74 | if 'text' in self.condition_types:
75 | im_name = os.path.split(fname)[1].split('.')[0]
76 | captions_im = []
77 | with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f:
78 | for line in f.readlines():
79 | captions_im.append(line.strip())
80 | texts.append(captions_im)
81 |
82 | if 'image' in self.condition_types:
83 | im_name = int(os.path.split(fname)[1].split('.')[0])
84 | masks.append(os.path.join(im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name)))
85 | if 'text' in self.condition_types:
86 | assert len(texts) == len(ims), "Condition Type Text but could not find captions for all images"
87 | if 'image' in self.condition_types:
88 | assert len(masks) == len(ims), "Condition Type Image but could not find masks for all images"
89 | print('Found {} images'.format(len(ims)))
90 | print('Found {} masks'.format(len(masks)))
91 | print('Found {} captions'.format(len(texts)))
92 | return ims, texts, masks
93 |
94 | def get_mask(self, index):
95 | r"""
96 | Method to get the mask of WxH
97 | for given index and convert it into
98 | Classes x W x H mask image
99 | :param index:
100 | :return:
101 | """
102 | mask_im = Image.open(self.masks[index])
103 | mask_im = np.array(mask_im)
104 | im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels))
105 | for orig_idx in range(len(self.idx_to_cls_map)):
106 | im_base[mask_im == (orig_idx+1), orig_idx] = 1
107 | mask = torch.from_numpy(im_base).permute(2, 0, 1).float()
108 | return mask
109 |
110 | def __len__(self):
111 | return len(self.images)
112 |
113 | def __getitem__(self, index):
114 | ######## Set Conditioning Info ########
115 | cond_inputs = {}
116 | if 'text' in self.condition_types:
117 | cond_inputs['text'] = random.sample(self.texts[index], k=1)[0]
118 | if 'image' in self.condition_types:
119 | mask = self.get_mask(index)
120 | cond_inputs['image'] = mask
121 | #######################################
122 |
123 | if self.use_latents:
124 | latent = self.latent_maps[self.images[index]]
125 | if len(self.condition_types) == 0:
126 | return latent
127 | else:
128 | return latent, cond_inputs
129 | else:
130 | im = Image.open(self.images[index])
131 | im_tensor = torchvision.transforms.Compose([
132 | torchvision.transforms.Resize(self.im_size),
133 | torchvision.transforms.CenterCrop(self.im_size),
134 | torchvision.transforms.ToTensor(),
135 | ])(im)
136 | im.close()
137 |
138 | # Convert input to -1 to 1 range.
139 | im_tensor = (2 * im_tensor) - 1
140 | if len(self.condition_types) == 0:
141 | return im_tensor
142 | else:
143 | return im_tensor, cond_inputs
144 |
--------------------------------------------------------------------------------
/models/lpips.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import namedtuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn
9 | import torchvision
10 |
11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 |
16 | def spatial_average(in_tens, keepdim=True):
17 | return in_tens.mean([2, 3], keepdim=keepdim)
18 |
19 |
20 | class vgg16(torch.nn.Module):
21 | def __init__(self, requires_grad=False, pretrained=True):
22 | super(vgg16, self).__init__()
23 | # Load pretrained vgg model from torchvision
24 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features
25 | self.slice1 = torch.nn.Sequential()
26 | self.slice2 = torch.nn.Sequential()
27 | self.slice3 = torch.nn.Sequential()
28 | self.slice4 = torch.nn.Sequential()
29 | self.slice5 = torch.nn.Sequential()
30 | self.N_slices = 5
31 | for x in range(4):
32 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
33 | for x in range(4, 9):
34 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
35 | for x in range(9, 16):
36 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
37 | for x in range(16, 23):
38 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
39 | for x in range(23, 30):
40 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
41 |
42 | # Freeze vgg model
43 | if not requires_grad:
44 | for param in self.parameters():
45 | param.requires_grad = False
46 |
47 | def forward(self, X):
48 | # Return output of vgg features
49 | h = self.slice1(X)
50 | h_relu1_2 = h
51 | h = self.slice2(h)
52 | h_relu2_2 = h
53 | h = self.slice3(h)
54 | h_relu3_3 = h
55 | h = self.slice4(h)
56 | h_relu4_3 = h
57 | h = self.slice5(h)
58 | h_relu5_3 = h
59 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
60 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
61 | return out
62 |
63 |
64 | # Learned perceptual metric
65 | class LPIPS(nn.Module):
66 | def __init__(self, net='vgg', version='0.1', use_dropout=True):
67 | super(LPIPS, self).__init__()
68 | self.version = version
69 | # Imagenet normalization
70 | self.scaling_layer = ScalingLayer()
71 | ########################
72 |
73 | # Instantiate vgg model
74 | self.chns = [64, 128, 256, 512, 512]
75 | self.L = len(self.chns)
76 | self.net = vgg16(pretrained=True, requires_grad=False)
77 |
78 | # Add 1x1 convolutional Layers
79 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
80 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
81 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
82 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
83 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
84 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
85 | self.lins = nn.ModuleList(self.lins)
86 | ########################
87 |
88 | # Load the weights of trained LPIPS model
89 | import inspect
90 | import os
91 | model_path = os.path.abspath(
92 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
93 | print('Loading model from: %s' % model_path)
94 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
95 | ########################
96 |
97 | # Freeze all parameters
98 | self.eval()
99 | for param in self.parameters():
100 | param.requires_grad = False
101 | ########################
102 |
103 | def forward(self, in0, in1, normalize=False):
104 | # Scale the inputs to -1 to +1 range if needed
105 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
106 | in0 = 2 * in0 - 1
107 | in1 = 2 * in1 - 1
108 | ########################
109 |
110 | # Normalize the inputs according to imagenet normalization
111 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
112 | ########################
113 |
114 | # Get VGG outputs for image0 and image1
115 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
116 | feats0, feats1, diffs = {}, {}, {}
117 | ########################
118 |
119 | # Compute Square of Difference for each layer output
120 | for kk in range(self.L):
121 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
122 | outs1[kk])
123 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
124 | ########################
125 |
126 | # 1x1 convolution followed by spatial average on the square differences
127 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
128 | val = 0
129 |
130 | # Aggregate the results of each layer
131 | for l in range(self.L):
132 | val += res[l]
133 | return val
134 |
135 |
136 | class ScalingLayer(nn.Module):
137 | def __init__(self):
138 | super(ScalingLayer, self).__init__()
139 | # Imagnet normalization for (0-1)
140 | # mean = [0.485, 0.456, 0.406]
141 | # std = [0.229, 0.224, 0.225]
142 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
143 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
144 |
145 | def forward(self, inp):
146 | return (inp - self.shift) / self.scale
147 |
148 |
149 | class NetLinLayer(nn.Module):
150 | ''' A single linear layer which does a 1x1 conv '''
151 |
152 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
153 | super(NetLinLayer, self).__init__()
154 |
155 | layers = [nn.Dropout(), ] if (use_dropout) else []
156 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
157 | self.model = nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | out = self.model(x)
161 | return out
162 |
--------------------------------------------------------------------------------
/tools/sample_ddpm_class_cond.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | from torchvision.utils import make_grid
7 | from PIL import Image
8 | from tqdm import tqdm
9 | from models.unet_cond_base import Unet
10 | from models.vqvae import VQVAE
11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
12 | from utils.config_utils import *
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 |
16 |
17 | def sample(model, scheduler, train_config, diffusion_model_config,
18 | autoencoder_model_config, diffusion_config, dataset_config, vae):
19 | r"""
20 | Sample stepwise by going backward one timestep at a time.
21 | We save the x0 predictions
22 | """
23 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
24 |
25 | ########### Sample random noise latent ##########
26 | xt = torch.randn((train_config['num_samples'],
27 | autoencoder_model_config['z_channels'],
28 | im_size,
29 | im_size)).to(device)
30 | ###############################################
31 |
32 | ############# Validate the config #################
33 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
34 | assert condition_config is not None, ("This sampling script is for class conditional "
35 | "but no conditioning config found")
36 | condition_types = get_config_value(condition_config, 'condition_types', [])
37 | assert 'class' in condition_types, ("This sampling script is for class conditional "
38 | "but no class condition found in config")
39 | validate_class_config(condition_config)
40 | ###############################################
41 |
42 | ############ Create Conditional input ###############
43 | num_classes = condition_config['class_condition_config']['num_classes']
44 | sample_classes = torch.randint(0, num_classes, (train_config['num_samples'], ))
45 | print('Generating images for {}'.format(list(sample_classes.numpy())))
46 | cond_input = {
47 | 'class': torch.nn.functional.one_hot(sample_classes, num_classes).to(device)
48 | }
49 | # Unconditional input for classifier free guidance
50 | uncond_input = {
51 | 'class': cond_input['class'] * 0
52 | }
53 | ###############################################
54 |
55 | # By default classifier free guidance is disabled
56 | # Change value in config or change default value here to enable it
57 | cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0)
58 |
59 | ################# Sampling Loop ########################
60 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
61 | # Get prediction of noise
62 | t = (torch.ones((xt.shape[0],))*i).long().to(device)
63 | noise_pred_cond = model(xt, t, cond_input)
64 |
65 | if cf_guidance_scale > 1:
66 | noise_pred_uncond = model(xt, t, uncond_input)
67 | noise_pred = noise_pred_uncond + cf_guidance_scale*(noise_pred_cond - noise_pred_uncond)
68 | else:
69 | noise_pred = noise_pred_cond
70 |
71 | # Use scheduler to get x0 and xt-1
72 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
73 |
74 | if i == 0:
75 | # Decode ONLY the final image to save time
76 | ims = vae.decode(xt)
77 | else:
78 | ims = x0_pred
79 |
80 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
81 | ims = (ims + 1) / 2
82 | grid = make_grid(ims, nrow=1)
83 | img = torchvision.transforms.ToPILImage()(grid)
84 |
85 | if not os.path.exists(os.path.join(train_config['task_name'], 'cond_class_samples')):
86 | os.mkdir(os.path.join(train_config['task_name'], 'cond_class_samples'))
87 | img.save(os.path.join(train_config['task_name'], 'cond_class_samples', 'x0_{}.png'.format(i)))
88 | img.close()
89 | ##############################################################
90 |
91 | def infer(args):
92 | # Read the config file #
93 | with open(args.config_path, 'r') as file:
94 | try:
95 | config = yaml.safe_load(file)
96 | except yaml.YAMLError as exc:
97 | print(exc)
98 | print(config)
99 | ########################
100 |
101 | diffusion_config = config['diffusion_params']
102 | dataset_config = config['dataset_params']
103 | diffusion_model_config = config['ldm_params']
104 | autoencoder_model_config = config['autoencoder_params']
105 | train_config = config['train_params']
106 |
107 | ########## Create the noise scheduler #############
108 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
109 | beta_start=diffusion_config['beta_start'],
110 | beta_end=diffusion_config['beta_end'])
111 | ###############################################
112 |
113 | ########## Load Unet #############
114 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
115 | model_config=diffusion_model_config).to(device)
116 | model.eval()
117 | if os.path.exists(os.path.join(train_config['task_name'],
118 | train_config['ldm_ckpt_name'])):
119 | print('Loaded unet checkpoint')
120 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
121 | train_config['ldm_ckpt_name']),
122 | map_location=device))
123 | else:
124 | raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'],
125 | train_config['ldm_ckpt_name'])))
126 | #####################################
127 |
128 | # Create output directories
129 | if not os.path.exists(train_config['task_name']):
130 | os.mkdir(train_config['task_name'])
131 |
132 | ########## Load VQVAE #############
133 | vae = VQVAE(im_channels=dataset_config['im_channels'],
134 | model_config=autoencoder_model_config).to(device)
135 | vae.eval()
136 |
137 | # Load vae if found
138 | if os.path.exists(os.path.join(train_config['task_name'],
139 | train_config['vqvae_autoencoder_ckpt_name'])):
140 | print('Loaded vae checkpoint')
141 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
142 | train_config['vqvae_autoencoder_ckpt_name']),
143 | map_location=device), strict=True)
144 | else:
145 | raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['task_name'],
146 | train_config['vqvae_autoencoder_ckpt_name'])))
147 | #####################################
148 |
149 | with torch.no_grad():
150 | sample(model, scheduler, train_config, diffusion_model_config,
151 | autoencoder_model_config, diffusion_config, dataset_config, vae)
152 |
153 |
154 | if __name__ == '__main__':
155 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation for class conditional '
156 | 'Mnist generation')
157 | parser.add_argument('--config', dest='config_path',
158 | default='config/mnist_class_cond.yaml', type=str)
159 | args = parser.parse_args()
160 | infer(args)
161 |
--------------------------------------------------------------------------------
/models/vqvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.blocks import DownBlock, MidBlock, UpBlock
4 |
5 |
6 | class VQVAE(nn.Module):
7 | def __init__(self, im_channels, model_config):
8 | super().__init__()
9 | self.down_channels = model_config['down_channels']
10 | self.mid_channels = model_config['mid_channels']
11 | self.down_sample = model_config['down_sample']
12 | self.num_down_layers = model_config['num_down_layers']
13 | self.num_mid_layers = model_config['num_mid_layers']
14 | self.num_up_layers = model_config['num_up_layers']
15 |
16 | # To disable attention in Downblock of Encoder and Upblock of Decoder
17 | self.attns = model_config['attn_down']
18 |
19 | # Latent Dimension
20 | self.z_channels = model_config['z_channels']
21 | self.codebook_size = model_config['codebook_size']
22 | self.norm_channels = model_config['norm_channels']
23 | self.num_heads = model_config['num_heads']
24 |
25 | # Assertion to validate the channel information
26 | assert self.mid_channels[0] == self.down_channels[-1]
27 | assert self.mid_channels[-1] == self.down_channels[-1]
28 | assert len(self.down_sample) == len(self.down_channels) - 1
29 | assert len(self.attns) == len(self.down_channels) - 1
30 |
31 | # Wherever we use downsampling in encoder correspondingly use
32 | # upsampling in decoder
33 | self.up_sample = list(reversed(self.down_sample))
34 |
35 | ##################### Encoder ######################
36 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
37 |
38 | # Downblock + Midblock
39 | self.encoder_layers = nn.ModuleList([])
40 | for i in range(len(self.down_channels) - 1):
41 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
42 | t_emb_dim=None, down_sample=self.down_sample[i],
43 | num_heads=self.num_heads,
44 | num_layers=self.num_down_layers,
45 | attn=self.attns[i],
46 | norm_channels=self.norm_channels))
47 |
48 | self.encoder_mids = nn.ModuleList([])
49 | for i in range(len(self.mid_channels) - 1):
50 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
51 | t_emb_dim=None,
52 | num_heads=self.num_heads,
53 | num_layers=self.num_mid_layers,
54 | norm_channels=self.norm_channels))
55 |
56 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
57 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
58 |
59 | # Pre Quantization Convolution
60 | self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
61 |
62 | # Codebook
63 | self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
64 | ####################################################
65 |
66 | ##################### Decoder ######################
67 |
68 | # Post Quantization Convolution
69 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
70 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
71 |
72 | # Midblock + Upblock
73 | self.decoder_mids = nn.ModuleList([])
74 | for i in reversed(range(1, len(self.mid_channels))):
75 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
76 | t_emb_dim=None,
77 | num_heads=self.num_heads,
78 | num_layers=self.num_mid_layers,
79 | norm_channels=self.norm_channels))
80 |
81 | self.decoder_layers = nn.ModuleList([])
82 | for i in reversed(range(1, len(self.down_channels))):
83 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
84 | t_emb_dim=None, up_sample=self.down_sample[i - 1],
85 | num_heads=self.num_heads,
86 | num_layers=self.num_up_layers,
87 | attn=self.attns[i-1],
88 | norm_channels=self.norm_channels))
89 |
90 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
91 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
92 |
93 | def quantize(self, x):
94 | B, C, H, W = x.shape
95 |
96 | # B, C, H, W -> B, H, W, C
97 | x = x.permute(0, 2, 3, 1)
98 |
99 | # B, H, W, C -> B, H*W, C
100 | x = x.reshape(x.size(0), -1, x.size(-1))
101 |
102 | # Find nearest embedding/codebook vector
103 | # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
104 | dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
105 | # (B, H*W)
106 | min_encoding_indices = torch.argmin(dist, dim=-1)
107 |
108 | # Replace encoder output with nearest codebook
109 | # quant_out -> B*H*W, C
110 | quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
111 |
112 | # x -> B*H*W, C
113 | x = x.reshape((-1, x.size(-1)))
114 | commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
115 | codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
116 | quantize_losses = {
117 | 'codebook_loss': codebook_loss,
118 | 'commitment_loss': commmitment_loss
119 | }
120 | # Straight through estimation
121 | quant_out = x + (quant_out - x).detach()
122 |
123 | # quant_out -> B, C, H, W
124 | quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
125 | min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
126 | return quant_out, quantize_losses, min_encoding_indices
127 |
128 | def encode(self, x):
129 | out = self.encoder_conv_in(x)
130 | for idx, down in enumerate(self.encoder_layers):
131 | out = down(out)
132 | for mid in self.encoder_mids:
133 | out = mid(out)
134 | out = self.encoder_norm_out(out)
135 | out = nn.SiLU()(out)
136 | out = self.encoder_conv_out(out)
137 | out = self.pre_quant_conv(out)
138 | out, quant_losses, _ = self.quantize(out)
139 | return out, quant_losses
140 |
141 | def decode(self, z):
142 | out = z
143 | out = self.post_quant_conv(out)
144 | out = self.decoder_conv_in(out)
145 | for mid in self.decoder_mids:
146 | out = mid(out)
147 | for idx, up in enumerate(self.decoder_layers):
148 | out = up(out)
149 |
150 | out = self.decoder_norm_out(out)
151 | out = nn.SiLU()(out)
152 | out = self.decoder_conv_out(out)
153 | return out
154 |
155 | def forward(self, x):
156 | z, quant_losses = self.encode(x)
157 | out = self.decode(z)
158 | return out, z, quant_losses
159 |
160 |
--------------------------------------------------------------------------------
/tools/sample_ddpm_text_cond.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import argparse
3 | import yaml
4 | import os
5 | from torchvision.utils import make_grid
6 | from tqdm import tqdm
7 | from models.unet_cond_base import Unet
8 | from models.vqvae import VQVAE
9 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
10 | from utils.config_utils import *
11 | from utils.text_utils import *
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 |
16 | def sample(model, scheduler, train_config, diffusion_model_config,
17 | autoencoder_model_config, diffusion_config, dataset_config, vae, text_tokenizer, text_model):
18 | r"""
19 | Sample stepwise by going backward one timestep at a time.
20 | We save the x0 predictions
21 | """
22 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
23 |
24 | ########### Sample random noise latent ##########
25 | # For not fixing generation with one sample
26 | xt = torch.randn((1,
27 | autoencoder_model_config['z_channels'],
28 | im_size,
29 | im_size)).to(device)
30 | ###############################################
31 |
32 | ############ Create Conditional input ###############
33 | text_prompt = ['She is a woman with blond hair. She is wearing lipstick.']
34 | neg_prompt = ['He is a man.']
35 | empty_prompt = ['']
36 | text_prompt_embed = get_text_representation(text_prompt,
37 | text_tokenizer,
38 | text_model,
39 | device)
40 | # Can replace empty prompt with negative prompt
41 | empty_text_embed = get_text_representation(empty_prompt, text_tokenizer, text_model, device)
42 | assert empty_text_embed.shape == text_prompt_embed.shape
43 |
44 | uncond_input = {
45 | 'text': empty_text_embed
46 | }
47 | cond_input = {
48 | 'text': text_prompt_embed
49 | }
50 | ###############################################
51 |
52 | # By default classifier free guidance is disabled
53 | # Change value in config or change default value here to enable it
54 | cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0)
55 |
56 | ################# Sampling Loop ########################
57 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
58 | # Get prediction of noise
59 | t = (torch.ones((xt.shape[0],)) * i).long().to(device)
60 | noise_pred_cond = model(xt, t, cond_input)
61 |
62 | if cf_guidance_scale > 1:
63 | noise_pred_uncond = model(xt, t, uncond_input)
64 | noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond)
65 | else:
66 | noise_pred = noise_pred_cond
67 |
68 | # Use scheduler to get x0 and xt-1
69 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
70 |
71 | # Save x0
72 | # ims = torch.clamp(xt, -1., 1.).detach().cpu()
73 | if i == 0:
74 | # Decode ONLY the final iamge to save time
75 | ims = vae.decode(xt)
76 | else:
77 | ims = x0_pred
78 |
79 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
80 | ims = (ims + 1) / 2
81 | grid = make_grid(ims, nrow=1)
82 | img = torchvision.transforms.ToPILImage()(grid)
83 |
84 | if not os.path.exists(os.path.join(train_config['task_name'], 'cond_text_samples')):
85 | os.mkdir(os.path.join(train_config['task_name'], 'cond_text_samples'))
86 | img.save(os.path.join(train_config['task_name'], 'cond_text_samples', 'x0_{}.png'.format(i)))
87 | img.close()
88 | ##############################################################
89 |
90 |
91 | def infer(args):
92 | # Read the config file #
93 | with open(args.config_path, 'r') as file:
94 | try:
95 | config = yaml.safe_load(file)
96 | except yaml.YAMLError as exc:
97 | print(exc)
98 | print(config)
99 | ########################
100 |
101 | diffusion_config = config['diffusion_params']
102 | dataset_config = config['dataset_params']
103 | diffusion_model_config = config['ldm_params']
104 | autoencoder_model_config = config['autoencoder_params']
105 | train_config = config['train_params']
106 |
107 | ########## Create the noise scheduler #############
108 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
109 | beta_start=diffusion_config['beta_start'],
110 | beta_end=diffusion_config['beta_end'])
111 | ###############################################
112 |
113 | text_tokenizer = None
114 | text_model = None
115 |
116 | ############# Validate the config #################
117 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
118 | assert condition_config is not None, ("This sampling script is for text conditional "
119 | "but no conditioning config found")
120 | condition_types = get_config_value(condition_config, 'condition_types', [])
121 | assert 'text' in condition_types, ("This sampling script is for text conditional "
122 | "but no text condition found in config")
123 | validate_text_config(condition_config)
124 | ###############################################
125 |
126 | ############# Load tokenizer and text model #################
127 | with torch.no_grad():
128 | # Load tokenizer and text model based on config
129 | # Also get empty text representation
130 | text_tokenizer, text_model = get_tokenizer_and_model(condition_config['text_condition_config']
131 | ['text_embed_model'], device=device)
132 | ###############################################
133 |
134 | ########## Load Unet #############
135 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
136 | model_config=diffusion_model_config).to(device)
137 | model.eval()
138 | if os.path.exists(os.path.join(train_config['task_name'],
139 | train_config['ldm_ckpt_name'])):
140 | print('Loaded unet checkpoint')
141 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
142 | train_config['ldm_ckpt_name']),
143 | map_location=device))
144 | else:
145 | raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'],
146 | train_config['ldm_ckpt_name'])))
147 | #####################################
148 |
149 | # Create output directories
150 | if not os.path.exists(train_config['task_name']):
151 | os.mkdir(train_config['task_name'])
152 |
153 | ########## Load VQVAE #############
154 | vae = VQVAE(im_channels=dataset_config['im_channels'],
155 | model_config=autoencoder_model_config).to(device)
156 | vae.eval()
157 |
158 | # Load vae if found
159 | if os.path.exists(os.path.join(train_config['task_name'],
160 | train_config['vqvae_autoencoder_ckpt_name'])):
161 | print('Loaded vae checkpoint')
162 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
163 | train_config['vqvae_autoencoder_ckpt_name']),
164 | map_location=device), strict=True)
165 | else:
166 | raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['task_name'],
167 | train_config['vqvae_autoencoder_ckpt_name'])))
168 | #####################################
169 |
170 | with torch.no_grad():
171 | sample(model, scheduler, train_config, diffusion_model_config,
172 | autoencoder_model_config, diffusion_config, dataset_config, vae,text_tokenizer, text_model)
173 |
174 |
175 | if __name__ == '__main__':
176 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation with only '
177 | 'text conditioning')
178 | parser.add_argument('--config', dest='config_path',
179 | default='config/celebhq_text_cond.yaml', type=str)
180 | args = parser.parse_args()
181 | infer(args)
182 |
--------------------------------------------------------------------------------
/tools/train_ddpm_cond.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import numpy as np
4 | from tqdm import tqdm
5 | from torch.optim import Adam
6 | from dataset.mnist_dataset import MnistDataset
7 | from dataset.celeb_dataset import CelebDataset
8 | from torch.utils.data import DataLoader
9 | from models.unet_cond_base import Unet
10 | from models.vqvae import VQVAE
11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
12 | from utils.text_utils import *
13 | from utils.config_utils import *
14 | from utils.diffusion_utils import *
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 |
18 |
19 | def train(args):
20 | # Read the config file #
21 | with open(args.config_path, 'r') as file:
22 | try:
23 | config = yaml.safe_load(file)
24 | except yaml.YAMLError as exc:
25 | print(exc)
26 | print(config)
27 | ########################
28 |
29 | diffusion_config = config['diffusion_params']
30 | dataset_config = config['dataset_params']
31 | diffusion_model_config = config['ldm_params']
32 | autoencoder_model_config = config['autoencoder_params']
33 | train_config = config['train_params']
34 |
35 | ########## Create the noise scheduler #############
36 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
37 | beta_start=diffusion_config['beta_start'],
38 | beta_end=diffusion_config['beta_end'])
39 | ###############################################
40 |
41 | # Instantiate Condition related components
42 | text_tokenizer = None
43 | text_model = None
44 | empty_text_embed = None
45 | condition_types = []
46 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
47 | if condition_config is not None:
48 | assert 'condition_types' in condition_config, \
49 | "condition type missing in conditioning config"
50 | condition_types = condition_config['condition_types']
51 | if 'text' in condition_types:
52 | validate_text_config(condition_config)
53 | with torch.no_grad():
54 | # Load tokenizer and text model based on config
55 | # Also get empty text representation
56 | text_tokenizer, text_model = get_tokenizer_and_model(condition_config['text_condition_config']
57 | ['text_embed_model'], device=device)
58 | empty_text_embed = get_text_representation([''], text_tokenizer, text_model, device)
59 |
60 | im_dataset_cls = {
61 | 'mnist': MnistDataset,
62 | 'celebhq': CelebDataset,
63 | }.get(dataset_config['name'])
64 |
65 | im_dataset = im_dataset_cls(split='train',
66 | im_path=dataset_config['im_path'],
67 | im_size=dataset_config['im_size'],
68 | im_channels=dataset_config['im_channels'],
69 | use_latents=True,
70 | latent_path=os.path.join(train_config['task_name'],
71 | train_config['vqvae_latent_dir_name']),
72 | condition_config=condition_config)
73 |
74 | data_loader = DataLoader(im_dataset,
75 | batch_size=train_config['ldm_batch_size'],
76 | shuffle=True)
77 |
78 | # Instantiate the unet model
79 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
80 | model_config=diffusion_model_config).to(device)
81 | model.train()
82 |
83 | vae = None
84 | # Load VAE ONLY if latents are not to be saved or some are missing
85 | if not im_dataset.use_latents:
86 | print('Loading vqvae model as latents not present')
87 | vae = VQVAE(im_channels=dataset_config['im_channels'],
88 | model_config=autoencoder_model_config).to(device)
89 | vae.eval()
90 | # Load vae if found
91 | if os.path.exists(os.path.join(train_config['task_name'],
92 | train_config['vqvae_autoencoder_ckpt_name'])):
93 | print('Loaded vae checkpoint')
94 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
95 | train_config['vqvae_autoencoder_ckpt_name']),
96 | map_location=device))
97 | else:
98 | raise Exception('VAE checkpoint not found and use_latents was disabled')
99 |
100 | # Specify training parameters
101 | num_epochs = train_config['ldm_epochs']
102 | optimizer = Adam(model.parameters(), lr=train_config['ldm_lr'])
103 | criterion = torch.nn.MSELoss()
104 |
105 | # Load vae and freeze parameters ONLY if latents already not saved
106 | if not im_dataset.use_latents:
107 | assert vae is not None
108 | for param in vae.parameters():
109 | param.requires_grad = False
110 |
111 | # Run training
112 | for epoch_idx in range(num_epochs):
113 | losses = []
114 | for data in tqdm(data_loader):
115 | cond_input = None
116 | if condition_config is not None:
117 | im, cond_input = data
118 | else:
119 | im = data
120 | optimizer.zero_grad()
121 | im = im.float().to(device)
122 | if not im_dataset.use_latents:
123 | with torch.no_grad():
124 | im, _ = vae.encode(im)
125 |
126 | ########### Handling Conditional Input ###########
127 | if 'text' in condition_types:
128 | with torch.no_grad():
129 | assert 'text' in cond_input, 'Conditioning Type Text but no text conditioning input present'
130 | validate_text_config(condition_config)
131 | text_condition = get_text_representation(cond_input['text'],
132 | text_tokenizer,
133 | text_model,
134 | device)
135 | text_drop_prob = get_config_value(condition_config['text_condition_config'],
136 | 'cond_drop_prob', 0.)
137 | text_condition = drop_text_condition(text_condition, im, empty_text_embed, text_drop_prob)
138 | cond_input['text'] = text_condition
139 | if 'image' in condition_types:
140 | assert 'image' in cond_input, 'Conditioning Type Image but no image conditioning input present'
141 | validate_image_config(condition_config)
142 | cond_input_image = cond_input['image'].to(device)
143 | # Drop condition
144 | im_drop_prob = get_config_value(condition_config['image_condition_config'],
145 | 'cond_drop_prob', 0.)
146 | cond_input['image'] = drop_image_condition(cond_input_image, im, im_drop_prob)
147 | if 'class' in condition_types:
148 | assert 'class' in cond_input, 'Conditioning Type Class but no class conditioning input present'
149 | validate_class_config(condition_config)
150 | class_condition = torch.nn.functional.one_hot(
151 | cond_input['class'],
152 | condition_config['class_condition_config']['num_classes']).to(device)
153 | class_drop_prob = get_config_value(condition_config['class_condition_config'],
154 | 'cond_drop_prob', 0.)
155 | # Drop condition
156 | cond_input['class'] = drop_class_condition(class_condition, class_drop_prob, im)
157 | ################################################
158 |
159 | # Sample random noise
160 | noise = torch.randn_like(im).to(device)
161 |
162 | # Sample timestep
163 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
164 |
165 | # Add noise to images according to timestep
166 | noisy_im = scheduler.add_noise(im, noise, t)
167 | noise_pred = model(noisy_im, t, cond_input=cond_input)
168 | loss = criterion(noise_pred, noise)
169 | losses.append(loss.item())
170 | loss.backward()
171 | optimizer.step()
172 | print('Finished epoch:{} | Loss : {:.4f}'.format(
173 | epoch_idx + 1,
174 | np.mean(losses)))
175 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
176 | train_config['ldm_ckpt_name']))
177 |
178 | print('Done Training ...')
179 |
180 |
181 | if __name__ == '__main__':
182 | parser = argparse.ArgumentParser(description='Arguments for ddpm training')
183 | parser.add_argument('--config', dest='config_path',
184 | default='config/celebhq_text_cond_clip.yaml', type=str)
185 | args = parser.parse_args()
186 | train(args)
187 |
--------------------------------------------------------------------------------
/models/unet_cond_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import einsum
3 | import torch.nn as nn
4 | from models.blocks import get_time_embedding
5 | from models.blocks import DownBlock, MidBlock, UpBlockUnet
6 | from utils.config_utils import *
7 |
8 |
9 | class Unet(nn.Module):
10 | r"""
11 | Unet model comprising
12 | Down blocks, Midblocks and Uplocks
13 | """
14 |
15 | def __init__(self, im_channels, model_config):
16 | super().__init__()
17 | self.down_channels = model_config['down_channels']
18 | self.mid_channels = model_config['mid_channels']
19 | self.t_emb_dim = model_config['time_emb_dim']
20 | self.down_sample = model_config['down_sample']
21 | self.num_down_layers = model_config['num_down_layers']
22 | self.num_mid_layers = model_config['num_mid_layers']
23 | self.num_up_layers = model_config['num_up_layers']
24 | self.attns = model_config['attn_down']
25 | self.norm_channels = model_config['norm_channels']
26 | self.num_heads = model_config['num_heads']
27 | self.conv_out_channels = model_config['conv_out_channels']
28 |
29 | # Validating Unet Model configurations
30 | assert self.mid_channels[0] == self.down_channels[-1]
31 | assert self.mid_channels[-1] == self.down_channels[-2]
32 | assert len(self.down_sample) == len(self.down_channels) - 1
33 | assert len(self.attns) == len(self.down_channels) - 1
34 |
35 | ######## Class, Mask and Text Conditioning Config #####
36 | self.class_cond = False
37 | self.text_cond = False
38 | self.image_cond = False
39 | self.text_embed_dim = None
40 | self.condition_config = get_config_value(model_config, 'condition_config', None)
41 | if self.condition_config is not None:
42 | assert 'condition_types' in self.condition_config, 'Condition Type not provided in model config'
43 | condition_types = self.condition_config['condition_types']
44 | if 'class' in condition_types:
45 | validate_class_config(self.condition_config)
46 | self.class_cond = True
47 | self.num_classes = self.condition_config['class_condition_config']['num_classes']
48 | if 'text' in condition_types:
49 | validate_text_config(self.condition_config)
50 | self.text_cond = True
51 | self.text_embed_dim = self.condition_config['text_condition_config']['text_embed_dim']
52 | if 'image' in condition_types:
53 | self.image_cond = True
54 | self.im_cond_input_ch = self.condition_config['image_condition_config'][
55 | 'image_condition_input_channels']
56 | self.im_cond_output_ch = self.condition_config['image_condition_config'][
57 | 'image_condition_output_channels']
58 | if self.class_cond:
59 | # Rather than using a special null class we dont add the
60 | # class embedding information for unconditional generation
61 | self.class_emb = nn.Embedding(self.num_classes,
62 | self.t_emb_dim)
63 |
64 | if self.image_cond:
65 | # Map the mask image to a N channel image and
66 | # concat that with input across channel dimension
67 | self.cond_conv_in = nn.Conv2d(in_channels=self.im_cond_input_ch,
68 | out_channels=self.im_cond_output_ch,
69 | kernel_size=1,
70 | bias=False)
71 | self.conv_in_concat = nn.Conv2d(im_channels + self.im_cond_output_ch,
72 | self.down_channels[0], kernel_size=3, padding=1)
73 | else:
74 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
75 | self.cond = self.text_cond or self.image_cond or self.class_cond
76 | ###################################
77 |
78 | # Initial projection from sinusoidal time embedding
79 | self.t_proj = nn.Sequential(
80 | nn.Linear(self.t_emb_dim, self.t_emb_dim),
81 | nn.SiLU(),
82 | nn.Linear(self.t_emb_dim, self.t_emb_dim)
83 | )
84 |
85 | self.up_sample = list(reversed(self.down_sample))
86 | self.downs = nn.ModuleList([])
87 |
88 | # Build the Downblocks
89 | for i in range(len(self.down_channels) - 1):
90 | # Cross Attention and Context Dim only needed if text condition is present
91 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim,
92 | down_sample=self.down_sample[i],
93 | num_heads=self.num_heads,
94 | num_layers=self.num_down_layers,
95 | attn=self.attns[i], norm_channels=self.norm_channels,
96 | cross_attn=self.text_cond,
97 | context_dim=self.text_embed_dim))
98 |
99 | self.mids = nn.ModuleList([])
100 | # Build the Midblocks
101 | for i in range(len(self.mid_channels) - 1):
102 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
103 | num_heads=self.num_heads,
104 | num_layers=self.num_mid_layers,
105 | norm_channels=self.norm_channels,
106 | cross_attn=self.text_cond,
107 | context_dim=self.text_embed_dim))
108 |
109 | self.ups = nn.ModuleList([])
110 | # Build the Upblocks
111 | for i in reversed(range(len(self.down_channels) - 1)):
112 | self.ups.append(
113 | UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
114 | self.t_emb_dim, up_sample=self.down_sample[i],
115 | num_heads=self.num_heads,
116 | num_layers=self.num_up_layers,
117 | norm_channels=self.norm_channels,
118 | cross_attn=self.text_cond,
119 | context_dim=self.text_embed_dim))
120 |
121 | self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
122 | self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
123 |
124 | def forward(self, x, t, cond_input=None):
125 | # Shapes assuming downblocks are [C1, C2, C3, C4]
126 | # Shapes assuming midblocks are [C4, C4, C3]
127 | # Shapes assuming downsamples are [True, True, False]
128 | if self.cond:
129 | assert cond_input is not None, \
130 | "Model initialized with conditioning so cond_input cannot be None"
131 | if self.image_cond:
132 | ######## Mask Conditioning ########
133 | validate_image_conditional_input(cond_input, x)
134 | im_cond = cond_input['image']
135 | im_cond = torch.nn.functional.interpolate(im_cond, size=x.shape[-2:])
136 | im_cond = self.cond_conv_in(im_cond)
137 | assert im_cond.shape[-2:] == x.shape[-2:]
138 | x = torch.cat([x, im_cond], dim=1)
139 | # B x (C+N) x H x W
140 | out = self.conv_in_concat(x)
141 | #####################################
142 | else:
143 | # B x C x H x W
144 | out = self.conv_in(x)
145 | # B x C1 x H x W
146 |
147 | # t_emb -> B x t_emb_dim
148 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
149 | t_emb = self.t_proj(t_emb)
150 |
151 | ######## Class Conditioning ########
152 | if self.class_cond:
153 | validate_class_conditional_input(cond_input, x, self.num_classes)
154 | class_embed = einsum(cond_input['class'].float(), self.class_emb.weight, 'b n, n d -> b d')
155 | t_emb += class_embed
156 | ####################################
157 |
158 | context_hidden_states = None
159 | if self.text_cond:
160 | assert 'text' in cond_input, \
161 | "Model initialized with text conditioning but cond_input has no text information"
162 | context_hidden_states = cond_input['text']
163 | down_outs = []
164 |
165 | for idx, down in enumerate(self.downs):
166 | down_outs.append(out)
167 | out = down(out, t_emb, context_hidden_states)
168 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
169 | # out B x C4 x H/4 x W/4
170 |
171 | for mid in self.mids:
172 | out = mid(out, t_emb, context_hidden_states)
173 | # out B x C3 x H/4 x W/4
174 |
175 | for up in self.ups:
176 | down_out = down_outs.pop()
177 | out = up(out, down_out, t_emb, context_hidden_states)
178 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
179 | out = self.norm_out(out)
180 | out = nn.SiLU()(out)
181 | out = self.conv_out(out)
182 | # out B x C x H x W
183 | return out
184 |
--------------------------------------------------------------------------------
/tools/train_vqvae.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import random
5 | import torchvision
6 | import os
7 | import numpy as np
8 | from tqdm import tqdm
9 | from models.vqvae import VQVAE
10 | from models.lpips import LPIPS
11 | from models.discriminator import Discriminator
12 | from torch.utils.data.dataloader import DataLoader
13 | from dataset.mnist_dataset import MnistDataset
14 | from dataset.celeb_dataset import CelebDataset
15 | from torch.optim import Adam
16 | from torchvision.utils import make_grid
17 |
18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19 |
20 |
21 | def train(args):
22 | # Read the config file #
23 | with open(args.config_path, 'r') as file:
24 | try:
25 | config = yaml.safe_load(file)
26 | except yaml.YAMLError as exc:
27 | print(exc)
28 | print(config)
29 |
30 | dataset_config = config['dataset_params']
31 | autoencoder_config = config['autoencoder_params']
32 | train_config = config['train_params']
33 |
34 | # Set the desired seed value #
35 | seed = train_config['seed']
36 | torch.manual_seed(seed)
37 | np.random.seed(seed)
38 | random.seed(seed)
39 | if device == 'cuda':
40 | torch.cuda.manual_seed_all(seed)
41 | #############################
42 |
43 | # Create the model and dataset #
44 | model = VQVAE(im_channels=dataset_config['im_channels'],
45 | model_config=autoencoder_config).to(device)
46 | # Create the dataset
47 | im_dataset_cls = {
48 | 'mnist': MnistDataset,
49 | 'celebhq': CelebDataset,
50 | }.get(dataset_config['name'])
51 |
52 | im_dataset = im_dataset_cls(split='train',
53 | im_path=dataset_config['im_path'],
54 | im_size=dataset_config['im_size'],
55 | im_channels=dataset_config['im_channels'])
56 |
57 | data_loader = DataLoader(im_dataset,
58 | batch_size=train_config['autoencoder_batch_size'],
59 | shuffle=True)
60 |
61 | # Create output directories
62 | if not os.path.exists(train_config['task_name']):
63 | os.mkdir(train_config['task_name'])
64 |
65 | num_epochs = train_config['autoencoder_epochs']
66 |
67 | # L1/L2 loss for Reconstruction
68 | recon_criterion = torch.nn.MSELoss()
69 | # Disc Loss can even be BCEWithLogits
70 | disc_criterion = torch.nn.MSELoss()
71 |
72 | # No need to freeze lpips as lpips.py takes care of that
73 | lpips_model = LPIPS().eval().to(device)
74 | discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device)
75 |
76 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
77 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
78 |
79 | disc_step_start = train_config['disc_start']
80 | step_count = 0
81 |
82 | # This is for accumulating gradients incase the images are huge
83 | # And one cant afford higher batch sizes
84 | acc_steps = train_config['autoencoder_acc_steps']
85 | image_save_steps = train_config['autoencoder_img_save_steps']
86 | img_save_count = 0
87 |
88 | for epoch_idx in range(num_epochs):
89 | recon_losses = []
90 | codebook_losses = []
91 | #commitment_losses = []
92 | perceptual_losses = []
93 | disc_losses = []
94 | gen_losses = []
95 | losses = []
96 |
97 | optimizer_g.zero_grad()
98 | optimizer_d.zero_grad()
99 |
100 | for im in tqdm(data_loader):
101 | step_count += 1
102 | im = im.float().to(device)
103 |
104 | # Fetch autoencoders output(reconstructions)
105 | model_output = model(im)
106 | output, z, quantize_losses = model_output
107 |
108 | # Image Saving Logic
109 | if step_count % image_save_steps == 0 or step_count == 1:
110 | sample_size = min(8, im.shape[0])
111 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu()
112 | save_output = ((save_output + 1) / 2)
113 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
114 |
115 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
116 | img = torchvision.transforms.ToPILImage()(grid)
117 | if not os.path.exists(os.path.join(train_config['task_name'],'vqvae_autoencoder_samples')):
118 | os.mkdir(os.path.join(train_config['task_name'], 'vqvae_autoencoder_samples'))
119 | img.save(os.path.join(train_config['task_name'],'vqvae_autoencoder_samples',
120 | 'current_autoencoder_sample_{}.png'.format(img_save_count)))
121 | img_save_count += 1
122 | img.close()
123 |
124 | ######### Optimize Generator ##########
125 | # L2 Loss
126 | recon_loss = recon_criterion(output, im)
127 | recon_losses.append(recon_loss.item())
128 | recon_loss = recon_loss / acc_steps
129 | g_loss = (recon_loss +
130 | (train_config['codebook_weight'] * quantize_losses['codebook_loss'] / acc_steps) +
131 | (train_config['commitment_beta'] * quantize_losses['commitment_loss'] / acc_steps))
132 | codebook_losses.append(train_config['codebook_weight'] * quantize_losses['codebook_loss'].item())
133 | # Adversarial loss only if disc_step_start steps passed
134 | if step_count > disc_step_start:
135 | disc_fake_pred = discriminator(model_output[0])
136 | disc_fake_loss = disc_criterion(disc_fake_pred,
137 | torch.ones(disc_fake_pred.shape,
138 | device=disc_fake_pred.device))
139 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item())
140 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps
141 | lpips_loss = torch.mean(lpips_model(output, im))
142 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item())
143 | g_loss += train_config['perceptual_weight']*lpips_loss / acc_steps
144 | losses.append(g_loss.item())
145 | g_loss.backward()
146 | #####################################
147 |
148 | ######### Optimize Discriminator #######
149 | if step_count > disc_step_start:
150 | fake = output
151 | disc_fake_pred = discriminator(fake.detach())
152 | disc_real_pred = discriminator(im)
153 | disc_fake_loss = disc_criterion(disc_fake_pred,
154 | torch.zeros(disc_fake_pred.shape,
155 | device=disc_fake_pred.device))
156 | disc_real_loss = disc_criterion(disc_real_pred,
157 | torch.ones(disc_real_pred.shape,
158 | device=disc_real_pred.device))
159 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2
160 | disc_losses.append(disc_loss.item())
161 | disc_loss = disc_loss / acc_steps
162 | disc_loss.backward()
163 | if step_count % acc_steps == 0:
164 | optimizer_d.step()
165 | optimizer_d.zero_grad()
166 | #####################################
167 |
168 | if step_count % acc_steps == 0:
169 | optimizer_g.step()
170 | optimizer_g.zero_grad()
171 | optimizer_d.step()
172 | optimizer_d.zero_grad()
173 | optimizer_g.step()
174 | optimizer_g.zero_grad()
175 | if len(disc_losses) > 0:
176 | print(
177 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | '
178 | 'Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}'.
179 | format(epoch_idx + 1,
180 | np.mean(recon_losses),
181 | np.mean(perceptual_losses),
182 | np.mean(codebook_losses),
183 | np.mean(gen_losses),
184 | np.mean(disc_losses)))
185 | else:
186 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}'.
187 | format(epoch_idx + 1,
188 | np.mean(recon_losses),
189 | np.mean(perceptual_losses),
190 | np.mean(codebook_losses)))
191 |
192 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
193 | train_config['vqvae_autoencoder_ckpt_name']))
194 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'],
195 | train_config['vqvae_discriminator_ckpt_name']))
196 | print('Done Training...')
197 |
198 |
199 | if __name__ == '__main__':
200 | parser = argparse.ArgumentParser(description='Arguments for vq vae training')
201 | parser.add_argument('--config', dest='config_path',
202 | default='config/mnist.yaml', type=str)
203 | args = parser.parse_args()
204 | train(args)
205 |
--------------------------------------------------------------------------------
/tools/sample_ddpm_text_image_cond.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | import torchvision
5 | import argparse
6 | import yaml
7 | import os
8 | from torchvision.utils import make_grid
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from models.unet_cond_base import Unet
12 | from models.vqvae import VQVAE
13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
14 | from transformers import DistilBertModel, DistilBertTokenizer, CLIPTokenizer, CLIPTextModel
15 | from utils.config_utils import *
16 | from utils.text_utils import *
17 | from dataset.celeb_dataset import CelebDataset
18 |
19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 |
21 |
22 | def sample(model, scheduler, train_config, diffusion_model_config,
23 | autoencoder_model_config, diffusion_config, dataset_config, vae, text_tokenizer, text_model):
24 | r"""
25 | Sample stepwise by going backward one timestep at a time.
26 | We save the x0 predictions
27 | """
28 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
29 |
30 | ########### Sample random noise latent ##########
31 | # For not fixing generation with one sample
32 | xt = torch.randn((1,
33 | autoencoder_model_config['z_channels'],
34 | im_size,
35 | im_size)).to(device)
36 | ###############################################
37 |
38 | ############ Create Conditional input ###############
39 | text_prompt = ['She is a woman with blond hair. She is wearing lipstick.']
40 | neg_prompts = ['He is a man.']
41 | empty_prompt = ['']
42 | text_prompt_embed = get_text_representation(text_prompt,
43 | text_tokenizer,
44 | text_model,
45 | device)
46 | # Can replace empty prompt with negative prompt
47 | empty_text_embed = get_text_representation(empty_prompt, text_tokenizer, text_model, device)
48 | assert empty_text_embed.shape == text_prompt_embed.shape
49 |
50 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
51 | validate_image_config(condition_config)
52 |
53 | # This is required to get a random but valid mask
54 | dataset = CelebDataset(split='train',
55 | im_path=dataset_config['im_path'],
56 | im_size=dataset_config['im_size'],
57 | im_channels=dataset_config['im_channels'],
58 | use_latents=True,
59 | latent_path=os.path.join(train_config['task_name'],
60 | train_config['vqvae_latent_dir_name']),
61 | condition_config=condition_config)
62 | mask_idx = random.randint(0, len(dataset.masks))
63 | mask = dataset.get_mask(mask_idx).unsqueeze(0).to(device)
64 | uncond_input = {
65 | 'text': empty_text_embed,
66 | 'image': torch.zeros_like(mask)
67 | }
68 | cond_input = {
69 | 'text': text_prompt_embed,
70 | 'image': mask
71 | }
72 | ###############################################
73 |
74 | # By default classifier free guidance is disabled
75 | # Change value in config or change default value here to enable it
76 | cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0)
77 |
78 | ################# Sampling Loop ########################
79 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
80 | # Get prediction of noise
81 | t = (torch.ones((xt.shape[0],)) * i).long().to(device)
82 | noise_pred_cond = model(xt, t, cond_input)
83 |
84 | if cf_guidance_scale > 1:
85 | noise_pred_uncond = model(xt, t, uncond_input)
86 | noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond)
87 | else:
88 | noise_pred = noise_pred_cond
89 |
90 | # Use scheduler to get x0 and xt-1
91 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
92 |
93 | # Save x0
94 | if i == 0:
95 | # Decode ONLY the final image to save time
96 | ims = vae.decode(xt)
97 | else:
98 | ims = x0_pred
99 |
100 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
101 | ims = (ims + 1) / 2
102 | grid = make_grid(ims, nrow=10)
103 | img = torchvision.transforms.ToPILImage()(grid)
104 |
105 | if not os.path.exists(os.path.join(train_config['task_name'], 'cond_text_image_samples')):
106 | os.mkdir(os.path.join(train_config['task_name'], 'cond_text_image_samples'))
107 | img.save(os.path.join(train_config['task_name'], 'cond_text_image_samples', 'x0_{}.png'.format(i)))
108 | img.close()
109 | ##############################################################
110 |
111 | def infer(args):
112 | # Read the config file #
113 | with open(args.config_path, 'r') as file:
114 | try:
115 | config = yaml.safe_load(file)
116 | except yaml.YAMLError as exc:
117 | print(exc)
118 | print(config)
119 | ########################
120 |
121 | diffusion_config = config['diffusion_params']
122 | dataset_config = config['dataset_params']
123 | diffusion_model_config = config['ldm_params']
124 | autoencoder_model_config = config['autoencoder_params']
125 | train_config = config['train_params']
126 |
127 | ########## Create the noise scheduler #############
128 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
129 | beta_start=diffusion_config['beta_start'],
130 | beta_end=diffusion_config['beta_end'])
131 | ###############################################
132 |
133 | ############# Validate the config #################
134 | condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
135 | assert condition_config is not None, ("This sampling script is for image and text conditional "
136 | "but no conditioning config found")
137 | condition_types = get_config_value(condition_config, 'condition_types', [])
138 | assert 'text' in condition_types, ("This sampling script is for image and text conditional "
139 | "but no text condition found in config")
140 | assert 'image' in condition_types, ("This sampling script is for image and text conditional "
141 | "but no image condition found in config")
142 | validate_text_config(condition_config)
143 | validate_image_config(condition_config)
144 | ###############################################
145 |
146 | ############# Load tokenizer and text model #################
147 | with torch.no_grad():
148 | # Load tokenizer and text model based on config
149 | # Also get empty text representation
150 | text_tokenizer, text_model = get_tokenizer_and_model(condition_config['text_condition_config']
151 | ['text_embed_model'], device=device)
152 | ###############################################
153 |
154 | ########## Load Unet #############
155 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
156 | model_config=diffusion_model_config).to(device)
157 | model.eval()
158 | if os.path.exists(os.path.join(train_config['task_name'],
159 | train_config['ldm_ckpt_name'])):
160 | print('Loaded unet checkpoint')
161 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
162 | train_config['ldm_ckpt_name']),
163 | map_location=device))
164 | else:
165 | raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'],
166 | train_config['ldm_ckpt_name'])))
167 | #####################################
168 |
169 | # Create output directories
170 | if not os.path.exists(train_config['task_name']):
171 | os.mkdir(train_config['task_name'])
172 |
173 | ########## Load VQVAE #############
174 | vae = VQVAE(im_channels=dataset_config['im_channels'],
175 | model_config=autoencoder_model_config).to(device)
176 | vae.eval()
177 |
178 | # Load vae if found
179 | if os.path.exists(os.path.join(train_config['task_name'],
180 | train_config['vqvae_autoencoder_ckpt_name'])):
181 | print('Loaded vae checkpoint')
182 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
183 | train_config['vqvae_autoencoder_ckpt_name']),
184 | map_location=device))
185 | else:
186 | raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['task_name'],
187 | train_config['vqvae_autoencoder_ckpt_name'])))
188 | #####################################
189 |
190 | with torch.no_grad():
191 | sample(model, scheduler, train_config, diffusion_model_config,
192 | autoencoder_model_config, diffusion_config, dataset_config, vae, text_tokenizer, text_model)
193 |
194 |
195 | if __name__ == '__main__':
196 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation '
197 | 'with text and mask conditioning')
198 | parser.add_argument('--config', dest='config_path',
199 | default='config/celebhq_text_image_cond.yaml', type=str)
200 | args = parser.parse_args()
201 | infer(args)
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Stable Diffusion Implementation in PyTorch
2 | ========
3 |
4 | This repository implements Stable Diffusion.
5 | As of today the repo provides code to do the following:
6 | * Training and Inference on Unconditional Latent Diffusion Models
7 | * Training a Class Conditional Latent Diffusion Model
8 | * Training a Text Conditioned Latent Diffusion Model
9 | * Training a Semantic Mask Conditioned Latent Diffusion Model
10 | * Any Combination of the above three conditioning
11 |
12 | For autoencoder I provide code for vae as well as vqvae.
13 | But both the stages of training use VQVAE only. One can easily change that to vae if needed
14 |
15 | For diffusion part, as of now it only implements DDPM with linear schedule.
16 |
17 |
18 | ## Stable Diffusion Tutorial Videos
19 |
20 |
22 |
23 |
24 |
26 |
27 | ___
28 |
29 |
30 | ## Sample Output for Autoencoder on CelebHQ
31 | Image - Top, Reconstructions - Below
32 |
33 |
34 |
35 | ## Sample Output for Unconditional LDM on CelebHQ (not fully converged)
36 |
37 |
38 |
39 | ## Sample Output for Conditional LDM
40 | ### Sample Output for Class Conditioned on MNIST
41 | 
42 | 
43 | 
44 | 
45 | 
46 |
47 | ### Sample Output for Text(using CLIP) and Mask Conditioned on CelebHQ (not converged)
48 |
49 |
50 |
51 | Text - She is a woman with blond hair
52 |
53 |
54 | Text - She is a woman with black hair
55 |
56 | ___
57 |
58 | ## Setup
59 | * Create a new conda environment with python 3.8 then run below commands
60 | * `conda activate `
61 | * ```git clone https://github.com/explainingai-code/StableDiffusion-PyTorch.git```
62 | * ```cd StableDiffusion-PyTorch```
63 | * ```pip install -r requirements.txt```
64 | * Download lpips weights by opening this link in browser(dont use cURL or wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```models/weights/v0.1/vgg.pth```
65 |
66 | ___
67 |
68 | ## Data Preparation
69 | ### Mnist
70 |
71 | For setting up the mnist dataset follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
72 |
73 | Ensure directory structure is following
74 | ```
75 | StableDiffusion-PyTorch
76 | -> data
77 | -> mnist
78 | -> train
79 | -> images
80 | -> *.png
81 | -> test
82 | -> images
83 | -> *.png
84 | ```
85 |
86 | ### CelebHQ
87 | #### Unconditional
88 | For setting up on CelebHQ for unconditional, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file).
89 |
90 | Ensure directory structure is the following
91 | ```
92 | StableDiffusion-PyTorch
93 | -> data
94 | -> CelebAMask-HQ
95 | -> CelebA-HQ-img
96 | -> *.jpg
97 |
98 | ```
99 | #### Mask Conditional
100 | For CelebHQ for mask conditional LDM additionally do the following:
101 |
102 | Ensure directory structure is the following
103 | ```
104 | StableDiffusion-PyTorch
105 | -> data
106 | -> CelebAMask-HQ
107 | -> CelebA-HQ-img
108 | -> *.jpg
109 | -> CelebAMask-HQ-mask-anno
110 | -> 0/1/2/3.../14
111 | -> *.png
112 |
113 | ```
114 |
115 | * Run `python -m utils.create_celeb_mask` from repo root to create the mask images from mask annotations
116 |
117 | Ensure directory structure is the following
118 | ```
119 | StableDiffusion-PyTorch
120 | -> data
121 | -> CelebAMask-HQ
122 | -> CelebA-HQ-img
123 | -> *.jpg
124 | -> CelebAMask-HQ-mask-anno
125 | -> 0/1/2/3.../14
126 | -> *.png
127 | -> CelebAMask-HQ-mask
128 | -> *.png
129 | ```
130 |
131 | #### Text Conditional
132 | For CelebHQ for text conditional LDM additionally do the following:
133 | * The repo uses captions collected as part of this repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file
134 | * Download the captions from the `text` link provided in the repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file#overview
135 | * This will download a `celeba-captions` folder, simply move this inside the `data/CelebAMask-HQ` folder as that is where the dataset class expects it to be.
136 |
137 | Ensure directory structure is the following
138 | ```
139 | StableDiffusion-PyTorch
140 | -> data
141 | -> CelebAMask-HQ
142 | -> CelebA-HQ-img
143 | -> *.jpg
144 | -> CelebAMask-HQ-mask-anno
145 | -> 0/1/2/3.../14
146 | -> *.png
147 | -> CelebAMask-HQ-mask
148 | -> *.png
149 | -> celeba-caption
150 | -> *.txt
151 | ```
152 | ---
153 | ## Configuration
154 | Allows you to play with different components of ddpm and autoencoder training
155 | * ```config/mnist.yaml``` - Small autoencoder and ldm can even be trained on CPU
156 | * ```config/celebhq.yaml``` - Configuration used for celebhq dataset
157 |
158 | Relevant configuration parameters
159 |
160 | Most parameters are self explanatory but below I mention couple which are specific to this repo.
161 | * ```autoencoder_acc_steps``` : For accumulating gradients if image size is too large for larger batch sizes
162 | * ```save_latents``` : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster
163 |
164 | ___
165 | ## Training
166 | The repo provides training and inference for Mnist(Unconditional and Class Conditional) and CelebHQ (Unconditional, Text and/or Mask Conditional).
167 |
168 | For working on your own dataset:
169 | * Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance)
170 | * Create your own dataset class which will just collect all the filenames and return the image in its getitem method. Look at `mnist_dataset.py` or `celeb_dataset.py` for guidance
171 |
172 | Once the config and dataset is setup:
173 | * Train the auto encoder on your dataset using [this section](#training-autoencoder-for-ldm)
174 | * For training Unconditional LDM follow [this section](#training-unconditional-ldm)
175 | * For class conditional ldm go through [this section](#training-class-conditional-ldm)
176 | * For text conditional ldm go through [this section](#training-text-conditional-ldm)
177 | * For text and mask conditional ldm go through [this section](#training-text-and-mask-conditional-ldm)
178 |
179 |
180 | ## Training AutoEncoder for LDM
181 | * For training autoencoder on mnist,ensure the right path is mentioned in `mnist.yaml`
182 | * For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml`
183 | * For training autoencoder on your own dataset
184 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance)
185 | * Create your own dataset class, similar to celeb_dataset.py without conditining parts
186 | * Map the dataset name to the right class in the training code [here](https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/tools/train_ddpm_vqvae.py#L40)
187 | * For training autoencoder run ```python -m tools.train_vqvae --config config/mnist.yaml``` for training vqvae with the desire config file
188 | * For inference using trained autoencoder run```python -m tools.infer_vqvae --config config/mnist.yaml``` for generating reconstructions with right config file. Use save_latent in config to save the latent files
189 |
190 |
191 | ## Training Unconditional LDM
192 | Train the autoencoder first and setup dataset accordingly.
193 |
194 | For training unconditional LDM map the dataset to the right class in `train_ddpm_vqvae.py`
195 | * ```python -m tools.train_ddpm_vqvae --config config/mnist.yaml``` for training unconditional ddpm using right config
196 | * ```python -m tools.sample_ddpm_vqvae --config config/mnist.yaml``` for generating images using trained ddpm
197 |
198 | ## Training Conditional LDM
199 | For training conditional models we need two changes:
200 | * Dataset classes must provide the additional conditional inputs(see below)
201 | * Config must be changed with additional conditioning config added
202 |
203 | Specifically the dataset `getitem` will return the following:
204 | * `image_tensor` for unconditional training
205 | * tuple of `(image_tensor, cond_input )` for conditional training where cond_input is a dictionary consisting of keys ```{class/text/image}```
206 |
207 | ### Training Class Conditional LDM
208 | The repo provides class conditional latent diffusion model training code for mnist dataset, so one
209 | can use that to follow the same for their own dataset
210 |
211 | * Use `mnist_class_cond.yaml` config file as a guide to create your class conditional config file.
212 | Specifically following new keys need to be modified according to your dataset within `ldm_params`.
213 | * ```
214 | condition_config:
215 | condition_types: ['class']
216 | class_condition_config :
217 | num_classes :
218 | cond_drop_prob :
219 | ```
220 | * Create a dataset class similar to mnist where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
221 | * For class, conditional input will ONLY be the integer class
222 | * ```
223 | (image_tensor, {
224 | 'class' : {0/1/.../num_classes}
225 | })
226 |
227 | For training class conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
228 | * ```python -m tools.train_ddpm_cond --config config/mnist_class_cond.yaml``` for training class conditional on mnist
229 | * ```python -m tools.sample_ddpm_class_cond --config config/mnist.yaml``` for generating images using class conditional trained ddpm
230 |
231 | ### Training Text Conditional LDM
232 | The repo provides text conditional latent diffusion model training code for celebhq dataset, so one
233 | can use that to follow the same for their own dataset
234 |
235 | * Use `celebhq_text_cond.yaml` config file as a guide to create your config file.
236 | Specifically following new keys need to be modified according to your dataset within `ldm_params`.
237 | * ```
238 | condition_config:
239 | condition_types: [ 'text' ]
240 | text_condition_config:
241 | text_embed_model: 'clip' or 'bert'
242 | text_embed_dim: 512 or 768
243 | cond_drop_prob: 0.1
244 | ```
245 | * Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
246 | * For text, conditional input will ONLY be the caption
247 | * ```
248 | (image_tensor, {
249 | 'text' : 'a sample caption for image_tensor'
250 | })
251 |
252 | For training text conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
253 | * ```python -m tools.train_ddpm_cond --config config/celebhq_text_cond.yaml``` for training text conditioned ldm on celebhq
254 | * ```python -m tools.sample_ddpm_text_cond --config config/celebhq_text_cond.yaml``` for generating images using text conditional trained ddpm
255 |
256 | ### Training Text and Mask Conditional LDM
257 | The repo provides text and mask conditional latent diffusion model training code for celebhq dataset, so one
258 | can use that to follow the same for their own dataset and can even use that train a mask only conditional ldm
259 |
260 | * Use `celebhq_text_image_cond.yaml` config file as a guide to create your config file.
261 | Specifically following new keys need to be modified according to your dataset within `ldm_params`.
262 | * ```
263 | condition_config:
264 | condition_types: [ 'text', 'image' ]
265 | text_condition_config:
266 | text_embed_model: 'clip' or 'bert
267 | text_embed_dim: 512 or 768
268 | cond_drop_prob: 0.1
269 | image_condition_config:
270 | image_condition_input_channels: 18
271 | image_condition_output_channels: 3
272 | image_condition_h : 512
273 | image_condition_w : 512
274 | cond_drop_prob: 0.1
275 | ```
276 | * Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
277 | * For text and mask, conditional input will be caption and mask image
278 | * ```
279 | (image_tensor, {
280 | 'text' : 'a sample caption for image_tensor',
281 | 'image' : NUM_CLASSES x MASK_H x MASK_W
282 | })
283 |
284 | For training text unconditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
285 | * ```python -m tools.train_ddpm_cond --config config/celebhq_text_image_cond.yaml``` for training text and mask conditioned ldm on celebhq
286 | * ```python -m tools.sample_ddpm_text_image_cond --config config/celebhq_text_image_cond.yaml``` for generating images using text and mask conditional trained ddpm
287 |
288 |
289 | ## Output
290 | Outputs will be saved according to the configuration present in yaml files.
291 |
292 | For every run a folder of ```task_name``` key in config will be created
293 |
294 | During training of autoencoder the following output will be saved
295 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory
296 | * Sample reconstructions in ```task_name/vqvae_autoencoder_samples```
297 |
298 | During inference of autoencoder the following output will be saved
299 | * Reconstructions for random images in ```task_name```
300 | * Latents will be save in ```task_name/vqvae_latent_dir_name``` if mentioned in config
301 |
302 | During training and inference of ddpm following output will be saved
303 | * During training of unconditional or conditional DDPM we will save the latest checkpoint in ```task_name``` directory
304 | * During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
305 | * During sampling, class conditionally sampled image grid for all timesteps in ```task_name/cond_class_samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
306 | * During sampling, text only conditionally sampled image grid for all timesteps in ```task_name/cond_text_samples/*.png``` . The final decoded generated image will be `x0_0.png` . Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
307 | * During sampling, image only conditionally sampled image grid for all timesteps in ```task_name/cond_text_image_samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
308 |
309 |
310 |
311 |
312 |
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def get_time_embedding(time_steps, temb_dim):
6 | r"""
7 | Convert time steps tensor into an embedding using the
8 | sinusoidal time embedding formula
9 | :param time_steps: 1D tensor of length batch size
10 | :param temb_dim: Dimension of the embedding
11 | :return: BxD embedding representation of B time steps
12 | """
13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
14 |
15 | # factor = 10000^(2i/d_model)
16 | factor = 10000 ** ((torch.arange(
17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
18 | )
19 |
20 | # pos / factor
21 | # timesteps B -> B, 1 -> B, temb_dim
22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
24 | return t_emb
25 |
26 |
27 | class DownBlock(nn.Module):
28 | r"""
29 | Down conv block with attention.
30 | Sequence of following block
31 | 1. Resnet block with time embedding
32 | 2. Attention block
33 | 3. Downsample
34 | """
35 |
36 | def __init__(self, in_channels, out_channels, t_emb_dim,
37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
38 | super().__init__()
39 | self.num_layers = num_layers
40 | self.down_sample = down_sample
41 | self.attn = attn
42 | self.context_dim = context_dim
43 | self.cross_attn = cross_attn
44 | self.t_emb_dim = t_emb_dim
45 | self.resnet_conv_first = nn.ModuleList(
46 | [
47 | nn.Sequential(
48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
49 | nn.SiLU(),
50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
51 | kernel_size=3, stride=1, padding=1),
52 | )
53 | for i in range(num_layers)
54 | ]
55 | )
56 | if self.t_emb_dim is not None:
57 | self.t_emb_layers = nn.ModuleList([
58 | nn.Sequential(
59 | nn.SiLU(),
60 | nn.Linear(self.t_emb_dim, out_channels)
61 | )
62 | for _ in range(num_layers)
63 | ])
64 | self.resnet_conv_second = nn.ModuleList(
65 | [
66 | nn.Sequential(
67 | nn.GroupNorm(norm_channels, out_channels),
68 | nn.SiLU(),
69 | nn.Conv2d(out_channels, out_channels,
70 | kernel_size=3, stride=1, padding=1),
71 | )
72 | for _ in range(num_layers)
73 | ]
74 | )
75 |
76 | if self.attn:
77 | self.attention_norms = nn.ModuleList(
78 | [nn.GroupNorm(norm_channels, out_channels)
79 | for _ in range(num_layers)]
80 | )
81 |
82 | self.attentions = nn.ModuleList(
83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
84 | for _ in range(num_layers)]
85 | )
86 |
87 | if self.cross_attn:
88 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
89 | self.cross_attention_norms = nn.ModuleList(
90 | [nn.GroupNorm(norm_channels, out_channels)
91 | for _ in range(num_layers)]
92 | )
93 | self.cross_attentions = nn.ModuleList(
94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
95 | for _ in range(num_layers)]
96 | )
97 | self.context_proj = nn.ModuleList(
98 | [nn.Linear(context_dim, out_channels)
99 | for _ in range(num_layers)]
100 | )
101 |
102 | self.residual_input_conv = nn.ModuleList(
103 | [
104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
105 | for i in range(num_layers)
106 | ]
107 | )
108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
109 | 4, 2, 1) if self.down_sample else nn.Identity()
110 |
111 | def forward(self, x, t_emb=None, context=None):
112 | out = x
113 | for i in range(self.num_layers):
114 | # Resnet block of Unet
115 | resnet_input = out
116 | out = self.resnet_conv_first[i](out)
117 | if self.t_emb_dim is not None:
118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
119 | out = self.resnet_conv_second[i](out)
120 | out = out + self.residual_input_conv[i](resnet_input)
121 |
122 | if self.attn:
123 | # Attention block of Unet
124 | batch_size, channels, h, w = out.shape
125 | in_attn = out.reshape(batch_size, channels, h * w)
126 | in_attn = self.attention_norms[i](in_attn)
127 | in_attn = in_attn.transpose(1, 2)
128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
130 | out = out + out_attn
131 |
132 | if self.cross_attn:
133 | assert context is not None, "context cannot be None if cross attention layers are used"
134 | batch_size, channels, h, w = out.shape
135 | in_attn = out.reshape(batch_size, channels, h * w)
136 | in_attn = self.cross_attention_norms[i](in_attn)
137 | in_attn = in_attn.transpose(1, 2)
138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
139 | context_proj = self.context_proj[i](context)
140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
142 | out = out + out_attn
143 |
144 | # Downsample
145 | out = self.down_sample_conv(out)
146 | return out
147 |
148 |
149 | class MidBlock(nn.Module):
150 | r"""
151 | Mid conv block with attention.
152 | Sequence of following blocks
153 | 1. Resnet block with time embedding
154 | 2. Attention block
155 | 3. Resnet block with time embedding
156 | """
157 |
158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
159 | super().__init__()
160 | self.num_layers = num_layers
161 | self.t_emb_dim = t_emb_dim
162 | self.context_dim = context_dim
163 | self.cross_attn = cross_attn
164 | self.resnet_conv_first = nn.ModuleList(
165 | [
166 | nn.Sequential(
167 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
168 | nn.SiLU(),
169 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
170 | padding=1),
171 | )
172 | for i in range(num_layers + 1)
173 | ]
174 | )
175 |
176 | if self.t_emb_dim is not None:
177 | self.t_emb_layers = nn.ModuleList([
178 | nn.Sequential(
179 | nn.SiLU(),
180 | nn.Linear(t_emb_dim, out_channels)
181 | )
182 | for _ in range(num_layers + 1)
183 | ])
184 | self.resnet_conv_second = nn.ModuleList(
185 | [
186 | nn.Sequential(
187 | nn.GroupNorm(norm_channels, out_channels),
188 | nn.SiLU(),
189 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
190 | )
191 | for _ in range(num_layers + 1)
192 | ]
193 | )
194 |
195 | self.attention_norms = nn.ModuleList(
196 | [nn.GroupNorm(norm_channels, out_channels)
197 | for _ in range(num_layers)]
198 | )
199 |
200 | self.attentions = nn.ModuleList(
201 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
202 | for _ in range(num_layers)]
203 | )
204 | if self.cross_attn:
205 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
206 | self.cross_attention_norms = nn.ModuleList(
207 | [nn.GroupNorm(norm_channels, out_channels)
208 | for _ in range(num_layers)]
209 | )
210 | self.cross_attentions = nn.ModuleList(
211 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
212 | for _ in range(num_layers)]
213 | )
214 | self.context_proj = nn.ModuleList(
215 | [nn.Linear(context_dim, out_channels)
216 | for _ in range(num_layers)]
217 | )
218 | self.residual_input_conv = nn.ModuleList(
219 | [
220 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
221 | for i in range(num_layers + 1)
222 | ]
223 | )
224 |
225 | def forward(self, x, t_emb=None, context=None):
226 | out = x
227 |
228 | # First resnet block
229 | resnet_input = out
230 | out = self.resnet_conv_first[0](out)
231 | if self.t_emb_dim is not None:
232 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
233 | out = self.resnet_conv_second[0](out)
234 | out = out + self.residual_input_conv[0](resnet_input)
235 |
236 | for i in range(self.num_layers):
237 | # Attention Block
238 | batch_size, channels, h, w = out.shape
239 | in_attn = out.reshape(batch_size, channels, h * w)
240 | in_attn = self.attention_norms[i](in_attn)
241 | in_attn = in_attn.transpose(1, 2)
242 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
243 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
244 | out = out + out_attn
245 |
246 | if self.cross_attn:
247 | assert context is not None, "context cannot be None if cross attention layers are used"
248 | batch_size, channels, h, w = out.shape
249 | in_attn = out.reshape(batch_size, channels, h * w)
250 | in_attn = self.cross_attention_norms[i](in_attn)
251 | in_attn = in_attn.transpose(1, 2)
252 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
253 | context_proj = self.context_proj[i](context)
254 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
255 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
256 | out = out + out_attn
257 |
258 |
259 | # Resnet Block
260 | resnet_input = out
261 | out = self.resnet_conv_first[i + 1](out)
262 | if self.t_emb_dim is not None:
263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
264 | out = self.resnet_conv_second[i + 1](out)
265 | out = out + self.residual_input_conv[i + 1](resnet_input)
266 |
267 | return out
268 |
269 |
270 | class UpBlock(nn.Module):
271 | r"""
272 | Up conv block with attention.
273 | Sequence of following blocks
274 | 1. Upsample
275 | 1. Concatenate Down block output
276 | 2. Resnet block with time embedding
277 | 3. Attention Block
278 | """
279 |
280 | def __init__(self, in_channels, out_channels, t_emb_dim,
281 | up_sample, num_heads, num_layers, attn, norm_channels):
282 | super().__init__()
283 | self.num_layers = num_layers
284 | self.up_sample = up_sample
285 | self.t_emb_dim = t_emb_dim
286 | self.attn = attn
287 | self.resnet_conv_first = nn.ModuleList(
288 | [
289 | nn.Sequential(
290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
291 | nn.SiLU(),
292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
293 | padding=1),
294 | )
295 | for i in range(num_layers)
296 | ]
297 | )
298 |
299 | if self.t_emb_dim is not None:
300 | self.t_emb_layers = nn.ModuleList([
301 | nn.Sequential(
302 | nn.SiLU(),
303 | nn.Linear(t_emb_dim, out_channels)
304 | )
305 | for _ in range(num_layers)
306 | ])
307 |
308 | self.resnet_conv_second = nn.ModuleList(
309 | [
310 | nn.Sequential(
311 | nn.GroupNorm(norm_channels, out_channels),
312 | nn.SiLU(),
313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
314 | )
315 | for _ in range(num_layers)
316 | ]
317 | )
318 | if self.attn:
319 | self.attention_norms = nn.ModuleList(
320 | [
321 | nn.GroupNorm(norm_channels, out_channels)
322 | for _ in range(num_layers)
323 | ]
324 | )
325 |
326 | self.attentions = nn.ModuleList(
327 | [
328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
329 | for _ in range(num_layers)
330 | ]
331 | )
332 |
333 | self.residual_input_conv = nn.ModuleList(
334 | [
335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
336 | for i in range(num_layers)
337 | ]
338 | )
339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
340 | 4, 2, 1) \
341 | if self.up_sample else nn.Identity()
342 |
343 | def forward(self, x, out_down=None, t_emb=None):
344 | # Upsample
345 | x = self.up_sample_conv(x)
346 |
347 | # Concat with Downblock output
348 | if out_down is not None:
349 | x = torch.cat([x, out_down], dim=1)
350 |
351 | out = x
352 | for i in range(self.num_layers):
353 | # Resnet Block
354 | resnet_input = out
355 | out = self.resnet_conv_first[i](out)
356 | if self.t_emb_dim is not None:
357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
358 | out = self.resnet_conv_second[i](out)
359 | out = out + self.residual_input_conv[i](resnet_input)
360 |
361 | # Self Attention
362 | if self.attn:
363 | batch_size, channels, h, w = out.shape
364 | in_attn = out.reshape(batch_size, channels, h * w)
365 | in_attn = self.attention_norms[i](in_attn)
366 | in_attn = in_attn.transpose(1, 2)
367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
369 | out = out + out_attn
370 | return out
371 |
372 |
373 | class UpBlockUnet(nn.Module):
374 | r"""
375 | Up conv block with attention.
376 | Sequence of following blocks
377 | 1. Upsample
378 | 1. Concatenate Down block output
379 | 2. Resnet block with time embedding
380 | 3. Attention Block
381 | """
382 |
383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
385 | super().__init__()
386 | self.num_layers = num_layers
387 | self.up_sample = up_sample
388 | self.t_emb_dim = t_emb_dim
389 | self.cross_attn = cross_attn
390 | self.context_dim = context_dim
391 | self.resnet_conv_first = nn.ModuleList(
392 | [
393 | nn.Sequential(
394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
395 | nn.SiLU(),
396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
397 | padding=1),
398 | )
399 | for i in range(num_layers)
400 | ]
401 | )
402 |
403 | if self.t_emb_dim is not None:
404 | self.t_emb_layers = nn.ModuleList([
405 | nn.Sequential(
406 | nn.SiLU(),
407 | nn.Linear(t_emb_dim, out_channels)
408 | )
409 | for _ in range(num_layers)
410 | ])
411 |
412 | self.resnet_conv_second = nn.ModuleList(
413 | [
414 | nn.Sequential(
415 | nn.GroupNorm(norm_channels, out_channels),
416 | nn.SiLU(),
417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
418 | )
419 | for _ in range(num_layers)
420 | ]
421 | )
422 |
423 | self.attention_norms = nn.ModuleList(
424 | [
425 | nn.GroupNorm(norm_channels, out_channels)
426 | for _ in range(num_layers)
427 | ]
428 | )
429 |
430 | self.attentions = nn.ModuleList(
431 | [
432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
433 | for _ in range(num_layers)
434 | ]
435 | )
436 |
437 | if self.cross_attn:
438 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
439 | self.cross_attention_norms = nn.ModuleList(
440 | [nn.GroupNorm(norm_channels, out_channels)
441 | for _ in range(num_layers)]
442 | )
443 | self.cross_attentions = nn.ModuleList(
444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
445 | for _ in range(num_layers)]
446 | )
447 | self.context_proj = nn.ModuleList(
448 | [nn.Linear(context_dim, out_channels)
449 | for _ in range(num_layers)]
450 | )
451 | self.residual_input_conv = nn.ModuleList(
452 | [
453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
454 | for i in range(num_layers)
455 | ]
456 | )
457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
458 | 4, 2, 1) \
459 | if self.up_sample else nn.Identity()
460 |
461 | def forward(self, x, out_down=None, t_emb=None, context=None):
462 | x = self.up_sample_conv(x)
463 | if out_down is not None:
464 | x = torch.cat([x, out_down], dim=1)
465 |
466 | out = x
467 | for i in range(self.num_layers):
468 | # Resnet
469 | resnet_input = out
470 | out = self.resnet_conv_first[i](out)
471 | if self.t_emb_dim is not None:
472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
473 | out = self.resnet_conv_second[i](out)
474 | out = out + self.residual_input_conv[i](resnet_input)
475 | # Self Attention
476 | batch_size, channels, h, w = out.shape
477 | in_attn = out.reshape(batch_size, channels, h * w)
478 | in_attn = self.attention_norms[i](in_attn)
479 | in_attn = in_attn.transpose(1, 2)
480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
482 | out = out + out_attn
483 | # Cross Attention
484 | if self.cross_attn:
485 | assert context is not None, "context cannot be None if cross attention layers are used"
486 | batch_size, channels, h, w = out.shape
487 | in_attn = out.reshape(batch_size, channels, h * w)
488 | in_attn = self.cross_attention_norms[i](in_attn)
489 | in_attn = in_attn.transpose(1, 2)
490 | assert len(context.shape) == 3, \
491 | "Context shape does not match B,_,CONTEXT_DIM"
492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
493 | "Context shape does not match B,_,CONTEXT_DIM"
494 | context_proj = self.context_proj[i](context)
495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
497 | out = out + out_attn
498 |
499 | return out
500 |
501 |
502 |
--------------------------------------------------------------------------------