├── train_eqvae
├── ldm
│ ├── data
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── dataset.cpython-38.pyc
│ │ │ └── __init__.cpython-38.pyc
│ │ └── dataset.py
│ ├── models
│ │ ├── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── classifier.py
│ │ │ ├── ddim.py
│ │ │ └── plms.py
│ │ ├── __pycache__
│ │ │ ├── autoencoder.cpython-312.pyc
│ │ │ └── autoencoder.cpython-38.pyc
│ │ └── autoencoder.py
│ ├── modules
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── util.cpython-38.pyc
│ │ │ │ ├── model.cpython-312.pyc
│ │ │ │ ├── model.cpython-38.pyc
│ │ │ │ ├── util.cpython-312.pyc
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ └── __init__.cpython-38.pyc
│ │ │ └── util.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── __init__.cpython-38.pyc
│ │ │ │ ├── distributions.cpython-312.pyc
│ │ │ │ └── distributions.cpython-38.pyc
│ │ │ └── distributions.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── __init__.cpython-38.pyc
│ │ │ │ ├── contperceptual.cpython-38.pyc
│ │ │ │ └── contperceptual.cpython-312.pyc
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ ├── image_degradation
│ │ │ ├── utils
│ │ │ │ └── test.png
│ │ │ ├── __pycache__
│ │ │ │ ├── bsrgan.cpython-38.pyc
│ │ │ │ ├── __init__.cpython-38.pyc
│ │ │ │ ├── utils_image.cpython-38.pyc
│ │ │ │ └── bsrgan_light.cpython-38.pyc
│ │ │ └── __init__.py
│ │ ├── __pycache__
│ │ │ ├── attention.cpython-38.pyc
│ │ │ └── attention.cpython-312.pyc
│ │ ├── ema.py
│ │ └── attention.py
│ ├── __pycache__
│ │ ├── util.cpython-38.pyc
│ │ └── util.cpython-312.pyc
│ ├── lr_scheduler.py
│ └── util.py
├── download_sdvae.sh
├── setup.py
├── environment.yaml
├── README.md
└── configs
│ └── eqvae_config.yaml
├── media
└── teaser.png
├── evaluation
├── __pycache__
│ ├── lpips.cpython-38.pyc
│ └── calculate_fid.cpython-38.pyc
└── lpips.py
├── models
├── __pycache__
│ └── dit_models.cpython-312.pyc
├── diffusion
│ ├── __pycache__
│ │ ├── respace.cpython-312.pyc
│ │ ├── __init__.cpython-312.pyc
│ │ ├── diffusion_utils.cpython-312.pyc
│ │ └── gaussian_diffusion.cpython-312.pyc
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── respace.py
│ └── timestep_sampler.py
└── dit_models.py
├── environment.yml
├── README.md
├── eval.py
└── train_gen
├── extract_features.py
├── sample_ddp.py
└── train.py
/train_eqvae/ldm/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/media/teaser.png
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/evaluation/__pycache__/lpips.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/evaluation/__pycache__/lpips.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/dit_models.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/__pycache__/dit_models.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/__pycache__/util.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/__pycache__/util.cpython-312.pyc
--------------------------------------------------------------------------------
/evaluation/__pycache__/calculate_fid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/evaluation/__pycache__/calculate_fid.cpython-38.pyc
--------------------------------------------------------------------------------
/models/diffusion/__pycache__/respace.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/respace.cpython-312.pyc
--------------------------------------------------------------------------------
/models/diffusion/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/data/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/data/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/models/diffusion/__pycache__/diffusion_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/diffusion_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/models/__pycache__/autoencoder.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/models/__pycache__/autoencoder.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/models/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/models/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/__pycache__/attention.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/__pycache__/attention.cpython-312.pyc
--------------------------------------------------------------------------------
/models/diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/image_degradation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-312.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/image_degradation/__pycache__/utils_image.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/utils_image.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan_light.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan_light.cpython-38.pyc
--------------------------------------------------------------------------------
/train_eqvae/download_sdvae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir -p pretrained_models
4 | wget -O pretrained_models/kl-f8.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
5 | unzip pretrained_models/kl-f8.zip -d pretrained_models
6 | rm pretrained_models/kl-f8.zip
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/train_eqvae/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: eqvae
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python >=3.8
7 | - pytorch >=1.13
8 | - torchvision
9 | - pytorch-cuda=12.1
10 | - pip:
11 | - timm
12 | - diffusers
13 | - accelerate
14 | - pytorch_lightning
15 | - omegaconf
16 | - einops
17 | - -e git+https://github.com/CompVis/taming-transformers.git@3ba01b241669f5ade541ce990f7650a3b8f65318#egg=taming_transformers
18 |
--------------------------------------------------------------------------------
/train_eqvae/environment.yaml:
--------------------------------------------------------------------------------
1 | name: eqvae_train
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=10.2
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch_lightning==1.4.2
19 | - omegaconf==2.0.0
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - more-itertools>=8.0.0
24 | - transformers==4.3.1
25 | - torch-fidelity==0.3.0
26 | - torchmetrics==0.6.0
27 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
28 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
29 | - -e .
30 |
--------------------------------------------------------------------------------
/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | # rescale_timesteps=rescale_timesteps,
46 | )
47 |
--------------------------------------------------------------------------------
/train_eqvae/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### 1. Environment setup
3 | ```bash
4 | conda env create -f environment.yaml
5 | conda activate eqvae_train
6 | pip install packaging==21.3
7 | pip install 'torchmetrics<0.8'
8 | pip install transformers==4.10.2
9 | pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
10 | pip install Pillow==9.5.0
11 | ```
12 |
13 |
14 | ### 2. Download SD-VAE
15 | To download the SD-VAE from the official LDM repository run:
16 |
17 |
18 | ```bash
19 | bash download_sdvae.sh
20 | ```
21 |
22 |
23 |
24 | ### 3. Dataset
25 |
26 | #### Dataset download
27 |
28 | Currently, we provide experiments for [OpenImages](https://storage.googleapis.com/openimages/web/index.html). After downloading modify paths of train_dir, val_dir, dataset_name in the [cofig file](configs/eqvae_config.yaml)
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | ### 4. Training
39 |
40 | To run EQ-VAE regularization on 8 GPUs:
41 |
42 | ```bash
43 | python main.py \
44 | --base configs/eqvae_config.yaml \
45 | -t \
46 | --gpus 0,1,2,3,4,5,6,7 \
47 | --resume pretrained_models/model.ckpt \
48 | --logdir logs/eq-vae
49 | ```
50 |
51 |
52 | Then this script will automatically create the folder in `logs/eq-vae` to save logs and checkpoints.
53 | The provided arguments in `configs/eqvae_config.yaml` are the ones used in our paper. You can adjust the following options for your experiments:
54 |
55 | - `anisotropic`: If `True` will do anisotropic scaling
56 | - `uniform_sample_scale`: If `True` will sample scale factors uniformly from `[0.25, 1)` if set to `False` will randomly choose from scales from `{0.25, 0.5, 0.75}`.
57 | - `p_prior`: Probability to do prior preservation instead of equivariance regularization
58 | - `p_prior_s`: Probability to do prior presevation on lower resolutions instead of equivariance regularization
59 |
--------------------------------------------------------------------------------
/train_eqvae/configs/eqvae_config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | anisotropic: False
8 | uniform_sample_scale: True
9 | p_prior: 0.5
10 | p_prior_s: 0.25
11 | lossconfig:
12 | target: ldm.modules.losses.LPIPSWithDiscriminator
13 | params:
14 | disc_start: 0
15 | kl_weight: 0.000001
16 | disc_weight: 0.5
17 |
18 | ddconfig:
19 | double_z: True
20 | z_channels: 4
21 | resolution: 256
22 | in_channels: 3
23 | out_ch: 3
24 | ch: 128
25 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
26 | num_res_blocks: 2
27 | attn_resolutions: [ ]
28 | dropout: 0.0
29 |
30 |
31 | data:
32 | target: main.DataModuleFromConfig
33 | params:
34 | batch_size: 10
35 | wrap: True
36 | train:
37 | target: ldm.data.dataset.DatasetTrain
38 | params:
39 | train_dir: "/data/openimages/target_dir/train/"
40 | dataset_name: openimages # Currently we support imagenet/openimages
41 | size: 256
42 | degradation: pil_nearest
43 | validation:
44 | target: ldm.data.dataset.DatasetVal
45 | params:
46 | val_dir: "/data/openimages/target_dir/validation"
47 | dataset_name: openimages # Currently we support imagenet/openimages
48 | size: 256
49 | degradation: pil_nearest
50 |
51 |
52 |
53 |
54 | # data:
55 | # target: main.DataModuleFromConfig
56 | # params:
57 | # batch_size: 10
58 | # wrap: True
59 | # train:
60 | # target: ldm.data.dataset.DatasetTrain
61 | # params:
62 | # train_dir: "/data/imagenet/train"
63 | # dataset_name: imagenet # Currently we support imagenet/openimages
64 | # size: 256
65 | # degradation: pil_nearest
66 | # validation:
67 | # target: ldm.data.dataset.DatasetVal
68 | # params:
69 | # val_dir: "/data/imagenet/val"
70 | # dataset_name: imagenet # Currently we support imagenet/openimages
71 | # size: 256
72 | # degradation: pil_nearest
73 |
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 1000
82 | max_images: 8
83 | increase_log_steps: True
84 |
85 | trainer:
86 | benchmark: True
87 | accumulate_grad_batches: 2
88 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/models/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
10 |
11 |
12 | EQ-VAE: Equivariance Regularized Latent Space for Improved Generative Image Modeling
13 |
14 |
15 |
16 |
17 |
18 |
38 |
39 |
40 |
41 |
42 |
43 | TL;DR: We propose **EQ-VAE**, a straightforward regularization objective that promotes equivariance in the latent space of pretrained autoencoders under scaling and rotation. This leads to a more structured latent distribution, which accelerates generative model training and improves performance.
44 |
45 |
46 | ### 0. Quick Start with Hugging Face
47 | If you just want to use EQ-VAE to speedup 🚀 the training on your diffusion model you can use our HuggingFace checkpoints 🤗.
48 | We provide two models [eq-vae](https://huggingface.co/zelaki/eq-vae)
49 | and [eq-vae-ema](https://huggingface.co/zelaki/eq-vae-ema).
50 |
51 | | Model | Basemodel | Dataset | Epochs | rFID | PSNR | LPIPS | SSIM |
52 | |---------|-------------|-----------|--------|--------|--------|--------|--------|
53 | | [eq-vae](https://huggingface.co/zelaki/eq-vae) | SD-VAE | OpenImages | 5 | 0.82 | 25.95 | 0.141 | 0.72|
54 | | [eq-vae-ema](https://huggingface.co/zelaki/eq-vae-ema) | SD-VAE | Imagenet | 44 | 0.55 | 26.15 | 0.133 | 0.72 |
55 |
56 |
57 | ```python
58 | from diffusers import AutoencoderKL
59 | eqvae = AutoencoderKL.from_pretrained("zelaki/eq-vae")
60 | ```
61 |
62 | If you are looking for the weights in the original LDM format you can find them here: [eq-vae-ldm](https://huggingface.co/zelaki/eq-vae-ldm), [eq-vae-ema-ldm](https://huggingface.co/zelaki/eq-vae-ema-ldm)
63 |
64 |
65 | ### 1. Environment setup
66 |
67 | ```bash
68 | conda env create -f environment.yml
69 | conda activate eqvae
70 | ```
71 |
72 |
73 | ### 2. Train EQ-VAE
74 | We provide a training script to finetune [SD-VAE](https://ommer-lab.com/files/latent-diffusion/kl-f8.zip) with EQ-VAE regularization. For detailed guide go to [train_eqvae](./train_eqvae/).
75 |
76 |
77 | ### 3. Evaluate Reconstruction
78 | To evaluate the reconstruction of EQ-VAE, calculate rFID, LPIPS, SSIM and PSNR on a validation set (we use Imagenet Validation in our paper) with the following:
79 | ```bash
80 | torchrun --nproc_per_node=8 eval.py \
81 | --data_path /path/to/imagenet/validation \
82 | --output_path results \
83 | --ckpt_path /path/to/your/ckpt
84 | ```
85 |
86 | ### 4. Train DiT with EQ-VAE
87 | To train a DiT model with EQ-VAE on ImageNet:
88 | - First extract the latent representations:
89 | ```bash
90 | torchrun --nnodes=1 --nproc_per_node=8 train_gen/extract_features.py \
91 | --data-path /path/to/imagenet/train \
92 | --features-path /path/to/latents \
93 | --vae-ckpt /path/to/eqvae.ckpt \
94 | --vae-config configs/eqvae_config.yaml
95 | ```
96 | - Then train DiT on the precomputed latents:
97 | ```bash
98 | accelerate launch --mixed_precision fp16 train_gen/train.py \
99 | --model DiT-XL/2 \
100 | --feature-path /path/to/latents \
101 | --results-dir results
102 | ```
103 | - Evaluate generation as follows:
104 | ```bash
105 | torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py \
106 | --model DiT-XL/2 \
107 | --num-fid-samples 50000 \
108 | --ckpt /path/to/dit.cpt \
109 | --sample-dir samples \
110 | --vae-ckpt /path/to/eqvae.ckpt \
111 | --vae-config configs/eqvae_config.yaml \
112 | --ddpm True \
113 | --cfg-scale 1.0
114 | ```
115 |
116 | This script generates a folder of 50k samples as well as a .npz file and directly used with [ADM's TensorFlow evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute gFID.
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 | ### Acknowledgement
126 |
127 | This code is mainly built upon [LDM](https://github.com/CompVis/latent-diffusion) and [fastDiT](https://github.com/chuanyangjin/fast-DiT).
128 |
129 |
130 | ### Citation
131 |
132 | ```bibtex
133 | @inproceedings{
134 | kouzelis2025eqvae,
135 | title={{EQ}-{VAE}: Equivariance Regularized Latent Space for Improved Generative Image Modeling},
136 | author={Theodoros Kouzelis and Ioannis Kakogeorgiou and Spyros Gidaris and Nikos Komodakis},
137 | booktitle={Forty-second International Conference on Machine Learning},
138 | year={2025},
139 | url={https://openreview.net/forum?id=UWhW5YYLo6}
140 | }
141 | ```
142 | ``
143 |
--------------------------------------------------------------------------------
/models/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/models/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/evaluation/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import hashlib
3 | import requests
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 | import os
8 | from tqdm import tqdm
9 |
10 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
11 |
12 | CKPT_MAP = {"vgg_lpips": "vgg.pth"}
13 |
14 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
15 |
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def md5_hash(path):
31 | with open(path, "rb") as f:
32 | content = f.read()
33 | return hashlib.md5(content).hexdigest()
34 |
35 |
36 | def get_ckpt_path(name, root, check=False):
37 | assert name in URL_MAP
38 | path = os.path.join(root, CKPT_MAP[name])
39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41 | download(URL_MAP[name], path)
42 | md5 = md5_hash(path)
43 | assert md5 == MD5_MAP[name], md5
44 | return path
45 |
46 |
47 | class LPIPS(nn.Module):
48 | # Learned perceptual metric
49 | def __init__(self, use_dropout=True):
50 | super().__init__()
51 | self.scaling_layer = ScalingLayer()
52 | self.chns = [64, 128, 256, 512, 512] # vgg16 features
53 | self.net = vgg16(pretrained=True, requires_grad=False)
54 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
55 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
56 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
57 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
58 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
59 | self.load_from_pretrained()
60 | for param in self.parameters():
61 | param.requires_grad = False
62 |
63 | def load_from_pretrained(self, name="vgg_lpips"):
64 | ckpt = get_ckpt_path(name, "movqgan/modules/losses/lpips")
65 | self.load_state_dict(
66 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
67 | )
68 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
69 |
70 | @classmethod
71 | def from_pretrained(cls, name="vgg_lpips"):
72 | if name != "vgg_lpips":
73 | raise NotImplementedError
74 | model = cls()
75 | ckpt = get_ckpt_path(name)
76 | model.load_state_dict(
77 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
78 | )
79 | return model
80 |
81 | def forward(self, input, target):
82 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
83 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
84 | feats0, feats1, diffs = {}, {}, {}
85 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
86 | for kk in range(len(self.chns)):
87 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
88 | outs1[kk]
89 | )
90 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
91 |
92 | res = [
93 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
94 | for kk in range(len(self.chns))
95 | ]
96 | val = res[0]
97 | for l in range(1, len(self.chns)):
98 | val += res[l]
99 | return val
100 |
101 |
102 | class ScalingLayer(nn.Module):
103 | def __init__(self):
104 | super(ScalingLayer, self).__init__()
105 | self.register_buffer(
106 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
107 | )
108 | self.register_buffer(
109 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
110 | )
111 |
112 | def forward(self, inp):
113 | # convert imagenet normalized data to [-1, 1]
114 | return (inp - self.shift) / self.scale
115 |
116 |
117 | class NetLinLayer(nn.Module):
118 | """A single linear layer which does a 1x1 conv"""
119 |
120 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
121 | super(NetLinLayer, self).__init__()
122 | layers = (
123 | [
124 | nn.Dropout(),
125 | ]
126 | if (use_dropout)
127 | else []
128 | )
129 | layers += [
130 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
131 | ]
132 | self.model = nn.Sequential(*layers)
133 |
134 |
135 | class vgg16(torch.nn.Module):
136 | def __init__(self, requires_grad=False, pretrained=True):
137 | super(vgg16, self).__init__()
138 | vgg_pretrained_features = models.vgg16(pretrained=pretrained)
139 | vgg_pretrained_features = vgg_pretrained_features.features
140 | self.slice1 = torch.nn.Sequential()
141 | self.slice2 = torch.nn.Sequential()
142 | self.slice3 = torch.nn.Sequential()
143 | self.slice4 = torch.nn.Sequential()
144 | self.slice5 = torch.nn.Sequential()
145 | self.N_slices = 5
146 | for x in range(4):
147 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
148 | for x in range(4, 9):
149 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
150 | for x in range(9, 16):
151 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
152 | for x in range(16, 23):
153 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
154 | for x in range(23, 30):
155 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
156 | if not requires_grad:
157 | for param in self.parameters():
158 | param.requires_grad = False
159 |
160 | def forward(self, X):
161 | h = self.slice1(X)
162 | h_relu1_2 = h
163 | h = self.slice2(h)
164 | h_relu2_2 = h
165 | h = self.slice3(h)
166 | h_relu3_3 = h
167 | h = self.slice4(h)
168 | h_relu4_3 = h
169 | h = self.slice5(h)
170 | h_relu5_3 = h
171 | vgg_outputs = namedtuple(
172 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
173 | )
174 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
175 | return out
176 |
177 |
178 | def normalize_tensor(x, eps=1e-10):
179 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
180 | return x / (norm_factor + eps)
181 |
182 |
183 | def spatial_average(x, keepdim=True):
184 | return x.mean([2, 3], keepdim=keepdim)
185 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | import kornia
7 |
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 |
138 | class FrozenCLIPTextEmbedder(nn.Module):
139 | """
140 | Uses the CLIP transformer encoder for text.
141 | """
142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
143 | super().__init__()
144 | self.model, _ = clip.load(version, jit=False, device="cpu")
145 | self.device = device
146 | self.max_length = max_length
147 | self.n_repeat = n_repeat
148 | self.normalize = normalize
149 |
150 | def freeze(self):
151 | self.model = self.model.eval()
152 | for param in self.parameters():
153 | param.requires_grad = False
154 |
155 | def forward(self, text):
156 | tokens = clip.tokenize(text).to(self.device)
157 | z = self.model.encode_text(tokens)
158 | if self.normalize:
159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
160 | return z
161 |
162 | def encode(self, text):
163 | z = self(text)
164 | if z.ndim==2:
165 | z = z[:, None, :]
166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
167 | return z
168 |
169 |
170 | class FrozenClipImageEmbedder(nn.Module):
171 | """
172 | Uses the CLIP image encoder.
173 | """
174 | def __init__(
175 | self,
176 | model,
177 | jit=False,
178 | device='cuda' if torch.cuda.is_available() else 'cpu',
179 | antialias=False,
180 | ):
181 | super().__init__()
182 | self.model, _ = clip.load(name=model, device=device, jit=jit)
183 |
184 | self.antialias = antialias
185 |
186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
188 |
189 | def preprocess(self, x):
190 | # normalize to [0,1]
191 | x = kornia.geometry.resize(x, (224, 224),
192 | interpolation='bicubic',align_corners=True,
193 | antialias=self.antialias)
194 | x = (x + 1.) / 2.
195 | # renormalize according to clip
196 | x = kornia.enhance.normalize(x, self.mean, self.std)
197 | return x
198 |
199 | def forward(self, x):
200 | # x is assumed to be in range [-1,1]
201 | return self.model.encode_image(self.preprocess(x))
202 |
203 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os, yaml, pickle, shutil, tarfile, glob
2 | import cv2
3 | import albumentations
4 | import PIL
5 | import numpy as np
6 | import torchvision.transforms.functional as TF
7 | from omegaconf import OmegaConf
8 | from functools import partial
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from torch.utils.data import Dataset, Subset
12 |
13 | import taming.data.utils as tdu
14 |
15 | from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
16 |
17 |
18 |
19 | class DatasetTrain(Dataset):
20 | def __init__(self,
21 | train_dir = "/data/openimages/target_dir/train/",
22 | size=None, dataset_name="openimages",
23 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
24 | random_crop=True):
25 |
26 |
27 |
28 | if dataset_name == "openimages":
29 | rel_paths = [l for l in os.listdir(train_dir)]
30 | else:
31 | raise NotImplementedError
32 |
33 |
34 | self.labels = {
35 | "relative_file_path_": rel_paths,
36 | "file_path_": [os.path.join(train_dir, p) for p in rel_paths],
37 | }
38 | self.length = len(rel_paths)
39 |
40 |
41 | assert size
42 | assert (size / downscale_f).is_integer()
43 | self.size = size
44 | self.LR_size = int(size / downscale_f)
45 | self.min_crop_f = min_crop_f
46 | self.max_crop_f = max_crop_f
47 | assert(max_crop_f <= 1.)
48 | self.center_crop = not random_crop
49 |
50 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
51 |
52 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
53 |
54 | if degradation == "bsrgan":
55 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
56 |
57 | elif degradation == "bsrgan_light":
58 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
59 |
60 | else:
61 | interpolation_fn = {
62 | "cv_nearest": cv2.INTER_NEAREST,
63 | "cv_bilinear": cv2.INTER_LINEAR,
64 | "cv_bicubic": cv2.INTER_CUBIC,
65 | "cv_area": cv2.INTER_AREA,
66 | "cv_lanczos": cv2.INTER_LANCZOS4,
67 | "pil_nearest": PIL.Image.NEAREST,
68 | "pil_bilinear": PIL.Image.BILINEAR,
69 | "pil_bicubic": PIL.Image.BICUBIC,
70 | "pil_box": PIL.Image.BOX,
71 | "pil_hamming": PIL.Image.HAMMING,
72 | "pil_lanczos": PIL.Image.LANCZOS,
73 | }[degradation]
74 |
75 | self.pil_interpolation = degradation.startswith("pil_")
76 |
77 | if self.pil_interpolation:
78 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
79 |
80 | else:
81 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
82 | interpolation=interpolation_fn)
83 |
84 | def __len__(self):
85 | return self.length
86 |
87 |
88 | def __getitem__(self, i):
89 | example = dict((k, self.labels[k][i]) for k in self.labels)
90 | image = Image.open(example["file_path_"])
91 |
92 | if not image.mode == "RGB":
93 | image = image.convert("RGB")
94 |
95 | image = np.array(image).astype(np.uint8)
96 |
97 | min_side_len = min(image.shape[:2])
98 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
99 | crop_side_len = int(crop_side_len)
100 |
101 | if self.center_crop:
102 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
103 |
104 | else:
105 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
106 |
107 | image = self.cropper(image=image)["image"]
108 | image = self.image_rescaler(image=image)["image"]
109 |
110 | if self.pil_interpolation:
111 | image_pil = PIL.Image.fromarray(image)
112 | LR_image = self.degradation_process(image_pil)
113 | LR_image = np.array(LR_image).astype(np.uint8)
114 |
115 | else:
116 | LR_image = self.degradation_process(image=image)["image"]
117 |
118 | example["image"] = (image/127.5 - 1.0).astype(np.float32)
119 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
120 |
121 | return example
122 |
123 |
124 |
125 |
126 | class DatasetVal(Dataset):
127 | def __init__(self,
128 | val_dir="/data/openimages/target_dir/validation/",
129 | dataset_name='openimages',
130 | size=None,
131 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
132 | random_crop=True):
133 |
134 | if dataset_name == "openimages":
135 | rel_paths = [l for l in os.listdir(val_dir)]
136 | else:
137 | raise NotImplementedError
138 |
139 | self.labels = {
140 | "relative_file_path_": rel_paths,
141 | "file_path_": [os.path.join(val_dir, p) for p in rel_paths],
142 | }
143 | self.length = len(rel_paths)
144 |
145 |
146 |
147 | assert size
148 | assert (size / downscale_f).is_integer()
149 | self.size = size
150 | self.LR_size = int(size / downscale_f)
151 | self.min_crop_f = min_crop_f
152 | self.max_crop_f = max_crop_f
153 | assert(max_crop_f <= 1.)
154 | self.center_crop = not random_crop
155 |
156 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
157 |
158 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
159 |
160 | if degradation == "bsrgan":
161 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
162 |
163 | elif degradation == "bsrgan_light":
164 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
165 |
166 | else:
167 | interpolation_fn = {
168 | "cv_nearest": cv2.INTER_NEAREST,
169 | "cv_bilinear": cv2.INTER_LINEAR,
170 | "cv_bicubic": cv2.INTER_CUBIC,
171 | "cv_area": cv2.INTER_AREA,
172 | "cv_lanczos": cv2.INTER_LANCZOS4,
173 | "pil_nearest": PIL.Image.NEAREST,
174 | "pil_bilinear": PIL.Image.BILINEAR,
175 | "pil_bicubic": PIL.Image.BICUBIC,
176 | "pil_box": PIL.Image.BOX,
177 | "pil_hamming": PIL.Image.HAMMING,
178 | "pil_lanczos": PIL.Image.LANCZOS,
179 | }[degradation]
180 |
181 | self.pil_interpolation = degradation.startswith("pil_")
182 |
183 | if self.pil_interpolation:
184 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
185 |
186 | else:
187 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
188 | interpolation=interpolation_fn)
189 |
190 | def __len__(self):
191 | return self.length
192 |
193 | def __getitem__(self, i):
194 | example = dict((k, self.labels[k][i]) for k in self.labels)
195 | image = Image.open(example["file_path_"])
196 |
197 | if not image.mode == "RGB":
198 | image = image.convert("RGB")
199 |
200 | image = np.array(image).astype(np.uint8)
201 |
202 | min_side_len = min(image.shape[:2])
203 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
204 | crop_side_len = int(crop_side_len)
205 |
206 | if self.center_crop:
207 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
208 |
209 | else:
210 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
211 |
212 | image = self.cropper(image=image)["image"]
213 | image = self.image_rescaler(image=image)["image"]
214 |
215 | if self.pil_interpolation:
216 | image_pil = PIL.Image.fromarray(image)
217 | LR_image = self.degradation_process(image_pil)
218 | LR_image = np.array(LR_image).astype(np.uint8)
219 |
220 | else:
221 | LR_image = self.degradation_process(image=image)["image"]
222 |
223 | example["image"] = (image/127.5 - 1.0).astype(np.float32)
224 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
225 |
226 | return example
227 |
228 |
229 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # Code modified from https://github.com/hustvl/LightningDiT/blob/main/evaluate_tokenizer.py
2 |
3 | import os
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | import torch.distributed as dist
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from omegaconf import OmegaConf
11 | from torch.utils.data import DataLoader, DistributedSampler
12 | from evaluation.calculate_fid import calculate_fid_given_paths
13 | from concurrent.futures import ThreadPoolExecutor, as_completed
14 | from torchmetrics import StructuralSimilarityIndexMeasure
15 | from evaluation.lpips import LPIPS
16 | from torchvision.datasets import ImageFolder
17 | from torchvision import transforms
18 | import csv
19 | import sys
20 | from ldm.models.autoencoder import AutoencoderKL
21 | from ldm.util import instantiate_from_config
22 | import yaml
23 | from omegaconf import OmegaConf
24 |
25 | def load_config(config_path, display=False):
26 | config = OmegaConf.load(config_path)
27 | if display:
28 | print(yaml.dump(OmegaConf.to_container(config)))
29 | return config
30 |
31 |
32 | def load_kl(config, type="sd", ckpt_path=None):
33 | model = AutoencoderKL(**config.model.params)
34 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
35 | missing, unexpected = model.load_state_dict(sd, strict=False)
36 | return model.eval()
37 |
38 |
39 | def print_with_prefix(content, prefix='Tokenizer Evaluation', rank=0):
40 | if rank == 0:
41 | print(f"\033[34m[{prefix}]\033[0m {content}")
42 |
43 | def save_image(image, filename):
44 | Image.fromarray(image).save(filename)
45 |
46 |
47 |
48 | def evaluate_tokenizer(config_path, model_name, data_path, output_path, ckpt_path):
49 | # Initialize distributed training
50 | dist.init_process_group(backend='nccl')
51 | local_rank = torch.distributed.get_rank()
52 | torch.cuda.set_device(local_rank)
53 | device = torch.device(f'cuda:{local_rank}')
54 |
55 |
56 | vae_config = load_config(config_path, display=False)
57 | model = load_kl(vae_config, ckpt_path=ckpt_path).to(device)
58 |
59 | # Image preprocessing
60 | transform = transforms.Compose([
61 | transforms.ToTensor(),
62 | transforms.Resize(256),
63 | transforms.CenterCrop(256),
64 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
65 | ])
66 |
67 | # Create dataset and dataloader
68 | dataset = ImageFolder(root=data_path, transform=transform)
69 | distributed_sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=local_rank)
70 | val_dataloader = DataLoader(
71 | dataset,
72 | batch_size=1,
73 | shuffle=False,
74 | num_workers=4,
75 | sampler=distributed_sampler
76 | )
77 |
78 | folder_name = model_name
79 |
80 | save_dir = os.path.join(output_path, folder_name, 'decoded_images')
81 | ref_path = os.path.join(output_path, folder_name, 'ref_images')
82 | metric_path = os.path.join(output_path, folder_name, 'metrics.csv')
83 |
84 | os.makedirs(save_dir, exist_ok=True)
85 | os.makedirs(ref_path, exist_ok=True)
86 |
87 | if local_rank == 0:
88 | print_with_prefix(f"Output dir: {save_dir}")
89 | print_with_prefix(f"Reference dir: {ref_path}")
90 |
91 | # Save reference images if needed
92 | ref_png_files = [f for f in os.listdir(ref_path) if f.endswith('.png')]
93 | if len(ref_png_files) < 50000:
94 | total_samples = 0
95 | for batch in val_dataloader:
96 | images = batch[0].to(device)
97 | for j in range(images.size(0)):
98 | img = torch.clamp(127.5 * images[j] + 128.0, 0, 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8)
99 | Image.fromarray(img).save(os.path.join(ref_path, f"ref_image_rank_{local_rank}_{total_samples}.png"))
100 | total_samples += 1
101 | if total_samples % 100 == 0 and local_rank == 0:
102 | print_with_prefix(f"Rank {local_rank}, Saved {total_samples} reference images")
103 | dist.barrier()
104 |
105 | # Initialize metrics
106 | lpips_values = []
107 | ssim_values = []
108 | lpips = LPIPS().to(device).eval()
109 | ssim_metric = StructuralSimilarityIndexMeasure().to(device)
110 |
111 | # Generate reconstructions and compute metrics
112 | if local_rank == 0:
113 | print_with_prefix("Generating reconstructions...")
114 | all_indices = 0
115 |
116 | for batch in val_dataloader:
117 | images = batch[0].to(device)
118 | with torch.no_grad():
119 | latents = model.encode(images).sample().to(torch.float32)
120 | decoded_images_tensor = model.decode(latents)
121 |
122 | decoded_images = torch.clamp(127.5 * decoded_images_tensor + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
123 |
124 | # Compute metrics
125 | lpips_values.append(lpips(decoded_images_tensor, images).mean())
126 | ssim_values.append(ssim_metric(decoded_images_tensor, images))
127 |
128 | # Save reconstructions
129 | for i, img in enumerate(decoded_images):
130 | save_image(img, os.path.join(save_dir, f"decoded_image_rank_{local_rank}_{all_indices + i}.png"))
131 | if (all_indices + i) % 100 == 0 and local_rank == 0:
132 | print_with_prefix(f"Rank {local_rank}, Processed {all_indices + i} images")
133 | all_indices += len(decoded_images)
134 | dist.barrier()
135 |
136 | # Aggregate metrics across GPUs
137 | lpips_values = torch.tensor(lpips_values).to(device)
138 | ssim_values = torch.tensor(ssim_values).to(device)
139 | dist.all_reduce(lpips_values, op=dist.ReduceOp.AVG)
140 | dist.all_reduce(ssim_values, op=dist.ReduceOp.AVG)
141 |
142 | avg_lpips = lpips_values.mean().item()
143 | avg_ssim = ssim_values.mean().item()
144 |
145 | if local_rank == 0:
146 | # Calculate FID
147 | print_with_prefix("Computing rFID...")
148 | fid = calculate_fid_given_paths([ref_path, save_dir], batch_size=50, dims=2048, device=device, num_workers=16)
149 |
150 | # Calculate PSNR
151 | print_with_prefix("Computing PSNR...")
152 | psnr_values = calculate_psnr_between_folders(ref_path, save_dir)
153 | avg_psnr = sum(psnr_values) / len(psnr_values)
154 | with open(metric_path, mode="w", newline="") as file:
155 | writer = csv.writer(file)
156 | writer.writerow(["FID", f"{fid:.3f}"])
157 | writer.writerow(["PSNR", f"{avg_psnr:.3f}"])
158 | writer.writerow(["LPIPS", f"{avg_lpips:.3f}"])
159 | writer.writerow(["SSIM", f"{avg_ssim:.3f}"])
160 |
161 | dist.destroy_process_group()
162 |
163 |
164 | def decode_to_images(model, z):
165 | with torch.no_grad():
166 | images = model.decode(z)
167 | images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
168 | return images
169 |
170 | def calculate_psnr(original, processed):
171 | mse = torch.mean((original - processed) ** 2)
172 | return 20 * torch.log10(255.0 / torch.sqrt(mse)).item()
173 |
174 | def load_image(image_path):
175 | image = Image.open(image_path).convert('RGB')
176 | return torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32)
177 |
178 | def calculate_psnr_for_pair(original_path, processed_path):
179 | return calculate_psnr(load_image(original_path), load_image(processed_path))
180 |
181 | def calculate_psnr_between_folders(original_folder, processed_folder):
182 | original_files = sorted(os.listdir(original_folder))
183 | processed_files = sorted(os.listdir(processed_folder))
184 |
185 | if len(original_files) != len(processed_files):
186 | print("Warning: Mismatched number of images in folders")
187 | return []
188 |
189 | with ThreadPoolExecutor() as executor:
190 | futures = [
191 | executor.submit(calculate_psnr_for_pair,
192 | os.path.join(original_folder, orig),
193 | os.path.join(processed_folder, proc))
194 | for orig, proc in zip(original_files, processed_files)
195 | ]
196 | return [future.result() for future in as_completed(futures)]
197 |
198 | if __name__ == "__main__":
199 | import argparse
200 | parser = argparse.ArgumentParser()
201 | parser.add_argument('--config_path', type=str, default='configs/eqvae_config.yaml')
202 | parser.add_argument('--model_name', type=str, default='eq_vae')
203 | parser.add_argument('--ckpt_path', type=str)
204 | parser.add_argument('--data_path', type=str, default='/path/to/your/imagenet/ILSVRC2012_validation/data')
205 | parser.add_argument('--output_path', type=str, default='/path/to/your/output')
206 | parser.add_argument('--seed', type=int, default=42)
207 | args = parser.parse_args()
208 | evaluate_tokenizer(config_path=args.config_path, model_name=args.model_name, data_path=args.data_path, output_path=args.output_path, ckpt_path=args.ckpt_path)
209 |
--------------------------------------------------------------------------------
/train_gen/extract_features.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/chuanyangjin/fast-DiT
3 | """
4 | import sys
5 | import os
6 | sys.path.append(os.path.join(os.path.dirname(__file__), '../train_eqvae'))
7 | sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
8 |
9 | import torch
10 | torch.backends.cuda.matmul.allow_tf32 = True
11 | torch.backends.cudnn.allow_tf32 = True
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.utils.data import DataLoader
15 | from torch.utils.data.distributed import DistributedSampler
16 | from torchvision.datasets import ImageFolder
17 | from torchvision import transforms
18 | import numpy as np
19 | from collections import OrderedDict
20 | from PIL import Image
21 | from copy import deepcopy
22 | from glob import glob
23 | from time import time
24 | import argparse
25 | import logging
26 | from models.dit_models import DiT_models
27 | from models.diffusion import create_diffusion
28 | from diffusers.models import AutoencoderKL
29 | from tqdm import tqdm
30 |
31 | from train_eqvae.ldm.models.autoencoder import AutoencoderKL as LDMAutoencoderKL
32 | from ldm.util import instantiate_from_config
33 | import yaml
34 | from omegaconf import OmegaConf
35 |
36 | def load_config(config_path, display=False):
37 | config = OmegaConf.load(config_path)
38 | if display:
39 | print(yaml.dump(OmegaConf.to_container(config)))
40 | return config
41 |
42 |
43 | def load_kl(config, ckpt_path=None):
44 |
45 | model = LDMAutoencoderKL(**config.model.params)
46 |
47 | if ckpt_path is not None:
48 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
49 | missing, unexpected = model.load_state_dict(sd, strict=False)
50 | return model.eval()
51 |
52 | def preprocess_kl(x):
53 | x = 2.*x - 1.
54 | return x
55 |
56 |
57 |
58 | @torch.no_grad()
59 | def update_ema(ema_model, model, decay=0.9999):
60 | """
61 | Step the EMA model towards the current model.
62 | """
63 | ema_params = OrderedDict(ema_model.named_parameters())
64 | model_params = OrderedDict(model.named_parameters())
65 |
66 | for name, param in model_params.items():
67 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
68 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
69 |
70 |
71 | def requires_grad(model, flag=True):
72 | """
73 | Set requires_grad flag for all parameters in a model.
74 | """
75 | for p in model.parameters():
76 | p.requires_grad = flag
77 |
78 |
79 | def cleanup():
80 | """
81 | End DDP training.
82 | """
83 | dist.destroy_process_group()
84 |
85 |
86 | def create_logger(logging_dir):
87 | """
88 | Create a logger that writes to a log file and stdout.
89 | """
90 | if dist.get_rank() == 0: # real logger
91 | logging.basicConfig(
92 | level=logging.INFO,
93 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
94 | datefmt='%Y-%m-%d %H:%M:%S',
95 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
96 | )
97 | logger = logging.getLogger(__name__)
98 | else: # dummy logger (does nothing)
99 | logger = logging.getLogger(__name__)
100 | logger.addHandler(logging.NullHandler())
101 | return logger
102 |
103 |
104 | def center_crop_arr(pil_image, image_size):
105 | """
106 | Center cropping implementation from ADM.
107 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
108 | """
109 | while min(*pil_image.size) >= 2 * image_size:
110 | pil_image = pil_image.resize(
111 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
112 | )
113 |
114 | scale = image_size / min(*pil_image.size)
115 | pil_image = pil_image.resize(
116 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
117 | )
118 |
119 | arr = np.array(pil_image)
120 | crop_y = (arr.shape[0] - image_size) // 2
121 | crop_x = (arr.shape[1] - image_size) // 2
122 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
123 |
124 |
125 | #################################################################################
126 | # Training Loop #
127 | #################################################################################
128 |
129 | def main(args):
130 | """
131 | Trains a new DiT model.
132 | """
133 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
134 |
135 | # Setup DDP:
136 | dist.init_process_group("nccl")
137 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
138 | rank = dist.get_rank()
139 | device = rank % torch.cuda.device_count()
140 | seed = args.global_seed * dist.get_world_size() + rank
141 | torch.manual_seed(seed)
142 | torch.cuda.set_device(device)
143 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
144 |
145 | # Setup a feature folder:
146 | if rank == 0:
147 | os.makedirs(args.features_path, exist_ok=True)
148 | os.makedirs(os.path.join(args.features_path, 'imagenet256_features'), exist_ok=True)
149 | os.makedirs(os.path.join(args.features_path, 'imagenet256_labels'), exist_ok=True)
150 |
151 | # Create model:
152 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
153 | latent_size = args.image_size // 8
154 |
155 | if args.vae_ckpt is not None:
156 | vae_config = load_config(args.vae_config, display=False)
157 | vae = load_kl(vae_config, ckpt_path=args.vae_ckpt).to(device)
158 | else:
159 | from diffusers.models import AutoencoderKL
160 | vae = AutoencoderKL.from_pretrained(args.hf_model_name).to(device)
161 |
162 |
163 | # Setup data:
164 | transform = transforms.Compose([
165 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
166 | transforms.RandomHorizontalFlip(),
167 | transforms.ToTensor(),
168 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
169 | ])
170 | dataset = ImageFolder(args.data_path, transform=transform)
171 | sampler = DistributedSampler(
172 | dataset,
173 | num_replicas=dist.get_world_size(),
174 | rank=rank,
175 | shuffle=False,
176 | seed=args.global_seed
177 | )
178 | loader = DataLoader(
179 | dataset,
180 | batch_size = 1,
181 | shuffle=False,
182 | sampler=sampler,
183 | num_workers=args.num_workers,
184 | pin_memory=True,
185 | drop_last=True
186 | )
187 |
188 | train_steps = 0
189 | for x, y in tqdm(loader):
190 | x = x.to(device)
191 | y = y.to(device)
192 | with torch.no_grad():
193 |
194 | if args.vae_ckpt is not None:
195 |
196 | x = vae.encode(x).sample().mul_(args.vae_scaling_factor)
197 |
198 | else:
199 | x = vae.encode(x).latent_dist.sample().mul_(args.vae_scaling_factor)
200 |
201 | x = x.detach().cpu().numpy() # (1, 4, 32, 32)
202 | np.save(f'{args.features_path}/imagenet256_features/{rank}-{train_steps}.npy', x)
203 |
204 | y = y.detach().cpu().numpy() # (1,)
205 | np.save(f'{args.features_path}/imagenet256_labels/{rank}-{train_steps}.npy', y)
206 |
207 | train_steps += 1
208 |
209 | if __name__ == "__main__":
210 | # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
211 | parser = argparse.ArgumentParser()
212 | parser.add_argument("--data-path", type=str, required=True)
213 | parser.add_argument("--features-path", type=str, default="features")
214 | parser.add_argument("--results-dir", type=str, default="results")
215 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
216 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
217 | parser.add_argument("--num-classes", type=int, default=1000)
218 | parser.add_argument("--epochs", type=int, default=1400)
219 | parser.add_argument("--global-batch-size", type=int, default=256)
220 | parser.add_argument("--global-seed", type=int, default=0)
221 | parser.add_argument("--vae", type=str, choices=["ema", "mse", "xl", "sd3","ours"], default="ema")
222 | parser.add_argument("--vae-ckpt", type=str, default=None)
223 | parser.add_argument("--vae-config", type=str, default=None)
224 | parser.add_argument("--hf-model-name", type=str, default="zelaki/eq-vae")
225 | parser.add_argument("--vae-scaling-factor", type=float, default=0.18215)
226 | parser.add_argument("--num-workers", type=int, default=4)
227 | parser.add_argument("--log-every", type=int, default=100)
228 | parser.add_argument("--ckpt-every", type=int, default=50_000)
229 | args = parser.parse_args()
230 | main(args)
231 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/train_eqvae/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7 |
8 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
9 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10 |
11 | from ldm.util import instantiate_from_config
12 | import random
13 | import json
14 | import os
15 | import copy
16 |
17 |
18 |
19 | def flip_or_rotate_image(inputs, flip):
20 | if flip == "h":
21 | inputs = torch.flip(inputs, [-1])
22 |
23 | elif flip == "v":
24 | inputs = torch.flip(inputs, [-2])
25 |
26 | elif flip == "vh":
27 | inputs = torch.flip(inputs, [-1,-2])
28 |
29 | elif flip == "90":
30 | inputs = torch.rot90(inputs, k=1, dims=[-1, -2])
31 |
32 | else:
33 | inputs = torch.rot90(inputs, k=3, dims=[-1, -2])
34 |
35 | return inputs
36 |
37 |
38 | class AutoencoderKL(pl.LightningModule):
39 | def __init__(self,
40 | ddconfig,
41 | lossconfig,
42 | embed_dim,
43 | anisotropic=False,
44 | uniform_sample_scale=True,
45 | ckpt_path=None,
46 | p_prior=0.5,
47 | p_prior_s=0.25,
48 | ignore_keys=[],
49 | image_key="image",
50 | colorize_nlabels=None,
51 | monitor=None,
52 | ):
53 | super().__init__()
54 | self.image_key = image_key
55 | self.encoder = Encoder(**ddconfig)
56 | self.decoder = Decoder(**ddconfig)
57 | self.loss = instantiate_from_config(lossconfig)
58 | self.uniform_sample_scale = uniform_sample_scale
59 | self.anisotropic = anisotropic
60 | self.p_prior=p_prior
61 | self.p_prior_s=p_prior_s
62 | assert ddconfig["double_z"]
63 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
64 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
65 | self.embed_dim = embed_dim
66 | if colorize_nlabels is not None:
67 | assert type(colorize_nlabels)==int
68 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
69 | if monitor is not None:
70 | self.monitor = monitor
71 | if ckpt_path is not None:
72 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
73 |
74 | def init_from_ckpt(self, path, ignore_keys=list()):
75 | sd = torch.load(path, map_location="cpu")["state_dict"]
76 | keys = list(sd.keys())
77 | for k in keys:
78 | for ik in ignore_keys:
79 | if k.startswith(ik):
80 | print("Deleting key {} from state_dict.".format(k))
81 | del sd[k]
82 | self.load_state_dict(sd, strict=False)
83 | print(f"Restored from {path}")
84 |
85 | def encode(self, x):
86 | h = self.encoder(x)
87 | moments = self.quant_conv(h)
88 | posterior = DiagonalGaussianDistribution(moments)
89 | return posterior
90 |
91 | def decode(self, z):
92 | z = self.post_quant_conv(z)
93 | dec = self.decoder(z)
94 | return dec
95 |
96 | def forward(self, input, scale=1, angle=0):
97 | posterior = self.encode(input)
98 | z = posterior.sample()
99 |
100 | if scale != 1:
101 | z = torch.nn.functional.interpolate(z, scale_factor=scale, mode='bilinear', align_corners=False)
102 |
103 | if angle != 0:
104 | z = torch.rot90(z, k=angle, dims=[-1, -2])
105 |
106 | dec = self.decode(z)
107 | return dec, posterior, z
108 |
109 | def get_input(self, batch, k):
110 | x = batch[k]
111 | if len(x.shape) == 3:
112 | x = x[..., None]
113 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
114 | return x
115 |
116 | def training_step(self, batch, batch_idx, optimizer_idx):
117 | inputs = self.get_input(batch, self.image_key)
118 |
119 | # EQ-VAE regularization
120 | if random.random() < self.p_prior:
121 |
122 | mode = "latent"
123 | if self.anisotropic:
124 | scale_x = random.choice([s / 32 for s in range(8,32)])
125 | scale_y = random.choice([s / 32 for s in range(8,32)])
126 | scale=(scale_x, scale_y)
127 | else:
128 | scale = random.choice([s / 32 for s in range(8,32)])
129 |
130 | # rotation angles 1 -> π/2, 2 -> π, 3 -> 3π/2
131 | angle = random.choice([1, 2, 3])
132 | reconstructions, posterior, z_after = self(inputs, scale=scale, angle=angle)
133 |
134 | # Scale ground truth images with the same scale
135 | inputs = torch.nn.functional.interpolate(inputs, scale_factor=scale, mode='bilinear', align_corners=False)
136 |
137 | # Rotate ground truth images with the same angle
138 | inputs = torch.rot90(inputs, k=angle, dims=[-1, -2])
139 |
140 | # prior preservation
141 | else:
142 | mode = "image"
143 | # this is prior preservation for low resolution images
144 | if random.random() < self.p_prior_s:
145 |
146 | scale = random.choice([s / 32 for s in range(8,32)])
147 | inputs = torch.nn.functional.interpolate(inputs, scale_factor=scale, mode='bilinear', align_corners=False)
148 | reconstructions, posterior, _ = self(inputs)
149 |
150 | # this is prior preservation for full resolution images
151 | else:
152 | scale=1
153 | reconstructions, posterior, _ = self(inputs)
154 |
155 |
156 |
157 | if optimizer_idx == 0:
158 |
159 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
160 | last_layer=self.get_last_layer(), split="train")
161 |
162 |
163 | self.log(f"aeloss_scale-{scale}-{mode}", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
164 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
165 | return aeloss
166 |
167 | if optimizer_idx == 1:
168 |
169 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
170 | last_layer=self.get_last_layer(), split="train")
171 |
172 |
173 | self.log(f"discloss_scale-{scale}-{mode}", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
174 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
175 | return discloss
176 |
177 | def validation_step(self, batch, batch_idx):
178 | inputs = self.get_input(batch, self.image_key)
179 | reconstructions, posterior, _ = self(inputs)
180 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
181 | last_layer=self.get_last_layer(), split="val")
182 |
183 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
184 | last_layer=self.get_last_layer(), split="val")
185 |
186 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
187 | self.log_dict(log_dict_ae)
188 | self.log_dict(log_dict_disc)
189 | return self.log_dict
190 |
191 | def configure_optimizers(self):
192 | lr = self.learning_rate
193 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
194 | list(self.decoder.parameters())+
195 | list(self.quant_conv.parameters())+
196 | list(self.post_quant_conv.parameters()),
197 | lr=lr, betas=(0.5, 0.9))
198 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
199 | lr=lr, betas=(0.5, 0.9))
200 | return [opt_ae, opt_disc], []
201 |
202 | def get_last_layer(self):
203 | return self.decoder.conv_out.weight
204 |
205 | @torch.no_grad()
206 | def log_images(self, batch, only_inputs=False, **kwargs):
207 | log = dict()
208 | x = self.get_input(batch, self.image_key)
209 | x = x.to(self.device)
210 | if not only_inputs:
211 | if random.random() < 0.5:
212 | xrec, posterior, _ = self(x)
213 | else:
214 | xrec, posterior, _ = self(x, scale=0.5)
215 |
216 | if x.shape[1] > 3:
217 | # colorize with random projection
218 | assert xrec.shape[1] > 3
219 | x = self.to_rgb(x)
220 | xrec = self.to_rgb(xrec)
221 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
222 | log["reconstructions"] = xrec
223 | log["inputs"] = x
224 | return log
225 |
226 | def to_rgb(self, x):
227 | assert self.image_key == "segmentation"
228 | if not hasattr(self, "colorize"):
229 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
230 | x = F.conv2d(x, weight=self.colorize)
231 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
232 | return x
233 |
234 |
235 | class IdentityFirstStage(torch.nn.Module):
236 | def __init__(self, *args, vq_interface=False, **kwargs):
237 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
238 | super().__init__()
239 |
240 | def encode(self, x, *args, **kwargs):
241 | return x
242 |
243 | def decode(self, x, *args, **kwargs):
244 | return x
245 |
246 | def quantize(self, x, *args, **kwargs):
247 | if self.vq_interface:
248 | return x, None, [None, None, None]
249 | return x
250 |
251 | def forward(self, x, *args, **kwargs):
252 | return x
253 |
254 |
255 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/train_eqvae/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class DDIMSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27 | alphas_cumprod = self.model.alphas_cumprod
28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30 |
31 | self.register_buffer('betas', to_torch(self.model.betas))
32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34 |
35 | # calculations for diffusion q(x_t | x_{t-1}) and others
36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41 |
42 | # ddim sampling parameters
43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44 | ddim_timesteps=self.ddim_timesteps,
45 | eta=ddim_eta,verbose=verbose)
46 | self.register_buffer('ddim_sigmas', ddim_sigmas)
47 | self.register_buffer('ddim_alphas', ddim_alphas)
48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54 |
55 | @torch.no_grad()
56 | def sample(self,
57 | S,
58 | batch_size,
59 | shape,
60 | conditioning=None,
61 | callback=None,
62 | normals_sequence=None,
63 | img_callback=None,
64 | quantize_x0=False,
65 | eta=0.,
66 | mask=None,
67 | x0=None,
68 | temperature=1.,
69 | noise_dropout=0.,
70 | score_corrector=None,
71 | corrector_kwargs=None,
72 | verbose=True,
73 | x_T=None,
74 | log_every_t=100,
75 | unconditional_guidance_scale=1.,
76 | unconditional_conditioning=None,
77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78 | **kwargs
79 | ):
80 | if conditioning is not None:
81 | if isinstance(conditioning, dict):
82 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
83 | if cbs != batch_size:
84 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
85 | else:
86 | if conditioning.shape[0] != batch_size:
87 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
88 |
89 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
90 | # sampling
91 | C, H, W = shape
92 | size = (batch_size, C, H, W)
93 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
94 |
95 | samples, intermediates = self.ddim_sampling(conditioning, size,
96 | callback=callback,
97 | img_callback=img_callback,
98 | quantize_denoised=quantize_x0,
99 | mask=mask, x0=x0,
100 | ddim_use_original_steps=False,
101 | noise_dropout=noise_dropout,
102 | temperature=temperature,
103 | score_corrector=score_corrector,
104 | corrector_kwargs=corrector_kwargs,
105 | x_T=x_T,
106 | log_every_t=log_every_t,
107 | unconditional_guidance_scale=unconditional_guidance_scale,
108 | unconditional_conditioning=unconditional_conditioning,
109 | )
110 | return samples, intermediates
111 |
112 | @torch.no_grad()
113 | def ddim_sampling(self, cond, shape,
114 | x_T=None, ddim_use_original_steps=False,
115 | callback=None, timesteps=None, quantize_denoised=False,
116 | mask=None, x0=None, img_callback=None, log_every_t=100,
117 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
118 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
119 | device = self.model.betas.device
120 | b = shape[0]
121 | if x_T is None:
122 | img = torch.randn(shape, device=device)
123 | else:
124 | img = x_T
125 |
126 | if timesteps is None:
127 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
128 | elif timesteps is not None and not ddim_use_original_steps:
129 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
130 | timesteps = self.ddim_timesteps[:subset_end]
131 |
132 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
133 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
134 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
135 | print(f"Running DDIM Sampling with {total_steps} timesteps")
136 |
137 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
138 |
139 | for i, step in enumerate(iterator):
140 | index = total_steps - i - 1
141 | ts = torch.full((b,), step, device=device, dtype=torch.long)
142 |
143 | if mask is not None:
144 | assert x0 is not None
145 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
146 | img = img_orig * mask + (1. - mask) * img
147 |
148 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
149 | quantize_denoised=quantize_denoised, temperature=temperature,
150 | noise_dropout=noise_dropout, score_corrector=score_corrector,
151 | corrector_kwargs=corrector_kwargs,
152 | unconditional_guidance_scale=unconditional_guidance_scale,
153 | unconditional_conditioning=unconditional_conditioning)
154 | img, pred_x0 = outs
155 | if callback: callback(i)
156 | if img_callback: img_callback(pred_x0, i)
157 |
158 | if index % log_every_t == 0 or index == total_steps - 1:
159 | intermediates['x_inter'].append(img)
160 | intermediates['pred_x0'].append(pred_x0)
161 |
162 | return img, intermediates
163 |
164 | @torch.no_grad()
165 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
166 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
167 | unconditional_guidance_scale=1., unconditional_conditioning=None):
168 | b, *_, device = *x.shape, x.device
169 |
170 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
171 | e_t = self.model.apply_model(x, t, c)
172 | else:
173 | x_in = torch.cat([x] * 2)
174 | t_in = torch.cat([t] * 2)
175 | c_in = torch.cat([unconditional_conditioning, c])
176 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
177 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
178 |
179 | if score_corrector is not None:
180 | assert self.model.parameterization == "eps"
181 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
182 |
183 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187 | # select parameters corresponding to the currently considered timestep
188 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
189 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
190 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
191 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
192 |
193 | # current prediction for x_0
194 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
195 | if quantize_denoised:
196 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
197 | # direction pointing to x_t
198 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200 | if noise_dropout > 0.:
201 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203 | return x_prev, pred_x0
204 |
--------------------------------------------------------------------------------
/train_gen/sample_ddp.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Code adapted from https://github.com/chuanyangjin/fast-DiT
4 | """
5 | import sys
6 | import os
7 | sys.path.append(os.path.join(os.path.dirname(__file__), '../train_eqvae'))
8 | sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
9 |
10 |
11 | import torch
12 | import torch.distributed as dist
13 | from download import find_model
14 | from models.dit_models import DiT_models
15 | from models.diffusion import create_diffusion
16 | from diffusers.models import AutoencoderKL
17 | from tqdm import tqdm
18 | import os
19 | from PIL import Image
20 | import numpy as np
21 | import math
22 | import argparse
23 | from evaluator import Evaluator
24 | import tensorflow.compat.v1 as tf
25 | from train_eqvae.ldm.models.autoencoder import AutoencoderKL as LDMAutoencoderKL
26 |
27 | from ldm.util import instantiate_from_config
28 | import yaml
29 | from omegaconf import OmegaConf
30 |
31 | def load_config(config_path, display=False):
32 | config = OmegaConf.load(config_path)
33 | if display:
34 | print(yaml.dump(OmegaConf.to_container(config)))
35 | return config
36 |
37 |
38 | def load_kl(config, ckpt_path=None):
39 |
40 | model = LDMAutoencoderKL(**config.model.params)
41 |
42 | if ckpt_path is not None:
43 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
44 | missing, unexpected = model.load_state_dict(sd, strict=False)
45 | return model.eval()
46 |
47 | def preprocess_kl(x):
48 | x = 2.*x - 1.
49 | return x
50 |
51 |
52 |
53 | def create_npz_from_sample_folder(sample_dir, num=50_000):
54 | """
55 | Builds a single .npz file from a folder of .png samples.
56 | """
57 | samples = []
58 | for i in tqdm(range(num), desc="Building .npz file from samples"):
59 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
60 | sample_np = np.asarray(sample_pil).astype(np.uint8)
61 | samples.append(sample_np)
62 | samples = np.stack(samples)
63 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
64 | npz_path = f"{sample_dir}.npz"
65 | np.savez(npz_path, arr_0=samples)
66 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
67 | return npz_path
68 |
69 |
70 |
71 | def custom_to_pil(x):
72 | x = x.detach().cpu()
73 | x = torch.clamp(x, -1., 1.)
74 | x = (x + 1.)/2.
75 | x = x.permute(1,2,0).numpy()
76 | x = (255*x).astype(np.uint8)
77 | x = Image.fromarray(x)
78 | if not x.mode == "RGB":
79 | x = x.convert("RGB")
80 | return x
81 |
82 |
83 | def calculate_metrics(ref_batch, sample_batch, fid_path):
84 | config = tf.ConfigProto(
85 | allow_soft_placement=True
86 | )
87 | config.gpu_options.allow_growth = True
88 | evaluator = Evaluator(tf.Session(config=config))
89 |
90 | evaluator.warmup()
91 |
92 | ref_acts = evaluator.read_activations(ref_batch)
93 | ref_stats, ref_stats_spatial = evaluator.read_statistics(ref_batch, ref_acts)
94 |
95 | sample_acts = evaluator.read_activations(sample_batch)
96 | sample_stats, sample_stats_spatial = evaluator.read_statistics(sample_batch, sample_acts)
97 |
98 | with open(fid_path, 'w') as fd:
99 |
100 | fd.write("Computing evaluations...\n")
101 | fd.write(f"Inception Score:{evaluator.compute_inception_score(sample_acts[0])}\n" )
102 | fd.write(f"FID:{sample_stats.frechet_distance(ref_stats)}\n")
103 | fd.write(f"sFID:{sample_stats_spatial.frechet_distance(ref_stats_spatial)}\n")
104 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
105 | fd.write(f"Precision:{prec}\n")
106 | fd.write(f"Recall:{recall}\n")
107 |
108 |
109 |
110 | def main(args):
111 | """
112 | Run sampling.
113 | """
114 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
115 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
116 | torch.set_grad_enabled(False)
117 |
118 | # Setup DDP:
119 | dist.init_process_group("nccl")
120 | rank = dist.get_rank()
121 | device = rank % torch.cuda.device_count()
122 | seed = args.global_seed * dist.get_world_size() + rank
123 | torch.manual_seed(seed)
124 | torch.cuda.set_device(device)
125 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
126 |
127 | if args.ckpt is None:
128 | assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download."
129 | assert args.image_size in [256, 512]
130 | assert args.num_classes == 1000
131 |
132 | # Load model:
133 | latent_size = args.image_size // 8
134 | model = DiT_models[args.model](
135 | input_size=latent_size,
136 | num_classes=args.num_classes,
137 | in_channels=args.in_channels
138 | ).to(device)
139 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py:
140 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
141 | state_dict = find_model(ckpt_path)
142 | model.load_state_dict(state_dict)
143 | model.eval() # important!
144 | diffusion = create_diffusion(str(args.num_sampling_steps))
145 |
146 | if args.vae_ckpt is not None:
147 | vae_config = load_config(args.vae_config, display=False)
148 | vae = load_kl(vae_config, ckpt_path=args.vae_ckpt).to(device)
149 | else:
150 | from diffusers.models import AutoencoderKL
151 | vae = AutoencoderKL.from_pretrained(args.hf_model_name).to(device)
152 |
153 |
154 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
155 | using_cfg = args.cfg_scale > 1.0
156 |
157 | # Create folder to save samples:
158 | model_string_name = args.model.replace("/", "-")
159 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
160 | folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \
161 | f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
162 |
163 | if args.ddpm:
164 | folder_name += f"-ddpm"
165 |
166 |
167 | sample_folder_dir = f"{args.sample_dir}/{folder_name}"
168 |
169 | exp_dir_name = os.path.dirname(args.sample_dir)
170 | fid_dir = f"{exp_dir_name}/fids"
171 |
172 | if rank == 0:
173 | os.makedirs(sample_folder_dir, exist_ok=True)
174 | os.makedirs(fid_dir, exist_ok=True)
175 |
176 | print(f"Saving .png samples at {sample_folder_dir}")
177 | print(f"Saving metrics at {fid_dir}")
178 |
179 | dist.barrier()
180 |
181 | fid_path = f"{fid_dir}/metrics-adm-{folder_name}.txt"
182 |
183 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
184 | n = args.per_proc_batch_size
185 | global_batch_size = n * dist.get_world_size()
186 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
187 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
188 | if rank == 0:
189 | print(f"Total number of images that will be sampled: {total_samples}")
190 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
191 | samples_needed_this_gpu = int(total_samples // dist.get_world_size())
192 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
193 | iterations = int(samples_needed_this_gpu // n)
194 | pbar = range(iterations)
195 | pbar = tqdm(pbar) if rank == 0 else pbar
196 | total = 0
197 | for _ in pbar:
198 | # Sample inputs:
199 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
200 | y = torch.randint(0, args.num_classes, (n,), device=device)
201 |
202 | # Setup classifier-free guidance:
203 | if using_cfg:
204 | z = torch.cat([z, z], 0)
205 | y_null = torch.tensor([1000] * n, device=device)
206 | y = torch.cat([y, y_null], 0)
207 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
208 | sample_fn = model.forward_with_cfg
209 | else:
210 | model_kwargs = dict(y=y)
211 | sample_fn = model.forward
212 |
213 | # Sample images:
214 | if args.ddpm:
215 | samples = diffusion.p_sample_loop(
216 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
217 | )
218 | else:
219 | samples = diffusion.ddim_sample_loop(
220 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
221 | )
222 | if using_cfg:
223 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
224 |
225 | if args.vae_ckpt is not None:
226 |
227 | samples = vae.decode(samples / args.vae_scaling_factor)
228 |
229 | else:
230 |
231 | samples = vae.decode(samples / args.vae_scaling_factor).sample
232 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
233 |
234 | # Save samples to disk as individual .png files
235 | for i, sample in enumerate(samples):
236 | index = i * dist.get_world_size() + rank + total
237 |
238 | if args.vae in ["ours", "hf"]:
239 | sample = custom_to_pil(sample)
240 | sample.save(f"{sample_folder_dir}/{index:06d}.png")
241 | else:
242 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
243 | total += global_batch_size
244 |
245 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
246 | dist.barrier()
247 | if rank == 0:
248 | sample_batch = create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
249 | calculate_metrics(sample_batch, args.ref_batch, fid_path)
250 |
251 | dist.barrier()
252 | dist.destroy_process_group()
253 |
254 |
255 | if __name__ == "__main__":
256 | parser = argparse.ArgumentParser()
257 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
258 | parser.add_argument("--vae-ckpt", type=str, default=None)
259 | parser.add_argument("--sample-dir", type=str, default="samples")
260 | parser.add_argument("--ddpm", type=bool, default=False)
261 | parser.add_argument("--vae-scaling-factor", type=float, default=0.18215)
262 | parser.add_argument("--ref-batch", type=str, default="/data/imagenet/VIRTUAL_imagenet256_labeled.npz")
263 | parser.add_argument("--wavelets", type=str, default=False)
264 | parser.add_argument("--diff-proj", type=str, default=False)
265 | parser.add_argument("--gaussian-registers", type=str, default=False)
266 | parser.add_argument("--in-channels", type=int, default=4)
267 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
268 | parser.add_argument("--num-classes", type=int, default=1000)
269 | parser.add_argument("--cfg-scale", type=float, default=1.5)
270 | parser.add_argument("--num-sampling-steps", type=int, default=250)
271 | parser.add_argument("--global-seed", type=int, default=0)
272 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
273 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
274 | parser.add_argument("--ckpt", type=str, default=None,
275 | help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
276 | args = parser.parse_args()
277 | main(args)
278 |
--------------------------------------------------------------------------------
/train_gen/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/chuanyangjin/fast-DiT
3 | """
4 | import sys
5 | import os
6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
7 |
8 | import torch
9 | # the first flag below was False when we tested this script but True makes A100 training a lot faster:
10 | torch.backends.cuda.matmul.allow_tf32 = True
11 | torch.backends.cudnn.allow_tf32 = True
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.utils.data import Dataset, DataLoader
15 | from torch.utils.data.distributed import DistributedSampler
16 | from torchvision.datasets import ImageFolder
17 | from torchvision import transforms
18 | import numpy as np
19 | from collections import OrderedDict
20 | from PIL import Image
21 | from copy import deepcopy
22 | from glob import glob
23 | from time import time
24 | import argparse
25 | import logging
26 | from accelerate import Accelerator
27 | from models.dit_models import DiT_models
28 | from models.diffusion import create_diffusion
29 |
30 | # from models import DiT_models
31 | # from diffusion import create_diffusion
32 | from diffusers.models import AutoencoderKL
33 |
34 |
35 | #################################################################################
36 | # Training Helper Functions #
37 | #################################################################################
38 |
39 | @torch.no_grad()
40 | def update_ema(ema_model, model, decay=0.9999):
41 | """
42 | Step the EMA model towards the current model.
43 | """
44 | ema_params = OrderedDict(ema_model.named_parameters())
45 | model_params = OrderedDict(model.named_parameters())
46 |
47 | for name, param in model_params.items():
48 | name = name.replace("module.", "")
49 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
50 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
51 |
52 |
53 | def requires_grad(model, flag=True):
54 | """
55 | Set requires_grad flag for all parameters in a model.
56 | """
57 | for p in model.parameters():
58 | p.requires_grad = flag
59 |
60 |
61 | def create_logger(logging_dir):
62 | """
63 | Create a logger that writes to a log file and stdout.
64 | """
65 | logging.basicConfig(
66 | level=logging.INFO,
67 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
68 | datefmt='%Y-%m-%d %H:%M:%S',
69 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
70 | )
71 | logger = logging.getLogger(__name__)
72 | return logger
73 |
74 |
75 | def center_crop_arr(pil_image, image_size):
76 | """
77 | Center cropping implementation from ADM.
78 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
79 | """
80 | while min(*pil_image.size) >= 2 * image_size:
81 | pil_image = pil_image.resize(
82 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
83 | )
84 |
85 | scale = image_size / min(*pil_image.size)
86 | pil_image = pil_image.resize(
87 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
88 | )
89 |
90 | arr = np.array(pil_image)
91 | crop_y = (arr.shape[0] - image_size) // 2
92 | crop_x = (arr.shape[1] - image_size) // 2
93 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
94 |
95 |
96 | class CustomDataset(Dataset):
97 | def __init__(self, features_dir, labels_dir):
98 | self.features_dir = features_dir
99 | self.labels_dir = labels_dir
100 |
101 | self.features_files = sorted(os.listdir(features_dir))
102 | self.labels_files = sorted(os.listdir(labels_dir))
103 |
104 | def __len__(self):
105 | assert len(self.features_files) == len(self.labels_files), \
106 | "Number of feature files and label files should be same"
107 | return len(self.features_files)
108 |
109 | def __getitem__(self, idx):
110 | feature_file = self.features_files[idx]
111 | label_file = self.labels_files[idx]
112 |
113 | features = np.load(os.path.join(self.features_dir, feature_file))
114 | labels = np.load(os.path.join(self.labels_dir, label_file))
115 | return torch.from_numpy(features), torch.from_numpy(labels)
116 |
117 |
118 | #################################################################################
119 | # Training Loop #
120 | #################################################################################
121 |
122 | def main(args):
123 | """
124 | Trains a new DiT model.
125 | """
126 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
127 |
128 | # Setup accelerator:
129 | accelerator = Accelerator()
130 | device = accelerator.device
131 |
132 | experiment_index = len(glob(f"{args.results_dir}/*"))
133 | model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
134 |
135 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
136 |
137 |
138 | if args.dataset_name:
139 | experiment_dir += f"-{args.dataset_name}"
140 |
141 | if args.vae_name:
142 | experiment_dir += f"-{args.vae_name}"
143 |
144 |
145 | # Setup an experiment folder:
146 | if accelerator.is_main_process:
147 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
148 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
149 | os.makedirs(checkpoint_dir, exist_ok=True)
150 | logger = create_logger(experiment_dir)
151 | logger.info(f"Experiment directory created at {experiment_dir}")
152 |
153 | # Create model:
154 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
155 | latent_size = args.image_size // 8
156 | model = DiT_models[args.model](
157 | input_size=latent_size,
158 | in_channels=args.in_channels
159 | )
160 | # Note that parameter initialization is done within the DiT constructor
161 | model = model.to(device)
162 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
163 |
164 |
165 | opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
166 |
167 | if args.ckpt is not None:
168 | ckpt_path = args.ckpt
169 | state_dict = torch.load(ckpt_path, map_location="cpu")
170 | model.load_state_dict(state_dict["model"])
171 | ema.load_state_dict(state_dict["ema"])
172 | opt.load_state_dict(state_dict["opt"])
173 | args = state_dict["args"]
174 |
175 |
176 | requires_grad(ema, False)
177 |
178 |
179 | diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
180 | if accelerator.is_main_process:
181 | logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
182 |
183 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
184 |
185 | # Setup data:
186 | features_dir = f"{args.feature_path}/imagenet256_features"
187 | labels_dir = f"{args.feature_path}/imagenet256_labels"
188 | dataset = CustomDataset(features_dir, labels_dir)
189 | loader = DataLoader(
190 | dataset,
191 | batch_size=int(args.global_batch_size // accelerator.num_processes),
192 | shuffle=True,
193 | num_workers=args.num_workers,
194 | pin_memory=True,
195 | drop_last=True
196 | )
197 | if accelerator.is_main_process:
198 | logger.info(f"Dataset contains {len(dataset):,} images ({args.feature_path})")
199 |
200 | # Prepare models for training:
201 | update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
202 | model.train() # important! This enables embedding dropout for classifier-free guidance
203 | ema.eval() # EMA model should always be in eval mode
204 | model, opt, loader = accelerator.prepare(model, opt, loader)
205 |
206 | # Variables for monitoring/logging purposes:
207 | train_steps = 0
208 | log_steps = 0
209 | running_loss = 0
210 | start_time = time()
211 |
212 | if accelerator.is_main_process:
213 | logger.info(f"Training for {args.epochs} epochs...")
214 | for epoch in range(args.epochs):
215 | if accelerator.is_main_process:
216 | logger.info(f"Beginning epoch {epoch}...")
217 | for x, y in loader:
218 | x = x.to(device)
219 | y = y.to(device)
220 | x = x.squeeze(dim=1)
221 | y = y.squeeze(dim=1)
222 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
223 | model_kwargs = dict(y=y)
224 | loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
225 | loss = loss_dict["loss"].mean()
226 | opt.zero_grad()
227 | accelerator.backward(loss)
228 | opt.step()
229 | update_ema(ema, model)
230 |
231 | # Log loss values:
232 | running_loss += loss.item()
233 | log_steps += 1
234 | train_steps += 1
235 | if train_steps % args.log_every == 0:
236 | # Measure training speed:
237 | torch.cuda.synchronize()
238 | end_time = time()
239 | steps_per_sec = log_steps / (end_time - start_time)
240 | # Reduce loss history over all processes:
241 | avg_loss = torch.tensor(running_loss / log_steps, device=device)
242 | avg_loss = avg_loss.item() / accelerator.num_processes
243 | if accelerator.is_main_process:
244 | logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
245 | # Reset monitoring variables:
246 | running_loss = 0
247 | log_steps = 0
248 | start_time = time()
249 |
250 | # Save DiT checkpoint:
251 | if train_steps % args.ckpt_every == 0:
252 | if accelerator.is_main_process:
253 | checkpoint = {
254 | "model": model.module.state_dict(),
255 | "ema": ema.state_dict(),
256 | "opt": opt.state_dict(),
257 | "args": args
258 | }
259 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
260 | torch.save(checkpoint, checkpoint_path)
261 | logger.info(f"Saved checkpoint to {checkpoint_path}")
262 | if train_steps > args.max_train_steps:
263 | break
264 |
265 | if accelerator.is_main_process:
266 | logger.info("Done!")
267 |
268 |
269 | if __name__ == "__main__":
270 | # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
271 | parser = argparse.ArgumentParser()
272 | parser.add_argument("--feature-path", type=str, default="features")
273 | parser.add_argument("--results-dir", type=str, default="results")
274 | parser.add_argument("--dataset-name", type=str, default=False)
275 | parser.add_argument("--in-channels", type=int, default=4)
276 | parser.add_argument("--ckpt", type=str, default=None)
277 | parser.add_argument("--max-train-steps", type=str, default=400_000)
278 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
279 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
280 | parser.add_argument("--num-classes", type=int, default=1000)
281 | parser.add_argument("--epochs", type=int, default=1400)
282 | parser.add_argument("--global-batch-size", type=int, default=256)
283 | parser.add_argument("--global-seed", type=int, default=0)
284 | parser.add_argument("--vae-name", type=str, default="ema") # Choice doesn't affect training
285 | parser.add_argument("--num-workers", type=int, default=4)
286 | parser.add_argument("--log-every", type=int, default=100)
287 | parser.add_argument("--ckpt-every", type=int, default=50_000)
288 |
289 | args = parser.parse_args()
290 | main(args)
291 |
--------------------------------------------------------------------------------
/train_eqvae/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class PLMSSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | if ddim_eta != 0:
26 | raise ValueError('ddim_eta must be 0 for PLMS')
27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29 | alphas_cumprod = self.model.alphas_cumprod
30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32 |
33 | self.register_buffer('betas', to_torch(self.model.betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43 |
44 | # ddim sampling parameters
45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46 | ddim_timesteps=self.ddim_timesteps,
47 | eta=ddim_eta,verbose=verbose)
48 | self.register_buffer('ddim_sigmas', ddim_sigmas)
49 | self.register_buffer('ddim_alphas', ddim_alphas)
50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56 |
57 | @torch.no_grad()
58 | def sample(self,
59 | S,
60 | batch_size,
61 | shape,
62 | conditioning=None,
63 | callback=None,
64 | normals_sequence=None,
65 | img_callback=None,
66 | quantize_x0=False,
67 | eta=0.,
68 | mask=None,
69 | x0=None,
70 | temperature=1.,
71 | noise_dropout=0.,
72 | score_corrector=None,
73 | corrector_kwargs=None,
74 | verbose=True,
75 | x_T=None,
76 | log_every_t=100,
77 | unconditional_guidance_scale=1.,
78 | unconditional_conditioning=None,
79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80 | **kwargs
81 | ):
82 | if conditioning is not None:
83 | if isinstance(conditioning, dict):
84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85 | if cbs != batch_size:
86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87 | else:
88 | if conditioning.shape[0] != batch_size:
89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90 |
91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92 | # sampling
93 | C, H, W = shape
94 | size = (batch_size, C, H, W)
95 | print(f'Data shape for PLMS sampling is {size}')
96 |
97 | samples, intermediates = self.plms_sampling(conditioning, size,
98 | callback=callback,
99 | img_callback=img_callback,
100 | quantize_denoised=quantize_x0,
101 | mask=mask, x0=x0,
102 | ddim_use_original_steps=False,
103 | noise_dropout=noise_dropout,
104 | temperature=temperature,
105 | score_corrector=score_corrector,
106 | corrector_kwargs=corrector_kwargs,
107 | x_T=x_T,
108 | log_every_t=log_every_t,
109 | unconditional_guidance_scale=unconditional_guidance_scale,
110 | unconditional_conditioning=unconditional_conditioning,
111 | )
112 | return samples, intermediates
113 |
114 | @torch.no_grad()
115 | def plms_sampling(self, cond, shape,
116 | x_T=None, ddim_use_original_steps=False,
117 | callback=None, timesteps=None, quantize_denoised=False,
118 | mask=None, x0=None, img_callback=None, log_every_t=100,
119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
121 | device = self.model.betas.device
122 | b = shape[0]
123 | if x_T is None:
124 | img = torch.randn(shape, device=device)
125 | else:
126 | img = x_T
127 |
128 | if timesteps is None:
129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130 | elif timesteps is not None and not ddim_use_original_steps:
131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132 | timesteps = self.ddim_timesteps[:subset_end]
133 |
134 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137 | print(f"Running PLMS Sampling with {total_steps} timesteps")
138 |
139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140 | old_eps = []
141 |
142 | for i, step in enumerate(iterator):
143 | index = total_steps - i - 1
144 | ts = torch.full((b,), step, device=device, dtype=torch.long)
145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146 |
147 | if mask is not None:
148 | assert x0 is not None
149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150 | img = img_orig * mask + (1. - mask) * img
151 |
152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153 | quantize_denoised=quantize_denoised, temperature=temperature,
154 | noise_dropout=noise_dropout, score_corrector=score_corrector,
155 | corrector_kwargs=corrector_kwargs,
156 | unconditional_guidance_scale=unconditional_guidance_scale,
157 | unconditional_conditioning=unconditional_conditioning,
158 | old_eps=old_eps, t_next=ts_next)
159 | img, pred_x0, e_t = outs
160 | old_eps.append(e_t)
161 | if len(old_eps) >= 4:
162 | old_eps.pop(0)
163 | if callback: callback(i)
164 | if img_callback: img_callback(pred_x0, i)
165 |
166 | if index % log_every_t == 0 or index == total_steps - 1:
167 | intermediates['x_inter'].append(img)
168 | intermediates['pred_x0'].append(pred_x0)
169 |
170 | return img, intermediates
171 |
172 | @torch.no_grad()
173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176 | b, *_, device = *x.shape, x.device
177 |
178 | def get_model_output(x, t):
179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180 | e_t = self.model.apply_model(x, t, c)
181 | else:
182 | x_in = torch.cat([x] * 2)
183 | t_in = torch.cat([t] * 2)
184 | c_in = torch.cat([unconditional_conditioning, c])
185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187 |
188 | if score_corrector is not None:
189 | assert self.model.parameterization == "eps"
190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191 |
192 | return e_t
193 |
194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198 |
199 | def get_x_prev_and_pred_x0(e_t, index):
200 | # select parameters corresponding to the currently considered timestep
201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205 |
206 | # current prediction for x_0
207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208 | if quantize_denoised:
209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210 | # direction pointing to x_t
211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213 | if noise_dropout > 0.:
214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216 | return x_prev, pred_x0
217 |
218 | e_t = get_model_output(x, t)
219 | if len(old_eps) == 0:
220 | # Pseudo Improved Euler (2nd order)
221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222 | e_t_next = get_model_output(x_prev, t_next)
223 | e_t_prime = (e_t + e_t_next) / 2
224 | elif len(old_eps) == 1:
225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
227 | elif len(old_eps) == 2:
228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230 | elif len(old_eps) >= 3:
231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233 |
234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235 |
236 | return x_prev, pred_x0, e_t
237 |
--------------------------------------------------------------------------------
/models/dit_models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17 |
18 |
19 | def modulate(x, shift, scale):
20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21 |
22 |
23 | #################################################################################
24 | # Embedding Layers for Timesteps and Class Labels #
25 | #################################################################################
26 |
27 | class TimestepEmbedder(nn.Module):
28 | """
29 | Embeds scalar timesteps into vector representations.
30 | """
31 | def __init__(self, hidden_size, frequency_embedding_size=256):
32 | super().__init__()
33 | self.mlp = nn.Sequential(
34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
35 | nn.SiLU(),
36 | nn.Linear(hidden_size, hidden_size, bias=True),
37 | )
38 | self.frequency_embedding_size = frequency_embedding_size
39 |
40 | @staticmethod
41 | def timestep_embedding(t, dim, max_period=10000):
42 | """
43 | Create sinusoidal timestep embeddings.
44 | :param t: a 1-D Tensor of N indices, one per batch element.
45 | These may be fractional.
46 | :param dim: the dimension of the output.
47 | :param max_period: controls the minimum frequency of the embeddings.
48 | :return: an (N, D) Tensor of positional embeddings.
49 | """
50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51 | half = dim // 2
52 | freqs = torch.exp(
53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
54 | ).to(device=t.device)
55 | args = t[:, None].float() * freqs[None]
56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57 | if dim % 2:
58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59 | return embedding
60 |
61 | def forward(self, t):
62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63 | t_emb = self.mlp(t_freq)
64 | return t_emb
65 |
66 |
67 | class LabelEmbedder(nn.Module):
68 | """
69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70 | """
71 | def __init__(self, num_classes, hidden_size, dropout_prob):
72 | super().__init__()
73 | use_cfg_embedding = dropout_prob > 0
74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75 | self.num_classes = num_classes
76 | self.dropout_prob = dropout_prob
77 |
78 | def token_drop(self, labels, force_drop_ids=None):
79 | """
80 | Drops labels to enable classifier-free guidance.
81 | """
82 | if force_drop_ids is None:
83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84 | else:
85 | drop_ids = force_drop_ids == 1
86 | labels = torch.where(drop_ids, self.num_classes, labels)
87 | return labels
88 |
89 | def forward(self, labels, train, force_drop_ids=None):
90 | use_dropout = self.dropout_prob > 0
91 | if (train and use_dropout) or (force_drop_ids is not None):
92 | labels = self.token_drop(labels, force_drop_ids)
93 | embeddings = self.embedding_table(labels)
94 | return embeddings
95 |
96 |
97 | #################################################################################
98 | # Core DiT Model #
99 | #################################################################################
100 |
101 | class DiTBlock(nn.Module):
102 | """
103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
104 | """
105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
106 | super().__init__()
107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
111 | approx_gelu = lambda: nn.GELU(approximate="tanh")
112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
113 | self.adaLN_modulation = nn.Sequential(
114 | nn.SiLU(),
115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
116 | )
117 |
118 | def forward(self, x, c):
119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
122 | return x
123 |
124 |
125 | class FinalLayer(nn.Module):
126 | """
127 | The final layer of DiT.
128 | """
129 | def __init__(self, hidden_size, patch_size, out_channels):
130 | super().__init__()
131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133 | self.adaLN_modulation = nn.Sequential(
134 | nn.SiLU(),
135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
136 | )
137 |
138 | def forward(self, x, c):
139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
140 | x = modulate(self.norm_final(x), shift, scale)
141 | x = self.linear(x)
142 | return x
143 |
144 |
145 | class DiT(nn.Module):
146 | """
147 | Diffusion model with a Transformer backbone.
148 | """
149 | def __init__(
150 | self,
151 | input_size=32,
152 | patch_size=2,
153 | in_channels=4,
154 | hidden_size=1152,
155 | depth=28,
156 | num_heads=16,
157 | mlp_ratio=4.0,
158 | class_dropout_prob=0.1,
159 | num_classes=1000,
160 | learn_sigma=True,
161 | ):
162 | super().__init__()
163 | self.learn_sigma = learn_sigma
164 | self.in_channels = in_channels
165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
166 | self.patch_size = patch_size
167 | self.num_heads = num_heads
168 |
169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
170 | self.t_embedder = TimestepEmbedder(hidden_size)
171 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
172 | num_patches = self.x_embedder.num_patches
173 | # Will use fixed sin-cos embedding:
174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
175 |
176 | self.blocks = nn.ModuleList([
177 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
178 | ])
179 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
180 | self.initialize_weights()
181 |
182 | def initialize_weights(self):
183 | # Initialize transformer layers:
184 | def _basic_init(module):
185 | if isinstance(module, nn.Linear):
186 | torch.nn.init.xavier_uniform_(module.weight)
187 | if module.bias is not None:
188 | nn.init.constant_(module.bias, 0)
189 | self.apply(_basic_init)
190 |
191 | # Initialize (and freeze) pos_embed by sin-cos embedding:
192 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
193 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
194 |
195 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
196 | w = self.x_embedder.proj.weight.data
197 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
198 | nn.init.constant_(self.x_embedder.proj.bias, 0)
199 |
200 | # Initialize label embedding table:
201 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
202 |
203 | # Initialize timestep embedding MLP:
204 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
205 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
206 |
207 | # Zero-out adaLN modulation layers in DiT blocks:
208 | for block in self.blocks:
209 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
210 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
211 |
212 | # Zero-out output layers:
213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
214 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
215 | nn.init.constant_(self.final_layer.linear.weight, 0)
216 | nn.init.constant_(self.final_layer.linear.bias, 0)
217 |
218 | def unpatchify(self, x):
219 | """
220 | x: (N, T, patch_size**2 * C)
221 | imgs: (N, H, W, C)
222 | """
223 | c = self.out_channels
224 | p = self.x_embedder.patch_size[0]
225 | h = w = int(x.shape[1] ** 0.5)
226 | assert h * w == x.shape[1]
227 |
228 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
229 | x = torch.einsum('nhwpqc->nchpwq', x)
230 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
231 | return imgs
232 |
233 | def ckpt_wrapper(self, module):
234 | def ckpt_forward(*inputs):
235 | outputs = module(*inputs)
236 | return outputs
237 | return ckpt_forward
238 |
239 | def forward(self, x, t, y):
240 | """
241 | Forward pass of DiT.
242 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
243 | t: (N,) tensor of diffusion timesteps
244 | y: (N,) tensor of class labels
245 | """
246 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
247 | t = self.t_embedder(t) # (N, D)
248 | y = self.y_embedder(y, self.training) # (N, D)
249 | c = t + y # (N, D)
250 | for block in self.blocks:
251 | x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, c) # (N, T, D)
252 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
253 | x = self.unpatchify(x) # (N, out_channels, H, W)
254 | return x
255 |
256 | def forward_with_cfg(self, x, t, y, cfg_scale):
257 | """
258 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
259 | """
260 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
261 | half = x[: len(x) // 2]
262 | combined = torch.cat([half, half], dim=0)
263 | model_out = self.forward(combined, t, y)
264 | # For exact reproducibility reasons, we apply classifier-free guidance on only
265 | # three channels by default. The standard approach to cfg applies it to all channels.
266 | # This can be done by uncommenting the following line and commenting-out the line following that.
267 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
268 | eps, rest = model_out[:, :3], model_out[:, 3:]
269 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
270 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
271 | eps = torch.cat([half_eps, half_eps], dim=0)
272 | return torch.cat([eps, rest], dim=1)
273 |
274 |
275 | #################################################################################
276 | # Sine/Cosine Positional Embedding Functions #
277 | #################################################################################
278 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
279 |
280 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
281 | """
282 | grid_size: int of the grid height and width
283 | return:
284 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
285 | """
286 | grid_h = np.arange(grid_size, dtype=np.float32)
287 | grid_w = np.arange(grid_size, dtype=np.float32)
288 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
289 | grid = np.stack(grid, axis=0)
290 |
291 | grid = grid.reshape([2, 1, grid_size, grid_size])
292 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
293 | if cls_token and extra_tokens > 0:
294 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
295 | return pos_embed
296 |
297 |
298 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
299 | assert embed_dim % 2 == 0
300 |
301 | # use half of dimensions to encode grid_h
302 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
303 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
304 |
305 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
306 | return emb
307 |
308 |
309 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
310 | """
311 | embed_dim: output dimension for each position
312 | pos: a list of positions to be encoded: size (M,)
313 | out: (M, D)
314 | """
315 | assert embed_dim % 2 == 0
316 | omega = np.arange(embed_dim // 2, dtype=np.float64)
317 | omega /= embed_dim / 2.
318 | omega = 1. / 10000**omega # (D/2,)
319 |
320 | pos = pos.reshape(-1) # (M,)
321 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
322 |
323 | emb_sin = np.sin(out) # (M, D/2)
324 | emb_cos = np.cos(out) # (M, D/2)
325 |
326 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
327 | return emb
328 |
329 |
330 | #################################################################################
331 | # DiT Configs #
332 | #################################################################################
333 |
334 | def DiT_XL_2(**kwargs):
335 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
336 |
337 | def DiT_XL_4(**kwargs):
338 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
339 |
340 | def DiT_XL_8(**kwargs):
341 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
342 |
343 | def DiT_L_2(**kwargs):
344 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
345 |
346 | def DiT_L_4(**kwargs):
347 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
348 |
349 | def DiT_L_8(**kwargs):
350 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
351 |
352 | def DiT_B_2(**kwargs):
353 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
354 |
355 | def DiT_B_4(**kwargs):
356 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
357 |
358 | def DiT_B_8(**kwargs):
359 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
360 |
361 | def DiT_S_2(**kwargs):
362 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
363 |
364 | def DiT_S_4(**kwargs):
365 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
366 |
367 | def DiT_S_8(**kwargs):
368 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
369 |
370 |
371 | DiT_models = {
372 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
373 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
374 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
375 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
376 | }
--------------------------------------------------------------------------------