├── U-ViT
├── assets
│ ├── fid_stats
│ └── stable-diffusion
├── libs
│ ├── __init__.py
│ ├── clip.py
│ ├── timm.py
│ ├── uvit_t2i.py
│ ├── uvit.py
│ ├── uvit_dynamic.py
│ └── uvit_router.py
├── u-vit.gif
├── ckpt
│ ├── dpm20_router.pth
│ └── dpm50_router.pth
├── .gitignore
├── fid.py
├── configs
│ ├── imagenet256_uvit_huge.py
│ ├── imagenet256_uvit_huge_dynamic_cache.py
│ └── imagenet256_uvit_huge_router.py
├── readme.md
├── tools
│ ├── read_npz.py
│ ├── fid_score.py
│ └── inception.py
├── eval.py
├── sample_ldm_discrete.py
├── eval_ldm_discrete.py
├── utils.py
├── sde.py
└── train_router_discrete.py
├── DiT
├── requirement.txt
├── assets
│ └── dit.gif
├── .gitignore
├── ckpt
│ ├── DDIM20_router.pt
│ └── DDIM50_router.pt
├── diffusion
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── respace.py
│ └── timestep_sampler.py
├── download.py
├── README.md
├── sample.py
├── sample_ddp.py
├── models
│ └── models.py
└── train_router.py
├── assets
├── teaser.png
├── dit_baseline.png
└── uvit_baseline.png
└── README.md
/U-ViT/assets/fid_stats:
--------------------------------------------------------------------------------
1 | ../../../U-ViT/assets/fid_stats
--------------------------------------------------------------------------------
/U-ViT/libs/__init__.py:
--------------------------------------------------------------------------------
1 | # codes from third party
2 |
--------------------------------------------------------------------------------
/U-ViT/assets/stable-diffusion:
--------------------------------------------------------------------------------
1 | ../../../U-ViT/assets/stable-diffusion
--------------------------------------------------------------------------------
/DiT/requirement.txt:
--------------------------------------------------------------------------------
1 | pytorch
2 | torchvision
3 | timm
4 | diffusers
5 | accelerate
--------------------------------------------------------------------------------
/U-ViT/u-vit.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/U-ViT/u-vit.gif
--------------------------------------------------------------------------------
/DiT/assets/dit.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/DiT/assets/dit.gif
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/DiT/.gitignore:
--------------------------------------------------------------------------------
1 | pretrained_models/
2 | *.png
3 | __pycache__
4 | *.pb
5 | samples/
6 | results/
7 | wandb/
--------------------------------------------------------------------------------
/assets/dit_baseline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/assets/dit_baseline.png
--------------------------------------------------------------------------------
/DiT/ckpt/DDIM20_router.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/DiT/ckpt/DDIM20_router.pt
--------------------------------------------------------------------------------
/DiT/ckpt/DDIM50_router.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/DiT/ckpt/DDIM50_router.pt
--------------------------------------------------------------------------------
/assets/uvit_baseline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/assets/uvit_baseline.png
--------------------------------------------------------------------------------
/U-ViT/ckpt/dpm20_router.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/U-ViT/ckpt/dpm20_router.pth
--------------------------------------------------------------------------------
/U-ViT/ckpt/dpm50_router.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/U-ViT/ckpt/dpm50_router.pth
--------------------------------------------------------------------------------
/U-ViT/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | assets/fid_stats/*.npz
3 | assets/stable-diffusion/*.pth
4 | imagenet256_uvit_huge.pth
5 | samples
6 | *.png
7 | workdir
--------------------------------------------------------------------------------
/U-ViT/fid.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import sys
3 |
4 | if __name__ == '__main__':
5 | sample_npz_path = sys.argv[1]
6 | res = sys.argv[2]
7 |
8 | if res == '256':
9 | ref_path = 'assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz'
10 | elif res == '512':
11 | ref_path = 'assets/fid_stats/fid_stats_imagenet512_guided_diffusion.npz'
12 | else:
13 | raise NotImplementedError
14 | fid_value = calculate_fid_given_paths([ref_path, sample_npz_path], batch_size=1000)
15 | print(fid_value)
16 |
17 |
--------------------------------------------------------------------------------
/U-ViT/libs/clip.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import CLIPTokenizer, CLIPTextModel
3 |
4 |
5 | class AbstractEncoder(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def encode(self, *args, **kwargs):
10 | raise NotImplementedError
11 |
12 |
13 | class FrozenCLIPEmbedder(AbstractEncoder):
14 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
16 | super().__init__()
17 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
18 | self.transformer = CLIPTextModel.from_pretrained(version)
19 | self.device = device
20 | self.max_length = max_length
21 | self.freeze()
22 |
23 | def freeze(self):
24 | self.transformer = self.transformer.eval()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def forward(self, text):
29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
31 | tokens = batch_encoding["input_ids"].to(self.device)
32 | outputs = self.transformer(input_ids=tokens)
33 |
34 | z = outputs.last_hidden_state
35 | return z
36 |
37 | def encode(self, text):
38 | return self(text)
39 |
--------------------------------------------------------------------------------
/DiT/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | # rescale_timesteps=rescale_timesteps,
46 | )
47 |
--------------------------------------------------------------------------------
/U-ViT/configs/imagenet256_uvit_huge.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 32, 32)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | config.train = d(
21 | n_steps=500000,
22 | batch_size=32,
23 | mode='cond',
24 | log_interval=10,
25 | eval_interval=5000,
26 | save_interval=50000,
27 | )
28 |
29 | config.optimizer = d(
30 | name='adamw',
31 | lr=0.0002,
32 | weight_decay=0.03,
33 | betas=(0.99, 0.99),
34 | )
35 |
36 | config.lr_scheduler = d(
37 | name='customized',
38 | warmup_steps=5000
39 | )
40 |
41 | config.nnet = d(
42 | name='uvit',
43 | img_size=32,
44 | patch_size=2,
45 | in_chans=4,
46 | embed_dim=1152,
47 | depth=28,
48 | num_heads=16,
49 | mlp_ratio=4,
50 | qkv_bias=False,
51 | mlp_time_embed=False,
52 | num_classes=1001,
53 | use_checkpoint=True,
54 | conv=False
55 | )
56 |
57 | config.dataset = d(
58 | name='imagenet256_features',
59 | path='assets/datasets/imagenet256_features',
60 | cfg=True,
61 | p_uncond=0.1
62 | )
63 |
64 | config.sample = d(
65 | n_samples=50000,
66 | mini_batch_size=50, # the decoder is large
67 | algorithm='dpm_solver',
68 | cfg=True,
69 | scale=0.4,
70 | path=''
71 | )
72 |
73 | return config
74 |
--------------------------------------------------------------------------------
/U-ViT/configs/imagenet256_uvit_huge_dynamic_cache.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 32, 32)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | config.train = d(
21 | n_steps=500000,
22 | batch_size=1024,
23 | mode='cond',
24 | log_interval=10,
25 | eval_interval=5000,
26 | save_interval=50000,
27 | )
28 |
29 | config.optimizer = d(
30 | name='adamw',
31 | lr=0.0002,
32 | weight_decay=0.03,
33 | betas=(0.99, 0.99),
34 | )
35 |
36 | config.lr_scheduler = d(
37 | name='customized',
38 | warmup_steps=5000
39 | )
40 |
41 | config.nnet = d(
42 | name='uvit_dynamic',
43 | img_size=32,
44 | patch_size=2,
45 | in_chans=4,
46 | embed_dim=1152,
47 | depth=28,
48 | num_heads=16,
49 | mlp_ratio=4,
50 | qkv_bias=False,
51 | mlp_time_embed=False,
52 | num_classes=1001,
53 | use_checkpoint=True,
54 | conv=False
55 | )
56 |
57 | config.dataset = d(
58 | name='imagenet256_features',
59 | path='assets/datasets/imagenet256_features',
60 | cfg=True,
61 | p_uncond=0.1
62 | )
63 |
64 | config.sample = d(
65 | n_samples=50000,
66 | mini_batch_size=50, # the decoder is large
67 | algorithm='dpm_solver',
68 | cfg=True,
69 | scale=0.4,
70 | path='',
71 | dynamic=True
72 | )
73 |
74 | return config
75 |
--------------------------------------------------------------------------------
/U-ViT/configs/imagenet256_uvit_huge_router.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 32, 32)
15 |
16 | config.nnet_path='imagenet256_uvit_huge.pth'
17 | config.autoencoder = d(
18 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
19 | )
20 |
21 | config.train = d(
22 | n_steps=40000,
23 | batch_size=64,
24 | mode='cond',
25 | log_interval=100,
26 | eval_interval=5000,
27 | save_interval=1000,
28 | )
29 |
30 | config.optimizer = d(
31 | name='adamw',
32 | lr=0.0002,
33 | weight_decay=0.03,
34 | betas=(0.99, 0.99),
35 | )
36 |
37 | config.lr_scheduler = d(
38 | name='customized',
39 | warmup_steps=5000
40 | )
41 |
42 | config.nnet = d(
43 | name='uvit_router',
44 | img_size=32,
45 | patch_size=2,
46 | in_chans=4,
47 | embed_dim=1152,
48 | depth=28,
49 | num_heads=16,
50 | mlp_ratio=4,
51 | qkv_bias=False,
52 | mlp_time_embed=False,
53 | num_classes=1001,
54 | use_checkpoint=True,
55 | conv=False
56 | )
57 |
58 | config.dataset = d(
59 | name='imagenet',
60 | path='PATH_TO_IMAGENET',
61 | resolution=256,
62 | cfg=True,
63 | p_uncond=0.1
64 | )
65 |
66 | config.sample = d(
67 | n_samples=50000,
68 | mini_batch_size=50, # the decoder is large
69 | algorithm='dpm_solver',
70 | cfg=True,
71 | scale=0.4,
72 | path=''
73 | )
74 |
75 | return config
76 |
--------------------------------------------------------------------------------
/DiT/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Functions for downloading pre-trained DiT models
9 | """
10 | from torchvision.datasets.utils import download_url
11 | import torch
12 | import os
13 |
14 |
15 | pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}
16 |
17 |
18 | def find_model(model_name):
19 | """
20 | Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
21 | """
22 | if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
23 | return download_model(model_name)
24 | else: # Load a custom DiT checkpoint:
25 | assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
26 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
27 | if "ema" in checkpoint: # supports checkpoints from train.py
28 | checkpoint = checkpoint["ema"]
29 | return checkpoint
30 |
31 |
32 | def download_model(model_name):
33 | """
34 | Downloads a pre-trained DiT model from the web.
35 | """
36 | assert model_name in pretrained_models
37 | local_path = f'pretrained_models/{model_name}'
38 | if not os.path.isfile(local_path):
39 | os.makedirs('pretrained_models', exist_ok=True)
40 | web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
41 | download_url(web_path, 'pretrained_models')
42 | model = torch.load(local_path, map_location=lambda storage, loc: storage)
43 | return model
44 |
45 |
46 | if __name__ == "__main__":
47 | # Download all DiT checkpoints
48 | for model in pretrained_models:
49 | download_model(model)
50 | print('Done.')
51 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching
2 |
3 |

4 |
5 |
6 | (Results on DiT-XL/2 and U-ViT-H/2)
7 |
8 |
9 |
10 |
11 | > **Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching** 🥯[[Arxiv]](https://arxiv.org/abs/2406.01733)
12 | > [Xinyin Ma](https://horseee.github.io/), [Gongfan Fang](https://fangggf.github.io/), [Michael Bi Mi](), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
13 | > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore, Huawei Technologies Ltd
14 |
15 |
16 |
17 |
18 | ## Introduction
19 | We introduce a novel scheme, named **L**earning-to-**C**ache (L2C), that learns to conduct caching in a dynamic manner for diffusion transformers. A router is optimized to decide the layers to be cached.
20 |
21 |
22 |

23 |
24 |
25 | (Changes in the router for U-ViT when optimizing across different layers (x-axis) over all steps (y-axis). The white indicates the layer is activated, while the black indicates it is disabled.)
26 |
27 |
28 |
29 |
30 | **Some takeaways**:
31 |
32 | 1. A large proportion of layers in the diffusion transformer can be removed, without updating the model parameters.
33 | - In U-ViT-H/2, up to 93.68% of the layers in the cache steps (46.84% for all steps) can be removed, with less than 0.01 drop in FID.
34 |
35 | 2. L2C largely outperforms samplers such as DDIM and DPM-Solver.
36 |
37 |
38 |

39 |

40 |
41 |
42 | (Comparison with Baselines. Left: DiT-XL/2. Right: U-ViT-H/2)
43 |
44 |
45 |
46 | ## Checkpoint for Routers
47 | | Model | NFE | Checkpoint |
48 | | -- | -- | -- |
49 | | DiT-XL/2 | 50 | [link](DiT/ckpt/DDIM50_router.pt) |
50 | | DiT-XL/2 | 20 | [link](DiT/ckpt/DDIM20_router.pt) |
51 | | U-ViT-H/2 | 50 | [link](U-ViT/ckpt/dpm50_router.pth) |
52 | | U-ViT-H/2 | 20 | [link](U-ViT/ckpt/dpm20_router.pth)|
53 |
54 | ## Code
55 | We implement Learning-to-Cache on two basic structures: DiT and U-ViT. Check the instructions below:
56 |
57 | 1. DiT: [README](https://github.com/horseee/learning-to-cache/tree/main/DiT#learning-to-cache-for-dit)
58 | 2. U-ViT: [README](https://github.com/horseee/learning-to-cache/blob/main/U-ViT/readme.md)
59 |
60 | ## Citation
61 | ```
62 | @misc{ma2024learningtocache,
63 | title={Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching},
64 | author={Xinyin Ma and Gongfan Fang and Michael Bi Mi and Xinchao Wang},
65 | year={2024},
66 | eprint={2406.01733},
67 | archivePrefix={arXiv},
68 | primaryClass={cs.LG}
69 | }
70 | ```
71 |
--------------------------------------------------------------------------------
/U-ViT/readme.md:
--------------------------------------------------------------------------------
1 |
2 | ## Preparation
3 |
4 | Please follow [U-ViT](https://github.com/baofff/U-ViT) to:
5 | 1. Prepara the environment and install necessary packages
6 | 2. Download the autoencoder and the reference statistic for FID in `assets/`
7 | 3. Download the model [imagenet 256x256(U-ViT-H/2)](https://drive.google.com/file/d/13StUdrjaaSXjfqqF7M47BzPyhMAArQ4u/view?usp=share_link) and put it here.
8 |
9 | After completing the above steps, those files would be contained in the directory:
10 | ```
11 | - imagenet256_uvit_huge.pth
12 | - assets
13 | | - fid_stats
14 | | - fid_stats_imagenet256_guided_diffusion.npz
15 | | - ...
16 | | - stable-diffusion
17 | | - autoencoder_kl_ema.pth
18 | | - autoencoder_kl.pth
19 | ```
20 |
21 | ## Sample Images
22 | For 20 NFEs in DPM-Solver:
23 | ```bash
24 | python sample_ldm_discrete.py --config configs/imagenet256_uvit_huge_dynamic_cache.py --nnet_path imagenet256_uvit_huge.pth --nfe 20 --router ckpt/dpm20_router.pth --thres 0.9
25 | ```
26 |
27 | For 50 NFEs in DPM-Solver:
28 | ```bash
29 | python sample_ldm_discrete.py --config configs/imagenet256_uvit_huge_dynamic_cache.py --nnet_path imagenet256_uvit_huge.pth --nfe 50 --router ckpt/dpm50_router.pth --thres 0.9
30 | ```
31 |
32 | The code would repeat the generation for 5 times to avoid the fluctuations in the inference time. If you want to see the images without acceleration, you can use the follwing command:
33 |
34 | ```bash
35 | python sample_ldm_discrete.py --config configs/imagenet256_uvit_huge.py --nnet_path imagenet256_uvit_huge.pth --nfe 50
36 | ```
37 |
38 | ## Sample 50k Images for Evaluation
39 |
40 | ```bash
41 | export NFE=50
42 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm_discrete.py --config=configs/imagenet256_uvit_huge_dynamic_cache.py --nnet_path=imagenet256_uvit_huge.pth --config.sample.path=samples/dpm${NFE}_router --nfe=$NFE --router ckpt/dpm${NFE}_router.pth --thres 0.9
43 | ```
44 |
45 | The FID would be automatically evaluated after the images are all sampled. Be sure to modify NUM_STEPS and PATH_TO_TRAINED_ROUTER to correspond to the respective NFE steps and the location of the router.
46 |
47 | Results:
48 |
49 | | NFE | Router | FID |
50 | | -- | -- | -- |
51 | | 50 | - | 2.3728 |
52 | | 50 | ckpt/dpm50_router.pth | 2.3625 |
53 | | 20 | - | 2.5739 |
54 | | 20 | ckpt/dpm20_router.pth | 2.5809|
55 |
56 |
57 | ## Train the router
58 | Execute the following command to train the router:
59 | ```
60 | accelerate launch --multi_gpu --main_process_port 18100 --num_processes 8 --mixed_precision fp16 train_router_discrete.py --config=configs/imagenet256_uvit_huge_router.py --config.dataset.path=PATH_TO_IMAGENET --nnet_path=imagenet256_uvit_huge.pth --nfe=20 --router_lr=0.001 --l1_weight=0.1 --workdir=workdir/uvit_router_l1_0.1
61 | ```
62 | Change `PATH_TO_IMAGENET` to your path to the imagenet dataset.
63 |
64 |
65 |

66 |
67 |
68 | (Changes in the router during training)
69 |
70 |
71 |
72 |
73 | ## Acknowledgement
74 | This implementation is based on [U-ViT](https://github.com/baofff/U-ViT)
--------------------------------------------------------------------------------
/DiT/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/DiT/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Learning-to-Cache for DiT
3 |
4 | ## Requirement
5 | With pytorch(>2.0) installed, execute the following command to install necessary packages
6 | ```
7 | pip install accelerate diffusers timm torchvision wandb
8 | ```
9 |
10 | ## Sample Image
11 | For DDIM-20:
12 | ```
13 | python sample.py --model DiT-XL/2 --num-sampling-steps 20 --ddim-sample --accelerate-method dynamiclayer --path ckpt/DDIM20_router.pt --thres 0.1
14 | ```
15 |
16 | For DDIM-50:
17 | ```
18 | python sample.py --model DiT-XL/2 --num-sampling-steps 50 --ddim-sample --accelerate-method dynamiclayer --path ckpt/DDIM50_router.pt --thres 0.1
19 | ```
20 | The code would repeat the generation for 5 times to avoid the fluctuations in the inference time. If you want to see the images without acceleration, you can use the follwing command:
21 | ```
22 | python sample.py --model DiT-XL/2 --num-sampling-steps 20 --ddim-sample
23 | ```
24 |
25 | ## Sample 50k images for Evaluation
26 | If you want to reproduce the FID results from the paper, you can use the following command to sample 50k images:
27 | ```
28 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 12345 sample_ddp.py --model DiT-XL/2 --num-sampling-steps NUM_STEPS --ddim-sample --accelerate-method dynamiclayer --path PATH_TO_TRAINED_ROUTER --thres 0.1
29 | ```
30 | Be sure to modify NUM_STEPS and PATH_TO_TRAINED_ROUTER to correspond to the respective NFE steps and the location of the router.
31 |
32 | ## Calculate FID
33 | We follow DiT to evaluate FID by [the code](https://github.com/openai/guided-diffusion/tree/main/evaluations). Please install the required packages, download the pre-computed sample batches, and then run the following command:
34 | ```
35 | python evaluator.py ~/ckpt/VIRTUAL_imagenet256_labeled.npz PATH_TO_NPZ
36 | ```
37 |
38 | Results:
39 |
40 | | NFE | Router | IS | sFID | FID | Precision | Recall | Latency |
41 | | -- | -- | -- | -- | -- | -- | -- | -- |
42 | | 50 | - | 238.64 | 2.264 | 4.290 | 80.16 | 59.89 | 7.245±0.029 |
43 | | 50 | ckpt/DDIM50_router.pt | 244.14 | 2.269| 4.226| 80.91| 58.80 | 5.568±0.017 |
44 | | 20 | - | 223.49 | 3.484 | 4.892 | 78.76 | 57.07 | 2.869±0.008 |
45 | | 20 | ckpt/DDIM20_router.pt | 227.04 | 3.455| 4.644| 79.16| 55.58 | 2.261±0.005 |
46 |
47 |
48 | ## Training
49 | Here is the command for training the router. Make sure you change the PATH_TO_IMAGENET_TRAIN to your path for the training set of ImageNet.
50 | ```
51 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 12345 train_router.py --model DiT-XL/2 --data-path PATH_TO_IMAGENET_TRAIN --global-batch-size 64 --image-size 256 --ckpt-every 1000 --l1 5e-6 --lr 0.001 --wandb
52 | ```
53 | The checkpoint for the router would be saved in `results/XXX-DiT-XL-2/checkpoints`. You can also observe the changes in the router during the learning process on wandb.
54 |
55 |
56 |

57 |
58 |
59 | (Changes in the router during training)
60 |
61 |
62 |
63 | * Hyperoarameters for training the routers:
64 |
65 | | Model | DiT-XL/2 | DiT-XL/2 | DiT-XL/2 | DiT-XL/2 | DiT-L/2 | DiT-L/2 |
66 | | -- | -- | -- | -- | -- | -- | -- |
67 | | NFE | 50 | 20 | 10 | 50 | 50 | 20 |
68 | | Resolution | 256 | 256 | 256 | 512 | 256 | 256 |
69 | | - For Train | | | | | |
70 | | \lambda (--l1) | 1e-6 | 5e-6 | 1e-6 | 5e-6 | 1e-6 | 5e-6 |
71 | | learning rate (--lr) | 1e-3 | 1e-3 | 1e-3 | 1e-3 | 1e-3 | 1e-2 |
72 | | - For Inference | | | | | |
73 | | \theta (--thres) | 0.1 | 0.1 | 0.1 | 0.9 | 0.1 | 0.1 | 0.1 |
74 |
75 |
76 |
77 |
78 | ## Acknowledgement
79 | This implementation is based on [DiT](https://github.com/facebookresearch/DiT).
80 |
81 |
--------------------------------------------------------------------------------
/U-ViT/tools/read_npz.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from contextlib import contextmanager
4 | import zipfile
5 | from abc import ABC, abstractmethod
6 | from typing import Iterable, Optional, Tuple
7 |
8 | import matplotlib.pyplot as plt
9 |
10 | class NpzArrayReader(ABC):
11 | @abstractmethod
12 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
13 | pass
14 |
15 | @abstractmethod
16 | def remaining(self) -> int:
17 | pass
18 |
19 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
20 | def gen_fn():
21 | while True:
22 | batch = self.read_batch(batch_size)
23 | if batch is None:
24 | break
25 | yield batch
26 |
27 | rem = self.remaining()
28 | num_batches = rem // batch_size + int(rem % batch_size != 0)
29 | return BatchIterator(gen_fn, num_batches)
30 |
31 | class StreamingNpzArrayReader(NpzArrayReader):
32 | def __init__(self, arr_f, shape, dtype):
33 | self.arr_f = arr_f
34 | self.shape = shape
35 | self.dtype = dtype
36 | self.idx = 0
37 |
38 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
39 | if self.idx >= self.shape[0]:
40 | return None
41 |
42 | bs = min(batch_size, self.shape[0] - self.idx)
43 | self.idx += bs
44 |
45 | if self.dtype.itemsize == 0:
46 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
47 |
48 | read_count = bs * np.prod(self.shape[1:])
49 | read_size = int(read_count * self.dtype.itemsize)
50 | data = _read_bytes(self.arr_f, read_size, "array data")
51 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
52 |
53 | def remaining(self) -> int:
54 | return max(0, self.shape[0] - self.idx)
55 |
56 | class BatchIterator:
57 | def __init__(self, gen_fn, length):
58 | self.gen_fn = gen_fn
59 | self.length = length
60 |
61 | def __len__(self):
62 | return self.length
63 |
64 | def __iter__(self):
65 | return self.gen_fn()
66 |
67 | @contextmanager
68 | def open_npz_array(path: str, arr_name: str):
69 | with _open_npy_file(path, arr_name) as arr_f:
70 | version = np.lib.format.read_magic(arr_f)
71 | if version == (1, 0):
72 | header = np.lib.format.read_array_header_1_0(arr_f)
73 | elif version == (2, 0):
74 | header = np.lib.format.read_array_header_2_0(arr_f)
75 | else:
76 | yield MemoryNpzArrayReader.load(path, arr_name)
77 | return
78 | print(header)
79 | shape, fortran, dtype = header
80 | if fortran or dtype.hasobject:
81 | yield MemoryNpzArrayReader.load(path, arr_name)
82 | else:
83 | yield StreamingNpzArrayReader(arr_f, shape, dtype)
84 |
85 | @contextmanager
86 | def _open_npy_file(path: str, arr_name: str):
87 | with open(path, "rb") as f:
88 | with zipfile.ZipFile(f, "r") as zip_f:
89 | if f"{arr_name}.npy" not in zip_f.namelist():
90 | raise ValueError(f"missing {arr_name} in npz file")
91 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
92 | yield arr_f
93 |
94 | def _read_bytes(fp, size, error_template="ran out of data"):
95 | """
96 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
97 |
98 | Read from file-like object until size bytes are read.
99 | Raises ValueError if not EOF is encountered before size bytes are read.
100 | Non-blocking objects only supported if they derive from io objects.
101 | Required as e.g. ZipExtFile in python 2.6 can return less data than
102 | requested.
103 | """
104 | data = bytes()
105 | while True:
106 | # io files (default in python3) return None or raise on
107 | # would-block, python2 file will truncate, probably nothing can be
108 | # done about that. note that regular files can't be non-blocking
109 | try:
110 | r = fp.read(size - len(data))
111 | data += r
112 | if len(r) == 0 or len(data) == size:
113 | break
114 | except io.BlockingIOError:
115 | pass
116 | if len(data) != size:
117 | msg = "EOF: reading %s, expected %d bytes got %d"
118 | raise ValueError(msg % (error_template, size, len(data)))
119 | else:
120 | return data
--------------------------------------------------------------------------------
/U-ViT/libs/timm.py:
--------------------------------------------------------------------------------
1 | # code from timm 0.3.2
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | import warnings
6 |
7 |
8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | with torch.no_grad():
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | Args:
53 | tensor: an n-dimensional `torch.Tensor`
54 | mean: the mean of the normal distribution
55 | std: the standard deviation of the normal distribution
56 | a: the minimum cutoff value
57 | b: the maximum cutoff value
58 | Examples:
59 | >>> w = torch.empty(3, 5)
60 | >>> nn.init.trunc_normal_(w)
61 | """
62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63 |
64 |
65 | def drop_path(x, drop_prob: float = 0., training: bool = False):
66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67 |
68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72 | 'survival rate' as the argument.
73 |
74 | """
75 | if drop_prob == 0. or not training:
76 | return x
77 | keep_prob = 1 - drop_prob
78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80 | random_tensor.floor_() # binarize
81 | output = x.div(keep_prob) * random_tensor
82 | return output
83 |
84 |
85 | class DropPath(nn.Module):
86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87 | """
88 | def __init__(self, drop_prob=None):
89 | super(DropPath, self).__init__()
90 | self.drop_prob = drop_prob
91 |
92 | def forward(self, x):
93 | return drop_path(x, self.drop_prob, self.training)
94 |
95 |
96 | class Mlp(nn.Module):
97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
98 | super().__init__()
99 | out_features = out_features or in_features
100 | hidden_features = hidden_features or in_features
101 | self.fc1 = nn.Linear(in_features, hidden_features)
102 | self.act = act_layer()
103 | self.fc2 = nn.Linear(hidden_features, out_features)
104 | self.drop = nn.Dropout(drop)
105 |
106 | def forward(self, x):
107 | x = self.fc1(x)
108 | x = self.act(x)
109 | x = self.drop(x)
110 | x = self.fc2(x)
111 | x = self.drop(x)
112 | return x
113 |
--------------------------------------------------------------------------------
/U-ViT/eval.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | import sde
8 | from datasets import get_dataset
9 | import tempfile
10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
11 | from absl import logging
12 | import builtins
13 |
14 |
15 | def evaluate(config):
16 | if config.get('benchmark', False):
17 | torch.backends.cudnn.benchmark = True
18 | torch.backends.cudnn.deterministic = False
19 |
20 | mp.set_start_method('spawn')
21 | accelerator = accelerate.Accelerator()
22 | device = accelerator.device
23 | accelerate.utils.set_seed(config.seed, device_specific=True)
24 | logging.info(f'Process {accelerator.process_index} using device: {device}')
25 |
26 | config.mixed_precision = accelerator.mixed_precision
27 | config = ml_collections.FrozenConfigDict(config)
28 | if accelerator.is_main_process:
29 | utils.set_logger(log_level='info', fname=config.output_path)
30 | else:
31 | utils.set_logger(log_level='error')
32 | builtins.print = lambda *args: None
33 |
34 | dataset = get_dataset(**config.dataset)
35 |
36 | nnet = utils.get_nnet(**config.nnet)
37 | nnet = accelerator.prepare(nnet)
38 | logging.info(f'load nnet from {config.nnet_path}')
39 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
40 | nnet.eval()
41 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
42 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
43 | def cfg_nnet(x, timesteps, y):
44 | _cond = nnet(x, timesteps, y=y)
45 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
46 | return _cond + config.sample.scale * (_cond - _uncond)
47 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE())
48 | else:
49 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
50 |
51 |
52 | logging.info(config.sample)
53 | assert os.path.exists(dataset.fid_stat)
54 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
55 |
56 | def sample_fn(_n_samples):
57 | x_init = torch.randn(_n_samples, *dataset.data_shape, device=device)
58 | if config.train.mode == 'uncond':
59 | kwargs = dict()
60 | elif config.train.mode == 'cond':
61 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
62 | else:
63 | raise NotImplementedError
64 |
65 | if config.sample.algorithm == 'euler_maruyama_sde':
66 | rsde = sde.ReverseSDE(score_model)
67 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
68 | elif config.sample.algorithm == 'euler_maruyama_ode':
69 | rsde = sde.ODE(score_model)
70 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
71 | elif config.sample.algorithm == 'dpm_solver':
72 | noise_schedule = NoiseScheduleVP(schedule='linear')
73 | model_fn = model_wrapper(
74 | score_model.noise_pred,
75 | noise_schedule,
76 | time_input_type='0',
77 | model_kwargs=kwargs
78 | )
79 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
80 | return dpm_solver.sample(
81 | x_init,
82 | steps=config.sample.sample_steps,
83 | eps=1e-4,
84 | adaptive_step_size=False,
85 | fast_version=True,
86 | )
87 | else:
88 | raise NotImplementedError
89 |
90 | with tempfile.TemporaryDirectory() as temp_path:
91 | path = config.sample.path or temp_path
92 | if accelerator.is_main_process:
93 | os.makedirs(path, exist_ok=True)
94 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
95 | if accelerator.is_main_process:
96 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
97 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
98 |
99 | from absl import flags
100 | from absl import app
101 | from ml_collections import config_flags
102 | import os
103 |
104 |
105 | FLAGS = flags.FLAGS
106 | config_flags.DEFINE_config_file(
107 | "config", None, "Training configuration.", lock_config=False)
108 | flags.mark_flags_as_required(["config"])
109 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
110 | flags.DEFINE_string("output_path", None, "The path to output log.")
111 |
112 |
113 | def main(argv):
114 | config = FLAGS.config
115 | config.nnet_path = FLAGS.nnet_path
116 | config.output_path = FLAGS.output_path
117 | evaluate(config)
118 |
119 |
120 | if __name__ == "__main__":
121 | app.run(main)
122 |
--------------------------------------------------------------------------------
/DiT/sample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Sample new images from a pre-trained DiT.
9 | """
10 | import torch
11 | torch.backends.cuda.matmul.allow_tf32 = True
12 | torch.backends.cudnn.allow_tf32 = True
13 | from torchvision.utils import save_image
14 | from diffusion import create_diffusion
15 | from diffusers.models import AutoencoderKL
16 | from download import find_model
17 |
18 | import argparse
19 | import numpy as np
20 |
21 |
22 | def main(args):
23 | # Setup PyTorch:
24 | torch.set_grad_enabled(False)
25 | device = "cuda" if torch.cuda.is_available() else "cpu"
26 |
27 | if args.ckpt is None:
28 | assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download."
29 | assert args.image_size in [256, 512]
30 | assert args.num_classes == 1000
31 |
32 | # initialize diffusin process
33 | diffusion = create_diffusion(str(args.num_sampling_steps))
34 |
35 | # Load model:
36 | latent_size = args.image_size // 8
37 | if args.accelerate_method is not None and args.accelerate_method == "dynamiclayer":
38 | from models.dynamic_models import DiT_models
39 | else:
40 | from models.models import DiT_models
41 |
42 | model = DiT_models[args.model](
43 | input_size=latent_size,
44 | num_classes=args.num_classes
45 | ).to(device)
46 |
47 | if args.accelerate_method is not None and 'dynamiclayer' in args.accelerate_method:
48 | model.load_ranking(args.path, args.num_sampling_steps, diffusion.timestep_map, args.thres)
49 |
50 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py:
51 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
52 | state_dict = find_model(ckpt_path)
53 | model.load_state_dict(state_dict)
54 | model.eval() # important!
55 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
56 |
57 | torch.manual_seed(args.seed)
58 | # Labels to condition the model with (feel free to change):
59 | class_labels = [207, 992, 387, 974, 142, 979, 417, 279]
60 |
61 | # Create sampling noise:
62 | n = len(class_labels)
63 | z = torch.randn(n, 4, latent_size, latent_size, device=device)
64 | y = torch.tensor(class_labels, device=device)
65 |
66 | # Setup classifier-free guidance:
67 | z = torch.cat([z, z], 0)
68 | y_null = torch.tensor([1000] * n, device=device)
69 | y = torch.cat([y, y_null], 0)
70 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
71 |
72 | # Sample images:
73 | import time
74 | times = []
75 | for _ in range(6):
76 | start_time = time.time()
77 | if args.p_sample:
78 | samples = diffusion.p_sample_loop(
79 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
80 | )
81 | elif args.ddim_sample:
82 | samples = diffusion.ddim_sample_loop(
83 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
84 | )
85 | times.append(time.time() - start_time)
86 | model.reset()
87 |
88 | if len(times) > 1:
89 | times = np.array(times[1:])
90 | print("Sampling time: {:.3f}±{:.3f}".format(np.mean(times), np.std(times)))
91 |
92 |
93 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
94 | samples = vae.decode(samples / 0.18215).sample
95 | save_image(samples, f"Sample_NFE{args.num_sampling_steps}_Method_{args.accelerate_method}.png", nrow=8, normalize=True, value_range=(-1, 1))
96 |
97 | if __name__ == "__main__":
98 | parser = argparse.ArgumentParser()
99 | parser.add_argument("--model", type=str, default="DiT-XL/2")
100 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
101 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
102 | parser.add_argument("--num-classes", type=int, default=1000)
103 | parser.add_argument("--cfg-scale", type=float, default=4.0)
104 | parser.add_argument("--num-sampling-steps", type=int, default=250)
105 | parser.add_argument("--seed", type=int, default=0)
106 | parser.add_argument("--ckpt", type=str, default=None,
107 | help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
108 | parser.add_argument("--accelerate-method", type=str, default=None,
109 | help="Use the accelerated version of the model.")
110 |
111 | parser.add_argument("--ddim-sample", action="store_true", default=False,)
112 | parser.add_argument("--p-sample", action="store_true", default=False,)
113 |
114 | parser.add_argument("--path", type=str, default=None,
115 | help="Optional path to a router checkpoint")
116 | parser.add_argument("--thres", type=float, default=0.5)
117 |
118 | args = parser.parse_args()
119 | main(args)
120 |
--------------------------------------------------------------------------------
/DiT/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/DiT/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/U-ViT/sample_ldm_discrete.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | from datasets import get_dataset
8 | import tempfile
9 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
10 | from absl import logging
11 | import builtins
12 | import libs.autoencoder
13 | from torchvision.utils import save_image
14 | import numpy as np
15 |
16 |
17 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
18 | _betas = (
19 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
20 | )
21 | return _betas.numpy()
22 |
23 |
24 | def evaluate(config):
25 | if config.get('benchmark', False):
26 | torch.backends.cudnn.benchmark = True
27 | torch.backends.cudnn.deterministic = False
28 |
29 | mp.set_start_method('spawn')
30 | accelerator = accelerate.Accelerator()
31 | device = accelerator.device
32 | accelerate.utils.set_seed(0, device_specific=True)
33 | logging.info(f'Process {accelerator.process_index} using device: {device}')
34 |
35 | config.mixed_precision = accelerator.mixed_precision
36 | config = ml_collections.FrozenConfigDict(config)
37 | if accelerator.is_main_process:
38 | utils.set_logger(log_level='info', fname=config.output_path)
39 | else:
40 | utils.set_logger(log_level='error')
41 | builtins.print = lambda *args: None
42 |
43 | dataset = get_dataset(**config.dataset)
44 |
45 | nnet = utils.get_nnet(**config.nnet)
46 | nnet = accelerator.prepare(nnet)
47 | logging.info(f'load nnet from {config.nnet_path}')
48 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
49 | nnet.eval()
50 |
51 | if 'dynamic' in config.sample:
52 | # Get Timestep Mapping
53 | t_0 = 1. / 1000
54 | t_T = 1.0
55 | order_value = 2
56 | N_steps = config.nfe // order_value
57 | timesteps = torch.linspace(t_T, t_0, N_steps + 1).cpu().numpy()
58 | #timesteps = timesteps.numpy()
59 | timestep_mapping = np.round(timesteps * 1000)
60 | #accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping)
61 |
62 | nnet.load_ranking(config.router, config.nfe, timestep_mapping, config.thres)
63 |
64 |
65 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
66 | autoencoder.to(device)
67 |
68 | @torch.cuda.amp.autocast()
69 | def encode(_batch):
70 | return autoencoder.encode(_batch)
71 |
72 | @torch.cuda.amp.autocast()
73 | def decode(_batch):
74 | return autoencoder.decode(_batch)
75 |
76 | def decode_large_batch(_batch):
77 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large
78 | xs = []
79 | pt = 0
80 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
81 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
82 | pt += _decode_mini_batch_size
83 | xs.append(x)
84 | xs = torch.concat(xs, dim=0)
85 | assert xs.size(0) == _batch.size(0)
86 | return xs
87 |
88 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
89 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
90 | def cfg_nnet(x, timesteps, y):
91 | _cond = nnet(x, timesteps, y=y)
92 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
93 | return _cond + config.sample.scale * (_cond - _uncond)
94 | else:
95 | def cfg_nnet(x, timesteps, y):
96 | _cond = nnet(x, timesteps, y=y)
97 | return _cond
98 |
99 | logging.info(config.sample)
100 | assert os.path.exists(dataset.fid_stat)
101 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
102 |
103 | _betas = stable_diffusion_beta_schedule()
104 | N = len(_betas)
105 |
106 | def sample_z(_n_samples, _sample_steps, **kwargs):
107 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
108 |
109 | if config.sample.algorithm == 'dpm_solver':
110 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
111 |
112 | def model_fn(x, t_continuous):
113 | t = t_continuous * N
114 |
115 | eps_pre = cfg_nnet(x, t, **kwargs)
116 | return eps_pre
117 |
118 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
119 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1., order=2, method='singlestep')
120 |
121 | else:
122 | raise NotImplementedError
123 |
124 | return _z
125 |
126 | def sample_fn(_n_samples):
127 | class_labels = [207, 992, 387, 974, 142, 979, 417, 279]
128 |
129 | if config.train.mode == 'uncond':
130 | kwargs = dict()
131 | elif config.train.mode == 'cond':
132 | kwargs = dict(y=torch.tensor(class_labels, device=device))
133 | else:
134 | raise NotImplementedError
135 | _z = sample_z(_n_samples, _sample_steps=config.nfe, **kwargs)
136 | return decode_large_batch(_z)
137 |
138 | import time
139 | use_time = []
140 | if config.teaser:
141 | samples = sample_fn(8)
142 | samples = dataset.unpreprocess(samples)
143 | dynamic_flag = 'dynamic' in config.sample
144 | for idx, sample in enumerate(samples):
145 | save_image(sample, f"images/teaser_{dynamic_flag}_{idx}.png", nrow=1)
146 | else:
147 | logging.info("Start sampling")
148 | for _ in range(6):
149 | start_time = time.time()
150 | samples = sample_fn(8)
151 | use_time.append(time.time() - start_time)
152 | samples = dataset.unpreprocess(samples)
153 | nnet.reset()
154 |
155 | times = np.array(use_time[1:])
156 | logging.info("Sampling time: {:.2f}±{:.2f}".format(np.mean(times), np.std(times)))
157 | save_image(samples, "u-vit-H-2.png", nrow=8)
158 |
159 |
160 | from absl import flags
161 | from absl import app
162 | from ml_collections import config_flags
163 | import os
164 |
165 |
166 | FLAGS = flags.FLAGS
167 | config_flags.DEFINE_config_file(
168 | "config", None, "Training configuration.", lock_config=False)
169 | flags.mark_flags_as_required(["config"])
170 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
171 | flags.DEFINE_string("output_path", None, "The path to output log.")
172 | flags.DEFINE_string("nfe", None, "number of evaluation")
173 |
174 | flags.DEFINE_string("router", None, "path of router")
175 | flags.DEFINE_string("thres", "0", "threshold of router")
176 |
177 | flags.DEFINE_boolean("teaser", False, "generate teaser image")
178 |
179 |
180 | def main(argv):
181 | config = FLAGS.config
182 | config.nnet_path = FLAGS.nnet_path
183 | config.output_path = FLAGS.output_path
184 | config.teaser = FLAGS.teaser
185 | config.nfe = int(FLAGS.nfe)
186 | config.router = FLAGS.router
187 | config.thres = float(FLAGS.thres)
188 | evaluate(config)
189 |
190 |
191 | if __name__ == "__main__":
192 | app.run(main)
193 |
--------------------------------------------------------------------------------
/U-ViT/eval_ldm_discrete.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | import numpy as np
5 | from torch import multiprocessing as mp
6 | import accelerate
7 | import utils
8 | from datasets import get_dataset
9 | import tempfile
10 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
11 | from absl import logging
12 | import builtins
13 | import libs.autoencoder
14 |
15 |
16 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
17 | _betas = (
18 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
19 | )
20 | return _betas.numpy()
21 |
22 |
23 | def evaluate(config):
24 | if config.get('benchmark', False):
25 | torch.backends.cudnn.benchmark = True
26 | torch.backends.cudnn.deterministic = False
27 |
28 | mp.set_start_method('spawn')
29 | accelerator = accelerate.Accelerator()
30 | device = accelerator.device
31 | accelerate.utils.set_seed(config.seed, device_specific=True)
32 | logging.info(f'Process {accelerator.process_index} using device: {device}')
33 |
34 | config.mixed_precision = accelerator.mixed_precision
35 | config = ml_collections.FrozenConfigDict(config)
36 | if accelerator.is_main_process:
37 | utils.set_logger(log_level='info', fname=config.output_path)
38 | else:
39 | utils.set_logger(log_level='error')
40 | builtins.print = lambda *args: None
41 |
42 | dataset = get_dataset(**config.dataset)
43 |
44 | nnet = utils.get_nnet(**config.nnet)
45 | nnet = accelerator.prepare(nnet)
46 | logging.info(f'load nnet from {config.nnet_path}')
47 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
48 | nnet.eval()
49 |
50 | if 'dynamic' in config.sample:
51 | # Get Timestep Mapping
52 | t_0 = 1. / 1000
53 | t_T = 1.0
54 | order_value = 2
55 | N_steps = config.nfe // order_value
56 | timesteps = torch.linspace(t_T, t_0, N_steps + 1).cpu().numpy()
57 | #timesteps = timesteps.numpy()
58 | timestep_mapping = np.round(timesteps * 1000)
59 | #accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping)
60 |
61 | accelerator.unwrap_model(nnet).load_ranking(config.router, config.nfe, timestep_mapping, config.thres)
62 | elif 'rank' in config.sample:
63 | # Get Timestep Mapping
64 | t_0 = 1. / 1000
65 | t_T = 1.0
66 | order_value = 2
67 | N_steps = config.nfe // order_value
68 | timesteps = torch.linspace(t_T, t_0, N_steps + 1).cpu().numpy()
69 | #timesteps = timesteps.numpy()
70 | timestep_mapping = np.round(timesteps * 1000)
71 | #accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping)
72 |
73 | accelerator.unwrap_model(nnet).load_ranking(config.nfe, config.thres)
74 |
75 | elif 'topk' in config.sample or 'random' in config.sample:
76 | accelerator.unwrap_model(nnet).load_ranking(config.sample.topk, config.sample.reverse, config.sample.random)
77 |
78 |
79 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
80 | autoencoder.to(device)
81 |
82 | @torch.cuda.amp.autocast()
83 | def encode(_batch):
84 | return autoencoder.encode(_batch)
85 |
86 | @torch.cuda.amp.autocast()
87 | def decode(_batch):
88 | return autoencoder.decode(_batch)
89 |
90 | def decode_large_batch(_batch):
91 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large
92 | xs = []
93 | pt = 0
94 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
95 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
96 | pt += _decode_mini_batch_size
97 | xs.append(x)
98 | xs = torch.concat(xs, dim=0)
99 | assert xs.size(0) == _batch.size(0)
100 | return xs
101 |
102 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
103 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
104 | def cfg_nnet(x, timesteps, y):
105 | _cond = nnet(x, timesteps, y=y)
106 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
107 | return _cond + config.sample.scale * (_cond - _uncond)
108 | else:
109 | def cfg_nnet(x, timesteps, y):
110 | _cond = nnet(x, timesteps, y=y)
111 | return _cond
112 |
113 | logging.info(config.sample)
114 | assert os.path.exists(dataset.fid_stat)
115 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
116 |
117 | _betas = stable_diffusion_beta_schedule()
118 | N = len(_betas)
119 |
120 | def sample_z(_n_samples, _sample_steps, **kwargs):
121 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
122 |
123 | if config.sample.algorithm == 'dpm_solver':
124 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
125 |
126 | def model_fn(x, t_continuous):
127 | t = t_continuous * N
128 | eps_pre = cfg_nnet(x, t, **kwargs)
129 | return eps_pre
130 |
131 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
132 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1., order=2, method='singlestep')
133 |
134 | else:
135 | raise NotImplementedError
136 |
137 | return _z
138 |
139 | def sample_fn(_n_samples):
140 | if config.train.mode == 'uncond':
141 | kwargs = dict()
142 | elif config.train.mode == 'cond':
143 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
144 | else:
145 | raise NotImplementedError
146 | _z = sample_z(_n_samples, _sample_steps=config.nfe, **kwargs)
147 | return decode_large_batch(_z)
148 |
149 | with tempfile.TemporaryDirectory() as temp_path:
150 | path = config.sample.path or temp_path
151 | #if accelerator.is_main_process:
152 | # os.makedirs(path, exist_ok=True)
153 | logging.info(f'Samples are saved in {path}')
154 | #utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
155 | utils.sample2npz(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, reset_fn=accelerator.unwrap_model(nnet).reset)
156 |
157 | if accelerator.is_main_process:
158 | torch.cuda.empty_cache()
159 | fid = calculate_fid_given_paths((dataset.fid_stat, f"{path}.npz"), batch_size=1000)
160 | log_path = path.replace('manual_samples/', 'log/')
161 | with open(f"{log_path}.log", "a") as f:
162 | f.write(f"npz_path={path}.npz, fid={fid}")
163 | logging.info(f'npz_path={path}.npz, fid={fid}')
164 |
165 |
166 | from absl import flags
167 | from absl import app
168 | from ml_collections import config_flags
169 | import os
170 |
171 |
172 | FLAGS = flags.FLAGS
173 | config_flags.DEFINE_config_file(
174 | "config", None, "Training configuration.", lock_config=False)
175 | flags.mark_flags_as_required(["config"])
176 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
177 | flags.DEFINE_string("output_path", None, "The path to output log.")
178 | flags.DEFINE_string("nfe", None, "NFE")
179 | flags.DEFINE_string("router", None, "path of router")
180 | flags.DEFINE_string("thres", "0", "threshold of router")
181 |
182 |
183 | def main(argv):
184 | config = FLAGS.config
185 | config.nnet_path = FLAGS.nnet_path
186 | config.output_path = FLAGS.output_path
187 | config.nfe = int(FLAGS.nfe)
188 | config.thres = float(FLAGS.thres)
189 | config.router = FLAGS.router
190 |
191 | evaluate(config)
192 |
193 |
194 | if __name__ == "__main__":
195 | app.run(main)
196 |
--------------------------------------------------------------------------------
/U-ViT/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import os
5 | from tqdm import tqdm
6 | from torchvision.utils import save_image
7 | from absl import logging
8 |
9 |
10 | def set_logger(log_level='info', fname=None):
11 | import logging as _logging
12 | handler = logging.get_absl_handler()
13 | formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
14 | handler.setFormatter(formatter)
15 | logging.set_verbosity(log_level)
16 | if fname is not None:
17 | handler = _logging.FileHandler(fname)
18 | handler.setFormatter(formatter)
19 | logging.get_absl_logger().addHandler(handler)
20 |
21 |
22 | def dct2str(dct):
23 | return str({k: f'{v:.6g}' for k, v in dct.items()})
24 |
25 |
26 | def get_nnet(name, **kwargs):
27 | if name == 'uvit':
28 | from libs.uvit import UViT
29 | return UViT(**kwargs)
30 | elif name == 'uvit_t2i':
31 | from libs.uvit_t2i import UViT
32 | return UViT(**kwargs)
33 | elif name == 'uvit_timecache':
34 | from libs.uvit_timecache import UViT
35 | return UViT(**kwargs)
36 | elif name == 'uvit_router':
37 | from libs.uvit_router import UViT
38 | return UViT(**kwargs)
39 | elif name == 'uvit_dynamic':
40 | from libs.uvit_dynamic import UViT
41 | return UViT(**kwargs)
42 | elif name == 'uvit_manual':
43 | from libs.uvit_manual import UViT
44 | return UViT(**kwargs)
45 | elif name == 'uvit_deepcache':
46 | from libs.uvit_deepcache import UViT
47 | return UViT(**kwargs)
48 | elif name == 'uvit_fasterdiffusion':
49 | from libs.uvit_fasterdiffusion import UViT
50 | return UViT(**kwargs)
51 | elif name == 'uvit_analysis':
52 | from libs.uvit_analysis import UViT
53 | return UViT(**kwargs)
54 | elif name == 'uvit_ranklayer':
55 | from libs.uvit_ranklayer import UViT
56 | return UViT(**kwargs)
57 | else:
58 | raise NotImplementedError(name)
59 |
60 |
61 | def set_seed(seed: int):
62 | if seed is not None:
63 | torch.manual_seed(seed)
64 | np.random.seed(seed)
65 |
66 |
67 | def get_optimizer(params, name, **kwargs):
68 | if name == 'adam':
69 | from torch.optim import Adam
70 | return Adam(params, **kwargs)
71 | elif name == 'adamw':
72 | from torch.optim import AdamW
73 | return AdamW(params, **kwargs)
74 | else:
75 | raise NotImplementedError(name)
76 |
77 |
78 | def customized_lr_scheduler(optimizer, warmup_steps=-1):
79 | from torch.optim.lr_scheduler import LambdaLR
80 | def fn(step):
81 | if warmup_steps > 0:
82 | return min(step / warmup_steps, 1)
83 | else:
84 | return 1
85 | return LambdaLR(optimizer, fn)
86 |
87 |
88 | def get_lr_scheduler(optimizer, name, **kwargs):
89 | if name == 'customized':
90 | return customized_lr_scheduler(optimizer, **kwargs)
91 | elif name == 'cosine':
92 | from torch.optim.lr_scheduler import CosineAnnealingLR
93 | return CosineAnnealingLR(optimizer, **kwargs)
94 | else:
95 | raise NotImplementedError(name)
96 |
97 |
98 | def ema(model_dest: nn.Module, model_src: nn.Module, rate):
99 | param_dict_src = dict(model_src.named_parameters())
100 | for p_name, p_dest in model_dest.named_parameters():
101 | p_src = param_dict_src[p_name]
102 | assert p_src is not p_dest
103 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
104 |
105 |
106 | class TrainState(object):
107 | def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
108 | self.optimizer = optimizer
109 | self.lr_scheduler = lr_scheduler
110 | self.step = step
111 | self.nnet = nnet
112 | self.nnet_ema = nnet_ema
113 |
114 | #def ema_update(self, rate=0.9999):
115 | # if self.nnet_ema is not None:
116 | # ema(self.nnet_ema, self.nnet, rate)
117 |
118 | def save(self, path):
119 | os.makedirs(path, exist_ok=True)
120 | torch.save(self.step, os.path.join(path, 'step.pth'))
121 | for key, val in self.__dict__.items():
122 | if key != 'step' and 'ema' not in key and val is not None:
123 | if key == 'nnet':
124 | torch.save(val.routers.state_dict(), os.path.join(path, f'{key}.pth'))
125 | else:
126 | torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
127 |
128 | def load(self, path):
129 | logging.info(f'load from {path}')
130 | self.step = torch.load(os.path.join(path, 'step.pth'))
131 | for key, val in self.__dict__.items():
132 | if key != 'step' and val is not None:
133 | val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
134 |
135 | def resume(self, ckpt_root, step=None):
136 | if not os.path.exists(ckpt_root):
137 | return
138 | if step is None:
139 | ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
140 | if not ckpts:
141 | return
142 | steps = map(lambda x: int(x.split(".")[0]), ckpts)
143 | step = max(steps)
144 | ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
145 | logging.info(f'resume from {ckpt_path}')
146 | self.load(ckpt_path)
147 |
148 | def to(self, device):
149 | for key, val in self.__dict__.items():
150 | if isinstance(val, nn.Module):
151 | val.to(device)
152 |
153 | def update_optimizer(self, optimizer):
154 | self.optimizer = optimizer
155 |
156 |
157 | def cnt_params(model):
158 | return sum(param.numel() for param in model.parameters())
159 |
160 |
161 | def initialize_train_state(config, device):
162 | params = []
163 |
164 | nnet = get_nnet(**config.nnet)
165 | params += nnet.parameters()
166 | nnet_ema = get_nnet(**config.nnet)
167 | nnet_ema.eval()
168 | logging.info(f'nnet has {cnt_params(nnet)} parameters')
169 |
170 | optimizer = get_optimizer(params, **config.optimizer)
171 | lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
172 |
173 | train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
174 | nnet=nnet, nnet_ema=nnet_ema)
175 | #train_state.ema_update(0)
176 | train_state.to(device)
177 | return train_state
178 |
179 |
180 | def amortize(n_samples, batch_size):
181 | k = n_samples // batch_size
182 | r = n_samples % batch_size
183 | return k * [batch_size] if r == 0 else k * [batch_size] + [r]
184 |
185 |
186 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None):
187 | os.makedirs(path, exist_ok=True)
188 | idx = 0
189 | batch_size = mini_batch_size * accelerator.num_processes
190 |
191 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
192 | samples = unpreprocess_fn(sample_fn(mini_batch_size))
193 | samples = accelerator.gather(samples.contiguous())[:_batch_size]
194 | if accelerator.is_main_process:
195 | for sample in samples:
196 | save_image(sample, os.path.join(path, f"{idx}.png"))
197 | idx += 1
198 | accelerator.wait_for_everyone()
199 |
200 | def sample2npz(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, reset_fn=None):
201 | #os.makedirs(path, exist_ok=True)
202 | idx = 0
203 | batch_size = mini_batch_size * accelerator.num_processes
204 |
205 | all_images = []
206 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
207 | samples = unpreprocess_fn(sample_fn(mini_batch_size))
208 | samples = accelerator.gather(samples.contiguous())[:_batch_size]
209 | if accelerator.is_main_process:
210 | samples = samples.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
211 | all_images.append(samples)
212 |
213 | reset_fn()
214 | accelerator.wait_for_everyone()
215 |
216 |
217 | if accelerator.is_main_process:
218 | arr = np.concatenate(all_images, axis=0)
219 | arr = arr[: n_samples]
220 | out_path = f"{path}.npz"
221 |
222 | print(f"saving to {out_path}")
223 | np.savez(out_path, arr_0=arr)
224 |
225 |
226 | def grad_norm(model):
227 | total_norm = 0.
228 | for p in model.parameters():
229 | param_norm = p.grad.data.norm(2)
230 | total_norm += param_norm.item() ** 2
231 | total_norm = total_norm ** (1. / 2)
232 | return total_norm
233 |
--------------------------------------------------------------------------------
/U-ViT/libs/uvit_t2i.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
9 | ATTENTION_MODE = 'flash'
10 | else:
11 | try:
12 | import xformers
13 | import xformers.ops
14 | ATTENTION_MODE = 'xformers'
15 | except:
16 | ATTENTION_MODE = 'math'
17 | print(f'attention mode is {ATTENTION_MODE}')
18 |
19 |
20 | def timestep_embedding(timesteps, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 |
24 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
25 | These may be fractional.
26 | :param dim: the dimension of the output.
27 | :param max_period: controls the minimum frequency of the embeddings.
28 | :return: an [N x dim] Tensor of positional embeddings.
29 | """
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33 | ).to(device=timesteps.device)
34 | args = timesteps[:, None].float() * freqs[None]
35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36 | if dim % 2:
37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38 | return embedding
39 |
40 |
41 | def patchify(imgs, patch_size):
42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
43 | return x
44 |
45 |
46 | def unpatchify(x, channels=3):
47 | patch_size = int((x.shape[2] // channels) ** 0.5)
48 | h = w = int(x.shape[1] ** .5)
49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
56 | super().__init__()
57 | self.num_heads = num_heads
58 | head_dim = dim // num_heads
59 | self.scale = qk_scale or head_dim ** -0.5
60 |
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.attn_drop = nn.Dropout(attn_drop)
63 | self.proj = nn.Linear(dim, dim)
64 | self.proj_drop = nn.Dropout(proj_drop)
65 |
66 | def forward(self, x):
67 | B, L, C = x.shape
68 |
69 | qkv = self.qkv(x)
70 | if ATTENTION_MODE == 'flash':
71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
74 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
75 | elif ATTENTION_MODE == 'xformers':
76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
78 | x = xformers.ops.memory_efficient_attention(q, k, v)
79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
80 | elif ATTENTION_MODE == 'math':
81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
87 | else:
88 | raise NotImplemented
89 |
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
107 | self.use_checkpoint = use_checkpoint
108 |
109 | def forward(self, x, skip=None):
110 | if self.use_checkpoint:
111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
112 | else:
113 | return self._forward(x, skip)
114 |
115 | def _forward(self, x, skip=None):
116 | if self.skip_linear is not None:
117 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
118 | x = x + self.attn(self.norm1(x))
119 | x = x + self.mlp(self.norm2(x))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | self.patch_size = patch_size
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | assert H % self.patch_size == 0 and W % self.patch_size == 0
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class UViT(nn.Module):
139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
141 | clip_dim=768, num_clip_token=77, conv=True, skip=True):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
144 | self.in_chans = in_chans
145 |
146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
147 | num_patches = (img_size // patch_size) ** 2
148 |
149 | self.time_embed = nn.Sequential(
150 | nn.Linear(embed_dim, 4 * embed_dim),
151 | nn.SiLU(),
152 | nn.Linear(4 * embed_dim, embed_dim),
153 | ) if mlp_time_embed else nn.Identity()
154 |
155 | self.context_embed = nn.Linear(clip_dim, embed_dim)
156 |
157 | self.extras = 1 + num_clip_token
158 |
159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
160 |
161 | self.in_blocks = nn.ModuleList([
162 | Block(
163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
165 | for _ in range(depth // 2)])
166 |
167 | self.mid_block = Block(
168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
170 |
171 | self.out_blocks = nn.ModuleList([
172 | Block(
173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
175 | for _ in range(depth // 2)])
176 |
177 | self.norm = norm_layer(embed_dim)
178 | self.patch_dim = patch_size ** 2 * in_chans
179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
181 |
182 | trunc_normal_(self.pos_embed, std=.02)
183 | self.apply(self._init_weights)
184 |
185 | def _init_weights(self, m):
186 | if isinstance(m, nn.Linear):
187 | trunc_normal_(m.weight, std=.02)
188 | if isinstance(m, nn.Linear) and m.bias is not None:
189 | nn.init.constant_(m.bias, 0)
190 | elif isinstance(m, nn.LayerNorm):
191 | nn.init.constant_(m.bias, 0)
192 | nn.init.constant_(m.weight, 1.0)
193 |
194 | @torch.jit.ignore
195 | def no_weight_decay(self):
196 | return {'pos_embed'}
197 |
198 | def forward(self, x, timesteps, context):
199 | x = self.patch_embed(x)
200 | B, L, D = x.shape
201 |
202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
203 | time_token = time_token.unsqueeze(dim=1)
204 | context_token = self.context_embed(context)
205 | x = torch.cat((time_token, context_token, x), dim=1)
206 | x = x + self.pos_embed
207 |
208 | skips = []
209 | for blk in self.in_blocks:
210 | x = blk(x)
211 | skips.append(x)
212 |
213 | x = self.mid_block(x)
214 |
215 | for blk in self.out_blocks:
216 | x = blk(x, skips.pop())
217 |
218 | x = self.norm(x)
219 | x = self.decoder_pred(x)
220 | assert x.size(1) == self.extras + L
221 | x = x[:, self.extras:, :]
222 | x = unpatchify(x, self.in_chans)
223 | x = self.final_layer(x)
224 | return x
225 |
--------------------------------------------------------------------------------
/U-ViT/sde.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from absl import logging
4 | import numpy as np
5 | import math
6 | from tqdm import tqdm
7 |
8 |
9 | def get_sde(name, **kwargs):
10 | if name == 'vpsde':
11 | return VPSDE(**kwargs)
12 | elif name == 'vpsde_cosine':
13 | return VPSDECosine(**kwargs)
14 | else:
15 | raise NotImplementedError
16 |
17 |
18 | def stp(s, ts: torch.Tensor): # scalar tensor product
19 | if isinstance(s, np.ndarray):
20 | s = torch.from_numpy(s).type_as(ts)
21 | extra_dims = (1,) * (ts.dim() - 1)
22 | return s.view(-1, *extra_dims) * ts
23 |
24 |
25 | def mos(a, start_dim=1): # mean of square
26 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
27 |
28 |
29 | def duplicate(tensor, *size):
30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape)
31 |
32 |
33 | class SDE(object):
34 | r"""
35 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
36 | f(x, t) is the drift
37 | g(t) is the diffusion
38 | """
39 | def drift(self, x, t):
40 | raise NotImplementedError
41 |
42 | def diffusion(self, t):
43 | raise NotImplementedError
44 |
45 | def cum_beta(self, t): # the variance of xt|x0
46 | raise NotImplementedError
47 |
48 | def cum_alpha(self, t):
49 | raise NotImplementedError
50 |
51 | def snr(self, t): # signal noise ratio
52 | raise NotImplementedError
53 |
54 | def nsr(self, t): # noise signal ratio
55 | raise NotImplementedError
56 |
57 | def marginal_prob(self, x0, t): # the mean and std of q(xt|x0)
58 | alpha = self.cum_alpha(t)
59 | beta = self.cum_beta(t)
60 | mean = stp(alpha ** 0.5, x0) # E[xt|x0]
61 | std = beta ** 0.5 # Cov[xt|x0] ** 0.5
62 | return mean, std
63 |
64 | def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform
65 | t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init
66 | mean, std = self.marginal_prob(x0, t)
67 | eps = torch.randn_like(x0)
68 | xt = mean + stp(std, eps)
69 | return t, eps, xt
70 |
71 |
72 | class VPSDE(SDE):
73 | def __init__(self, beta_min=0.1, beta_max=20):
74 | # 0 <= t <= 1
75 | self.beta_0 = beta_min
76 | self.beta_1 = beta_max
77 |
78 | def drift(self, x, t):
79 | return -0.5 * stp(self.squared_diffusion(t), x)
80 |
81 | def diffusion(self, t):
82 | return self.squared_diffusion(t) ** 0.5
83 |
84 | def squared_diffusion(self, t): # beta(t)
85 | return self.beta_0 + t * (self.beta_1 - self.beta_0)
86 |
87 | def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau
88 | return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5
89 |
90 | def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I
91 | return 1. - self.skip_alpha(s, t)
92 |
93 | def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs
94 | x = -self.squared_diffusion_integral(s, t)
95 | return x.exp()
96 |
97 | def cum_beta(self, t):
98 | return self.skip_beta(0, t)
99 |
100 | def cum_alpha(self, t):
101 | return self.skip_alpha(0, t)
102 |
103 | def nsr(self, t):
104 | return self.squared_diffusion_integral(0, t).expm1()
105 |
106 | def snr(self, t):
107 | return 1. / self.nsr(t)
108 |
109 | def __str__(self):
110 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
111 |
112 | def __repr__(self):
113 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
114 |
115 |
116 | class VPSDECosine(SDE):
117 | r"""
118 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
119 | f(x, t) is the drift
120 | g(t) is the diffusion
121 | """
122 | def __init__(self, s=0.008):
123 | self.s = s
124 | self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2
125 | self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2
126 |
127 | def drift(self, x, t):
128 | ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2
129 | return stp(ft, x)
130 |
131 | def diffusion(self, t):
132 | return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5
133 |
134 | def cum_beta(self, t): # the variance of xt|x0
135 | return 1 - self.cum_alpha(t)
136 |
137 | def cum_alpha(self, t):
138 | return self.F(t) / self.F0
139 |
140 | def snr(self, t): # signal noise ratio
141 | Ft = self.F(t)
142 | return Ft / (self.F0 - Ft)
143 |
144 | def nsr(self, t): # noise signal ratio
145 | Ft = self.F(t)
146 | return self.F0 / Ft - 1
147 |
148 | def __str__(self):
149 | return 'vpsde_cosine'
150 |
151 | def __repr__(self):
152 | return 'vpsde_cosine'
153 |
154 |
155 | class ScoreModel(object):
156 | r"""
157 | The forward process is q(x_[0,T])
158 | """
159 |
160 | def __init__(self, nnet: nn.Module, pred: str, sde: SDE, T=1):
161 | assert T == 1
162 | self.nnet = nnet
163 | self.pred = pred
164 | self.sde = sde
165 | self.T = T
166 | print(f'ScoreModel with pred={pred}, sde={sde}, T={T}')
167 |
168 | def predict(self, xt, t, **kwargs):
169 | if not isinstance(t, torch.Tensor):
170 | t = torch.tensor(t)
171 | t = t.to(xt.device)
172 | if t.dim() == 0:
173 | t = duplicate(t, xt.size(0))
174 | return self.nnet(xt, t * 999, **kwargs) # follow SDE
175 |
176 | def noise_pred(self, xt, t, **kwargs):
177 | pred = self.predict(xt, t, **kwargs)
178 | if self.pred == 'noise_pred':
179 | noise_pred = pred
180 | elif self.pred == 'x0_pred':
181 | noise_pred = - stp(self.sde.snr(t).sqrt(), pred) + stp(self.sde.cum_beta(t).rsqrt(), xt)
182 | else:
183 | raise NotImplementedError
184 | return noise_pred
185 |
186 | def x0_pred(self, xt, t, **kwargs):
187 | pred = self.predict(xt, t, **kwargs)
188 | if self.pred == 'noise_pred':
189 | x0_pred = stp(self.sde.cum_alpha(t).rsqrt(), xt) - stp(self.sde.nsr(t).sqrt(), pred)
190 | elif self.pred == 'x0_pred':
191 | x0_pred = pred
192 | else:
193 | raise NotImplementedError
194 | return x0_pred
195 |
196 | def score(self, xt, t, **kwargs):
197 | cum_beta = self.sde.cum_beta(t)
198 | noise_pred = self.noise_pred(xt, t, **kwargs)
199 | return stp(-cum_beta.rsqrt(), noise_pred)
200 |
201 |
202 | class ReverseSDE(object):
203 | r"""
204 | dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw
205 | """
206 | def __init__(self, score_model):
207 | self.sde = score_model.sde # the forward sde
208 | self.score_model = score_model
209 |
210 | def drift(self, x, t, **kwargs):
211 | drift = self.sde.drift(x, t) # f(x, t)
212 | diffusion = self.sde.diffusion(t) # g(t)
213 | score = self.score_model.score(x, t, **kwargs)
214 | return drift - stp(diffusion ** 2, score)
215 |
216 | def diffusion(self, t):
217 | return self.sde.diffusion(t)
218 |
219 |
220 | class ODE(object):
221 | r"""
222 | dx = [f(x, t) - g(t)^2 s(x, t)] dt
223 | """
224 |
225 | def __init__(self, score_model):
226 | self.sde = score_model.sde # the forward sde
227 | self.score_model = score_model
228 |
229 | def drift(self, x, t, **kwargs):
230 | drift = self.sde.drift(x, t) # f(x, t)
231 | diffusion = self.sde.diffusion(t) # g(t)
232 | score = self.score_model.score(x, t, **kwargs)
233 | return drift - 0.5 * stp(diffusion ** 2, score)
234 |
235 | def diffusion(self, t):
236 | return 0
237 |
238 |
239 | def dct2str(dct):
240 | return str({k: f'{v:.6g}' for k, v in dct.items()})
241 |
242 |
243 | @ torch.no_grad()
244 | def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs):
245 | r"""
246 | The Euler Maruyama sampler for reverse SDE / ODE
247 | See `Score-Based Generative Modeling through Stochastic Differential Equations`
248 | """
249 | assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE)
250 | print(f"euler_maruyama with sample_steps={sample_steps}")
251 | timesteps = np.append(0., np.linspace(eps, T, sample_steps))
252 | timesteps = torch.tensor(timesteps).to(x_init)
253 | x = x_init
254 | if trace is not None:
255 | trace.append(x)
256 | for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'):
257 | drift = rsde.drift(x, t, **kwargs)
258 | diffusion = rsde.diffusion(t)
259 | dt = s - t
260 | mean = x + drift * dt
261 | sigma = diffusion * (-dt).sqrt()
262 | x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean
263 | if trace is not None:
264 | trace.append(x)
265 | statistics = dict(s=s, t=t, sigma=sigma.item())
266 | logging.debug(dct2str(statistics))
267 | return x
268 |
269 |
270 | def LSimple(score_model: ScoreModel, x0, pred='noise_pred', **kwargs):
271 | t, noise, xt = score_model.sde.sample(x0)
272 | if pred == 'noise_pred':
273 | noise_pred = score_model.noise_pred(xt, t, **kwargs)
274 | return mos(noise - noise_pred)
275 | elif pred == 'x0_pred':
276 | x0_pred = score_model.x0_pred(xt, t, **kwargs)
277 | return mos(x0 - x0_pred)
278 | else:
279 | raise NotImplementedError(pred)
280 |
--------------------------------------------------------------------------------
/U-ViT/libs/uvit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
9 | ATTENTION_MODE = 'flash'
10 | else:
11 | try:
12 | import xformers
13 | import xformers.ops
14 | ATTENTION_MODE = 'xformers'
15 | except:
16 | ATTENTION_MODE = 'math'
17 | print(f'attention mode is {ATTENTION_MODE}')
18 |
19 |
20 | def timestep_embedding(timesteps, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 |
24 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
25 | These may be fractional.
26 | :param dim: the dimension of the output.
27 | :param max_period: controls the minimum frequency of the embeddings.
28 | :return: an [N x dim] Tensor of positional embeddings.
29 | """
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33 | ).to(device=timesteps.device)
34 | args = timesteps[:, None].float() * freqs[None]
35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36 | if dim % 2:
37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38 | return embedding
39 |
40 |
41 | def patchify(imgs, patch_size):
42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
43 | return x
44 |
45 |
46 | def unpatchify(x, channels=3):
47 | patch_size = int((x.shape[2] // channels) ** 0.5)
48 | h = w = int(x.shape[1] ** .5)
49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
56 | super().__init__()
57 | self.num_heads = num_heads
58 | head_dim = dim // num_heads
59 | self.scale = qk_scale or head_dim ** -0.5
60 |
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.attn_drop = nn.Dropout(attn_drop)
63 | self.proj = nn.Linear(dim, dim)
64 | self.proj_drop = nn.Dropout(proj_drop)
65 |
66 | def forward(self, x):
67 | B, L, C = x.shape
68 |
69 | qkv = self.qkv(x)
70 | if ATTENTION_MODE == 'flash':
71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
74 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
75 | elif ATTENTION_MODE == 'xformers':
76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
78 | x = xformers.ops.memory_efficient_attention(q, k, v)
79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
80 | elif ATTENTION_MODE == 'math':
81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
87 | else:
88 | raise NotImplemented
89 |
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
107 | self.use_checkpoint = use_checkpoint
108 |
109 | def forward(self, x, skip=None):
110 | if self.use_checkpoint:
111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
112 | else:
113 | return self._forward(x, skip)
114 |
115 | def _forward(self, x, skip=None):
116 | if self.skip_linear is not None:
117 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
118 | x = x + self.attn(self.norm1(x))
119 | x = x + self.mlp(self.norm2(x))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | self.patch_size = patch_size
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | assert H % self.patch_size == 0 and W % self.patch_size == 0
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class UViT(nn.Module):
139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
141 | use_checkpoint=False, conv=True, skip=True):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
144 | self.num_classes = num_classes
145 | self.in_chans = in_chans
146 |
147 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
148 | num_patches = (img_size // patch_size) ** 2
149 |
150 | self.time_embed = nn.Sequential(
151 | nn.Linear(embed_dim, 4 * embed_dim),
152 | nn.SiLU(),
153 | nn.Linear(4 * embed_dim, embed_dim),
154 | ) if mlp_time_embed else nn.Identity()
155 |
156 | if self.num_classes > 0:
157 | self.label_emb = nn.Embedding(self.num_classes, embed_dim)
158 | self.extras = 2
159 | else:
160 | self.extras = 1
161 |
162 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
163 | self.in_blocks = nn.ModuleList([
164 | Block(
165 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
166 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
167 | for _ in range(depth // 2)])
168 |
169 | self.mid_block = Block(
170 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
171 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
172 |
173 | self.out_blocks = nn.ModuleList([
174 | Block(
175 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
176 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
177 | for _ in range(depth // 2)])
178 |
179 | self.norm = norm_layer(embed_dim)
180 | self.patch_dim = patch_size ** 2 * in_chans
181 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
182 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
183 |
184 | trunc_normal_(self.pos_embed, std=.02)
185 | self.apply(self._init_weights)
186 |
187 | def _init_weights(self, m):
188 | if isinstance(m, nn.Linear):
189 | trunc_normal_(m.weight, std=.02)
190 | if isinstance(m, nn.Linear) and m.bias is not None:
191 | nn.init.constant_(m.bias, 0)
192 | elif isinstance(m, nn.LayerNorm):
193 | nn.init.constant_(m.bias, 0)
194 | nn.init.constant_(m.weight, 1.0)
195 |
196 | @torch.jit.ignore
197 | def no_weight_decay(self):
198 | return {'pos_embed'}
199 |
200 | def reset(self):
201 | pass
202 |
203 | def forward(self, x, timesteps, y=None):
204 | #print(timesteps)
205 | x = self.patch_embed(x)
206 | B, L, D = x.shape
207 |
208 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
209 | time_token = time_token.unsqueeze(dim=1)
210 | x = torch.cat((time_token, x), dim=1)
211 | if y is not None:
212 | label_emb = self.label_emb(y)
213 | label_emb = label_emb.unsqueeze(dim=1)
214 | x = torch.cat((label_emb, x), dim=1)
215 | x = x + self.pos_embed
216 |
217 | skips = []
218 | for blk in self.in_blocks:
219 | x = blk(x)
220 | skips.append(x)
221 |
222 | x = self.mid_block(x)
223 |
224 | for blk in self.out_blocks:
225 | x = blk(x, skips.pop())
226 |
227 | x = self.norm(x)
228 | x = self.decoder_pred(x)
229 | assert x.size(1) == self.extras + L
230 | x = x[:, self.extras:, :]
231 | x = unpatchify(x, self.in_chans)
232 | x = self.final_layer(x)
233 | return x
234 |
--------------------------------------------------------------------------------
/U-ViT/tools/fid_score.py:
--------------------------------------------------------------------------------
1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2 |
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 |
7 | When run as a stand-alone program, it compares the distribution of
8 | images that are stored as PNG/JPEG at a specified location with a
9 | distribution given by summary statistics (in pickle format).
10 |
11 | The FID is calculated by assuming that X_1 and X_2 are the activations of
12 | the pool_3 layer of the inception net for generated samples and real world
13 | samples respectively.
14 |
15 | See --help to see further details.
16 |
17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18 | of Tensorflow
19 |
20 | Copyright 2018 Institute of Bioinformatics, JKU Linz
21 |
22 | Licensed under the Apache License, Version 2.0 (the "License");
23 | you may not use this file except in compliance with the License.
24 | You may obtain a copy of the License at
25 |
26 | http://www.apache.org/licenses/LICENSE-2.0
27 |
28 | Unless required by applicable law or agreed to in writing, software
29 | distributed under the License is distributed on an "AS IS" BASIS,
30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31 | See the License for the specific language governing permissions and
32 | limitations under the License.
33 | """
34 | import os
35 | import pathlib
36 |
37 | import numpy as np
38 | import torch
39 | import torchvision.transforms as TF
40 | from PIL import Image
41 | from scipy import linalg
42 | from torch.nn.functional import adaptive_avg_pool2d
43 |
44 | from .read_npz import open_npz_array
45 |
46 | import matplotlib.pyplot as plt
47 |
48 | from torchvision.transforms.functional import to_tensor
49 |
50 | try:
51 | from tqdm import tqdm
52 | except ImportError:
53 | # If tqdm is not available, provide a mock version of it
54 | def tqdm(x):
55 | return x
56 |
57 | from .inception import InceptionV3
58 |
59 |
60 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
61 | 'tif', 'tiff', 'webp'}
62 |
63 |
64 | class ImagePathDataset(torch.utils.data.Dataset):
65 | def __init__(self, files, transforms=None):
66 | self.files = files
67 | self.transforms = transforms
68 |
69 | def __len__(self):
70 | return len(self.files)
71 |
72 | def __getitem__(self, i):
73 | path = self.files[i]
74 | img = Image.open(path).convert('RGB')
75 | if self.transforms is not None:
76 | img = self.transforms(img)
77 | return img
78 |
79 |
80 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
81 | """Calculates the activations of the pool_3 layer for all images.
82 |
83 | Params:
84 | -- files : List of image files paths
85 | -- model : Instance of inception model
86 | -- batch_size : Batch size of images for the model to process at once.
87 | Make sure that the number of samples is a multiple of
88 | the batch size, otherwise some samples are ignored. This
89 | behavior is retained to match the original FID score
90 | implementation.
91 | -- dims : Dimensionality of features returned by Inception
92 | -- device : Device to run calculations
93 | -- num_workers : Number of parallel dataloader workers
94 |
95 | Returns:
96 | -- A numpy array of dimension (num images, dims) that contains the
97 | activations of the given tensor when feeding inception with the
98 | query tensor.
99 | """
100 | model.eval()
101 |
102 | if batch_size > len(files):
103 | print(('Warning: batch size is bigger than the data size. '
104 | 'Setting batch size to data size'))
105 | batch_size = len(files)
106 |
107 | dataset = ImagePathDataset(files, transforms=TF.ToTensor())
108 | dataloader = torch.utils.data.DataLoader(dataset,
109 | batch_size=batch_size,
110 | shuffle=False,
111 | drop_last=False,
112 | num_workers=num_workers)
113 |
114 | pred_arr = np.empty((len(files), dims))
115 |
116 | start_idx = 0
117 |
118 | for batch in tqdm(dataloader):
119 | batch = batch.to(device)
120 | with torch.no_grad():
121 | pred = model(batch)[0]
122 |
123 | # If model output is not scalar, apply global spatial average pooling.
124 | # This happens if you choose a dimensionality not equal 2048.
125 | if pred.size(2) != 1 or pred.size(3) != 1:
126 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
127 |
128 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
129 |
130 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred
131 |
132 | start_idx = start_idx + pred.shape[0]
133 |
134 | return pred_arr
135 |
136 |
137 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
138 | """Numpy implementation of the Frechet Distance.
139 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
140 | and X_2 ~ N(mu_2, C_2) is
141 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
142 |
143 | Stable version by Dougal J. Sutherland.
144 |
145 | Params:
146 | -- mu1 : Numpy array containing the activations of a layer of the
147 | inception net (like returned by the function 'get_predictions')
148 | for generated samples.
149 | -- mu2 : The sample mean over activations, precalculated on an
150 | representative data set.
151 | -- sigma1: The covariance matrix over activations for generated samples.
152 | -- sigma2: The covariance matrix over activations, precalculated on an
153 | representative data set.
154 |
155 | Returns:
156 | -- : The Frechet Distance.
157 | """
158 |
159 | mu1 = np.atleast_1d(mu1)
160 | mu2 = np.atleast_1d(mu2)
161 |
162 | sigma1 = np.atleast_2d(sigma1)
163 | sigma2 = np.atleast_2d(sigma2)
164 |
165 | assert mu1.shape == mu2.shape, \
166 | 'Training and test mean vectors have different lengths'
167 | assert sigma1.shape == sigma2.shape, \
168 | 'Training and test covariances have different dimensions'
169 |
170 | diff = mu1 - mu2
171 |
172 | # Product might be almost singular
173 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
174 | if not np.isfinite(covmean).all():
175 | msg = ('fid calculation produces singular product; '
176 | 'adding %s to diagonal of cov estimates') % eps
177 | print(msg)
178 | offset = np.eye(sigma1.shape[0]) * eps
179 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
180 |
181 | # Numerical error might give slight imaginary component
182 | if np.iscomplexobj(covmean):
183 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
184 | m = np.max(np.abs(covmean.imag))
185 | raise ValueError('Imaginary component {}'.format(m))
186 | covmean = covmean.real
187 |
188 | tr_covmean = np.trace(covmean)
189 |
190 | return (diff.dot(diff) + np.trace(sigma1)
191 | + np.trace(sigma2) - 2 * tr_covmean)
192 |
193 |
194 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
195 | device='cpu', num_workers=8):
196 | """Calculation of the statistics used by the FID.
197 | Params:
198 | -- files : List of image files paths
199 | -- model : Instance of inception model
200 | -- batch_size : The images numpy array is split into batches with
201 | batch size batch_size. A reasonable batch size
202 | depends on the hardware.
203 | -- dims : Dimensionality of features returned by Inception
204 | -- device : Device to run calculations
205 | -- num_workers : Number of parallel dataloader workers
206 |
207 | Returns:
208 | -- mu : The mean over samples of the activations of the pool_3 layer of
209 | the inception model.
210 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
211 | the inception model.
212 | """
213 | act = get_activations(files, model, batch_size, dims, device, num_workers)
214 | mu = np.mean(act, axis=0)
215 | sigma = np.cov(act, rowvar=False)
216 | return mu, sigma
217 |
218 | def compute_statistics_of_images_in_npz(path, model, batch_size, dims, device, num_workers=8):
219 | model.eval()
220 | pred_arr = np.empty((50000, dims))
221 | start_idx = 0
222 |
223 | with open_npz_array(path, "arr_0") as reader:
224 | for samples in tqdm(reader.read_batches(batch_size), total=50000 // batch_size):
225 | samples = np.array(samples)
226 | samples = torch.from_numpy(samples.transpose(0, 3, 1, 2)).contiguous().to(device)
227 | samples = samples.div(255).float()
228 |
229 | #batch = torch.tensor(samples).to(device)
230 | with torch.no_grad():
231 | pred = model(samples)[0]
232 |
233 | # If model output is not scalar, apply global spatial average pooling.
234 | # This happens if you choose a dimensionality not equal 2048.
235 | if pred.size(2) != 1 or pred.size(3) != 1:
236 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
237 |
238 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
239 |
240 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred
241 |
242 | start_idx = start_idx + pred.shape[0]
243 |
244 | act = pred_arr
245 | mu = np.mean(act, axis=0)
246 | sigma = np.cov(act, rowvar=False)
247 | return mu, sigma
248 |
249 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
250 | if path.endswith('.npz'):
251 | with np.load(path) as f:
252 | m, s = f['mu'][:], f['sigma'][:]
253 | else:
254 | path = pathlib.Path(path)
255 | files = sorted([file for ext in IMAGE_EXTENSIONS
256 | for file in path.glob('*.{}'.format(ext))])
257 | m, s = calculate_activation_statistics(files, model, batch_size,
258 | dims, device, num_workers)
259 |
260 | return m, s
261 |
262 |
263 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
264 | if device is None:
265 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
266 | else:
267 | device = torch.device(device)
268 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
269 | model = InceptionV3([block_idx]).to(device)
270 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
271 | np.savez(out_path, mu=m1, sigma=s1)
272 |
273 |
274 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
275 | """Calculates the FID of two paths"""
276 | if device is None:
277 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
278 | else:
279 | device = torch.device(device)
280 |
281 | for p in paths:
282 | if not os.path.exists(p):
283 | raise RuntimeError('Invalid path: %s' % p)
284 |
285 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
286 |
287 | model = InceptionV3([block_idx]).to(device)
288 |
289 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
290 | dims, device, num_workers)
291 | if paths[1].endswith('.npz'):
292 | m2, s2 = compute_statistics_of_images_in_npz(paths[1], model, batch_size,
293 | dims, device, num_workers)
294 | else:
295 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
296 | dims, device, num_workers)
297 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
298 |
299 | return fid_value
300 |
--------------------------------------------------------------------------------
/DiT/sample_ddp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Samples a large number of images from a pre-trained DiT model using DDP.
9 | Subsequently saves a .npz file that can be used to compute FID and other
10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
11 |
12 | For a simple single-GPU/CPU sampling script, see sample.py.
13 | """
14 | import torch
15 | import torch.distributed as dist
16 | from download import find_model
17 | from diffusion import create_diffusion
18 | from diffusers.models import AutoencoderKL
19 | from tqdm import tqdm
20 | import os
21 | from PIL import Image
22 | import numpy as np
23 | import math
24 | import argparse
25 |
26 |
27 | def create_npz_from_sample_folder(sample_dir, num=50_000):
28 | """
29 | Builds a single .npz file from a folder of .png samples.
30 | """
31 | samples = []
32 | for i in tqdm(range(num), desc="Building .npz file from samples"):
33 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
34 | sample_np = np.asarray(sample_pil).astype(np.uint8)
35 | samples.append(sample_np)
36 | samples = np.stack(samples)
37 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
38 | npz_path = f"{sample_dir}.npz"
39 | np.savez(npz_path, arr_0=samples)
40 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
41 | return npz_path
42 |
43 |
44 | def main(args):
45 | """
46 | Run sampling.
47 | """
48 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
49 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
50 | torch.set_grad_enabled(False)
51 |
52 | # Setup DDP:
53 | dist.init_process_group("nccl")
54 | rank = dist.get_rank()
55 | device = rank % torch.cuda.device_count()
56 | seed = args.global_seed * dist.get_world_size() + rank
57 | torch.manual_seed(seed)
58 | torch.cuda.set_device(device)
59 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
60 |
61 | if args.ckpt is None:
62 | assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download."
63 | assert args.image_size in [256, 512]
64 | assert args.num_classes == 1000
65 |
66 | diffusion = create_diffusion(str(args.num_sampling_steps))
67 |
68 | # Load model:
69 | latent_size = args.image_size // 8
70 | if args.accelerate_method == "cache":
71 | from models.cache_models import DiT_models
72 | elif args.accelerate_method == "iterate":
73 | from models.iterate_models import DiT_models
74 | elif args.accelerate_method == "nolastlayer":
75 | from models.nolastlayer_models import DiT_models
76 | elif args.accelerate_method is not None and "ranklayer" in args.accelerate_method:
77 | from models.rankdrop_models import DiT_models
78 | elif args.accelerate_method is not None and "bottomlayer" in args.accelerate_method:
79 | from models.bottom_models import DiT_models
80 | elif args.accelerate_method is not None and "randomlayer" in args.accelerate_method:
81 | from models.randomlayer_models import DiT_models
82 | elif args.accelerate_method is not None and "fixlayer" in args.accelerate_method:
83 | from models.fixlayer_models import DiT_models
84 | elif args.accelerate_method is not None and args.accelerate_method == "dynamiclayer":
85 | from models.dynamic_models import DiT_models
86 | elif args.accelerate_method is not None and args.accelerate_method == "layerdropout":
87 | from models.layerdropout_models import DiT_models
88 | elif args.accelerate_method is not None and args.accelerate_method == "dynamiclayer_soft":
89 | from models.router_models_inference import DiT_models
90 | else:
91 | from models.models import DiT_models
92 |
93 | model = DiT_models[args.model](
94 | input_size=latent_size,
95 | num_classes=args.num_classes
96 | ).to(device)
97 |
98 | if args.accelerate_method is not None:
99 | if 'ranklayer' in args.accelerate_method:
100 | model.load_ranking(args.num_sampling_steps, args.accelerate_method)
101 | elif 'randomlayer' in args.accelerate_method:
102 | model.load_ranking(args.accelerate_method)
103 | elif 'bottomlayer' in args.accelerate_method or 'fixlayer' in args.accelerate_method:
104 | model.load_ranking(args.accelerate_method)
105 | elif 'dynamiclayer' in args.accelerate_method or 'layerdropout' in args.accelerate_method or 'dynamiclayer_soft' in args.accelerate_method:
106 | model.load_ranking(args.path, args.num_sampling_steps, diffusion.timestep_map, args.thres)
107 |
108 |
109 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py:
110 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
111 | state_dict = find_model(ckpt_path)
112 | model.load_state_dict(state_dict)
113 | model.eval() # important!
114 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
115 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
116 | using_cfg = args.cfg_scale > 1.0
117 |
118 | # Create folder to save samples:
119 | model_string_name = args.model.replace("/", "-")
120 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
121 | if args.accelerate_method is not None and 'dynamiclayer' in args.accelerate_method:
122 | router_name = args.path.split('/')[1].split('.')[0]
123 | folder_name = f"router-{router_name}-thres-{args.thres}-accelerate-{args.accelerate_method}-size-{args.image_size}-vae-{args.vae}-ddim-{args.ddim_sample}-" \
124 | f"steps-{args.num_sampling_steps}-cfg-{args.cfg_scale}-seed-{args.global_seed}"
125 | else:
126 | folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-psampler-{args.p_sample}-ddim-{args.ddim_sample}-" \
127 | f"steps-{args.num_sampling_steps}-accelerate-{args.accelerate_method}-cfg-{args.cfg_scale}-seed-{args.global_seed}"
128 | sample_folder_dir = f"{args.sample_dir}/{folder_name}"
129 |
130 | os.makedirs(f"{args.sample_dir}", exist_ok=True)
131 | if rank == 0 and args.save_to_disk:
132 | os.makedirs(sample_folder_dir, exist_ok=True)
133 | print(f"Saving .png samples at {sample_folder_dir}")
134 | dist.barrier()
135 |
136 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
137 | n = args.per_proc_batch_size
138 | global_batch_size = n * dist.get_world_size()
139 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
140 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
141 | if rank == 0:
142 | print(f"Total number of images that will be sampled: {total_samples}")
143 | all_images = []
144 |
145 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
146 | samples_needed_this_gpu = int(total_samples // dist.get_world_size())
147 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
148 | iterations = int(samples_needed_this_gpu // n)
149 | pbar = range(iterations)
150 | pbar = tqdm(pbar) if rank == 0 else pbar
151 | total = 0
152 |
153 | for _ in pbar:
154 | model.reset(args.num_sampling_steps)
155 |
156 | # Sample inputs:
157 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
158 | y = torch.randint(0, args.num_classes, (n,), device=device)
159 |
160 |
161 | # Setup classifier-free guidance:
162 | if using_cfg:
163 | z = torch.cat([z, z], 0)
164 | y_null = torch.tensor([1000] * n, device=device)
165 | y = torch.cat([y, y_null], 0)
166 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
167 | sample_fn = model.forward_with_cfg
168 | else:
169 | model_kwargs = dict(y=y)
170 | sample_fn = model.forward
171 |
172 | # Sample images:
173 | if args.p_sample:
174 | samples = diffusion.p_sample_loop(
175 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
176 | )
177 | elif args.ddim_sample:
178 | samples = diffusion.ddim_sample_loop(
179 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
180 | )
181 | else:
182 | raise NotImplementedError
183 |
184 | if using_cfg:
185 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
186 |
187 | samples = vae.decode(samples / 0.18215).sample
188 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(dtype=torch.uint8)
189 |
190 | # Save samples to disk as individual .png files
191 | if args.save_to_disk:
192 | for i, sample in enumerate(samples):
193 | index = i * dist.get_world_size() + rank + total
194 | sample = sample.cpu().numpy()
195 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
196 | else:
197 | samples = samples.contiguous()
198 | gathered_samples = [torch.zeros_like(samples) for _ in range(dist.get_world_size())]
199 | dist.all_gather(gathered_samples, samples)
200 |
201 | if rank == 0:
202 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
203 | total += global_batch_size
204 |
205 | dist.barrier()
206 |
207 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
208 | dist.barrier()
209 | if rank == 0:
210 | if args.save_to_disk:
211 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
212 | print("Done.")
213 | else:
214 | if rank == 0:
215 | arr = np.concatenate(all_images, axis=0)
216 | arr = arr[: args.num_fid_samples]
217 |
218 | out_path = f"{sample_folder_dir}.npz"
219 |
220 | print(f"saving to {out_path}")
221 | np.savez(out_path, arr_0=arr)
222 | dist.barrier()
223 | dist.destroy_process_group()
224 |
225 |
226 | if __name__ == "__main__":
227 | parser = argparse.ArgumentParser()
228 | parser.add_argument("--model", type=str, default="DiT-XL/2")
229 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
230 | parser.add_argument("--sample-dir", type=str, default="samples")
231 | parser.add_argument("--per-proc-batch-size", type=int, default=32)
232 | parser.add_argument("--num-fid-samples", type=int, default=50_000)
233 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
234 | parser.add_argument("--num-classes", type=int, default=1000)
235 | parser.add_argument("--cfg-scale", type=float, default=1.5)
236 | parser.add_argument("--num-sampling-steps", type=int, default=250)
237 | parser.add_argument("--global-seed", type=int, default=0)
238 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
239 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
240 | parser.add_argument("--ckpt", type=str, default=None,
241 | help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
242 |
243 | parser.add_argument("--ddim-sample", action="store_true", default=False,)
244 | parser.add_argument("--p-sample", action="store_true", default=False,)
245 |
246 | parser.add_argument("--accelerate-method", type=str, default=None,
247 | help="Use the accelerated version of the model.")
248 | parser.add_argument("--thres", type=float, default=0.5)
249 |
250 | parser.add_argument("--name", type=str, default="None")
251 | parser.add_argument("--path", type=str, default=None,)
252 |
253 | parser.add_argument("--save-to-disk", action="store_true", default=False,)
254 |
255 |
256 | args = parser.parse_args()
257 | main(args)
258 |
--------------------------------------------------------------------------------
/U-ViT/tools/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=(DEFAULT_BLOCK_INDEX,),
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 |
39 | Parameters
40 | ----------
41 | output_blocks : list of int
42 | Indices of blocks to return features of. Possible values are:
43 | - 0: corresponds to output of first max pooling
44 | - 1: corresponds to output of second max pooling
45 | - 2: corresponds to output which is fed to aux classifier
46 | - 3: corresponds to output of final average pooling
47 | resize_input : bool
48 | If true, bilinearly resizes input to width and height 299 before
49 | feeding input to model. As the network without fully connected
50 | layers is fully convolutional, it should be able to handle inputs
51 | of arbitrary size, so resizing might not be strictly needed
52 | normalize_input : bool
53 | If true, scales the input from range (0, 1) to the range the
54 | pretrained Inception network expects, namely (-1, 1)
55 | requires_grad : bool
56 | If true, parameters of the model require gradients. Possibly useful
57 | for finetuning the network
58 | use_fid_inception : bool
59 | If true, uses the pretrained Inception model used in Tensorflow's
60 | FID implementation. If false, uses the pretrained Inception model
61 | available in torchvision. The FID Inception model has different
62 | weights and a slightly different structure from torchvision's
63 | Inception model. If you want to compute FID scores, you are
64 | strongly advised to set this parameter to true to get comparable
65 | results.
66 | """
67 | super(InceptionV3, self).__init__()
68 |
69 | self.resize_input = resize_input
70 | self.normalize_input = normalize_input
71 | self.output_blocks = sorted(output_blocks)
72 | self.last_needed_block = max(output_blocks)
73 |
74 | assert self.last_needed_block <= 3, \
75 | 'Last possible output block index is 3'
76 |
77 | self.blocks = nn.ModuleList()
78 |
79 | if use_fid_inception:
80 | inception = fid_inception_v3()
81 | else:
82 | inception = _inception_v3(pretrained=True)
83 |
84 | # Block 0: input to maxpool1
85 | block0 = [
86 | inception.Conv2d_1a_3x3,
87 | inception.Conv2d_2a_3x3,
88 | inception.Conv2d_2b_3x3,
89 | nn.MaxPool2d(kernel_size=3, stride=2)
90 | ]
91 | self.blocks.append(nn.Sequential(*block0))
92 |
93 | # Block 1: maxpool1 to maxpool2
94 | if self.last_needed_block >= 1:
95 | block1 = [
96 | inception.Conv2d_3b_1x1,
97 | inception.Conv2d_4a_3x3,
98 | nn.MaxPool2d(kernel_size=3, stride=2)
99 | ]
100 | self.blocks.append(nn.Sequential(*block1))
101 |
102 | # Block 2: maxpool2 to aux classifier
103 | if self.last_needed_block >= 2:
104 | block2 = [
105 | inception.Mixed_5b,
106 | inception.Mixed_5c,
107 | inception.Mixed_5d,
108 | inception.Mixed_6a,
109 | inception.Mixed_6b,
110 | inception.Mixed_6c,
111 | inception.Mixed_6d,
112 | inception.Mixed_6e,
113 | ]
114 | self.blocks.append(nn.Sequential(*block2))
115 |
116 | # Block 3: aux classifier to final avgpool
117 | if self.last_needed_block >= 3:
118 | block3 = [
119 | inception.Mixed_7a,
120 | inception.Mixed_7b,
121 | inception.Mixed_7c,
122 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
123 | ]
124 | self.blocks.append(nn.Sequential(*block3))
125 |
126 | for param in self.parameters():
127 | param.requires_grad = requires_grad
128 |
129 | def forward(self, inp):
130 | """Get Inception feature maps
131 |
132 | Parameters
133 | ----------
134 | inp : torch.autograd.Variable
135 | Input tensor of shape Bx3xHxW. Values are expected to be in
136 | range (0, 1)
137 |
138 | Returns
139 | -------
140 | List of torch.autograd.Variable, corresponding to the selected output
141 | block, sorted ascending by index
142 | """
143 | outp = []
144 | x = inp
145 |
146 | if self.resize_input:
147 | x = F.interpolate(x,
148 | size=(299, 299),
149 | mode='bilinear',
150 | align_corners=False)
151 |
152 | if self.normalize_input:
153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154 |
155 | for idx, block in enumerate(self.blocks):
156 | x = block(x)
157 | if idx in self.output_blocks:
158 | outp.append(x)
159 |
160 | if idx == self.last_needed_block:
161 | break
162 |
163 | return outp
164 |
165 |
166 | def _inception_v3(*args, **kwargs):
167 | """Wraps `torchvision.models.inception_v3`
168 |
169 | Skips default weight inititialization if supported by torchvision version.
170 | See https://github.com/mseitzer/pytorch-fid/issues/28.
171 | """
172 | try:
173 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
174 | except ValueError:
175 | # Just a caution against weird version strings
176 | version = (0,)
177 |
178 | if version >= (0, 6):
179 | kwargs['init_weights'] = False
180 |
181 | return torchvision.models.inception_v3(*args, **kwargs)
182 |
183 |
184 | def fid_inception_v3():
185 | """Build pretrained Inception model for FID computation
186 |
187 | The Inception model for FID computation uses a different set of weights
188 | and has a slightly different structure than torchvision's Inception.
189 |
190 | This method first constructs torchvision's Inception and then patches the
191 | necessary parts that are different in the FID Inception model.
192 | """
193 | inception = _inception_v3(num_classes=1008,
194 | aux_logits=False,
195 | pretrained=False)
196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203 | inception.Mixed_7b = FIDInceptionE_1(1280)
204 | inception.Mixed_7c = FIDInceptionE_2(2048)
205 |
206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
207 | inception.load_state_dict(state_dict)
208 | return inception
209 |
210 |
211 | class FIDInceptionA(torchvision.models.inception.InceptionA):
212 | """InceptionA block patched for FID computation"""
213 | def __init__(self, in_channels, pool_features):
214 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
215 |
216 | def forward(self, x):
217 | branch1x1 = self.branch1x1(x)
218 |
219 | branch5x5 = self.branch5x5_1(x)
220 | branch5x5 = self.branch5x5_2(branch5x5)
221 |
222 | branch3x3dbl = self.branch3x3dbl_1(x)
223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225 |
226 | # Patch: Tensorflow's average pool does not use the padded zero's in
227 | # its average calculation
228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229 | count_include_pad=False)
230 | branch_pool = self.branch_pool(branch_pool)
231 |
232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233 | return torch.cat(outputs, 1)
234 |
235 |
236 | class FIDInceptionC(torchvision.models.inception.InceptionC):
237 | """InceptionC block patched for FID computation"""
238 | def __init__(self, in_channels, channels_7x7):
239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240 |
241 | def forward(self, x):
242 | branch1x1 = self.branch1x1(x)
243 |
244 | branch7x7 = self.branch7x7_1(x)
245 | branch7x7 = self.branch7x7_2(branch7x7)
246 | branch7x7 = self.branch7x7_3(branch7x7)
247 |
248 | branch7x7dbl = self.branch7x7dbl_1(x)
249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253 |
254 | # Patch: Tensorflow's average pool does not use the padded zero's in
255 | # its average calculation
256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257 | count_include_pad=False)
258 | branch_pool = self.branch_pool(branch_pool)
259 |
260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261 | return torch.cat(outputs, 1)
262 |
263 |
264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265 | """First InceptionE block patched for FID computation"""
266 | def __init__(self, in_channels):
267 | super(FIDInceptionE_1, self).__init__(in_channels)
268 |
269 | def forward(self, x):
270 | branch1x1 = self.branch1x1(x)
271 |
272 | branch3x3 = self.branch3x3_1(x)
273 | branch3x3 = [
274 | self.branch3x3_2a(branch3x3),
275 | self.branch3x3_2b(branch3x3),
276 | ]
277 | branch3x3 = torch.cat(branch3x3, 1)
278 |
279 | branch3x3dbl = self.branch3x3dbl_1(x)
280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281 | branch3x3dbl = [
282 | self.branch3x3dbl_3a(branch3x3dbl),
283 | self.branch3x3dbl_3b(branch3x3dbl),
284 | ]
285 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
286 |
287 | # Patch: Tensorflow's average pool does not use the padded zero's in
288 | # its average calculation
289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290 | count_include_pad=False)
291 | branch_pool = self.branch_pool(branch_pool)
292 |
293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294 | return torch.cat(outputs, 1)
295 |
296 |
297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298 | """Second InceptionE block patched for FID computation"""
299 | def __init__(self, in_channels):
300 | super(FIDInceptionE_2, self).__init__(in_channels)
301 |
302 | def forward(self, x):
303 | branch1x1 = self.branch1x1(x)
304 |
305 | branch3x3 = self.branch3x3_1(x)
306 | branch3x3 = [
307 | self.branch3x3_2a(branch3x3),
308 | self.branch3x3_2b(branch3x3),
309 | ]
310 | branch3x3 = torch.cat(branch3x3, 1)
311 |
312 | branch3x3dbl = self.branch3x3dbl_1(x)
313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314 | branch3x3dbl = [
315 | self.branch3x3dbl_3a(branch3x3dbl),
316 | self.branch3x3dbl_3b(branch3x3dbl),
317 | ]
318 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
319 |
320 | # Patch: The FID Inception model uses max pooling instead of average
321 | # pooling. This is likely an error in this specific Inception
322 | # implementation, as other Inception models use average pooling here
323 | # (which matches the description in the paper).
324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325 | branch_pool = self.branch_pool(branch_pool)
326 |
327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328 | return torch.cat(outputs, 1)
329 |
--------------------------------------------------------------------------------
/U-ViT/libs/uvit_dynamic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import numpy as np
5 | from .timm import trunc_normal_, Mlp
6 | import einops
7 | import torch.utils.checkpoint
8 |
9 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
10 | ATTENTION_MODE = 'flash'
11 | else:
12 | try:
13 | import xformers
14 | import xformers.ops
15 | ATTENTION_MODE = 'xformers'
16 | except:
17 | ATTENTION_MODE = 'math'
18 | print(f'attention mode is {ATTENTION_MODE}')
19 |
20 |
21 | def timestep_embedding(timesteps, dim, max_period=10000):
22 | """
23 | Create sinusoidal timestep embeddings.
24 |
25 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
26 | These may be fractional.
27 | :param dim: the dimension of the output.
28 | :param max_period: controls the minimum frequency of the embeddings.
29 | :return: an [N x dim] Tensor of positional embeddings.
30 | """
31 | half = dim // 2
32 | freqs = torch.exp(
33 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
34 | ).to(device=timesteps.device)
35 | args = timesteps[:, None].float() * freqs[None]
36 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
37 | if dim % 2:
38 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
39 | return embedding
40 |
41 |
42 | def patchify(imgs, patch_size):
43 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
44 | return x
45 |
46 |
47 | def unpatchify(x, channels=3):
48 | patch_size = int((x.shape[2] // channels) ** 0.5)
49 | h = w = int(x.shape[1] ** .5)
50 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
51 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
52 | return x
53 |
54 |
55 | class Attention(nn.Module):
56 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
57 | super().__init__()
58 | self.num_heads = num_heads
59 | head_dim = dim // num_heads
60 | self.scale = qk_scale or head_dim ** -0.5
61 |
62 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
63 | self.attn_drop = nn.Dropout(attn_drop)
64 | self.proj = nn.Linear(dim, dim)
65 | self.proj_drop = nn.Dropout(proj_drop)
66 |
67 | def forward(self, x):
68 | B, L, C = x.shape
69 |
70 | qkv = self.qkv(x)
71 | if ATTENTION_MODE == 'flash':
72 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
73 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
74 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
75 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
76 | elif ATTENTION_MODE == 'xformers':
77 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
78 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
79 | x = xformers.ops.memory_efficient_attention(q, k, v)
80 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
81 | elif ATTENTION_MODE == 'math':
82 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
83 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
84 | attn = (q @ k.transpose(-2, -1)) * self.scale
85 | attn = attn.softmax(dim=-1)
86 | attn = self.attn_drop(attn)
87 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
88 | else:
89 | raise NotImplemented
90 |
91 | x = self.proj(x)
92 | x = self.proj_drop(x)
93 | return x
94 |
95 |
96 | class Block(nn.Module):
97 |
98 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
99 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
100 | super().__init__()
101 | self.norm1 = norm_layer(dim)
102 | self.attn = Attention(
103 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
104 | self.norm2 = norm_layer(dim)
105 | mlp_hidden_dim = int(dim * mlp_ratio)
106 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
107 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
108 | self.use_checkpoint = use_checkpoint
109 |
110 | def forward(self, x, skip=None, reuse_att=None, reuse_mlp=None):
111 | if self.use_checkpoint:
112 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip, reuse_att, reuse_mlp)
113 | else:
114 | return self._forward(x, skip, reuse_att, reuse_mlp)
115 |
116 | def _forward(self, x, skip=None, reuse_att=None, reuse_mlp=None):
117 | if self.skip_linear is not None:
118 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
119 |
120 | if reuse_att is not None:
121 | x = x + reuse_att
122 | else:
123 | reuse_att = self.attn(self.norm1(x))
124 | x = x + reuse_att
125 |
126 | if reuse_mlp is not None:
127 | x = x + reuse_mlp
128 | else:
129 | reuse_mlp = self.mlp(self.norm2(x))
130 | x = x + reuse_mlp
131 | return x, (reuse_att, reuse_mlp)
132 |
133 |
134 | class PatchEmbed(nn.Module):
135 | """ Image to Patch Embedding
136 | """
137 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
138 | super().__init__()
139 | self.patch_size = patch_size
140 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
141 |
142 | def forward(self, x):
143 | B, C, H, W = x.shape
144 | assert H % self.patch_size == 0 and W % self.patch_size == 0
145 | x = self.proj(x).flatten(2).transpose(1, 2)
146 | return x
147 |
148 |
149 | class UViT(nn.Module):
150 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
151 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
152 | use_checkpoint=False, conv=True, skip=True):
153 | super().__init__()
154 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
155 | self.num_classes = num_classes
156 | self.in_chans = in_chans
157 |
158 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
159 | num_patches = (img_size // patch_size) ** 2
160 |
161 | self.time_embed = nn.Sequential(
162 | nn.Linear(embed_dim, 4 * embed_dim),
163 | nn.SiLU(),
164 | nn.Linear(4 * embed_dim, embed_dim),
165 | ) if mlp_time_embed else nn.Identity()
166 |
167 | if self.num_classes > 0:
168 | self.label_emb = nn.Embedding(self.num_classes, embed_dim)
169 | self.extras = 2
170 | else:
171 | self.extras = 1
172 |
173 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
174 |
175 | self.in_blocks = nn.ModuleList([
176 | Block(
177 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
178 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
179 | for _ in range(depth // 2)])
180 |
181 | self.mid_block = Block(
182 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
183 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
184 |
185 | self.out_blocks = nn.ModuleList([
186 | Block(
187 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
188 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
189 | for _ in range(depth // 2)])
190 |
191 | self.depth = depth + 1 # depth//2 for in/out, and 1 for mid
192 |
193 | self.norm = norm_layer(embed_dim)
194 | self.patch_dim = patch_size ** 2 * in_chans
195 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
196 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
197 |
198 | trunc_normal_(self.pos_embed, std=.02)
199 | self.apply(self._init_weights)
200 |
201 | self.reset()
202 |
203 | def reset_cache_features(self):
204 | self.cond_cache_features = [None] * self.depth
205 | self.uncond_cache_features = [None] * self.depth
206 |
207 | def reset(self):
208 | self.cur_step_idx = 0
209 | self.reset_cache_features()
210 |
211 | def _init_weights(self, m):
212 | if isinstance(m, nn.Linear):
213 | trunc_normal_(m.weight, std=.02)
214 | if isinstance(m, nn.Linear) and m.bias is not None:
215 | nn.init.constant_(m.bias, 0)
216 | elif isinstance(m, nn.LayerNorm):
217 | nn.init.constant_(m.bias, 0)
218 | nn.init.constant_(m.weight, 1.0)
219 |
220 | @torch.jit.ignore
221 | def no_weight_decay(self):
222 | return {'pos_embed'}
223 |
224 | def load_ranking(self, path, num_steps, timestep_map, thres):
225 | self.rank = [None] * num_steps
226 | from .uvit_router import Router
227 |
228 | act_layer, total_layer = 0, 0
229 | ckpt = torch.load(path, map_location='cpu')
230 | routers = torch.nn.ModuleList([
231 | Router(2*self.depth) for _ in range(num_steps)
232 | ])
233 | routers.load_state_dict(ckpt)
234 | self.timestep_map = {timestep: i for i, timestep in enumerate(timestep_map)}
235 | print(self.timestep_map)
236 |
237 | act_att, act_mlp = 0, 0
238 | for idx, router in enumerate(routers[:num_steps//2]):
239 | if idx != 0:
240 | self.rank[idx] = (router() > thres).float().nonzero().squeeze(0)
241 | total_layer += 2 * self.depth
242 | act_layer += len(self.rank[idx])
243 | print(f"TImestep {idx}: Not Reuse: {self.rank[idx].squeeze()}")
244 |
245 | if len(self.rank[idx]) > 0:
246 | act_att += sum(1 - torch.remainder(self.rank[idx], 2)).item()
247 | act_mlp += sum(torch.remainder(self.rank[idx], 2)).item()
248 |
249 | print(f"Total Activate Layer: {act_layer}/{total_layer}, Remove Ratio = {1 - act_layer/total_layer}")
250 | print(f"Total Activate Attention: {act_att}/{total_layer//2}, Remove Ratio = {1 - act_att/(total_layer//2)}")
251 | print(f"Total Activate MLP: {act_mlp}/{total_layer//2}, Remove Ratio = {1 - act_mlp/(total_layer//2)}")
252 |
253 | def forward(self, x, timesteps, y=None):
254 | x = self.patch_embed(x)
255 | B, L, D = x.shape
256 |
257 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
258 | time_token = time_token.unsqueeze(dim=1)
259 | x = torch.cat((time_token, x), dim=1)
260 | if y is not None:
261 | label_emb = self.label_emb(y)
262 | label_emb = label_emb.unsqueeze(dim=1)
263 | x = torch.cat((label_emb, x), dim=1)
264 | x = x + self.pos_embed
265 |
266 | skips = []
267 |
268 | if self.cur_step_idx % 4 == 2:
269 | self.reset_cache_features()
270 |
271 | if self.cur_step_idx % 2 == 0:
272 | cache_features = self.cond_cache_features
273 | else:
274 | cache_features = self.uncond_cache_features
275 |
276 | round_timestep = round(timesteps[0].item())
277 | router_idx = self.timestep_map[round_timestep] if round_timestep in self.timestep_map else None
278 | #print(f"Round Timestep: {round_timestep}, Router Index: {router_idx}")
279 |
280 | layer_idx = 0
281 | for blk in self.in_blocks:
282 | reuse_att, reuse_mlp = None, None
283 | if cache_features[layer_idx] is not None:
284 | if layer_idx * 2 not in self.rank[router_idx]:
285 | reuse_att, _ = cache_features[layer_idx]
286 | if layer_idx * 2 + 1 not in self.rank[router_idx]:
287 | _, reuse_mlp = cache_features[layer_idx]
288 |
289 | x, cache_feature = blk(x, reuse_att=reuse_att, reuse_mlp=reuse_mlp)
290 | skips.append(x)
291 | cache_features[layer_idx] = cache_feature
292 | layer_idx += 1
293 |
294 | reuse_att, reuse_mlp = None, None
295 | if cache_features[layer_idx] is not None:
296 | if layer_idx * 2 not in self.rank[router_idx]:
297 | reuse_att, _ = cache_features[layer_idx]
298 | if layer_idx * 2 + 1 not in self.rank[router_idx]:
299 | _, reuse_mlp = cache_features[layer_idx]
300 |
301 | x, cache_feature = self.mid_block(x, reuse_att=reuse_att, reuse_mlp=reuse_mlp)
302 | cache_features[layer_idx] = cache_feature
303 | layer_idx += 1
304 |
305 | for blk in self.out_blocks:
306 | reuse_att, reuse_mlp = None, None
307 | if cache_features[layer_idx] is not None:
308 | if layer_idx * 2 not in self.rank[router_idx]:
309 | reuse_att, _ = cache_features[layer_idx]
310 | if layer_idx * 2 + 1 not in self.rank[router_idx]:
311 | _, reuse_mlp = cache_features[layer_idx]
312 | x , cache_feature = blk(x, skips.pop(), reuse_att=reuse_att, reuse_mlp=reuse_mlp)
313 | cache_features[layer_idx] = cache_feature
314 | layer_idx += 1
315 |
316 | x = self.norm(x)
317 | x = self.decoder_pred(x)
318 | assert x.size(1) == self.extras + L
319 | x = x[:, self.extras:, :]
320 | x = unpatchify(x, self.in_chans)
321 | x = self.final_layer(x)
322 |
323 | self.cur_step_idx += 1
324 | return x
325 |
--------------------------------------------------------------------------------
/U-ViT/libs/uvit_router.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 | import numpy as np
8 |
9 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
10 | ATTENTION_MODE = 'flash'
11 | else:
12 | try:
13 | import xformers
14 | import xformers.ops
15 | ATTENTION_MODE = 'xformers'
16 | except:
17 | ATTENTION_MODE = 'math'
18 | print(f'attention mode is {ATTENTION_MODE}')
19 |
20 |
21 | def timestep_embedding(timesteps, dim, max_period=10000):
22 | """
23 | Create sinusoidal timestep embeddings.
24 |
25 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
26 | These may be fractional.
27 | :param dim: the dimension of the output.
28 | :param max_period: controls the minimum frequency of the embeddings.
29 | :return: an [N x dim] Tensor of positional embeddings.
30 | """
31 | half = dim // 2
32 | freqs = torch.exp(
33 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
34 | ).to(device=timesteps.device)
35 | args = timesteps[:, None].float() * freqs[None]
36 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
37 | if dim % 2:
38 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
39 | return embedding
40 |
41 |
42 | def patchify(imgs, patch_size):
43 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
44 | return x
45 |
46 |
47 | def unpatchify(x, channels=3):
48 | patch_size = int((x.shape[2] // channels) ** 0.5)
49 | h = w = int(x.shape[1] ** .5)
50 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
51 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
52 | return x
53 |
54 |
55 | class Attention(nn.Module):
56 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
57 | super().__init__()
58 | self.num_heads = num_heads
59 | head_dim = dim // num_heads
60 | self.scale = qk_scale or head_dim ** -0.5
61 |
62 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
63 | self.attn_drop = nn.Dropout(attn_drop)
64 | self.proj = nn.Linear(dim, dim)
65 | self.proj_drop = nn.Dropout(proj_drop)
66 |
67 | def forward(self, x):
68 | B, L, C = x.shape
69 |
70 | qkv = self.qkv(x)
71 | if ATTENTION_MODE == 'flash':
72 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
73 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
74 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
75 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
76 | elif ATTENTION_MODE == 'xformers':
77 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
78 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
79 | x = xformers.ops.memory_efficient_attention(q, k, v)
80 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
81 | elif ATTENTION_MODE == 'math':
82 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
83 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
84 | attn = (q @ k.transpose(-2, -1)) * self.scale
85 | attn = attn.softmax(dim=-1)
86 | attn = self.attn_drop(attn)
87 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
88 | else:
89 | raise NotImplemented
90 |
91 | x = self.proj(x)
92 | x = self.proj_drop(x)
93 | return x
94 |
95 |
96 | class Block(nn.Module):
97 |
98 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
99 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
100 | super().__init__()
101 | self.norm1 = norm_layer(dim)
102 | self.attn = Attention(
103 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
104 | self.norm2 = norm_layer(dim)
105 | mlp_hidden_dim = int(dim * mlp_ratio)
106 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
107 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
108 | self.use_checkpoint = False #use_checkpoint
109 |
110 | def forward(self, x, skip=None, reuse_att=None, reuse_mlp=None,
111 | reuse_att_weight=0, reuse_mlp_weight=0):
112 | if self.use_checkpoint:
113 | return torch.utils.checkpoint.checkpoint(
114 | self._forward, x, skip, reuse_att, reuse_mlp,
115 | reuse_att_weight, reuse_mlp_weight
116 | )
117 | else:
118 | return self._forward(
119 | x, skip, reuse_att, reuse_mlp,
120 | reuse_att_weight, reuse_mlp_weight
121 | )
122 |
123 | def _forward(self, x, skip=None, reuse_att=None, reuse_mlp=None, reuse_att_weight=None, reuse_mlp_weight=None):
124 | if self.skip_linear is not None:
125 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
126 |
127 | att_out = self.attn(self.norm1(x))
128 | if reuse_att is not None:
129 | att_out = att_out * (1 - reuse_att_weight) + reuse_att * reuse_att_weight
130 | x = x + att_out
131 |
132 | mlp_out = self.mlp(self.norm2(x))
133 | if reuse_mlp is not None:
134 | mlp_out = mlp_out * (1 - reuse_mlp_weight) + reuse_mlp * reuse_mlp_weight
135 | x = x + mlp_out
136 | return x, (att_out, mlp_out)
137 |
138 |
139 | class PatchEmbed(nn.Module):
140 | """ Image to Patch Embedding
141 | """
142 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
143 | super().__init__()
144 | self.patch_size = patch_size
145 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
146 |
147 | def forward(self, x):
148 | B, C, H, W = x.shape
149 | assert H % self.patch_size == 0 and W % self.patch_size == 0
150 | x = self.proj(x).flatten(2).transpose(1, 2)
151 | return x
152 |
153 | class Router(nn.Module):
154 | def __init__(self, num_choises):
155 | super().__init__()
156 | self.num_choises = num_choises
157 | self.prob = torch.nn.Parameter(torch.randn(num_choises), requires_grad=True)
158 |
159 | self.activation = torch.nn.Sigmoid()
160 |
161 | def forward(self, x=None): # Any input will be ignored, only for solving the issue of https://github.com/pytorch/pytorch/issues/37814
162 | return self.activation(self.prob)
163 |
164 | class UViT(nn.Module):
165 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
166 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
167 | use_checkpoint=False, conv=True, skip=True):
168 | super().__init__()
169 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
170 | self.num_classes = num_classes
171 | self.in_chans = in_chans
172 |
173 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
174 | num_patches = (img_size // patch_size) ** 2
175 |
176 | self.time_embed = nn.Sequential(
177 | nn.Linear(embed_dim, 4 * embed_dim),
178 | nn.SiLU(),
179 | nn.Linear(4 * embed_dim, embed_dim),
180 | ) if mlp_time_embed else nn.Identity()
181 |
182 | if self.num_classes > 0:
183 | self.label_emb = nn.Embedding(self.num_classes, embed_dim)
184 | self.extras = 2
185 | else:
186 | self.extras = 1
187 |
188 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
189 |
190 | self.in_blocks = nn.ModuleList([
191 | Block(
192 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
193 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
194 | for _ in range(depth // 2)])
195 |
196 | self.mid_block = Block(
197 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
198 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
199 |
200 | self.out_blocks = nn.ModuleList([
201 | Block(
202 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
203 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
204 | for _ in range(depth // 2)])
205 |
206 | self.depth = depth + 1 # depth//2 for in/out, and 1 for mid
207 |
208 | self.norm = norm_layer(embed_dim)
209 | self.patch_dim = patch_size ** 2 * in_chans
210 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
211 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
212 |
213 | trunc_normal_(self.pos_embed, std=.02)
214 | self.apply(self._init_weights)
215 |
216 | self.reset()
217 |
218 | def reset_cache_features(self):
219 | self.cache_features = [None] * self.depth
220 | self.activate_cache = False
221 | self.record_cache = True
222 |
223 | def reset(self):
224 | self.cur_step_idx = 0
225 | self.reset_cache_features()
226 |
227 | def add_router(self, num_nfes):
228 | self.routers = torch.nn.ModuleList([
229 | Router(2*self.depth) for _ in range(num_nfes)
230 | ])
231 |
232 | def set_activate_cache(self, activate_cache):
233 | self.activate_cache = activate_cache
234 |
235 | def set_record_cache(self, record_cache):
236 | self.record_cache = record_cache
237 |
238 | def set_timestep_map(self, timestep_map):
239 | self.timestep_map = {timestep: i for i, timestep in enumerate(timestep_map)}
240 | print("Timestep -> Router IDX Map:", self.timestep_map)
241 |
242 | def _init_weights(self, m):
243 | if isinstance(m, nn.Linear):
244 | trunc_normal_(m.weight, std=.02)
245 | if isinstance(m, nn.Linear) and m.bias is not None:
246 | nn.init.constant_(m.bias, 0)
247 | elif isinstance(m, nn.LayerNorm):
248 | nn.init.constant_(m.bias, 0)
249 | nn.init.constant_(m.weight, 1.0)
250 |
251 | @torch.jit.ignore
252 | def no_weight_decay(self):
253 | return {'pos_embed'}
254 |
255 | def forward(self, x, timesteps, y=None):
256 | #print("In Model: Get y: ", y, ". Get Timesteps: ", timesteps)
257 | x = self.patch_embed(x)
258 | B, L, D = x.shape
259 |
260 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
261 |
262 | time_token = time_token.unsqueeze(dim=1)
263 | x = torch.cat((time_token, x), dim=1)
264 | if y is not None:
265 | label_emb = self.label_emb(y)
266 | label_emb = label_emb.unsqueeze(dim=1)
267 | x = torch.cat((label_emb, x), dim=1)
268 | x = x + self.pos_embed
269 |
270 | skips = []
271 | cache_features = self.cache_features
272 | if self.activate_cache :
273 | router_idx = self.timestep_map[np.round(timesteps[0].item())]
274 | scores = self.routers[router_idx]()
275 | router_l1_loss = scores.sum()
276 | else:
277 | router_l1_loss = None
278 |
279 | layer_idx = 0
280 | for blk in self.in_blocks:
281 | if cache_features[layer_idx] is not None and self.activate_cache:
282 | reuse_att, reuse_mlp = cache_features[layer_idx]
283 | reuse_att_weight = 1 - scores[layer_idx*2]
284 | reuse_mlp_weight = 1 - scores[layer_idx*2+1]
285 | else:
286 | reuse_att, reuse_mlp = None, None
287 | reuse_att_weight, reuse_mlp_weight = 0, 0
288 |
289 | x, cache_feature = blk(
290 | x, reuse_att=reuse_att, reuse_mlp=reuse_mlp,
291 | reuse_att_weight=reuse_att_weight,
292 | reuse_mlp_weight=reuse_mlp_weight,
293 | )
294 | skips.append(x)
295 | if self.record_cache:
296 | cache_features[layer_idx] = cache_feature
297 | layer_idx += 1
298 |
299 | if cache_features[layer_idx] is not None and self.activate_cache:
300 | reuse_att, reuse_mlp = cache_features[layer_idx]
301 | reuse_att_weight = 1 - scores[layer_idx*2]
302 | reuse_mlp_weight = 1 - scores[layer_idx*2+1]
303 | else:
304 | reuse_att, reuse_mlp = None, None
305 | reuse_att_weight, reuse_mlp_weight = 0, 0
306 |
307 | x, cache_feature = self.mid_block(
308 | x, reuse_att=reuse_att, reuse_mlp=reuse_mlp,
309 | reuse_att_weight=reuse_att_weight,
310 | reuse_mlp_weight=reuse_mlp_weight,
311 | )
312 | if self.record_cache:
313 | cache_features[layer_idx] = cache_feature
314 | layer_idx += 1
315 |
316 | for blk in self.out_blocks:
317 | if cache_features[layer_idx] is not None and self.activate_cache:
318 | reuse_att, reuse_mlp = cache_features[layer_idx]
319 | reuse_att_weight = 1 - scores[layer_idx*2]
320 | reuse_mlp_weight = 1 - scores[layer_idx*2+1]
321 | else:
322 | reuse_att, reuse_mlp = None, None
323 | reuse_att_weight, reuse_mlp_weight = 0, 0
324 |
325 | x , cache_feature = blk(
326 | x, skips.pop(), reuse_att=reuse_att, reuse_mlp=reuse_mlp,
327 | reuse_att_weight=reuse_att_weight,
328 | reuse_mlp_weight=reuse_mlp_weight,
329 | )
330 | if self.record_cache:
331 | cache_features[layer_idx] = cache_feature
332 | layer_idx += 1
333 |
334 | x = self.norm(x)
335 | x = self.decoder_pred(x)
336 | assert x.size(1) == self.extras + L
337 | x = x[:, self.extras:, :]
338 | x = unpatchify(x, self.in_chans)
339 | x = self.final_layer(x)
340 |
341 | self.cur_step_idx += 1
342 |
343 | if self.activate_cache:
344 | return x, router_l1_loss
345 | else:
346 | return x
347 |
--------------------------------------------------------------------------------
/DiT/models/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17 |
18 |
19 | def modulate(x, shift, scale):
20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21 |
22 |
23 | #################################################################################
24 | # Embedding Layers for Timesteps and Class Labels #
25 | #################################################################################
26 |
27 | class TimestepEmbedder(nn.Module):
28 | """
29 | Embeds scalar timesteps into vector representations.
30 | """
31 | def __init__(self, hidden_size, frequency_embedding_size=256):
32 | super().__init__()
33 | self.mlp = nn.Sequential(
34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
35 | nn.SiLU(),
36 | nn.Linear(hidden_size, hidden_size, bias=True),
37 | )
38 | self.frequency_embedding_size = frequency_embedding_size
39 |
40 | @staticmethod
41 | def timestep_embedding(t, dim, max_period=10000):
42 | """
43 | Create sinusoidal timestep embeddings.
44 | :param t: a 1-D Tensor of N indices, one per batch element.
45 | These may be fractional.
46 | :param dim: the dimension of the output.
47 | :param max_period: controls the minimum frequency of the embeddings.
48 | :return: an (N, D) Tensor of positional embeddings.
49 | """
50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51 | half = dim // 2
52 | freqs = torch.exp(
53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
54 | ).to(device=t.device)
55 | args = t[:, None].float() * freqs[None]
56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57 | if dim % 2:
58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59 | return embedding
60 |
61 | def forward(self, t):
62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63 | t_emb = self.mlp(t_freq)
64 | return t_emb
65 |
66 |
67 | class LabelEmbedder(nn.Module):
68 | """
69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70 | """
71 | def __init__(self, num_classes, hidden_size, dropout_prob):
72 | super().__init__()
73 | use_cfg_embedding = dropout_prob > 0
74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75 | self.num_classes = num_classes
76 | self.dropout_prob = dropout_prob
77 |
78 | def token_drop(self, labels, force_drop_ids=None):
79 | """
80 | Drops labels to enable classifier-free guidance.
81 | """
82 | if force_drop_ids is None:
83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84 | else:
85 | drop_ids = force_drop_ids == 1
86 | labels = torch.where(drop_ids, self.num_classes, labels)
87 | return labels
88 |
89 | def forward(self, labels, train, force_drop_ids=None):
90 | use_dropout = self.dropout_prob > 0
91 | if (train and use_dropout) or (force_drop_ids is not None):
92 | labels = self.token_drop(labels, force_drop_ids)
93 | embeddings = self.embedding_table(labels)
94 | return embeddings
95 |
96 |
97 | #################################################################################
98 | # Core DiT Model #
99 | #################################################################################
100 |
101 | class DiTBlock(nn.Module):
102 | """
103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
104 | """
105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
106 | super().__init__()
107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
111 | approx_gelu = lambda: nn.GELU(approximate="tanh")
112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
113 | self.adaLN_modulation = nn.Sequential(
114 | nn.SiLU(),
115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
116 | )
117 |
118 | def forward(self, x, c):
119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
122 | return x
123 |
124 |
125 | class FinalLayer(nn.Module):
126 | """
127 | The final layer of DiT.
128 | """
129 | def __init__(self, hidden_size, patch_size, out_channels):
130 | super().__init__()
131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133 | self.adaLN_modulation = nn.Sequential(
134 | nn.SiLU(),
135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
136 | )
137 |
138 | def forward(self, x, c):
139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
140 | x = modulate(self.norm_final(x), shift, scale)
141 | x = self.linear(x)
142 | return x
143 |
144 |
145 | class DiT(nn.Module):
146 | """
147 | Diffusion model with a Transformer backbone.
148 | """
149 | def __init__(
150 | self,
151 | input_size=32,
152 | patch_size=2,
153 | in_channels=4,
154 | hidden_size=1152,
155 | depth=28,
156 | num_heads=16,
157 | mlp_ratio=4.0,
158 | class_dropout_prob=0.1,
159 | num_classes=1000,
160 | learn_sigma=True,
161 | ):
162 | super().__init__()
163 | self.learn_sigma = learn_sigma
164 | self.in_channels = in_channels
165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
166 | self.patch_size = patch_size
167 | self.num_heads = num_heads
168 |
169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
170 | self.t_embedder = TimestepEmbedder(hidden_size)
171 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
172 | num_patches = self.x_embedder.num_patches
173 | # Will use fixed sin-cos embedding:
174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
175 |
176 | self.blocks = nn.ModuleList([
177 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
178 | ])
179 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
180 | self.initialize_weights()
181 |
182 | def reset(self):
183 | pass
184 |
185 | def initialize_weights(self):
186 | # Initialize transformer layers:
187 | def _basic_init(module):
188 | if isinstance(module, nn.Linear):
189 | torch.nn.init.xavier_uniform_(module.weight)
190 | if module.bias is not None:
191 | nn.init.constant_(module.bias, 0)
192 | self.apply(_basic_init)
193 |
194 | # Initialize (and freeze) pos_embed by sin-cos embedding:
195 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
196 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
197 |
198 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
199 | w = self.x_embedder.proj.weight.data
200 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
201 | nn.init.constant_(self.x_embedder.proj.bias, 0)
202 |
203 | # Initialize label embedding table:
204 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
205 |
206 | # Initialize timestep embedding MLP:
207 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
208 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
209 |
210 | # Zero-out adaLN modulation layers in DiT blocks:
211 | for block in self.blocks:
212 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
213 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
214 |
215 | # Zero-out output layers:
216 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
217 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
218 | nn.init.constant_(self.final_layer.linear.weight, 0)
219 | nn.init.constant_(self.final_layer.linear.bias, 0)
220 |
221 | def unpatchify(self, x):
222 | """
223 | x: (N, T, patch_size**2 * C)
224 | imgs: (N, H, W, C)
225 | """
226 | c = self.out_channels
227 | p = self.x_embedder.patch_size[0]
228 | h = w = int(x.shape[1] ** 0.5)
229 | assert h * w == x.shape[1]
230 |
231 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
232 | x = torch.einsum('nhwpqc->nchpwq', x)
233 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
234 | return imgs
235 |
236 | def forward(self, x, t, y):
237 | """
238 | Forward pass of DiT.
239 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
240 | t: (N,) tensor of diffusion timesteps
241 | y: (N,) tensor of class labels
242 | """
243 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
244 | t = self.t_embedder(t) # (N, D)
245 | y = self.y_embedder(y, self.training) # (N, D)
246 | c = t + y # (N, D)
247 | for block in self.blocks:
248 | x = block(x, c) # (N, T, D)
249 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
250 | x = self.unpatchify(x) # (N, out_channels, H, W)
251 | return x
252 |
253 | def forward_with_cfg(self, x, t, y, cfg_scale):
254 | """
255 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
256 | """
257 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
258 | half = x[: len(x) // 2]
259 | combined = torch.cat([half, half], dim=0)
260 | model_out = self.forward(combined, t, y)
261 | # For exact reproducibility reasons, we apply classifier-free guidance on only
262 | # three channels by default. The standard approach to cfg applies it to all channels.
263 | # This can be done by uncommenting the following line and commenting-out the line following that.
264 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
265 | eps, rest = model_out[:, :3], model_out[:, 3:]
266 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
267 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
268 | eps = torch.cat([half_eps, half_eps], dim=0)
269 | return torch.cat([eps, rest], dim=1)
270 |
271 |
272 | #################################################################################
273 | # Sine/Cosine Positional Embedding Functions #
274 | #################################################################################
275 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
276 |
277 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
278 | """
279 | grid_size: int of the grid height and width
280 | return:
281 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
282 | """
283 | grid_h = np.arange(grid_size, dtype=np.float32)
284 | grid_w = np.arange(grid_size, dtype=np.float32)
285 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
286 | grid = np.stack(grid, axis=0)
287 |
288 | grid = grid.reshape([2, 1, grid_size, grid_size])
289 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
290 | if cls_token and extra_tokens > 0:
291 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
292 | return pos_embed
293 |
294 |
295 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
296 | assert embed_dim % 2 == 0
297 |
298 | # use half of dimensions to encode grid_h
299 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
300 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
301 |
302 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
303 | return emb
304 |
305 |
306 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
307 | """
308 | embed_dim: output dimension for each position
309 | pos: a list of positions to be encoded: size (M,)
310 | out: (M, D)
311 | """
312 | assert embed_dim % 2 == 0
313 | omega = np.arange(embed_dim // 2, dtype=np.float64)
314 | omega /= embed_dim / 2.
315 | omega = 1. / 10000**omega # (D/2,)
316 |
317 | pos = pos.reshape(-1) # (M,)
318 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
319 |
320 | emb_sin = np.sin(out) # (M, D/2)
321 | emb_cos = np.cos(out) # (M, D/2)
322 |
323 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
324 | return emb
325 |
326 |
327 | #################################################################################
328 | # DiT Configs #
329 | #################################################################################
330 |
331 | def DiT_XL_2(**kwargs):
332 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
333 |
334 | def DiT_XL_4(**kwargs):
335 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
336 |
337 | def DiT_XL_8(**kwargs):
338 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
339 |
340 | def DiT_L_2(**kwargs):
341 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
342 |
343 | def DiT_L_4(**kwargs):
344 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
345 |
346 | def DiT_L_8(**kwargs):
347 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
348 |
349 | def DiT_B_2(**kwargs):
350 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
351 |
352 | def DiT_B_4(**kwargs):
353 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
354 |
355 | def DiT_B_8(**kwargs):
356 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
357 |
358 | def DiT_S_2(**kwargs):
359 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
360 |
361 | def DiT_S_4(**kwargs):
362 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
363 |
364 | def DiT_S_8(**kwargs):
365 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
366 |
367 |
368 | DiT_models = {
369 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
370 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
371 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
372 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
373 | }
374 |
--------------------------------------------------------------------------------
/U-ViT/train_router_discrete.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import torch
3 | from torch import multiprocessing as mp
4 | from datasets import get_dataset
5 | from torchvision.utils import make_grid, save_image
6 | import utils
7 | import einops
8 | from torch.utils._pytree import tree_map
9 | import accelerate
10 | from accelerate import DistributedDataParallelKwargs
11 | from torch.utils.data import DataLoader
12 | from tqdm.auto import tqdm
13 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
14 | import tempfile
15 | from tools.fid_score import calculate_fid_given_paths
16 | from absl import logging
17 | import builtins
18 | import os
19 | import wandb
20 | import libs.autoencoder
21 | import numpy as np
22 |
23 |
24 | def format_image_to_wandb(num_router, router_size, router_scores):
25 | image = np.zeros((num_router, router_size, 3), dtype=np.float32)
26 | ones = np.ones((3), dtype=np.float32)
27 | for idx, score in enumerate(router_scores):
28 | mask = score.cpu().detach()
29 | for pos in range(router_size):
30 | image[idx, pos] = ones * mask[pos].item()
31 | return image
32 |
33 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
34 | _betas = (
35 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
36 | )
37 | return _betas.numpy()
38 |
39 |
40 | def get_skip(alphas, betas):
41 | N = len(betas) - 1
42 | skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
43 | for s in range(N + 1):
44 | skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
45 | skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
46 | for t in range(N + 1):
47 | prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
48 | skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
49 | return skip_alphas, skip_betas
50 |
51 |
52 | def stp(s, ts: torch.Tensor): # scalar tensor product
53 | if isinstance(s, np.ndarray):
54 | s = torch.from_numpy(s).type_as(ts)
55 | extra_dims = (1,) * (ts.dim() - 1)
56 | return s.view(-1, *extra_dims) * ts
57 |
58 |
59 | def mos(a, start_dim=1): # mean of square
60 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
61 |
62 | def sos(a, start_dim=1): # sum of square
63 | e = a.pow(2).flatten(start_dim=start_dim)
64 | return e.sum(dim=-1)
65 |
66 |
67 | class Schedule(object): # discrete time
68 | def __init__(self, _betas):
69 | r""" _betas[0...999] = betas[1...1000]
70 | for n>=1, betas[n] is the variance of q(xn|xn-1)
71 | for n=0, betas[0]=0
72 | """
73 |
74 | self._betas = _betas
75 | self.betas = np.append(0., _betas)
76 | self.alphas = 1. - self.betas
77 | self.N = len(_betas)
78 |
79 | assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
80 | assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
81 | assert len(self.betas) == len(self.alphas)
82 |
83 | # skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
84 | self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
85 | self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod()
86 | self.cum_betas = self.skip_betas[0]
87 | self.snr = self.cum_alphas / self.cum_betas
88 |
89 | def tilde_beta(self, s, t):
90 | return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]
91 |
92 | def sample(self, x0): # sample from q(xn|x0), where n is uniform
93 | n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
94 | eps = torch.randn_like(x0)
95 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
96 | return torch.tensor(n, device=x0.device), eps, xn
97 |
98 | def get_xn(self, x0, n):
99 | eps = torch.randn_like(x0)
100 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
101 | return torch.tensor(n, device=x0.device), eps, xn
102 |
103 | def __repr__(self):
104 | return f'Schedule({self.betas[:10]}..., {self.N})'
105 |
106 |
107 | def LSimple(x0, nnet, schedule, **kwargs):
108 |
109 | n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000}
110 | eps_pred = nnet(xn, n, **kwargs)
111 | return mos(eps - eps_pred)
112 |
113 |
114 | def LRouter(x0, nnet, schedule, order=None, timesteps=None, dpm_solver=None, **kwargs):
115 | #print(x0.shape)
116 | #print(order, timesteps)
117 |
118 | def model_fn(x, t_continuous):
119 | t = t_continuous * 1000
120 | eps_pre = nnet(x, t, **kwargs)
121 | return eps_pre
122 | dpm_solver.model = model_fn
123 | nnet.module.reset_cache_features()
124 | random_step = np.random.randint(0, len(order)-1)
125 | random_t = np.round(timesteps[random_step] * 1000).astype(int).repeat(x0.shape[0])
126 |
127 | #print(random_t)
128 | _, _, xn = schedule.get_xn(x0, random_t)
129 | vec_s = torch.ones((xn.shape[0],)).to(xn.device) * timesteps[random_step]
130 | vec_t = torch.ones((xn.shape[0],)).to(xn.device) * timesteps[random_step + 1]
131 | with torch.no_grad():
132 | xn_minus_1 = dpm_solver.dpm_solver_second_update(xn, vec_s, vec_t, return_noise=False, solver_type='dpm_solver')
133 |
134 | random_t_minus_1 = np.round(timesteps[random_step + 1] * 1000).astype(int).repeat(x0.shape[0])
135 | random_t_minus_1 = torch.tensor(random_t_minus_1).to(xn_minus_1.device)
136 |
137 | # Teacher
138 | nnet.module.set_activate_cache(False)
139 | nnet.module.set_record_cache(False)
140 | t_pred = nnet(xn_minus_1, random_t_minus_1, **kwargs)
141 |
142 | # Student
143 | nnet.module.set_activate_cache(True)
144 |
145 | s_pred, l1_loss = nnet(xn_minus_1, random_t_minus_1, **kwargs)
146 |
147 | nnet.module.set_activate_cache(False)
148 | nnet.module.set_record_cache(True)
149 |
150 | return sos(t_pred - s_pred), l1_loss
151 |
152 |
153 | def train(config):
154 | if config.get('benchmark', False):
155 | torch.backends.cudnn.benchmark = True
156 | torch.backends.cudnn.deterministic = False
157 |
158 | mp.set_start_method('spawn')
159 |
160 |
161 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
162 | accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs])
163 | #accelerator = accelerate.Accelerator()
164 | device = accelerator.device
165 | accelerate.utils.set_seed(config.seed, device_specific=True)
166 | logging.info(f'Process {accelerator.process_index} using device: {device}')
167 |
168 | config.mixed_precision = accelerator.mixed_precision
169 | config = ml_collections.FrozenConfigDict(config)
170 |
171 | assert config.train.batch_size % accelerator.num_processes == 0
172 | mini_batch_size = config.train.batch_size // accelerator.num_processes
173 |
174 | if accelerator.is_main_process:
175 | os.makedirs(config.ckpt_root, exist_ok=True)
176 | os.makedirs(config.sample_dir, exist_ok=True)
177 | accelerator.wait_for_everyone()
178 | if accelerator.is_main_process:
179 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
180 | name=config.hparams, job_type='train')#, mode='offline')
181 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
182 | logging.info(config)
183 | else:
184 | utils.set_logger(log_level='error')
185 | builtins.print = lambda *args: None
186 | logging.info(f'Run on {accelerator.num_processes} devices')
187 |
188 | # Load Dataset
189 | dataset = get_dataset(**config.dataset)
190 | assert os.path.exists(dataset.fid_stat)
191 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
192 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
193 | num_workers=8, pin_memory=True, persistent_workers=True)
194 |
195 | # Load Model and Optimizer
196 | train_state = utils.initialize_train_state(config, device)
197 | train_state.nnet.add_router(config.nfe)
198 | router_optim = torch.optim.AdamW(
199 | [param for name, param in train_state.nnet.named_parameters() if "routers" in name],
200 | lr=config.router_lr, weight_decay=0
201 | )
202 | train_state.update_optimizer(router_optim)
203 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
204 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
205 | logging.info(f'load nnet from {config.nnet_path}')
206 | msg = accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'), strict=False)
207 | logging.info(f'load nnet messgae = {config.nnet_path}')
208 |
209 |
210 | lr_scheduler = train_state.lr_scheduler
211 | train_state.resume(config.ckpt_root)
212 |
213 | # Load Autoencoder
214 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
215 | autoencoder.to(device)
216 |
217 | # Setup DPM Solver
218 | _betas = stable_diffusion_beta_schedule()
219 | _schedule = Schedule(_betas)
220 | logging.info(f'use {_schedule}')
221 |
222 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
223 | dpm_solver = DPM_Solver(None, noise_schedule, predict_x0=True, thresholding=False)
224 | t_0 = 1. / _schedule.N
225 | t_T = 1.0
226 | order_value = 2
227 | N_steps = config.nfe // order_value
228 | order = [order_value,] * N_steps
229 | timesteps = dpm_solver.get_time_steps(
230 | skip_type='time_uniform', t_T=t_T, t_0=t_0, N=N_steps, device=device
231 | )
232 | timesteps = timesteps.cpu().numpy()
233 | timestep_mapping = np.round(timesteps * 1000)
234 | accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping)
235 |
236 | @ torch.cuda.amp.autocast()
237 | def encode(_batch):
238 | return autoencoder.encode(_batch)
239 |
240 | @ torch.cuda.amp.autocast()
241 | def decode(_batch):
242 | return autoencoder.decode(_batch)
243 |
244 | def get_data_generator():
245 | while True:
246 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
247 | yield data
248 |
249 | data_generator = get_data_generator()
250 |
251 |
252 | def train_step(_batch):
253 | _metrics = dict()
254 | optimizer.zero_grad()
255 | if config.train.mode == 'uncond':
256 | _z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch)
257 | data_loss, l1_loss = LRouter(_z, nnet, _schedule, order=order, timesteps=timesteps, dpm_solver=dpm_solver, l1_weight=config.l1_weight)
258 | elif config.train.mode == 'cond':
259 | #print("Label = ", _batch[1])
260 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
261 | data_loss, l1_loss = LRouter(_z, nnet, _schedule, y=_batch[1], order=order, timesteps=timesteps, dpm_solver=dpm_solver)
262 | loss = data_loss + config.l1_weight * l1_loss
263 | else:
264 | raise NotImplementedError(config.train.mode)
265 | _metrics['loss'] = accelerator.gather(loss.detach()).mean()
266 | _metrics['data_loss'] = accelerator.gather(data_loss.detach()).mean()
267 | _metrics['l1_loss'] = accelerator.gather(l1_loss.detach()).mean()
268 |
269 |
270 | accelerator.backward(loss.mean())
271 | optimizer.step()
272 | lr_scheduler.step()
273 | train_state.step += 1
274 |
275 | #print("Router 0:", nnet.module.routers[0].prob.data)
276 | #print("Router 1:", nnet.module.routers[1].prob.data)
277 | #print()
278 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
279 |
280 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
281 |
282 | loss_metrics = 0
283 | data_loss_metrics = 0
284 | l1_loss_metrics = 0
285 | while train_state.step < config.train.n_steps:
286 | nnet.train()
287 | batch = tree_map(lambda x: x.to(device), next(data_generator))
288 | metrics = train_step(batch)
289 |
290 | if accelerator.is_main_process:
291 | loss_metrics += metrics['loss']
292 | data_loss_metrics += metrics['data_loss']
293 | l1_loss_metrics += metrics['l1_loss']
294 |
295 | nnet.eval()
296 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
297 | scores = [nnet.module.routers[idx]() for idx in range(1, config.nfe//2)]
298 | mask = format_image_to_wandb(config.nfe//2-1, nnet.module.depth*2, scores)
299 | mask = wandb.Image(
300 | mask,
301 | )
302 | metrics['loss'] = loss_metrics / config.train.log_interval
303 | metrics['data_loss'] = data_loss_metrics / config.train.log_interval
304 | metrics['l1_loss'] = l1_loss_metrics / config.train.log_interval
305 | final_score = [sum(score) for score in scores]
306 | metrics['non_zero'] = sum(final_score) / (len(final_score) * len(scores[0]))
307 |
308 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
309 | metrics['router'] = mask
310 | #logging.info(config.workdir)
311 | wandb.log(metrics, step=train_state.step)
312 | loss_metrics, data_loss_metrics, l1_loss_metrics = 0, 0, 0
313 |
314 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
315 | torch.cuda.empty_cache()
316 | logging.info(f'Save and eval checkpoint {train_state.step}...')
317 | if accelerator.local_process_index == 0:
318 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
319 | accelerator.wait_for_everyone()
320 | #fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
321 | #step_fid.append((train_state.step, fid))
322 | torch.cuda.empty_cache()
323 | accelerator.wait_for_everyone()
324 |
325 | logging.info(f'Finish fitting, step={train_state.step}')
326 |
327 |
328 |
329 | from absl import flags
330 | from absl import app
331 | from ml_collections import config_flags
332 | import sys
333 | from pathlib import Path
334 |
335 |
336 | FLAGS = flags.FLAGS
337 | config_flags.DEFINE_config_file(
338 | "config", None, "Training configuration.", lock_config=False)
339 | flags.mark_flags_as_required(["config"])
340 | flags.DEFINE_string("workdir", None, "Work unit directory.")
341 | flags.DEFINE_string("nfe", None, "NFE")
342 | flags.DEFINE_string("router_lr", None, "learning rate for router")
343 | flags.DEFINE_string("l1_weight", None, "l1 weight for router loss")
344 | flags.DEFINE_string("nnet_path", None, "l1 weight for router loss")
345 |
346 |
347 |
348 | def get_config_name():
349 | argv = sys.argv
350 | for i in range(1, len(argv)):
351 | if argv[i].startswith('--config='):
352 | return Path(argv[i].split('=')[-1]).stem
353 |
354 |
355 | def get_hparams():
356 | argv = sys.argv
357 | lst = []
358 | for i in range(1, len(argv)):
359 | assert '=' in argv[i]
360 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
361 | hparam, val = argv[i].split('=')
362 | hparam = hparam.split('.')[-1]
363 | if hparam.endswith('path'):
364 | val = Path(val).stem
365 | lst.append(f'{hparam}={val}')
366 | hparams = '-'.join(lst)
367 | if hparams == '':
368 | hparams = 'default'
369 | return hparams
370 |
371 |
372 | def main(argv):
373 | config = FLAGS.config
374 | config.nfe = int(FLAGS.nfe)
375 | config.router_lr = float(FLAGS.router_lr)
376 | config.l1_weight = float(FLAGS.l1_weight)
377 | config.nnet_path = FLAGS.nnet_path
378 | config.config_name = get_config_name()
379 | config.hparams = get_hparams()
380 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
381 | config.ckpt_root = os.path.join(config.workdir, 'ckpts')
382 | config.sample_dir = os.path.join(config.workdir, 'samples')
383 | train(config)
384 |
385 |
386 | if __name__ == "__main__":
387 | app.run(main)
388 |
--------------------------------------------------------------------------------
/DiT/train_router.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | A minimal training script for DiT using PyTorch DDP.
9 | """
10 | import torch
11 | # the first flag below was False when we tested this script but True makes A100 training a lot faster:
12 | torch.backends.cuda.matmul.allow_tf32 = True
13 | torch.backends.cudnn.allow_tf32 = True
14 | import torch.distributed as dist
15 | from torch.nn.parallel import DistributedDataParallel as DDP
16 | from torch.utils.data import DataLoader
17 | from torch.utils.data.distributed import DistributedSampler
18 | from torchvision.datasets import ImageFolder
19 | from torchvision import transforms
20 | import numpy as np
21 | from collections import OrderedDict
22 | from PIL import Image
23 | from copy import deepcopy
24 | from glob import glob
25 | from time import time
26 | import argparse
27 | import logging
28 | import os
29 |
30 | from models.router_models import DiT_models, STE
31 | from diffusion import create_diffusion
32 | from diffusers.models import AutoencoderKL
33 | from download import find_model
34 |
35 |
36 | #################################################################################
37 | # Training Helper Functions #
38 | #################################################################################
39 |
40 | @torch.no_grad()
41 | def update_ema(ema_model, model, decay=0.9999):
42 | """
43 | Step the EMA model towards the current model.
44 | """
45 | ema_params = OrderedDict(ema_model.named_parameters())
46 | model_params = OrderedDict(model.named_parameters())
47 |
48 | for name, param in model_params.items():
49 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
50 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
51 |
52 |
53 | def requires_grad(model, flag=True):
54 | """
55 | Set requires_grad flag for all parameters in a model.
56 | """
57 | for p in model.parameters():
58 | p.requires_grad = flag
59 |
60 |
61 | def cleanup():
62 | """
63 | End DDP training.
64 | """
65 | dist.destroy_process_group()
66 |
67 |
68 | def create_logger(logging_dir):
69 | """
70 | Create a logger that writes to a log file and stdout.
71 | """
72 | if dist.get_rank() == 0: # real logger
73 | logging.basicConfig(
74 | level=logging.INFO,
75 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
76 | datefmt='%Y-%m-%d %H:%M:%S',
77 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
78 | )
79 | logger = logging.getLogger(__name__)
80 | else: # dummy logger (does nothing)
81 | logger = logging.getLogger(__name__)
82 | logger.addHandler(logging.NullHandler())
83 | return logger
84 |
85 | def format_image_to_wandb(num_router, router_size, router_scores):
86 | image = np.zeros((num_router, router_size, 3), dtype=np.float32)
87 | ones = np.ones((3), dtype=np.float32)
88 | for idx, score in enumerate(router_scores):
89 | mask = score.cpu().detach()
90 | for pos in range(router_size):
91 | image[idx, pos] = ones * mask[pos].item()
92 | return image
93 |
94 |
95 | def center_crop_arr(pil_image, image_size):
96 | """
97 | Center cropping implementation from ADM.
98 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
99 | """
100 | while min(*pil_image.size) >= 2 * image_size:
101 | pil_image = pil_image.resize(
102 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
103 | )
104 |
105 | scale = image_size / min(*pil_image.size)
106 | pil_image = pil_image.resize(
107 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
108 | )
109 |
110 | arr = np.array(pil_image)
111 | crop_y = (arr.shape[0] - image_size) // 2
112 | crop_x = (arr.shape[1] - image_size) // 2
113 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
114 |
115 |
116 | #################################################################################
117 | # Training Loop #
118 | #################################################################################
119 |
120 | def main(args):
121 | """
122 | Trains a new DiT model.
123 | """
124 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
125 |
126 | # Setup DDP:
127 | dist.init_process_group("nccl")
128 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
129 | rank = dist.get_rank()
130 | device = rank % torch.cuda.device_count()
131 | seed = args.global_seed * dist.get_world_size() + rank
132 | torch.manual_seed(seed)
133 | torch.cuda.set_device(device)
134 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
135 |
136 | # Setup an experiment folder:
137 | if rank == 0:
138 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
139 | experiment_index = len(glob(f"{args.results_dir}/*"))
140 | model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
141 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
142 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
143 | os.makedirs(checkpoint_dir, exist_ok=True)
144 | logger = create_logger(experiment_dir)
145 | logger.info(f"Experiment directory created at {experiment_dir}")
146 | else:
147 | logger = create_logger(None)
148 |
149 | # Create model:
150 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
151 | latent_size = args.image_size // 8
152 |
153 |
154 | model = DiT_models[args.model](
155 | input_size=latent_size,
156 | num_classes=args.num_classes
157 | ).to(device)
158 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py:
159 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
160 | state_dict = find_model(ckpt_path)
161 | msg = model.load_state_dict(state_dict, strict=False)
162 | if rank == 0:
163 | logger.info(f"Loaded model from {ckpt_path} with msg: {msg}")
164 | model.eval() # important!
165 |
166 | diffusion = create_diffusion(str(args.num_sampling_steps))
167 | model.add_router(args.num_sampling_steps, diffusion.timestep_map)
168 | model = DDP(model.to(device), device_ids=[rank], find_unused_parameters=True)
169 |
170 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
171 | logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
172 |
173 | #routers = [Router(len(model.module.blocks)*2) for _ in range(args.num_sampling_steps//2)]
174 | #routers = [DDP(r.to(device), device_ids=[rank]) for r in routers]
175 | opts = torch.optim.AdamW(
176 | [param for name, param in model.named_parameters() if "routers" in name],
177 | lr=args.lr, weight_decay=0
178 | )
179 |
180 | # Setup data:
181 | transform = transforms.Compose([
182 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
183 | transforms.RandomHorizontalFlip(),
184 | transforms.ToTensor(),
185 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
186 | ])
187 | dataset = ImageFolder(args.data_path, transform=transform)
188 | sampler = DistributedSampler(
189 | dataset,
190 | num_replicas=dist.get_world_size(),
191 | rank=rank,
192 | shuffle=True,
193 | seed=args.global_seed
194 | )
195 | loader = DataLoader(
196 | dataset,
197 | batch_size=int(args.global_batch_size // dist.get_world_size()),
198 | shuffle=False,
199 | sampler=sampler,
200 | num_workers=args.num_workers,
201 | pin_memory=True,
202 | drop_last=True
203 | )
204 | logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
205 |
206 | if args.wandb and rank == 0:
207 | import wandb
208 | wandb.init(
209 | # Set the project where this run will be logged
210 | project="DiT-Router",
211 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
212 | name=f"{experiment_index:03d}-{model_string_name}",
213 | # Track hyperparameters and run metadata
214 | config=args.__dict__
215 | )
216 | wandb.define_metric("step")
217 | wandb.define_metric("loss", step_metric="step")
218 |
219 | # Prepare models for training:
220 | #update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
221 | model.train() # important! We need to use embedding dropout for classifier-free guidance here.
222 | #ema.eval() # EMA model should always be in eval mode
223 |
224 | # Variables for monitoring/logging purposes:
225 | train_steps = 0
226 | log_steps = 0
227 | running_loss = 0
228 | running_data_loss, running_l1_loss = 0, 0
229 | start_time = time()
230 |
231 | logger.info(f"Training for {args.epochs} epochs...")
232 | for epoch in range(args.epochs):
233 | sampler.set_epoch(epoch)
234 | logger.info(f"Beginning epoch {epoch}...")
235 | for x, y in loader:
236 | x = x.to(device)
237 | y = y.to(device)
238 |
239 | with torch.no_grad():
240 | # Map input images to latent space + normalize latents:
241 | x = vae.encode(x).latent_dist.sample().mul_(0.18215)
242 | model_kwargs = dict(y=y, thres=args.ste_threshold)
243 |
244 | #t = 1+2*torch.randint(0, diffusion.num_timesteps//2, (x.shape[0],), device=device)
245 | t = torch.randint(0, diffusion.num_timesteps//2, (1,), device=device)
246 | #t = torch.tensor(2, device=device)
247 | ts = t.repeat(x.shape[0])*2 + 1
248 |
249 | loss_dict = diffusion.router_training_losses(model, x, ts, model_kwargs)
250 | data_loss = loss_dict["mse"].mean()
251 | l1_loss = loss_dict["l1_loss"].mean()
252 |
253 | #print(f"Rank: {rank}, t: {t}, data loss: {data_loss}. L1 loss: {l1_loss}")
254 | loss = data_loss + args.l1 * l1_loss
255 | opts.zero_grad()
256 | model.zero_grad()
257 |
258 | loss.backward()
259 | #for idx, router in enumerate(model.module.routers):
260 | # print(f"Rank: {rank}, idx: {idx}, ", router.prob.grad)
261 | opts.step()
262 |
263 | with torch.no_grad():
264 | for name, param in model.named_parameters():
265 | if "routers" in name:
266 | param.clamp_(-5, 5)
267 |
268 | # Log loss values:
269 | running_loss += loss.item()
270 | running_data_loss += data_loss.item()
271 | running_l1_loss += args.l1 * l1_loss.item()
272 |
273 | log_steps += 1
274 | train_steps += 1
275 |
276 | model.module.reset()
277 |
278 | if train_steps % args.log_every == 0:
279 | # Measure training speed:
280 | torch.cuda.synchronize()
281 | end_time = time()
282 | steps_per_sec = log_steps / (end_time - start_time)
283 |
284 | # Reduce loss history over all processes:
285 | for name, loss in [("loss", running_loss), ("data_loss", running_data_loss), ("l1_loss", running_l1_loss)]:
286 | loss = torch.tensor(loss / log_steps, device=device)
287 | dist.all_reduce(loss, op=dist.ReduceOp.SUM)
288 | loss = loss.item() / dist.get_world_size()
289 | logger.info(f"(step={train_steps:07d}) Train {name} Loss: {loss:.7f}, Train Steps/Sec: {steps_per_sec:.2f}")
290 |
291 | scores = [model.module.routers[idx]() for idx in range(0, args.num_sampling_steps, 2)]
292 |
293 | if args.wandb and rank == 0:
294 | #print(scores)
295 | mask = format_image_to_wandb(args.num_sampling_steps//2 , model.module.depth*2, scores)
296 | mask = wandb.Image(
297 | mask,
298 | )
299 | if args.ste_threshold is not None:
300 | final_score = [sum(STE.apply(score, args.ste_threshold)) for score in scores]
301 | else:
302 | final_score = [sum(score) for score in scores]
303 | wandb.log({
304 | "step": train_steps,
305 | "loss": loss,
306 | "data_loss": running_data_loss / log_steps,
307 | "l1_loss": running_l1_loss / log_steps,
308 | "non_zero": sum(final_score),
309 | "router": mask
310 | })
311 |
312 | # Reset monitoring variables:
313 | running_loss = 0
314 | running_data_loss, running_l1_loss = 0, 0
315 | log_steps = 0
316 | start_time = time()
317 |
318 | # Save DiT checkpoint:
319 | if train_steps % args.ckpt_every == 0 and train_steps > 0:
320 | if rank == 0:
321 | checkpoint = {
322 | #"model": model.module.state_dict(),
323 | #"ema": ema.state_dict(),
324 | "routers": model.module.routers.state_dict(),
325 | "opt": opts.state_dict(),
326 | "args": args
327 | }
328 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
329 | torch.save(checkpoint, checkpoint_path)
330 | logger.info(f"Saved checkpoint to {checkpoint_path}")
331 | dist.barrier()
332 |
333 | if train_steps > args.max_steps:
334 | print("Reach Maximum Step")
335 | break
336 |
337 | model.eval() # important! This disables randomized embedding dropout
338 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
339 |
340 | logger.info("Done!")
341 | cleanup()
342 |
343 |
344 | if __name__ == "__main__":
345 | # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
346 | parser = argparse.ArgumentParser()
347 | parser.add_argument("--data-path", type=str, required=True)
348 | parser.add_argument("--results-dir", type=str, default="results")
349 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
350 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
351 | parser.add_argument("--num-classes", type=int, default=1000)
352 | parser.add_argument("--epochs", type=int, default=1)
353 | parser.add_argument("--global-batch-size", type=int, default=256)
354 | parser.add_argument("--global-seed", type=int, default=0)
355 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
356 | parser.add_argument("--num-workers", type=int, default=4)
357 | parser.add_argument("--log-every", type=int, default=100)
358 | parser.add_argument("--ckpt-every", type=int, default=50_000)
359 | parser.add_argument("--wandb", action="store_true")
360 |
361 | parser.add_argument("--ckpt", type=str, default=None)
362 | #parser.add_argument("--cfg-scale", type=float, required=True)
363 | parser.add_argument("--num-sampling-steps", type=int, default=20)
364 | parser.add_argument("--l1", type=float, default=1.0)
365 |
366 | parser.add_argument("--lr", type=float, default=1.0)
367 | parser.add_argument("--max-steps", type=int, default=50000)
368 |
369 | parser.add_argument("--ste-threshold", type=float, default=None)
370 |
371 | args = parser.parse_args()
372 | main(args)
373 |
--------------------------------------------------------------------------------