├── train_eqvae ├── ldm │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── dataset.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ │ └── dataset.py │ ├── models │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ └── plms.py │ │ ├── __pycache__ │ │ │ ├── autoencoder.cpython-312.pyc │ │ │ └── autoencoder.cpython-38.pyc │ │ └── autoencoder.py │ ├── modules │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── util.cpython-38.pyc │ │ │ │ ├── model.cpython-312.pyc │ │ │ │ ├── model.cpython-38.pyc │ │ │ │ ├── util.cpython-312.pyc │ │ │ │ ├── __init__.cpython-312.pyc │ │ │ │ └── __init__.cpython-38.pyc │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-312.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── distributions.cpython-312.pyc │ │ │ │ └── distributions.cpython-38.pyc │ │ │ └── distributions.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-312.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── contperceptual.cpython-38.pyc │ │ │ │ └── contperceptual.cpython-312.pyc │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ ├── image_degradation │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ ├── __pycache__ │ │ │ │ ├── bsrgan.cpython-38.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── utils_image.cpython-38.pyc │ │ │ │ └── bsrgan_light.cpython-38.pyc │ │ │ └── __init__.py │ │ ├── __pycache__ │ │ │ ├── attention.cpython-38.pyc │ │ │ └── attention.cpython-312.pyc │ │ ├── ema.py │ │ └── attention.py │ ├── __pycache__ │ │ ├── util.cpython-38.pyc │ │ └── util.cpython-312.pyc │ ├── lr_scheduler.py │ └── util.py ├── download_sdvae.sh ├── setup.py ├── environment.yaml ├── README.md └── configs │ └── eqvae_config.yaml ├── media └── teaser.png ├── evaluation ├── __pycache__ │ ├── lpips.cpython-38.pyc │ └── calculate_fid.cpython-38.pyc └── lpips.py ├── models ├── __pycache__ │ └── dit_models.cpython-312.pyc ├── diffusion │ ├── __pycache__ │ │ ├── respace.cpython-312.pyc │ │ ├── __init__.cpython-312.pyc │ │ ├── diffusion_utils.cpython-312.pyc │ │ └── gaussian_diffusion.cpython-312.pyc │ ├── __init__.py │ ├── diffusion_utils.py │ ├── respace.py │ └── timestep_sampler.py └── dit_models.py ├── environment.yml ├── README.md ├── eval.py └── train_gen ├── extract_features.py ├── sample_ddp.py └── train.py /train_eqvae/ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_eqvae/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /media/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/media/teaser.png -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /evaluation/__pycache__/lpips.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/evaluation/__pycache__/lpips.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/dit_models.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/__pycache__/dit_models.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/__pycache__/util.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/__pycache__/util.cpython-312.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/calculate_fid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/evaluation/__pycache__/calculate_fid.cpython-38.pyc -------------------------------------------------------------------------------- /models/diffusion/__pycache__/respace.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/respace.cpython-312.pyc -------------------------------------------------------------------------------- /models/diffusion/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /models/diffusion/__pycache__/diffusion_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/diffusion_utils.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/models/__pycache__/autoencoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/models/__pycache__/autoencoder.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/__pycache__/attention.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/__pycache__/attention.cpython-312.pyc -------------------------------------------------------------------------------- /models/diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/models/diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/util.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/losses/__pycache__/contperceptual.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/image_degradation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-312.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/image_degradation/__pycache__/utils_image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/utils_image.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan_light.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zelaki/eqvae/HEAD/train_eqvae/ldm/modules/image_degradation/__pycache__/bsrgan_light.cpython-38.pyc -------------------------------------------------------------------------------- /train_eqvae/download_sdvae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p pretrained_models 4 | wget -O pretrained_models/kl-f8.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 5 | unzip pretrained_models/kl-f8.zip -d pretrained_models 6 | rm pretrained_models/kl-f8.zip -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /train_eqvae/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: eqvae 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >=3.8 7 | - pytorch >=1.13 8 | - torchvision 9 | - pytorch-cuda=12.1 10 | - pip: 11 | - timm 12 | - diffusers 13 | - accelerate 14 | - pytorch_lightning 15 | - omegaconf 16 | - einops 17 | - -e git+https://github.com/CompVis/taming-transformers.git@3ba01b241669f5ade541ce990f7650a3b8f65318#egg=taming_transformers 18 | -------------------------------------------------------------------------------- /train_eqvae/environment.yaml: -------------------------------------------------------------------------------- 1 | name: eqvae_train 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=10.2 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch_lightning==1.4.2 19 | - omegaconf==2.0.0 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - more-itertools>=8.0.0 24 | - transformers==4.3.1 25 | - torch-fidelity==0.3.0 26 | - torchmetrics==0.6.0 27 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 28 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 29 | - -e . 30 | -------------------------------------------------------------------------------- /models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /train_eqvae/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### 1. Environment setup 3 | ```bash 4 | conda env create -f environment.yaml 5 | conda activate eqvae_train 6 | pip install packaging==21.3 7 | pip install 'torchmetrics<0.8' 8 | pip install transformers==4.10.2 9 | pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html 10 | pip install Pillow==9.5.0 11 | ``` 12 | 13 | 14 | ### 2. Download SD-VAE 15 | To download the SD-VAE from the official LDM repository run: 16 | 17 | 18 | ```bash 19 | bash download_sdvae.sh 20 | ``` 21 | 22 | 23 | 24 | ### 3. Dataset 25 | 26 | #### Dataset download 27 | 28 | Currently, we provide experiments for [OpenImages](https://storage.googleapis.com/openimages/web/index.html). After downloading modify paths of train_dir, val_dir, dataset_name in the [cofig file](configs/eqvae_config.yaml) 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | ### 4. Training 39 | 40 | To run EQ-VAE regularization on 8 GPUs: 41 | 42 | ```bash 43 | python main.py \ 44 | --base configs/eqvae_config.yaml \ 45 | -t \ 46 | --gpus 0,1,2,3,4,5,6,7 \ 47 | --resume pretrained_models/model.ckpt \ 48 | --logdir logs/eq-vae 49 | ``` 50 | 51 | 52 | Then this script will automatically create the folder in `logs/eq-vae` to save logs and checkpoints. 53 | The provided arguments in `configs/eqvae_config.yaml` are the ones used in our paper. You can adjust the following options for your experiments: 54 | 55 | - `anisotropic`: If `True` will do anisotropic scaling 56 | - `uniform_sample_scale`: If `True` will sample scale factors uniformly from `[0.25, 1)` if set to `False` will randomly choose from scales from `{0.25, 0.5, 0.75}`. 57 | - `p_prior`: Probability to do prior preservation instead of equivariance regularization 58 | - `p_prior_s`: Probability to do prior presevation on lower resolutions instead of equivariance regularization 59 | -------------------------------------------------------------------------------- /train_eqvae/configs/eqvae_config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | anisotropic: False 8 | uniform_sample_scale: True 9 | p_prior: 0.5 10 | p_prior_s: 0.25 11 | lossconfig: 12 | target: ldm.modules.losses.LPIPSWithDiscriminator 13 | params: 14 | disc_start: 0 15 | kl_weight: 0.000001 16 | disc_weight: 0.5 17 | 18 | ddconfig: 19 | double_z: True 20 | z_channels: 4 21 | resolution: 256 22 | in_channels: 3 23 | out_ch: 3 24 | ch: 128 25 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 26 | num_res_blocks: 2 27 | attn_resolutions: [ ] 28 | dropout: 0.0 29 | 30 | 31 | data: 32 | target: main.DataModuleFromConfig 33 | params: 34 | batch_size: 10 35 | wrap: True 36 | train: 37 | target: ldm.data.dataset.DatasetTrain 38 | params: 39 | train_dir: "/data/openimages/target_dir/train/" 40 | dataset_name: openimages # Currently we support imagenet/openimages 41 | size: 256 42 | degradation: pil_nearest 43 | validation: 44 | target: ldm.data.dataset.DatasetVal 45 | params: 46 | val_dir: "/data/openimages/target_dir/validation" 47 | dataset_name: openimages # Currently we support imagenet/openimages 48 | size: 256 49 | degradation: pil_nearest 50 | 51 | 52 | 53 | 54 | # data: 55 | # target: main.DataModuleFromConfig 56 | # params: 57 | # batch_size: 10 58 | # wrap: True 59 | # train: 60 | # target: ldm.data.dataset.DatasetTrain 61 | # params: 62 | # train_dir: "/data/imagenet/train" 63 | # dataset_name: imagenet # Currently we support imagenet/openimages 64 | # size: 256 65 | # degradation: pil_nearest 66 | # validation: 67 | # target: ldm.data.dataset.DatasetVal 68 | # params: 69 | # val_dir: "/data/imagenet/val" 70 | # dataset_name: imagenet # Currently we support imagenet/openimages 71 | # size: 256 72 | # degradation: pil_nearest 73 | 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 1000 82 | max_images: 8 83 | increase_log_steps: True 84 | 85 | trainer: 86 | benchmark: True 87 | accumulate_grad_batches: 2 88 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /models/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /train_eqvae/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 10 | 11 |

12 | EQ-VAE: Equivariance Regularized Latent Space for Improved Generative Image Modeling 13 |

14 | 15 | 16 | 17 | 18 |
19 | Theodoros Kouzelis1,3·   20 | Ioannis Kakogeorgiou1·   21 | Spyros Gidaris2·   22 | Nikos Komodakis1,4,5 23 |
24 | 1 Archimedes/Athena RC   2 valeo.ai   3 National Technical University of Athens  
25 | 4 University of Crete   5 IACM-Forth  
26 | 27 |

28 | 30 | 32 |

