├── LICENSE ├── README.md ├── assets ├── classifier-free-cifar10.png ├── classifier-free-imagenet.png ├── clip-guidance-celebahq.png ├── ddib-imagenet.png ├── ddim-celebahq-interpolate.png ├── ddim-celebahq-reconstruction.png ├── ddim-cifar10-interpolate.png ├── ddim-cifar10-reconstruction.png ├── ddim-cifar10.png ├── ddpm-celebahq-denoise.png ├── ddpm-celebahq-progressive.png ├── ddpm-celebahq-random.png ├── ddpm-cifar10-denoise.png ├── ddpm-cifar10-progressive.png ├── ddpm-cifar10-random.png ├── ddpm-mnist-random.png ├── fidelity-speed-visualization.png ├── ilvr-celebahq.png ├── mask-guidance-imagenet.png ├── sdedit.png └── streamlit.png ├── configs ├── ddpm_celebahq.yaml ├── ddpm_cfg_cifar10.yaml ├── ddpm_cifar10.yaml └── ddpm_mnist.yaml ├── datasets ├── ImageDir.py ├── __init__.py ├── celebahq.py ├── cifar10.py ├── imagenet.py └── mnist.py ├── diffusions ├── __init__.py ├── ddim.py ├── ddpm.py ├── ddpm_ip.py ├── euler.py ├── guidance │ ├── __init__.py │ ├── base.py │ ├── clip_guidance.py │ ├── ilvr.py │ └── mask_guidance.py ├── heun.py └── schedule.py ├── docs ├── CLIP Guidance.md ├── Classifier-Free Guidance.md ├── DDIB.md ├── DDIM.md ├── DDPM-IP.md ├── DDPM.md ├── ILVR.md ├── Mask Guidance.md ├── SDEdit.md └── Samplers.md ├── models ├── __init__.py ├── adm │ ├── __init__.py │ ├── nn.py │ ├── readme.md │ ├── unet.py │ └── unet_combined.py ├── base_latent.py ├── dit │ ├── __init__.py │ ├── autoencoder.py │ ├── dit.py │ ├── model.py │ └── readme.md ├── ema.py ├── mdt │ ├── __init__.py │ ├── autoencoder.py │ ├── mdt.py │ ├── model.py │ └── readme.md ├── modules.py ├── pesser │ ├── __init__.py │ ├── model.py │ └── readme.md ├── sdxl │ ├── __init__.py │ ├── attention.py │ ├── autoencoder.py │ ├── conditioner.py │ ├── distributions.py │ ├── modules.py │ ├── readme.md │ ├── regularizers.py │ ├── stablediffusion.py │ ├── unet.py │ └── util.py ├── stablediffusion │ ├── __init__.py │ ├── attention.py │ ├── autoencoder.py │ ├── distributions.py │ ├── modules.py │ ├── readme.md │ ├── stablediffusion.py │ ├── text_encoders.py │ ├── unet.py │ └── util.py ├── unet.py └── unet_categorial_adagn.py ├── requirements.txt ├── scripts ├── sample_cfg.py ├── sample_clip_guidance.py ├── sample_ddib.py ├── sample_ilvr.py ├── sample_mask_guidance.py ├── sample_sdedit.py ├── sample_uncond.py ├── train_ddpm.py └── train_ddpm_cfg.py ├── streamlit ├── Hello.py └── pages │ ├── 1_Unconditional_Image_Generation.py │ ├── 2_Class_conditional_Image_Generation.py │ ├── 3_Stable_Diffusion_v1.5.py │ └── 4_Stable_Diffusion_XL.py ├── test_images ├── celebahq │ ├── 182660.jpg │ ├── 182664.jpg │ ├── 182690.jpg │ ├── 182699.jpg │ ├── 182704.jpg │ ├── 182716.jpg │ ├── 182722.jpg │ ├── 182727.jpg │ ├── 182734.jpg │ └── 182743.jpg ├── cifar10 │ ├── val_00.png │ ├── val_01.png │ ├── val_02.png │ ├── val_03.png │ ├── val_04.png │ ├── val_05.png │ ├── val_06.png │ ├── val_07.png │ ├── val_08.png │ ├── val_09.png │ ├── val_10.png │ ├── val_11.png │ ├── val_12.png │ ├── val_13.png │ ├── val_14.png │ ├── val_15.png │ ├── val_16.png │ ├── val_17.png │ ├── val_18.png │ ├── val_19.png │ ├── val_20.png │ ├── val_21.png │ ├── val_22.png │ ├── val_23.png │ ├── val_24.png │ ├── val_25.png │ ├── val_26.png │ ├── val_27.png │ ├── val_28.png │ ├── val_29.png │ ├── val_30.png │ ├── val_31.png │ ├── val_32.png │ ├── val_33.png │ ├── val_34.png │ ├── val_35.png │ ├── val_36.png │ ├── val_37.png │ ├── val_38.png │ ├── val_39.png │ ├── val_40.png │ ├── val_41.png │ ├── val_42.png │ ├── val_43.png │ ├── val_44.png │ ├── val_45.png │ ├── val_46.png │ ├── val_47.png │ ├── val_48.png │ ├── val_49.png │ ├── val_50.png │ ├── val_51.png │ ├── val_52.png │ ├── val_53.png │ ├── val_54.png │ ├── val_55.png │ ├── val_56.png │ ├── val_57.png │ ├── val_58.png │ ├── val_59.png │ ├── val_60.png │ ├── val_61.png │ ├── val_62.png │ └── val_63.png ├── imagenet │ ├── 220 │ │ ├── ILSVRC2012_val_00000539.JPEG │ │ ├── ILSVRC2012_val_00005027.JPEG │ │ ├── ILSVRC2012_val_00008148.JPEG │ │ ├── ILSVRC2012_val_00026146.JPEG │ │ └── ILSVRC2012_val_00048337.JPEG │ ├── 248 │ │ ├── ILSVRC2012_val_00000178.JPEG │ │ ├── ILSVRC2012_val_00010011.JPEG │ │ ├── ILSVRC2012_val_00019757.JPEG │ │ ├── ILSVRC2012_val_00032646.JPEG │ │ └── ILSVRC2012_val_00049841.JPEG │ ├── 290 │ │ ├── ILSVRC2012_val_00002931.JPEG │ │ ├── ILSVRC2012_val_00009297.JPEG │ │ ├── ILSVRC2012_val_00018843.JPEG │ │ ├── ILSVRC2012_val_00040784.JPEG │ │ └── ILSVRC2012_val_00045488.JPEG │ ├── 291 │ │ ├── ILSVRC2012_val_00003788.JPEG │ │ ├── ILSVRC2012_val_00011645.JPEG │ │ ├── ILSVRC2012_val_00012423.JPEG │ │ ├── ILSVRC2012_val_00027413.JPEG │ │ └── ILSVRC2012_val_00034811.JPEG │ ├── 292 │ │ ├── ILSVRC2012_val_00029267.JPEG │ │ ├── ILSVRC2012_val_00030651.JPEG │ │ ├── ILSVRC2012_val_00047614.JPEG │ │ ├── ILSVRC2012_val_00048380.JPEG │ │ └── ILSVRC2012_val_00049631.JPEG │ ├── 294 │ │ ├── ILSVRC2012_val_00010281.JPEG │ │ ├── ILSVRC2012_val_00020362.JPEG │ │ ├── ILSVRC2012_val_00031645.JPEG │ │ ├── ILSVRC2012_val_00032453.JPEG │ │ └── ILSVRC2012_val_00039227.JPEG │ └── test │ │ ├── ILSVRC2012_test_00000001.JPEG │ │ ├── ILSVRC2012_test_00000002.JPEG │ │ ├── ILSVRC2012_test_00000003.JPEG │ │ ├── ILSVRC2012_test_00000004.JPEG │ │ ├── ILSVRC2012_test_00000005.JPEG │ │ ├── ILSVRC2012_test_00000006.JPEG │ │ ├── ILSVRC2012_test_00000007.JPEG │ │ ├── ILSVRC2012_test_00000008.JPEG │ │ ├── ILSVRC2012_test_00000009.JPEG │ │ └── ILSVRC2012_test_00000010.JPEG └── strokes │ ├── 0.png │ ├── 1.png │ └── 2.png ├── utils ├── __init__.py ├── load.py ├── logger.py ├── mask.py ├── misc.py └── resize_right │ ├── __init__.py │ ├── interp_methods.py │ └── resize_right.py └── weights ├── ChenWu98 └── cycle-diffusion │ ├── cat_ema_0.9999_050000.yaml │ └── wild_ema_0.9999_050000.yaml ├── andreas128 └── RePaint │ └── celeba256_250000.yaml ├── facebookresearch └── DiT │ ├── DiT-XL-2-256x256.yaml │ └── DiT-XL-2-512x512.yaml ├── jychoi118 └── ilvr_adm │ └── afhqdog_p2.yaml ├── openai └── guided-diffusion │ ├── 256x256_diffusion.yaml │ ├── 256x256_diffusion_combined.yaml │ └── 256x256_diffusion_uncond.yaml ├── pesser └── pytorch_diffusion │ ├── ema_diffusion_celebahq_model-560000.yaml │ └── ema_diffusion_lsun_church_model-4432000.yaml ├── sail-sg └── MDT │ └── mdt_xl2_v2_ckpt.yaml ├── sdxl └── sd_xl_base.yaml └── stablediffusion ├── v1-inference.yaml └── v2-inference-v.yaml /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yifeng Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/classifier-free-cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/classifier-free-cifar10.png -------------------------------------------------------------------------------- /assets/classifier-free-imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/classifier-free-imagenet.png -------------------------------------------------------------------------------- /assets/clip-guidance-celebahq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/clip-guidance-celebahq.png -------------------------------------------------------------------------------- /assets/ddib-imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddib-imagenet.png -------------------------------------------------------------------------------- /assets/ddim-celebahq-interpolate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-celebahq-interpolate.png -------------------------------------------------------------------------------- /assets/ddim-celebahq-reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-celebahq-reconstruction.png -------------------------------------------------------------------------------- /assets/ddim-cifar10-interpolate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-cifar10-interpolate.png -------------------------------------------------------------------------------- /assets/ddim-cifar10-reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-cifar10-reconstruction.png -------------------------------------------------------------------------------- /assets/ddim-cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-cifar10.png -------------------------------------------------------------------------------- /assets/ddpm-celebahq-denoise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-celebahq-denoise.png -------------------------------------------------------------------------------- /assets/ddpm-celebahq-progressive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-celebahq-progressive.png -------------------------------------------------------------------------------- /assets/ddpm-celebahq-random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-celebahq-random.png -------------------------------------------------------------------------------- /assets/ddpm-cifar10-denoise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-cifar10-denoise.png -------------------------------------------------------------------------------- /assets/ddpm-cifar10-progressive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-cifar10-progressive.png -------------------------------------------------------------------------------- /assets/ddpm-cifar10-random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-cifar10-random.png -------------------------------------------------------------------------------- /assets/ddpm-mnist-random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-mnist-random.png -------------------------------------------------------------------------------- /assets/fidelity-speed-visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/fidelity-speed-visualization.png -------------------------------------------------------------------------------- /assets/ilvr-celebahq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ilvr-celebahq.png -------------------------------------------------------------------------------- /assets/mask-guidance-imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/mask-guidance-imagenet.png -------------------------------------------------------------------------------- /assets/sdedit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/sdedit.png -------------------------------------------------------------------------------- /assets/streamlit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/streamlit.png -------------------------------------------------------------------------------- /configs/ddpm_celebahq.yaml: -------------------------------------------------------------------------------- 1 | seed: 2022 2 | 3 | data: 4 | target: datasets.celebahq.CelebAHQ 5 | params: 6 | root: ~/data/CelebA-HQ/ 7 | img_size: 256 8 | img_channels: 3 9 | 10 | dataloader: 11 | num_workers: 4 12 | pin_memory: true 13 | prefetch_factor: 2 14 | 15 | model: 16 | target: models.unet.UNet 17 | params: 18 | in_channels: 3 19 | out_channels: 3 20 | dim: 128 21 | dim_mults: [1, 1, 2, 2, 4, 4] 22 | use_attn: [false, false, false, false, true, false] 23 | num_res_blocks: 2 24 | n_heads: 1 25 | dropout: 0.0 26 | 27 | diffusion: 28 | target: diffusions.ddpm.DDPM 29 | params: 30 | total_steps: 1000 31 | beta_schedule: linear 32 | beta_start: 0.0001 33 | beta_end: 0.02 34 | objective: pred_eps 35 | var_type: fixed_small 36 | 37 | train: 38 | n_steps: 500000 39 | batch_size: 64 40 | micro_batch: 0 41 | 42 | clip_grad_norm: 1.0 43 | ema_decay: 0.9999 44 | ema_gradual: true 45 | 46 | print_freq: 400 47 | save_freq: 10000 48 | sample_freq: 5000 49 | n_samples: 36 50 | 51 | optim: 52 | target: torch.optim.Adam 53 | params: 54 | lr: 0.00002 55 | -------------------------------------------------------------------------------- /configs/ddpm_cfg_cifar10.yaml: -------------------------------------------------------------------------------- 1 | seed: 2022 2 | 3 | data: 4 | target: datasets.cifar10.CIFAR10 5 | params: 6 | root: ~/data/CIFAR-10/ 7 | img_size: 32 8 | img_channels: 3 9 | num_classes: 10 10 | 11 | dataloader: 12 | num_workers: 4 13 | pin_memory: true 14 | prefetch_factor: 2 15 | 16 | model: 17 | target: models.unet_categorial_adagn.UNetCategorialAdaGN 18 | params: 19 | in_channels: 3 20 | out_channels: 3 21 | dim: 128 22 | dim_mults: [1, 2, 2, 2] 23 | use_attn: [false, true, true, false] 24 | num_res_blocks: 2 25 | num_classes: 10 26 | attn_head_dims: 64 27 | resblock_updown: true 28 | dropout: 0.1 29 | 30 | diffusion: 31 | target: diffusions.cfg.ddpm_cfg.DDPMCFG 32 | params: 33 | total_steps: 1000 34 | beta_schedule: cosine 35 | beta_start: 0.0001 36 | beta_end: 0.02 37 | objective: pred_eps 38 | var_type: fixed_large 39 | 40 | train: 41 | n_steps: 800000 42 | batch_size: 128 43 | micro_batch: 0 44 | 45 | clip_grad_norm: 1.0 46 | ema_decay: 0.9999 47 | ema_gradual: true 48 | 49 | print_freq: 400 50 | save_freq: 10000 51 | sample_freq: 5000 52 | n_samples_each_class: 10 53 | 54 | p_uncond: 0.2 55 | 56 | optim: 57 | target: torch.optim.AdamW 58 | params: 59 | lr: 0.0002 60 | -------------------------------------------------------------------------------- /configs/ddpm_cifar10.yaml: -------------------------------------------------------------------------------- 1 | seed: 2022 2 | 3 | data: 4 | target: datasets.cifar10.CIFAR10 5 | params: 6 | root: ~/data/CIFAR-10/ 7 | img_size: 32 8 | img_channels: 3 9 | num_classes: 10 10 | 11 | dataloader: 12 | num_workers: 4 13 | pin_memory: true 14 | prefetch_factor: 2 15 | 16 | model: 17 | target: models.unet.UNet 18 | params: 19 | in_channels: 3 20 | out_channels: 3 21 | dim: 128 22 | dim_mults: [1, 2, 2, 2] 23 | use_attn: [false, true, false, false] 24 | num_res_blocks: 2 25 | n_heads: 1 26 | dropout: 0.1 27 | 28 | diffusion: 29 | target: diffusions.ddpm.DDPM 30 | params: 31 | total_steps: 1000 32 | beta_schedule: linear 33 | beta_start: 0.0001 34 | beta_end: 0.02 35 | objective: pred_eps 36 | var_type: fixed_large 37 | 38 | train: 39 | n_steps: 800000 40 | batch_size: 128 41 | micro_batch: 0 42 | 43 | clip_grad_norm: 1.0 44 | ema_decay: 0.9999 45 | ema_gradual: true 46 | 47 | print_freq: 400 48 | save_freq: 10000 49 | sample_freq: 5000 50 | n_samples: 64 51 | 52 | optim: 53 | target: torch.optim.Adam 54 | params: 55 | lr: 0.0002 56 | -------------------------------------------------------------------------------- /configs/ddpm_mnist.yaml: -------------------------------------------------------------------------------- 1 | seed: 2022 2 | 3 | data: 4 | target: datasets.mnist.MNIST 5 | params: 6 | root: ~/data/MNIST/ 7 | img_size: 32 8 | img_channels: 1 9 | num_classes: 10 10 | 11 | dataloader: 12 | num_workers: 4 13 | pin_memory: true 14 | prefetch_factor: 2 15 | 16 | model: 17 | target: models.unet.UNet 18 | params: 19 | in_channels: 1 20 | out_channels: 1 21 | dim: 64 22 | dim_mults: [1, 2, 2, 2] 23 | use_attn: [false, true, false, false] 24 | num_res_blocks: 2 25 | n_heads: 1 26 | dropout: 0.1 27 | 28 | diffusion: 29 | target: diffusions.ddpm.DDPM 30 | params: 31 | total_steps: 200 32 | beta_schedule: linear 33 | beta_start: 0.0001 34 | beta_end: 0.02 35 | objective: pred_eps 36 | var_type: fixed_small 37 | 38 | train: 39 | n_steps: 50000 40 | batch_size: 128 41 | micro_batch: 0 42 | 43 | clip_grad_norm: 1.0 44 | ema_decay: 0.9999 45 | ema_gradual: true 46 | 47 | print_freq: 400 48 | save_freq: 10000 49 | sample_freq: 5000 50 | n_samples: 64 51 | 52 | optim: 53 | target: torch.optim.Adam 54 | params: 55 | lr: 0.0002 56 | -------------------------------------------------------------------------------- /datasets/ImageDir.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def extract_images(root): 7 | """ Extract all images under root """ 8 | img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] 9 | root = os.path.expanduser(root) 10 | img_paths = [] 11 | for curdir, subdirs, files in os.walk(root): 12 | for file in files: 13 | if os.path.splitext(file)[1].lower() in img_ext: 14 | img_paths.append(os.path.join(curdir, file)) 15 | img_paths = sorted(img_paths) 16 | return img_paths 17 | 18 | 19 | class ImageDir(Dataset): 20 | def __init__(self, root, transform=None): 21 | root = os.path.expanduser(root) 22 | if not os.path.isdir(root): 23 | raise ValueError(f'{root} is not a valid directory') 24 | 25 | self.transform = transform 26 | self.img_paths = extract_images(root) 27 | 28 | def __len__(self): 29 | return len(self.img_paths) 30 | 31 | def __getitem__(self, item): 32 | X = Image.open(self.img_paths[item]).convert('RGB') 33 | if self.transform is not None: 34 | X = self.transform(X) 35 | return X 36 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ImageDir import ImageDir 2 | -------------------------------------------------------------------------------- /datasets/celebahq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from typing import Optional, Callable 4 | 5 | import torchvision.transforms as T 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def extract_images(root): 10 | """ Extract all images under root """ 11 | img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] 12 | root = os.path.expanduser(root) 13 | img_paths = [] 14 | for curdir, subdirs, files in os.walk(root): 15 | for file in files: 16 | if os.path.splitext(file)[1].lower() in img_ext: 17 | img_paths.append(os.path.join(curdir, file)) 18 | img_paths = sorted(img_paths) 19 | return img_paths 20 | 21 | 22 | class CelebAHQ(Dataset): 23 | """The CelebA-HQ Dataset. 24 | 25 | The CelebA-HQ dataset is a high-quality version of CelebA that consists of 30,000 images at 1024×1024 resolution. 26 | (Copied from PaperWithCode) 27 | 28 | The official way to prepare the dataset is to download img_celeba.7z from the original CelebA dataset and the delta 29 | files from the official GitHub repository. Then use dataset_tool.py to generate the high-quality images. 30 | 31 | However, I personally recommend downloading the CelebAMask-HQ dataset, which contains processed CelebA-HQ images. 32 | Nevertheless, the filenames in CelebAMask-HQ are sorted from 0 to 29999, which is inconsistent with the original 33 | CelebA filenames. I provide a python script (scripts/celebahq_map_filenames.py) to help convert the filenames. 34 | 35 | To load data with this class, the dataset should be organized in the following structure: 36 | 37 | root 38 | ├── CelebA-HQ-img 39 | │ ├── 000004.jpg 40 | │ ├── ... 41 | │ └── 202591.jpg 42 | └── CelebA-HQ-to-CelebA-mapping.txt 43 | 44 | The train/valid/test sets are split according to the original CelebA dataset, 45 | resulting in 24,183 training images, 2,993 validation images, and 2,824 test images. 46 | 47 | This class has one pre-defined transform: 48 | - 'resize' (default): Resize the image directly to the target size 49 | 50 | References: 51 | - https://github.com/tkarras/progressive_growing_of_gans 52 | - https://paperswithcode.com/dataset/celeba-hq 53 | - https://github.com/switchablenorms/CelebAMask-HQ 54 | 55 | """ 56 | def __init__( 57 | self, 58 | root: str, 59 | img_size: int, 60 | split: str = 'train', 61 | transform_type: str = 'default', 62 | transform: Optional[Callable] = None, 63 | ): 64 | if split not in ['train', 'valid', 'test', 'all']: 65 | raise ValueError(f'Invalid split: {split}') 66 | root = os.path.expanduser(root) 67 | image_root = os.path.join(root, 'CelebA-HQ-img') 68 | if not os.path.isdir(image_root): 69 | raise ValueError(f'{image_root} is not an existing directory') 70 | 71 | self.root = root 72 | self.img_size = img_size 73 | self.split = split 74 | self.transform_type = transform_type 75 | self.transform = transform 76 | if transform is None: 77 | self.transform = self.get_transform() 78 | 79 | def filter_func(p): 80 | if split == 'all': 81 | return True 82 | celeba_splits = [1, 162771, 182638, 202600] 83 | k = 0 if split == 'train' else (1 if split == 'valid' else 2) 84 | return celeba_splits[k] <= int(os.path.splitext(os.path.basename(p))[0]) < celeba_splits[k+1] 85 | 86 | self.img_paths = extract_images(image_root) 87 | self.img_paths = list(filter(filter_func, self.img_paths)) 88 | 89 | def __len__(self): 90 | return len(self.img_paths) 91 | 92 | def __getitem__(self, item): 93 | X = Image.open(self.img_paths[item]) 94 | if self.transform is not None: 95 | X = self.transform(X) 96 | return X 97 | 98 | def get_transform(self): 99 | flip_p = 0.5 if self.split in ['train', 'all'] else 0.0 100 | if self.transform_type in ['default', 'resize']: 101 | transform = T.Compose([ 102 | T.Resize((self.img_size, self.img_size)), 103 | T.RandomHorizontalFlip(flip_p), 104 | T.ToTensor(), 105 | T.Normalize([0.5] * 3, [0.5] * 3), 106 | ]) 107 | elif self.transform_type == 'none': 108 | transform = None 109 | else: 110 | raise ValueError(f'Invalid transform_type: {self.transform_type}') 111 | return transform 112 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | import torchvision.datasets 4 | import torchvision.transforms as T 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class CIFAR10(Dataset): 9 | """Extend torchvision.datasets.CIFAR10 with one pre-defined transform. 10 | 11 | The pre-defined transform is: 12 | - 'resize' (default): Resize the image directly to the target size, followed by random horizontal flipping. 13 | 14 | """ 15 | 16 | def __init__( 17 | self, 18 | root: str, 19 | img_size: int, 20 | split: str = 'train', 21 | transform_type: str = 'default', 22 | transform: Optional[Callable] = None, 23 | target_transform: Optional[Callable] = None, 24 | download: bool = False, 25 | ): 26 | if split not in ['train', 'test']: 27 | raise ValueError(f'Invalid split: {split}') 28 | 29 | self.img_size = img_size 30 | self.split = split 31 | self.transform_type = transform_type 32 | if transform is None: 33 | transform = self.get_transform() 34 | 35 | self.cifar10 = torchvision.datasets.CIFAR10( 36 | root=root, 37 | train=(split == 'train'), 38 | transform=transform, 39 | target_transform=target_transform, 40 | download=download, 41 | ) 42 | 43 | def __len__(self): 44 | return len(self.cifar10) 45 | 46 | def __getitem__(self, item): 47 | X, y = self.cifar10[item] 48 | return X, y 49 | 50 | def get_transform(self): 51 | flip_p = 0.5 if self.split == 'train' else 0.0 52 | if self.transform_type in ['default', 'resize']: 53 | transform = T.Compose([ 54 | T.Resize((self.img_size, self.img_size), antialias=True), 55 | T.RandomHorizontalFlip(flip_p), 56 | T.ToTensor(), 57 | T.Normalize([0.5] * 3, [0.5] * 3), 58 | ]) 59 | elif self.transform_type == 'none': 60 | transform = None 61 | else: 62 | raise ValueError(f'Invalid transform_type: {self.transform_type}') 63 | return transform 64 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from typing import Optional, Callable 4 | 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as T 7 | 8 | 9 | def extract_images(root): 10 | """ Extract all images under root """ 11 | img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] 12 | root = os.path.expanduser(root) 13 | img_paths = [] 14 | for curdir, subdirs, files in os.walk(root): 15 | for file in files: 16 | if os.path.splitext(file)[1].lower() in img_ext: 17 | img_paths.append(os.path.join(curdir, file)) 18 | img_paths = sorted(img_paths) 19 | return img_paths 20 | 21 | 22 | class ImageNet(Dataset): 23 | """Extend torchvision.datasets.ImageNet with two pre-defined transforms and support test set. 24 | 25 | This class has two pre-defined transforms: 26 | - 'resize-crop' (default): Resize the image so that the short side match the target size, then crop a square patch 27 | - 'resize': Resize the image directly to the target size 28 | All of the above transforms will be followed by random horizontal flipping. 29 | 30 | To load data with this class, the dataset should be organized in the following structure: 31 | 32 | root 33 | ├── train 34 | │ ├── n01440764 35 | │ ├── ... 36 | │ └── n15075141 37 | ├── valid (or val) 38 | │ ├── n01440764 (or directly put all validation images here) 39 | │ ├── ... 40 | │ └── n15075141 41 | └── test 42 | ├── ILSVRC2012_test_00000001.JPEG 43 | ├── ... 44 | └── ILSVRC2012_test_00100000.JPEG 45 | 46 | References: 47 | - https://image-net.org/challenges/LSVRC/2012/2012-downloads.php 48 | 49 | """ 50 | 51 | def __init__( 52 | self, 53 | root: str, 54 | img_size: int, 55 | split: str = 'train', 56 | transform_type: str = 'default', 57 | transform: Optional[Callable] = None, 58 | ): 59 | if split not in ['train', 'valid', 'test']: 60 | raise ValueError(f'Invalid split: {split}') 61 | root = os.path.expanduser(root) 62 | image_root = os.path.join(root, split) 63 | if split == 'valid' and not os.path.isdir(image_root): 64 | image_root = os.path.join(root, 'val') 65 | if not os.path.isdir(image_root): 66 | raise ValueError(f'{image_root} is not an existing directory') 67 | 68 | self.img_size = img_size 69 | self.split = split 70 | self.transform_type = transform_type 71 | self.transform = transform 72 | 73 | self.img_paths = extract_images(image_root) 74 | 75 | def __len__(self): 76 | return len(self.img_paths) 77 | 78 | def __getitem__(self, item): 79 | X = Image.open(self.img_paths[item]).convert('RGB') 80 | if self.transform is not None: 81 | X = self.transform(X) 82 | return X 83 | 84 | def get_transform(self): 85 | crop = T.RandomCrop if self.split == 'train' else T.CenterCrop 86 | flip_p = 0.5 if self.split == 'train' else 0.0 87 | if self.transform_type in ['default', 'resize-crop']: 88 | transform = T.Compose([ 89 | T.Resize(self.img_size), 90 | crop((self.img_size, self.img_size)), 91 | T.RandomHorizontalFlip(flip_p), 92 | T.ToTensor(), 93 | T.Normalize([0.5] * 3, [0.5] * 3), 94 | ]) 95 | elif self.transform_type == 'resize': 96 | transform = T.Compose([ 97 | T.Resize((self.img_size, self.img_size)), 98 | T.RandomHorizontalFlip(flip_p), 99 | T.ToTensor(), 100 | T.Normalize([0.5] * 3, [0.5] * 3), 101 | ]) 102 | elif self.transform_type == 'none': 103 | transform = None 104 | else: 105 | raise ValueError(f'Invalid transform_type: {self.transform_type}') 106 | return transform 107 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | import torchvision.datasets 4 | import torchvision.transforms as T 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MNIST(Dataset): 9 | """Extend torchvision.datasets.MNIST with one pre-defined transform. 10 | 11 | The pre-defined transform is: 12 | - 'resize' (default): Resize the image directly to the target size 13 | 14 | """ 15 | 16 | def __init__( 17 | self, 18 | root: str, 19 | img_size: int, 20 | split: str = 'train', 21 | transform_type: str = 'default', 22 | transform: Optional[Callable] = None, 23 | target_transform: Optional[Callable] = None, 24 | download: bool = False, 25 | ): 26 | if split not in ['train', 'test']: 27 | raise ValueError(f'Invalid split: {split}') 28 | 29 | self.img_size = img_size 30 | self.transform_type = transform_type 31 | if transform is None: 32 | transform = self.get_transform() 33 | 34 | self.mnist = torchvision.datasets.MNIST( 35 | root=root, 36 | train=(split == 'train'), 37 | transform=transform, 38 | target_transform=target_transform, 39 | download=download, 40 | ) 41 | 42 | def __len__(self): 43 | return len(self.mnist) 44 | 45 | def __getitem__(self, item): 46 | X, y = self.mnist[item] 47 | return X, y 48 | 49 | def get_transform(self): 50 | if self.transform_type in ['default', 'resize']: 51 | transform = T.Compose([ 52 | T.Resize((self.img_size, self.img_size), antialias=True), 53 | T.ToTensor(), 54 | T.Normalize([0.5], [0.5]), 55 | ]) 56 | elif self.transform_type == 'none': 57 | transform = None 58 | else: 59 | raise ValueError(f'Invalid transform_type: {self.transform_type}') 60 | return transform 61 | -------------------------------------------------------------------------------- /diffusions/__init__.py: -------------------------------------------------------------------------------- 1 | from .schedule import get_beta_schedule, get_respaced_seq 2 | 3 | from .ddpm import DDPM, DDPMCFG 4 | from .ddim import DDIM, DDIMCFG 5 | from .euler import EulerSampler 6 | from .heun import HeunSampler 7 | 8 | from .guidance.ilvr import ILVR 9 | from .guidance.mask_guidance import MaskGuidance 10 | from .guidance.clip_guidance import CLIPGuidance 11 | -------------------------------------------------------------------------------- /diffusions/ddpm_ip.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from diffusions import DDPM 9 | 10 | 11 | class DDPM_IP(DDPM): 12 | def __init__(self, gamma: float = 0.1, *args, **kwargs): 13 | """Denoising Diffusion Probabilistic Models with Input Perturbation. 14 | 15 | Perturb the input (xt) during training to simulate the gap between training and testing. 16 | Surprisingly simple but effective. 17 | 18 | Args: 19 | gamma: Perturbation strength. 20 | 21 | References: 22 | [1] Ning, Mang, Enver Sangineto, Angelo Porrello, Simone Calderara, and Rita Cucchiara. "Input 23 | Perturbation Reduces Exposure Bias in Diffusion Models." arXiv preprint arXiv:2301.11706 (2023). 24 | 25 | """ 26 | super().__init__(*args, **kwargs) 27 | self.gamma = gamma 28 | 29 | def loss_func(self, model: nn.Module, x0: Tensor, t: Tensor, eps: Tensor = None, model_kwargs: Dict = None): 30 | if model_kwargs is None: 31 | model_kwargs = dict() 32 | if eps is None: 33 | eps = torch.randn_like(x0) 34 | # input perturbation 35 | perturbed_eps = eps + self.gamma * torch.randn_like(eps) 36 | xt = self.diffuse(x0, t, perturbed_eps) 37 | if self.objective == 'pred_eps': 38 | pred_eps = model(xt, t, **model_kwargs) 39 | return F.mse_loss(pred_eps, eps) 40 | elif self.objective == 'pred_x0': 41 | pred_x0 = model(xt, t, **model_kwargs) 42 | return F.mse_loss(pred_x0, x0) 43 | elif self.objective == 'pred_v': 44 | v = self.get_v(x0, eps, t) 45 | pred_v = model(xt, t, **model_kwargs) 46 | return F.mse_loss(pred_v, v) 47 | else: 48 | raise ValueError(f'Objective {self.objective} is not supported.') 49 | -------------------------------------------------------------------------------- /diffusions/euler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from diffusions.ddpm import DDPM 5 | 6 | 7 | class EulerSampler(DDPM): 8 | def __init__( 9 | self, 10 | total_steps: int = 1000, 11 | beta_schedule: str = 'linear', 12 | beta_start: float = 0.0001, 13 | beta_end: float = 0.02, 14 | betas: Tensor = None, 15 | objective: str = 'pred_eps', 16 | 17 | clip_denoised: bool = True, 18 | respace_type: str = None, 19 | respace_steps: int = 100, 20 | respaced_seq: Tensor = None, 21 | 22 | device: torch.device = 'cpu', 23 | **kwargs, 24 | ): 25 | """Euler sampler for DDPM-like diffusion process. 26 | 27 | References: 28 | [1] Karras, Tero, Miika Aittala, Timo Aila, and Samuli Laine. "Elucidating the design space of 29 | diffusion-based generative models." Advances in Neural Information Processing Systems 35 (2022): 30 | 26565-26577. 31 | 32 | """ 33 | super().__init__( 34 | total_steps=total_steps, 35 | beta_schedule=beta_schedule, 36 | beta_start=beta_start, 37 | beta_end=beta_end, 38 | betas=betas, 39 | objective=objective, 40 | clip_denoised=clip_denoised, 41 | respace_type=respace_type, 42 | respace_steps=respace_steps, 43 | respaced_seq=respaced_seq, 44 | device=device, 45 | **kwargs, 46 | ) 47 | 48 | self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod).sqrt() 49 | 50 | def denoise(self, model_output: Tensor, xt: Tensor, t: int, t_prev: int): 51 | """Denoise from x_t to x_{t-1}.""" 52 | # Prepare parameters 53 | sigmas_t = self.sigmas[t] 54 | sigmas_t_prev = self.sigmas[t_prev] if t_prev >= 0 else torch.tensor(0.0) 55 | 56 | # Predict x0 and eps 57 | predict = self.predict(model_output, xt, t) 58 | pred_x0 = predict['pred_x0'] 59 | 60 | # Calculate the x{t-1} 61 | bar_xt = (1 + sigmas_t ** 2).sqrt() * xt 62 | derivative = (bar_xt - pred_x0) / sigmas_t 63 | bar_sample = bar_xt + derivative * (sigmas_t_prev - sigmas_t) 64 | sample = bar_sample / (1 + sigmas_t_prev ** 2).sqrt() 65 | 66 | return {'sample': sample, 'pred_x0': pred_x0} 67 | -------------------------------------------------------------------------------- /diffusions/guidance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/diffusions/guidance/__init__.py -------------------------------------------------------------------------------- /diffusions/guidance/clip_guidance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torchvision.transforms as T 4 | 5 | from transformers import CLIPProcessor, CLIPModel 6 | 7 | from diffusions.guidance.base import BaseGuidance 8 | from utils.misc import image_norm_to_uint8 9 | 10 | 11 | class CLIPGuidance(BaseGuidance): 12 | """Diffusion Models with CLIP Guidance. 13 | 14 | Guide the diffusion process with similarity between CLIP image feature and text feature, so that the generated 15 | image matches the description of the input text. 16 | 17 | In each step, the guidance is applied on the predicted x0 to avoid training CLIP on noisy images. 18 | 19 | """ 20 | def __init__( 21 | self, 22 | guidance_weight: float = 1.0, 23 | clip_pretrained: str = 'openai/clip-vit-base-patch32', 24 | **kwargs, 25 | ): 26 | super().__init__(**kwargs) 27 | self.guidance_weight = guidance_weight 28 | 29 | self.clip_processor = CLIPProcessor.from_pretrained(clip_pretrained) 30 | self.clip_model = CLIPModel.from_pretrained(clip_pretrained).to(self.device) 31 | 32 | self.text = None 33 | 34 | def set_text(self, text: str): 35 | assert isinstance(text, str) 36 | self.text = text 37 | 38 | @torch.enable_grad() 39 | def cond_fn_mean(self, t: int, xt: Tensor, pred_x0: Tensor, var: Tensor, **kwargs): 40 | if self.text is None: 41 | raise RuntimeError('Please call `set_text()` before sampling.') 42 | images = image_norm_to_uint8(pred_x0) 43 | processed = self.clip_processor(text=self.text, images=images, return_tensors="pt", padding=True) 44 | processed = {k: v.to(self.device) for k, v in processed.items()} 45 | processed['pixel_values'].requires_grad_(True) 46 | out = self.clip_model(**processed) 47 | similarities = torch.matmul(out['image_embeds'], out['text_embeds'].t()).squeeze(dim=1) 48 | grad = torch.autograd.grad(outputs=similarities.sum(), inputs=processed['pixel_values'])[0] 49 | grad = T.Resize(xt.shape[-2:], antialias=True)(grad) 50 | return self.guidance_weight * ((1. / self.alphas_cumprod[t]) ** 0.5) * var * grad 51 | -------------------------------------------------------------------------------- /diffusions/guidance/ilvr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from diffusions.guidance.base import BaseGuidance 5 | from utils.resize_right import resize_right, interp_methods 6 | 7 | 8 | class ILVR(BaseGuidance): 9 | def __init__( 10 | self, 11 | ref_images: Tensor = None, 12 | downsample_factor: int = 8, 13 | interp_method: str = 'cubic', 14 | *args, **kwargs, 15 | ): 16 | """Iterative Latent Variable Refinement (ILVR). 17 | 18 | Args: 19 | ref_images: The reference images of shape [B, C, H, W]. 20 | downsample_factor: The downsample factor. 21 | interp_method: The interpolation method. Options: 'cubic', 'lanczos2', 'lanczos3', 'linear', 'box'. 22 | 23 | References: 24 | [1] Choi, Jooyoung, Sungwon Kim, Yonghyun Jeong, Youngjune Gwon, and Sungroh Yoon. “ILVR: Conditioning Method 25 | for Denoising Diffusion Probabilistic Models.” In 2021 IEEE/CVF International Conference on Computer Vision 26 | (ICCV), pp. 14347-14356. IEEE, 2021. 27 | 28 | """ 29 | super().__init__(*args, **kwargs) 30 | self.ref_images = ref_images 31 | self.downsample_factor = downsample_factor 32 | self.interp_method = getattr(interp_methods, interp_method) 33 | 34 | def set_ref_images(self, ref_images: Tensor): 35 | self.ref_images = ref_images 36 | 37 | def cond_fn_sample(self, t: int, t_prev: int, sample: Tensor, **kwargs): 38 | if self.ref_images is None: 39 | raise RuntimeError('Please call `set_ref_images()` before sampling.') 40 | if t == 0: 41 | noisy_ref_images = self.ref_images 42 | else: 43 | noisy_ref_images = self.diffuse( 44 | x0=self.ref_images, 45 | t=torch.full((sample.shape[0], ), t_prev, device=self.device), 46 | ) 47 | return self.low_pass_filter(noisy_ref_images) - self.low_pass_filter(sample) 48 | 49 | def low_pass_filter(self, x: Tensor): 50 | x = resize_right.resize(x, scale_factors=1./self.downsample_factor, interp_method=self.interp_method) 51 | x = resize_right.resize(x, scale_factors=self.downsample_factor, interp_method=self.interp_method) 52 | return x 53 | -------------------------------------------------------------------------------- /diffusions/guidance/mask_guidance.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import Dict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from diffusions.guidance.base import BaseGuidance 9 | 10 | 11 | class MaskGuidance(BaseGuidance): 12 | def __init__( 13 | self, 14 | masked_image: Tensor = None, 15 | mask: Tensor = None, 16 | *args, **kwargs, 17 | ): 18 | """Diffusion Models with Mask Guidance. 19 | 20 | The idea was first proposed in [1] and further developed in [2], [3], etc. for image inpainting. In each reverse 21 | step, xt is computed by composing the noisy known part and denoised unknown part of the image. 22 | 23 | .. math:: 24 | x_{t−1} = m \odot x^{known}_{t−1} + (1 − m) \odot x^{unknown}_{t-1} 25 | 26 | Args: 27 | masked_image: The masked input images of shape [B, C, H, W]. 28 | mask: The binary masks of shape [B, 1, H, W]. Note that 1 denotes known areas and 0 denotes unknown areas. 29 | 30 | References: 31 | [1]. Song, Yang, and Stefano Ermon. “Generative modeling by estimating gradients of the data distribution.” 32 | Advances in neural information processing systems 32 (2019). 33 | 34 | [2]. Avrahami, Omri, Dani Lischinski, and Ohad Fried. “Blended diffusion for text-driven editing of natural 35 | images.” In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 18208 36 | -18218. 2022. 37 | 38 | [3]. Lugmayr, Andreas, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, and Luc Van Gool. “Repaint: 39 | Inpainting using denoising diffusion probabilistic models.” In Proceedings of the IEEE/CVF Conference on 40 | Computer Vision and Pattern Recognition, pp. 11461-11471. 2022. 41 | 42 | """ 43 | super().__init__(*args, **kwargs) 44 | self.masked_image = masked_image 45 | self.mask = mask 46 | 47 | def set_mask_and_image(self, masked_image: Tensor, mask: Tensor): 48 | self.masked_image = masked_image 49 | self.mask = mask 50 | 51 | def cond_fn_sample(self, t: int, t_prev: int, sample: Tensor, **kwargs): 52 | if self.masked_image is None or self.mask is None: 53 | raise RuntimeError('Please call `set_mask_and_image()` before sampling.') 54 | if t == 0: 55 | noisy_known = self.masked_image 56 | else: 57 | noisy_known = self.diffuse( 58 | x0=self.masked_image, 59 | t=torch.full((sample.shape[0], ), t_prev, device=self.device), 60 | ) 61 | return (noisy_known - sample) * self.mask 62 | 63 | def q_sample_one_step(self, xt: Tensor, t: int, t_next: int): 64 | """Sample from q(x{t+1} | xt). """ 65 | alphas_cumprod_t = self.alphas_cumprod[t] 66 | alphas_cumprod_t_next = self.alphas_cumprod[t_next] if t_next < self.total_steps else torch.tensor(0.0) 67 | alphas_t_next = alphas_cumprod_t_next / alphas_cumprod_t 68 | return torch.sqrt(alphas_t_next) * xt + torch.sqrt(1. - alphas_t_next) * torch.randn_like(xt) 69 | 70 | def resample_loop( 71 | self, model: nn.Module, init_noise: Tensor, 72 | resample_r: int = 10, resample_j: int = 10, 73 | tqdm_kwargs: Dict = None, model_kwargs: Dict = None, 74 | ): 75 | """Sample following RePaint paper. """ 76 | tqdm_kwargs = dict() if tqdm_kwargs is None else tqdm_kwargs 77 | model_kwargs = dict() if model_kwargs is None else model_kwargs 78 | 79 | img = init_noise 80 | resample_seq1 = self.get_resample_seq(resample_r, resample_j) 81 | resample_seq2 = resample_seq1[1:] + [-1] 82 | pbar = tqdm.tqdm(total=len(resample_seq1), **tqdm_kwargs) 83 | for t1, t2 in zip(resample_seq1, resample_seq2): 84 | if t1 > t2: 85 | t_batch = torch.full((img.shape[0], ), t1, device=self.device, dtype=torch.long) 86 | model_output = model(img, t_batch, **model_kwargs) 87 | out = self.denoise(model_output, img, t1, t2) 88 | out = self.apply_guidance(**out, xt=img, t=t1, t_prev=t2) # apply guidance 89 | img = out['sample'] 90 | yield out 91 | else: 92 | img = self.q_sample_one_step(img, t1, t2) 93 | yield {'sample': img} 94 | pbar.update(1) 95 | pbar.close() 96 | 97 | def resample( 98 | self, model: nn.Module, init_noise: Tensor, 99 | resample_r: int = 10, resample_j: int = 10, 100 | tqdm_kwargs: Dict = None, model_kwargs: Dict = None, 101 | ): 102 | sample = None 103 | for out in self.resample_loop( 104 | model, init_noise, 105 | resample_r, resample_j, 106 | tqdm_kwargs, model_kwargs, 107 | ): 108 | sample = out['sample'] 109 | return sample 110 | 111 | def get_resample_seq(self, resample_r: int = 10, resample_j: int = 10): 112 | """Figure 9 in RePaint paper. 113 | 114 | Args: 115 | resample_r: Number of resampling, as proposed in RePaint paper. 116 | resample_j: Jump lengths of resampling, as proposed in RePaint paper. 117 | 118 | """ 119 | t_T = len(self.respaced_seq) 120 | 121 | jumps = {} 122 | for j in range(0, t_T - resample_j, resample_j): 123 | jumps[j] = resample_r - 1 124 | 125 | t = t_T 126 | ts = [] 127 | while t >= 1: 128 | t = t - 1 129 | ts.append(self.respaced_seq[t].item()) 130 | if jumps.get(t, 0) > 0: 131 | jumps[t] = jumps[t] - 1 132 | for _ in range(resample_j): 133 | t = t + 1 134 | ts.append(self.respaced_seq[t].item()) 135 | return ts 136 | 137 | 138 | def _test(r, j): 139 | import matplotlib.pyplot as plt 140 | 141 | dummy_image = torch.rand((10, 3, 256, 256)) 142 | dummy_mask = torch.randint(0, 2, (10, 1, 256, 256)) 143 | mask_guided = MaskGuidance(dummy_image, dummy_mask, skip_type='uniform', skip_steps=250) 144 | 145 | ts = mask_guided.get_resample_seq(resample_r=r, resample_j=j) 146 | plt.rcParams["figure.figsize"] = (10, 5) 147 | plt.plot(range(len(ts)), ts) 148 | plt.title(f'r={r}, j={j}') 149 | plt.show() 150 | 151 | 152 | if __name__ == '__main__': 153 | _test(1, 10) 154 | _test(5, 10) 155 | _test(10, 10) 156 | -------------------------------------------------------------------------------- /diffusions/heun.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import Dict 3 | 4 | import torch 5 | from torch import Tensor, nn as nn 6 | 7 | from diffusions.ddpm import DDPM 8 | 9 | 10 | class HeunSampler(DDPM): 11 | def __init__( 12 | self, 13 | total_steps: int = 1000, 14 | beta_schedule: str = 'linear', 15 | beta_start: float = 0.0001, 16 | beta_end: float = 0.02, 17 | betas: Tensor = None, 18 | objective: str = 'pred_eps', 19 | 20 | clip_denoised: bool = True, 21 | respace_type: str = None, 22 | respace_steps: int = 100, 23 | respaced_seq: Tensor = None, 24 | 25 | device: torch.device = 'cpu', 26 | **kwargs, 27 | ): 28 | """Heun sampler for DDPM-like diffusion process. 29 | 30 | References: 31 | [1] Karras, Tero, Miika Aittala, Timo Aila, and Samuli Laine. "Elucidating the design space of 32 | diffusion-based generative models." Advances in Neural Information Processing Systems 35 (2022): 33 | 26565-26577. 34 | 35 | """ 36 | super().__init__( 37 | total_steps=total_steps, 38 | beta_schedule=beta_schedule, 39 | beta_start=beta_start, 40 | beta_end=beta_end, 41 | betas=betas, 42 | objective=objective, 43 | clip_denoised=clip_denoised, 44 | respace_type=respace_type, 45 | respace_steps=respace_steps, 46 | respaced_seq=respaced_seq, 47 | device=device, 48 | **kwargs, 49 | ) 50 | 51 | self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod).sqrt() 52 | 53 | self._1st_order_derivative = None 54 | self._1st_order_xt = None 55 | 56 | def denoise_1st_order(self, model_output: Tensor, xt: Tensor, t: int, t_prev: int): 57 | """1st order step. Same as euler sampler.""" 58 | # Prepare parameters 59 | sigmas_t = self.sigmas[t] 60 | sigmas_t_prev = self.sigmas[t_prev] if t_prev >= 0 else torch.tensor(0.0) 61 | 62 | # Predict x0 63 | predict = self.predict(model_output, xt, t) 64 | pred_x0 = predict['pred_x0'] 65 | 66 | # Calculate the x{t-1} 67 | bar_xt = (1 + sigmas_t ** 2).sqrt() * xt 68 | derivative = (bar_xt - pred_x0) / sigmas_t 69 | bar_sample = bar_xt + derivative * (sigmas_t_prev - sigmas_t) 70 | sample = bar_sample / (1 + sigmas_t_prev ** 2).sqrt() 71 | 72 | # Store the 1st order info 73 | self._1st_order_derivative = derivative 74 | self._1st_order_xt = xt 75 | 76 | return {'sample': sample, 'pred_x0': pred_x0} 77 | 78 | def denoise_2nd_order(self, model_output: Tensor, xt_prev: Tensor, t: int, t_prev: int): 79 | """2nd order step.""" 80 | # Prepare parameters 81 | sigmas_t = self.sigmas[t] 82 | sigmas_t_prev = self.sigmas[t_prev] if t_prev >= 0 else torch.tensor(0.0) 83 | 84 | # Predict x0 85 | predict = self.predict(model_output, xt_prev, t_prev) 86 | pred_x0 = predict['pred_x0'] 87 | 88 | # Calculate derivative 89 | bar_xt_prev = (1 + sigmas_t_prev ** 2).sqrt() * xt_prev 90 | derivative = (bar_xt_prev - pred_x0) / sigmas_t_prev 91 | derivative = (derivative + self._1st_order_derivative) / 2 92 | 93 | # Calculate the x{t-1} 94 | bar_xt = (1 + sigmas_t ** 2).sqrt() * self._1st_order_xt 95 | bar_sample = bar_xt + derivative * (sigmas_t_prev - sigmas_t) 96 | sample = bar_sample / (1 + sigmas_t_prev ** 2).sqrt() 97 | 98 | # Clear the 1st order info 99 | self._1st_order_derivative = None 100 | self._1st_order_xt = None 101 | 102 | return {'sample': sample, 'pred_x0': pred_x0} 103 | 104 | def sample_loop( 105 | self, model: nn.Module, init_noise: Tensor, 106 | tqdm_kwargs: Dict = None, model_kwargs: Dict = None, 107 | ): 108 | tqdm_kwargs = dict() if tqdm_kwargs is None else tqdm_kwargs 109 | model_kwargs = dict() if model_kwargs is None else model_kwargs 110 | 111 | img = init_noise 112 | sample_seq = self.respaced_seq.tolist() 113 | sample_seq_prev = [-1] + self.respaced_seq[:-1].tolist() 114 | pbar = tqdm.tqdm(total=len(sample_seq), **tqdm_kwargs) 115 | for t, t_prev in zip(reversed(sample_seq), reversed(sample_seq_prev)): 116 | # 1st order step 117 | t_batch = torch.full((img.shape[0], ), t, device=self.device, dtype=torch.long) 118 | model_output = model(img, t_batch, **model_kwargs) 119 | out = self.denoise_1st_order(model_output, img, t, t_prev) 120 | img = out['sample'] 121 | 122 | if t_prev >= 0: 123 | # 2nd order step 124 | t_prev_batch = torch.full((img.shape[0], ), t_prev, device=self.device, dtype=torch.long) 125 | model_output = model(img, t_prev_batch, **model_kwargs) 126 | out = self.denoise_2nd_order(model_output, img, t, t_prev) 127 | img = out['sample'] 128 | 129 | pbar.update(1) 130 | yield out 131 | pbar.close() 132 | -------------------------------------------------------------------------------- /diffusions/schedule.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def get_beta_schedule( 6 | total_steps: int = 1000, 7 | beta_schedule: str = 'linear', 8 | beta_start: float = 0.0001, 9 | beta_end: float = 0.02, 10 | ): 11 | """Get the beta schedule for diffusion. 12 | 13 | Args: 14 | total_steps: Number of diffusion steps. 15 | beta_schedule: Type of beta schedule. Options: 'linear', 'quad', 'const', 'cosine'. 16 | beta_start: Starting beta value. 17 | beta_end: Ending beta value. 18 | 19 | Returns: 20 | A Tensor of length `total_steps`. 21 | 22 | """ 23 | if beta_schedule == 'linear': 24 | return torch.linspace(beta_start, beta_end, total_steps, dtype=torch.float64) 25 | elif beta_schedule == 'quad': 26 | return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, total_steps, dtype=torch.float64) ** 2 27 | elif beta_schedule == 'const': 28 | return torch.full((total_steps, ), fill_value=beta_end, dtype=torch.float64) 29 | elif beta_schedule == 'cosine': 30 | def alpha_bar(t): 31 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 32 | betas = [ 33 | min(1 - alpha_bar((i + 1) / total_steps) / alpha_bar(i / total_steps), 0.999) 34 | for i in range(total_steps) 35 | ] 36 | return torch.tensor(betas) 37 | else: 38 | raise ValueError(f'Beta schedule {beta_schedule} is not supported.') 39 | 40 | 41 | def get_respaced_seq( 42 | total_steps: int = 1000, 43 | respace_type: str = 'uniform', 44 | respace_steps: int = 100, 45 | ): 46 | """Get respaced time sequence for fast inference. 47 | 48 | Args: 49 | total_steps: Number of the original diffusion steps. 50 | respace_type: Type of respaced timestep sequence. Options: 'uniform', 'uniform-leading', 'uniform-linspace', 51 | 'uniform-trailing', 'quad', 'none', None. 52 | respace_steps: Length of respaced timestep sequence. 53 | 54 | Returns: 55 | A Tensor of length `respace_steps`, containing indices that are preserved in the respaced sequence. 56 | 57 | """ 58 | if respace_type in ['uniform', 'uniform-leading']: 59 | space = total_steps // respace_steps 60 | seq = torch.arange(0, total_steps, space).long() 61 | elif respace_type == 'uniform-linspace': 62 | seq = torch.linspace(0, total_steps - 1, respace_steps).long() 63 | elif respace_type == 'uniform-trailing': 64 | space = total_steps // respace_steps 65 | seq = torch.arange(total_steps-1, -1, -space).long().flip(dims=[0]) 66 | elif respace_type == 'quad': 67 | seq = torch.linspace(0, math.sqrt(total_steps * 0.8), respace_steps) ** 2 68 | seq = torch.floor(seq).long() 69 | elif respace_type is None or respace_type == 'none': 70 | seq = torch.arange(0, total_steps).long() 71 | else: 72 | raise ValueError(f'Respace type {respace_type} is not supported.') 73 | return seq 74 | 75 | 76 | def _test_betas(): 77 | import matplotlib.pyplot as plt 78 | fig, ax = plt.subplots(1, 2, figsize=(10, 4)) 79 | 80 | betas_linear = get_beta_schedule( 81 | total_steps=1000, 82 | beta_schedule='linear', 83 | beta_start=0.0001, 84 | beta_end=0.02, 85 | ) 86 | alphas_bar_linear = torch.cumprod(1. - betas_linear, dim=0) 87 | betas_quad = get_beta_schedule( 88 | total_steps=1000, 89 | beta_schedule='quad', 90 | beta_start=0.0001, 91 | beta_end=0.02, 92 | ) 93 | alphas_bar_quad = torch.cumprod(1. - betas_quad, dim=0) 94 | betas_cosine = get_beta_schedule( 95 | total_steps=1000, 96 | beta_schedule='cosine', 97 | beta_start=0.0001, 98 | beta_end=0.02, 99 | ) 100 | alphas_bar_cosine = torch.cumprod(1. - betas_cosine, dim=0) 101 | 102 | ax[0].plot(torch.arange(1000), betas_linear, label='linear') 103 | ax[0].plot(torch.arange(1000), betas_quad, label='quad') 104 | ax[0].plot(torch.arange(1000), betas_cosine, label='cosine') 105 | ax[0].set_title(r'$\beta_t$') 106 | ax[0].set_xlabel(r'$t$') 107 | ax[0].legend() 108 | ax[1].plot(torch.arange(1000), alphas_bar_linear, label='linear') 109 | ax[1].plot(torch.arange(1000), alphas_bar_quad, label='quad') 110 | ax[1].plot(torch.arange(1000), alphas_bar_cosine, label='cosine') 111 | ax[1].set_title(r'$\bar\alpha_t$') 112 | ax[1].set_xlabel(r'$t$') 113 | ax[1].legend() 114 | plt.show() 115 | 116 | 117 | def _test_respace(): 118 | seq = get_respaced_seq( 119 | total_steps=1000, 120 | respace_type='uniform-leading', 121 | respace_steps=10, 122 | ) 123 | print('uniform-leading:\t', seq) 124 | seq = get_respaced_seq( 125 | total_steps=1000, 126 | respace_type='uniform-linspace', 127 | respace_steps=10, 128 | ) 129 | print('uniform-linspace:\t', seq) 130 | seq = get_respaced_seq( 131 | total_steps=1000, 132 | respace_type='uniform-trailing', 133 | respace_steps=10, 134 | ) 135 | print('uniform-trailing:\t', seq) 136 | 137 | 138 | if __name__ == '__main__': 139 | _test_betas() 140 | # _test_respace() 141 | -------------------------------------------------------------------------------- /docs/CLIP Guidance.md: -------------------------------------------------------------------------------- 1 | # CLIP Guidance 2 | 3 | CLIP Guidance is a technique to generate images following an input text description with a pretrained diffusion model and a pretrained CLIP model. It uses CLIP score to guide the reverse diffusion process during sampling. To be specific, each reverse step is modified to: 4 | $$ 5 | \begin{align} 6 | &p_\theta(\mathbf x_{t-1}\vert\mathbf x_{t})=\mathcal N(\mathbf x_{t-1};\mu_\theta(\mathbf x_t,t){\color{dodgerblue}{+s\sigma_t^2\nabla_{\mathbf x_t}\mathcal L_{\text{CLIP}}}},\sigma_t^2\mathbf I)\\ 7 | &\mathcal L_\text{CLIP}=E_\text{image}(\mathbf x_\theta(\mathbf x_t,t))\cdot E_\text{text}(y) 8 | \end{align} 9 | $$ 10 | where $y$ is the input text, $E_\text{image}$ and $E_\text{text}$ are CLIP's image and text encoders, and $s$ is a hyper-parameter controlling the scale of guidance. 11 | 12 | 13 | 14 | ## Sampling 15 | 16 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms. 17 | 18 | ```shell 19 | accelerate-launch scripts/sample_clip_guidance.py -c CONFIG \ 20 | --weights WEIGHTS \ 21 | --text TEXT \ 22 | --n_samples N_SAMPLES \ 23 | --save_dir SAVE_DIR \ 24 | [--seed SEED] \ 25 | [--var_type VAR_TYPE] \ 26 | [--respace_type RESPACE_TYPE] \ 27 | [--respace_steps RESPACE_STEPS] \ 28 | [--guidance_weight GUIDANCE_WEIGHT] \ 29 | [--clip_model CLIP_MODEL] \ 30 | [--batch_size BATCH_SIZE] 31 | ``` 32 | 33 | Basic arguments: 34 | 35 | - `-c CONFIG`: path to the configuration file. 36 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file. 37 | - `--text TEXT`: text description. Please wrap your description with quotation marks if it contains spaces, e.g., `--text 'a lovely dog'`. 38 | - `--n_samples N_SAMPLES`: number of samples to generate. 39 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved. 40 | - `--guidance_weight GUIDANCE_WEIGHT`: guidance weight (strength). 41 | - `--clip_model CLIP_MODEL`: name of CLIP model. Default to "openai/clip-vit-base-patch32". 42 | 43 | Advanced arguments: 44 | 45 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps. 46 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices. 47 | 48 | See more details by running `python sample_clip_guidance -h`. 49 | 50 | 51 | 52 | ## Results 53 | 54 | **Pretrained on CelebA-HQ 256x256**: 55 | 56 |

