├── README.md
├── configs
├── autoencoder
│ ├── autoencoder_kl_16x16x16.yaml
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_64x64x3.yaml
│ └── autoencoder_kl_8x8x64.yaml
├── cifar10.yml
├── latent-diffusion
│ ├── celebahq-ldm-vq-4.yaml
│ ├── cin-ldm-vq-f8.yaml
│ ├── cin256-v2.yaml
│ ├── ffhq-ldm-vq-4.yaml
│ ├── lsun_bedrooms-ldm-vq-4.yaml
│ ├── lsun_churches-ldm-kl-8.yaml
│ └── txt2img-1p4B-eval.yaml
├── retrieval-augmented-diffusion
│ └── 768x768.yaml
└── stable-diffusion
│ ├── v1-inference.yaml
│ └── v1-inpainting-inference.yaml
├── ddim
├── __init__.py
├── __pycache__
│ └── __init__.cpython-37.pyc
├── datasets
│ ├── __init__.py
│ ├── artcifar10.py
│ ├── celeba.py
│ ├── ffhq.py
│ ├── lsun.py
│ ├── utils.py
│ └── vision.py
├── dpm_solver_pytorch.py
├── functions
│ ├── __init__.py
│ ├── ckpt_util.py
│ ├── denoising.py
│ └── losses.py
└── models
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── diffusion.cpython-37.pyc
│ ├── diffusion.py
│ └── ema.py
├── environment.yml
├── get_calibration_set_imagenet_ddim.py
├── imgs
├── fig1.png
├── imagenet_example.png
└── sd_images.png
├── ldm
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── util.cpython-37.pyc
├── lr_scheduler.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── autoencoder.cpython-37.pyc
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── ddim.cpython-37.pyc
│ │ └── ddpm.cpython-37.pyc
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── dpm_solver.cpython-37.pyc
│ │ │ └── sampler.cpython-37.pyc
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ └── plms.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── attention.cpython-37.pyc
│ │ └── ema.cpython-37.pyc
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── model.cpython-37.pyc
│ │ │ ├── openaimodel.cpython-37.pyc
│ │ │ └── util.cpython-37.pyc
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ └── distributions.cpython-37.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ └── x_transformer.py
└── util.py
├── models
├── first_stage_models
│ ├── kl-f16
│ │ └── config.yaml
│ ├── kl-f32
│ │ └── config.yaml
│ ├── kl-f4
│ │ └── config.yaml
│ ├── kl-f8
│ │ └── config.yaml
│ ├── vq-f16
│ │ └── config.yaml
│ ├── vq-f4-noattn
│ │ └── config.yaml
│ ├── vq-f4
│ │ └── config.yaml
│ ├── vq-f8-n256
│ │ └── config.yaml
│ └── vq-f8
│ │ └── config.yaml
└── ldm
│ ├── bsr_sr
│ └── config.yaml
│ ├── celeba256
│ └── config.yaml
│ ├── cin256-v2
│ └── config.yaml
│ ├── cin256
│ └── config.yaml
│ ├── inpainting_big
│ └── config.yaml
│ ├── layout2img-openimages256
│ └── config.yaml
│ ├── lsun_churches256
│ └── config.yaml
│ ├── semantic_synthesis256
│ └── config.yaml
│ ├── semantic_synthesis512
│ └── config.yaml
│ └── text2img256
│ └── config.yaml
├── qdiff
├── __init__.py
├── adaptive_rounding.py
├── block_recon.py
├── layer_recon.py
├── post_layer_recon_imagenet.py
├── post_layer_recon_sd.py
├── post_layer_recon_uncond.py
├── quant_block.py
├── quant_layer.py
├── quant_model.py
└── utils.py
├── sample_diffusion_ldm_bedroom.py
├── sample_diffusion_ldm_church.py
├── sample_diffusion_ldm_imagenet.py
└── txt2img.py
/README.md:
--------------------------------------------------------------------------------
1 | # QuEST
2 | The official repository for **QuEST: Low-bit Diffusion Model Quantization via Efficient Selective Finetuning** [[ArXiv]](https://arxiv.org/abs/2402.03666)
3 |
4 | ## Update Log
5 | **(2024.2.28)** Reorganized the code structures.
6 |
7 | **(2024.12.15)** Fixed some inconsistencies in implementation.
8 |
9 | ## Features
10 | QuEST achieves state-of-the-art performance on mutiple high-resolution image generation tasks, including unconditional image generation, class-conditional image generation and text-to-image generation. We also achieve superior performance on full 4-bit (W4A4) generation.
11 |
12 |
13 |
14 | On ImageNet 256*256:
15 |
16 |
17 |
18 | On Stable Diffusion v1.4 (512*512):
19 |
20 |
21 |
22 |
23 | ## Get Started
24 | ### Prerequisites
25 | Make sure you have conda installed first, then:
26 | ```
27 | git clone https://github.com/hatchetProject/QuEST.git
28 | cd QuEST
29 | conda env create -f environment.yml
30 | conda activate quest
31 | ```
32 | ### Usage
33 | 1. For Latent Diffusion and Stable Diffusion experiments, first download pretrained model checkpoints following the instructions in the [latent-diffusion](https://github.com/CompVis/latent-diffusion/tree/main) and [stable-diffusion](https://github.com/CompVis/stable-diffusion#weights) repos from CompVis. We currently use sd-v1-4.ckpt for Stable Diffusion.
34 | 2. The calibration data for LSUN-Bedrooms/Churches and Stable Diffusion (COCO) can be downloaded from the [Q-Diffusion](https://github.com/Xiuyu-Li/q-diffusion/tree/master) repository. We will upload the calibration data for ImageNet soon.
35 | 3. Use the following commands to reproduce the models. `act_bit=4` additionally use channel-wise quantization on a more hardware-friendly dimension, which reduces computation cost. Also, exclude the `--running_stat` argument for W4A4 quantization.
36 | 4. Change line 151 in ldm/models/diffusion/ddim.py according the the number of timesteps you use, where you can replace '10' with the resulting number of `--c // --cali_st`, e.g. 200 // 20 = 10. Comment line 149~158 if you would like to inference with the FP model.
37 | 5. We highly recommend to use checkpoints by QDiffusion and resume it with the `--resume_w` command.
38 | ```
39 | # LSUN-Bedrooms (LDM-4)
40 | python sample_diffusion_ldm_bedroom.py -r models/ldm/lsun_beds256/model.ckpt -n 100 --batch_size 20 -c 200 -e 1.0 --seed 40 --ptq --weight_bit <4 or 8> --quant_mode qdiff --cali_st 20 --cali_batch_size 32 --cali_n 256 --quant_act --act_bit <4 or 8> --a_sym --a_min_max --running_stat --cali_data_path -l
41 | # LSUN-Churches (LDM-8)
42 | python scripts/sample_diffusion_ldm_church.py -r models/ldm/lsun_churches256/model.ckpt -n 50000 --batch_size 10 -c 500 -e 0.0 --seed 40 --ptq --weight_bit <4 or 8> --quant_mode qdiff --cali_st 20 --cali_batch_size 32 --cali_n 256 --quant_act --act_bit <4 or 8> --cali_data_path -l
43 | # ImageNet
44 | python sample_diffusion_ldm_imagenet.py -r models/ldm/cin256-v2/model.ckpt -n 50 --batch_size 50 -c 20 -e 1.0 --seed 40 --ptq --weight_bit <4 or 8> --quant_mode qdiff --cali_st 20 --cali_batch_size 32 --cali_n 256 --quant_act --act_bit <4 or 8> --a_sym --a_min_max --running_stat --cond --cali_data_path -l
45 | # Stable Diffusion
46 | python txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms --cond --ptq --weight_bit <4 or 8> --quant_mode qdiff --quant_act --act_bit <4 or 8> --cali_st 25 --cali_batch_size 8 --cali_n 128 --no_grad_ckpt --split --running_stat --sm_abit 16 --cali_data_path --outdir
47 | ```
48 |
49 | ### Calibration dataset
50 | We will release the calibration data. But you can also generate them yourself by using the following command (10 images per class over all timesteps):
51 | ```
52 | python get_calibration_set_imagenet_ddim.py -r -n 10 --batch_size 10 -c 20 -e 1.0 -seed 40 -l output/ --cond
53 | ```
54 |
55 | ### Evaluation
56 | We use the ADM’s TensorFlow evaluation suite [link](https://github.com/openai/guided-diffusion/tree/main/evaluations) for evaluating FID, sFID and IS. For Stable Diffusion, we generate 10,000 samples based on the prompts from the COCO2014 dataset calculate the average CLIP score.
57 |
58 | ## Acknowledgement
59 | This project is heavily based on [LDM](https://github.com/CompVis/latent-diffusion/tree/main) and [Q-Diffusion](https://github.com/Xiuyu-Li/q-diffusion/tree/master).
60 |
61 | ## Citation
62 | If you find this work helpful, please consider citing our paper:
63 | ```
64 | @misc{wang2024quest,
65 | title={QuEST: Low-bit Diffusion Model Quantization via Efficient Selective Finetuning},
66 | author={Haoxuan Wang and Yuzhang Shang and Zhihang Yuan and Junyi Wu and Yan Yan},
67 | year={2024},
68 | eprint={2402.03666},
69 | archivePrefix={arXiv},
70 | primaryClass={cs.CV}
71 | }
72 | ```
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/cifar10.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "CIFAR10"
3 | image_size: 32
4 | channels: 3
5 | logit_transform: false
6 | uniform_dequantization: false
7 | gaussian_dequantization: false
8 | random_flip: true
9 | rescaled: true
10 | num_workers: 4
11 |
12 | model:
13 | type: "simple"
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult: [1, 2, 2, 2]
18 | num_res_blocks: 2
19 | attn_resolutions: [16, ]
20 | dropout: 0.1
21 | var_type: fixedlarge
22 | ema_rate: 0.9999
23 | ema: True
24 | resamp_with_conv: True
25 |
26 | diffusion:
27 | beta_schedule: linear
28 | beta_start: 0.0001
29 | beta_end: 0.02
30 | num_diffusion_timesteps: 1000
31 |
32 | training:
33 | batch_size: 128
34 | n_epochs: 10000
35 | n_iters: 5000000
36 | snapshot_freq: 5000
37 | validation_freq: 2000
38 |
39 | sampling:
40 | batch_size: 64
41 | last_only: True
42 |
43 | optim:
44 | weight_decay: 0.000
45 | optimizer: "Adam"
46 | lr: 0.0002
47 | beta1: 0.9
48 | amsgrad: false
49 | eps: 0.00000001
50 | grad_clip: 1.0
51 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inpainting-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 7.5e-05
3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid # important
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | finetune_keys: null
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/ddim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/__init__.py
--------------------------------------------------------------------------------
/ddim/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ddim/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numbers
4 | import torchvision.transforms as transforms
5 | import torchvision.transforms.functional as F
6 | from torchvision.datasets import CIFAR10
7 | from ddim.datasets.artcifar10 import artCIFAR10
8 | from ddim.datasets.celeba import CelebA
9 | from ddim.datasets.ffhq import FFHQ
10 | from ddim.datasets.lsun import LSUN
11 | from torch.utils.data import Subset
12 | import numpy as np
13 |
14 |
15 | class Crop(object):
16 | def __init__(self, x1, x2, y1, y2):
17 | self.x1 = x1
18 | self.x2 = x2
19 | self.y1 = y1
20 | self.y2 = y2
21 |
22 | def __call__(self, img):
23 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)
24 |
25 | def __repr__(self):
26 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
27 | self.x1, self.x2, self.y1, self.y2
28 | )
29 |
30 |
31 | def get_dataset(args, config):
32 | if config.data.random_flip is False:
33 | tran_transform = test_transform = transforms.Compose(
34 | [transforms.Resize(config.data.image_size), transforms.ToTensor()]
35 | )
36 | else:
37 | tran_transform = transforms.Compose(
38 | [
39 | transforms.Resize(config.data.image_size),
40 | transforms.RandomHorizontalFlip(p=0.5),
41 | transforms.ToTensor(),
42 | ]
43 | )
44 | test_transform = transforms.Compose(
45 | [transforms.Resize(config.data.image_size), transforms.ToTensor()]
46 | )
47 |
48 | if config.data.dataset == "CIFAR10":
49 | dataset = CIFAR10(
50 | os.path.join(args.dataset, "cifar10"),
51 | train=True,
52 | download=True,
53 | transform=tran_transform,
54 | )
55 | test_dataset = CIFAR10(
56 | os.path.join(args.dataset, "cifar10_test"),
57 | train=False,
58 | download=True,
59 | transform=test_transform,
60 | )
61 |
62 | elif config.data.dataset == "artCIFAR10":
63 | dataset = artCIFAR10(
64 | os.path.join(args.dataset),
65 | train=True,
66 | download=False,
67 | transform=tran_transform,
68 | )
69 | test_dataset = artCIFAR10(
70 | os.path.join(args.dataset),
71 | train=False,
72 | download=False,
73 | transform=test_transform,
74 | )
75 |
76 | elif config.data.dataset == "CELEBA":
77 | cx = 89
78 | cy = 121
79 | x1 = cy - 64
80 | x2 = cy + 64
81 | y1 = cx - 64
82 | y2 = cx + 64
83 | if config.data.random_flip:
84 | dataset = CelebA(
85 | root=os.path.join(args.dataset, "celeba"),
86 | split="train",
87 | transform=transforms.Compose(
88 | [
89 | Crop(x1, x2, y1, y2),
90 | transforms.Resize(config.data.image_size),
91 | transforms.RandomHorizontalFlip(),
92 | transforms.ToTensor(),
93 | ]
94 | ),
95 | download=True,
96 | )
97 | else:
98 | dataset = CelebA(
99 | root=os.path.join(args.dataset, "celeba"),
100 | split="train",
101 | transform=transforms.Compose(
102 | [
103 | Crop(x1, x2, y1, y2),
104 | transforms.Resize(config.data.image_size),
105 | transforms.ToTensor(),
106 | ]
107 | ),
108 | download=True,
109 | )
110 |
111 | test_dataset = CelebA(
112 | root=os.path.join(args.dataset, "celeba"),
113 | split="test",
114 | transform=transforms.Compose(
115 | [
116 | Crop(x1, x2, y1, y2),
117 | transforms.Resize(config.data.image_size),
118 | transforms.ToTensor(),
119 | ]
120 | ),
121 | download=True,
122 | )
123 |
124 | elif config.data.dataset == "LSUN":
125 | train_folder = "{}_train".format(config.data.category)
126 | val_folder = "{}_val".format(config.data.category)
127 | if config.data.random_flip:
128 | dataset = LSUN(
129 | root=os.path.join(args.dataset, "lsun"),
130 | classes=[train_folder],
131 | transform=transforms.Compose(
132 | [
133 | transforms.Resize(config.data.image_size),
134 | transforms.CenterCrop(config.data.image_size),
135 | transforms.RandomHorizontalFlip(p=0.5),
136 | transforms.ToTensor(),
137 | ]
138 | ),
139 | )
140 | else:
141 | dataset = LSUN(
142 | root=os.path.join(args.dataset, "lsun"),
143 | classes=[train_folder],
144 | transform=transforms.Compose(
145 | [
146 | transforms.Resize(config.data.image_size),
147 | transforms.CenterCrop(config.data.image_size),
148 | transforms.ToTensor(),
149 | ]
150 | ),
151 | )
152 |
153 | test_dataset = LSUN(
154 | root=os.path.join(args.dataset, "lsun"),
155 | classes=[val_folder],
156 | transform=transforms.Compose(
157 | [
158 | transforms.Resize(config.data.image_size),
159 | transforms.CenterCrop(config.data.image_size),
160 | transforms.ToTensor(),
161 | ]
162 | ),
163 | )
164 |
165 | elif config.data.dataset == "FFHQ":
166 | if config.data.random_flip:
167 | dataset = FFHQ(
168 | path=os.path.join(args.dataset, "FFHQ"),
169 | transform=transforms.Compose(
170 | [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()]
171 | ),
172 | resolution=config.data.image_size,
173 | )
174 | else:
175 | dataset = FFHQ(
176 | path=os.path.join(args.dataset, "FFHQ"),
177 | transform=transforms.ToTensor(),
178 | resolution=config.data.image_size,
179 | )
180 |
181 | num_items = len(dataset)
182 | indices = list(range(num_items))
183 | random_state = np.random.get_state()
184 | np.random.seed(2019)
185 | np.random.shuffle(indices)
186 | np.random.set_state(random_state)
187 | train_indices, test_indices = (
188 | indices[: int(num_items * 0.9)],
189 | indices[int(num_items * 0.9) :],
190 | )
191 | test_dataset = Subset(dataset, test_indices)
192 | dataset = Subset(dataset, train_indices)
193 | else:
194 | dataset, test_dataset = None, None
195 |
196 | return dataset, test_dataset
197 |
198 |
199 | def logit_transform(image, lam=1e-6):
200 | image = lam + (1 - 2 * lam) * image
201 | return torch.log(image) - torch.log1p(-image)
202 |
203 |
204 | def data_transform(config, X):
205 | if config.data.uniform_dequantization:
206 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
207 | if config.data.gaussian_dequantization:
208 | X = X + torch.randn_like(X) * 0.01
209 |
210 | if config.data.rescaled:
211 | X = 2 * X - 1.0
212 | elif config.data.logit_transform:
213 | X = logit_transform(X)
214 |
215 | if hasattr(config, "image_mean"):
216 | return X - config.image_mean.to(X.device)[None, ...]
217 |
218 | return X
219 |
220 |
221 | def inverse_data_transform(config, X):
222 | if hasattr(config, "image_mean"):
223 | X = X + config.image_mean.to(X.device)[None, ...]
224 |
225 | if config.data.logit_transform:
226 | X = torch.sigmoid(X)
227 | elif config.data.rescaled:
228 | X = (X + 1.0) / 2.0
229 |
230 | return torch.clamp(X, 0.0, 1.0)
231 |
--------------------------------------------------------------------------------
/ddim/datasets/artcifar10.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR10
2 | class artCIFAR10(CIFAR10):
3 | """artCIFAR10
4 | """
5 |
6 | base_folder = "artcifar-10-batches-py"
7 | url = "https://artcifar.s3.us-east-2.amazonaws.com/artcifar-10-python.tar.gz"
8 | filename = "artcifar-10-python.tar.gz"
9 | tgz_md5 = "a6b71c6e0e3435d34e17896cc83ae1c1"
10 | train_list = [
11 | ["data_batch_1", "866ea0e474da9d5a033cfdb6adf4a631"],
12 | ["data_batch_2", "bba43dc11082b32fa8119cba55bebddd"],
13 | ["data_batch_3", "4978cbeb187b89e77e31fe39449c95ec"],
14 | ["data_batch_4", "1de1d0514f0c9cd28c5991386fa45f12"],
15 | ["data_batch_5", "81a7899dd79469824b75550bc0be3267"],
16 | ]
17 |
18 | test_list = [
19 | ["test_batch", "2ce6fb1ee71ffba9cef5df551fcce572"],
20 | ]
21 | meta = {
22 | "filename": "meta",
23 | "key": "styles",
24 | "md5": "6993b54be992f459837552f042419cdf",
25 | }
--------------------------------------------------------------------------------
/ddim/datasets/celeba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import PIL
4 | from .vision import VisionDataset
5 | from .utils import download_file_from_google_drive, check_integrity
6 |
7 |
8 | class CelebA(VisionDataset):
9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
10 |
11 | Args:
12 | root (string): Root directory where images are downloaded to.
13 | split (string): One of {'train', 'valid', 'test'}.
14 | Accordingly dataset is selected.
15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types.
17 | The targets represent:
18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
19 | ``identity`` (int): label for each person (data points with the same identity are the same person)
20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
23 | Defaults to ``attr``.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.ToTensor``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | download (bool, optional): If true, downloads the dataset from the internet and
29 | puts it in root directory. If dataset is already downloaded, it is not
30 | downloaded again.
31 | """
32 |
33 | base_folder = "celeba"
34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
36 | # right now.
37 | file_list = [
38 | # File ID MD5 Hash Filename
39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
48 | ]
49 |
50 | def __init__(self, root,
51 | split="train",
52 | target_type="attr",
53 | transform=None, target_transform=None,
54 | download=False):
55 | import pandas
56 | super(CelebA, self).__init__(root)
57 | self.split = split
58 | if isinstance(target_type, list):
59 | self.target_type = target_type
60 | else:
61 | self.target_type = [target_type]
62 | self.transform = transform
63 | self.target_transform = target_transform
64 |
65 | if download:
66 | self.download()
67 |
68 | if not self._check_integrity():
69 | raise RuntimeError('Dataset not found or corrupted.' +
70 | ' You can use download=True to download it')
71 |
72 | self.transform = transform
73 | self.target_transform = target_transform
74 |
75 | if split.lower() == "train":
76 | split = 0
77 | elif split.lower() == "valid":
78 | split = 1
79 | elif split.lower() == "test":
80 | split = 2
81 | else:
82 | raise ValueError('Wrong split entered! Please use split="train" '
83 | 'or split="valid" or split="test"')
84 |
85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
87 |
88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
90 |
91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)
93 |
94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)
96 |
97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)
99 |
100 | mask = (splits[1] == split)
101 | self.filename = splits[mask].index.values
102 | self.identity = torch.as_tensor(self.identity[mask].values)
103 | self.bbox = torch.as_tensor(self.bbox[mask].values)
104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
105 | self.attr = torch.as_tensor(self.attr[mask].values)
106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
107 |
108 | def _check_integrity(self):
109 | for (_, md5, filename) in self.file_list:
110 | fpath = os.path.join(self.root, self.base_folder, filename)
111 | _, ext = os.path.splitext(filename)
112 | # Allow original archive to be deleted (zip and 7z)
113 | # Only need the extracted images
114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
115 | return False
116 |
117 | # Should check a hash of the images
118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
119 |
120 | def download(self):
121 | import zipfile
122 |
123 | if self._check_integrity():
124 | print('Files already downloaded and verified')
125 | return
126 |
127 | for (file_id, md5, filename) in self.file_list:
128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
129 |
130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
131 | f.extractall(os.path.join(self.root, self.base_folder))
132 |
133 | def __getitem__(self, index):
134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
135 |
136 | target = []
137 | for t in self.target_type:
138 | if t == "attr":
139 | target.append(self.attr[index, :])
140 | elif t == "identity":
141 | target.append(self.identity[index, 0])
142 | elif t == "bbox":
143 | target.append(self.bbox[index, :])
144 | elif t == "landmarks":
145 | target.append(self.landmarks_align[index, :])
146 | else:
147 | raise ValueError("Target type \"{}\" is not recognized.".format(t))
148 | target = tuple(target) if len(target) > 1 else target[0]
149 |
150 | if self.transform is not None:
151 | X = self.transform(X)
152 |
153 | if self.target_transform is not None:
154 | target = self.target_transform(target)
155 |
156 | return X, target
157 |
158 | def __len__(self):
159 | return len(self.attr)
160 |
161 | def extra_repr(self):
162 | lines = ["Target type: {target_type}", "Split: {split}"]
163 | return '\n'.join(lines).format(**self.__dict__)
164 |
--------------------------------------------------------------------------------
/ddim/datasets/ffhq.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class FFHQ(Dataset):
9 | def __init__(self, path, transform, resolution=8):
10 | self.env = lmdb.open(
11 | path,
12 | max_readers=32,
13 | readonly=True,
14 | lock=False,
15 | readahead=False,
16 | meminit=False,
17 | )
18 |
19 | if not self.env:
20 | raise IOError('Cannot open lmdb dataset', path)
21 |
22 | with self.env.begin(write=False) as txn:
23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24 |
25 | self.resolution = resolution
26 | self.transform = transform
27 |
28 | def __len__(self):
29 | return self.length
30 |
31 | def __getitem__(self, index):
32 | with self.env.begin(write=False) as txn:
33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34 | img_bytes = txn.get(key)
35 |
36 | buffer = BytesIO(img_bytes)
37 | img = Image.open(buffer)
38 | img = self.transform(img)
39 | target = 0
40 |
41 | return img, target
--------------------------------------------------------------------------------
/ddim/datasets/lsun.py:
--------------------------------------------------------------------------------
1 | from .vision import VisionDataset
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import io
6 | from collections.abc import Iterable
7 | import pickle
8 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str
9 |
10 |
11 | class LSUNClass(VisionDataset):
12 | def __init__(self, root, transform=None, target_transform=None):
13 | import lmdb
14 |
15 | super(LSUNClass, self).__init__(
16 | root, transform=transform, target_transform=target_transform
17 | )
18 |
19 | self.env = lmdb.open(
20 | root,
21 | max_readers=1,
22 | readonly=True,
23 | lock=False,
24 | readahead=False,
25 | meminit=False,
26 | )
27 | with self.env.begin(write=False) as txn:
28 | self.length = txn.stat()["entries"]
29 | root_split = root.split("/")
30 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
31 | if os.path.isfile(cache_file):
32 | self.keys = pickle.load(open(cache_file, "rb"))
33 | else:
34 | with self.env.begin(write=False) as txn:
35 | self.keys = [key for key, _ in txn.cursor()]
36 | pickle.dump(self.keys, open(cache_file, "wb"))
37 |
38 | def __getitem__(self, index):
39 | img, target = None, None
40 | env = self.env
41 | with env.begin(write=False) as txn:
42 | imgbuf = txn.get(self.keys[index])
43 |
44 | buf = io.BytesIO()
45 | buf.write(imgbuf)
46 | buf.seek(0)
47 | img = Image.open(buf).convert("RGB")
48 |
49 | if self.transform is not None:
50 | img = self.transform(img)
51 |
52 | if self.target_transform is not None:
53 | target = self.target_transform(target)
54 |
55 | return img, target
56 |
57 | def __len__(self):
58 | return self.length
59 |
60 |
61 | class LSUN(VisionDataset):
62 | """
63 | `LSUN `_ dataset.
64 |
65 | Args:
66 | root (string): Root directory for the database files.
67 | classes (string or list): One of {'train', 'val', 'test'} or a list of
68 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
69 | transform (callable, optional): A function/transform that takes in an PIL image
70 | and returns a transformed version. E.g, ``transforms.RandomCrop``
71 | target_transform (callable, optional): A function/transform that takes in the
72 | target and transforms it.
73 | """
74 |
75 | def __init__(self, root, classes="train", transform=None, target_transform=None):
76 | super(LSUN, self).__init__(
77 | root, transform=transform, target_transform=target_transform
78 | )
79 | self.classes = self._verify_classes(classes)
80 |
81 | # for each class, create an LSUNClassDataset
82 | self.dbs = []
83 | for c in self.classes:
84 | self.dbs.append(
85 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform)
86 | )
87 |
88 | self.indices = []
89 | count = 0
90 | for db in self.dbs:
91 | count += len(db)
92 | self.indices.append(count)
93 |
94 | self.length = count
95 |
96 | def _verify_classes(self, classes):
97 | categories = [
98 | "bedroom",
99 | "bridge",
100 | "church_outdoor",
101 | "classroom",
102 | "conference_room",
103 | "dining_room",
104 | "kitchen",
105 | "living_room",
106 | "restaurant",
107 | "tower",
108 | ]
109 | dset_opts = ["train", "val", "test"]
110 |
111 | try:
112 | verify_str_arg(classes, "classes", dset_opts)
113 | if classes == "test":
114 | classes = [classes]
115 | else:
116 | classes = [c + "_" + classes for c in categories]
117 | except ValueError:
118 | if not isinstance(classes, Iterable):
119 | msg = (
120 | "Expected type str or Iterable for argument classes, "
121 | "but got type {}."
122 | )
123 | raise ValueError(msg.format(type(classes)))
124 |
125 | classes = list(classes)
126 | msg_fmtstr = (
127 | "Expected type str for elements in argument classes, "
128 | "but got type {}."
129 | )
130 | for c in classes:
131 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
132 | c_short = c.split("_")
133 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
134 |
135 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
136 | msg = msg_fmtstr.format(
137 | category, "LSUN class", iterable_to_str(categories)
138 | )
139 | verify_str_arg(category, valid_values=categories, custom_msg=msg)
140 |
141 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
142 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
143 |
144 | return classes
145 |
146 | def __getitem__(self, index):
147 | """
148 | Args:
149 | index (int): Index
150 |
151 | Returns:
152 | tuple: Tuple (image, target) where target is the index of the target category.
153 | """
154 | target = 0
155 | sub = 0
156 | for ind in self.indices:
157 | if index < ind:
158 | break
159 | target += 1
160 | sub = ind
161 |
162 | db = self.dbs[target]
163 | index = index - sub
164 |
165 | if self.target_transform is not None:
166 | target = self.target_transform(target)
167 |
168 | img, _ = db[index]
169 | return img, target
170 |
171 | def __len__(self):
172 | return self.length
173 |
174 | def extra_repr(self):
175 | return "Classes: {classes}".format(**self.__dict__)
176 |
--------------------------------------------------------------------------------
/ddim/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import errno
5 | from torch.utils.model_zoo import tqdm
6 |
7 |
8 | def gen_bar_updater():
9 | pbar = tqdm(total=None)
10 |
11 | def bar_update(count, block_size, total_size):
12 | if pbar.total is None and total_size:
13 | pbar.total = total_size
14 | progress_bytes = count * block_size
15 | pbar.update(progress_bytes - pbar.n)
16 |
17 | return bar_update
18 |
19 |
20 | def check_integrity(fpath, md5=None):
21 | if md5 is None:
22 | return True
23 | if not os.path.isfile(fpath):
24 | return False
25 | md5o = hashlib.md5()
26 | with open(fpath, 'rb') as f:
27 | # read in 1MB chunks
28 | for chunk in iter(lambda: f.read(1024 * 1024), b''):
29 | md5o.update(chunk)
30 | md5c = md5o.hexdigest()
31 | if md5c != md5:
32 | return False
33 | return True
34 |
35 |
36 | def makedir_exist_ok(dirpath):
37 | """
38 | Python2 support for os.makedirs(.., exist_ok=True)
39 | """
40 | try:
41 | os.makedirs(dirpath)
42 | except OSError as e:
43 | if e.errno == errno.EEXIST:
44 | pass
45 | else:
46 | raise
47 |
48 |
49 | def download_url(url, root, filename=None, md5=None):
50 | """Download a file from a url and place it in root.
51 |
52 | Args:
53 | url (str): URL to download file from
54 | root (str): Directory to place downloaded file in
55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL
56 | md5 (str, optional): MD5 checksum of the download. If None, do not check
57 | """
58 | from six.moves import urllib
59 |
60 | root = os.path.expanduser(root)
61 | if not filename:
62 | filename = os.path.basename(url)
63 | fpath = os.path.join(root, filename)
64 |
65 | makedir_exist_ok(root)
66 |
67 | # downloads file
68 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
69 | print('Using downloaded and verified file: ' + fpath)
70 | else:
71 | try:
72 | print('Downloading ' + url + ' to ' + fpath)
73 | urllib.request.urlretrieve(
74 | url, fpath,
75 | reporthook=gen_bar_updater()
76 | )
77 | except OSError:
78 | if url[:5] == 'https':
79 | url = url.replace('https:', 'http:')
80 | print('Failed download. Trying https -> http instead.'
81 | ' Downloading ' + url + ' to ' + fpath)
82 | urllib.request.urlretrieve(
83 | url, fpath,
84 | reporthook=gen_bar_updater()
85 | )
86 |
87 |
88 | def list_dir(root, prefix=False):
89 | """List all directories at a given root
90 |
91 | Args:
92 | root (str): Path to directory whose folders need to be listed
93 | prefix (bool, optional): If true, prepends the path to each result, otherwise
94 | only returns the name of the directories found
95 | """
96 | root = os.path.expanduser(root)
97 | directories = list(
98 | filter(
99 | lambda p: os.path.isdir(os.path.join(root, p)),
100 | os.listdir(root)
101 | )
102 | )
103 |
104 | if prefix is True:
105 | directories = [os.path.join(root, d) for d in directories]
106 |
107 | return directories
108 |
109 |
110 | def list_files(root, suffix, prefix=False):
111 | """List all files ending with a suffix at a given root
112 |
113 | Args:
114 | root (str): Path to directory whose folders need to be listed
115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
116 | It uses the Python "str.endswith" method and is passed directly
117 | prefix (bool, optional): If true, prepends the path to each result, otherwise
118 | only returns the name of the files found
119 | """
120 | root = os.path.expanduser(root)
121 | files = list(
122 | filter(
123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
124 | os.listdir(root)
125 | )
126 | )
127 |
128 | if prefix is True:
129 | files = [os.path.join(root, d) for d in files]
130 |
131 | return files
132 |
133 |
134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None):
135 | """Download a Google Drive file from and place it in root.
136 |
137 | Args:
138 | file_id (str): id of file to be downloaded
139 | root (str): Directory to place downloaded file in
140 | filename (str, optional): Name to save the file under. If None, use the id of the file.
141 | md5 (str, optional): MD5 checksum of the download. If None, do not check
142 | """
143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
144 | import requests
145 | url = "https://docs.google.com/uc?export=download"
146 |
147 | root = os.path.expanduser(root)
148 | if not filename:
149 | filename = file_id
150 | fpath = os.path.join(root, filename)
151 |
152 | makedir_exist_ok(root)
153 |
154 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
155 | print('Using downloaded and verified file: ' + fpath)
156 | else:
157 | session = requests.Session()
158 |
159 | response = session.get(url, params={'id': file_id}, stream=True)
160 | token = _get_confirm_token(response)
161 |
162 | if token:
163 | params = {'id': file_id, 'confirm': token}
164 | response = session.get(url, params=params, stream=True)
165 |
166 | _save_response_content(response, fpath)
167 |
168 |
169 | def _get_confirm_token(response):
170 | for key, value in response.cookies.items():
171 | if key.startswith('download_warning'):
172 | return value
173 |
174 | return None
175 |
176 |
177 | def _save_response_content(response, destination, chunk_size=32768):
178 | with open(destination, "wb") as f:
179 | pbar = tqdm(total=None)
180 | progress = 0
181 | for chunk in response.iter_content(chunk_size):
182 | if chunk: # filter out keep-alive new chunks
183 | f.write(chunk)
184 | progress += len(chunk)
185 | pbar.update(progress - pbar.n)
186 | pbar.close()
187 |
--------------------------------------------------------------------------------
/ddim/datasets/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, 'transform') and self.transform is not None:
41 | body += self._format_transform_repr(self.transform,
42 | "Transforms: ")
43 | if hasattr(self, 'target_transform') and self.target_transform is not None:
44 | body += self._format_transform_repr(self.target_transform,
45 | "Target transforms: ")
46 | lines = [head] + [" " * self._repr_indent + line for line in body]
47 | return '\n'.join(lines)
48 |
49 | def _format_transform_repr(self, transform, head):
50 | lines = transform.__repr__().splitlines()
51 | return (["{}{}".format(head, lines[0])] +
52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
53 |
54 | def extra_repr(self):
55 | return ""
56 |
57 |
58 | class StandardTransform(object):
59 | def __init__(self, transform=None, target_transform=None):
60 | self.transform = transform
61 | self.target_transform = target_transform
62 |
63 | def __call__(self, input, target):
64 | if self.transform is not None:
65 | input = self.transform(input)
66 | if self.target_transform is not None:
67 | target = self.target_transform(target)
68 | return input, target
69 |
70 | def _format_transform_repr(self, transform, head):
71 | lines = transform.__repr__().splitlines()
72 | return (["{}{}".format(head, lines[0])] +
73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
74 |
75 | def __repr__(self):
76 | body = [self.__class__.__name__]
77 | if self.transform is not None:
78 | body += self._format_transform_repr(self.transform,
79 | "Transform: ")
80 | if self.target_transform is not None:
81 | body += self._format_transform_repr(self.target_transform,
82 | "Target transform: ")
83 |
84 | return '\n'.join(body)
85 |
--------------------------------------------------------------------------------
/ddim/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/functions/__init__.py
--------------------------------------------------------------------------------
/ddim/functions/ckpt_util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
14 | }
15 | CKPT_MAP = {
16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt",
17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt",
18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt",
19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt",
20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt",
21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt",
22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt",
23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt",
24 | }
25 | MD5_MAP = {
26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669",
27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3",
28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c",
29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f",
30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b",
31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558",
32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3",
33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f",
34 | }
35 |
36 |
37 | def download(url, local_path, chunk_size=1024):
38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
39 | with requests.get(url, stream=True) as r:
40 | total_size = int(r.headers.get("content-length", 0))
41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
42 | with open(local_path, "wb") as f:
43 | for data in r.iter_content(chunk_size=chunk_size):
44 | if data:
45 | f.write(data)
46 | pbar.update(chunk_size)
47 |
48 |
49 | def md5_hash(path):
50 | with open(path, "rb") as f:
51 | content = f.read()
52 | return hashlib.md5(content).hexdigest()
53 |
54 |
55 | def get_ckpt_path(name, root=None, check=False):
56 | if 'church_outdoor' in name:
57 | name = name.replace('church_outdoor', 'church')
58 | assert name in URL_MAP
59 | # Modify the path when necessary
60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
61 | root = (
62 | root
63 | if root is not None
64 | else os.path.join(cachedir, "diffusion_models_converted")
65 | )
66 | path = os.path.join(root, CKPT_MAP[name])
67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
69 | download(URL_MAP[name], path)
70 | md5 = md5_hash(path)
71 | assert md5 == MD5_MAP[name], md5
72 | return path
73 |
--------------------------------------------------------------------------------
/ddim/functions/denoising.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def compute_alpha(beta, t):
5 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
6 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
7 | return a
8 |
9 |
10 | def generalized_steps(x, seq, model, b, **kwargs):
11 | with torch.no_grad():
12 | n = x.size(0)
13 | seq_next = [-1] + list(seq[:-1])
14 | x0_preds = []
15 | xs = [x]
16 | for i, j in zip(reversed(seq), reversed(seq_next)):
17 | t = (torch.ones(n) * i).to(x.device)
18 | next_t = (torch.ones(n) * j).to(x.device)
19 | at = compute_alpha(b, t.long())
20 | at_next = compute_alpha(b, next_t.long())
21 | xt = xs[-1].to('cuda')
22 | et = model(xt, t)
23 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
24 | x0_preds.append(x0_t.to('cpu'))
25 | c1 = (
26 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
27 | )
28 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
29 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
30 | xs.append(xt_next.to('cpu'))
31 |
32 | return xs, x0_preds
33 |
34 |
35 | def ddpm_steps(x, seq, model, b, **kwargs):
36 | with torch.no_grad():
37 | n = x.size(0)
38 | seq_next = [-1] + list(seq[:-1])
39 | xs = [x]
40 | x0_preds = []
41 | betas = b
42 | for i, j in zip(reversed(seq), reversed(seq_next)):
43 | t = (torch.ones(n) * i).to(x.device)
44 | next_t = (torch.ones(n) * j).to(x.device)
45 | at = compute_alpha(betas, t.long())
46 | atm1 = compute_alpha(betas, next_t.long())
47 | beta_t = 1 - at / atm1
48 | x = xs[-1].to('cuda')
49 |
50 | output = model(x, t.float())
51 | e = output
52 |
53 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
54 | x0_from_e = torch.clamp(x0_from_e, -1, 1)
55 | x0_preds.append(x0_from_e.to('cpu'))
56 | mean_eps = (
57 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
58 | ) / (1.0 - at)
59 |
60 | mean = mean_eps
61 | noise = torch.randn_like(x)
62 | mask = 1 - (t == 0).float()
63 | mask = mask.view(-1, 1, 1, 1)
64 | logvar = beta_t.log()
65 | sample = mean + mask * torch.exp(0.5 * logvar) * noise
66 | xs.append(sample.to('cpu'))
67 | return xs, x0_preds
68 |
--------------------------------------------------------------------------------
/ddim/functions/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def noise_estimation_loss(model,
5 | x0: torch.Tensor,
6 | t: torch.LongTensor,
7 | e: torch.Tensor,
8 | b: torch.Tensor, keepdim=False):
9 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
10 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
11 | output = model(x, t.float())
12 | if keepdim:
13 | return (e - output).square().sum(dim=(1, 2, 3))
14 | else:
15 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
16 |
17 |
18 | loss_registry = {
19 | 'simple': noise_estimation_loss,
20 | }
21 |
--------------------------------------------------------------------------------
/ddim/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/models/__init__.py
--------------------------------------------------------------------------------
/ddim/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ddim/models/__pycache__/diffusion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ddim/models/__pycache__/diffusion.cpython-37.pyc
--------------------------------------------------------------------------------
/ddim/models/ema.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class EMAHelper(object):
5 | def __init__(self, mu=0.999):
6 | self.mu = mu
7 | self.shadow = {}
8 |
9 | def register(self, module):
10 | if isinstance(module, nn.DataParallel):
11 | module = module.module
12 | for name, param in module.named_parameters():
13 | if param.requires_grad:
14 | self.shadow[name] = param.data.clone()
15 |
16 | def update(self, module):
17 | if isinstance(module, nn.DataParallel):
18 | module = module.module
19 | for name, param in module.named_parameters():
20 | if param.requires_grad:
21 | self.shadow[name].data = (
22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data
23 |
24 | def ema(self, module):
25 | if isinstance(module, nn.DataParallel):
26 | module = module.module
27 | for name, param in module.named_parameters():
28 | if param.requires_grad:
29 | param.data.copy_(self.shadow[name].data)
30 |
31 | def ema_copy(self, module):
32 | if isinstance(module, nn.DataParallel):
33 | inner_module = module.module
34 | module_copy = type(inner_module)(
35 | inner_module.config).to(inner_module.config.device)
36 | module_copy.load_state_dict(inner_module.state_dict())
37 | module_copy = nn.DataParallel(module_copy)
38 | else:
39 | module_copy = type(module)(module.config).to(module.config.device)
40 | module_copy.load_state_dict(module.state_dict())
41 | # module_copy = copy.deepcopy(module)
42 | self.ema(module_copy)
43 | return module_copy
44 |
45 | def state_dict(self):
46 | return self.shadow
47 |
48 | def load_state_dict(self, state_dict):
49 | self.shadow = state_dict
50 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: quest
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - python=3.8.5
8 | - pip=20.3
9 | - pytorch
10 | - pytorch-cuda=11.7
11 | - torchvision
12 | - torchaudio
13 | - numpy
14 | - pip:
15 | - albumentations==0.4.3
16 | - diffusers==0.3.0
17 | - pudb==2019.2
18 | - invisible-watermark
19 | - imageio==2.9.0
20 | - imageio-ffmpeg==0.4.2
21 | - test-tube>=0.7.5
22 | - streamlit>=0.73.1
23 | - torch-fidelity==0.3.0
24 | - torchmetrics==0.6.0
25 | - streamlit-drawable-canvas==0.8
26 | - einops==0.3.0
27 | - kornia==0.6.9
28 | - lmdb==1.3.0
29 | - natsort==8.3.1
30 | - omegaconf==2.1.1
31 | - opencv_python==4.1.2.30
32 | - opencv_python_headless==4.6.0.66
33 | - pandas==1.4.2
34 | - Pillow==9.0.1
35 | - pytorch_lightning==1.4.2
36 | - PyYAML==6.0
37 | - six==1.16.0
38 | - tqdm==4.64.0
39 | - transformers==4.22.2
40 | - omegaconf
41 | - git+https://github.com/openai/CLIP.git@main#egg=clip
42 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
43 | - -e .
44 |
--------------------------------------------------------------------------------
/imgs/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/imgs/fig1.png
--------------------------------------------------------------------------------
/imgs/imagenet_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/imgs/imagenet_example.png
--------------------------------------------------------------------------------
/imgs/sd_images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/imgs/sd_images.png
--------------------------------------------------------------------------------
/ldm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/__init__.py
--------------------------------------------------------------------------------
/ldm/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/__init__.py
--------------------------------------------------------------------------------
/ldm/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/__pycache__/autoencoder.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__pycache__/ddim.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/__pycache__/ddpm.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 |
5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6 |
7 |
8 | class DPMSolverSampler(object):
9 | def __init__(self, model, **kwargs):
10 | super().__init__()
11 | self.model = model
12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
14 |
15 | def register_buffer(self, name, attr):
16 | if type(attr) == torch.Tensor:
17 | if attr.device != torch.device("cuda"):
18 | attr = attr.to(torch.device("cuda"))
19 | setattr(self, name, attr)
20 |
21 | @torch.no_grad()
22 | def sample(self,
23 | S,
24 | batch_size,
25 | shape,
26 | conditioning=None,
27 | callback=None,
28 | normals_sequence=None,
29 | img_callback=None,
30 | quantize_x0=False,
31 | eta=0.,
32 | mask=None,
33 | x0=None,
34 | temperature=1.,
35 | noise_dropout=0.,
36 | score_corrector=None,
37 | corrector_kwargs=None,
38 | verbose=True,
39 | x_T=None,
40 | log_every_t=100,
41 | unconditional_guidance_scale=1.,
42 | unconditional_conditioning=None,
43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
44 | **kwargs
45 | ):
46 | if conditioning is not None:
47 | if isinstance(conditioning, dict):
48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
49 | if cbs != batch_size:
50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
51 | else:
52 | if conditioning.shape[0] != batch_size:
53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
54 |
55 | # sampling
56 | C, H, W = shape
57 | size = (batch_size, C, H, W)
58 |
59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
60 |
61 | device = self.model.betas.device
62 | if x_T is None:
63 | img = torch.randn(size, device=device)
64 | else:
65 | img = x_T
66 |
67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
68 |
69 | model_fn = model_wrapper(
70 | lambda x, t, c: self.model.apply_model(x, t, c),
71 | ns,
72 | model_type="noise",
73 | guidance_type="classifier-free",
74 | condition=conditioning,
75 | unconditional_condition=unconditional_conditioning,
76 | guidance_scale=unconditional_guidance_scale,
77 | )
78 |
79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
81 |
82 | return x.to(device), None
83 |
--------------------------------------------------------------------------------
/ldm/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__pycache__/attention.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/__pycache__/ema.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.qk_matmul = CrossQKMatMul(self.scale)
166 | self.smv_matmul = CrossSMVMatMul()
167 |
168 | self.to_out = nn.Sequential(
169 | nn.Linear(inner_dim, query_dim),
170 | nn.Dropout(dropout)
171 | )
172 |
173 | def forward(self, x, context=None, mask=None):
174 | h = self.heads
175 |
176 | q = self.to_q(x)
177 | context = default(context, x)
178 | k = self.to_k(context)
179 | v = self.to_v(context)
180 |
181 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
182 |
183 | # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
184 | sim = self.qk_matmul(q, k)
185 |
186 | if exists(mask):
187 | mask = rearrange(mask, 'b ... -> b (...)')
188 | max_neg_value = -torch.finfo(sim.dtype).max
189 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
190 | sim.masked_fill_(~mask, max_neg_value)
191 |
192 | # attention, what we cannot get enough of
193 | attn = sim.softmax(dim=-1)
194 |
195 | # out = einsum('b i j, b j d -> b i d', attn, v)
196 | out = self.smv_matmul(attn, v)
197 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
198 | return self.to_out(out)
199 |
200 |
201 | class CrossQKMatMul(nn.Module):
202 |
203 | def __init__(self, scale):
204 | super().__init__()
205 | self.scale = scale
206 |
207 | def forward(self, q, k):
208 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
209 | return sim
210 |
211 |
212 | class CrossSMVMatMul(nn.Module):
213 |
214 | def __init__(self):
215 | super().__init__()
216 |
217 | def forward(self, attn, v):
218 | out = einsum('b i j, b j d -> b i d', attn, v)
219 | return out
220 |
221 |
222 | class BasicTransformerBlock(nn.Module):
223 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
224 | super().__init__()
225 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
226 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
227 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
228 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
229 | self.norm1 = nn.LayerNorm(dim)
230 | self.norm2 = nn.LayerNorm(dim)
231 | self.norm3 = nn.LayerNorm(dim)
232 | self.checkpoint = checkpoint
233 |
234 | def forward(self, x, context=None):
235 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
236 |
237 | def _forward(self, x, context=None):
238 | x = self.attn1(self.norm1(x)) + x
239 | x = self.attn2(self.norm2(x), context=context) + x
240 | x = self.ff(self.norm3(x)) + x
241 | return x
242 |
243 |
244 | class SpatialTransformer(nn.Module):
245 | """
246 | Transformer block for image-like data.
247 | First, project the input (aka embedding)
248 | and reshape to b, t, d.
249 | Then apply standard transformer action.
250 | Finally, reshape to image
251 | """
252 | def __init__(self, in_channels, n_heads, d_head,
253 | depth=1, dropout=0., context_dim=None):
254 | super().__init__()
255 | self.in_channels = in_channels
256 | inner_dim = n_heads * d_head
257 | self.norm = Normalize(in_channels)
258 |
259 | self.proj_in = nn.Conv2d(in_channels,
260 | inner_dim,
261 | kernel_size=1,
262 | stride=1,
263 | padding=0)
264 |
265 | self.transformer_blocks = nn.ModuleList(
266 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
267 | for d in range(depth)]
268 | )
269 |
270 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
271 | in_channels,
272 | kernel_size=1,
273 | stride=1,
274 | padding=0))
275 |
276 | def forward(self, x, context=None):
277 | # note: if no context is given, cross-attention defaults to self-attention
278 | b, c, h, w = x.shape
279 | x_in = x
280 | x = self.norm(x)
281 | x = self.proj_in(x)
282 | x = rearrange(x, 'b c h w -> b (h w) c')
283 | for block in self.transformer_blocks:
284 | x = block(x, context)
285 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
286 | x = self.proj_out(x)
287 | return x + x_in
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | # print(key)
52 | # try:
53 | # print(self.m_name2s_name[key])
54 | # except:
55 | # print(self.m_name2s_name.keys())
56 | shadow_key = key.replace('.model.','.') if '.model.' in key else key
57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[shadow_key]].data)
58 | else:
59 | assert not key in self.m_name2s_name
60 |
61 | def store(self, parameters):
62 | """
63 | Save the current parameters for restoring later.
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | temporarily stored.
67 | """
68 | self.collected_params = [param.clone() for param in parameters]
69 |
70 | def restore(self, parameters):
71 | """
72 | Restore the parameters stored with the `store` method.
73 | Useful to validate the model with EMA parameters without affecting the
74 | original optimization process. Store the parameters before the
75 | `copy_to` method. After validation (or model saving), use this to
76 | restore the former parameters.
77 | Args:
78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79 | updated with the stored parameters.
80 | """
81 | for c_param, param in zip(self.collected_params, parameters):
82 | param.data.copy_(c_param.data)
83 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hatchetProject/QuEST/c1c917f71844e17575b311c7d08339925bccce2b/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | from transformers import CLIPTokenizer, CLIPTextModel
7 | import kornia
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 | class FrozenCLIPEmbedder(AbstractEncoder):
138 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
140 | super().__init__()
141 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
142 | self.transformer = CLIPTextModel.from_pretrained(version)
143 | self.device = device
144 | self.max_length = max_length
145 | self.freeze()
146 |
147 | def freeze(self):
148 | self.transformer = self.transformer.eval()
149 | for param in self.parameters():
150 | param.requires_grad = False
151 |
152 | def forward(self, text):
153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
155 | tokens = batch_encoding["input_ids"].to(self.device)
156 | outputs = self.transformer(input_ids=tokens)
157 |
158 | z = outputs.last_hidden_state
159 | return z
160 |
161 | def encode(self, text):
162 | return self(text)
163 |
164 |
165 | class FrozenCLIPTextEmbedder(nn.Module):
166 | """
167 | Uses the CLIP transformer encoder for text.
168 | """
169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
170 | super().__init__()
171 | self.model, _ = clip.load(version, jit=False, device="cpu")
172 | self.device = device
173 | self.max_length = max_length
174 | self.n_repeat = n_repeat
175 | self.normalize = normalize
176 |
177 | def freeze(self):
178 | self.model = self.model.eval()
179 | for param in self.parameters():
180 | param.requires_grad = False
181 |
182 | def forward(self, text):
183 | tokens = clip.tokenize(text).to(self.device)
184 | z = self.model.encode_text(tokens)
185 | if self.normalize:
186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
187 | return z
188 |
189 | def encode(self, text):
190 | z = self(text)
191 | if z.ndim==2:
192 | z = z[:, None, :]
193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
194 | return z
195 |
196 |
197 | class FrozenClipImageEmbedder(nn.Module):
198 | """
199 | Uses the CLIP image encoder.
200 | """
201 | def __init__(
202 | self,
203 | model,
204 | jit=False,
205 | device='cuda' if torch.cuda.is_available() else 'cpu',
206 | antialias=False,
207 | ):
208 | super().__init__()
209 | self.model, _ = clip.load(name=model, device=device, jit=jit)
210 |
211 | self.antialias = antialias
212 |
213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
215 |
216 | def preprocess(self, x):
217 | # normalize to [0,1]
218 | x = kornia.geometry.resize(x, (224, 224),
219 | interpolation='bicubic',align_corners=True,
220 | antialias=self.antialias)
221 | x = (x + 1.) / 2.
222 | # renormalize according to clip
223 | x = kornia.enhance.normalize(x, self.mean, self.std)
224 | return x
225 |
226 | def forward(self, x):
227 | # x is assumed to be in range [-1,1]
228 | return self.model.encode_image(self.preprocess(x))
229 |
230 |
231 | if __name__ == "__main__":
232 | from ldm.util import count_params
233 | model = FrozenCLIPEmbedder()
234 | count_params(model, verbose=True)
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 | import logging
9 |
10 | import multiprocessing as mp
11 | from threading import Thread
12 | from queue import Queue
13 |
14 | from inspect import isfunction
15 | from PIL import Image, ImageDraw, ImageFont
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def log_txt_as_img(wh, xc, size=10):
21 | # wh a tuple of (width, height)
22 | # xc a list of captions to plot
23 | b = len(xc)
24 | txts = list()
25 | for bi in range(b):
26 | txt = Image.new("RGB", wh, color="white")
27 | draw = ImageDraw.Draw(txt)
28 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
29 | nc = int(40 * (wh[0] / 256))
30 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
31 |
32 | try:
33 | draw.text((0, 0), lines, fill="black", font=font)
34 | except UnicodeEncodeError:
35 | print("Cant encode string for logging. Skipping.")
36 |
37 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
38 | txts.append(txt)
39 | txts = np.stack(txts)
40 | txts = torch.tensor(txts)
41 | return txts
42 |
43 |
44 | def ismap(x):
45 | if not isinstance(x, torch.Tensor):
46 | return False
47 | return (len(x.shape) == 4) and (x.shape[1] > 3)
48 |
49 |
50 | def isimage(x):
51 | if not isinstance(x, torch.Tensor):
52 | return False
53 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
54 |
55 |
56 | def exists(x):
57 | return x is not None
58 |
59 |
60 | def default(val, d):
61 | if exists(val):
62 | return val
63 | return d() if isfunction(d) else d
64 |
65 |
66 | def mean_flat(tensor):
67 | """
68 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
69 | Take the mean over all non-batch dimensions.
70 | """
71 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
72 |
73 |
74 | def count_params(model, verbose=False):
75 | total_params = sum(p.numel() for p in model.parameters())
76 | if verbose:
77 | logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
78 | return total_params
79 |
80 |
81 | def instantiate_from_config(config):
82 | if not "target" in config:
83 | if config == '__is_first_stage__':
84 | return None
85 | elif config == "__is_unconditional__":
86 | return None
87 | raise KeyError("Expected key `target` to instantiate.")
88 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
89 |
90 |
91 | def get_obj_from_str(string, reload=False):
92 | module, cls = string.rsplit(".", 1)
93 | if reload:
94 | module_imp = importlib.import_module(module)
95 | importlib.reload(module_imp)
96 | return getattr(importlib.import_module(module, package=None), cls)
97 |
98 |
99 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
100 | # create dummy dataset instance
101 |
102 | # run prefetching
103 | if idx_to_fn:
104 | res = func(data, worker_id=idx)
105 | else:
106 | res = func(data)
107 | Q.put([idx, res])
108 | Q.put("Done")
109 |
110 |
111 | def parallel_data_prefetch(
112 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
113 | ):
114 | # if target_data_type not in ["ndarray", "list"]:
115 | # raise ValueError(
116 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
117 | # )
118 | if isinstance(data, np.ndarray) and target_data_type == "list":
119 | raise ValueError("list expected but function got ndarray.")
120 | elif isinstance(data, abc.Iterable):
121 | if isinstance(data, dict):
122 | print(
123 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
124 | )
125 | data = list(data.values())
126 | if target_data_type == "ndarray":
127 | data = np.asarray(data)
128 | else:
129 | data = list(data)
130 | else:
131 | raise TypeError(
132 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
133 | )
134 |
135 | if cpu_intensive:
136 | Q = mp.Queue(1000)
137 | proc = mp.Process
138 | else:
139 | Q = Queue(1000)
140 | proc = Thread
141 | # spawn processes
142 | if target_data_type == "ndarray":
143 | arguments = [
144 | [func, Q, part, i, use_worker_id]
145 | for i, part in enumerate(np.array_split(data, n_proc))
146 | ]
147 | else:
148 | step = (
149 | int(len(data) / n_proc + 1)
150 | if len(data) % n_proc != 0
151 | else int(len(data) / n_proc)
152 | )
153 | arguments = [
154 | [func, Q, part, i, use_worker_id]
155 | for i, part in enumerate(
156 | [data[i: i + step] for i in range(0, len(data), step)]
157 | )
158 | ]
159 | processes = []
160 | for i in range(n_proc):
161 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
162 | processes += [p]
163 |
164 | # start processes
165 | print(f"Start prefetching...")
166 | import time
167 |
168 | start = time.time()
169 | gather_res = [[] for _ in range(n_proc)]
170 | try:
171 | for p in processes:
172 | p.start()
173 |
174 | k = 0
175 | while k < n_proc:
176 | # get result
177 | res = Q.get()
178 | if res == "Done":
179 | k += 1
180 | else:
181 | gather_res[res[0]] = res[1]
182 |
183 | except Exception as e:
184 | print("Exception: ", e)
185 | for p in processes:
186 | p.terminate()
187 |
188 | raise e
189 | finally:
190 | for p in processes:
191 | p.join()
192 | print(f"Prefetching complete. [{time.time() - start} sec.]")
193 |
194 | if target_data_type == 'ndarray':
195 | if not isinstance(gather_res[0], np.ndarray):
196 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
197 |
198 | # order outputs
199 | return np.concatenate(gather_res, axis=0)
200 | elif target_data_type == 'list':
201 | out = []
202 | for r in gather_res:
203 | out.extend(r)
204 | return out
205 | else:
206 | return gather_res
207 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/ldm/bsr_sr/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l2
10 | first_stage_key: image
11 | cond_stage_key: LR_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: false
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 160
23 | attention_resolutions:
24 | - 16
25 | - 8
26 | num_res_blocks: 2
27 | channel_mult:
28 | - 1
29 | - 2
30 | - 2
31 | - 4
32 | num_head_channels: 32
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: torch.nn.Identity
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 64
61 | wrap: false
62 | num_workers: 12
63 | train:
64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain
65 | params:
66 | size: 256
67 | degradation: bsrgan_light
68 | downscale_f: 4
69 | min_crop_f: 0.5
70 | max_crop_f: 1.0
71 | random_crop: true
72 | validation:
73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation
74 | params:
75 | size: 256
76 | degradation: bsrgan_light
77 | downscale_f: 4
78 | min_crop_f: 0.5
79 | max_crop_f: 1.0
80 | random_crop: true
81 |
--------------------------------------------------------------------------------
/models/ldm/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/cin256-v2/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/models/ldm/cin256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 1
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 4
41 | n_embed: 16384
42 | ddconfig:
43 | double_z: false
44 | z_channels: 4
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 32
57 | dropout: 0.0
58 | lossconfig:
59 | target: torch.nn.Identity
60 | cond_stage_config:
61 | target: ldm.modules.encoders.modules.ClassEmbedder
62 | params:
63 | embed_dim: 512
64 | key: class_label
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 64
69 | num_workers: 12
70 | wrap: false
71 | train:
72 | target: ldm.data.imagenet.ImageNetTrain
73 | params:
74 | config:
75 | size: 256
76 | validation:
77 | target: ldm.data.imagenet.ImageNetValidation
78 | params:
79 | config:
80 | size: 256
81 |
--------------------------------------------------------------------------------
/models/ldm/inpainting_big/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: masked_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | monitor: val/loss
16 | scheduler_config:
17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
18 | params:
19 | verbosity_interval: 0
20 | warm_up_steps: 1000
21 | max_decay_steps: 50000
22 | lr_start: 0.001
23 | lr_max: 0.1
24 | lr_min: 0.0001
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64
29 | in_channels: 7
30 | out_channels: 3
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 2
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 8
43 | resblock_updown: true
44 | first_stage_config:
45 | target: ldm.models.autoencoder.VQModelInterface
46 | params:
47 | embed_dim: 3
48 | n_embed: 8192
49 | monitor: val/rec_loss
50 | ddconfig:
51 | attn_type: none
52 | double_z: false
53 | z_channels: 3
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: ldm.modules.losses.contperceptual.DummyLoss
67 | cond_stage_config: __is_first_stage__
68 |
--------------------------------------------------------------------------------
/models/ldm/layout2img-openimages256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: coordinates_bbox
12 | image_size: 64
13 | channels: 3
14 | conditioning_key: crossattn
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 3
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 8
25 | - 4
26 | - 2
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 2
31 | - 3
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 3
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | monitor: val/rec_loss
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 512
63 | n_layer: 16
64 | vocab_size: 8192
65 | max_seq_len: 92
66 | use_tokenizer: false
67 | monitor: val/loss_simple_ema
68 | data:
69 | target: main.DataModuleFromConfig
70 | params:
71 | batch_size: 24
72 | wrap: false
73 | num_workers: 10
74 | train:
75 | target: ldm.data.openimages.OpenImagesBBoxTrain
76 | params:
77 | size: 256
78 | validation:
79 | target: ldm.data.openimages.OpenImagesBBoxValidation
80 | params:
81 | size: 256
82 |
--------------------------------------------------------------------------------
/models/ldm/lsun_churches256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: false
16 | concat_mode: false
17 | scale_by_std: true
18 | monitor: val/loss_simple_ema
19 | scheduler_config:
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps:
23 | - 10000
24 | cycle_lengths:
25 | - 10000000000000
26 | f_start:
27 | - 1.0e-06
28 | f_max:
29 | - 1.0
30 | f_min:
31 | - 1.0
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 192
39 | attention_resolutions:
40 | - 1
41 | - 2
42 | - 4
43 | - 8
44 | num_res_blocks: 2
45 | channel_mult:
46 | - 1
47 | - 2
48 | - 2
49 | - 4
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config: '__is_unconditional__'
78 |
79 | data:
80 | target: main.DataModuleFromConfig
81 | params:
82 | batch_size: 96
83 | num_workers: 5
84 | wrap: false
85 | train:
86 | target: ldm.data.lsun.LSUNChurchesTrain
87 | params:
88 | size: 256
89 | validation:
90 | target: ldm.data.lsun.LSUNChurchesValidation
91 | params:
92 | size: 256
93 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | ddconfig:
39 | double_z: false
40 | z_channels: 3
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 2
48 | - 4
49 | num_res_blocks: 2
50 | attn_resolutions: []
51 | dropout: 0.0
52 | lossconfig:
53 | target: torch.nn.Identity
54 | cond_stage_config:
55 | target: ldm.modules.encoders.modules.SpatialRescaler
56 | params:
57 | n_stages: 2
58 | in_channels: 182
59 | out_channels: 3
60 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis512/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 128
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 128
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: ldm.modules.encoders.modules.SpatialRescaler
57 | params:
58 | n_stages: 2
59 | in_channels: 182
60 | out_channels: 3
61 | data:
62 | target: main.DataModuleFromConfig
63 | params:
64 | batch_size: 8
65 | wrap: false
66 | num_workers: 10
67 | train:
68 | target: ldm.data.landscapes.RFWTrain
69 | params:
70 | size: 768
71 | crop_size: 512
72 | segmentation_to_float32: true
73 | validation:
74 | target: ldm.data.landscapes.RFWValidation
75 | params:
76 | size: 768
77 | crop_size: 512
78 | segmentation_to_float32: true
79 |
--------------------------------------------------------------------------------
/models/ldm/text2img256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 192
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 5
34 | num_head_channels: 32
35 | use_spatial_transformer: true
36 | transformer_depth: 1
37 | context_dim: 640
38 | first_stage_config:
39 | target: ldm.models.autoencoder.VQModelInterface
40 | params:
41 | embed_dim: 3
42 | n_embed: 8192
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 640
63 | n_layer: 32
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 28
68 | num_workers: 10
69 | wrap: false
70 | train:
71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation
76 | params:
77 | size: 256
78 |
--------------------------------------------------------------------------------
/qdiff/__init__.py:
--------------------------------------------------------------------------------
1 | from qdiff.block_recon import block_reconstruction
2 | from qdiff.layer_recon import layer_reconstruction
3 | from qdiff.quant_block import BaseQuantBlock, QuantSMVMatMul, QuantQKMatMul, QuantBasicTransformerBlock
4 | from qdiff.quant_layer import QuantModule
5 | from qdiff.quant_model import QuantModel
--------------------------------------------------------------------------------
/qdiff/adaptive_rounding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import logging
4 | from qdiff.quant_layer import UniformAffineQuantizer, round_ste, floor_ste
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class AdaRoundQuantizer(nn.Module):
10 | """
11 | Adaptive Rounding Quantizer, used to optimize the rounding policy
12 | by reconstructing the intermediate output.
13 | Based on
14 | Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568
15 |
16 | :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer
17 | :param round_mode: controls the forward pass in this quantizer
18 | :param weight_tensor: initialize alpha
19 | """
20 |
21 | def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'):
22 | super(AdaRoundQuantizer, self).__init__()
23 | # copying all attributes from UniformAffineQuantizer
24 | self.n_bits = uaq.n_bits
25 | self.sym = uaq.sym
26 | self.delta = nn.Parameter(uaq.delta)
27 | self.zero_point = nn.Parameter(uaq.zero_point)
28 | self.n_levels = uaq.n_levels
29 |
30 | self.round_mode = round_mode
31 | self.alpha = None
32 | self.soft_targets = False
33 |
34 | # params for sigmoid function
35 | self.gamma, self.zeta = -0.1, 1.1
36 | self.beta = 2/3
37 | self.init_alpha(x=weight_tensor.clone())
38 |
39 | def forward(self, x):
40 | if self.round_mode == 'nearest':
41 | x_int = round_ste(x / self.delta)
42 | elif self.round_mode == 'nearest_ste':
43 | x_int = round_ste(x / self.delta)
44 | elif self.round_mode == 'stochastic':
45 | # x_floor = torch.floor(x / self.delta)
46 | x_floor = floor_ste(x / self.delta)
47 | rest = (x / self.delta) - x_floor # rest of rounding
48 | x_int = x_floor + torch.bernoulli(rest)
49 | logger.info('Draw stochastic sample')
50 | elif self.round_mode == 'learned_hard_sigmoid':
51 | # x_floor = torch.floor(x / self.delta)
52 | x_floor = floor_ste(x / self.delta)
53 | if self.soft_targets:
54 | x_int = x_floor + self.get_soft_targets()
55 | else:
56 | x_int = x_floor + (self.alpha >= 0).float()
57 | else:
58 | raise ValueError('Wrong rounding mode')
59 |
60 | x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1)
61 | x_float_q = (x_quant - self.zero_point) * self.delta
62 |
63 | return x_float_q
64 |
65 | def get_soft_targets(self):
66 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1)
67 |
68 | def init_alpha(self, x: torch.Tensor):
69 | x_floor = torch.floor(x / self.delta)
70 | if self.round_mode == 'learned_hard_sigmoid':
71 | # logger.info('Init alpha to be FP32')
72 | rest = (x / self.delta) - x_floor # rest of rounding [0, 1)
73 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
74 | self.alpha = nn.Parameter(alpha)
75 | else:
76 | raise NotImplementedError
77 |
78 | def extra_repr(self):
79 | s = 'bit={n_bits}, symmetric={sym}, round_mode={round_mode}'
80 | return s.format(**self.__dict__)
81 |
--------------------------------------------------------------------------------
/qdiff/block_recon.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # import linklink as link
3 | import logging
4 | from qdiff.quant_layer import QuantModule, StraightThrough, lp_loss
5 | from qdiff.quant_model import QuantModel
6 | from qdiff.quant_block import BaseQuantBlock
7 | from qdiff.adaptive_rounding import AdaRoundQuantizer
8 | from qdiff.utils import save_grad_data, save_inp_oup_data
9 | import os
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def block_reconstruction(model: QuantModel, block: BaseQuantBlock, cali_data: torch.Tensor,
15 | batch_size: int = 32, iters: int = 20000, weight: float = 0.01, opt_mode: str = 'mse',
16 | asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2),
17 | warmup: float = 0.0, act_quant: bool = False, lr: float = 4e-5, p: float = 2.0,
18 | multi_gpu: bool = False, cond: bool = False, is_sm: bool = False, outpath: str = None):
19 | """
20 | Block reconstruction to optimize the output from each block.
21 |
22 | :param model: QuantModel
23 | :param block: BaseQuantBlock that needs to be optimized
24 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound
25 | :param batch_size: mini-batch size for reconstruction
26 | :param iters: optimization iterations for reconstruction,
27 | :param weight: the weight of rounding regularization term
28 | :param opt_mode: optimization mode
29 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output
30 | :param include_act_func: optimize the output after activation function
31 | :param b_range: temperature range
32 | :param warmup: proportion of iterations that no scheduling for temperature
33 | :param act_quant: use activation quantization or not.
34 | :param lr: learning rate for act delta learning
35 | :param p: L_p norm minimization
36 | :param multi_gpu: use multi-GPU or not, if enabled, we should sync the gradients
37 | :param cond: conditional generation or not
38 | :param is_sm: avoid OOM when caching n^2 attention matrix when n is large
39 | """
40 | model.set_quant_state(False, False)
41 | block.set_quant_state(True, act_quant)
42 | round_mode = 'learned_hard_sigmoid'
43 |
44 | if not include_act_func:
45 | org_act_func = block.activation_function
46 | block.activation_function = StraightThrough()
47 |
48 | if not act_quant:
49 | # Replace weight quantizer to AdaRoundQuantizer
50 | for name, module in block.named_modules():
51 | if isinstance(module, QuantModule):
52 | if module.split != 0:
53 | module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode,
54 | weight_tensor=module.org_weight.data[:, :module.split, ...])
55 | module.weight_quantizer_0 = AdaRoundQuantizer(uaq=module.weight_quantizer_0, round_mode=round_mode,
56 | weight_tensor=module.org_weight.data[:, module.split:, ...])
57 | else:
58 | module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode,
59 | weight_tensor=module.org_weight.data)
60 | module.weight_quantizer.soft_targets = True
61 | if module.split != 0:
62 | module.weight_quantizer_0.soft_targets = True
63 |
64 | # Set up optimizer
65 | opt_params = []
66 | for name, module in block.named_modules():
67 | if isinstance(module, QuantModule):
68 | opt_params += [module.weight_quantizer.alpha]
69 | if module.split != 0:
70 | opt_params += [module.weight_quantizer_0.alpha]
71 | optimizer = torch.optim.Adam(opt_params)
72 | scheduler = None
73 | else:
74 | # Use UniformAffineQuantizer to learn delta
75 | if hasattr(block.act_quantizer, 'delta') and block.act_quantizer.delta is not None:
76 | opt_params = [block.act_quantizer.delta]
77 | else:
78 | opt_params = []
79 |
80 | if hasattr(block, 'attn1'):
81 | opt_params += [
82 | block.attn1.act_quantizer_q.delta,
83 | block.attn1.act_quantizer_k.delta,
84 | block.attn1.act_quantizer_v.delta,
85 | block.attn2.act_quantizer_q.delta,
86 | block.attn2.act_quantizer_k.delta,
87 | block.attn2.act_quantizer_v.delta]
88 | if block.attn1.act_quantizer_w.n_bits != 16:
89 | opt_params += [block.attn1.act_quantizer_w.delta]
90 | if block.attn2.act_quantizer_w.n_bits != 16:
91 | opt_params += [block.attn2.act_quantizer_w.delta]
92 | if hasattr(block, 'act_quantizer_q'):
93 | opt_params += [
94 | block.act_quantizer_q.delta,
95 | block.act_quantizer_k.delta]
96 | if hasattr(block, 'act_quantizer_w'):
97 | opt_params += [block.act_quantizer_v.delta]
98 | if block.act_quantizer_w.n_bits != 16:
99 | opt_params += [block.act_quantizer_w.delta]
100 |
101 | for name, module in block.named_modules():
102 | if isinstance(module, QuantModule):
103 | if module.act_quantizer.delta is not None:
104 | opt_params += [module.act_quantizer.delta]
105 | if module.split != 0 and module.act_quantizer_0.delta is not None:
106 | opt_params += [module.act_quantizer_0.delta]
107 |
108 | optimizer = torch.optim.Adam(opt_params, lr=lr)
109 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.)
110 |
111 | loss_mode = 'none' if act_quant else 'relaxation'
112 | rec_loss = opt_mode
113 |
114 | loss_func = LossFunction(block, round_loss=loss_mode, weight=weight, max_count=iters, rec_loss=rec_loss,
115 | b_range=b_range, decay_start=0, warmup=warmup, p=p)
116 |
117 | # Save data before optimizing the rounding
118 | print(f"cond {cond}")
119 | num_split = 10
120 | b_size = cali_data[0].shape[0] // num_split
121 | for k in range(num_split):
122 | logger.info(f"Saving {num_split} intermediate results to disk to avoid OOM")
123 | if cond:
124 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size], cali_data[2][k*b_size:(k+1)*b_size])
125 | else:
126 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size])
127 | cached_inps, cached_outs = save_inp_oup_data(
128 | model, block, cali_data_t, asym, act_quant, batch_size=8, keep_gpu=False, cond=cond, is_sm=is_sm)
129 | cached_path = os.path.join(outpath, 'tmp_cached/')
130 | if not os.path.exists(cached_path):
131 | os.makedirs(cached_path)
132 | torch.save(cached_inps, os.path.join(cached_path, f'cached_inps_t{k}.pt'))
133 | torch.save(cached_outs, os.path.join(cached_path, f'cached_outs_t{k}.pt'))
134 | # cached_inps, cached_outs = save_inp_oup_data(
135 | # model, block, cali_data, asym, act_quant, 8, keep_gpu=False, cond=cond, is_sm=is_sm)
136 |
137 | if opt_mode != 'mse':
138 | cached_grads = save_grad_data(model, block, cali_data, act_quant, batch_size=batch_size)
139 | else:
140 | cached_grads = None
141 | device = 'cuda'
142 |
143 | for k in range(num_split):
144 | cached_inps = torch.load(os.path.join(cached_path, f'cached_inps_t{k}.pt'))
145 | cached_outs = torch.load(os.path.join(cached_path, f'cached_outs_t{k}.pt'))
146 | for i in range(iters // num_split):
147 | if isinstance(cached_inps, list):
148 | idx = torch.randperm(cached_inps[0].size(0))[:batch_size]
149 | cur_x = cached_inps[0][idx].to(device)
150 | cur_t = cached_inps[1][idx].to(device)
151 | cur_inp = (cur_x, cur_t)
152 | else:
153 | idx = torch.randperm(cached_inps.size(0))[:batch_size]
154 | cur_inp = cached_inps[idx].to(device)
155 | cur_out = cached_outs[idx].to(device)
156 | cur_grad = cached_grads[idx].to(device) if opt_mode != 'mse' else None
157 |
158 | optimizer.zero_grad()
159 | if isinstance(cur_inp, tuple):
160 | out_quant = block(cur_inp[0], cur_inp[1])
161 | else:
162 | out_quant = block(cur_inp)
163 |
164 | err = loss_func(out_quant, cur_out, cur_grad)
165 | err.backward(retain_graph=True)
166 | if multi_gpu:
167 | raise NotImplementedError
168 | # for p in opt_params:
169 | # link.allreduce(p.grad)
170 | optimizer.step()
171 | if scheduler:
172 | scheduler.step()
173 | torch.cuda.empty_cache()
174 |
175 | # Finish optimization, use hard rounding.
176 | for name, module in block.named_modules():
177 | if isinstance(module, QuantModule):
178 | module.weight_quantizer.soft_targets = False
179 | if module.split != 0:
180 | module.weight_quantizer_0.soft_targets = False
181 |
182 | # Reset original activation function
183 | if not include_act_func:
184 | block.activation_function = org_act_func
185 |
186 |
187 | class LossFunction:
188 | def __init__(self,
189 | block: BaseQuantBlock,
190 | round_loss: str = 'relaxation',
191 | weight: float = 1.,
192 | rec_loss: str = 'mse',
193 | max_count: int = 2000,
194 | b_range: tuple = (10, 2),
195 | decay_start: float = 0.0,
196 | warmup: float = 0.0,
197 | p: float = 2.):
198 |
199 | self.block = block
200 | self.round_loss = round_loss
201 | self.weight = weight
202 | self.rec_loss = rec_loss
203 | self.loss_start = max_count * warmup
204 | self.p = p
205 |
206 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start,
207 | start_b=b_range[0], end_b=b_range[1])
208 | self.count = 0
209 |
210 | def __call__(self, pred, tgt, grad=None):
211 | """
212 | Compute the total loss for adaptive rounding:
213 | rec_loss is the quadratic output reconstruction loss, round_loss is
214 | a regularization term to optimize the rounding policy
215 |
216 | :param pred: output from quantized model
217 | :param tgt: output from FP model
218 | :param grad: gradients to compute fisher information
219 | :return: total loss function
220 | """
221 | self.count += 1
222 | if self.rec_loss == 'mse':
223 | rec_loss = lp_loss(pred, tgt, p=self.p)
224 | elif self.rec_loss == 'fisher_diag':
225 | rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean()
226 | elif self.rec_loss == 'fisher_full':
227 | a = (pred - tgt).abs()
228 | grad = grad.abs()
229 | batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1)
230 | rec_loss = (batch_dotprod * a * grad).mean() / 100
231 | else:
232 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))
233 |
234 | b = self.temp_decay(self.count)
235 | if self.count < self.loss_start or self.round_loss == 'none':
236 | b = round_loss = 0
237 | elif self.round_loss == 'relaxation':
238 | round_loss = 0
239 | for name, module in self.block.named_modules():
240 | if isinstance(module, QuantModule):
241 | round_vals = module.weight_quantizer.get_soft_targets()
242 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
243 | else:
244 | raise NotImplementedError
245 |
246 | total_loss = rec_loss + round_loss
247 | if self.count % 500 == 0:
248 | logger.info('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
249 | float(total_loss), float(rec_loss), float(round_loss), b, self.count))
250 | return total_loss
251 |
252 |
253 | class LinearTempDecay:
254 | def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2):
255 | self.t_max = t_max
256 | self.start_decay = rel_start_decay * t_max
257 | self.start_b = start_b
258 | self.end_b = end_b
259 |
260 | def __call__(self, t):
261 | """
262 | Cosine annealing scheduler for temperature b.
263 | :param t: the current time step
264 | :return: scheduled temperature
265 | """
266 | if t < self.start_decay:
267 | return self.start_b
268 | else:
269 | rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
270 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))
271 |
--------------------------------------------------------------------------------
/qdiff/layer_recon.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # import linklink as link
3 | import logging
4 | from qdiff.quant_layer import QuantModule, StraightThrough, lp_loss
5 | from qdiff.quant_model import QuantModel
6 | from qdiff.block_recon import LinearTempDecay
7 | from qdiff.adaptive_rounding import AdaRoundQuantizer
8 | from qdiff.utils import save_grad_data, save_inp_oup_data
9 | import os
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def layer_reconstruction(model: QuantModel, layer: QuantModule, cali_data: torch.Tensor,
15 | batch_size: int = 32, iters: int = 20000, weight: float = 0.001, opt_mode: str = 'mse',
16 | asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2),
17 | warmup: float = 0.0, act_quant: bool = False, lr: float = 4e-5, p: float = 2.0,
18 | multi_gpu: bool = False, cond: bool = False, is_sm: bool = False, outpath: str = None):
19 | """
20 | Block reconstruction to optimize the output from each layer.
21 |
22 | :param model: QuantModel
23 | :param layer: QuantModule that needs to be optimized
24 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound
25 | :param batch_size: mini-batch size for reconstruction
26 | :param iters: optimization iterations for reconstruction,
27 | :param weight: the weight of rounding regularization term
28 | :param opt_mode: optimization mode
29 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output
30 | :param include_act_func: optimize the output after activation function
31 | :param b_range: temperature range
32 | :param warmup: proportion of iterations that no scheduling for temperature
33 | :param act_quant: use activation quantization or not.
34 | :param lr: learning rate for act delta learning
35 | :param p: L_p norm minimization
36 | :param multi_gpu: use multi-GPU or not, if enabled, we should sync the gradients
37 | :param cond: conditional generation or not
38 | :param is_sm: avoid OOM when caching n^2 attention matrix when n is large
39 | """
40 |
41 | model.set_quant_state(False, False)
42 | layer.set_quant_state(True, act_quant)
43 | round_mode = 'learned_hard_sigmoid'
44 |
45 | if not include_act_func:
46 | org_act_func = layer.activation_function
47 | layer.activation_function = StraightThrough()
48 |
49 | if not act_quant:
50 | # Replace weight quantizer to AdaRoundQuantizer
51 | if layer.split != 0:
52 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode,
53 | weight_tensor=layer.org_weight.data[:, :layer.split, ...])
54 | layer.weight_quantizer_0 = AdaRoundQuantizer(uaq=layer.weight_quantizer_0, round_mode=round_mode,
55 | weight_tensor=layer.org_weight.data[:, layer.split:, ...])
56 | else:
57 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode,
58 | weight_tensor=layer.org_weight.data)
59 | layer.weight_quantizer.soft_targets = True
60 |
61 | # Set up optimizer
62 | opt_params = [layer.weight_quantizer.alpha]
63 | if layer.split != 0:
64 | opt_params += [layer.weight_quantizer_0.alpha]
65 | optimizer = torch.optim.Adam(opt_params)
66 | scheduler = None
67 | else:
68 | # Use UniformAffineQuantizer to learn delta
69 | opt_params = [layer.act_quantizer.delta]
70 | if layer.split != 0 and layer.act_quantizer_0.delta is not None:
71 | opt_params += [layer.act_quantizer_0.delta]
72 | optimizer = torch.optim.Adam(opt_params, lr=lr)
73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.)
74 |
75 | loss_mode = 'none' if act_quant else 'relaxation'
76 | rec_loss = opt_mode
77 |
78 | loss_func = LossFunction(layer, round_loss=loss_mode, weight=weight,
79 | max_count=iters, rec_loss=rec_loss, b_range=b_range,
80 | decay_start=0, warmup=warmup, p=p)
81 |
82 | # Save data before optimizing the rounding
83 | num_split = 10
84 | b_size = cali_data[0].shape[0] // num_split
85 | for k in range(num_split):
86 | logger.info(f"Saving {num_split} intermediate results to disk to avoid OOM")
87 | if cond:
88 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size], cali_data[2][k*b_size:(k+1)*b_size])
89 | else:
90 | cali_data_t = (cali_data[0][k*b_size:(k+1)*b_size], cali_data[1][k*b_size:(k+1)*b_size])
91 | cached_inps, cached_outs = save_inp_oup_data(
92 | model, layer, cali_data_t, asym, act_quant, batch_size=8, keep_gpu=False, cond=cond, is_sm=is_sm)
93 | cached_path = os.path.join(outpath, 'tmp_cached/')
94 | if not os.path.exists(cached_path):
95 | os.makedirs(cached_path)
96 | torch.save(cached_inps, os.path.join(cached_path, f'cached_inps_t{k}.pt'))
97 | torch.save(cached_outs, os.path.join(cached_path, f'cached_outs_t{k}.pt'))
98 | # cached_inps, cached_outs = save_inp_oup_data(
99 | # model, layer, cali_data, asym, act_quant, 8, keep_gpu=False, cond=cond, is_sm=is_sm)
100 |
101 | if opt_mode != 'mse':
102 | cached_grads = save_grad_data(model, layer, cali_data, act_quant, batch_size=batch_size)
103 | else:
104 | cached_grads = None
105 | device = 'cuda'
106 | for k in range(num_split):
107 | cached_inps = torch.load(os.path.join(cached_path, f'cached_inps_t{k}.pt'))
108 | cached_outs = torch.load(os.path.join(cached_path, f'cached_outs_t{k}.pt'))
109 | for i in range(iters // num_split):
110 | idx = torch.randperm(cached_inps.size(0))[:batch_size]
111 | cur_inp = cached_inps[idx].to(device)
112 | cur_out = cached_outs[idx].to(device)
113 | cur_grad = cached_grads[idx] if opt_mode != 'mse' else None
114 |
115 | optimizer.zero_grad()
116 | out_quant = layer(cur_inp)
117 |
118 | err = loss_func(out_quant, cur_out, cur_grad)
119 | err.backward(retain_graph=True)
120 | if multi_gpu:
121 | raise NotImplementedError
122 | # for p in opt_params:
123 | # link.allreduce(p.grad)
124 | optimizer.step()
125 | if scheduler:
126 | scheduler.step()
127 |
128 | torch.cuda.empty_cache()
129 |
130 | # Finish optimization, use hard rounding.
131 | layer.weight_quantizer.soft_targets = False
132 | if layer.split != 0:
133 | layer.weight_quantizer_0.soft_targets = False
134 |
135 | # Reset original activation function
136 | if not include_act_func:
137 | layer.activation_function = org_act_func
138 |
139 |
140 | class LossFunction:
141 | def __init__(self,
142 | layer: QuantModule,
143 | round_loss: str = 'relaxation',
144 | weight: float = 1.,
145 | rec_loss: str = 'mse',
146 | max_count: int = 2000,
147 | b_range: tuple = (10, 2),
148 | decay_start: float = 0.0,
149 | warmup: float = 0.0,
150 | p: float = 2.):
151 |
152 | self.layer = layer
153 | self.round_loss = round_loss
154 | self.weight = weight
155 | self.rec_loss = rec_loss
156 | self.loss_start = max_count * warmup
157 | self.p = p
158 |
159 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start,
160 | start_b=b_range[0], end_b=b_range[1])
161 | self.count = 0
162 |
163 | def __call__(self, pred, tgt, grad=None):
164 | """
165 | Compute the total loss for adaptive rounding:
166 | rec_loss is the quadratic output reconstruction loss, round_loss is
167 | a regularization term to optimize the rounding policy
168 |
169 | :param pred: output from quantized model
170 | :param tgt: output from FP model
171 | :param grad: gradients to compute fisher information
172 | :return: total loss function
173 | """
174 | self.count += 1
175 | if self.rec_loss == 'mse':
176 | rec_loss = lp_loss(pred, tgt, p=self.p)
177 | elif self.rec_loss == 'fisher_diag':
178 | rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean()
179 | elif self.rec_loss == 'fisher_full':
180 | a = (pred - tgt).abs()
181 | grad = grad.abs()
182 | batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1)
183 | rec_loss = (batch_dotprod * a * grad).mean() / 100
184 | else:
185 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))
186 |
187 | b = self.temp_decay(self.count)
188 | if self.count < self.loss_start or self.round_loss == 'none':
189 | b = round_loss = 0
190 | elif self.round_loss == 'relaxation':
191 | round_loss = 0
192 | round_vals = self.layer.weight_quantizer.get_soft_targets()
193 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
194 | else:
195 | raise NotImplementedError
196 |
197 | total_loss = rec_loss + round_loss
198 | if self.count % 500 == 0:
199 | logger.info('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
200 | float(total_loss), float(rec_loss), float(round_loss), b, self.count))
201 | return total_loss
202 |
203 |
--------------------------------------------------------------------------------
/qdiff/quant_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.nn as nn
3 | from qdiff.quant_block import get_specials, BaseQuantBlock
4 | from qdiff.quant_block import QuantBasicTransformerBlock, QuantResBlock
5 | from qdiff.quant_block import QuantQKMatMul, QuantSMVMatMul, QuantBasicTransformerBlock#, QuantAttnBlock
6 | from qdiff.quant_layer import QuantModule, StraightThrough, TimewiseUniformQuantizer
7 | from ldm.modules.attention import BasicTransformerBlock
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class QuantModel(nn.Module):
13 |
14 | def __init__(self, model: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}, **kwargs):
15 | super().__init__()
16 | self.model = model
17 | timewise = kwargs.get('timewise', True)
18 | self.timewise = timewise
19 | list_timesteps = kwargs.get('list_timesteps', 50)
20 | self.timesteps = list_timesteps
21 | self.sm_abit = kwargs.get('sm_abit', 8)
22 | self.in_channels = model.in_channels
23 | if hasattr(model, 'image_size'):
24 | self.image_size = model.image_size
25 | self.specials = get_specials(act_quant_params['leaf_param'])
26 | self.quant_module_refactor(self.model, weight_quant_params, act_quant_params, timewise, list_timesteps)
27 | self.quant_block_refactor(self.model, weight_quant_params, act_quant_params, timewise, list_timesteps)
28 |
29 | def quant_module_refactor(self, module, weight_quant_params, act_quant_params, timewise, list_timesteps):
30 | """
31 | Recursively replace the normal layers (conv2D, conv1D, Linear etc.) to QuantModule
32 | :param module: nn.Module with nn.Conv2d, nn.Conv1d, or nn.Linear in its children
33 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer
34 | :param act_quant_params: quantization parameters like n_bits for activation quantizer
35 | """
36 | prev_quantmodule = None
37 | for name, child_module in module.named_children():
38 | if isinstance(child_module, (nn.Conv2d, nn.Conv1d, nn.Linear)): # nn.Conv1d
39 | setattr(module, name, QuantModule(
40 | child_module, weight_quant_params, act_quant_params, timewise=timewise, list_timesteps=list_timesteps))
41 | prev_quantmodule = getattr(module, name)
42 |
43 | elif isinstance(child_module, StraightThrough):
44 | continue
45 |
46 | else:
47 | self.quant_module_refactor(child_module, weight_quant_params, act_quant_params, timewise, list_timesteps)
48 |
49 | def quant_block_refactor(self, module, weight_quant_params, act_quant_params, timewise, list_timesteps):
50 | for name, child_module in module.named_children():
51 | if type(child_module) in self.specials:
52 | if self.specials[type(child_module)] in [QuantBasicTransformerBlock]:
53 | setattr(module, name, self.specials[type(child_module)](child_module,
54 | act_quant_params, sm_abit=self.sm_abit, timewise=timewise, list_timesteps=list_timesteps))
55 | elif self.specials[type(child_module)] == QuantSMVMatMul:
56 | setattr(module, name, self.specials[type(child_module)](
57 | act_quant_params, sm_abit=self.sm_abit, timewise=timewise, list_timesteps=list_timesteps))
58 | elif self.specials[type(child_module)] == QuantQKMatMul:
59 | setattr(module, name, self.specials[type(child_module)](
60 | act_quant_params, timewise=timewise, list_timesteps=list_timesteps))
61 | else:
62 | setattr(module, name, self.specials[type(child_module)](child_module,
63 | act_quant_params))
64 | else:
65 | self.quant_block_refactor(child_module, weight_quant_params, act_quant_params, timewise, list_timesteps)
66 |
67 |
68 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
69 | for m in self.model.modules():
70 | if isinstance(m, (QuantModule, BaseQuantBlock)):
71 | m.set_quant_state(weight_quant, act_quant)
72 |
73 | def set_timestep(self, t):
74 | for m in self.model.modules():
75 | if isinstance(m, (QuantModule, BaseQuantBlock)):
76 | m.set_timestep(t)
77 |
78 | def forward(self, x, timesteps=None, context=None):
79 | return self.model(x, timesteps, context)
80 |
81 | def set_running_stat(self, running_stat: bool, sm_only=False):
82 | # Only consider timewise=True here
83 | for m in self.model.modules():
84 | if isinstance(m, QuantBasicTransformerBlock):
85 | if sm_only:
86 | m.attn1.act_quantizer_w.set_running_stat(running_stat)
87 | m.attn2.act_quantizer_w.set_running_stat(running_stat)
88 | else:
89 | m.attn1.act_quantizer_q.set_running_stat(running_stat)
90 | m.attn1.act_quantizer_k.set_running_stat(running_stat)
91 | m.attn1.act_quantizer_v.set_running_stat(running_stat)
92 | m.attn1.act_quantizer_w.set_running_stat(running_stat)
93 | m.attn2.act_quantizer_q.set_running_stat(running_stat)
94 | m.attn2.act_quantizer_k.set_running_stat(running_stat)
95 | m.attn2.act_quantizer_v.set_running_stat(running_stat)
96 | m.attn2.act_quantizer_w.set_running_stat(running_stat)
97 | if isinstance(m , QuantQKMatMul):
98 | m.act_quantizer_q.set_running_stat(running_stat)
99 | m.act_quantizer_k.set_running_stat(running_stat)
100 | if isinstance(m , QuantSMVMatMul):
101 | m.act_quantizer_v.set_running_stat(running_stat)
102 | m.act_quantizer_w.set_running_stat(running_stat)
103 | if isinstance(m, QuantModule) and not sm_only:
104 | m.set_running_stat(running_stat)
105 |
106 | def save_dict_params(self):
107 | for m in self.model.modules():
108 | if isinstance(m, TimewiseUniformQuantizer):
109 | m.save_dict_params()
110 |
111 | def load_dict_params(self):
112 | for m in self.model.modules():
113 | if isinstance(m, TimewiseUniformQuantizer):
114 | m.load_dict_params()
115 |
116 | def set_grad_ckpt(self, grad_ckpt: bool):
117 | for name, m in self.model.named_modules():
118 | if isinstance(m, (QuantBasicTransformerBlock, BasicTransformerBlock)):
119 | # logger.info(name)
120 | m.checkpoint = grad_ckpt
121 | # elif isinstance(m, QuantResBlock):
122 | # logger.info(name)
123 | # m.use_checkpoint = grad_ckpt
124 |
--------------------------------------------------------------------------------