├── LICENSE
├── README.md
├── assets
├── classifier-free-cifar10.png
├── classifier-free-imagenet.png
├── clip-guidance-celebahq.png
├── ddib-imagenet.png
├── ddim-celebahq-interpolate.png
├── ddim-celebahq-reconstruction.png
├── ddim-cifar10-interpolate.png
├── ddim-cifar10-reconstruction.png
├── ddim-cifar10.png
├── ddpm-celebahq-denoise.png
├── ddpm-celebahq-progressive.png
├── ddpm-celebahq-random.png
├── ddpm-cifar10-denoise.png
├── ddpm-cifar10-progressive.png
├── ddpm-cifar10-random.png
├── ddpm-mnist-random.png
├── fidelity-speed-visualization.png
├── ilvr-celebahq.png
├── mask-guidance-imagenet.png
├── sdedit.png
└── streamlit.png
├── configs
├── ddpm_celebahq.yaml
├── ddpm_cfg_cifar10.yaml
├── ddpm_cifar10.yaml
└── ddpm_mnist.yaml
├── datasets
├── ImageDir.py
├── __init__.py
├── celebahq.py
├── cifar10.py
├── imagenet.py
└── mnist.py
├── diffusions
├── __init__.py
├── ddim.py
├── ddpm.py
├── ddpm_ip.py
├── euler.py
├── guidance
│ ├── __init__.py
│ ├── base.py
│ ├── clip_guidance.py
│ ├── ilvr.py
│ └── mask_guidance.py
├── heun.py
└── schedule.py
├── docs
├── CLIP Guidance.md
├── Classifier-Free Guidance.md
├── DDIB.md
├── DDIM.md
├── DDPM-IP.md
├── DDPM.md
├── ILVR.md
├── Mask Guidance.md
├── SDEdit.md
└── Samplers.md
├── models
├── __init__.py
├── adm
│ ├── __init__.py
│ ├── nn.py
│ ├── readme.md
│ ├── unet.py
│ └── unet_combined.py
├── base_latent.py
├── dit
│ ├── __init__.py
│ ├── autoencoder.py
│ ├── dit.py
│ ├── model.py
│ └── readme.md
├── ema.py
├── mdt
│ ├── __init__.py
│ ├── autoencoder.py
│ ├── mdt.py
│ ├── model.py
│ └── readme.md
├── modules.py
├── pesser
│ ├── __init__.py
│ ├── model.py
│ └── readme.md
├── sdxl
│ ├── __init__.py
│ ├── attention.py
│ ├── autoencoder.py
│ ├── conditioner.py
│ ├── distributions.py
│ ├── modules.py
│ ├── readme.md
│ ├── regularizers.py
│ ├── stablediffusion.py
│ ├── unet.py
│ └── util.py
├── stablediffusion
│ ├── __init__.py
│ ├── attention.py
│ ├── autoencoder.py
│ ├── distributions.py
│ ├── modules.py
│ ├── readme.md
│ ├── stablediffusion.py
│ ├── text_encoders.py
│ ├── unet.py
│ └── util.py
├── unet.py
└── unet_categorial_adagn.py
├── requirements.txt
├── scripts
├── sample_cfg.py
├── sample_clip_guidance.py
├── sample_ddib.py
├── sample_ilvr.py
├── sample_mask_guidance.py
├── sample_sdedit.py
├── sample_uncond.py
├── train_ddpm.py
└── train_ddpm_cfg.py
├── streamlit
├── Hello.py
└── pages
│ ├── 1_Unconditional_Image_Generation.py
│ ├── 2_Class_conditional_Image_Generation.py
│ ├── 3_Stable_Diffusion_v1.5.py
│ └── 4_Stable_Diffusion_XL.py
├── test_images
├── celebahq
│ ├── 182660.jpg
│ ├── 182664.jpg
│ ├── 182690.jpg
│ ├── 182699.jpg
│ ├── 182704.jpg
│ ├── 182716.jpg
│ ├── 182722.jpg
│ ├── 182727.jpg
│ ├── 182734.jpg
│ └── 182743.jpg
├── cifar10
│ ├── val_00.png
│ ├── val_01.png
│ ├── val_02.png
│ ├── val_03.png
│ ├── val_04.png
│ ├── val_05.png
│ ├── val_06.png
│ ├── val_07.png
│ ├── val_08.png
│ ├── val_09.png
│ ├── val_10.png
│ ├── val_11.png
│ ├── val_12.png
│ ├── val_13.png
│ ├── val_14.png
│ ├── val_15.png
│ ├── val_16.png
│ ├── val_17.png
│ ├── val_18.png
│ ├── val_19.png
│ ├── val_20.png
│ ├── val_21.png
│ ├── val_22.png
│ ├── val_23.png
│ ├── val_24.png
│ ├── val_25.png
│ ├── val_26.png
│ ├── val_27.png
│ ├── val_28.png
│ ├── val_29.png
│ ├── val_30.png
│ ├── val_31.png
│ ├── val_32.png
│ ├── val_33.png
│ ├── val_34.png
│ ├── val_35.png
│ ├── val_36.png
│ ├── val_37.png
│ ├── val_38.png
│ ├── val_39.png
│ ├── val_40.png
│ ├── val_41.png
│ ├── val_42.png
│ ├── val_43.png
│ ├── val_44.png
│ ├── val_45.png
│ ├── val_46.png
│ ├── val_47.png
│ ├── val_48.png
│ ├── val_49.png
│ ├── val_50.png
│ ├── val_51.png
│ ├── val_52.png
│ ├── val_53.png
│ ├── val_54.png
│ ├── val_55.png
│ ├── val_56.png
│ ├── val_57.png
│ ├── val_58.png
│ ├── val_59.png
│ ├── val_60.png
│ ├── val_61.png
│ ├── val_62.png
│ └── val_63.png
├── imagenet
│ ├── 220
│ │ ├── ILSVRC2012_val_00000539.JPEG
│ │ ├── ILSVRC2012_val_00005027.JPEG
│ │ ├── ILSVRC2012_val_00008148.JPEG
│ │ ├── ILSVRC2012_val_00026146.JPEG
│ │ └── ILSVRC2012_val_00048337.JPEG
│ ├── 248
│ │ ├── ILSVRC2012_val_00000178.JPEG
│ │ ├── ILSVRC2012_val_00010011.JPEG
│ │ ├── ILSVRC2012_val_00019757.JPEG
│ │ ├── ILSVRC2012_val_00032646.JPEG
│ │ └── ILSVRC2012_val_00049841.JPEG
│ ├── 290
│ │ ├── ILSVRC2012_val_00002931.JPEG
│ │ ├── ILSVRC2012_val_00009297.JPEG
│ │ ├── ILSVRC2012_val_00018843.JPEG
│ │ ├── ILSVRC2012_val_00040784.JPEG
│ │ └── ILSVRC2012_val_00045488.JPEG
│ ├── 291
│ │ ├── ILSVRC2012_val_00003788.JPEG
│ │ ├── ILSVRC2012_val_00011645.JPEG
│ │ ├── ILSVRC2012_val_00012423.JPEG
│ │ ├── ILSVRC2012_val_00027413.JPEG
│ │ └── ILSVRC2012_val_00034811.JPEG
│ ├── 292
│ │ ├── ILSVRC2012_val_00029267.JPEG
│ │ ├── ILSVRC2012_val_00030651.JPEG
│ │ ├── ILSVRC2012_val_00047614.JPEG
│ │ ├── ILSVRC2012_val_00048380.JPEG
│ │ └── ILSVRC2012_val_00049631.JPEG
│ ├── 294
│ │ ├── ILSVRC2012_val_00010281.JPEG
│ │ ├── ILSVRC2012_val_00020362.JPEG
│ │ ├── ILSVRC2012_val_00031645.JPEG
│ │ ├── ILSVRC2012_val_00032453.JPEG
│ │ └── ILSVRC2012_val_00039227.JPEG
│ └── test
│ │ ├── ILSVRC2012_test_00000001.JPEG
│ │ ├── ILSVRC2012_test_00000002.JPEG
│ │ ├── ILSVRC2012_test_00000003.JPEG
│ │ ├── ILSVRC2012_test_00000004.JPEG
│ │ ├── ILSVRC2012_test_00000005.JPEG
│ │ ├── ILSVRC2012_test_00000006.JPEG
│ │ ├── ILSVRC2012_test_00000007.JPEG
│ │ ├── ILSVRC2012_test_00000008.JPEG
│ │ ├── ILSVRC2012_test_00000009.JPEG
│ │ └── ILSVRC2012_test_00000010.JPEG
└── strokes
│ ├── 0.png
│ ├── 1.png
│ └── 2.png
├── utils
├── __init__.py
├── load.py
├── logger.py
├── mask.py
├── misc.py
└── resize_right
│ ├── __init__.py
│ ├── interp_methods.py
│ └── resize_right.py
└── weights
├── ChenWu98
└── cycle-diffusion
│ ├── cat_ema_0.9999_050000.yaml
│ └── wild_ema_0.9999_050000.yaml
├── andreas128
└── RePaint
│ └── celeba256_250000.yaml
├── facebookresearch
└── DiT
│ ├── DiT-XL-2-256x256.yaml
│ └── DiT-XL-2-512x512.yaml
├── jychoi118
└── ilvr_adm
│ └── afhqdog_p2.yaml
├── openai
└── guided-diffusion
│ ├── 256x256_diffusion.yaml
│ ├── 256x256_diffusion_combined.yaml
│ └── 256x256_diffusion_uncond.yaml
├── pesser
└── pytorch_diffusion
│ ├── ema_diffusion_celebahq_model-560000.yaml
│ └── ema_diffusion_lsun_church_model-4432000.yaml
├── sail-sg
└── MDT
│ └── mdt_xl2_v2_ckpt.yaml
├── sdxl
└── sd_xl_base.yaml
└── stablediffusion
├── v1-inference.yaml
└── v2-inference-v.yaml
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yifeng Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/assets/classifier-free-cifar10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/classifier-free-cifar10.png
--------------------------------------------------------------------------------
/assets/classifier-free-imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/classifier-free-imagenet.png
--------------------------------------------------------------------------------
/assets/clip-guidance-celebahq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/clip-guidance-celebahq.png
--------------------------------------------------------------------------------
/assets/ddib-imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddib-imagenet.png
--------------------------------------------------------------------------------
/assets/ddim-celebahq-interpolate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-celebahq-interpolate.png
--------------------------------------------------------------------------------
/assets/ddim-celebahq-reconstruction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-celebahq-reconstruction.png
--------------------------------------------------------------------------------
/assets/ddim-cifar10-interpolate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-cifar10-interpolate.png
--------------------------------------------------------------------------------
/assets/ddim-cifar10-reconstruction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-cifar10-reconstruction.png
--------------------------------------------------------------------------------
/assets/ddim-cifar10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddim-cifar10.png
--------------------------------------------------------------------------------
/assets/ddpm-celebahq-denoise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-celebahq-denoise.png
--------------------------------------------------------------------------------
/assets/ddpm-celebahq-progressive.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-celebahq-progressive.png
--------------------------------------------------------------------------------
/assets/ddpm-celebahq-random.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-celebahq-random.png
--------------------------------------------------------------------------------
/assets/ddpm-cifar10-denoise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-cifar10-denoise.png
--------------------------------------------------------------------------------
/assets/ddpm-cifar10-progressive.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-cifar10-progressive.png
--------------------------------------------------------------------------------
/assets/ddpm-cifar10-random.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-cifar10-random.png
--------------------------------------------------------------------------------
/assets/ddpm-mnist-random.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ddpm-mnist-random.png
--------------------------------------------------------------------------------
/assets/fidelity-speed-visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/fidelity-speed-visualization.png
--------------------------------------------------------------------------------
/assets/ilvr-celebahq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/ilvr-celebahq.png
--------------------------------------------------------------------------------
/assets/mask-guidance-imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/mask-guidance-imagenet.png
--------------------------------------------------------------------------------
/assets/sdedit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/sdedit.png
--------------------------------------------------------------------------------
/assets/streamlit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/assets/streamlit.png
--------------------------------------------------------------------------------
/configs/ddpm_celebahq.yaml:
--------------------------------------------------------------------------------
1 | seed: 2022
2 |
3 | data:
4 | target: datasets.celebahq.CelebAHQ
5 | params:
6 | root: ~/data/CelebA-HQ/
7 | img_size: 256
8 | img_channels: 3
9 |
10 | dataloader:
11 | num_workers: 4
12 | pin_memory: true
13 | prefetch_factor: 2
14 |
15 | model:
16 | target: models.unet.UNet
17 | params:
18 | in_channels: 3
19 | out_channels: 3
20 | dim: 128
21 | dim_mults: [1, 1, 2, 2, 4, 4]
22 | use_attn: [false, false, false, false, true, false]
23 | num_res_blocks: 2
24 | n_heads: 1
25 | dropout: 0.0
26 |
27 | diffusion:
28 | target: diffusions.ddpm.DDPM
29 | params:
30 | total_steps: 1000
31 | beta_schedule: linear
32 | beta_start: 0.0001
33 | beta_end: 0.02
34 | objective: pred_eps
35 | var_type: fixed_small
36 |
37 | train:
38 | n_steps: 500000
39 | batch_size: 64
40 | micro_batch: 0
41 |
42 | clip_grad_norm: 1.0
43 | ema_decay: 0.9999
44 | ema_gradual: true
45 |
46 | print_freq: 400
47 | save_freq: 10000
48 | sample_freq: 5000
49 | n_samples: 36
50 |
51 | optim:
52 | target: torch.optim.Adam
53 | params:
54 | lr: 0.00002
55 |
--------------------------------------------------------------------------------
/configs/ddpm_cfg_cifar10.yaml:
--------------------------------------------------------------------------------
1 | seed: 2022
2 |
3 | data:
4 | target: datasets.cifar10.CIFAR10
5 | params:
6 | root: ~/data/CIFAR-10/
7 | img_size: 32
8 | img_channels: 3
9 | num_classes: 10
10 |
11 | dataloader:
12 | num_workers: 4
13 | pin_memory: true
14 | prefetch_factor: 2
15 |
16 | model:
17 | target: models.unet_categorial_adagn.UNetCategorialAdaGN
18 | params:
19 | in_channels: 3
20 | out_channels: 3
21 | dim: 128
22 | dim_mults: [1, 2, 2, 2]
23 | use_attn: [false, true, true, false]
24 | num_res_blocks: 2
25 | num_classes: 10
26 | attn_head_dims: 64
27 | resblock_updown: true
28 | dropout: 0.1
29 |
30 | diffusion:
31 | target: diffusions.cfg.ddpm_cfg.DDPMCFG
32 | params:
33 | total_steps: 1000
34 | beta_schedule: cosine
35 | beta_start: 0.0001
36 | beta_end: 0.02
37 | objective: pred_eps
38 | var_type: fixed_large
39 |
40 | train:
41 | n_steps: 800000
42 | batch_size: 128
43 | micro_batch: 0
44 |
45 | clip_grad_norm: 1.0
46 | ema_decay: 0.9999
47 | ema_gradual: true
48 |
49 | print_freq: 400
50 | save_freq: 10000
51 | sample_freq: 5000
52 | n_samples_each_class: 10
53 |
54 | p_uncond: 0.2
55 |
56 | optim:
57 | target: torch.optim.AdamW
58 | params:
59 | lr: 0.0002
60 |
--------------------------------------------------------------------------------
/configs/ddpm_cifar10.yaml:
--------------------------------------------------------------------------------
1 | seed: 2022
2 |
3 | data:
4 | target: datasets.cifar10.CIFAR10
5 | params:
6 | root: ~/data/CIFAR-10/
7 | img_size: 32
8 | img_channels: 3
9 | num_classes: 10
10 |
11 | dataloader:
12 | num_workers: 4
13 | pin_memory: true
14 | prefetch_factor: 2
15 |
16 | model:
17 | target: models.unet.UNet
18 | params:
19 | in_channels: 3
20 | out_channels: 3
21 | dim: 128
22 | dim_mults: [1, 2, 2, 2]
23 | use_attn: [false, true, false, false]
24 | num_res_blocks: 2
25 | n_heads: 1
26 | dropout: 0.1
27 |
28 | diffusion:
29 | target: diffusions.ddpm.DDPM
30 | params:
31 | total_steps: 1000
32 | beta_schedule: linear
33 | beta_start: 0.0001
34 | beta_end: 0.02
35 | objective: pred_eps
36 | var_type: fixed_large
37 |
38 | train:
39 | n_steps: 800000
40 | batch_size: 128
41 | micro_batch: 0
42 |
43 | clip_grad_norm: 1.0
44 | ema_decay: 0.9999
45 | ema_gradual: true
46 |
47 | print_freq: 400
48 | save_freq: 10000
49 | sample_freq: 5000
50 | n_samples: 64
51 |
52 | optim:
53 | target: torch.optim.Adam
54 | params:
55 | lr: 0.0002
56 |
--------------------------------------------------------------------------------
/configs/ddpm_mnist.yaml:
--------------------------------------------------------------------------------
1 | seed: 2022
2 |
3 | data:
4 | target: datasets.mnist.MNIST
5 | params:
6 | root: ~/data/MNIST/
7 | img_size: 32
8 | img_channels: 1
9 | num_classes: 10
10 |
11 | dataloader:
12 | num_workers: 4
13 | pin_memory: true
14 | prefetch_factor: 2
15 |
16 | model:
17 | target: models.unet.UNet
18 | params:
19 | in_channels: 1
20 | out_channels: 1
21 | dim: 64
22 | dim_mults: [1, 2, 2, 2]
23 | use_attn: [false, true, false, false]
24 | num_res_blocks: 2
25 | n_heads: 1
26 | dropout: 0.1
27 |
28 | diffusion:
29 | target: diffusions.ddpm.DDPM
30 | params:
31 | total_steps: 200
32 | beta_schedule: linear
33 | beta_start: 0.0001
34 | beta_end: 0.02
35 | objective: pred_eps
36 | var_type: fixed_small
37 |
38 | train:
39 | n_steps: 50000
40 | batch_size: 128
41 | micro_batch: 0
42 |
43 | clip_grad_norm: 1.0
44 | ema_decay: 0.9999
45 | ema_gradual: true
46 |
47 | print_freq: 400
48 | save_freq: 10000
49 | sample_freq: 5000
50 | n_samples: 64
51 |
52 | optim:
53 | target: torch.optim.Adam
54 | params:
55 | lr: 0.0002
56 |
--------------------------------------------------------------------------------
/datasets/ImageDir.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torch.utils.data import Dataset
4 |
5 |
6 | def extract_images(root):
7 | """ Extract all images under root """
8 | img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
9 | root = os.path.expanduser(root)
10 | img_paths = []
11 | for curdir, subdirs, files in os.walk(root):
12 | for file in files:
13 | if os.path.splitext(file)[1].lower() in img_ext:
14 | img_paths.append(os.path.join(curdir, file))
15 | img_paths = sorted(img_paths)
16 | return img_paths
17 |
18 |
19 | class ImageDir(Dataset):
20 | def __init__(self, root, transform=None):
21 | root = os.path.expanduser(root)
22 | if not os.path.isdir(root):
23 | raise ValueError(f'{root} is not a valid directory')
24 |
25 | self.transform = transform
26 | self.img_paths = extract_images(root)
27 |
28 | def __len__(self):
29 | return len(self.img_paths)
30 |
31 | def __getitem__(self, item):
32 | X = Image.open(self.img_paths[item]).convert('RGB')
33 | if self.transform is not None:
34 | X = self.transform(X)
35 | return X
36 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .ImageDir import ImageDir
2 |
--------------------------------------------------------------------------------
/datasets/celebahq.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from typing import Optional, Callable
4 |
5 | import torchvision.transforms as T
6 | from torch.utils.data import Dataset
7 |
8 |
9 | def extract_images(root):
10 | """ Extract all images under root """
11 | img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
12 | root = os.path.expanduser(root)
13 | img_paths = []
14 | for curdir, subdirs, files in os.walk(root):
15 | for file in files:
16 | if os.path.splitext(file)[1].lower() in img_ext:
17 | img_paths.append(os.path.join(curdir, file))
18 | img_paths = sorted(img_paths)
19 | return img_paths
20 |
21 |
22 | class CelebAHQ(Dataset):
23 | """The CelebA-HQ Dataset.
24 |
25 | The CelebA-HQ dataset is a high-quality version of CelebA that consists of 30,000 images at 1024×1024 resolution.
26 | (Copied from PaperWithCode)
27 |
28 | The official way to prepare the dataset is to download img_celeba.7z from the original CelebA dataset and the delta
29 | files from the official GitHub repository. Then use dataset_tool.py to generate the high-quality images.
30 |
31 | However, I personally recommend downloading the CelebAMask-HQ dataset, which contains processed CelebA-HQ images.
32 | Nevertheless, the filenames in CelebAMask-HQ are sorted from 0 to 29999, which is inconsistent with the original
33 | CelebA filenames. I provide a python script (scripts/celebahq_map_filenames.py) to help convert the filenames.
34 |
35 | To load data with this class, the dataset should be organized in the following structure:
36 |
37 | root
38 | ├── CelebA-HQ-img
39 | │ ├── 000004.jpg
40 | │ ├── ...
41 | │ └── 202591.jpg
42 | └── CelebA-HQ-to-CelebA-mapping.txt
43 |
44 | The train/valid/test sets are split according to the original CelebA dataset,
45 | resulting in 24,183 training images, 2,993 validation images, and 2,824 test images.
46 |
47 | This class has one pre-defined transform:
48 | - 'resize' (default): Resize the image directly to the target size
49 |
50 | References:
51 | - https://github.com/tkarras/progressive_growing_of_gans
52 | - https://paperswithcode.com/dataset/celeba-hq
53 | - https://github.com/switchablenorms/CelebAMask-HQ
54 |
55 | """
56 | def __init__(
57 | self,
58 | root: str,
59 | img_size: int,
60 | split: str = 'train',
61 | transform_type: str = 'default',
62 | transform: Optional[Callable] = None,
63 | ):
64 | if split not in ['train', 'valid', 'test', 'all']:
65 | raise ValueError(f'Invalid split: {split}')
66 | root = os.path.expanduser(root)
67 | image_root = os.path.join(root, 'CelebA-HQ-img')
68 | if not os.path.isdir(image_root):
69 | raise ValueError(f'{image_root} is not an existing directory')
70 |
71 | self.root = root
72 | self.img_size = img_size
73 | self.split = split
74 | self.transform_type = transform_type
75 | self.transform = transform
76 | if transform is None:
77 | self.transform = self.get_transform()
78 |
79 | def filter_func(p):
80 | if split == 'all':
81 | return True
82 | celeba_splits = [1, 162771, 182638, 202600]
83 | k = 0 if split == 'train' else (1 if split == 'valid' else 2)
84 | return celeba_splits[k] <= int(os.path.splitext(os.path.basename(p))[0]) < celeba_splits[k+1]
85 |
86 | self.img_paths = extract_images(image_root)
87 | self.img_paths = list(filter(filter_func, self.img_paths))
88 |
89 | def __len__(self):
90 | return len(self.img_paths)
91 |
92 | def __getitem__(self, item):
93 | X = Image.open(self.img_paths[item])
94 | if self.transform is not None:
95 | X = self.transform(X)
96 | return X
97 |
98 | def get_transform(self):
99 | flip_p = 0.5 if self.split in ['train', 'all'] else 0.0
100 | if self.transform_type in ['default', 'resize']:
101 | transform = T.Compose([
102 | T.Resize((self.img_size, self.img_size)),
103 | T.RandomHorizontalFlip(flip_p),
104 | T.ToTensor(),
105 | T.Normalize([0.5] * 3, [0.5] * 3),
106 | ])
107 | elif self.transform_type == 'none':
108 | transform = None
109 | else:
110 | raise ValueError(f'Invalid transform_type: {self.transform_type}')
111 | return transform
112 |
--------------------------------------------------------------------------------
/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable
2 |
3 | import torchvision.datasets
4 | import torchvision.transforms as T
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class CIFAR10(Dataset):
9 | """Extend torchvision.datasets.CIFAR10 with one pre-defined transform.
10 |
11 | The pre-defined transform is:
12 | - 'resize' (default): Resize the image directly to the target size, followed by random horizontal flipping.
13 |
14 | """
15 |
16 | def __init__(
17 | self,
18 | root: str,
19 | img_size: int,
20 | split: str = 'train',
21 | transform_type: str = 'default',
22 | transform: Optional[Callable] = None,
23 | target_transform: Optional[Callable] = None,
24 | download: bool = False,
25 | ):
26 | if split not in ['train', 'test']:
27 | raise ValueError(f'Invalid split: {split}')
28 |
29 | self.img_size = img_size
30 | self.split = split
31 | self.transform_type = transform_type
32 | if transform is None:
33 | transform = self.get_transform()
34 |
35 | self.cifar10 = torchvision.datasets.CIFAR10(
36 | root=root,
37 | train=(split == 'train'),
38 | transform=transform,
39 | target_transform=target_transform,
40 | download=download,
41 | )
42 |
43 | def __len__(self):
44 | return len(self.cifar10)
45 |
46 | def __getitem__(self, item):
47 | X, y = self.cifar10[item]
48 | return X, y
49 |
50 | def get_transform(self):
51 | flip_p = 0.5 if self.split == 'train' else 0.0
52 | if self.transform_type in ['default', 'resize']:
53 | transform = T.Compose([
54 | T.Resize((self.img_size, self.img_size), antialias=True),
55 | T.RandomHorizontalFlip(flip_p),
56 | T.ToTensor(),
57 | T.Normalize([0.5] * 3, [0.5] * 3),
58 | ])
59 | elif self.transform_type == 'none':
60 | transform = None
61 | else:
62 | raise ValueError(f'Invalid transform_type: {self.transform_type}')
63 | return transform
64 |
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from typing import Optional, Callable
4 |
5 | from torch.utils.data import Dataset
6 | import torchvision.transforms as T
7 |
8 |
9 | def extract_images(root):
10 | """ Extract all images under root """
11 | img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
12 | root = os.path.expanduser(root)
13 | img_paths = []
14 | for curdir, subdirs, files in os.walk(root):
15 | for file in files:
16 | if os.path.splitext(file)[1].lower() in img_ext:
17 | img_paths.append(os.path.join(curdir, file))
18 | img_paths = sorted(img_paths)
19 | return img_paths
20 |
21 |
22 | class ImageNet(Dataset):
23 | """Extend torchvision.datasets.ImageNet with two pre-defined transforms and support test set.
24 |
25 | This class has two pre-defined transforms:
26 | - 'resize-crop' (default): Resize the image so that the short side match the target size, then crop a square patch
27 | - 'resize': Resize the image directly to the target size
28 | All of the above transforms will be followed by random horizontal flipping.
29 |
30 | To load data with this class, the dataset should be organized in the following structure:
31 |
32 | root
33 | ├── train
34 | │ ├── n01440764
35 | │ ├── ...
36 | │ └── n15075141
37 | ├── valid (or val)
38 | │ ├── n01440764 (or directly put all validation images here)
39 | │ ├── ...
40 | │ └── n15075141
41 | └── test
42 | ├── ILSVRC2012_test_00000001.JPEG
43 | ├── ...
44 | └── ILSVRC2012_test_00100000.JPEG
45 |
46 | References:
47 | - https://image-net.org/challenges/LSVRC/2012/2012-downloads.php
48 |
49 | """
50 |
51 | def __init__(
52 | self,
53 | root: str,
54 | img_size: int,
55 | split: str = 'train',
56 | transform_type: str = 'default',
57 | transform: Optional[Callable] = None,
58 | ):
59 | if split not in ['train', 'valid', 'test']:
60 | raise ValueError(f'Invalid split: {split}')
61 | root = os.path.expanduser(root)
62 | image_root = os.path.join(root, split)
63 | if split == 'valid' and not os.path.isdir(image_root):
64 | image_root = os.path.join(root, 'val')
65 | if not os.path.isdir(image_root):
66 | raise ValueError(f'{image_root} is not an existing directory')
67 |
68 | self.img_size = img_size
69 | self.split = split
70 | self.transform_type = transform_type
71 | self.transform = transform
72 |
73 | self.img_paths = extract_images(image_root)
74 |
75 | def __len__(self):
76 | return len(self.img_paths)
77 |
78 | def __getitem__(self, item):
79 | X = Image.open(self.img_paths[item]).convert('RGB')
80 | if self.transform is not None:
81 | X = self.transform(X)
82 | return X
83 |
84 | def get_transform(self):
85 | crop = T.RandomCrop if self.split == 'train' else T.CenterCrop
86 | flip_p = 0.5 if self.split == 'train' else 0.0
87 | if self.transform_type in ['default', 'resize-crop']:
88 | transform = T.Compose([
89 | T.Resize(self.img_size),
90 | crop((self.img_size, self.img_size)),
91 | T.RandomHorizontalFlip(flip_p),
92 | T.ToTensor(),
93 | T.Normalize([0.5] * 3, [0.5] * 3),
94 | ])
95 | elif self.transform_type == 'resize':
96 | transform = T.Compose([
97 | T.Resize((self.img_size, self.img_size)),
98 | T.RandomHorizontalFlip(flip_p),
99 | T.ToTensor(),
100 | T.Normalize([0.5] * 3, [0.5] * 3),
101 | ])
102 | elif self.transform_type == 'none':
103 | transform = None
104 | else:
105 | raise ValueError(f'Invalid transform_type: {self.transform_type}')
106 | return transform
107 |
--------------------------------------------------------------------------------
/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable
2 |
3 | import torchvision.datasets
4 | import torchvision.transforms as T
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class MNIST(Dataset):
9 | """Extend torchvision.datasets.MNIST with one pre-defined transform.
10 |
11 | The pre-defined transform is:
12 | - 'resize' (default): Resize the image directly to the target size
13 |
14 | """
15 |
16 | def __init__(
17 | self,
18 | root: str,
19 | img_size: int,
20 | split: str = 'train',
21 | transform_type: str = 'default',
22 | transform: Optional[Callable] = None,
23 | target_transform: Optional[Callable] = None,
24 | download: bool = False,
25 | ):
26 | if split not in ['train', 'test']:
27 | raise ValueError(f'Invalid split: {split}')
28 |
29 | self.img_size = img_size
30 | self.transform_type = transform_type
31 | if transform is None:
32 | transform = self.get_transform()
33 |
34 | self.mnist = torchvision.datasets.MNIST(
35 | root=root,
36 | train=(split == 'train'),
37 | transform=transform,
38 | target_transform=target_transform,
39 | download=download,
40 | )
41 |
42 | def __len__(self):
43 | return len(self.mnist)
44 |
45 | def __getitem__(self, item):
46 | X, y = self.mnist[item]
47 | return X, y
48 |
49 | def get_transform(self):
50 | if self.transform_type in ['default', 'resize']:
51 | transform = T.Compose([
52 | T.Resize((self.img_size, self.img_size), antialias=True),
53 | T.ToTensor(),
54 | T.Normalize([0.5], [0.5]),
55 | ])
56 | elif self.transform_type == 'none':
57 | transform = None
58 | else:
59 | raise ValueError(f'Invalid transform_type: {self.transform_type}')
60 | return transform
61 |
--------------------------------------------------------------------------------
/diffusions/__init__.py:
--------------------------------------------------------------------------------
1 | from .schedule import get_beta_schedule, get_respaced_seq
2 |
3 | from .ddpm import DDPM, DDPMCFG
4 | from .ddim import DDIM, DDIMCFG
5 | from .euler import EulerSampler
6 | from .heun import HeunSampler
7 |
8 | from .guidance.ilvr import ILVR
9 | from .guidance.mask_guidance import MaskGuidance
10 | from .guidance.clip_guidance import CLIPGuidance
11 |
--------------------------------------------------------------------------------
/diffusions/ddpm_ip.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 | from diffusions import DDPM
9 |
10 |
11 | class DDPM_IP(DDPM):
12 | def __init__(self, gamma: float = 0.1, *args, **kwargs):
13 | """Denoising Diffusion Probabilistic Models with Input Perturbation.
14 |
15 | Perturb the input (xt) during training to simulate the gap between training and testing.
16 | Surprisingly simple but effective.
17 |
18 | Args:
19 | gamma: Perturbation strength.
20 |
21 | References:
22 | [1] Ning, Mang, Enver Sangineto, Angelo Porrello, Simone Calderara, and Rita Cucchiara. "Input
23 | Perturbation Reduces Exposure Bias in Diffusion Models." arXiv preprint arXiv:2301.11706 (2023).
24 |
25 | """
26 | super().__init__(*args, **kwargs)
27 | self.gamma = gamma
28 |
29 | def loss_func(self, model: nn.Module, x0: Tensor, t: Tensor, eps: Tensor = None, model_kwargs: Dict = None):
30 | if model_kwargs is None:
31 | model_kwargs = dict()
32 | if eps is None:
33 | eps = torch.randn_like(x0)
34 | # input perturbation
35 | perturbed_eps = eps + self.gamma * torch.randn_like(eps)
36 | xt = self.diffuse(x0, t, perturbed_eps)
37 | if self.objective == 'pred_eps':
38 | pred_eps = model(xt, t, **model_kwargs)
39 | return F.mse_loss(pred_eps, eps)
40 | elif self.objective == 'pred_x0':
41 | pred_x0 = model(xt, t, **model_kwargs)
42 | return F.mse_loss(pred_x0, x0)
43 | elif self.objective == 'pred_v':
44 | v = self.get_v(x0, eps, t)
45 | pred_v = model(xt, t, **model_kwargs)
46 | return F.mse_loss(pred_v, v)
47 | else:
48 | raise ValueError(f'Objective {self.objective} is not supported.')
49 |
--------------------------------------------------------------------------------
/diffusions/euler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from diffusions.ddpm import DDPM
5 |
6 |
7 | class EulerSampler(DDPM):
8 | def __init__(
9 | self,
10 | total_steps: int = 1000,
11 | beta_schedule: str = 'linear',
12 | beta_start: float = 0.0001,
13 | beta_end: float = 0.02,
14 | betas: Tensor = None,
15 | objective: str = 'pred_eps',
16 |
17 | clip_denoised: bool = True,
18 | respace_type: str = None,
19 | respace_steps: int = 100,
20 | respaced_seq: Tensor = None,
21 |
22 | device: torch.device = 'cpu',
23 | **kwargs,
24 | ):
25 | """Euler sampler for DDPM-like diffusion process.
26 |
27 | References:
28 | [1] Karras, Tero, Miika Aittala, Timo Aila, and Samuli Laine. "Elucidating the design space of
29 | diffusion-based generative models." Advances in Neural Information Processing Systems 35 (2022):
30 | 26565-26577.
31 |
32 | """
33 | super().__init__(
34 | total_steps=total_steps,
35 | beta_schedule=beta_schedule,
36 | beta_start=beta_start,
37 | beta_end=beta_end,
38 | betas=betas,
39 | objective=objective,
40 | clip_denoised=clip_denoised,
41 | respace_type=respace_type,
42 | respace_steps=respace_steps,
43 | respaced_seq=respaced_seq,
44 | device=device,
45 | **kwargs,
46 | )
47 |
48 | self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod).sqrt()
49 |
50 | def denoise(self, model_output: Tensor, xt: Tensor, t: int, t_prev: int):
51 | """Denoise from x_t to x_{t-1}."""
52 | # Prepare parameters
53 | sigmas_t = self.sigmas[t]
54 | sigmas_t_prev = self.sigmas[t_prev] if t_prev >= 0 else torch.tensor(0.0)
55 |
56 | # Predict x0 and eps
57 | predict = self.predict(model_output, xt, t)
58 | pred_x0 = predict['pred_x0']
59 |
60 | # Calculate the x{t-1}
61 | bar_xt = (1 + sigmas_t ** 2).sqrt() * xt
62 | derivative = (bar_xt - pred_x0) / sigmas_t
63 | bar_sample = bar_xt + derivative * (sigmas_t_prev - sigmas_t)
64 | sample = bar_sample / (1 + sigmas_t_prev ** 2).sqrt()
65 |
66 | return {'sample': sample, 'pred_x0': pred_x0}
67 |
--------------------------------------------------------------------------------
/diffusions/guidance/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/diffusions/guidance/__init__.py
--------------------------------------------------------------------------------
/diffusions/guidance/clip_guidance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torchvision.transforms as T
4 |
5 | from transformers import CLIPProcessor, CLIPModel
6 |
7 | from diffusions.guidance.base import BaseGuidance
8 | from utils.misc import image_norm_to_uint8
9 |
10 |
11 | class CLIPGuidance(BaseGuidance):
12 | """Diffusion Models with CLIP Guidance.
13 |
14 | Guide the diffusion process with similarity between CLIP image feature and text feature, so that the generated
15 | image matches the description of the input text.
16 |
17 | In each step, the guidance is applied on the predicted x0 to avoid training CLIP on noisy images.
18 |
19 | """
20 | def __init__(
21 | self,
22 | guidance_weight: float = 1.0,
23 | clip_pretrained: str = 'openai/clip-vit-base-patch32',
24 | **kwargs,
25 | ):
26 | super().__init__(**kwargs)
27 | self.guidance_weight = guidance_weight
28 |
29 | self.clip_processor = CLIPProcessor.from_pretrained(clip_pretrained)
30 | self.clip_model = CLIPModel.from_pretrained(clip_pretrained).to(self.device)
31 |
32 | self.text = None
33 |
34 | def set_text(self, text: str):
35 | assert isinstance(text, str)
36 | self.text = text
37 |
38 | @torch.enable_grad()
39 | def cond_fn_mean(self, t: int, xt: Tensor, pred_x0: Tensor, var: Tensor, **kwargs):
40 | if self.text is None:
41 | raise RuntimeError('Please call `set_text()` before sampling.')
42 | images = image_norm_to_uint8(pred_x0)
43 | processed = self.clip_processor(text=self.text, images=images, return_tensors="pt", padding=True)
44 | processed = {k: v.to(self.device) for k, v in processed.items()}
45 | processed['pixel_values'].requires_grad_(True)
46 | out = self.clip_model(**processed)
47 | similarities = torch.matmul(out['image_embeds'], out['text_embeds'].t()).squeeze(dim=1)
48 | grad = torch.autograd.grad(outputs=similarities.sum(), inputs=processed['pixel_values'])[0]
49 | grad = T.Resize(xt.shape[-2:], antialias=True)(grad)
50 | return self.guidance_weight * ((1. / self.alphas_cumprod[t]) ** 0.5) * var * grad
51 |
--------------------------------------------------------------------------------
/diffusions/guidance/ilvr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from diffusions.guidance.base import BaseGuidance
5 | from utils.resize_right import resize_right, interp_methods
6 |
7 |
8 | class ILVR(BaseGuidance):
9 | def __init__(
10 | self,
11 | ref_images: Tensor = None,
12 | downsample_factor: int = 8,
13 | interp_method: str = 'cubic',
14 | *args, **kwargs,
15 | ):
16 | """Iterative Latent Variable Refinement (ILVR).
17 |
18 | Args:
19 | ref_images: The reference images of shape [B, C, H, W].
20 | downsample_factor: The downsample factor.
21 | interp_method: The interpolation method. Options: 'cubic', 'lanczos2', 'lanczos3', 'linear', 'box'.
22 |
23 | References:
24 | [1] Choi, Jooyoung, Sungwon Kim, Yonghyun Jeong, Youngjune Gwon, and Sungroh Yoon. “ILVR: Conditioning Method
25 | for Denoising Diffusion Probabilistic Models.” In 2021 IEEE/CVF International Conference on Computer Vision
26 | (ICCV), pp. 14347-14356. IEEE, 2021.
27 |
28 | """
29 | super().__init__(*args, **kwargs)
30 | self.ref_images = ref_images
31 | self.downsample_factor = downsample_factor
32 | self.interp_method = getattr(interp_methods, interp_method)
33 |
34 | def set_ref_images(self, ref_images: Tensor):
35 | self.ref_images = ref_images
36 |
37 | def cond_fn_sample(self, t: int, t_prev: int, sample: Tensor, **kwargs):
38 | if self.ref_images is None:
39 | raise RuntimeError('Please call `set_ref_images()` before sampling.')
40 | if t == 0:
41 | noisy_ref_images = self.ref_images
42 | else:
43 | noisy_ref_images = self.diffuse(
44 | x0=self.ref_images,
45 | t=torch.full((sample.shape[0], ), t_prev, device=self.device),
46 | )
47 | return self.low_pass_filter(noisy_ref_images) - self.low_pass_filter(sample)
48 |
49 | def low_pass_filter(self, x: Tensor):
50 | x = resize_right.resize(x, scale_factors=1./self.downsample_factor, interp_method=self.interp_method)
51 | x = resize_right.resize(x, scale_factors=self.downsample_factor, interp_method=self.interp_method)
52 | return x
53 |
--------------------------------------------------------------------------------
/diffusions/guidance/mask_guidance.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | from typing import Dict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 | from diffusions.guidance.base import BaseGuidance
9 |
10 |
11 | class MaskGuidance(BaseGuidance):
12 | def __init__(
13 | self,
14 | masked_image: Tensor = None,
15 | mask: Tensor = None,
16 | *args, **kwargs,
17 | ):
18 | """Diffusion Models with Mask Guidance.
19 |
20 | The idea was first proposed in [1] and further developed in [2], [3], etc. for image inpainting. In each reverse
21 | step, xt is computed by composing the noisy known part and denoised unknown part of the image.
22 |
23 | .. math::
24 | x_{t−1} = m \odot x^{known}_{t−1} + (1 − m) \odot x^{unknown}_{t-1}
25 |
26 | Args:
27 | masked_image: The masked input images of shape [B, C, H, W].
28 | mask: The binary masks of shape [B, 1, H, W]. Note that 1 denotes known areas and 0 denotes unknown areas.
29 |
30 | References:
31 | [1]. Song, Yang, and Stefano Ermon. “Generative modeling by estimating gradients of the data distribution.”
32 | Advances in neural information processing systems 32 (2019).
33 |
34 | [2]. Avrahami, Omri, Dani Lischinski, and Ohad Fried. “Blended diffusion for text-driven editing of natural
35 | images.” In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 18208
36 | -18218. 2022.
37 |
38 | [3]. Lugmayr, Andreas, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, and Luc Van Gool. “Repaint:
39 | Inpainting using denoising diffusion probabilistic models.” In Proceedings of the IEEE/CVF Conference on
40 | Computer Vision and Pattern Recognition, pp. 11461-11471. 2022.
41 |
42 | """
43 | super().__init__(*args, **kwargs)
44 | self.masked_image = masked_image
45 | self.mask = mask
46 |
47 | def set_mask_and_image(self, masked_image: Tensor, mask: Tensor):
48 | self.masked_image = masked_image
49 | self.mask = mask
50 |
51 | def cond_fn_sample(self, t: int, t_prev: int, sample: Tensor, **kwargs):
52 | if self.masked_image is None or self.mask is None:
53 | raise RuntimeError('Please call `set_mask_and_image()` before sampling.')
54 | if t == 0:
55 | noisy_known = self.masked_image
56 | else:
57 | noisy_known = self.diffuse(
58 | x0=self.masked_image,
59 | t=torch.full((sample.shape[0], ), t_prev, device=self.device),
60 | )
61 | return (noisy_known - sample) * self.mask
62 |
63 | def q_sample_one_step(self, xt: Tensor, t: int, t_next: int):
64 | """Sample from q(x{t+1} | xt). """
65 | alphas_cumprod_t = self.alphas_cumprod[t]
66 | alphas_cumprod_t_next = self.alphas_cumprod[t_next] if t_next < self.total_steps else torch.tensor(0.0)
67 | alphas_t_next = alphas_cumprod_t_next / alphas_cumprod_t
68 | return torch.sqrt(alphas_t_next) * xt + torch.sqrt(1. - alphas_t_next) * torch.randn_like(xt)
69 |
70 | def resample_loop(
71 | self, model: nn.Module, init_noise: Tensor,
72 | resample_r: int = 10, resample_j: int = 10,
73 | tqdm_kwargs: Dict = None, model_kwargs: Dict = None,
74 | ):
75 | """Sample following RePaint paper. """
76 | tqdm_kwargs = dict() if tqdm_kwargs is None else tqdm_kwargs
77 | model_kwargs = dict() if model_kwargs is None else model_kwargs
78 |
79 | img = init_noise
80 | resample_seq1 = self.get_resample_seq(resample_r, resample_j)
81 | resample_seq2 = resample_seq1[1:] + [-1]
82 | pbar = tqdm.tqdm(total=len(resample_seq1), **tqdm_kwargs)
83 | for t1, t2 in zip(resample_seq1, resample_seq2):
84 | if t1 > t2:
85 | t_batch = torch.full((img.shape[0], ), t1, device=self.device, dtype=torch.long)
86 | model_output = model(img, t_batch, **model_kwargs)
87 | out = self.denoise(model_output, img, t1, t2)
88 | out = self.apply_guidance(**out, xt=img, t=t1, t_prev=t2) # apply guidance
89 | img = out['sample']
90 | yield out
91 | else:
92 | img = self.q_sample_one_step(img, t1, t2)
93 | yield {'sample': img}
94 | pbar.update(1)
95 | pbar.close()
96 |
97 | def resample(
98 | self, model: nn.Module, init_noise: Tensor,
99 | resample_r: int = 10, resample_j: int = 10,
100 | tqdm_kwargs: Dict = None, model_kwargs: Dict = None,
101 | ):
102 | sample = None
103 | for out in self.resample_loop(
104 | model, init_noise,
105 | resample_r, resample_j,
106 | tqdm_kwargs, model_kwargs,
107 | ):
108 | sample = out['sample']
109 | return sample
110 |
111 | def get_resample_seq(self, resample_r: int = 10, resample_j: int = 10):
112 | """Figure 9 in RePaint paper.
113 |
114 | Args:
115 | resample_r: Number of resampling, as proposed in RePaint paper.
116 | resample_j: Jump lengths of resampling, as proposed in RePaint paper.
117 |
118 | """
119 | t_T = len(self.respaced_seq)
120 |
121 | jumps = {}
122 | for j in range(0, t_T - resample_j, resample_j):
123 | jumps[j] = resample_r - 1
124 |
125 | t = t_T
126 | ts = []
127 | while t >= 1:
128 | t = t - 1
129 | ts.append(self.respaced_seq[t].item())
130 | if jumps.get(t, 0) > 0:
131 | jumps[t] = jumps[t] - 1
132 | for _ in range(resample_j):
133 | t = t + 1
134 | ts.append(self.respaced_seq[t].item())
135 | return ts
136 |
137 |
138 | def _test(r, j):
139 | import matplotlib.pyplot as plt
140 |
141 | dummy_image = torch.rand((10, 3, 256, 256))
142 | dummy_mask = torch.randint(0, 2, (10, 1, 256, 256))
143 | mask_guided = MaskGuidance(dummy_image, dummy_mask, skip_type='uniform', skip_steps=250)
144 |
145 | ts = mask_guided.get_resample_seq(resample_r=r, resample_j=j)
146 | plt.rcParams["figure.figsize"] = (10, 5)
147 | plt.plot(range(len(ts)), ts)
148 | plt.title(f'r={r}, j={j}')
149 | plt.show()
150 |
151 |
152 | if __name__ == '__main__':
153 | _test(1, 10)
154 | _test(5, 10)
155 | _test(10, 10)
156 |
--------------------------------------------------------------------------------
/diffusions/heun.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | from typing import Dict
3 |
4 | import torch
5 | from torch import Tensor, nn as nn
6 |
7 | from diffusions.ddpm import DDPM
8 |
9 |
10 | class HeunSampler(DDPM):
11 | def __init__(
12 | self,
13 | total_steps: int = 1000,
14 | beta_schedule: str = 'linear',
15 | beta_start: float = 0.0001,
16 | beta_end: float = 0.02,
17 | betas: Tensor = None,
18 | objective: str = 'pred_eps',
19 |
20 | clip_denoised: bool = True,
21 | respace_type: str = None,
22 | respace_steps: int = 100,
23 | respaced_seq: Tensor = None,
24 |
25 | device: torch.device = 'cpu',
26 | **kwargs,
27 | ):
28 | """Heun sampler for DDPM-like diffusion process.
29 |
30 | References:
31 | [1] Karras, Tero, Miika Aittala, Timo Aila, and Samuli Laine. "Elucidating the design space of
32 | diffusion-based generative models." Advances in Neural Information Processing Systems 35 (2022):
33 | 26565-26577.
34 |
35 | """
36 | super().__init__(
37 | total_steps=total_steps,
38 | beta_schedule=beta_schedule,
39 | beta_start=beta_start,
40 | beta_end=beta_end,
41 | betas=betas,
42 | objective=objective,
43 | clip_denoised=clip_denoised,
44 | respace_type=respace_type,
45 | respace_steps=respace_steps,
46 | respaced_seq=respaced_seq,
47 | device=device,
48 | **kwargs,
49 | )
50 |
51 | self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod).sqrt()
52 |
53 | self._1st_order_derivative = None
54 | self._1st_order_xt = None
55 |
56 | def denoise_1st_order(self, model_output: Tensor, xt: Tensor, t: int, t_prev: int):
57 | """1st order step. Same as euler sampler."""
58 | # Prepare parameters
59 | sigmas_t = self.sigmas[t]
60 | sigmas_t_prev = self.sigmas[t_prev] if t_prev >= 0 else torch.tensor(0.0)
61 |
62 | # Predict x0
63 | predict = self.predict(model_output, xt, t)
64 | pred_x0 = predict['pred_x0']
65 |
66 | # Calculate the x{t-1}
67 | bar_xt = (1 + sigmas_t ** 2).sqrt() * xt
68 | derivative = (bar_xt - pred_x0) / sigmas_t
69 | bar_sample = bar_xt + derivative * (sigmas_t_prev - sigmas_t)
70 | sample = bar_sample / (1 + sigmas_t_prev ** 2).sqrt()
71 |
72 | # Store the 1st order info
73 | self._1st_order_derivative = derivative
74 | self._1st_order_xt = xt
75 |
76 | return {'sample': sample, 'pred_x0': pred_x0}
77 |
78 | def denoise_2nd_order(self, model_output: Tensor, xt_prev: Tensor, t: int, t_prev: int):
79 | """2nd order step."""
80 | # Prepare parameters
81 | sigmas_t = self.sigmas[t]
82 | sigmas_t_prev = self.sigmas[t_prev] if t_prev >= 0 else torch.tensor(0.0)
83 |
84 | # Predict x0
85 | predict = self.predict(model_output, xt_prev, t_prev)
86 | pred_x0 = predict['pred_x0']
87 |
88 | # Calculate derivative
89 | bar_xt_prev = (1 + sigmas_t_prev ** 2).sqrt() * xt_prev
90 | derivative = (bar_xt_prev - pred_x0) / sigmas_t_prev
91 | derivative = (derivative + self._1st_order_derivative) / 2
92 |
93 | # Calculate the x{t-1}
94 | bar_xt = (1 + sigmas_t ** 2).sqrt() * self._1st_order_xt
95 | bar_sample = bar_xt + derivative * (sigmas_t_prev - sigmas_t)
96 | sample = bar_sample / (1 + sigmas_t_prev ** 2).sqrt()
97 |
98 | # Clear the 1st order info
99 | self._1st_order_derivative = None
100 | self._1st_order_xt = None
101 |
102 | return {'sample': sample, 'pred_x0': pred_x0}
103 |
104 | def sample_loop(
105 | self, model: nn.Module, init_noise: Tensor,
106 | tqdm_kwargs: Dict = None, model_kwargs: Dict = None,
107 | ):
108 | tqdm_kwargs = dict() if tqdm_kwargs is None else tqdm_kwargs
109 | model_kwargs = dict() if model_kwargs is None else model_kwargs
110 |
111 | img = init_noise
112 | sample_seq = self.respaced_seq.tolist()
113 | sample_seq_prev = [-1] + self.respaced_seq[:-1].tolist()
114 | pbar = tqdm.tqdm(total=len(sample_seq), **tqdm_kwargs)
115 | for t, t_prev in zip(reversed(sample_seq), reversed(sample_seq_prev)):
116 | # 1st order step
117 | t_batch = torch.full((img.shape[0], ), t, device=self.device, dtype=torch.long)
118 | model_output = model(img, t_batch, **model_kwargs)
119 | out = self.denoise_1st_order(model_output, img, t, t_prev)
120 | img = out['sample']
121 |
122 | if t_prev >= 0:
123 | # 2nd order step
124 | t_prev_batch = torch.full((img.shape[0], ), t_prev, device=self.device, dtype=torch.long)
125 | model_output = model(img, t_prev_batch, **model_kwargs)
126 | out = self.denoise_2nd_order(model_output, img, t, t_prev)
127 | img = out['sample']
128 |
129 | pbar.update(1)
130 | yield out
131 | pbar.close()
132 |
--------------------------------------------------------------------------------
/diffusions/schedule.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | def get_beta_schedule(
6 | total_steps: int = 1000,
7 | beta_schedule: str = 'linear',
8 | beta_start: float = 0.0001,
9 | beta_end: float = 0.02,
10 | ):
11 | """Get the beta schedule for diffusion.
12 |
13 | Args:
14 | total_steps: Number of diffusion steps.
15 | beta_schedule: Type of beta schedule. Options: 'linear', 'quad', 'const', 'cosine'.
16 | beta_start: Starting beta value.
17 | beta_end: Ending beta value.
18 |
19 | Returns:
20 | A Tensor of length `total_steps`.
21 |
22 | """
23 | if beta_schedule == 'linear':
24 | return torch.linspace(beta_start, beta_end, total_steps, dtype=torch.float64)
25 | elif beta_schedule == 'quad':
26 | return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, total_steps, dtype=torch.float64) ** 2
27 | elif beta_schedule == 'const':
28 | return torch.full((total_steps, ), fill_value=beta_end, dtype=torch.float64)
29 | elif beta_schedule == 'cosine':
30 | def alpha_bar(t):
31 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
32 | betas = [
33 | min(1 - alpha_bar((i + 1) / total_steps) / alpha_bar(i / total_steps), 0.999)
34 | for i in range(total_steps)
35 | ]
36 | return torch.tensor(betas)
37 | else:
38 | raise ValueError(f'Beta schedule {beta_schedule} is not supported.')
39 |
40 |
41 | def get_respaced_seq(
42 | total_steps: int = 1000,
43 | respace_type: str = 'uniform',
44 | respace_steps: int = 100,
45 | ):
46 | """Get respaced time sequence for fast inference.
47 |
48 | Args:
49 | total_steps: Number of the original diffusion steps.
50 | respace_type: Type of respaced timestep sequence. Options: 'uniform', 'uniform-leading', 'uniform-linspace',
51 | 'uniform-trailing', 'quad', 'none', None.
52 | respace_steps: Length of respaced timestep sequence.
53 |
54 | Returns:
55 | A Tensor of length `respace_steps`, containing indices that are preserved in the respaced sequence.
56 |
57 | """
58 | if respace_type in ['uniform', 'uniform-leading']:
59 | space = total_steps // respace_steps
60 | seq = torch.arange(0, total_steps, space).long()
61 | elif respace_type == 'uniform-linspace':
62 | seq = torch.linspace(0, total_steps - 1, respace_steps).long()
63 | elif respace_type == 'uniform-trailing':
64 | space = total_steps // respace_steps
65 | seq = torch.arange(total_steps-1, -1, -space).long().flip(dims=[0])
66 | elif respace_type == 'quad':
67 | seq = torch.linspace(0, math.sqrt(total_steps * 0.8), respace_steps) ** 2
68 | seq = torch.floor(seq).long()
69 | elif respace_type is None or respace_type == 'none':
70 | seq = torch.arange(0, total_steps).long()
71 | else:
72 | raise ValueError(f'Respace type {respace_type} is not supported.')
73 | return seq
74 |
75 |
76 | def _test_betas():
77 | import matplotlib.pyplot as plt
78 | fig, ax = plt.subplots(1, 2, figsize=(10, 4))
79 |
80 | betas_linear = get_beta_schedule(
81 | total_steps=1000,
82 | beta_schedule='linear',
83 | beta_start=0.0001,
84 | beta_end=0.02,
85 | )
86 | alphas_bar_linear = torch.cumprod(1. - betas_linear, dim=0)
87 | betas_quad = get_beta_schedule(
88 | total_steps=1000,
89 | beta_schedule='quad',
90 | beta_start=0.0001,
91 | beta_end=0.02,
92 | )
93 | alphas_bar_quad = torch.cumprod(1. - betas_quad, dim=0)
94 | betas_cosine = get_beta_schedule(
95 | total_steps=1000,
96 | beta_schedule='cosine',
97 | beta_start=0.0001,
98 | beta_end=0.02,
99 | )
100 | alphas_bar_cosine = torch.cumprod(1. - betas_cosine, dim=0)
101 |
102 | ax[0].plot(torch.arange(1000), betas_linear, label='linear')
103 | ax[0].plot(torch.arange(1000), betas_quad, label='quad')
104 | ax[0].plot(torch.arange(1000), betas_cosine, label='cosine')
105 | ax[0].set_title(r'$\beta_t$')
106 | ax[0].set_xlabel(r'$t$')
107 | ax[0].legend()
108 | ax[1].plot(torch.arange(1000), alphas_bar_linear, label='linear')
109 | ax[1].plot(torch.arange(1000), alphas_bar_quad, label='quad')
110 | ax[1].plot(torch.arange(1000), alphas_bar_cosine, label='cosine')
111 | ax[1].set_title(r'$\bar\alpha_t$')
112 | ax[1].set_xlabel(r'$t$')
113 | ax[1].legend()
114 | plt.show()
115 |
116 |
117 | def _test_respace():
118 | seq = get_respaced_seq(
119 | total_steps=1000,
120 | respace_type='uniform-leading',
121 | respace_steps=10,
122 | )
123 | print('uniform-leading:\t', seq)
124 | seq = get_respaced_seq(
125 | total_steps=1000,
126 | respace_type='uniform-linspace',
127 | respace_steps=10,
128 | )
129 | print('uniform-linspace:\t', seq)
130 | seq = get_respaced_seq(
131 | total_steps=1000,
132 | respace_type='uniform-trailing',
133 | respace_steps=10,
134 | )
135 | print('uniform-trailing:\t', seq)
136 |
137 |
138 | if __name__ == '__main__':
139 | _test_betas()
140 | # _test_respace()
141 |
--------------------------------------------------------------------------------
/docs/CLIP Guidance.md:
--------------------------------------------------------------------------------
1 | # CLIP Guidance
2 |
3 | CLIP Guidance is a technique to generate images following an input text description with a pretrained diffusion model and a pretrained CLIP model. It uses CLIP score to guide the reverse diffusion process during sampling. To be specific, each reverse step is modified to:
4 | $$
5 | \begin{align}
6 | &p_\theta(\mathbf x_{t-1}\vert\mathbf x_{t})=\mathcal N(\mathbf x_{t-1};\mu_\theta(\mathbf x_t,t){\color{dodgerblue}{+s\sigma_t^2\nabla_{\mathbf x_t}\mathcal L_{\text{CLIP}}}},\sigma_t^2\mathbf I)\\
7 | &\mathcal L_\text{CLIP}=E_\text{image}(\mathbf x_\theta(\mathbf x_t,t))\cdot E_\text{text}(y)
8 | \end{align}
9 | $$
10 | where $y$ is the input text, $E_\text{image}$ and $E_\text{text}$ are CLIP's image and text encoders, and $s$ is a hyper-parameter controlling the scale of guidance.
11 |
12 |
13 |
14 | ## Sampling
15 |
16 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms.
17 |
18 | ```shell
19 | accelerate-launch scripts/sample_clip_guidance.py -c CONFIG \
20 | --weights WEIGHTS \
21 | --text TEXT \
22 | --n_samples N_SAMPLES \
23 | --save_dir SAVE_DIR \
24 | [--seed SEED] \
25 | [--var_type VAR_TYPE] \
26 | [--respace_type RESPACE_TYPE] \
27 | [--respace_steps RESPACE_STEPS] \
28 | [--guidance_weight GUIDANCE_WEIGHT] \
29 | [--clip_model CLIP_MODEL] \
30 | [--batch_size BATCH_SIZE]
31 | ```
32 |
33 | Basic arguments:
34 |
35 | - `-c CONFIG`: path to the configuration file.
36 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file.
37 | - `--text TEXT`: text description. Please wrap your description with quotation marks if it contains spaces, e.g., `--text 'a lovely dog'`.
38 | - `--n_samples N_SAMPLES`: number of samples to generate.
39 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved.
40 | - `--guidance_weight GUIDANCE_WEIGHT`: guidance weight (strength).
41 | - `--clip_model CLIP_MODEL`: name of CLIP model. Default to "openai/clip-vit-base-patch32".
42 |
43 | Advanced arguments:
44 |
45 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps.
46 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices.
47 |
48 | See more details by running `python sample_clip_guidance -h`.
49 |
50 |
51 |
52 | ## Results
53 |
54 | **Pretrained on CelebA-HQ 256x256**:
55 |
56 |
57 |
58 |
59 |
60 | All the images are sampled with 50 DDPM steps.
61 |
62 | Images in the same row are sampled with the same random seed, thus share similar semantics. The first column shows the original samples (i.e., guidance scale=0). The next 3 columns are sampled using text description "a young girl with brown hair", with increasing guidance scales of 10, 50, and 100. The following 3 columns are similarly sampled with text description "an old man with a smile".
63 |
64 | As expected, bigger guidance scale leads to greater changes. However, some results fail to match the descriptions, such as gender and hair color. Furthermore, guidance scale larger than 100 damages the image quality drastically.
65 |
--------------------------------------------------------------------------------
/docs/Classifier-Free Guidance.md:
--------------------------------------------------------------------------------
1 | # Classifier-Free Guidance
2 |
3 | > Ho, Jonathan, and Tim Salimans. "Classifier-Free Diffusion Guidance." In *NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications*. 2021.
4 |
5 |
6 |
7 | ## Training
8 |
9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms.
10 |
11 | ```shell
12 | accelerate-launch scripts/train_ddpm_cfg.py -c CONFIG [-e EXP_DIR] [--key value ...]
13 | ```
14 |
15 | Arguments:
16 |
17 | - `-c CONFIG`: path to the training configuration file.
18 | - `-e EXP_DIR`: results (logs, checkpoints, tensorboard, etc.) will be saved to `EXP_DIR`. Default to `runs/exp-{current time}/`.
19 | - `--key value`: modify configuration items in `CONFIG` via CLI.
20 |
21 | For example, to train on CIFAR-10 with default settings:
22 |
23 | ```shell
24 | accelerate-launch scripts/train_ddpm_cfg.py -c ./configs/ddpm_cfg_cifar10.yaml
25 | ```
26 |
27 | To change the default `p_uncond` (the probability to disable condition in training) in `./configs/ddpm_cfg_cifar10.yaml` from 0.2 to 0.1:
28 |
29 | ```shell
30 | accelerate-launch scripts/train_ddpm_cfg.py -c ./configs/ddpm_cfg_cifar10.yaml --train.p_uncond 0.1
31 | ```
32 |
33 |
34 |
35 | ## Sampling
36 |
37 | ```shell
38 | accelerate-launch scripts/sample_cfg.py -c CONFIG \
39 | --weights WEIGHTS \
40 | --sampler {ddpm,ddim} \
41 | --n_samples_each_class N_SAMPLES_EACH_CLASS \
42 | --save_dir SAVE_DIR \
43 | --guidance_scale GUIDANCE_SCALE \
44 | [--seed SEED] \
45 | [--class_ids CLASS_IDS [CLASS_IDS ...]] \
46 | [--respace_type RESPACE_TYPE] \
47 | [--respace_steps RESPACE_STEPS] \
48 | [--ddim] \
49 | [--ddim_eta DDIM_ETA] \
50 | [--batch_size BATCH_SIZE]
51 | ```
52 |
53 | Basic arguments:
54 |
55 | - `-c CONFIG`: path to the configuration file.
56 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file.
57 | - `--sampler {ddpm,ddim}`: set the sampler.
58 | - `--n_samples_each_class N_SAMPLES_EACH_CLASS`: number of samples for each class.
59 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved.
60 | - `--guidance_scale GUIDANCE_SCALE`: the guidance scale factor $s$
61 | - $s=0$: unconditional generation
62 | - $s=1$: non-guided conditional generation
63 | - $s>1$: guided conditional generation
64 |
65 | Advanced arguments:
66 |
67 | - `--class_ids CLASS_IDS [CLASS_IDS ...]`: a list of class ids to sample. If not specified, all classes will be sampled.
68 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps.
69 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices.
70 |
71 | See more details by running `python sample_cfg.py -h`.
72 |
73 | For example, to sample 10 images for class (0, 2, 4, 8) from a pretrained CIFAR-10 model with guidance scale 3 using 100 DDIM steps:
74 |
75 | ```shell
76 | accelerate-launch scripts/sample_cfg.py -c ./configs/ddpm_cfg_cifar10.yaml --weights /path/to/model/weights --sampler ddim --n_samples_each_class 10 --save_dir ./samples/cfg-cifar10 --guidance_scale 3 --class_ids 0 2 4 8 --respace_steps 100
77 | ```
78 |
79 |
80 |
81 | ## Evaluation
82 |
83 | Sample 10K-50K images following the previous section and evaluate image quality with tools like [torch-fidelity](https://github.com/toshas/torch-fidelity), [pytorch-fid](https://github.com/mseitzer/pytorch-fid), [clean-fid](https://github.com/GaParmar/clean-fid), etc.
84 |
85 |
86 |
87 | ## Results
88 |
89 | **FID and IS on CIFAR-10 32x32**:
90 |
91 | | guidance scale | FID ↓ | IS ↑ |
92 | | :------------------------: | :------: | :-------------: |
93 | | 0 (unconditional) | 6.2904 | 8.9851 ± 0.0825 |
94 | | 1 (non-guided conditional) | 4.6630 | 9.1763 ± 0.1201 |
95 | | 3 (guided conditional) | 10.2304 | 9.6252 ± 0.0977 |
96 | | 5 (guided conditional) | 16.23021 | 9.3210 ± 0.0744 |
97 |
98 | - The images are sampled using DDIM with 50 steps.
99 | - All the metrics are evaluated on 50K samples.
100 | - FID measures diversity and IS measures fidelity. This table shows diversity-fidelity trade-off as guidance scale increases.
101 |
102 |
103 |
104 | **Samples with different guidance scales on CIFAR-10 32x32**:
105 |
106 |
107 |
108 |
109 | From left to right: $s=0$ (unconditional), $s=1.0$ (non-guided conditional), $s=3.0$, $s=5.0$. Each row corresponds to a class.
110 |
111 |
112 |
113 | **Samples with different guidance scales on ImageNet 256x256**:
114 |
115 | The pretrained models are sourced from [openai/guided-diffusion](https://github.com/openai/guided-diffusion). Note that these models were initially designed for classifier guidance and thus are either conditional-only or unconditional-only. However, to facilitate classifier-free guidance, it would be more convenient if the model can handle both conditional and unconditional cases. To address this, I define a new class [UNetCombined](../models/adm/unet_combined.py), which combines the conditional-only and unconditional-only models into a single model. Also, we need to combine the pretrained weights for loading, which can be done by the following script:
116 |
117 | ```python
118 | import yaml
119 | from models.adm.unet_combined import UNetCombined
120 |
121 |
122 | config_path = './weights/openai/guided-diffusion/256x256_diffusion.yaml'
123 | weight_cond_path = './weights/openai/guided-diffusion/256x256_diffusion.pt'
124 | weight_uncond_path = './weights/openai/guided-diffusion/256x256_diffusion_uncond.pt'
125 | save_path = './weights/openai/guided-diffusion/256x256_diffusion_combined.pt'
126 |
127 | with open(config_path, 'r') as f:
128 | cfg = yaml.safe_load(f)
129 | model = UNetCombined(**cfg['model']['params'])
130 | model.combine_weights(weight_cond_path, weight_uncond_path, save_path)
131 | ```
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 | From left to right: $s=1.0$ (non-guided conditional), $s=2.0$, $s=3.0$. Each row corresponds to a class.
140 |
--------------------------------------------------------------------------------
/docs/DDIB.md:
--------------------------------------------------------------------------------
1 | # DDIB
2 |
3 | > Su, Xuan, Jiaming Song, Chenlin Meng, and Stefano Ermon. "Dual diffusion implicit bridges for image-to-image translation." *arXiv preprint arXiv:2203.08382* (2022).
4 |
5 |
6 |
7 | ## Sampling
8 |
9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms.
10 |
11 | ```shell
12 | accelerate-launch scripts/sample_ddib.py -c CONFIG \
13 | --weights WEIGHTS \
14 | --input_dir INPUT_DIR \
15 | --save_dir SAVE_DIR \
16 | --class_A CLASS_A \
17 | --class_B CLASS_B \
18 | [--seed SEED] \
19 | [--respace_type RESPACE_TYPE] \
20 | [--respace_steps RESPACE_STEPS] \
21 | [--batch_size BATCH_SIZE]
22 | ```
23 |
24 | Basic arguments:
25 |
26 | - `-c CONFIG`: path to the configuration file.
27 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file.
28 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved.
29 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved.
30 | - `--class_A CLASS_A`: input class label.
31 | - `--class_B CLASS_B`: output class label.
32 |
33 | Advanced arguments:
34 |
35 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps.
36 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices.
37 |
38 | See more details by running `python sample_ddib.py -h`.
39 |
40 |
41 |
42 | ## Results
43 |
44 | **ImageNet 256x256 (conditional)** with pretrained model from [openai/guided-diffusion](https://github.com/openai/guided-diffusion):
45 |
46 |
47 |
48 |
49 |
50 | Notes: All images are sampled with 100 DDIM steps.
51 |
52 | The results are not as good as expected. Some are acceptable, such as the Sussex Spaniel, Husky and Tiger in the 3rd row, but the others are not.
53 |
--------------------------------------------------------------------------------
/docs/DDIM.md:
--------------------------------------------------------------------------------
1 | # DDIM
2 |
3 | > Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising Diffusion Implicit Models." In *International Conference on Learning Representations*. 2020.
4 |
5 |
6 |
7 | ## Sampling
8 |
9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms.
10 |
11 | ```shell
12 | accelerate-launch scripts/sample_uncond.py -c CONFIG \
13 | --weights WEIGHTS \
14 | --sampler ddim \
15 | --ddim_eta DDIM_ETA \
16 | --n_samples N_SAMPLES \
17 | --save_dir SAVE_DIR \
18 | [--seed SEED] \
19 | [--batch_size BATCH_SIZE] \
20 | [--respace_type RESPACE_TYPE] \
21 | [--respace_steps RESPACE_STEPS] \
22 | [--mode {sample,interpolate,reconstruction}] \
23 | [--n_interpolate N_INTERPOLATE] \
24 | [--input_dir INPUT_DIR]
25 | ```
26 |
27 | Basic arguments:
28 |
29 | - `-c CONFIG`: path to the inference configuration file.
30 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file.
31 | - `--sampler ddim`: set the sampler to DDIM.
32 | - `--n_samples N_SAMPLES`: number of samples.
33 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved.
34 | - `--mode MODE`: choose a sampling mode, the options are:
35 | - "sample" (default): randomly sample images
36 | - "interpolate": sample two random images and interpolate between them. Use `--n_interpolate` to specify the number of images in between.
37 | - "reconstruction": encode a real image from dataset with **DDIM inversion** (DDIM encoding), and then decode it with DDIM sampling.
38 |
39 | Advanced arguments:
40 |
41 | - `--ddim_eta`: parameter eta in DDIM sampling.
42 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps.
43 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices.
44 |
45 | See more details by running `python sample_ddim.py -h`.
46 |
47 | For example, to sample 50000 images from a pretrained CIFAR-10 model with 100 DDIM steps:
48 |
49 | ```shell
50 | accelerate-launch scripts/sample_uncond.py -c ./configs/ddpm_cifar10.yaml --weights /path/to/model/weights --sampler ddim --n_samples 50000 --save_dir ./samples/ddim-cifar10 --respace_steps 100
51 | ```
52 |
53 |
54 |
55 | ## Evaluation
56 |
57 | Sample 10K-50K images following the previous section and evaluate image quality with tools like [torch-fidelity](https://github.com/toshas/torch-fidelity), [pytorch-fid](https://github.com/mseitzer/pytorch-fid), [clean-fid](https://github.com/GaParmar/clean-fid), etc.
58 |
59 |
60 |
61 | ## Results
62 |
63 | **FID and IS on CIFAR-10 32x32**:
64 |
65 | All the metrics are evaluated on 50K samples using [torch-fidelity](https://torch-fidelity.readthedocs.io/en/latest/index.html) library.
66 |
67 |
68 |
69 | eta |
70 | timesteps |
71 | FID ↓ |
72 | IS ↑ |
73 |
74 |
75 | 0.0 |
76 | 1000 |
77 | 4.1892 |
78 | 9.0626 ± 0.1093 |
79 |
80 |
81 | 100 (10x faster) |
82 | 6.0508 |
83 | 8.8424 ± 0.0862 |
84 |
85 |
86 | 50 (20x faster) |
87 | 7.7011 |
88 | 8.7076 ± 0.1021 |
89 |
90 |
91 | 20 (50x faster) |
92 | 11.6506 |
93 | 8.4744 ± 0.0879 |
94 |
95 |
96 | 10 (100x faster) |
97 | 18.9559 |
98 | 8.0852 ± 0.1137 |
99 |
100 |
101 |
102 |
103 |
104 | **Sample with fewer steps**:
105 |
106 |
107 |
108 |
109 |
110 | From top to bottom: 10 steps, 50 steps, 100 steps and 1000 steps. It can be seen that fewer steps leads to blurrier results, but human eyes can hardly distinguish the difference between 50/100 steps and 1000 steps.
111 |
112 |
113 |
114 | **Spherical linear interpolation (slerp) between two samples (sample with 100 steps)**:
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 | **Reconstruction (sample with 100 steps)**:
126 |
127 |
128 |
129 |
130 |
131 | In each pair, image on the left is the real image sampled from dataset, the other is the reconstructed image generated by DDIM inversion + DDIM sampling.
132 |
133 |
134 |
135 | **Reconstruction (sample with 1000 steps)**:
136 |
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/docs/DDPM-IP.md:
--------------------------------------------------------------------------------
1 | # DDPM-IP
2 |
3 | > Ning, Mang, Enver Sangineto, Angelo Porrello, Simone Calderara, and Rita Cucchiara. "Input Perturbation Reduces Exposure Bias in Diffusion Models." *arXiv preprint arXiv:2301.11706* (2023).
4 |
5 |
6 |
7 | ## Training
8 |
9 | Almost the same as DDPM (see [doc](./DDPM.md)), except using `diffusions.ddpm_ip.DDPM_IP` instead of `diffusions.ddpm.DDPM`.
10 |
11 | For example, to train on CIFAR-10 with default settings:
12 |
13 | ```shell
14 | accelerate-launch scripts/train_ddpm.py -c ./configs/ddpm_cifar10.yaml --diffusion.target diffusions.ddpm_ip.DDPM_IP
15 | ```
16 |
17 |
18 |
19 | ## Sampling & Evaluation
20 |
21 | Exactly the same as DDPM, refer to [doc](./DDPM.md) for more information.
22 |
23 |
24 |
25 | ## Results
26 |
27 | **FID and IS on CIFAR-10 32x32**:
28 |
29 | All the metrics are evaluated on 50K samples using [torch-fidelity](https://torch-fidelity.readthedocs.io/en/latest/index.html) library.
30 |
31 |
32 |
33 | Type of variance |
34 | timesteps |
35 | FID ↓ |
36 | IS ↑ |
37 |
38 |
39 | fixed-large |
40 | 1000 |
41 | 3.2497 |
42 | 9.4885 ± 0.09244 |
43 |
44 |
45 | 100 |
46 | 46.7994 |
47 | 8.5720 ± 0.0917 |
48 |
49 |
50 | 50 |
51 | 87.1883 |
52 | 6.1429 ± 0.0630 |
53 |
54 |
55 | 10 |
56 | 268.1108 |
57 | 1.5842 ± 0.0055 |
58 |
59 |
60 | fixed-small |
61 | 1000 |
62 | 4.4868 |
63 | 9.1092 ± 0.1025 |
64 |
65 |
66 | 100 |
67 | 9.2460 |
68 | 8.7068 ± 0.0813 |
69 |
70 |
71 | 50 |
72 | 12.7965 |
73 | 8.4902 ± 0.0701 |
74 |
75 |
76 | 10 |
77 | 35.5062 |
78 | 7.3680 ± 0.1092 |
79 |
80 |
81 |
82 |
83 | The results are substantially better than DDPM for fixed-small variance, but not for fixed-large variance.
84 |
85 |
--------------------------------------------------------------------------------
/docs/ILVR.md:
--------------------------------------------------------------------------------
1 | # ILVR
2 |
3 | Iterative Latent Variable Refinement (ILVR).
4 |
5 | > Choi, Jooyoung, Sungwon Kim, Yonghyun Jeong, Youngjune Gwon, and Sungroh Yoon. “ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models.” In 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pp. 14347-14356. IEEE, 2021.
6 |
7 |
8 |
9 | ## Sampling
10 |
11 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms.
12 |
13 | ```shell
14 | accelerate-launch scripts/sample_ilvr.py -c CONFIG \
15 | --weights WEIGHTS \
16 | --input_dir INPUT_DIR \
17 | --save_dir SAVE_DIR \
18 | [--seed SEED] \
19 | [--var_type VAR_TYPE] \
20 | [--respace_type RESPACEP_TYPE] \
21 | [--respace_steps RESPACE_STEPS] \
22 | [--downsample_factor DOWNSAMPLE_FACTOR] \
23 | [--interp_method {cubic,lanczos2,lanczos3,linear,box}] \
24 | [--batch_size BATCH_SIZE]
25 | ```
26 |
27 | Basic arguments:
28 |
29 | - `-c CONFIG`: path to the configuration file.
30 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file.
31 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved.
32 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved.
33 | - `--downsample_factor DOWNSAMPLE_FACTOR`: higher factor leads to more diverse results.
34 |
35 | Advanced arguments:
36 |
37 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps.
38 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices.
39 |
40 | See more details by running `python sample_ilvr.py -h`.
41 |
42 |
43 |
44 | ## Notes
45 |
46 | Using correct image resizing methods ([ResizeRight](https://github.com/assafshocher/ResizeRight)) is **CRUCIAL**! The default resizing functions in PyTorch (`torch.nn.functional.interpolate`) will damage the results.
47 |
48 |
49 |
50 | ## Results
51 |
52 | **Pretrained CelebA-HQ 256x256**:
53 |
54 | Note: All the images are sampled with 50 DDPM steps.
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/docs/Mask Guidance.md:
--------------------------------------------------------------------------------
1 | # Mask Guidance
2 |
3 | Mask Guidance is a technique to fill the masked area in an input image with a pretrained diffusion model. It was first proposed in [1] and further developed in [2], [3], etc. for image inpainting.
4 |
5 | Directly applying mask guidance may lead to inconsistent semantic between masked and unmasked areas. To overcome this problem, RePaint[3] proposed a resampling strategy, which goes forward and backward on the Markov chain from time to time.
6 |
7 | > [1]. Song, Yang, and Stefano Ermon. “Generative modeling by estimating gradients of the data distribution.”
8 | > Advances in neural information processing systems 32 (2019).
9 | >
10 | > [2]. Avrahami, Omri, Dani Lischinski, and Ohad Fried. “Blended diffusion for text-driven editing of natural
11 | > images.” In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 18208
12 | > -18218. 2022.
13 | >
14 | > [3]. Lugmayr, Andreas, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, and Luc Van Gool. “Repaint:
15 | > Inpainting using denoising diffusion probabilistic models.” In Proceedings of the IEEE/CVF Conference on
16 | > Computer Vision and Pattern Recognition, pp. 11461-11471. 2022.
17 |
18 |
19 |
20 | ## Sampling
21 |
22 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms.
23 |
24 | ```shell
25 | accelerate-launch scripts/sample_mask_guidance.py -c CONFIG \
26 | --weights WEIGHTS \
27 | --input_dir INPUT_DIR \
28 | --save_dir SAVE_DIR \
29 | [--seed SEED] \
30 | [--var_type VAR_TYPE] \
31 | [--respace_type RESPACE_TYPE] \
32 | [--respace_steps RESPACE_STEPS] \
33 | [--resample] \
34 | [--resample_r RESAMPLE_R] \
35 | [--resample_j RESAMPLE_J] \
36 | [--batch_size BATCH_SIZE]
37 | ```
38 |
39 | Basic arguments:
40 |
41 | - `-c CONFIG`: path to the configuration file.
42 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file.
43 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved.
44 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved.
45 | - `--resample`: use the resample strategy proposed in RePaint paper[3]. This strategy has two hyperparameters:
46 | - `--resample_r RESAMPLE_R`: number of resampling.
47 | - `--resample_j RESAMPLE_J`: jump lengths.
48 |
49 | Advanced arguments:
50 |
51 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps.
52 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices.
53 |
54 | See more details by running `python sample_mask_guidance.py -h`.
55 |
56 |
57 |
58 | ## Results
59 |
60 | **ImageNet 256x256** with pretrained model from [openai/guided-diffusion](https://github.com/openai/guided-diffusion):
61 |
62 |
63 |
64 |
65 |
66 | Notes:
67 |
68 | - All the images are sampled with 50 DDPM steps.
69 | - Jump length $j$ is fixed to 10.
70 | - $r=1$ is equivalent to the original DDPM sampling (w/o resampling).
71 |
72 |
--------------------------------------------------------------------------------
/docs/SDEdit.md:
--------------------------------------------------------------------------------
1 | # SDEdit
2 |
3 | > Meng, Chenlin, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan Zhu, and Stefano Ermon. "Sdedit: Guided image synthesis and editing with stochastic differential equations." In *International Conference on Learning Representations*. 2021.
4 |
5 |
6 |
7 | ## Sampling
8 |
9 | This repo uses the [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) library for multi-GPUs/fp16 supports. Please read the [documentation](https://huggingface.co/docs/accelerate/basic_tutorials/launch#using-accelerate-launch) on how to launch the script on different platforms.
10 |
11 | ```shell
12 | accelerate-launch scripts/sample_sdedit.py -c CONFIG \
13 | --weights WEIGHTS \
14 | --input_dir INPUT_DIR \
15 | --save_dir SAVE_DIR \
16 | --edit_steps EDIT_STEPS \
17 | [--seed SEED] \
18 | [--var_type VAR_TYPE] \
19 | [--respace_type RESPACE_TYPE] \
20 | [--respace_steps RESPACE_STEPS] \
21 | [--batch_size BATCH_SIZE]
22 | ```
23 |
24 | Basic arguments:
25 |
26 | - `-c CONFIG`: path to the configuration file.
27 | - `--weights WEIGHTS`: path to the model weights (checkpoint) file.
28 | - `--input_dir INPUT_DIR`: path to the directory where input images are saved.
29 | - `--save_dir SAVE_DIR`: path to the directory where samples will be saved.
30 | - `--edit_steps EDIT_STEPS`: number of edit steps. Controls realism-faithfulness trade-off.
31 |
32 | Advanced arguments:
33 |
34 | - `--respace_steps RESPACE_STEPS`: faster sampling that uses respaced timesteps.
35 | - `--batch_size BATCH_SIZE`: Batch size on each process. Sample by batch is faster, so set it as large as possible to fully utilize your devices.
36 |
37 | See more details by running `python sample_sdedit.py -h`.
38 |
39 |
40 |
41 | ## Results
42 |
43 | **LSUN-Church 256x256** with pretrained model from [pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion):
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/docs/Samplers.md:
--------------------------------------------------------------------------------
1 | # Samplers: Fidelity-Speed Visualization
2 |
3 | Once a diffusion model is trained, we can use different samplers (SDE / ODE solvers) to generate samples. Currently, this repo supports the following samplers:
4 |
5 | - DDPM
6 | - DDIM
7 | - Euler
8 |
9 | We can choose the number of steps for the samplers to generate samples. Generally speaking, the more steps we use, the better the fidelity of the samples. However, the speed of the sampler also decreases as the number of steps increases. Therefore, it is important to choose the right number of steps to balance the trade-off between fidelity and speed.
10 |
11 | The table and figure below show the trade-off between fidelity and speed of different samplers, based on the same model trained on CIFAR-10 following the standard DDPM. All the metrics are evaluated on 50K samples using [torch-fidelity](https://torch-fidelity.readthedocs.io/en/latest/index.html) library.
12 |
13 |
14 |
15 | sampler |
16 | Steps |
17 | NFE |
18 | FID ↓ |
19 | IS ↑ |
20 |
21 |
22 | DDPM (fixed-large) |
23 | 1000 |
24 | 1000 |
25 | 3.0459 |
26 | 9.4515 ± 0.1179 |
27 |
28 |
29 | 100 |
30 | 100 |
31 | 46.5454 |
32 | 8.7223 ± 0.0923 |
33 |
34 |
35 | 50 |
36 | 50 |
37 | 85.2221 |
38 | 6.3863 ± 0.0894 |
39 |
40 |
41 | 20 |
42 | 20 |
43 | 183.3468 |
44 | 2.6885 ± 0.0176 |
45 |
46 |
47 | 10 |
48 | 10 |
49 | 266.7540 |
50 | 1.5870 ± 0.0092 |
51 |
52 |
53 | DDPM (fixed-small) |
54 | 1000 |
55 | 1000 |
56 | 5.3727 |
57 | 9.0118 ± 0.0968 |
58 |
59 |
60 | 100 |
61 | 100 |
62 | 11.2191 |
63 | 8.6237 ± 0.0921 |
64 |
65 |
66 | 50 |
67 | 50 |
68 | 15.0471 |
69 | 8.4077 ± 0.1623 |
70 |
71 |
72 | 20 |
73 | 20 |
74 | 24.5131 |
75 | 7.9957 ± 0.1067 |
76 |
77 |
78 | 10 |
79 | 10 |
80 | 41.0479 |
81 | 7.1373 ± 0.0801 |
82 |
83 |
84 | DDIM (eta=0) |
85 | 1000 |
86 | 1000 |
87 | 4.1892 |
88 | 9.0626 ± 0.1093 |
89 |
90 |
91 | 100 |
92 | 100 |
93 | 6.0508 |
94 | 8.8424 ± 0.0862 |
95 |
96 |
97 | 50 |
98 | 50 |
99 | 7.7011 |
100 | 8.7076 ± 0.1021 |
101 |
102 |
103 | 20 |
104 | 20 |
105 | 11.6506 |
106 | 8.4744 ± 0.0879 |
107 |
108 |
109 | 10 |
110 | 10 |
111 | 18.9559 |
112 | 8.0852 ± 0.1137 |
113 |
114 |
115 | Euler |
116 | 1000 |
117 | 1000 |
118 | 4.2099 |
119 | 9.0678 ± 0.1191 |
120 |
121 |
122 | 100 |
123 | 100 |
124 | 6.0469 |
125 | 8.8511 ± 0.1054 |
126 |
127 |
128 | 50 |
129 | 50 |
130 | 7.6770 |
131 | 8.7217 ± 0.1122 |
132 |
133 |
134 | 20 |
135 | 20 |
136 | 11.6681 |
137 | 8.4362 ± 0.1151 |
138 |
139 |
140 | 10 |
141 | 10 |
142 | 18.7698 |
143 | 8.0287 ± 0.0781 |
144 |
145 |
146 | Heun |
147 | 500 |
148 | 999 |
149 | 4.0046 |
150 | 9.0509 ± 0.1475 |
151 |
152 |
153 | 50 |
154 | 99 |
155 | 3.4687 |
156 | 9.2595 ± 0.1323 |
157 |
158 |
159 | 25 |
160 | 49 |
161 | 5.8767 |
162 | 9.4325 ± 0.1308 |
163 |
164 |
165 | 10 |
166 | 19 |
167 | 29.6088 |
168 | 8.4687 ± 0.0864 |
169 |
170 |
171 | 5 |
172 | 9 |
173 | 82.0586 |
174 | 5.3521 ± 0.0646 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 | Notes:
187 |
188 | - DDPM (fixed-small) is equivalent to DDIM(η=1).
189 | - DDPM (fixed-large) performs better than DDPM (fixed-small) with 1000 steps, but degrades drastically as the number of steps decreases. If you check on the samples from DDPM (fixed-large) (<= 100 steps), you'll find that they still contain noticeable noises.
190 | - Euler sampler and DDIM (η=0) have almost the same performance.
191 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ema import EMA
2 | from .unet import UNet
3 | from .unet_categorial_adagn import UNetCategorialAdaGN
4 |
--------------------------------------------------------------------------------
/models/adm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/adm/__init__.py
--------------------------------------------------------------------------------
/models/adm/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/models/adm/readme.md:
--------------------------------------------------------------------------------
1 | The ADM UNet architecture proposed in [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233).
2 |
3 | Codes are copied and adapted from https://github.com/openai/guided-diffusion.
4 |
--------------------------------------------------------------------------------
/models/adm/unet_combined.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .unet import UNetModel
4 |
5 |
6 | class UNetCombined(nn.Module):
7 | """ Combines a conditional-only model and an unconditional-only model into a single nn.Module.
8 |
9 | The guided diffusion models proposed by OpenAI are trained to be either conditional or unconditional,
10 | leading to difficulties if we want to use their pretrained models in classifier-free guidance. This
11 | class wraps a conditional model and an unconditional model, and decides which one to use based on the
12 | input class label.
13 |
14 | """
15 | def __init__(self, *args, **kwargs):
16 | super().__init__()
17 | assert kwargs.get('num_classes') is not None
18 | self.unet_cond = UNetModel(*args, **kwargs)
19 | kwargs_uncond = kwargs.copy()
20 | kwargs_uncond.update({'num_classes': None})
21 | self.unet_uncond = UNetModel(*args, **kwargs_uncond)
22 |
23 | def forward(self, x, timesteps, y=None):
24 | unet = self.unet_uncond if y is None else self.unet_cond
25 | return unet(x, timesteps, y)
26 |
27 | def combine_weights(self, cond_path, uncond_path, save_path):
28 | ckpt_cond = torch.load(cond_path, map_location='cpu')
29 | ckpt_uncond = torch.load(uncond_path, map_location='cpu')
30 | self.unet_cond.load_state_dict(ckpt_cond)
31 | self.unet_uncond.load_state_dict(ckpt_uncond)
32 | torch.save(self.state_dict(), save_path)
33 |
--------------------------------------------------------------------------------
/models/base_latent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 |
5 |
6 | class BaseLatent(nn.Module):
7 | def __init__(self, scale_factor: float = 1.0):
8 | super().__init__()
9 | self.register_buffer('scale_factor', torch.tensor(scale_factor))
10 | self.device = self.scale_factor.device
11 |
12 | def to(self, *args, **kwargs):
13 | super().to(*args, **kwargs)
14 | self.device = self.scale_factor.device
15 | return self
16 |
17 | def forward(self, x: Tensor, timesteps: Tensor):
18 | raise NotImplementedError
19 |
20 | def encode_latent(self, x: Tensor):
21 | raise NotImplementedError
22 |
23 | def decode_latent(self, z: Tensor):
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/models/dit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/dit/__init__.py
--------------------------------------------------------------------------------
/models/dit/autoencoder.py:
--------------------------------------------------------------------------------
1 | import diffusers
2 |
3 |
4 | def AutoEncoderKL(from_pretrained: str):
5 | return diffusers.AutoencoderKL.from_pretrained(from_pretrained)
6 |
--------------------------------------------------------------------------------
/models/dit/dit.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping, Any
2 | from omegaconf import OmegaConf
3 |
4 | from torch import Tensor
5 |
6 | from ..base_latent import BaseLatent
7 | from utils.misc import instantiate_from_config
8 |
9 |
10 | class DiT(BaseLatent):
11 | def __init__(
12 | self,
13 | vae_config: OmegaConf,
14 | vit_config: OmegaConf,
15 | scale_factor: float = 0.18215,
16 | ):
17 | super().__init__(scale_factor=scale_factor)
18 |
19 | self.vae = instantiate_from_config(vae_config)
20 | self.vit = instantiate_from_config(vit_config)
21 |
22 | def decode_latent(self, z: Tensor):
23 | z = 1. / self.scale_factor * z
24 | return self.vae.decode(z).sample
25 |
26 | def vit_forward(self, x: Tensor, t: Tensor, y: Tensor):
27 | return self.vit(x, t, y)
28 |
29 | def forward(self, x: Tensor, timesteps: Tensor, y: Tensor = None):
30 | return self.vit_forward(x, timesteps, y)
31 |
32 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
33 | self.vit.load_state_dict(state_dict, strict=strict, assign=assign)
34 |
--------------------------------------------------------------------------------
/models/dit/readme.md:
--------------------------------------------------------------------------------
1 | The DiT architecture proposed in [Scalable Diffusion Models with Transformers](http://arxiv.org/abs/2212.09748).
2 |
3 | Codes are copied and adapted from https://github.com/facebookresearch/DiT.
4 |
--------------------------------------------------------------------------------
/models/ema.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class EMA:
8 | """Exponential moving average of model parameters.
9 |
10 | References:
11 | - https://github.com/huggingface/diffusers/blob/main/src/diffusers/training_utils.py#L76
12 | - https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
13 | - https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/ema.py#L5
14 | - https://github.com/lucidrains/ema-pytorch
15 |
16 | """
17 | def __init__(
18 | self,
19 | parameters: Iterable[nn.Parameter],
20 | decay: float = 0.9999,
21 | gradual: bool = True,
22 | ):
23 | """
24 | Args:
25 | parameters: Iterable of parameters, typically from `model.parameters()`.
26 | decay: The decay factor for exponential moving average.
27 | gradual: Whether to a gradually increasing decay factor.
28 |
29 | """
30 | super().__init__()
31 | self.decay = decay
32 | self.gradual = gradual
33 |
34 | self.num_updates = 0
35 | self.shadow = [param.detach().clone() for param in parameters]
36 | self.backup = []
37 |
38 | def get_decay(self):
39 | if self.gradual:
40 | return min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
41 | else:
42 | return self.decay
43 |
44 | @torch.no_grad()
45 | def update(self, parameters: Iterable[nn.Parameter]):
46 | self.num_updates += 1
47 | decay = self.get_decay()
48 | for s_param, param in zip(self.shadow, parameters):
49 | if param.requires_grad:
50 | s_param.sub_((1. - decay) * (s_param - param))
51 | else:
52 | s_param.copy_(param)
53 |
54 | def apply_shadow(self, parameters: Iterable[nn.Parameter]):
55 | assert len(self.backup) == 0, 'backup is not empty'
56 | for s_param, param in zip(self.shadow, parameters):
57 | self.backup.append(param.detach().cpu().clone())
58 | param.data.copy_(s_param.data)
59 |
60 | def restore(self, parameters: Iterable[nn.Parameter]):
61 | assert len(self.backup) > 0, 'backup is empty'
62 | for b_param, param in zip(self.backup, parameters):
63 | param.data.copy_(b_param.to(param.device).data)
64 | self.backup = []
65 |
66 | def state_dict(self):
67 | return dict(
68 | decay=self.decay,
69 | shadow=self.shadow,
70 | num_updates=self.num_updates,
71 | )
72 |
73 | def load_state_dict(self, state_dict):
74 | self.decay = state_dict['decay']
75 | self.shadow = state_dict['shadow']
76 | self.num_updates = state_dict['num_updates']
77 |
78 | def to(self, device):
79 | self.shadow = [s_param.to(device) for s_param in self.shadow]
80 |
81 |
82 | def _test():
83 | # initialize to 0
84 | model = nn.Sequential(nn.Linear(5, 1))
85 | for p in model[0].parameters():
86 | p.data.fill_(0)
87 | ema = EMA(model.parameters(), decay=0.9, gradual=False)
88 | print(model.state_dict()) # 0
89 | print(ema.state_dict()['shadow']) # 0
90 | print()
91 |
92 | # update the model to 1
93 | for p in model[0].parameters():
94 | p.data.fill_(1)
95 | ema.update(model.parameters())
96 | print(model.state_dict()) # 1
97 | print(ema.state_dict()['shadow']) # 0.9 * 0 + 0.1 * 1 = 0.1
98 | print()
99 |
100 | # update the model to 2
101 | for p in model[0].parameters():
102 | p.data.fill_(2)
103 | ema.update(model.parameters())
104 | print(model.state_dict()) # 2
105 | print(ema.state_dict()['shadow']) # 0.9 * 0.1 + 0.1 * 2 = 0.29
106 | print()
107 |
108 | # apply shadow
109 | ema.apply_shadow(model.parameters())
110 | print(model.state_dict()) # 0.29
111 | print(ema.state_dict()['shadow']) # 0.29
112 | print()
113 |
114 | # restore
115 | ema.restore(model.parameters())
116 | print(model.state_dict()) # 2
117 | print(ema.state_dict()['shadow']) # 0.29
118 |
119 |
120 | if __name__ == '__main__':
121 | _test()
122 |
--------------------------------------------------------------------------------
/models/mdt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/mdt/__init__.py
--------------------------------------------------------------------------------
/models/mdt/autoencoder.py:
--------------------------------------------------------------------------------
1 | import diffusers
2 |
3 |
4 | def AutoEncoderKL(from_pretrained: str):
5 | return diffusers.AutoencoderKL.from_pretrained(from_pretrained)
6 |
--------------------------------------------------------------------------------
/models/mdt/mdt.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping, Any
2 | from omegaconf import OmegaConf
3 |
4 | from torch import Tensor
5 |
6 | from ..base_latent import BaseLatent
7 | from utils.misc import instantiate_from_config
8 |
9 |
10 | class MDT(BaseLatent):
11 | def __init__(
12 | self,
13 | vae_config: OmegaConf,
14 | vit_config: OmegaConf,
15 | scale_factor: float = 0.18215,
16 | ):
17 | super().__init__(scale_factor=scale_factor)
18 |
19 | self.vae = instantiate_from_config(vae_config)
20 | self.vit = instantiate_from_config(vit_config)
21 |
22 | def decode_latent(self, z: Tensor):
23 | z = 1. / self.scale_factor * z
24 | return self.vae.decode(z).sample
25 |
26 | def vit_forward(self, x: Tensor, t: Tensor, y: Tensor):
27 | return self.vit(x, t, y)
28 |
29 | def forward(self, x: Tensor, timesteps: Tensor, y: Tensor = None):
30 | return self.vit_forward(x, timesteps, y)
31 |
32 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
33 | self.vit.load_state_dict(state_dict, strict=strict, assign=assign)
34 |
--------------------------------------------------------------------------------
/models/mdt/readme.md:
--------------------------------------------------------------------------------
1 | The MDT architecture proposed in [Masked Diffusion Transformer is a Strong Image Synthesizer](https://arxiv.org/abs/2303.14389).
2 |
3 | Codes are copied and adapted from https://github.com/sail-sg/MDT.
4 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 |
7 |
8 | def init_weights(init_type=None, gain=0.02):
9 |
10 | def init_func(m):
11 | classname = m.__class__.__name__
12 |
13 | if classname.find('BatchNorm') != -1:
14 | if hasattr(m, 'weight') and m.weight is not None:
15 | nn.init.normal_(m.weight, 1.0, gain)
16 | if hasattr(m, 'bias') and m.bias is not None:
17 | nn.init.constant_(m.bias, 0.0)
18 |
19 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
20 | if init_type == 'normal':
21 | nn.init.normal_(m.weight, 0.0, gain)
22 | elif init_type == 'xavier':
23 | nn.init.xavier_normal_(m.weight, gain=gain)
24 | elif init_type == 'xavier_uniform':
25 | nn.init.xavier_uniform_(m.weight, gain=1.0)
26 | elif init_type == 'kaiming':
27 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
28 | elif init_type == 'orthogonal':
29 | nn.init.orthogonal_(m.weight, gain=gain)
30 | elif init_type is None:
31 | m.reset_parameters()
32 | else:
33 | raise ValueError(f'invalid initialization method: {init_type}.')
34 | if hasattr(m, 'bias') and m.bias is not None:
35 | nn.init.constant_(m.bias, 0.0)
36 |
37 | return init_func
38 |
39 |
40 | class SinusoidalPosEmb(nn.Module):
41 | def __init__(self, dim: int):
42 | super().__init__()
43 | self.dim = dim
44 |
45 | def forward(self, X: Tensor):
46 | """
47 | Args:
48 | X (Tensor): [bs]
49 | Returns:
50 | Sinusoidal embeddings of shape [bs, dim]
51 | """
52 | half_dim = self.dim // 2
53 | embed = math.log(10000) / (half_dim - 1)
54 | embed = torch.exp(torch.arange(half_dim, device=X.device) * -embed)
55 | embed = X[:, None] * embed[None, :]
56 | embed = torch.cat((embed.sin(), embed.cos()), dim=-1)
57 | return embed
58 |
59 |
60 | def Upsample(in_channels: int, out_channels: int, use_conv: bool = True):
61 | if use_conv:
62 | return nn.Sequential(
63 | nn.Upsample(scale_factor=2, mode='nearest'),
64 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
65 | )
66 | else:
67 | return nn.Upsample(scale_factor=2, mode='nearest')
68 |
69 |
70 | def Downsample(in_channels: int, out_channels: int, use_conv: bool = True):
71 | if use_conv:
72 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
73 | else:
74 | return nn.AvgPool2d(kernel_size=2, stride=2)
75 |
76 |
77 | class SelfAttentionBlock(nn.Module):
78 | def __init__(self, dim: int, n_heads: int = 1, groups: int = 32):
79 | super().__init__()
80 | assert dim % n_heads == 0
81 | self.n_heads = n_heads
82 | self.norm = nn.GroupNorm(groups, dim)
83 | self.q = nn.Conv2d(dim, dim, kernel_size=1)
84 | self.k = nn.Conv2d(dim, dim, kernel_size=1)
85 | self.v = nn.Conv2d(dim, dim, kernel_size=1)
86 | self.proj = nn.Conv2d(dim, dim, kernel_size=1)
87 | self.scale = (dim // n_heads) ** -0.5
88 |
89 | def forward(self, X: Tensor, return_attn_map: bool = False):
90 | bs, C, H, W = X.shape
91 | normX = self.norm(X)
92 | q = self.q(normX).view(bs * self.n_heads, -1, H*W)
93 | k = self.k(normX).view(bs * self.n_heads, -1, H*W)
94 | v = self.v(normX).view(bs * self.n_heads, -1, H*W)
95 | q = q * self.scale
96 | attn = torch.bmm(q.permute(0, 2, 1), k).softmax(dim=-1)
97 | output = torch.bmm(v, attn.permute(0, 2, 1)).view(bs, -1, H, W)
98 | output = self.proj(output)
99 | if not return_attn_map:
100 | return output + X
101 | else:
102 | return output + X, attn.view(bs, self.n_heads, H*W, H*W)
103 |
104 |
105 | class AdaGN(nn.Module):
106 | def __init__(self, num_groups: int, num_channels: int, embed_dim: int):
107 | super().__init__()
108 | self.gn = nn.GroupNorm(num_groups, num_channels)
109 | self.proj = nn.Sequential(
110 | nn.SiLU(),
111 | nn.Linear(embed_dim, num_channels * 2),
112 | )
113 |
114 | def forward(self, X: Tensor, embed: Tensor):
115 | """
116 | Args:
117 | X (Tensor): [bs, C, H, W]
118 | embed (Tensor): [bs, embed_dim]
119 | """
120 | ys, yb = torch.chunk(self.proj(embed), 2, dim=-1)
121 | ys = ys[:, :, None, None]
122 | yb = yb[:, :, None, None]
123 | return self.gn(X) * (1 + ys) + yb
124 |
--------------------------------------------------------------------------------
/models/pesser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/pesser/__init__.py
--------------------------------------------------------------------------------
/models/pesser/readme.md:
--------------------------------------------------------------------------------
1 | The DDPM UNet architecture proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239).
2 |
3 | The original codes are in TensorFlow, so we use the PyTorch version developed by [pesser](https://github.com/pesser).
4 |
5 | Codes are copied and adapted from https://github.com/pesser/pytorch_diffusion.
6 |
--------------------------------------------------------------------------------
/models/sdxl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/sdxl/__init__.py
--------------------------------------------------------------------------------
/models/sdxl/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self):
38 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
39 | device=self.parameters.device
40 | )
41 | return x
42 |
43 | def kl(self, other=None):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | else:
47 | if other is None:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50 | dim=[1, 2, 3],
51 | )
52 | else:
53 | return 0.5 * torch.sum(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var
56 | - 1.0
57 | - self.logvar
58 | + other.logvar,
59 | dim=[1, 2, 3],
60 | )
61 |
62 | def nll(self, sample, dims=(1, 2, 3)):
63 | if self.deterministic:
64 | return torch.Tensor([0.0])
65 | logtwopi = np.log(2.0 * np.pi)
66 | return 0.5 * torch.sum(
67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68 | dim=dims,
69 | )
70 |
71 | def mode(self):
72 | return self.mean
73 |
74 |
75 | def normal_kl(mean1, logvar1, mean2, logvar2):
76 | """
77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78 | Compute the KL divergence between two gaussians.
79 | Shapes are automatically broadcasted, so batches can be compared to
80 | scalars, among other use cases.
81 | """
82 | tensor = None
83 | for obj in (mean1, logvar1, mean2, logvar2):
84 | if isinstance(obj, torch.Tensor):
85 | tensor = obj
86 | break
87 | assert tensor is not None, "at least one argument must be a Tensor"
88 |
89 | # Force variances to be Tensors. Broadcasting helps convert scalars to
90 | # Tensors, but it does not work for torch.exp().
91 | logvar1, logvar2 = [
92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93 | for x in (logvar1, logvar2)
94 | ]
95 |
96 | return 0.5 * (
97 | -1.0
98 | + logvar2
99 | - logvar1
100 | + torch.exp(logvar1 - logvar2)
101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102 | )
103 |
--------------------------------------------------------------------------------
/models/sdxl/readme.md:
--------------------------------------------------------------------------------
1 | The StableDiffusion XL model architectures.
2 |
3 | Codes are copied and adapted from https://github.com/Stability-AI/generative-models. Files are reorganized. Un-used functions and classes are removed. Some PEP8 warnings are resolved.
4 |
--------------------------------------------------------------------------------
/models/sdxl/regularizers.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from .distributions import DiagonalGaussianDistribution
9 |
10 |
11 | class AbstractRegularizer(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
16 | raise NotImplementedError()
17 |
18 | @abstractmethod
19 | def get_trainable_parameters(self) -> Any:
20 | raise NotImplementedError()
21 |
22 |
23 | class IdentityRegularizer(AbstractRegularizer):
24 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
25 | return z, dict()
26 |
27 | def get_trainable_parameters(self) -> Any:
28 | yield from ()
29 |
30 |
31 | def measure_perplexity(
32 | predicted_indices: torch.Tensor, num_centroids: int
33 | ) -> Tuple[torch.Tensor, torch.Tensor]:
34 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
35 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
36 | encodings = (
37 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
38 | )
39 | avg_probs = encodings.mean(0)
40 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
41 | cluster_use = torch.sum(avg_probs > 0)
42 | return perplexity, cluster_use
43 |
44 |
45 | class DiagonalGaussianRegularizer(AbstractRegularizer):
46 | def __init__(self, sample: bool = True):
47 | super().__init__()
48 | self.sample = sample
49 |
50 | def get_trainable_parameters(self) -> Any:
51 | yield from ()
52 |
53 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
54 | log = dict()
55 | posterior = DiagonalGaussianDistribution(z)
56 | if self.sample:
57 | z = posterior.sample()
58 | else:
59 | z = posterior.mode()
60 | kl_loss = posterior.kl()
61 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
62 | log["kl_loss"] = kl_loss
63 | return z, log
64 |
--------------------------------------------------------------------------------
/models/sdxl/stablediffusion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Mapping, Any, Dict
2 | from omegaconf import OmegaConf
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | from ..base_latent import BaseLatent
8 | from utils.misc import instantiate_from_config
9 |
10 |
11 | class StableDiffusion(BaseLatent):
12 | def __init__(
13 | self,
14 | conditioner_config: OmegaConf,
15 | vae_config: OmegaConf,
16 | unet_config: OmegaConf,
17 | scale_factor: float = 0.13025,
18 | low_vram_shift_enabled: bool = False,
19 | ):
20 | super().__init__(scale_factor=scale_factor)
21 |
22 | self.conditioner = instantiate_from_config(conditioner_config)
23 | self.vae = instantiate_from_config(vae_config)
24 | self.unet = instantiate_from_config(unet_config)
25 |
26 | self.low_vram_shift_enabled = low_vram_shift_enabled
27 |
28 | def encode_latent(self, x: Tensor):
29 | if self.low_vram_shift_enabled:
30 | self.conditioner.to('cpu')
31 | self.unet.to('cpu')
32 | self.vae.to(self.device)
33 | torch.cuda.empty_cache()
34 | z = self.vae.encode(x)
35 | return self.scale_factor * z
36 |
37 | def decode_latent(self, z: Tensor):
38 | if self.low_vram_shift_enabled:
39 | self.conditioner.to('cpu')
40 | self.unet.to('cpu')
41 | self.vae.to(self.device)
42 | torch.cuda.empty_cache()
43 | z = 1. / self.scale_factor * z
44 | return self.vae.decode(z)
45 |
46 | def conditioner_forward(self, text: List[str], H: int, W: int):
47 | if self.low_vram_shift_enabled:
48 | self.vae.to('cpu')
49 | self.unet.to('cpu')
50 | self.conditioner.to(self.device)
51 | torch.cuda.empty_cache()
52 | batch = dict(
53 | txt=text,
54 | original_size_as_tuple=torch.tensor([1024, 1024], device=self.device).repeat(len(text), 1),
55 | crop_coords_top_left=torch.tensor([0, 0], device=self.device).repeat(len(text), 1),
56 | target_size_as_tuple=torch.tensor([H, W], device=self.device).repeat(len(text), 1),
57 | )
58 | return self.conditioner(batch)
59 |
60 | def unet_forward(self, x: Tensor, timesteps: Tensor, context: Tensor, y: Tensor):
61 | if self.low_vram_shift_enabled:
62 | self.vae.to('cpu')
63 | self.conditioner.to('cpu')
64 | self.unet.to(self.device)
65 | torch.cuda.empty_cache()
66 | return self.unet(x, timesteps=timesteps, context=context, y=y)
67 |
68 | def forward(
69 | self, x: Tensor, timesteps: Tensor, condition_dict: Dict = None,
70 | text: List[str] = None, H: int = None, W: int = None,
71 | ):
72 | if condition_dict is None:
73 | if text is None or H is None or W is None:
74 | raise ValueError('text, H and W must be provided when `condition_dict` is not provided.')
75 | condition_dict = self.conditioner_forward(text, H, W)
76 | context = condition_dict.get('crossattn')
77 | y = condition_dict.get('vector')
78 | x = self.unet_forward(x, timesteps=timesteps, context=context, y=y)
79 | return x
80 |
81 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
82 | # state_dict_conditioner = {k[12:]: v for k, v in state_dict.items() if k.startswith('conditioner.')}
83 | state_dict_vae = {k[18:]: v for k, v in state_dict.items() if k.startswith('first_stage_model.')}
84 | state_dict_unet = {k[22:]: v for k, v in state_dict.items() if k.startswith('model.diffusion_model.')}
85 | # self.conditioner.load_state_dict(state_dict_conditioner, strict=strict, assign=assign)
86 | self.vae.load_state_dict(state_dict_vae, strict=strict, assign=assign)
87 | self.unet.load_state_dict(state_dict_unet, strict=strict, assign=assign)
88 | # del state_dict_conditioner
89 | del state_dict_vae
90 | del state_dict_unet
91 |
--------------------------------------------------------------------------------
/models/stablediffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyfJASON/diffusion-models-pytorch/3351aab1d5f0459a5af75c622932e071f9b6ee7a/models/stablediffusion/__init__.py
--------------------------------------------------------------------------------
/models/stablediffusion/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=(1, 2, 3)):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/models/stablediffusion/modules.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 | import math
11 | import torch
12 | import torch.nn as nn
13 | from einops import repeat
14 |
15 |
16 | def checkpoint(func, inputs, params, flag):
17 | """
18 | Evaluate a function without caching intermediate activations, allowing for
19 | reduced memory at the expense of extra compute in the backward pass.
20 | :param func: the function to evaluate.
21 | :param inputs: the argument sequence to pass to `func`.
22 | :param params: a sequence of parameters `func` depends on but does not
23 | explicitly take as arguments.
24 | :param flag: if False, disable gradient checkpointing.
25 | """
26 | if flag:
27 | args = tuple(inputs) + tuple(params)
28 | return CheckpointFunction.apply(func, len(inputs), *args)
29 | else:
30 | return func(*inputs)
31 |
32 |
33 | class CheckpointFunction(torch.autograd.Function):
34 | @staticmethod
35 | def forward(ctx, run_function, length, *args):
36 | ctx.run_function = run_function
37 | ctx.input_tensors = list(args[:length])
38 | ctx.input_params = list(args[length:])
39 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
40 | "dtype": torch.get_autocast_gpu_dtype(),
41 | "cache_enabled": torch.is_autocast_cache_enabled()}
42 | with torch.no_grad():
43 | output_tensors = ctx.run_function(*ctx.input_tensors)
44 | return output_tensors
45 |
46 | @staticmethod
47 | def backward(ctx, *output_grads):
48 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
49 | with torch.enable_grad(), \
50 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
51 | # Fixes a bug where the first op in run_function modifies the
52 | # Tensor storage in place, which is not allowed for detach()'d
53 | # Tensors.
54 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
55 | output_tensors = ctx.run_function(*shallow_copies)
56 | input_grads = torch.autograd.grad(
57 | output_tensors,
58 | ctx.input_tensors + ctx.input_params,
59 | output_grads,
60 | allow_unused=True,
61 | )
62 | del ctx.input_tensors
63 | del ctx.input_params
64 | del output_tensors
65 | return (None, None) + input_grads
66 |
67 |
68 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
69 | """
70 | Create sinusoidal timestep embeddings.
71 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
72 | These may be fractional.
73 | :param dim: the dimension of the output.
74 | :param max_period: controls the minimum frequency of the embeddings.
75 | :return: an [N x dim] Tensor of positional embeddings.
76 | """
77 | if not repeat_only:
78 | half = dim // 2
79 | freqs = torch.exp(
80 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
81 | ).to(device=timesteps.device)
82 | args = timesteps[:, None].float() * freqs[None]
83 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
84 | if dim % 2:
85 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
86 | else:
87 | embedding = repeat(timesteps, 'b -> b d', d=dim)
88 | return embedding
89 |
90 |
91 | def zero_module(module):
92 | """
93 | Zero out the parameters of a module and return it.
94 | """
95 | for p in module.parameters():
96 | p.detach().zero_()
97 | return module
98 |
99 |
100 | def scale_module(module, scale):
101 | """
102 | Scale the parameters of a module and return it.
103 | """
104 | for p in module.parameters():
105 | p.detach().mul_(scale)
106 | return module
107 |
108 |
109 | def mean_flat(tensor):
110 | """
111 | Take the mean over all non-batch dimensions.
112 | """
113 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
114 |
115 |
116 | def normalization(channels):
117 | """
118 | Make a standard normalization layer.
119 | :param channels: number of input channels.
120 | :return: an nn.Module for normalization.
121 | """
122 | return GroupNorm32(32, channels)
123 |
124 |
125 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
126 | class SiLU(nn.Module):
127 | def forward(self, x):
128 | return x * torch.sigmoid(x)
129 |
130 |
131 | class GroupNorm32(nn.GroupNorm):
132 | def forward(self, x):
133 | return super().forward(x.float()).type(x.dtype)
134 |
135 |
136 | def conv_nd(dims, *args, **kwargs):
137 | """
138 | Create a 1D, 2D, or 3D convolution module.
139 | """
140 | if dims == 1:
141 | return nn.Conv1d(*args, **kwargs)
142 | elif dims == 2:
143 | return nn.Conv2d(*args, **kwargs)
144 | elif dims == 3:
145 | return nn.Conv3d(*args, **kwargs)
146 | raise ValueError(f"unsupported dimensions: {dims}")
147 |
148 |
149 | def linear(*args, **kwargs):
150 | """
151 | Create a linear module.
152 | """
153 | return nn.Linear(*args, **kwargs)
154 |
155 |
156 | def avg_pool_nd(dims, *args, **kwargs):
157 | """
158 | Create a 1D, 2D, or 3D average pooling module.
159 | """
160 | if dims == 1:
161 | return nn.AvgPool1d(*args, **kwargs)
162 | elif dims == 2:
163 | return nn.AvgPool2d(*args, **kwargs)
164 | elif dims == 3:
165 | return nn.AvgPool3d(*args, **kwargs)
166 | raise ValueError(f"unsupported dimensions: {dims}")
167 |
--------------------------------------------------------------------------------
/models/stablediffusion/readme.md:
--------------------------------------------------------------------------------
1 | The StableDiffusion model architectures.
2 |
3 | Codes are copied and adapted from https://github.com/Stability-AI/stablediffusion. Files are reorganized. Un-used functions and classes are removed. Some PEP8 warnings are resolved.
4 |
--------------------------------------------------------------------------------
/models/stablediffusion/stablediffusion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Mapping, Any
2 | from omegaconf import OmegaConf
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | from .distributions import DiagonalGaussianDistribution
8 | from ..base_latent import BaseLatent
9 | from utils.misc import instantiate_from_config
10 |
11 |
12 | class StableDiffusion(BaseLatent):
13 | def __init__(
14 | self,
15 | text_encoder_config: OmegaConf,
16 | vae_config: OmegaConf,
17 | unet_config: OmegaConf,
18 | scale_factor: float = 0.18215,
19 | low_vram_shift_enabled: bool = False,
20 | ):
21 | super().__init__(scale_factor=scale_factor)
22 |
23 | self.text_encoder = instantiate_from_config(text_encoder_config)
24 | self.vae = instantiate_from_config(vae_config)
25 | self.unet = instantiate_from_config(unet_config)
26 |
27 | self.low_vram_shift_enabled = low_vram_shift_enabled
28 |
29 | def encode_latent(self, x: Tensor):
30 | if self.low_vram_shift_enabled:
31 | self.text_encoder.to('cpu')
32 | self.unet.to('cpu')
33 | self.vae.to(self.device)
34 | torch.cuda.empty_cache()
35 | z = self.vae.encode(x)
36 | if isinstance(z, DiagonalGaussianDistribution):
37 | z = z.sample()
38 | return self.scale_factor * z
39 |
40 | def decode_latent(self, z: Tensor):
41 | if self.low_vram_shift_enabled:
42 | self.text_encoder.to('cpu')
43 | self.unet.to('cpu')
44 | self.vae.to(self.device)
45 | torch.cuda.empty_cache()
46 | z = 1. / self.scale_factor * z
47 | return self.vae.decode(z)
48 |
49 | def text_encoder_encode(self, text: List[str]):
50 | if self.low_vram_shift_enabled:
51 | self.vae.to('cpu')
52 | self.unet.to('cpu')
53 | self.text_encoder.to(self.device)
54 | torch.cuda.empty_cache()
55 | return self.text_encoder.encode(text)
56 |
57 | def unet_forward(self, x: Tensor, timesteps: Tensor, context: Tensor):
58 | if self.low_vram_shift_enabled:
59 | self.vae.to('cpu')
60 | self.text_encoder.to('cpu')
61 | self.unet.to(self.device)
62 | torch.cuda.empty_cache()
63 | return self.unet(x, timesteps=timesteps, context=context)
64 |
65 | def forward(self, x: Tensor, timesteps: Tensor, text_embed: Tensor = None, text: List[str] = None):
66 | if text_embed is None and text is None:
67 | raise ValueError('Either `text_embed` or `text` must be provided.')
68 | if text_embed is None:
69 | text_embed = self.text_encoder_encode(text)
70 | x = self.unet_forward(x, timesteps=timesteps, context=text_embed)
71 | return x
72 |
73 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
74 | state_dict_vae = {k[18:]: v for k, v in state_dict.items() if k.startswith('first_stage_model.')}
75 | state_dict_unet = {k[22:]: v for k, v in state_dict.items() if k.startswith('model.diffusion_model.')}
76 | self.vae.load_state_dict(state_dict_vae, strict=strict)
77 | self.unet.load_state_dict(state_dict_unet, strict=strict)
78 | del state_dict_vae
79 | del state_dict_unet
80 |
--------------------------------------------------------------------------------
/models/stablediffusion/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from inspect import isfunction
3 |
4 | import torch
5 |
6 |
7 | def ismap(x):
8 | if not isinstance(x, torch.Tensor):
9 | return False
10 | return (len(x.shape) == 4) and (x.shape[1] > 3)
11 |
12 |
13 | def isimage(x):
14 | if not isinstance(x, torch.Tensor):
15 | return False
16 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
17 |
18 |
19 | def exists(x):
20 | return x is not None
21 |
22 |
23 | def default(val, d):
24 | if exists(val):
25 | return val
26 | return d() if isfunction(d) else d
27 |
28 |
29 | def count_params(model, verbose=False):
30 | total_params = sum(p.numel() for p in model.parameters())
31 | if verbose:
32 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
33 | return total_params
34 |
35 |
36 | def instantiate_from_config(config):
37 | if "target" not in config:
38 | if config == '__is_first_stage__':
39 | return None
40 | elif config == "__is_unconditional__":
41 | return None
42 | raise KeyError("Expected key `target` to instantiate.")
43 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
44 |
45 |
46 | def get_obj_from_str(string, reload=False):
47 | module, cls = string.rsplit(".", 1)
48 | if reload:
49 | module_imp = importlib.import_module(module)
50 | importlib.reload(module_imp)
51 | return getattr(importlib.import_module(module, package=None), cls)
52 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 |
7 | from models.modules import SinusoidalPosEmb, SelfAttentionBlock, Downsample, Upsample
8 |
9 |
10 | class ResBlock(nn.Module):
11 | def __init__(self, in_channels: int, out_channels: int, embed_dim: int, dropout: float = 0.1):
12 | super().__init__()
13 | self.blk1 = nn.Sequential(
14 | nn.GroupNorm(32, in_channels),
15 | nn.SiLU(),
16 | nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
17 | )
18 | self.proj = nn.Sequential(
19 | nn.SiLU(),
20 | nn.Linear(embed_dim, out_channels),
21 | )
22 | self.blk2 = nn.Sequential(
23 | nn.GroupNorm(32, out_channels),
24 | nn.SiLU(),
25 | nn.Dropout(dropout),
26 | nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
27 | )
28 | self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
29 |
30 | def forward(self, X: Tensor, time_embed: Tensor = None):
31 | """
32 | Args:
33 | X (Tensor): [bs, C, H, W]
34 | time_embed (Tensor): [bs, embed_dim]
35 | Returns:
36 | [bs, C', H, W]
37 | """
38 | shortcut = self.shortcut(X)
39 | X = self.blk1(X)
40 | if time_embed is not None:
41 | X = X + self.proj(time_embed)[:, :, None, None]
42 | X = self.blk2(X)
43 | return X + shortcut
44 |
45 |
46 | class UNet(nn.Module):
47 | def __init__(
48 | self,
49 | in_channels: int = 3,
50 | out_channels: int = 3,
51 | dim: int = 128,
52 | dim_mults: List[int] = (1, 2, 2, 2),
53 | use_attn: List[int] = (False, True, False, False),
54 | num_res_blocks: int = 2,
55 | n_heads: int = 1,
56 | dropout: float = 0.1,
57 | ):
58 | super().__init__()
59 | n_stages = len(dim_mults)
60 | dims = [dim]
61 |
62 | # Time embeddings
63 | time_embed_dim = dim * 4
64 | self.time_embed = nn.Sequential(
65 | SinusoidalPosEmb(dim),
66 | nn.Linear(dim, time_embed_dim),
67 | nn.SiLU(),
68 | nn.Linear(time_embed_dim, time_embed_dim),
69 | )
70 |
71 | # First convolution
72 | self.first_conv = nn.Conv2d(in_channels, dim, 3, stride=1, padding=1)
73 | cur_dim = dim
74 |
75 | # Down-sample blocks
76 | # Default: 32x32 -> 16x16 -> 8x8 -> 4x4
77 | self.down_blocks = nn.ModuleList([])
78 | for i in range(n_stages):
79 | out_dim = dim * dim_mults[i]
80 | stage_blocks = nn.ModuleList([])
81 | for j in range(num_res_blocks):
82 | stage_blocks.append(ResBlock(cur_dim, out_dim, embed_dim=time_embed_dim, dropout=dropout))
83 | if use_attn[i]:
84 | stage_blocks.append(SelfAttentionBlock(out_dim, n_heads=n_heads))
85 | dims.append(out_dim)
86 | cur_dim = out_dim
87 | if i < n_stages - 1:
88 | stage_blocks.append(Downsample(out_dim, out_dim))
89 | dims.append(out_dim)
90 | self.down_blocks.append(stage_blocks)
91 |
92 | # Bottleneck block
93 | self.bottleneck_block = nn.ModuleList([
94 | ResBlock(cur_dim, cur_dim, embed_dim=time_embed_dim, dropout=dropout),
95 | SelfAttentionBlock(cur_dim),
96 | ResBlock(cur_dim, cur_dim, embed_dim=time_embed_dim, dropout=dropout),
97 | ])
98 |
99 | # Up-sample blocks
100 | # Default: 4x4 -> 8x8 -> 16x16 -> 32x32
101 | self.up_blocks = nn.ModuleList([])
102 | for i in range(n_stages-1, -1, -1):
103 | out_dim = dim * dim_mults[i]
104 | stage_blocks = nn.ModuleList([])
105 | for j in range(num_res_blocks + 1):
106 | stage_blocks.append(ResBlock(dims.pop() + cur_dim, out_dim, embed_dim=time_embed_dim, dropout=dropout))
107 | if use_attn[i]:
108 | stage_blocks.append(SelfAttentionBlock(out_dim, n_heads=n_heads))
109 | cur_dim = out_dim
110 | if i > 0:
111 | stage_blocks.append(Upsample(out_dim, out_dim))
112 | self.up_blocks.append(stage_blocks)
113 |
114 | # Last convolution
115 | self.last_conv = nn.Sequential(
116 | nn.GroupNorm(32, cur_dim),
117 | nn.SiLU(),
118 | nn.Conv2d(cur_dim, out_channels, 3, stride=1, padding=1),
119 | )
120 |
121 | def forward(self, X: Tensor, T: Tensor):
122 | time_embed = self.time_embed(T)
123 | X = self.first_conv(X)
124 | skips = [X]
125 |
126 | for stage_blocks in self.down_blocks:
127 | for blk in stage_blocks: # noqa
128 | if isinstance(blk, ResBlock):
129 | X = blk(X, time_embed)
130 | skips.append(X)
131 | elif isinstance(blk, SelfAttentionBlock):
132 | X = blk(X)
133 | skips[-1] = X
134 | else: # Downsample
135 | X = blk(X)
136 | skips.append(X)
137 |
138 | X = self.bottleneck_block[0](X, time_embed)
139 | X = self.bottleneck_block[1](X)
140 | X = self.bottleneck_block[2](X, time_embed)
141 |
142 | for stage_blocks in self.up_blocks:
143 | for blk in stage_blocks: # noqa
144 | if isinstance(blk, ResBlock):
145 | X = blk(torch.cat((X, skips.pop()), dim=1), time_embed)
146 | elif isinstance(blk, SelfAttentionBlock):
147 | X = blk(X)
148 | else: # Upsample
149 | X = blk(X)
150 |
151 | X = self.last_conv(X)
152 | return X
153 |
154 |
155 | def _test():
156 | unet = UNet()
157 | X = torch.empty((10, 3, 32, 32))
158 | T = torch.arange(10)
159 | out = unet(X, T)
160 | print(out.shape)
161 | print(sum(p.numel() for p in unet.parameters()))
162 |
163 | unet = UNet(
164 | in_channels=1,
165 | out_channels=1,
166 | dim=128,
167 | dim_mults=[1, 1, 2, 2, 4, 4],
168 | use_attn=[False, False, False, False, True, False],
169 | dropout=0.0,
170 | )
171 | X = torch.empty((10, 1, 256, 256))
172 | T = torch.arange(10)
173 | out = unet(X, T)
174 | print(out.shape)
175 | print(sum(p.numel() for p in unet.parameters()))
176 |
177 |
178 | if __name__ == '__main__':
179 | _test()
180 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | einops
3 | omegaconf
4 | matplotlib
5 | tensorboard
6 | numpy
7 | pandas
8 | Pillow
9 | torch~=2.1.0
10 | torchvision~=0.16.0
11 | accelerate~=0.23.0
12 | transformers~=4.37.2
13 | diffusers~=0.21.4
14 | safetensors~=0.4.2
15 | open-clip-torch
16 | timm~=0.9.12
17 | streamlit~=1.31.0
18 |
--------------------------------------------------------------------------------
/scripts/sample_clip_guidance.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import math
6 | import argparse
7 | from omegaconf import OmegaConf
8 |
9 | import torch
10 | import accelerate
11 | from torchvision.utils import save_image
12 |
13 | import diffusions
14 | from utils.logger import get_logger
15 | from utils.load import load_weights
16 | from utils.misc import image_norm_to_float, instantiate_from_config, amortize
17 |
18 |
19 | def get_parser():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument(
22 | '-c', '--config', type=str, required=True,
23 | help='Path to inference configuration file',
24 | )
25 | parser.add_argument(
26 | '--seed', type=int, default=2022,
27 | help='Set random seed',
28 | )
29 | parser.add_argument(
30 | '--weights', type=str, required=True,
31 | help='Path to pretrained model weights',
32 | )
33 | parser.add_argument(
34 | '--var_type', type=str, default=None,
35 | help='Type of variance of the reverse process',
36 | )
37 | parser.add_argument(
38 | '--respace_type', type=str, default='uniform',
39 | help='Type of respaced timestep sequence',
40 | )
41 | parser.add_argument(
42 | '--respace_steps', type=int, default=None,
43 | help='Length of respaced timestep sequence',
44 | )
45 | parser.add_argument(
46 | '--text', type=str, required=True,
47 | help='Text description of the generated image',
48 | )
49 | parser.add_argument(
50 | '--guidance_weight', type=float, default=100.,
51 | help='Weight of CLIP guidance',
52 | )
53 | parser.add_argument(
54 | '--clip_model', type=str, default='openai/clip-vit-large-patch14',
55 | help='Name of CLIP model',
56 | )
57 | parser.add_argument(
58 | '--n_samples', type=int, required=True,
59 | help='Number of samples',
60 | )
61 | parser.add_argument(
62 | '--save_dir', type=str, required=True,
63 | help='Path to directory saving samples',
64 | )
65 | parser.add_argument(
66 | '--batch_size', type=int, default=500,
67 | help='Batch size on each process',
68 | )
69 | return parser
70 |
71 |
72 | def main():
73 | # PARSE ARGS AND CONFIGS
74 | args, unknown_args = get_parser().parse_known_args()
75 | unknown_args = [(a[2:] if a.startswith('--') else a) for a in unknown_args]
76 | unknown_args = [f'{k}={v}' for k, v in zip(unknown_args[::2], unknown_args[1::2])]
77 | conf = OmegaConf.load(args.config)
78 | conf = OmegaConf.merge(conf, OmegaConf.from_dotlist(unknown_args))
79 |
80 | # INITIALIZE ACCELERATOR
81 | accelerator = accelerate.Accelerator()
82 | device = accelerator.device
83 | print(f'Process {accelerator.process_index} using device: {device}')
84 | accelerator.wait_for_everyone()
85 |
86 | # INITIALIZE LOGGER
87 | logger = get_logger(
88 | use_tqdm_handler=True,
89 | is_main_process=accelerator.is_main_process,
90 | )
91 |
92 | # SET SEED
93 | accelerate.utils.set_seed(args.seed, device_specific=True)
94 | logger.info('=' * 19 + ' System Info ' + '=' * 18)
95 | logger.info(f'Number of processes: {accelerator.num_processes}')
96 | logger.info(f'Distributed type: {accelerator.distributed_type}')
97 | logger.info(f'Mixed precision: {accelerator.mixed_precision}')
98 |
99 | accelerator.wait_for_everyone()
100 |
101 | # BUILD DIFFUSER
102 | diffusion_params = OmegaConf.to_container(conf.diffusion.params)
103 | diffusion_params.update({
104 | 'var_type': args.var_type or diffusion_params.get('var_type', None),
105 | 'respace_type': None if args.respace_steps is None else args.respace_type,
106 | 'respace_steps': args.respace_steps,
107 | 'device': device,
108 | 'guidance_weight': args.guidance_weight,
109 | 'clip_pretrained': args.clip_model,
110 | })
111 | diffuser = diffusions.CLIPGuidance(**diffusion_params)
112 | logger.info('=' * 19 + ' Model Info ' + '=' * 19)
113 | logger.info(f'Using CLIP model: `{args.clip_model}`')
114 |
115 | # BUILD MODEL
116 | model = instantiate_from_config(conf.model)
117 |
118 | # LOAD WEIGHTS
119 | weights = load_weights(args.weights)
120 | model.load_state_dict(weights)
121 | logger.info(f'Successfully load model from {args.weights}')
122 | logger.info('=' * 50)
123 |
124 | # PREPARE FOR DISTRIBUTED MODE AND MIXED PRECISION
125 | model = accelerator.prepare(model)
126 | model.eval()
127 |
128 | accelerator.wait_for_everyone()
129 |
130 | @torch.no_grad()
131 | def sample():
132 | idx = 0
133 | img_shape = (conf.data.img_channels, conf.data.params.img_size, conf.data.params.img_size)
134 | bspp = min(args.batch_size, math.ceil(args.n_samples / accelerator.num_processes))
135 | folds = amortize(args.n_samples, bspp * accelerator.num_processes)
136 | for i, bs in enumerate(folds):
137 | init_noise = torch.randn((bspp, *img_shape), device=device)
138 | samples = diffuser.sample(
139 | model=accelerator.unwrap_model(model), init_noise=init_noise,
140 | tqdm_kwargs=dict(desc=f'Fold {i}/{len(folds)}', disable=not accelerator.is_main_process)
141 | ).clamp(-1, 1)
142 | samples = accelerator.gather(samples)[:bs]
143 | if accelerator.is_main_process:
144 | for x in samples:
145 | x = image_norm_to_float(x).cpu()
146 | save_image(x, os.path.join(args.save_dir, f'{idx}.png'), nrow=1)
147 | idx += 1
148 | with open(os.path.join(args.save_dir, 'description.txt'), 'w') as f:
149 | f.write(args.text)
150 |
151 | # START SAMPLING
152 | logger.info('Start sampling...')
153 | logger.info(f'The input text description is: {args.text}')
154 | logger.info(f'Guidance weight: {args.guidance_weight}')
155 | diffuser.set_text(args.text)
156 | os.makedirs(args.save_dir, exist_ok=True)
157 | logger.info(f'Samples will be saved to {args.save_dir}')
158 | sample()
159 | logger.info(f'Sampled images are saved to {args.save_dir}')
160 | logger.info('End of sampling')
161 |
162 |
163 | if __name__ == '__main__':
164 | main()
165 |
--------------------------------------------------------------------------------
/streamlit/Hello.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | st.set_page_config(page_title="Diffusion", layout="wide")
4 |
5 | st.markdown(
6 | """
7 |