57 | 58 |

59 | 60 | All the images are sampled with 50 DDPM steps. 61 | 62 | Images in the same row are sampled with the same random seed, thus share similar semantics. The first column shows the original samples (i.e., guidance scale=0). The next 3 columns are sampled using text description "a young girl with brown hair", with increasing guidance scales of 10, 50, and 100. The following 3 columns are similarly sampled with text description "an old man with a smile". 63 | 64 | As expected, bigger guidance scale leads to greater changes. However, some results fail to match the descriptions, such as gender and hair color. Furthermore, guidance scale larger than 100 damages the image quality drastically. 65 | -------------------------------------------------------------------------------- /docs/Classifier-Free Guidance.md: -------------------------------------------------------------------------------- 1 | # Classifier-Free Guidance 2 | 3 | > Ho, Jonathan, and Tim Salimans. "Classifier-Free Diffusion Guidance." In *NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications*. 2021. 4 | 5 | 6 | 7 | ## Training 8 | 9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms. 10 | 11 | ```shell 12 | accelerate-launch scripts/train_ddpm_cfg.py -c CONFIG [-e EXP_DIR] [--key value ...] 13 | ``` 14 | 15 | Arguments: 16 | 17 | - `-c CONFIG`: path to the training configuration file. 18 | - `-e EXP_DIR`: results (logs, checkpoints, tensorboard, etc.) will be saved to `EXP_DIR`. Default to `runs/exp-{current time}/`. 19 | - `--key value`: modify configuration items in `CONFIG` via CLI. 20 | 21 | For example, to train on CIFAR-10 with default settings: 22 | 23 | ```shell 24 | accelerate-launch scripts/train_ddpm_cfg.py -c ./configs/ddpm_cfg_cifar10.yaml 25 | ``` 26 | 27 | To change the default `p_uncond` (the probability to disable condition in training) in `./configs/ddpm_cfg_cifar10.yaml` from 0.2 to 0.1: 28 | 29 | ```shell 30 | accelerate-launch scripts/train_ddpm_cfg.py -c ./configs/ddpm_cfg_cifar10.yaml --train.p_uncond 0.1 31 | ``` 32 | 33 | 34 | 35 | ## Sampling 36 | 37 | ```shell 38 | accelerate-launch scripts/sample_cfg.py -c CONFIG \ 39 | --weights WEIGHTS \ 40 | --sampler {ddpm,ddim} \ 41 | --n_samples_each_class N_SAMPLES_EACH_CLASS \ 42 | --save_dir SAVE_DIR \ 43 | --guidance_scale GUIDANCE_SCALE \ 44 | [--seed SEED] \ 45 | [--class_ids CLASS_IDS [CLASS_IDS ...]] \ 46 | [--respace_type RESPACE_TYPE] \ 47 | [--respace_steps RESPACE_STEPS] \ 48 | [--ddim] \ 49 | [--ddim_eta DDIM_ETA] \ 50 | [--batch_size BATCH_SIZE] 51 | ``` 52 | 53 | Basic arguments: 54 | 55 | - `-c CONFIG`: path to the configuration file. 56 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file. 57 | - `--sampler {ddpm,ddim}`: set the sampler. 58 | - `--n_samples_each_class N_SAMPLES_EACH_CLASS`: number of samples for each class. 59 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved. 60 | - `--guidance_scale GUIDANCE_SCALE`: the guidance scale factor $s$ 61 | - $s=0$: unconditional generation 62 | - $s=1$: non-guided conditional generation 63 | - $s>1$: guided conditional generation 64 | 65 | Advanced arguments: 66 | 67 | - `--class_ids CLASS_IDS [CLASS_IDS ...]`: a list of class ids to sample. If not specified, all classes will be sampled. 68 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps. 69 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices. 70 | 71 | See more details by running `python sample_cfg.py -h`. 72 | 73 | For example, to sample 10 images for class (0, 2, 4, 8) from a pretrained CIFAR-10 model with guidance scale 3 using 100 DDIM steps: 74 | 75 | ```shell 76 | accelerate-launch scripts/sample_cfg.py -c ./configs/ddpm_cfg_cifar10.yaml --weights /path/to/model/weights --sampler ddim --n_samples_each_class 10 --save_dir ./samples/cfg-cifar10 --guidance_scale 3 --class_ids 0 2 4 8 --respace_steps 100 77 | ``` 78 | 79 | 80 | 81 | ## Evaluation 82 | 83 | Sample 10K-50K images following the previous section and evaluate image quality with tools like [torch-fidelity](https://github.com/toshas/torch-fidelity), [pytorch-fid](https://github.com/mseitzer/pytorch-fid), [clean-fid](https://github.com/GaParmar/clean-fid), etc. 84 | 85 | 86 | 87 | ## Results 88 | 89 | **FID and IS on CIFAR-10 32x32**: 90 | 91 | | guidance scale | FID ↓ | IS ↑ | 92 | | :------------------------: | :------: | :-------------: | 93 | | 0 (unconditional) | 6.2904 | 8.9851 ± 0.0825 | 94 | | 1 (non-guided conditional) | 4.6630 | 9.1763 ± 0.1201 | 95 | | 3 (guided conditional) | 10.2304 | 9.6252 ± 0.0977 | 96 | | 5 (guided conditional) | 16.23021 | 9.3210 ± 0.0744 | 97 | 98 | - The images are sampled using DDIM with 50 steps. 99 | - All the metrics are evaluated on 50K samples. 100 | - FID measures diversity and IS measures fidelity. This table shows diversity-fidelity trade-off as guidance scale increases. 101 | 102 | 103 | 104 | **Samples with different guidance scales on CIFAR-10 32x32**: 105 | 106 |

