├── .gitignore ├── LICENSE ├── README.md ├── assets ├── poster.pdf └── slides.pdf ├── ddim ├── README.md ├── configs │ ├── bedroom.yml │ ├── celeba.yml │ ├── church.yml │ └── cifar10.yml ├── datasets │ ├── __init__.py │ ├── celeba.py │ ├── ffhq.py │ ├── lsun.py │ ├── utils.py │ └── vision.py ├── functions │ ├── __init__.py │ ├── ckpt_util.py │ ├── denoising.py │ └── losses.py ├── main.py ├── models │ ├── diffusion.py │ └── ema.py └── runners │ ├── __init__.py │ └── diffusion.py ├── img ├── ldm.png ├── overview.png └── sd.png ├── latent_imagenet_diffusion.py ├── linklink ├── __init__.py ├── dist_helper.py └── log_helper.py ├── quant ├── __init__.py ├── adaptive_rounding.py ├── calibration.py ├── data_generate.py ├── data_utill.py ├── quant_block.py ├── quant_layer.py ├── quant_model.py ├── reconstruction.py └── reconstruction_util.py ├── requirements.txt ├── sample_diffusion_ddim.py ├── sample_diffusion_ldm.py ├── stable-diffusion ├── README.md ├── Stable_Diffusion_v1_Model_Card.md ├── assets │ ├── a-painting-of-a-fire.png │ ├── a-photograph-of-a-fire.png │ ├── a-shirt-with-a-fire-printed-on-it.png │ ├── a-shirt-with-the-inscription-'fire'.png │ ├── a-watercolor-painting-of-a-fire.png │ ├── birdhouse.png │ ├── fire.png │ ├── inpainting.png │ ├── modelfigure.png │ ├── rdm-preview.jpg │ ├── reconstruction1.png │ ├── reconstruction2.png │ ├── results.gif │ ├── rick.jpeg │ ├── stable-samples │ │ ├── img2img │ │ │ ├── mountains-1.png │ │ │ ├── mountains-2.png │ │ │ ├── mountains-3.png │ │ │ ├── sketch-mountains-input.jpg │ │ │ ├── upscaling-in.png │ │ │ └── upscaling-out.png │ │ └── txt2img │ │ │ ├── 000002025.png │ │ │ ├── 000002035.png │ │ │ ├── merged-0005.png │ │ │ ├── merged-0006.png │ │ │ └── merged-0007.png │ ├── the-earth-is-on-fire,-oil-on-canvas.png │ ├── txt2img-convsample.png │ ├── txt2img-preview.png │ └── v1-variants-scores.jpg ├── configs │ ├── autoencoder │ │ ├── autoencoder_kl_16x16x16.yaml │ │ ├── autoencoder_kl_32x32x4.yaml │ │ ├── autoencoder_kl_64x64x3.yaml │ │ └── autoencoder_kl_8x8x64.yaml │ ├── latent-diffusion │ │ ├── celebahq-ldm-vq-4.yaml │ │ ├── cin-ldm-vq-f8.yaml │ │ ├── cin256-v2.yaml │ │ ├── ffhq-ldm-vq-4.yaml │ │ ├── lsun_bedrooms-ldm-vq-4.yaml │ │ ├── lsun_churches-ldm-kl-8.yaml │ │ └── txt2img-1p4B-eval.yaml │ ├── retrieval-augmented-diffusion │ │ └── 768x768.yaml │ └── stable-diffusion │ │ └── v1-inference.yaml ├── data │ ├── DejaVuSans.ttf │ ├── example_conditioning │ │ ├── superresolution │ │ │ └── sample_0.jpg │ │ └── text_conditional │ │ │ └── sample_0.txt │ ├── imagenet_clsidx_to_label.txt │ ├── imagenet_train_hr_indices.p │ ├── imagenet_val_hr_indices.p │ ├── index_synset.yaml │ └── inpainting_examples │ │ ├── 6458524847_2f4c361183_k.png │ │ ├── 6458524847_2f4c361183_k_mask.png │ │ ├── 8399166846_f6fb4e4b8e_k.png │ │ ├── 8399166846_f6fb4e4b8e_k_mask.png │ │ ├── alex-iby-G_Pk4D9rMLs.png │ │ ├── alex-iby-G_Pk4D9rMLs_mask.png │ │ ├── bench2.png │ │ ├── bench2_mask.png │ │ ├── bertrand-gabioud-CpuFzIsHYJ0.png │ │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png │ │ ├── billow926-12-Wc-Zgx6Y.png │ │ ├── billow926-12-Wc-Zgx6Y_mask.png │ │ ├── overture-creations-5sI6fQgYIuo.png │ │ ├── overture-creations-5sI6fQgYIuo_mask.png │ │ ├── photo-1583445095369-9c651e7e5d34.png │ │ └── photo-1583445095369-9c651e7e5d34_mask.png ├── environment.yaml ├── ldm │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── imagenet.py │ │ └── lsun.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ │ └── plms.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py ├── main.py ├── models │ ├── first_stage_models │ │ ├── kl-f16 │ │ │ └── config.yaml │ │ ├── kl-f32 │ │ │ └── config.yaml │ │ ├── kl-f4 │ │ │ └── config.yaml │ │ ├── kl-f8 │ │ │ └── config.yaml │ │ ├── vq-f16 │ │ │ └── config.yaml │ │ ├── vq-f4-noattn │ │ │ └── config.yaml │ │ ├── vq-f4 │ │ │ └── config.yaml │ │ ├── vq-f8-n256 │ │ │ └── config.yaml │ │ └── vq-f8 │ │ │ └── config.yaml │ └── ldm │ │ ├── bsr_sr │ │ └── config.yaml │ │ ├── celeba256 │ │ └── config.yaml │ │ ├── cin256 │ │ └── config.yaml │ │ ├── ffhq256 │ │ └── config.yaml │ │ ├── inpainting_big │ │ └── config.yaml │ │ ├── layout2img-openimages256 │ │ └── config.yaml │ │ ├── lsun_beds256 │ │ └── config.yaml │ │ ├── lsun_churches256 │ │ └── config.yaml │ │ ├── semantic_synthesis256 │ │ └── config.yaml │ │ ├── semantic_synthesis512 │ │ └── config.yaml │ │ └── text2img256 │ │ └── config.yaml ├── notebook_helpers.py ├── scripts │ ├── download_first_stages.sh │ ├── download_models.sh │ ├── img2img.py │ ├── inpaint.py │ ├── knn2img.py │ ├── latent_imagenet_diffusion.ipynb │ ├── sample_diffusion.py │ ├── tests │ │ └── test_watermark.py │ ├── train_searcher.py │ └── txt2img.py └── setup.py └── txt2img.py /.gitignore: -------------------------------------------------------------------------------- 1 | trace 2 | __pycache__ 3 | *.ckpt 4 | *.pth 5 | *.pt 6 | *.log 7 | *.DS_Store 8 | .vscode 9 | *.idea 10 | -------------------------------------------------------------------------------- /assets/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/assets/poster.pdf -------------------------------------------------------------------------------- /assets/slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/assets/slides.pdf -------------------------------------------------------------------------------- /ddim/README.md: -------------------------------------------------------------------------------- 1 | # Denoising Diffusion Implicit Models (DDIM) 2 | 3 | [Jiaming Song](http://tsong.me), [Chenlin Meng](http://cs.stanford.edu/~chenlin) and [Stefano Ermon](http://cs.stanford.edu/~ermon), Stanford 4 | 5 | Implements sampling from an implicit model that is trained with the same procedure as [Denoising Diffusion Probabilistic Model](https://hojonathanho.github.io/diffusion/), but costs much less time and compute if you want to sample from it (click image below for a video demo): 6 | 7 | ![](http://img.youtube.com/vi/WCKzxoSduJQ/0.jpg) 8 | 9 | ## **Integration with 🤗 Diffusers library** 10 | 11 | DDIM is now also available in 🧨 Diffusers and accesible via the [DDIMPipeline](https://huggingface.co/docs/diffusers/api/pipelines/ddim). 12 | Diffusers allows you to test DDIM in PyTorch in just a couple lines of code. 13 | 14 | You can install diffusers as follows: 15 | 16 | ``` 17 | pip install diffusers torch accelerate 18 | ``` 19 | 20 | And then try out the model with just a couple lines of code: 21 | 22 | ```python 23 | from diffusers import DDIMPipeline 24 | 25 | model_id = "google/ddpm-cifar10-32" 26 | 27 | # load model and scheduler 28 | ddim = DDIMPipeline.from_pretrained(model_id) 29 | 30 | # run pipeline in inference (sample random noise and denoise) 31 | image = ddim(num_inference_steps=50).images[0] 32 | 33 | # save image 34 | image.save("ddim_generated_image.png") 35 | ``` 36 | 37 | More DDPM/DDIM models compatible with hte DDIM pipeline can be found directly [on the Hub](https://huggingface.co/models?library=diffusers&sort=downloads&search=ddpm) 38 | 39 | To better understand the DDIM scheduler, you can check out [this introductionary google colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) 40 | 41 | The DDIM scheduler can also be used with more powerful diffusion models such as [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#stable-diffusion-pipelines) 42 | 43 | You simply need to [accept the license on the Hub](https://huggingface.co/runwayml/stable-diffusion-v1-5), login with `huggingface-cli login` and install transformers: 44 | 45 | ``` 46 | pip install transformers 47 | ``` 48 | 49 | Then you can run: 50 | 51 | ```python 52 | from diffusers import StableDiffusionPipeline, DDIMScheduler 53 | 54 | ddim = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") 55 | pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=ddim) 56 | 57 | image = pipeline("An astronaut riding a horse.").images[0] 58 | 59 | image.save("astronaut_riding_a_horse.png") 60 | ``` 61 | 62 | ## Running the Experiments 63 | The code has been tested on PyTorch 1.6. 64 | 65 | ### Train a model 66 | Training is exactly the same as DDPM with the following: 67 | ``` 68 | python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --doc {MODEL_NAME} --ni 69 | ``` 70 | 71 | ### Sampling from the model 72 | 73 | #### Sampling from the generalized model for FID evaluation 74 | ``` 75 | python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --doc {MODEL_NAME} --sample --fid --timesteps {STEPS} --eta {ETA} --ni 76 | ``` 77 | where 78 | - `ETA` controls the scale of the variance (0 is DDIM, and 1 is one type of DDPM). 79 | - `STEPS` controls how many timesteps used in the process. 80 | - `MODEL_NAME` finds the pre-trained checkpoint according to its inferred path. 81 | 82 | If you want to use the DDPM pretrained model: 83 | ``` 84 | python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --use_pretrained --sample --fid --timesteps {STEPS} --eta {ETA} --ni 85 | ``` 86 | the `--use_pretrained` option will automatically load the model according to the dataset. 87 | 88 | We provide a CelebA 64x64 model [here](https://drive.google.com/file/d/1R_H-fJYXSH79wfSKs9D-fuKQVan5L-GR/view?usp=sharing), and use the DDPM version for CIFAR10 and LSUN. 89 | 90 | If you want to use the version with the larger variance in DDPM: use the `--sample_type ddpm_noisy` option. 91 | 92 | #### Sampling from the model for image inpainting 93 | Use `--interpolation` option instead of `--fid`. 94 | 95 | #### Sampling from the sequence of images that lead to the sample 96 | Use `--sequence` option instead. 97 | 98 | The above two cases contain some hard-coded lines specific to producing the image, so modify them according to your needs. 99 | 100 | 101 | ## References and Acknowledgements 102 | ``` 103 | @article{song2020denoising, 104 | title={Denoising Diffusion Implicit Models}, 105 | author={Song, Jiaming and Meng, Chenlin and Ermon, Stefano}, 106 | journal={arXiv:2010.02502}, 107 | year={2020}, 108 | month={October}, 109 | abbr={Preprint}, 110 | url={https://arxiv.org/abs/2010.02502} 111 | } 112 | ``` 113 | 114 | 115 | This implementation is based on / inspired by: 116 | 117 | - [https://github.com/hojonathanho/diffusion](https://github.com/hojonathanho/diffusion) (the DDPM TensorFlow repo), 118 | - [https://github.com/pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion) (PyTorch helper that loads the DDPM model), and 119 | - [https://github.com/ermongroup/ncsnv2](https://github.com/ermongroup/ncsnv2) (code structure). 120 | -------------------------------------------------------------------------------- /ddim/configs/bedroom.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "bedroom" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | 13 | model: 14 | type: "simple" 15 | in_channels: 3 16 | out_ch: 3 17 | ch: 128 18 | ch_mult: [1, 1, 2, 2, 4, 4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.0 22 | var_type: fixedsmall 23 | ema_rate: 0.999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | training: 34 | batch_size: 64 35 | n_epochs: 10000 36 | n_iters: 5000000 37 | snapshot_freq: 5000 38 | validation_freq: 2000 39 | 40 | sampling: 41 | batch_size: 32 42 | last_only: True 43 | 44 | optim: 45 | weight_decay: 0.000 46 | optimizer: "Adam" 47 | lr: 0.00002 48 | beta1: 0.9 49 | amsgrad: false 50 | eps: 0.00000001 51 | -------------------------------------------------------------------------------- /ddim/configs/celeba.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CELEBA" 3 | image_size: 64 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | 12 | model: 13 | type: "simple" 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 2, 2, 4] 18 | num_res_blocks: 2 19 | attn_resolutions: [16, ] 20 | dropout: 0.1 21 | var_type: fixedlarge 22 | ema_rate: 0.9999 23 | ema: True 24 | resamp_with_conv: True 25 | 26 | diffusion: 27 | beta_schedule: linear 28 | beta_start: 0.0001 29 | beta_end: 0.02 30 | num_diffusion_timesteps: 1000 31 | 32 | training: 33 | batch_size: 128 34 | n_epochs: 10000 35 | n_iters: 5000000 36 | snapshot_freq: 5000 37 | validation_freq: 20000 38 | 39 | sampling: 40 | batch_size: 32 41 | last_only: True 42 | 43 | optim: 44 | weight_decay: 0.000 45 | optimizer: "Adam" 46 | lr: 0.0002 47 | beta1: 0.9 48 | amsgrad: false 49 | eps: 0.00000001 50 | grad_clip: 1.0 51 | -------------------------------------------------------------------------------- /ddim/configs/church.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "church_outdoor" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | 13 | model: 14 | type: "simple" 15 | in_channels: 3 16 | out_ch: 3 17 | ch: 128 18 | ch_mult: [1, 1, 2, 2, 4, 4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.0 22 | var_type: fixedsmall 23 | ema_rate: 0.999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | training: 34 | batch_size: 64 35 | n_epochs: 10000 36 | n_iters: 5000000 37 | snapshot_freq: 5000 38 | validation_freq: 2000 39 | 40 | sampling: 41 | batch_size: 32 42 | last_only: True 43 | 44 | optim: 45 | weight_decay: 0.000 46 | optimizer: "Adam" 47 | lr: 0.00002 48 | beta1: 0.9 49 | amsgrad: false 50 | eps: 0.00000001 51 | -------------------------------------------------------------------------------- /ddim/configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CIFAR10" 3 | image_size: 32 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | 12 | model: 13 | type: "simple" 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 2, 2] 18 | num_res_blocks: 2 19 | attn_resolutions: [16, ] 20 | dropout: 0.1 21 | var_type: fixedlarge 22 | ema_rate: 0.9999 23 | ema: True 24 | resamp_with_conv: True 25 | 26 | diffusion: 27 | beta_schedule: linear 28 | beta_start: 0.0001 29 | beta_end: 0.02 30 | num_diffusion_timesteps: 1000 31 | 32 | training: 33 | batch_size: 128 34 | n_epochs: 10000 35 | n_iters: 5000000 36 | snapshot_freq: 5000 37 | validation_freq: 2000 38 | 39 | sampling: 40 | batch_size: 64 41 | last_only: True 42 | 43 | optim: 44 | weight_decay: 0.000 45 | optimizer: "Adam" 46 | lr: 0.0002 47 | beta1: 0.9 48 | amsgrad: false 49 | eps: 0.00000001 50 | grad_clip: 1.0 51 | -------------------------------------------------------------------------------- /ddim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numbers 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as F 6 | from torchvision.datasets import CIFAR10 7 | from .celeba import CelebA 8 | from .ffhq import FFHQ 9 | from .lsun import LSUN 10 | from torch.utils.data import Subset 11 | import numpy as np 12 | 13 | 14 | class Crop(object): 15 | def __init__(self, x1, x2, y1, y2): 16 | self.x1 = x1 17 | self.x2 = x2 18 | self.y1 = y1 19 | self.y2 = y2 20 | 21 | def __call__(self, img): 22 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 23 | 24 | def __repr__(self): 25 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 26 | self.x1, self.x2, self.y1, self.y2 27 | ) 28 | 29 | 30 | def get_dataset(args, config): 31 | if config.data.random_flip is False: 32 | tran_transform = test_transform = transforms.Compose( 33 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 34 | ) 35 | else: 36 | tran_transform = transforms.Compose( 37 | [ 38 | transforms.Resize(config.data.image_size), 39 | transforms.RandomHorizontalFlip(p=0.5), 40 | transforms.ToTensor(), 41 | ] 42 | ) 43 | test_transform = transforms.Compose( 44 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 45 | ) 46 | 47 | if config.data.dataset == "CIFAR10": 48 | dataset = CIFAR10( 49 | os.path.join(args.exp, "datasets", "cifar10"), 50 | train=True, 51 | download=True, 52 | transform=tran_transform, 53 | ) 54 | test_dataset = CIFAR10( 55 | os.path.join(args.exp, "datasets", "cifar10_test"), 56 | train=False, 57 | download=True, 58 | transform=test_transform, 59 | ) 60 | 61 | elif config.data.dataset == "CELEBA": 62 | cx = 89 63 | cy = 121 64 | x1 = cy - 64 65 | x2 = cy + 64 66 | y1 = cx - 64 67 | y2 = cx + 64 68 | if config.data.random_flip: 69 | dataset = CelebA( 70 | root=os.path.join(args.exp, "datasets", "celeba"), 71 | split="train", 72 | transform=transforms.Compose( 73 | [ 74 | Crop(x1, x2, y1, y2), 75 | transforms.Resize(config.data.image_size), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | ] 79 | ), 80 | download=True, 81 | ) 82 | else: 83 | dataset = CelebA( 84 | root=os.path.join(args.exp, "datasets", "celeba"), 85 | split="train", 86 | transform=transforms.Compose( 87 | [ 88 | Crop(x1, x2, y1, y2), 89 | transforms.Resize(config.data.image_size), 90 | transforms.ToTensor(), 91 | ] 92 | ), 93 | download=True, 94 | ) 95 | 96 | test_dataset = CelebA( 97 | root=os.path.join(args.exp, "datasets", "celeba"), 98 | split="test", 99 | transform=transforms.Compose( 100 | [ 101 | Crop(x1, x2, y1, y2), 102 | transforms.Resize(config.data.image_size), 103 | transforms.ToTensor(), 104 | ] 105 | ), 106 | download=True, 107 | ) 108 | 109 | elif config.data.dataset == "LSUN": 110 | train_folder = "{}_train".format(config.data.category) 111 | val_folder = "{}_val".format(config.data.category) 112 | if config.data.random_flip: 113 | dataset = LSUN( 114 | root=os.path.join(args.exp, "datasets", "lsun"), 115 | classes=[train_folder], 116 | transform=transforms.Compose( 117 | [ 118 | transforms.Resize(config.data.image_size), 119 | transforms.CenterCrop(config.data.image_size), 120 | transforms.RandomHorizontalFlip(p=0.5), 121 | transforms.ToTensor(), 122 | ] 123 | ), 124 | ) 125 | else: 126 | dataset = LSUN( 127 | root=os.path.join(args.exp, "datasets", "lsun"), 128 | classes=[train_folder], 129 | transform=transforms.Compose( 130 | [ 131 | transforms.Resize(config.data.image_size), 132 | transforms.CenterCrop(config.data.image_size), 133 | transforms.ToTensor(), 134 | ] 135 | ), 136 | ) 137 | 138 | test_dataset = LSUN( 139 | root=os.path.join(args.exp, "datasets", "lsun"), 140 | classes=[val_folder], 141 | transform=transforms.Compose( 142 | [ 143 | transforms.Resize(config.data.image_size), 144 | transforms.CenterCrop(config.data.image_size), 145 | transforms.ToTensor(), 146 | ] 147 | ), 148 | ) 149 | 150 | elif config.data.dataset == "FFHQ": 151 | if config.data.random_flip: 152 | dataset = FFHQ( 153 | path=os.path.join(args.exp, "datasets", "FFHQ"), 154 | transform=transforms.Compose( 155 | [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()] 156 | ), 157 | resolution=config.data.image_size, 158 | ) 159 | else: 160 | dataset = FFHQ( 161 | path=os.path.join(args.exp, "datasets", "FFHQ"), 162 | transform=transforms.ToTensor(), 163 | resolution=config.data.image_size, 164 | ) 165 | 166 | num_items = len(dataset) 167 | indices = list(range(num_items)) 168 | random_state = np.random.get_state() 169 | np.random.seed(2019) 170 | np.random.shuffle(indices) 171 | np.random.set_state(random_state) 172 | train_indices, test_indices = ( 173 | indices[: int(num_items * 0.9)], 174 | indices[int(num_items * 0.9) :], 175 | ) 176 | test_dataset = Subset(dataset, test_indices) 177 | dataset = Subset(dataset, train_indices) 178 | else: 179 | dataset, test_dataset = None, None 180 | 181 | return dataset, test_dataset 182 | 183 | 184 | def logit_transform(image, lam=1e-6): 185 | image = lam + (1 - 2 * lam) * image 186 | return torch.log(image) - torch.log1p(-image) 187 | 188 | 189 | def data_transform(config, X): 190 | if config.data.uniform_dequantization: 191 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 192 | if config.data.gaussian_dequantization: 193 | X = X + torch.randn_like(X) * 0.01 194 | 195 | if config.data.rescaled: 196 | X = 2 * X - 1.0 197 | elif config.data.logit_transform: 198 | X = logit_transform(X) 199 | 200 | if hasattr(config, "image_mean"): 201 | return X - config.image_mean.to(X.device)[None, ...] 202 | 203 | return X 204 | 205 | 206 | def inverse_data_transform(config, X): 207 | if hasattr(config, "image_mean"): 208 | X = X + config.image_mean.to(X.device)[None, ...] 209 | 210 | if config.data.logit_transform: 211 | X = torch.sigmoid(X) 212 | elif config.data.rescaled: 213 | X = (X + 1.0) / 2.0 214 | 215 | return torch.clamp(X, 0.0, 1.0) 216 | -------------------------------------------------------------------------------- /ddim/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import PIL 4 | from .vision import VisionDataset 5 | from .utils import download_file_from_google_drive, check_integrity 6 | 7 | 8 | class CelebA(VisionDataset): 9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory where images are downloaded to. 13 | split (string): One of {'train', 'valid', 'test'}. 14 | Accordingly dataset is selected. 15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 17 | The targets represent: 18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 19 | ``identity`` (int): label for each person (data points with the same identity are the same person) 20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 23 | Defaults to ``attr``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | 33 | base_folder = "celeba" 34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 36 | # right now. 37 | file_list = [ 38 | # File ID MD5 Hash Filename 39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 48 | ] 49 | 50 | def __init__(self, root, 51 | split="train", 52 | target_type="attr", 53 | transform=None, target_transform=None, 54 | download=False): 55 | import pandas 56 | super(CelebA, self).__init__(root) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it') 71 | 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | 75 | if split.lower() == "train": 76 | split = 0 77 | elif split.lower() == "valid": 78 | split = 1 79 | elif split.lower() == "test": 80 | split = 2 81 | else: 82 | raise ValueError('Wrong split entered! Please use split="train" ' 83 | 'or split="valid" or split="test"') 84 | 85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 87 | 88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 90 | 91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 93 | 94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 96 | 97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 99 | 100 | mask = (splits[1] == split) 101 | self.filename = splits[mask].index.values 102 | self.identity = torch.as_tensor(self.identity[mask].values) 103 | self.bbox = torch.as_tensor(self.bbox[mask].values) 104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 105 | self.attr = torch.as_tensor(self.attr[mask].values) 106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 107 | 108 | def _check_integrity(self): 109 | for (_, md5, filename) in self.file_list: 110 | fpath = os.path.join(self.root, self.base_folder, filename) 111 | _, ext = os.path.splitext(filename) 112 | # Allow original archive to be deleted (zip and 7z) 113 | # Only need the extracted images 114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 115 | return False 116 | 117 | # Should check a hash of the images 118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 119 | 120 | def download(self): 121 | import zipfile 122 | 123 | if self._check_integrity(): 124 | print('Files already downloaded and verified') 125 | return 126 | 127 | for (file_id, md5, filename) in self.file_list: 128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 129 | 130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 131 | f.extractall(os.path.join(self.root, self.base_folder)) 132 | 133 | def __getitem__(self, index): 134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 135 | 136 | target = [] 137 | for t in self.target_type: 138 | if t == "attr": 139 | target.append(self.attr[index, :]) 140 | elif t == "identity": 141 | target.append(self.identity[index, 0]) 142 | elif t == "bbox": 143 | target.append(self.bbox[index, :]) 144 | elif t == "landmarks": 145 | target.append(self.landmarks_align[index, :]) 146 | else: 147 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 148 | target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.transform is not None: 151 | X = self.transform(X) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return X, target 157 | 158 | def __len__(self): 159 | return len(self.attr) 160 | 161 | def extra_repr(self): 162 | lines = ["Target type: {target_type}", "Split: {split}"] 163 | return '\n'.join(lines).format(**self.__dict__) 164 | -------------------------------------------------------------------------------- /ddim/datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FFHQ(Dataset): 9 | def __init__(self, path, transform, resolution=8): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | target = 0 40 | 41 | return img, target -------------------------------------------------------------------------------- /ddim/datasets/lsun.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import io 6 | from collections.abc import Iterable 7 | import pickle 8 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str 9 | 10 | 11 | class LSUNClass(VisionDataset): 12 | def __init__(self, root, transform=None, target_transform=None): 13 | import lmdb 14 | 15 | super(LSUNClass, self).__init__( 16 | root, transform=transform, target_transform=target_transform 17 | ) 18 | 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False, 26 | ) 27 | with self.env.begin(write=False) as txn: 28 | self.length = txn.stat()["entries"] 29 | root_split = root.split("/") 30 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}") 31 | if os.path.isfile(cache_file): 32 | self.keys = pickle.load(open(cache_file, "rb")) 33 | else: 34 | with self.env.begin(write=False) as txn: 35 | self.keys = [key for key, _ in txn.cursor()] 36 | pickle.dump(self.keys, open(cache_file, "wb")) 37 | 38 | def __getitem__(self, index): 39 | img, target = None, None 40 | env = self.env 41 | with env.begin(write=False) as txn: 42 | imgbuf = txn.get(self.keys[index]) 43 | 44 | buf = io.BytesIO() 45 | buf.write(imgbuf) 46 | buf.seek(0) 47 | img = Image.open(buf).convert("RGB") 48 | 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return img, target 56 | 57 | def __len__(self): 58 | return self.length 59 | 60 | 61 | class LSUN(VisionDataset): 62 | """ 63 | `LSUN `_ dataset. 64 | 65 | Args: 66 | root (string): Root directory for the database files. 67 | classes (string or list): One of {'train', 'val', 'test'} or a list of 68 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. 69 | transform (callable, optional): A function/transform that takes in an PIL image 70 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 71 | target_transform (callable, optional): A function/transform that takes in the 72 | target and transforms it. 73 | """ 74 | 75 | def __init__(self, root, classes="train", transform=None, target_transform=None): 76 | super(LSUN, self).__init__( 77 | root, transform=transform, target_transform=target_transform 78 | ) 79 | self.classes = self._verify_classes(classes) 80 | 81 | # for each class, create an LSUNClassDataset 82 | self.dbs = [] 83 | for c in self.classes: 84 | self.dbs.append( 85 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform) 86 | ) 87 | 88 | self.indices = [] 89 | count = 0 90 | for db in self.dbs: 91 | count += len(db) 92 | self.indices.append(count) 93 | 94 | self.length = count 95 | 96 | def _verify_classes(self, classes): 97 | categories = [ 98 | "bedroom", 99 | "bridge", 100 | "church_outdoor", 101 | "classroom", 102 | "conference_room", 103 | "dining_room", 104 | "kitchen", 105 | "living_room", 106 | "restaurant", 107 | "tower", 108 | ] 109 | dset_opts = ["train", "val", "test"] 110 | 111 | try: 112 | verify_str_arg(classes, "classes", dset_opts) 113 | if classes == "test": 114 | classes = [classes] 115 | else: 116 | classes = [c + "_" + classes for c in categories] 117 | except ValueError: 118 | if not isinstance(classes, Iterable): 119 | msg = ( 120 | "Expected type str or Iterable for argument classes, " 121 | "but got type {}." 122 | ) 123 | raise ValueError(msg.format(type(classes))) 124 | 125 | classes = list(classes) 126 | msg_fmtstr = ( 127 | "Expected type str for elements in argument classes, " 128 | "but got type {}." 129 | ) 130 | for c in classes: 131 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) 132 | c_short = c.split("_") 133 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1] 134 | 135 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." 136 | msg = msg_fmtstr.format( 137 | category, "LSUN class", iterable_to_str(categories) 138 | ) 139 | verify_str_arg(category, valid_values=categories, custom_msg=msg) 140 | 141 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) 142 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) 143 | 144 | return classes 145 | 146 | def __getitem__(self, index): 147 | """ 148 | Args: 149 | index (int): Index 150 | 151 | Returns: 152 | tuple: Tuple (image, target) where target is the index of the target category. 153 | """ 154 | target = 0 155 | sub = 0 156 | for ind in self.indices: 157 | if index < ind: 158 | break 159 | target += 1 160 | sub = ind 161 | 162 | db = self.dbs[target] 163 | index = index - sub 164 | 165 | if self.target_transform is not None: 166 | target = self.target_transform(target) 167 | 168 | img, _ = db[index] 169 | return img, target 170 | 171 | def __len__(self): 172 | return self.length 173 | 174 | def extra_repr(self): 175 | return "Classes: {classes}".format(**self.__dict__) 176 | -------------------------------------------------------------------------------- /ddim/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /ddim/datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, 'transform') and self.transform is not None: 41 | body += self._format_transform_repr(self.transform, 42 | "Transforms: ") 43 | if hasattr(self, 'target_transform') and self.target_transform is not None: 44 | body += self._format_transform_repr(self.target_transform, 45 | "Target transforms: ") 46 | lines = [head] + [" " * self._repr_indent + line for line in body] 47 | return '\n'.join(lines) 48 | 49 | def _format_transform_repr(self, transform, head): 50 | lines = transform.__repr__().splitlines() 51 | return (["{}{}".format(head, lines[0])] + 52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 53 | 54 | def extra_repr(self): 55 | return "" 56 | 57 | 58 | class StandardTransform(object): 59 | def __init__(self, transform=None, target_transform=None): 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | def __call__(self, input, target): 64 | if self.transform is not None: 65 | input = self.transform(input) 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | return input, target 69 | 70 | def _format_transform_repr(self, transform, head): 71 | lines = transform.__repr__().splitlines() 72 | return (["{}{}".format(head, lines[0])] + 73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 74 | 75 | def __repr__(self): 76 | body = [self.__class__.__name__] 77 | if self.transform is not None: 78 | body += self._format_transform_repr(self.transform, 79 | "Transform: ") 80 | if self.target_transform is not None: 81 | body += self._format_transform_repr(self.target_transform, 82 | "Target transform: ") 83 | 84 | return '\n'.join(body) 85 | -------------------------------------------------------------------------------- /ddim/functions/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError( 15 | 'Optimizer {} not understood.'.format(config.optim.optimizer)) 16 | -------------------------------------------------------------------------------- /ddim/functions/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | 37 | def download(url, local_path, chunk_size=1024): 38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root=None, check=False): 56 | if 'church_outdoor' in name: 57 | name = name.replace('church_outdoor', 'church') 58 | assert name in URL_MAP 59 | # Modify the path when necessary 60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("/atlas/u/tsong/.cache")) 61 | root = ( 62 | root 63 | if root is not None 64 | else os.path.join(cachedir, "diffusion_models_converted") 65 | ) 66 | path = os.path.join(root, CKPT_MAP[name]) 67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 69 | download(URL_MAP[name], path) 70 | md5 = md5_hash(path) 71 | assert md5 == MD5_MAP[name], md5 72 | return path 73 | -------------------------------------------------------------------------------- /ddim/functions/denoising.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_alpha(beta, t): 5 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 6 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 7 | return a 8 | 9 | 10 | def generalized_steps(x, seq, model, b, **kwargs): 11 | with torch.no_grad(): 12 | n = x.size(0) 13 | seq_next = [-1] + list(seq[:-1]) 14 | x0_preds = [] 15 | xs = [x] 16 | cnt = 0 17 | t, xt = None, None 18 | for i, j in zip(reversed(seq), reversed(seq_next)): 19 | t = (torch.ones(n) * i).to(x.device) 20 | next_t = (torch.ones(n) * j).to(x.device) 21 | at = compute_alpha(b, t.long()) 22 | at_next = compute_alpha(b, next_t.long()) 23 | xt = xs[-1].to('cuda') 24 | if "untill_fake_t" in kwargs and cnt == kwargs["untill_fake_t"] - 1: 25 | break 26 | if 'tot' in kwargs and kwargs['tot'] is not None: 27 | # act = kwargs['cali_ckpt'][f'act_{kwargs["t_max"] - i // kwargs["tot"]}'] 28 | act = kwargs['cali_ckpt'][f'act_{cnt}'] 29 | model.load_state_dict(act, strict=False) 30 | et = model(xt, t) 31 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 32 | x0_preds.append(x0_t.to('cpu')) 33 | c1 = ( 34 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 35 | ) 36 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 37 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 38 | xs.append(xt_next.to('cpu')) 39 | cnt += 1 40 | 41 | return xs, x0_preds, xt, t 42 | 43 | 44 | def ddpm_steps(x, seq, model, b, **kwargs): 45 | with torch.no_grad(): 46 | n = x.size(0) 47 | seq_next = [-1] + list(seq[:-1]) 48 | xs = [x] 49 | x0_preds = [] 50 | betas = b 51 | cnt = 0 52 | t, xt = None, None 53 | for i, j in zip(reversed(seq), reversed(seq_next)): 54 | t = (torch.ones(n) * i).to(x.device) 55 | next_t = (torch.ones(n) * j).to(x.device) 56 | at = compute_alpha(betas, t.long()) 57 | atm1 = compute_alpha(betas, next_t.long()) 58 | beta_t = 1 - at / atm1 59 | x = xs[-1].to('cuda') 60 | if "untill_fake_t" in kwargs and cnt == kwargs["untill_fake_t"] - 1: 61 | break 62 | if 'tot' in kwargs and kwargs['tot'] is not None: 63 | # act = kwargs['cali_ckpt'][f'act_{kwargs["t_max"] - i // kwargs["tot"]}'] 64 | act = kwargs['cali_ckpt'][f'act_{cnt}'] 65 | model.load_state_dict(act, strict=False) 66 | output = model(x, t.float()) 67 | e = output 68 | 69 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e 70 | x0_from_e = torch.clamp(x0_from_e, -1, 1) 71 | x0_preds.append(x0_from_e.to('cpu')) 72 | mean_eps = ( 73 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x 74 | ) / (1.0 - at) 75 | 76 | mean = mean_eps 77 | noise = torch.randn_like(x) 78 | mask = 1 - (t == 0).float() 79 | mask = mask.view(-1, 1, 1, 1) 80 | logvar = beta_t.log() 81 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 82 | xs.append(sample.to('cpu')) 83 | return xs, x0_preds, xt, t 84 | -------------------------------------------------------------------------------- /ddim/functions/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def noise_estimation_loss(model, 5 | x0: torch.Tensor, 6 | t: torch.LongTensor, 7 | e: torch.Tensor, 8 | b: torch.Tensor, keepdim=False): 9 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 10 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt() 11 | output = model(x, t.float()) 12 | if keepdim: 13 | return (e - output).square().sum(dim=(1, 2, 3)) 14 | else: 15 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 16 | 17 | 18 | loss_registry = { 19 | 'simple': noise_estimation_loss, 20 | } 21 | -------------------------------------------------------------------------------- /ddim/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import shutil 4 | import logging 5 | import yaml 6 | import sys 7 | import os 8 | import torch 9 | import numpy as np 10 | import torch.utils.tensorboard as tb 11 | 12 | from runners.diffusion import Diffusion 13 | 14 | torch.set_printoptions(sci_mode=False) 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 19 | 20 | parser.add_argument( 21 | "--config", type=str, required=True, help="Path to the config file" 22 | ) 23 | parser.add_argument("--seed", type=int, default=1234, help="Random seed") 24 | parser.add_argument( 25 | "--exp", type=str, default="exp", help="Path for saving running related data." 26 | ) 27 | parser.add_argument( 28 | "--doc", 29 | type=str, 30 | required=True, 31 | help="A string for documentation purpose. " 32 | "Will be the name of the log folder.", 33 | ) 34 | parser.add_argument( 35 | "--comment", type=str, default="", help="A string for experiment comment" 36 | ) 37 | parser.add_argument( 38 | "--verbose", 39 | type=str, 40 | default="info", 41 | help="Verbose level: info | debug | warning | critical", 42 | ) 43 | parser.add_argument("--test", action="store_true", help="Whether to test the model") 44 | parser.add_argument( 45 | "--sample", 46 | action="store_true", 47 | help="Whether to produce samples from the model", 48 | ) 49 | parser.add_argument("--fid", action="store_true") 50 | parser.add_argument("--interpolation", action="store_true") 51 | parser.add_argument( 52 | "--resume_training", action="store_true", help="Whether to resume training" 53 | ) 54 | parser.add_argument( 55 | "-i", 56 | "--image_folder", 57 | type=str, 58 | default="images", 59 | help="The folder name of samples", 60 | ) 61 | parser.add_argument( 62 | "--ni", 63 | action="store_true", 64 | help="No interaction. Suitable for Slurm Job launcher", 65 | ) 66 | parser.add_argument("--use_pretrained", action="store_true") 67 | parser.add_argument( 68 | "--sample_type", 69 | type=str, 70 | default="generalized", 71 | help="sampling approach (generalized or ddpm_noisy)", 72 | ) 73 | parser.add_argument( 74 | "--skip_type", 75 | type=str, 76 | default="uniform", 77 | help="skip according to (uniform or quadratic)", 78 | ) 79 | parser.add_argument( 80 | "--timesteps", type=int, default=1000, help="number of steps involved" 81 | ) 82 | parser.add_argument( 83 | "--eta", 84 | type=float, 85 | default=0.0, 86 | help="eta used to control the variances of sigma", 87 | ) 88 | parser.add_argument("--sequence", action="store_true") 89 | 90 | args = parser.parse_args() 91 | args.log_path = os.path.join(args.exp, "logs", args.doc) 92 | 93 | # parse config file 94 | with open(os.path.join("configs", args.config), "r") as f: 95 | config = yaml.safe_load(f) 96 | new_config = dict2namespace(config) 97 | 98 | tb_path = os.path.join(args.exp, "tensorboard", args.doc) 99 | 100 | if not args.test and not args.sample: 101 | if not args.resume_training: 102 | if os.path.exists(args.log_path): 103 | overwrite = False 104 | if args.ni: 105 | overwrite = True 106 | else: 107 | response = input("Folder already exists. Overwrite? (Y/N)") 108 | if response.upper() == "Y": 109 | overwrite = True 110 | 111 | if overwrite: 112 | shutil.rmtree(args.log_path) 113 | shutil.rmtree(tb_path) 114 | os.makedirs(args.log_path) 115 | if os.path.exists(tb_path): 116 | shutil.rmtree(tb_path) 117 | else: 118 | print("Folder exists. Program halted.") 119 | sys.exit(0) 120 | else: 121 | os.makedirs(args.log_path) 122 | 123 | with open(os.path.join(args.log_path, "config.yml"), "w") as f: 124 | yaml.dump(new_config, f, default_flow_style=False) 125 | 126 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path) 127 | # setup logger 128 | level = getattr(logging, args.verbose.upper(), None) 129 | if not isinstance(level, int): 130 | raise ValueError("level {} not supported".format(args.verbose)) 131 | 132 | handler1 = logging.StreamHandler() 133 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) 134 | formatter = logging.Formatter( 135 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 136 | ) 137 | handler1.setFormatter(formatter) 138 | handler2.setFormatter(formatter) 139 | logger = logging.getLogger() 140 | logger.addHandler(handler1) 141 | logger.addHandler(handler2) 142 | logger.setLevel(level) 143 | 144 | else: 145 | level = getattr(logging, args.verbose.upper(), None) 146 | if not isinstance(level, int): 147 | raise ValueError("level {} not supported".format(args.verbose)) 148 | 149 | handler1 = logging.StreamHandler() 150 | formatter = logging.Formatter( 151 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 152 | ) 153 | handler1.setFormatter(formatter) 154 | logger = logging.getLogger() 155 | logger.addHandler(handler1) 156 | logger.setLevel(level) 157 | 158 | if args.sample: 159 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) 160 | args.image_folder = os.path.join( 161 | args.exp, "image_samples", args.image_folder 162 | ) 163 | if not os.path.exists(args.image_folder): 164 | os.makedirs(args.image_folder) 165 | else: 166 | if not (args.fid or args.interpolation): 167 | overwrite = False 168 | if args.ni: 169 | overwrite = True 170 | else: 171 | response = input( 172 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)" 173 | ) 174 | if response.upper() == "Y": 175 | overwrite = True 176 | 177 | if overwrite: 178 | shutil.rmtree(args.image_folder) 179 | os.makedirs(args.image_folder) 180 | else: 181 | print("Output image folder exists. Program halted.") 182 | sys.exit(0) 183 | 184 | # add device 185 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 186 | logging.info("Using device: {}".format(device)) 187 | new_config.device = device 188 | 189 | # set random seed 190 | torch.manual_seed(args.seed) 191 | np.random.seed(args.seed) 192 | if torch.cuda.is_available(): 193 | torch.cuda.manual_seed_all(args.seed) 194 | 195 | torch.backends.cudnn.benchmark = True 196 | 197 | return args, new_config 198 | 199 | 200 | def dict2namespace(config): 201 | namespace = argparse.Namespace() 202 | for key, value in config.items(): 203 | if isinstance(value, dict): 204 | new_value = dict2namespace(value) 205 | else: 206 | new_value = value 207 | setattr(namespace, key, new_value) 208 | return namespace 209 | 210 | 211 | def main(): 212 | args, config = parse_args_and_config() 213 | logging.info("Writing log file to {}".format(args.log_path)) 214 | logging.info("Exp instance id = {}".format(os.getpid())) 215 | logging.info("Exp comment = {}".format(args.comment)) 216 | 217 | try: 218 | runner = Diffusion(args, config) 219 | if args.sample: 220 | runner.sample() 221 | elif args.test: 222 | runner.test() 223 | else: 224 | runner.train() 225 | except Exception: 226 | logging.error(traceback.format_exc()) 227 | 228 | return 0 229 | 230 | 231 | if __name__ == "__main__": 232 | sys.exit(main()) 233 | -------------------------------------------------------------------------------- /ddim/models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict 50 | -------------------------------------------------------------------------------- /ddim/runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/ddim/runners/__init__.py -------------------------------------------------------------------------------- /img/ldm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/img/ldm.png -------------------------------------------------------------------------------- /img/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/img/overview.png -------------------------------------------------------------------------------- /img/sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/img/sd.png -------------------------------------------------------------------------------- /linklink/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | allreduce = dist.all_reduce 7 | allgather = dist.all_gather 8 | broadcast = dist.broadcast 9 | barrier = dist.barrier 10 | synchronize = torch.cuda.synchronize 11 | init_process_group = dist.init_process_group 12 | get_rank = dist.get_rank 13 | get_world_size = dist.get_world_size 14 | 15 | 16 | def get_local_rank(): 17 | rank = dist.get_rank() 18 | return rank % torch.cuda.device_count() 19 | 20 | 21 | def initialize(backend='nccl', port='2333', job_envrion='normal'): 22 | """ 23 | Function to initialize distributed enviroments. 24 | :param backend: nccl backend supports GPU DDP, this should not be modified. 25 | :param port: port to communication 26 | :param job_envrion: we refer normal enviroments as the pytorch suggested initialization, the slurm 27 | enviroment is used for SLURM job submit system. 28 | """ 29 | 30 | if job_envrion == 'nomal': 31 | # this step is taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L129 32 | dist.init_process_group(backend=backend, init_method='tcp://224.66.41.62:23456') 33 | elif job_envrion == 'slurm': 34 | proc_id = int(os.environ['SLURM_PROCID']) 35 | ntasks = int(os.environ['SLURM_NTASKS']) 36 | node_list = os.environ['SLURM_NODELIST'] 37 | if '[' in node_list: 38 | beg = node_list.find('[') 39 | pos1 = node_list.find('-', beg) 40 | if pos1 < 0: 41 | pos1 = 1000 42 | pos2 = node_list.find(',', beg) 43 | if pos2 < 0: 44 | pos2 = 1000 45 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 46 | addr = node_list[8:].replace('-', '.') 47 | os.environ['MASTER_PORT'] = port 48 | os.environ['MASTER_ADDR'] = addr 49 | os.environ['WORLD_SIZE'] = str(ntasks) 50 | os.environ['RANK'] = str(proc_id) 51 | if backend == 'nccl': 52 | dist.init_process_group(backend='nccl') 53 | else: 54 | dist.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks) 55 | rank = dist.get_rank() 56 | device = rank % torch.cuda.device_count() 57 | torch.cuda.set_device(device) 58 | else: 59 | raise NotImplementedError 60 | 61 | 62 | def finalize(): 63 | pass 64 | 65 | 66 | class nn(object): 67 | SyncBatchNorm2d = torch.nn.BatchNorm2d 68 | print("You are using fake SyncBatchNorm2d who is actually the official BatchNorm2d") 69 | 70 | 71 | class syncbnVarMode_t(object): 72 | L1 = None 73 | L2 = None 74 | 75 | -------------------------------------------------------------------------------- /linklink/dist_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | import linklink as link 6 | 7 | 8 | def save_file(dict, name): 9 | if link.get_local_rank() == 0: 10 | torch.save(dict, name) 11 | 12 | 13 | def dist_finalize(): 14 | link.finalize() 15 | 16 | 17 | class AllReduce(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, input): 20 | output = torch.zeros_like(input) 21 | output.copy_(input) 22 | link.allreduce(output) 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | in_grad = torch.zeros_like(grad_output) 28 | in_grad.copy_(grad_output) 29 | link.allreduce(in_grad) 30 | return in_grad 31 | 32 | 33 | def allaverage(tensor): 34 | tensor.data /= link.get_world_size() 35 | link.allreduce(tensor.data) 36 | return tensor 37 | 38 | 39 | def allaverage_autograd(tensor): 40 | if tensor.is_cuda is True: 41 | tensor /= link.get_world_size() 42 | tensor = AllReduce().apply(tensor) 43 | return tensor 44 | 45 | 46 | def allreduce(tensor): 47 | link.allreduce(tensor.data) 48 | 49 | 50 | def link_dist(func): 51 | 52 | def wrapper(*args, **kwargs): 53 | dist_init() 54 | func(*args, **kwargs) 55 | dist_finalize() 56 | 57 | return wrapper 58 | 59 | 60 | def dist_init(method='slurm', device_id=0): 61 | if method == 'slurm': 62 | proc_id = int(os.environ['SLURM_PROCID']) 63 | # ntasks = int(os.environ['SLURM_NTASKS']) 64 | # node_list = os.environ['SLURM_NODELIST'] 65 | num_gpus = torch.cuda.device_count() 66 | torch.cuda.set_device(proc_id % num_gpus) 67 | elif method == 'normal': 68 | torch.cuda.set_device(device_id) 69 | link.initialize(backend='nccl', job_envrion=method) 70 | world_size = link.get_world_size() 71 | rank = link.get_rank() 72 | 73 | return rank, world_size 74 | 75 | 76 | def dist_finalize(): 77 | link.finalize() 78 | 79 | 80 | def simple_group_split(world_size, rank, num_groups): 81 | groups = [] 82 | rank_list = np.split(np.arange(world_size), num_groups) 83 | rank_list = [list(map(int, x)) for x in rank_list] 84 | for i in range(num_groups): 85 | groups.append(link.new_group(rank_list[i])) 86 | group_size = world_size // num_groups 87 | return groups[rank//group_size] 88 | 89 | 90 | class DistModule(torch.nn.Module): 91 | def __init__(self, module, sync=False): 92 | super(DistModule, self).__init__() 93 | self.module = module 94 | self.broadcast_params() 95 | 96 | self.sync = sync 97 | if not sync: 98 | self._grad_accs = [] 99 | self._register_hooks() 100 | 101 | def forward(self, *inputs, **kwargs): 102 | return self.module(*inputs, **kwargs) 103 | 104 | def _register_hooks(self): 105 | for i, (name, p) in enumerate(self.named_parameters()): 106 | if p.requires_grad: 107 | p_tmp = p.expand_as(p) 108 | grad_acc = p_tmp.grad_fn.next_functions[0][0] 109 | grad_acc.register_hook(self._make_hook(name, p, i)) 110 | self._grad_accs.append(grad_acc) 111 | 112 | def _make_hook(self, name, p, i): 113 | def hook(*ignore): 114 | link.allreduce_async(name, p.grad.data) 115 | return hook 116 | 117 | def sync_gradients(self): 118 | """ average gradients """ 119 | if self.sync and link.get_world_size() > 1: 120 | for name, param in self.module.named_parameters(): 121 | if param.requires_grad and param.grad is not None: 122 | link.allreduce(param.grad.data) 123 | else: 124 | link.synchronize() 125 | 126 | def broadcast_params(self): 127 | """ broadcast model parameters """ 128 | for name, param in self.module.state_dict().items(): 129 | link.broadcast(param, 0) 130 | 131 | 132 | def _serialize_to_tensor(data, group=None): 133 | # backend = link.get_backend(group) 134 | # assert backend in ["gloo", "nccl"] 135 | # device = torch.device("cpu" if backend == "gloo" else "cuda") 136 | device = torch.cuda.current_device() 137 | 138 | buffer = pickle.dumps(data) 139 | if len(buffer) > 1024 ** 3: 140 | import logging 141 | logger = logging.getLogger('global') 142 | logger.warning( 143 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 144 | link.get_rank(), len(buffer) / (1024 ** 3), device 145 | ) 146 | ) 147 | storage = torch.ByteStorage.from_buffer(buffer) 148 | tensor = torch.ByteTensor(storage).to(device=device) 149 | return tensor 150 | 151 | 152 | def broadcast_object(obj, group=None): 153 | """make suare obj is picklable 154 | """ 155 | if link.get_world_size() == 1: 156 | return obj 157 | 158 | serialized_tensor = _serialize_to_tensor(obj).cuda() 159 | numel = torch.IntTensor([serialized_tensor.numel()]).cuda() 160 | link.broadcast(numel, 0) 161 | # serialized_tensor from storage is not resizable 162 | serialized_tensor = serialized_tensor.clone() 163 | serialized_tensor.resize_(numel) 164 | link.broadcast(serialized_tensor, 0) 165 | serialized_bytes = serialized_tensor.cpu().numpy().tobytes() 166 | deserialized_obj = pickle.loads(serialized_bytes) 167 | return deserialized_obj 168 | -------------------------------------------------------------------------------- /linklink/log_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import linklink as link 5 | 6 | 7 | _logger = None 8 | _logger_fh = None 9 | _logger_names = [] 10 | 11 | 12 | def create_logger(log_file, level=logging.INFO): 13 | global _logger, _logger_fh 14 | if _logger is None: 15 | _logger = logging.getLogger() 16 | formatter = logging.Formatter( 17 | '[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s') 18 | fh = logging.FileHandler(log_file) 19 | fh.setFormatter(formatter) 20 | sh = logging.StreamHandler() 21 | sh.setFormatter(formatter) 22 | _logger.setLevel(level) 23 | _logger.addHandler(fh) 24 | _logger.addHandler(sh) 25 | _logger_fh = fh 26 | else: 27 | _logger.removeHandler(_logger_fh) 28 | _logger.setLevel(level) 29 | 30 | return _logger 31 | 32 | 33 | def get_logger(name, level=logging.INFO): 34 | global _logger_names 35 | logger = logging.getLogger(name) 36 | if name in _logger_names: 37 | return logger 38 | 39 | _logger_names.append(name) 40 | if link.get_rank() > 0: 41 | logger.addFilter(RankFilter()) 42 | 43 | return logger 44 | 45 | 46 | class RankFilter(logging.Filter): 47 | def filter(self, record): 48 | return False -------------------------------------------------------------------------------- /quant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/quant/__init__.py -------------------------------------------------------------------------------- /quant/adaptive_rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from quant.quant_layer import UniformAffineQuantizer, ste_round 4 | from enum import Enum 5 | 6 | RMODE = Enum('RMODE', ('LEARNED_ROUND_SIGMOID', 7 | 'NEAREST', 8 | 'NEAREST_STE', 9 | 'STOCHASTIC', 10 | 'LEARNED_HARD_SIGMOID')) 11 | 12 | class AdaRoundQuantizer(nn.Module): 13 | 14 | def __init__(self, 15 | uaqtizer: UniformAffineQuantizer, 16 | w: torch.Tensor, 17 | rmode: RMODE = RMODE.LEARNED_ROUND_SIGMOID, 18 | ) -> None: 19 | 20 | super().__init__() 21 | self.level = uaqtizer.level 22 | self.symmetric = uaqtizer.symmetric 23 | self.delta = uaqtizer.delta 24 | self.zero_point = uaqtizer.zero_point 25 | self.rmode = rmode 26 | self.soft_tgt = False 27 | self.gamma, self.zeta = -0.1, 1.1 28 | self.alpha = None 29 | self.init_alpha(x=w.clone()) 30 | 31 | def init_alpha(self, x: torch.Tensor) -> None: 32 | self.delta = self.delta.to(x.device) 33 | if self.rmode == RMODE.LEARNED_HARD_SIGMOID: 34 | rest = (x / self.delta) - torch.floor(x / self.delta) 35 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) 36 | self.alpha = nn.Parameter(alpha) 37 | else: 38 | raise NotImplementedError 39 | 40 | def get_soft_tgt(self) -> torch.Tensor: 41 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1) 42 | 43 | def forward(self, 44 | x: torch.Tensor 45 | ) -> torch.Tensor: 46 | if isinstance(self.delta, torch.Tensor): 47 | self.delta = self.delta.to(x.device) 48 | if isinstance(self.zero_point, torch.Tensor): 49 | self.zero_point = self.zero_point.to(x.device) 50 | 51 | x_floor = torch.floor(x / self.delta) 52 | if self.rmode == RMODE.NEAREST: 53 | x_int = torch.round(x / self.delta) 54 | elif self.rmode == RMODE.NEAREST_STE: 55 | x_int = ste_round(x / self.delta) 56 | elif self.rmode == RMODE.STOCHASTIC: 57 | x_int = x_floor + torch.bernoulli((x / self.delta) - x_floor) 58 | elif self.rmode == RMODE.LEARNED_HARD_SIGMOID: 59 | if self.soft_tgt: 60 | x_int = x_floor + self.get_soft_tgt().to(x.device) 61 | else: 62 | self.alpha = self.alpha.to(x.device) 63 | x_int = x_floor + (self.alpha >= 0).float() 64 | else: 65 | raise NotImplementedError 66 | 67 | NB, PB = -self.level // 2 if self.symmetric else 0, self.level // 2 - 1 if self.symmetric else self.level - 1 68 | x_q = torch.clamp(x_int + self.zero_point, NB, PB) 69 | x_dq = self.delta * (x_q - self.zero_point) 70 | return x_dq 71 | 72 | def extra_repr(self) -> str: 73 | s = 'level={}, symmetric={}, rmode={}'.format(self.level, self.symmetric, self.rmode) 74 | return s.format(**self.__dict__) 75 | 76 | 77 | -------------------------------------------------------------------------------- /quant/data_generate.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | from typing import List, Union 3 | from ddim.models.diffusion import Model 4 | from ldm.models.diffusion.ddim import DDIMSampler 5 | from ldm.models.diffusion.ddpm import LatentDiffusion 6 | from ldm.models.diffusion.dpm_solver.sampler import DPMSolverSampler 7 | from ldm.models.diffusion.plms import PLMSSampler 8 | from typing import Tuple 9 | from torch import autocast 10 | import torch 11 | 12 | 13 | def generate_cali_text_guided_data(model: LatentDiffusion, 14 | sampler: Union[DPMSolverSampler, PLMSSampler, DDIMSampler], 15 | T: int, 16 | c: int, 17 | batch_size: int, 18 | prompts: Tuple[str], 19 | shape: List[int], 20 | precision_scope: Union[autocast, nullcontext], 21 | ) -> Tuple[torch.Tensor]: 22 | tmp = list() 23 | model.eval() 24 | with torch.no_grad(): 25 | with precision_scope("cuda"): 26 | for t in range(1, T + 1): 27 | # x_{t + 1} = f(x_t, t_t, c_t) 28 | if t % c == 0: 29 | for p in prompts: 30 | uc_t = model.get_learned_conditioning(batch_size * [""]) 31 | c_t = model.get_learned_conditioning(batch_size * [p]) 32 | x_t, t_t = sampler.sample(S=T, 33 | conditioning=c_t, 34 | batch_size=batch_size, 35 | shape=shape, 36 | verbose=False, 37 | unconditional_guidance_scale=7.5, 38 | unconditional_conditioning=uc_t, 39 | untill_fake_t=t) 40 | if isinstance(sampler, (PLMSSampler, DDIMSampler)): 41 | ddpm_time_num = 1000 # in yaml 42 | real_time = (T - t) * ddpm_time_num // T + 1 43 | t_t = torch.full((batch_size,), real_time, device=sampler.model.betas.device, dtype=torch.long) 44 | tmp += [[x_t, t_t, c_t], [x_t, t_t, uc_t]] 45 | 46 | cali_data = () 47 | for i in range(len(tmp[0])): 48 | cali_data += (torch.cat([x[i] for x in tmp]), ) 49 | return cali_data 50 | 51 | 52 | def generate_cali_data_ddim(runnr, 53 | model: Model, 54 | T: int, 55 | c: int, 56 | batch_size: int, 57 | shape: List[int], 58 | ) -> Tuple[torch.Tensor]: 59 | tmp = list() 60 | for i in range(1, T + 1): 61 | # x_{t + 1} = f(x_t, t_t, c_t) 62 | if i % c == 0: 63 | from ddim.runners.diffusion import Diffusion 64 | runnr: Diffusion 65 | N, C, H, W = batch_size, *shape 66 | x = torch.randn((N, C, H, W), device=runnr.device) 67 | x_t, t_t = runnr.sample_image(x, model, untill_fake_t=i)[1:] 68 | tmp += [[x_t, t_t]] 69 | cali_data = () 70 | for i in range(len(tmp[0])): 71 | cali_data += (torch.cat([x[i] for x in tmp]), ) 72 | return cali_data 73 | 74 | 75 | def generate_cali_data_ldm(model: LatentDiffusion, 76 | T: int, 77 | c: int, 78 | batch_size: int, 79 | shape: List[int], 80 | vanilla: bool = False, 81 | dpm: bool = False, 82 | plms: bool = False, 83 | eta: float = 0.0, 84 | ) -> Tuple[torch.Tensor]: 85 | if vanilla: 86 | pass 87 | elif dpm: 88 | sampler = DPMSolverSampler(model) 89 | elif plms: 90 | sampler = PLMSSampler(model) 91 | else: 92 | sampler = DDIMSampler(model) 93 | tmp = list() 94 | for t in range(1, T + 1): 95 | if t % c == 0: 96 | if not vanilla: 97 | x_t, t_t = sampler.sample(S=T, 98 | batch_size=batch_size, 99 | shape=shape, 100 | verbose=False, 101 | eta=eta, 102 | untill_fake_t=t) 103 | if isinstance(sampler, (PLMSSampler, DDIMSampler)): 104 | ddpm_time_num = 1000 # in yaml 105 | real_time = (T - t) * ddpm_time_num // T + 1 106 | t_t = torch.full((batch_size,), real_time, device=sampler.model.betas.device, dtype=torch.long) 107 | tmp += [[x_t, t_t]] 108 | else: 109 | raise NotImplementedError("Vanilla LDM is not implemented yet, because it needs 1000 steps to generate one sample.") 110 | cali_data = () 111 | for i in range(len(tmp[0])): 112 | cali_data += (torch.cat([x[i] for x in tmp]), ) 113 | return cali_data 114 | 115 | 116 | def generate_cali_data_ldm_imagenet(model: LatentDiffusion, 117 | T: int, 118 | c: int, 119 | batch_size: int, # 8 120 | shape: List[int], 121 | eta: float = 0.0, 122 | scale: float = 3.0 123 | ) -> Tuple[torch.Tensor]: 124 | sampler = DDIMSampler(model) 125 | tmp = list() 126 | classes = [i for i in range(0, 1000, 1000 // 31)] 127 | with torch.no_grad(): 128 | with model.ema_scope(): 129 | for i in range(1, T + 1): 130 | if i % c == 0: 131 | uc_t = model.get_learned_conditioning( 132 | {model.cond_stage_key: torch.tensor(batch_size * [1000]).to(model.device)} 133 | ) 134 | for class_label in classes: 135 | xc = torch.tensor(batch_size * [class_label]) 136 | c_t = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)}) 137 | x_t, t_t = sampler.sample(S=T, 138 | batch_size=batch_size, 139 | shape=shape, 140 | verbose=False, 141 | eta=eta, 142 | unconditional_guidance_scale=scale, 143 | unconditional_conditioning=uc_t, 144 | conditioning=c_t, 145 | untill_fake_t=i) 146 | if isinstance(sampler, DDIMSampler): 147 | ddpm_time_num = 1000 148 | real_time = (T - i) * ddpm_time_num // T + 1 149 | t_t = torch.full((batch_size,), real_time, device=sampler.model.betas.device, dtype=torch.long) 150 | tmp += [[x_t, t_t, c_t], [x_t, t_t, uc_t]] 151 | cali_data = () 152 | for i in range(len(tmp[0])): 153 | cali_data += (torch.cat([x[i] for x in tmp]), ) 154 | return cali_data -------------------------------------------------------------------------------- /quant/quant_model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch.nn as nn 3 | import torch 4 | from ldm.modules.attention import BasicTransformerBlock 5 | from quant.quant_block import QuantAttentionBlock, QuantAttnBlock, QuantQKMatMul, QuantResnetBlock, QuantSMVMatMul, QuantTemporalInformationBlock, QuantTemporalInformationBlockDDIM, b2qb, BaseQuantBlock 6 | from quant.quant_block import QuantBasicTransformerBlock, QuantResBlock 7 | from quant.quant_layer import QMODE, QuantLayer, StraightThrough 8 | 9 | 10 | class QuantModel(nn.Module): 11 | 12 | def __init__(self, 13 | model: nn.Module, 14 | wq_params: dict = {}, 15 | aq_params: dict = {}, 16 | cali: bool = True, 17 | **kwargs 18 | ) -> None: 19 | super().__init__() 20 | self.model = model 21 | self.softmax_a_bit = kwargs.get("softmax_a_bit", 8) 22 | self.in_channels = model.in_channels 23 | if hasattr(model, 'image_size'): 24 | self.image_size = model.image_size 25 | self.B = b2qb(aq_params['leaf_param']) 26 | self.quant_module(self.model, wq_params, aq_params, aq_mode=kwargs.get("aq_mode", [QMODE.NORMAL.value]), prev_name=None) 27 | self.quant_block(self.model, wq_params, aq_params) 28 | if cali: 29 | self.get_tib(self.model, wq_params, aq_params) 30 | 31 | def get_tib(self, 32 | module: nn.Module, 33 | wq_params: dict = {}, 34 | aq_params: dict = {}, 35 | ) -> QuantTemporalInformationBlock: 36 | for name, child in module.named_children(): 37 | if name == 'temb': 38 | self.tib = QuantTemporalInformationBlockDDIM(child, aq_params, self.model.ch) 39 | elif name == 'time_embed': 40 | self.tib = QuantTemporalInformationBlock(child, aq_params, self.model.model_channels, None) 41 | elif isinstance(child, QuantResBlock): 42 | self.tib.add_emb_layer(child.emb_layers) 43 | elif isinstance(child, QuantResnetBlock): 44 | self.tib.add_temb_proj(child.temb_proj) 45 | else: 46 | self.get_tib(child, wq_params, aq_params) 47 | 48 | 49 | def quant_module(self, 50 | module: nn.Module, 51 | wq_params: dict = {}, 52 | aq_params: dict = {}, 53 | aq_mode: List[int] = [QMODE.NORMAL.value], 54 | prev_name: str = None, 55 | ) -> None: 56 | for name, child in module.named_children(): 57 | if isinstance(child, tuple(QuantLayer.QMAP.keys())) and \ 58 | 'skip' not in name and 'op' not in name and not (prev_name == 'downsample' and name == 'conv') and 'shortcut' not in name: # refer to PTQD 59 | if prev_name is not None and 'emb_layers' in prev_name and '1' in name or 'temb_proj' in name: 60 | setattr(module, name, QuantLayer(child, wq_params, aq_params, aq_mode=aq_mode, quant_emb=True)) 61 | continue 62 | setattr(module, name, QuantLayer(child, wq_params, aq_params, aq_mode=aq_mode)) 63 | elif isinstance(child, StraightThrough): 64 | continue 65 | else: 66 | self.quant_module(child, wq_params, aq_params, aq_mode=aq_mode, prev_name=name) 67 | 68 | def quant_block(self, 69 | module: nn.Module, 70 | wq_params: dict = {}, 71 | aq_params: dict = {}, 72 | ) -> None: 73 | for name, child in module.named_children(): 74 | if child.__class__.__name__ in self.B: 75 | if self.B[child.__class__.__name__] in [QuantBasicTransformerBlock, QuantAttnBlock]: 76 | setattr(module, name, self.B[child.__class__.__name__](child, aq_params, softmax_a_bit = self.softmax_a_bit)) 77 | elif self.B[child.__class__.__name__] in [QuantResnetBlock, QuantAttentionBlock, QuantResBlock]: 78 | setattr(module, name, self.B[child.__class__.__name__](child, aq_params)) 79 | elif self.B[child.__class__.__name__] in [QuantSMVMatMul]: 80 | setattr(module, name, self.B[child.__class__.__name__](aq_params, softmax_a_bit = self.softmax_a_bit)) 81 | elif self.B[child.__class__.__name__] in [QuantQKMatMul]: 82 | setattr(module, name, self.B[child.__class__.__name__](aq_params)) 83 | else: 84 | self.quant_block(child, wq_params, aq_params) 85 | 86 | def set_quant_state(self, 87 | use_wq: bool = False, 88 | use_aq: bool = False 89 | ) -> None: 90 | for m in self.model.modules(): 91 | if isinstance(m, (BaseQuantBlock, QuantLayer)): 92 | m.set_quant_state(use_wq=use_wq, use_aq=use_aq) 93 | 94 | def forward(self, 95 | x: torch.Tensor, 96 | timestep: int = None, 97 | context: torch.Tensor = None, 98 | ) -> torch.Tensor: 99 | if context is None: 100 | return self.model(x, timestep) 101 | return self.model(x, timestep, context) 102 | 103 | def disable_out_quantization(self) -> None: 104 | modules = [] 105 | for m in self.model.modules(): 106 | if isinstance(m, QuantLayer): 107 | modules.append(m) 108 | modules: List[QuantLayer] 109 | # disable the last layer and the first layer 110 | modules[-1].use_wq = False 111 | modules[-1].disable_aq = True 112 | modules[0].disable_aq = True 113 | modules[0].use_wq = False 114 | modules[1].disable_aq = True 115 | modules[2].disable_aq = True 116 | modules[2].use_wq = False 117 | modules[3].disable_aq = True 118 | modules[0].ignore_recon = True 119 | modules[2].ignore_recon = True 120 | modules[-1].ignore_recon = True 121 | 122 | def set_grad_ckpt(self, grad_ckpt: bool) -> None: 123 | for _, module in self.model.named_modules(): 124 | if isinstance(module, (QuantBasicTransformerBlock, BasicTransformerBlock)): 125 | module.checkpoint = grad_ckpt 126 | 127 | def synchorize_activation_statistics(self): 128 | import linklink.dist_helper as dist 129 | for module in self.modules(): 130 | if isinstance(module, QuantLayer): 131 | if module.aqtizer.delta is not None: 132 | dist.allaverage(module.aqtizer.delta) 133 | 134 | 135 | def set_running_stat(self, 136 | running_stat: bool = False 137 | ) -> None: 138 | for m in self.model.modules(): 139 | if isinstance(m, QuantBasicTransformerBlock): 140 | m.attn1.aqtizer_q.running_stat = running_stat 141 | m.attn1.aqtizer_k.running_stat = running_stat 142 | m.attn1.aqtizer_v.running_stat = running_stat 143 | m.attn1.aqtizer_w.running_stat = running_stat 144 | m.attn2.aqtizer_q.running_stat = running_stat 145 | m.attn2.aqtizer_k.running_stat = running_stat 146 | m.attn2.aqtizer_v.running_stat = running_stat 147 | m.attn2.aqtizer_w.running_stat = running_stat 148 | elif isinstance(m, QuantQKMatMul): 149 | m.aqtizer_q.running_stat = running_stat 150 | m.aqtizer_k.running_stat = running_stat 151 | elif isinstance(m, QuantSMVMatMul): 152 | m.aqtizer_v.running_stat = running_stat 153 | m.aqtizer_w.running_stat = running_stat 154 | elif isinstance(m, QuantAttnBlock): 155 | m.aqtizer_q.running_stat = running_stat 156 | m.aqtizer_k.running_stat = running_stat 157 | m.aqtizer_v.running_stat = running_stat 158 | m.aqtizer_w.running_stat = running_stat 159 | elif isinstance(m, QuantLayer): 160 | m.set_running_stat(running_stat) 161 | 162 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --index-url https://download.pytorch.org/whl/nightly/cu121 2 | --pre 3 | torch 4 | torchvision 5 | torchaudio 6 | transformers==4.31.0 7 | lmdb -------------------------------------------------------------------------------- /sample_diffusion_ddim.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import os 5 | from pytorch_lightning import seed_everything 6 | 7 | import yaml 8 | 9 | from ddim.runners.diffusion import Diffusion 10 | from quant.quant_layer import QMODE 11 | 12 | 13 | def get_parser(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--config", type=str, required=True, help="Path to the config file" 17 | ) 18 | parser.add_argument("--seed", type=int, default=1234, help="Random seed") 19 | parser.add_argument( 20 | "-l", 21 | "--logdir", 22 | type=str, 23 | nargs="?", 24 | help="extra logdir", 25 | default="none" 26 | ) 27 | parser.add_argument("--use_pretrained", action="store_true") 28 | parser.add_argument( 29 | "--sample_type", 30 | type=str, 31 | default="generalized", 32 | help="sampling approach (generalized or ddpm_noisy)", 33 | ) 34 | parser.add_argument( 35 | "--skip_type", 36 | type=str, 37 | default="uniform", 38 | help="skip according to (uniform or quadratic)", 39 | ) 40 | parser.add_argument( 41 | "--timesteps", type=int, default=1000, help="number of steps involved" 42 | ) 43 | parser.add_argument( 44 | "--eta", 45 | type=float, 46 | default=0.0, 47 | help="eta used to control the variances of sigma", 48 | ) 49 | parser.add_argument("--sequence", action="store_true") 50 | 51 | # quantization configs 52 | parser.add_argument( 53 | "--ptq", action="store_true", help="apply post-training quantization" 54 | ) 55 | parser.add_argument( 56 | "--wq", 57 | type=int, 58 | default=8, 59 | help="int bit for weight quantization", 60 | ) 61 | parser.add_argument( 62 | "--aq", 63 | type=int, 64 | default=8, 65 | help="int bit for activation quantization", 66 | ) 67 | parser.add_argument( 68 | "--max_images", type=int, default=50000, help="number of images to sample" 69 | ) 70 | 71 | # qdiff specific configs 72 | parser.add_argument( 73 | "--cali_ckpt", type=str, 74 | help="path for calibrated model ckpt" 75 | ) 76 | parser.add_argument( 77 | "--softmax_a_bit",type=int, default=8, 78 | help="attn softmax activation bit" 79 | ) 80 | parser.add_argument( 81 | "--verbose", action="store_true", 82 | help="print out info like quantized model arch" 83 | ) 84 | 85 | parser.add_argument( 86 | "--cali", 87 | action="store_true", 88 | help="whether to calibrate the model" 89 | ) 90 | parser.add_argument( 91 | "--cali_save_path", 92 | type=str, 93 | default="cali_ckpt/quant_ddim.pth", 94 | help="path to save the calibrated ckpt" 95 | ) 96 | parser.add_argument( 97 | "--interval_length", 98 | type=int, 99 | default=1, 100 | help="calibration interval length" 101 | ) 102 | parser.add_argument( 103 | '--use_aq', 104 | action='store_true', 105 | help='whether to use activation quantization' 106 | ) 107 | return parser 108 | 109 | 110 | def dict2namespace(config): 111 | namespace = argparse.Namespace() 112 | for key, value in config.items(): 113 | if isinstance(value, dict): 114 | new_value = dict2namespace(value) 115 | else: 116 | new_value = value 117 | setattr(namespace, key, new_value) 118 | return namespace 119 | 120 | 121 | if __name__ == '__main__': 122 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 123 | 124 | parser = get_parser() 125 | args = parser.parse_args() 126 | with open(args.config, "r") as f: 127 | config = yaml.safe_load(f) 128 | config = dict2namespace(config) 129 | 130 | # fix random seed 131 | seed_everything(args.seed) 132 | 133 | # setup logger 134 | logdir = os.path.join(args.logdir, "samples", now) 135 | os.makedirs(logdir) 136 | log_path = os.path.join(logdir, "run.log") 137 | logging.basicConfig( 138 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 139 | datefmt='%m/%d/%Y %H:%M:%S', 140 | level=logging.INFO, 141 | handlers=[ 142 | logging.FileHandler(log_path), 143 | logging.StreamHandler() 144 | ] 145 | ) 146 | logger = logging.getLogger(__name__) 147 | logger.info(75 * "=") 148 | logger.info(f"Host {os.uname()[1]}") 149 | logger.info("logging to:") 150 | imglogdir = os.path.join(logdir, "img") 151 | nplogdir = os.path.join(logdir, "numpy") 152 | os.makedirs(nplogdir) 153 | args.image_folder = imglogdir 154 | args.numpy_folder = nplogdir 155 | 156 | os.makedirs(imglogdir) 157 | logger.info(logdir) 158 | logger.info(75 * "=") 159 | p = [QMODE.NORMAL.value] 160 | p.append(QMODE.QDIFF.value) 161 | args.q_mode = p 162 | args.fid = True 163 | args.log_path = "test/" 164 | args.use_pretrained = True 165 | args.use_aq = args.use_aq 166 | args.asym = True 167 | args.running_stat = True 168 | config.device = 'cuda0' 169 | runner = Diffusion(args, config) 170 | runner.sample() 171 | 172 | -------------------------------------------------------------------------------- /stable-diffusion/assets/a-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-painting-of-a-fire.png -------------------------------------------------------------------------------- /stable-diffusion/assets/a-photograph-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-photograph-of-a-fire.png -------------------------------------------------------------------------------- /stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png -------------------------------------------------------------------------------- /stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png -------------------------------------------------------------------------------- /stable-diffusion/assets/a-watercolor-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png -------------------------------------------------------------------------------- /stable-diffusion/assets/birdhouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/birdhouse.png -------------------------------------------------------------------------------- /stable-diffusion/assets/fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/fire.png -------------------------------------------------------------------------------- /stable-diffusion/assets/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/inpainting.png -------------------------------------------------------------------------------- /stable-diffusion/assets/modelfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/modelfigure.png -------------------------------------------------------------------------------- /stable-diffusion/assets/rdm-preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/rdm-preview.jpg -------------------------------------------------------------------------------- /stable-diffusion/assets/reconstruction1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/reconstruction1.png -------------------------------------------------------------------------------- /stable-diffusion/assets/reconstruction2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/reconstruction2.png -------------------------------------------------------------------------------- /stable-diffusion/assets/results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/results.gif -------------------------------------------------------------------------------- /stable-diffusion/assets/rick.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/rick.jpeg -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/img2img/mountains-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/mountains-1.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/img2img/mountains-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/mountains-2.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/img2img/mountains-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/mountains-3.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/img2img/upscaling-in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/img2img/upscaling-out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/txt2img/000002025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/000002025.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/txt2img/000002035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/000002035.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/txt2img/merged-0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/txt2img/merged-0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png -------------------------------------------------------------------------------- /stable-diffusion/assets/stable-samples/txt2img/merged-0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png -------------------------------------------------------------------------------- /stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png -------------------------------------------------------------------------------- /stable-diffusion/assets/txt2img-convsample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/txt2img-convsample.png -------------------------------------------------------------------------------- /stable-diffusion/assets/txt2img-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/txt2img-preview.png -------------------------------------------------------------------------------- /stable-diffusion/assets/v1-variants-scores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/assets/v1-variants-scores.jpg -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /stable-diffusion/configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /stable-diffusion/data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg -------------------------------------------------------------------------------- /stable-diffusion/data/example_conditioning/text_conditional/sample_0.txt: -------------------------------------------------------------------------------- 1 | A basket of cerries 2 | -------------------------------------------------------------------------------- /stable-diffusion/data/imagenet_train_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/imagenet_train_hr_indices.p -------------------------------------------------------------------------------- /stable-diffusion/data/imagenet_val_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/imagenet_val_hr_indices.p -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/bench2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bench2.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/bench2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bench2_mask.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png -------------------------------------------------------------------------------- /stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png -------------------------------------------------------------------------------- /stable-diffusion/environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - invisible-watermark 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit>=0.73.1 24 | - einops==0.3.0 25 | - torch-fidelity==0.3.0 26 | - transformers==4.19.2 27 | - torchmetrics==0.6.0 28 | - kornia==0.6 29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | - -e . 32 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/data/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /stable-diffusion/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample(self, 23 | S, 24 | batch_size, 25 | shape, 26 | conditioning=None, 27 | callback=None, 28 | normals_sequence=None, 29 | img_callback=None, 30 | quantize_x0=False, 31 | eta=0., 32 | mask=None, 33 | x0=None, 34 | temperature=1., 35 | noise_dropout=0., 36 | score_corrector=None, 37 | corrector_kwargs=None, 38 | verbose=True, 39 | x_T=None, 40 | log_every_t=100, 41 | unconditional_guidance_scale=1., 42 | unconditional_conditioning=None, 43 | untill_fake_t: int = None, 44 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 45 | **kwargs 46 | ): 47 | if conditioning is not None: 48 | if isinstance(conditioning, dict): 49 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 50 | if cbs != batch_size: 51 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 52 | else: 53 | if conditioning.shape[0] != batch_size: 54 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 55 | 56 | if not untill_fake_t: 57 | untill_fake_t = untill_fake_t = float('inf') 58 | # sampling 59 | C, H, W = shape 60 | size = (batch_size, C, H, W) 61 | 62 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 63 | 64 | device = self.model.betas.device 65 | if x_T is None: 66 | img = torch.randn(size, device=device) 67 | else: 68 | img = x_T 69 | 70 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 71 | 72 | model_fn = model_wrapper( 73 | lambda x, t, c: self.model.apply_model(x, t, c), 74 | ns, 75 | model_type="noise", 76 | guidance_type="classifier-free", 77 | condition=conditioning, 78 | unconditional_condition=unconditional_conditioning, 79 | guidance_scale=unconditional_guidance_scale, 80 | ) 81 | 82 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 83 | x, vec_t = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True, untill_fake_t=untill_fake_t) 84 | 85 | return x.to(device), vec_t 86 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/TFMQ-DM/c4de88c300a2935cc82d487f2fa3f95ef4253fab/stable-diffusion/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /stable-diffusion/scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /stable-diffusion/scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /stable-diffusion/scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /stable-diffusion/scripts/tests/test_watermark.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import fire 3 | from imwatermark import WatermarkDecoder 4 | 5 | 6 | def testit(img_path): 7 | bgr = cv2.imread(img_path) 8 | decoder = WatermarkDecoder('bytes', 136) 9 | watermark = decoder.decode(bgr, 'dwtDct') 10 | try: 11 | dec = watermark.decode('utf-8') 12 | except: 13 | dec = "null" 14 | print(dec) 15 | 16 | 17 | if __name__ == "__main__": 18 | fire.Fire(testit) -------------------------------------------------------------------------------- /stable-diffusion/scripts/train_searcher.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import scann 4 | import argparse 5 | import glob 6 | from multiprocessing import cpu_count 7 | from tqdm import tqdm 8 | 9 | from ldm.util import parallel_data_prefetch 10 | 11 | 12 | def search_bruteforce(searcher): 13 | return searcher.score_brute_force().build() 14 | 15 | 16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, 17 | partioning_trainsize, num_leaves, num_leaves_to_search): 18 | return searcher.tree(num_leaves=num_leaves, 19 | num_leaves_to_search=num_leaves_to_search, 20 | training_sample_size=partioning_trainsize). \ 21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() 22 | 23 | 24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): 25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( 26 | reorder_k).build() 27 | 28 | def load_datapool(dpath): 29 | 30 | 31 | def load_single_file(saved_embeddings): 32 | compressed = np.load(saved_embeddings) 33 | database = {key: compressed[key] for key in compressed.files} 34 | return database 35 | 36 | def load_multi_files(data_archive): 37 | database = {key: [] for key in data_archive[0].files} 38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): 39 | for key in d.files: 40 | database[key].append(d[key]) 41 | 42 | return database 43 | 44 | print(f'Load saved patch embedding from "{dpath}"') 45 | file_content = glob.glob(os.path.join(dpath, '*.npz')) 46 | 47 | if len(file_content) == 1: 48 | data_pool = load_single_file(file_content[0]) 49 | elif len(file_content) > 1: 50 | data = [np.load(f) for f in file_content] 51 | prefetched_data = parallel_data_prefetch(load_multi_files, data, 52 | n_proc=min(len(data), cpu_count()), target_data_type='dict') 53 | 54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} 55 | else: 56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') 57 | 58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') 59 | return data_pool 60 | 61 | 62 | def train_searcher(opt, 63 | metric='dot_product', 64 | partioning_trainsize=None, 65 | reorder_k=None, 66 | # todo tune 67 | aiq_thld=0.2, 68 | dims_per_block=2, 69 | num_leaves=None, 70 | num_leaves_to_search=None,): 71 | 72 | data_pool = load_datapool(opt.database) 73 | k = opt.knn 74 | 75 | if not reorder_k: 76 | reorder_k = 2 * k 77 | 78 | # normalize 79 | # embeddings = 80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) 81 | pool_size = data_pool['embedding'].shape[0] 82 | 83 | print(*(['#'] * 100)) 84 | print('Initializing scaNN searcher with the following values:') 85 | print(f'k: {k}') 86 | print(f'metric: {metric}') 87 | print(f'reorder_k: {reorder_k}') 88 | print(f'anisotropic_quantization_threshold: {aiq_thld}') 89 | print(f'dims_per_block: {dims_per_block}') 90 | print(*(['#'] * 100)) 91 | print('Start training searcher....') 92 | print(f'N samples in pool is {pool_size}') 93 | 94 | # this reflects the recommended design choices proposed at 95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md 96 | if pool_size < 2e4: 97 | print('Using brute force search.') 98 | searcher = search_bruteforce(searcher) 99 | elif 2e4 <= pool_size and pool_size < 1e5: 100 | print('Using asymmetric hashing search and reordering.') 101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 102 | else: 103 | print('Using using partioning, asymmetric hashing search and reordering.') 104 | 105 | if not partioning_trainsize: 106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10 107 | if not num_leaves: 108 | num_leaves = int(np.sqrt(pool_size)) 109 | 110 | if not num_leaves_to_search: 111 | num_leaves_to_search = max(num_leaves // 20, 1) 112 | 113 | print('Partitioning params:') 114 | print(f'num_leaves: {num_leaves}') 115 | print(f'num_leaves_to_search: {num_leaves_to_search}') 116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, 118 | partioning_trainsize, num_leaves, num_leaves_to_search) 119 | 120 | print('Finish training searcher') 121 | searcher_savedir = opt.target_path 122 | os.makedirs(searcher_savedir, exist_ok=True) 123 | searcher.serialize(searcher_savedir) 124 | print(f'Saved trained searcher under "{searcher_savedir}"') 125 | 126 | if __name__ == '__main__': 127 | sys.path.append(os.getcwd()) 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--database', 130 | '-d', 131 | default='data/rdm/retrieval_databases/openimages', 132 | type=str, 133 | help='path to folder containing the clip feature of the database') 134 | parser.add_argument('--target_path', 135 | '-t', 136 | default='data/rdm/searchers/openimages', 137 | type=str, 138 | help='path to the target folder where the searcher shall be stored.') 139 | parser.add_argument('--knn', 140 | '-k', 141 | default=20, 142 | type=int, 143 | help='number of nearest neighbors, for which the searcher shall be optimized') 144 | 145 | opt, _ = parser.parse_known_args() 146 | 147 | train_searcher(opt,) -------------------------------------------------------------------------------- /stable-diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------