├── README.md ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── cifar10.yml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml ├── retrieval-augmented-diffusion │ └── 768x768.yaml └── stable-diffusion │ ├── v1-inference.yaml │ └── v1-inpainting-inference.yaml ├── ddim ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── datasets │ ├── __init__.py │ ├── artcifar10.py │ ├── celeba.py │ ├── ffhq.py │ ├── lsun.py │ ├── utils.py │ └── vision.py ├── dpm_solver_pytorch.py ├── functions │ ├── __init__.py │ ├── ckpt_util.py │ ├── denoising.py │ └── losses.py └── models │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── diffusion.cpython-37.pyc │ ├── diffusion.py │ └── ema.py ├── environment.yml ├── get_calibration_set_imagenet_ddim.py ├── imgs ├── fig1.png ├── imagenet_example.png └── sd_images.png ├── ldm ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── util.cpython-37.pyc ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── autoencoder.cpython-37.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── ddim.cpython-37.pyc │ │ └── ddpm.cpython-37.pyc │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── dpm_solver.cpython-37.pyc │ │ │ └── sampler.cpython-37.pyc │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ └── plms.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── attention.cpython-37.pyc │ │ └── ema.cpython-37.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── model.cpython-37.pyc │ │ │ ├── openaimodel.cpython-37.pyc │ │ │ └── util.cpython-37.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── distributions.cpython-37.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ └── x_transformer.py └── util.py ├── models ├── first_stage_models │ ├── kl-f16 │ │ └── config.yaml │ ├── kl-f32 │ │ └── config.yaml │ ├── kl-f4 │ │ └── config.yaml │ ├── kl-f8 │ │ └── config.yaml │ ├── vq-f16 │ │ └── config.yaml │ ├── vq-f4-noattn │ │ └── config.yaml │ ├── vq-f4 │ │ └── config.yaml │ ├── vq-f8-n256 │ │ └── config.yaml │ └── vq-f8 │ │ └── config.yaml └── ldm │ ├── bsr_sr │ └── config.yaml │ ├── celeba256 │ └── config.yaml │ ├── cin256-v2 │ └── config.yaml │ ├── cin256 │ └── config.yaml │ ├── inpainting_big │ └── config.yaml │ ├── layout2img-openimages256 │ └── config.yaml │ ├── lsun_churches256 │ └── config.yaml │ ├── semantic_synthesis256 │ └── config.yaml │ ├── semantic_synthesis512 │ └── config.yaml │ └── text2img256 │ └── config.yaml ├── qdiff ├── __init__.py ├── adaptive_rounding.py ├── block_recon.py ├── layer_recon.py ├── post_layer_recon_imagenet.py ├── post_layer_recon_sd.py ├── post_layer_recon_uncond.py ├── quant_block.py ├── quant_layer.py ├── quant_model.py └── utils.py ├── sample_diffusion_ldm_bedroom.py ├── sample_diffusion_ldm_church.py ├── sample_diffusion_ldm_imagenet.py └── txt2img.py /README.md: -------------------------------------------------------------------------------- 1 | # QuEST 2 | The official repository for **QuEST: Low-bit Diffusion Model Quantization via Efficient Selective Finetuning** [[ArXiv]](https://arxiv.org/abs/2402.03666) 3 | 4 | ## Update Log 5 | **(2024.2.28)** Reorganized the code structures. 6 | 7 | **(2024.12.15)** Fixed some inconsistencies in implementation. 8 | 9 | ## Features 10 | QuEST achieves state-of-the-art performance on mutiple high-resolution image generation tasks, including unconditional image generation, class-conditional image generation and text-to-image generation. We also achieve superior performance on full 4-bit (W4A4) generation. 11 |

12 | drawing 13 |

14 | On ImageNet 256*256: 15 |

16 | drawing 17 |

18 | On Stable Diffusion v1.4 (512*512): 19 |

20 | drawing 21 |