107 | 108 |

109 | From left to right: $s=0$ (unconditional), $s=1.0$ (non-guided conditional), $s=3.0$, $s=5.0$. Each row corresponds to a class. 110 | 111 | 112 | 113 | **Samples with different guidance scales on ImageNet 256x256**: 114 | 115 | The pretrained models are sourced from [openai/guided-diffusion](https://github.com/openai/guided-diffusion). Note that these models were initially designed for classifier guidance and thus are either conditional-only or unconditional-only. However, to facilitate classifier-free guidance, it would be more convenient if the model can handle both conditional and unconditional cases. To address this, I define a new class [UNetCombined](../models/adm/unet_combined.py), which combines the conditional-only and unconditional-only models into a single model. Also, we need to combine the pretrained weights for loading, which can be done by the following script: 116 | 117 | ```python 118 | import yaml 119 | from models.adm.unet_combined import UNetCombined 120 | 121 | 122 | config_path = './weights/openai/guided-diffusion/256x256_diffusion.yaml' 123 | weight_cond_path = './weights/openai/guided-diffusion/256x256_diffusion.pt' 124 | weight_uncond_path = './weights/openai/guided-diffusion/256x256_diffusion_uncond.pt' 125 | save_path = './weights/openai/guided-diffusion/256x256_diffusion_combined.pt' 126 | 127 | with open(config_path, 'r') as f: 128 | cfg = yaml.safe_load(f) 129 | model = UNetCombined(**cfg['model']['params']) 130 | model.combine_weights(weight_cond_path, weight_uncond_path, save_path) 131 | ``` 132 | 133 | 134 | 135 |

136 | 137 |

138 | 139 | From left to right: $s=1.0$ (non-guided conditional), $s=2.0$, $s=3.0$. Each row corresponds to a class. 140 | -------------------------------------------------------------------------------- /docs/DDIB.md: -------------------------------------------------------------------------------- 1 | # DDIB 2 | 3 | > Su, Xuan, Jiaming Song, Chenlin Meng, and Stefano Ermon. "Dual diffusion implicit bridges for image-to-image translation." *arXiv preprint arXiv:2203.08382* (2022). 4 | 5 | 6 | 7 | ## Sampling 8 | 9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms. 10 | 11 | ```shell 12 | accelerate-launch scripts/sample_ddib.py -c CONFIG \ 13 | --weights WEIGHTS \ 14 | --input_dir INPUT_DIR \ 15 | --save_dir SAVE_DIR \ 16 | --class_A CLASS_A \ 17 | --class_B CLASS_B \ 18 | [--seed SEED] \ 19 | [--respace_type RESPACE_TYPE] \ 20 | [--respace_steps RESPACE_STEPS] \ 21 | [--batch_size BATCH_SIZE] 22 | ``` 23 | 24 | Basic arguments: 25 | 26 | - `-c CONFIG`: path to the configuration file. 27 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file. 28 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved. 29 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved. 30 | - `--class_A CLASS_A`: input class label. 31 | - `--class_B CLASS_B`: output class label. 32 | 33 | Advanced arguments: 34 | 35 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps. 36 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices. 37 | 38 | See more details by running `python sample_ddib.py -h`. 39 | 40 | 41 | 42 | ## Results 43 | 44 | **ImageNet 256x256 (conditional)** with pretrained model from [openai/guided-diffusion](https://github.com/openai/guided-diffusion): 45 | 46 |

47 | 48 |

49 | 50 | Notes: All images are sampled with 100 DDIM steps. 51 | 52 | The results are not as good as expected. Some are acceptable, such as the Sussex Spaniel, Husky and Tiger in the 3rd row, but the others are not. 53 | -------------------------------------------------------------------------------- /docs/DDIM.md: -------------------------------------------------------------------------------- 1 | # DDIM 2 | 3 | > Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising Diffusion Implicit Models." In *International Conference on Learning Representations*. 2020. 4 | 5 | 6 | 7 | ## Sampling 8 | 9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms. 10 | 11 | ```shell 12 | accelerate-launch scripts/sample_uncond.py -c CONFIG \ 13 | --weights WEIGHTS \ 14 | --sampler ddim \ 15 | --ddim_eta DDIM_ETA \ 16 | --n_samples N_SAMPLES \ 17 | --save_dir SAVE_DIR \ 18 | [--seed SEED] \ 19 | [--batch_size BATCH_SIZE] \ 20 | [--respace_type RESPACE_TYPE] \ 21 | [--respace_steps RESPACE_STEPS] \ 22 | [--mode {sample,interpolate,reconstruction}] \ 23 | [--n_interpolate N_INTERPOLATE] \ 24 | [--input_dir INPUT_DIR] 25 | ``` 26 | 27 | Basic arguments: 28 | 29 | - `-c CONFIG`: path to the inference configuration file. 30 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file. 31 | - `--sampler ddim`: set the sampler to DDIM. 32 | - `--n_samples N_SAMPLES`: number of samples. 33 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved. 34 | - `--mode MODE`: choose a sampling mode, the options are: 35 | - "sample" (default): randomly sample images 36 | - "interpolate": sample two random images and interpolate between them. Use `--n_interpolate` to specify the number of images in between. 37 | - "reconstruction": encode a real image from dataset with **DDIM inversion** (DDIM encoding), and then decode it with DDIM sampling. 38 | 39 | Advanced arguments: 40 | 41 | - `--ddim_eta`: parameter eta in DDIM sampling. 42 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps. 43 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices. 44 | 45 | See more details by running `python sample_ddim.py -h`. 46 | 47 | For example, to sample 50000 images from a pretrained CIFAR-10 model with 100 DDIM steps: 48 | 49 | ```shell 50 | accelerate-launch scripts/sample_uncond.py -c ./configs/ddpm_cifar10.yaml --weights /path/to/model/weights --sampler ddim --n_samples 50000 --save_dir ./samples/ddim-cifar10 --respace_steps 100 51 | ``` 52 | 53 | 54 | 55 | ## Evaluation 56 | 57 | Sample 10K-50K images following the previous section and evaluate image quality with tools like [torch-fidelity](https://github.com/toshas/torch-fidelity), [pytorch-fid](https://github.com/mseitzer/pytorch-fid), [clean-fid](https://github.com/GaParmar/clean-fid), etc. 58 | 59 | 60 | 61 | ## Results 62 | 63 | **FID and IS on CIFAR-10 32x32**: 64 | 65 | All the metrics are evaluated on 50K samples using [torch-fidelity](https://torch-fidelity.readthedocs.io/en/latest/index.html) library. 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 |
etatimestepsFID ↓IS ↑
0.010004.18929.0626 ± 0.1093
100 (10x faster)6.05088.8424 ± 0.0862
50 (20x faster)7.70118.7076 ± 0.1021
20 (50x faster)11.65068.4744 ± 0.0879
10 (100x faster)18.95598.0852 ± 0.1137
101 | 102 | 103 | 104 | **Sample with fewer steps**: 105 | 106 |

107 | 108 |

109 | 110 | From top to bottom: 10 steps, 50 steps, 100 steps and 1000 steps. It can be seen that fewer steps leads to blurrier results, but human eyes can hardly distinguish the difference between 50/100 steps and 1000 steps. 111 | 112 | 113 | 114 | **Spherical linear interpolation (slerp) between two samples (sample with 100 steps)**: 115 | 116 |

117 | 118 |

119 | 120 |

121 | 122 |

123 | 124 | 125 | **Reconstruction (sample with 100 steps)**: 126 | 127 |

128 | 129 |

130 | 131 | In each pair, image on the left is the real image sampled from dataset, the other is the reconstructed image generated by DDIM inversion + DDIM sampling. 132 | 133 | 134 | 135 | **Reconstruction (sample with 1000 steps)**: 136 | 137 |

138 | 139 |

140 | -------------------------------------------------------------------------------- /docs/DDPM-IP.md: -------------------------------------------------------------------------------- 1 | # DDPM-IP 2 | 3 | > Ning, Mang, Enver Sangineto, Angelo Porrello, Simone Calderara, and Rita Cucchiara. "Input Perturbation Reduces Exposure Bias in Diffusion Models." *arXiv preprint arXiv:2301.11706* (2023). 4 | 5 | 6 | 7 | ## Training 8 | 9 | Almost the same as DDPM (see [doc](./DDPM.md)), except using `diffusions.ddpm_ip.DDPM_IP` instead of `diffusions.ddpm.DDPM`. 10 | 11 | For example, to train on CIFAR-10 with default settings: 12 | 13 | ```shell 14 | accelerate-launch scripts/train_ddpm.py -c ./configs/ddpm_cifar10.yaml --diffusion.target diffusions.ddpm_ip.DDPM_IP 15 | ``` 16 | 17 | 18 | 19 | ## Sampling & Evaluation 20 | 21 | Exactly the same as DDPM, refer to [doc](./DDPM.md) for more information. 22 | 23 | 24 | 25 | ## Results 26 | 27 | **FID and IS on CIFAR-10 32x32**: 28 | 29 | All the metrics are evaluated on 50K samples using [torch-fidelity](https://torch-fidelity.readthedocs.io/en/latest/index.html) library. 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 |
Type of variancetimestepsFID ↓IS ↑
fixed-large10003.24979.4885 ± 0.09244
10046.79948.5720 ± 0.0917
5087.18836.1429 ± 0.0630
10268.11081.5842 ± 0.0055
fixed-small10004.48689.1092 ± 0.1025
1009.24608.7068 ± 0.0813
5012.79658.4902 ± 0.0701
1035.50627.3680 ± 0.1092
81 | 82 | 83 | The results are substantially better than DDPM for fixed-small variance, but not for fixed-large variance. 84 | 85 | -------------------------------------------------------------------------------- /docs/ILVR.md: -------------------------------------------------------------------------------- 1 | # ILVR 2 | 3 | Iterative Latent Variable Refinement (ILVR). 4 | 5 | > Choi, Jooyoung, Sungwon Kim, Yonghyun Jeong, Youngjune Gwon, and Sungroh Yoon. “ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models.” In 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pp. 14347-14356. IEEE, 2021. 6 | 7 | 8 | 9 | ## Sampling 10 | 11 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms. 12 | 13 | ```shell 14 | accelerate-launch scripts/sample_ilvr.py -c CONFIG \ 15 | --weights WEIGHTS \ 16 | --input_dir INPUT_DIR \ 17 | --save_dir SAVE_DIR \ 18 | [--seed SEED] \ 19 | [--var_type VAR_TYPE] \ 20 | [--respace_type RESPACEP_TYPE] \ 21 | [--respace_steps RESPACE_STEPS] \ 22 | [--downsample_factor DOWNSAMPLE_FACTOR] \ 23 | [--interp_method {cubic,lanczos2,lanczos3,linear,box}] \ 24 | [--batch_size BATCH_SIZE] 25 | ``` 26 | 27 | Basic arguments: 28 | 29 | - `-c CONFIG`: path to the configuration file. 30 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file. 31 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved. 32 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved. 33 | - `--downsample_factor DOWNSAMPLE_FACTOR`: higher factor leads to more diverse results. 34 | 35 | Advanced arguments: 36 | 37 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps. 38 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices. 39 | 40 | See more details by running `python sample_ilvr.py -h`. 41 | 42 | 43 | 44 | ## Notes 45 | 46 | Using correct image resizing methods ([ResizeRight](https://github.com/assafshocher/ResizeRight)) is **CRUCIAL**! The default resizing functions in PyTorch (`torch.nn.functional.interpolate`) will damage the results. 47 | 48 | 49 | 50 | ## Results 51 | 52 | **Pretrained CelebA-HQ 256x256**: 53 | 54 | Note: All the images are sampled with 50 DDPM steps. 55 | 56 |

57 | 58 |

59 | -------------------------------------------------------------------------------- /docs/Mask Guidance.md: -------------------------------------------------------------------------------- 1 | # Mask Guidance 2 | 3 | Mask Guidance is a technique to fill the masked area in an input image with a pretrained diffusion model. It was first proposed in [1] and further developed in [2], [3], etc. for image inpainting. 4 | 5 | Directly applying mask guidance may lead to inconsistent semantic between masked and unmasked areas. To overcome this problem, RePaint[3] proposed a resampling strategy, which goes forward and backward on the Markov chain from time to time. 6 | 7 | > [1]. Song, Yang, and Stefano Ermon. “Generative modeling by estimating gradients of the data distribution.” 8 | > Advances in neural information processing systems 32 (2019). 9 | > 10 | > [2]. Avrahami, Omri, Dani Lischinski, and Ohad Fried. “Blended diffusion for text-driven editing of natural 11 | > images.” In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 18208 12 | > -18218. 2022. 13 | > 14 | > [3]. Lugmayr, Andreas, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, and Luc Van Gool. “Repaint: 15 | > Inpainting using denoising diffusion probabilistic models.” In Proceedings of the IEEE/CVF Conference on 16 | > Computer Vision and Pattern Recognition, pp. 11461-11471. 2022. 17 | 18 | 19 | 20 | ## Sampling 21 | 22 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms. 23 | 24 | ```shell 25 | accelerate-launch scripts/sample_mask_guidance.py -c CONFIG \ 26 | --weights WEIGHTS \ 27 | --input_dir INPUT_DIR \ 28 | --save_dir SAVE_DIR \ 29 | [--seed SEED] \ 30 | [--var_type VAR_TYPE] \ 31 | [--respace_type RESPACE_TYPE] \ 32 | [--respace_steps RESPACE_STEPS] \ 33 | [--resample] \ 34 | [--resample_r RESAMPLE_R] \ 35 | [--resample_j RESAMPLE_J] \ 36 | [--batch_size BATCH_SIZE] 37 | ``` 38 | 39 | Basic arguments: 40 | 41 | - `-c CONFIG`: path to the configuration file. 42 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file. 43 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved. 44 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved. 45 | - `--resample`: use the resample strategy proposed in RePaint paper[3]. This strategy has two hyperparameters: 46 | - `--resample_r RESAMPLE_R`: number of resampling. 47 | - `--resample_j RESAMPLE_J`: jump lengths. 48 | 49 | Advanced arguments: 50 | 51 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps. 52 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices. 53 | 54 | See more details by running `python sample_mask_guidance.py -h`. 55 | 56 | 57 | 58 | ## Results 59 | 60 | **ImageNet 256x256** with pretrained model from [openai/guided-diffusion](https://github.com/openai/guided-diffusion): 61 | 62 |

63 | 64 |

65 | 66 | Notes: 67 | 68 | - All the images are sampled with 50 DDPM steps. 69 | - Jump length $j$ is fixed to 10. 70 | - $r=1$ is equivalent to the original DDPM sampling (w/o resampling). 71 | 72 | -------------------------------------------------------------------------------- /docs/SDEdit.md: -------------------------------------------------------------------------------- 1 | # SDEdit 2 | 3 | > Meng, Chenlin, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan Zhu, and Stefano Ermon. "Sdedit: Guided image synthesis and editing with stochastic differential equations." In *International Conference on Learning Representations*. 2021. 4 | 5 | 6 | 7 | ## Sampling 8 | 9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms. 10 | 11 | ```shell 12 | accelerate-launch scripts/sample_sdedit.py -c CONFIG \ 13 | --weights WEIGHTS \ 14 | --input_dir INPUT_DIR \ 15 | --save_dir SAVE_DIR \ 16 | --edit_steps EDIT_STEPS \ 17 | [--seed SEED] \ 18 | [--var_type VAR_TYPE] \ 19 | [--respace_type RESPACE_TYPE] \ 20 | [--respace_steps RESPACE_STEPS] \ 21 | [--batch_size BATCH_SIZE] 22 | ``` 23 | 24 | Basic arguments: 25 | 26 | - `-c CONFIG`: path to the configuration file. 27 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file. 28 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved. 29 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved. 30 | - `--edit_steps EDIT_STEPS`: number of edit steps. Controls realism-faithfulness trade-off. 31 | 32 | Advanced arguments: 33 | 34 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps. 35 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices. 36 | 37 | See more details by running `python sample_sdedit.py -h`. 38 | 39 | 40 | 41 | ## Results 42 | 43 | **LSUN-Church 256x256** with pretrained model from [pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion): 44 | 45 |

46 | 47 |

48 | -------------------------------------------------------------------------------- /docs/Samplers.md: -------------------------------------------------------------------------------- 1 | # Samplers: Fidelity-Speed Visualization 2 | 3 | Once a diffusion model is trained, we can use different samplers (SDE / ODE solvers) to generate samples. Currently, this repo supports the following samplers: 4 | 5 | - DDPM 6 | - DDIM 7 | - Euler 8 | 9 | We can choose the number of steps for the samplers to generate samples. Generally speaking, the more steps we use, the better the fidelity of the samples. However, the speed of the sampler also decreases as the number of steps increases. Therefore, it is important to choose the right number of steps to balance the trade-off between fidelity and speed. 10 | 11 | The table and figure below show the trade-off between fidelity and speed of different samplers, based on the same model trained on CIFAR-10 following the standard DDPM. All the metrics are evaluated on 50K samples using [torch-fidelity](https://torch-fidelity.readthedocs.io/en/latest/index.html) library. 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 |
samplerStepsNFEFID ↓IS ↑
DDPM (fixed-large)100010003.04599.4515 ± 0.1179
10010046.54548.7223 ± 0.0923
505085.22216.3863 ± 0.0894
2020183.34682.6885 ± 0.0176
1010266.75401.5870 ± 0.0092
DDPM (fixed-small)100010005.37279.0118 ± 0.0968
10010011.21918.6237 ± 0.0921
505015.04718.4077 ± 0.1623
202024.51317.9957 ± 0.1067
101041.04797.1373 ± 0.0801
DDIM (eta=0)100010004.18929.0626 ± 0.1093
1001006.05088.8424 ± 0.0862
50507.70118.7076 ± 0.1021
202011.65068.4744 ± 0.0879
101018.95598.0852 ± 0.1137
Euler100010004.20999.0678 ± 0.1191
1001006.04698.8511 ± 0.1054
50507.67708.7217 ± 0.1122
202011.66818.4362 ± 0.1151
101018.76988.0287 ± 0.0781
Heun5009994.00469.0509 ± 0.1475
50993.46879.2595 ± 0.1323
25495.87679.4325 ± 0.1308
101929.60888.4687 ± 0.0864
5982.05865.3521 ± 0.0646
177 | 178 | 179 | 180 |

181 | 182 |

183 | 184 | 185 | 186 | Notes: 187 | 188 | - DDPM (fixed-small) is equivalent to DDIM(η=1). 189 | - DDPM (fixed-large) performs better than DDPM (fixed-small) with 1000 steps, but degrades drastically as the number of steps decreases. If you check on the samples from DDPM (fixed-large) (<= 100 steps), you'll find that they still contain noticeable noises. 190 | - Euler sampler and DDIM (η=0) have almost the same performance. 191 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ema import EMA 2 | from .unet import UNet 3 | from .unet_categorial_adagn import UNetCategorialAdaGN 4 | -------------------------------------------------------------------------------- /models/adm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/adm/__init__.py -------------------------------------------------------------------------------- /models/adm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /models/adm/readme.md: -------------------------------------------------------------------------------- 1 | The ADM UNet architecture proposed in [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233). 2 | 3 | Codes are copied and adapted from https://github.com/openai/guided-diffusion. 4 | -------------------------------------------------------------------------------- /models/adm/unet_combined.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .unet import UNetModel 4 | 5 | 6 | class UNetCombined(nn.Module): 7 | """ Combines a conditional-only model and an unconditional-only model into a single nn.Module. 8 | 9 | The guided diffusion models proposed by OpenAI are trained to be either conditional or unconditional, 10 | leading to difficulties if we want to use their pretrained models in classifier-free guidance. This 11 | class wraps a conditional model and an unconditional model, and decides which one to use based on the 12 | input class label. 13 | 14 | """ 15 | def __init__(self, *args, **kwargs): 16 | super().__init__() 17 | assert kwargs.get('num_classes') is not None 18 | self.unet_cond = UNetModel(*args, **kwargs) 19 | kwargs_uncond = kwargs.copy() 20 | kwargs_uncond.update({'num_classes': None}) 21 | self.unet_uncond = UNetModel(*args, **kwargs_uncond) 22 | 23 | def forward(self, x, timesteps, y=None): 24 | unet = self.unet_uncond if y is None else self.unet_cond 25 | return unet(x, timesteps, y) 26 | 27 | def combine_weights(self, cond_path, uncond_path, save_path): 28 | ckpt_cond = torch.load(cond_path, map_location='cpu') 29 | ckpt_uncond = torch.load(uncond_path, map_location='cpu') 30 | self.unet_cond.load_state_dict(ckpt_cond) 31 | self.unet_uncond.load_state_dict(ckpt_uncond) 32 | torch.save(self.state_dict(), save_path) 33 | -------------------------------------------------------------------------------- /models/base_latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | 6 | class BaseLatent(nn.Module): 7 | def __init__(self, scale_factor: float = 1.0): 8 | super().__init__() 9 | self.register_buffer('scale_factor', torch.tensor(scale_factor)) 10 | self.device = self.scale_factor.device 11 | 12 | def to(self, *args, **kwargs): 13 | super().to(*args, **kwargs) 14 | self.device = self.scale_factor.device 15 | return self 16 | 17 | def forward(self, x: Tensor, timesteps: Tensor): 18 | raise NotImplementedError 19 | 20 | def encode_latent(self, x: Tensor): 21 | raise NotImplementedError 22 | 23 | def decode_latent(self, z: Tensor): 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /models/dit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/dit/__init__.py -------------------------------------------------------------------------------- /models/dit/autoencoder.py: -------------------------------------------------------------------------------- 1 | import diffusers 2 | 3 | 4 | def AutoEncoderKL(from_pretrained: str): 5 | return diffusers.AutoencoderKL.from_pretrained(from_pretrained) 6 | -------------------------------------------------------------------------------- /models/dit/dit.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Any 2 | from omegaconf import OmegaConf 3 | 4 | from torch import Tensor 5 | 6 | from ..base_latent import BaseLatent 7 | from utils.misc import instantiate_from_config 8 | 9 | 10 | class DiT(BaseLatent): 11 | def __init__( 12 | self, 13 | vae_config: OmegaConf, 14 | vit_config: OmegaConf, 15 | scale_factor: float = 0.18215, 16 | ): 17 | super().__init__(scale_factor=scale_factor) 18 | 19 | self.vae = instantiate_from_config(vae_config) 20 | self.vit = instantiate_from_config(vit_config) 21 | 22 | def decode_latent(self, z: Tensor): 23 | z = 1. / self.scale_factor * z 24 | return self.vae.decode(z).sample 25 | 26 | def vit_forward(self, x: Tensor, t: Tensor, y: Tensor): 27 | return self.vit(x, t, y) 28 | 29 | def forward(self, x: Tensor, timesteps: Tensor, y: Tensor = None): 30 | return self.vit_forward(x, timesteps, y) 31 | 32 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): 33 | self.vit.load_state_dict(state_dict, strict=strict, assign=assign) 34 | -------------------------------------------------------------------------------- /models/dit/readme.md: -------------------------------------------------------------------------------- 1 | The DiT architecture proposed in [Scalable Diffusion Models with Transformers](http://arxiv.org/abs/2212.09748). 2 | 3 | Codes are copied and adapted from https://github.com/facebookresearch/DiT. 4 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class EMA: 8 | """Exponential moving average of model parameters. 9 | 10 | References: 11 | - https://github.com/huggingface/diffusers/blob/main/src/diffusers/training_utils.py#L76 12 | - https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 13 | - https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/ema.py#L5 14 | - https://github.com/lucidrains/ema-pytorch 15 | 16 | """ 17 | def __init__( 18 | self, 19 | parameters: Iterable[nn.Parameter], 20 | decay: float = 0.9999, 21 | gradual: bool = True, 22 | ): 23 | """ 24 | Args: 25 | parameters: Iterable of parameters, typically from `model.parameters()`. 26 | decay: The decay factor for exponential moving average. 27 | gradual: Whether to a gradually increasing decay factor. 28 | 29 | """ 30 | super().__init__() 31 | self.decay = decay 32 | self.gradual = gradual 33 | 34 | self.num_updates = 0 35 | self.shadow = [param.detach().clone() for param in parameters] 36 | self.backup = [] 37 | 38 | def get_decay(self): 39 | if self.gradual: 40 | return min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 41 | else: 42 | return self.decay 43 | 44 | @torch.no_grad() 45 | def update(self, parameters: Iterable[nn.Parameter]): 46 | self.num_updates += 1 47 | decay = self.get_decay() 48 | for s_param, param in zip(self.shadow, parameters): 49 | if param.requires_grad: 50 | s_param.sub_((1. - decay) * (s_param - param)) 51 | else: 52 | s_param.copy_(param) 53 | 54 | def apply_shadow(self, parameters: Iterable[nn.Parameter]): 55 | assert len(self.backup) == 0, 'backup is not empty' 56 | for s_param, param in zip(self.shadow, parameters): 57 | self.backup.append(param.detach().cpu().clone()) 58 | param.data.copy_(s_param.data) 59 | 60 | def restore(self, parameters: Iterable[nn.Parameter]): 61 | assert len(self.backup) > 0, 'backup is empty' 62 | for b_param, param in zip(self.backup, parameters): 63 | param.data.copy_(b_param.to(param.device).data) 64 | self.backup = [] 65 | 66 | def state_dict(self): 67 | return dict( 68 | decay=self.decay, 69 | shadow=self.shadow, 70 | num_updates=self.num_updates, 71 | ) 72 | 73 | def load_state_dict(self, state_dict): 74 | self.decay = state_dict['decay'] 75 | self.shadow = state_dict['shadow'] 76 | self.num_updates = state_dict['num_updates'] 77 | 78 | def to(self, device): 79 | self.shadow = [s_param.to(device) for s_param in self.shadow] 80 | 81 | 82 | def _test(): 83 | # initialize to 0 84 | model = nn.Sequential(nn.Linear(5, 1)) 85 | for p in model[0].parameters(): 86 | p.data.fill_(0) 87 | ema = EMA(model.parameters(), decay=0.9, gradual=False) 88 | print(model.state_dict()) # 0 89 | print(ema.state_dict()['shadow']) # 0 90 | print() 91 | 92 | # update the model to 1 93 | for p in model[0].parameters(): 94 | p.data.fill_(1) 95 | ema.update(model.parameters()) 96 | print(model.state_dict()) # 1 97 | print(ema.state_dict()['shadow']) # 0.9 * 0 + 0.1 * 1 = 0.1 98 | print() 99 | 100 | # update the model to 2 101 | for p in model[0].parameters(): 102 | p.data.fill_(2) 103 | ema.update(model.parameters()) 104 | print(model.state_dict()) # 2 105 | print(ema.state_dict()['shadow']) # 0.9 * 0.1 + 0.1 * 2 = 0.29 106 | print() 107 | 108 | # apply shadow 109 | ema.apply_shadow(model.parameters()) 110 | print(model.state_dict()) # 0.29 111 | print(ema.state_dict()['shadow']) # 0.29 112 | print() 113 | 114 | # restore 115 | ema.restore(model.parameters()) 116 | print(model.state_dict()) # 2 117 | print(ema.state_dict()['shadow']) # 0.29 118 | 119 | 120 | if __name__ == '__main__': 121 | _test() 122 | -------------------------------------------------------------------------------- /models/mdt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/mdt/__init__.py -------------------------------------------------------------------------------- /models/mdt/autoencoder.py: -------------------------------------------------------------------------------- 1 | import diffusers 2 | 3 | 4 | def AutoEncoderKL(from_pretrained: str): 5 | return diffusers.AutoencoderKL.from_pretrained(from_pretrained) 6 | -------------------------------------------------------------------------------- /models/mdt/mdt.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Any 2 | from omegaconf import OmegaConf 3 | 4 | from torch import Tensor 5 | 6 | from ..base_latent import BaseLatent 7 | from utils.misc import instantiate_from_config 8 | 9 | 10 | class MDT(BaseLatent): 11 | def __init__( 12 | self, 13 | vae_config: OmegaConf, 14 | vit_config: OmegaConf, 15 | scale_factor: float = 0.18215, 16 | ): 17 | super().__init__(scale_factor=scale_factor) 18 | 19 | self.vae = instantiate_from_config(vae_config) 20 | self.vit = instantiate_from_config(vit_config) 21 | 22 | def decode_latent(self, z: Tensor): 23 | z = 1. / self.scale_factor * z 24 | return self.vae.decode(z).sample 25 | 26 | def vit_forward(self, x: Tensor, t: Tensor, y: Tensor): 27 | return self.vit(x, t, y) 28 | 29 | def forward(self, x: Tensor, timesteps: Tensor, y: Tensor = None): 30 | return self.vit_forward(x, timesteps, y) 31 | 32 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): 33 | self.vit.load_state_dict(state_dict, strict=strict, assign=assign) 34 | -------------------------------------------------------------------------------- /models/mdt/readme.md: -------------------------------------------------------------------------------- 1 | The MDT architecture proposed in [Masked Diffusion Transformer is a Strong Image Synthesizer](https://arxiv.org/abs/2303.14389). 2 | 3 | Codes are copied and adapted from https://github.com/sail-sg/MDT. 4 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | 8 | def init_weights(init_type=None, gain=0.02): 9 | 10 | def init_func(m): 11 | classname = m.__class__.__name__ 12 | 13 | if classname.find('BatchNorm') != -1: 14 | if hasattr(m, 'weight') and m.weight is not None: 15 | nn.init.normal_(m.weight, 1.0, gain) 16 | if hasattr(m, 'bias') and m.bias is not None: 17 | nn.init.constant_(m.bias, 0.0) 18 | 19 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 20 | if init_type == 'normal': 21 | nn.init.normal_(m.weight, 0.0, gain) 22 | elif init_type == 'xavier': 23 | nn.init.xavier_normal_(m.weight, gain=gain) 24 | elif init_type == 'xavier_uniform': 25 | nn.init.xavier_uniform_(m.weight, gain=1.0) 26 | elif init_type == 'kaiming': 27 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 28 | elif init_type == 'orthogonal': 29 | nn.init.orthogonal_(m.weight, gain=gain) 30 | elif init_type is None: 31 | m.reset_parameters() 32 | else: 33 | raise ValueError(f'invalid initialization method: {init_type}.') 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | nn.init.constant_(m.bias, 0.0) 36 | 37 | return init_func 38 | 39 | 40 | class SinusoidalPosEmb(nn.Module): 41 | def __init__(self, dim: int): 42 | super().__init__() 43 | self.dim = dim 44 | 45 | def forward(self, X: Tensor): 46 | """ 47 | Args: 48 | X (Tensor): [bs] 49 | Returns: 50 | Sinusoidal embeddings of shape [bs, dim] 51 | """ 52 | half_dim = self.dim // 2 53 | embed = math.log(10000) / (half_dim - 1) 54 | embed = torch.exp(torch.arange(half_dim, device=X.device) * -embed) 55 | embed = X[:, None] * embed[None, :] 56 | embed = torch.cat((embed.sin(), embed.cos()), dim=-1) 57 | return embed 58 | 59 | 60 | def Upsample(in_channels: int, out_channels: int, use_conv: bool = True): 61 | if use_conv: 62 | return nn.Sequential( 63 | nn.Upsample(scale_factor=2, mode='nearest'), 64 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 65 | ) 66 | else: 67 | return nn.Upsample(scale_factor=2, mode='nearest') 68 | 69 | 70 | def Downsample(in_channels: int, out_channels: int, use_conv: bool = True): 71 | if use_conv: 72 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 73 | else: 74 | return nn.AvgPool2d(kernel_size=2, stride=2) 75 | 76 | 77 | class SelfAttentionBlock(nn.Module): 78 | def __init__(self, dim: int, n_heads: int = 1, groups: int = 32): 79 | super().__init__() 80 | assert dim % n_heads == 0 81 | self.n_heads = n_heads 82 | self.norm = nn.GroupNorm(groups, dim) 83 | self.q = nn.Conv2d(dim, dim, kernel_size=1) 84 | self.k = nn.Conv2d(dim, dim, kernel_size=1) 85 | self.v = nn.Conv2d(dim, dim, kernel_size=1) 86 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 87 | self.scale = (dim // n_heads) ** -0.5 88 | 89 | def forward(self, X: Tensor, return_attn_map: bool = False): 90 | bs, C, H, W = X.shape 91 | normX = self.norm(X) 92 | q = self.q(normX).view(bs * self.n_heads, -1, H*W) 93 | k = self.k(normX).view(bs * self.n_heads, -1, H*W) 94 | v = self.v(normX).view(bs * self.n_heads, -1, H*W) 95 | q = q * self.scale 96 | attn = torch.bmm(q.permute(0, 2, 1), k).softmax(dim=-1) 97 | output = torch.bmm(v, attn.permute(0, 2, 1)).view(bs, -1, H, W) 98 | output = self.proj(output) 99 | if not return_attn_map: 100 | return output + X 101 | else: 102 | return output + X, attn.view(bs, self.n_heads, H*W, H*W) 103 | 104 | 105 | class AdaGN(nn.Module): 106 | def __init__(self, num_groups: int, num_channels: int, embed_dim: int): 107 | super().__init__() 108 | self.gn = nn.GroupNorm(num_groups, num_channels) 109 | self.proj = nn.Sequential( 110 | nn.SiLU(), 111 | nn.Linear(embed_dim, num_channels * 2), 112 | ) 113 | 114 | def forward(self, X: Tensor, embed: Tensor): 115 | """ 116 | Args: 117 | X (Tensor): [bs, C, H, W] 118 | embed (Tensor): [bs, embed_dim] 119 | """ 120 | ys, yb = torch.chunk(self.proj(embed), 2, dim=-1) 121 | ys = ys[:, :, None, None] 122 | yb = yb[:, :, None, None] 123 | return self.gn(X) * (1 + ys) + yb 124 | -------------------------------------------------------------------------------- /models/pesser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/pesser/__init__.py -------------------------------------------------------------------------------- /models/pesser/readme.md: -------------------------------------------------------------------------------- 1 | The DDPM UNet architecture proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239). 2 | 3 | The original codes are in TensorFlow, so we use the PyTorch version developed by [pesser](https://github.com/pesser). 4 | 5 | Codes are copied and adapted from https://github.com/pesser/pytorch_diffusion. 6 | -------------------------------------------------------------------------------- /models/sdxl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/sdxl/__init__.py -------------------------------------------------------------------------------- /models/sdxl/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=(1, 2, 3)): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /models/sdxl/readme.md: -------------------------------------------------------------------------------- 1 | The StableDiffusion XL model architectures. 2 | 3 | Codes are copied and adapted from https://github.com/Stability-AI/generative-models. Files are reorganized. Un-used functions and classes are removed. Some PEP8 warnings are resolved. 4 | -------------------------------------------------------------------------------- /models/sdxl/regularizers.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from .distributions import DiagonalGaussianDistribution 9 | 10 | 11 | class AbstractRegularizer(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 16 | raise NotImplementedError() 17 | 18 | @abstractmethod 19 | def get_trainable_parameters(self) -> Any: 20 | raise NotImplementedError() 21 | 22 | 23 | class IdentityRegularizer(AbstractRegularizer): 24 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 25 | return z, dict() 26 | 27 | def get_trainable_parameters(self) -> Any: 28 | yield from () 29 | 30 | 31 | def measure_perplexity( 32 | predicted_indices: torch.Tensor, num_centroids: int 33 | ) -> Tuple[torch.Tensor, torch.Tensor]: 34 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 35 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 36 | encodings = ( 37 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 38 | ) 39 | avg_probs = encodings.mean(0) 40 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 41 | cluster_use = torch.sum(avg_probs > 0) 42 | return perplexity, cluster_use 43 | 44 | 45 | class DiagonalGaussianRegularizer(AbstractRegularizer): 46 | def __init__(self, sample: bool = True): 47 | super().__init__() 48 | self.sample = sample 49 | 50 | def get_trainable_parameters(self) -> Any: 51 | yield from () 52 | 53 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 54 | log = dict() 55 | posterior = DiagonalGaussianDistribution(z) 56 | if self.sample: 57 | z = posterior.sample() 58 | else: 59 | z = posterior.mode() 60 | kl_loss = posterior.kl() 61 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 62 | log["kl_loss"] = kl_loss 63 | return z, log 64 | -------------------------------------------------------------------------------- /models/sdxl/stablediffusion.py: -------------------------------------------------------------------------------- 1 | from typing import List, Mapping, Any, Dict 2 | from omegaconf import OmegaConf 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from ..base_latent import BaseLatent 8 | from utils.misc import instantiate_from_config 9 | 10 | 11 | class StableDiffusion(BaseLatent): 12 | def __init__( 13 | self, 14 | conditioner_config: OmegaConf, 15 | vae_config: OmegaConf, 16 | unet_config: OmegaConf, 17 | scale_factor: float = 0.13025, 18 | low_vram_shift_enabled: bool = False, 19 | ): 20 | super().__init__(scale_factor=scale_factor) 21 | 22 | self.conditioner = instantiate_from_config(conditioner_config) 23 | self.vae = instantiate_from_config(vae_config) 24 | self.unet = instantiate_from_config(unet_config) 25 | 26 | self.low_vram_shift_enabled = low_vram_shift_enabled 27 | 28 | def encode_latent(self, x: Tensor): 29 | if self.low_vram_shift_enabled: 30 | self.conditioner.to('cpu') 31 | self.unet.to('cpu') 32 | self.vae.to(self.device) 33 | torch.cuda.empty_cache() 34 | z = self.vae.encode(x) 35 | return self.scale_factor * z 36 | 37 | def decode_latent(self, z: Tensor): 38 | if self.low_vram_shift_enabled: 39 | self.conditioner.to('cpu') 40 | self.unet.to('cpu') 41 | self.vae.to(self.device) 42 | torch.cuda.empty_cache() 43 | z = 1. / self.scale_factor * z 44 | return self.vae.decode(z) 45 | 46 | def conditioner_forward(self, text: List[str], H: int, W: int): 47 | if self.low_vram_shift_enabled: 48 | self.vae.to('cpu') 49 | self.unet.to('cpu') 50 | self.conditioner.to(self.device) 51 | torch.cuda.empty_cache() 52 | batch = dict( 53 | txt=text, 54 | original_size_as_tuple=torch.tensor([1024, 1024], device=self.device).repeat(len(text), 1), 55 | crop_coords_top_left=torch.tensor([0, 0], device=self.device).repeat(len(text), 1), 56 | target_size_as_tuple=torch.tensor([H, W], device=self.device).repeat(len(text), 1), 57 | ) 58 | return self.conditioner(batch) 59 | 60 | def unet_forward(self, x: Tensor, timesteps: Tensor, context: Tensor, y: Tensor): 61 | if self.low_vram_shift_enabled: 62 | self.vae.to('cpu') 63 | self.conditioner.to('cpu') 64 | self.unet.to(self.device) 65 | torch.cuda.empty_cache() 66 | return self.unet(x, timesteps=timesteps, context=context, y=y) 67 | 68 | def forward( 69 | self, x: Tensor, timesteps: Tensor, condition_dict: Dict = None, 70 | text: List[str] = None, H: int = None, W: int = None, 71 | ): 72 | if condition_dict is None: 73 | if text is None or H is None or W is None: 74 | raise ValueError('text, H and W must be provided when `condition_dict` is not provided.') 75 | condition_dict = self.conditioner_forward(text, H, W) 76 | context = condition_dict.get('crossattn') 77 | y = condition_dict.get('vector') 78 | x = self.unet_forward(x, timesteps=timesteps, context=context, y=y) 79 | return x 80 | 81 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): 82 | # state_dict_conditioner = {k[12:]: v for k, v in state_dict.items() if k.startswith('conditioner.')} 83 | state_dict_vae = {k[18:]: v for k, v in state_dict.items() if k.startswith('first_stage_model.')} 84 | state_dict_unet = {k[22:]: v for k, v in state_dict.items() if k.startswith('model.diffusion_model.')} 85 | # self.conditioner.load_state_dict(state_dict_conditioner, strict=strict, assign=assign) 86 | self.vae.load_state_dict(state_dict_vae, strict=strict, assign=assign) 87 | self.unet.load_state_dict(state_dict_unet, strict=strict, assign=assign) 88 | # del state_dict_conditioner 89 | del state_dict_vae 90 | del state_dict_unet 91 | -------------------------------------------------------------------------------- /models/stablediffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/stablediffusion/__init__.py -------------------------------------------------------------------------------- /models/stablediffusion/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=(1, 2, 3)): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /models/stablediffusion/modules.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | from einops import repeat 14 | 15 | 16 | def checkpoint(func, inputs, params, flag): 17 | """ 18 | Evaluate a function without caching intermediate activations, allowing for 19 | reduced memory at the expense of extra compute in the backward pass. 20 | :param func: the function to evaluate. 21 | :param inputs: the argument sequence to pass to `func`. 22 | :param params: a sequence of parameters `func` depends on but does not 23 | explicitly take as arguments. 24 | :param flag: if False, disable gradient checkpointing. 25 | """ 26 | if flag: 27 | args = tuple(inputs) + tuple(params) 28 | return CheckpointFunction.apply(func, len(inputs), *args) 29 | else: 30 | return func(*inputs) 31 | 32 | 33 | class CheckpointFunction(torch.autograd.Function): 34 | @staticmethod 35 | def forward(ctx, run_function, length, *args): 36 | ctx.run_function = run_function 37 | ctx.input_tensors = list(args[:length]) 38 | ctx.input_params = list(args[length:]) 39 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 40 | "dtype": torch.get_autocast_gpu_dtype(), 41 | "cache_enabled": torch.is_autocast_cache_enabled()} 42 | with torch.no_grad(): 43 | output_tensors = ctx.run_function(*ctx.input_tensors) 44 | return output_tensors 45 | 46 | @staticmethod 47 | def backward(ctx, *output_grads): 48 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 49 | with torch.enable_grad(), \ 50 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 51 | # Fixes a bug where the first op in run_function modifies the 52 | # Tensor storage in place, which is not allowed for detach()'d 53 | # Tensors. 54 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 55 | output_tensors = ctx.run_function(*shallow_copies) 56 | input_grads = torch.autograd.grad( 57 | output_tensors, 58 | ctx.input_tensors + ctx.input_params, 59 | output_grads, 60 | allow_unused=True, 61 | ) 62 | del ctx.input_tensors 63 | del ctx.input_params 64 | del output_tensors 65 | return (None, None) + input_grads 66 | 67 | 68 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 69 | """ 70 | Create sinusoidal timestep embeddings. 71 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 72 | These may be fractional. 73 | :param dim: the dimension of the output. 74 | :param max_period: controls the minimum frequency of the embeddings. 75 | :return: an [N x dim] Tensor of positional embeddings. 76 | """ 77 | if not repeat_only: 78 | half = dim // 2 79 | freqs = torch.exp( 80 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 81 | ).to(device=timesteps.device) 82 | args = timesteps[:, None].float() * freqs[None] 83 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 84 | if dim % 2: 85 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 86 | else: 87 | embedding = repeat(timesteps, 'b -> b d', d=dim) 88 | return embedding 89 | 90 | 91 | def zero_module(module): 92 | """ 93 | Zero out the parameters of a module and return it. 94 | """ 95 | for p in module.parameters(): 96 | p.detach().zero_() 97 | return module 98 | 99 | 100 | def scale_module(module, scale): 101 | """ 102 | Scale the parameters of a module and return it. 103 | """ 104 | for p in module.parameters(): 105 | p.detach().mul_(scale) 106 | return module 107 | 108 | 109 | def mean_flat(tensor): 110 | """ 111 | Take the mean over all non-batch dimensions. 112 | """ 113 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 114 | 115 | 116 | def normalization(channels): 117 | """ 118 | Make a standard normalization layer. 119 | :param channels: number of input channels. 120 | :return: an nn.Module for normalization. 121 | """ 122 | return GroupNorm32(32, channels) 123 | 124 | 125 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 126 | class SiLU(nn.Module): 127 | def forward(self, x): 128 | return x * torch.sigmoid(x) 129 | 130 | 131 | class GroupNorm32(nn.GroupNorm): 132 | def forward(self, x): 133 | return super().forward(x.float()).type(x.dtype) 134 | 135 | 136 | def conv_nd(dims, *args, **kwargs): 137 | """ 138 | Create a 1D, 2D, or 3D convolution module. 139 | """ 140 | if dims == 1: 141 | return nn.Conv1d(*args, **kwargs) 142 | elif dims == 2: 143 | return nn.Conv2d(*args, **kwargs) 144 | elif dims == 3: 145 | return nn.Conv3d(*args, **kwargs) 146 | raise ValueError(f"unsupported dimensions: {dims}") 147 | 148 | 149 | def linear(*args, **kwargs): 150 | """ 151 | Create a linear module. 152 | """ 153 | return nn.Linear(*args, **kwargs) 154 | 155 | 156 | def avg_pool_nd(dims, *args, **kwargs): 157 | """ 158 | Create a 1D, 2D, or 3D average pooling module. 159 | """ 160 | if dims == 1: 161 | return nn.AvgPool1d(*args, **kwargs) 162 | elif dims == 2: 163 | return nn.AvgPool2d(*args, **kwargs) 164 | elif dims == 3: 165 | return nn.AvgPool3d(*args, **kwargs) 166 | raise ValueError(f"unsupported dimensions: {dims}") 167 | -------------------------------------------------------------------------------- /models/stablediffusion/readme.md: -------------------------------------------------------------------------------- 1 | The StableDiffusion model architectures. 2 | 3 | Codes are copied and adapted from https://github.com/Stability-AI/stablediffusion. Files are reorganized. Un-used functions and classes are removed. Some PEP8 warnings are resolved. 4 | -------------------------------------------------------------------------------- /models/stablediffusion/stablediffusion.py: -------------------------------------------------------------------------------- 1 | from typing import List, Mapping, Any 2 | from omegaconf import OmegaConf 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from .distributions import DiagonalGaussianDistribution 8 | from ..base_latent import BaseLatent 9 | from utils.misc import instantiate_from_config 10 | 11 | 12 | class StableDiffusion(BaseLatent): 13 | def __init__( 14 | self, 15 | text_encoder_config: OmegaConf, 16 | vae_config: OmegaConf, 17 | unet_config: OmegaConf, 18 | scale_factor: float = 0.18215, 19 | low_vram_shift_enabled: bool = False, 20 | ): 21 | super().__init__(scale_factor=scale_factor) 22 | 23 | self.text_encoder = instantiate_from_config(text_encoder_config) 24 | self.vae = instantiate_from_config(vae_config) 25 | self.unet = instantiate_from_config(unet_config) 26 | 27 | self.low_vram_shift_enabled = low_vram_shift_enabled 28 | 29 | def encode_latent(self, x: Tensor): 30 | if self.low_vram_shift_enabled: 31 | self.text_encoder.to('cpu') 32 | self.unet.to('cpu') 33 | self.vae.to(self.device) 34 | torch.cuda.empty_cache() 35 | z = self.vae.encode(x) 36 | if isinstance(z, DiagonalGaussianDistribution): 37 | z = z.sample() 38 | return self.scale_factor * z 39 | 40 | def decode_latent(self, z: Tensor): 41 | if self.low_vram_shift_enabled: 42 | self.text_encoder.to('cpu') 43 | self.unet.to('cpu') 44 | self.vae.to(self.device) 45 | torch.cuda.empty_cache() 46 | z = 1. / self.scale_factor * z 47 | return self.vae.decode(z) 48 | 49 | def text_encoder_encode(self, text: List[str]): 50 | if self.low_vram_shift_enabled: 51 | self.vae.to('cpu') 52 | self.unet.to('cpu') 53 | self.text_encoder.to(self.device) 54 | torch.cuda.empty_cache() 55 | return self.text_encoder.encode(text) 56 | 57 | def unet_forward(self, x: Tensor, timesteps: Tensor, context: Tensor): 58 | if self.low_vram_shift_enabled: 59 | self.vae.to('cpu') 60 | self.text_encoder.to('cpu') 61 | self.unet.to(self.device) 62 | torch.cuda.empty_cache() 63 | return self.unet(x, timesteps=timesteps, context=context) 64 | 65 | def forward(self, x: Tensor, timesteps: Tensor, text_embed: Tensor = None, text: List[str] = None): 66 | if text_embed is None and text is None: 67 | raise ValueError('Either `text_embed` or `text` must be provided.') 68 | if text_embed is None: 69 | text_embed = self.text_encoder_encode(text) 70 | x = self.unet_forward(x, timesteps=timesteps, context=text_embed) 71 | return x 72 | 73 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): 74 | state_dict_vae = {k[18:]: v for k, v in state_dict.items() if k.startswith('first_stage_model.')} 75 | state_dict_unet = {k[22:]: v for k, v in state_dict.items() if k.startswith('model.diffusion_model.')} 76 | self.vae.load_state_dict(state_dict_vae, strict=strict) 77 | self.unet.load_state_dict(state_dict_unet, strict=strict) 78 | del state_dict_vae 79 | del state_dict_unet 80 | -------------------------------------------------------------------------------- /models/stablediffusion/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from inspect import isfunction 3 | 4 | import torch 5 | 6 | 7 | def ismap(x): 8 | if not isinstance(x, torch.Tensor): 9 | return False 10 | return (len(x.shape) == 4) and (x.shape[1] > 3) 11 | 12 | 13 | def isimage(x): 14 | if not isinstance(x, torch.Tensor): 15 | return False 16 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 17 | 18 | 19 | def exists(x): 20 | return x is not None 21 | 22 | 23 | def default(val, d): 24 | if exists(val): 25 | return val 26 | return d() if isfunction(d) else d 27 | 28 | 29 | def count_params(model, verbose=False): 30 | total_params = sum(p.numel() for p in model.parameters()) 31 | if verbose: 32 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 33 | return total_params 34 | 35 | 36 | def instantiate_from_config(config): 37 | if "target" not in config: 38 | if config == '__is_first_stage__': 39 | return None 40 | elif config == "__is_unconditional__": 41 | return None 42 | raise KeyError("Expected key `target` to instantiate.") 43 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 44 | 45 | 46 | def get_obj_from_str(string, reload=False): 47 | module, cls = string.rsplit(".", 1) 48 | if reload: 49 | module_imp = importlib.import_module(module) 50 | importlib.reload(module_imp) 51 | return getattr(importlib.import_module(module, package=None), cls) 52 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from models.modules import SinusoidalPosEmb, SelfAttentionBlock, Downsample, Upsample 8 | 9 | 10 | class ResBlock(nn.Module): 11 | def __init__(self, in_channels: int, out_channels: int, embed_dim: int, dropout: float = 0.1): 12 | super().__init__() 13 | self.blk1 = nn.Sequential( 14 | nn.GroupNorm(32, in_channels), 15 | nn.SiLU(), 16 | nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1), 17 | ) 18 | self.proj = nn.Sequential( 19 | nn.SiLU(), 20 | nn.Linear(embed_dim, out_channels), 21 | ) 22 | self.blk2 = nn.Sequential( 23 | nn.GroupNorm(32, out_channels), 24 | nn.SiLU(), 25 | nn.Dropout(dropout), 26 | nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1), 27 | ) 28 | self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() 29 | 30 | def forward(self, X: Tensor, time_embed: Tensor = None): 31 | """ 32 | Args: 33 | X (Tensor): [bs, C, H, W] 34 | time_embed (Tensor): [bs, embed_dim] 35 | Returns: 36 | [bs, C', H, W] 37 | """ 38 | shortcut = self.shortcut(X) 39 | X = self.blk1(X) 40 | if time_embed is not None: 41 | X = X + self.proj(time_embed)[:, :, None, None] 42 | X = self.blk2(X) 43 | return X + shortcut 44 | 45 | 46 | class UNet(nn.Module): 47 | def __init__( 48 | self, 49 | in_channels: int = 3, 50 | out_channels: int = 3, 51 | dim: int = 128, 52 | dim_mults: List[int] = (1, 2, 2, 2), 53 | use_attn: List[int] = (False, True, False, False), 54 | num_res_blocks: int = 2, 55 | n_heads: int = 1, 56 | dropout: float = 0.1, 57 | ): 58 | super().__init__() 59 | n_stages = len(dim_mults) 60 | dims = [dim] 61 | 62 | # Time embeddings 63 | time_embed_dim = dim * 4 64 | self.time_embed = nn.Sequential( 65 | SinusoidalPosEmb(dim), 66 | nn.Linear(dim, time_embed_dim), 67 | nn.SiLU(), 68 | nn.Linear(time_embed_dim, time_embed_dim), 69 | ) 70 | 71 | # First convolution 72 | self.first_conv = nn.Conv2d(in_channels, dim, 3, stride=1, padding=1) 73 | cur_dim = dim 74 | 75 | # Down-sample blocks 76 | # Default: 32x32 -> 16x16 -> 8x8 -> 4x4 77 | self.down_blocks = nn.ModuleList([]) 78 | for i in range(n_stages): 79 | out_dim = dim * dim_mults[i] 80 | stage_blocks = nn.ModuleList([]) 81 | for j in range(num_res_blocks): 82 | stage_blocks.append(ResBlock(cur_dim, out_dim, embed_dim=time_embed_dim, dropout=dropout)) 83 | if use_attn[i]: 84 | stage_blocks.append(SelfAttentionBlock(out_dim, n_heads=n_heads)) 85 | dims.append(out_dim) 86 | cur_dim = out_dim 87 | if i < n_stages - 1: 88 | stage_blocks.append(Downsample(out_dim, out_dim)) 89 | dims.append(out_dim) 90 | self.down_blocks.append(stage_blocks) 91 | 92 | # Bottleneck block 93 | self.bottleneck_block = nn.ModuleList([ 94 | ResBlock(cur_dim, cur_dim, embed_dim=time_embed_dim, dropout=dropout), 95 | SelfAttentionBlock(cur_dim), 96 | ResBlock(cur_dim, cur_dim, embed_dim=time_embed_dim, dropout=dropout), 97 | ]) 98 | 99 | # Up-sample blocks 100 | # Default: 4x4 -> 8x8 -> 16x16 -> 32x32 101 | self.up_blocks = nn.ModuleList([]) 102 | for i in range(n_stages-1, -1, -1): 103 | out_dim = dim * dim_mults[i] 104 | stage_blocks = nn.ModuleList([]) 105 | for j in range(num_res_blocks + 1): 106 | stage_blocks.append(ResBlock(dims.pop() + cur_dim, out_dim, embed_dim=time_embed_dim, dropout=dropout)) 107 | if use_attn[i]: 108 | stage_blocks.append(SelfAttentionBlock(out_dim, n_heads=n_heads)) 109 | cur_dim = out_dim 110 | if i > 0: 111 | stage_blocks.append(Upsample(out_dim, out_dim)) 112 | self.up_blocks.append(stage_blocks) 113 | 114 | # Last convolution 115 | self.last_conv = nn.Sequential( 116 | nn.GroupNorm(32, cur_dim), 117 | nn.SiLU(), 118 | nn.Conv2d(cur_dim, out_channels, 3, stride=1, padding=1), 119 | ) 120 | 121 | def forward(self, X: Tensor, T: Tensor): 122 | time_embed = self.time_embed(T) 123 | X = self.first_conv(X) 124 | skips = [X] 125 | 126 | for stage_blocks in self.down_blocks: 127 | for blk in stage_blocks: # noqa 128 | if isinstance(blk, ResBlock): 129 | X = blk(X, time_embed) 130 | skips.append(X) 131 | elif isinstance(blk, SelfAttentionBlock): 132 | X = blk(X) 133 | skips[-1] = X 134 | else: # Downsample 135 | X = blk(X) 136 | skips.append(X) 137 | 138 | X = self.bottleneck_block[0](X, time_embed) 139 | X = self.bottleneck_block[1](X) 140 | X = self.bottleneck_block[2](X, time_embed) 141 | 142 | for stage_blocks in self.up_blocks: 143 | for blk in stage_blocks: # noqa 144 | if isinstance(blk, ResBlock): 145 | X = blk(torch.cat((X, skips.pop()), dim=1), time_embed) 146 | elif isinstance(blk, SelfAttentionBlock): 147 | X = blk(X) 148 | else: # Upsample 149 | X = blk(X) 150 | 151 | X = self.last_conv(X) 152 | return X 153 | 154 | 155 | def _test(): 156 | unet = UNet() 157 | X = torch.empty((10, 3, 32, 32)) 158 | T = torch.arange(10) 159 | out = unet(X, T) 160 | print(out.shape) 161 | print(sum(p.numel() for p in unet.parameters())) 162 | 163 | unet = UNet( 164 | in_channels=1, 165 | out_channels=1, 166 | dim=128, 167 | dim_mults=[1, 1, 2, 2, 4, 4], 168 | use_attn=[False, False, False, False, True, False], 169 | dropout=0.0, 170 | ) 171 | X = torch.empty((10, 1, 256, 256)) 172 | T = torch.arange(10) 173 | out = unet(X, T) 174 | print(out.shape) 175 | print(sum(p.numel() for p in unet.parameters())) 176 | 177 | 178 | if __name__ == '__main__': 179 | _test() 180 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | einops 3 | omegaconf 4 | matplotlib 5 | tensorboard 6 | numpy 7 | pandas 8 | Pillow 9 | torch~=2.1.0 10 | torchvision~=0.16.0 11 | accelerate~=0.23.0 12 | transformers~=4.37.2 13 | diffusers~=0.21.4 14 | safetensors~=0.4.2 15 | open-clip-torch 16 | timm~=0.9.12 17 | streamlit~=1.31.0 18 | -------------------------------------------------------------------------------- /scripts/sample_clip_guidance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import math 6 | import argparse 7 | from omegaconf import OmegaConf 8 | 9 | import torch 10 | import accelerate 11 | from torchvision.utils import save_image 12 | 13 | import diffusions 14 | from utils.logger import get_logger 15 | from utils.load import load_weights 16 | from utils.misc import image_norm_to_float, instantiate_from_config, amortize 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | '-c', '--config', type=str, required=True, 23 | help='Path to inference configuration file', 24 | ) 25 | parser.add_argument( 26 | '--seed', type=int, default=2022, 27 | help='Set random seed', 28 | ) 29 | parser.add_argument( 30 | '--weights', type=str, required=True, 31 | help='Path to pretrained model weights', 32 | ) 33 | parser.add_argument( 34 | '--var_type', type=str, default=None, 35 | help='Type of variance of the reverse process', 36 | ) 37 | parser.add_argument( 38 | '--respace_type', type=str, default='uniform', 39 | help='Type of respaced timestep sequence', 40 | ) 41 | parser.add_argument( 42 | '--respace_steps', type=int, default=None, 43 | help='Length of respaced timestep sequence', 44 | ) 45 | parser.add_argument( 46 | '--text', type=str, required=True, 47 | help='Text description of the generated image', 48 | ) 49 | parser.add_argument( 50 | '--guidance_weight', type=float, default=100., 51 | help='Weight of CLIP guidance', 52 | ) 53 | parser.add_argument( 54 | '--clip_model', type=str, default='openai/clip-vit-large-patch14', 55 | help='Name of CLIP model', 56 | ) 57 | parser.add_argument( 58 | '--n_samples', type=int, required=True, 59 | help='Number of samples', 60 | ) 61 | parser.add_argument( 62 | '--save_dir', type=str, required=True, 63 | help='Path to directory saving samples', 64 | ) 65 | parser.add_argument( 66 | '--batch_size', type=int, default=500, 67 | help='Batch size on each process', 68 | ) 69 | return parser 70 | 71 | 72 | def main(): 73 | # PARSE ARGS AND CONFIGS 74 | args, unknown_args = get_parser().parse_known_args() 75 | unknown_args = [(a[2:] if a.startswith('--') else a) for a in unknown_args] 76 | unknown_args = [f'{k}={v}' for k, v in zip(unknown_args[::2], unknown_args[1::2])] 77 | conf = OmegaConf.load(args.config) 78 | conf = OmegaConf.merge(conf, OmegaConf.from_dotlist(unknown_args)) 79 | 80 | # INITIALIZE ACCELERATOR 81 | accelerator = accelerate.Accelerator() 82 | device = accelerator.device 83 | print(f'Process {accelerator.process_index} using device: {device}') 84 | accelerator.wait_for_everyone() 85 | 86 | # INITIALIZE LOGGER 87 | logger = get_logger( 88 | use_tqdm_handler=True, 89 | is_main_process=accelerator.is_main_process, 90 | ) 91 | 92 | # SET SEED 93 | accelerate.utils.set_seed(args.seed, device_specific=True) 94 | logger.info('=' * 19 + ' System Info ' + '=' * 18) 95 | logger.info(f'Number of processes: {accelerator.num_processes}') 96 | logger.info(f'Distributed type: {accelerator.distributed_type}') 97 | logger.info(f'Mixed precision: {accelerator.mixed_precision}') 98 | 99 | accelerator.wait_for_everyone() 100 | 101 | # BUILD DIFFUSER 102 | diffusion_params = OmegaConf.to_container(conf.diffusion.params) 103 | diffusion_params.update({ 104 | 'var_type': args.var_type or diffusion_params.get('var_type', None), 105 | 'respace_type': None if args.respace_steps is None else args.respace_type, 106 | 'respace_steps': args.respace_steps, 107 | 'device': device, 108 | 'guidance_weight': args.guidance_weight, 109 | 'clip_pretrained': args.clip_model, 110 | }) 111 | diffuser = diffusions.CLIPGuidance(**diffusion_params) 112 | logger.info('=' * 19 + ' Model Info ' + '=' * 19) 113 | logger.info(f'Using CLIP model: `{args.clip_model}`') 114 | 115 | # BUILD MODEL 116 | model = instantiate_from_config(conf.model) 117 | 118 | # LOAD WEIGHTS 119 | weights = load_weights(args.weights) 120 | model.load_state_dict(weights) 121 | logger.info(f'Successfully load model from {args.weights}') 122 | logger.info('=' * 50) 123 | 124 | # PREPARE FOR DISTRIBUTED MODE AND MIXED PRECISION 125 | model = accelerator.prepare(model) 126 | model.eval() 127 | 128 | accelerator.wait_for_everyone() 129 | 130 | @torch.no_grad() 131 | def sample(): 132 | idx = 0 133 | img_shape = (conf.data.img_channels, conf.data.params.img_size, conf.data.params.img_size) 134 | bspp = min(args.batch_size, math.ceil(args.n_samples / accelerator.num_processes)) 135 | folds = amortize(args.n_samples, bspp * accelerator.num_processes) 136 | for i, bs in enumerate(folds): 137 | init_noise = torch.randn((bspp, *img_shape), device=device) 138 | samples = diffuser.sample( 139 | model=accelerator.unwrap_model(model), init_noise=init_noise, 140 | tqdm_kwargs=dict(desc=f'Fold {i}/{len(folds)}', disable=not accelerator.is_main_process) 141 | ).clamp(-1, 1) 142 | samples = accelerator.gather(samples)[:bs] 143 | if accelerator.is_main_process: 144 | for x in samples: 145 | x = image_norm_to_float(x).cpu() 146 | save_image(x, os.path.join(args.save_dir, f'{idx}.png'), nrow=1) 147 | idx += 1 148 | with open(os.path.join(args.save_dir, 'description.txt'), 'w') as f: 149 | f.write(args.text) 150 | 151 | # START SAMPLING 152 | logger.info('Start sampling...') 153 | logger.info(f'The input text description is: {args.text}') 154 | logger.info(f'Guidance weight: {args.guidance_weight}') 155 | diffuser.set_text(args.text) 156 | os.makedirs(args.save_dir, exist_ok=True) 157 | logger.info(f'Samples will be saved to {args.save_dir}') 158 | sample() 159 | logger.info(f'Sampled images are saved to {args.save_dir}') 160 | logger.info('End of sampling') 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /streamlit/Hello.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | st.set_page_config(page_title="Diffusion", layout="wide") 4 | 5 | st.markdown( 6 | """ 7 |