├── .gitignore
├── LICENSE
├── README.md
├── assets
├── poster.pdf
└── slides.pdf
├── ddim
├── README.md
├── configs
│ ├── bedroom.yml
│ ├── celeba.yml
│ ├── church.yml
│ └── cifar10.yml
├── datasets
│ ├── __init__.py
│ ├── celeba.py
│ ├── ffhq.py
│ ├── lsun.py
│ ├── utils.py
│ └── vision.py
├── functions
│ ├── __init__.py
│ ├── ckpt_util.py
│ ├── denoising.py
│ └── losses.py
├── main.py
├── models
│ ├── diffusion.py
│ └── ema.py
└── runners
│ ├── __init__.py
│ └── diffusion.py
├── img
├── ldm.png
├── overview.png
└── sd.png
├── latent_imagenet_diffusion.py
├── linklink
├── __init__.py
├── dist_helper.py
└── log_helper.py
├── quant
├── __init__.py
├── adaptive_rounding.py
├── calibration.py
├── data_generate.py
├── data_utill.py
├── quant_block.py
├── quant_layer.py
├── quant_model.py
├── reconstruction.py
└── reconstruction_util.py
├── requirements.txt
├── sample_diffusion_ddim.py
├── sample_diffusion_ldm.py
├── stable-diffusion
├── README.md
├── Stable_Diffusion_v1_Model_Card.md
├── assets
│ ├── a-painting-of-a-fire.png
│ ├── a-photograph-of-a-fire.png
│ ├── a-shirt-with-a-fire-printed-on-it.png
│ ├── a-shirt-with-the-inscription-'fire'.png
│ ├── a-watercolor-painting-of-a-fire.png
│ ├── birdhouse.png
│ ├── fire.png
│ ├── inpainting.png
│ ├── modelfigure.png
│ ├── rdm-preview.jpg
│ ├── reconstruction1.png
│ ├── reconstruction2.png
│ ├── results.gif
│ ├── rick.jpeg
│ ├── stable-samples
│ │ ├── img2img
│ │ │ ├── mountains-1.png
│ │ │ ├── mountains-2.png
│ │ │ ├── mountains-3.png
│ │ │ ├── sketch-mountains-input.jpg
│ │ │ ├── upscaling-in.png
│ │ │ └── upscaling-out.png
│ │ └── txt2img
│ │ │ ├── 000002025.png
│ │ │ ├── 000002035.png
│ │ │ ├── merged-0005.png
│ │ │ ├── merged-0006.png
│ │ │ └── merged-0007.png
│ ├── the-earth-is-on-fire,-oil-on-canvas.png
│ ├── txt2img-convsample.png
│ ├── txt2img-preview.png
│ └── v1-variants-scores.jpg
├── configs
│ ├── autoencoder
│ │ ├── autoencoder_kl_16x16x16.yaml
│ │ ├── autoencoder_kl_32x32x4.yaml
│ │ ├── autoencoder_kl_64x64x3.yaml
│ │ └── autoencoder_kl_8x8x64.yaml
│ ├── latent-diffusion
│ │ ├── celebahq-ldm-vq-4.yaml
│ │ ├── cin-ldm-vq-f8.yaml
│ │ ├── cin256-v2.yaml
│ │ ├── ffhq-ldm-vq-4.yaml
│ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ └── txt2img-1p4B-eval.yaml
│ ├── retrieval-augmented-diffusion
│ │ └── 768x768.yaml
│ └── stable-diffusion
│ │ └── v1-inference.yaml
├── data
│ ├── DejaVuSans.ttf
│ ├── example_conditioning
│ │ ├── superresolution
│ │ │ └── sample_0.jpg
│ │ └── text_conditional
│ │ │ └── sample_0.txt
│ ├── imagenet_clsidx_to_label.txt
│ ├── imagenet_train_hr_indices.p
│ ├── imagenet_val_hr_indices.p
│ ├── index_synset.yaml
│ └── inpainting_examples
│ │ ├── 6458524847_2f4c361183_k.png
│ │ ├── 6458524847_2f4c361183_k_mask.png
│ │ ├── 8399166846_f6fb4e4b8e_k.png
│ │ ├── 8399166846_f6fb4e4b8e_k_mask.png
│ │ ├── alex-iby-G_Pk4D9rMLs.png
│ │ ├── alex-iby-G_Pk4D9rMLs_mask.png
│ │ ├── bench2.png
│ │ ├── bench2_mask.png
│ │ ├── bertrand-gabioud-CpuFzIsHYJ0.png
│ │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png
│ │ ├── billow926-12-Wc-Zgx6Y.png
│ │ ├── billow926-12-Wc-Zgx6Y_mask.png
│ │ ├── overture-creations-5sI6fQgYIuo.png
│ │ ├── overture-creations-5sI6fQgYIuo_mask.png
│ │ ├── photo-1583445095369-9c651e7e5d34.png
│ │ └── photo-1583445095369-9c651e7e5d34_mask.png
├── environment.yaml
├── ldm
│ ├── data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── imagenet.py
│ │ └── lsun.py
│ ├── lr_scheduler.py
│ ├── models
│ │ ├── autoencoder.py
│ │ └── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── classifier.py
│ │ │ ├── ddim.py
│ │ │ ├── ddpm.py
│ │ │ ├── dpm_solver
│ │ │ ├── __init__.py
│ │ │ ├── dpm_solver.py
│ │ │ └── sampler.py
│ │ │ └── plms.py
│ ├── modules
│ │ ├── attention.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── image_degradation
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ ├── utils
│ │ │ │ └── test.png
│ │ │ └── utils_image.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ └── util.py
├── main.py
├── models
│ ├── first_stage_models
│ │ ├── kl-f16
│ │ │ └── config.yaml
│ │ ├── kl-f32
│ │ │ └── config.yaml
│ │ ├── kl-f4
│ │ │ └── config.yaml
│ │ ├── kl-f8
│ │ │ └── config.yaml
│ │ ├── vq-f16
│ │ │ └── config.yaml
│ │ ├── vq-f4-noattn
│ │ │ └── config.yaml
│ │ ├── vq-f4
│ │ │ └── config.yaml
│ │ ├── vq-f8-n256
│ │ │ └── config.yaml
│ │ └── vq-f8
│ │ │ └── config.yaml
│ └── ldm
│ │ ├── bsr_sr
│ │ └── config.yaml
│ │ ├── celeba256
│ │ └── config.yaml
│ │ ├── cin256
│ │ └── config.yaml
│ │ ├── ffhq256
│ │ └── config.yaml
│ │ ├── inpainting_big
│ │ └── config.yaml
│ │ ├── layout2img-openimages256
│ │ └── config.yaml
│ │ ├── lsun_beds256
│ │ └── config.yaml
│ │ ├── lsun_churches256
│ │ └── config.yaml
│ │ ├── semantic_synthesis256
│ │ └── config.yaml
│ │ ├── semantic_synthesis512
│ │ └── config.yaml
│ │ └── text2img256
│ │ └── config.yaml
├── notebook_helpers.py
├── scripts
│ ├── download_first_stages.sh
│ ├── download_models.sh
│ ├── img2img.py
│ ├── inpaint.py
│ ├── knn2img.py
│ ├── latent_imagenet_diffusion.ipynb
│ ├── sample_diffusion.py
│ ├── tests
│ │ └── test_watermark.py
│ ├── train_searcher.py
│ └── txt2img.py
└── setup.py
└── txt2img.py
/.gitignore:
--------------------------------------------------------------------------------
1 | trace
2 | __pycache__
3 | *.ckpt
4 | *.pth
5 | *.pt
6 | *.log
7 | *.DS_Store
8 | .vscode
9 | *.idea
10 |
--------------------------------------------------------------------------------
/assets/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/assets/poster.pdf
--------------------------------------------------------------------------------
/assets/slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/assets/slides.pdf
--------------------------------------------------------------------------------
/ddim/README.md:
--------------------------------------------------------------------------------
1 | # Denoising Diffusion Implicit Models (DDIM)
2 |
3 | [Jiaming Song](http://tsong.me), [Chenlin Meng](http://cs.stanford.edu/~chenlin) and [Stefano Ermon](http://cs.stanford.edu/~ermon), Stanford
4 |
5 | Implements sampling from an implicit model that is trained with the same procedure as [Denoising Diffusion Probabilistic Model](https://hojonathanho.github.io/diffusion/), but costs much less time and compute if you want to sample from it (click image below for a video demo):
6 |
7 | 
8 |
9 | ## **Integration with 🤗 Diffusers library**
10 |
11 | DDIM is now also available in 🧨 Diffusers and accesible via the [DDIMPipeline](https://huggingface.co/docs/diffusers/api/pipelines/ddim).
12 | Diffusers allows you to test DDIM in PyTorch in just a couple lines of code.
13 |
14 | You can install diffusers as follows:
15 |
16 | ```
17 | pip install diffusers torch accelerate
18 | ```
19 |
20 | And then try out the model with just a couple lines of code:
21 |
22 | ```python
23 | from diffusers import DDIMPipeline
24 |
25 | model_id = "google/ddpm-cifar10-32"
26 |
27 | # load model and scheduler
28 | ddim = DDIMPipeline.from_pretrained(model_id)
29 |
30 | # run pipeline in inference (sample random noise and denoise)
31 | image = ddim(num_inference_steps=50).images[0]
32 |
33 | # save image
34 | image.save("ddim_generated_image.png")
35 | ```
36 |
37 | More DDPM/DDIM models compatible with hte DDIM pipeline can be found directly [on the Hub](https://huggingface.co/models?library=diffusers&sort=downloads&search=ddpm)
38 |
39 | To better understand the DDIM scheduler, you can check out [this introductionary google colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)
40 |
41 | The DDIM scheduler can also be used with more powerful diffusion models such as [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#stable-diffusion-pipelines)
42 |
43 | You simply need to [accept the license on the Hub](https://huggingface.co/runwayml/stable-diffusion-v1-5), login with `huggingface-cli login` and install transformers:
44 |
45 | ```
46 | pip install transformers
47 | ```
48 |
49 | Then you can run:
50 |
51 | ```python
52 | from diffusers import StableDiffusionPipeline, DDIMScheduler
53 |
54 | ddim = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
55 | pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=ddim)
56 |
57 | image = pipeline("An astronaut riding a horse.").images[0]
58 |
59 | image.save("astronaut_riding_a_horse.png")
60 | ```
61 |
62 | ## Running the Experiments
63 | The code has been tested on PyTorch 1.6.
64 |
65 | ### Train a model
66 | Training is exactly the same as DDPM with the following:
67 | ```
68 | python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --doc {MODEL_NAME} --ni
69 | ```
70 |
71 | ### Sampling from the model
72 |
73 | #### Sampling from the generalized model for FID evaluation
74 | ```
75 | python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --doc {MODEL_NAME} --sample --fid --timesteps {STEPS} --eta {ETA} --ni
76 | ```
77 | where
78 | - `ETA` controls the scale of the variance (0 is DDIM, and 1 is one type of DDPM).
79 | - `STEPS` controls how many timesteps used in the process.
80 | - `MODEL_NAME` finds the pre-trained checkpoint according to its inferred path.
81 |
82 | If you want to use the DDPM pretrained model:
83 | ```
84 | python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --use_pretrained --sample --fid --timesteps {STEPS} --eta {ETA} --ni
85 | ```
86 | the `--use_pretrained` option will automatically load the model according to the dataset.
87 |
88 | We provide a CelebA 64x64 model [here](https://drive.google.com/file/d/1R_H-fJYXSH79wfSKs9D-fuKQVan5L-GR/view?usp=sharing), and use the DDPM version for CIFAR10 and LSUN.
89 |
90 | If you want to use the version with the larger variance in DDPM: use the `--sample_type ddpm_noisy` option.
91 |
92 | #### Sampling from the model for image inpainting
93 | Use `--interpolation` option instead of `--fid`.
94 |
95 | #### Sampling from the sequence of images that lead to the sample
96 | Use `--sequence` option instead.
97 |
98 | The above two cases contain some hard-coded lines specific to producing the image, so modify them according to your needs.
99 |
100 |
101 | ## References and Acknowledgements
102 | ```
103 | @article{song2020denoising,
104 | title={Denoising Diffusion Implicit Models},
105 | author={Song, Jiaming and Meng, Chenlin and Ermon, Stefano},
106 | journal={arXiv:2010.02502},
107 | year={2020},
108 | month={October},
109 | abbr={Preprint},
110 | url={https://arxiv.org/abs/2010.02502}
111 | }
112 | ```
113 |
114 |
115 | This implementation is based on / inspired by:
116 |
117 | - [https://github.com/hojonathanho/diffusion](https://github.com/hojonathanho/diffusion) (the DDPM TensorFlow repo),
118 | - [https://github.com/pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion) (PyTorch helper that loads the DDPM model), and
119 | - [https://github.com/ermongroup/ncsnv2](https://github.com/ermongroup/ncsnv2) (code structure).
120 |
--------------------------------------------------------------------------------
/ddim/configs/bedroom.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "LSUN"
3 | category: "bedroom"
4 | image_size: 256
5 | channels: 3
6 | logit_transform: false
7 | uniform_dequantization: false
8 | gaussian_dequantization: false
9 | random_flip: true
10 | rescaled: true
11 | num_workers: 32
12 |
13 | model:
14 | type: "simple"
15 | in_channels: 3
16 | out_ch: 3
17 | ch: 128
18 | ch_mult: [1, 1, 2, 2, 4, 4]
19 | num_res_blocks: 2
20 | attn_resolutions: [16, ]
21 | dropout: 0.0
22 | var_type: fixedsmall
23 | ema_rate: 0.999
24 | ema: True
25 | resamp_with_conv: True
26 |
27 | diffusion:
28 | beta_schedule: linear
29 | beta_start: 0.0001
30 | beta_end: 0.02
31 | num_diffusion_timesteps: 1000
32 |
33 | training:
34 | batch_size: 64
35 | n_epochs: 10000
36 | n_iters: 5000000
37 | snapshot_freq: 5000
38 | validation_freq: 2000
39 |
40 | sampling:
41 | batch_size: 32
42 | last_only: True
43 |
44 | optim:
45 | weight_decay: 0.000
46 | optimizer: "Adam"
47 | lr: 0.00002
48 | beta1: 0.9
49 | amsgrad: false
50 | eps: 0.00000001
51 |
--------------------------------------------------------------------------------
/ddim/configs/celeba.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "CELEBA"
3 | image_size: 64
4 | channels: 3
5 | logit_transform: false
6 | uniform_dequantization: false
7 | gaussian_dequantization: false
8 | random_flip: true
9 | rescaled: true
10 | num_workers: 4
11 |
12 | model:
13 | type: "simple"
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult: [1, 2, 2, 2, 4]
18 | num_res_blocks: 2
19 | attn_resolutions: [16, ]
20 | dropout: 0.1
21 | var_type: fixedlarge
22 | ema_rate: 0.9999
23 | ema: True
24 | resamp_with_conv: True
25 |
26 | diffusion:
27 | beta_schedule: linear
28 | beta_start: 0.0001
29 | beta_end: 0.02
30 | num_diffusion_timesteps: 1000
31 |
32 | training:
33 | batch_size: 128
34 | n_epochs: 10000
35 | n_iters: 5000000
36 | snapshot_freq: 5000
37 | validation_freq: 20000
38 |
39 | sampling:
40 | batch_size: 32
41 | last_only: True
42 |
43 | optim:
44 | weight_decay: 0.000
45 | optimizer: "Adam"
46 | lr: 0.0002
47 | beta1: 0.9
48 | amsgrad: false
49 | eps: 0.00000001
50 | grad_clip: 1.0
51 |
--------------------------------------------------------------------------------
/ddim/configs/church.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "LSUN"
3 | category: "church_outdoor"
4 | image_size: 256
5 | channels: 3
6 | logit_transform: false
7 | uniform_dequantization: false
8 | gaussian_dequantization: false
9 | random_flip: true
10 | rescaled: true
11 | num_workers: 32
12 |
13 | model:
14 | type: "simple"
15 | in_channels: 3
16 | out_ch: 3
17 | ch: 128
18 | ch_mult: [1, 1, 2, 2, 4, 4]
19 | num_res_blocks: 2
20 | attn_resolutions: [16, ]
21 | dropout: 0.0
22 | var_type: fixedsmall
23 | ema_rate: 0.999
24 | ema: True
25 | resamp_with_conv: True
26 |
27 | diffusion:
28 | beta_schedule: linear
29 | beta_start: 0.0001
30 | beta_end: 0.02
31 | num_diffusion_timesteps: 1000
32 |
33 | training:
34 | batch_size: 64
35 | n_epochs: 10000
36 | n_iters: 5000000
37 | snapshot_freq: 5000
38 | validation_freq: 2000
39 |
40 | sampling:
41 | batch_size: 32
42 | last_only: True
43 |
44 | optim:
45 | weight_decay: 0.000
46 | optimizer: "Adam"
47 | lr: 0.00002
48 | beta1: 0.9
49 | amsgrad: false
50 | eps: 0.00000001
51 |
--------------------------------------------------------------------------------
/ddim/configs/cifar10.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "CIFAR10"
3 | image_size: 32
4 | channels: 3
5 | logit_transform: false
6 | uniform_dequantization: false
7 | gaussian_dequantization: false
8 | random_flip: true
9 | rescaled: true
10 | num_workers: 4
11 |
12 | model:
13 | type: "simple"
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult: [1, 2, 2, 2]
18 | num_res_blocks: 2
19 | attn_resolutions: [16, ]
20 | dropout: 0.1
21 | var_type: fixedlarge
22 | ema_rate: 0.9999
23 | ema: True
24 | resamp_with_conv: True
25 |
26 | diffusion:
27 | beta_schedule: linear
28 | beta_start: 0.0001
29 | beta_end: 0.02
30 | num_diffusion_timesteps: 1000
31 |
32 | training:
33 | batch_size: 128
34 | n_epochs: 10000
35 | n_iters: 5000000
36 | snapshot_freq: 5000
37 | validation_freq: 2000
38 |
39 | sampling:
40 | batch_size: 64
41 | last_only: True
42 |
43 | optim:
44 | weight_decay: 0.000
45 | optimizer: "Adam"
46 | lr: 0.0002
47 | beta1: 0.9
48 | amsgrad: false
49 | eps: 0.00000001
50 | grad_clip: 1.0
51 |
--------------------------------------------------------------------------------
/ddim/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numbers
4 | import torchvision.transforms as transforms
5 | import torchvision.transforms.functional as F
6 | from torchvision.datasets import CIFAR10
7 | from .celeba import CelebA
8 | from .ffhq import FFHQ
9 | from .lsun import LSUN
10 | from torch.utils.data import Subset
11 | import numpy as np
12 |
13 |
14 | class Crop(object):
15 | def __init__(self, x1, x2, y1, y2):
16 | self.x1 = x1
17 | self.x2 = x2
18 | self.y1 = y1
19 | self.y2 = y2
20 |
21 | def __call__(self, img):
22 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)
23 |
24 | def __repr__(self):
25 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
26 | self.x1, self.x2, self.y1, self.y2
27 | )
28 |
29 |
30 | def get_dataset(args, config):
31 | if config.data.random_flip is False:
32 | tran_transform = test_transform = transforms.Compose(
33 | [transforms.Resize(config.data.image_size), transforms.ToTensor()]
34 | )
35 | else:
36 | tran_transform = transforms.Compose(
37 | [
38 | transforms.Resize(config.data.image_size),
39 | transforms.RandomHorizontalFlip(p=0.5),
40 | transforms.ToTensor(),
41 | ]
42 | )
43 | test_transform = transforms.Compose(
44 | [transforms.Resize(config.data.image_size), transforms.ToTensor()]
45 | )
46 |
47 | if config.data.dataset == "CIFAR10":
48 | dataset = CIFAR10(
49 | os.path.join(args.exp, "datasets", "cifar10"),
50 | train=True,
51 | download=True,
52 | transform=tran_transform,
53 | )
54 | test_dataset = CIFAR10(
55 | os.path.join(args.exp, "datasets", "cifar10_test"),
56 | train=False,
57 | download=True,
58 | transform=test_transform,
59 | )
60 |
61 | elif config.data.dataset == "CELEBA":
62 | cx = 89
63 | cy = 121
64 | x1 = cy - 64
65 | x2 = cy + 64
66 | y1 = cx - 64
67 | y2 = cx + 64
68 | if config.data.random_flip:
69 | dataset = CelebA(
70 | root=os.path.join(args.exp, "datasets", "celeba"),
71 | split="train",
72 | transform=transforms.Compose(
73 | [
74 | Crop(x1, x2, y1, y2),
75 | transforms.Resize(config.data.image_size),
76 | transforms.RandomHorizontalFlip(),
77 | transforms.ToTensor(),
78 | ]
79 | ),
80 | download=True,
81 | )
82 | else:
83 | dataset = CelebA(
84 | root=os.path.join(args.exp, "datasets", "celeba"),
85 | split="train",
86 | transform=transforms.Compose(
87 | [
88 | Crop(x1, x2, y1, y2),
89 | transforms.Resize(config.data.image_size),
90 | transforms.ToTensor(),
91 | ]
92 | ),
93 | download=True,
94 | )
95 |
96 | test_dataset = CelebA(
97 | root=os.path.join(args.exp, "datasets", "celeba"),
98 | split="test",
99 | transform=transforms.Compose(
100 | [
101 | Crop(x1, x2, y1, y2),
102 | transforms.Resize(config.data.image_size),
103 | transforms.ToTensor(),
104 | ]
105 | ),
106 | download=True,
107 | )
108 |
109 | elif config.data.dataset == "LSUN":
110 | train_folder = "{}_train".format(config.data.category)
111 | val_folder = "{}_val".format(config.data.category)
112 | if config.data.random_flip:
113 | dataset = LSUN(
114 | root=os.path.join(args.exp, "datasets", "lsun"),
115 | classes=[train_folder],
116 | transform=transforms.Compose(
117 | [
118 | transforms.Resize(config.data.image_size),
119 | transforms.CenterCrop(config.data.image_size),
120 | transforms.RandomHorizontalFlip(p=0.5),
121 | transforms.ToTensor(),
122 | ]
123 | ),
124 | )
125 | else:
126 | dataset = LSUN(
127 | root=os.path.join(args.exp, "datasets", "lsun"),
128 | classes=[train_folder],
129 | transform=transforms.Compose(
130 | [
131 | transforms.Resize(config.data.image_size),
132 | transforms.CenterCrop(config.data.image_size),
133 | transforms.ToTensor(),
134 | ]
135 | ),
136 | )
137 |
138 | test_dataset = LSUN(
139 | root=os.path.join(args.exp, "datasets", "lsun"),
140 | classes=[val_folder],
141 | transform=transforms.Compose(
142 | [
143 | transforms.Resize(config.data.image_size),
144 | transforms.CenterCrop(config.data.image_size),
145 | transforms.ToTensor(),
146 | ]
147 | ),
148 | )
149 |
150 | elif config.data.dataset == "FFHQ":
151 | if config.data.random_flip:
152 | dataset = FFHQ(
153 | path=os.path.join(args.exp, "datasets", "FFHQ"),
154 | transform=transforms.Compose(
155 | [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()]
156 | ),
157 | resolution=config.data.image_size,
158 | )
159 | else:
160 | dataset = FFHQ(
161 | path=os.path.join(args.exp, "datasets", "FFHQ"),
162 | transform=transforms.ToTensor(),
163 | resolution=config.data.image_size,
164 | )
165 |
166 | num_items = len(dataset)
167 | indices = list(range(num_items))
168 | random_state = np.random.get_state()
169 | np.random.seed(2019)
170 | np.random.shuffle(indices)
171 | np.random.set_state(random_state)
172 | train_indices, test_indices = (
173 | indices[: int(num_items * 0.9)],
174 | indices[int(num_items * 0.9) :],
175 | )
176 | test_dataset = Subset(dataset, test_indices)
177 | dataset = Subset(dataset, train_indices)
178 | else:
179 | dataset, test_dataset = None, None
180 |
181 | return dataset, test_dataset
182 |
183 |
184 | def logit_transform(image, lam=1e-6):
185 | image = lam + (1 - 2 * lam) * image
186 | return torch.log(image) - torch.log1p(-image)
187 |
188 |
189 | def data_transform(config, X):
190 | if config.data.uniform_dequantization:
191 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
192 | if config.data.gaussian_dequantization:
193 | X = X + torch.randn_like(X) * 0.01
194 |
195 | if config.data.rescaled:
196 | X = 2 * X - 1.0
197 | elif config.data.logit_transform:
198 | X = logit_transform(X)
199 |
200 | if hasattr(config, "image_mean"):
201 | return X - config.image_mean.to(X.device)[None, ...]
202 |
203 | return X
204 |
205 |
206 | def inverse_data_transform(config, X):
207 | if hasattr(config, "image_mean"):
208 | X = X + config.image_mean.to(X.device)[None, ...]
209 |
210 | if config.data.logit_transform:
211 | X = torch.sigmoid(X)
212 | elif config.data.rescaled:
213 | X = (X + 1.0) / 2.0
214 |
215 | return torch.clamp(X, 0.0, 1.0)
216 |
--------------------------------------------------------------------------------
/ddim/datasets/celeba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import PIL
4 | from .vision import VisionDataset
5 | from .utils import download_file_from_google_drive, check_integrity
6 |
7 |
8 | class CelebA(VisionDataset):
9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
10 |
11 | Args:
12 | root (string): Root directory where images are downloaded to.
13 | split (string): One of {'train', 'valid', 'test'}.
14 | Accordingly dataset is selected.
15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types.
17 | The targets represent:
18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
19 | ``identity`` (int): label for each person (data points with the same identity are the same person)
20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
23 | Defaults to ``attr``.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.ToTensor``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | download (bool, optional): If true, downloads the dataset from the internet and
29 | puts it in root directory. If dataset is already downloaded, it is not
30 | downloaded again.
31 | """
32 |
33 | base_folder = "celeba"
34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
36 | # right now.
37 | file_list = [
38 | # File ID MD5 Hash Filename
39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
48 | ]
49 |
50 | def __init__(self, root,
51 | split="train",
52 | target_type="attr",
53 | transform=None, target_transform=None,
54 | download=False):
55 | import pandas
56 | super(CelebA, self).__init__(root)
57 | self.split = split
58 | if isinstance(target_type, list):
59 | self.target_type = target_type
60 | else:
61 | self.target_type = [target_type]
62 | self.transform = transform
63 | self.target_transform = target_transform
64 |
65 | if download:
66 | self.download()
67 |
68 | if not self._check_integrity():
69 | raise RuntimeError('Dataset not found or corrupted.' +
70 | ' You can use download=True to download it')
71 |
72 | self.transform = transform
73 | self.target_transform = target_transform
74 |
75 | if split.lower() == "train":
76 | split = 0
77 | elif split.lower() == "valid":
78 | split = 1
79 | elif split.lower() == "test":
80 | split = 2
81 | else:
82 | raise ValueError('Wrong split entered! Please use split="train" '
83 | 'or split="valid" or split="test"')
84 |
85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
87 |
88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
90 |
91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)
93 |
94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)
96 |
97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)
99 |
100 | mask = (splits[1] == split)
101 | self.filename = splits[mask].index.values
102 | self.identity = torch.as_tensor(self.identity[mask].values)
103 | self.bbox = torch.as_tensor(self.bbox[mask].values)
104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
105 | self.attr = torch.as_tensor(self.attr[mask].values)
106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
107 |
108 | def _check_integrity(self):
109 | for (_, md5, filename) in self.file_list:
110 | fpath = os.path.join(self.root, self.base_folder, filename)
111 | _, ext = os.path.splitext(filename)
112 | # Allow original archive to be deleted (zip and 7z)
113 | # Only need the extracted images
114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
115 | return False
116 |
117 | # Should check a hash of the images
118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
119 |
120 | def download(self):
121 | import zipfile
122 |
123 | if self._check_integrity():
124 | print('Files already downloaded and verified')
125 | return
126 |
127 | for (file_id, md5, filename) in self.file_list:
128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
129 |
130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
131 | f.extractall(os.path.join(self.root, self.base_folder))
132 |
133 | def __getitem__(self, index):
134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
135 |
136 | target = []
137 | for t in self.target_type:
138 | if t == "attr":
139 | target.append(self.attr[index, :])
140 | elif t == "identity":
141 | target.append(self.identity[index, 0])
142 | elif t == "bbox":
143 | target.append(self.bbox[index, :])
144 | elif t == "landmarks":
145 | target.append(self.landmarks_align[index, :])
146 | else:
147 | raise ValueError("Target type \"{}\" is not recognized.".format(t))
148 | target = tuple(target) if len(target) > 1 else target[0]
149 |
150 | if self.transform is not None:
151 | X = self.transform(X)
152 |
153 | if self.target_transform is not None:
154 | target = self.target_transform(target)
155 |
156 | return X, target
157 |
158 | def __len__(self):
159 | return len(self.attr)
160 |
161 | def extra_repr(self):
162 | lines = ["Target type: {target_type}", "Split: {split}"]
163 | return '\n'.join(lines).format(**self.__dict__)
164 |
--------------------------------------------------------------------------------
/ddim/datasets/ffhq.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class FFHQ(Dataset):
9 | def __init__(self, path, transform, resolution=8):
10 | self.env = lmdb.open(
11 | path,
12 | max_readers=32,
13 | readonly=True,
14 | lock=False,
15 | readahead=False,
16 | meminit=False,
17 | )
18 |
19 | if not self.env:
20 | raise IOError('Cannot open lmdb dataset', path)
21 |
22 | with self.env.begin(write=False) as txn:
23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24 |
25 | self.resolution = resolution
26 | self.transform = transform
27 |
28 | def __len__(self):
29 | return self.length
30 |
31 | def __getitem__(self, index):
32 | with self.env.begin(write=False) as txn:
33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34 | img_bytes = txn.get(key)
35 |
36 | buffer = BytesIO(img_bytes)
37 | img = Image.open(buffer)
38 | img = self.transform(img)
39 | target = 0
40 |
41 | return img, target
--------------------------------------------------------------------------------
/ddim/datasets/lsun.py:
--------------------------------------------------------------------------------
1 | from .vision import VisionDataset
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import io
6 | from collections.abc import Iterable
7 | import pickle
8 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str
9 |
10 |
11 | class LSUNClass(VisionDataset):
12 | def __init__(self, root, transform=None, target_transform=None):
13 | import lmdb
14 |
15 | super(LSUNClass, self).__init__(
16 | root, transform=transform, target_transform=target_transform
17 | )
18 |
19 | self.env = lmdb.open(
20 | root,
21 | max_readers=1,
22 | readonly=True,
23 | lock=False,
24 | readahead=False,
25 | meminit=False,
26 | )
27 | with self.env.begin(write=False) as txn:
28 | self.length = txn.stat()["entries"]
29 | root_split = root.split("/")
30 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
31 | if os.path.isfile(cache_file):
32 | self.keys = pickle.load(open(cache_file, "rb"))
33 | else:
34 | with self.env.begin(write=False) as txn:
35 | self.keys = [key for key, _ in txn.cursor()]
36 | pickle.dump(self.keys, open(cache_file, "wb"))
37 |
38 | def __getitem__(self, index):
39 | img, target = None, None
40 | env = self.env
41 | with env.begin(write=False) as txn:
42 | imgbuf = txn.get(self.keys[index])
43 |
44 | buf = io.BytesIO()
45 | buf.write(imgbuf)
46 | buf.seek(0)
47 | img = Image.open(buf).convert("RGB")
48 |
49 | if self.transform is not None:
50 | img = self.transform(img)
51 |
52 | if self.target_transform is not None:
53 | target = self.target_transform(target)
54 |
55 | return img, target
56 |
57 | def __len__(self):
58 | return self.length
59 |
60 |
61 | class LSUN(VisionDataset):
62 | """
63 | `LSUN `_ dataset.
64 |
65 | Args:
66 | root (string): Root directory for the database files.
67 | classes (string or list): One of {'train', 'val', 'test'} or a list of
68 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
69 | transform (callable, optional): A function/transform that takes in an PIL image
70 | and returns a transformed version. E.g, ``transforms.RandomCrop``
71 | target_transform (callable, optional): A function/transform that takes in the
72 | target and transforms it.
73 | """
74 |
75 | def __init__(self, root, classes="train", transform=None, target_transform=None):
76 | super(LSUN, self).__init__(
77 | root, transform=transform, target_transform=target_transform
78 | )
79 | self.classes = self._verify_classes(classes)
80 |
81 | # for each class, create an LSUNClassDataset
82 | self.dbs = []
83 | for c in self.classes:
84 | self.dbs.append(
85 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform)
86 | )
87 |
88 | self.indices = []
89 | count = 0
90 | for db in self.dbs:
91 | count += len(db)
92 | self.indices.append(count)
93 |
94 | self.length = count
95 |
96 | def _verify_classes(self, classes):
97 | categories = [
98 | "bedroom",
99 | "bridge",
100 | "church_outdoor",
101 | "classroom",
102 | "conference_room",
103 | "dining_room",
104 | "kitchen",
105 | "living_room",
106 | "restaurant",
107 | "tower",
108 | ]
109 | dset_opts = ["train", "val", "test"]
110 |
111 | try:
112 | verify_str_arg(classes, "classes", dset_opts)
113 | if classes == "test":
114 | classes = [classes]
115 | else:
116 | classes = [c + "_" + classes for c in categories]
117 | except ValueError:
118 | if not isinstance(classes, Iterable):
119 | msg = (
120 | "Expected type str or Iterable for argument classes, "
121 | "but got type {}."
122 | )
123 | raise ValueError(msg.format(type(classes)))
124 |
125 | classes = list(classes)
126 | msg_fmtstr = (
127 | "Expected type str for elements in argument classes, "
128 | "but got type {}."
129 | )
130 | for c in classes:
131 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
132 | c_short = c.split("_")
133 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
134 |
135 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
136 | msg = msg_fmtstr.format(
137 | category, "LSUN class", iterable_to_str(categories)
138 | )
139 | verify_str_arg(category, valid_values=categories, custom_msg=msg)
140 |
141 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
142 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
143 |
144 | return classes
145 |
146 | def __getitem__(self, index):
147 | """
148 | Args:
149 | index (int): Index
150 |
151 | Returns:
152 | tuple: Tuple (image, target) where target is the index of the target category.
153 | """
154 | target = 0
155 | sub = 0
156 | for ind in self.indices:
157 | if index < ind:
158 | break
159 | target += 1
160 | sub = ind
161 |
162 | db = self.dbs[target]
163 | index = index - sub
164 |
165 | if self.target_transform is not None:
166 | target = self.target_transform(target)
167 |
168 | img, _ = db[index]
169 | return img, target
170 |
171 | def __len__(self):
172 | return self.length
173 |
174 | def extra_repr(self):
175 | return "Classes: {classes}".format(**self.__dict__)
176 |
--------------------------------------------------------------------------------
/ddim/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import errno
5 | from torch.utils.model_zoo import tqdm
6 |
7 |
8 | def gen_bar_updater():
9 | pbar = tqdm(total=None)
10 |
11 | def bar_update(count, block_size, total_size):
12 | if pbar.total is None and total_size:
13 | pbar.total = total_size
14 | progress_bytes = count * block_size
15 | pbar.update(progress_bytes - pbar.n)
16 |
17 | return bar_update
18 |
19 |
20 | def check_integrity(fpath, md5=None):
21 | if md5 is None:
22 | return True
23 | if not os.path.isfile(fpath):
24 | return False
25 | md5o = hashlib.md5()
26 | with open(fpath, 'rb') as f:
27 | # read in 1MB chunks
28 | for chunk in iter(lambda: f.read(1024 * 1024), b''):
29 | md5o.update(chunk)
30 | md5c = md5o.hexdigest()
31 | if md5c != md5:
32 | return False
33 | return True
34 |
35 |
36 | def makedir_exist_ok(dirpath):
37 | """
38 | Python2 support for os.makedirs(.., exist_ok=True)
39 | """
40 | try:
41 | os.makedirs(dirpath)
42 | except OSError as e:
43 | if e.errno == errno.EEXIST:
44 | pass
45 | else:
46 | raise
47 |
48 |
49 | def download_url(url, root, filename=None, md5=None):
50 | """Download a file from a url and place it in root.
51 |
52 | Args:
53 | url (str): URL to download file from
54 | root (str): Directory to place downloaded file in
55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL
56 | md5 (str, optional): MD5 checksum of the download. If None, do not check
57 | """
58 | from six.moves import urllib
59 |
60 | root = os.path.expanduser(root)
61 | if not filename:
62 | filename = os.path.basename(url)
63 | fpath = os.path.join(root, filename)
64 |
65 | makedir_exist_ok(root)
66 |
67 | # downloads file
68 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
69 | print('Using downloaded and verified file: ' + fpath)
70 | else:
71 | try:
72 | print('Downloading ' + url + ' to ' + fpath)
73 | urllib.request.urlretrieve(
74 | url, fpath,
75 | reporthook=gen_bar_updater()
76 | )
77 | except OSError:
78 | if url[:5] == 'https':
79 | url = url.replace('https:', 'http:')
80 | print('Failed download. Trying https -> http instead.'
81 | ' Downloading ' + url + ' to ' + fpath)
82 | urllib.request.urlretrieve(
83 | url, fpath,
84 | reporthook=gen_bar_updater()
85 | )
86 |
87 |
88 | def list_dir(root, prefix=False):
89 | """List all directories at a given root
90 |
91 | Args:
92 | root (str): Path to directory whose folders need to be listed
93 | prefix (bool, optional): If true, prepends the path to each result, otherwise
94 | only returns the name of the directories found
95 | """
96 | root = os.path.expanduser(root)
97 | directories = list(
98 | filter(
99 | lambda p: os.path.isdir(os.path.join(root, p)),
100 | os.listdir(root)
101 | )
102 | )
103 |
104 | if prefix is True:
105 | directories = [os.path.join(root, d) for d in directories]
106 |
107 | return directories
108 |
109 |
110 | def list_files(root, suffix, prefix=False):
111 | """List all files ending with a suffix at a given root
112 |
113 | Args:
114 | root (str): Path to directory whose folders need to be listed
115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
116 | It uses the Python "str.endswith" method and is passed directly
117 | prefix (bool, optional): If true, prepends the path to each result, otherwise
118 | only returns the name of the files found
119 | """
120 | root = os.path.expanduser(root)
121 | files = list(
122 | filter(
123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
124 | os.listdir(root)
125 | )
126 | )
127 |
128 | if prefix is True:
129 | files = [os.path.join(root, d) for d in files]
130 |
131 | return files
132 |
133 |
134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None):
135 | """Download a Google Drive file from and place it in root.
136 |
137 | Args:
138 | file_id (str): id of file to be downloaded
139 | root (str): Directory to place downloaded file in
140 | filename (str, optional): Name to save the file under. If None, use the id of the file.
141 | md5 (str, optional): MD5 checksum of the download. If None, do not check
142 | """
143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
144 | import requests
145 | url = "https://docs.google.com/uc?export=download"
146 |
147 | root = os.path.expanduser(root)
148 | if not filename:
149 | filename = file_id
150 | fpath = os.path.join(root, filename)
151 |
152 | makedir_exist_ok(root)
153 |
154 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
155 | print('Using downloaded and verified file: ' + fpath)
156 | else:
157 | session = requests.Session()
158 |
159 | response = session.get(url, params={'id': file_id}, stream=True)
160 | token = _get_confirm_token(response)
161 |
162 | if token:
163 | params = {'id': file_id, 'confirm': token}
164 | response = session.get(url, params=params, stream=True)
165 |
166 | _save_response_content(response, fpath)
167 |
168 |
169 | def _get_confirm_token(response):
170 | for key, value in response.cookies.items():
171 | if key.startswith('download_warning'):
172 | return value
173 |
174 | return None
175 |
176 |
177 | def _save_response_content(response, destination, chunk_size=32768):
178 | with open(destination, "wb") as f:
179 | pbar = tqdm(total=None)
180 | progress = 0
181 | for chunk in response.iter_content(chunk_size):
182 | if chunk: # filter out keep-alive new chunks
183 | f.write(chunk)
184 | progress += len(chunk)
185 | pbar.update(progress - pbar.n)
186 | pbar.close()
187 |
--------------------------------------------------------------------------------
/ddim/datasets/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, 'transform') and self.transform is not None:
41 | body += self._format_transform_repr(self.transform,
42 | "Transforms: ")
43 | if hasattr(self, 'target_transform') and self.target_transform is not None:
44 | body += self._format_transform_repr(self.target_transform,
45 | "Target transforms: ")
46 | lines = [head] + [" " * self._repr_indent + line for line in body]
47 | return '\n'.join(lines)
48 |
49 | def _format_transform_repr(self, transform, head):
50 | lines = transform.__repr__().splitlines()
51 | return (["{}{}".format(head, lines[0])] +
52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
53 |
54 | def extra_repr(self):
55 | return ""
56 |
57 |
58 | class StandardTransform(object):
59 | def __init__(self, transform=None, target_transform=None):
60 | self.transform = transform
61 | self.target_transform = target_transform
62 |
63 | def __call__(self, input, target):
64 | if self.transform is not None:
65 | input = self.transform(input)
66 | if self.target_transform is not None:
67 | target = self.target_transform(target)
68 | return input, target
69 |
70 | def _format_transform_repr(self, transform, head):
71 | lines = transform.__repr__().splitlines()
72 | return (["{}{}".format(head, lines[0])] +
73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
74 |
75 | def __repr__(self):
76 | body = [self.__class__.__name__]
77 | if self.transform is not None:
78 | body += self._format_transform_repr(self.transform,
79 | "Transform: ")
80 | if self.target_transform is not None:
81 | body += self._format_transform_repr(self.target_transform,
82 | "Target transform: ")
83 |
84 | return '\n'.join(body)
85 |
--------------------------------------------------------------------------------
/ddim/functions/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | def get_optimizer(config, parameters):
5 | if config.optim.optimizer == 'Adam':
6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad,
8 | eps=config.optim.eps)
9 | elif config.optim.optimizer == 'RMSProp':
10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
11 | elif config.optim.optimizer == 'SGD':
12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
13 | else:
14 | raise NotImplementedError(
15 | 'Optimizer {} not understood.'.format(config.optim.optimizer))
16 |
--------------------------------------------------------------------------------
/ddim/functions/ckpt_util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
14 | }
15 | CKPT_MAP = {
16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt",
17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt",
18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt",
19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt",
20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt",
21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt",
22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt",
23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt",
24 | }
25 | MD5_MAP = {
26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669",
27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3",
28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c",
29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f",
30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b",
31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558",
32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3",
33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f",
34 | }
35 |
36 |
37 | def download(url, local_path, chunk_size=1024):
38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
39 | with requests.get(url, stream=True) as r:
40 | total_size = int(r.headers.get("content-length", 0))
41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
42 | with open(local_path, "wb") as f:
43 | for data in r.iter_content(chunk_size=chunk_size):
44 | if data:
45 | f.write(data)
46 | pbar.update(chunk_size)
47 |
48 |
49 | def md5_hash(path):
50 | with open(path, "rb") as f:
51 | content = f.read()
52 | return hashlib.md5(content).hexdigest()
53 |
54 |
55 | def get_ckpt_path(name, root=None, check=False):
56 | if 'church_outdoor' in name:
57 | name = name.replace('church_outdoor', 'church')
58 | assert name in URL_MAP
59 | # Modify the path when necessary
60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("/atlas/u/tsong/.cache"))
61 | root = (
62 | root
63 | if root is not None
64 | else os.path.join(cachedir, "diffusion_models_converted")
65 | )
66 | path = os.path.join(root, CKPT_MAP[name])
67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
69 | download(URL_MAP[name], path)
70 | md5 = md5_hash(path)
71 | assert md5 == MD5_MAP[name], md5
72 | return path
73 |
--------------------------------------------------------------------------------
/ddim/functions/denoising.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def compute_alpha(beta, t):
5 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
6 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
7 | return a
8 |
9 |
10 | def generalized_steps(x, seq, model, b, **kwargs):
11 | with torch.no_grad():
12 | n = x.size(0)
13 | seq_next = [-1] + list(seq[:-1])
14 | x0_preds = []
15 | xs = [x]
16 | cnt = 0
17 | t, xt = None, None
18 | for i, j in zip(reversed(seq), reversed(seq_next)):
19 | t = (torch.ones(n) * i).to(x.device)
20 | next_t = (torch.ones(n) * j).to(x.device)
21 | at = compute_alpha(b, t.long())
22 | at_next = compute_alpha(b, next_t.long())
23 | xt = xs[-1].to('cuda')
24 | if "untill_fake_t" in kwargs and cnt == kwargs["untill_fake_t"] - 1:
25 | break
26 | if 'tot' in kwargs and kwargs['tot'] is not None:
27 | # act = kwargs['cali_ckpt'][f'act_{kwargs["t_max"] - i // kwargs["tot"]}']
28 | act = kwargs['cali_ckpt'][f'act_{cnt}']
29 | model.load_state_dict(act, strict=False)
30 | et = model(xt, t)
31 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
32 | x0_preds.append(x0_t.to('cpu'))
33 | c1 = (
34 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
35 | )
36 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
37 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
38 | xs.append(xt_next.to('cpu'))
39 | cnt += 1
40 |
41 | return xs, x0_preds, xt, t
42 |
43 |
44 | def ddpm_steps(x, seq, model, b, **kwargs):
45 | with torch.no_grad():
46 | n = x.size(0)
47 | seq_next = [-1] + list(seq[:-1])
48 | xs = [x]
49 | x0_preds = []
50 | betas = b
51 | cnt = 0
52 | t, xt = None, None
53 | for i, j in zip(reversed(seq), reversed(seq_next)):
54 | t = (torch.ones(n) * i).to(x.device)
55 | next_t = (torch.ones(n) * j).to(x.device)
56 | at = compute_alpha(betas, t.long())
57 | atm1 = compute_alpha(betas, next_t.long())
58 | beta_t = 1 - at / atm1
59 | x = xs[-1].to('cuda')
60 | if "untill_fake_t" in kwargs and cnt == kwargs["untill_fake_t"] - 1:
61 | break
62 | if 'tot' in kwargs and kwargs['tot'] is not None:
63 | # act = kwargs['cali_ckpt'][f'act_{kwargs["t_max"] - i // kwargs["tot"]}']
64 | act = kwargs['cali_ckpt'][f'act_{cnt}']
65 | model.load_state_dict(act, strict=False)
66 | output = model(x, t.float())
67 | e = output
68 |
69 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
70 | x0_from_e = torch.clamp(x0_from_e, -1, 1)
71 | x0_preds.append(x0_from_e.to('cpu'))
72 | mean_eps = (
73 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
74 | ) / (1.0 - at)
75 |
76 | mean = mean_eps
77 | noise = torch.randn_like(x)
78 | mask = 1 - (t == 0).float()
79 | mask = mask.view(-1, 1, 1, 1)
80 | logvar = beta_t.log()
81 | sample = mean + mask * torch.exp(0.5 * logvar) * noise
82 | xs.append(sample.to('cpu'))
83 | return xs, x0_preds, xt, t
84 |
--------------------------------------------------------------------------------
/ddim/functions/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def noise_estimation_loss(model,
5 | x0: torch.Tensor,
6 | t: torch.LongTensor,
7 | e: torch.Tensor,
8 | b: torch.Tensor, keepdim=False):
9 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
10 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
11 | output = model(x, t.float())
12 | if keepdim:
13 | return (e - output).square().sum(dim=(1, 2, 3))
14 | else:
15 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
16 |
17 |
18 | loss_registry = {
19 | 'simple': noise_estimation_loss,
20 | }
21 |
--------------------------------------------------------------------------------
/ddim/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import traceback
3 | import shutil
4 | import logging
5 | import yaml
6 | import sys
7 | import os
8 | import torch
9 | import numpy as np
10 | import torch.utils.tensorboard as tb
11 |
12 | from runners.diffusion import Diffusion
13 |
14 | torch.set_printoptions(sci_mode=False)
15 |
16 |
17 | def parse_args_and_config():
18 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
19 |
20 | parser.add_argument(
21 | "--config", type=str, required=True, help="Path to the config file"
22 | )
23 | parser.add_argument("--seed", type=int, default=1234, help="Random seed")
24 | parser.add_argument(
25 | "--exp", type=str, default="exp", help="Path for saving running related data."
26 | )
27 | parser.add_argument(
28 | "--doc",
29 | type=str,
30 | required=True,
31 | help="A string for documentation purpose. "
32 | "Will be the name of the log folder.",
33 | )
34 | parser.add_argument(
35 | "--comment", type=str, default="", help="A string for experiment comment"
36 | )
37 | parser.add_argument(
38 | "--verbose",
39 | type=str,
40 | default="info",
41 | help="Verbose level: info | debug | warning | critical",
42 | )
43 | parser.add_argument("--test", action="store_true", help="Whether to test the model")
44 | parser.add_argument(
45 | "--sample",
46 | action="store_true",
47 | help="Whether to produce samples from the model",
48 | )
49 | parser.add_argument("--fid", action="store_true")
50 | parser.add_argument("--interpolation", action="store_true")
51 | parser.add_argument(
52 | "--resume_training", action="store_true", help="Whether to resume training"
53 | )
54 | parser.add_argument(
55 | "-i",
56 | "--image_folder",
57 | type=str,
58 | default="images",
59 | help="The folder name of samples",
60 | )
61 | parser.add_argument(
62 | "--ni",
63 | action="store_true",
64 | help="No interaction. Suitable for Slurm Job launcher",
65 | )
66 | parser.add_argument("--use_pretrained", action="store_true")
67 | parser.add_argument(
68 | "--sample_type",
69 | type=str,
70 | default="generalized",
71 | help="sampling approach (generalized or ddpm_noisy)",
72 | )
73 | parser.add_argument(
74 | "--skip_type",
75 | type=str,
76 | default="uniform",
77 | help="skip according to (uniform or quadratic)",
78 | )
79 | parser.add_argument(
80 | "--timesteps", type=int, default=1000, help="number of steps involved"
81 | )
82 | parser.add_argument(
83 | "--eta",
84 | type=float,
85 | default=0.0,
86 | help="eta used to control the variances of sigma",
87 | )
88 | parser.add_argument("--sequence", action="store_true")
89 |
90 | args = parser.parse_args()
91 | args.log_path = os.path.join(args.exp, "logs", args.doc)
92 |
93 | # parse config file
94 | with open(os.path.join("configs", args.config), "r") as f:
95 | config = yaml.safe_load(f)
96 | new_config = dict2namespace(config)
97 |
98 | tb_path = os.path.join(args.exp, "tensorboard", args.doc)
99 |
100 | if not args.test and not args.sample:
101 | if not args.resume_training:
102 | if os.path.exists(args.log_path):
103 | overwrite = False
104 | if args.ni:
105 | overwrite = True
106 | else:
107 | response = input("Folder already exists. Overwrite? (Y/N)")
108 | if response.upper() == "Y":
109 | overwrite = True
110 |
111 | if overwrite:
112 | shutil.rmtree(args.log_path)
113 | shutil.rmtree(tb_path)
114 | os.makedirs(args.log_path)
115 | if os.path.exists(tb_path):
116 | shutil.rmtree(tb_path)
117 | else:
118 | print("Folder exists. Program halted.")
119 | sys.exit(0)
120 | else:
121 | os.makedirs(args.log_path)
122 |
123 | with open(os.path.join(args.log_path, "config.yml"), "w") as f:
124 | yaml.dump(new_config, f, default_flow_style=False)
125 |
126 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
127 | # setup logger
128 | level = getattr(logging, args.verbose.upper(), None)
129 | if not isinstance(level, int):
130 | raise ValueError("level {} not supported".format(args.verbose))
131 |
132 | handler1 = logging.StreamHandler()
133 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))
134 | formatter = logging.Formatter(
135 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
136 | )
137 | handler1.setFormatter(formatter)
138 | handler2.setFormatter(formatter)
139 | logger = logging.getLogger()
140 | logger.addHandler(handler1)
141 | logger.addHandler(handler2)
142 | logger.setLevel(level)
143 |
144 | else:
145 | level = getattr(logging, args.verbose.upper(), None)
146 | if not isinstance(level, int):
147 | raise ValueError("level {} not supported".format(args.verbose))
148 |
149 | handler1 = logging.StreamHandler()
150 | formatter = logging.Formatter(
151 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
152 | )
153 | handler1.setFormatter(formatter)
154 | logger = logging.getLogger()
155 | logger.addHandler(handler1)
156 | logger.setLevel(level)
157 |
158 | if args.sample:
159 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)
160 | args.image_folder = os.path.join(
161 | args.exp, "image_samples", args.image_folder
162 | )
163 | if not os.path.exists(args.image_folder):
164 | os.makedirs(args.image_folder)
165 | else:
166 | if not (args.fid or args.interpolation):
167 | overwrite = False
168 | if args.ni:
169 | overwrite = True
170 | else:
171 | response = input(
172 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
173 | )
174 | if response.upper() == "Y":
175 | overwrite = True
176 |
177 | if overwrite:
178 | shutil.rmtree(args.image_folder)
179 | os.makedirs(args.image_folder)
180 | else:
181 | print("Output image folder exists. Program halted.")
182 | sys.exit(0)
183 |
184 | # add device
185 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
186 | logging.info("Using device: {}".format(device))
187 | new_config.device = device
188 |
189 | # set random seed
190 | torch.manual_seed(args.seed)
191 | np.random.seed(args.seed)
192 | if torch.cuda.is_available():
193 | torch.cuda.manual_seed_all(args.seed)
194 |
195 | torch.backends.cudnn.benchmark = True
196 |
197 | return args, new_config
198 |
199 |
200 | def dict2namespace(config):
201 | namespace = argparse.Namespace()
202 | for key, value in config.items():
203 | if isinstance(value, dict):
204 | new_value = dict2namespace(value)
205 | else:
206 | new_value = value
207 | setattr(namespace, key, new_value)
208 | return namespace
209 |
210 |
211 | def main():
212 | args, config = parse_args_and_config()
213 | logging.info("Writing log file to {}".format(args.log_path))
214 | logging.info("Exp instance id = {}".format(os.getpid()))
215 | logging.info("Exp comment = {}".format(args.comment))
216 |
217 | try:
218 | runner = Diffusion(args, config)
219 | if args.sample:
220 | runner.sample()
221 | elif args.test:
222 | runner.test()
223 | else:
224 | runner.train()
225 | except Exception:
226 | logging.error(traceback.format_exc())
227 |
228 | return 0
229 |
230 |
231 | if __name__ == "__main__":
232 | sys.exit(main())
233 |
--------------------------------------------------------------------------------
/ddim/models/ema.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class EMAHelper(object):
5 | def __init__(self, mu=0.999):
6 | self.mu = mu
7 | self.shadow = {}
8 |
9 | def register(self, module):
10 | if isinstance(module, nn.DataParallel):
11 | module = module.module
12 | for name, param in module.named_parameters():
13 | if param.requires_grad:
14 | self.shadow[name] = param.data.clone()
15 |
16 | def update(self, module):
17 | if isinstance(module, nn.DataParallel):
18 | module = module.module
19 | for name, param in module.named_parameters():
20 | if param.requires_grad:
21 | self.shadow[name].data = (
22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data
23 |
24 | def ema(self, module):
25 | if isinstance(module, nn.DataParallel):
26 | module = module.module
27 | for name, param in module.named_parameters():
28 | if param.requires_grad:
29 | param.data.copy_(self.shadow[name].data)
30 |
31 | def ema_copy(self, module):
32 | if isinstance(module, nn.DataParallel):
33 | inner_module = module.module
34 | module_copy = type(inner_module)(
35 | inner_module.config).to(inner_module.config.device)
36 | module_copy.load_state_dict(inner_module.state_dict())
37 | module_copy = nn.DataParallel(module_copy)
38 | else:
39 | module_copy = type(module)(module.config).to(module.config.device)
40 | module_copy.load_state_dict(module.state_dict())
41 | # module_copy = copy.deepcopy(module)
42 | self.ema(module_copy)
43 | return module_copy
44 |
45 | def state_dict(self):
46 | return self.shadow
47 |
48 | def load_state_dict(self, state_dict):
49 | self.shadow = state_dict
50 |
--------------------------------------------------------------------------------
/ddim/runners/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/ddim/runners/__init__.py
--------------------------------------------------------------------------------
/img/ldm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/img/ldm.png
--------------------------------------------------------------------------------
/img/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/img/overview.png
--------------------------------------------------------------------------------
/img/sd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/img/sd.png
--------------------------------------------------------------------------------
/linklink/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 | allreduce = dist.all_reduce
7 | allgather = dist.all_gather
8 | broadcast = dist.broadcast
9 | barrier = dist.barrier
10 | synchronize = torch.cuda.synchronize
11 | init_process_group = dist.init_process_group
12 | get_rank = dist.get_rank
13 | get_world_size = dist.get_world_size
14 |
15 |
16 | def get_local_rank():
17 | rank = dist.get_rank()
18 | return rank % torch.cuda.device_count()
19 |
20 |
21 | def initialize(backend='nccl', port='2333', job_envrion='normal'):
22 | """
23 | Function to initialize distributed enviroments.
24 | :param backend: nccl backend supports GPU DDP, this should not be modified.
25 | :param port: port to communication
26 | :param job_envrion: we refer normal enviroments as the pytorch suggested initialization, the slurm
27 | enviroment is used for SLURM job submit system.
28 | """
29 |
30 | if job_envrion == 'nomal':
31 | # this step is taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L129
32 | dist.init_process_group(backend=backend, init_method='tcp://224.66.41.62:23456')
33 | elif job_envrion == 'slurm':
34 | proc_id = int(os.environ['SLURM_PROCID'])
35 | ntasks = int(os.environ['SLURM_NTASKS'])
36 | node_list = os.environ['SLURM_NODELIST']
37 | if '[' in node_list:
38 | beg = node_list.find('[')
39 | pos1 = node_list.find('-', beg)
40 | if pos1 < 0:
41 | pos1 = 1000
42 | pos2 = node_list.find(',', beg)
43 | if pos2 < 0:
44 | pos2 = 1000
45 | node_list = node_list[:min(pos1, pos2)].replace('[', '')
46 | addr = node_list[8:].replace('-', '.')
47 | os.environ['MASTER_PORT'] = port
48 | os.environ['MASTER_ADDR'] = addr
49 | os.environ['WORLD_SIZE'] = str(ntasks)
50 | os.environ['RANK'] = str(proc_id)
51 | if backend == 'nccl':
52 | dist.init_process_group(backend='nccl')
53 | else:
54 | dist.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks)
55 | rank = dist.get_rank()
56 | device = rank % torch.cuda.device_count()
57 | torch.cuda.set_device(device)
58 | else:
59 | raise NotImplementedError
60 |
61 |
62 | def finalize():
63 | pass
64 |
65 |
66 | class nn(object):
67 | SyncBatchNorm2d = torch.nn.BatchNorm2d
68 | print("You are using fake SyncBatchNorm2d who is actually the official BatchNorm2d")
69 |
70 |
71 | class syncbnVarMode_t(object):
72 | L1 = None
73 | L2 = None
74 |
75 |
--------------------------------------------------------------------------------
/linklink/dist_helper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import numpy as np
5 | import linklink as link
6 |
7 |
8 | def save_file(dict, name):
9 | if link.get_local_rank() == 0:
10 | torch.save(dict, name)
11 |
12 |
13 | def dist_finalize():
14 | link.finalize()
15 |
16 |
17 | class AllReduce(torch.autograd.Function):
18 | @staticmethod
19 | def forward(ctx, input):
20 | output = torch.zeros_like(input)
21 | output.copy_(input)
22 | link.allreduce(output)
23 | return output
24 |
25 | @staticmethod
26 | def backward(ctx, grad_output):
27 | in_grad = torch.zeros_like(grad_output)
28 | in_grad.copy_(grad_output)
29 | link.allreduce(in_grad)
30 | return in_grad
31 |
32 |
33 | def allaverage(tensor):
34 | tensor.data /= link.get_world_size()
35 | link.allreduce(tensor.data)
36 | return tensor
37 |
38 |
39 | def allaverage_autograd(tensor):
40 | if tensor.is_cuda is True:
41 | tensor /= link.get_world_size()
42 | tensor = AllReduce().apply(tensor)
43 | return tensor
44 |
45 |
46 | def allreduce(tensor):
47 | link.allreduce(tensor.data)
48 |
49 |
50 | def link_dist(func):
51 |
52 | def wrapper(*args, **kwargs):
53 | dist_init()
54 | func(*args, **kwargs)
55 | dist_finalize()
56 |
57 | return wrapper
58 |
59 |
60 | def dist_init(method='slurm', device_id=0):
61 | if method == 'slurm':
62 | proc_id = int(os.environ['SLURM_PROCID'])
63 | # ntasks = int(os.environ['SLURM_NTASKS'])
64 | # node_list = os.environ['SLURM_NODELIST']
65 | num_gpus = torch.cuda.device_count()
66 | torch.cuda.set_device(proc_id % num_gpus)
67 | elif method == 'normal':
68 | torch.cuda.set_device(device_id)
69 | link.initialize(backend='nccl', job_envrion=method)
70 | world_size = link.get_world_size()
71 | rank = link.get_rank()
72 |
73 | return rank, world_size
74 |
75 |
76 | def dist_finalize():
77 | link.finalize()
78 |
79 |
80 | def simple_group_split(world_size, rank, num_groups):
81 | groups = []
82 | rank_list = np.split(np.arange(world_size), num_groups)
83 | rank_list = [list(map(int, x)) for x in rank_list]
84 | for i in range(num_groups):
85 | groups.append(link.new_group(rank_list[i]))
86 | group_size = world_size // num_groups
87 | return groups[rank//group_size]
88 |
89 |
90 | class DistModule(torch.nn.Module):
91 | def __init__(self, module, sync=False):
92 | super(DistModule, self).__init__()
93 | self.module = module
94 | self.broadcast_params()
95 |
96 | self.sync = sync
97 | if not sync:
98 | self._grad_accs = []
99 | self._register_hooks()
100 |
101 | def forward(self, *inputs, **kwargs):
102 | return self.module(*inputs, **kwargs)
103 |
104 | def _register_hooks(self):
105 | for i, (name, p) in enumerate(self.named_parameters()):
106 | if p.requires_grad:
107 | p_tmp = p.expand_as(p)
108 | grad_acc = p_tmp.grad_fn.next_functions[0][0]
109 | grad_acc.register_hook(self._make_hook(name, p, i))
110 | self._grad_accs.append(grad_acc)
111 |
112 | def _make_hook(self, name, p, i):
113 | def hook(*ignore):
114 | link.allreduce_async(name, p.grad.data)
115 | return hook
116 |
117 | def sync_gradients(self):
118 | """ average gradients """
119 | if self.sync and link.get_world_size() > 1:
120 | for name, param in self.module.named_parameters():
121 | if param.requires_grad and param.grad is not None:
122 | link.allreduce(param.grad.data)
123 | else:
124 | link.synchronize()
125 |
126 | def broadcast_params(self):
127 | """ broadcast model parameters """
128 | for name, param in self.module.state_dict().items():
129 | link.broadcast(param, 0)
130 |
131 |
132 | def _serialize_to_tensor(data, group=None):
133 | # backend = link.get_backend(group)
134 | # assert backend in ["gloo", "nccl"]
135 | # device = torch.device("cpu" if backend == "gloo" else "cuda")
136 | device = torch.cuda.current_device()
137 |
138 | buffer = pickle.dumps(data)
139 | if len(buffer) > 1024 ** 3:
140 | import logging
141 | logger = logging.getLogger('global')
142 | logger.warning(
143 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
144 | link.get_rank(), len(buffer) / (1024 ** 3), device
145 | )
146 | )
147 | storage = torch.ByteStorage.from_buffer(buffer)
148 | tensor = torch.ByteTensor(storage).to(device=device)
149 | return tensor
150 |
151 |
152 | def broadcast_object(obj, group=None):
153 | """make suare obj is picklable
154 | """
155 | if link.get_world_size() == 1:
156 | return obj
157 |
158 | serialized_tensor = _serialize_to_tensor(obj).cuda()
159 | numel = torch.IntTensor([serialized_tensor.numel()]).cuda()
160 | link.broadcast(numel, 0)
161 | # serialized_tensor from storage is not resizable
162 | serialized_tensor = serialized_tensor.clone()
163 | serialized_tensor.resize_(numel)
164 | link.broadcast(serialized_tensor, 0)
165 | serialized_bytes = serialized_tensor.cpu().numpy().tobytes()
166 | deserialized_obj = pickle.loads(serialized_bytes)
167 | return deserialized_obj
168 |
--------------------------------------------------------------------------------
/linklink/log_helper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import linklink as link
5 |
6 |
7 | _logger = None
8 | _logger_fh = None
9 | _logger_names = []
10 |
11 |
12 | def create_logger(log_file, level=logging.INFO):
13 | global _logger, _logger_fh
14 | if _logger is None:
15 | _logger = logging.getLogger()
16 | formatter = logging.Formatter(
17 | '[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s')
18 | fh = logging.FileHandler(log_file)
19 | fh.setFormatter(formatter)
20 | sh = logging.StreamHandler()
21 | sh.setFormatter(formatter)
22 | _logger.setLevel(level)
23 | _logger.addHandler(fh)
24 | _logger.addHandler(sh)
25 | _logger_fh = fh
26 | else:
27 | _logger.removeHandler(_logger_fh)
28 | _logger.setLevel(level)
29 |
30 | return _logger
31 |
32 |
33 | def get_logger(name, level=logging.INFO):
34 | global _logger_names
35 | logger = logging.getLogger(name)
36 | if name in _logger_names:
37 | return logger
38 |
39 | _logger_names.append(name)
40 | if link.get_rank() > 0:
41 | logger.addFilter(RankFilter())
42 |
43 | return logger
44 |
45 |
46 | class RankFilter(logging.Filter):
47 | def filter(self, record):
48 | return False
--------------------------------------------------------------------------------
/quant/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/quant/__init__.py
--------------------------------------------------------------------------------
/quant/adaptive_rounding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from quant.quant_layer import UniformAffineQuantizer, ste_round
4 | from enum import Enum
5 |
6 | RMODE = Enum('RMODE', ('LEARNED_ROUND_SIGMOID',
7 | 'NEAREST',
8 | 'NEAREST_STE',
9 | 'STOCHASTIC',
10 | 'LEARNED_HARD_SIGMOID'))
11 |
12 | class AdaRoundQuantizer(nn.Module):
13 |
14 | def __init__(self,
15 | uaqtizer: UniformAffineQuantizer,
16 | w: torch.Tensor,
17 | rmode: RMODE = RMODE.LEARNED_ROUND_SIGMOID,
18 | ) -> None:
19 |
20 | super().__init__()
21 | self.level = uaqtizer.level
22 | self.symmetric = uaqtizer.symmetric
23 | self.delta = uaqtizer.delta
24 | self.zero_point = uaqtizer.zero_point
25 | self.rmode = rmode
26 | self.soft_tgt = False
27 | self.gamma, self.zeta = -0.1, 1.1
28 | self.alpha = None
29 | self.init_alpha(x=w.clone())
30 |
31 | def init_alpha(self, x: torch.Tensor) -> None:
32 | self.delta = self.delta.to(x.device)
33 | if self.rmode == RMODE.LEARNED_HARD_SIGMOID:
34 | rest = (x / self.delta) - torch.floor(x / self.delta)
35 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1)
36 | self.alpha = nn.Parameter(alpha)
37 | else:
38 | raise NotImplementedError
39 |
40 | def get_soft_tgt(self) -> torch.Tensor:
41 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1)
42 |
43 | def forward(self,
44 | x: torch.Tensor
45 | ) -> torch.Tensor:
46 | if isinstance(self.delta, torch.Tensor):
47 | self.delta = self.delta.to(x.device)
48 | if isinstance(self.zero_point, torch.Tensor):
49 | self.zero_point = self.zero_point.to(x.device)
50 |
51 | x_floor = torch.floor(x / self.delta)
52 | if self.rmode == RMODE.NEAREST:
53 | x_int = torch.round(x / self.delta)
54 | elif self.rmode == RMODE.NEAREST_STE:
55 | x_int = ste_round(x / self.delta)
56 | elif self.rmode == RMODE.STOCHASTIC:
57 | x_int = x_floor + torch.bernoulli((x / self.delta) - x_floor)
58 | elif self.rmode == RMODE.LEARNED_HARD_SIGMOID:
59 | if self.soft_tgt:
60 | x_int = x_floor + self.get_soft_tgt().to(x.device)
61 | else:
62 | self.alpha = self.alpha.to(x.device)
63 | x_int = x_floor + (self.alpha >= 0).float()
64 | else:
65 | raise NotImplementedError
66 |
67 | NB, PB = -self.level // 2 if self.symmetric else 0, self.level // 2 - 1 if self.symmetric else self.level - 1
68 | x_q = torch.clamp(x_int + self.zero_point, NB, PB)
69 | x_dq = self.delta * (x_q - self.zero_point)
70 | return x_dq
71 |
72 | def extra_repr(self) -> str:
73 | s = 'level={}, symmetric={}, rmode={}'.format(self.level, self.symmetric, self.rmode)
74 | return s.format(**self.__dict__)
75 |
76 |
77 |
--------------------------------------------------------------------------------
/quant/data_generate.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | from typing import List, Union
3 | from ddim.models.diffusion import Model
4 | from ldm.models.diffusion.ddim import DDIMSampler
5 | from ldm.models.diffusion.ddpm import LatentDiffusion
6 | from ldm.models.diffusion.dpm_solver.sampler import DPMSolverSampler
7 | from ldm.models.diffusion.plms import PLMSSampler
8 | from typing import Tuple
9 | from torch import autocast
10 | import torch
11 |
12 |
13 | def generate_cali_text_guided_data(model: LatentDiffusion,
14 | sampler: Union[DPMSolverSampler, PLMSSampler, DDIMSampler],
15 | T: int,
16 | c: int,
17 | batch_size: int,
18 | prompts: Tuple[str],
19 | shape: List[int],
20 | precision_scope: Union[autocast, nullcontext],
21 | ) -> Tuple[torch.Tensor]:
22 | tmp = list()
23 | model.eval()
24 | with torch.no_grad():
25 | with precision_scope("cuda"):
26 | for t in range(1, T + 1):
27 | # x_{t + 1} = f(x_t, t_t, c_t)
28 | if t % c == 0:
29 | for p in prompts:
30 | uc_t = model.get_learned_conditioning(batch_size * [""])
31 | c_t = model.get_learned_conditioning(batch_size * [p])
32 | x_t, t_t = sampler.sample(S=T,
33 | conditioning=c_t,
34 | batch_size=batch_size,
35 | shape=shape,
36 | verbose=False,
37 | unconditional_guidance_scale=7.5,
38 | unconditional_conditioning=uc_t,
39 | untill_fake_t=t)
40 | if isinstance(sampler, (PLMSSampler, DDIMSampler)):
41 | ddpm_time_num = 1000 # in yaml
42 | real_time = (T - t) * ddpm_time_num // T + 1
43 | t_t = torch.full((batch_size,), real_time, device=sampler.model.betas.device, dtype=torch.long)
44 | tmp += [[x_t, t_t, c_t], [x_t, t_t, uc_t]]
45 |
46 | cali_data = ()
47 | for i in range(len(tmp[0])):
48 | cali_data += (torch.cat([x[i] for x in tmp]), )
49 | return cali_data
50 |
51 |
52 | def generate_cali_data_ddim(runnr,
53 | model: Model,
54 | T: int,
55 | c: int,
56 | batch_size: int,
57 | shape: List[int],
58 | ) -> Tuple[torch.Tensor]:
59 | tmp = list()
60 | for i in range(1, T + 1):
61 | # x_{t + 1} = f(x_t, t_t, c_t)
62 | if i % c == 0:
63 | from ddim.runners.diffusion import Diffusion
64 | runnr: Diffusion
65 | N, C, H, W = batch_size, *shape
66 | x = torch.randn((N, C, H, W), device=runnr.device)
67 | x_t, t_t = runnr.sample_image(x, model, untill_fake_t=i)[1:]
68 | tmp += [[x_t, t_t]]
69 | cali_data = ()
70 | for i in range(len(tmp[0])):
71 | cali_data += (torch.cat([x[i] for x in tmp]), )
72 | return cali_data
73 |
74 |
75 | def generate_cali_data_ldm(model: LatentDiffusion,
76 | T: int,
77 | c: int,
78 | batch_size: int,
79 | shape: List[int],
80 | vanilla: bool = False,
81 | dpm: bool = False,
82 | plms: bool = False,
83 | eta: float = 0.0,
84 | ) -> Tuple[torch.Tensor]:
85 | if vanilla:
86 | pass
87 | elif dpm:
88 | sampler = DPMSolverSampler(model)
89 | elif plms:
90 | sampler = PLMSSampler(model)
91 | else:
92 | sampler = DDIMSampler(model)
93 | tmp = list()
94 | for t in range(1, T + 1):
95 | if t % c == 0:
96 | if not vanilla:
97 | x_t, t_t = sampler.sample(S=T,
98 | batch_size=batch_size,
99 | shape=shape,
100 | verbose=False,
101 | eta=eta,
102 | untill_fake_t=t)
103 | if isinstance(sampler, (PLMSSampler, DDIMSampler)):
104 | ddpm_time_num = 1000 # in yaml
105 | real_time = (T - t) * ddpm_time_num // T + 1
106 | t_t = torch.full((batch_size,), real_time, device=sampler.model.betas.device, dtype=torch.long)
107 | tmp += [[x_t, t_t]]
108 | else:
109 | raise NotImplementedError("Vanilla LDM is not implemented yet, because it needs 1000 steps to generate one sample.")
110 | cali_data = ()
111 | for i in range(len(tmp[0])):
112 | cali_data += (torch.cat([x[i] for x in tmp]), )
113 | return cali_data
114 |
115 |
116 | def generate_cali_data_ldm_imagenet(model: LatentDiffusion,
117 | T: int,
118 | c: int,
119 | batch_size: int, # 8
120 | shape: List[int],
121 | eta: float = 0.0,
122 | scale: float = 3.0
123 | ) -> Tuple[torch.Tensor]:
124 | sampler = DDIMSampler(model)
125 | tmp = list()
126 | classes = [i for i in range(0, 1000, 1000 // 31)]
127 | with torch.no_grad():
128 | with model.ema_scope():
129 | for i in range(1, T + 1):
130 | if i % c == 0:
131 | uc_t = model.get_learned_conditioning(
132 | {model.cond_stage_key: torch.tensor(batch_size * [1000]).to(model.device)}
133 | )
134 | for class_label in classes:
135 | xc = torch.tensor(batch_size * [class_label])
136 | c_t = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
137 | x_t, t_t = sampler.sample(S=T,
138 | batch_size=batch_size,
139 | shape=shape,
140 | verbose=False,
141 | eta=eta,
142 | unconditional_guidance_scale=scale,
143 | unconditional_conditioning=uc_t,
144 | conditioning=c_t,
145 | untill_fake_t=i)
146 | if isinstance(sampler, DDIMSampler):
147 | ddpm_time_num = 1000
148 | real_time = (T - i) * ddpm_time_num // T + 1
149 | t_t = torch.full((batch_size,), real_time, device=sampler.model.betas.device, dtype=torch.long)
150 | tmp += [[x_t, t_t, c_t], [x_t, t_t, uc_t]]
151 | cali_data = ()
152 | for i in range(len(tmp[0])):
153 | cali_data += (torch.cat([x[i] for x in tmp]), )
154 | return cali_data
--------------------------------------------------------------------------------
/quant/quant_model.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch.nn as nn
3 | import torch
4 | from ldm.modules.attention import BasicTransformerBlock
5 | from quant.quant_block import QuantAttentionBlock, QuantAttnBlock, QuantQKMatMul, QuantResnetBlock, QuantSMVMatMul, QuantTemporalInformationBlock, QuantTemporalInformationBlockDDIM, b2qb, BaseQuantBlock
6 | from quant.quant_block import QuantBasicTransformerBlock, QuantResBlock
7 | from quant.quant_layer import QMODE, QuantLayer, StraightThrough
8 |
9 |
10 | class QuantModel(nn.Module):
11 |
12 | def __init__(self,
13 | model: nn.Module,
14 | wq_params: dict = {},
15 | aq_params: dict = {},
16 | cali: bool = True,
17 | **kwargs
18 | ) -> None:
19 | super().__init__()
20 | self.model = model
21 | self.softmax_a_bit = kwargs.get("softmax_a_bit", 8)
22 | self.in_channels = model.in_channels
23 | if hasattr(model, 'image_size'):
24 | self.image_size = model.image_size
25 | self.B = b2qb(aq_params['leaf_param'])
26 | self.quant_module(self.model, wq_params, aq_params, aq_mode=kwargs.get("aq_mode", [QMODE.NORMAL.value]), prev_name=None)
27 | self.quant_block(self.model, wq_params, aq_params)
28 | if cali:
29 | self.get_tib(self.model, wq_params, aq_params)
30 |
31 | def get_tib(self,
32 | module: nn.Module,
33 | wq_params: dict = {},
34 | aq_params: dict = {},
35 | ) -> QuantTemporalInformationBlock:
36 | for name, child in module.named_children():
37 | if name == 'temb':
38 | self.tib = QuantTemporalInformationBlockDDIM(child, aq_params, self.model.ch)
39 | elif name == 'time_embed':
40 | self.tib = QuantTemporalInformationBlock(child, aq_params, self.model.model_channels, None)
41 | elif isinstance(child, QuantResBlock):
42 | self.tib.add_emb_layer(child.emb_layers)
43 | elif isinstance(child, QuantResnetBlock):
44 | self.tib.add_temb_proj(child.temb_proj)
45 | else:
46 | self.get_tib(child, wq_params, aq_params)
47 |
48 |
49 | def quant_module(self,
50 | module: nn.Module,
51 | wq_params: dict = {},
52 | aq_params: dict = {},
53 | aq_mode: List[int] = [QMODE.NORMAL.value],
54 | prev_name: str = None,
55 | ) -> None:
56 | for name, child in module.named_children():
57 | if isinstance(child, tuple(QuantLayer.QMAP.keys())) and \
58 | 'skip' not in name and 'op' not in name and not (prev_name == 'downsample' and name == 'conv') and 'shortcut' not in name: # refer to PTQD
59 | if prev_name is not None and 'emb_layers' in prev_name and '1' in name or 'temb_proj' in name:
60 | setattr(module, name, QuantLayer(child, wq_params, aq_params, aq_mode=aq_mode, quant_emb=True))
61 | continue
62 | setattr(module, name, QuantLayer(child, wq_params, aq_params, aq_mode=aq_mode))
63 | elif isinstance(child, StraightThrough):
64 | continue
65 | else:
66 | self.quant_module(child, wq_params, aq_params, aq_mode=aq_mode, prev_name=name)
67 |
68 | def quant_block(self,
69 | module: nn.Module,
70 | wq_params: dict = {},
71 | aq_params: dict = {},
72 | ) -> None:
73 | for name, child in module.named_children():
74 | if child.__class__.__name__ in self.B:
75 | if self.B[child.__class__.__name__] in [QuantBasicTransformerBlock, QuantAttnBlock]:
76 | setattr(module, name, self.B[child.__class__.__name__](child, aq_params, softmax_a_bit = self.softmax_a_bit))
77 | elif self.B[child.__class__.__name__] in [QuantResnetBlock, QuantAttentionBlock, QuantResBlock]:
78 | setattr(module, name, self.B[child.__class__.__name__](child, aq_params))
79 | elif self.B[child.__class__.__name__] in [QuantSMVMatMul]:
80 | setattr(module, name, self.B[child.__class__.__name__](aq_params, softmax_a_bit = self.softmax_a_bit))
81 | elif self.B[child.__class__.__name__] in [QuantQKMatMul]:
82 | setattr(module, name, self.B[child.__class__.__name__](aq_params))
83 | else:
84 | self.quant_block(child, wq_params, aq_params)
85 |
86 | def set_quant_state(self,
87 | use_wq: bool = False,
88 | use_aq: bool = False
89 | ) -> None:
90 | for m in self.model.modules():
91 | if isinstance(m, (BaseQuantBlock, QuantLayer)):
92 | m.set_quant_state(use_wq=use_wq, use_aq=use_aq)
93 |
94 | def forward(self,
95 | x: torch.Tensor,
96 | timestep: int = None,
97 | context: torch.Tensor = None,
98 | ) -> torch.Tensor:
99 | if context is None:
100 | return self.model(x, timestep)
101 | return self.model(x, timestep, context)
102 |
103 | def disable_out_quantization(self) -> None:
104 | modules = []
105 | for m in self.model.modules():
106 | if isinstance(m, QuantLayer):
107 | modules.append(m)
108 | modules: List[QuantLayer]
109 | # disable the last layer and the first layer
110 | modules[-1].use_wq = False
111 | modules[-1].disable_aq = True
112 | modules[0].disable_aq = True
113 | modules[0].use_wq = False
114 | modules[1].disable_aq = True
115 | modules[2].disable_aq = True
116 | modules[2].use_wq = False
117 | modules[3].disable_aq = True
118 | modules[0].ignore_recon = True
119 | modules[2].ignore_recon = True
120 | modules[-1].ignore_recon = True
121 |
122 | def set_grad_ckpt(self, grad_ckpt: bool) -> None:
123 | for _, module in self.model.named_modules():
124 | if isinstance(module, (QuantBasicTransformerBlock, BasicTransformerBlock)):
125 | module.checkpoint = grad_ckpt
126 |
127 | def synchorize_activation_statistics(self):
128 | import linklink.dist_helper as dist
129 | for module in self.modules():
130 | if isinstance(module, QuantLayer):
131 | if module.aqtizer.delta is not None:
132 | dist.allaverage(module.aqtizer.delta)
133 |
134 |
135 | def set_running_stat(self,
136 | running_stat: bool = False
137 | ) -> None:
138 | for m in self.model.modules():
139 | if isinstance(m, QuantBasicTransformerBlock):
140 | m.attn1.aqtizer_q.running_stat = running_stat
141 | m.attn1.aqtizer_k.running_stat = running_stat
142 | m.attn1.aqtizer_v.running_stat = running_stat
143 | m.attn1.aqtizer_w.running_stat = running_stat
144 | m.attn2.aqtizer_q.running_stat = running_stat
145 | m.attn2.aqtizer_k.running_stat = running_stat
146 | m.attn2.aqtizer_v.running_stat = running_stat
147 | m.attn2.aqtizer_w.running_stat = running_stat
148 | elif isinstance(m, QuantQKMatMul):
149 | m.aqtizer_q.running_stat = running_stat
150 | m.aqtizer_k.running_stat = running_stat
151 | elif isinstance(m, QuantSMVMatMul):
152 | m.aqtizer_v.running_stat = running_stat
153 | m.aqtizer_w.running_stat = running_stat
154 | elif isinstance(m, QuantAttnBlock):
155 | m.aqtizer_q.running_stat = running_stat
156 | m.aqtizer_k.running_stat = running_stat
157 | m.aqtizer_v.running_stat = running_stat
158 | m.aqtizer_w.running_stat = running_stat
159 | elif isinstance(m, QuantLayer):
160 | m.set_running_stat(running_stat)
161 |
162 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --index-url https://download.pytorch.org/whl/nightly/cu121
2 | --pre
3 | torch
4 | torchvision
5 | torchaudio
6 | transformers==4.31.0
7 | lmdb
--------------------------------------------------------------------------------
/sample_diffusion_ddim.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import logging
4 | import os
5 | from pytorch_lightning import seed_everything
6 |
7 | import yaml
8 |
9 | from ddim.runners.diffusion import Diffusion
10 | from quant.quant_layer import QMODE
11 |
12 |
13 | def get_parser():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--config", type=str, required=True, help="Path to the config file"
17 | )
18 | parser.add_argument("--seed", type=int, default=1234, help="Random seed")
19 | parser.add_argument(
20 | "-l",
21 | "--logdir",
22 | type=str,
23 | nargs="?",
24 | help="extra logdir",
25 | default="none"
26 | )
27 | parser.add_argument("--use_pretrained", action="store_true")
28 | parser.add_argument(
29 | "--sample_type",
30 | type=str,
31 | default="generalized",
32 | help="sampling approach (generalized or ddpm_noisy)",
33 | )
34 | parser.add_argument(
35 | "--skip_type",
36 | type=str,
37 | default="uniform",
38 | help="skip according to (uniform or quadratic)",
39 | )
40 | parser.add_argument(
41 | "--timesteps", type=int, default=1000, help="number of steps involved"
42 | )
43 | parser.add_argument(
44 | "--eta",
45 | type=float,
46 | default=0.0,
47 | help="eta used to control the variances of sigma",
48 | )
49 | parser.add_argument("--sequence", action="store_true")
50 |
51 | # quantization configs
52 | parser.add_argument(
53 | "--ptq", action="store_true", help="apply post-training quantization"
54 | )
55 | parser.add_argument(
56 | "--wq",
57 | type=int,
58 | default=8,
59 | help="int bit for weight quantization",
60 | )
61 | parser.add_argument(
62 | "--aq",
63 | type=int,
64 | default=8,
65 | help="int bit for activation quantization",
66 | )
67 | parser.add_argument(
68 | "--max_images", type=int, default=50000, help="number of images to sample"
69 | )
70 |
71 | # qdiff specific configs
72 | parser.add_argument(
73 | "--cali_ckpt", type=str,
74 | help="path for calibrated model ckpt"
75 | )
76 | parser.add_argument(
77 | "--softmax_a_bit",type=int, default=8,
78 | help="attn softmax activation bit"
79 | )
80 | parser.add_argument(
81 | "--verbose", action="store_true",
82 | help="print out info like quantized model arch"
83 | )
84 |
85 | parser.add_argument(
86 | "--cali",
87 | action="store_true",
88 | help="whether to calibrate the model"
89 | )
90 | parser.add_argument(
91 | "--cali_save_path",
92 | type=str,
93 | default="cali_ckpt/quant_ddim.pth",
94 | help="path to save the calibrated ckpt"
95 | )
96 | parser.add_argument(
97 | "--interval_length",
98 | type=int,
99 | default=1,
100 | help="calibration interval length"
101 | )
102 | parser.add_argument(
103 | '--use_aq',
104 | action='store_true',
105 | help='whether to use activation quantization'
106 | )
107 | return parser
108 |
109 |
110 | def dict2namespace(config):
111 | namespace = argparse.Namespace()
112 | for key, value in config.items():
113 | if isinstance(value, dict):
114 | new_value = dict2namespace(value)
115 | else:
116 | new_value = value
117 | setattr(namespace, key, new_value)
118 | return namespace
119 |
120 |
121 | if __name__ == '__main__':
122 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
123 |
124 | parser = get_parser()
125 | args = parser.parse_args()
126 | with open(args.config, "r") as f:
127 | config = yaml.safe_load(f)
128 | config = dict2namespace(config)
129 |
130 | # fix random seed
131 | seed_everything(args.seed)
132 |
133 | # setup logger
134 | logdir = os.path.join(args.logdir, "samples", now)
135 | os.makedirs(logdir)
136 | log_path = os.path.join(logdir, "run.log")
137 | logging.basicConfig(
138 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
139 | datefmt='%m/%d/%Y %H:%M:%S',
140 | level=logging.INFO,
141 | handlers=[
142 | logging.FileHandler(log_path),
143 | logging.StreamHandler()
144 | ]
145 | )
146 | logger = logging.getLogger(__name__)
147 | logger.info(75 * "=")
148 | logger.info(f"Host {os.uname()[1]}")
149 | logger.info("logging to:")
150 | imglogdir = os.path.join(logdir, "img")
151 | nplogdir = os.path.join(logdir, "numpy")
152 | os.makedirs(nplogdir)
153 | args.image_folder = imglogdir
154 | args.numpy_folder = nplogdir
155 |
156 | os.makedirs(imglogdir)
157 | logger.info(logdir)
158 | logger.info(75 * "=")
159 | p = [QMODE.NORMAL.value]
160 | p.append(QMODE.QDIFF.value)
161 | args.q_mode = p
162 | args.fid = True
163 | args.log_path = "test/"
164 | args.use_pretrained = True
165 | args.use_aq = args.use_aq
166 | args.asym = True
167 | args.running_stat = True
168 | config.device = 'cuda0'
169 | runner = Diffusion(args, config)
170 | runner.sample()
171 |
172 |
--------------------------------------------------------------------------------
/stable-diffusion/assets/a-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-painting-of-a-fire.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/a-photograph-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-photograph-of-a-fire.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/birdhouse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/birdhouse.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/fire.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/inpainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/inpainting.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/modelfigure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/modelfigure.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/rdm-preview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/rdm-preview.jpg
--------------------------------------------------------------------------------
/stable-diffusion/assets/reconstruction1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/reconstruction1.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/reconstruction2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/reconstruction2.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/results.gif
--------------------------------------------------------------------------------
/stable-diffusion/assets/rick.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/rick.jpeg
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/img2img/mountains-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/mountains-1.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/img2img/mountains-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/mountains-2.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/img2img/mountains-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/mountains-3.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/txt2img/000002025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/000002025.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/txt2img/000002035.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/000002035.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/txt2img-convsample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/txt2img-convsample.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/txt2img-preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/txt2img-preview.png
--------------------------------------------------------------------------------
/stable-diffusion/assets/v1-variants-scores.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/v1-variants-scores.jpg
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/stable-diffusion/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/stable-diffusion/data/DejaVuSans.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/DejaVuSans.ttf
--------------------------------------------------------------------------------
/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg
--------------------------------------------------------------------------------
/stable-diffusion/data/example_conditioning/text_conditional/sample_0.txt:
--------------------------------------------------------------------------------
1 | A basket of cerries
2 |
--------------------------------------------------------------------------------
/stable-diffusion/data/imagenet_train_hr_indices.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/imagenet_train_hr_indices.p
--------------------------------------------------------------------------------
/stable-diffusion/data/imagenet_val_hr_indices.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/imagenet_val_hr_indices.p
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/bench2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bench2.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/bench2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bench2_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png
--------------------------------------------------------------------------------
/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png
--------------------------------------------------------------------------------
/stable-diffusion/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.11.0
10 | - torchvision=0.12.0
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - diffusers
15 | - opencv-python==4.1.2.30
16 | - pudb==2019.2
17 | - invisible-watermark
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.4.2
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit>=0.73.1
24 | - einops==0.3.0
25 | - torch-fidelity==0.3.0
26 | - transformers==4.19.2
27 | - torchmetrics==0.6.0
28 | - kornia==0.6
29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31 | - -e .
32 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/data/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/stable-diffusion/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 |
5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6 |
7 |
8 | class DPMSolverSampler(object):
9 | def __init__(self, model, **kwargs):
10 | super().__init__()
11 | self.model = model
12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
14 |
15 | def register_buffer(self, name, attr):
16 | if type(attr) == torch.Tensor:
17 | if attr.device != torch.device("cuda"):
18 | attr = attr.to(torch.device("cuda"))
19 | setattr(self, name, attr)
20 |
21 | @torch.no_grad()
22 | def sample(self,
23 | S,
24 | batch_size,
25 | shape,
26 | conditioning=None,
27 | callback=None,
28 | normals_sequence=None,
29 | img_callback=None,
30 | quantize_x0=False,
31 | eta=0.,
32 | mask=None,
33 | x0=None,
34 | temperature=1.,
35 | noise_dropout=0.,
36 | score_corrector=None,
37 | corrector_kwargs=None,
38 | verbose=True,
39 | x_T=None,
40 | log_every_t=100,
41 | unconditional_guidance_scale=1.,
42 | unconditional_conditioning=None,
43 | untill_fake_t: int = None,
44 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
45 | **kwargs
46 | ):
47 | if conditioning is not None:
48 | if isinstance(conditioning, dict):
49 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
50 | if cbs != batch_size:
51 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
52 | else:
53 | if conditioning.shape[0] != batch_size:
54 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
55 |
56 | if not untill_fake_t:
57 | untill_fake_t = untill_fake_t = float('inf')
58 | # sampling
59 | C, H, W = shape
60 | size = (batch_size, C, H, W)
61 |
62 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
63 |
64 | device = self.model.betas.device
65 | if x_T is None:
66 | img = torch.randn(size, device=device)
67 | else:
68 | img = x_T
69 |
70 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
71 |
72 | model_fn = model_wrapper(
73 | lambda x, t, c: self.model.apply_model(x, t, c),
74 | ns,
75 | model_type="noise",
76 | guidance_type="classifier-free",
77 | condition=conditioning,
78 | unconditional_condition=unconditional_conditioning,
79 | guidance_scale=unconditional_guidance_scale,
80 | )
81 |
82 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
83 | x, vec_t = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True, untill_fake_t=untill_fake_t)
84 |
85 | return x.to(device), vec_t
86 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/bsr_sr/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l2
10 | first_stage_key: image
11 | cond_stage_key: LR_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: false
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 160
23 | attention_resolutions:
24 | - 16
25 | - 8
26 | num_res_blocks: 2
27 | channel_mult:
28 | - 1
29 | - 2
30 | - 2
31 | - 4
32 | num_head_channels: 32
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: torch.nn.Identity
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 64
61 | wrap: false
62 | num_workers: 12
63 | train:
64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain
65 | params:
66 | size: 256
67 | degradation: bsrgan_light
68 | downscale_f: 4
69 | min_crop_f: 0.5
70 | max_crop_f: 1.0
71 | random_crop: true
72 | validation:
73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation
74 | params:
75 | size: 256
76 | degradation: bsrgan_light
77 | downscale_f: 4
78 | min_crop_f: 0.5
79 | max_crop_f: 1.0
80 | random_crop: true
81 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/cin256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 1
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 4
41 | n_embed: 16384
42 | ddconfig:
43 | double_z: false
44 | z_channels: 4
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 32
57 | dropout: 0.0
58 | lossconfig:
59 | target: torch.nn.Identity
60 | cond_stage_config:
61 | target: ldm.modules.encoders.modules.ClassEmbedder
62 | params:
63 | embed_dim: 512
64 | key: class_label
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 64
69 | num_workers: 12
70 | wrap: false
71 | train:
72 | target: ldm.data.imagenet.ImageNetTrain
73 | params:
74 | config:
75 | size: 256
76 | validation:
77 | target: ldm.data.imagenet.ImageNetValidation
78 | params:
79 | config:
80 | size: 256
81 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/ffhq256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 42
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.FFHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.FFHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/inpainting_big/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: masked_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | monitor: val/loss
16 | scheduler_config:
17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
18 | params:
19 | verbosity_interval: 0
20 | warm_up_steps: 1000
21 | max_decay_steps: 50000
22 | lr_start: 0.001
23 | lr_max: 0.1
24 | lr_min: 0.0001
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64
29 | in_channels: 7
30 | out_channels: 3
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 2
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 8
43 | resblock_updown: true
44 | first_stage_config:
45 | target: ldm.models.autoencoder.VQModelInterface
46 | params:
47 | embed_dim: 3
48 | n_embed: 8192
49 | monitor: val/rec_loss
50 | ddconfig:
51 | attn_type: none
52 | double_z: false
53 | z_channels: 3
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: ldm.modules.losses.contperceptual.DummyLoss
67 | cond_stage_config: __is_first_stage__
68 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/layout2img-openimages256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: coordinates_bbox
12 | image_size: 64
13 | channels: 3
14 | conditioning_key: crossattn
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 3
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 8
25 | - 4
26 | - 2
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 2
31 | - 3
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 3
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | monitor: val/rec_loss
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 512
63 | n_layer: 16
64 | vocab_size: 8192
65 | max_seq_len: 92
66 | use_tokenizer: false
67 | monitor: val/loss_simple_ema
68 | data:
69 | target: main.DataModuleFromConfig
70 | params:
71 | batch_size: 24
72 | wrap: false
73 | num_workers: 10
74 | train:
75 | target: ldm.data.openimages.OpenImagesBBoxTrain
76 | params:
77 | size: 256
78 | validation:
79 | target: ldm.data.openimages.OpenImagesBBoxValidation
80 | params:
81 | size: 256
82 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/lsun_beds256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.lsun.LSUNBedroomsTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.lsun.LSUNBedroomsValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/lsun_churches256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: false
16 | concat_mode: false
17 | scale_by_std: true
18 | monitor: val/loss_simple_ema
19 | scheduler_config:
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps:
23 | - 10000
24 | cycle_lengths:
25 | - 10000000000000
26 | f_start:
27 | - 1.0e-06
28 | f_max:
29 | - 1.0
30 | f_min:
31 | - 1.0
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 192
39 | attention_resolutions:
40 | - 1
41 | - 2
42 | - 4
43 | - 8
44 | num_res_blocks: 2
45 | channel_mult:
46 | - 1
47 | - 2
48 | - 2
49 | - 4
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config: '__is_unconditional__'
78 |
79 | data:
80 | target: main.DataModuleFromConfig
81 | params:
82 | batch_size: 96
83 | num_workers: 5
84 | wrap: false
85 | train:
86 | target: ldm.data.lsun.LSUNChurchesTrain
87 | params:
88 | size: 256
89 | validation:
90 | target: ldm.data.lsun.LSUNChurchesValidation
91 | params:
92 | size: 256
93 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/semantic_synthesis256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | ddconfig:
39 | double_z: false
40 | z_channels: 3
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 2
48 | - 4
49 | num_res_blocks: 2
50 | attn_resolutions: []
51 | dropout: 0.0
52 | lossconfig:
53 | target: torch.nn.Identity
54 | cond_stage_config:
55 | target: ldm.modules.encoders.modules.SpatialRescaler
56 | params:
57 | n_stages: 2
58 | in_channels: 182
59 | out_channels: 3
60 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/semantic_synthesis512/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 128
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 128
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: ldm.modules.encoders.modules.SpatialRescaler
57 | params:
58 | n_stages: 2
59 | in_channels: 182
60 | out_channels: 3
61 | data:
62 | target: main.DataModuleFromConfig
63 | params:
64 | batch_size: 8
65 | wrap: false
66 | num_workers: 10
67 | train:
68 | target: ldm.data.landscapes.RFWTrain
69 | params:
70 | size: 768
71 | crop_size: 512
72 | segmentation_to_float32: true
73 | validation:
74 | target: ldm.data.landscapes.RFWValidation
75 | params:
76 | size: 768
77 | crop_size: 512
78 | segmentation_to_float32: true
79 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/text2img256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 192
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 5
34 | num_head_channels: 32
35 | use_spatial_transformer: true
36 | transformer_depth: 1
37 | context_dim: 640
38 | first_stage_config:
39 | target: ldm.models.autoencoder.VQModelInterface
40 | params:
41 | embed_dim: 3
42 | n_embed: 8192
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 640
63 | n_layer: 32
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 28
68 | num_workers: 10
69 | wrap: false
70 | train:
71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation
76 | params:
77 | size: 256
78 |
--------------------------------------------------------------------------------
/stable-diffusion/scripts/download_first_stages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
11 |
12 |
13 |
14 | cd models/first_stage_models/kl-f4
15 | unzip -o model.zip
16 |
17 | cd ../kl-f8
18 | unzip -o model.zip
19 |
20 | cd ../kl-f16
21 | unzip -o model.zip
22 |
23 | cd ../kl-f32
24 | unzip -o model.zip
25 |
26 | cd ../vq-f4
27 | unzip -o model.zip
28 |
29 | cd ../vq-f4-noattn
30 | unzip -o model.zip
31 |
32 | cd ../vq-f8
33 | unzip -o model.zip
34 |
35 | cd ../vq-f8-n256
36 | unzip -o model.zip
37 |
38 | cd ../vq-f16
39 | unzip -o model.zip
40 |
41 | cd ../..
--------------------------------------------------------------------------------
/stable-diffusion/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
13 |
14 |
15 |
16 | cd models/ldm/celeba256
17 | unzip -o celeba-256.zip
18 |
19 | cd ../ffhq256
20 | unzip -o ffhq-256.zip
21 |
22 | cd ../lsun_churches256
23 | unzip -o lsun_churches-256.zip
24 |
25 | cd ../lsun_beds256
26 | unzip -o lsun_beds-256.zip
27 |
28 | cd ../text2img256
29 | unzip -o model.zip
30 |
31 | cd ../cin256
32 | unzip -o model.zip
33 |
34 | cd ../semantic_synthesis512
35 | unzip -o model.zip
36 |
37 | cd ../semantic_synthesis256
38 | unzip -o model.zip
39 |
40 | cd ../bsr_sr
41 | unzip -o model.zip
42 |
43 | cd ../layout2img-openimages256
44 | unzip -o model.zip
45 |
46 | cd ../inpainting_big
47 | unzip -o model.zip
48 |
49 | cd ../..
50 |
--------------------------------------------------------------------------------
/stable-diffusion/scripts/inpaint.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | from main import instantiate_from_config
8 | from ldm.models.diffusion.ddim import DDIMSampler
9 |
10 |
11 | def make_batch(image, mask, device):
12 | image = np.array(Image.open(image).convert("RGB"))
13 | image = image.astype(np.float32)/255.0
14 | image = image[None].transpose(0,3,1,2)
15 | image = torch.from_numpy(image)
16 |
17 | mask = np.array(Image.open(mask).convert("L"))
18 | mask = mask.astype(np.float32)/255.0
19 | mask = mask[None,None]
20 | mask[mask < 0.5] = 0
21 | mask[mask >= 0.5] = 1
22 | mask = torch.from_numpy(mask)
23 |
24 | masked_image = (1-mask)*image
25 |
26 | batch = {"image": image, "mask": mask, "masked_image": masked_image}
27 | for k in batch:
28 | batch[k] = batch[k].to(device=device)
29 | batch[k] = batch[k]*2.0-1.0
30 | return batch
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--indir",
37 | type=str,
38 | nargs="?",
39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
40 | )
41 | parser.add_argument(
42 | "--outdir",
43 | type=str,
44 | nargs="?",
45 | help="dir to write results to",
46 | )
47 | parser.add_argument(
48 | "--steps",
49 | type=int,
50 | default=50,
51 | help="number of ddim sampling steps",
52 | )
53 | opt = parser.parse_args()
54 |
55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
56 | images = [x.replace("_mask.png", ".png") for x in masks]
57 | print(f"Found {len(masks)} inputs.")
58 |
59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
60 | model = instantiate_from_config(config.model)
61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
62 | strict=False)
63 |
64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
65 | model = model.to(device)
66 | sampler = DDIMSampler(model)
67 |
68 | os.makedirs(opt.outdir, exist_ok=True)
69 | with torch.no_grad():
70 | with model.ema_scope():
71 | for image, mask in tqdm(zip(images, masks)):
72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1])
73 | batch = make_batch(image, mask, device=device)
74 |
75 | # encode masked image and concat downsampled mask
76 | c = model.cond_stage_model.encode(batch["masked_image"])
77 | cc = torch.nn.functional.interpolate(batch["mask"],
78 | size=c.shape[-2:])
79 | c = torch.cat((c, cc), dim=1)
80 |
81 | shape = (c.shape[1]-1,)+c.shape[2:]
82 | samples_ddim, _ = sampler.sample(S=opt.steps,
83 | conditioning=c,
84 | batch_size=c.shape[0],
85 | shape=shape,
86 | verbose=False)
87 | x_samples_ddim = model.decode_first_stage(samples_ddim)
88 |
89 | image = torch.clamp((batch["image"]+1.0)/2.0,
90 | min=0.0, max=1.0)
91 | mask = torch.clamp((batch["mask"]+1.0)/2.0,
92 | min=0.0, max=1.0)
93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
94 | min=0.0, max=1.0)
95 |
96 | inpainted = (1-mask)*image+mask*predicted_image
97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
99 |
--------------------------------------------------------------------------------
/stable-diffusion/scripts/tests/test_watermark.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import fire
3 | from imwatermark import WatermarkDecoder
4 |
5 |
6 | def testit(img_path):
7 | bgr = cv2.imread(img_path)
8 | decoder = WatermarkDecoder('bytes', 136)
9 | watermark = decoder.decode(bgr, 'dwtDct')
10 | try:
11 | dec = watermark.decode('utf-8')
12 | except:
13 | dec = "null"
14 | print(dec)
15 |
16 |
17 | if __name__ == "__main__":
18 | fire.Fire(testit)
--------------------------------------------------------------------------------
/stable-diffusion/scripts/train_searcher.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import scann
4 | import argparse
5 | import glob
6 | from multiprocessing import cpu_count
7 | from tqdm import tqdm
8 |
9 | from ldm.util import parallel_data_prefetch
10 |
11 |
12 | def search_bruteforce(searcher):
13 | return searcher.score_brute_force().build()
14 |
15 |
16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
17 | partioning_trainsize, num_leaves, num_leaves_to_search):
18 | return searcher.tree(num_leaves=num_leaves,
19 | num_leaves_to_search=num_leaves_to_search,
20 | training_sample_size=partioning_trainsize). \
21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
22 |
23 |
24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
26 | reorder_k).build()
27 |
28 | def load_datapool(dpath):
29 |
30 |
31 | def load_single_file(saved_embeddings):
32 | compressed = np.load(saved_embeddings)
33 | database = {key: compressed[key] for key in compressed.files}
34 | return database
35 |
36 | def load_multi_files(data_archive):
37 | database = {key: [] for key in data_archive[0].files}
38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
39 | for key in d.files:
40 | database[key].append(d[key])
41 |
42 | return database
43 |
44 | print(f'Load saved patch embedding from "{dpath}"')
45 | file_content = glob.glob(os.path.join(dpath, '*.npz'))
46 |
47 | if len(file_content) == 1:
48 | data_pool = load_single_file(file_content[0])
49 | elif len(file_content) > 1:
50 | data = [np.load(f) for f in file_content]
51 | prefetched_data = parallel_data_prefetch(load_multi_files, data,
52 | n_proc=min(len(data), cpu_count()), target_data_type='dict')
53 |
54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
55 | else:
56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
57 |
58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
59 | return data_pool
60 |
61 |
62 | def train_searcher(opt,
63 | metric='dot_product',
64 | partioning_trainsize=None,
65 | reorder_k=None,
66 | # todo tune
67 | aiq_thld=0.2,
68 | dims_per_block=2,
69 | num_leaves=None,
70 | num_leaves_to_search=None,):
71 |
72 | data_pool = load_datapool(opt.database)
73 | k = opt.knn
74 |
75 | if not reorder_k:
76 | reorder_k = 2 * k
77 |
78 | # normalize
79 | # embeddings =
80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
81 | pool_size = data_pool['embedding'].shape[0]
82 |
83 | print(*(['#'] * 100))
84 | print('Initializing scaNN searcher with the following values:')
85 | print(f'k: {k}')
86 | print(f'metric: {metric}')
87 | print(f'reorder_k: {reorder_k}')
88 | print(f'anisotropic_quantization_threshold: {aiq_thld}')
89 | print(f'dims_per_block: {dims_per_block}')
90 | print(*(['#'] * 100))
91 | print('Start training searcher....')
92 | print(f'N samples in pool is {pool_size}')
93 |
94 | # this reflects the recommended design choices proposed at
95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
96 | if pool_size < 2e4:
97 | print('Using brute force search.')
98 | searcher = search_bruteforce(searcher)
99 | elif 2e4 <= pool_size and pool_size < 1e5:
100 | print('Using asymmetric hashing search and reordering.')
101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
102 | else:
103 | print('Using using partioning, asymmetric hashing search and reordering.')
104 |
105 | if not partioning_trainsize:
106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10
107 | if not num_leaves:
108 | num_leaves = int(np.sqrt(pool_size))
109 |
110 | if not num_leaves_to_search:
111 | num_leaves_to_search = max(num_leaves // 20, 1)
112 |
113 | print('Partitioning params:')
114 | print(f'num_leaves: {num_leaves}')
115 | print(f'num_leaves_to_search: {num_leaves_to_search}')
116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
118 | partioning_trainsize, num_leaves, num_leaves_to_search)
119 |
120 | print('Finish training searcher')
121 | searcher_savedir = opt.target_path
122 | os.makedirs(searcher_savedir, exist_ok=True)
123 | searcher.serialize(searcher_savedir)
124 | print(f'Saved trained searcher under "{searcher_savedir}"')
125 |
126 | if __name__ == '__main__':
127 | sys.path.append(os.getcwd())
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--database',
130 | '-d',
131 | default='data/rdm/retrieval_databases/openimages',
132 | type=str,
133 | help='path to folder containing the clip feature of the database')
134 | parser.add_argument('--target_path',
135 | '-t',
136 | default='data/rdm/searchers/openimages',
137 | type=str,
138 | help='path to the target folder where the searcher shall be stored.')
139 | parser.add_argument('--knn',
140 | '-k',
141 | default=20,
142 | type=int,
143 | help='number of nearest neighbors, for which the searcher shall be optimized')
144 |
145 | opt, _ = parser.parse_known_args()
146 |
147 | train_searcher(opt,)
--------------------------------------------------------------------------------
/stable-diffusion/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------