22 | 23 | ## Get Started 24 | ### Prerequisites 25 | Make sure you have conda installed first, then: 26 | ``` 27 | git clone https://github.com/hatchetProject/QuEST.git 28 | cd QuEST 29 | conda env create -f environment.yml 30 | conda activate quest 31 | ``` 32 | ### Usage 33 | 1. For Latent Diffusion and Stable Diffusion experiments, first download pretrained model checkpoints following the instructions in the [latent-diffusion](https://github.com/CompVis/latent-diffusion/tree/main) and [stable-diffusion](https://github.com/CompVis/stable-diffusion#weights) repos from CompVis. We currently use sd-v1-4.ckpt for Stable Diffusion. 34 | 2. The calibration data for LSUN-Bedrooms/Churches and Stable Diffusion (COCO) can be downloaded from the [Q-Diffusion](https://github.com/Xiuyu-Li/q-diffusion/tree/master) repository. We will upload the calibration data for ImageNet soon. 35 | 3. Use the following commands to reproduce the models. `act_bit=4` additionally use channel-wise quantization on a more hardware-friendly dimension, which reduces computation cost. Also, exclude the `--running_stat` argument for W4A4 quantization. 36 | 4. Change line 151 in ldm/models/diffusion/ddim.py according the the number of timesteps you use, where you can replace '10' with the resulting number of `--c // --cali_st`, e.g. 200 // 20 = 10. Comment line 149~158 if you would like to inference with the FP model. 37 | 5. We highly recommend to use checkpoints by QDiffusion and resume it with the `--resume_w` command. 38 | ``` 39 | # LSUN-Bedrooms (LDM-4) 40 | python sample_diffusion_ldm_bedroom.py -r models/ldm/lsun_beds256/model.ckpt -n 100 --batch_size 20 -c 200 -e 1.0 --seed 40 --ptq --weight_bit <4 or 8> --quant_mode qdiff --cali_st 20 --cali_batch_size 32 --cali_n 256 --quant_act --act_bit <4 or 8> --a_sym --a_min_max --running_stat --cali_data_path -l 41 | # LSUN-Churches (LDM-8) 42 | python scripts/sample_diffusion_ldm_church.py -r models/ldm/lsun_churches256/model.ckpt -n 50000 --batch_size 10 -c 500 -e 0.0 --seed 40 --ptq --weight_bit <4 or 8> --quant_mode qdiff --cali_st 20 --cali_batch_size 32 --cali_n 256 --quant_act --act_bit <4 or 8> --cali_data_path -l 43 | # ImageNet 44 | python sample_diffusion_ldm_imagenet.py -r models/ldm/cin256-v2/model.ckpt -n 50 --batch_size 50 -c 20 -e 1.0 --seed 40 --ptq --weight_bit <4 or 8> --quant_mode qdiff --cali_st 20 --cali_batch_size 32 --cali_n 256 --quant_act --act_bit <4 or 8> --a_sym --a_min_max --running_stat --cond --cali_data_path -l 45 | # Stable Diffusion 46 | python txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms --cond --ptq --weight_bit <4 or 8> --quant_mode qdiff --quant_act --act_bit <4 or 8> --cali_st 25 --cali_batch_size 8 --cali_n 128 --no_grad_ckpt --split --running_stat --sm_abit 16 --cali_data_path --outdir 47 | ``` 48 | 49 | ### Calibration dataset 50 | We will release the calibration data. But you can also generate them yourself by using the following command (10 images per class over all timesteps): 51 | ``` 52 | python get_calibration_set_imagenet_ddim.py -r -n 10 --batch_size 10 -c 20 -e 1.0 -seed 40 -l output/ --cond 53 | ``` 54 | 55 | ### Evaluation 56 | We use the ADM’s TensorFlow evaluation suite [link](https://github.com/openai/guided-diffusion/tree/main/evaluations) for evaluating FID, sFID and IS. For Stable Diffusion, we generate 10,000 samples based on the prompts from the COCO2014 dataset calculate the average CLIP score. 57 | 58 | ## Acknowledgement 59 | This project is heavily based on [LDM](https://github.com/CompVis/latent-diffusion/tree/main) and [Q-Diffusion](https://github.com/Xiuyu-Li/q-diffusion/tree/master). 60 | 61 | ## Citation 62 | If you find this work helpful, please consider citing our paper: 63 | ``` 64 | @misc{wang2024quest, 65 | title={QuEST: Low-bit Diffusion Model Quantization via Efficient Selective Finetuning}, 66 | author={Haoxuan Wang and Yuzhang Shang and Zhihang Yuan and Junyi Wu and Yan Yan}, 67 | year={2024}, 68 | eprint={2402.03666}, 69 | archivePrefix={arXiv}, 70 | primaryClass={cs.CV} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CIFAR10" 3 | image_size: 32 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | 12 | model: 13 | type: "simple" 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 2, 2] 18 | num_res_blocks: 2 19 | attn_resolutions: [16, ] 20 | dropout: 0.1 21 | var_type: fixedlarge 22 | ema_rate: 0.9999 23 | ema: True 24 | resamp_with_conv: True 25 | 26 | diffusion: 27 | beta_schedule: linear 28 | beta_start: 0.0001 29 | beta_end: 0.02 30 | num_diffusion_timesteps: 1000 31 | 32 | training: 33 | batch_size: 128 34 | n_epochs: 10000 35 | n_iters: 5000000 36 | snapshot_freq: 5000 37 | validation_freq: 2000 38 | 39 | sampling: 40 | batch_size: 64 41 | last_only: True 42 | 43 | optim: 44 | weight_decay: 0.000 45 | optimizer: "Adam" 46 | lr: 0.0002 47 | beta1: 0.9 48 | amsgrad: false 49 | eps: 0.00000001 50 | grad_clip: 1.0 51 | -------------------------------------------------------------------------------- /configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | finetune_keys: null 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /ddim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/__init__.py -------------------------------------------------------------------------------- /ddim/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ddim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numbers 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as F 6 | from torchvision.datasets import CIFAR10 7 | from ddim.datasets.artcifar10 import artCIFAR10 8 | from ddim.datasets.celeba import CelebA 9 | from ddim.datasets.ffhq import FFHQ 10 | from ddim.datasets.lsun import LSUN 11 | from torch.utils.data import Subset 12 | import numpy as np 13 | 14 | 15 | class Crop(object): 16 | def __init__(self, x1, x2, y1, y2): 17 | self.x1 = x1 18 | self.x2 = x2 19 | self.y1 = y1 20 | self.y2 = y2 21 | 22 | def __call__(self, img): 23 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 24 | 25 | def __repr__(self): 26 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 27 | self.x1, self.x2, self.y1, self.y2 28 | ) 29 | 30 | 31 | def get_dataset(args, config): 32 | if config.data.random_flip is False: 33 | tran_transform = test_transform = transforms.Compose( 34 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 35 | ) 36 | else: 37 | tran_transform = transforms.Compose( 38 | [ 39 | transforms.Resize(config.data.image_size), 40 | transforms.RandomHorizontalFlip(p=0.5), 41 | transforms.ToTensor(), 42 | ] 43 | ) 44 | test_transform = transforms.Compose( 45 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 46 | ) 47 | 48 | if config.data.dataset == "CIFAR10": 49 | dataset = CIFAR10( 50 | os.path.join(args.dataset, "cifar10"), 51 | train=True, 52 | download=True, 53 | transform=tran_transform, 54 | ) 55 | test_dataset = CIFAR10( 56 | os.path.join(args.dataset, "cifar10_test"), 57 | train=False, 58 | download=True, 59 | transform=test_transform, 60 | ) 61 | 62 | elif config.data.dataset == "artCIFAR10": 63 | dataset = artCIFAR10( 64 | os.path.join(args.dataset), 65 | train=True, 66 | download=False, 67 | transform=tran_transform, 68 | ) 69 | test_dataset = artCIFAR10( 70 | os.path.join(args.dataset), 71 | train=False, 72 | download=False, 73 | transform=test_transform, 74 | ) 75 | 76 | elif config.data.dataset == "CELEBA": 77 | cx = 89 78 | cy = 121 79 | x1 = cy - 64 80 | x2 = cy + 64 81 | y1 = cx - 64 82 | y2 = cx + 64 83 | if config.data.random_flip: 84 | dataset = CelebA( 85 | root=os.path.join(args.dataset, "celeba"), 86 | split="train", 87 | transform=transforms.Compose( 88 | [ 89 | Crop(x1, x2, y1, y2), 90 | transforms.Resize(config.data.image_size), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | ] 94 | ), 95 | download=True, 96 | ) 97 | else: 98 | dataset = CelebA( 99 | root=os.path.join(args.dataset, "celeba"), 100 | split="train", 101 | transform=transforms.Compose( 102 | [ 103 | Crop(x1, x2, y1, y2), 104 | transforms.Resize(config.data.image_size), 105 | transforms.ToTensor(), 106 | ] 107 | ), 108 | download=True, 109 | ) 110 | 111 | test_dataset = CelebA( 112 | root=os.path.join(args.dataset, "celeba"), 113 | split="test", 114 | transform=transforms.Compose( 115 | [ 116 | Crop(x1, x2, y1, y2), 117 | transforms.Resize(config.data.image_size), 118 | transforms.ToTensor(), 119 | ] 120 | ), 121 | download=True, 122 | ) 123 | 124 | elif config.data.dataset == "LSUN": 125 | train_folder = "{}_train".format(config.data.category) 126 | val_folder = "{}_val".format(config.data.category) 127 | if config.data.random_flip: 128 | dataset = LSUN( 129 | root=os.path.join(args.dataset, "lsun"), 130 | classes=[train_folder], 131 | transform=transforms.Compose( 132 | [ 133 | transforms.Resize(config.data.image_size), 134 | transforms.CenterCrop(config.data.image_size), 135 | transforms.RandomHorizontalFlip(p=0.5), 136 | transforms.ToTensor(), 137 | ] 138 | ), 139 | ) 140 | else: 141 | dataset = LSUN( 142 | root=os.path.join(args.dataset, "lsun"), 143 | classes=[train_folder], 144 | transform=transforms.Compose( 145 | [ 146 | transforms.Resize(config.data.image_size), 147 | transforms.CenterCrop(config.data.image_size), 148 | transforms.ToTensor(), 149 | ] 150 | ), 151 | ) 152 | 153 | test_dataset = LSUN( 154 | root=os.path.join(args.dataset, "lsun"), 155 | classes=[val_folder], 156 | transform=transforms.Compose( 157 | [ 158 | transforms.Resize(config.data.image_size), 159 | transforms.CenterCrop(config.data.image_size), 160 | transforms.ToTensor(), 161 | ] 162 | ), 163 | ) 164 | 165 | elif config.data.dataset == "FFHQ": 166 | if config.data.random_flip: 167 | dataset = FFHQ( 168 | path=os.path.join(args.dataset, "FFHQ"), 169 | transform=transforms.Compose( 170 | [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()] 171 | ), 172 | resolution=config.data.image_size, 173 | ) 174 | else: 175 | dataset = FFHQ( 176 | path=os.path.join(args.dataset, "FFHQ"), 177 | transform=transforms.ToTensor(), 178 | resolution=config.data.image_size, 179 | ) 180 | 181 | num_items = len(dataset) 182 | indices = list(range(num_items)) 183 | random_state = np.random.get_state() 184 | np.random.seed(2019) 185 | np.random.shuffle(indices) 186 | np.random.set_state(random_state) 187 | train_indices, test_indices = ( 188 | indices[: int(num_items * 0.9)], 189 | indices[int(num_items * 0.9) :], 190 | ) 191 | test_dataset = Subset(dataset, test_indices) 192 | dataset = Subset(dataset, train_indices) 193 | else: 194 | dataset, test_dataset = None, None 195 | 196 | return dataset, test_dataset 197 | 198 | 199 | def logit_transform(image, lam=1e-6): 200 | image = lam + (1 - 2 * lam) * image 201 | return torch.log(image) - torch.log1p(-image) 202 | 203 | 204 | def data_transform(config, X): 205 | if config.data.uniform_dequantization: 206 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 207 | if config.data.gaussian_dequantization: 208 | X = X + torch.randn_like(X) * 0.01 209 | 210 | if config.data.rescaled: 211 | X = 2 * X - 1.0 212 | elif config.data.logit_transform: 213 | X = logit_transform(X) 214 | 215 | if hasattr(config, "image_mean"): 216 | return X - config.image_mean.to(X.device)[None, ...] 217 | 218 | return X 219 | 220 | 221 | def inverse_data_transform(config, X): 222 | if hasattr(config, "image_mean"): 223 | X = X + config.image_mean.to(X.device)[None, ...] 224 | 225 | if config.data.logit_transform: 226 | X = torch.sigmoid(X) 227 | elif config.data.rescaled: 228 | X = (X + 1.0) / 2.0 229 | 230 | return torch.clamp(X, 0.0, 1.0) 231 | -------------------------------------------------------------------------------- /ddim/datasets/artcifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | class artCIFAR10(CIFAR10): 3 | """artCIFAR10 4 | """ 5 | 6 | base_folder = "artcifar-10-batches-py" 7 | url = "https://artcifar.s3.us-east-2.amazonaws.com/artcifar-10-python.tar.gz" 8 | filename = "artcifar-10-python.tar.gz" 9 | tgz_md5 = "a6b71c6e0e3435d34e17896cc83ae1c1" 10 | train_list = [ 11 | ["data_batch_1", "866ea0e474da9d5a033cfdb6adf4a631"], 12 | ["data_batch_2", "bba43dc11082b32fa8119cba55bebddd"], 13 | ["data_batch_3", "4978cbeb187b89e77e31fe39449c95ec"], 14 | ["data_batch_4", "1de1d0514f0c9cd28c5991386fa45f12"], 15 | ["data_batch_5", "81a7899dd79469824b75550bc0be3267"], 16 | ] 17 | 18 | test_list = [ 19 | ["test_batch", "2ce6fb1ee71ffba9cef5df551fcce572"], 20 | ] 21 | meta = { 22 | "filename": "meta", 23 | "key": "styles", 24 | "md5": "6993b54be992f459837552f042419cdf", 25 | } -------------------------------------------------------------------------------- /ddim/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import PIL 4 | from .vision import VisionDataset 5 | from .utils import download_file_from_google_drive, check_integrity 6 | 7 | 8 | class CelebA(VisionDataset): 9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory where images are downloaded to. 13 | split (string): One of {'train', 'valid', 'test'}. 14 | Accordingly dataset is selected. 15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 17 | The targets represent: 18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 19 | ``identity`` (int): label for each person (data points with the same identity are the same person) 20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 23 | Defaults to ``attr``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | 33 | base_folder = "celeba" 34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 36 | # right now. 37 | file_list = [ 38 | # File ID MD5 Hash Filename 39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 48 | ] 49 | 50 | def __init__(self, root, 51 | split="train", 52 | target_type="attr", 53 | transform=None, target_transform=None, 54 | download=False): 55 | import pandas 56 | super(CelebA, self).__init__(root) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it') 71 | 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | 75 | if split.lower() == "train": 76 | split = 0 77 | elif split.lower() == "valid": 78 | split = 1 79 | elif split.lower() == "test": 80 | split = 2 81 | else: 82 | raise ValueError('Wrong split entered! Please use split="train" ' 83 | 'or split="valid" or split="test"') 84 | 85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 87 | 88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 90 | 91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 93 | 94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 96 | 97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 99 | 100 | mask = (splits[1] == split) 101 | self.filename = splits[mask].index.values 102 | self.identity = torch.as_tensor(self.identity[mask].values) 103 | self.bbox = torch.as_tensor(self.bbox[mask].values) 104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 105 | self.attr = torch.as_tensor(self.attr[mask].values) 106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 107 | 108 | def _check_integrity(self): 109 | for (_, md5, filename) in self.file_list: 110 | fpath = os.path.join(self.root, self.base_folder, filename) 111 | _, ext = os.path.splitext(filename) 112 | # Allow original archive to be deleted (zip and 7z) 113 | # Only need the extracted images 114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 115 | return False 116 | 117 | # Should check a hash of the images 118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 119 | 120 | def download(self): 121 | import zipfile 122 | 123 | if self._check_integrity(): 124 | print('Files already downloaded and verified') 125 | return 126 | 127 | for (file_id, md5, filename) in self.file_list: 128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 129 | 130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 131 | f.extractall(os.path.join(self.root, self.base_folder)) 132 | 133 | def __getitem__(self, index): 134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 135 | 136 | target = [] 137 | for t in self.target_type: 138 | if t == "attr": 139 | target.append(self.attr[index, :]) 140 | elif t == "identity": 141 | target.append(self.identity[index, 0]) 142 | elif t == "bbox": 143 | target.append(self.bbox[index, :]) 144 | elif t == "landmarks": 145 | target.append(self.landmarks_align[index, :]) 146 | else: 147 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 148 | target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.transform is not None: 151 | X = self.transform(X) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return X, target 157 | 158 | def __len__(self): 159 | return len(self.attr) 160 | 161 | def extra_repr(self): 162 | lines = ["Target type: {target_type}", "Split: {split}"] 163 | return '\n'.join(lines).format(**self.__dict__) 164 | -------------------------------------------------------------------------------- /ddim/datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FFHQ(Dataset): 9 | def __init__(self, path, transform, resolution=8): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | target = 0 40 | 41 | return img, target -------------------------------------------------------------------------------- /ddim/datasets/lsun.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import io 6 | from collections.abc import Iterable 7 | import pickle 8 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str 9 | 10 | 11 | class LSUNClass(VisionDataset): 12 | def __init__(self, root, transform=None, target_transform=None): 13 | import lmdb 14 | 15 | super(LSUNClass, self).__init__( 16 | root, transform=transform, target_transform=target_transform 17 | ) 18 | 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False, 26 | ) 27 | with self.env.begin(write=False) as txn: 28 | self.length = txn.stat()["entries"] 29 | root_split = root.split("/") 30 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}") 31 | if os.path.isfile(cache_file): 32 | self.keys = pickle.load(open(cache_file, "rb")) 33 | else: 34 | with self.env.begin(write=False) as txn: 35 | self.keys = [key for key, _ in txn.cursor()] 36 | pickle.dump(self.keys, open(cache_file, "wb")) 37 | 38 | def __getitem__(self, index): 39 | img, target = None, None 40 | env = self.env 41 | with env.begin(write=False) as txn: 42 | imgbuf = txn.get(self.keys[index]) 43 | 44 | buf = io.BytesIO() 45 | buf.write(imgbuf) 46 | buf.seek(0) 47 | img = Image.open(buf).convert("RGB") 48 | 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return img, target 56 | 57 | def __len__(self): 58 | return self.length 59 | 60 | 61 | class LSUN(VisionDataset): 62 | """ 63 | `LSUN `_ dataset. 64 | 65 | Args: 66 | root (string): Root directory for the database files. 67 | classes (string or list): One of {'train', 'val', 'test'} or a list of 68 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. 69 | transform (callable, optional): A function/transform that takes in an PIL image 70 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 71 | target_transform (callable, optional): A function/transform that takes in the 72 | target and transforms it. 73 | """ 74 | 75 | def __init__(self, root, classes="train", transform=None, target_transform=None): 76 | super(LSUN, self).__init__( 77 | root, transform=transform, target_transform=target_transform 78 | ) 79 | self.classes = self._verify_classes(classes) 80 | 81 | # for each class, create an LSUNClassDataset 82 | self.dbs = [] 83 | for c in self.classes: 84 | self.dbs.append( 85 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform) 86 | ) 87 | 88 | self.indices = [] 89 | count = 0 90 | for db in self.dbs: 91 | count += len(db) 92 | self.indices.append(count) 93 | 94 | self.length = count 95 | 96 | def _verify_classes(self, classes): 97 | categories = [ 98 | "bedroom", 99 | "bridge", 100 | "church_outdoor", 101 | "classroom", 102 | "conference_room", 103 | "dining_room", 104 | "kitchen", 105 | "living_room", 106 | "restaurant", 107 | "tower", 108 | ] 109 | dset_opts = ["train", "val", "test"] 110 | 111 | try: 112 | verify_str_arg(classes, "classes", dset_opts) 113 | if classes == "test": 114 | classes = [classes] 115 | else: 116 | classes = [c + "_" + classes for c in categories] 117 | except ValueError: 118 | if not isinstance(classes, Iterable): 119 | msg = ( 120 | "Expected type str or Iterable for argument classes, " 121 | "but got type {}." 122 | ) 123 | raise ValueError(msg.format(type(classes))) 124 | 125 | classes = list(classes) 126 | msg_fmtstr = ( 127 | "Expected type str for elements in argument classes, " 128 | "but got type {}." 129 | ) 130 | for c in classes: 131 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) 132 | c_short = c.split("_") 133 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1] 134 | 135 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." 136 | msg = msg_fmtstr.format( 137 | category, "LSUN class", iterable_to_str(categories) 138 | ) 139 | verify_str_arg(category, valid_values=categories, custom_msg=msg) 140 | 141 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) 142 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) 143 | 144 | return classes 145 | 146 | def __getitem__(self, index): 147 | """ 148 | Args: 149 | index (int): Index 150 | 151 | Returns: 152 | tuple: Tuple (image, target) where target is the index of the target category. 153 | """ 154 | target = 0 155 | sub = 0 156 | for ind in self.indices: 157 | if index < ind: 158 | break 159 | target += 1 160 | sub = ind 161 | 162 | db = self.dbs[target] 163 | index = index - sub 164 | 165 | if self.target_transform is not None: 166 | target = self.target_transform(target) 167 | 168 | img, _ = db[index] 169 | return img, target 170 | 171 | def __len__(self): 172 | return self.length 173 | 174 | def extra_repr(self): 175 | return "Classes: {classes}".format(**self.__dict__) 176 | -------------------------------------------------------------------------------- /ddim/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /ddim/datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, 'transform') and self.transform is not None: 41 | body += self._format_transform_repr(self.transform, 42 | "Transforms: ") 43 | if hasattr(self, 'target_transform') and self.target_transform is not None: 44 | body += self._format_transform_repr(self.target_transform, 45 | "Target transforms: ") 46 | lines = [head] + [" " * self._repr_indent + line for line in body] 47 | return '\n'.join(lines) 48 | 49 | def _format_transform_repr(self, transform, head): 50 | lines = transform.__repr__().splitlines() 51 | return (["{}{}".format(head, lines[0])] + 52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 53 | 54 | def extra_repr(self): 55 | return "" 56 | 57 | 58 | class StandardTransform(object): 59 | def __init__(self, transform=None, target_transform=None): 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | def __call__(self, input, target): 64 | if self.transform is not None: 65 | input = self.transform(input) 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | return input, target 69 | 70 | def _format_transform_repr(self, transform, head): 71 | lines = transform.__repr__().splitlines() 72 | return (["{}{}".format(head, lines[0])] + 73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 74 | 75 | def __repr__(self): 76 | body = [self.__class__.__name__] 77 | if self.transform is not None: 78 | body += self._format_transform_repr(self.transform, 79 | "Transform: ") 80 | if self.target_transform is not None: 81 | body += self._format_transform_repr(self.target_transform, 82 | "Target transform: ") 83 | 84 | return '\n'.join(body) 85 | -------------------------------------------------------------------------------- /ddim/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/functions/__init__.py -------------------------------------------------------------------------------- /ddim/functions/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | 37 | def download(url, local_path, chunk_size=1024): 38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root=None, check=False): 56 | if 'church_outdoor' in name: 57 | name = name.replace('church_outdoor', 'church') 58 | assert name in URL_MAP 59 | # Modify the path when necessary 60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 61 | root = ( 62 | root 63 | if root is not None 64 | else os.path.join(cachedir, "diffusion_models_converted") 65 | ) 66 | path = os.path.join(root, CKPT_MAP[name]) 67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 69 | download(URL_MAP[name], path) 70 | md5 = md5_hash(path) 71 | assert md5 == MD5_MAP[name], md5 72 | return path 73 | -------------------------------------------------------------------------------- /ddim/functions/denoising.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_alpha(beta, t): 5 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 6 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 7 | return a 8 | 9 | 10 | def generalized_steps(x, seq, model, b, **kwargs): 11 | with torch.no_grad(): 12 | n = x.size(0) 13 | seq_next = [-1] + list(seq[:-1]) 14 | x0_preds = [] 15 | xs = [x] 16 | for i, j in zip(reversed(seq), reversed(seq_next)): 17 | t = (torch.ones(n) * i).to(x.device) 18 | next_t = (torch.ones(n) * j).to(x.device) 19 | at = compute_alpha(b, t.long()) 20 | at_next = compute_alpha(b, next_t.long()) 21 | xt = xs[-1].to('cuda') 22 | et = model(xt, t) 23 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 24 | x0_preds.append(x0_t.to('cpu')) 25 | c1 = ( 26 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 27 | ) 28 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 29 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 30 | xs.append(xt_next.to('cpu')) 31 | 32 | return xs, x0_preds 33 | 34 | 35 | def ddpm_steps(x, seq, model, b, **kwargs): 36 | with torch.no_grad(): 37 | n = x.size(0) 38 | seq_next = [-1] + list(seq[:-1]) 39 | xs = [x] 40 | x0_preds = [] 41 | betas = b 42 | for i, j in zip(reversed(seq), reversed(seq_next)): 43 | t = (torch.ones(n) * i).to(x.device) 44 | next_t = (torch.ones(n) * j).to(x.device) 45 | at = compute_alpha(betas, t.long()) 46 | atm1 = compute_alpha(betas, next_t.long()) 47 | beta_t = 1 - at / atm1 48 | x = xs[-1].to('cuda') 49 | 50 | output = model(x, t.float()) 51 | e = output 52 | 53 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e 54 | x0_from_e = torch.clamp(x0_from_e, -1, 1) 55 | x0_preds.append(x0_from_e.to('cpu')) 56 | mean_eps = ( 57 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x 58 | ) / (1.0 - at) 59 | 60 | mean = mean_eps 61 | noise = torch.randn_like(x) 62 | mask = 1 - (t == 0).float() 63 | mask = mask.view(-1, 1, 1, 1) 64 | logvar = beta_t.log() 65 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 66 | xs.append(sample.to('cpu')) 67 | return xs, x0_preds 68 | -------------------------------------------------------------------------------- /ddim/functions/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def noise_estimation_loss(model, 5 | x0: torch.Tensor, 6 | t: torch.LongTensor, 7 | e: torch.Tensor, 8 | b: torch.Tensor, keepdim=False): 9 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 10 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt() 11 | output = model(x, t.float()) 12 | if keepdim: 13 | return (e - output).square().sum(dim=(1, 2, 3)) 14 | else: 15 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 16 | 17 | 18 | loss_registry = { 19 | 'simple': noise_estimation_loss, 20 | } 21 | -------------------------------------------------------------------------------- /ddim/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/models/__init__.py -------------------------------------------------------------------------------- /ddim/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ddim/models/__pycache__/diffusion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/models/__pycache__/diffusion.cpython-37.pyc -------------------------------------------------------------------------------- /ddim/models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict 50 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: quest 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - python=3.8.5 8 | - pip=20.3 9 | - pytorch 10 | - pytorch-cuda=11.7 11 | - torchvision 12 | - torchaudio 13 | - numpy 14 | - pip: 15 | - albumentations==0.4.3 16 | - diffusers==0.3.0 17 | - pudb==2019.2 18 | - invisible-watermark 19 | - imageio==2.9.0 20 | - imageio-ffmpeg==0.4.2 21 | - test-tube>=0.7.5 22 | - streamlit>=0.73.1 23 | - torch-fidelity==0.3.0 24 | - torchmetrics==0.6.0 25 | - streamlit-drawable-canvas==0.8 26 | - einops==0.3.0 27 | - kornia==0.6.9 28 | - lmdb==1.3.0 29 | - natsort==8.3.1 30 | - omegaconf==2.1.1 31 | - opencv_python==4.1.2.30 32 | - opencv_python_headless==4.6.0.66 33 | - pandas==1.4.2 34 | - Pillow==9.0.1 35 | - pytorch_lightning==1.4.2 36 | - PyYAML==6.0 37 | - six==1.16.0 38 | - tqdm==4.64.0 39 | - transformers==4.22.2 40 | - omegaconf 41 | - git+https://github.com/openai/CLIP.git@main#egg=clip 42 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 43 | - -e . 44 | -------------------------------------------------------------------------------- /imgs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/imgs/fig1.png -------------------------------------------------------------------------------- /imgs/imagenet_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/imgs/imagenet_example.png -------------------------------------------------------------------------------- /imgs/sd_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/imgs/sd_images.png -------------------------------------------------------------------------------- /ldm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/__init__.py -------------------------------------------------------------------------------- /ldm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/__init__.py -------------------------------------------------------------------------------- /ldm/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/__pycache__/autoencoder.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__pycache__/ddim.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__pycache__/ddpm.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample(self, 23 | S, 24 | batch_size, 25 | shape, 26 | conditioning=None, 27 | callback=None, 28 | normals_sequence=None, 29 | img_callback=None, 30 | quantize_x0=False, 31 | eta=0., 32 | mask=None, 33 | x0=None, 34 | temperature=1., 35 | noise_dropout=0., 36 | score_corrector=None, 37 | corrector_kwargs=None, 38 | verbose=True, 39 | x_T=None, 40 | log_every_t=100, 41 | unconditional_guidance_scale=1., 42 | unconditional_conditioning=None, 43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 44 | **kwargs 45 | ): 46 | if conditioning is not None: 47 | if isinstance(conditioning, dict): 48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 49 | if cbs != batch_size: 50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 51 | else: 52 | if conditioning.shape[0] != batch_size: 53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 54 | 55 | # sampling 56 | C, H, W = shape 57 | size = (batch_size, C, H, W) 58 | 59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 60 | 61 | device = self.model.betas.device 62 | if x_T is None: 63 | img = torch.randn(size, device=device) 64 | else: 65 | img = x_T 66 | 67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 68 | 69 | model_fn = model_wrapper( 70 | lambda x, t, c: self.model.apply_model(x, t, c), 71 | ns, 72 | model_type="noise", 73 | guidance_type="classifier-free", 74 | condition=conditioning, 75 | unconditional_condition=unconditional_conditioning, 76 | guidance_scale=unconditional_guidance_scale, 77 | ) 78 | 79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 81 | 82 | return x.to(device), None 83 | -------------------------------------------------------------------------------- /ldm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__pycache__/ema.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.qk_matmul = CrossQKMatMul(self.scale) 166 | self.smv_matmul = CrossSMVMatMul() 167 | 168 | self.to_out = nn.Sequential( 169 | nn.Linear(inner_dim, query_dim), 170 | nn.Dropout(dropout) 171 | ) 172 | 173 | def forward(self, x, context=None, mask=None): 174 | h = self.heads 175 | 176 | q = self.to_q(x) 177 | context = default(context, x) 178 | k = self.to_k(context) 179 | v = self.to_v(context) 180 | 181 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 182 | 183 | # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 184 | sim = self.qk_matmul(q, k) 185 | 186 | if exists(mask): 187 | mask = rearrange(mask, 'b ... -> b (...)') 188 | max_neg_value = -torch.finfo(sim.dtype).max 189 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 190 | sim.masked_fill_(~mask, max_neg_value) 191 | 192 | # attention, what we cannot get enough of 193 | attn = sim.softmax(dim=-1) 194 | 195 | # out = einsum('b i j, b j d -> b i d', attn, v) 196 | out = self.smv_matmul(attn, v) 197 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 198 | return self.to_out(out) 199 | 200 | 201 | class CrossQKMatMul(nn.Module): 202 | 203 | def __init__(self, scale): 204 | super().__init__() 205 | self.scale = scale 206 | 207 | def forward(self, q, k): 208 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 209 | return sim 210 | 211 | 212 | class CrossSMVMatMul(nn.Module): 213 | 214 | def __init__(self): 215 | super().__init__() 216 | 217 | def forward(self, attn, v): 218 | out = einsum('b i j, b j d -> b i d', attn, v) 219 | return out 220 | 221 | 222 | class BasicTransformerBlock(nn.Module): 223 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 224 | super().__init__() 225 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 226 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 227 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 228 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 229 | self.norm1 = nn.LayerNorm(dim) 230 | self.norm2 = nn.LayerNorm(dim) 231 | self.norm3 = nn.LayerNorm(dim) 232 | self.checkpoint = checkpoint 233 | 234 | def forward(self, x, context=None): 235 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 236 | 237 | def _forward(self, x, context=None): 238 | x = self.attn1(self.norm1(x)) + x 239 | x = self.attn2(self.norm2(x), context=context) + x 240 | x = self.ff(self.norm3(x)) + x 241 | return x 242 | 243 | 244 | class SpatialTransformer(nn.Module): 245 | """ 246 | Transformer block for image-like data. 247 | First, project the input (aka embedding) 248 | and reshape to b, t, d. 249 | Then apply standard transformer action. 250 | Finally, reshape to image 251 | """ 252 | def __init__(self, in_channels, n_heads, d_head, 253 | depth=1, dropout=0., context_dim=None): 254 | super().__init__() 255 | self.in_channels = in_channels 256 | inner_dim = n_heads * d_head 257 | self.norm = Normalize(in_channels) 258 | 259 | self.proj_in = nn.Conv2d(in_channels, 260 | inner_dim, 261 | kernel_size=1, 262 | stride=1, 263 | padding=0) 264 | 265 | self.transformer_blocks = nn.ModuleList( 266 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 267 | for d in range(depth)] 268 | ) 269 | 270 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 271 | in_channels, 272 | kernel_size=1, 273 | stride=1, 274 | padding=0)) 275 | 276 | def forward(self, x, context=None): 277 | # note: if no context is given, cross-attention defaults to self-attention 278 | b, c, h, w = x.shape 279 | x_in = x 280 | x = self.norm(x) 281 | x = self.proj_in(x) 282 | x = rearrange(x, 'b c h w -> b (h w) c') 283 | for block in self.transformer_blocks: 284 | x = block(x, context) 285 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 286 | x = self.proj_out(x) 287 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | # print(key) 52 | # try: 53 | # print(self.m_name2s_name[key]) 54 | # except: 55 | # print(self.m_name2s_name.keys()) 56 | shadow_key = key.replace('.model.','.') if '.model.' in key else key 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[shadow_key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | import kornia 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | class FrozenCLIPEmbedder(AbstractEncoder): 138 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 140 | super().__init__() 141 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 142 | self.transformer = CLIPTextModel.from_pretrained(version) 143 | self.device = device 144 | self.max_length = max_length 145 | self.freeze() 146 | 147 | def freeze(self): 148 | self.transformer = self.transformer.eval() 149 | for param in self.parameters(): 150 | param.requires_grad = False 151 | 152 | def forward(self, text): 153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 155 | tokens = batch_encoding["input_ids"].to(self.device) 156 | outputs = self.transformer(input_ids=tokens) 157 | 158 | z = outputs.last_hidden_state 159 | return z 160 | 161 | def encode(self, text): 162 | return self(text) 163 | 164 | 165 | class FrozenCLIPTextEmbedder(nn.Module): 166 | """ 167 | Uses the CLIP transformer encoder for text. 168 | """ 169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 170 | super().__init__() 171 | self.model, _ = clip.load(version, jit=False, device="cpu") 172 | self.device = device 173 | self.max_length = max_length 174 | self.n_repeat = n_repeat 175 | self.normalize = normalize 176 | 177 | def freeze(self): 178 | self.model = self.model.eval() 179 | for param in self.parameters(): 180 | param.requires_grad = False 181 | 182 | def forward(self, text): 183 | tokens = clip.tokenize(text).to(self.device) 184 | z = self.model.encode_text(tokens) 185 | if self.normalize: 186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 187 | return z 188 | 189 | def encode(self, text): 190 | z = self(text) 191 | if z.ndim==2: 192 | z = z[:, None, :] 193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 194 | return z 195 | 196 | 197 | class FrozenClipImageEmbedder(nn.Module): 198 | """ 199 | Uses the CLIP image encoder. 200 | """ 201 | def __init__( 202 | self, 203 | model, 204 | jit=False, 205 | device='cuda' if torch.cuda.is_available() else 'cpu', 206 | antialias=False, 207 | ): 208 | super().__init__() 209 | self.model, _ = clip.load(name=model, device=device, jit=jit) 210 | 211 | self.antialias = antialias 212 | 213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 215 | 216 | def preprocess(self, x): 217 | # normalize to [0,1] 218 | x = kornia.geometry.resize(x, (224, 224), 219 | interpolation='bicubic',align_corners=True, 220 | antialias=self.antialias) 221 | x = (x + 1.) / 2. 222 | # renormalize according to clip 223 | x = kornia.enhance.normalize(x, self.mean, self.std) 224 | return x 225 | 226 | def forward(self, x): 227 | # x is assumed to be in range [-1,1] 228 | return self.model.encode_image(self.preprocess(x)) 229 | 230 | 231 | if __name__ == "__main__": 232 | from ldm.util import count_params 233 | model = FrozenCLIPEmbedder() 234 | count_params(model, verbose=True) -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | import logging 9 | 10 | import multiprocessing as mp 11 | from threading import Thread 12 | from queue import Queue 13 | 14 | from inspect import isfunction 15 | from PIL import Image, ImageDraw, ImageFont 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def log_txt_as_img(wh, xc, size=10): 21 | # wh a tuple of (width, height) 22 | # xc a list of captions to plot 23 | b = len(xc) 24 | txts = list() 25 | for bi in range(b): 26 | txt = Image.new("RGB", wh, color="white") 27 | draw = ImageDraw.Draw(txt) 28 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 29 | nc = int(40 * (wh[0] / 256)) 30 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 31 | 32 | try: 33 | draw.text((0, 0), lines, fill="black", font=font) 34 | except UnicodeEncodeError: 35 | print("Cant encode string for logging. Skipping.") 36 | 37 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 38 | txts.append(txt) 39 | txts = np.stack(txts) 40 | txts = torch.tensor(txts) 41 | return txts 42 | 43 | 44 | def ismap(x): 45 | if not isinstance(x, torch.Tensor): 46 | return False 47 | return (len(x.shape) == 4) and (x.shape[1] > 3) 48 | 49 | 50 | def isimage(x): 51 | if not isinstance(x, torch.Tensor): 52 | return False 53 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 54 | 55 | 56 | def exists(x): 57 | return x is not None 58 | 59 | 60 | def default(val, d): 61 | if exists(val): 62 | return val 63 | return d() if isfunction(d) else d 64 | 65 | 66 | def mean_flat(tensor): 67 | """ 68 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 69 | Take the mean over all non-batch dimensions. 70 | """ 71 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 72 | 73 | 74 | def count_params(model, verbose=False): 75 | total_params = sum(p.numel() for p in model.parameters()) 76 | if verbose: 77 | logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 78 | return total_params 79 | 80 | 81 | def instantiate_from_config(config): 82 | if not "target" in config: 83 | if config == '__is_first_stage__': 84 | return None 85 | elif config == "__is_unconditional__": 86 | return None 87 | raise KeyError("Expected key `target` to instantiate.") 88 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 89 | 90 | 91 | def get_obj_from_str(string, reload=False): 92 | module, cls = string.rsplit(".", 1) 93 | if reload: 94 | module_imp = importlib.import_module(module) 95 | importlib.reload(module_imp) 96 | return getattr(importlib.import_module(module, package=None), cls) 97 | 98 | 99 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 100 | # create dummy dataset instance 101 | 102 | # run prefetching 103 | if idx_to_fn: 104 | res = func(data, worker_id=idx) 105 | else: 106 | res = func(data) 107 | Q.put([idx, res]) 108 | Q.put("Done") 109 | 110 | 111 | def parallel_data_prefetch( 112 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 113 | ): 114 | # if target_data_type not in ["ndarray", "list"]: 115 | # raise ValueError( 116 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 117 | # ) 118 | if isinstance(data, np.ndarray) and target_data_type == "list": 119 | raise ValueError("list expected but function got ndarray.") 120 | elif isinstance(data, abc.Iterable): 121 | if isinstance(data, dict): 122 | print( 123 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 124 | ) 125 | data = list(data.values()) 126 | if target_data_type == "ndarray": 127 | data = np.asarray(data) 128 | else: 129 | data = list(data) 130 | else: 131 | raise TypeError( 132 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 133 | ) 134 | 135 | if cpu_intensive: 136 | Q = mp.Queue(1000) 137 | proc = mp.Process 138 | else: 139 | Q = Queue(1000) 140 | proc = Thread 141 | # spawn processes 142 | if target_data_type == "ndarray": 143 | arguments = [ 144 | [func, Q, part, i, use_worker_id] 145 | for i, part in enumerate(np.array_split(data, n_proc)) 146 | ] 147 | else: 148 | step = ( 149 | int(len(data) / n_proc + 1) 150 | if len(data) % n_proc != 0 151 | else int(len(data) / n_proc) 152 | ) 153 | arguments = [ 154 | [func, Q, part, i, use_worker_id] 155 | for i, part in enumerate( 156 | [data[i: i + step] for i in range(0, len(data), step)] 157 | ) 158 | ] 159 | processes = [] 160 | for i in range(n_proc): 161 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 162 | processes += [p] 163 | 164 | # start processes 165 | print(f"Start prefetching...") 166 | import time 167 | 168 | start = time.time() 169 | gather_res = [[] for _ in range(n_proc)] 170 | try: 171 | for p in processes: 172 | p.start() 173 | 174 | k = 0 175 | while k < n_proc: 176 | # get result 177 | res = Q.get() 178 | if res == "Done": 179 | k += 1 180 | else: 181 | gather_res[res[0]] = res[1] 182 | 183 | except Exception as e: 184 | print("Exception: ", e) 185 | for p in processes: 186 | p.terminate() 187 | 188 | raise e 189 | finally: 190 | for p in processes: 191 | p.join() 192 | print(f"Prefetching complete. [{time.time() - start} sec.]") 193 | 194 | if target_data_type == 'ndarray': 195 | if not isinstance(gather_res[0], np.ndarray): 196 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 197 | 198 | # order outputs 199 | return np.concatenate(gather_res, axis=0) 200 | elif target_data_type == 'list': 201 | out = [] 202 | for r in gather_res: 203 | out.extend(r) 204 | return out 205 | else: 206 | return gather_res 207 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/cin256-v2/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /qdiff/__init__.py: -------------------------------------------------------------------------------- 1 | from qdiff.block_recon import block_reconstruction 2 | from qdiff.layer_recon import layer_reconstruction 3 | from qdiff.quant_block import BaseQuantBlock, QuantSMVMatMul, QuantQKMatMul, QuantBasicTransformerBlock 4 | from qdiff.quant_layer import QuantModule 5 | from qdiff.quant_model import QuantModel -------------------------------------------------------------------------------- /qdiff/adaptive_rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import logging 4 | from qdiff.quant_layer import UniformAffineQuantizer, round_ste, floor_ste 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class AdaRoundQuantizer(nn.Module): 10 | """ 11 | Adaptive Rounding Quantizer, used to optimize the rounding policy 12 | by reconstructing the intermediate output. 13 | Based on 14 | Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568 15 | 16 | :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer 17 | :param round_mode: controls the forward pass in this quantizer 18 | :param weight_tensor: initialize alpha 19 | """ 20 | 21 | def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'): 22 | super(AdaRoundQuantizer, self).__init__() 23 | # copying all attributes from UniformAffineQuantizer 24 | self.n_bits = uaq.n_bits 25 | self.sym = uaq.sym 26 | self.delta = nn.Parameter(uaq.delta) 27 | self.zero_point = nn.Parameter(uaq.zero_point) 28 | self.n_levels = uaq.n_levels 29 | 30 | self.round_mode = round_mode 31 | self.alpha = None 32 | self.soft_targets = False 33 | 34 | # params for sigmoid function 35 | self.gamma, self.zeta = -0.1, 1.1 36 | self.beta = 2/3 37 | self.init_alpha(x=weight_tensor.clone()) 38 | 39 | def forward(self, x): 40 | if self.round_mode == 'nearest': 41 | x_int = round_ste(x / self.delta) 42 | elif self.round_mode == 'nearest_ste': 43 | x_int = round_ste(x / self.delta) 44 | elif self.round_mode == 'stochastic': 45 | # x_floor = torch.floor(x / self.delta) 46 | x_floor = floor_ste(x / self.delta) 47 | rest = (x / self.delta) - x_floor # rest of rounding 48 | x_int = x_floor + torch.bernoulli(rest) 49 | logger.info('Draw stochastic sample') 50 | elif self.round_mode == 'learned_hard_sigmoid': 51 | # x_floor = torch.floor(x / self.delta) 52 | x_floor = floor_ste(x / self.delta) 53 | if self.soft_targets: 54 | x_int = x_floor + self.get_soft_targets() 55 | else: 56 | x_int = x_floor + (self.alpha >= 0).float() 57 | else: 58 | raise ValueError('Wrong rounding mode') 59 | 60 | x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1) 61 | x_float_q = (x_quant - self.zero_point) * self.delta 62 | 63 | return x_float_q 64 | 65 | def get_soft_targets(self): 66 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1) 67 | 68 | def init_alpha(self, x: torch.Tensor): 69 | x_floor = torch.floor(x / self.delta) 70 | if self.round_mode == 'learned_hard_sigmoid': 71 | # logger.info('Init alpha to be FP32') 72 | rest = (x / self.delta) - x_floor # rest of rounding [0, 1) 73 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest 74 | self.alpha = nn.Parameter(alpha) 75 | else: 76 | raise NotImplementedError 77 | 78 | def extra_repr(self): 79 | s = 'bit={n_bits}, symmetric={sym}, round_mode={round_mode}' 80 | return s.format(**self.__dict__) 81 | -------------------------------------------------------------------------------- /qdiff/block_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import linklink as link 3 | import logging 4 | from qdiff.quant_layer import QuantModule, StraightThrough, lp_loss 5 | from qdiff.quant_model import QuantModel 6 | from qdiff.quant_block import BaseQuantBlock 7 | from qdiff.adaptive_rounding import AdaRoundQuantizer 8 | from qdiff.utils import save_grad_data, save_inp_oup_data 9 | import os 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def block_reconstruction(model: QuantModel, block: BaseQuantBlock, cali_data: torch.Tensor, 15 | batch_size: int = 32, iters: int = 20000, weight: float = 0.01, opt_mode: str = 'mse', 16 | asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2), 17 | warmup: float = 0.0, act_quant: bool = False, lr: float = 4e-5, p: float = 2.0, 18 | multi_gpu: bool = False, cond: bool = False, is_sm: bool = False, outpath: str = None): 19 | """ 20 | Block reconstruction to optimize the output from each block. 21 | 22 | :param model: QuantModel 23 | :param block: BaseQuantBlock that needs to be optimized 24 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound 25 | :param batch_size: mini-batch size for reconstruction 26 | :param iters: optimization iterations for reconstruction, 27 | :param weight: the weight of rounding regularization term 28 | :param opt_mode: optimization mode 29 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output 30 | :param include_act_func: optimize the output after activation function 31 | :param b_range: temperature range 32 | :param warmup: proportion of iterations that no scheduling for temperature 33 | :param act_quant: use activation quantization or not. 34 | :param lr: learning rate for act delta learning 35 | :param p: L_p norm minimization 36 | :param multi_gpu: use multi-GPU or not, if enabled, we should sync the gradients 37 | :param cond: conditional generation or not 38 | :param is_sm: avoid OOM when caching n^2 attention matrix when n is large 39 | """ 40 | model.set_quant_state(False, False) 41 | block.set_quant_state(True, act_quant) 42 | round_mode = 'learned_hard_sigmoid' 43 | 44 | if not include_act_func: 45 | org_act_func = block.activation_function 46 | block.activation_function = StraightThrough() 47 | 48 | if not act_quant: 49 | # Replace weight quantizer to AdaRoundQuantizer 50 | for name, module in block.named_modules(): 51 | if isinstance(module, QuantModule): 52 | if module.split != 0: 53 | module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode, 54 | weight_tensor=module.org_weight.data[:, :module.split, ...]) 55 | module.weight_quantizer_0 = AdaRoundQuantizer(uaq=module.weight_quantizer_0, round_mode=round_mode, 56 | weight_tensor=module.org_weight.data[:, module.split:, ...]) 57 | else: 58 | module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode, 59 | weight_tensor=module.org_weight.data) 60 | module.weight_quantizer.soft_targets = True 61 | if module.split != 0: 62 | module.weight_quantizer_0.soft_targets = True 63 | 64 | # Set up optimizer 65 | opt_params = [] 66 | for name, module in block.named_modules(): 67 | if isinstance(module, QuantModule): 68 | opt_params += [module.weight_quantizer.alpha] 69 | if module.split != 0: 70 | opt_params += [module.weight_quantizer_0.alpha] 71 | optimizer = torch.optim.Adam(opt_params) 72 | scheduler = None 73 | else: 74 | # Use UniformAffineQuantizer to learn delta 75 | if hasattr(block.act_quantizer, 'delta') and block.act_quantizer.delta is not None: 76 | opt_params = [block.act_quantizer.delta] 77 | else: 78 | opt_params = [] 79 | 80 | if hasattr(block, 'attn1'): 81 | opt_params += [ 82 | block.attn1.act_quantizer_q.delta, 83 | block.attn1.act_quantizer_k.delta, 84 | block.attn1.act_quantizer_v.delta, 85 | block.attn2.act_quantizer_q.delta, 86 | block.attn2.act_quantizer_k.delta, 87 | block.attn2.act_quantizer_v.delta] 88 | if block.attn1.act_quantizer_w.n_bits != 16: 89 | opt_params += [block.attn1.act_quantizer_w.delta] 90 | if block.attn2.act_quantizer_w.n_bits != 16: 91 | opt_params += [block.attn2.act_quantizer_w.delta] 92 | if hasattr(block, 'act_quantizer_q'): 93 | opt_params += [ 94 | block.act_quantizer_q.delta, 95 | block.act_quantizer_k.delta] 96 | if hasattr(block, 'act_quantizer_w'): 97 | opt_params += [block.act_quantizer_v.delta] 98 | if block.act_quantizer_w.n_bits != 16: 99 | opt_params += [block.act_quantizer_w.delta] 100 | 101 | for name, module in block.named_modules(): 102 | if isinstance(module, QuantModule): 103 | if module.act_quantizer.delta is not None: 104 | opt_params += [module.act_quantizer.delta] 105 | if module.split != 0 and module.act_quantizer_0.delta is not None: 106 | opt_params += [module.act_quantizer_0.delta] 107 | 108 | optimizer = torch.optim.Adam(opt_params, lr=lr) 109 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.) 110 | 111 | loss_mode = 'none' if act_quant else 'relaxation' 112 | rec_loss = opt_mode 113 | 114 | loss_func = LossFunction(block, round_loss=loss_mode, weight=weight, max_count=iters, rec_loss=rec_loss, 115 | b_range=b_range, decay_start=0, warmup=warmup, p=p) 116 | 117 | # Save data before optimizing the rounding 118 | print(f"cond {cond}") 119 | num_split = 10 120 | b_size = cali_data[0].shape[0] // num_split 121 | for k in range(num_split): 122 | logger.info(f"Saving {num_split} intermediate results to disk to avoid OOM") 123 | if cond: 124 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size], cali_data[2][k*b_size:(k+1)*b_size]) 125 | else: 126 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size]) 127 | cached_inps, cached_outs = save_inp_oup_data( 128 | model, block, cali_data_t, asym, act_quant, batch_size=8, keep_gpu=False, cond=cond, is_sm=is_sm) 129 | cached_path = os.path.join(outpath, 'tmp_cached/') 130 | if not os.path.exists(cached_path): 131 | os.makedirs(cached_path) 132 | torch.save(cached_inps, os.path.join(cached_path, f'cached_inps_t{k}.pt')) 133 | torch.save(cached_outs, os.path.join(cached_path, f'cached_outs_t{k}.pt')) 134 | # cached_inps, cached_outs = save_inp_oup_data( 135 | # model, block, cali_data, asym, act_quant, 8, keep_gpu=False, cond=cond, is_sm=is_sm) 136 | 137 | if opt_mode != 'mse': 138 | cached_grads = save_grad_data(model, block, cali_data, act_quant, batch_size=batch_size) 139 | else: 140 | cached_grads = None 141 | device = 'cuda' 142 | 143 | for k in range(num_split): 144 | cached_inps = torch.load(os.path.join(cached_path, f'cached_inps_t{k}.pt')) 145 | cached_outs = torch.load(os.path.join(cached_path, f'cached_outs_t{k}.pt')) 146 | for i in range(iters // num_split): 147 | if isinstance(cached_inps, list): 148 | idx = torch.randperm(cached_inps[0].size(0))[:batch_size] 149 | cur_x = cached_inps[0][idx].to(device) 150 | cur_t = cached_inps[1][idx].to(device) 151 | cur_inp = (cur_x, cur_t) 152 | else: 153 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 154 | cur_inp = cached_inps[idx].to(device) 155 | cur_out = cached_outs[idx].to(device) 156 | cur_grad = cached_grads[idx].to(device) if opt_mode != 'mse' else None 157 | 158 | optimizer.zero_grad() 159 | if isinstance(cur_inp, tuple): 160 | out_quant = block(cur_inp[0], cur_inp[1]) 161 | else: 162 | out_quant = block(cur_inp) 163 | 164 | err = loss_func(out_quant, cur_out, cur_grad) 165 | err.backward(retain_graph=True) 166 | if multi_gpu: 167 | raise NotImplementedError 168 | # for p in opt_params: 169 | # link.allreduce(p.grad) 170 | optimizer.step() 171 | if scheduler: 172 | scheduler.step() 173 | torch.cuda.empty_cache() 174 | 175 | # Finish optimization, use hard rounding. 176 | for name, module in block.named_modules(): 177 | if isinstance(module, QuantModule): 178 | module.weight_quantizer.soft_targets = False 179 | if module.split != 0: 180 | module.weight_quantizer_0.soft_targets = False 181 | 182 | # Reset original activation function 183 | if not include_act_func: 184 | block.activation_function = org_act_func 185 | 186 | 187 | class LossFunction: 188 | def __init__(self, 189 | block: BaseQuantBlock, 190 | round_loss: str = 'relaxation', 191 | weight: float = 1., 192 | rec_loss: str = 'mse', 193 | max_count: int = 2000, 194 | b_range: tuple = (10, 2), 195 | decay_start: float = 0.0, 196 | warmup: float = 0.0, 197 | p: float = 2.): 198 | 199 | self.block = block 200 | self.round_loss = round_loss 201 | self.weight = weight 202 | self.rec_loss = rec_loss 203 | self.loss_start = max_count * warmup 204 | self.p = p 205 | 206 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start, 207 | start_b=b_range[0], end_b=b_range[1]) 208 | self.count = 0 209 | 210 | def __call__(self, pred, tgt, grad=None): 211 | """ 212 | Compute the total loss for adaptive rounding: 213 | rec_loss is the quadratic output reconstruction loss, round_loss is 214 | a regularization term to optimize the rounding policy 215 | 216 | :param pred: output from quantized model 217 | :param tgt: output from FP model 218 | :param grad: gradients to compute fisher information 219 | :return: total loss function 220 | """ 221 | self.count += 1 222 | if self.rec_loss == 'mse': 223 | rec_loss = lp_loss(pred, tgt, p=self.p) 224 | elif self.rec_loss == 'fisher_diag': 225 | rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() 226 | elif self.rec_loss == 'fisher_full': 227 | a = (pred - tgt).abs() 228 | grad = grad.abs() 229 | batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) 230 | rec_loss = (batch_dotprod * a * grad).mean() / 100 231 | else: 232 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) 233 | 234 | b = self.temp_decay(self.count) 235 | if self.count < self.loss_start or self.round_loss == 'none': 236 | b = round_loss = 0 237 | elif self.round_loss == 'relaxation': 238 | round_loss = 0 239 | for name, module in self.block.named_modules(): 240 | if isinstance(module, QuantModule): 241 | round_vals = module.weight_quantizer.get_soft_targets() 242 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() 243 | else: 244 | raise NotImplementedError 245 | 246 | total_loss = rec_loss + round_loss 247 | if self.count % 500 == 0: 248 | logger.info('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format( 249 | float(total_loss), float(rec_loss), float(round_loss), b, self.count)) 250 | return total_loss 251 | 252 | 253 | class LinearTempDecay: 254 | def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): 255 | self.t_max = t_max 256 | self.start_decay = rel_start_decay * t_max 257 | self.start_b = start_b 258 | self.end_b = end_b 259 | 260 | def __call__(self, t): 261 | """ 262 | Cosine annealing scheduler for temperature b. 263 | :param t: the current time step 264 | :return: scheduled temperature 265 | """ 266 | if t < self.start_decay: 267 | return self.start_b 268 | else: 269 | rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) 270 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) 271 | -------------------------------------------------------------------------------- /qdiff/layer_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import linklink as link 3 | import logging 4 | from qdiff.quant_layer import QuantModule, StraightThrough, lp_loss 5 | from qdiff.quant_model import QuantModel 6 | from qdiff.block_recon import LinearTempDecay 7 | from qdiff.adaptive_rounding import AdaRoundQuantizer 8 | from qdiff.utils import save_grad_data, save_inp_oup_data 9 | import os 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def layer_reconstruction(model: QuantModel, layer: QuantModule, cali_data: torch.Tensor, 15 | batch_size: int = 32, iters: int = 20000, weight: float = 0.001, opt_mode: str = 'mse', 16 | asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2), 17 | warmup: float = 0.0, act_quant: bool = False, lr: float = 4e-5, p: float = 2.0, 18 | multi_gpu: bool = False, cond: bool = False, is_sm: bool = False, outpath: str = None): 19 | """ 20 | Block reconstruction to optimize the output from each layer. 21 | 22 | :param model: QuantModel 23 | :param layer: QuantModule that needs to be optimized 24 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound 25 | :param batch_size: mini-batch size for reconstruction 26 | :param iters: optimization iterations for reconstruction, 27 | :param weight: the weight of rounding regularization term 28 | :param opt_mode: optimization mode 29 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output 30 | :param include_act_func: optimize the output after activation function 31 | :param b_range: temperature range 32 | :param warmup: proportion of iterations that no scheduling for temperature 33 | :param act_quant: use activation quantization or not. 34 | :param lr: learning rate for act delta learning 35 | :param p: L_p norm minimization 36 | :param multi_gpu: use multi-GPU or not, if enabled, we should sync the gradients 37 | :param cond: conditional generation or not 38 | :param is_sm: avoid OOM when caching n^2 attention matrix when n is large 39 | """ 40 | 41 | model.set_quant_state(False, False) 42 | layer.set_quant_state(True, act_quant) 43 | round_mode = 'learned_hard_sigmoid' 44 | 45 | if not include_act_func: 46 | org_act_func = layer.activation_function 47 | layer.activation_function = StraightThrough() 48 | 49 | if not act_quant: 50 | # Replace weight quantizer to AdaRoundQuantizer 51 | if layer.split != 0: 52 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode, 53 | weight_tensor=layer.org_weight.data[:, :layer.split, ...]) 54 | layer.weight_quantizer_0 = AdaRoundQuantizer(uaq=layer.weight_quantizer_0, round_mode=round_mode, 55 | weight_tensor=layer.org_weight.data[:, layer.split:, ...]) 56 | else: 57 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode, 58 | weight_tensor=layer.org_weight.data) 59 | layer.weight_quantizer.soft_targets = True 60 | 61 | # Set up optimizer 62 | opt_params = [layer.weight_quantizer.alpha] 63 | if layer.split != 0: 64 | opt_params += [layer.weight_quantizer_0.alpha] 65 | optimizer = torch.optim.Adam(opt_params) 66 | scheduler = None 67 | else: 68 | # Use UniformAffineQuantizer to learn delta 69 | opt_params = [layer.act_quantizer.delta] 70 | if layer.split != 0 and layer.act_quantizer_0.delta is not None: 71 | opt_params += [layer.act_quantizer_0.delta] 72 | optimizer = torch.optim.Adam(opt_params, lr=lr) 73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.) 74 | 75 | loss_mode = 'none' if act_quant else 'relaxation' 76 | rec_loss = opt_mode 77 | 78 | loss_func = LossFunction(layer, round_loss=loss_mode, weight=weight, 79 | max_count=iters, rec_loss=rec_loss, b_range=b_range, 80 | decay_start=0, warmup=warmup, p=p) 81 | 82 | # Save data before optimizing the rounding 83 | num_split = 10 84 | b_size = cali_data[0].shape[0] // num_split 85 | for k in range(num_split): 86 | logger.info(f"Saving {num_split} intermediate results to disk to avoid OOM") 87 | if cond: 88 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size], cali_data[2][k*b_size:(k+1)*b_size]) 89 | else: 90 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size]) 91 | cached_inps, cached_outs = save_inp_oup_data( 92 | model, layer, cali_data_t, asym, act_quant, batch_size=8, keep_gpu=False, cond=cond, is_sm=is_sm) 93 | cached_path = os.path.join(outpath, 'tmp_cached/') 94 | if not os.path.exists(cached_path): 95 | os.makedirs(cached_path) 96 | torch.save(cached_inps, os.path.join(cached_path, f'cached_inps_t{k}.pt')) 97 | torch.save(cached_outs, os.path.join(cached_path, f'cached_outs_t{k}.pt')) 98 | # cached_inps, cached_outs = save_inp_oup_data( 99 | # model, layer, cali_data, asym, act_quant, 8, keep_gpu=False, cond=cond, is_sm=is_sm) 100 | 101 | if opt_mode != 'mse': 102 | cached_grads = save_grad_data(model, layer, cali_data, act_quant, batch_size=batch_size) 103 | else: 104 | cached_grads = None 105 | device = 'cuda' 106 | for k in range(num_split): 107 | cached_inps = torch.load(os.path.join(cached_path, f'cached_inps_t{k}.pt')) 108 | cached_outs = torch.load(os.path.join(cached_path, f'cached_outs_t{k}.pt')) 109 | for i in range(iters // num_split): 110 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 111 | cur_inp = cached_inps[idx].to(device) 112 | cur_out = cached_outs[idx].to(device) 113 | cur_grad = cached_grads[idx] if opt_mode != 'mse' else None 114 | 115 | optimizer.zero_grad() 116 | out_quant = layer(cur_inp) 117 | 118 | err = loss_func(out_quant, cur_out, cur_grad) 119 | err.backward(retain_graph=True) 120 | if multi_gpu: 121 | raise NotImplementedError 122 | # for p in opt_params: 123 | # link.allreduce(p.grad) 124 | optimizer.step() 125 | if scheduler: 126 | scheduler.step() 127 | 128 | torch.cuda.empty_cache() 129 | 130 | # Finish optimization, use hard rounding. 131 | layer.weight_quantizer.soft_targets = False 132 | if layer.split != 0: 133 | layer.weight_quantizer_0.soft_targets = False 134 | 135 | # Reset original activation function 136 | if not include_act_func: 137 | layer.activation_function = org_act_func 138 | 139 | 140 | class LossFunction: 141 | def __init__(self, 142 | layer: QuantModule, 143 | round_loss: str = 'relaxation', 144 | weight: float = 1., 145 | rec_loss: str = 'mse', 146 | max_count: int = 2000, 147 | b_range: tuple = (10, 2), 148 | decay_start: float = 0.0, 149 | warmup: float = 0.0, 150 | p: float = 2.): 151 | 152 | self.layer = layer 153 | self.round_loss = round_loss 154 | self.weight = weight 155 | self.rec_loss = rec_loss 156 | self.loss_start = max_count * warmup 157 | self.p = p 158 | 159 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start, 160 | start_b=b_range[0], end_b=b_range[1]) 161 | self.count = 0 162 | 163 | def __call__(self, pred, tgt, grad=None): 164 | """ 165 | Compute the total loss for adaptive rounding: 166 | rec_loss is the quadratic output reconstruction loss, round_loss is 167 | a regularization term to optimize the rounding policy 168 | 169 | :param pred: output from quantized model 170 | :param tgt: output from FP model 171 | :param grad: gradients to compute fisher information 172 | :return: total loss function 173 | """ 174 | self.count += 1 175 | if self.rec_loss == 'mse': 176 | rec_loss = lp_loss(pred, tgt, p=self.p) 177 | elif self.rec_loss == 'fisher_diag': 178 | rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() 179 | elif self.rec_loss == 'fisher_full': 180 | a = (pred - tgt).abs() 181 | grad = grad.abs() 182 | batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) 183 | rec_loss = (batch_dotprod * a * grad).mean() / 100 184 | else: 185 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) 186 | 187 | b = self.temp_decay(self.count) 188 | if self.count < self.loss_start or self.round_loss == 'none': 189 | b = round_loss = 0 190 | elif self.round_loss == 'relaxation': 191 | round_loss = 0 192 | round_vals = self.layer.weight_quantizer.get_soft_targets() 193 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() 194 | else: 195 | raise NotImplementedError 196 | 197 | total_loss = rec_loss + round_loss 198 | if self.count % 500 == 0: 199 | logger.info('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format( 200 | float(total_loss), float(rec_loss), float(round_loss), b, self.count)) 201 | return total_loss 202 | 203 | -------------------------------------------------------------------------------- /qdiff/quant_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | from qdiff.quant_block import get_specials, BaseQuantBlock 4 | from qdiff.quant_block import QuantBasicTransformerBlock, QuantResBlock 5 | from qdiff.quant_block import QuantQKMatMul, QuantSMVMatMul, QuantBasicTransformerBlock#, QuantAttnBlock 6 | from qdiff.quant_layer import QuantModule, StraightThrough, TimewiseUniformQuantizer 7 | from ldm.modules.attention import BasicTransformerBlock 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class QuantModel(nn.Module): 13 | 14 | def __init__(self, model: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | timewise = kwargs.get('timewise', True) 18 | self.timewise = timewise 19 | list_timesteps = kwargs.get('list_timesteps', 50) 20 | self.timesteps = list_timesteps 21 | self.sm_abit = kwargs.get('sm_abit', 8) 22 | self.in_channels = model.in_channels 23 | if hasattr(model, 'image_size'): 24 | self.image_size = model.image_size 25 | self.specials = get_specials(act_quant_params['leaf_param']) 26 | self.quant_module_refactor(self.model, weight_quant_params, act_quant_params, timewise, list_timesteps) 27 | self.quant_block_refactor(self.model, weight_quant_params, act_quant_params, timewise, list_timesteps) 28 | 29 | def quant_module_refactor(self, module, weight_quant_params, act_quant_params, timewise, list_timesteps): 30 | """ 31 | Recursively replace the normal layers (conv2D, conv1D, Linear etc.) to QuantModule 32 | :param module: nn.Module with nn.Conv2d, nn.Conv1d, or nn.Linear in its children 33 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer 34 | :param act_quant_params: quantization parameters like n_bits for activation quantizer 35 | """ 36 | prev_quantmodule = None 37 | for name, child_module in module.named_children(): 38 | if isinstance(child_module, (nn.Conv2d, nn.Conv1d, nn.Linear)): # nn.Conv1d 39 | setattr(module, name, QuantModule( 40 | child_module, weight_quant_params, act_quant_params, timewise=timewise, list_timesteps=list_timesteps)) 41 | prev_quantmodule = getattr(module, name) 42 | 43 | elif isinstance(child_module, StraightThrough): 44 | continue 45 | 46 | else: 47 | self.quant_module_refactor(child_module, weight_quant_params, act_quant_params, timewise, list_timesteps) 48 | 49 | def quant_block_refactor(self, module, weight_quant_params, act_quant_params, timewise, list_timesteps): 50 | for name, child_module in module.named_children(): 51 | if type(child_module) in self.specials: 52 | if self.specials[type(child_module)] in [QuantBasicTransformerBlock]: 53 | setattr(module, name, self.specials[type(child_module)](child_module, 54 | act_quant_params, sm_abit=self.sm_abit, timewise=timewise, list_timesteps=list_timesteps)) 55 | elif self.specials[type(child_module)] == QuantSMVMatMul: 56 | setattr(module, name, self.specials[type(child_module)]( 57 | act_quant_params, sm_abit=self.sm_abit, timewise=timewise, list_timesteps=list_timesteps)) 58 | elif self.specials[type(child_module)] == QuantQKMatMul: 59 | setattr(module, name, self.specials[type(child_module)]( 60 | act_quant_params, timewise=timewise, list_timesteps=list_timesteps)) 61 | else: 62 | setattr(module, name, self.specials[type(child_module)](child_module, 63 | act_quant_params)) 64 | else: 65 | self.quant_block_refactor(child_module, weight_quant_params, act_quant_params, timewise, list_timesteps) 66 | 67 | 68 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 69 | for m in self.model.modules(): 70 | if isinstance(m, (QuantModule, BaseQuantBlock)): 71 | m.set_quant_state(weight_quant, act_quant) 72 | 73 | def set_timestep(self, t): 74 | for m in self.model.modules(): 75 | if isinstance(m, (QuantModule, BaseQuantBlock)): 76 | m.set_timestep(t) 77 | 78 | def forward(self, x, timesteps=None, context=None): 79 | return self.model(x, timesteps, context) 80 | 81 | def set_running_stat(self, running_stat: bool, sm_only=False): 82 | # Only consider timewise=True here 83 | for m in self.model.modules(): 84 | if isinstance(m, QuantBasicTransformerBlock): 85 | if sm_only: 86 | m.attn1.act_quantizer_w.set_running_stat(running_stat) 87 | m.attn2.act_quantizer_w.set_running_stat(running_stat) 88 | else: 89 | m.attn1.act_quantizer_q.set_running_stat(running_stat) 90 | m.attn1.act_quantizer_k.set_running_stat(running_stat) 91 | m.attn1.act_quantizer_v.set_running_stat(running_stat) 92 | m.attn1.act_quantizer_w.set_running_stat(running_stat) 93 | m.attn2.act_quantizer_q.set_running_stat(running_stat) 94 | m.attn2.act_quantizer_k.set_running_stat(running_stat) 95 | m.attn2.act_quantizer_v.set_running_stat(running_stat) 96 | m.attn2.act_quantizer_w.set_running_stat(running_stat) 97 | if isinstance(m , QuantQKMatMul): 98 | m.act_quantizer_q.set_running_stat(running_stat) 99 | m.act_quantizer_k.set_running_stat(running_stat) 100 | if isinstance(m , QuantSMVMatMul): 101 | m.act_quantizer_v.set_running_stat(running_stat) 102 | m.act_quantizer_w.set_running_stat(running_stat) 103 | if isinstance(m, QuantModule) and not sm_only: 104 | m.set_running_stat(running_stat) 105 | 106 | def save_dict_params(self): 107 | for m in self.model.modules(): 108 | if isinstance(m, TimewiseUniformQuantizer): 109 | m.save_dict_params() 110 | 111 | def load_dict_params(self): 112 | for m in self.model.modules(): 113 | if isinstance(m, TimewiseUniformQuantizer): 114 | m.load_dict_params() 115 | 116 | def set_grad_ckpt(self, grad_ckpt: bool): 117 | for name, m in self.model.named_modules(): 118 | if isinstance(m, (QuantBasicTransformerBlock, BasicTransformerBlock)): 119 | # logger.info(name) 120 | m.checkpoint = grad_ckpt 121 | # elif isinstance(m, QuantResBlock): 122 | # logger.info(name) 123 | # m.use_checkpoint = grad_ckpt 124 | --------------------------------------------------------------------------------