├── .gitignore ├── LICENSE ├── README.md ├── config ├── __init__.py ├── celebhq.yaml └── mnist.yaml ├── data └── .gitkeep ├── dataset ├── __init__.py ├── celeb_dataset.py └── mnist_dataset.py ├── models ├── __init__.py ├── blocks.py ├── controlnet.py ├── controlnet_ldm.py ├── discriminator.py ├── lpips.py ├── unet_base.py ├── unet_cond_base.py ├── vae.py └── weights │ └── v0.1 │ └── .gitkeep ├── requirements.txt ├── scheduler ├── __init__.py └── linear_noise_scheduler.py ├── tools ├── __init__.py ├── infer_vae.py ├── sample_ddpm.py ├── sample_ddpm_controlnet.py ├── sample_ldm_controlnet.py ├── sample_ldm_vae.py ├── train_ddpm.py ├── train_ddpm_controlnet.py ├── train_ldm_controlnet.py ├── train_ldm_vae.py └── train_vae.py └── utils ├── __init__.py ├── config_utils.py └── diffusion_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all image files 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | 6 | # Ignore pycharm and system files 7 | .DS_Store 8 | *.idea 9 | __pycache__ 10 | *.zip 11 | 12 | # Ignore dataset files 13 | *.csv 14 | *.json 15 | 16 | # Ignore checkpoints 17 | *.pth 18 | 19 | # Ignore pickle files 20 | *.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ExplainingAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ControlNet Implementation in PyTorch 2 | ======== 3 | ## ControlNet Tutorial Video 4 | 5 | ControlNet Tutorial 7 | 8 | 9 | ## Sample Output for ControlNet with DDPM on MNIST and with LDM on CelebHQ 10 | Canny Edge Control - Top, Sample - Below 11 | 12 | 13 | 14 | ___ 15 | 16 | This repository implements ControlNet in PyTorch for diffusion models. 17 | As of now, the repo provides code to do the following: 18 | * Training and Inference of Unconditional DDPM on MNIST dataset 19 | * Training and Inference of ControlNet with DDPM on MNIST using canny edges 20 | * Training and Inference of Unconditional Latent Diffusion Model on CelebHQ dataset(resized to 128x128 with latent images being 32x32) 21 | * Training and Inference of ControlNet with Unconditional Latent Diffusion Model on CelebHQ using canny edges 22 | 23 | 24 | For autoencoder of Latent Diffusion Model, I provide training and inference code for vae. 25 | 26 | ## Setup 27 | * Create a new conda environment with python 3.10 then run below commands 28 | * `conda activate ` 29 | * ```git clone https://github.com/explainingai-code/ControlNet-PyTorch.git``` 30 | * ```cd ControlNet-PyTorch``` 31 | * ```pip install -r requirements.txt``` 32 | * Download lpips weights by opening this link in browser(dont use cURL or wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```models/weights/v0.1/vgg.pth``` 33 | ___ 34 | 35 | ## Data Preparation 36 | ### Mnist 37 | 38 | For setting up the mnist dataset follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation 39 | 40 | Ensure directory structure is following 41 | ``` 42 | ControlNet-PyTorch 43 | -> data 44 | -> mnist 45 | -> train 46 | -> images 47 | -> *.png 48 | -> test 49 | -> images 50 | -> *.png 51 | ``` 52 | 53 | ### CelebHQ 54 | For setting up on CelebHQ, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file). 55 | 56 | Ensure directory structure is the following 57 | ``` 58 | ControlNet-PyTorch 59 | -> data 60 | -> CelebAMask-HQ 61 | -> CelebA-HQ-img 62 | -> *.jpg 63 | 64 | ``` 65 | --- 66 | ## Configuration 67 | Allows you to play with different components of ddpm and autoencoder training 68 | * ```config/mnist.yaml``` - Config for MNIST dataset 69 | * ```config/celebhq.yaml``` - Configuration used for celebhq dataset 70 | 71 | Relevant configuration parameters 72 | 73 | Most parameters are self-explanatory but below I mention couple which are specific to this repo. 74 | * ```autoencoder_acc_steps``` : For accumulating gradients if image size is too large for larger batch sizes 75 | * ```save_latents``` : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster 76 | 77 | ___ 78 | ## Training 79 | The repo provides training and inference for Mnist(Unconditional DDPM) and CelebHQ (Unconditional LDM) and ControlNet with both these variations using canny edges. 80 | 81 | For working on your own dataset: 82 | * Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance) 83 | * Create your own dataset class which will just collect all the filenames and return the image and its hint in its getitem method. Look at `mnist_dataset.py` or `celeb_dataset.py` for guidance 84 | 85 | Once the config and dataset is setup: 86 | * For training and inference of Unconditional DDPM follow [this section](#training-unconditional-ddpm) 87 | * For training and inference of ControlNet with Unconditional DDPM follow [this section](#training-controlnet-for-unconditional-ddpm) 88 | * Train the auto encoder on your dataset using [this section](#training-autoencoder-for-ldm) 89 | * For training and inference of Unconditional LDM follow [this section](#training-unconditional-ldm) 90 | * For training and inference of ControlNet with Unconditional LDM follow [this section](#training-controlnet-for-unconditional-ldm) 91 | 92 | 93 | 94 | ## Training Unconditional DDPM 95 | * For training ddpm on mnist,ensure the right path is mentioned in `mnist.yaml` 96 | * For training ddpm on your own dataset 97 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance) 98 | * Create your own dataset class, similar to celeb_dataset.py 99 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ddpm.py#L40) 100 | * For training DDPM run ```python -m tools.train_ddpm --config config/mnist.yaml``` for training ddpm with the desire config file 101 | * For inference run ```python -m tools.sample_ddpm --config config/mnist.yaml``` for generating samples with right config file. 102 | 103 | ## Training ControlNet for Unconditional DDPM 104 | * For training controlnet, ensure the right path is mentioned in `mnist.yaml` 105 | * For training controlnet with ddpm on your own dataset 106 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance) 107 | * Create your own dataset class, similar to celeb_dataset.py 108 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ddpm_controlnet.py#L40) 109 | * Ensure ```return_hints``` is passed as True in the dataset class initialization 110 | * For training controlnet run ```python -m tools.train_ddpm_controlnet --config config/mnist.yaml``` for training controlnet ddpm with the desire config file 111 | * For inference run ```python -m tools.sample_ddpm_controlnet --config config/mnist.yaml``` for generating ddpm samples using canny hints with right config file. 112 | 113 | 114 | ## Training AutoEncoder for LDM 115 | * For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml` 116 | * For training autoencoder on your own dataset 117 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance) 118 | * Create your own dataset class, similar to celeb_dataset.py 119 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_vae.py#L49) 120 | * For training autoencoder run ```python -m tools.train_vae --config config/celebhq.yaml``` for training autoencoder with the desire config file 121 | * For inference make sure `save_latent` is `True` in the config 122 | * For inference run ```python -m tools.infer_vae --config config/celebhq.yaml``` for generating reconstructions and saving latents with right config file. 123 | 124 | 125 | ## Training Unconditional LDM 126 | Train the autoencoder first and setup dataset accordingly. 127 | 128 | For training unconditional LDM ensure the right dataset is used in `train_ldm_vae.py` [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ldm_vae.py#L43) 129 | * ```python -m tools.train_ldm_vae --config config/celebhq.yaml``` for training unconditional ldm using right config 130 | * ```python -m tools.sample_ldm_vae --config config/celebhq.yaml``` for generating images using trained ldm 131 | 132 | 133 | ## Training ControlNet for Unconditional LDM 134 | * For training controlnet with celebhq, ensure the right path is mentioned in `celebhq.yaml` 135 | * For training controlnet with ldm on your own dataset 136 | * Create your own config and have the path point to images (look at celebhq.yaml for guidance) 137 | * Create your own dataset class, similar to celeb_dataset.py 138 | * Ensure Autoencoder and LDM have already been trained 139 | * Call the desired dataset class in training file [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ldm_controlnet.py#L43) 140 | * Ensure ```return_hints``` is passed as True in the dataset class initialization 141 | * Ensure ```down_sample_factor``` is correctly computed in the model initialization [here](https://github.com/explainingai-code/ControlNet-PyTorch/blob/main/tools/train_ldm_controlnet.py#L60) 142 | * For training controlnet run ```python -m tools.train_ldm_controlnet --config config/celebhq.yaml``` for training controlnet ldm with the desire config file 143 | * For inference with controlnet run ```python -m tools.sample_ldm_controlnet --config config/celebhq.yaml``` for generating ldm samples using canny hints with right config file. 144 | 145 | 146 | ## Output 147 | Outputs will be saved according to the configuration present in yaml files. 148 | 149 | For every run a folder of ```task_name``` key in config will be created 150 | 151 | During training of autoencoder the following output will be saved 152 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory 153 | * Sample reconstructions in ```task_name/vae_autoencoder_samples``` 154 | 155 | During inference of autoencoder the following output will be saved 156 | * Reconstructions for random images in ```task_name``` 157 | * Latents will be save in ```task_name/vae_latent_dir_name``` if mentioned in config 158 | 159 | During training and inference of unconditional ddpm or ldm following output will be saved: 160 | * During training we will save the latest checkpoint in ```task_name``` directory 161 | * During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png``` . The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0 162 | 163 | During training and inference of controlnet with ddpm/ldm following output will be saved: 164 | * During training we will save the latest checkpoint in ```task_name``` directory 165 | * During sampling, randomly selected hints and generated samples will be saved in ```task_name/hint.png``` and ```task_name/controlnet_samples/*.png```. The final decoded generated image will be `x0_0.png`. Images from `x0_999.png` to `x0_1.png` will be latent image predictions of denoising process from T=999 to T=1. Generated Image is at T=0 166 | 167 | 168 | 169 | ## Citations 170 | ``` 171 | @misc{zhang2023addingconditionalcontroltexttoimage, 172 | title={Adding Conditional Control to Text-to-Image Diffusion Models}, 173 | author={Lvmin Zhang and Anyi Rao and Maneesh Agrawala}, 174 | year={2023}, 175 | eprint={2302.05543}, 176 | archivePrefix={arXiv}, 177 | primaryClass={cs.CV}, 178 | url={https://arxiv.org/abs/2302.05543}, 179 | } 180 | ``` 181 | 182 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/config/__init__.py -------------------------------------------------------------------------------- /config/celebhq.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/CelebAMask-HQ' 3 | im_channels : 3 4 | im_size : 128 5 | canny_im_size : 1024 6 | name: 'celebhq' 7 | 8 | diffusion_params: 9 | num_timesteps : 1000 10 | beta_start : 0.0015 11 | beta_end : 0.0195 12 | 13 | ldm_params: 14 | hint_channels : 3 15 | down_channels: [ 256, 384, 512, 768 ] 16 | mid_channels: [ 768, 512 ] 17 | down_sample: [ True, True, True ] 18 | attn_down : [True, True, True] 19 | time_emb_dim: 512 20 | norm_channels: 32 21 | num_heads: 16 22 | conv_out_channels : 128 23 | num_down_layers : 2 24 | num_mid_layers : 2 25 | num_up_layers : 2 26 | 27 | autoencoder_params: 28 | z_channels: 4 29 | codebook_size : 8192 30 | down_channels : [128, 256, 384] 31 | mid_channels : [384] 32 | down_sample : [True, True] 33 | attn_down : [False, False] 34 | norm_channels: 32 35 | num_heads: 4 36 | num_down_layers : 2 37 | num_mid_layers : 2 38 | num_up_layers : 2 39 | 40 | 41 | train_params: 42 | seed : 1111 43 | task_name: 'celebhq' 44 | ldm_batch_size: 16 45 | autoencoder_batch_size: 4 46 | disc_start: 7500 47 | disc_weight: 0.5 48 | codebook_weight: 1 49 | commitment_beta: 0.2 50 | perceptual_weight: 1 51 | kl_weight: 0.000005 52 | ldm_epochs: 200 53 | autoencoder_epochs: 3 54 | controlnet_epochs : 15 55 | num_samples: 2 56 | num_grid_rows: 2 57 | ldm_lr: 0.000025 58 | ldm_lr_steps : [25, 50, 75, 100] 59 | autoencoder_lr: 0.00001 60 | controlnet_lr: 0.00001 61 | controlnet_lr_steps : [10] 62 | autoencoder_acc_steps: 1 63 | autoencoder_img_save_steps: 64 64 | save_latents : True 65 | vae_latent_dir_name: 'vae_latents' 66 | vqvae_latent_dir_name: 'vqvae_latents' 67 | ldm_ckpt_name: 'ddpm_ckpt.pth' 68 | controlnet_ckpt_name: 'ddpm_controlnet_ckpt.pth' 69 | vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth' 70 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 71 | vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth' 72 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 73 | -------------------------------------------------------------------------------- /config/mnist.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'data/mnist/train/images' 3 | im_test_path: 'data/mnist/test/images' 4 | canny_im_size: 28 5 | 6 | diffusion_params: 7 | num_timesteps : 1000 8 | beta_start : 0.0001 9 | beta_end : 0.02 10 | 11 | model_params: 12 | im_channels : 1 13 | im_size : 28 14 | hint_channels : 3 15 | down_channels : [32, 64, 128, 256] 16 | mid_channels : [256, 256, 128] 17 | down_sample : [True, True, False] 18 | time_emb_dim : 128 19 | num_down_layers : 2 20 | num_mid_layers : 2 21 | num_up_layers : 2 22 | num_heads : 4 23 | 24 | train_params: 25 | task_name: 'mnist' 26 | batch_size: 64 27 | num_epochs: 40 28 | controlnet_epochs : 1 29 | num_samples : 25 30 | num_grid_rows : 5 31 | ddpm_lr: 0.0001 32 | controlnet_lr: 0.0001 33 | ddpm_ckpt_name: 'ddpm_ckpt.pth' 34 | controlnet_ckpt_name: 'ddpm_controlnet_ckpt.pth' 35 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/data/.gitkeep -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/celeb_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import torchvision 5 | import numpy as np 6 | from PIL import Image 7 | from utils.diffusion_utils import load_latents 8 | from tqdm import tqdm 9 | from torch.utils.data.dataset import Dataset 10 | 11 | 12 | class CelebDataset(Dataset): 13 | r""" 14 | Celeb dataset will by default centre crop and resize the images. 15 | This can be replaced by any other dataset. As long as all the images 16 | are under one directory. 17 | """ 18 | 19 | def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg', 20 | use_latents=False, latent_path=None, return_hint=False):#, condition_config=None): 21 | self.split = split 22 | self.im_size = im_size 23 | self.im_channels = im_channels 24 | self.im_ext = im_ext 25 | self.im_path = im_path 26 | self.latent_maps = None 27 | self.use_latents = False 28 | self.return_hints = return_hint 29 | # self.condition_types = [] if condition_config is None else condition_config['condition_types'] 30 | 31 | # self.idx_to_cls_map = {} 32 | # self.cls_to_idx_map = {} 33 | 34 | # if 'image' in self.condition_types: 35 | # self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels'] 36 | # self.mask_h = condition_config['image_condition_config']['image_condition_h'] 37 | # self.mask_w = condition_config['image_condition_config']['image_condition_w'] 38 | 39 | #self.images, self.texts, self.masks = self.load_images(im_path) 40 | self.images = self.load_images(im_path) 41 | 42 | # Whether to load images or to load latents 43 | if use_latents and latent_path is not None: 44 | latent_maps = load_latents(latent_path) 45 | if len(latent_maps) == len(self.images): 46 | self.use_latents = True 47 | self.latent_maps = latent_maps 48 | print('Found {} latents'.format(len(self.latent_maps))) 49 | else: 50 | print('Latents not found') 51 | 52 | def load_images(self, im_path): 53 | r""" 54 | Gets all images from the path specified 55 | and stacks them all up 56 | """ 57 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path) 58 | ims = [] 59 | fnames = glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('png'))) 60 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpg'))) 61 | fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpeg'))) 62 | texts = [] 63 | masks = [] 64 | 65 | # if 'image' in self.condition_types: 66 | # label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 67 | # 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] 68 | # self.idx_to_cls_map = {idx: label_list[idx] for idx in range(len(label_list))} 69 | # self.cls_to_idx_map = {label_list[idx]: idx for idx in range(len(label_list))} 70 | 71 | for fname in tqdm(fnames): 72 | ims.append(fname) 73 | 74 | # if 'text' in self.condition_types: 75 | # im_name = os.path.split(fname)[1].split('.')[0] 76 | # captions_im = [] 77 | # with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f: 78 | # for line in f.readlines(): 79 | # captions_im.append(line.strip()) 80 | # texts.append(captions_im) 81 | 82 | # if 'image' in self.condition_types: 83 | # im_name = int(os.path.split(fname)[1].split('.')[0]) 84 | # masks.append(os.path.join(im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name))) 85 | # if 'text' in self.condition_types: 86 | # assert len(texts) == len(ims), "Condition Type Text but could not find captions for all images" 87 | # if 'image' in self.condition_types: 88 | # assert len(masks) == len(ims), "Condition Type Image but could not find masks for all images" 89 | print('Found {} images'.format(len(ims))) 90 | #print('Found {} masks'.format(len(masks))) 91 | #print('Found {} captions'.format(len(texts))) 92 | return ims#, texts, masks 93 | 94 | # def get_mask(self, index): 95 | # r""" 96 | # Method to get the mask of WxH 97 | # for given index and convert it into 98 | # Classes x W x H mask image 99 | # :param index: 100 | # :return: 101 | # """ 102 | # mask_im = Image.open(self.masks[index]) 103 | # mask_im = np.array(mask_im) 104 | # im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels)) 105 | # for orig_idx in range(len(self.idx_to_cls_map)): 106 | # im_base[mask_im == (orig_idx + 1), orig_idx] = 1 107 | # mask = torch.from_numpy(im_base).permute(2, 0, 1).float() 108 | # return mask 109 | 110 | def __len__(self): 111 | return len(self.images) 112 | 113 | def __getitem__(self, index): 114 | ######## Set Conditioning Info ######## 115 | # cond_inputs = {} 116 | # if 'text' in self.condition_types: 117 | # cond_inputs['text'] = random.sample(self.texts[index], k=1)[0] 118 | # if 'image' in self.condition_types: 119 | # mask = self.get_mask(index) 120 | # cond_inputs['image'] = mask 121 | ####################################### 122 | # im = Image.open(self.images[index]) 123 | # im.save('original_image.png') 124 | # canny_image = np.array(im) 125 | # print(self.images[index]) 126 | # low_threshold = 100 127 | # high_threshold = 200 128 | # import cv2 129 | # canny_image = cv2.Canny(canny_image, low_threshold, high_threshold) 130 | # canny_image = canny_image[:, :, None] 131 | # canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) 132 | # canny_image = 255 - canny_image 133 | # canny_image = Image.fromarray(canny_image) 134 | # canny_image.save('canny_image.png') 135 | # print(list(self.latent_maps.keys())[0]) 136 | # print(self.images[index] in self.latent_maps) 137 | # print(self.images[index].replace('../', '') in self.latent_maps) 138 | # latent = self.latent_maps[self.images[index].replace('../', '')] 139 | # latent = torch.clamp(latent, -1., 1.) 140 | # latent = (latent + 1) / 2 141 | # latent = torchvision.transforms.ToPILImage()(latent[0:-1, :, :]) 142 | # latent.save('latent_image.png') 143 | # exit(0) 144 | 145 | if self.use_latents: 146 | latent = self.latent_maps[self.images[index]] 147 | if self.return_hints: 148 | canny_image = Image.open(self.images[index]) 149 | canny_image = np.array(canny_image) 150 | canny_image = cv2.Canny(canny_image, 100, 200) 151 | canny_image = canny_image[:, :, None] 152 | canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) 153 | canny_image_tensor = torchvision.transforms.ToTensor()(canny_image) 154 | return latent, canny_image_tensor 155 | else: 156 | return latent 157 | 158 | else: 159 | im = Image.open(self.images[index]) 160 | im_tensor = torchvision.transforms.Compose([ 161 | torchvision.transforms.Resize(self.im_size), 162 | torchvision.transforms.CenterCrop(self.im_size), 163 | torchvision.transforms.ToTensor(), 164 | ])(im) 165 | im.close() 166 | 167 | # Convert input to -1 to 1 range. 168 | im_tensor = (2 * im_tensor) - 1 169 | 170 | if self.return_hints: 171 | canny_image = Image.open(self.images[index]) 172 | canny_image = np.array(canny_image) 173 | canny_image = cv2.Canny(canny_image, 100, 200) 174 | canny_image = canny_image[:, :, None] 175 | canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) 176 | canny_image_tensor = torchvision.transforms.ToTensor()(canny_image) 177 | return im_tensor, canny_image_tensor 178 | else: 179 | return im_tensor 180 | 181 | 182 | if __name__ == '__main__': 183 | mnist = CelebDataset('train', im_path='../data/CelebAMask-HQ', 184 | use_latents=True, latent_path='../celebhq/vae_latents') 185 | x = mnist[1800] 186 | -------------------------------------------------------------------------------- /dataset/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torchvision 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from torch.utils.data.dataset import Dataset 9 | 10 | 11 | class MnistDataset(Dataset): 12 | r""" 13 | Nothing special here. Just a simple dataset class for mnist images. 14 | Created a dataset class rather using torchvision to allow 15 | replacement with any other image dataset 16 | """ 17 | def __init__(self, split, im_path, im_ext='png', im_size=28, return_hints=False): 18 | r""" 19 | Init method for initializing the dataset properties 20 | :param split: train/test to locate the image files 21 | :param im_path: root folder of images 22 | :param im_ext: image extension. assumes all 23 | images would be this type. 24 | """ 25 | self.split = split 26 | self.im_ext = im_ext 27 | self.return_hints = return_hints 28 | self.images = self.load_images(im_path) 29 | 30 | def load_images(self, im_path): 31 | r""" 32 | Gets all images from the path specified 33 | and stacks them all up 34 | :param im_path: 35 | :return: 36 | """ 37 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path) 38 | ims = [] 39 | labels = [] 40 | for d_name in tqdm(os.listdir(im_path)): 41 | for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))): 42 | ims.append(fname) 43 | print('Found {} images for split {}'.format(len(ims), self.split)) 44 | return ims 45 | 46 | def __len__(self): 47 | return len(self.images) 48 | 49 | def __getitem__(self, index): 50 | im = Image.open(self.images[index]) 51 | im_tensor = torchvision.transforms.ToTensor()(im) 52 | 53 | # Convert input to -1 to 1 range. 54 | im_tensor = (2 * im_tensor) - 1 55 | 56 | if self.return_hints: 57 | canny_image = Image.open(self.images[index]) 58 | canny_image = np.array(canny_image) 59 | canny_image = cv2.Canny(canny_image, 100, 200) 60 | canny_image = canny_image[:, :, None] 61 | canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) 62 | canny_image_tensor = torchvision.transforms.ToTensor()(canny_image) 63 | return im_tensor, canny_image_tensor 64 | else: 65 | return im_tensor 66 | 67 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/models/__init__.py -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_time_embedding(time_steps, temb_dim): 6 | r""" 7 | Convert time steps tensor into an embedding using the 8 | sinusoidal time embedding formula 9 | :param time_steps: 1D tensor of length batch size 10 | :param temb_dim: Dimension of the embedding 11 | :return: BxD embedding representation of B time steps 12 | """ 13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" 14 | 15 | # factor = 10000^(2i/d_model) 16 | factor = 10000 ** ((torch.arange( 17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) 18 | ) 19 | 20 | # pos / factor 21 | # timesteps B -> B, 1 -> B, temb_dim 22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor 23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) 24 | return t_emb 25 | 26 | 27 | class DownBlock(nn.Module): 28 | r""" 29 | Down conv block with attention. 30 | Sequence of following block 31 | 1. Resnet block with time embedding 32 | 2. Attention block 33 | 3. Downsample 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, t_emb_dim, 37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): 38 | super().__init__() 39 | self.num_layers = num_layers 40 | self.down_sample = down_sample 41 | self.attn = attn 42 | self.context_dim = context_dim 43 | self.cross_attn = cross_attn 44 | self.t_emb_dim = t_emb_dim 45 | self.resnet_conv_first = nn.ModuleList( 46 | [ 47 | nn.Sequential( 48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 49 | nn.SiLU(), 50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, 51 | kernel_size=3, stride=1, padding=1), 52 | ) 53 | for i in range(num_layers) 54 | ] 55 | ) 56 | if self.t_emb_dim is not None: 57 | self.t_emb_layers = nn.ModuleList([ 58 | nn.Sequential( 59 | nn.SiLU(), 60 | nn.Linear(self.t_emb_dim, out_channels) 61 | ) 62 | for _ in range(num_layers) 63 | ]) 64 | self.resnet_conv_second = nn.ModuleList( 65 | [ 66 | nn.Sequential( 67 | nn.GroupNorm(norm_channels, out_channels), 68 | nn.SiLU(), 69 | nn.Conv2d(out_channels, out_channels, 70 | kernel_size=3, stride=1, padding=1), 71 | ) 72 | for _ in range(num_layers) 73 | ] 74 | ) 75 | 76 | if self.attn: 77 | self.attention_norms = nn.ModuleList( 78 | [nn.GroupNorm(norm_channels, out_channels) 79 | for _ in range(num_layers)] 80 | ) 81 | 82 | self.attentions = nn.ModuleList( 83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 84 | for _ in range(num_layers)] 85 | ) 86 | 87 | if self.cross_attn: 88 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 89 | self.cross_attention_norms = nn.ModuleList( 90 | [nn.GroupNorm(norm_channels, out_channels) 91 | for _ in range(num_layers)] 92 | ) 93 | self.cross_attentions = nn.ModuleList( 94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 95 | for _ in range(num_layers)] 96 | ) 97 | self.context_proj = nn.ModuleList( 98 | [nn.Linear(context_dim, out_channels) 99 | for _ in range(num_layers)] 100 | ) 101 | 102 | self.residual_input_conv = nn.ModuleList( 103 | [ 104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 105 | for i in range(num_layers) 106 | ] 107 | ) 108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 109 | 4, 2, 1) if self.down_sample else nn.Identity() 110 | 111 | def forward(self, x, t_emb=None, context=None): 112 | out = x 113 | for i in range(self.num_layers): 114 | # Resnet block of Unet 115 | resnet_input = out 116 | out = self.resnet_conv_first[i](out) 117 | if self.t_emb_dim is not None: 118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 119 | out = self.resnet_conv_second[i](out) 120 | out = out + self.residual_input_conv[i](resnet_input) 121 | 122 | if self.attn: 123 | # Attention block of Unet 124 | batch_size, channels, h, w = out.shape 125 | in_attn = out.reshape(batch_size, channels, h * w) 126 | in_attn = self.attention_norms[i](in_attn) 127 | in_attn = in_attn.transpose(1, 2) 128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 130 | out = out + out_attn 131 | 132 | if self.cross_attn: 133 | assert context is not None, "context cannot be None if cross attention layers are used" 134 | batch_size, channels, h, w = out.shape 135 | in_attn = out.reshape(batch_size, channels, h * w) 136 | in_attn = self.cross_attention_norms[i](in_attn) 137 | in_attn = in_attn.transpose(1, 2) 138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 139 | context_proj = self.context_proj[i](context) 140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 142 | out = out + out_attn 143 | 144 | # Downsample 145 | out = self.down_sample_conv(out) 146 | return out 147 | 148 | 149 | class MidBlock(nn.Module): 150 | r""" 151 | Mid conv block with attention. 152 | Sequence of following blocks 153 | 1. Resnet block with time embedding 154 | 2. Attention block 155 | 3. Resnet block with time embedding 156 | """ 157 | 158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, 159 | context_dim=None): 160 | super().__init__() 161 | self.num_layers = num_layers 162 | self.t_emb_dim = t_emb_dim 163 | self.context_dim = context_dim 164 | self.cross_attn = cross_attn 165 | self.resnet_conv_first = nn.ModuleList( 166 | [ 167 | nn.Sequential( 168 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 169 | nn.SiLU(), 170 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 171 | padding=1), 172 | ) 173 | for i in range(num_layers + 1) 174 | ] 175 | ) 176 | 177 | if self.t_emb_dim is not None: 178 | self.t_emb_layers = nn.ModuleList([ 179 | nn.Sequential( 180 | nn.SiLU(), 181 | nn.Linear(t_emb_dim, out_channels) 182 | ) 183 | for _ in range(num_layers + 1) 184 | ]) 185 | self.resnet_conv_second = nn.ModuleList( 186 | [ 187 | nn.Sequential( 188 | nn.GroupNorm(norm_channels, out_channels), 189 | nn.SiLU(), 190 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 191 | ) 192 | for _ in range(num_layers + 1) 193 | ] 194 | ) 195 | 196 | self.attention_norms = nn.ModuleList( 197 | [nn.GroupNorm(norm_channels, out_channels) 198 | for _ in range(num_layers)] 199 | ) 200 | 201 | self.attentions = nn.ModuleList( 202 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 203 | for _ in range(num_layers)] 204 | ) 205 | if self.cross_attn: 206 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 207 | self.cross_attention_norms = nn.ModuleList( 208 | [nn.GroupNorm(norm_channels, out_channels) 209 | for _ in range(num_layers)] 210 | ) 211 | self.cross_attentions = nn.ModuleList( 212 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 213 | for _ in range(num_layers)] 214 | ) 215 | self.context_proj = nn.ModuleList( 216 | [nn.Linear(context_dim, out_channels) 217 | for _ in range(num_layers)] 218 | ) 219 | self.residual_input_conv = nn.ModuleList( 220 | [ 221 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 222 | for i in range(num_layers + 1) 223 | ] 224 | ) 225 | 226 | def forward(self, x, t_emb=None, context=None): 227 | out = x 228 | 229 | # First resnet block 230 | resnet_input = out 231 | out = self.resnet_conv_first[0](out) 232 | if self.t_emb_dim is not None: 233 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] 234 | out = self.resnet_conv_second[0](out) 235 | out = out + self.residual_input_conv[0](resnet_input) 236 | 237 | for i in range(self.num_layers): 238 | # Attention Block 239 | batch_size, channels, h, w = out.shape 240 | in_attn = out.reshape(batch_size, channels, h * w) 241 | in_attn = self.attention_norms[i](in_attn) 242 | in_attn = in_attn.transpose(1, 2) 243 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 244 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 245 | out = out + out_attn 246 | 247 | if self.cross_attn: 248 | assert context is not None, "context cannot be None if cross attention layers are used" 249 | batch_size, channels, h, w = out.shape 250 | in_attn = out.reshape(batch_size, channels, h * w) 251 | in_attn = self.cross_attention_norms[i](in_attn) 252 | in_attn = in_attn.transpose(1, 2) 253 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 254 | context_proj = self.context_proj[i](context) 255 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 256 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 257 | out = out + out_attn 258 | 259 | # Resnet Block 260 | resnet_input = out 261 | out = self.resnet_conv_first[i + 1](out) 262 | if self.t_emb_dim is not None: 263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] 264 | out = self.resnet_conv_second[i + 1](out) 265 | out = out + self.residual_input_conv[i + 1](resnet_input) 266 | 267 | return out 268 | 269 | 270 | class UpBlock(nn.Module): 271 | r""" 272 | Up conv block with attention. 273 | Sequence of following blocks 274 | 1. Upsample 275 | 1. Concatenate Down block output 276 | 2. Resnet block with time embedding 277 | 3. Attention Block 278 | """ 279 | 280 | def __init__(self, in_channels, out_channels, t_emb_dim, 281 | up_sample, num_heads, num_layers, attn, norm_channels): 282 | super().__init__() 283 | self.num_layers = num_layers 284 | self.up_sample = up_sample 285 | self.t_emb_dim = t_emb_dim 286 | self.attn = attn 287 | self.resnet_conv_first = nn.ModuleList( 288 | [ 289 | nn.Sequential( 290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 291 | nn.SiLU(), 292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 293 | padding=1), 294 | ) 295 | for i in range(num_layers) 296 | ] 297 | ) 298 | 299 | if self.t_emb_dim is not None: 300 | self.t_emb_layers = nn.ModuleList([ 301 | nn.Sequential( 302 | nn.SiLU(), 303 | nn.Linear(t_emb_dim, out_channels) 304 | ) 305 | for _ in range(num_layers) 306 | ]) 307 | 308 | self.resnet_conv_second = nn.ModuleList( 309 | [ 310 | nn.Sequential( 311 | nn.GroupNorm(norm_channels, out_channels), 312 | nn.SiLU(), 313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 314 | ) 315 | for _ in range(num_layers) 316 | ] 317 | ) 318 | if self.attn: 319 | self.attention_norms = nn.ModuleList( 320 | [ 321 | nn.GroupNorm(norm_channels, out_channels) 322 | for _ in range(num_layers) 323 | ] 324 | ) 325 | 326 | self.attentions = nn.ModuleList( 327 | [ 328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 329 | for _ in range(num_layers) 330 | ] 331 | ) 332 | 333 | self.residual_input_conv = nn.ModuleList( 334 | [ 335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 336 | for i in range(num_layers) 337 | ] 338 | ) 339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, 340 | 4, 2, 1) \ 341 | if self.up_sample else nn.Identity() 342 | 343 | def forward(self, x, out_down=None, t_emb=None): 344 | # Upsample 345 | x = self.up_sample_conv(x) 346 | 347 | # Concat with Downblock output 348 | if out_down is not None: 349 | x = torch.cat([x, out_down], dim=1) 350 | 351 | out = x 352 | for i in range(self.num_layers): 353 | # Resnet Block 354 | resnet_input = out 355 | out = self.resnet_conv_first[i](out) 356 | if self.t_emb_dim is not None: 357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 358 | out = self.resnet_conv_second[i](out) 359 | out = out + self.residual_input_conv[i](resnet_input) 360 | 361 | # Self Attention 362 | if self.attn: 363 | batch_size, channels, h, w = out.shape 364 | in_attn = out.reshape(batch_size, channels, h * w) 365 | in_attn = self.attention_norms[i](in_attn) 366 | in_attn = in_attn.transpose(1, 2) 367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 369 | out = out + out_attn 370 | return out 371 | 372 | 373 | class UpBlockUnet(nn.Module): 374 | r""" 375 | Up conv block with attention. 376 | Sequence of following blocks 377 | 1. Upsample 378 | 1. Concatenate Down block output 379 | 2. Resnet block with time embedding 380 | 3. Attention Block 381 | """ 382 | 383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, 384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): 385 | super().__init__() 386 | self.num_layers = num_layers 387 | self.up_sample = up_sample 388 | self.t_emb_dim = t_emb_dim 389 | self.cross_attn = cross_attn 390 | self.context_dim = context_dim 391 | self.resnet_conv_first = nn.ModuleList( 392 | [ 393 | nn.Sequential( 394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 395 | nn.SiLU(), 396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 397 | padding=1), 398 | ) 399 | for i in range(num_layers) 400 | ] 401 | ) 402 | 403 | if self.t_emb_dim is not None: 404 | self.t_emb_layers = nn.ModuleList([ 405 | nn.Sequential( 406 | nn.SiLU(), 407 | nn.Linear(t_emb_dim, out_channels) 408 | ) 409 | for _ in range(num_layers) 410 | ]) 411 | 412 | self.resnet_conv_second = nn.ModuleList( 413 | [ 414 | nn.Sequential( 415 | nn.GroupNorm(norm_channels, out_channels), 416 | nn.SiLU(), 417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 418 | ) 419 | for _ in range(num_layers) 420 | ] 421 | ) 422 | 423 | self.attention_norms = nn.ModuleList( 424 | [ 425 | nn.GroupNorm(norm_channels, out_channels) 426 | for _ in range(num_layers) 427 | ] 428 | ) 429 | 430 | self.attentions = nn.ModuleList( 431 | [ 432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 433 | for _ in range(num_layers) 434 | ] 435 | ) 436 | 437 | if self.cross_attn: 438 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 439 | self.cross_attention_norms = nn.ModuleList( 440 | [nn.GroupNorm(norm_channels, out_channels) 441 | for _ in range(num_layers)] 442 | ) 443 | self.cross_attentions = nn.ModuleList( 444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 445 | for _ in range(num_layers)] 446 | ) 447 | self.context_proj = nn.ModuleList( 448 | [nn.Linear(context_dim, out_channels) 449 | for _ in range(num_layers)] 450 | ) 451 | self.residual_input_conv = nn.ModuleList( 452 | [ 453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 454 | for i in range(num_layers) 455 | ] 456 | ) 457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 458 | 4, 2, 1) \ 459 | if self.up_sample else nn.Identity() 460 | 461 | def forward(self, x, out_down=None, t_emb=None, context=None): 462 | x = self.up_sample_conv(x) 463 | if out_down is not None: 464 | x = torch.cat([x, out_down], dim=1) 465 | 466 | out = x 467 | for i in range(self.num_layers): 468 | # Resnet 469 | resnet_input = out 470 | out = self.resnet_conv_first[i](out) 471 | if self.t_emb_dim is not None: 472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 473 | out = self.resnet_conv_second[i](out) 474 | out = out + self.residual_input_conv[i](resnet_input) 475 | # Self Attention 476 | batch_size, channels, h, w = out.shape 477 | in_attn = out.reshape(batch_size, channels, h * w) 478 | in_attn = self.attention_norms[i](in_attn) 479 | in_attn = in_attn.transpose(1, 2) 480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 482 | out = out + out_attn 483 | # Cross Attention 484 | if self.cross_attn: 485 | assert context is not None, "context cannot be None if cross attention layers are used" 486 | batch_size, channels, h, w = out.shape 487 | in_attn = out.reshape(batch_size, channels, h * w) 488 | in_attn = self.cross_attention_norms[i](in_attn) 489 | in_attn = in_attn.transpose(1, 2) 490 | assert len(context.shape) == 3, \ 491 | "Context shape does not match B,_,CONTEXT_DIM" 492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \ 493 | "Context shape does not match B,_,CONTEXT_DIM" 494 | context_proj = self.context_proj[i](context) 495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 497 | out = out + out_attn 498 | 499 | return out 500 | 501 | 502 | -------------------------------------------------------------------------------- /models/controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.unet_base import Unet 4 | from models.unet_base import get_time_embedding 5 | 6 | 7 | def make_zero_module(module): 8 | for p in module.parameters(): 9 | p.detach().zero_() 10 | return module 11 | 12 | 13 | class ControlNet(nn.Module): 14 | r""" 15 | Control Net Module for DDPM 16 | """ 17 | def __init__(self, model_config, 18 | model_locked=True, 19 | model_ckpt=None, 20 | device=None): 21 | super().__init__() 22 | # Trained DDPM 23 | self.model_locked = model_locked 24 | self.trained_unet = Unet(model_config) 25 | 26 | # Load weights for the trained model 27 | if model_ckpt is not None and device is not None: 28 | print('Loading Trained Diffusion Model') 29 | self.trained_unet.load_state_dict(torch.load(model_ckpt, 30 | map_location=device), strict=True) 31 | 32 | # ControlNet Copy of Trained DDPM 33 | # use_up = False removes the upblocks(decoder layers) from DDPM Unet 34 | self.control_copy_unet = Unet(model_config, use_up=False) 35 | # Load same weights as the trained model 36 | if model_ckpt is not None and device is not None: 37 | print('Loading Control Diffusion Model') 38 | self.control_copy_unet.load_state_dict(torch.load(model_ckpt, 39 | map_location=device), strict=False) 40 | 41 | # Hint Block for ControlNet 42 | # Stack of Conv activation and zero convolution at the end 43 | self.control_copy_unet_hint_block = nn.Sequential( 44 | nn.Conv2d(model_config['hint_channels'], 45 | 64, 46 | kernel_size=3, 47 | padding=(1, 1)), 48 | nn.SiLU(), 49 | nn.Conv2d(64, 50 | 128, 51 | kernel_size=3, 52 | padding=(1, 1)), 53 | nn.SiLU(), 54 | nn.Conv2d(128, 55 | self.trained_unet.down_channels[0], 56 | kernel_size=3, 57 | padding=(1, 1)), 58 | nn.SiLU(), 59 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[0], 60 | self.trained_unet.down_channels[0], 61 | kernel_size=1, 62 | padding=0)) 63 | ) 64 | 65 | # Zero Convolution Module for Downblocks(encoder Layers) 66 | self.control_copy_unet_down_zero_convs = nn.ModuleList([ 67 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[i], 68 | self.trained_unet.down_channels[i], 69 | kernel_size=1, 70 | padding=0)) 71 | for i in range(len(self.trained_unet.down_channels)-1) 72 | ]) 73 | 74 | # Zero Convolution Module for MidBlocks 75 | self.control_copy_unet_mid_zero_convs = nn.ModuleList([ 76 | make_zero_module(nn.Conv2d(self.trained_unet.mid_channels[i], 77 | self.trained_unet.mid_channels[i], 78 | kernel_size=1, 79 | padding=0)) 80 | for i in range(1, len(self.trained_unet.mid_channels)) 81 | ]) 82 | 83 | def get_params(self): 84 | # Add all ControlNet parameters 85 | # First is our copy of unet 86 | params = list(self.control_copy_unet.parameters()) 87 | 88 | # Add parameters of hint Blocks & Zero convolutions for down/mid blocks 89 | params += list(self.control_copy_unet_hint_block.parameters()) 90 | params += list(self.control_copy_unet_down_zero_convs.parameters()) 91 | params += list(self.control_copy_unet_mid_zero_convs.parameters()) 92 | 93 | # If we desire to not have the decoder layers locked, then add 94 | # them as well 95 | if not self.model_locked: 96 | params += list(self.trained_unet.ups.parameters()) 97 | params += list(self.trained_unet.norm_out.parameters()) 98 | params += list(self.trained_unet.conv_out.parameters()) 99 | return params 100 | 101 | def forward(self, x, t, hint): 102 | # Time embedding and timestep projection layers of trained unet 103 | trained_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(), 104 | self.trained_unet.t_emb_dim) 105 | trained_unet_t_emb = self.trained_unet.t_proj(trained_unet_t_emb) 106 | 107 | # Get all downblocks output of trained unet first 108 | trained_unet_down_outs = [] 109 | with torch.no_grad(): 110 | train_unet_out = self.trained_unet.conv_in(x) 111 | for idx, down in enumerate(self.trained_unet.downs): 112 | trained_unet_down_outs.append(train_unet_out) 113 | train_unet_out = down(train_unet_out, trained_unet_t_emb) 114 | 115 | # ControlNet Layers start here # 116 | # Time embedding and timestep projection layers of controlnet's copy of unet 117 | control_copy_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(), 118 | self.control_copy_unet.t_emb_dim) 119 | control_copy_unet_t_emb = self.control_copy_unet.t_proj(control_copy_unet_t_emb) 120 | 121 | # Hint block of controlnet's copy of unet 122 | control_copy_unet_hint_out = self.control_copy_unet_hint_block(hint) 123 | 124 | # Call conv_in layer for controlnet's copy of unet 125 | # and add hint blocks output to it 126 | control_copy_unet_out = self.control_copy_unet.conv_in(x) 127 | control_copy_unet_out += control_copy_unet_hint_out 128 | 129 | # Get all downblocks output for controlnet's copy of unet 130 | control_copy_unet_down_outs = [] 131 | for idx, down in enumerate(self.control_copy_unet.downs): 132 | # Save the control nets copy output after passing it through zero conv layers 133 | control_copy_unet_down_outs.append( 134 | self.control_copy_unet_down_zero_convs[idx](control_copy_unet_out) 135 | ) 136 | control_copy_unet_out = down(control_copy_unet_out, control_copy_unet_t_emb) 137 | 138 | for idx in range(len(self.control_copy_unet.mids)): 139 | # Get midblock output of controlnets copy of unet 140 | control_copy_unet_out = self.control_copy_unet.mids[idx]( 141 | control_copy_unet_out, 142 | control_copy_unet_t_emb 143 | ) 144 | 145 | # Get midblock output of trained unet 146 | train_unet_out = self.trained_unet.mids[idx](train_unet_out, trained_unet_t_emb) 147 | 148 | # Add midblock output of controlnets copy of unet to that of trained unet 149 | # but after passing them through zero conv layers 150 | train_unet_out += self.control_copy_unet_mid_zero_convs[idx](control_copy_unet_out) 151 | 152 | # Call upblocks of trained unet 153 | for up in self.trained_unet.ups: 154 | # Get downblocks output from both trained unet and controlnets copy of unet 155 | trained_unet_down_out = trained_unet_down_outs.pop() 156 | control_copy_unet_down_out = control_copy_unet_down_outs.pop() 157 | 158 | # Add these together and pass this as downblock input to upblock 159 | train_unet_out = up(train_unet_out, 160 | control_copy_unet_down_out + trained_unet_down_out, 161 | trained_unet_t_emb) 162 | 163 | # Call output layers of trained unet 164 | train_unet_out = self.trained_unet.norm_out(train_unet_out) 165 | train_unet_out = nn.SiLU()(train_unet_out) 166 | train_unet_out = self.trained_unet.conv_out(train_unet_out) 167 | # out B x C x H x W 168 | return train_unet_out 169 | 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /models/controlnet_ldm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.unet_cond_base import Unet 4 | from models.blocks import get_time_embedding 5 | 6 | 7 | def make_zero_module(module): 8 | for p in module.parameters(): 9 | p.detach().zero_() 10 | return module 11 | 12 | 13 | class ControlNet(nn.Module): 14 | r""" 15 | Control Net Module for DDPM 16 | """ 17 | def __init__(self, im_channels, 18 | model_config, 19 | model_locked=True, 20 | model_ckpt=None, 21 | device=None, 22 | down_sample_factor=32): 23 | super().__init__() 24 | # Trained DDPM 25 | self.model_locked = model_locked 26 | self.trained_unet = Unet(im_channels, model_config) 27 | 28 | # Load weights for the trained model 29 | if model_ckpt is not None and device is not None: 30 | print('Loading Trained Diffusion Model') 31 | self.trained_unet.load_state_dict(torch.load(model_ckpt, 32 | map_location=device), strict=True) 33 | 34 | # ControlNet Copy of Trained DDPM 35 | # use_up = False removes the upblocks(decoder layers) from DDPM Unet 36 | self.control_unet = Unet(im_channels, model_config, use_up=False) 37 | # Load same weights as the trained model 38 | if model_ckpt is not None and device is not None: 39 | print('Loading Control Diffusion Model') 40 | self.control_unet.load_state_dict(torch.load(model_ckpt, 41 | map_location=device), strict=False) 42 | 43 | # Hint Block for ControlNet 44 | # Stack of Conv activation and zero convolution at the end 45 | base_hint_channel = 16 46 | curr_down_sample_factor = down_sample_factor 47 | hint_layers = [nn.Sequential( 48 | nn.Conv2d(model_config['hint_channels'], 49 | base_hint_channel, 50 | kernel_size=3, 51 | padding=(1, 1)), 52 | nn.SiLU())] 53 | while curr_down_sample_factor > 1: 54 | hint_layers.append(nn.Sequential( 55 | nn.Conv2d(base_hint_channel, 56 | base_hint_channel*2, 57 | kernel_size=3, 58 | padding=(1, 1), 59 | stride=2), 60 | nn.SiLU(), 61 | nn.Conv2d(base_hint_channel*2, 62 | base_hint_channel*2, 63 | kernel_size=3, 64 | padding=(1, 1)) 65 | )) 66 | base_hint_channel = base_hint_channel * 2 67 | curr_down_sample_factor = curr_down_sample_factor / 2 68 | hint_layers.append(nn.Sequential( 69 | nn.Conv2d(base_hint_channel, 70 | self.trained_unet.down_channels[0], 71 | kernel_size=3, 72 | padding=(1, 1)), 73 | nn.SiLU(), 74 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[0], 75 | self.trained_unet.down_channels[0], 76 | kernel_size=1, 77 | padding=0)) 78 | )) 79 | self.control_unet_hint_block = nn.Sequential(*hint_layers) 80 | 81 | # Zero Convolution Module for Downblocks(encoder Layers) 82 | self.control_unet_down_zero_convs = nn.ModuleList([ 83 | make_zero_module(nn.Conv2d(self.trained_unet.down_channels[i], 84 | self.trained_unet.down_channels[i], 85 | kernel_size=1, 86 | padding=0)) 87 | for i in range(len(self.trained_unet.down_channels)-1) 88 | ]) 89 | 90 | # Zero Convolution Module for MidBlocks 91 | self.control_unet_mid_zero_convs = nn.ModuleList([ 92 | make_zero_module(nn.Conv2d(self.trained_unet.mid_channels[i], 93 | self.trained_unet.mid_channels[i], 94 | kernel_size=1, 95 | padding=0)) 96 | for i in range(1, len(self.trained_unet.mid_channels)) 97 | ]) 98 | 99 | def get_params(self): 100 | # Add all ControlNet parameters 101 | # First is our copy of unet 102 | params = list(self.control_unet.parameters()) 103 | 104 | # Add parameters of hint Blocks & Zero convolutions for down/mid blocks 105 | params += list(self.control_unet_hint_block.parameters()) 106 | params += list(self.control_unet_down_zero_convs.parameters()) 107 | params += list(self.control_unet_mid_zero_convs.parameters()) 108 | 109 | # If we desire to not have the decoder layers locked, then add 110 | # them as well 111 | if not self.model_locked: 112 | params += list(self.trained_unet.ups.parameters()) 113 | params += list(self.trained_unet.norm_out.parameters()) 114 | params += list(self.trained_unet.conv_out.parameters()) 115 | return params 116 | 117 | def forward(self, x, t, hint): 118 | # Time embedding and timestep projection layers of trained unet 119 | trained_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(), 120 | self.trained_unet.t_emb_dim) 121 | trained_unet_t_emb = self.trained_unet.t_proj(trained_unet_t_emb) 122 | 123 | # Get all downblocks output of trained unet first 124 | trained_unet_down_outs = [] 125 | with torch.no_grad(): 126 | train_unet_out = self.trained_unet.conv_in(x) 127 | for idx, down in enumerate(self.trained_unet.downs): 128 | trained_unet_down_outs.append(train_unet_out) 129 | train_unet_out = down(train_unet_out, trained_unet_t_emb) 130 | 131 | # ControlNet Layers start here # 132 | # Time embedding and timestep projection layers of controlnet's copy of unet 133 | control_unet_t_emb = get_time_embedding(torch.as_tensor(t).long(), 134 | self.control_unet.t_emb_dim) 135 | control_unet_t_emb = self.control_unet.t_proj(control_unet_t_emb) 136 | 137 | # Hint block of controlnet's copy of unet 138 | control_unet_hint_out = self.control_unet_hint_block(hint) 139 | 140 | # Call conv_in layer for controlnet's copy of unet 141 | # and add hint blocks output to it 142 | control_unet_out = self.control_unet.conv_in(x) 143 | control_unet_out += control_unet_hint_out 144 | 145 | # Get all downblocks output for controlnet's copy of unet 146 | control_unet_down_outs = [] 147 | for idx, down in enumerate(self.control_unet.downs): 148 | # Save the control nets copy output after passing it through zero conv layers 149 | control_unet_down_outs.append(self.control_unet_down_zero_convs[idx](control_unet_out)) 150 | control_unet_out = down(control_unet_out, control_unet_t_emb) 151 | 152 | for idx in range(len(self.control_unet.mids)): 153 | # Get midblock output of controlnets copy of unet 154 | control_unet_out = self.control_unet.mids[idx](control_unet_out, control_unet_t_emb) 155 | 156 | # Get midblock output of trained unet 157 | train_unet_out = self.trained_unet.mids[idx](train_unet_out, trained_unet_t_emb) 158 | 159 | # Add midblock output of controlnets copy of unet to that of trained unet 160 | # but after passing them through zero conv layers 161 | train_unet_out += self.control_unet_mid_zero_convs[idx](control_unet_out) 162 | 163 | # Call upblocks of trained unet 164 | for up in self.trained_unet.ups: 165 | # Get downblocks output from both trained unet and controlnets copy of unet 166 | trained_unet_down_out = trained_unet_down_outs.pop() 167 | control_unet_down_out = control_unet_down_outs.pop() 168 | 169 | # Add these together and pass this as downblock input to upblock 170 | train_unet_out = up(train_unet_out, 171 | control_unet_down_out + trained_unet_down_out, 172 | trained_unet_t_emb) 173 | 174 | # Call output layers of trained unet 175 | train_unet_out = self.trained_unet.norm_out(train_unet_out) 176 | train_unet_out = nn.SiLU()(train_unet_out) 177 | train_unet_out = self.trained_unet.conv_out(train_unet_out) 178 | # out B x C x H x W 179 | return train_unet_out 180 | 181 | 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Discriminator(nn.Module): 6 | r""" 7 | PatchGAN Discriminator. 8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to 9 | 1 scalar value , we instead predict grid of values. 10 | Where each grid is prediction of how likely 11 | the discriminator thinks that the image patch corresponding 12 | to the grid cell is real 13 | """ 14 | 15 | def __init__(self, im_channels=3, 16 | conv_channels=[64, 128, 256], 17 | kernels=[4,4,4,4], 18 | strides=[2,2,2,1], 19 | paddings=[1,1,1,1]): 20 | super().__init__() 21 | self.im_channels = im_channels 22 | activation = nn.LeakyReLU(0.2) 23 | layers_dim = [self.im_channels] + conv_channels + [1] 24 | self.layers = nn.ModuleList([ 25 | nn.Sequential( 26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1], 27 | kernel_size=kernels[i], 28 | stride=strides[i], 29 | padding=paddings[i], 30 | bias=False if i !=0 else True), 31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(), 32 | activation if i != len(layers_dim) - 2 else nn.Identity() 33 | ) 34 | for i in range(len(layers_dim) - 1) 35 | ]) 36 | 37 | def forward(self, x): 38 | out = x 39 | for layer in self.layers: 40 | out = layer(out) 41 | return out 42 | 43 | 44 | if __name__ == '__main__': 45 | x = torch.randn((2,3, 256, 256)) 46 | prob = Discriminator(im_channels=3)(x) 47 | print(prob.shape) 48 | -------------------------------------------------------------------------------- /models/lpips.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn 9 | import torchvision 10 | 11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | def spatial_average(in_tens, keepdim=True): 19 | return in_tens.mean([2, 3], keepdim=keepdim) 20 | 21 | 22 | class vgg16(torch.nn.Module): 23 | def __init__(self, requires_grad=False, pretrained=True): 24 | super(vgg16, self).__init__() 25 | # Load pretrained vgg model from torchvision 26 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | self.N_slices = 5 33 | for x in range(4): 34 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 35 | for x in range(4, 9): 36 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 37 | for x in range(9, 16): 38 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 39 | for x in range(16, 23): 40 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 41 | for x in range(23, 30): 42 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 43 | 44 | # Freeze vgg model 45 | if not requires_grad: 46 | for param in self.parameters(): 47 | param.requires_grad = False 48 | 49 | def forward(self, X): 50 | # Return output of vgg features 51 | h = self.slice1(X) 52 | h_relu1_2 = h 53 | h = self.slice2(h) 54 | h_relu2_2 = h 55 | h = self.slice3(h) 56 | h_relu3_3 = h 57 | h = self.slice4(h) 58 | h_relu4_3 = h 59 | h = self.slice5(h) 60 | h_relu5_3 = h 61 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 62 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 63 | return out 64 | 65 | 66 | # Learned perceptual metric 67 | class LPIPS(nn.Module): 68 | def __init__(self, net='vgg', version='0.1', use_dropout=True): 69 | super(LPIPS, self).__init__() 70 | self.version = version 71 | # Imagenet normalization 72 | self.scaling_layer = ScalingLayer() 73 | ######################## 74 | 75 | # Instantiate vgg model 76 | self.chns = [64, 128, 256, 512, 512] 77 | self.L = len(self.chns) 78 | self.net = vgg16(pretrained=True, requires_grad=False) 79 | 80 | # Add 1x1 convolutional Layers 81 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 82 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 83 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 84 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 85 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 86 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 87 | self.lins = nn.ModuleList(self.lins) 88 | ######################## 89 | 90 | # Load the weights of trained LPIPS model 91 | import inspect 92 | import os 93 | model_path = os.path.abspath( 94 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) 95 | print('Loading model from: %s' % model_path) 96 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False) 97 | ######################## 98 | 99 | # Freeze all parameters 100 | self.eval() 101 | for param in self.parameters(): 102 | param.requires_grad = False 103 | ######################## 104 | 105 | def forward(self, in0, in1, normalize=False): 106 | # Scale the inputs to -1 to +1 range if needed 107 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 108 | in0 = 2 * in0 - 1 109 | in1 = 2 * in1 - 1 110 | ######################## 111 | 112 | # Normalize the inputs according to imagenet normalization 113 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) 114 | ######################## 115 | 116 | # Get VGG outputs for image0 and image1 117 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 118 | feats0, feats1, diffs = {}, {}, {} 119 | ######################## 120 | 121 | # Compute Square of Difference for each layer output 122 | for kk in range(self.L): 123 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize( 124 | outs1[kk]) 125 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 126 | ######################## 127 | 128 | # 1x1 convolution followed by spatial average on the square differences 129 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 130 | val = 0 131 | 132 | # Aggregate the results of each layer 133 | for l in range(self.L): 134 | val += res[l] 135 | return val 136 | 137 | 138 | class ScalingLayer(nn.Module): 139 | def __init__(self): 140 | super(ScalingLayer, self).__init__() 141 | # Imagnet normalization for (0-1) 142 | # mean = [0.485, 0.456, 0.406] 143 | # std = [0.229, 0.224, 0.225] 144 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 145 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 146 | 147 | def forward(self, inp): 148 | return (inp - self.shift) / self.scale 149 | 150 | 151 | class NetLinLayer(nn.Module): 152 | ''' A single linear layer which does a 1x1 conv ''' 153 | 154 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 155 | super(NetLinLayer, self).__init__() 156 | 157 | layers = [nn.Dropout(), ] if (use_dropout) else [] 158 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 159 | self.model = nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | out = self.model(x) 163 | return out 164 | -------------------------------------------------------------------------------- /models/unet_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_time_embedding(time_steps, temb_dim): 6 | r""" 7 | Convert time steps tensor into an embedding using the 8 | sinusoidal time embedding formula 9 | :param time_steps: 1D tensor of length batch size 10 | :param temb_dim: Dimension of the embedding 11 | :return: BxD embedding representation of B time steps 12 | """ 13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" 14 | 15 | # factor = 10000^(2i/d_model) 16 | factor = 10000 ** ((torch.arange( 17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) 18 | ) 19 | 20 | # pos / factor 21 | # timesteps B -> B, 1 -> B, temb_dim 22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor 23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) 24 | return t_emb 25 | 26 | 27 | class DownBlock(nn.Module): 28 | r""" 29 | Down conv block with attention. 30 | Sequence of following block 31 | 1. Resnet block with time embedding 32 | 2. Attention block 33 | 3. Downsample using 2x2 average pooling 34 | """ 35 | def __init__(self, in_channels, out_channels, t_emb_dim, 36 | down_sample=True, num_heads=4, num_layers=1): 37 | super().__init__() 38 | self.num_layers = num_layers 39 | self.down_sample = down_sample 40 | self.resnet_conv_first = nn.ModuleList( 41 | [ 42 | nn.Sequential( 43 | nn.GroupNorm(8, in_channels if i == 0 else out_channels), 44 | nn.SiLU(), 45 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, 46 | kernel_size=3, stride=1, padding=1), 47 | ) 48 | for i in range(num_layers) 49 | ] 50 | ) 51 | self.t_emb_layers = nn.ModuleList([ 52 | nn.Sequential( 53 | nn.SiLU(), 54 | nn.Linear(t_emb_dim, out_channels) 55 | ) 56 | for _ in range(num_layers) 57 | ]) 58 | self.resnet_conv_second = nn.ModuleList( 59 | [ 60 | nn.Sequential( 61 | nn.GroupNorm(8, out_channels), 62 | nn.SiLU(), 63 | nn.Conv2d(out_channels, out_channels, 64 | kernel_size=3, stride=1, padding=1), 65 | ) 66 | for _ in range(num_layers) 67 | ] 68 | ) 69 | self.attention_norms = nn.ModuleList( 70 | [nn.GroupNorm(8, out_channels) 71 | for _ in range(num_layers)] 72 | ) 73 | 74 | self.attentions = nn.ModuleList( 75 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 76 | for _ in range(num_layers)] 77 | ) 78 | self.residual_input_conv = nn.ModuleList( 79 | [ 80 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 81 | for i in range(num_layers) 82 | ] 83 | ) 84 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 85 | 4, 2, 1) if self.down_sample else nn.Identity() 86 | 87 | def forward(self, x, t_emb): 88 | out = x 89 | for i in range(self.num_layers): 90 | 91 | # Resnet block of Unet 92 | resnet_input = out 93 | out = self.resnet_conv_first[i](out) 94 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 95 | out = self.resnet_conv_second[i](out) 96 | out = out + self.residual_input_conv[i](resnet_input) 97 | 98 | # Attention block of Unet 99 | batch_size, channels, h, w = out.shape 100 | in_attn = out.reshape(batch_size, channels, h * w) 101 | in_attn = self.attention_norms[i](in_attn) 102 | in_attn = in_attn.transpose(1, 2) 103 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 104 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 105 | out = out + out_attn 106 | 107 | out = self.down_sample_conv(out) 108 | return out 109 | 110 | 111 | class MidBlock(nn.Module): 112 | r""" 113 | Mid conv block with attention. 114 | Sequence of following blocks 115 | 1. Resnet block with time embedding 116 | 2. Attention block 117 | 3. Resnet block with time embedding 118 | """ 119 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1): 120 | super().__init__() 121 | self.num_layers = num_layers 122 | self.resnet_conv_first = nn.ModuleList( 123 | [ 124 | nn.Sequential( 125 | nn.GroupNorm(8, in_channels if i == 0 else out_channels), 126 | nn.SiLU(), 127 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 128 | padding=1), 129 | ) 130 | for i in range(num_layers+1) 131 | ] 132 | ) 133 | self.t_emb_layers = nn.ModuleList([ 134 | nn.Sequential( 135 | nn.SiLU(), 136 | nn.Linear(t_emb_dim, out_channels) 137 | ) 138 | for _ in range(num_layers + 1) 139 | ]) 140 | self.resnet_conv_second = nn.ModuleList( 141 | [ 142 | nn.Sequential( 143 | nn.GroupNorm(8, out_channels), 144 | nn.SiLU(), 145 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 146 | ) 147 | for _ in range(num_layers+1) 148 | ] 149 | ) 150 | 151 | self.attention_norms = nn.ModuleList( 152 | [nn.GroupNorm(8, out_channels) 153 | for _ in range(num_layers)] 154 | ) 155 | 156 | self.attentions = nn.ModuleList( 157 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 158 | for _ in range(num_layers)] 159 | ) 160 | self.residual_input_conv = nn.ModuleList( 161 | [ 162 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 163 | for i in range(num_layers+1) 164 | ] 165 | ) 166 | 167 | def forward(self, x, t_emb): 168 | out = x 169 | 170 | # First resnet block 171 | resnet_input = out 172 | out = self.resnet_conv_first[0](out) 173 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] 174 | out = self.resnet_conv_second[0](out) 175 | out = out + self.residual_input_conv[0](resnet_input) 176 | 177 | for i in range(self.num_layers): 178 | 179 | # Attention Block 180 | batch_size, channels, h, w = out.shape 181 | in_attn = out.reshape(batch_size, channels, h * w) 182 | in_attn = self.attention_norms[i](in_attn) 183 | in_attn = in_attn.transpose(1, 2) 184 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 185 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 186 | out = out + out_attn 187 | 188 | # Resnet Block 189 | resnet_input = out 190 | out = self.resnet_conv_first[i+1](out) 191 | out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None] 192 | out = self.resnet_conv_second[i+1](out) 193 | out = out + self.residual_input_conv[i+1](resnet_input) 194 | 195 | return out 196 | 197 | 198 | class UpBlock(nn.Module): 199 | r""" 200 | Up conv block with attention. 201 | Sequence of following blocks 202 | 1. Upsample 203 | 1. Concatenate Down block output 204 | 2. Resnet block with time embedding 205 | 3. Attention Block 206 | """ 207 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1): 208 | super().__init__() 209 | self.num_layers = num_layers 210 | self.up_sample = up_sample 211 | self.resnet_conv_first = nn.ModuleList( 212 | [ 213 | nn.Sequential( 214 | nn.GroupNorm(8, in_channels if i == 0 else out_channels), 215 | nn.SiLU(), 216 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 217 | padding=1), 218 | ) 219 | for i in range(num_layers) 220 | ] 221 | ) 222 | self.t_emb_layers = nn.ModuleList([ 223 | nn.Sequential( 224 | nn.SiLU(), 225 | nn.Linear(t_emb_dim, out_channels) 226 | ) 227 | for _ in range(num_layers) 228 | ]) 229 | self.resnet_conv_second = nn.ModuleList( 230 | [ 231 | nn.Sequential( 232 | nn.GroupNorm(8, out_channels), 233 | nn.SiLU(), 234 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 235 | ) 236 | for _ in range(num_layers) 237 | ] 238 | ) 239 | 240 | self.attention_norms = nn.ModuleList( 241 | [ 242 | nn.GroupNorm(8, out_channels) 243 | for _ in range(num_layers) 244 | ] 245 | ) 246 | 247 | self.attentions = nn.ModuleList( 248 | [ 249 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 250 | for _ in range(num_layers) 251 | ] 252 | ) 253 | self.residual_input_conv = nn.ModuleList( 254 | [ 255 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 256 | for i in range(num_layers) 257 | ] 258 | ) 259 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 260 | 4, 2, 1) \ 261 | if self.up_sample else nn.Identity() 262 | 263 | def forward(self, x, out_down, t_emb): 264 | x = self.up_sample_conv(x) 265 | x = torch.cat([x, out_down], dim=1) 266 | 267 | out = x 268 | for i in range(self.num_layers): 269 | # Resnet Block 270 | resnet_input = out 271 | out = self.resnet_conv_first[i](out) 272 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 273 | out = self.resnet_conv_second[i](out) 274 | out = out + self.residual_input_conv[i](resnet_input) 275 | 276 | # Attention Block 277 | batch_size, channels, h, w = out.shape 278 | in_attn = out.reshape(batch_size, channels, h * w) 279 | in_attn = self.attention_norms[i](in_attn) 280 | in_attn = in_attn.transpose(1, 2) 281 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 282 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 283 | out = out + out_attn 284 | 285 | return out 286 | 287 | 288 | class Unet(nn.Module): 289 | r""" 290 | Unet model comprising 291 | Down blocks, Midblocks and Uplocks 292 | """ 293 | def __init__(self, model_config, use_up=True): 294 | super().__init__() 295 | im_channels = model_config['im_channels'] 296 | self.down_channels = model_config['down_channels'] 297 | self.mid_channels = model_config['mid_channels'] 298 | self.t_emb_dim = model_config['time_emb_dim'] 299 | self.down_sample = model_config['down_sample'] 300 | self.num_down_layers = model_config['num_down_layers'] 301 | self.num_mid_layers = model_config['num_mid_layers'] 302 | self.num_up_layers = model_config['num_up_layers'] 303 | 304 | assert self.mid_channels[0] == self.down_channels[-1] 305 | assert self.mid_channels[-1] == self.down_channels[-2] 306 | assert len(self.down_sample) == len(self.down_channels) - 1 307 | 308 | # Initial projection from sinusoidal time embedding 309 | self.t_proj = nn.Sequential( 310 | nn.Linear(self.t_emb_dim, self.t_emb_dim), 311 | nn.SiLU(), 312 | nn.Linear(self.t_emb_dim, self.t_emb_dim) 313 | ) 314 | 315 | self.up_sample = list(reversed(self.down_sample)) 316 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) 317 | 318 | self.downs = nn.ModuleList([]) 319 | for i in range(len(self.down_channels)-1): 320 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim, 321 | down_sample=self.down_sample[i], num_layers=self.num_down_layers)) 322 | 323 | self.mids = nn.ModuleList([]) 324 | for i in range(len(self.mid_channels)-1): 325 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim, 326 | num_layers=self.num_mid_layers)) 327 | 328 | if use_up: 329 | self.ups = nn.ModuleList([]) 330 | for i in reversed(range(len(self.down_channels)-1)): 331 | self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16, 332 | self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers)) 333 | 334 | self.norm_out = nn.GroupNorm(8, 16) 335 | self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1) 336 | 337 | def forward(self, x, t): 338 | # Shapes assuming downblocks are [C1, C2, C3, C4] 339 | # Shapes assuming midblocks are [C4, C4, C3] 340 | # Shapes assuming downsamples are [True, True, False] 341 | 342 | # t_emb -> B x t_emb_dim 343 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) 344 | t_emb = self.t_proj(t_emb) 345 | 346 | # B x C x H x W 347 | out = self.conv_in(x) 348 | # B x C1 x H x W 349 | 350 | down_outs = [] 351 | 352 | for idx, down in enumerate(self.downs): 353 | down_outs.append(out) 354 | out = down(out, t_emb) 355 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4] 356 | # out B x C4 x H/4 x W/4 357 | 358 | for mid in self.mids: 359 | out = mid(out, t_emb) 360 | # out B x C3 x H/4 x W/4 361 | 362 | for up in self.ups: 363 | down_out = down_outs.pop() 364 | out = up(out, down_out, t_emb) 365 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W] 366 | out = self.norm_out(out) 367 | out = nn.SiLU()(out) 368 | out = self.conv_out(out) 369 | # out B x C x H x W 370 | return out 371 | -------------------------------------------------------------------------------- /models/unet_cond_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum 3 | import torch.nn as nn 4 | from models.blocks import get_time_embedding 5 | from models.blocks import DownBlock, MidBlock, UpBlockUnet 6 | from utils.config_utils import * 7 | 8 | 9 | class Unet(nn.Module): 10 | r""" 11 | Unet model comprising 12 | Down blocks, Midblocks and Uplocks 13 | """ 14 | 15 | def __init__(self, im_channels, model_config, use_up=True): 16 | super().__init__() 17 | self.down_channels = model_config['down_channels'] 18 | self.mid_channels = model_config['mid_channels'] 19 | self.t_emb_dim = model_config['time_emb_dim'] 20 | self.down_sample = model_config['down_sample'] 21 | self.num_down_layers = model_config['num_down_layers'] 22 | self.num_mid_layers = model_config['num_mid_layers'] 23 | self.num_up_layers = model_config['num_up_layers'] 24 | self.attns = model_config['attn_down'] 25 | self.norm_channels = model_config['norm_channels'] 26 | self.num_heads = model_config['num_heads'] 27 | self.conv_out_channels = model_config['conv_out_channels'] 28 | 29 | # Validating Unet Model configurations 30 | assert self.mid_channels[0] == self.down_channels[-1] 31 | assert self.mid_channels[-1] == self.down_channels[-2] 32 | assert len(self.down_sample) == len(self.down_channels) - 1 33 | assert len(self.attns) == len(self.down_channels) - 1 34 | 35 | ######## Class, Mask and Text Conditioning Config ##### 36 | self.class_cond = False 37 | self.text_cond = False 38 | self.image_cond = False 39 | self.text_embed_dim = None 40 | self.condition_config = get_config_value(model_config, 'condition_config', None) 41 | if self.condition_config is not None: 42 | assert 'condition_types' in self.condition_config, 'Condition Type not provided in model config' 43 | condition_types = self.condition_config['condition_types'] 44 | if 'class' in condition_types: 45 | validate_class_config(self.condition_config) 46 | self.class_cond = True 47 | self.num_classes = self.condition_config['class_condition_config']['num_classes'] 48 | if 'text' in condition_types: 49 | validate_text_config(self.condition_config) 50 | self.text_cond = True 51 | self.text_embed_dim = self.condition_config['text_condition_config']['text_embed_dim'] 52 | if 'image' in condition_types: 53 | self.image_cond = True 54 | self.im_cond_input_ch = self.condition_config['image_condition_config'][ 55 | 'image_condition_input_channels'] 56 | self.im_cond_output_ch = self.condition_config['image_condition_config'][ 57 | 'image_condition_output_channels'] 58 | if self.class_cond: 59 | # Rather than using a special null class we dont add the 60 | # class embedding information for unconditional generation 61 | self.class_emb = nn.Embedding(self.num_classes, 62 | self.t_emb_dim) 63 | 64 | if self.image_cond: 65 | # Map the mask image to a N channel image and 66 | # concat that with input across channel dimension 67 | self.cond_conv_in = nn.Conv2d(in_channels=self.im_cond_input_ch, 68 | out_channels=self.im_cond_output_ch, 69 | kernel_size=1, 70 | bias=False) 71 | self.conv_in_concat = nn.Conv2d(im_channels + self.im_cond_output_ch, 72 | self.down_channels[0], kernel_size=3, padding=1) 73 | else: 74 | self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) 75 | self.cond = self.text_cond or self.image_cond or self.class_cond 76 | ################################### 77 | 78 | # Initial projection from sinusoidal time embedding 79 | self.t_proj = nn.Sequential( 80 | nn.Linear(self.t_emb_dim, self.t_emb_dim), 81 | nn.SiLU(), 82 | nn.Linear(self.t_emb_dim, self.t_emb_dim) 83 | ) 84 | 85 | self.up_sample = list(reversed(self.down_sample)) 86 | self.downs = nn.ModuleList([]) 87 | 88 | # Build the Downblocks 89 | for i in range(len(self.down_channels) - 1): 90 | # Cross Attention and Context Dim only needed if text condition is present 91 | self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, 92 | down_sample=self.down_sample[i], 93 | num_heads=self.num_heads, 94 | num_layers=self.num_down_layers, 95 | attn=self.attns[i], norm_channels=self.norm_channels, 96 | cross_attn=self.text_cond, 97 | context_dim=self.text_embed_dim)) 98 | 99 | self.mids = nn.ModuleList([]) 100 | # Build the Midblocks 101 | for i in range(len(self.mid_channels) - 1): 102 | self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, 103 | num_heads=self.num_heads, 104 | num_layers=self.num_mid_layers, 105 | norm_channels=self.norm_channels, 106 | cross_attn=self.text_cond, 107 | context_dim=self.text_embed_dim)) 108 | 109 | self.ups = nn.ModuleList([]) 110 | if use_up: 111 | # Build the Upblocks 112 | for i in reversed(range(len(self.down_channels) - 1)): 113 | self.ups.append( 114 | UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, 115 | self.t_emb_dim, up_sample=self.down_sample[i], 116 | num_heads=self.num_heads, 117 | num_layers=self.num_up_layers, 118 | norm_channels=self.norm_channels, 119 | cross_attn=self.text_cond, 120 | context_dim=self.text_embed_dim)) 121 | 122 | self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) 123 | self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) 124 | 125 | def forward(self, x, t, cond_input=None): 126 | # Shapes assuming downblocks are [C1, C2, C3, C4] 127 | # Shapes assuming midblocks are [C4, C4, C3] 128 | # Shapes assuming downsamples are [True, True, False] 129 | if self.cond: 130 | assert cond_input is not None, \ 131 | "Model initialized with conditioning so cond_input cannot be None" 132 | if self.image_cond: 133 | ######## Mask Conditioning ######## 134 | validate_image_conditional_input(cond_input, x) 135 | im_cond = cond_input['image'] 136 | im_cond = torch.nn.functional.interpolate(im_cond, size=x.shape[-2:]) 137 | im_cond = self.cond_conv_in(im_cond) 138 | assert im_cond.shape[-2:] == x.shape[-2:] 139 | x = torch.cat([x, im_cond], dim=1) 140 | # B x (C+N) x H x W 141 | out = self.conv_in_concat(x) 142 | ##################################### 143 | else: 144 | # B x C x H x W 145 | out = self.conv_in(x) 146 | # B x C1 x H x W 147 | 148 | # t_emb -> B x t_emb_dim 149 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) 150 | t_emb = self.t_proj(t_emb) 151 | 152 | ######## Class Conditioning ######## 153 | if self.class_cond: 154 | validate_class_conditional_input(cond_input, x, self.num_classes) 155 | class_embed = einsum(cond_input['class'].float(), self.class_emb.weight, 'b n, n d -> b d') 156 | t_emb += class_embed 157 | #################################### 158 | 159 | context_hidden_states = None 160 | if self.text_cond: 161 | assert 'text' in cond_input, \ 162 | "Model initialized with text conditioning but cond_input has no text information" 163 | context_hidden_states = cond_input['text'] 164 | down_outs = [] 165 | 166 | for idx, down in enumerate(self.downs): 167 | down_outs.append(out) 168 | out = down(out, t_emb, context_hidden_states) 169 | # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4] 170 | # out B x C4 x H/4 x W/4 171 | 172 | for mid in self.mids: 173 | out = mid(out, t_emb, context_hidden_states) 174 | # out B x C3 x H/4 x W/4 175 | 176 | for up in self.ups: 177 | down_out = down_outs.pop() 178 | out = up(out, down_out, t_emb, context_hidden_states) 179 | # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W] 180 | out = self.norm_out(out) 181 | out = nn.SiLU()(out) 182 | out = self.conv_out(out) 183 | # out B x C x H x W 184 | return out 185 | -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.blocks import DownBlock, MidBlock, UpBlock 4 | 5 | 6 | class VAE(nn.Module): 7 | def __init__(self, im_channels, model_config): 8 | super().__init__() 9 | self.down_channels = model_config['down_channels'] 10 | self.mid_channels = model_config['mid_channels'] 11 | self.down_sample = model_config['down_sample'] 12 | self.num_down_layers = model_config['num_down_layers'] 13 | self.num_mid_layers = model_config['num_mid_layers'] 14 | self.num_up_layers = model_config['num_up_layers'] 15 | 16 | # To disable attention in Downblock of Encoder and Upblock of Decoder 17 | self.attns = model_config['attn_down'] 18 | 19 | # Latent Dimension 20 | self.z_channels = model_config['z_channels'] 21 | self.norm_channels = model_config['norm_channels'] 22 | self.num_heads = model_config['num_heads'] 23 | 24 | # Assertion to validate the channel information 25 | assert self.mid_channels[0] == self.down_channels[-1] 26 | assert self.mid_channels[-1] == self.down_channels[-1] 27 | assert len(self.down_sample) == len(self.down_channels) - 1 28 | assert len(self.attns) == len(self.down_channels) - 1 29 | 30 | # Wherever we use downsampling in encoder correspondingly use 31 | # upsampling in decoder 32 | self.up_sample = list(reversed(self.down_sample)) 33 | 34 | ##################### Encoder ###################### 35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) 36 | 37 | # Downblock + Midblock 38 | self.encoder_layers = nn.ModuleList([]) 39 | for i in range(len(self.down_channels) - 1): 40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], 41 | t_emb_dim=None, down_sample=self.down_sample[i], 42 | num_heads=self.num_heads, 43 | num_layers=self.num_down_layers, 44 | attn=self.attns[i], 45 | norm_channels=self.norm_channels)) 46 | 47 | self.encoder_mids = nn.ModuleList([]) 48 | for i in range(len(self.mid_channels) - 1): 49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], 50 | t_emb_dim=None, 51 | num_heads=self.num_heads, 52 | num_layers=self.num_mid_layers, 53 | norm_channels=self.norm_channels)) 54 | 55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) 56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2 * self.z_channels, kernel_size=3, padding=1) 57 | 58 | # Latent Dimension is 2*Latent because we are predicting mean & variance 59 | self.pre_quant_conv = nn.Conv2d(2 * self.z_channels, 2 * self.z_channels, kernel_size=1) 60 | #################################################### 61 | 62 | ##################### Decoder ###################### 63 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) 64 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) 65 | 66 | # Midblock + Upblock 67 | self.decoder_mids = nn.ModuleList([]) 68 | for i in reversed(range(1, len(self.mid_channels))): 69 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], 70 | t_emb_dim=None, 71 | num_heads=self.num_heads, 72 | num_layers=self.num_mid_layers, 73 | norm_channels=self.norm_channels)) 74 | 75 | self.decoder_layers = nn.ModuleList([]) 76 | for i in reversed(range(1, len(self.down_channels))): 77 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], 78 | t_emb_dim=None, up_sample=self.down_sample[i - 1], 79 | num_heads=self.num_heads, 80 | num_layers=self.num_up_layers, 81 | attn=self.attns[i - 1], 82 | norm_channels=self.norm_channels)) 83 | 84 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) 85 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) 86 | 87 | def encode(self, x): 88 | out = self.encoder_conv_in(x) 89 | for idx, down in enumerate(self.encoder_layers): 90 | out = down(out) 91 | for mid in self.encoder_mids: 92 | out = mid(out) 93 | out = self.encoder_norm_out(out) 94 | out = nn.SiLU()(out) 95 | out = self.encoder_conv_out(out) 96 | out = self.pre_quant_conv(out) 97 | mean, logvar = torch.chunk(out, 2, dim=1) 98 | std = torch.exp(0.5 * logvar) 99 | sample = mean + std * torch.randn(mean.shape).to(device=x.device) 100 | return sample, out 101 | 102 | def decode(self, z): 103 | out = z 104 | out = self.post_quant_conv(out) 105 | out = self.decoder_conv_in(out) 106 | for mid in self.decoder_mids: 107 | out = mid(out) 108 | for idx, up in enumerate(self.decoder_layers): 109 | out = up(out) 110 | 111 | out = self.decoder_norm_out(out) 112 | out = nn.SiLU()(out) 113 | out = self.decoder_conv_out(out) 114 | return out 115 | 116 | def forward(self, x): 117 | z, encoder_output = self.encode(x) 118 | out = self.decode(z) 119 | return out, encoder_output 120 | 121 | -------------------------------------------------------------------------------- /models/weights/v0.1/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/models/weights/v0.1/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | numpy==2.0.1 3 | opencv_python==4.10.0.84 4 | Pillow==10.4.0 5 | PyYAML==6.0.1 6 | torch==2.3.1 7 | torchvision==0.18.1 8 | tqdm==4.66.4 9 | -------------------------------------------------------------------------------- /scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/scheduler/__init__.py -------------------------------------------------------------------------------- /scheduler/linear_noise_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LinearNoiseScheduler: 5 | r""" 6 | Class for the linear noise scheduler that is used in DDPM. 7 | """ 8 | def __init__(self, num_timesteps, beta_start, beta_end, ldm_scheduler=False): 9 | self.num_timesteps = num_timesteps 10 | self.beta_start = beta_start 11 | self.beta_end = beta_end 12 | 13 | if ldm_scheduler: 14 | # Mimicking how compvis repo creates schedule 15 | self.betas = ( 16 | torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2 17 | ) 18 | else: 19 | self.betas = torch.linspace(beta_start, beta_end, num_timesteps) 20 | self.alphas = 1. - self.betas 21 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) 22 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) 23 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) 24 | 25 | def add_noise(self, original, noise, t): 26 | r""" 27 | Forward method for diffusion 28 | :param original: Image on which noise is to be applied 29 | :param noise: Random Noise Tensor (from normal dist) 30 | :param t: timestep of the forward process of shape -> (B,) 31 | :return: 32 | """ 33 | original_shape = original.shape 34 | batch_size = original_shape[0] 35 | 36 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 37 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 38 | 39 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W) 40 | for _ in range(len(original_shape) - 1): 41 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) 42 | for _ in range(len(original_shape) - 1): 43 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) 44 | 45 | # Apply and Return Forward process equation 46 | return (sqrt_alpha_cum_prod.to(original.device) * original 47 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) 48 | 49 | def sample_prev_timestep(self, xt, noise_pred, t): 50 | r""" 51 | Use the noise prediction by model to get 52 | xt-1 using xt and the noise predicted 53 | :param xt: current timestep sample 54 | :param noise_pred: model noise prediction 55 | :param t: current timestep we are at 56 | :return: 57 | """ 58 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / 59 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) 60 | x0 = torch.clamp(x0, -1., 1.) 61 | 62 | mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) 63 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) 64 | 65 | if t == 0: 66 | return mean, x0 67 | else: 68 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) 69 | variance = variance * self.betas.to(xt.device)[t] 70 | sigma = variance ** 0.5 71 | z = torch.randn(xt.shape).to(xt.device) 72 | 73 | # OR 74 | # variance = self.betas[t] 75 | # sigma = variance ** 0.5 76 | # z = torch.randn(xt.shape).to(xt.device) 77 | return mean + sigma * z, x0 78 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/tools/__init__.py -------------------------------------------------------------------------------- /tools/infer_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pickle 5 | 6 | import torch 7 | import torchvision 8 | import yaml 9 | from torch.utils.data.dataloader import DataLoader 10 | from torchvision.utils import make_grid 11 | from tqdm import tqdm 12 | 13 | from dataset.celeb_dataset import CelebDataset 14 | from dataset.mnist_dataset import MnistDataset 15 | from models.vae import VAE 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | if torch.backends.mps.is_available(): 19 | device = torch.device('mps') 20 | print('Using mps') 21 | 22 | 23 | def infer(args): 24 | ######## Read the config file ####### 25 | with open(args.config_path, 'r') as file: 26 | try: 27 | config = yaml.safe_load(file) 28 | except yaml.YAMLError as exc: 29 | print(exc) 30 | print(config) 31 | 32 | dataset_config = config['dataset_params'] 33 | autoencoder_config = config['autoencoder_params'] 34 | train_config = config['train_params'] 35 | 36 | # Create the dataset 37 | im_dataset_cls = { 38 | 'mnist': MnistDataset, 39 | 'celebhq': CelebDataset, 40 | }.get(dataset_config['name']) 41 | 42 | im_dataset = im_dataset_cls(split='train', 43 | im_path=dataset_config['im_path'], 44 | im_size=dataset_config['im_size'], 45 | im_channels=dataset_config['im_channels']) 46 | 47 | # This is only used for saving latents. Which as of now 48 | # is not done in batches hence batch size 1 49 | data_loader = DataLoader(im_dataset, 50 | batch_size=1, 51 | shuffle=False) 52 | 53 | num_images = train_config['num_samples'] 54 | ngrid = train_config['num_grid_rows'] 55 | 56 | idxs = torch.randint(0, len(im_dataset) - 1, (num_images,)) 57 | ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float() 58 | ims = ims.to(device) 59 | 60 | model = VAE(im_channels=dataset_config['im_channels'], 61 | model_config=autoencoder_config).to(device) 62 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 63 | train_config['vae_autoencoder_ckpt_name']), 64 | map_location=device)) 65 | model.eval() 66 | 67 | with torch.no_grad(): 68 | 69 | encoded_output, _ = model.encode(ims) 70 | decoded_output = model.decode(encoded_output) 71 | encoded_output = torch.clamp(encoded_output, -1., 1.) 72 | encoded_output = (encoded_output + 1) / 2 73 | decoded_output = torch.clamp(decoded_output, -1., 1.) 74 | decoded_output = (decoded_output + 1) / 2 75 | ims = (ims + 1) / 2 76 | 77 | encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid) 78 | decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid) 79 | input_grid = make_grid(ims.cpu(), nrow=ngrid) 80 | encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid) 81 | decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid) 82 | input_grid = torchvision.transforms.ToPILImage()(input_grid) 83 | 84 | input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png')) 85 | encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png')) 86 | decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png')) 87 | 88 | if train_config['save_latents']: 89 | # save Latents (but in a very unoptimized way) 90 | latent_path = os.path.join(train_config['task_name'], train_config['vae_latent_dir_name']) 91 | latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vae_latent_dir_name'], 92 | '*.pkl')) 93 | assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run' 94 | if not os.path.exists(latent_path): 95 | os.mkdir(latent_path) 96 | print('Saving Latents for {}'.format(dataset_config['name'])) 97 | 98 | fname_latent_map = {} 99 | part_count = 0 100 | count = 0 101 | for idx, im in enumerate(tqdm(data_loader)): 102 | _, encoded_output = model.encode(im.float().to(device)) 103 | fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu() 104 | # Save latents every 1000 images 105 | if (count + 1) % 1000 == 0: 106 | pickle.dump(fname_latent_map, open(os.path.join(latent_path, 107 | '{}.pkl'.format(part_count)), 'wb')) 108 | part_count += 1 109 | fname_latent_map = {} 110 | count += 1 111 | if len(fname_latent_map) > 0: 112 | pickle.dump(fname_latent_map, open(os.path.join(latent_path, 113 | '{}.pkl'.format(part_count)), 'wb')) 114 | print('Done saving latents') 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser(description='Arguments for vae inference') 119 | parser.add_argument('--config', dest='config_path', 120 | default='config/celebhq.yaml', type=str) 121 | args = parser.parse_args() 122 | infer(args) 123 | -------------------------------------------------------------------------------- /tools/sample_ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | from torchvision.utils import make_grid 7 | from tqdm import tqdm 8 | from models.unet_base import Unet 9 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 10 | 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | if torch.backends.mps.is_available(): 14 | device = torch.device('mps') 15 | print('Using mps') 16 | 17 | 18 | def sample(model, scheduler, train_config, model_config, diffusion_config): 19 | r""" 20 | Sample stepwise by going backward one timestep at a time. 21 | We save the x0 predictions 22 | """ 23 | xt = torch.randn((train_config['num_samples'], 24 | model_config['im_channels'], 25 | model_config['im_size'], 26 | model_config['im_size'])).to(device) 27 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 28 | # Get prediction of noise 29 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) 30 | 31 | # Use scheduler to get x0 and xt-1 32 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 33 | 34 | # Save x0 35 | ims = torch.clamp(xt, -1., 1.).detach().cpu() 36 | ims = (ims + 1) / 2 37 | grid = make_grid(ims, nrow=train_config['num_grid_rows']) 38 | img = torchvision.transforms.ToPILImage()(grid) 39 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')): 40 | os.mkdir(os.path.join(train_config['task_name'], 'samples')) 41 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i))) 42 | img.close() 43 | 44 | 45 | def infer(args): 46 | # Read the config file # 47 | with open(args.config_path, 'r') as file: 48 | try: 49 | config = yaml.safe_load(file) 50 | except yaml.YAMLError as exc: 51 | print(exc) 52 | print(config) 53 | ######################## 54 | 55 | diffusion_config = config['diffusion_params'] 56 | model_config = config['model_params'] 57 | train_config = config['train_params'] 58 | 59 | # Load model with checkpoint 60 | model = Unet(model_config).to(device) 61 | assert os.path.exists(os.path.join(train_config['task_name'], 62 | train_config['ddpm_ckpt_name'])), "Train DDPM first" 63 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 64 | train_config['ddpm_ckpt_name']), map_location=device)) 65 | model.eval() 66 | 67 | # Create the noise scheduler 68 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 69 | beta_start=diffusion_config['beta_start'], 70 | beta_end=diffusion_config['beta_end']) 71 | with torch.no_grad(): 72 | sample(model, scheduler, train_config, model_config, diffusion_config) 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser(description='Arguments for ddpm image generation') 77 | parser.add_argument('--config', dest='config_path', 78 | default='config/mnist.yaml', type=str) 79 | args = parser.parse_args() 80 | infer(args) 81 | -------------------------------------------------------------------------------- /tools/sample_ddpm_controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | import random 7 | from torchvision.utils import make_grid 8 | from tqdm import tqdm 9 | from dataset.mnist_dataset import MnistDataset 10 | from models.controlnet import ControlNet 11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 12 | 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | if torch.backends.mps.is_available(): 16 | device = torch.device('mps') 17 | print('Using mps') 18 | 19 | 20 | def sample(model, scheduler, train_config, model_config, diffusion_config, dataset): 21 | r""" 22 | Sample stepwise by going backward one timestep at a time. 23 | We save the x0 predictions 24 | """ 25 | xt = torch.randn((train_config['num_samples'], 26 | model_config['im_channels'], 27 | model_config['im_size'], 28 | model_config['im_size'])).to(device) 29 | 30 | # Get random hints for the desired number of samples 31 | hints = [] 32 | for idx in range(train_config['num_samples']): 33 | hint_idx = random.randint(0, len(dataset)) 34 | hints.append(dataset[hint_idx][1].unsqueeze(0).to(device)) 35 | hints = torch.cat(hints, dim=0).to(device) 36 | 37 | # Save the hints 38 | hints_grid = make_grid(hints, nrow=train_config['num_grid_rows']) 39 | hints_img = torchvision.transforms.ToPILImage()(hints_grid) 40 | hints_img.save(os.path.join(train_config['task_name'], 'hint.png')) 41 | 42 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 43 | # Get prediction of noise 44 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device), hints) 45 | 46 | # Prediction from original model 47 | # noise_pred = model.trained_unet(xt, torch.as_tensor(i).unsqueeze(0).to(device)) 48 | 49 | # Use scheduler to get x0 and xt-1 50 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 51 | 52 | # Save x0 53 | ims = torch.clamp(xt, -1., 1.).detach().cpu() 54 | ims = (ims + 1) / 2 55 | grid = make_grid(ims, nrow=train_config['num_grid_rows']) 56 | img = torchvision.transforms.ToPILImage()(grid) 57 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples_controlnet')): 58 | os.mkdir(os.path.join(train_config['task_name'], 'samples_controlnet')) 59 | img.save(os.path.join(train_config['task_name'], 'samples_controlnet', 'x0_{}.png'.format(i))) 60 | img.close() 61 | 62 | 63 | def infer(args): 64 | # Read the config file # 65 | with open(args.config_path, 'r') as file: 66 | try: 67 | config = yaml.safe_load(file) 68 | except yaml.YAMLError as exc: 69 | print(exc) 70 | print(config) 71 | ######################## 72 | 73 | diffusion_config = config['diffusion_params'] 74 | model_config = config['model_params'] 75 | train_config = config['train_params'] 76 | dataset_config = config['dataset_params'] 77 | 78 | # Change to require hints 79 | mnist_canny = MnistDataset('test', im_path=dataset_config['im_test_path'], return_hints=True) 80 | 81 | # Load model with checkpoint 82 | model = ControlNet(model_config, 83 | model_ckpt=os.path.join(train_config['task_name'], train_config['ddpm_ckpt_name']), 84 | device=device).to(device) 85 | 86 | assert os.path.exists(os.path.join(train_config['task_name'], 87 | train_config['controlnet_ckpt_name'])), "Train ControlNet first" 88 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 89 | train_config['controlnet_ckpt_name']), 90 | map_location=device)) 91 | model.eval() 92 | print('Loaded ControlNet checkpoint') 93 | 94 | # Create the noise scheduler 95 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 96 | beta_start=diffusion_config['beta_start'], 97 | beta_end=diffusion_config['beta_end']) 98 | with torch.no_grad(): 99 | sample(model, scheduler, train_config, model_config, diffusion_config, mnist_canny) 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser(description='Arguments for controlnet ddpm image generation') 104 | parser.add_argument('--config', dest='config_path', 105 | default='config/mnist.yaml', type=str) 106 | args = parser.parse_args() 107 | infer(args) 108 | -------------------------------------------------------------------------------- /tools/sample_ldm_controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | import random 7 | from torchvision.utils import make_grid 8 | from tqdm import tqdm 9 | from models.controlnet_ldm import ControlNet 10 | from dataset.celeb_dataset import CelebDataset 11 | from models.vae import VAE 12 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | if torch.backends.mps.is_available(): 16 | device = torch.device('mps') 17 | print('Using mps') 18 | 19 | 20 | def sample(model, scheduler, train_config, diffusion_model_config, 21 | autoencoder_model_config, diffusion_config, dataset_config, vae, dataset): 22 | r""" 23 | Sample stepwise by going backward one timestep at a time. 24 | We save the x0 predictions 25 | """ 26 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 27 | xt = torch.randn((train_config['num_samples'], 28 | autoencoder_model_config['z_channels'], 29 | im_size, 30 | im_size)).to(device) 31 | 32 | # Get random hints for the desired number of samples 33 | hints = [] 34 | 35 | for idx in range(train_config['num_samples']): 36 | hint_idx = random.randint(0, len(dataset)) 37 | hints.append(dataset[hint_idx][1].unsqueeze(0).to(device)) 38 | hints = torch.cat(hints, dim=0).to(device) 39 | 40 | # Save the hints 41 | hints_grid = make_grid(hints, nrow=train_config['num_grid_rows']) 42 | hints_img = torchvision.transforms.ToPILImage()(hints_grid) 43 | hints_img.save(os.path.join(train_config['task_name'], 'hint.png')) 44 | 45 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 46 | # Get prediction of noise 47 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device), hints) 48 | 49 | # Use scheduler to get x0 and xt-1 50 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 51 | 52 | # Save x0 53 | # ims = torch.clamp(xt, -1., 1.).detach().cpu() 54 | if i == 0: 55 | # Decode ONLY the final image to save time 56 | ims = vae.to(device).decode(xt) 57 | else: 58 | ims = xt 59 | 60 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 61 | ims = (ims + 1) / 2 62 | grid = make_grid(ims, nrow=train_config['num_grid_rows']) 63 | img = torchvision.transforms.ToPILImage()(grid) 64 | 65 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples_controlnet')): 66 | os.mkdir(os.path.join(train_config['task_name'], 'samples_controlnet')) 67 | img.save(os.path.join(train_config['task_name'], 'samples_controlnet', 'x0_{}.png'.format(i))) 68 | img.close() 69 | 70 | 71 | def infer(args): 72 | # Read the config file # 73 | with open(args.config_path, 'r') as file: 74 | try: 75 | config = yaml.safe_load(file) 76 | except yaml.YAMLError as exc: 77 | print(exc) 78 | print(config) 79 | ######################## 80 | 81 | diffusion_config = config['diffusion_params'] 82 | dataset_config = config['dataset_params'] 83 | diffusion_model_config = config['ldm_params'] 84 | autoencoder_model_config = config['autoencoder_params'] 85 | train_config = config['train_params'] 86 | 87 | # Create the noise scheduler 88 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 89 | beta_start=diffusion_config['beta_start'], 90 | beta_end=diffusion_config['beta_end'], 91 | ldm_scheduler=True) 92 | 93 | celeb_canny = CelebDataset('test', 94 | im_path=dataset_config['im_path'], 95 | im_size=dataset_config['im_size'], 96 | return_hint=True) 97 | 98 | # Instantiate the model 99 | latent_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 100 | downscale_factor = dataset_config['canny_im_size'] // latent_size 101 | model = ControlNet(im_channels=autoencoder_model_config['z_channels'], 102 | model_config=diffusion_model_config, 103 | model_locked=True, 104 | model_ckpt=os.path.join(train_config['task_name'], train_config['ldm_ckpt_name']), 105 | device=device, 106 | down_sample_factor=downscale_factor).to(device) 107 | model.eval() 108 | 109 | assert os.path.exists(os.path.join(train_config['task_name'], 110 | train_config['controlnet_ckpt_name'])), "Train ControlNet first" 111 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 112 | train_config['controlnet_ckpt_name']), 113 | map_location=device)) 114 | print('Loaded controlnet checkpoint') 115 | 116 | vae = VAE(im_channels=dataset_config['im_channels'], 117 | model_config=autoencoder_model_config) 118 | vae.eval() 119 | 120 | # Load vae if found 121 | assert os.path.exists(os.path.join(train_config['task_name'], train_config['vae_autoencoder_ckpt_name'])), \ 122 | "VAE checkpoint not present. Train VAE first." 123 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 124 | train_config['vae_autoencoder_ckpt_name']), 125 | map_location=device), strict=True) 126 | print('Loaded vae checkpoint') 127 | 128 | with torch.no_grad(): 129 | sample(model, scheduler, train_config, diffusion_model_config, 130 | autoencoder_model_config, diffusion_config, dataset_config, vae, celeb_canny) 131 | 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser(description='Arguments for ldm controlnet generation') 135 | parser.add_argument('--config', dest='config_path', 136 | default='config/celebhq.yaml', type=str) 137 | args = parser.parse_args() 138 | infer(args) 139 | -------------------------------------------------------------------------------- /tools/sample_ldm_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | from torchvision.utils import make_grid 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from models.unet_cond_base import Unet 10 | from models.vae import VAE 11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | 19 | def sample(model, scheduler, train_config, diffusion_model_config, 20 | autoencoder_model_config, diffusion_config, dataset_config, vae): 21 | r""" 22 | Sample stepwise by going backward one timestep at a time. 23 | We save the x0 predictions 24 | """ 25 | im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 26 | xt = torch.randn((train_config['num_samples'], 27 | autoencoder_model_config['z_channels'], 28 | im_size, 29 | im_size)).to(device) 30 | 31 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 32 | # Get prediction of noise 33 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) 34 | 35 | # Use scheduler to get x0 and xt-1 36 | xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) 37 | 38 | # Save x0 39 | # ims = torch.clamp(xt, -1., 1.).detach().cpu() 40 | if i == 0: 41 | # Decode ONLY the final image to save time 42 | ims = vae.to(device).decode(xt) 43 | else: 44 | ims = xt 45 | 46 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 47 | ims = (ims + 1) / 2 48 | 49 | grid = make_grid(ims, nrow=train_config['num_grid_rows']) 50 | img = torchvision.transforms.ToPILImage()(grid) 51 | 52 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')): 53 | os.mkdir(os.path.join(train_config['task_name'], 'samples')) 54 | img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i))) 55 | img.close() 56 | 57 | 58 | def infer(args): 59 | # Read the config file # 60 | with open(args.config_path, 'r') as file: 61 | try: 62 | config = yaml.safe_load(file) 63 | except yaml.YAMLError as exc: 64 | print(exc) 65 | print(config) 66 | ######################## 67 | 68 | diffusion_config = config['diffusion_params'] 69 | dataset_config = config['dataset_params'] 70 | diffusion_model_config = config['ldm_params'] 71 | autoencoder_model_config = config['autoencoder_params'] 72 | train_config = config['train_params'] 73 | 74 | # Create the noise scheduler 75 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 76 | beta_start=diffusion_config['beta_start'], 77 | beta_end=diffusion_config['beta_end'], 78 | ldm_scheduler=True) 79 | 80 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 81 | model_config=diffusion_model_config).to(device) 82 | model.eval() 83 | 84 | assert os.path.exists(os.path.join(train_config['task_name'], 85 | train_config['ldm_ckpt_name'])), "Train LDM first" 86 | 87 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 88 | train_config['ldm_ckpt_name']), 89 | map_location=device)) 90 | print('Loaded unet checkpoint') 91 | 92 | # Create output directories 93 | if not os.path.exists(train_config['task_name']): 94 | os.mkdir(train_config['task_name']) 95 | 96 | vae = VAE(im_channels=dataset_config['im_channels'], 97 | model_config=autoencoder_model_config) 98 | vae.eval() 99 | 100 | # Load vae if found 101 | assert os.path.exists(os.path.join(train_config['task_name'], train_config['vae_autoencoder_ckpt_name'])), \ 102 | "VAE checkpoint not present. Train VAE first." 103 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 104 | train_config['vae_autoencoder_ckpt_name']), 105 | map_location=device), strict=True) 106 | print('Loaded vae checkpoint') 107 | 108 | with torch.no_grad(): 109 | sample(model, scheduler, train_config, diffusion_model_config, 110 | autoencoder_model_config, diffusion_config, dataset_config, vae) 111 | 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser(description='Arguments for ldm image generation') 115 | parser.add_argument('--config', dest='config_path', 116 | default='config/celebhq.yaml', type=str) 117 | args = parser.parse_args() 118 | infer(args) 119 | -------------------------------------------------------------------------------- /tools/train_ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.optim import Adam 8 | from dataset.mnist_dataset import MnistDataset 9 | from torch.utils.data import DataLoader 10 | from models.unet_base import Unet 11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | 19 | def train(args): 20 | # Read the config file # 21 | with open(args.config_path, 'r') as file: 22 | try: 23 | config = yaml.safe_load(file) 24 | except yaml.YAMLError as exc: 25 | print(exc) 26 | print(config) 27 | ######################## 28 | 29 | diffusion_config = config['diffusion_params'] 30 | dataset_config = config['dataset_params'] 31 | model_config = config['model_params'] 32 | train_config = config['train_params'] 33 | 34 | # Create the noise scheduler 35 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 36 | beta_start=diffusion_config['beta_start'], 37 | beta_end=diffusion_config['beta_end']) 38 | 39 | # Create the dataset 40 | mnist = MnistDataset('train', im_path=dataset_config['im_path']) 41 | mnist_loader = DataLoader(mnist, 42 | batch_size=train_config['batch_size'], 43 | shuffle=True, 44 | num_workers=4) 45 | 46 | # Instantiate the model 47 | model = Unet(model_config).to(device) 48 | model.train() 49 | 50 | # Create output directories 51 | if not os.path.exists(train_config['task_name']): 52 | os.mkdir(train_config['task_name']) 53 | 54 | # Load checkpoint if found 55 | if os.path.exists(os.path.join(train_config['task_name'],train_config['ddpm_ckpt_name'])): 56 | print('Loading checkpoint as found one') 57 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 58 | train_config['ddpm_ckpt_name']), map_location=device)) 59 | # Specify training parameters 60 | num_epochs = train_config['num_epochs'] 61 | optimizer = Adam(model.parameters(), lr=train_config['ddpm_lr']) 62 | criterion = torch.nn.MSELoss() 63 | 64 | # Run training 65 | for epoch_idx in range(num_epochs): 66 | losses = [] 67 | for im in tqdm(mnist_loader): 68 | optimizer.zero_grad() 69 | im = im.float().to(device) 70 | 71 | # Sample random noise 72 | noise = torch.randn_like(im).to(device) 73 | 74 | # Sample timestep 75 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device) 76 | 77 | # Add noise to images according to timestep 78 | noisy_im = scheduler.add_noise(im, noise, t) 79 | noise_pred = model(noisy_im, t) 80 | 81 | loss = criterion(noise_pred, noise) 82 | losses.append(loss.item()) 83 | loss.backward() 84 | optimizer.step() 85 | print('Finished epoch:{} | Loss : {:.4f}'.format( 86 | epoch_idx + 1, 87 | np.mean(losses), 88 | )) 89 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 90 | train_config['ddpm_ckpt_name'])) 91 | 92 | print('Done Training ...') 93 | 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser(description='Arguments for ddpm training') 97 | parser.add_argument('--config', dest='config_path', 98 | default='config/mnist.yaml', type=str) 99 | args = parser.parse_args() 100 | train(args) 101 | -------------------------------------------------------------------------------- /tools/train_ddpm_controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.optim import Adam 8 | from dataset.mnist_dataset import MnistDataset 9 | from torch.utils.data import DataLoader 10 | from models.controlnet import ControlNet 11 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | 19 | def train(args): 20 | # Read the config file # 21 | with open(args.config_path, 'r') as file: 22 | try: 23 | config = yaml.safe_load(file) 24 | except yaml.YAMLError as exc: 25 | print(exc) 26 | print(config) 27 | ######################## 28 | 29 | diffusion_config = config['diffusion_params'] 30 | dataset_config = config['dataset_params'] 31 | model_config = config['model_params'] 32 | train_config = config['train_params'] 33 | 34 | # Create the noise scheduler 35 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 36 | beta_start=diffusion_config['beta_start'], 37 | beta_end=diffusion_config['beta_end']) 38 | 39 | # Create the dataset 40 | mnist = MnistDataset('train', 41 | im_path=dataset_config['im_path'], 42 | return_hints=True) 43 | mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True) 44 | 45 | # Load model with checkpoint 46 | model = ControlNet(model_config, 47 | model_locked=True, 48 | model_ckpt=os.path.join(train_config['task_name'], 49 | train_config['ddpm_ckpt_name']), 50 | device=device).to(device) 51 | model.train() 52 | 53 | # Create output directories 54 | if not os.path.exists(train_config['task_name']): 55 | os.mkdir(train_config['task_name']) 56 | 57 | # Load checkpoint if found 58 | if os.path.exists(os.path.join(train_config['task_name'], 59 | train_config['controlnet_ckpt_name'])): 60 | print('Loading checkpoint as found one') 61 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 62 | train_config['controlnet_ckpt_name']), 63 | map_location=device)) 64 | 65 | # Specify training parameters 66 | num_epochs = train_config['controlnet_epochs'] 67 | optimizer = Adam(model.get_params(), lr=train_config['controlnet_lr']) 68 | criterion = torch.nn.MSELoss() 69 | 70 | # Run training 71 | steps = 0 72 | for epoch_idx in range(num_epochs): 73 | losses = [] 74 | for im, hint in tqdm(mnist_loader): 75 | optimizer.zero_grad() 76 | 77 | im = im.float().to(device) 78 | hint = hint.float().to(device) 79 | 80 | # Sample random noise 81 | noise = torch.randn_like(im).to(device) 82 | 83 | # Sample timestep 84 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device) 85 | 86 | # Add noise to images according to timestep 87 | noisy_im = scheduler.add_noise(im, noise, t) 88 | 89 | # Additionally start passing the hint 90 | noise_pred = model(noisy_im, t, hint) 91 | 92 | loss = criterion(noise_pred, noise) 93 | losses.append(loss.item()) 94 | loss.backward() 95 | optimizer.step() 96 | steps += 1 97 | print('Finished epoch:{} | Loss : {:.4f}'.format( 98 | epoch_idx + 1, 99 | np.mean(losses), 100 | )) 101 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 102 | train_config['controlnet_ckpt_name'])) 103 | 104 | print('Done Training ...') 105 | 106 | 107 | if __name__ == '__main__': 108 | parser = argparse.ArgumentParser(description='Arguments for controlnet ddpm training') 109 | parser.add_argument('--config', dest='config_path', 110 | default='config/mnist.yaml', type=str) 111 | args = parser.parse_args() 112 | train(args) 113 | -------------------------------------------------------------------------------- /tools/train_ldm_controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.optim import Adam 8 | from dataset.celeb_dataset import CelebDataset 9 | from torch.utils.data import DataLoader 10 | from models.controlnet_ldm import ControlNet 11 | from models.vae import VAE 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | if torch.backends.mps.is_available(): 17 | device = torch.device('mps') 18 | print('Using mps') 19 | 20 | 21 | def train(args): 22 | # Read the config file # 23 | with open(args.config_path, 'r') as file: 24 | try: 25 | config = yaml.safe_load(file) 26 | except yaml.YAMLError as exc: 27 | print(exc) 28 | print(config) 29 | ######################## 30 | 31 | diffusion_config = config['diffusion_params'] 32 | dataset_config = config['dataset_params'] 33 | diffusion_model_config = config['ldm_params'] 34 | autoencoder_model_config = config['autoencoder_params'] 35 | train_config = config['train_params'] 36 | 37 | # Create the noise scheduler 38 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 39 | beta_start=diffusion_config['beta_start'], 40 | beta_end=diffusion_config['beta_end'], 41 | ldm_scheduler=True) 42 | 43 | im_dataset = CelebDataset(split='train', 44 | im_path=dataset_config['im_path'], 45 | im_size=dataset_config['im_size'], 46 | im_channels=dataset_config['im_channels'], 47 | use_latents=True, 48 | latent_path=os.path.join(train_config['task_name'], 49 | train_config['vae_latent_dir_name']), 50 | return_hint=True 51 | ) 52 | 53 | data_loader = DataLoader(im_dataset, 54 | batch_size=train_config['ldm_batch_size'], 55 | shuffle=True) 56 | 57 | # Instantiate the model 58 | # downscale factor = canny_image_size // latent_size 59 | latent_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample']) 60 | downscale_factor = dataset_config['canny_im_size'] // latent_size 61 | model = ControlNet(im_channels=autoencoder_model_config['z_channels'], 62 | model_config=diffusion_model_config, 63 | model_locked=True, 64 | model_ckpt=os.path.join(train_config['task_name'], train_config['ldm_ckpt_name']), 65 | device=device, 66 | down_sample_factor=downscale_factor).to(device) 67 | model.train() 68 | # Create output directories 69 | if not os.path.exists(train_config['task_name']): 70 | os.mkdir(train_config['task_name']) 71 | 72 | # Load checkpoint if found 73 | if os.path.exists(os.path.join(train_config['task_name'], train_config['controlnet_ckpt_name'])): 74 | print('Loading checkpoint as found one') 75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 76 | train_config['controlnet_ckpt_name']), 77 | map_location=device)) 78 | 79 | # Load VAE ONLY if latents are not to be used or are missing 80 | if not im_dataset.use_latents: 81 | print('Loading vae model as latents not present') 82 | vae = VAE(im_channels=dataset_config['im_channels'], 83 | model_config=autoencoder_model_config).to(device) 84 | vae.eval() 85 | # Load vae if found 86 | if os.path.exists(os.path.join(train_config['task_name'], 87 | train_config['vae_autoencoder_ckpt_name'])): 88 | print('Loaded vae checkpoint') 89 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 90 | train_config['vae_autoencoder_ckpt_name']), 91 | map_location=device)) 92 | # Specify training parameters 93 | num_epochs = train_config['ldm_epochs'] 94 | optimizer = Adam(model.get_params(), lr=train_config['controlnet_lr']) 95 | lr_scheduler = MultiStepLR(optimizer, milestones=train_config['controlnet_lr_steps'], gamma=0.1) 96 | criterion = torch.nn.MSELoss() 97 | 98 | # Run training 99 | if not im_dataset.use_latents: 100 | for param in vae.parameters(): 101 | param.requires_grad = False 102 | 103 | step_count = 0 104 | count = 0 105 | for epoch_idx in range(num_epochs): 106 | losses = [] 107 | for im, hint in tqdm(data_loader): 108 | optimizer.zero_grad() 109 | im = im.float().to(device) 110 | if im_dataset.use_latents: 111 | mean, logvar = torch.chunk(im, 2, dim=1) 112 | std = torch.exp(0.5 * logvar) 113 | im = mean + std * torch.randn(mean.shape).to(device=im.device) 114 | else: 115 | with torch.no_grad(): 116 | im, _ = vae.encode(im) 117 | 118 | hint = hint.float().to(device) 119 | # Sample random noise 120 | noise = torch.randn_like(im).to(device) 121 | 122 | # Sample timestep 123 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device) 124 | 125 | # Add noise to images according to timestep 126 | noisy_im = scheduler.add_noise(im, noise, t) 127 | noise_pred = model(noisy_im, t, hint) 128 | 129 | loss = criterion(noise_pred, noise) 130 | losses.append(loss.item()) 131 | loss.backward() 132 | optimizer.step() 133 | step_count += 1 134 | print('Finished epoch:{} | Loss : {:.4f}'.format( 135 | epoch_idx + 1, 136 | np.mean(losses))) 137 | lr_scheduler.step() 138 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 139 | train_config['controlnet_ckpt_name'])) 140 | 141 | print('Done Training ...') 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser(description='Arguments for ldm controlnet training') 146 | parser.add_argument('--config', dest='config_path', 147 | default='config/celebhq.yaml', type=str) 148 | args = parser.parse_args() 149 | train(args) 150 | -------------------------------------------------------------------------------- /tools/train_ldm_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.optim import Adam 8 | from dataset.celeb_dataset import CelebDataset 9 | from torch.utils.data import DataLoader 10 | from models.unet_cond_base import Unet 11 | from models.vae import VAE 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | from scheduler.linear_noise_scheduler import LinearNoiseScheduler 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | if torch.backends.mps.is_available(): 17 | device = torch.device('mps') 18 | print('Using mps') 19 | 20 | 21 | def train(args): 22 | # Read the config file # 23 | with open(args.config_path, 'r') as file: 24 | try: 25 | config = yaml.safe_load(file) 26 | except yaml.YAMLError as exc: 27 | print(exc) 28 | print(config) 29 | ######################## 30 | 31 | diffusion_config = config['diffusion_params'] 32 | dataset_config = config['dataset_params'] 33 | diffusion_model_config = config['ldm_params'] 34 | autoencoder_model_config = config['autoencoder_params'] 35 | train_config = config['train_params'] 36 | 37 | # Create the noise scheduler 38 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 39 | beta_start=diffusion_config['beta_start'], 40 | beta_end=diffusion_config['beta_end'], 41 | ldm_scheduler=True) 42 | 43 | im_dataset = CelebDataset(split='train', 44 | im_path=dataset_config['im_path'], 45 | im_size=dataset_config['im_size'], 46 | im_channels=dataset_config['im_channels'], 47 | use_latents=True, 48 | latent_path=os.path.join(train_config['task_name'], 49 | train_config['vae_latent_dir_name']) 50 | ) 51 | 52 | data_loader = DataLoader(im_dataset, 53 | batch_size=train_config['ldm_batch_size'], 54 | shuffle=True) 55 | 56 | # Instantiate the model 57 | model = Unet(im_channels=autoencoder_model_config['z_channels'], 58 | model_config=diffusion_model_config).to(device) 59 | model.train() 60 | 61 | if os.path.exists(os.path.join(train_config['task_name'], 62 | train_config['ldm_ckpt_name'])): 63 | print('Loaded unet checkpoint') 64 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 65 | train_config['ldm_ckpt_name']), 66 | map_location=device)) 67 | 68 | # Load VAE ONLY if latents are not to be used or are missing 69 | if not im_dataset.use_latents: 70 | print('Loading vae model as latents not present') 71 | vae = VAE(im_channels=dataset_config['im_channels'], 72 | model_config=autoencoder_model_config).to(device) 73 | vae.eval() 74 | # Load vae if found 75 | if os.path.exists(os.path.join(train_config['task_name'], 76 | train_config['vae_autoencoder_ckpt_name'])): 77 | print('Loaded vae checkpoint') 78 | vae.load_state_dict(torch.load(os.path.join(train_config['task_name'], 79 | train_config['vae_autoencoder_ckpt_name']), 80 | map_location=device)) 81 | # Specify training parameters 82 | num_epochs = train_config['ldm_epochs'] 83 | optimizer = Adam(model.parameters(), lr=train_config['ldm_lr']) 84 | lr_scheduler = MultiStepLR(optimizer, milestones=train_config['ldm_lr_steps'], gamma=0.5) 85 | criterion = torch.nn.MSELoss() 86 | 87 | # Run training 88 | if not im_dataset.use_latents: 89 | for param in vae.parameters(): 90 | param.requires_grad = False 91 | 92 | step_count = 0 93 | for epoch_idx in range(num_epochs): 94 | losses = [] 95 | for im in tqdm(data_loader): 96 | optimizer.zero_grad() 97 | im = im.float().to(device) 98 | if im_dataset.use_latents: 99 | mean, logvar = torch.chunk(im, 2, dim=1) 100 | std = torch.exp(0.5 * logvar) 101 | im = mean + std * torch.randn(mean.shape).to(device=im.device) 102 | else: 103 | with torch.no_grad(): 104 | im, _ = vae.encode(im) 105 | 106 | # Sample random noise 107 | noise = torch.randn_like(im).to(device) 108 | 109 | # Sample timestep 110 | t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device) 111 | 112 | # Add noise to images according to timestep 113 | noisy_im = scheduler.add_noise(im, noise, t) 114 | noise_pred = model(noisy_im, t) 115 | 116 | loss = criterion(noise_pred, noise) 117 | losses.append(loss.item()) 118 | loss.backward() 119 | optimizer.step() 120 | step_count += 1 121 | print('Finished epoch:{} | Loss : {:.4f}'.format( 122 | epoch_idx + 1, 123 | np.mean(losses))) 124 | lr_scheduler.step() 125 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 126 | train_config['ldm_ckpt_name'])) 127 | 128 | print('Done Training ...') 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser(description='Arguments for ldm training') 133 | parser.add_argument('--config', dest='config_path', 134 | default='config/celebhq.yaml', type=str) 135 | args = parser.parse_args() 136 | train(args) 137 | -------------------------------------------------------------------------------- /tools/train_vae.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import random 5 | import torchvision 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from models.vae import VAE 10 | from models.lpips import LPIPS 11 | from models.discriminator import Discriminator 12 | from torch.utils.data.dataloader import DataLoader 13 | from dataset.celeb_dataset import CelebDataset 14 | from torch.optim import Adam 15 | from torchvision.utils import make_grid 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | if torch.backends.mps.is_available(): 19 | device = torch.device('mps') 20 | print('Using mps') 21 | 22 | 23 | def train(args): 24 | # Read the config file # 25 | with open(args.config_path, 'r') as file: 26 | try: 27 | config = yaml.safe_load(file) 28 | except yaml.YAMLError as exc: 29 | print(exc) 30 | print(config) 31 | 32 | dataset_config = config['dataset_params'] 33 | autoencoder_config = config['autoencoder_params'] 34 | train_config = config['train_params'] 35 | 36 | # Set the desired seed value # 37 | seed = train_config['seed'] 38 | torch.manual_seed(seed) 39 | np.random.seed(seed) 40 | random.seed(seed) 41 | if device == 'cuda': 42 | torch.cuda.manual_seed_all(seed) 43 | ############################# 44 | 45 | # Create the model and dataset # 46 | model = VAE(im_channels=dataset_config['im_channels'], 47 | model_config=autoencoder_config).to(device) 48 | # Create the dataset 49 | im_dataset = CelebDataset(split='train', 50 | im_path=dataset_config['im_path'], 51 | im_size=dataset_config['im_size'], 52 | im_channels=dataset_config['im_channels']) 53 | 54 | data_loader = DataLoader(im_dataset, 55 | batch_size=train_config['autoencoder_batch_size'], 56 | shuffle=True) 57 | 58 | # Create output directories 59 | if not os.path.exists(train_config['task_name']): 60 | os.mkdir(train_config['task_name']) 61 | 62 | num_epochs = train_config['autoencoder_epochs'] 63 | 64 | # L1/L2 loss for Reconstruction 65 | recon_criterion = torch.nn.MSELoss() 66 | # Disc Loss can even be BCEWithLogits 67 | disc_criterion = torch.nn.MSELoss() 68 | 69 | # No need to freeze lpips as lpips.py takes care of that 70 | lpips_model = LPIPS().eval().to(device) 71 | discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device) 72 | 73 | if os.path.exists(os.path.join(train_config['task_name'], 74 | train_config['vae_autoencoder_ckpt_name'])): 75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 76 | train_config['vae_autoencoder_ckpt_name']), 77 | map_location=device)) 78 | print('Loaded autoencoder from checkpoint') 79 | 80 | if os.path.exists(os.path.join(train_config['task_name'], 81 | train_config['vae_discriminator_ckpt_name'])): 82 | discriminator.load_state_dict(torch.load(os.path.join(train_config['task_name'], 83 | train_config['vae_discriminator_ckpt_name']), 84 | map_location=device)) 85 | print('Loaded discriminator from checkpoint') 86 | 87 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 88 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 89 | 90 | disc_step_start = train_config['disc_start'] 91 | step_count = 0 92 | 93 | # This is for accumulating gradients incase the images are huge 94 | # And one cant afford higher batch sizes 95 | acc_steps = train_config['autoencoder_acc_steps'] 96 | image_save_steps = train_config['autoencoder_img_save_steps'] 97 | img_save_count = 0 98 | 99 | for epoch_idx in range(num_epochs): 100 | recon_losses = [] 101 | perceptual_losses = [] 102 | disc_losses = [] 103 | gen_losses = [] 104 | losses = [] 105 | 106 | optimizer_g.zero_grad() 107 | optimizer_d.zero_grad() 108 | 109 | for im in tqdm(data_loader): 110 | step_count += 1 111 | im = im.float().to(device) 112 | 113 | # Fetch autoencoders output(reconstructions) 114 | model_output = model(im) 115 | output, encoder_output = model_output 116 | 117 | # Image Saving Logic 118 | if step_count % image_save_steps == 0 or step_count == 1: 119 | sample_size = min(8, im.shape[0]) 120 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu() 121 | save_output = ((save_output + 1) / 2) 122 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu() 123 | 124 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) 125 | img = torchvision.transforms.ToPILImage()(grid) 126 | if not os.path.exists(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')): 127 | os.mkdir(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')) 128 | img.save(os.path.join(train_config['task_name'], 'vae_autoencoder_samples', 129 | 'current_autoencoder_sample_{}.png'.format(img_save_count))) 130 | img_save_count += 1 131 | img.close() 132 | 133 | ######### Optimize Generator ########## 134 | # L2 Loss 135 | recon_loss = recon_criterion(output, im) 136 | recon_losses.append(recon_loss.item()) 137 | recon_loss = recon_loss / acc_steps 138 | 139 | mean, logvar = torch.chunk(encoder_output, 2, dim=1) 140 | kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mean ** 2 - 1 - logvar, dim=[1, 2, 3])) 141 | 142 | g_loss = recon_loss + (train_config['kl_weight'] * kl_loss / acc_steps) 143 | 144 | # Adversarial loss only if disc_step_start steps passed 145 | if step_count > disc_step_start: 146 | disc_fake_pred = discriminator(model_output[0]) 147 | disc_fake_loss = disc_criterion(disc_fake_pred, 148 | torch.ones(disc_fake_pred.shape, 149 | device=disc_fake_pred.device)) 150 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item()) 151 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps 152 | lpips_loss = torch.mean(lpips_model(output, im)) 153 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item()) 154 | g_loss += train_config['perceptual_weight'] * lpips_loss / acc_steps 155 | losses.append(g_loss.item()) 156 | g_loss.backward() 157 | ##################################### 158 | 159 | ######### Optimize Discriminator ####### 160 | if step_count > disc_step_start: 161 | fake = output 162 | disc_fake_pred = discriminator(fake.detach()) 163 | disc_real_pred = discriminator(im) 164 | disc_fake_loss = disc_criterion(disc_fake_pred, 165 | torch.zeros(disc_fake_pred.shape, 166 | device=disc_fake_pred.device)) 167 | disc_real_loss = disc_criterion(disc_real_pred, 168 | torch.ones(disc_real_pred.shape, 169 | device=disc_real_pred.device)) 170 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2 171 | disc_losses.append(disc_loss.item()) 172 | disc_loss = disc_loss / acc_steps 173 | disc_loss.backward() 174 | if step_count % acc_steps == 0: 175 | optimizer_d.step() 176 | optimizer_d.zero_grad() 177 | ##################################### 178 | 179 | if step_count % acc_steps == 0: 180 | optimizer_g.step() 181 | optimizer_g.zero_grad() 182 | optimizer_d.step() 183 | optimizer_d.zero_grad() 184 | optimizer_g.step() 185 | optimizer_g.zero_grad() 186 | if len(disc_losses) > 0: 187 | print( 188 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | ' 189 | 'G Loss : {:.4f} | D Loss {:.4f}'. 190 | format(epoch_idx + 1, 191 | np.mean(recon_losses), 192 | np.mean(perceptual_losses), 193 | np.mean(gen_losses), 194 | np.mean(disc_losses))) 195 | else: 196 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}'. 197 | format(epoch_idx + 1, 198 | np.mean(recon_losses), 199 | np.mean(perceptual_losses))) 200 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 201 | train_config['vae_autoencoder_ckpt_name'])) 202 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'], 203 | train_config['vae_discriminator_ckpt_name'])) 204 | print('Done Training...') 205 | 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser(description='Arguments for vae training') 209 | parser.add_argument('--config', dest='config_path', 210 | default='config/celebhq.yaml', type=str) 211 | args = parser.parse_args() 212 | train(args) 213 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/ControlNet-PyTorch/9d20a8aaadab942eb5504774e467006d5a22302a/utils/__init__.py -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | def validate_class_config(condition_config): 2 | assert 'class_condition_config' in condition_config, \ 3 | "Class conditioning desired but class condition config missing" 4 | assert 'num_classes' in condition_config['class_condition_config'], \ 5 | "num_class missing in class condition config" 6 | 7 | 8 | def validate_text_config(condition_config): 9 | assert 'text_condition_config' in condition_config, \ 10 | "Text conditioning desired but text condition config missing" 11 | assert 'text_embed_dim' in condition_config['text_condition_config'], \ 12 | "text_embed_dim missing in text condition config" 13 | 14 | 15 | def validate_image_config(condition_config): 16 | assert 'image_condition_config' in condition_config, \ 17 | "Image conditioning desired but image condition config missing" 18 | assert 'image_condition_input_channels' in condition_config['image_condition_config'], \ 19 | "image_condition_input_channels missing in image condition config" 20 | assert 'image_condition_output_channels' in condition_config['image_condition_config'], \ 21 | "image_condition_output_channels missing in image condition config" 22 | 23 | 24 | def validate_image_conditional_input(cond_input, x): 25 | assert 'image' in cond_input, \ 26 | "Model initialized with image conditioning but cond_input has no image information" 27 | assert cond_input['image'].shape[0] == x.shape[0], \ 28 | "Batch size mismatch of image condition and input" 29 | assert cond_input['image'].shape[2] % x.shape[2] == 0, \ 30 | "Height/Width of image condition must be divisible by latent input" 31 | 32 | 33 | def validate_class_conditional_input(cond_input, x, num_classes): 34 | assert 'class' in cond_input, \ 35 | "Model initialized with class conditioning but cond_input has no class information" 36 | assert cond_input['class'].shape == (x.shape[0], num_classes), \ 37 | "Shape of class condition input must match (Batch Size, )" 38 | 39 | 40 | def get_config_value(config, key, default_value): 41 | return config[key] if key in config else default_value -------------------------------------------------------------------------------- /utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import glob 3 | import os 4 | import torch 5 | 6 | 7 | def load_latents(latent_path): 8 | r""" 9 | Simple utility to save latents to speed up ldm training 10 | :param latent_path: 11 | :return: 12 | """ 13 | latent_maps = {} 14 | for fname in glob.glob(os.path.join(latent_path, '*.pkl')): 15 | s = pickle.load(open(fname, 'rb')) 16 | for k, v in s.items(): 17 | latent_maps[k] = v[0] 18 | return latent_maps 19 | 20 | 21 | def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob): 22 | if text_drop_prob > 0: 23 | text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0, 24 | 1) < text_drop_prob 25 | assert empty_text_embed is not None, ("Text Conditioning required as well as" 26 | " text dropping but empty text representation not created") 27 | text_embed[text_drop_mask, :, :] = empty_text_embed[0] 28 | return text_embed 29 | 30 | 31 | def drop_image_condition(image_condition, im, im_drop_prob): 32 | if im_drop_prob > 0: 33 | im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, 34 | 1) > im_drop_prob 35 | return image_condition * im_drop_mask 36 | else: 37 | return image_condition 38 | 39 | 40 | def drop_class_condition(class_condition, class_drop_prob, im): 41 | if class_drop_prob > 0: 42 | class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0, 43 | 1) > class_drop_prob 44 | return class_condition * class_drop_mask 45 | else: 46 | return class_condition 47 | --------------------------------------------------------------------------------