├── .gitignore
├── LICENSE
├── README.md
├── config
├── __init__.py
├── celebhq.yaml
└── mnist.yaml
├── data
└── .gitkeep
├── dataset
├── __init__.py
├── celeb_dataset.py
└── mnist_dataset.py
├── models
├── __init__.py
├── blocks.py
├── controlnet.py
├── controlnet_ldm.py
├── discriminator.py
├── lpips.py
├── unet_base.py
├── unet_cond_base.py
├── vae.py
└── weights
│ └── v0.1
│ └── .gitkeep
├── requirements.txt
├── scheduler
├── __init__.py
└── linear_noise_scheduler.py
├── tools
├── __init__.py
├── infer_vae.py
├── sample_ddpm.py
├── sample_ddpm_controlnet.py
├── sample_ldm_controlnet.py
├── sample_ldm_vae.py
├── train_ddpm.py
├── train_ddpm_controlnet.py
├── train_ldm_controlnet.py
├── train_ldm_vae.py
└── train_vae.py
└── utils
├── __init__.py
├── config_utils.py
└── diffusion_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all image files
2 | *.jpg
3 | *.png
4 | *.jpeg
5 |
6 | # Ignore pycharm and system files
7 | .DS_Store
8 | *.idea
9 | __pycache__
10 | *.zip
11 |
12 | # Ignore dataset files
13 | *.csv
14 | *.json
15 |
16 | # Ignore checkpoints
17 | *.pth
18 |
19 | # Ignore pickle files
20 | *.pkl
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ExplainingAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ControlNet Implementation in PyTorch
2 | ========
3 | ## ControlNet Tutorial Video
4 |
5 |
7 |
8 |
9 | ## Sample Output for ControlNet with DDPM on MNIST and with LDM on CelebHQ
10 | Canny Edge Control - Top, Sample - Below
11 |
12 |
13 |
14 | ___
15 |
16 | This repository implements ControlNet in PyTorch for diffusion models.
17 | As of now, the repo provides code to do the following:
18 | * Training and Inference of Unconditional DDPM on MNIST dataset
19 | * Training and Inference of ControlNet with DDPM on MNIST using canny edges
20 | * Training and Inference of Unconditional Latent Diffusion Model on CelebHQ dataset(resized to 128x128 with latent images being 32x32)
21 | * Training and Inference of ControlNet with Unconditional Latent Diffusion Model on CelebHQ using canny edges
22 |
23 |
24 | For autoencoder of Latent Diffusion Model, I provide training and inference code for vae.
25 |
26 | ## Setup
27 | * Create a new conda environment with python 3.10 then run below commands
28 | * `conda activate `
29 | * ```git clone https://github.com/explainingai-code/ControlNet-PyTorch.git```
30 | * ```cd ControlNet-PyTorch```
31 | * ```pip install -r requirements.txt```
32 | * Download lpips weights by opening this link in browser(dont use cURL or wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```models/weights/v0.1/vgg.pth```
33 | ___
34 |
35 | ## Data Preparation
36 | ### Mnist
37 |
38 | For setting up the mnist dataset follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
39 |
40 | Ensure directory structure is following
41 | ```
42 | ControlNet-PyTorch
43 | -> data
44 | -> mnist
45 | -> train
46 | -> images
47 | -> *.png
48 | -> test
49 | -> images
50 | -> *.png
51 | ```
52 |
53 | ### CelebHQ
54 | For setting up on CelebHQ, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file).
55 |
56 | Ensure directory structure is the following
57 | ```
58 | ControlNet-PyTorch
59 | -> data
60 | -> CelebAMask-HQ
61 | -> CelebA-HQ-img
62 | -> *.jpg
63 |
64 | ```
65 | ---
66 | ## Configuration
67 | Allows you to play with different components of ddpm and autoencoder training
68 | * ```config/mnist.yaml``` - Config for MNIST dataset
69 | * ```config/celebhq.yaml``` - Configuration used for celebhq dataset
70 |
71 | Relevant configuration parameters
72 |
73 | Most parameters are self-explanatory but below I mention couple which are specific to this repo.
74 | * ```autoencoder_acc_steps``` : For accumulating gradients if image size is too large for larger batch sizes
75 | * ```save_latents``` : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster
76 |
77 | ___
78 | ## Training
79 | The repo provides training and inference for Mnist(Unconditional DDPM) and CelebHQ (Unconditional LDM) and ControlNet with both these variations using canny edges.
80 |
81 | For working on your own dataset:
82 | * Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance)
83 | * Create your own dataset class which will just collect all the filenames and return the image and its hint in its getitem method. Look at `mnist_dataset.py` or `celeb_dataset.py` for guidance
84 |
85 | Once the config and dataset is setup:
86 | * For training and inference of Unconditional DDPM follow [this section](#training-unconditional-ddpm)
87 | * For training and inference of ControlNet with Unconditional DDPM follow [this section](#training-controlnet-for-unconditional-ddpm)
88 | * Train the auto encoder on your dataset using [this section](#training-autoencoder-for-ldm)
89 | * For training and inference of Unconditional LDM follow [this section](#training-unconditional-ldm)
90 | * For training and inference of ControlNet with Unconditional LDM follow [this section](#training-controlnet-for-unconditional-ldm)
91 |
92 |
93 |
94 | ## Training Unconditional DDPM
95 | * For training ddpm on mnist,ensure the right path is mentioned in `mnist.yaml`
96 | * For training ddpm on your own dataset
97 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance)
98 | * Create your own dataset class, similar to celeb_dataset.py
99 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ddpm.py#L40)
100 | * For training DDPM run ```python -m tools.train_ddpm --config config/mnist.yaml``` for training ddpm with the desire config file
101 | * For inference run ```python -m tools.sample_ddpm --config config/mnist.yaml``` for generating samples with right config file.
102 |
103 | ## Training ControlNet for Unconditional DDPM
104 | * For training controlnet, ensure the right path is mentioned in `mnist.yaml`
105 | * For training controlnet with ddpm on your own dataset
106 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance)
107 | * Create your own dataset class, similar to celeb_dataset.py
108 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ddpm_controlnet.py#L40)
109 | * Ensure ```return_hints``` is passed as True in the dataset class initialization
110 | * For training controlnet run ```python -m tools.train_ddpm_controlnet --config config/mnist.yaml``` for training controlnet ddpm with the desire config file
111 | * For inference run ```python -m tools.sample_ddpm_controlnet --config config/mnist.yaml``` for generating ddpm samples using canny hints with right config file.
112 |
113 |
114 | ## Training AutoEncoder for LDM
115 | * For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml`
116 | * For training autoencoder on your own dataset
117 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance)
118 | * Create your own dataset class, similar to celeb_dataset.py
119 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_vae.py#L49)
120 | * For training autoencoder run ```python -m tools.train_vae --config config/celebhq.yaml``` for training autoencoder with the desire config file
121 | * For inference make sure `save_latent` is `True` in the config
122 | * For inference run ```python -m tools.infer_vae --config config/celebhq.yaml``` for generating reconstructions and saving latents with right config file.
123 |
124 |
125 | ## Training Unconditional LDM
126 | Train the autoencoder first and setup dataset accordingly.
127 |
128 | For training unconditional LDM ensure the right dataset is used in `train_ldm_vae.py` [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ldm_vae.py#L43)
129 | * ```python -m tools.train_ldm_vae --config config/celebhq.yaml``` for training unconditional ldm using right config
130 | * ```python -m tools.sample_ldm_vae --config config/celebhq.yaml``` for generating images using trained ldm
131 |
132 |
133 | ## Training ControlNet for Unconditional LDM
134 | * For training controlnet with celebhq, ensure the right path is mentioned in `celebhq.yaml`
135 | * For training controlnet with ldm on your own dataset
136 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance)
137 | * Create your own dataset class, similar to celeb_dataset.py
138 | * Ensure Autoencoder and LDM have already been trained
139 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ldm_controlnet.py#L43)
140 | * Ensure ```return_hints``` is passed as True in the dataset class initialization
141 | * Ensure ```down_sample_factor``` is correctly computed in the model initialization [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ldm_controlnet.py#L60)
142 | * For training controlnet run ```python -m tools.train_ldm_controlnet --config config/celebhq.yaml``` for training controlnet ldm with the desire config file
143 | * For inference with controlnet run ```python -m tools.sample_ldm_controlnet --config config/celebhq.yaml``` for generating ldm samples using canny hints with right config file.
144 |
145 |
146 | ## Output
147 | Outputs will be saved according to the configuration present in yaml files.
148 |
149 | For every run a folder of ```task_name``` key in config will be created
150 |
151 | During training of autoencoder the following output will be saved
152 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory
153 | * Sample reconstructions in ```task_name/vae_autoencoder_samples```
154 |
155 | During inference of autoencoder the following output will be saved
156 | * Reconstructions for random images in ```task_name```
157 | * Latents will be save in ```task_name/vae_latent_dir_name``` if mentioned in config
158 |
159 | During training and inference of unconditional ddpm or ldm following output will be saved:
160 | * During training we will save the latest checkpoint in ```task_name``` directory
161 | * During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
162 |
163 | During training and inference of controlnet with ddpm/ldm following output will be saved:
164 | * During training we will save the latest checkpoint in ```task_name``` directory
165 | * During sampling, randomly selected hints and generated samples will be saved in ```task_name/hint.png``` and ```task_name/controlnet_samples/*.png```. The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0
166 |
167 |
168 |
169 | ## Citations
170 | ```
171 | @misc{zhang2023addingconditionalcontroltexttoimage,
172 | title={Adding Conditional Control to Text-to-Image Diffusion Models},
173 | author={Lvmin Zhang and Anyi Rao and Maneesh Agrawala},
174 | year={2023},
175 | eprint={2302.05543},
176 | archivePrefix={arXiv},
177 | primaryClass={cs.CV},
178 | url={https://arxiv.org/abs/2302.05543},
179 | }
180 | ```
181 |
182 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/config/__init__.py
--------------------------------------------------------------------------------
/config/celebhq.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/CelebAMask-HQ'
3 | im_channels : 3
4 | im_size : 128
5 | canny_im_size : 1024
6 | name: 'celebhq'
7 |
8 | diffusion_params:
9 | num_timesteps : 1000
10 | beta_start : 0.0015
11 | beta_end : 0.0195
12 |
13 | ldm_params:
14 | hint_channels : 3
15 | down_channels: [ 256, 384, 512, 768 ]
16 | mid_channels: [ 768, 512 ]
17 | down_sample: [ True, True, True ]
18 | attn_down : [True, True, True]
19 | time_emb_dim: 512
20 | norm_channels: 32
21 | num_heads: 16
22 | conv_out_channels : 128
23 | num_down_layers : 2
24 | num_mid_layers : 2
25 | num_up_layers : 2
26 |
27 | autoencoder_params:
28 | z_channels: 4
29 | codebook_size : 8192
30 | down_channels : [128, 256, 384]
31 | mid_channels : [384]
32 | down_sample : [True, True]
33 | attn_down : [False, False]
34 | norm_channels: 32
35 | num_heads: 4
36 | num_down_layers : 2
37 | num_mid_layers : 2
38 | num_up_layers : 2
39 |
40 |
41 | train_params:
42 | seed : 1111
43 | task_name: 'celebhq'
44 | ldm_batch_size: 16
45 | autoencoder_batch_size: 4
46 | disc_start: 7500
47 | disc_weight: 0.5
48 | codebook_weight: 1
49 | commitment_beta: 0.2
50 | perceptual_weight: 1
51 | kl_weight: 0.000005
52 | ldm_epochs: 200
53 | autoencoder_epochs: 3
54 | controlnet_epochs : 15
55 | num_samples: 2
56 | num_grid_rows: 2
57 | ldm_lr: 0.000025
58 | ldm_lr_steps : [25, 50, 75, 100]
59 | autoencoder_lr: 0.00001
60 | controlnet_lr: 0.00001
61 | controlnet_lr_steps : [10]
62 | autoencoder_acc_steps: 1
63 | autoencoder_img_save_steps: 64
64 | save_latents : True
65 | vae_latent_dir_name: 'vae_latents'
66 | vqvae_latent_dir_name: 'vqvae_latents'
67 | ldm_ckpt_name: 'ddpm_ckpt.pth'
68 | controlnet_ckpt_name: 'ddpm_controlnet_ckpt.pth'
69 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
70 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
71 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
72 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
73 |
--------------------------------------------------------------------------------
/config/mnist.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'data/mnist/train/images'
3 | im_test_path: 'data/mnist/test/images'
4 | canny_im_size: 28
5 |
6 | diffusion_params:
7 | num_timesteps : 1000
8 | beta_start : 0.0001
9 | beta_end : 0.02
10 |
11 | model_params:
12 | im_channels : 1
13 | im_size : 28
14 | hint_channels : 3
15 | down_channels : [32, 64, 128, 256]
16 | mid_channels : [256, 256, 128]
17 | down_sample : [True, True, False]
18 | time_emb_dim : 128
19 | num_down_layers : 2
20 | num_mid_layers : 2
21 | num_up_layers : 2
22 | num_heads : 4
23 |
24 | train_params:
25 | task_name: 'mnist'
26 | batch_size: 64
27 | num_epochs: 40
28 | controlnet_epochs : 1
29 | num_samples : 25
30 | num_grid_rows : 5
31 | ddpm_lr: 0.0001
32 | controlnet_lr: 0.0001
33 | ddpm_ckpt_name: 'ddpm_ckpt.pth'
34 | controlnet_ckpt_name: 'ddpm_controlnet_ckpt.pth'
35 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/data/.gitkeep
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/celeb_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import cv2
4 | import torchvision
5 | import numpy as np
6 | from PIL import Image
7 | from utils.diffusion_utils import load_latents
8 | from tqdm import tqdm
9 | from torch.utils.data.dataset import Dataset
10 |
11 |
12 | class CelebDataset(Dataset):
13 | r"""
14 | Celeb dataset will by default centre crop and resize the images.
15 | This can be replaced by any other dataset. As long as all the images
16 | are under one directory.
17 | """
18 |
19 | def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
20 | use_latents=False, latent_path=None, return_hint=False):#, condition_config=None):
21 | self.split = split
22 | self.im_size = im_size
23 | self.im_channels = im_channels
24 | self.im_ext = im_ext
25 | self.im_path = im_path
26 | self.latent_maps = None
27 | self.use_latents = False
28 | self.return_hints = return_hint
29 | # self.condition_types = [] if condition_config is None else condition_config['condition_types']
30 |
31 | # self.idx_to_cls_map = {}
32 | # self.cls_to_idx_map = {}
33 |
34 | # if 'image' in self.condition_types:
35 | # self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels']
36 | # self.mask_h = condition_config['image_condition_config']['image_condition_h']
37 | # self.mask_w = condition_config['image_condition_config']['image_condition_w']
38 |
39 | #self.images, self.texts, self.masks = self.load_images(im_path)
40 | self.images = self.load_images(im_path)
41 |
42 | # Whether to load images or to load latents
43 | if use_latents and latent_path is not None:
44 | latent_maps = load_latents(latent_path)
45 | if len(latent_maps) == len(self.images):
46 | self.use_latents = True
47 | self.latent_maps = latent_maps
48 | print('Found {} latents'.format(len(self.latent_maps)))
49 | else:
50 | print('Latents not found')
51 |
52 | def load_images(self, im_path):
53 | r"""
54 | Gets all images from the path specified
55 | and stacks them all up
56 | """
57 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
58 | ims = []
59 | fnames = glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('png')))
60 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpg')))
61 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpeg')))
62 | texts = []
63 | masks = []
64 |
65 | # if 'image' in self.condition_types:
66 | # label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth',
67 | # 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
68 | # self.idx_to_cls_map = {idx: label_list[idx] for idx in range(len(label_list))}
69 | # self.cls_to_idx_map = {label_list[idx]: idx for idx in range(len(label_list))}
70 |
71 | for fname in tqdm(fnames):
72 | ims.append(fname)
73 |
74 | # if 'text' in self.condition_types:
75 | # im_name = os.path.split(fname)[1].split('.')[0]
76 | # captions_im = []
77 | # with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f:
78 | # for line in f.readlines():
79 | # captions_im.append(line.strip())
80 | # texts.append(captions_im)
81 |
82 | # if 'image' in self.condition_types:
83 | # im_name = int(os.path.split(fname)[1].split('.')[0])
84 | # masks.append(os.path.join(im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name)))
85 | # if 'text' in self.condition_types:
86 | # assert len(texts) == len(ims), "Condition Type Text but could not find captions for all images"
87 | # if 'image' in self.condition_types:
88 | # assert len(masks) == len(ims), "Condition Type Image but could not find masks for all images"
89 | print('Found {} images'.format(len(ims)))
90 | #print('Found {} masks'.format(len(masks)))
91 | #print('Found {} captions'.format(len(texts)))
92 | return ims#, texts, masks
93 |
94 | # def get_mask(self, index):
95 | # r"""
96 | # Method to get the mask of WxH
97 | # for given index and convert it into
98 | # Classes x W x H mask image
99 | # :param index:
100 | # :return:
101 | # """
102 | # mask_im = Image.open(self.masks[index])
103 | # mask_im = np.array(mask_im)
104 | # im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels))
105 | # for orig_idx in range(len(self.idx_to_cls_map)):
106 | # im_base[mask_im == (orig_idx + 1), orig_idx] = 1
107 | # mask = torch.from_numpy(im_base).permute(2, 0, 1).float()
108 | # return mask
109 |
110 | def __len__(self):
111 | return len(self.images)
112 |
113 | def __getitem__(self, index):
114 | ######## Set Conditioning Info ########
115 | # cond_inputs = {}
116 | # if 'text' in self.condition_types:
117 | # cond_inputs['text'] = random.sample(self.texts[index], k=1)[0]
118 | # if 'image' in self.condition_types:
119 | # mask = self.get_mask(index)
120 | # cond_inputs['image'] = mask
121 | #######################################
122 | # im = Image.open(self.images[index])
123 | # im.save('original_image.png')
124 | # canny_image = np.array(im)
125 | # print(self.images[index])
126 | # low_threshold = 100
127 | # high_threshold = 200
128 | # import cv2
129 | # canny_image = cv2.Canny(canny_image, low_threshold, high_threshold)
130 | # canny_image = canny_image[:, :, None]
131 | # canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
132 | # canny_image = 255 - canny_image
133 | # canny_image = Image.fromarray(canny_image)
134 | # canny_image.save('canny_image.png')
135 | # print(list(self.latent_maps.keys())[0])
136 | # print(self.images[index] in self.latent_maps)
137 | # print(self.images[index].replace('../', '') in self.latent_maps)
138 | # latent = self.latent_maps[self.images[index].replace('../', '')]
139 | # latent = torch.clamp(latent, -1., 1.)
140 | # latent = (latent + 1) / 2
141 | # latent = torchvision.transforms.ToPILImage()(latent[0:-1, :, :])
142 | # latent.save('latent_image.png')
143 | # exit(0)
144 |
145 | if self.use_latents:
146 | latent = self.latent_maps[self.images[index]]
147 | if self.return_hints:
148 | canny_image = Image.open(self.images[index])
149 | canny_image = np.array(canny_image)
150 | canny_image = cv2.Canny(canny_image, 100, 200)
151 | canny_image = canny_image[:, :, None]
152 | canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
153 | canny_image_tensor = torchvision.transforms.ToTensor()(canny_image)
154 | return latent, canny_image_tensor
155 | else:
156 | return latent
157 |
158 | else:
159 | im = Image.open(self.images[index])
160 | im_tensor = torchvision.transforms.Compose([
161 | torchvision.transforms.Resize(self.im_size),
162 | torchvision.transforms.CenterCrop(self.im_size),
163 | torchvision.transforms.ToTensor(),
164 | ])(im)
165 | im.close()
166 |
167 | # Convert input to -1 to 1 range.
168 | im_tensor = (2 * im_tensor) - 1
169 |
170 | if self.return_hints:
171 | canny_image = Image.open(self.images[index])
172 | canny_image = np.array(canny_image)
173 | canny_image = cv2.Canny(canny_image, 100, 200)
174 | canny_image = canny_image[:, :, None]
175 | canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
176 | canny_image_tensor = torchvision.transforms.ToTensor()(canny_image)
177 | return im_tensor, canny_image_tensor
178 | else:
179 | return im_tensor
180 |
181 |
182 | if __name__ == '__main__':
183 | mnist = CelebDataset('train', im_path='../data/CelebAMask-HQ',
184 | use_latents=True, latent_path='../celebhq/vae_latents')
185 | x = mnist[1800]
186 |
--------------------------------------------------------------------------------
/dataset/mnist_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torchvision
6 | from PIL import Image
7 | from tqdm import tqdm
8 | from torch.utils.data.dataset import Dataset
9 |
10 |
11 | class MnistDataset(Dataset):
12 | r"""
13 | Nothing special here. Just a simple dataset class for mnist images.
14 | Created a dataset class rather using torchvision to allow
15 | replacement with any other image dataset
16 | """
17 | def __init__(self, split, im_path, im_ext='png', im_size=28, return_hints=False):
18 | r"""
19 | Init method for initializing the dataset properties
20 | :param split: train/test to locate the image files
21 | :param im_path: root folder of images
22 | :param im_ext: image extension. assumes all
23 | images would be this type.
24 | """
25 | self.split = split
26 | self.im_ext = im_ext
27 | self.return_hints = return_hints
28 | self.images = self.load_images(im_path)
29 |
30 | def load_images(self, im_path):
31 | r"""
32 | Gets all images from the path specified
33 | and stacks them all up
34 | :param im_path:
35 | :return:
36 | """
37 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
38 | ims = []
39 | labels = []
40 | for d_name in tqdm(os.listdir(im_path)):
41 | for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
42 | ims.append(fname)
43 | print('Found {} images for split {}'.format(len(ims), self.split))
44 | return ims
45 |
46 | def __len__(self):
47 | return len(self.images)
48 |
49 | def __getitem__(self, index):
50 | im = Image.open(self.images[index])
51 | im_tensor = torchvision.transforms.ToTensor()(im)
52 |
53 | # Convert input to -1 to 1 range.
54 | im_tensor = (2 * im_tensor) - 1
55 |
56 | if self.return_hints:
57 | canny_image = Image.open(self.images[index])
58 | canny_image = np.array(canny_image)
59 | canny_image = cv2.Canny(canny_image, 100, 200)
60 | canny_image = canny_image[:, :, None]
61 | canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
62 | canny_image_tensor = torchvision.transforms.ToTensor()(canny_image)
63 | return im_tensor, canny_image_tensor
64 | else:
65 | return im_tensor
66 |
67 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/models/__init__.py
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def get_time_embedding(time_steps, temb_dim):
6 | r"""
7 | Convert time steps tensor into an embedding using the
8 | sinusoidal time embedding formula
9 | :param time_steps: 1D tensor of length batch size
10 | :param temb_dim: Dimension of the embedding
11 | :return: BxD embedding representation of B time steps
12 | """
13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
14 |
15 | # factor = 10000^(2i/d_model)
16 | factor = 10000 ** ((torch.arange(
17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
18 | )
19 |
20 | # pos / factor
21 | # timesteps B -> B, 1 -> B, temb_dim
22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
24 | return t_emb
25 |
26 |
27 | class DownBlock(nn.Module):
28 | r"""
29 | Down conv block with attention.
30 | Sequence of following block
31 | 1. Resnet block with time embedding
32 | 2. Attention block
33 | 3. Downsample
34 | """
35 |
36 | def __init__(self, in_channels, out_channels, t_emb_dim,
37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
38 | super().__init__()
39 | self.num_layers = num_layers
40 | self.down_sample = down_sample
41 | self.attn = attn
42 | self.context_dim = context_dim
43 | self.cross_attn = cross_attn
44 | self.t_emb_dim = t_emb_dim
45 | self.resnet_conv_first = nn.ModuleList(
46 | [
47 | nn.Sequential(
48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
49 | nn.SiLU(),
50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
51 | kernel_size=3, stride=1, padding=1),
52 | )
53 | for i in range(num_layers)
54 | ]
55 | )
56 | if self.t_emb_dim is not None:
57 | self.t_emb_layers = nn.ModuleList([
58 | nn.Sequential(
59 | nn.SiLU(),
60 | nn.Linear(self.t_emb_dim, out_channels)
61 | )
62 | for _ in range(num_layers)
63 | ])
64 | self.resnet_conv_second = nn.ModuleList(
65 | [
66 | nn.Sequential(
67 | nn.GroupNorm(norm_channels, out_channels),
68 | nn.SiLU(),
69 | nn.Conv2d(out_channels, out_channels,
70 | kernel_size=3, stride=1, padding=1),
71 | )
72 | for _ in range(num_layers)
73 | ]
74 | )
75 |
76 | if self.attn:
77 | self.attention_norms = nn.ModuleList(
78 | [nn.GroupNorm(norm_channels, out_channels)
79 | for _ in range(num_layers)]
80 | )
81 |
82 | self.attentions = nn.ModuleList(
83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
84 | for _ in range(num_layers)]
85 | )
86 |
87 | if self.cross_attn:
88 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
89 | self.cross_attention_norms = nn.ModuleList(
90 | [nn.GroupNorm(norm_channels, out_channels)
91 | for _ in range(num_layers)]
92 | )
93 | self.cross_attentions = nn.ModuleList(
94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
95 | for _ in range(num_layers)]
96 | )
97 | self.context_proj = nn.ModuleList(
98 | [nn.Linear(context_dim, out_channels)
99 | for _ in range(num_layers)]
100 | )
101 |
102 | self.residual_input_conv = nn.ModuleList(
103 | [
104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
105 | for i in range(num_layers)
106 | ]
107 | )
108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
109 | 4, 2, 1) if self.down_sample else nn.Identity()
110 |
111 | def forward(self, x, t_emb=None, context=None):
112 | out = x
113 | for i in range(self.num_layers):
114 | # Resnet block of Unet
115 | resnet_input = out
116 | out = self.resnet_conv_first[i](out)
117 | if self.t_emb_dim is not None:
118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
119 | out = self.resnet_conv_second[i](out)
120 | out = out + self.residual_input_conv[i](resnet_input)
121 |
122 | if self.attn:
123 | # Attention block of Unet
124 | batch_size, channels, h, w = out.shape
125 | in_attn = out.reshape(batch_size, channels, h * w)
126 | in_attn = self.attention_norms[i](in_attn)
127 | in_attn = in_attn.transpose(1, 2)
128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
130 | out = out + out_attn
131 |
132 | if self.cross_attn:
133 | assert context is not None, "context cannot be None if cross attention layers are used"
134 | batch_size, channels, h, w = out.shape
135 | in_attn = out.reshape(batch_size, channels, h * w)
136 | in_attn = self.cross_attention_norms[i](in_attn)
137 | in_attn = in_attn.transpose(1, 2)
138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
139 | context_proj = self.context_proj[i](context)
140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
142 | out = out + out_attn
143 |
144 | # Downsample
145 | out = self.down_sample_conv(out)
146 | return out
147 |
148 |
149 | class MidBlock(nn.Module):
150 | r"""
151 | Mid conv block with attention.
152 | Sequence of following blocks
153 | 1. Resnet block with time embedding
154 | 2. Attention block
155 | 3. Resnet block with time embedding
156 | """
157 |
158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None,
159 | context_dim=None):
160 | super().__init__()
161 | self.num_layers = num_layers
162 | self.t_emb_dim = t_emb_dim
163 | self.context_dim = context_dim
164 | self.cross_attn = cross_attn
165 | self.resnet_conv_first = nn.ModuleList(
166 | [
167 | nn.Sequential(
168 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
169 | nn.SiLU(),
170 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
171 | padding=1),
172 | )
173 | for i in range(num_layers + 1)
174 | ]
175 | )
176 |
177 | if self.t_emb_dim is not None:
178 | self.t_emb_layers = nn.ModuleList([
179 | nn.Sequential(
180 | nn.SiLU(),
181 | nn.Linear(t_emb_dim, out_channels)
182 | )
183 | for _ in range(num_layers + 1)
184 | ])
185 | self.resnet_conv_second = nn.ModuleList(
186 | [
187 | nn.Sequential(
188 | nn.GroupNorm(norm_channels, out_channels),
189 | nn.SiLU(),
190 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
191 | )
192 | for _ in range(num_layers + 1)
193 | ]
194 | )
195 |
196 | self.attention_norms = nn.ModuleList(
197 | [nn.GroupNorm(norm_channels, out_channels)
198 | for _ in range(num_layers)]
199 | )
200 |
201 | self.attentions = nn.ModuleList(
202 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
203 | for _ in range(num_layers)]
204 | )
205 | if self.cross_attn:
206 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
207 | self.cross_attention_norms = nn.ModuleList(
208 | [nn.GroupNorm(norm_channels, out_channels)
209 | for _ in range(num_layers)]
210 | )
211 | self.cross_attentions = nn.ModuleList(
212 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
213 | for _ in range(num_layers)]
214 | )
215 | self.context_proj = nn.ModuleList(
216 | [nn.Linear(context_dim, out_channels)
217 | for _ in range(num_layers)]
218 | )
219 | self.residual_input_conv = nn.ModuleList(
220 | [
221 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
222 | for i in range(num_layers + 1)
223 | ]
224 | )
225 |
226 | def forward(self, x, t_emb=None, context=None):
227 | out = x
228 |
229 | # First resnet block
230 | resnet_input = out
231 | out = self.resnet_conv_first[0](out)
232 | if self.t_emb_dim is not None:
233 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
234 | out = self.resnet_conv_second[0](out)
235 | out = out + self.residual_input_conv[0](resnet_input)
236 |
237 | for i in range(self.num_layers):
238 | # Attention Block
239 | batch_size, channels, h, w = out.shape
240 | in_attn = out.reshape(batch_size, channels, h * w)
241 | in_attn = self.attention_norms[i](in_attn)
242 | in_attn = in_attn.transpose(1, 2)
243 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
244 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
245 | out = out + out_attn
246 |
247 | if self.cross_attn:
248 | assert context is not None, "context cannot be None if cross attention layers are used"
249 | batch_size, channels, h, w = out.shape
250 | in_attn = out.reshape(batch_size, channels, h * w)
251 | in_attn = self.cross_attention_norms[i](in_attn)
252 | in_attn = in_attn.transpose(1, 2)
253 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
254 | context_proj = self.context_proj[i](context)
255 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
256 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
257 | out = out + out_attn
258 |
259 | # Resnet Block
260 | resnet_input = out
261 | out = self.resnet_conv_first[i + 1](out)
262 | if self.t_emb_dim is not None:
263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
264 | out = self.resnet_conv_second[i + 1](out)
265 | out = out + self.residual_input_conv[i + 1](resnet_input)
266 |
267 | return out
268 |
269 |
270 | class UpBlock(nn.Module):
271 | r"""
272 | Up conv block with attention.
273 | Sequence of following blocks
274 | 1. Upsample
275 | 1. Concatenate Down block output
276 | 2. Resnet block with time embedding
277 | 3. Attention Block
278 | """
279 |
280 | def __init__(self, in_channels, out_channels, t_emb_dim,
281 | up_sample, num_heads, num_layers, attn, norm_channels):
282 | super().__init__()
283 | self.num_layers = num_layers
284 | self.up_sample = up_sample
285 | self.t_emb_dim = t_emb_dim
286 | self.attn = attn
287 | self.resnet_conv_first = nn.ModuleList(
288 | [
289 | nn.Sequential(
290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
291 | nn.SiLU(),
292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
293 | padding=1),
294 | )
295 | for i in range(num_layers)
296 | ]
297 | )
298 |
299 | if self.t_emb_dim is not None:
300 | self.t_emb_layers = nn.ModuleList([
301 | nn.Sequential(
302 | nn.SiLU(),
303 | nn.Linear(t_emb_dim, out_channels)
304 | )
305 | for _ in range(num_layers)
306 | ])
307 |
308 | self.resnet_conv_second = nn.ModuleList(
309 | [
310 | nn.Sequential(
311 | nn.GroupNorm(norm_channels, out_channels),
312 | nn.SiLU(),
313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
314 | )
315 | for _ in range(num_layers)
316 | ]
317 | )
318 | if self.attn:
319 | self.attention_norms = nn.ModuleList(
320 | [
321 | nn.GroupNorm(norm_channels, out_channels)
322 | for _ in range(num_layers)
323 | ]
324 | )
325 |
326 | self.attentions = nn.ModuleList(
327 | [
328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
329 | for _ in range(num_layers)
330 | ]
331 | )
332 |
333 | self.residual_input_conv = nn.ModuleList(
334 | [
335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
336 | for i in range(num_layers)
337 | ]
338 | )
339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
340 | 4, 2, 1) \
341 | if self.up_sample else nn.Identity()
342 |
343 | def forward(self, x, out_down=None, t_emb=None):
344 | # Upsample
345 | x = self.up_sample_conv(x)
346 |
347 | # Concat with Downblock output
348 | if out_down is not None:
349 | x = torch.cat([x, out_down], dim=1)
350 |
351 | out = x
352 | for i in range(self.num_layers):
353 | # Resnet Block
354 | resnet_input = out
355 | out = self.resnet_conv_first[i](out)
356 | if self.t_emb_dim is not None:
357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
358 | out = self.resnet_conv_second[i](out)
359 | out = out + self.residual_input_conv[i](resnet_input)
360 |
361 | # Self Attention
362 | if self.attn:
363 | batch_size, channels, h, w = out.shape
364 | in_attn = out.reshape(batch_size, channels, h * w)
365 | in_attn = self.attention_norms[i](in_attn)
366 | in_attn = in_attn.transpose(1, 2)
367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
369 | out = out + out_attn
370 | return out
371 |
372 |
373 | class UpBlockUnet(nn.Module):
374 | r"""
375 | Up conv block with attention.
376 | Sequence of following blocks
377 | 1. Upsample
378 | 1. Concatenate Down block output
379 | 2. Resnet block with time embedding
380 | 3. Attention Block
381 | """
382 |
383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
385 | super().__init__()
386 | self.num_layers = num_layers
387 | self.up_sample = up_sample
388 | self.t_emb_dim = t_emb_dim
389 | self.cross_attn = cross_attn
390 | self.context_dim = context_dim
391 | self.resnet_conv_first = nn.ModuleList(
392 | [
393 | nn.Sequential(
394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
395 | nn.SiLU(),
396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
397 | padding=1),
398 | )
399 | for i in range(num_layers)
400 | ]
401 | )
402 |
403 | if self.t_emb_dim is not None:
404 | self.t_emb_layers = nn.ModuleList([
405 | nn.Sequential(
406 | nn.SiLU(),
407 | nn.Linear(t_emb_dim, out_channels)
408 | )
409 | for _ in range(num_layers)
410 | ])
411 |
412 | self.resnet_conv_second = nn.ModuleList(
413 | [
414 | nn.Sequential(
415 | nn.GroupNorm(norm_channels, out_channels),
416 | nn.SiLU(),
417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
418 | )
419 | for _ in range(num_layers)
420 | ]
421 | )
422 |
423 | self.attention_norms = nn.ModuleList(
424 | [
425 | nn.GroupNorm(norm_channels, out_channels)
426 | for _ in range(num_layers)
427 | ]
428 | )
429 |
430 | self.attentions = nn.ModuleList(
431 | [
432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
433 | for _ in range(num_layers)
434 | ]
435 | )
436 |
437 | if self.cross_attn:
438 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
439 | self.cross_attention_norms = nn.ModuleList(
440 | [nn.GroupNorm(norm_channels, out_channels)
441 | for _ in range(num_layers)]
442 | )
443 | self.cross_attentions = nn.ModuleList(
444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
445 | for _ in range(num_layers)]
446 | )
447 | self.context_proj = nn.ModuleList(
448 | [nn.Linear(context_dim, out_channels)
449 | for _ in range(num_layers)]
450 | )
451 | self.residual_input_conv = nn.ModuleList(
452 | [
453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
454 | for i in range(num_layers)
455 | ]
456 | )
457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
458 | 4, 2, 1) \
459 | if self.up_sample else nn.Identity()
460 |
461 | def forward(self, x, out_down=None, t_emb=None, context=None):
462 | x = self.up_sample_conv(x)
463 | if out_down is not None:
464 | x = torch.cat([x, out_down], dim=1)
465 |
466 | out = x
467 | for i in range(self.num_layers):
468 | # Resnet
469 | resnet_input = out
470 | out = self.resnet_conv_first[i](out)
471 | if self.t_emb_dim is not None:
472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
473 | out = self.resnet_conv_second[i](out)
474 | out = out + self.residual_input_conv[i](resnet_input)
475 | # Self Attention
476 | batch_size, channels, h, w = out.shape
477 | in_attn = out.reshape(batch_size, channels, h * w)
478 | in_attn = self.attention_norms[i](in_attn)
479 | in_attn = in_attn.transpose(1, 2)
480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
482 | out = out + out_attn
483 | # Cross Attention
484 | if self.cross_attn:
485 | assert context is not None, "context cannot be None if cross attention layers are used"
486 | batch_size, channels, h, w = out.shape
487 | in_attn = out.reshape(batch_size, channels, h * w)
488 | in_attn = self.cross_attention_norms[i](in_attn)
489 | in_attn = in_attn.transpose(1, 2)
490 | assert len(context.shape) == 3, \
491 | "Context shape does not match B,_,CONTEXT_DIM"
492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \
493 | "Context shape does not match B,_,CONTEXT_DIM"
494 | context_proj = self.context_proj[i](context)
495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
497 | out = out + out_attn
498 |
499 | return out
500 |
501 |
502 |
--------------------------------------------------------------------------------
/models/controlnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.unet_base import Unet
4 | from models.unet_base import get_time_embedding
5 |
6 |
7 | def make_zero_module(module):
8 | for p in module.parameters():
9 | p.detach().zero_()
10 | return module
11 |
12 |
13 | class ControlNet(nn.Module):
14 | r"""
15 | Control Net Module for DDPM
16 | """
17 | def __init__(self, model_config,
18 | model_locked=True,
19 | model_ckpt=None,
20 | device=None):
21 | super().__init__()
22 | # Trained DDPM
23 | self.model_locked = model_locked
24 | self.trained_unet = Unet(model_config)
25 |
26 | # Load weights for the trained model
27 | if model_ckpt is not None and device is not None:
28 | print('Loading Trained Diffusion Model')
29 | self.trained_unet.load_state_dict(torch.load(model_ckpt,
30 | map_location=device), strict=True)
31 |
32 | # ControlNet Copy of Trained DDPM
33 | # use_up = False removes the upblocks(decoder layers) from DDPM Unet
34 | self.control_copy_unet = Unet(model_config, use_up=False)
35 | # Load same weights as the trained model
36 | if model_ckpt is not None and device is not None:
37 | print('Loading Control Diffusion Model')
38 | self.control_copy_unet.load_state_dict(torch.load(model_ckpt,
39 | map_location=device), strict=False)
40 |
41 | # Hint Block for ControlNet
42 | # Stack of Conv activation and zero convolution at the end
43 | self.control_copy_unet_hint_block = nn.Sequential(
44 | nn.Conv2d(model_config['hint_channels'],
45 | 64,
46 | kernel_size=3,
47 | padding=(1, 1)),
48 | nn.SiLU(),
49 | nn.Conv2d(64,
50 | 128,
51 | kernel_size=3,
52 | padding=(1, 1)),
53 | nn.SiLU(),
54 | nn.Conv2d(128,
55 | self.trained_unet.down_channels[0],
56 | kernel_size=3,
57 | padding=(1, 1)),
58 | nn.SiLU(),
59 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[0],
60 | self.trained_unet.down_channels[0],
61 | kernel_size=1,
62 | padding=0))
63 | )
64 |
65 | # Zero Convolution Module for Downblocks(encoder Layers)
66 | self.control_copy_unet_down_zero_convs = nn.ModuleList([
67 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[i],
68 | self.trained_unet.down_channels[i],
69 | kernel_size=1,
70 | padding=0))
71 | for i in range(len(self.trained_unet.down_channels)-1)
72 | ])
73 |
74 | # Zero Convolution Module for MidBlocks
75 | self.control_copy_unet_mid_zero_convs = nn.ModuleList([
76 | make_zero_module(nn.Conv2d(self.trained_unet.mid_channels[i],
77 | self.trained_unet.mid_channels[i],
78 | kernel_size=1,
79 | padding=0))
80 | for i in range(1, len(self.trained_unet.mid_channels))
81 | ])
82 |
83 | def get_params(self):
84 | # Add all ControlNet parameters
85 | # First is our copy of unet
86 | params = list(self.control_copy_unet.parameters())
87 |
88 | # Add parameters of hint Blocks & Zero convolutions for down/mid blocks
89 | params += list(self.control_copy_unet_hint_block.parameters())
90 | params += list(self.control_copy_unet_down_zero_convs.parameters())
91 | params += list(self.control_copy_unet_mid_zero_convs.parameters())
92 |
93 | # If we desire to not have the decoder layers locked, then add
94 | # them as well
95 | if not self.model_locked:
96 | params += list(self.trained_unet.ups.parameters())
97 | params += list(self.trained_unet.norm_out.parameters())
98 | params += list(self.trained_unet.conv_out.parameters())
99 | return params
100 |
101 | def forward(self, x, t, hint):
102 | # Time embedding and timestep projection layers of trained unet
103 | trained_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(),
104 | self.trained_unet.t_emb_dim)
105 | trained_unet_t_emb = self.trained_unet.t_proj(trained_unet_t_emb)
106 |
107 | # Get all downblocks output of trained unet first
108 | trained_unet_down_outs = []
109 | with torch.no_grad():
110 | train_unet_out = self.trained_unet.conv_in(x)
111 | for idx, down in enumerate(self.trained_unet.downs):
112 | trained_unet_down_outs.append(train_unet_out)
113 | train_unet_out = down(train_unet_out, trained_unet_t_emb)
114 |
115 | # ControlNet Layers start here #
116 | # Time embedding and timestep projection layers of controlnet's copy of unet
117 | control_copy_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(),
118 | self.control_copy_unet.t_emb_dim)
119 | control_copy_unet_t_emb = self.control_copy_unet.t_proj(control_copy_unet_t_emb)
120 |
121 | # Hint block of controlnet's copy of unet
122 | control_copy_unet_hint_out = self.control_copy_unet_hint_block(hint)
123 |
124 | # Call conv_in layer for controlnet's copy of unet
125 | # and add hint blocks output to it
126 | control_copy_unet_out = self.control_copy_unet.conv_in(x)
127 | control_copy_unet_out += control_copy_unet_hint_out
128 |
129 | # Get all downblocks output for controlnet's copy of unet
130 | control_copy_unet_down_outs = []
131 | for idx, down in enumerate(self.control_copy_unet.downs):
132 | # Save the control nets copy output after passing it through zero conv layers
133 | control_copy_unet_down_outs.append(
134 | self.control_copy_unet_down_zero_convs[idx](control_copy_unet_out)
135 | )
136 | control_copy_unet_out = down(control_copy_unet_out, control_copy_unet_t_emb)
137 |
138 | for idx in range(len(self.control_copy_unet.mids)):
139 | # Get midblock output of controlnets copy of unet
140 | control_copy_unet_out = self.control_copy_unet.mids[idx](
141 | control_copy_unet_out,
142 | control_copy_unet_t_emb
143 | )
144 |
145 | # Get midblock output of trained unet
146 | train_unet_out = self.trained_unet.mids[idx](train_unet_out, trained_unet_t_emb)
147 |
148 | # Add midblock output of controlnets copy of unet to that of trained unet
149 | # but after passing them through zero conv layers
150 | train_unet_out += self.control_copy_unet_mid_zero_convs[idx](control_copy_unet_out)
151 |
152 | # Call upblocks of trained unet
153 | for up in self.trained_unet.ups:
154 | # Get downblocks output from both trained unet and controlnets copy of unet
155 | trained_unet_down_out = trained_unet_down_outs.pop()
156 | control_copy_unet_down_out = control_copy_unet_down_outs.pop()
157 |
158 | # Add these together and pass this as downblock input to upblock
159 | train_unet_out = up(train_unet_out,
160 | control_copy_unet_down_out + trained_unet_down_out,
161 | trained_unet_t_emb)
162 |
163 | # Call output layers of trained unet
164 | train_unet_out = self.trained_unet.norm_out(train_unet_out)
165 | train_unet_out = nn.SiLU()(train_unet_out)
166 | train_unet_out = self.trained_unet.conv_out(train_unet_out)
167 | # out B x C x H x W
168 | return train_unet_out
169 |
170 |
171 |
172 |
173 |
174 |
175 |
--------------------------------------------------------------------------------
/models/controlnet_ldm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.unet_cond_base import Unet
4 | from models.blocks import get_time_embedding
5 |
6 |
7 | def make_zero_module(module):
8 | for p in module.parameters():
9 | p.detach().zero_()
10 | return module
11 |
12 |
13 | class ControlNet(nn.Module):
14 | r"""
15 | Control Net Module for DDPM
16 | """
17 | def __init__(self, im_channels,
18 | model_config,
19 | model_locked=True,
20 | model_ckpt=None,
21 | device=None,
22 | down_sample_factor=32):
23 | super().__init__()
24 | # Trained DDPM
25 | self.model_locked = model_locked
26 | self.trained_unet = Unet(im_channels, model_config)
27 |
28 | # Load weights for the trained model
29 | if model_ckpt is not None and device is not None:
30 | print('Loading Trained Diffusion Model')
31 | self.trained_unet.load_state_dict(torch.load(model_ckpt,
32 | map_location=device), strict=True)
33 |
34 | # ControlNet Copy of Trained DDPM
35 | # use_up = False removes the upblocks(decoder layers) from DDPM Unet
36 | self.control_unet = Unet(im_channels, model_config, use_up=False)
37 | # Load same weights as the trained model
38 | if model_ckpt is not None and device is not None:
39 | print('Loading Control Diffusion Model')
40 | self.control_unet.load_state_dict(torch.load(model_ckpt,
41 | map_location=device), strict=False)
42 |
43 | # Hint Block for ControlNet
44 | # Stack of Conv activation and zero convolution at the end
45 | base_hint_channel = 16
46 | curr_down_sample_factor = down_sample_factor
47 | hint_layers = [nn.Sequential(
48 | nn.Conv2d(model_config['hint_channels'],
49 | base_hint_channel,
50 | kernel_size=3,
51 | padding=(1, 1)),
52 | nn.SiLU())]
53 | while curr_down_sample_factor > 1:
54 | hint_layers.append(nn.Sequential(
55 | nn.Conv2d(base_hint_channel,
56 | base_hint_channel*2,
57 | kernel_size=3,
58 | padding=(1, 1),
59 | stride=2),
60 | nn.SiLU(),
61 | nn.Conv2d(base_hint_channel*2,
62 | base_hint_channel*2,
63 | kernel_size=3,
64 | padding=(1, 1))
65 | ))
66 | base_hint_channel = base_hint_channel * 2
67 | curr_down_sample_factor = curr_down_sample_factor / 2
68 | hint_layers.append(nn.Sequential(
69 | nn.Conv2d(base_hint_channel,
70 | self.trained_unet.down_channels[0],
71 | kernel_size=3,
72 | padding=(1, 1)),
73 | nn.SiLU(),
74 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[0],
75 | self.trained_unet.down_channels[0],
76 | kernel_size=1,
77 | padding=0))
78 | ))
79 | self.control_unet_hint_block = nn.Sequential(*hint_layers)
80 |
81 | # Zero Convolution Module for Downblocks(encoder Layers)
82 | self.control_unet_down_zero_convs = nn.ModuleList([
83 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[i],
84 | self.trained_unet.down_channels[i],
85 | kernel_size=1,
86 | padding=0))
87 | for i in range(len(self.trained_unet.down_channels)-1)
88 | ])
89 |
90 | # Zero Convolution Module for MidBlocks
91 | self.control_unet_mid_zero_convs = nn.ModuleList([
92 | make_zero_module(nn.Conv2d(self.trained_unet.mid_channels[i],
93 | self.trained_unet.mid_channels[i],
94 | kernel_size=1,
95 | padding=0))
96 | for i in range(1, len(self.trained_unet.mid_channels))
97 | ])
98 |
99 | def get_params(self):
100 | # Add all ControlNet parameters
101 | # First is our copy of unet
102 | params = list(self.control_unet.parameters())
103 |
104 | # Add parameters of hint Blocks & Zero convolutions for down/mid blocks
105 | params += list(self.control_unet_hint_block.parameters())
106 | params += list(self.control_unet_down_zero_convs.parameters())
107 | params += list(self.control_unet_mid_zero_convs.parameters())
108 |
109 | # If we desire to not have the decoder layers locked, then add
110 | # them as well
111 | if not self.model_locked:
112 | params += list(self.trained_unet.ups.parameters())
113 | params += list(self.trained_unet.norm_out.parameters())
114 | params += list(self.trained_unet.conv_out.parameters())
115 | return params
116 |
117 | def forward(self, x, t, hint):
118 | # Time embedding and timestep projection layers of trained unet
119 | trained_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(),
120 | self.trained_unet.t_emb_dim)
121 | trained_unet_t_emb = self.trained_unet.t_proj(trained_unet_t_emb)
122 |
123 | # Get all downblocks output of trained unet first
124 | trained_unet_down_outs = []
125 | with torch.no_grad():
126 | train_unet_out = self.trained_unet.conv_in(x)
127 | for idx, down in enumerate(self.trained_unet.downs):
128 | trained_unet_down_outs.append(train_unet_out)
129 | train_unet_out = down(train_unet_out, trained_unet_t_emb)
130 |
131 | # ControlNet Layers start here #
132 | # Time embedding and timestep projection layers of controlnet's copy of unet
133 | control_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(),
134 | self.control_unet.t_emb_dim)
135 | control_unet_t_emb = self.control_unet.t_proj(control_unet_t_emb)
136 |
137 | # Hint block of controlnet's copy of unet
138 | control_unet_hint_out = self.control_unet_hint_block(hint)
139 |
140 | # Call conv_in layer for controlnet's copy of unet
141 | # and add hint blocks output to it
142 | control_unet_out = self.control_unet.conv_in(x)
143 | control_unet_out += control_unet_hint_out
144 |
145 | # Get all downblocks output for controlnet's copy of unet
146 | control_unet_down_outs = []
147 | for idx, down in enumerate(self.control_unet.downs):
148 | # Save the control nets copy output after passing it through zero conv layers
149 | control_unet_down_outs.append(self.control_unet_down_zero_convs[idx](control_unet_out))
150 | control_unet_out = down(control_unet_out, control_unet_t_emb)
151 |
152 | for idx in range(len(self.control_unet.mids)):
153 | # Get midblock output of controlnets copy of unet
154 | control_unet_out = self.control_unet.mids[idx](control_unet_out, control_unet_t_emb)
155 |
156 | # Get midblock output of trained unet
157 | train_unet_out = self.trained_unet.mids[idx](train_unet_out, trained_unet_t_emb)
158 |
159 | # Add midblock output of controlnets copy of unet to that of trained unet
160 | # but after passing them through zero conv layers
161 | train_unet_out += self.control_unet_mid_zero_convs[idx](control_unet_out)
162 |
163 | # Call upblocks of trained unet
164 | for up in self.trained_unet.ups:
165 | # Get downblocks output from both trained unet and controlnets copy of unet
166 | trained_unet_down_out = trained_unet_down_outs.pop()
167 | control_unet_down_out = control_unet_down_outs.pop()
168 |
169 | # Add these together and pass this as downblock input to upblock
170 | train_unet_out = up(train_unet_out,
171 | control_unet_down_out + trained_unet_down_out,
172 | trained_unet_t_emb)
173 |
174 | # Call output layers of trained unet
175 | train_unet_out = self.trained_unet.norm_out(train_unet_out)
176 | train_unet_out = nn.SiLU()(train_unet_out)
177 | train_unet_out = self.trained_unet.conv_out(train_unet_out)
178 | # out B x C x H x W
179 | return train_unet_out
180 |
181 |
182 |
183 |
184 |
185 |
186 |
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Discriminator(nn.Module):
6 | r"""
7 | PatchGAN Discriminator.
8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
9 | 1 scalar value , we instead predict grid of values.
10 | Where each grid is prediction of how likely
11 | the discriminator thinks that the image patch corresponding
12 | to the grid cell is real
13 | """
14 |
15 | def __init__(self, im_channels=3,
16 | conv_channels=[64, 128, 256],
17 | kernels=[4,4,4,4],
18 | strides=[2,2,2,1],
19 | paddings=[1,1,1,1]):
20 | super().__init__()
21 | self.im_channels = im_channels
22 | activation = nn.LeakyReLU(0.2)
23 | layers_dim = [self.im_channels] + conv_channels + [1]
24 | self.layers = nn.ModuleList([
25 | nn.Sequential(
26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1],
27 | kernel_size=kernels[i],
28 | stride=strides[i],
29 | padding=paddings[i],
30 | bias=False if i !=0 else True),
31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
32 | activation if i != len(layers_dim) - 2 else nn.Identity()
33 | )
34 | for i in range(len(layers_dim) - 1)
35 | ])
36 |
37 | def forward(self, x):
38 | out = x
39 | for layer in self.layers:
40 | out = layer(out)
41 | return out
42 |
43 |
44 | if __name__ == '__main__':
45 | x = torch.randn((2,3, 256, 256))
46 | prob = Discriminator(im_channels=3)(x)
47 | print(prob.shape)
48 |
--------------------------------------------------------------------------------
/models/lpips.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import namedtuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn
9 | import torchvision
10 |
11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 | def spatial_average(in_tens, keepdim=True):
19 | return in_tens.mean([2, 3], keepdim=keepdim)
20 |
21 |
22 | class vgg16(torch.nn.Module):
23 | def __init__(self, requires_grad=False, pretrained=True):
24 | super(vgg16, self).__init__()
25 | # Load pretrained vgg model from torchvision
26 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features
27 | self.slice1 = torch.nn.Sequential()
28 | self.slice2 = torch.nn.Sequential()
29 | self.slice3 = torch.nn.Sequential()
30 | self.slice4 = torch.nn.Sequential()
31 | self.slice5 = torch.nn.Sequential()
32 | self.N_slices = 5
33 | for x in range(4):
34 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
35 | for x in range(4, 9):
36 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
37 | for x in range(9, 16):
38 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
39 | for x in range(16, 23):
40 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
41 | for x in range(23, 30):
42 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
43 |
44 | # Freeze vgg model
45 | if not requires_grad:
46 | for param in self.parameters():
47 | param.requires_grad = False
48 |
49 | def forward(self, X):
50 | # Return output of vgg features
51 | h = self.slice1(X)
52 | h_relu1_2 = h
53 | h = self.slice2(h)
54 | h_relu2_2 = h
55 | h = self.slice3(h)
56 | h_relu3_3 = h
57 | h = self.slice4(h)
58 | h_relu4_3 = h
59 | h = self.slice5(h)
60 | h_relu5_3 = h
61 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
62 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
63 | return out
64 |
65 |
66 | # Learned perceptual metric
67 | class LPIPS(nn.Module):
68 | def __init__(self, net='vgg', version='0.1', use_dropout=True):
69 | super(LPIPS, self).__init__()
70 | self.version = version
71 | # Imagenet normalization
72 | self.scaling_layer = ScalingLayer()
73 | ########################
74 |
75 | # Instantiate vgg model
76 | self.chns = [64, 128, 256, 512, 512]
77 | self.L = len(self.chns)
78 | self.net = vgg16(pretrained=True, requires_grad=False)
79 |
80 | # Add 1x1 convolutional Layers
81 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
82 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
83 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
84 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
85 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
86 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
87 | self.lins = nn.ModuleList(self.lins)
88 | ########################
89 |
90 | # Load the weights of trained LPIPS model
91 | import inspect
92 | import os
93 | model_path = os.path.abspath(
94 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
95 | print('Loading model from: %s' % model_path)
96 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
97 | ########################
98 |
99 | # Freeze all parameters
100 | self.eval()
101 | for param in self.parameters():
102 | param.requires_grad = False
103 | ########################
104 |
105 | def forward(self, in0, in1, normalize=False):
106 | # Scale the inputs to -1 to +1 range if needed
107 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
108 | in0 = 2 * in0 - 1
109 | in1 = 2 * in1 - 1
110 | ########################
111 |
112 | # Normalize the inputs according to imagenet normalization
113 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
114 | ########################
115 |
116 | # Get VGG outputs for image0 and image1
117 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
118 | feats0, feats1, diffs = {}, {}, {}
119 | ########################
120 |
121 | # Compute Square of Difference for each layer output
122 | for kk in range(self.L):
123 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
124 | outs1[kk])
125 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
126 | ########################
127 |
128 | # 1x1 convolution followed by spatial average on the square differences
129 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
130 | val = 0
131 |
132 | # Aggregate the results of each layer
133 | for l in range(self.L):
134 | val += res[l]
135 | return val
136 |
137 |
138 | class ScalingLayer(nn.Module):
139 | def __init__(self):
140 | super(ScalingLayer, self).__init__()
141 | # Imagnet normalization for (0-1)
142 | # mean = [0.485, 0.456, 0.406]
143 | # std = [0.229, 0.224, 0.225]
144 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
145 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
146 |
147 | def forward(self, inp):
148 | return (inp - self.shift) / self.scale
149 |
150 |
151 | class NetLinLayer(nn.Module):
152 | ''' A single linear layer which does a 1x1 conv '''
153 |
154 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
155 | super(NetLinLayer, self).__init__()
156 |
157 | layers = [nn.Dropout(), ] if (use_dropout) else []
158 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
159 | self.model = nn.Sequential(*layers)
160 |
161 | def forward(self, x):
162 | out = self.model(x)
163 | return out
164 |
--------------------------------------------------------------------------------
/models/unet_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def get_time_embedding(time_steps, temb_dim):
6 | r"""
7 | Convert time steps tensor into an embedding using the
8 | sinusoidal time embedding formula
9 | :param time_steps: 1D tensor of length batch size
10 | :param temb_dim: Dimension of the embedding
11 | :return: BxD embedding representation of B time steps
12 | """
13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
14 |
15 | # factor = 10000^(2i/d_model)
16 | factor = 10000 ** ((torch.arange(
17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
18 | )
19 |
20 | # pos / factor
21 | # timesteps B -> B, 1 -> B, temb_dim
22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
24 | return t_emb
25 |
26 |
27 | class DownBlock(nn.Module):
28 | r"""
29 | Down conv block with attention.
30 | Sequence of following block
31 | 1. Resnet block with time embedding
32 | 2. Attention block
33 | 3. Downsample using 2x2 average pooling
34 | """
35 | def __init__(self, in_channels, out_channels, t_emb_dim,
36 | down_sample=True, num_heads=4, num_layers=1):
37 | super().__init__()
38 | self.num_layers = num_layers
39 | self.down_sample = down_sample
40 | self.resnet_conv_first = nn.ModuleList(
41 | [
42 | nn.Sequential(
43 | nn.GroupNorm(8, in_channels if i == 0 else out_channels),
44 | nn.SiLU(),
45 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
46 | kernel_size=3, stride=1, padding=1),
47 | )
48 | for i in range(num_layers)
49 | ]
50 | )
51 | self.t_emb_layers = nn.ModuleList([
52 | nn.Sequential(
53 | nn.SiLU(),
54 | nn.Linear(t_emb_dim, out_channels)
55 | )
56 | for _ in range(num_layers)
57 | ])
58 | self.resnet_conv_second = nn.ModuleList(
59 | [
60 | nn.Sequential(
61 | nn.GroupNorm(8, out_channels),
62 | nn.SiLU(),
63 | nn.Conv2d(out_channels, out_channels,
64 | kernel_size=3, stride=1, padding=1),
65 | )
66 | for _ in range(num_layers)
67 | ]
68 | )
69 | self.attention_norms = nn.ModuleList(
70 | [nn.GroupNorm(8, out_channels)
71 | for _ in range(num_layers)]
72 | )
73 |
74 | self.attentions = nn.ModuleList(
75 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
76 | for _ in range(num_layers)]
77 | )
78 | self.residual_input_conv = nn.ModuleList(
79 | [
80 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
81 | for i in range(num_layers)
82 | ]
83 | )
84 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
85 | 4, 2, 1) if self.down_sample else nn.Identity()
86 |
87 | def forward(self, x, t_emb):
88 | out = x
89 | for i in range(self.num_layers):
90 |
91 | # Resnet block of Unet
92 | resnet_input = out
93 | out = self.resnet_conv_first[i](out)
94 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
95 | out = self.resnet_conv_second[i](out)
96 | out = out + self.residual_input_conv[i](resnet_input)
97 |
98 | # Attention block of Unet
99 | batch_size, channels, h, w = out.shape
100 | in_attn = out.reshape(batch_size, channels, h * w)
101 | in_attn = self.attention_norms[i](in_attn)
102 | in_attn = in_attn.transpose(1, 2)
103 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
104 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
105 | out = out + out_attn
106 |
107 | out = self.down_sample_conv(out)
108 | return out
109 |
110 |
111 | class MidBlock(nn.Module):
112 | r"""
113 | Mid conv block with attention.
114 | Sequence of following blocks
115 | 1. Resnet block with time embedding
116 | 2. Attention block
117 | 3. Resnet block with time embedding
118 | """
119 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):
120 | super().__init__()
121 | self.num_layers = num_layers
122 | self.resnet_conv_first = nn.ModuleList(
123 | [
124 | nn.Sequential(
125 | nn.GroupNorm(8, in_channels if i == 0 else out_channels),
126 | nn.SiLU(),
127 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
128 | padding=1),
129 | )
130 | for i in range(num_layers+1)
131 | ]
132 | )
133 | self.t_emb_layers = nn.ModuleList([
134 | nn.Sequential(
135 | nn.SiLU(),
136 | nn.Linear(t_emb_dim, out_channels)
137 | )
138 | for _ in range(num_layers + 1)
139 | ])
140 | self.resnet_conv_second = nn.ModuleList(
141 | [
142 | nn.Sequential(
143 | nn.GroupNorm(8, out_channels),
144 | nn.SiLU(),
145 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
146 | )
147 | for _ in range(num_layers+1)
148 | ]
149 | )
150 |
151 | self.attention_norms = nn.ModuleList(
152 | [nn.GroupNorm(8, out_channels)
153 | for _ in range(num_layers)]
154 | )
155 |
156 | self.attentions = nn.ModuleList(
157 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
158 | for _ in range(num_layers)]
159 | )
160 | self.residual_input_conv = nn.ModuleList(
161 | [
162 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
163 | for i in range(num_layers+1)
164 | ]
165 | )
166 |
167 | def forward(self, x, t_emb):
168 | out = x
169 |
170 | # First resnet block
171 | resnet_input = out
172 | out = self.resnet_conv_first[0](out)
173 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
174 | out = self.resnet_conv_second[0](out)
175 | out = out + self.residual_input_conv[0](resnet_input)
176 |
177 | for i in range(self.num_layers):
178 |
179 | # Attention Block
180 | batch_size, channels, h, w = out.shape
181 | in_attn = out.reshape(batch_size, channels, h * w)
182 | in_attn = self.attention_norms[i](in_attn)
183 | in_attn = in_attn.transpose(1, 2)
184 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
185 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
186 | out = out + out_attn
187 |
188 | # Resnet Block
189 | resnet_input = out
190 | out = self.resnet_conv_first[i+1](out)
191 | out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]
192 | out = self.resnet_conv_second[i+1](out)
193 | out = out + self.residual_input_conv[i+1](resnet_input)
194 |
195 | return out
196 |
197 |
198 | class UpBlock(nn.Module):
199 | r"""
200 | Up conv block with attention.
201 | Sequence of following blocks
202 | 1. Upsample
203 | 1. Concatenate Down block output
204 | 2. Resnet block with time embedding
205 | 3. Attention Block
206 | """
207 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):
208 | super().__init__()
209 | self.num_layers = num_layers
210 | self.up_sample = up_sample
211 | self.resnet_conv_first = nn.ModuleList(
212 | [
213 | nn.Sequential(
214 | nn.GroupNorm(8, in_channels if i == 0 else out_channels),
215 | nn.SiLU(),
216 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
217 | padding=1),
218 | )
219 | for i in range(num_layers)
220 | ]
221 | )
222 | self.t_emb_layers = nn.ModuleList([
223 | nn.Sequential(
224 | nn.SiLU(),
225 | nn.Linear(t_emb_dim, out_channels)
226 | )
227 | for _ in range(num_layers)
228 | ])
229 | self.resnet_conv_second = nn.ModuleList(
230 | [
231 | nn.Sequential(
232 | nn.GroupNorm(8, out_channels),
233 | nn.SiLU(),
234 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
235 | )
236 | for _ in range(num_layers)
237 | ]
238 | )
239 |
240 | self.attention_norms = nn.ModuleList(
241 | [
242 | nn.GroupNorm(8, out_channels)
243 | for _ in range(num_layers)
244 | ]
245 | )
246 |
247 | self.attentions = nn.ModuleList(
248 | [
249 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
250 | for _ in range(num_layers)
251 | ]
252 | )
253 | self.residual_input_conv = nn.ModuleList(
254 | [
255 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
256 | for i in range(num_layers)
257 | ]
258 | )
259 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
260 | 4, 2, 1) \
261 | if self.up_sample else nn.Identity()
262 |
263 | def forward(self, x, out_down, t_emb):
264 | x = self.up_sample_conv(x)
265 | x = torch.cat([x, out_down], dim=1)
266 |
267 | out = x
268 | for i in range(self.num_layers):
269 | # Resnet Block
270 | resnet_input = out
271 | out = self.resnet_conv_first[i](out)
272 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
273 | out = self.resnet_conv_second[i](out)
274 | out = out + self.residual_input_conv[i](resnet_input)
275 |
276 | # Attention Block
277 | batch_size, channels, h, w = out.shape
278 | in_attn = out.reshape(batch_size, channels, h * w)
279 | in_attn = self.attention_norms[i](in_attn)
280 | in_attn = in_attn.transpose(1, 2)
281 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
282 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
283 | out = out + out_attn
284 |
285 | return out
286 |
287 |
288 | class Unet(nn.Module):
289 | r"""
290 | Unet model comprising
291 | Down blocks, Midblocks and Uplocks
292 | """
293 | def __init__(self, model_config, use_up=True):
294 | super().__init__()
295 | im_channels = model_config['im_channels']
296 | self.down_channels = model_config['down_channels']
297 | self.mid_channels = model_config['mid_channels']
298 | self.t_emb_dim = model_config['time_emb_dim']
299 | self.down_sample = model_config['down_sample']
300 | self.num_down_layers = model_config['num_down_layers']
301 | self.num_mid_layers = model_config['num_mid_layers']
302 | self.num_up_layers = model_config['num_up_layers']
303 |
304 | assert self.mid_channels[0] == self.down_channels[-1]
305 | assert self.mid_channels[-1] == self.down_channels[-2]
306 | assert len(self.down_sample) == len(self.down_channels) - 1
307 |
308 | # Initial projection from sinusoidal time embedding
309 | self.t_proj = nn.Sequential(
310 | nn.Linear(self.t_emb_dim, self.t_emb_dim),
311 | nn.SiLU(),
312 | nn.Linear(self.t_emb_dim, self.t_emb_dim)
313 | )
314 |
315 | self.up_sample = list(reversed(self.down_sample))
316 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
317 |
318 | self.downs = nn.ModuleList([])
319 | for i in range(len(self.down_channels)-1):
320 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,
321 | down_sample=self.down_sample[i], num_layers=self.num_down_layers))
322 |
323 | self.mids = nn.ModuleList([])
324 | for i in range(len(self.mid_channels)-1):
325 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,
326 | num_layers=self.num_mid_layers))
327 |
328 | if use_up:
329 | self.ups = nn.ModuleList([])
330 | for i in reversed(range(len(self.down_channels)-1)):
331 | self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
332 | self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))
333 |
334 | self.norm_out = nn.GroupNorm(8, 16)
335 | self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
336 |
337 | def forward(self, x, t):
338 | # Shapes assuming downblocks are [C1, C2, C3, C4]
339 | # Shapes assuming midblocks are [C4, C4, C3]
340 | # Shapes assuming downsamples are [True, True, False]
341 |
342 | # t_emb -> B x t_emb_dim
343 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
344 | t_emb = self.t_proj(t_emb)
345 |
346 | # B x C x H x W
347 | out = self.conv_in(x)
348 | # B x C1 x H x W
349 |
350 | down_outs = []
351 |
352 | for idx, down in enumerate(self.downs):
353 | down_outs.append(out)
354 | out = down(out, t_emb)
355 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
356 | # out B x C4 x H/4 x W/4
357 |
358 | for mid in self.mids:
359 | out = mid(out, t_emb)
360 | # out B x C3 x H/4 x W/4
361 |
362 | for up in self.ups:
363 | down_out = down_outs.pop()
364 | out = up(out, down_out, t_emb)
365 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
366 | out = self.norm_out(out)
367 | out = nn.SiLU()(out)
368 | out = self.conv_out(out)
369 | # out B x C x H x W
370 | return out
371 |
--------------------------------------------------------------------------------
/models/unet_cond_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import einsum
3 | import torch.nn as nn
4 | from models.blocks import get_time_embedding
5 | from models.blocks import DownBlock, MidBlock, UpBlockUnet
6 | from utils.config_utils import *
7 |
8 |
9 | class Unet(nn.Module):
10 | r"""
11 | Unet model comprising
12 | Down blocks, Midblocks and Uplocks
13 | """
14 |
15 | def __init__(self, im_channels, model_config, use_up=True):
16 | super().__init__()
17 | self.down_channels = model_config['down_channels']
18 | self.mid_channels = model_config['mid_channels']
19 | self.t_emb_dim = model_config['time_emb_dim']
20 | self.down_sample = model_config['down_sample']
21 | self.num_down_layers = model_config['num_down_layers']
22 | self.num_mid_layers = model_config['num_mid_layers']
23 | self.num_up_layers = model_config['num_up_layers']
24 | self.attns = model_config['attn_down']
25 | self.norm_channels = model_config['norm_channels']
26 | self.num_heads = model_config['num_heads']
27 | self.conv_out_channels = model_config['conv_out_channels']
28 |
29 | # Validating Unet Model configurations
30 | assert self.mid_channels[0] == self.down_channels[-1]
31 | assert self.mid_channels[-1] == self.down_channels[-2]
32 | assert len(self.down_sample) == len(self.down_channels) - 1
33 | assert len(self.attns) == len(self.down_channels) - 1
34 |
35 | ######## Class, Mask and Text Conditioning Config #####
36 | self.class_cond = False
37 | self.text_cond = False
38 | self.image_cond = False
39 | self.text_embed_dim = None
40 | self.condition_config = get_config_value(model_config, 'condition_config', None)
41 | if self.condition_config is not None:
42 | assert 'condition_types' in self.condition_config, 'Condition Type not provided in model config'
43 | condition_types = self.condition_config['condition_types']
44 | if 'class' in condition_types:
45 | validate_class_config(self.condition_config)
46 | self.class_cond = True
47 | self.num_classes = self.condition_config['class_condition_config']['num_classes']
48 | if 'text' in condition_types:
49 | validate_text_config(self.condition_config)
50 | self.text_cond = True
51 | self.text_embed_dim = self.condition_config['text_condition_config']['text_embed_dim']
52 | if 'image' in condition_types:
53 | self.image_cond = True
54 | self.im_cond_input_ch = self.condition_config['image_condition_config'][
55 | 'image_condition_input_channels']
56 | self.im_cond_output_ch = self.condition_config['image_condition_config'][
57 | 'image_condition_output_channels']
58 | if self.class_cond:
59 | # Rather than using a special null class we dont add the
60 | # class embedding information for unconditional generation
61 | self.class_emb = nn.Embedding(self.num_classes,
62 | self.t_emb_dim)
63 |
64 | if self.image_cond:
65 | # Map the mask image to a N channel image and
66 | # concat that with input across channel dimension
67 | self.cond_conv_in = nn.Conv2d(in_channels=self.im_cond_input_ch,
68 | out_channels=self.im_cond_output_ch,
69 | kernel_size=1,
70 | bias=False)
71 | self.conv_in_concat = nn.Conv2d(im_channels + self.im_cond_output_ch,
72 | self.down_channels[0], kernel_size=3, padding=1)
73 | else:
74 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
75 | self.cond = self.text_cond or self.image_cond or self.class_cond
76 | ###################################
77 |
78 | # Initial projection from sinusoidal time embedding
79 | self.t_proj = nn.Sequential(
80 | nn.Linear(self.t_emb_dim, self.t_emb_dim),
81 | nn.SiLU(),
82 | nn.Linear(self.t_emb_dim, self.t_emb_dim)
83 | )
84 |
85 | self.up_sample = list(reversed(self.down_sample))
86 | self.downs = nn.ModuleList([])
87 |
88 | # Build the Downblocks
89 | for i in range(len(self.down_channels) - 1):
90 | # Cross Attention and Context Dim only needed if text condition is present
91 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim,
92 | down_sample=self.down_sample[i],
93 | num_heads=self.num_heads,
94 | num_layers=self.num_down_layers,
95 | attn=self.attns[i], norm_channels=self.norm_channels,
96 | cross_attn=self.text_cond,
97 | context_dim=self.text_embed_dim))
98 |
99 | self.mids = nn.ModuleList([])
100 | # Build the Midblocks
101 | for i in range(len(self.mid_channels) - 1):
102 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
103 | num_heads=self.num_heads,
104 | num_layers=self.num_mid_layers,
105 | norm_channels=self.norm_channels,
106 | cross_attn=self.text_cond,
107 | context_dim=self.text_embed_dim))
108 |
109 | self.ups = nn.ModuleList([])
110 | if use_up:
111 | # Build the Upblocks
112 | for i in reversed(range(len(self.down_channels) - 1)):
113 | self.ups.append(
114 | UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
115 | self.t_emb_dim, up_sample=self.down_sample[i],
116 | num_heads=self.num_heads,
117 | num_layers=self.num_up_layers,
118 | norm_channels=self.norm_channels,
119 | cross_attn=self.text_cond,
120 | context_dim=self.text_embed_dim))
121 |
122 | self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
123 | self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
124 |
125 | def forward(self, x, t, cond_input=None):
126 | # Shapes assuming downblocks are [C1, C2, C3, C4]
127 | # Shapes assuming midblocks are [C4, C4, C3]
128 | # Shapes assuming downsamples are [True, True, False]
129 | if self.cond:
130 | assert cond_input is not None, \
131 | "Model initialized with conditioning so cond_input cannot be None"
132 | if self.image_cond:
133 | ######## Mask Conditioning ########
134 | validate_image_conditional_input(cond_input, x)
135 | im_cond = cond_input['image']
136 | im_cond = torch.nn.functional.interpolate(im_cond, size=x.shape[-2:])
137 | im_cond = self.cond_conv_in(im_cond)
138 | assert im_cond.shape[-2:] == x.shape[-2:]
139 | x = torch.cat([x, im_cond], dim=1)
140 | # B x (C+N) x H x W
141 | out = self.conv_in_concat(x)
142 | #####################################
143 | else:
144 | # B x C x H x W
145 | out = self.conv_in(x)
146 | # B x C1 x H x W
147 |
148 | # t_emb -> B x t_emb_dim
149 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
150 | t_emb = self.t_proj(t_emb)
151 |
152 | ######## Class Conditioning ########
153 | if self.class_cond:
154 | validate_class_conditional_input(cond_input, x, self.num_classes)
155 | class_embed = einsum(cond_input['class'].float(), self.class_emb.weight, 'b n, n d -> b d')
156 | t_emb += class_embed
157 | ####################################
158 |
159 | context_hidden_states = None
160 | if self.text_cond:
161 | assert 'text' in cond_input, \
162 | "Model initialized with text conditioning but cond_input has no text information"
163 | context_hidden_states = cond_input['text']
164 | down_outs = []
165 |
166 | for idx, down in enumerate(self.downs):
167 | down_outs.append(out)
168 | out = down(out, t_emb, context_hidden_states)
169 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
170 | # out B x C4 x H/4 x W/4
171 |
172 | for mid in self.mids:
173 | out = mid(out, t_emb, context_hidden_states)
174 | # out B x C3 x H/4 x W/4
175 |
176 | for up in self.ups:
177 | down_out = down_outs.pop()
178 | out = up(out, down_out, t_emb, context_hidden_states)
179 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
180 | out = self.norm_out(out)
181 | out = nn.SiLU()(out)
182 | out = self.conv_out(out)
183 | # out B x C x H x W
184 | return out
185 |
--------------------------------------------------------------------------------
/models/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.blocks import DownBlock, MidBlock, UpBlock
4 |
5 |
6 | class VAE(nn.Module):
7 | def __init__(self, im_channels, model_config):
8 | super().__init__()
9 | self.down_channels = model_config['down_channels']
10 | self.mid_channels = model_config['mid_channels']
11 | self.down_sample = model_config['down_sample']
12 | self.num_down_layers = model_config['num_down_layers']
13 | self.num_mid_layers = model_config['num_mid_layers']
14 | self.num_up_layers = model_config['num_up_layers']
15 |
16 | # To disable attention in Downblock of Encoder and Upblock of Decoder
17 | self.attns = model_config['attn_down']
18 |
19 | # Latent Dimension
20 | self.z_channels = model_config['z_channels']
21 | self.norm_channels = model_config['norm_channels']
22 | self.num_heads = model_config['num_heads']
23 |
24 | # Assertion to validate the channel information
25 | assert self.mid_channels[0] == self.down_channels[-1]
26 | assert self.mid_channels[-1] == self.down_channels[-1]
27 | assert len(self.down_sample) == len(self.down_channels) - 1
28 | assert len(self.attns) == len(self.down_channels) - 1
29 |
30 | # Wherever we use downsampling in encoder correspondingly use
31 | # upsampling in decoder
32 | self.up_sample = list(reversed(self.down_sample))
33 |
34 | ##################### Encoder ######################
35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
36 |
37 | # Downblock + Midblock
38 | self.encoder_layers = nn.ModuleList([])
39 | for i in range(len(self.down_channels) - 1):
40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
41 | t_emb_dim=None, down_sample=self.down_sample[i],
42 | num_heads=self.num_heads,
43 | num_layers=self.num_down_layers,
44 | attn=self.attns[i],
45 | norm_channels=self.norm_channels))
46 |
47 | self.encoder_mids = nn.ModuleList([])
48 | for i in range(len(self.mid_channels) - 1):
49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
50 | t_emb_dim=None,
51 | num_heads=self.num_heads,
52 | num_layers=self.num_mid_layers,
53 | norm_channels=self.norm_channels))
54 |
55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2 * self.z_channels, kernel_size=3, padding=1)
57 |
58 | # Latent Dimension is 2*Latent because we are predicting mean & variance
59 | self.pre_quant_conv = nn.Conv2d(2 * self.z_channels, 2 * self.z_channels, kernel_size=1)
60 | ####################################################
61 |
62 | ##################### Decoder ######################
63 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
64 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
65 |
66 | # Midblock + Upblock
67 | self.decoder_mids = nn.ModuleList([])
68 | for i in reversed(range(1, len(self.mid_channels))):
69 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
70 | t_emb_dim=None,
71 | num_heads=self.num_heads,
72 | num_layers=self.num_mid_layers,
73 | norm_channels=self.norm_channels))
74 |
75 | self.decoder_layers = nn.ModuleList([])
76 | for i in reversed(range(1, len(self.down_channels))):
77 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
78 | t_emb_dim=None, up_sample=self.down_sample[i - 1],
79 | num_heads=self.num_heads,
80 | num_layers=self.num_up_layers,
81 | attn=self.attns[i - 1],
82 | norm_channels=self.norm_channels))
83 |
84 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
85 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
86 |
87 | def encode(self, x):
88 | out = self.encoder_conv_in(x)
89 | for idx, down in enumerate(self.encoder_layers):
90 | out = down(out)
91 | for mid in self.encoder_mids:
92 | out = mid(out)
93 | out = self.encoder_norm_out(out)
94 | out = nn.SiLU()(out)
95 | out = self.encoder_conv_out(out)
96 | out = self.pre_quant_conv(out)
97 | mean, logvar = torch.chunk(out, 2, dim=1)
98 | std = torch.exp(0.5 * logvar)
99 | sample = mean + std * torch.randn(mean.shape).to(device=x.device)
100 | return sample, out
101 |
102 | def decode(self, z):
103 | out = z
104 | out = self.post_quant_conv(out)
105 | out = self.decoder_conv_in(out)
106 | for mid in self.decoder_mids:
107 | out = mid(out)
108 | for idx, up in enumerate(self.decoder_layers):
109 | out = up(out)
110 |
111 | out = self.decoder_norm_out(out)
112 | out = nn.SiLU()(out)
113 | out = self.decoder_conv_out(out)
114 | return out
115 |
116 | def forward(self, x):
117 | z, encoder_output = self.encode(x)
118 | out = self.decode(z)
119 | return out, encoder_output
120 |
121 |
--------------------------------------------------------------------------------
/models/weights/v0.1/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/models/weights/v0.1/.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.8.0
2 | numpy==2.0.1
3 | opencv_python==4.10.0.84
4 | Pillow==10.4.0
5 | PyYAML==6.0.1
6 | torch==2.3.1
7 | torchvision==0.18.1
8 | tqdm==4.66.4
9 |
--------------------------------------------------------------------------------
/scheduler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/scheduler/__init__.py
--------------------------------------------------------------------------------
/scheduler/linear_noise_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LinearNoiseScheduler:
5 | r"""
6 | Class for the linear noise scheduler that is used in DDPM.
7 | """
8 | def __init__(self, num_timesteps, beta_start, beta_end, ldm_scheduler=False):
9 | self.num_timesteps = num_timesteps
10 | self.beta_start = beta_start
11 | self.beta_end = beta_end
12 |
13 | if ldm_scheduler:
14 | # Mimicking how compvis repo creates schedule
15 | self.betas = (
16 | torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2
17 | )
18 | else:
19 | self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
20 | self.alphas = 1. - self.betas
21 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
22 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
23 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
24 |
25 | def add_noise(self, original, noise, t):
26 | r"""
27 | Forward method for diffusion
28 | :param original: Image on which noise is to be applied
29 | :param noise: Random Noise Tensor (from normal dist)
30 | :param t: timestep of the forward process of shape -> (B,)
31 | :return:
32 | """
33 | original_shape = original.shape
34 | batch_size = original_shape[0]
35 |
36 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
37 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
38 |
39 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
40 | for _ in range(len(original_shape) - 1):
41 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
42 | for _ in range(len(original_shape) - 1):
43 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
44 |
45 | # Apply and Return Forward process equation
46 | return (sqrt_alpha_cum_prod.to(original.device) * original
47 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
48 |
49 | def sample_prev_timestep(self, xt, noise_pred, t):
50 | r"""
51 | Use the noise prediction by model to get
52 | xt-1 using xt and the noise predicted
53 | :param xt: current timestep sample
54 | :param noise_pred: model noise prediction
55 | :param t: current timestep we are at
56 | :return:
57 | """
58 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
59 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
60 | x0 = torch.clamp(x0, -1., 1.)
61 |
62 | mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
63 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
64 |
65 | if t == 0:
66 | return mean, x0
67 | else:
68 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
69 | variance = variance * self.betas.to(xt.device)[t]
70 | sigma = variance ** 0.5
71 | z = torch.randn(xt.shape).to(xt.device)
72 |
73 | # OR
74 | # variance = self.betas[t]
75 | # sigma = variance ** 0.5
76 | # z = torch.randn(xt.shape).to(xt.device)
77 | return mean + sigma * z, x0
78 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/tools/__init__.py
--------------------------------------------------------------------------------
/tools/infer_vae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import pickle
5 |
6 | import torch
7 | import torchvision
8 | import yaml
9 | from torch.utils.data.dataloader import DataLoader
10 | from torchvision.utils import make_grid
11 | from tqdm import tqdm
12 |
13 | from dataset.celeb_dataset import CelebDataset
14 | from dataset.mnist_dataset import MnistDataset
15 | from models.vae import VAE
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 | if torch.backends.mps.is_available():
19 | device = torch.device('mps')
20 | print('Using mps')
21 |
22 |
23 | def infer(args):
24 | ######## Read the config file #######
25 | with open(args.config_path, 'r') as file:
26 | try:
27 | config = yaml.safe_load(file)
28 | except yaml.YAMLError as exc:
29 | print(exc)
30 | print(config)
31 |
32 | dataset_config = config['dataset_params']
33 | autoencoder_config = config['autoencoder_params']
34 | train_config = config['train_params']
35 |
36 | # Create the dataset
37 | im_dataset_cls = {
38 | 'mnist': MnistDataset,
39 | 'celebhq': CelebDataset,
40 | }.get(dataset_config['name'])
41 |
42 | im_dataset = im_dataset_cls(split='train',
43 | im_path=dataset_config['im_path'],
44 | im_size=dataset_config['im_size'],
45 | im_channels=dataset_config['im_channels'])
46 |
47 | # This is only used for saving latents. Which as of now
48 | # is not done in batches hence batch size 1
49 | data_loader = DataLoader(im_dataset,
50 | batch_size=1,
51 | shuffle=False)
52 |
53 | num_images = train_config['num_samples']
54 | ngrid = train_config['num_grid_rows']
55 |
56 | idxs = torch.randint(0, len(im_dataset) - 1, (num_images,))
57 | ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float()
58 | ims = ims.to(device)
59 |
60 | model = VAE(im_channels=dataset_config['im_channels'],
61 | model_config=autoencoder_config).to(device)
62 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
63 | train_config['vae_autoencoder_ckpt_name']),
64 | map_location=device))
65 | model.eval()
66 |
67 | with torch.no_grad():
68 |
69 | encoded_output, _ = model.encode(ims)
70 | decoded_output = model.decode(encoded_output)
71 | encoded_output = torch.clamp(encoded_output, -1., 1.)
72 | encoded_output = (encoded_output + 1) / 2
73 | decoded_output = torch.clamp(decoded_output, -1., 1.)
74 | decoded_output = (decoded_output + 1) / 2
75 | ims = (ims + 1) / 2
76 |
77 | encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid)
78 | decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid)
79 | input_grid = make_grid(ims.cpu(), nrow=ngrid)
80 | encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid)
81 | decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid)
82 | input_grid = torchvision.transforms.ToPILImage()(input_grid)
83 |
84 | input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png'))
85 | encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png'))
86 | decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png'))
87 |
88 | if train_config['save_latents']:
89 | # save Latents (but in a very unoptimized way)
90 | latent_path = os.path.join(train_config['task_name'], train_config['vae_latent_dir_name'])
91 | latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vae_latent_dir_name'],
92 | '*.pkl'))
93 | assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run'
94 | if not os.path.exists(latent_path):
95 | os.mkdir(latent_path)
96 | print('Saving Latents for {}'.format(dataset_config['name']))
97 |
98 | fname_latent_map = {}
99 | part_count = 0
100 | count = 0
101 | for idx, im in enumerate(tqdm(data_loader)):
102 | _, encoded_output = model.encode(im.float().to(device))
103 | fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu()
104 | # Save latents every 1000 images
105 | if (count + 1) % 1000 == 0:
106 | pickle.dump(fname_latent_map, open(os.path.join(latent_path,
107 | '{}.pkl'.format(part_count)), 'wb'))
108 | part_count += 1
109 | fname_latent_map = {}
110 | count += 1
111 | if len(fname_latent_map) > 0:
112 | pickle.dump(fname_latent_map, open(os.path.join(latent_path,
113 | '{}.pkl'.format(part_count)), 'wb'))
114 | print('Done saving latents')
115 |
116 |
117 | if __name__ == '__main__':
118 | parser = argparse.ArgumentParser(description='Arguments for vae inference')
119 | parser.add_argument('--config', dest='config_path',
120 | default='config/celebhq.yaml', type=str)
121 | args = parser.parse_args()
122 | infer(args)
123 |
--------------------------------------------------------------------------------
/tools/sample_ddpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | from torchvision.utils import make_grid
7 | from tqdm import tqdm
8 | from models.unet_base import Unet
9 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
10 |
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 | if torch.backends.mps.is_available():
14 | device = torch.device('mps')
15 | print('Using mps')
16 |
17 |
18 | def sample(model, scheduler, train_config, model_config, diffusion_config):
19 | r"""
20 | Sample stepwise by going backward one timestep at a time.
21 | We save the x0 predictions
22 | """
23 | xt = torch.randn((train_config['num_samples'],
24 | model_config['im_channels'],
25 | model_config['im_size'],
26 | model_config['im_size'])).to(device)
27 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
28 | # Get prediction of noise
29 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
30 |
31 | # Use scheduler to get x0 and xt-1
32 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
33 |
34 | # Save x0
35 | ims = torch.clamp(xt, -1., 1.).detach().cpu()
36 | ims = (ims + 1) / 2
37 | grid = make_grid(ims, nrow=train_config['num_grid_rows'])
38 | img = torchvision.transforms.ToPILImage()(grid)
39 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
40 | os.mkdir(os.path.join(train_config['task_name'], 'samples'))
41 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
42 | img.close()
43 |
44 |
45 | def infer(args):
46 | # Read the config file #
47 | with open(args.config_path, 'r') as file:
48 | try:
49 | config = yaml.safe_load(file)
50 | except yaml.YAMLError as exc:
51 | print(exc)
52 | print(config)
53 | ########################
54 |
55 | diffusion_config = config['diffusion_params']
56 | model_config = config['model_params']
57 | train_config = config['train_params']
58 |
59 | # Load model with checkpoint
60 | model = Unet(model_config).to(device)
61 | assert os.path.exists(os.path.join(train_config['task_name'],
62 | train_config['ddpm_ckpt_name'])), "Train DDPM first"
63 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
64 | train_config['ddpm_ckpt_name']), map_location=device))
65 | model.eval()
66 |
67 | # Create the noise scheduler
68 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
69 | beta_start=diffusion_config['beta_start'],
70 | beta_end=diffusion_config['beta_end'])
71 | with torch.no_grad():
72 | sample(model, scheduler, train_config, model_config, diffusion_config)
73 |
74 |
75 | if __name__ == '__main__':
76 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')
77 | parser.add_argument('--config', dest='config_path',
78 | default='config/mnist.yaml', type=str)
79 | args = parser.parse_args()
80 | infer(args)
81 |
--------------------------------------------------------------------------------
/tools/sample_ddpm_controlnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | import random
7 | from torchvision.utils import make_grid
8 | from tqdm import tqdm
9 | from dataset.mnist_dataset import MnistDataset
10 | from models.controlnet import ControlNet
11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
12 |
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | if torch.backends.mps.is_available():
16 | device = torch.device('mps')
17 | print('Using mps')
18 |
19 |
20 | def sample(model, scheduler, train_config, model_config, diffusion_config, dataset):
21 | r"""
22 | Sample stepwise by going backward one timestep at a time.
23 | We save the x0 predictions
24 | """
25 | xt = torch.randn((train_config['num_samples'],
26 | model_config['im_channels'],
27 | model_config['im_size'],
28 | model_config['im_size'])).to(device)
29 |
30 | # Get random hints for the desired number of samples
31 | hints = []
32 | for idx in range(train_config['num_samples']):
33 | hint_idx = random.randint(0, len(dataset))
34 | hints.append(dataset[hint_idx][1].unsqueeze(0).to(device))
35 | hints = torch.cat(hints, dim=0).to(device)
36 |
37 | # Save the hints
38 | hints_grid = make_grid(hints, nrow=train_config['num_grid_rows'])
39 | hints_img = torchvision.transforms.ToPILImage()(hints_grid)
40 | hints_img.save(os.path.join(train_config['task_name'], 'hint.png'))
41 |
42 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
43 | # Get prediction of noise
44 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device), hints)
45 |
46 | # Prediction from original model
47 | # noise_pred = model.trained_unet(xt, torch.as_tensor(i).unsqueeze(0).to(device))
48 |
49 | # Use scheduler to get x0 and xt-1
50 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
51 |
52 | # Save x0
53 | ims = torch.clamp(xt, -1., 1.).detach().cpu()
54 | ims = (ims + 1) / 2
55 | grid = make_grid(ims, nrow=train_config['num_grid_rows'])
56 | img = torchvision.transforms.ToPILImage()(grid)
57 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples_controlnet')):
58 | os.mkdir(os.path.join(train_config['task_name'], 'samples_controlnet'))
59 | img.save(os.path.join(train_config['task_name'], 'samples_controlnet', 'x0_{}.png'.format(i)))
60 | img.close()
61 |
62 |
63 | def infer(args):
64 | # Read the config file #
65 | with open(args.config_path, 'r') as file:
66 | try:
67 | config = yaml.safe_load(file)
68 | except yaml.YAMLError as exc:
69 | print(exc)
70 | print(config)
71 | ########################
72 |
73 | diffusion_config = config['diffusion_params']
74 | model_config = config['model_params']
75 | train_config = config['train_params']
76 | dataset_config = config['dataset_params']
77 |
78 | # Change to require hints
79 | mnist_canny = MnistDataset('test', im_path=dataset_config['im_test_path'], return_hints=True)
80 |
81 | # Load model with checkpoint
82 | model = ControlNet(model_config,
83 | model_ckpt=os.path.join(train_config['task_name'], train_config['ddpm_ckpt_name']),
84 | device=device).to(device)
85 |
86 | assert os.path.exists(os.path.join(train_config['task_name'],
87 | train_config['controlnet_ckpt_name'])), "Train ControlNet first"
88 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
89 | train_config['controlnet_ckpt_name']),
90 | map_location=device))
91 | model.eval()
92 | print('Loaded ControlNet checkpoint')
93 |
94 | # Create the noise scheduler
95 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
96 | beta_start=diffusion_config['beta_start'],
97 | beta_end=diffusion_config['beta_end'])
98 | with torch.no_grad():
99 | sample(model, scheduler, train_config, model_config, diffusion_config, mnist_canny)
100 |
101 |
102 | if __name__ == '__main__':
103 | parser = argparse.ArgumentParser(description='Arguments for controlnet ddpm image generation')
104 | parser.add_argument('--config', dest='config_path',
105 | default='config/mnist.yaml', type=str)
106 | args = parser.parse_args()
107 | infer(args)
108 |
--------------------------------------------------------------------------------
/tools/sample_ldm_controlnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | import random
7 | from torchvision.utils import make_grid
8 | from tqdm import tqdm
9 | from models.controlnet_ldm import ControlNet
10 | from dataset.celeb_dataset import CelebDataset
11 | from models.vae import VAE
12 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | if torch.backends.mps.is_available():
16 | device = torch.device('mps')
17 | print('Using mps')
18 |
19 |
20 | def sample(model, scheduler, train_config, diffusion_model_config,
21 | autoencoder_model_config, diffusion_config, dataset_config, vae, dataset):
22 | r"""
23 | Sample stepwise by going backward one timestep at a time.
24 | We save the x0 predictions
25 | """
26 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
27 | xt = torch.randn((train_config['num_samples'],
28 | autoencoder_model_config['z_channels'],
29 | im_size,
30 | im_size)).to(device)
31 |
32 | # Get random hints for the desired number of samples
33 | hints = []
34 |
35 | for idx in range(train_config['num_samples']):
36 | hint_idx = random.randint(0, len(dataset))
37 | hints.append(dataset[hint_idx][1].unsqueeze(0).to(device))
38 | hints = torch.cat(hints, dim=0).to(device)
39 |
40 | # Save the hints
41 | hints_grid = make_grid(hints, nrow=train_config['num_grid_rows'])
42 | hints_img = torchvision.transforms.ToPILImage()(hints_grid)
43 | hints_img.save(os.path.join(train_config['task_name'], 'hint.png'))
44 |
45 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
46 | # Get prediction of noise
47 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device), hints)
48 |
49 | # Use scheduler to get x0 and xt-1
50 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
51 |
52 | # Save x0
53 | # ims = torch.clamp(xt, -1., 1.).detach().cpu()
54 | if i == 0:
55 | # Decode ONLY the final image to save time
56 | ims = vae.to(device).decode(xt)
57 | else:
58 | ims = xt
59 |
60 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
61 | ims = (ims + 1) / 2
62 | grid = make_grid(ims, nrow=train_config['num_grid_rows'])
63 | img = torchvision.transforms.ToPILImage()(grid)
64 |
65 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples_controlnet')):
66 | os.mkdir(os.path.join(train_config['task_name'], 'samples_controlnet'))
67 | img.save(os.path.join(train_config['task_name'], 'samples_controlnet', 'x0_{}.png'.format(i)))
68 | img.close()
69 |
70 |
71 | def infer(args):
72 | # Read the config file #
73 | with open(args.config_path, 'r') as file:
74 | try:
75 | config = yaml.safe_load(file)
76 | except yaml.YAMLError as exc:
77 | print(exc)
78 | print(config)
79 | ########################
80 |
81 | diffusion_config = config['diffusion_params']
82 | dataset_config = config['dataset_params']
83 | diffusion_model_config = config['ldm_params']
84 | autoencoder_model_config = config['autoencoder_params']
85 | train_config = config['train_params']
86 |
87 | # Create the noise scheduler
88 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
89 | beta_start=diffusion_config['beta_start'],
90 | beta_end=diffusion_config['beta_end'],
91 | ldm_scheduler=True)
92 |
93 | celeb_canny = CelebDataset('test',
94 | im_path=dataset_config['im_path'],
95 | im_size=dataset_config['im_size'],
96 | return_hint=True)
97 |
98 | # Instantiate the model
99 | latent_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
100 | downscale_factor = dataset_config['canny_im_size'] // latent_size
101 | model = ControlNet(im_channels=autoencoder_model_config['z_channels'],
102 | model_config=diffusion_model_config,
103 | model_locked=True,
104 | model_ckpt=os.path.join(train_config['task_name'], train_config['ldm_ckpt_name']),
105 | device=device,
106 | down_sample_factor=downscale_factor).to(device)
107 | model.eval()
108 |
109 | assert os.path.exists(os.path.join(train_config['task_name'],
110 | train_config['controlnet_ckpt_name'])), "Train ControlNet first"
111 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
112 | train_config['controlnet_ckpt_name']),
113 | map_location=device))
114 | print('Loaded controlnet checkpoint')
115 |
116 | vae = VAE(im_channels=dataset_config['im_channels'],
117 | model_config=autoencoder_model_config)
118 | vae.eval()
119 |
120 | # Load vae if found
121 | assert os.path.exists(os.path.join(train_config['task_name'], train_config['vae_autoencoder_ckpt_name'])), \
122 | "VAE checkpoint not present. Train VAE first."
123 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
124 | train_config['vae_autoencoder_ckpt_name']),
125 | map_location=device), strict=True)
126 | print('Loaded vae checkpoint')
127 |
128 | with torch.no_grad():
129 | sample(model, scheduler, train_config, diffusion_model_config,
130 | autoencoder_model_config, diffusion_config, dataset_config, vae, celeb_canny)
131 |
132 |
133 | if __name__ == '__main__':
134 | parser = argparse.ArgumentParser(description='Arguments for ldm controlnet generation')
135 | parser.add_argument('--config', dest='config_path',
136 | default='config/celebhq.yaml', type=str)
137 | args = parser.parse_args()
138 | infer(args)
139 |
--------------------------------------------------------------------------------
/tools/sample_ldm_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | from torchvision.utils import make_grid
7 | from PIL import Image
8 | from tqdm import tqdm
9 | from models.unet_cond_base import Unet
10 | from models.vae import VAE
11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 |
19 | def sample(model, scheduler, train_config, diffusion_model_config,
20 | autoencoder_model_config, diffusion_config, dataset_config, vae):
21 | r"""
22 | Sample stepwise by going backward one timestep at a time.
23 | We save the x0 predictions
24 | """
25 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
26 | xt = torch.randn((train_config['num_samples'],
27 | autoencoder_model_config['z_channels'],
28 | im_size,
29 | im_size)).to(device)
30 |
31 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
32 | # Get prediction of noise
33 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
34 |
35 | # Use scheduler to get x0 and xt-1
36 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
37 |
38 | # Save x0
39 | # ims = torch.clamp(xt, -1., 1.).detach().cpu()
40 | if i == 0:
41 | # Decode ONLY the final image to save time
42 | ims = vae.to(device).decode(xt)
43 | else:
44 | ims = xt
45 |
46 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
47 | ims = (ims + 1) / 2
48 |
49 | grid = make_grid(ims, nrow=train_config['num_grid_rows'])
50 | img = torchvision.transforms.ToPILImage()(grid)
51 |
52 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
53 | os.mkdir(os.path.join(train_config['task_name'], 'samples'))
54 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
55 | img.close()
56 |
57 |
58 | def infer(args):
59 | # Read the config file #
60 | with open(args.config_path, 'r') as file:
61 | try:
62 | config = yaml.safe_load(file)
63 | except yaml.YAMLError as exc:
64 | print(exc)
65 | print(config)
66 | ########################
67 |
68 | diffusion_config = config['diffusion_params']
69 | dataset_config = config['dataset_params']
70 | diffusion_model_config = config['ldm_params']
71 | autoencoder_model_config = config['autoencoder_params']
72 | train_config = config['train_params']
73 |
74 | # Create the noise scheduler
75 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
76 | beta_start=diffusion_config['beta_start'],
77 | beta_end=diffusion_config['beta_end'],
78 | ldm_scheduler=True)
79 |
80 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
81 | model_config=diffusion_model_config).to(device)
82 | model.eval()
83 |
84 | assert os.path.exists(os.path.join(train_config['task_name'],
85 | train_config['ldm_ckpt_name'])), "Train LDM first"
86 |
87 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
88 | train_config['ldm_ckpt_name']),
89 | map_location=device))
90 | print('Loaded unet checkpoint')
91 |
92 | # Create output directories
93 | if not os.path.exists(train_config['task_name']):
94 | os.mkdir(train_config['task_name'])
95 |
96 | vae = VAE(im_channels=dataset_config['im_channels'],
97 | model_config=autoencoder_model_config)
98 | vae.eval()
99 |
100 | # Load vae if found
101 | assert os.path.exists(os.path.join(train_config['task_name'], train_config['vae_autoencoder_ckpt_name'])), \
102 | "VAE checkpoint not present. Train VAE first."
103 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
104 | train_config['vae_autoencoder_ckpt_name']),
105 | map_location=device), strict=True)
106 | print('Loaded vae checkpoint')
107 |
108 | with torch.no_grad():
109 | sample(model, scheduler, train_config, diffusion_model_config,
110 | autoencoder_model_config, diffusion_config, dataset_config, vae)
111 |
112 |
113 | if __name__ == '__main__':
114 | parser = argparse.ArgumentParser(description='Arguments for ldm image generation')
115 | parser.add_argument('--config', dest='config_path',
116 | default='config/celebhq.yaml', type=str)
117 | args = parser.parse_args()
118 | infer(args)
119 |
--------------------------------------------------------------------------------
/tools/train_ddpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import argparse
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.optim import Adam
8 | from dataset.mnist_dataset import MnistDataset
9 | from torch.utils.data import DataLoader
10 | from models.unet_base import Unet
11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 |
19 | def train(args):
20 | # Read the config file #
21 | with open(args.config_path, 'r') as file:
22 | try:
23 | config = yaml.safe_load(file)
24 | except yaml.YAMLError as exc:
25 | print(exc)
26 | print(config)
27 | ########################
28 |
29 | diffusion_config = config['diffusion_params']
30 | dataset_config = config['dataset_params']
31 | model_config = config['model_params']
32 | train_config = config['train_params']
33 |
34 | # Create the noise scheduler
35 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
36 | beta_start=diffusion_config['beta_start'],
37 | beta_end=diffusion_config['beta_end'])
38 |
39 | # Create the dataset
40 | mnist = MnistDataset('train', im_path=dataset_config['im_path'])
41 | mnist_loader = DataLoader(mnist,
42 | batch_size=train_config['batch_size'],
43 | shuffle=True,
44 | num_workers=4)
45 |
46 | # Instantiate the model
47 | model = Unet(model_config).to(device)
48 | model.train()
49 |
50 | # Create output directories
51 | if not os.path.exists(train_config['task_name']):
52 | os.mkdir(train_config['task_name'])
53 |
54 | # Load checkpoint if found
55 | if os.path.exists(os.path.join(train_config['task_name'],train_config['ddpm_ckpt_name'])):
56 | print('Loading checkpoint as found one')
57 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
58 | train_config['ddpm_ckpt_name']), map_location=device))
59 | # Specify training parameters
60 | num_epochs = train_config['num_epochs']
61 | optimizer = Adam(model.parameters(), lr=train_config['ddpm_lr'])
62 | criterion = torch.nn.MSELoss()
63 |
64 | # Run training
65 | for epoch_idx in range(num_epochs):
66 | losses = []
67 | for im in tqdm(mnist_loader):
68 | optimizer.zero_grad()
69 | im = im.float().to(device)
70 |
71 | # Sample random noise
72 | noise = torch.randn_like(im).to(device)
73 |
74 | # Sample timestep
75 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
76 |
77 | # Add noise to images according to timestep
78 | noisy_im = scheduler.add_noise(im, noise, t)
79 | noise_pred = model(noisy_im, t)
80 |
81 | loss = criterion(noise_pred, noise)
82 | losses.append(loss.item())
83 | loss.backward()
84 | optimizer.step()
85 | print('Finished epoch:{} | Loss : {:.4f}'.format(
86 | epoch_idx + 1,
87 | np.mean(losses),
88 | ))
89 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
90 | train_config['ddpm_ckpt_name']))
91 |
92 | print('Done Training ...')
93 |
94 |
95 | if __name__ == '__main__':
96 | parser = argparse.ArgumentParser(description='Arguments for ddpm training')
97 | parser.add_argument('--config', dest='config_path',
98 | default='config/mnist.yaml', type=str)
99 | args = parser.parse_args()
100 | train(args)
101 |
--------------------------------------------------------------------------------
/tools/train_ddpm_controlnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import argparse
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.optim import Adam
8 | from dataset.mnist_dataset import MnistDataset
9 | from torch.utils.data import DataLoader
10 | from models.controlnet import ControlNet
11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 |
19 | def train(args):
20 | # Read the config file #
21 | with open(args.config_path, 'r') as file:
22 | try:
23 | config = yaml.safe_load(file)
24 | except yaml.YAMLError as exc:
25 | print(exc)
26 | print(config)
27 | ########################
28 |
29 | diffusion_config = config['diffusion_params']
30 | dataset_config = config['dataset_params']
31 | model_config = config['model_params']
32 | train_config = config['train_params']
33 |
34 | # Create the noise scheduler
35 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
36 | beta_start=diffusion_config['beta_start'],
37 | beta_end=diffusion_config['beta_end'])
38 |
39 | # Create the dataset
40 | mnist = MnistDataset('train',
41 | im_path=dataset_config['im_path'],
42 | return_hints=True)
43 | mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True)
44 |
45 | # Load model with checkpoint
46 | model = ControlNet(model_config,
47 | model_locked=True,
48 | model_ckpt=os.path.join(train_config['task_name'],
49 | train_config['ddpm_ckpt_name']),
50 | device=device).to(device)
51 | model.train()
52 |
53 | # Create output directories
54 | if not os.path.exists(train_config['task_name']):
55 | os.mkdir(train_config['task_name'])
56 |
57 | # Load checkpoint if found
58 | if os.path.exists(os.path.join(train_config['task_name'],
59 | train_config['controlnet_ckpt_name'])):
60 | print('Loading checkpoint as found one')
61 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
62 | train_config['controlnet_ckpt_name']),
63 | map_location=device))
64 |
65 | # Specify training parameters
66 | num_epochs = train_config['controlnet_epochs']
67 | optimizer = Adam(model.get_params(), lr=train_config['controlnet_lr'])
68 | criterion = torch.nn.MSELoss()
69 |
70 | # Run training
71 | steps = 0
72 | for epoch_idx in range(num_epochs):
73 | losses = []
74 | for im, hint in tqdm(mnist_loader):
75 | optimizer.zero_grad()
76 |
77 | im = im.float().to(device)
78 | hint = hint.float().to(device)
79 |
80 | # Sample random noise
81 | noise = torch.randn_like(im).to(device)
82 |
83 | # Sample timestep
84 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
85 |
86 | # Add noise to images according to timestep
87 | noisy_im = scheduler.add_noise(im, noise, t)
88 |
89 | # Additionally start passing the hint
90 | noise_pred = model(noisy_im, t, hint)
91 |
92 | loss = criterion(noise_pred, noise)
93 | losses.append(loss.item())
94 | loss.backward()
95 | optimizer.step()
96 | steps += 1
97 | print('Finished epoch:{} | Loss : {:.4f}'.format(
98 | epoch_idx + 1,
99 | np.mean(losses),
100 | ))
101 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
102 | train_config['controlnet_ckpt_name']))
103 |
104 | print('Done Training ...')
105 |
106 |
107 | if __name__ == '__main__':
108 | parser = argparse.ArgumentParser(description='Arguments for controlnet ddpm training')
109 | parser.add_argument('--config', dest='config_path',
110 | default='config/mnist.yaml', type=str)
111 | args = parser.parse_args()
112 | train(args)
113 |
--------------------------------------------------------------------------------
/tools/train_ldm_controlnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import argparse
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.optim import Adam
8 | from dataset.celeb_dataset import CelebDataset
9 | from torch.utils.data import DataLoader
10 | from models.controlnet_ldm import ControlNet
11 | from models.vae import VAE
12 | from torch.optim.lr_scheduler import MultiStepLR
13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 | if torch.backends.mps.is_available():
17 | device = torch.device('mps')
18 | print('Using mps')
19 |
20 |
21 | def train(args):
22 | # Read the config file #
23 | with open(args.config_path, 'r') as file:
24 | try:
25 | config = yaml.safe_load(file)
26 | except yaml.YAMLError as exc:
27 | print(exc)
28 | print(config)
29 | ########################
30 |
31 | diffusion_config = config['diffusion_params']
32 | dataset_config = config['dataset_params']
33 | diffusion_model_config = config['ldm_params']
34 | autoencoder_model_config = config['autoencoder_params']
35 | train_config = config['train_params']
36 |
37 | # Create the noise scheduler
38 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
39 | beta_start=diffusion_config['beta_start'],
40 | beta_end=diffusion_config['beta_end'],
41 | ldm_scheduler=True)
42 |
43 | im_dataset = CelebDataset(split='train',
44 | im_path=dataset_config['im_path'],
45 | im_size=dataset_config['im_size'],
46 | im_channels=dataset_config['im_channels'],
47 | use_latents=True,
48 | latent_path=os.path.join(train_config['task_name'],
49 | train_config['vae_latent_dir_name']),
50 | return_hint=True
51 | )
52 |
53 | data_loader = DataLoader(im_dataset,
54 | batch_size=train_config['ldm_batch_size'],
55 | shuffle=True)
56 |
57 | # Instantiate the model
58 | # downscale factor = canny_image_size // latent_size
59 | latent_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])
60 | downscale_factor = dataset_config['canny_im_size'] // latent_size
61 | model = ControlNet(im_channels=autoencoder_model_config['z_channels'],
62 | model_config=diffusion_model_config,
63 | model_locked=True,
64 | model_ckpt=os.path.join(train_config['task_name'], train_config['ldm_ckpt_name']),
65 | device=device,
66 | down_sample_factor=downscale_factor).to(device)
67 | model.train()
68 | # Create output directories
69 | if not os.path.exists(train_config['task_name']):
70 | os.mkdir(train_config['task_name'])
71 |
72 | # Load checkpoint if found
73 | if os.path.exists(os.path.join(train_config['task_name'], train_config['controlnet_ckpt_name'])):
74 | print('Loading checkpoint as found one')
75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
76 | train_config['controlnet_ckpt_name']),
77 | map_location=device))
78 |
79 | # Load VAE ONLY if latents are not to be used or are missing
80 | if not im_dataset.use_latents:
81 | print('Loading vae model as latents not present')
82 | vae = VAE(im_channels=dataset_config['im_channels'],
83 | model_config=autoencoder_model_config).to(device)
84 | vae.eval()
85 | # Load vae if found
86 | if os.path.exists(os.path.join(train_config['task_name'],
87 | train_config['vae_autoencoder_ckpt_name'])):
88 | print('Loaded vae checkpoint')
89 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
90 | train_config['vae_autoencoder_ckpt_name']),
91 | map_location=device))
92 | # Specify training parameters
93 | num_epochs = train_config['ldm_epochs']
94 | optimizer = Adam(model.get_params(), lr=train_config['controlnet_lr'])
95 | lr_scheduler = MultiStepLR(optimizer, milestones=train_config['controlnet_lr_steps'], gamma=0.1)
96 | criterion = torch.nn.MSELoss()
97 |
98 | # Run training
99 | if not im_dataset.use_latents:
100 | for param in vae.parameters():
101 | param.requires_grad = False
102 |
103 | step_count = 0
104 | count = 0
105 | for epoch_idx in range(num_epochs):
106 | losses = []
107 | for im, hint in tqdm(data_loader):
108 | optimizer.zero_grad()
109 | im = im.float().to(device)
110 | if im_dataset.use_latents:
111 | mean, logvar = torch.chunk(im, 2, dim=1)
112 | std = torch.exp(0.5 * logvar)
113 | im = mean + std * torch.randn(mean.shape).to(device=im.device)
114 | else:
115 | with torch.no_grad():
116 | im, _ = vae.encode(im)
117 |
118 | hint = hint.float().to(device)
119 | # Sample random noise
120 | noise = torch.randn_like(im).to(device)
121 |
122 | # Sample timestep
123 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
124 |
125 | # Add noise to images according to timestep
126 | noisy_im = scheduler.add_noise(im, noise, t)
127 | noise_pred = model(noisy_im, t, hint)
128 |
129 | loss = criterion(noise_pred, noise)
130 | losses.append(loss.item())
131 | loss.backward()
132 | optimizer.step()
133 | step_count += 1
134 | print('Finished epoch:{} | Loss : {:.4f}'.format(
135 | epoch_idx + 1,
136 | np.mean(losses)))
137 | lr_scheduler.step()
138 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
139 | train_config['controlnet_ckpt_name']))
140 |
141 | print('Done Training ...')
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser(description='Arguments for ldm controlnet training')
146 | parser.add_argument('--config', dest='config_path',
147 | default='config/celebhq.yaml', type=str)
148 | args = parser.parse_args()
149 | train(args)
150 |
--------------------------------------------------------------------------------
/tools/train_ldm_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import argparse
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.optim import Adam
8 | from dataset.celeb_dataset import CelebDataset
9 | from torch.utils.data import DataLoader
10 | from models.unet_cond_base import Unet
11 | from models.vae import VAE
12 | from torch.optim.lr_scheduler import MultiStepLR
13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 | if torch.backends.mps.is_available():
17 | device = torch.device('mps')
18 | print('Using mps')
19 |
20 |
21 | def train(args):
22 | # Read the config file #
23 | with open(args.config_path, 'r') as file:
24 | try:
25 | config = yaml.safe_load(file)
26 | except yaml.YAMLError as exc:
27 | print(exc)
28 | print(config)
29 | ########################
30 |
31 | diffusion_config = config['diffusion_params']
32 | dataset_config = config['dataset_params']
33 | diffusion_model_config = config['ldm_params']
34 | autoencoder_model_config = config['autoencoder_params']
35 | train_config = config['train_params']
36 |
37 | # Create the noise scheduler
38 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
39 | beta_start=diffusion_config['beta_start'],
40 | beta_end=diffusion_config['beta_end'],
41 | ldm_scheduler=True)
42 |
43 | im_dataset = CelebDataset(split='train',
44 | im_path=dataset_config['im_path'],
45 | im_size=dataset_config['im_size'],
46 | im_channels=dataset_config['im_channels'],
47 | use_latents=True,
48 | latent_path=os.path.join(train_config['task_name'],
49 | train_config['vae_latent_dir_name'])
50 | )
51 |
52 | data_loader = DataLoader(im_dataset,
53 | batch_size=train_config['ldm_batch_size'],
54 | shuffle=True)
55 |
56 | # Instantiate the model
57 | model = Unet(im_channels=autoencoder_model_config['z_channels'],
58 | model_config=diffusion_model_config).to(device)
59 | model.train()
60 |
61 | if os.path.exists(os.path.join(train_config['task_name'],
62 | train_config['ldm_ckpt_name'])):
63 | print('Loaded unet checkpoint')
64 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
65 | train_config['ldm_ckpt_name']),
66 | map_location=device))
67 |
68 | # Load VAE ONLY if latents are not to be used or are missing
69 | if not im_dataset.use_latents:
70 | print('Loading vae model as latents not present')
71 | vae = VAE(im_channels=dataset_config['im_channels'],
72 | model_config=autoencoder_model_config).to(device)
73 | vae.eval()
74 | # Load vae if found
75 | if os.path.exists(os.path.join(train_config['task_name'],
76 | train_config['vae_autoencoder_ckpt_name'])):
77 | print('Loaded vae checkpoint')
78 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'],
79 | train_config['vae_autoencoder_ckpt_name']),
80 | map_location=device))
81 | # Specify training parameters
82 | num_epochs = train_config['ldm_epochs']
83 | optimizer = Adam(model.parameters(), lr=train_config['ldm_lr'])
84 | lr_scheduler = MultiStepLR(optimizer, milestones=train_config['ldm_lr_steps'], gamma=0.5)
85 | criterion = torch.nn.MSELoss()
86 |
87 | # Run training
88 | if not im_dataset.use_latents:
89 | for param in vae.parameters():
90 | param.requires_grad = False
91 |
92 | step_count = 0
93 | for epoch_idx in range(num_epochs):
94 | losses = []
95 | for im in tqdm(data_loader):
96 | optimizer.zero_grad()
97 | im = im.float().to(device)
98 | if im_dataset.use_latents:
99 | mean, logvar = torch.chunk(im, 2, dim=1)
100 | std = torch.exp(0.5 * logvar)
101 | im = mean + std * torch.randn(mean.shape).to(device=im.device)
102 | else:
103 | with torch.no_grad():
104 | im, _ = vae.encode(im)
105 |
106 | # Sample random noise
107 | noise = torch.randn_like(im).to(device)
108 |
109 | # Sample timestep
110 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
111 |
112 | # Add noise to images according to timestep
113 | noisy_im = scheduler.add_noise(im, noise, t)
114 | noise_pred = model(noisy_im, t)
115 |
116 | loss = criterion(noise_pred, noise)
117 | losses.append(loss.item())
118 | loss.backward()
119 | optimizer.step()
120 | step_count += 1
121 | print('Finished epoch:{} | Loss : {:.4f}'.format(
122 | epoch_idx + 1,
123 | np.mean(losses)))
124 | lr_scheduler.step()
125 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
126 | train_config['ldm_ckpt_name']))
127 |
128 | print('Done Training ...')
129 |
130 |
131 | if __name__ == '__main__':
132 | parser = argparse.ArgumentParser(description='Arguments for ldm training')
133 | parser.add_argument('--config', dest='config_path',
134 | default='config/celebhq.yaml', type=str)
135 | args = parser.parse_args()
136 | train(args)
137 |
--------------------------------------------------------------------------------
/tools/train_vae.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import random
5 | import torchvision
6 | import os
7 | import numpy as np
8 | from tqdm import tqdm
9 | from models.vae import VAE
10 | from models.lpips import LPIPS
11 | from models.discriminator import Discriminator
12 | from torch.utils.data.dataloader import DataLoader
13 | from dataset.celeb_dataset import CelebDataset
14 | from torch.optim import Adam
15 | from torchvision.utils import make_grid
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 | if torch.backends.mps.is_available():
19 | device = torch.device('mps')
20 | print('Using mps')
21 |
22 |
23 | def train(args):
24 | # Read the config file #
25 | with open(args.config_path, 'r') as file:
26 | try:
27 | config = yaml.safe_load(file)
28 | except yaml.YAMLError as exc:
29 | print(exc)
30 | print(config)
31 |
32 | dataset_config = config['dataset_params']
33 | autoencoder_config = config['autoencoder_params']
34 | train_config = config['train_params']
35 |
36 | # Set the desired seed value #
37 | seed = train_config['seed']
38 | torch.manual_seed(seed)
39 | np.random.seed(seed)
40 | random.seed(seed)
41 | if device == 'cuda':
42 | torch.cuda.manual_seed_all(seed)
43 | #############################
44 |
45 | # Create the model and dataset #
46 | model = VAE(im_channels=dataset_config['im_channels'],
47 | model_config=autoencoder_config).to(device)
48 | # Create the dataset
49 | im_dataset = CelebDataset(split='train',
50 | im_path=dataset_config['im_path'],
51 | im_size=dataset_config['im_size'],
52 | im_channels=dataset_config['im_channels'])
53 |
54 | data_loader = DataLoader(im_dataset,
55 | batch_size=train_config['autoencoder_batch_size'],
56 | shuffle=True)
57 |
58 | # Create output directories
59 | if not os.path.exists(train_config['task_name']):
60 | os.mkdir(train_config['task_name'])
61 |
62 | num_epochs = train_config['autoencoder_epochs']
63 |
64 | # L1/L2 loss for Reconstruction
65 | recon_criterion = torch.nn.MSELoss()
66 | # Disc Loss can even be BCEWithLogits
67 | disc_criterion = torch.nn.MSELoss()
68 |
69 | # No need to freeze lpips as lpips.py takes care of that
70 | lpips_model = LPIPS().eval().to(device)
71 | discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device)
72 |
73 | if os.path.exists(os.path.join(train_config['task_name'],
74 | train_config['vae_autoencoder_ckpt_name'])):
75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
76 | train_config['vae_autoencoder_ckpt_name']),
77 | map_location=device))
78 | print('Loaded autoencoder from checkpoint')
79 |
80 | if os.path.exists(os.path.join(train_config['task_name'],
81 | train_config['vae_discriminator_ckpt_name'])):
82 | discriminator.load_state_dict(torch.load(os.path.join(train_config['task_name'],
83 | train_config['vae_discriminator_ckpt_name']),
84 | map_location=device))
85 | print('Loaded discriminator from checkpoint')
86 |
87 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
88 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
89 |
90 | disc_step_start = train_config['disc_start']
91 | step_count = 0
92 |
93 | # This is for accumulating gradients incase the images are huge
94 | # And one cant afford higher batch sizes
95 | acc_steps = train_config['autoencoder_acc_steps']
96 | image_save_steps = train_config['autoencoder_img_save_steps']
97 | img_save_count = 0
98 |
99 | for epoch_idx in range(num_epochs):
100 | recon_losses = []
101 | perceptual_losses = []
102 | disc_losses = []
103 | gen_losses = []
104 | losses = []
105 |
106 | optimizer_g.zero_grad()
107 | optimizer_d.zero_grad()
108 |
109 | for im in tqdm(data_loader):
110 | step_count += 1
111 | im = im.float().to(device)
112 |
113 | # Fetch autoencoders output(reconstructions)
114 | model_output = model(im)
115 | output, encoder_output = model_output
116 |
117 | # Image Saving Logic
118 | if step_count % image_save_steps == 0 or step_count == 1:
119 | sample_size = min(8, im.shape[0])
120 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu()
121 | save_output = ((save_output + 1) / 2)
122 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
123 |
124 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
125 | img = torchvision.transforms.ToPILImage()(grid)
126 | if not os.path.exists(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')):
127 | os.mkdir(os.path.join(train_config['task_name'], 'vae_autoencoder_samples'))
128 | img.save(os.path.join(train_config['task_name'], 'vae_autoencoder_samples',
129 | 'current_autoencoder_sample_{}.png'.format(img_save_count)))
130 | img_save_count += 1
131 | img.close()
132 |
133 | ######### Optimize Generator ##########
134 | # L2 Loss
135 | recon_loss = recon_criterion(output, im)
136 | recon_losses.append(recon_loss.item())
137 | recon_loss = recon_loss / acc_steps
138 |
139 | mean, logvar = torch.chunk(encoder_output, 2, dim=1)
140 | kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mean ** 2 - 1 - logvar, dim=[1, 2, 3]))
141 |
142 | g_loss = recon_loss + (train_config['kl_weight'] * kl_loss / acc_steps)
143 |
144 | # Adversarial loss only if disc_step_start steps passed
145 | if step_count > disc_step_start:
146 | disc_fake_pred = discriminator(model_output[0])
147 | disc_fake_loss = disc_criterion(disc_fake_pred,
148 | torch.ones(disc_fake_pred.shape,
149 | device=disc_fake_pred.device))
150 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item())
151 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps
152 | lpips_loss = torch.mean(lpips_model(output, im))
153 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item())
154 | g_loss += train_config['perceptual_weight'] * lpips_loss / acc_steps
155 | losses.append(g_loss.item())
156 | g_loss.backward()
157 | #####################################
158 |
159 | ######### Optimize Discriminator #######
160 | if step_count > disc_step_start:
161 | fake = output
162 | disc_fake_pred = discriminator(fake.detach())
163 | disc_real_pred = discriminator(im)
164 | disc_fake_loss = disc_criterion(disc_fake_pred,
165 | torch.zeros(disc_fake_pred.shape,
166 | device=disc_fake_pred.device))
167 | disc_real_loss = disc_criterion(disc_real_pred,
168 | torch.ones(disc_real_pred.shape,
169 | device=disc_real_pred.device))
170 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2
171 | disc_losses.append(disc_loss.item())
172 | disc_loss = disc_loss / acc_steps
173 | disc_loss.backward()
174 | if step_count % acc_steps == 0:
175 | optimizer_d.step()
176 | optimizer_d.zero_grad()
177 | #####################################
178 |
179 | if step_count % acc_steps == 0:
180 | optimizer_g.step()
181 | optimizer_g.zero_grad()
182 | optimizer_d.step()
183 | optimizer_d.zero_grad()
184 | optimizer_g.step()
185 | optimizer_g.zero_grad()
186 | if len(disc_losses) > 0:
187 | print(
188 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | '
189 | 'G Loss : {:.4f} | D Loss {:.4f}'.
190 | format(epoch_idx + 1,
191 | np.mean(recon_losses),
192 | np.mean(perceptual_losses),
193 | np.mean(gen_losses),
194 | np.mean(disc_losses)))
195 | else:
196 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}'.
197 | format(epoch_idx + 1,
198 | np.mean(recon_losses),
199 | np.mean(perceptual_losses)))
200 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
201 | train_config['vae_autoencoder_ckpt_name']))
202 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'],
203 | train_config['vae_discriminator_ckpt_name']))
204 | print('Done Training...')
205 |
206 |
207 | if __name__ == '__main__':
208 | parser = argparse.ArgumentParser(description='Arguments for vae training')
209 | parser.add_argument('--config', dest='config_path',
210 | default='config/celebhq.yaml', type=str)
211 | args = parser.parse_args()
212 | train(args)
213 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | def validate_class_config(condition_config):
2 | assert 'class_condition_config' in condition_config, \
3 | "Class conditioning desired but class condition config missing"
4 | assert 'num_classes' in condition_config['class_condition_config'], \
5 | "num_class missing in class condition config"
6 |
7 |
8 | def validate_text_config(condition_config):
9 | assert 'text_condition_config' in condition_config, \
10 | "Text conditioning desired but text condition config missing"
11 | assert 'text_embed_dim' in condition_config['text_condition_config'], \
12 | "text_embed_dim missing in text condition config"
13 |
14 |
15 | def validate_image_config(condition_config):
16 | assert 'image_condition_config' in condition_config, \
17 | "Image conditioning desired but image condition config missing"
18 | assert 'image_condition_input_channels' in condition_config['image_condition_config'], \
19 | "image_condition_input_channels missing in image condition config"
20 | assert 'image_condition_output_channels' in condition_config['image_condition_config'], \
21 | "image_condition_output_channels missing in image condition config"
22 |
23 |
24 | def validate_image_conditional_input(cond_input, x):
25 | assert 'image' in cond_input, \
26 | "Model initialized with image conditioning but cond_input has no image information"
27 | assert cond_input['image'].shape[0] == x.shape[0], \
28 | "Batch size mismatch of image condition and input"
29 | assert cond_input['image'].shape[2] % x.shape[2] == 0, \
30 | "Height/Width of image condition must be divisible by latent input"
31 |
32 |
33 | def validate_class_conditional_input(cond_input, x, num_classes):
34 | assert 'class' in cond_input, \
35 | "Model initialized with class conditioning but cond_input has no class information"
36 | assert cond_input['class'].shape == (x.shape[0], num_classes), \
37 | "Shape of class condition input must match (Batch Size, )"
38 |
39 |
40 | def get_config_value(config, key, default_value):
41 | return config[key] if key in config else default_value
--------------------------------------------------------------------------------
/utils/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import glob
3 | import os
4 | import torch
5 |
6 |
7 | def load_latents(latent_path):
8 | r"""
9 | Simple utility to save latents to speed up ldm training
10 | :param latent_path:
11 | :return:
12 | """
13 | latent_maps = {}
14 | for fname in glob.glob(os.path.join(latent_path, '*.pkl')):
15 | s = pickle.load(open(fname, 'rb'))
16 | for k, v in s.items():
17 | latent_maps[k] = v[0]
18 | return latent_maps
19 |
20 |
21 | def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob):
22 | if text_drop_prob > 0:
23 | text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0,
24 | 1) < text_drop_prob
25 | assert empty_text_embed is not None, ("Text Conditioning required as well as"
26 | " text dropping but empty text representation not created")
27 | text_embed[text_drop_mask, :, :] = empty_text_embed[0]
28 | return text_embed
29 |
30 |
31 | def drop_image_condition(image_condition, im, im_drop_prob):
32 | if im_drop_prob > 0:
33 | im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0,
34 | 1) > im_drop_prob
35 | return image_condition * im_drop_mask
36 | else:
37 | return image_condition
38 |
39 |
40 | def drop_class_condition(class_condition, class_drop_prob, im):
41 | if class_drop_prob > 0:
42 | class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0,
43 | 1) > class_drop_prob
44 | return class_condition * class_drop_mask
45 | else:
46 | return class_condition
47 |
--------------------------------------------------------------------------------