33 | 34 | ![teaser.png](media/teaser.png) 35 | 36 | 37 |
38 | 39 | 40 | 41 |
42 | 43 | TL;DR: We propose **EQ-VAE**, a straightforward regularization objective that promotes equivariance in the latent space of pretrained autoencoders under scaling and rotation. This leads to a more structured latent distribution, which accelerates generative model training and improves performance. 44 | 45 | 46 | ### 0. Quick Start with Hugging Face 47 | If you just want to use EQ-VAE to speedup 🚀 the training on your diffusion model you can use our HuggingFace checkpoints 🤗. 48 | We provide two models [eq-vae](https://huggingface.co/zelaki/eq-vae) 49 | and [eq-vae-ema](https://huggingface.co/zelaki/eq-vae-ema). 50 | 51 | | Model | Basemodel | Dataset | Epochs | rFID | PSNR | LPIPS | SSIM | 52 | |---------|-------------|-----------|--------|--------|--------|--------|--------| 53 | | [eq-vae](https://huggingface.co/zelaki/eq-vae) | SD-VAE | OpenImages | 5 | 0.82 | 25.95 | 0.141 | 0.72| 54 | | [eq-vae-ema](https://huggingface.co/zelaki/eq-vae-ema) | SD-VAE | Imagenet | 44 | 0.55 | 26.15 | 0.133 | 0.72 | 55 | 56 | 57 | ```python 58 | from diffusers import AutoencoderKL 59 | eqvae = AutoencoderKL.from_pretrained("zelaki/eq-vae") 60 | ``` 61 | 62 | If you are looking for the weights in the original LDM format you can find them here: [eq-vae-ldm](https://huggingface.co/zelaki/eq-vae-ldm), [eq-vae-ema-ldm](https://huggingface.co/zelaki/eq-vae-ema-ldm) 63 | 64 | 65 | ### 1. Environment setup 66 | 67 | ```bash 68 | conda env create -f environment.yml 69 | conda activate eqvae 70 | ``` 71 | 72 | 73 | ### 2. Train EQ-VAE 74 | We provide a training script to finetune [SD-VAE](https://ommer-lab.com/files/latent-diffusion/kl-f8.zip) with EQ-VAE regularization. For detailed guide go to [train_eqvae](./train_eqvae/). 75 | 76 | 77 | ### 3. Evaluate Reconstruction 78 | To evaluate the reconstruction of EQ-VAE, calculate rFID, LPIPS, SSIM and PSNR on a validation set (we use Imagenet Validation in our paper) with the following: 79 | ```bash 80 | torchrun --nproc_per_node=8 eval.py \ 81 | --data_path /path/to/imagenet/validation \ 82 | --output_path results \ 83 | --ckpt_path /path/to/your/ckpt 84 | ``` 85 | 86 | ### 4. Train DiT with EQ-VAE 87 | To train a DiT model with EQ-VAE on ImageNet: 88 | - First extract the latent representations: 89 | ```bash 90 | torchrun --nnodes=1 --nproc_per_node=8 train_gen/extract_features.py \ 91 | --data-path /path/to/imagenet/train \ 92 | --features-path /path/to/latents \ 93 | --vae-ckpt /path/to/eqvae.ckpt \ 94 | --vae-config configs/eqvae_config.yaml 95 | ``` 96 | - Then train DiT on the precomputed latents: 97 | ```bash 98 | accelerate launch --mixed_precision fp16 train_gen/train.py \ 99 | --model DiT-XL/2 \ 100 | --feature-path /path/to/latents \ 101 | --results-dir results 102 | ``` 103 | - Evaluate generation as follows: 104 | ```bash 105 | torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py \ 106 | --model DiT-XL/2 \ 107 | --num-fid-samples 50000 \ 108 | --ckpt /path/to/dit.cpt \ 109 | --sample-dir samples \ 110 | --vae-ckpt /path/to/eqvae.ckpt \ 111 | --vae-config configs/eqvae_config.yaml \ 112 | --ddpm True \ 113 | --cfg-scale 1.0 114 | ``` 115 | 116 | This script generates a folder of 50k samples as well as a .npz file and directly used with [ADM's TensorFlow evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute gFID. 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | ### Acknowledgement 126 | 127 | This code is mainly built upon [LDM](https://github.com/CompVis/latent-diffusion) and [fastDiT](https://github.com/chuanyangjin/fast-DiT). 128 | 129 | 130 | ### Citation 131 | 132 | ```bibtex 133 | @inproceedings{ 134 | kouzelis2025eqvae, 135 | title={{EQ}-{VAE}: Equivariance Regularized Latent Space for Improved Generative Image Modeling}, 136 | author={Theodoros Kouzelis and Ioannis Kakogeorgiou and Spyros Gidaris and Nikos Komodakis}, 137 | booktitle={Forty-second International Conference on Machine Learning}, 138 | year={2025}, 139 | url={https://openreview.net/forum?id=UWhW5YYLo6} 140 | } 141 | ``` 142 | `` 143 | -------------------------------------------------------------------------------- /models/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /models/diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /train_eqvae/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /evaluation/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hashlib 3 | import requests 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | import os 8 | from tqdm import tqdm 9 | 10 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 11 | 12 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 13 | 14 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 15 | 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class LPIPS(nn.Module): 48 | # Learned perceptual metric 49 | def __init__(self, use_dropout=True): 50 | super().__init__() 51 | self.scaling_layer = ScalingLayer() 52 | self.chns = [64, 128, 256, 512, 512] # vgg16 features 53 | self.net = vgg16(pretrained=True, requires_grad=False) 54 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 55 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 56 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 57 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 58 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 59 | self.load_from_pretrained() 60 | for param in self.parameters(): 61 | param.requires_grad = False 62 | 63 | def load_from_pretrained(self, name="vgg_lpips"): 64 | ckpt = get_ckpt_path(name, "movqgan/modules/losses/lpips") 65 | self.load_state_dict( 66 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 67 | ) 68 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 69 | 70 | @classmethod 71 | def from_pretrained(cls, name="vgg_lpips"): 72 | if name != "vgg_lpips": 73 | raise NotImplementedError 74 | model = cls() 75 | ckpt = get_ckpt_path(name) 76 | model.load_state_dict( 77 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 78 | ) 79 | return model 80 | 81 | def forward(self, input, target): 82 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 83 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 84 | feats0, feats1, diffs = {}, {}, {} 85 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 86 | for kk in range(len(self.chns)): 87 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 88 | outs1[kk] 89 | ) 90 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 91 | 92 | res = [ 93 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 94 | for kk in range(len(self.chns)) 95 | ] 96 | val = res[0] 97 | for l in range(1, len(self.chns)): 98 | val += res[l] 99 | return val 100 | 101 | 102 | class ScalingLayer(nn.Module): 103 | def __init__(self): 104 | super(ScalingLayer, self).__init__() 105 | self.register_buffer( 106 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 107 | ) 108 | self.register_buffer( 109 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 110 | ) 111 | 112 | def forward(self, inp): 113 | # convert imagenet normalized data to [-1, 1] 114 | return (inp - self.shift) / self.scale 115 | 116 | 117 | class NetLinLayer(nn.Module): 118 | """A single linear layer which does a 1x1 conv""" 119 | 120 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 121 | super(NetLinLayer, self).__init__() 122 | layers = ( 123 | [ 124 | nn.Dropout(), 125 | ] 126 | if (use_dropout) 127 | else [] 128 | ) 129 | layers += [ 130 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 131 | ] 132 | self.model = nn.Sequential(*layers) 133 | 134 | 135 | class vgg16(torch.nn.Module): 136 | def __init__(self, requires_grad=False, pretrained=True): 137 | super(vgg16, self).__init__() 138 | vgg_pretrained_features = models.vgg16(pretrained=pretrained) 139 | vgg_pretrained_features = vgg_pretrained_features.features 140 | self.slice1 = torch.nn.Sequential() 141 | self.slice2 = torch.nn.Sequential() 142 | self.slice3 = torch.nn.Sequential() 143 | self.slice4 = torch.nn.Sequential() 144 | self.slice5 = torch.nn.Sequential() 145 | self.N_slices = 5 146 | for x in range(4): 147 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 148 | for x in range(4, 9): 149 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 150 | for x in range(9, 16): 151 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 152 | for x in range(16, 23): 153 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 154 | for x in range(23, 30): 155 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 156 | if not requires_grad: 157 | for param in self.parameters(): 158 | param.requires_grad = False 159 | 160 | def forward(self, X): 161 | h = self.slice1(X) 162 | h_relu1_2 = h 163 | h = self.slice2(h) 164 | h_relu2_2 = h 165 | h = self.slice3(h) 166 | h_relu3_3 = h 167 | h = self.slice4(h) 168 | h_relu4_3 = h 169 | h = self.slice5(h) 170 | h_relu5_3 = h 171 | vgg_outputs = namedtuple( 172 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 173 | ) 174 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 175 | return out 176 | 177 | 178 | def normalize_tensor(x, eps=1e-10): 179 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 180 | return x / (norm_factor + eps) 181 | 182 | 183 | def spatial_average(x, keepdim=True): 184 | return x.mean([2, 3], keepdim=keepdim) 185 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | 138 | class FrozenCLIPTextEmbedder(nn.Module): 139 | """ 140 | Uses the CLIP transformer encoder for text. 141 | """ 142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 143 | super().__init__() 144 | self.model, _ = clip.load(version, jit=False, device="cpu") 145 | self.device = device 146 | self.max_length = max_length 147 | self.n_repeat = n_repeat 148 | self.normalize = normalize 149 | 150 | def freeze(self): 151 | self.model = self.model.eval() 152 | for param in self.parameters(): 153 | param.requires_grad = False 154 | 155 | def forward(self, text): 156 | tokens = clip.tokenize(text).to(self.device) 157 | z = self.model.encode_text(tokens) 158 | if self.normalize: 159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 160 | return z 161 | 162 | def encode(self, text): 163 | z = self(text) 164 | if z.ndim==2: 165 | z = z[:, None, :] 166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 167 | return z 168 | 169 | 170 | class FrozenClipImageEmbedder(nn.Module): 171 | """ 172 | Uses the CLIP image encoder. 173 | """ 174 | def __init__( 175 | self, 176 | model, 177 | jit=False, 178 | device='cuda' if torch.cuda.is_available() else 'cpu', 179 | antialias=False, 180 | ): 181 | super().__init__() 182 | self.model, _ = clip.load(name=model, device=device, jit=jit) 183 | 184 | self.antialias = antialias 185 | 186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 188 | 189 | def preprocess(self, x): 190 | # normalize to [0,1] 191 | x = kornia.geometry.resize(x, (224, 224), 192 | interpolation='bicubic',align_corners=True, 193 | antialias=self.antialias) 194 | x = (x + 1.) / 2. 195 | # renormalize according to clip 196 | x = kornia.enhance.normalize(x, self.mean, self.std) 197 | return x 198 | 199 | def forward(self, x): 200 | # x is assumed to be in range [-1,1] 201 | return self.model.encode_image(self.preprocess(x)) 202 | 203 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /train_eqvae/ldm/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os, yaml, pickle, shutil, tarfile, glob 2 | import cv2 3 | import albumentations 4 | import PIL 5 | import numpy as np 6 | import torchvision.transforms.functional as TF 7 | from omegaconf import OmegaConf 8 | from functools import partial 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset, Subset 12 | 13 | import taming.data.utils as tdu 14 | 15 | from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light 16 | 17 | 18 | 19 | class DatasetTrain(Dataset): 20 | def __init__(self, 21 | train_dir = "/data/openimages/target_dir/train/", 22 | size=None, dataset_name="openimages", 23 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., 24 | random_crop=True): 25 | 26 | 27 | 28 | if dataset_name == "openimages": 29 | rel_paths = [l for l in os.listdir(train_dir)] 30 | else: 31 | raise NotImplementedError 32 | 33 | 34 | self.labels = { 35 | "relative_file_path_": rel_paths, 36 | "file_path_": [os.path.join(train_dir, p) for p in rel_paths], 37 | } 38 | self.length = len(rel_paths) 39 | 40 | 41 | assert size 42 | assert (size / downscale_f).is_integer() 43 | self.size = size 44 | self.LR_size = int(size / downscale_f) 45 | self.min_crop_f = min_crop_f 46 | self.max_crop_f = max_crop_f 47 | assert(max_crop_f <= 1.) 48 | self.center_crop = not random_crop 49 | 50 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) 51 | 52 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow 53 | 54 | if degradation == "bsrgan": 55 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) 56 | 57 | elif degradation == "bsrgan_light": 58 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) 59 | 60 | else: 61 | interpolation_fn = { 62 | "cv_nearest": cv2.INTER_NEAREST, 63 | "cv_bilinear": cv2.INTER_LINEAR, 64 | "cv_bicubic": cv2.INTER_CUBIC, 65 | "cv_area": cv2.INTER_AREA, 66 | "cv_lanczos": cv2.INTER_LANCZOS4, 67 | "pil_nearest": PIL.Image.NEAREST, 68 | "pil_bilinear": PIL.Image.BILINEAR, 69 | "pil_bicubic": PIL.Image.BICUBIC, 70 | "pil_box": PIL.Image.BOX, 71 | "pil_hamming": PIL.Image.HAMMING, 72 | "pil_lanczos": PIL.Image.LANCZOS, 73 | }[degradation] 74 | 75 | self.pil_interpolation = degradation.startswith("pil_") 76 | 77 | if self.pil_interpolation: 78 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) 79 | 80 | else: 81 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, 82 | interpolation=interpolation_fn) 83 | 84 | def __len__(self): 85 | return self.length 86 | 87 | 88 | def __getitem__(self, i): 89 | example = dict((k, self.labels[k][i]) for k in self.labels) 90 | image = Image.open(example["file_path_"]) 91 | 92 | if not image.mode == "RGB": 93 | image = image.convert("RGB") 94 | 95 | image = np.array(image).astype(np.uint8) 96 | 97 | min_side_len = min(image.shape[:2]) 98 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) 99 | crop_side_len = int(crop_side_len) 100 | 101 | if self.center_crop: 102 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) 103 | 104 | else: 105 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) 106 | 107 | image = self.cropper(image=image)["image"] 108 | image = self.image_rescaler(image=image)["image"] 109 | 110 | if self.pil_interpolation: 111 | image_pil = PIL.Image.fromarray(image) 112 | LR_image = self.degradation_process(image_pil) 113 | LR_image = np.array(LR_image).astype(np.uint8) 114 | 115 | else: 116 | LR_image = self.degradation_process(image=image)["image"] 117 | 118 | example["image"] = (image/127.5 - 1.0).astype(np.float32) 119 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) 120 | 121 | return example 122 | 123 | 124 | 125 | 126 | class DatasetVal(Dataset): 127 | def __init__(self, 128 | val_dir="/data/openimages/target_dir/validation/", 129 | dataset_name='openimages', 130 | size=None, 131 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., 132 | random_crop=True): 133 | 134 | if dataset_name == "openimages": 135 | rel_paths = [l for l in os.listdir(val_dir)] 136 | else: 137 | raise NotImplementedError 138 | 139 | self.labels = { 140 | "relative_file_path_": rel_paths, 141 | "file_path_": [os.path.join(val_dir, p) for p in rel_paths], 142 | } 143 | self.length = len(rel_paths) 144 | 145 | 146 | 147 | assert size 148 | assert (size / downscale_f).is_integer() 149 | self.size = size 150 | self.LR_size = int(size / downscale_f) 151 | self.min_crop_f = min_crop_f 152 | self.max_crop_f = max_crop_f 153 | assert(max_crop_f <= 1.) 154 | self.center_crop = not random_crop 155 | 156 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) 157 | 158 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow 159 | 160 | if degradation == "bsrgan": 161 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) 162 | 163 | elif degradation == "bsrgan_light": 164 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) 165 | 166 | else: 167 | interpolation_fn = { 168 | "cv_nearest": cv2.INTER_NEAREST, 169 | "cv_bilinear": cv2.INTER_LINEAR, 170 | "cv_bicubic": cv2.INTER_CUBIC, 171 | "cv_area": cv2.INTER_AREA, 172 | "cv_lanczos": cv2.INTER_LANCZOS4, 173 | "pil_nearest": PIL.Image.NEAREST, 174 | "pil_bilinear": PIL.Image.BILINEAR, 175 | "pil_bicubic": PIL.Image.BICUBIC, 176 | "pil_box": PIL.Image.BOX, 177 | "pil_hamming": PIL.Image.HAMMING, 178 | "pil_lanczos": PIL.Image.LANCZOS, 179 | }[degradation] 180 | 181 | self.pil_interpolation = degradation.startswith("pil_") 182 | 183 | if self.pil_interpolation: 184 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) 185 | 186 | else: 187 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, 188 | interpolation=interpolation_fn) 189 | 190 | def __len__(self): 191 | return self.length 192 | 193 | def __getitem__(self, i): 194 | example = dict((k, self.labels[k][i]) for k in self.labels) 195 | image = Image.open(example["file_path_"]) 196 | 197 | if not image.mode == "RGB": 198 | image = image.convert("RGB") 199 | 200 | image = np.array(image).astype(np.uint8) 201 | 202 | min_side_len = min(image.shape[:2]) 203 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) 204 | crop_side_len = int(crop_side_len) 205 | 206 | if self.center_crop: 207 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) 208 | 209 | else: 210 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) 211 | 212 | image = self.cropper(image=image)["image"] 213 | image = self.image_rescaler(image=image)["image"] 214 | 215 | if self.pil_interpolation: 216 | image_pil = PIL.Image.fromarray(image) 217 | LR_image = self.degradation_process(image_pil) 218 | LR_image = np.array(LR_image).astype(np.uint8) 219 | 220 | else: 221 | LR_image = self.degradation_process(image=image)["image"] 222 | 223 | example["image"] = (image/127.5 - 1.0).astype(np.float32) 224 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) 225 | 226 | return example 227 | 228 | 229 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/hustvl/LightningDiT/blob/main/evaluate_tokenizer.py 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import torch.distributed as dist 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from omegaconf import OmegaConf 11 | from torch.utils.data import DataLoader, DistributedSampler 12 | from evaluation.calculate_fid import calculate_fid_given_paths 13 | from concurrent.futures import ThreadPoolExecutor, as_completed 14 | from torchmetrics import StructuralSimilarityIndexMeasure 15 | from evaluation.lpips import LPIPS 16 | from torchvision.datasets import ImageFolder 17 | from torchvision import transforms 18 | import csv 19 | import sys 20 | from ldm.models.autoencoder import AutoencoderKL 21 | from ldm.util import instantiate_from_config 22 | import yaml 23 | from omegaconf import OmegaConf 24 | 25 | def load_config(config_path, display=False): 26 | config = OmegaConf.load(config_path) 27 | if display: 28 | print(yaml.dump(OmegaConf.to_container(config))) 29 | return config 30 | 31 | 32 | def load_kl(config, type="sd", ckpt_path=None): 33 | model = AutoencoderKL(**config.model.params) 34 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 35 | missing, unexpected = model.load_state_dict(sd, strict=False) 36 | return model.eval() 37 | 38 | 39 | def print_with_prefix(content, prefix='Tokenizer Evaluation', rank=0): 40 | if rank == 0: 41 | print(f"\033[34m[{prefix}]\033[0m {content}") 42 | 43 | def save_image(image, filename): 44 | Image.fromarray(image).save(filename) 45 | 46 | 47 | 48 | def evaluate_tokenizer(config_path, model_name, data_path, output_path, ckpt_path): 49 | # Initialize distributed training 50 | dist.init_process_group(backend='nccl') 51 | local_rank = torch.distributed.get_rank() 52 | torch.cuda.set_device(local_rank) 53 | device = torch.device(f'cuda:{local_rank}') 54 | 55 | 56 | vae_config = load_config(config_path, display=False) 57 | model = load_kl(vae_config, ckpt_path=ckpt_path).to(device) 58 | 59 | # Image preprocessing 60 | transform = transforms.Compose([ 61 | transforms.ToTensor(), 62 | transforms.Resize(256), 63 | transforms.CenterCrop(256), 64 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 65 | ]) 66 | 67 | # Create dataset and dataloader 68 | dataset = ImageFolder(root=data_path, transform=transform) 69 | distributed_sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=local_rank) 70 | val_dataloader = DataLoader( 71 | dataset, 72 | batch_size=1, 73 | shuffle=False, 74 | num_workers=4, 75 | sampler=distributed_sampler 76 | ) 77 | 78 | folder_name = model_name 79 | 80 | save_dir = os.path.join(output_path, folder_name, 'decoded_images') 81 | ref_path = os.path.join(output_path, folder_name, 'ref_images') 82 | metric_path = os.path.join(output_path, folder_name, 'metrics.csv') 83 | 84 | os.makedirs(save_dir, exist_ok=True) 85 | os.makedirs(ref_path, exist_ok=True) 86 | 87 | if local_rank == 0: 88 | print_with_prefix(f"Output dir: {save_dir}") 89 | print_with_prefix(f"Reference dir: {ref_path}") 90 | 91 | # Save reference images if needed 92 | ref_png_files = [f for f in os.listdir(ref_path) if f.endswith('.png')] 93 | if len(ref_png_files) < 50000: 94 | total_samples = 0 95 | for batch in val_dataloader: 96 | images = batch[0].to(device) 97 | for j in range(images.size(0)): 98 | img = torch.clamp(127.5 * images[j] + 128.0, 0, 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8) 99 | Image.fromarray(img).save(os.path.join(ref_path, f"ref_image_rank_{local_rank}_{total_samples}.png")) 100 | total_samples += 1 101 | if total_samples % 100 == 0 and local_rank == 0: 102 | print_with_prefix(f"Rank {local_rank}, Saved {total_samples} reference images") 103 | dist.barrier() 104 | 105 | # Initialize metrics 106 | lpips_values = [] 107 | ssim_values = [] 108 | lpips = LPIPS().to(device).eval() 109 | ssim_metric = StructuralSimilarityIndexMeasure().to(device) 110 | 111 | # Generate reconstructions and compute metrics 112 | if local_rank == 0: 113 | print_with_prefix("Generating reconstructions...") 114 | all_indices = 0 115 | 116 | for batch in val_dataloader: 117 | images = batch[0].to(device) 118 | with torch.no_grad(): 119 | latents = model.encode(images).sample().to(torch.float32) 120 | decoded_images_tensor = model.decode(latents) 121 | 122 | decoded_images = torch.clamp(127.5 * decoded_images_tensor + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 123 | 124 | # Compute metrics 125 | lpips_values.append(lpips(decoded_images_tensor, images).mean()) 126 | ssim_values.append(ssim_metric(decoded_images_tensor, images)) 127 | 128 | # Save reconstructions 129 | for i, img in enumerate(decoded_images): 130 | save_image(img, os.path.join(save_dir, f"decoded_image_rank_{local_rank}_{all_indices + i}.png")) 131 | if (all_indices + i) % 100 == 0 and local_rank == 0: 132 | print_with_prefix(f"Rank {local_rank}, Processed {all_indices + i} images") 133 | all_indices += len(decoded_images) 134 | dist.barrier() 135 | 136 | # Aggregate metrics across GPUs 137 | lpips_values = torch.tensor(lpips_values).to(device) 138 | ssim_values = torch.tensor(ssim_values).to(device) 139 | dist.all_reduce(lpips_values, op=dist.ReduceOp.AVG) 140 | dist.all_reduce(ssim_values, op=dist.ReduceOp.AVG) 141 | 142 | avg_lpips = lpips_values.mean().item() 143 | avg_ssim = ssim_values.mean().item() 144 | 145 | if local_rank == 0: 146 | # Calculate FID 147 | print_with_prefix("Computing rFID...") 148 | fid = calculate_fid_given_paths([ref_path, save_dir], batch_size=50, dims=2048, device=device, num_workers=16) 149 | 150 | # Calculate PSNR 151 | print_with_prefix("Computing PSNR...") 152 | psnr_values = calculate_psnr_between_folders(ref_path, save_dir) 153 | avg_psnr = sum(psnr_values) / len(psnr_values) 154 | with open(metric_path, mode="w", newline="") as file: 155 | writer = csv.writer(file) 156 | writer.writerow(["FID", f"{fid:.3f}"]) 157 | writer.writerow(["PSNR", f"{avg_psnr:.3f}"]) 158 | writer.writerow(["LPIPS", f"{avg_lpips:.3f}"]) 159 | writer.writerow(["SSIM", f"{avg_ssim:.3f}"]) 160 | 161 | dist.destroy_process_group() 162 | 163 | 164 | def decode_to_images(model, z): 165 | with torch.no_grad(): 166 | images = model.decode(z) 167 | images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 168 | return images 169 | 170 | def calculate_psnr(original, processed): 171 | mse = torch.mean((original - processed) ** 2) 172 | return 20 * torch.log10(255.0 / torch.sqrt(mse)).item() 173 | 174 | def load_image(image_path): 175 | image = Image.open(image_path).convert('RGB') 176 | return torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32) 177 | 178 | def calculate_psnr_for_pair(original_path, processed_path): 179 | return calculate_psnr(load_image(original_path), load_image(processed_path)) 180 | 181 | def calculate_psnr_between_folders(original_folder, processed_folder): 182 | original_files = sorted(os.listdir(original_folder)) 183 | processed_files = sorted(os.listdir(processed_folder)) 184 | 185 | if len(original_files) != len(processed_files): 186 | print("Warning: Mismatched number of images in folders") 187 | return [] 188 | 189 | with ThreadPoolExecutor() as executor: 190 | futures = [ 191 | executor.submit(calculate_psnr_for_pair, 192 | os.path.join(original_folder, orig), 193 | os.path.join(processed_folder, proc)) 194 | for orig, proc in zip(original_files, processed_files) 195 | ] 196 | return [future.result() for future in as_completed(futures)] 197 | 198 | if __name__ == "__main__": 199 | import argparse 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument('--config_path', type=str, default='configs/eqvae_config.yaml') 202 | parser.add_argument('--model_name', type=str, default='eq_vae') 203 | parser.add_argument('--ckpt_path', type=str) 204 | parser.add_argument('--data_path', type=str, default='/path/to/your/imagenet/ILSVRC2012_validation/data') 205 | parser.add_argument('--output_path', type=str, default='/path/to/your/output') 206 | parser.add_argument('--seed', type=int, default=42) 207 | args = parser.parse_args() 208 | evaluate_tokenizer(config_path=args.config_path, model_name=args.model_name, data_path=args.data_path, output_path=args.output_path, ckpt_path=args.ckpt_path) 209 | -------------------------------------------------------------------------------- /train_gen/extract_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/chuanyangjin/fast-DiT 3 | """ 4 | import sys 5 | import os 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '../train_eqvae')) 7 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 8 | 9 | import torch 10 | torch.backends.cuda.matmul.allow_tf32 = True 11 | torch.backends.cudnn.allow_tf32 = True 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torchvision.datasets import ImageFolder 17 | from torchvision import transforms 18 | import numpy as np 19 | from collections import OrderedDict 20 | from PIL import Image 21 | from copy import deepcopy 22 | from glob import glob 23 | from time import time 24 | import argparse 25 | import logging 26 | from models.dit_models import DiT_models 27 | from models.diffusion import create_diffusion 28 | from diffusers.models import AutoencoderKL 29 | from tqdm import tqdm 30 | 31 | from train_eqvae.ldm.models.autoencoder import AutoencoderKL as LDMAutoencoderKL 32 | from ldm.util import instantiate_from_config 33 | import yaml 34 | from omegaconf import OmegaConf 35 | 36 | def load_config(config_path, display=False): 37 | config = OmegaConf.load(config_path) 38 | if display: 39 | print(yaml.dump(OmegaConf.to_container(config))) 40 | return config 41 | 42 | 43 | def load_kl(config, ckpt_path=None): 44 | 45 | model = LDMAutoencoderKL(**config.model.params) 46 | 47 | if ckpt_path is not None: 48 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 49 | missing, unexpected = model.load_state_dict(sd, strict=False) 50 | return model.eval() 51 | 52 | def preprocess_kl(x): 53 | x = 2.*x - 1. 54 | return x 55 | 56 | 57 | 58 | @torch.no_grad() 59 | def update_ema(ema_model, model, decay=0.9999): 60 | """ 61 | Step the EMA model towards the current model. 62 | """ 63 | ema_params = OrderedDict(ema_model.named_parameters()) 64 | model_params = OrderedDict(model.named_parameters()) 65 | 66 | for name, param in model_params.items(): 67 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 68 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 69 | 70 | 71 | def requires_grad(model, flag=True): 72 | """ 73 | Set requires_grad flag for all parameters in a model. 74 | """ 75 | for p in model.parameters(): 76 | p.requires_grad = flag 77 | 78 | 79 | def cleanup(): 80 | """ 81 | End DDP training. 82 | """ 83 | dist.destroy_process_group() 84 | 85 | 86 | def create_logger(logging_dir): 87 | """ 88 | Create a logger that writes to a log file and stdout. 89 | """ 90 | if dist.get_rank() == 0: # real logger 91 | logging.basicConfig( 92 | level=logging.INFO, 93 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 94 | datefmt='%Y-%m-%d %H:%M:%S', 95 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 96 | ) 97 | logger = logging.getLogger(__name__) 98 | else: # dummy logger (does nothing) 99 | logger = logging.getLogger(__name__) 100 | logger.addHandler(logging.NullHandler()) 101 | return logger 102 | 103 | 104 | def center_crop_arr(pil_image, image_size): 105 | """ 106 | Center cropping implementation from ADM. 107 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 108 | """ 109 | while min(*pil_image.size) >= 2 * image_size: 110 | pil_image = pil_image.resize( 111 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 112 | ) 113 | 114 | scale = image_size / min(*pil_image.size) 115 | pil_image = pil_image.resize( 116 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 117 | ) 118 | 119 | arr = np.array(pil_image) 120 | crop_y = (arr.shape[0] - image_size) // 2 121 | crop_x = (arr.shape[1] - image_size) // 2 122 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 123 | 124 | 125 | ################################################################################# 126 | # Training Loop # 127 | ################################################################################# 128 | 129 | def main(args): 130 | """ 131 | Trains a new DiT model. 132 | """ 133 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 134 | 135 | # Setup DDP: 136 | dist.init_process_group("nccl") 137 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." 138 | rank = dist.get_rank() 139 | device = rank % torch.cuda.device_count() 140 | seed = args.global_seed * dist.get_world_size() + rank 141 | torch.manual_seed(seed) 142 | torch.cuda.set_device(device) 143 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 144 | 145 | # Setup a feature folder: 146 | if rank == 0: 147 | os.makedirs(args.features_path, exist_ok=True) 148 | os.makedirs(os.path.join(args.features_path, 'imagenet256_features'), exist_ok=True) 149 | os.makedirs(os.path.join(args.features_path, 'imagenet256_labels'), exist_ok=True) 150 | 151 | # Create model: 152 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 153 | latent_size = args.image_size // 8 154 | 155 | if args.vae_ckpt is not None: 156 | vae_config = load_config(args.vae_config, display=False) 157 | vae = load_kl(vae_config, ckpt_path=args.vae_ckpt).to(device) 158 | else: 159 | from diffusers.models import AutoencoderKL 160 | vae = AutoencoderKL.from_pretrained(args.hf_model_name).to(device) 161 | 162 | 163 | # Setup data: 164 | transform = transforms.Compose([ 165 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 166 | transforms.RandomHorizontalFlip(), 167 | transforms.ToTensor(), 168 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 169 | ]) 170 | dataset = ImageFolder(args.data_path, transform=transform) 171 | sampler = DistributedSampler( 172 | dataset, 173 | num_replicas=dist.get_world_size(), 174 | rank=rank, 175 | shuffle=False, 176 | seed=args.global_seed 177 | ) 178 | loader = DataLoader( 179 | dataset, 180 | batch_size = 1, 181 | shuffle=False, 182 | sampler=sampler, 183 | num_workers=args.num_workers, 184 | pin_memory=True, 185 | drop_last=True 186 | ) 187 | 188 | train_steps = 0 189 | for x, y in tqdm(loader): 190 | x = x.to(device) 191 | y = y.to(device) 192 | with torch.no_grad(): 193 | 194 | if args.vae_ckpt is not None: 195 | 196 | x = vae.encode(x).sample().mul_(args.vae_scaling_factor) 197 | 198 | else: 199 | x = vae.encode(x).latent_dist.sample().mul_(args.vae_scaling_factor) 200 | 201 | x = x.detach().cpu().numpy() # (1, 4, 32, 32) 202 | np.save(f'{args.features_path}/imagenet256_features/{rank}-{train_steps}.npy', x) 203 | 204 | y = y.detach().cpu().numpy() # (1,) 205 | np.save(f'{args.features_path}/imagenet256_labels/{rank}-{train_steps}.npy', y) 206 | 207 | train_steps += 1 208 | 209 | if __name__ == "__main__": 210 | # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters). 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument("--data-path", type=str, required=True) 213 | parser.add_argument("--features-path", type=str, default="features") 214 | parser.add_argument("--results-dir", type=str, default="results") 215 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") 216 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 217 | parser.add_argument("--num-classes", type=int, default=1000) 218 | parser.add_argument("--epochs", type=int, default=1400) 219 | parser.add_argument("--global-batch-size", type=int, default=256) 220 | parser.add_argument("--global-seed", type=int, default=0) 221 | parser.add_argument("--vae", type=str, choices=["ema", "mse", "xl", "sd3","ours"], default="ema") 222 | parser.add_argument("--vae-ckpt", type=str, default=None) 223 | parser.add_argument("--vae-config", type=str, default=None) 224 | parser.add_argument("--hf-model-name", type=str, default="zelaki/eq-vae") 225 | parser.add_argument("--vae-scaling-factor", type=float, default=0.18215) 226 | parser.add_argument("--num-workers", type=int, default=4) 227 | parser.add_argument("--log-every", type=int, default=100) 228 | parser.add_argument("--ckpt-every", type=int, default=50_000) 229 | args = parser.parse_args() 230 | main(args) 231 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /train_eqvae/ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 7 | 8 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 9 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 10 | 11 | from ldm.util import instantiate_from_config 12 | import random 13 | import json 14 | import os 15 | import copy 16 | 17 | 18 | 19 | def flip_or_rotate_image(inputs, flip): 20 | if flip == "h": 21 | inputs = torch.flip(inputs, [-1]) 22 | 23 | elif flip == "v": 24 | inputs = torch.flip(inputs, [-2]) 25 | 26 | elif flip == "vh": 27 | inputs = torch.flip(inputs, [-1,-2]) 28 | 29 | elif flip == "90": 30 | inputs = torch.rot90(inputs, k=1, dims=[-1, -2]) 31 | 32 | else: 33 | inputs = torch.rot90(inputs, k=3, dims=[-1, -2]) 34 | 35 | return inputs 36 | 37 | 38 | class AutoencoderKL(pl.LightningModule): 39 | def __init__(self, 40 | ddconfig, 41 | lossconfig, 42 | embed_dim, 43 | anisotropic=False, 44 | uniform_sample_scale=True, 45 | ckpt_path=None, 46 | p_prior=0.5, 47 | p_prior_s=0.25, 48 | ignore_keys=[], 49 | image_key="image", 50 | colorize_nlabels=None, 51 | monitor=None, 52 | ): 53 | super().__init__() 54 | self.image_key = image_key 55 | self.encoder = Encoder(**ddconfig) 56 | self.decoder = Decoder(**ddconfig) 57 | self.loss = instantiate_from_config(lossconfig) 58 | self.uniform_sample_scale = uniform_sample_scale 59 | self.anisotropic = anisotropic 60 | self.p_prior=p_prior 61 | self.p_prior_s=p_prior_s 62 | assert ddconfig["double_z"] 63 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 64 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 65 | self.embed_dim = embed_dim 66 | if colorize_nlabels is not None: 67 | assert type(colorize_nlabels)==int 68 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 69 | if monitor is not None: 70 | self.monitor = monitor 71 | if ckpt_path is not None: 72 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 73 | 74 | def init_from_ckpt(self, path, ignore_keys=list()): 75 | sd = torch.load(path, map_location="cpu")["state_dict"] 76 | keys = list(sd.keys()) 77 | for k in keys: 78 | for ik in ignore_keys: 79 | if k.startswith(ik): 80 | print("Deleting key {} from state_dict.".format(k)) 81 | del sd[k] 82 | self.load_state_dict(sd, strict=False) 83 | print(f"Restored from {path}") 84 | 85 | def encode(self, x): 86 | h = self.encoder(x) 87 | moments = self.quant_conv(h) 88 | posterior = DiagonalGaussianDistribution(moments) 89 | return posterior 90 | 91 | def decode(self, z): 92 | z = self.post_quant_conv(z) 93 | dec = self.decoder(z) 94 | return dec 95 | 96 | def forward(self, input, scale=1, angle=0): 97 | posterior = self.encode(input) 98 | z = posterior.sample() 99 | 100 | if scale != 1: 101 | z = torch.nn.functional.interpolate(z, scale_factor=scale, mode='bilinear', align_corners=False) 102 | 103 | if angle != 0: 104 | z = torch.rot90(z, k=angle, dims=[-1, -2]) 105 | 106 | dec = self.decode(z) 107 | return dec, posterior, z 108 | 109 | def get_input(self, batch, k): 110 | x = batch[k] 111 | if len(x.shape) == 3: 112 | x = x[..., None] 113 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 114 | return x 115 | 116 | def training_step(self, batch, batch_idx, optimizer_idx): 117 | inputs = self.get_input(batch, self.image_key) 118 | 119 | # EQ-VAE regularization 120 | if random.random() < self.p_prior: 121 | 122 | mode = "latent" 123 | if self.anisotropic: 124 | scale_x = random.choice([s / 32 for s in range(8,32)]) 125 | scale_y = random.choice([s / 32 for s in range(8,32)]) 126 | scale=(scale_x, scale_y) 127 | else: 128 | scale = random.choice([s / 32 for s in range(8,32)]) 129 | 130 | # rotation angles 1 -> π/2, 2 -> π, 3 -> 3π/2 131 | angle = random.choice([1, 2, 3]) 132 | reconstructions, posterior, z_after = self(inputs, scale=scale, angle=angle) 133 | 134 | # Scale ground truth images with the same scale 135 | inputs = torch.nn.functional.interpolate(inputs, scale_factor=scale, mode='bilinear', align_corners=False) 136 | 137 | # Rotate ground truth images with the same angle 138 | inputs = torch.rot90(inputs, k=angle, dims=[-1, -2]) 139 | 140 | # prior preservation 141 | else: 142 | mode = "image" 143 | # this is prior preservation for low resolution images 144 | if random.random() < self.p_prior_s: 145 | 146 | scale = random.choice([s / 32 for s in range(8,32)]) 147 | inputs = torch.nn.functional.interpolate(inputs, scale_factor=scale, mode='bilinear', align_corners=False) 148 | reconstructions, posterior, _ = self(inputs) 149 | 150 | # this is prior preservation for full resolution images 151 | else: 152 | scale=1 153 | reconstructions, posterior, _ = self(inputs) 154 | 155 | 156 | 157 | if optimizer_idx == 0: 158 | 159 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 160 | last_layer=self.get_last_layer(), split="train") 161 | 162 | 163 | self.log(f"aeloss_scale-{scale}-{mode}", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 164 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 165 | return aeloss 166 | 167 | if optimizer_idx == 1: 168 | 169 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 170 | last_layer=self.get_last_layer(), split="train") 171 | 172 | 173 | self.log(f"discloss_scale-{scale}-{mode}", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 174 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 175 | return discloss 176 | 177 | def validation_step(self, batch, batch_idx): 178 | inputs = self.get_input(batch, self.image_key) 179 | reconstructions, posterior, _ = self(inputs) 180 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 181 | last_layer=self.get_last_layer(), split="val") 182 | 183 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 184 | last_layer=self.get_last_layer(), split="val") 185 | 186 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 187 | self.log_dict(log_dict_ae) 188 | self.log_dict(log_dict_disc) 189 | return self.log_dict 190 | 191 | def configure_optimizers(self): 192 | lr = self.learning_rate 193 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 194 | list(self.decoder.parameters())+ 195 | list(self.quant_conv.parameters())+ 196 | list(self.post_quant_conv.parameters()), 197 | lr=lr, betas=(0.5, 0.9)) 198 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 199 | lr=lr, betas=(0.5, 0.9)) 200 | return [opt_ae, opt_disc], [] 201 | 202 | def get_last_layer(self): 203 | return self.decoder.conv_out.weight 204 | 205 | @torch.no_grad() 206 | def log_images(self, batch, only_inputs=False, **kwargs): 207 | log = dict() 208 | x = self.get_input(batch, self.image_key) 209 | x = x.to(self.device) 210 | if not only_inputs: 211 | if random.random() < 0.5: 212 | xrec, posterior, _ = self(x) 213 | else: 214 | xrec, posterior, _ = self(x, scale=0.5) 215 | 216 | if x.shape[1] > 3: 217 | # colorize with random projection 218 | assert xrec.shape[1] > 3 219 | x = self.to_rgb(x) 220 | xrec = self.to_rgb(xrec) 221 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 222 | log["reconstructions"] = xrec 223 | log["inputs"] = x 224 | return log 225 | 226 | def to_rgb(self, x): 227 | assert self.image_key == "segmentation" 228 | if not hasattr(self, "colorize"): 229 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 230 | x = F.conv2d(x, weight=self.colorize) 231 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 232 | return x 233 | 234 | 235 | class IdentityFirstStage(torch.nn.Module): 236 | def __init__(self, *args, vq_interface=False, **kwargs): 237 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 238 | super().__init__() 239 | 240 | def encode(self, x, *args, **kwargs): 241 | return x 242 | 243 | def decode(self, x, *args, **kwargs): 244 | return x 245 | 246 | def quantize(self, x, *args, **kwargs): 247 | if self.vq_interface: 248 | return x, None, [None, None, None] 249 | return x 250 | 251 | def forward(self, x, *args, **kwargs): 252 | return x 253 | 254 | 255 | -------------------------------------------------------------------------------- /train_eqvae/ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /train_eqvae/ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /train_eqvae/ldm/models/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class DDIMSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | self.register_buffer('betas', to_torch(self.model.betas)) 32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 34 | 35 | # calculations for diffusion q(x_t | x_{t-1}) and others 36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 41 | 42 | # ddim sampling parameters 43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 44 | ddim_timesteps=self.ddim_timesteps, 45 | eta=ddim_eta,verbose=verbose) 46 | self.register_buffer('ddim_sigmas', ddim_sigmas) 47 | self.register_buffer('ddim_alphas', ddim_alphas) 48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 54 | 55 | @torch.no_grad() 56 | def sample(self, 57 | S, 58 | batch_size, 59 | shape, 60 | conditioning=None, 61 | callback=None, 62 | normals_sequence=None, 63 | img_callback=None, 64 | quantize_x0=False, 65 | eta=0., 66 | mask=None, 67 | x0=None, 68 | temperature=1., 69 | noise_dropout=0., 70 | score_corrector=None, 71 | corrector_kwargs=None, 72 | verbose=True, 73 | x_T=None, 74 | log_every_t=100, 75 | unconditional_guidance_scale=1., 76 | unconditional_conditioning=None, 77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 78 | **kwargs 79 | ): 80 | if conditioning is not None: 81 | if isinstance(conditioning, dict): 82 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 83 | if cbs != batch_size: 84 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 85 | else: 86 | if conditioning.shape[0] != batch_size: 87 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 88 | 89 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 90 | # sampling 91 | C, H, W = shape 92 | size = (batch_size, C, H, W) 93 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 94 | 95 | samples, intermediates = self.ddim_sampling(conditioning, size, 96 | callback=callback, 97 | img_callback=img_callback, 98 | quantize_denoised=quantize_x0, 99 | mask=mask, x0=x0, 100 | ddim_use_original_steps=False, 101 | noise_dropout=noise_dropout, 102 | temperature=temperature, 103 | score_corrector=score_corrector, 104 | corrector_kwargs=corrector_kwargs, 105 | x_T=x_T, 106 | log_every_t=log_every_t, 107 | unconditional_guidance_scale=unconditional_guidance_scale, 108 | unconditional_conditioning=unconditional_conditioning, 109 | ) 110 | return samples, intermediates 111 | 112 | @torch.no_grad() 113 | def ddim_sampling(self, cond, shape, 114 | x_T=None, ddim_use_original_steps=False, 115 | callback=None, timesteps=None, quantize_denoised=False, 116 | mask=None, x0=None, img_callback=None, log_every_t=100, 117 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 118 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 119 | device = self.model.betas.device 120 | b = shape[0] 121 | if x_T is None: 122 | img = torch.randn(shape, device=device) 123 | else: 124 | img = x_T 125 | 126 | if timesteps is None: 127 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 128 | elif timesteps is not None and not ddim_use_original_steps: 129 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 130 | timesteps = self.ddim_timesteps[:subset_end] 131 | 132 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 133 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 134 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 135 | print(f"Running DDIM Sampling with {total_steps} timesteps") 136 | 137 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 138 | 139 | for i, step in enumerate(iterator): 140 | index = total_steps - i - 1 141 | ts = torch.full((b,), step, device=device, dtype=torch.long) 142 | 143 | if mask is not None: 144 | assert x0 is not None 145 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 146 | img = img_orig * mask + (1. - mask) * img 147 | 148 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 149 | quantize_denoised=quantize_denoised, temperature=temperature, 150 | noise_dropout=noise_dropout, score_corrector=score_corrector, 151 | corrector_kwargs=corrector_kwargs, 152 | unconditional_guidance_scale=unconditional_guidance_scale, 153 | unconditional_conditioning=unconditional_conditioning) 154 | img, pred_x0 = outs 155 | if callback: callback(i) 156 | if img_callback: img_callback(pred_x0, i) 157 | 158 | if index % log_every_t == 0 or index == total_steps - 1: 159 | intermediates['x_inter'].append(img) 160 | intermediates['pred_x0'].append(pred_x0) 161 | 162 | return img, intermediates 163 | 164 | @torch.no_grad() 165 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 166 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 167 | unconditional_guidance_scale=1., unconditional_conditioning=None): 168 | b, *_, device = *x.shape, x.device 169 | 170 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 171 | e_t = self.model.apply_model(x, t, c) 172 | else: 173 | x_in = torch.cat([x] * 2) 174 | t_in = torch.cat([t] * 2) 175 | c_in = torch.cat([unconditional_conditioning, c]) 176 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 177 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 178 | 179 | if score_corrector is not None: 180 | assert self.model.parameterization == "eps" 181 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 182 | 183 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 184 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 185 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 186 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 187 | # select parameters corresponding to the currently considered timestep 188 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 189 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 190 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 191 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 192 | 193 | # current prediction for x_0 194 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 195 | if quantize_denoised: 196 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 197 | # direction pointing to x_t 198 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 199 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 200 | if noise_dropout > 0.: 201 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 202 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 203 | return x_prev, pred_x0 204 | -------------------------------------------------------------------------------- /train_gen/sample_ddp.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Code adapted from https://github.com/chuanyangjin/fast-DiT 4 | """ 5 | import sys 6 | import os 7 | sys.path.append(os.path.join(os.path.dirname(__file__), '../train_eqvae')) 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 9 | 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from download import find_model 14 | from models.dit_models import DiT_models 15 | from models.diffusion import create_diffusion 16 | from diffusers.models import AutoencoderKL 17 | from tqdm import tqdm 18 | import os 19 | from PIL import Image 20 | import numpy as np 21 | import math 22 | import argparse 23 | from evaluator import Evaluator 24 | import tensorflow.compat.v1 as tf 25 | from train_eqvae.ldm.models.autoencoder import AutoencoderKL as LDMAutoencoderKL 26 | 27 | from ldm.util import instantiate_from_config 28 | import yaml 29 | from omegaconf import OmegaConf 30 | 31 | def load_config(config_path, display=False): 32 | config = OmegaConf.load(config_path) 33 | if display: 34 | print(yaml.dump(OmegaConf.to_container(config))) 35 | return config 36 | 37 | 38 | def load_kl(config, ckpt_path=None): 39 | 40 | model = LDMAutoencoderKL(**config.model.params) 41 | 42 | if ckpt_path is not None: 43 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 44 | missing, unexpected = model.load_state_dict(sd, strict=False) 45 | return model.eval() 46 | 47 | def preprocess_kl(x): 48 | x = 2.*x - 1. 49 | return x 50 | 51 | 52 | 53 | def create_npz_from_sample_folder(sample_dir, num=50_000): 54 | """ 55 | Builds a single .npz file from a folder of .png samples. 56 | """ 57 | samples = [] 58 | for i in tqdm(range(num), desc="Building .npz file from samples"): 59 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 60 | sample_np = np.asarray(sample_pil).astype(np.uint8) 61 | samples.append(sample_np) 62 | samples = np.stack(samples) 63 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 64 | npz_path = f"{sample_dir}.npz" 65 | np.savez(npz_path, arr_0=samples) 66 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 67 | return npz_path 68 | 69 | 70 | 71 | def custom_to_pil(x): 72 | x = x.detach().cpu() 73 | x = torch.clamp(x, -1., 1.) 74 | x = (x + 1.)/2. 75 | x = x.permute(1,2,0).numpy() 76 | x = (255*x).astype(np.uint8) 77 | x = Image.fromarray(x) 78 | if not x.mode == "RGB": 79 | x = x.convert("RGB") 80 | return x 81 | 82 | 83 | def calculate_metrics(ref_batch, sample_batch, fid_path): 84 | config = tf.ConfigProto( 85 | allow_soft_placement=True 86 | ) 87 | config.gpu_options.allow_growth = True 88 | evaluator = Evaluator(tf.Session(config=config)) 89 | 90 | evaluator.warmup() 91 | 92 | ref_acts = evaluator.read_activations(ref_batch) 93 | ref_stats, ref_stats_spatial = evaluator.read_statistics(ref_batch, ref_acts) 94 | 95 | sample_acts = evaluator.read_activations(sample_batch) 96 | sample_stats, sample_stats_spatial = evaluator.read_statistics(sample_batch, sample_acts) 97 | 98 | with open(fid_path, 'w') as fd: 99 | 100 | fd.write("Computing evaluations...\n") 101 | fd.write(f"Inception Score:{evaluator.compute_inception_score(sample_acts[0])}\n" ) 102 | fd.write(f"FID:{sample_stats.frechet_distance(ref_stats)}\n") 103 | fd.write(f"sFID:{sample_stats_spatial.frechet_distance(ref_stats_spatial)}\n") 104 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 105 | fd.write(f"Precision:{prec}\n") 106 | fd.write(f"Recall:{recall}\n") 107 | 108 | 109 | 110 | def main(args): 111 | """ 112 | Run sampling. 113 | """ 114 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 115 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 116 | torch.set_grad_enabled(False) 117 | 118 | # Setup DDP: 119 | dist.init_process_group("nccl") 120 | rank = dist.get_rank() 121 | device = rank % torch.cuda.device_count() 122 | seed = args.global_seed * dist.get_world_size() + rank 123 | torch.manual_seed(seed) 124 | torch.cuda.set_device(device) 125 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 126 | 127 | if args.ckpt is None: 128 | assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." 129 | assert args.image_size in [256, 512] 130 | assert args.num_classes == 1000 131 | 132 | # Load model: 133 | latent_size = args.image_size // 8 134 | model = DiT_models[args.model]( 135 | input_size=latent_size, 136 | num_classes=args.num_classes, 137 | in_channels=args.in_channels 138 | ).to(device) 139 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: 140 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" 141 | state_dict = find_model(ckpt_path) 142 | model.load_state_dict(state_dict) 143 | model.eval() # important! 144 | diffusion = create_diffusion(str(args.num_sampling_steps)) 145 | 146 | if args.vae_ckpt is not None: 147 | vae_config = load_config(args.vae_config, display=False) 148 | vae = load_kl(vae_config, ckpt_path=args.vae_ckpt).to(device) 149 | else: 150 | from diffusers.models import AutoencoderKL 151 | vae = AutoencoderKL.from_pretrained(args.hf_model_name).to(device) 152 | 153 | 154 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 155 | using_cfg = args.cfg_scale > 1.0 156 | 157 | # Create folder to save samples: 158 | model_string_name = args.model.replace("/", "-") 159 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" 160 | folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \ 161 | f"cfg-{args.cfg_scale}-seed-{args.global_seed}" 162 | 163 | if args.ddpm: 164 | folder_name += f"-ddpm" 165 | 166 | 167 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 168 | 169 | exp_dir_name = os.path.dirname(args.sample_dir) 170 | fid_dir = f"{exp_dir_name}/fids" 171 | 172 | if rank == 0: 173 | os.makedirs(sample_folder_dir, exist_ok=True) 174 | os.makedirs(fid_dir, exist_ok=True) 175 | 176 | print(f"Saving .png samples at {sample_folder_dir}") 177 | print(f"Saving metrics at {fid_dir}") 178 | 179 | dist.barrier() 180 | 181 | fid_path = f"{fid_dir}/metrics-adm-{folder_name}.txt" 182 | 183 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 184 | n = args.per_proc_batch_size 185 | global_batch_size = n * dist.get_world_size() 186 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 187 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 188 | if rank == 0: 189 | print(f"Total number of images that will be sampled: {total_samples}") 190 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 191 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 192 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 193 | iterations = int(samples_needed_this_gpu // n) 194 | pbar = range(iterations) 195 | pbar = tqdm(pbar) if rank == 0 else pbar 196 | total = 0 197 | for _ in pbar: 198 | # Sample inputs: 199 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) 200 | y = torch.randint(0, args.num_classes, (n,), device=device) 201 | 202 | # Setup classifier-free guidance: 203 | if using_cfg: 204 | z = torch.cat([z, z], 0) 205 | y_null = torch.tensor([1000] * n, device=device) 206 | y = torch.cat([y, y_null], 0) 207 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 208 | sample_fn = model.forward_with_cfg 209 | else: 210 | model_kwargs = dict(y=y) 211 | sample_fn = model.forward 212 | 213 | # Sample images: 214 | if args.ddpm: 215 | samples = diffusion.p_sample_loop( 216 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 217 | ) 218 | else: 219 | samples = diffusion.ddim_sample_loop( 220 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 221 | ) 222 | if using_cfg: 223 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 224 | 225 | if args.vae_ckpt is not None: 226 | 227 | samples = vae.decode(samples / args.vae_scaling_factor) 228 | 229 | else: 230 | 231 | samples = vae.decode(samples / args.vae_scaling_factor).sample 232 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 233 | 234 | # Save samples to disk as individual .png files 235 | for i, sample in enumerate(samples): 236 | index = i * dist.get_world_size() + rank + total 237 | 238 | if args.vae in ["ours", "hf"]: 239 | sample = custom_to_pil(sample) 240 | sample.save(f"{sample_folder_dir}/{index:06d}.png") 241 | else: 242 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 243 | total += global_batch_size 244 | 245 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 246 | dist.barrier() 247 | if rank == 0: 248 | sample_batch = create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) 249 | calculate_metrics(sample_batch, args.ref_batch, fid_path) 250 | 251 | dist.barrier() 252 | dist.destroy_process_group() 253 | 254 | 255 | if __name__ == "__main__": 256 | parser = argparse.ArgumentParser() 257 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") 258 | parser.add_argument("--vae-ckpt", type=str, default=None) 259 | parser.add_argument("--sample-dir", type=str, default="samples") 260 | parser.add_argument("--ddpm", type=bool, default=False) 261 | parser.add_argument("--vae-scaling-factor", type=float, default=0.18215) 262 | parser.add_argument("--ref-batch", type=str, default="/data/imagenet/VIRTUAL_imagenet256_labeled.npz") 263 | parser.add_argument("--wavelets", type=str, default=False) 264 | parser.add_argument("--diff-proj", type=str, default=False) 265 | parser.add_argument("--gaussian-registers", type=str, default=False) 266 | parser.add_argument("--in-channels", type=int, default=4) 267 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 268 | parser.add_argument("--num-classes", type=int, default=1000) 269 | parser.add_argument("--cfg-scale", type=float, default=1.5) 270 | parser.add_argument("--num-sampling-steps", type=int, default=250) 271 | parser.add_argument("--global-seed", type=int, default=0) 272 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, 273 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") 274 | parser.add_argument("--ckpt", type=str, default=None, 275 | help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") 276 | args = parser.parse_args() 277 | main(args) 278 | -------------------------------------------------------------------------------- /train_gen/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/chuanyangjin/fast-DiT 3 | """ 4 | import sys 5 | import os 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | 8 | import torch 9 | # the first flag below was False when we tested this script but True makes A100 training a lot faster: 10 | torch.backends.cuda.matmul.allow_tf32 = True 11 | torch.backends.cudnn.allow_tf32 = True 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torchvision.datasets import ImageFolder 17 | from torchvision import transforms 18 | import numpy as np 19 | from collections import OrderedDict 20 | from PIL import Image 21 | from copy import deepcopy 22 | from glob import glob 23 | from time import time 24 | import argparse 25 | import logging 26 | from accelerate import Accelerator 27 | from models.dit_models import DiT_models 28 | from models.diffusion import create_diffusion 29 | 30 | # from models import DiT_models 31 | # from diffusion import create_diffusion 32 | from diffusers.models import AutoencoderKL 33 | 34 | 35 | ################################################################################# 36 | # Training Helper Functions # 37 | ################################################################################# 38 | 39 | @torch.no_grad() 40 | def update_ema(ema_model, model, decay=0.9999): 41 | """ 42 | Step the EMA model towards the current model. 43 | """ 44 | ema_params = OrderedDict(ema_model.named_parameters()) 45 | model_params = OrderedDict(model.named_parameters()) 46 | 47 | for name, param in model_params.items(): 48 | name = name.replace("module.", "") 49 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 50 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 51 | 52 | 53 | def requires_grad(model, flag=True): 54 | """ 55 | Set requires_grad flag for all parameters in a model. 56 | """ 57 | for p in model.parameters(): 58 | p.requires_grad = flag 59 | 60 | 61 | def create_logger(logging_dir): 62 | """ 63 | Create a logger that writes to a log file and stdout. 64 | """ 65 | logging.basicConfig( 66 | level=logging.INFO, 67 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 68 | datefmt='%Y-%m-%d %H:%M:%S', 69 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 70 | ) 71 | logger = logging.getLogger(__name__) 72 | return logger 73 | 74 | 75 | def center_crop_arr(pil_image, image_size): 76 | """ 77 | Center cropping implementation from ADM. 78 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 79 | """ 80 | while min(*pil_image.size) >= 2 * image_size: 81 | pil_image = pil_image.resize( 82 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 83 | ) 84 | 85 | scale = image_size / min(*pil_image.size) 86 | pil_image = pil_image.resize( 87 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 88 | ) 89 | 90 | arr = np.array(pil_image) 91 | crop_y = (arr.shape[0] - image_size) // 2 92 | crop_x = (arr.shape[1] - image_size) // 2 93 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 94 | 95 | 96 | class CustomDataset(Dataset): 97 | def __init__(self, features_dir, labels_dir): 98 | self.features_dir = features_dir 99 | self.labels_dir = labels_dir 100 | 101 | self.features_files = sorted(os.listdir(features_dir)) 102 | self.labels_files = sorted(os.listdir(labels_dir)) 103 | 104 | def __len__(self): 105 | assert len(self.features_files) == len(self.labels_files), \ 106 | "Number of feature files and label files should be same" 107 | return len(self.features_files) 108 | 109 | def __getitem__(self, idx): 110 | feature_file = self.features_files[idx] 111 | label_file = self.labels_files[idx] 112 | 113 | features = np.load(os.path.join(self.features_dir, feature_file)) 114 | labels = np.load(os.path.join(self.labels_dir, label_file)) 115 | return torch.from_numpy(features), torch.from_numpy(labels) 116 | 117 | 118 | ################################################################################# 119 | # Training Loop # 120 | ################################################################################# 121 | 122 | def main(args): 123 | """ 124 | Trains a new DiT model. 125 | """ 126 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 127 | 128 | # Setup accelerator: 129 | accelerator = Accelerator() 130 | device = accelerator.device 131 | 132 | experiment_index = len(glob(f"{args.results_dir}/*")) 133 | model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders) 134 | 135 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder 136 | 137 | 138 | if args.dataset_name: 139 | experiment_dir += f"-{args.dataset_name}" 140 | 141 | if args.vae_name: 142 | experiment_dir += f"-{args.vae_name}" 143 | 144 | 145 | # Setup an experiment folder: 146 | if accelerator.is_main_process: 147 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) 148 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints 149 | os.makedirs(checkpoint_dir, exist_ok=True) 150 | logger = create_logger(experiment_dir) 151 | logger.info(f"Experiment directory created at {experiment_dir}") 152 | 153 | # Create model: 154 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 155 | latent_size = args.image_size // 8 156 | model = DiT_models[args.model]( 157 | input_size=latent_size, 158 | in_channels=args.in_channels 159 | ) 160 | # Note that parameter initialization is done within the DiT constructor 161 | model = model.to(device) 162 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training 163 | 164 | 165 | opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) 166 | 167 | if args.ckpt is not None: 168 | ckpt_path = args.ckpt 169 | state_dict = torch.load(ckpt_path, map_location="cpu") 170 | model.load_state_dict(state_dict["model"]) 171 | ema.load_state_dict(state_dict["ema"]) 172 | opt.load_state_dict(state_dict["opt"]) 173 | args = state_dict["args"] 174 | 175 | 176 | requires_grad(ema, False) 177 | 178 | 179 | diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule 180 | if accelerator.is_main_process: 181 | logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 182 | 183 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): 184 | 185 | # Setup data: 186 | features_dir = f"{args.feature_path}/imagenet256_features" 187 | labels_dir = f"{args.feature_path}/imagenet256_labels" 188 | dataset = CustomDataset(features_dir, labels_dir) 189 | loader = DataLoader( 190 | dataset, 191 | batch_size=int(args.global_batch_size // accelerator.num_processes), 192 | shuffle=True, 193 | num_workers=args.num_workers, 194 | pin_memory=True, 195 | drop_last=True 196 | ) 197 | if accelerator.is_main_process: 198 | logger.info(f"Dataset contains {len(dataset):,} images ({args.feature_path})") 199 | 200 | # Prepare models for training: 201 | update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights 202 | model.train() # important! This enables embedding dropout for classifier-free guidance 203 | ema.eval() # EMA model should always be in eval mode 204 | model, opt, loader = accelerator.prepare(model, opt, loader) 205 | 206 | # Variables for monitoring/logging purposes: 207 | train_steps = 0 208 | log_steps = 0 209 | running_loss = 0 210 | start_time = time() 211 | 212 | if accelerator.is_main_process: 213 | logger.info(f"Training for {args.epochs} epochs...") 214 | for epoch in range(args.epochs): 215 | if accelerator.is_main_process: 216 | logger.info(f"Beginning epoch {epoch}...") 217 | for x, y in loader: 218 | x = x.to(device) 219 | y = y.to(device) 220 | x = x.squeeze(dim=1) 221 | y = y.squeeze(dim=1) 222 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) 223 | model_kwargs = dict(y=y) 224 | loss_dict = diffusion.training_losses(model, x, t, model_kwargs) 225 | loss = loss_dict["loss"].mean() 226 | opt.zero_grad() 227 | accelerator.backward(loss) 228 | opt.step() 229 | update_ema(ema, model) 230 | 231 | # Log loss values: 232 | running_loss += loss.item() 233 | log_steps += 1 234 | train_steps += 1 235 | if train_steps % args.log_every == 0: 236 | # Measure training speed: 237 | torch.cuda.synchronize() 238 | end_time = time() 239 | steps_per_sec = log_steps / (end_time - start_time) 240 | # Reduce loss history over all processes: 241 | avg_loss = torch.tensor(running_loss / log_steps, device=device) 242 | avg_loss = avg_loss.item() / accelerator.num_processes 243 | if accelerator.is_main_process: 244 | logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") 245 | # Reset monitoring variables: 246 | running_loss = 0 247 | log_steps = 0 248 | start_time = time() 249 | 250 | # Save DiT checkpoint: 251 | if train_steps % args.ckpt_every == 0: 252 | if accelerator.is_main_process: 253 | checkpoint = { 254 | "model": model.module.state_dict(), 255 | "ema": ema.state_dict(), 256 | "opt": opt.state_dict(), 257 | "args": args 258 | } 259 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" 260 | torch.save(checkpoint, checkpoint_path) 261 | logger.info(f"Saved checkpoint to {checkpoint_path}") 262 | if train_steps > args.max_train_steps: 263 | break 264 | 265 | if accelerator.is_main_process: 266 | logger.info("Done!") 267 | 268 | 269 | if __name__ == "__main__": 270 | # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters). 271 | parser = argparse.ArgumentParser() 272 | parser.add_argument("--feature-path", type=str, default="features") 273 | parser.add_argument("--results-dir", type=str, default="results") 274 | parser.add_argument("--dataset-name", type=str, default=False) 275 | parser.add_argument("--in-channels", type=int, default=4) 276 | parser.add_argument("--ckpt", type=str, default=None) 277 | parser.add_argument("--max-train-steps", type=str, default=400_000) 278 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") 279 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 280 | parser.add_argument("--num-classes", type=int, default=1000) 281 | parser.add_argument("--epochs", type=int, default=1400) 282 | parser.add_argument("--global-batch-size", type=int, default=256) 283 | parser.add_argument("--global-seed", type=int, default=0) 284 | parser.add_argument("--vae-name", type=str, default="ema") # Choice doesn't affect training 285 | parser.add_argument("--num-workers", type=int, default=4) 286 | parser.add_argument("--log-every", type=int, default=100) 287 | parser.add_argument("--ckpt-every", type=int, default=50_000) 288 | 289 | args = parser.parse_args() 290 | main(args) 291 | -------------------------------------------------------------------------------- /train_eqvae/ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class PLMSSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | if ddim_eta != 0: 26 | raise ValueError('ddim_eta must be 0 for PLMS') 27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 29 | alphas_cumprod = self.model.alphas_cumprod 30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 32 | 33 | self.register_buffer('betas', to_torch(self.model.betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 43 | 44 | # ddim sampling parameters 45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 46 | ddim_timesteps=self.ddim_timesteps, 47 | eta=ddim_eta,verbose=verbose) 48 | self.register_buffer('ddim_sigmas', ddim_sigmas) 49 | self.register_buffer('ddim_alphas', ddim_alphas) 50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 56 | 57 | @torch.no_grad() 58 | def sample(self, 59 | S, 60 | batch_size, 61 | shape, 62 | conditioning=None, 63 | callback=None, 64 | normals_sequence=None, 65 | img_callback=None, 66 | quantize_x0=False, 67 | eta=0., 68 | mask=None, 69 | x0=None, 70 | temperature=1., 71 | noise_dropout=0., 72 | score_corrector=None, 73 | corrector_kwargs=None, 74 | verbose=True, 75 | x_T=None, 76 | log_every_t=100, 77 | unconditional_guidance_scale=1., 78 | unconditional_conditioning=None, 79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 80 | **kwargs 81 | ): 82 | if conditioning is not None: 83 | if isinstance(conditioning, dict): 84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 85 | if cbs != batch_size: 86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 87 | else: 88 | if conditioning.shape[0] != batch_size: 89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 90 | 91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 92 | # sampling 93 | C, H, W = shape 94 | size = (batch_size, C, H, W) 95 | print(f'Data shape for PLMS sampling is {size}') 96 | 97 | samples, intermediates = self.plms_sampling(conditioning, size, 98 | callback=callback, 99 | img_callback=img_callback, 100 | quantize_denoised=quantize_x0, 101 | mask=mask, x0=x0, 102 | ddim_use_original_steps=False, 103 | noise_dropout=noise_dropout, 104 | temperature=temperature, 105 | score_corrector=score_corrector, 106 | corrector_kwargs=corrector_kwargs, 107 | x_T=x_T, 108 | log_every_t=log_every_t, 109 | unconditional_guidance_scale=unconditional_guidance_scale, 110 | unconditional_conditioning=unconditional_conditioning, 111 | ) 112 | return samples, intermediates 113 | 114 | @torch.no_grad() 115 | def plms_sampling(self, cond, shape, 116 | x_T=None, ddim_use_original_steps=False, 117 | callback=None, timesteps=None, quantize_denoised=False, 118 | mask=None, x0=None, img_callback=None, log_every_t=100, 119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 120 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 121 | device = self.model.betas.device 122 | b = shape[0] 123 | if x_T is None: 124 | img = torch.randn(shape, device=device) 125 | else: 126 | img = x_T 127 | 128 | if timesteps is None: 129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 130 | elif timesteps is not None and not ddim_use_original_steps: 131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 132 | timesteps = self.ddim_timesteps[:subset_end] 133 | 134 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 137 | print(f"Running PLMS Sampling with {total_steps} timesteps") 138 | 139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 140 | old_eps = [] 141 | 142 | for i, step in enumerate(iterator): 143 | index = total_steps - i - 1 144 | ts = torch.full((b,), step, device=device, dtype=torch.long) 145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 146 | 147 | if mask is not None: 148 | assert x0 is not None 149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 150 | img = img_orig * mask + (1. - mask) * img 151 | 152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 153 | quantize_denoised=quantize_denoised, temperature=temperature, 154 | noise_dropout=noise_dropout, score_corrector=score_corrector, 155 | corrector_kwargs=corrector_kwargs, 156 | unconditional_guidance_scale=unconditional_guidance_scale, 157 | unconditional_conditioning=unconditional_conditioning, 158 | old_eps=old_eps, t_next=ts_next) 159 | img, pred_x0, e_t = outs 160 | old_eps.append(e_t) 161 | if len(old_eps) >= 4: 162 | old_eps.pop(0) 163 | if callback: callback(i) 164 | if img_callback: img_callback(pred_x0, i) 165 | 166 | if index % log_every_t == 0 or index == total_steps - 1: 167 | intermediates['x_inter'].append(img) 168 | intermediates['pred_x0'].append(pred_x0) 169 | 170 | return img, intermediates 171 | 172 | @torch.no_grad() 173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): 176 | b, *_, device = *x.shape, x.device 177 | 178 | def get_model_output(x, t): 179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 180 | e_t = self.model.apply_model(x, t, c) 181 | else: 182 | x_in = torch.cat([x] * 2) 183 | t_in = torch.cat([t] * 2) 184 | c_in = torch.cat([unconditional_conditioning, c]) 185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 187 | 188 | if score_corrector is not None: 189 | assert self.model.parameterization == "eps" 190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 191 | 192 | return e_t 193 | 194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 198 | 199 | def get_x_prev_and_pred_x0(e_t, index): 200 | # select parameters corresponding to the currently considered timestep 201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 205 | 206 | # current prediction for x_0 207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 208 | if quantize_denoised: 209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 210 | # direction pointing to x_t 211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 213 | if noise_dropout > 0.: 214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 216 | return x_prev, pred_x0 217 | 218 | e_t = get_model_output(x, t) 219 | if len(old_eps) == 0: 220 | # Pseudo Improved Euler (2nd order) 221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 222 | e_t_next = get_model_output(x_prev, t_next) 223 | e_t_prime = (e_t + e_t_next) / 2 224 | elif len(old_eps) == 1: 225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 227 | elif len(old_eps) == 2: 228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 230 | elif len(old_eps) >= 3: 231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 233 | 234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 235 | 236 | return x_prev, pred_x0, e_t 237 | -------------------------------------------------------------------------------- /models/dit_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, num_classes, hidden_size, dropout_prob): 72 | super().__init__() 73 | use_cfg_embedding = dropout_prob > 0 74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 75 | self.num_classes = num_classes 76 | self.dropout_prob = dropout_prob 77 | 78 | def token_drop(self, labels, force_drop_ids=None): 79 | """ 80 | Drops labels to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | labels = torch.where(drop_ids, self.num_classes, labels) 87 | return labels 88 | 89 | def forward(self, labels, train, force_drop_ids=None): 90 | use_dropout = self.dropout_prob > 0 91 | if (train and use_dropout) or (force_drop_ids is not None): 92 | labels = self.token_drop(labels, force_drop_ids) 93 | embeddings = self.embedding_table(labels) 94 | return embeddings 95 | 96 | 97 | ################################################################################# 98 | # Core DiT Model # 99 | ################################################################################# 100 | 101 | class DiTBlock(nn.Module): 102 | """ 103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 104 | """ 105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 106 | super().__init__() 107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 111 | approx_gelu = lambda: nn.GELU(approximate="tanh") 112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 113 | self.adaLN_modulation = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 116 | ) 117 | 118 | def forward(self, x, c): 119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 122 | return x 123 | 124 | 125 | class FinalLayer(nn.Module): 126 | """ 127 | The final layer of DiT. 128 | """ 129 | def __init__(self, hidden_size, patch_size, out_channels): 130 | super().__init__() 131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 133 | self.adaLN_modulation = nn.Sequential( 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 136 | ) 137 | 138 | def forward(self, x, c): 139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 140 | x = modulate(self.norm_final(x), shift, scale) 141 | x = self.linear(x) 142 | return x 143 | 144 | 145 | class DiT(nn.Module): 146 | """ 147 | Diffusion model with a Transformer backbone. 148 | """ 149 | def __init__( 150 | self, 151 | input_size=32, 152 | patch_size=2, 153 | in_channels=4, 154 | hidden_size=1152, 155 | depth=28, 156 | num_heads=16, 157 | mlp_ratio=4.0, 158 | class_dropout_prob=0.1, 159 | num_classes=1000, 160 | learn_sigma=True, 161 | ): 162 | super().__init__() 163 | self.learn_sigma = learn_sigma 164 | self.in_channels = in_channels 165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 166 | self.patch_size = patch_size 167 | self.num_heads = num_heads 168 | 169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 170 | self.t_embedder = TimestepEmbedder(hidden_size) 171 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 172 | num_patches = self.x_embedder.num_patches 173 | # Will use fixed sin-cos embedding: 174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 175 | 176 | self.blocks = nn.ModuleList([ 177 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 178 | ]) 179 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 180 | self.initialize_weights() 181 | 182 | def initialize_weights(self): 183 | # Initialize transformer layers: 184 | def _basic_init(module): 185 | if isinstance(module, nn.Linear): 186 | torch.nn.init.xavier_uniform_(module.weight) 187 | if module.bias is not None: 188 | nn.init.constant_(module.bias, 0) 189 | self.apply(_basic_init) 190 | 191 | # Initialize (and freeze) pos_embed by sin-cos embedding: 192 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 193 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 194 | 195 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 196 | w = self.x_embedder.proj.weight.data 197 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 198 | nn.init.constant_(self.x_embedder.proj.bias, 0) 199 | 200 | # Initialize label embedding table: 201 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 202 | 203 | # Initialize timestep embedding MLP: 204 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 205 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 206 | 207 | # Zero-out adaLN modulation layers in DiT blocks: 208 | for block in self.blocks: 209 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 210 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 211 | 212 | # Zero-out output layers: 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 214 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 215 | nn.init.constant_(self.final_layer.linear.weight, 0) 216 | nn.init.constant_(self.final_layer.linear.bias, 0) 217 | 218 | def unpatchify(self, x): 219 | """ 220 | x: (N, T, patch_size**2 * C) 221 | imgs: (N, H, W, C) 222 | """ 223 | c = self.out_channels 224 | p = self.x_embedder.patch_size[0] 225 | h = w = int(x.shape[1] ** 0.5) 226 | assert h * w == x.shape[1] 227 | 228 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 229 | x = torch.einsum('nhwpqc->nchpwq', x) 230 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 231 | return imgs 232 | 233 | def ckpt_wrapper(self, module): 234 | def ckpt_forward(*inputs): 235 | outputs = module(*inputs) 236 | return outputs 237 | return ckpt_forward 238 | 239 | def forward(self, x, t, y): 240 | """ 241 | Forward pass of DiT. 242 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 243 | t: (N,) tensor of diffusion timesteps 244 | y: (N,) tensor of class labels 245 | """ 246 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 247 | t = self.t_embedder(t) # (N, D) 248 | y = self.y_embedder(y, self.training) # (N, D) 249 | c = t + y # (N, D) 250 | for block in self.blocks: 251 | x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, c) # (N, T, D) 252 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 253 | x = self.unpatchify(x) # (N, out_channels, H, W) 254 | return x 255 | 256 | def forward_with_cfg(self, x, t, y, cfg_scale): 257 | """ 258 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 259 | """ 260 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 261 | half = x[: len(x) // 2] 262 | combined = torch.cat([half, half], dim=0) 263 | model_out = self.forward(combined, t, y) 264 | # For exact reproducibility reasons, we apply classifier-free guidance on only 265 | # three channels by default. The standard approach to cfg applies it to all channels. 266 | # This can be done by uncommenting the following line and commenting-out the line following that. 267 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 268 | eps, rest = model_out[:, :3], model_out[:, 3:] 269 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 270 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 271 | eps = torch.cat([half_eps, half_eps], dim=0) 272 | return torch.cat([eps, rest], dim=1) 273 | 274 | 275 | ################################################################################# 276 | # Sine/Cosine Positional Embedding Functions # 277 | ################################################################################# 278 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 279 | 280 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 281 | """ 282 | grid_size: int of the grid height and width 283 | return: 284 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 285 | """ 286 | grid_h = np.arange(grid_size, dtype=np.float32) 287 | grid_w = np.arange(grid_size, dtype=np.float32) 288 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 289 | grid = np.stack(grid, axis=0) 290 | 291 | grid = grid.reshape([2, 1, grid_size, grid_size]) 292 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 293 | if cls_token and extra_tokens > 0: 294 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 295 | return pos_embed 296 | 297 | 298 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 299 | assert embed_dim % 2 == 0 300 | 301 | # use half of dimensions to encode grid_h 302 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 303 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 304 | 305 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 306 | return emb 307 | 308 | 309 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 310 | """ 311 | embed_dim: output dimension for each position 312 | pos: a list of positions to be encoded: size (M,) 313 | out: (M, D) 314 | """ 315 | assert embed_dim % 2 == 0 316 | omega = np.arange(embed_dim // 2, dtype=np.float64) 317 | omega /= embed_dim / 2. 318 | omega = 1. / 10000**omega # (D/2,) 319 | 320 | pos = pos.reshape(-1) # (M,) 321 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 322 | 323 | emb_sin = np.sin(out) # (M, D/2) 324 | emb_cos = np.cos(out) # (M, D/2) 325 | 326 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 327 | return emb 328 | 329 | 330 | ################################################################################# 331 | # DiT Configs # 332 | ################################################################################# 333 | 334 | def DiT_XL_2(**kwargs): 335 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 336 | 337 | def DiT_XL_4(**kwargs): 338 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 339 | 340 | def DiT_XL_8(**kwargs): 341 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 342 | 343 | def DiT_L_2(**kwargs): 344 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 345 | 346 | def DiT_L_4(**kwargs): 347 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 348 | 349 | def DiT_L_8(**kwargs): 350 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 351 | 352 | def DiT_B_2(**kwargs): 353 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 354 | 355 | def DiT_B_4(**kwargs): 356 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 357 | 358 | def DiT_B_8(**kwargs): 359 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 360 | 361 | def DiT_S_2(**kwargs): 362 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 363 | 364 | def DiT_S_4(**kwargs): 365 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 366 | 367 | def DiT_S_8(**kwargs): 368 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 369 | 370 | 371 | DiT_models = { 372 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 373 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 374 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 375 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 376 | } --------------------------------------------------------------------------------