├── U-ViT ├── assets │ ├── fid_stats │ └── stable-diffusion ├── libs │ ├── __init__.py │ ├── clip.py │ ├── timm.py │ ├── uvit_t2i.py │ ├── uvit.py │ ├── uvit_dynamic.py │ └── uvit_router.py ├── u-vit.gif ├── ckpt │ ├── dpm20_router.pth │ └── dpm50_router.pth ├── .gitignore ├── fid.py ├── configs │ ├── imagenet256_uvit_huge.py │ ├── imagenet256_uvit_huge_dynamic_cache.py │ └── imagenet256_uvit_huge_router.py ├── readme.md ├── tools │ ├── read_npz.py │ ├── fid_score.py │ └── inception.py ├── eval.py ├── sample_ldm_discrete.py ├── eval_ldm_discrete.py ├── utils.py ├── sde.py └── train_router_discrete.py ├── DiT ├── requirement.txt ├── assets │ └── dit.gif ├── .gitignore ├── ckpt │ ├── DDIM20_router.pt │ └── DDIM50_router.pt ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── respace.py │ └── timestep_sampler.py ├── download.py ├── README.md ├── sample.py ├── sample_ddp.py ├── models │ └── models.py └── train_router.py ├── assets ├── teaser.png ├── dit_baseline.png └── uvit_baseline.png └── README.md /U-ViT/assets/fid_stats: -------------------------------------------------------------------------------- 1 | ../../../U-ViT/assets/fid_stats -------------------------------------------------------------------------------- /U-ViT/libs/__init__.py: -------------------------------------------------------------------------------- 1 | # codes from third party 2 | -------------------------------------------------------------------------------- /U-ViT/assets/stable-diffusion: -------------------------------------------------------------------------------- 1 | ../../../U-ViT/assets/stable-diffusion -------------------------------------------------------------------------------- /DiT/requirement.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | torchvision 3 | timm 4 | diffusers 5 | accelerate -------------------------------------------------------------------------------- /U-ViT/u-vit.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/U-ViT/u-vit.gif -------------------------------------------------------------------------------- /DiT/assets/dit.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/DiT/assets/dit.gif -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /DiT/.gitignore: -------------------------------------------------------------------------------- 1 | pretrained_models/ 2 | *.png 3 | __pycache__ 4 | *.pb 5 | samples/ 6 | results/ 7 | wandb/ -------------------------------------------------------------------------------- /assets/dit_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/assets/dit_baseline.png -------------------------------------------------------------------------------- /DiT/ckpt/DDIM20_router.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/DiT/ckpt/DDIM20_router.pt -------------------------------------------------------------------------------- /DiT/ckpt/DDIM50_router.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/DiT/ckpt/DDIM50_router.pt -------------------------------------------------------------------------------- /assets/uvit_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/assets/uvit_baseline.png -------------------------------------------------------------------------------- /U-ViT/ckpt/dpm20_router.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/U-ViT/ckpt/dpm20_router.pth -------------------------------------------------------------------------------- /U-ViT/ckpt/dpm50_router.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horseee/learning-to-cache/HEAD/U-ViT/ckpt/dpm50_router.pth -------------------------------------------------------------------------------- /U-ViT/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | assets/fid_stats/*.npz 3 | assets/stable-diffusion/*.pth 4 | imagenet256_uvit_huge.pth 5 | samples 6 | *.png 7 | workdir -------------------------------------------------------------------------------- /U-ViT/fid.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import sys 3 | 4 | if __name__ == '__main__': 5 | sample_npz_path = sys.argv[1] 6 | res = sys.argv[2] 7 | 8 | if res == '256': 9 | ref_path = 'assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz' 10 | elif res == '512': 11 | ref_path = 'assets/fid_stats/fid_stats_imagenet512_guided_diffusion.npz' 12 | else: 13 | raise NotImplementedError 14 | fid_value = calculate_fid_given_paths([ref_path, sample_npz_path], batch_size=1000) 15 | print(fid_value) 16 | 17 | -------------------------------------------------------------------------------- /U-ViT/libs/clip.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import CLIPTokenizer, CLIPTextModel 3 | 4 | 5 | class AbstractEncoder(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def encode(self, *args, **kwargs): 10 | raise NotImplementedError 11 | 12 | 13 | class FrozenCLIPEmbedder(AbstractEncoder): 14 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 16 | super().__init__() 17 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 18 | self.transformer = CLIPTextModel.from_pretrained(version) 19 | self.device = device 20 | self.max_length = max_length 21 | self.freeze() 22 | 23 | def freeze(self): 24 | self.transformer = self.transformer.eval() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def forward(self, text): 29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 31 | tokens = batch_encoding["input_ids"].to(self.device) 32 | outputs = self.transformer(input_ids=tokens) 33 | 34 | z = outputs.last_hidden_state 35 | return z 36 | 37 | def encode(self, text): 38 | return self(text) 39 | -------------------------------------------------------------------------------- /DiT/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet256_uvit_huge.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=32, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=32, 44 | patch_size=2, 45 | in_chans=4, 46 | embed_dim=1152, 47 | depth=28, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True, 54 | conv=False 55 | ) 56 | 57 | config.dataset = d( 58 | name='imagenet256_features', 59 | path='assets/datasets/imagenet256_features', 60 | cfg=True, 61 | p_uncond=0.1 62 | ) 63 | 64 | config.sample = d( 65 | n_samples=50000, 66 | mini_batch_size=50, # the decoder is large 67 | algorithm='dpm_solver', 68 | cfg=True, 69 | scale=0.4, 70 | path='' 71 | ) 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet256_uvit_huge_dynamic_cache.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit_dynamic', 43 | img_size=32, 44 | patch_size=2, 45 | in_chans=4, 46 | embed_dim=1152, 47 | depth=28, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True, 54 | conv=False 55 | ) 56 | 57 | config.dataset = d( 58 | name='imagenet256_features', 59 | path='assets/datasets/imagenet256_features', 60 | cfg=True, 61 | p_uncond=0.1 62 | ) 63 | 64 | config.sample = d( 65 | n_samples=50000, 66 | mini_batch_size=50, # the decoder is large 67 | algorithm='dpm_solver', 68 | cfg=True, 69 | scale=0.4, 70 | path='', 71 | dynamic=True 72 | ) 73 | 74 | return config 75 | -------------------------------------------------------------------------------- /U-ViT/configs/imagenet256_uvit_huge_router.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.nnet_path='imagenet256_uvit_huge.pth' 17 | config.autoencoder = d( 18 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 19 | ) 20 | 21 | config.train = d( 22 | n_steps=40000, 23 | batch_size=64, 24 | mode='cond', 25 | log_interval=100, 26 | eval_interval=5000, 27 | save_interval=1000, 28 | ) 29 | 30 | config.optimizer = d( 31 | name='adamw', 32 | lr=0.0002, 33 | weight_decay=0.03, 34 | betas=(0.99, 0.99), 35 | ) 36 | 37 | config.lr_scheduler = d( 38 | name='customized', 39 | warmup_steps=5000 40 | ) 41 | 42 | config.nnet = d( 43 | name='uvit_router', 44 | img_size=32, 45 | patch_size=2, 46 | in_chans=4, 47 | embed_dim=1152, 48 | depth=28, 49 | num_heads=16, 50 | mlp_ratio=4, 51 | qkv_bias=False, 52 | mlp_time_embed=False, 53 | num_classes=1001, 54 | use_checkpoint=True, 55 | conv=False 56 | ) 57 | 58 | config.dataset = d( 59 | name='imagenet', 60 | path='PATH_TO_IMAGENET', 61 | resolution=256, 62 | cfg=True, 63 | p_uncond=0.1 64 | ) 65 | 66 | config.sample = d( 67 | n_samples=50000, 68 | mini_batch_size=50, # the decoder is large 69 | algorithm='dpm_solver', 70 | cfg=True, 71 | scale=0.4, 72 | path='' 73 | ) 74 | 75 | return config 76 | -------------------------------------------------------------------------------- /DiT/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Functions for downloading pre-trained DiT models 9 | """ 10 | from torchvision.datasets.utils import download_url 11 | import torch 12 | import os 13 | 14 | 15 | pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'} 16 | 17 | 18 | def find_model(model_name): 19 | """ 20 | Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. 21 | """ 22 | if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints 23 | return download_model(model_name) 24 | else: # Load a custom DiT checkpoint: 25 | assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}' 26 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) 27 | if "ema" in checkpoint: # supports checkpoints from train.py 28 | checkpoint = checkpoint["ema"] 29 | return checkpoint 30 | 31 | 32 | def download_model(model_name): 33 | """ 34 | Downloads a pre-trained DiT model from the web. 35 | """ 36 | assert model_name in pretrained_models 37 | local_path = f'pretrained_models/{model_name}' 38 | if not os.path.isfile(local_path): 39 | os.makedirs('pretrained_models', exist_ok=True) 40 | web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}' 41 | download_url(web_path, 'pretrained_models') 42 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 43 | return model 44 | 45 | 46 | if __name__ == "__main__": 47 | # Download all DiT checkpoints 48 | for model in pretrained_models: 49 | download_model(model) 50 | print('Done.') 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching 2 |
3 | 4 |
5 | 6 | (Results on DiT-XL/2 and U-ViT-H/2) 7 | 8 |
9 |
10 | 11 | > **Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching** 🥯[[Arxiv]](https://arxiv.org/abs/2406.01733) 12 | > [Xinyin Ma](https://horseee.github.io/), [Gongfan Fang](https://fangggf.github.io/), [Michael Bi Mi](), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) 13 | > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore, Huawei Technologies Ltd 14 | 15 | 16 | 17 | 18 | ## Introduction 19 | We introduce a novel scheme, named **L**earning-to-**C**ache (L2C), that learns to conduct caching in a dynamic manner for diffusion transformers. A router is optimized to decide the layers to be cached. 20 | 21 |
22 | 23 |
24 | 25 | (Changes in the router for U-ViT when optimizing across different layers (x-axis) over all steps (y-axis). The white indicates the layer is activated, while the black indicates it is disabled.) 26 | 27 |
28 | 29 | 30 | **Some takeaways**: 31 | 32 | 1. A large proportion of layers in the diffusion transformer can be removed, without updating the model parameters. 33 | - In U-ViT-H/2, up to 93.68% of the layers in the cache steps (46.84% for all steps) can be removed, with less than 0.01 drop in FID. 34 | 35 | 2. L2C largely outperforms samplers such as DDIM and DPM-Solver. 36 | 37 |
38 | 39 | 40 |
41 | 42 | (Comparison with Baselines. Left: DiT-XL/2. Right: U-ViT-H/2) 43 | 44 |
45 | 46 | ## Checkpoint for Routers 47 | | Model | NFE | Checkpoint | 48 | | -- | -- | -- | 49 | | DiT-XL/2 | 50 | [link](DiT/ckpt/DDIM50_router.pt) | 50 | | DiT-XL/2 | 20 | [link](DiT/ckpt/DDIM20_router.pt) | 51 | | U-ViT-H/2 | 50 | [link](U-ViT/ckpt/dpm50_router.pth) | 52 | | U-ViT-H/2 | 20 | [link](U-ViT/ckpt/dpm20_router.pth)| 53 | 54 | ## Code 55 | We implement Learning-to-Cache on two basic structures: DiT and U-ViT. Check the instructions below: 56 | 57 | 1. DiT: [README](https://github.com/horseee/learning-to-cache/tree/main/DiT#learning-to-cache-for-dit) 58 | 2. U-ViT: [README](https://github.com/horseee/learning-to-cache/blob/main/U-ViT/readme.md) 59 | 60 | ## Citation 61 | ``` 62 | @misc{ma2024learningtocache, 63 | title={Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching}, 64 | author={Xinyin Ma and Gongfan Fang and Michael Bi Mi and Xinchao Wang}, 65 | year={2024}, 66 | eprint={2406.01733}, 67 | archivePrefix={arXiv}, 68 | primaryClass={cs.LG} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /U-ViT/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## Preparation 3 | 4 | Please follow [U-ViT](https://github.com/baofff/U-ViT) to: 5 | 1. Prepara the environment and install necessary packages 6 | 2. Download the autoencoder and the reference statistic for FID in `assets/` 7 | 3. Download the model [imagenet 256x256(U-ViT-H/2)](https://drive.google.com/file/d/13StUdrjaaSXjfqqF7M47BzPyhMAArQ4u/view?usp=share_link) and put it here. 8 | 9 | After completing the above steps, those files would be contained in the directory: 10 | ``` 11 | - imagenet256_uvit_huge.pth 12 | - assets 13 | | - fid_stats 14 | | - fid_stats_imagenet256_guided_diffusion.npz 15 | | - ... 16 | | - stable-diffusion 17 | | - autoencoder_kl_ema.pth 18 | | - autoencoder_kl.pth 19 | ``` 20 | 21 | ## Sample Images 22 | For 20 NFEs in DPM-Solver: 23 | ```bash 24 | python sample_ldm_discrete.py --config configs/imagenet256_uvit_huge_dynamic_cache.py --nnet_path imagenet256_uvit_huge.pth --nfe 20 --router ckpt/dpm20_router.pth --thres 0.9 25 | ``` 26 | 27 | For 50 NFEs in DPM-Solver: 28 | ```bash 29 | python sample_ldm_discrete.py --config configs/imagenet256_uvit_huge_dynamic_cache.py --nnet_path imagenet256_uvit_huge.pth --nfe 50 --router ckpt/dpm50_router.pth --thres 0.9 30 | ``` 31 | 32 | The code would repeat the generation for 5 times to avoid the fluctuations in the inference time. If you want to see the images without acceleration, you can use the follwing command: 33 | 34 | ```bash 35 | python sample_ldm_discrete.py --config configs/imagenet256_uvit_huge.py --nnet_path imagenet256_uvit_huge.pth --nfe 50 36 | ``` 37 | 38 | ## Sample 50k Images for Evaluation 39 | 40 | ```bash 41 | export NFE=50 42 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm_discrete.py --config=configs/imagenet256_uvit_huge_dynamic_cache.py --nnet_path=imagenet256_uvit_huge.pth --config.sample.path=samples/dpm${NFE}_router --nfe=$NFE --router ckpt/dpm${NFE}_router.pth --thres 0.9 43 | ``` 44 | 45 | The FID would be automatically evaluated after the images are all sampled. Be sure to modify NUM_STEPS and PATH_TO_TRAINED_ROUTER to correspond to the respective NFE steps and the location of the router. 46 | 47 | Results: 48 | 49 | | NFE | Router | FID | 50 | | -- | -- | -- | 51 | | 50 | - | 2.3728 | 52 | | 50 | ckpt/dpm50_router.pth | 2.3625 | 53 | | 20 | - | 2.5739 | 54 | | 20 | ckpt/dpm20_router.pth | 2.5809| 55 | 56 | 57 | ## Train the router 58 | Execute the following command to train the router: 59 | ``` 60 | accelerate launch --multi_gpu --main_process_port 18100 --num_processes 8 --mixed_precision fp16 train_router_discrete.py --config=configs/imagenet256_uvit_huge_router.py --config.dataset.path=PATH_TO_IMAGENET --nnet_path=imagenet256_uvit_huge.pth --nfe=20 --router_lr=0.001 --l1_weight=0.1 --workdir=workdir/uvit_router_l1_0.1 61 | ``` 62 | Change `PATH_TO_IMAGENET` to your path to the imagenet dataset. 63 | 64 |
65 | 66 |
67 | 68 | (Changes in the router during training) 69 | 70 |
71 | 72 | 73 | ## Acknowledgement 74 | This implementation is based on [U-ViT](https://github.com/baofff/U-ViT) -------------------------------------------------------------------------------- /DiT/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /DiT/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Learning-to-Cache for DiT 3 | 4 | ## Requirement 5 | With pytorch(>2.0) installed, execute the following command to install necessary packages 6 | ``` 7 | pip install accelerate diffusers timm torchvision wandb 8 | ``` 9 | 10 | ## Sample Image 11 | For DDIM-20: 12 | ``` 13 | python sample.py --model DiT-XL/2 --num-sampling-steps 20 --ddim-sample --accelerate-method dynamiclayer --path ckpt/DDIM20_router.pt --thres 0.1 14 | ``` 15 | 16 | For DDIM-50: 17 | ``` 18 | python sample.py --model DiT-XL/2 --num-sampling-steps 50 --ddim-sample --accelerate-method dynamiclayer --path ckpt/DDIM50_router.pt --thres 0.1 19 | ``` 20 | The code would repeat the generation for 5 times to avoid the fluctuations in the inference time. If you want to see the images without acceleration, you can use the follwing command: 21 | ``` 22 | python sample.py --model DiT-XL/2 --num-sampling-steps 20 --ddim-sample 23 | ``` 24 | 25 | ## Sample 50k images for Evaluation 26 | If you want to reproduce the FID results from the paper, you can use the following command to sample 50k images: 27 | ``` 28 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 12345 sample_ddp.py --model DiT-XL/2 --num-sampling-steps NUM_STEPS --ddim-sample --accelerate-method dynamiclayer --path PATH_TO_TRAINED_ROUTER --thres 0.1 29 | ``` 30 | Be sure to modify NUM_STEPS and PATH_TO_TRAINED_ROUTER to correspond to the respective NFE steps and the location of the router. 31 | 32 | ## Calculate FID 33 | We follow DiT to evaluate FID by [the code](https://github.com/openai/guided-diffusion/tree/main/evaluations). Please install the required packages, download the pre-computed sample batches, and then run the following command: 34 | ``` 35 | python evaluator.py ~/ckpt/VIRTUAL_imagenet256_labeled.npz PATH_TO_NPZ 36 | ``` 37 | 38 | Results: 39 | 40 | | NFE | Router | IS | sFID | FID | Precision | Recall | Latency | 41 | | -- | -- | -- | -- | -- | -- | -- | -- | 42 | | 50 | - | 238.64 | 2.264 | 4.290 | 80.16 | 59.89 | 7.245±0.029 | 43 | | 50 | ckpt/DDIM50_router.pt | 244.14 | 2.269| 4.226| 80.91| 58.80 | 5.568±0.017 | 44 | | 20 | - | 223.49 | 3.484 | 4.892 | 78.76 | 57.07 | 2.869±0.008 | 45 | | 20 | ckpt/DDIM20_router.pt | 227.04 | 3.455| 4.644| 79.16| 55.58 | 2.261±0.005 | 46 | 47 | 48 | ## Training 49 | Here is the command for training the router. Make sure you change the PATH_TO_IMAGENET_TRAIN to your path for the training set of ImageNet. 50 | ``` 51 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 12345 train_router.py --model DiT-XL/2 --data-path PATH_TO_IMAGENET_TRAIN --global-batch-size 64 --image-size 256 --ckpt-every 1000 --l1 5e-6 --lr 0.001 --wandb 52 | ``` 53 | The checkpoint for the router would be saved in `results/XXX-DiT-XL-2/checkpoints`. You can also observe the changes in the router during the learning process on wandb. 54 | 55 |
56 | 57 |
58 | 59 | (Changes in the router during training) 60 | 61 |
62 | 63 | * Hyperoarameters for training the routers: 64 | 65 | | Model | DiT-XL/2 | DiT-XL/2 | DiT-XL/2 | DiT-XL/2 | DiT-L/2 | DiT-L/2 | 66 | | -- | -- | -- | -- | -- | -- | -- | 67 | | NFE | 50 | 20 | 10 | 50 | 50 | 20 | 68 | | Resolution | 256 | 256 | 256 | 512 | 256 | 256 | 69 | | - For Train | | | | | | 70 | | \lambda (--l1) | 1e-6 | 5e-6 | 1e-6 | 5e-6 | 1e-6 | 5e-6 | 71 | | learning rate (--lr) | 1e-3 | 1e-3 | 1e-3 | 1e-3 | 1e-3 | 1e-2 | 72 | | - For Inference | | | | | | 73 | | \theta (--thres) | 0.1 | 0.1 | 0.1 | 0.9 | 0.1 | 0.1 | 0.1 | 74 | 75 | 76 | 77 | 78 | ## Acknowledgement 79 | This implementation is based on [DiT](https://github.com/facebookresearch/DiT). 80 | 81 | -------------------------------------------------------------------------------- /U-ViT/tools/read_npz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from contextlib import contextmanager 4 | import zipfile 5 | from abc import ABC, abstractmethod 6 | from typing import Iterable, Optional, Tuple 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | class NpzArrayReader(ABC): 11 | @abstractmethod 12 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 13 | pass 14 | 15 | @abstractmethod 16 | def remaining(self) -> int: 17 | pass 18 | 19 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: 20 | def gen_fn(): 21 | while True: 22 | batch = self.read_batch(batch_size) 23 | if batch is None: 24 | break 25 | yield batch 26 | 27 | rem = self.remaining() 28 | num_batches = rem // batch_size + int(rem % batch_size != 0) 29 | return BatchIterator(gen_fn, num_batches) 30 | 31 | class StreamingNpzArrayReader(NpzArrayReader): 32 | def __init__(self, arr_f, shape, dtype): 33 | self.arr_f = arr_f 34 | self.shape = shape 35 | self.dtype = dtype 36 | self.idx = 0 37 | 38 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 39 | if self.idx >= self.shape[0]: 40 | return None 41 | 42 | bs = min(batch_size, self.shape[0] - self.idx) 43 | self.idx += bs 44 | 45 | if self.dtype.itemsize == 0: 46 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) 47 | 48 | read_count = bs * np.prod(self.shape[1:]) 49 | read_size = int(read_count * self.dtype.itemsize) 50 | data = _read_bytes(self.arr_f, read_size, "array data") 51 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) 52 | 53 | def remaining(self) -> int: 54 | return max(0, self.shape[0] - self.idx) 55 | 56 | class BatchIterator: 57 | def __init__(self, gen_fn, length): 58 | self.gen_fn = gen_fn 59 | self.length = length 60 | 61 | def __len__(self): 62 | return self.length 63 | 64 | def __iter__(self): 65 | return self.gen_fn() 66 | 67 | @contextmanager 68 | def open_npz_array(path: str, arr_name: str): 69 | with _open_npy_file(path, arr_name) as arr_f: 70 | version = np.lib.format.read_magic(arr_f) 71 | if version == (1, 0): 72 | header = np.lib.format.read_array_header_1_0(arr_f) 73 | elif version == (2, 0): 74 | header = np.lib.format.read_array_header_2_0(arr_f) 75 | else: 76 | yield MemoryNpzArrayReader.load(path, arr_name) 77 | return 78 | print(header) 79 | shape, fortran, dtype = header 80 | if fortran or dtype.hasobject: 81 | yield MemoryNpzArrayReader.load(path, arr_name) 82 | else: 83 | yield StreamingNpzArrayReader(arr_f, shape, dtype) 84 | 85 | @contextmanager 86 | def _open_npy_file(path: str, arr_name: str): 87 | with open(path, "rb") as f: 88 | with zipfile.ZipFile(f, "r") as zip_f: 89 | if f"{arr_name}.npy" not in zip_f.namelist(): 90 | raise ValueError(f"missing {arr_name} in npz file") 91 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f: 92 | yield arr_f 93 | 94 | def _read_bytes(fp, size, error_template="ran out of data"): 95 | """ 96 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 97 | 98 | Read from file-like object until size bytes are read. 99 | Raises ValueError if not EOF is encountered before size bytes are read. 100 | Non-blocking objects only supported if they derive from io objects. 101 | Required as e.g. ZipExtFile in python 2.6 can return less data than 102 | requested. 103 | """ 104 | data = bytes() 105 | while True: 106 | # io files (default in python3) return None or raise on 107 | # would-block, python2 file will truncate, probably nothing can be 108 | # done about that. note that regular files can't be non-blocking 109 | try: 110 | r = fp.read(size - len(data)) 111 | data += r 112 | if len(r) == 0 or len(data) == size: 113 | break 114 | except io.BlockingIOError: 115 | pass 116 | if len(data) != size: 117 | msg = "EOF: reading %s, expected %d bytes got %d" 118 | raise ValueError(msg % (error_template, size, len(data))) 119 | else: 120 | return data -------------------------------------------------------------------------------- /U-ViT/libs/timm.py: -------------------------------------------------------------------------------- 1 | # code from timm 0.3.2 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import warnings 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def drop_path(x, drop_prob: float = 0., training: bool = False): 66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 67 | 68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 72 | 'survival rate' as the argument. 73 | 74 | """ 75 | if drop_prob == 0. or not training: 76 | return x 77 | keep_prob = 1 - drop_prob 78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 80 | random_tensor.floor_() # binarize 81 | output = x.div(keep_prob) * random_tensor 82 | return output 83 | 84 | 85 | class DropPath(nn.Module): 86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 87 | """ 88 | def __init__(self, drop_prob=None): 89 | super(DropPath, self).__init__() 90 | self.drop_prob = drop_prob 91 | 92 | def forward(self, x): 93 | return drop_path(x, self.drop_prob, self.training) 94 | 95 | 96 | class Mlp(nn.Module): 97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 98 | super().__init__() 99 | out_features = out_features or in_features 100 | hidden_features = hidden_features or in_features 101 | self.fc1 = nn.Linear(in_features, hidden_features) 102 | self.act = act_layer() 103 | self.fc2 = nn.Linear(hidden_features, out_features) 104 | self.drop = nn.Dropout(drop) 105 | 106 | def forward(self, x): 107 | x = self.fc1(x) 108 | x = self.act(x) 109 | x = self.drop(x) 110 | x = self.fc2(x) 111 | x = self.drop(x) 112 | return x 113 | -------------------------------------------------------------------------------- /U-ViT/eval.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | 14 | 15 | def evaluate(config): 16 | if config.get('benchmark', False): 17 | torch.backends.cudnn.benchmark = True 18 | torch.backends.cudnn.deterministic = False 19 | 20 | mp.set_start_method('spawn') 21 | accelerator = accelerate.Accelerator() 22 | device = accelerator.device 23 | accelerate.utils.set_seed(config.seed, device_specific=True) 24 | logging.info(f'Process {accelerator.process_index} using device: {device}') 25 | 26 | config.mixed_precision = accelerator.mixed_precision 27 | config = ml_collections.FrozenConfigDict(config) 28 | if accelerator.is_main_process: 29 | utils.set_logger(log_level='info', fname=config.output_path) 30 | else: 31 | utils.set_logger(log_level='error') 32 | builtins.print = lambda *args: None 33 | 34 | dataset = get_dataset(**config.dataset) 35 | 36 | nnet = utils.get_nnet(**config.nnet) 37 | nnet = accelerator.prepare(nnet) 38 | logging.info(f'load nnet from {config.nnet_path}') 39 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 40 | nnet.eval() 41 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 42 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 43 | def cfg_nnet(x, timesteps, y): 44 | _cond = nnet(x, timesteps, y=y) 45 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 46 | return _cond + config.sample.scale * (_cond - _uncond) 47 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 48 | else: 49 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 50 | 51 | 52 | logging.info(config.sample) 53 | assert os.path.exists(dataset.fid_stat) 54 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 55 | 56 | def sample_fn(_n_samples): 57 | x_init = torch.randn(_n_samples, *dataset.data_shape, device=device) 58 | if config.train.mode == 'uncond': 59 | kwargs = dict() 60 | elif config.train.mode == 'cond': 61 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 62 | else: 63 | raise NotImplementedError 64 | 65 | if config.sample.algorithm == 'euler_maruyama_sde': 66 | rsde = sde.ReverseSDE(score_model) 67 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 68 | elif config.sample.algorithm == 'euler_maruyama_ode': 69 | rsde = sde.ODE(score_model) 70 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 71 | elif config.sample.algorithm == 'dpm_solver': 72 | noise_schedule = NoiseScheduleVP(schedule='linear') 73 | model_fn = model_wrapper( 74 | score_model.noise_pred, 75 | noise_schedule, 76 | time_input_type='0', 77 | model_kwargs=kwargs 78 | ) 79 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 80 | return dpm_solver.sample( 81 | x_init, 82 | steps=config.sample.sample_steps, 83 | eps=1e-4, 84 | adaptive_step_size=False, 85 | fast_version=True, 86 | ) 87 | else: 88 | raise NotImplementedError 89 | 90 | with tempfile.TemporaryDirectory() as temp_path: 91 | path = config.sample.path or temp_path 92 | if accelerator.is_main_process: 93 | os.makedirs(path, exist_ok=True) 94 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 95 | if accelerator.is_main_process: 96 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 97 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 98 | 99 | from absl import flags 100 | from absl import app 101 | from ml_collections import config_flags 102 | import os 103 | 104 | 105 | FLAGS = flags.FLAGS 106 | config_flags.DEFINE_config_file( 107 | "config", None, "Training configuration.", lock_config=False) 108 | flags.mark_flags_as_required(["config"]) 109 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 110 | flags.DEFINE_string("output_path", None, "The path to output log.") 111 | 112 | 113 | def main(argv): 114 | config = FLAGS.config 115 | config.nnet_path = FLAGS.nnet_path 116 | config.output_path = FLAGS.output_path 117 | evaluate(config) 118 | 119 | 120 | if __name__ == "__main__": 121 | app.run(main) 122 | -------------------------------------------------------------------------------- /DiT/sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Sample new images from a pre-trained DiT. 9 | """ 10 | import torch 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | torch.backends.cudnn.allow_tf32 = True 13 | from torchvision.utils import save_image 14 | from diffusion import create_diffusion 15 | from diffusers.models import AutoencoderKL 16 | from download import find_model 17 | 18 | import argparse 19 | import numpy as np 20 | 21 | 22 | def main(args): 23 | # Setup PyTorch: 24 | torch.set_grad_enabled(False) 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | 27 | if args.ckpt is None: 28 | assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." 29 | assert args.image_size in [256, 512] 30 | assert args.num_classes == 1000 31 | 32 | # initialize diffusin process 33 | diffusion = create_diffusion(str(args.num_sampling_steps)) 34 | 35 | # Load model: 36 | latent_size = args.image_size // 8 37 | if args.accelerate_method is not None and args.accelerate_method == "dynamiclayer": 38 | from models.dynamic_models import DiT_models 39 | else: 40 | from models.models import DiT_models 41 | 42 | model = DiT_models[args.model]( 43 | input_size=latent_size, 44 | num_classes=args.num_classes 45 | ).to(device) 46 | 47 | if args.accelerate_method is not None and 'dynamiclayer' in args.accelerate_method: 48 | model.load_ranking(args.path, args.num_sampling_steps, diffusion.timestep_map, args.thres) 49 | 50 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: 51 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" 52 | state_dict = find_model(ckpt_path) 53 | model.load_state_dict(state_dict) 54 | model.eval() # important! 55 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 56 | 57 | torch.manual_seed(args.seed) 58 | # Labels to condition the model with (feel free to change): 59 | class_labels = [207, 992, 387, 974, 142, 979, 417, 279] 60 | 61 | # Create sampling noise: 62 | n = len(class_labels) 63 | z = torch.randn(n, 4, latent_size, latent_size, device=device) 64 | y = torch.tensor(class_labels, device=device) 65 | 66 | # Setup classifier-free guidance: 67 | z = torch.cat([z, z], 0) 68 | y_null = torch.tensor([1000] * n, device=device) 69 | y = torch.cat([y, y_null], 0) 70 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 71 | 72 | # Sample images: 73 | import time 74 | times = [] 75 | for _ in range(6): 76 | start_time = time.time() 77 | if args.p_sample: 78 | samples = diffusion.p_sample_loop( 79 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device 80 | ) 81 | elif args.ddim_sample: 82 | samples = diffusion.ddim_sample_loop( 83 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device 84 | ) 85 | times.append(time.time() - start_time) 86 | model.reset() 87 | 88 | if len(times) > 1: 89 | times = np.array(times[1:]) 90 | print("Sampling time: {:.3f}±{:.3f}".format(np.mean(times), np.std(times))) 91 | 92 | 93 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 94 | samples = vae.decode(samples / 0.18215).sample 95 | save_image(samples, f"Sample_NFE{args.num_sampling_steps}_Method_{args.accelerate_method}.png", nrow=8, normalize=True, value_range=(-1, 1)) 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model", type=str, default="DiT-XL/2") 100 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") 101 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 102 | parser.add_argument("--num-classes", type=int, default=1000) 103 | parser.add_argument("--cfg-scale", type=float, default=4.0) 104 | parser.add_argument("--num-sampling-steps", type=int, default=250) 105 | parser.add_argument("--seed", type=int, default=0) 106 | parser.add_argument("--ckpt", type=str, default=None, 107 | help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") 108 | parser.add_argument("--accelerate-method", type=str, default=None, 109 | help="Use the accelerated version of the model.") 110 | 111 | parser.add_argument("--ddim-sample", action="store_true", default=False,) 112 | parser.add_argument("--p-sample", action="store_true", default=False,) 113 | 114 | parser.add_argument("--path", type=str, default=None, 115 | help="Optional path to a router checkpoint") 116 | parser.add_argument("--thres", type=float, default=0.5) 117 | 118 | args = parser.parse_args() 119 | main(args) 120 | -------------------------------------------------------------------------------- /DiT/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /DiT/diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /U-ViT/sample_ldm_discrete.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | from datasets import get_dataset 8 | import tempfile 9 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 10 | from absl import logging 11 | import builtins 12 | import libs.autoencoder 13 | from torchvision.utils import save_image 14 | import numpy as np 15 | 16 | 17 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 18 | _betas = ( 19 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 20 | ) 21 | return _betas.numpy() 22 | 23 | 24 | def evaluate(config): 25 | if config.get('benchmark', False): 26 | torch.backends.cudnn.benchmark = True 27 | torch.backends.cudnn.deterministic = False 28 | 29 | mp.set_start_method('spawn') 30 | accelerator = accelerate.Accelerator() 31 | device = accelerator.device 32 | accelerate.utils.set_seed(0, device_specific=True) 33 | logging.info(f'Process {accelerator.process_index} using device: {device}') 34 | 35 | config.mixed_precision = accelerator.mixed_precision 36 | config = ml_collections.FrozenConfigDict(config) 37 | if accelerator.is_main_process: 38 | utils.set_logger(log_level='info', fname=config.output_path) 39 | else: 40 | utils.set_logger(log_level='error') 41 | builtins.print = lambda *args: None 42 | 43 | dataset = get_dataset(**config.dataset) 44 | 45 | nnet = utils.get_nnet(**config.nnet) 46 | nnet = accelerator.prepare(nnet) 47 | logging.info(f'load nnet from {config.nnet_path}') 48 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 49 | nnet.eval() 50 | 51 | if 'dynamic' in config.sample: 52 | # Get Timestep Mapping 53 | t_0 = 1. / 1000 54 | t_T = 1.0 55 | order_value = 2 56 | N_steps = config.nfe // order_value 57 | timesteps = torch.linspace(t_T, t_0, N_steps + 1).cpu().numpy() 58 | #timesteps = timesteps.numpy() 59 | timestep_mapping = np.round(timesteps * 1000) 60 | #accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping) 61 | 62 | nnet.load_ranking(config.router, config.nfe, timestep_mapping, config.thres) 63 | 64 | 65 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 66 | autoencoder.to(device) 67 | 68 | @torch.cuda.amp.autocast() 69 | def encode(_batch): 70 | return autoencoder.encode(_batch) 71 | 72 | @torch.cuda.amp.autocast() 73 | def decode(_batch): 74 | return autoencoder.decode(_batch) 75 | 76 | def decode_large_batch(_batch): 77 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 78 | xs = [] 79 | pt = 0 80 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 81 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 82 | pt += _decode_mini_batch_size 83 | xs.append(x) 84 | xs = torch.concat(xs, dim=0) 85 | assert xs.size(0) == _batch.size(0) 86 | return xs 87 | 88 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 89 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 90 | def cfg_nnet(x, timesteps, y): 91 | _cond = nnet(x, timesteps, y=y) 92 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 93 | return _cond + config.sample.scale * (_cond - _uncond) 94 | else: 95 | def cfg_nnet(x, timesteps, y): 96 | _cond = nnet(x, timesteps, y=y) 97 | return _cond 98 | 99 | logging.info(config.sample) 100 | assert os.path.exists(dataset.fid_stat) 101 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 102 | 103 | _betas = stable_diffusion_beta_schedule() 104 | N = len(_betas) 105 | 106 | def sample_z(_n_samples, _sample_steps, **kwargs): 107 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 108 | 109 | if config.sample.algorithm == 'dpm_solver': 110 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 111 | 112 | def model_fn(x, t_continuous): 113 | t = t_continuous * N 114 | 115 | eps_pre = cfg_nnet(x, t, **kwargs) 116 | return eps_pre 117 | 118 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 119 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1., order=2, method='singlestep') 120 | 121 | else: 122 | raise NotImplementedError 123 | 124 | return _z 125 | 126 | def sample_fn(_n_samples): 127 | class_labels = [207, 992, 387, 974, 142, 979, 417, 279] 128 | 129 | if config.train.mode == 'uncond': 130 | kwargs = dict() 131 | elif config.train.mode == 'cond': 132 | kwargs = dict(y=torch.tensor(class_labels, device=device)) 133 | else: 134 | raise NotImplementedError 135 | _z = sample_z(_n_samples, _sample_steps=config.nfe, **kwargs) 136 | return decode_large_batch(_z) 137 | 138 | import time 139 | use_time = [] 140 | if config.teaser: 141 | samples = sample_fn(8) 142 | samples = dataset.unpreprocess(samples) 143 | dynamic_flag = 'dynamic' in config.sample 144 | for idx, sample in enumerate(samples): 145 | save_image(sample, f"images/teaser_{dynamic_flag}_{idx}.png", nrow=1) 146 | else: 147 | logging.info("Start sampling") 148 | for _ in range(6): 149 | start_time = time.time() 150 | samples = sample_fn(8) 151 | use_time.append(time.time() - start_time) 152 | samples = dataset.unpreprocess(samples) 153 | nnet.reset() 154 | 155 | times = np.array(use_time[1:]) 156 | logging.info("Sampling time: {:.2f}±{:.2f}".format(np.mean(times), np.std(times))) 157 | save_image(samples, "u-vit-H-2.png", nrow=8) 158 | 159 | 160 | from absl import flags 161 | from absl import app 162 | from ml_collections import config_flags 163 | import os 164 | 165 | 166 | FLAGS = flags.FLAGS 167 | config_flags.DEFINE_config_file( 168 | "config", None, "Training configuration.", lock_config=False) 169 | flags.mark_flags_as_required(["config"]) 170 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 171 | flags.DEFINE_string("output_path", None, "The path to output log.") 172 | flags.DEFINE_string("nfe", None, "number of evaluation") 173 | 174 | flags.DEFINE_string("router", None, "path of router") 175 | flags.DEFINE_string("thres", "0", "threshold of router") 176 | 177 | flags.DEFINE_boolean("teaser", False, "generate teaser image") 178 | 179 | 180 | def main(argv): 181 | config = FLAGS.config 182 | config.nnet_path = FLAGS.nnet_path 183 | config.output_path = FLAGS.output_path 184 | config.teaser = FLAGS.teaser 185 | config.nfe = int(FLAGS.nfe) 186 | config.router = FLAGS.router 187 | config.thres = float(FLAGS.thres) 188 | evaluate(config) 189 | 190 | 191 | if __name__ == "__main__": 192 | app.run(main) 193 | -------------------------------------------------------------------------------- /U-ViT/eval_ldm_discrete.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | import numpy as np 5 | from torch import multiprocessing as mp 6 | import accelerate 7 | import utils 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | import libs.autoencoder 14 | 15 | 16 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 17 | _betas = ( 18 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 19 | ) 20 | return _betas.numpy() 21 | 22 | 23 | def evaluate(config): 24 | if config.get('benchmark', False): 25 | torch.backends.cudnn.benchmark = True 26 | torch.backends.cudnn.deterministic = False 27 | 28 | mp.set_start_method('spawn') 29 | accelerator = accelerate.Accelerator() 30 | device = accelerator.device 31 | accelerate.utils.set_seed(config.seed, device_specific=True) 32 | logging.info(f'Process {accelerator.process_index} using device: {device}') 33 | 34 | config.mixed_precision = accelerator.mixed_precision 35 | config = ml_collections.FrozenConfigDict(config) 36 | if accelerator.is_main_process: 37 | utils.set_logger(log_level='info', fname=config.output_path) 38 | else: 39 | utils.set_logger(log_level='error') 40 | builtins.print = lambda *args: None 41 | 42 | dataset = get_dataset(**config.dataset) 43 | 44 | nnet = utils.get_nnet(**config.nnet) 45 | nnet = accelerator.prepare(nnet) 46 | logging.info(f'load nnet from {config.nnet_path}') 47 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 48 | nnet.eval() 49 | 50 | if 'dynamic' in config.sample: 51 | # Get Timestep Mapping 52 | t_0 = 1. / 1000 53 | t_T = 1.0 54 | order_value = 2 55 | N_steps = config.nfe // order_value 56 | timesteps = torch.linspace(t_T, t_0, N_steps + 1).cpu().numpy() 57 | #timesteps = timesteps.numpy() 58 | timestep_mapping = np.round(timesteps * 1000) 59 | #accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping) 60 | 61 | accelerator.unwrap_model(nnet).load_ranking(config.router, config.nfe, timestep_mapping, config.thres) 62 | elif 'rank' in config.sample: 63 | # Get Timestep Mapping 64 | t_0 = 1. / 1000 65 | t_T = 1.0 66 | order_value = 2 67 | N_steps = config.nfe // order_value 68 | timesteps = torch.linspace(t_T, t_0, N_steps + 1).cpu().numpy() 69 | #timesteps = timesteps.numpy() 70 | timestep_mapping = np.round(timesteps * 1000) 71 | #accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping) 72 | 73 | accelerator.unwrap_model(nnet).load_ranking(config.nfe, config.thres) 74 | 75 | elif 'topk' in config.sample or 'random' in config.sample: 76 | accelerator.unwrap_model(nnet).load_ranking(config.sample.topk, config.sample.reverse, config.sample.random) 77 | 78 | 79 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 80 | autoencoder.to(device) 81 | 82 | @torch.cuda.amp.autocast() 83 | def encode(_batch): 84 | return autoencoder.encode(_batch) 85 | 86 | @torch.cuda.amp.autocast() 87 | def decode(_batch): 88 | return autoencoder.decode(_batch) 89 | 90 | def decode_large_batch(_batch): 91 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 92 | xs = [] 93 | pt = 0 94 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 95 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 96 | pt += _decode_mini_batch_size 97 | xs.append(x) 98 | xs = torch.concat(xs, dim=0) 99 | assert xs.size(0) == _batch.size(0) 100 | return xs 101 | 102 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 103 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 104 | def cfg_nnet(x, timesteps, y): 105 | _cond = nnet(x, timesteps, y=y) 106 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 107 | return _cond + config.sample.scale * (_cond - _uncond) 108 | else: 109 | def cfg_nnet(x, timesteps, y): 110 | _cond = nnet(x, timesteps, y=y) 111 | return _cond 112 | 113 | logging.info(config.sample) 114 | assert os.path.exists(dataset.fid_stat) 115 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 116 | 117 | _betas = stable_diffusion_beta_schedule() 118 | N = len(_betas) 119 | 120 | def sample_z(_n_samples, _sample_steps, **kwargs): 121 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 122 | 123 | if config.sample.algorithm == 'dpm_solver': 124 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 125 | 126 | def model_fn(x, t_continuous): 127 | t = t_continuous * N 128 | eps_pre = cfg_nnet(x, t, **kwargs) 129 | return eps_pre 130 | 131 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 132 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1., order=2, method='singlestep') 133 | 134 | else: 135 | raise NotImplementedError 136 | 137 | return _z 138 | 139 | def sample_fn(_n_samples): 140 | if config.train.mode == 'uncond': 141 | kwargs = dict() 142 | elif config.train.mode == 'cond': 143 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 144 | else: 145 | raise NotImplementedError 146 | _z = sample_z(_n_samples, _sample_steps=config.nfe, **kwargs) 147 | return decode_large_batch(_z) 148 | 149 | with tempfile.TemporaryDirectory() as temp_path: 150 | path = config.sample.path or temp_path 151 | #if accelerator.is_main_process: 152 | # os.makedirs(path, exist_ok=True) 153 | logging.info(f'Samples are saved in {path}') 154 | #utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 155 | utils.sample2npz(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, reset_fn=accelerator.unwrap_model(nnet).reset) 156 | 157 | if accelerator.is_main_process: 158 | torch.cuda.empty_cache() 159 | fid = calculate_fid_given_paths((dataset.fid_stat, f"{path}.npz"), batch_size=1000) 160 | log_path = path.replace('manual_samples/', 'log/') 161 | with open(f"{log_path}.log", "a") as f: 162 | f.write(f"npz_path={path}.npz, fid={fid}") 163 | logging.info(f'npz_path={path}.npz, fid={fid}') 164 | 165 | 166 | from absl import flags 167 | from absl import app 168 | from ml_collections import config_flags 169 | import os 170 | 171 | 172 | FLAGS = flags.FLAGS 173 | config_flags.DEFINE_config_file( 174 | "config", None, "Training configuration.", lock_config=False) 175 | flags.mark_flags_as_required(["config"]) 176 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 177 | flags.DEFINE_string("output_path", None, "The path to output log.") 178 | flags.DEFINE_string("nfe", None, "NFE") 179 | flags.DEFINE_string("router", None, "path of router") 180 | flags.DEFINE_string("thres", "0", "threshold of router") 181 | 182 | 183 | def main(argv): 184 | config = FLAGS.config 185 | config.nnet_path = FLAGS.nnet_path 186 | config.output_path = FLAGS.output_path 187 | config.nfe = int(FLAGS.nfe) 188 | config.thres = float(FLAGS.thres) 189 | config.router = FLAGS.router 190 | 191 | evaluate(config) 192 | 193 | 194 | if __name__ == "__main__": 195 | app.run(main) 196 | -------------------------------------------------------------------------------- /U-ViT/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | from torchvision.utils import save_image 7 | from absl import logging 8 | 9 | 10 | def set_logger(log_level='info', fname=None): 11 | import logging as _logging 12 | handler = logging.get_absl_handler() 13 | formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') 14 | handler.setFormatter(formatter) 15 | logging.set_verbosity(log_level) 16 | if fname is not None: 17 | handler = _logging.FileHandler(fname) 18 | handler.setFormatter(formatter) 19 | logging.get_absl_logger().addHandler(handler) 20 | 21 | 22 | def dct2str(dct): 23 | return str({k: f'{v:.6g}' for k, v in dct.items()}) 24 | 25 | 26 | def get_nnet(name, **kwargs): 27 | if name == 'uvit': 28 | from libs.uvit import UViT 29 | return UViT(**kwargs) 30 | elif name == 'uvit_t2i': 31 | from libs.uvit_t2i import UViT 32 | return UViT(**kwargs) 33 | elif name == 'uvit_timecache': 34 | from libs.uvit_timecache import UViT 35 | return UViT(**kwargs) 36 | elif name == 'uvit_router': 37 | from libs.uvit_router import UViT 38 | return UViT(**kwargs) 39 | elif name == 'uvit_dynamic': 40 | from libs.uvit_dynamic import UViT 41 | return UViT(**kwargs) 42 | elif name == 'uvit_manual': 43 | from libs.uvit_manual import UViT 44 | return UViT(**kwargs) 45 | elif name == 'uvit_deepcache': 46 | from libs.uvit_deepcache import UViT 47 | return UViT(**kwargs) 48 | elif name == 'uvit_fasterdiffusion': 49 | from libs.uvit_fasterdiffusion import UViT 50 | return UViT(**kwargs) 51 | elif name == 'uvit_analysis': 52 | from libs.uvit_analysis import UViT 53 | return UViT(**kwargs) 54 | elif name == 'uvit_ranklayer': 55 | from libs.uvit_ranklayer import UViT 56 | return UViT(**kwargs) 57 | else: 58 | raise NotImplementedError(name) 59 | 60 | 61 | def set_seed(seed: int): 62 | if seed is not None: 63 | torch.manual_seed(seed) 64 | np.random.seed(seed) 65 | 66 | 67 | def get_optimizer(params, name, **kwargs): 68 | if name == 'adam': 69 | from torch.optim import Adam 70 | return Adam(params, **kwargs) 71 | elif name == 'adamw': 72 | from torch.optim import AdamW 73 | return AdamW(params, **kwargs) 74 | else: 75 | raise NotImplementedError(name) 76 | 77 | 78 | def customized_lr_scheduler(optimizer, warmup_steps=-1): 79 | from torch.optim.lr_scheduler import LambdaLR 80 | def fn(step): 81 | if warmup_steps > 0: 82 | return min(step / warmup_steps, 1) 83 | else: 84 | return 1 85 | return LambdaLR(optimizer, fn) 86 | 87 | 88 | def get_lr_scheduler(optimizer, name, **kwargs): 89 | if name == 'customized': 90 | return customized_lr_scheduler(optimizer, **kwargs) 91 | elif name == 'cosine': 92 | from torch.optim.lr_scheduler import CosineAnnealingLR 93 | return CosineAnnealingLR(optimizer, **kwargs) 94 | else: 95 | raise NotImplementedError(name) 96 | 97 | 98 | def ema(model_dest: nn.Module, model_src: nn.Module, rate): 99 | param_dict_src = dict(model_src.named_parameters()) 100 | for p_name, p_dest in model_dest.named_parameters(): 101 | p_src = param_dict_src[p_name] 102 | assert p_src is not p_dest 103 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) 104 | 105 | 106 | class TrainState(object): 107 | def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): 108 | self.optimizer = optimizer 109 | self.lr_scheduler = lr_scheduler 110 | self.step = step 111 | self.nnet = nnet 112 | self.nnet_ema = nnet_ema 113 | 114 | #def ema_update(self, rate=0.9999): 115 | # if self.nnet_ema is not None: 116 | # ema(self.nnet_ema, self.nnet, rate) 117 | 118 | def save(self, path): 119 | os.makedirs(path, exist_ok=True) 120 | torch.save(self.step, os.path.join(path, 'step.pth')) 121 | for key, val in self.__dict__.items(): 122 | if key != 'step' and 'ema' not in key and val is not None: 123 | if key == 'nnet': 124 | torch.save(val.routers.state_dict(), os.path.join(path, f'{key}.pth')) 125 | else: 126 | torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) 127 | 128 | def load(self, path): 129 | logging.info(f'load from {path}') 130 | self.step = torch.load(os.path.join(path, 'step.pth')) 131 | for key, val in self.__dict__.items(): 132 | if key != 'step' and val is not None: 133 | val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) 134 | 135 | def resume(self, ckpt_root, step=None): 136 | if not os.path.exists(ckpt_root): 137 | return 138 | if step is None: 139 | ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) 140 | if not ckpts: 141 | return 142 | steps = map(lambda x: int(x.split(".")[0]), ckpts) 143 | step = max(steps) 144 | ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') 145 | logging.info(f'resume from {ckpt_path}') 146 | self.load(ckpt_path) 147 | 148 | def to(self, device): 149 | for key, val in self.__dict__.items(): 150 | if isinstance(val, nn.Module): 151 | val.to(device) 152 | 153 | def update_optimizer(self, optimizer): 154 | self.optimizer = optimizer 155 | 156 | 157 | def cnt_params(model): 158 | return sum(param.numel() for param in model.parameters()) 159 | 160 | 161 | def initialize_train_state(config, device): 162 | params = [] 163 | 164 | nnet = get_nnet(**config.nnet) 165 | params += nnet.parameters() 166 | nnet_ema = get_nnet(**config.nnet) 167 | nnet_ema.eval() 168 | logging.info(f'nnet has {cnt_params(nnet)} parameters') 169 | 170 | optimizer = get_optimizer(params, **config.optimizer) 171 | lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) 172 | 173 | train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, 174 | nnet=nnet, nnet_ema=nnet_ema) 175 | #train_state.ema_update(0) 176 | train_state.to(device) 177 | return train_state 178 | 179 | 180 | def amortize(n_samples, batch_size): 181 | k = n_samples // batch_size 182 | r = n_samples % batch_size 183 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 184 | 185 | 186 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None): 187 | os.makedirs(path, exist_ok=True) 188 | idx = 0 189 | batch_size = mini_batch_size * accelerator.num_processes 190 | 191 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): 192 | samples = unpreprocess_fn(sample_fn(mini_batch_size)) 193 | samples = accelerator.gather(samples.contiguous())[:_batch_size] 194 | if accelerator.is_main_process: 195 | for sample in samples: 196 | save_image(sample, os.path.join(path, f"{idx}.png")) 197 | idx += 1 198 | accelerator.wait_for_everyone() 199 | 200 | def sample2npz(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, reset_fn=None): 201 | #os.makedirs(path, exist_ok=True) 202 | idx = 0 203 | batch_size = mini_batch_size * accelerator.num_processes 204 | 205 | all_images = [] 206 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): 207 | samples = unpreprocess_fn(sample_fn(mini_batch_size)) 208 | samples = accelerator.gather(samples.contiguous())[:_batch_size] 209 | if accelerator.is_main_process: 210 | samples = samples.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() 211 | all_images.append(samples) 212 | 213 | reset_fn() 214 | accelerator.wait_for_everyone() 215 | 216 | 217 | if accelerator.is_main_process: 218 | arr = np.concatenate(all_images, axis=0) 219 | arr = arr[: n_samples] 220 | out_path = f"{path}.npz" 221 | 222 | print(f"saving to {out_path}") 223 | np.savez(out_path, arr_0=arr) 224 | 225 | 226 | def grad_norm(model): 227 | total_norm = 0. 228 | for p in model.parameters(): 229 | param_norm = p.grad.data.norm(2) 230 | total_norm += param_norm.item() ** 2 231 | total_norm = total_norm ** (1. / 2) 232 | return total_norm 233 | -------------------------------------------------------------------------------- /U-ViT/libs/uvit_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False, 141 | clip_dim=768, num_clip_token=77, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.in_chans = in_chans 145 | 146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | num_patches = (img_size // patch_size) ** 2 148 | 149 | self.time_embed = nn.Sequential( 150 | nn.Linear(embed_dim, 4 * embed_dim), 151 | nn.SiLU(), 152 | nn.Linear(4 * embed_dim, embed_dim), 153 | ) if mlp_time_embed else nn.Identity() 154 | 155 | self.context_embed = nn.Linear(clip_dim, embed_dim) 156 | 157 | self.extras = 1 + num_clip_token 158 | 159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 160 | 161 | self.in_blocks = nn.ModuleList([ 162 | Block( 163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 165 | for _ in range(depth // 2)]) 166 | 167 | self.mid_block = Block( 168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 170 | 171 | self.out_blocks = nn.ModuleList([ 172 | Block( 173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 175 | for _ in range(depth // 2)]) 176 | 177 | self.norm = norm_layer(embed_dim) 178 | self.patch_dim = patch_size ** 2 * in_chans 179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 181 | 182 | trunc_normal_(self.pos_embed, std=.02) 183 | self.apply(self._init_weights) 184 | 185 | def _init_weights(self, m): 186 | if isinstance(m, nn.Linear): 187 | trunc_normal_(m.weight, std=.02) 188 | if isinstance(m, nn.Linear) and m.bias is not None: 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.LayerNorm): 191 | nn.init.constant_(m.bias, 0) 192 | nn.init.constant_(m.weight, 1.0) 193 | 194 | @torch.jit.ignore 195 | def no_weight_decay(self): 196 | return {'pos_embed'} 197 | 198 | def forward(self, x, timesteps, context): 199 | x = self.patch_embed(x) 200 | B, L, D = x.shape 201 | 202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 203 | time_token = time_token.unsqueeze(dim=1) 204 | context_token = self.context_embed(context) 205 | x = torch.cat((time_token, context_token, x), dim=1) 206 | x = x + self.pos_embed 207 | 208 | skips = [] 209 | for blk in self.in_blocks: 210 | x = blk(x) 211 | skips.append(x) 212 | 213 | x = self.mid_block(x) 214 | 215 | for blk in self.out_blocks: 216 | x = blk(x, skips.pop()) 217 | 218 | x = self.norm(x) 219 | x = self.decoder_pred(x) 220 | assert x.size(1) == self.extras + L 221 | x = x[:, self.extras:, :] 222 | x = unpatchify(x, self.in_chans) 223 | x = self.final_layer(x) 224 | return x 225 | -------------------------------------------------------------------------------- /U-ViT/sde.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from absl import logging 4 | import numpy as np 5 | import math 6 | from tqdm import tqdm 7 | 8 | 9 | def get_sde(name, **kwargs): 10 | if name == 'vpsde': 11 | return VPSDE(**kwargs) 12 | elif name == 'vpsde_cosine': 13 | return VPSDECosine(**kwargs) 14 | else: 15 | raise NotImplementedError 16 | 17 | 18 | def stp(s, ts: torch.Tensor): # scalar tensor product 19 | if isinstance(s, np.ndarray): 20 | s = torch.from_numpy(s).type_as(ts) 21 | extra_dims = (1,) * (ts.dim() - 1) 22 | return s.view(-1, *extra_dims) * ts 23 | 24 | 25 | def mos(a, start_dim=1): # mean of square 26 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 27 | 28 | 29 | def duplicate(tensor, *size): 30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape) 31 | 32 | 33 | class SDE(object): 34 | r""" 35 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1 36 | f(x, t) is the drift 37 | g(t) is the diffusion 38 | """ 39 | def drift(self, x, t): 40 | raise NotImplementedError 41 | 42 | def diffusion(self, t): 43 | raise NotImplementedError 44 | 45 | def cum_beta(self, t): # the variance of xt|x0 46 | raise NotImplementedError 47 | 48 | def cum_alpha(self, t): 49 | raise NotImplementedError 50 | 51 | def snr(self, t): # signal noise ratio 52 | raise NotImplementedError 53 | 54 | def nsr(self, t): # noise signal ratio 55 | raise NotImplementedError 56 | 57 | def marginal_prob(self, x0, t): # the mean and std of q(xt|x0) 58 | alpha = self.cum_alpha(t) 59 | beta = self.cum_beta(t) 60 | mean = stp(alpha ** 0.5, x0) # E[xt|x0] 61 | std = beta ** 0.5 # Cov[xt|x0] ** 0.5 62 | return mean, std 63 | 64 | def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform 65 | t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init 66 | mean, std = self.marginal_prob(x0, t) 67 | eps = torch.randn_like(x0) 68 | xt = mean + stp(std, eps) 69 | return t, eps, xt 70 | 71 | 72 | class VPSDE(SDE): 73 | def __init__(self, beta_min=0.1, beta_max=20): 74 | # 0 <= t <= 1 75 | self.beta_0 = beta_min 76 | self.beta_1 = beta_max 77 | 78 | def drift(self, x, t): 79 | return -0.5 * stp(self.squared_diffusion(t), x) 80 | 81 | def diffusion(self, t): 82 | return self.squared_diffusion(t) ** 0.5 83 | 84 | def squared_diffusion(self, t): # beta(t) 85 | return self.beta_0 + t * (self.beta_1 - self.beta_0) 86 | 87 | def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau 88 | return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5 89 | 90 | def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I 91 | return 1. - self.skip_alpha(s, t) 92 | 93 | def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs 94 | x = -self.squared_diffusion_integral(s, t) 95 | return x.exp() 96 | 97 | def cum_beta(self, t): 98 | return self.skip_beta(0, t) 99 | 100 | def cum_alpha(self, t): 101 | return self.skip_alpha(0, t) 102 | 103 | def nsr(self, t): 104 | return self.squared_diffusion_integral(0, t).expm1() 105 | 106 | def snr(self, t): 107 | return 1. / self.nsr(t) 108 | 109 | def __str__(self): 110 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}' 111 | 112 | def __repr__(self): 113 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}' 114 | 115 | 116 | class VPSDECosine(SDE): 117 | r""" 118 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1 119 | f(x, t) is the drift 120 | g(t) is the diffusion 121 | """ 122 | def __init__(self, s=0.008): 123 | self.s = s 124 | self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2 125 | self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2 126 | 127 | def drift(self, x, t): 128 | ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2 129 | return stp(ft, x) 130 | 131 | def diffusion(self, t): 132 | return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5 133 | 134 | def cum_beta(self, t): # the variance of xt|x0 135 | return 1 - self.cum_alpha(t) 136 | 137 | def cum_alpha(self, t): 138 | return self.F(t) / self.F0 139 | 140 | def snr(self, t): # signal noise ratio 141 | Ft = self.F(t) 142 | return Ft / (self.F0 - Ft) 143 | 144 | def nsr(self, t): # noise signal ratio 145 | Ft = self.F(t) 146 | return self.F0 / Ft - 1 147 | 148 | def __str__(self): 149 | return 'vpsde_cosine' 150 | 151 | def __repr__(self): 152 | return 'vpsde_cosine' 153 | 154 | 155 | class ScoreModel(object): 156 | r""" 157 | The forward process is q(x_[0,T]) 158 | """ 159 | 160 | def __init__(self, nnet: nn.Module, pred: str, sde: SDE, T=1): 161 | assert T == 1 162 | self.nnet = nnet 163 | self.pred = pred 164 | self.sde = sde 165 | self.T = T 166 | print(f'ScoreModel with pred={pred}, sde={sde}, T={T}') 167 | 168 | def predict(self, xt, t, **kwargs): 169 | if not isinstance(t, torch.Tensor): 170 | t = torch.tensor(t) 171 | t = t.to(xt.device) 172 | if t.dim() == 0: 173 | t = duplicate(t, xt.size(0)) 174 | return self.nnet(xt, t * 999, **kwargs) # follow SDE 175 | 176 | def noise_pred(self, xt, t, **kwargs): 177 | pred = self.predict(xt, t, **kwargs) 178 | if self.pred == 'noise_pred': 179 | noise_pred = pred 180 | elif self.pred == 'x0_pred': 181 | noise_pred = - stp(self.sde.snr(t).sqrt(), pred) + stp(self.sde.cum_beta(t).rsqrt(), xt) 182 | else: 183 | raise NotImplementedError 184 | return noise_pred 185 | 186 | def x0_pred(self, xt, t, **kwargs): 187 | pred = self.predict(xt, t, **kwargs) 188 | if self.pred == 'noise_pred': 189 | x0_pred = stp(self.sde.cum_alpha(t).rsqrt(), xt) - stp(self.sde.nsr(t).sqrt(), pred) 190 | elif self.pred == 'x0_pred': 191 | x0_pred = pred 192 | else: 193 | raise NotImplementedError 194 | return x0_pred 195 | 196 | def score(self, xt, t, **kwargs): 197 | cum_beta = self.sde.cum_beta(t) 198 | noise_pred = self.noise_pred(xt, t, **kwargs) 199 | return stp(-cum_beta.rsqrt(), noise_pred) 200 | 201 | 202 | class ReverseSDE(object): 203 | r""" 204 | dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw 205 | """ 206 | def __init__(self, score_model): 207 | self.sde = score_model.sde # the forward sde 208 | self.score_model = score_model 209 | 210 | def drift(self, x, t, **kwargs): 211 | drift = self.sde.drift(x, t) # f(x, t) 212 | diffusion = self.sde.diffusion(t) # g(t) 213 | score = self.score_model.score(x, t, **kwargs) 214 | return drift - stp(diffusion ** 2, score) 215 | 216 | def diffusion(self, t): 217 | return self.sde.diffusion(t) 218 | 219 | 220 | class ODE(object): 221 | r""" 222 | dx = [f(x, t) - g(t)^2 s(x, t)] dt 223 | """ 224 | 225 | def __init__(self, score_model): 226 | self.sde = score_model.sde # the forward sde 227 | self.score_model = score_model 228 | 229 | def drift(self, x, t, **kwargs): 230 | drift = self.sde.drift(x, t) # f(x, t) 231 | diffusion = self.sde.diffusion(t) # g(t) 232 | score = self.score_model.score(x, t, **kwargs) 233 | return drift - 0.5 * stp(diffusion ** 2, score) 234 | 235 | def diffusion(self, t): 236 | return 0 237 | 238 | 239 | def dct2str(dct): 240 | return str({k: f'{v:.6g}' for k, v in dct.items()}) 241 | 242 | 243 | @ torch.no_grad() 244 | def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs): 245 | r""" 246 | The Euler Maruyama sampler for reverse SDE / ODE 247 | See `Score-Based Generative Modeling through Stochastic Differential Equations` 248 | """ 249 | assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE) 250 | print(f"euler_maruyama with sample_steps={sample_steps}") 251 | timesteps = np.append(0., np.linspace(eps, T, sample_steps)) 252 | timesteps = torch.tensor(timesteps).to(x_init) 253 | x = x_init 254 | if trace is not None: 255 | trace.append(x) 256 | for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'): 257 | drift = rsde.drift(x, t, **kwargs) 258 | diffusion = rsde.diffusion(t) 259 | dt = s - t 260 | mean = x + drift * dt 261 | sigma = diffusion * (-dt).sqrt() 262 | x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean 263 | if trace is not None: 264 | trace.append(x) 265 | statistics = dict(s=s, t=t, sigma=sigma.item()) 266 | logging.debug(dct2str(statistics)) 267 | return x 268 | 269 | 270 | def LSimple(score_model: ScoreModel, x0, pred='noise_pred', **kwargs): 271 | t, noise, xt = score_model.sde.sample(x0) 272 | if pred == 'noise_pred': 273 | noise_pred = score_model.noise_pred(xt, t, **kwargs) 274 | return mos(noise - noise_pred) 275 | elif pred == 'x0_pred': 276 | x0_pred = score_model.x0_pred(xt, t, **kwargs) 277 | return mos(x0 - x0_pred) 278 | else: 279 | raise NotImplementedError(pred) 280 | -------------------------------------------------------------------------------- /U-ViT/libs/uvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, 141 | use_checkpoint=False, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.num_classes = num_classes 145 | self.in_chans = in_chans 146 | 147 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 148 | num_patches = (img_size // patch_size) ** 2 149 | 150 | self.time_embed = nn.Sequential( 151 | nn.Linear(embed_dim, 4 * embed_dim), 152 | nn.SiLU(), 153 | nn.Linear(4 * embed_dim, embed_dim), 154 | ) if mlp_time_embed else nn.Identity() 155 | 156 | if self.num_classes > 0: 157 | self.label_emb = nn.Embedding(self.num_classes, embed_dim) 158 | self.extras = 2 159 | else: 160 | self.extras = 1 161 | 162 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 163 | self.in_blocks = nn.ModuleList([ 164 | Block( 165 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 166 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 167 | for _ in range(depth // 2)]) 168 | 169 | self.mid_block = Block( 170 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 171 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 172 | 173 | self.out_blocks = nn.ModuleList([ 174 | Block( 175 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 176 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 177 | for _ in range(depth // 2)]) 178 | 179 | self.norm = norm_layer(embed_dim) 180 | self.patch_dim = patch_size ** 2 * in_chans 181 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 182 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 183 | 184 | trunc_normal_(self.pos_embed, std=.02) 185 | self.apply(self._init_weights) 186 | 187 | def _init_weights(self, m): 188 | if isinstance(m, nn.Linear): 189 | trunc_normal_(m.weight, std=.02) 190 | if isinstance(m, nn.Linear) and m.bias is not None: 191 | nn.init.constant_(m.bias, 0) 192 | elif isinstance(m, nn.LayerNorm): 193 | nn.init.constant_(m.bias, 0) 194 | nn.init.constant_(m.weight, 1.0) 195 | 196 | @torch.jit.ignore 197 | def no_weight_decay(self): 198 | return {'pos_embed'} 199 | 200 | def reset(self): 201 | pass 202 | 203 | def forward(self, x, timesteps, y=None): 204 | #print(timesteps) 205 | x = self.patch_embed(x) 206 | B, L, D = x.shape 207 | 208 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 209 | time_token = time_token.unsqueeze(dim=1) 210 | x = torch.cat((time_token, x), dim=1) 211 | if y is not None: 212 | label_emb = self.label_emb(y) 213 | label_emb = label_emb.unsqueeze(dim=1) 214 | x = torch.cat((label_emb, x), dim=1) 215 | x = x + self.pos_embed 216 | 217 | skips = [] 218 | for blk in self.in_blocks: 219 | x = blk(x) 220 | skips.append(x) 221 | 222 | x = self.mid_block(x) 223 | 224 | for blk in self.out_blocks: 225 | x = blk(x, skips.pop()) 226 | 227 | x = self.norm(x) 228 | x = self.decoder_pred(x) 229 | assert x.size(1) == self.extras + L 230 | x = x[:, self.extras:, :] 231 | x = unpatchify(x, self.in_chans) 232 | x = self.final_layer(x) 233 | return x 234 | -------------------------------------------------------------------------------- /U-ViT/tools/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | 37 | import numpy as np 38 | import torch 39 | import torchvision.transforms as TF 40 | from PIL import Image 41 | from scipy import linalg 42 | from torch.nn.functional import adaptive_avg_pool2d 43 | 44 | from .read_npz import open_npz_array 45 | 46 | import matplotlib.pyplot as plt 47 | 48 | from torchvision.transforms.functional import to_tensor 49 | 50 | try: 51 | from tqdm import tqdm 52 | except ImportError: 53 | # If tqdm is not available, provide a mock version of it 54 | def tqdm(x): 55 | return x 56 | 57 | from .inception import InceptionV3 58 | 59 | 60 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 61 | 'tif', 'tiff', 'webp'} 62 | 63 | 64 | class ImagePathDataset(torch.utils.data.Dataset): 65 | def __init__(self, files, transforms=None): 66 | self.files = files 67 | self.transforms = transforms 68 | 69 | def __len__(self): 70 | return len(self.files) 71 | 72 | def __getitem__(self, i): 73 | path = self.files[i] 74 | img = Image.open(path).convert('RGB') 75 | if self.transforms is not None: 76 | img = self.transforms(img) 77 | return img 78 | 79 | 80 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): 81 | """Calculates the activations of the pool_3 layer for all images. 82 | 83 | Params: 84 | -- files : List of image files paths 85 | -- model : Instance of inception model 86 | -- batch_size : Batch size of images for the model to process at once. 87 | Make sure that the number of samples is a multiple of 88 | the batch size, otherwise some samples are ignored. This 89 | behavior is retained to match the original FID score 90 | implementation. 91 | -- dims : Dimensionality of features returned by Inception 92 | -- device : Device to run calculations 93 | -- num_workers : Number of parallel dataloader workers 94 | 95 | Returns: 96 | -- A numpy array of dimension (num images, dims) that contains the 97 | activations of the given tensor when feeding inception with the 98 | query tensor. 99 | """ 100 | model.eval() 101 | 102 | if batch_size > len(files): 103 | print(('Warning: batch size is bigger than the data size. ' 104 | 'Setting batch size to data size')) 105 | batch_size = len(files) 106 | 107 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 108 | dataloader = torch.utils.data.DataLoader(dataset, 109 | batch_size=batch_size, 110 | shuffle=False, 111 | drop_last=False, 112 | num_workers=num_workers) 113 | 114 | pred_arr = np.empty((len(files), dims)) 115 | 116 | start_idx = 0 117 | 118 | for batch in tqdm(dataloader): 119 | batch = batch.to(device) 120 | with torch.no_grad(): 121 | pred = model(batch)[0] 122 | 123 | # If model output is not scalar, apply global spatial average pooling. 124 | # This happens if you choose a dimensionality not equal 2048. 125 | if pred.size(2) != 1 or pred.size(3) != 1: 126 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 127 | 128 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 129 | 130 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 131 | 132 | start_idx = start_idx + pred.shape[0] 133 | 134 | return pred_arr 135 | 136 | 137 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 138 | """Numpy implementation of the Frechet Distance. 139 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 140 | and X_2 ~ N(mu_2, C_2) is 141 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 142 | 143 | Stable version by Dougal J. Sutherland. 144 | 145 | Params: 146 | -- mu1 : Numpy array containing the activations of a layer of the 147 | inception net (like returned by the function 'get_predictions') 148 | for generated samples. 149 | -- mu2 : The sample mean over activations, precalculated on an 150 | representative data set. 151 | -- sigma1: The covariance matrix over activations for generated samples. 152 | -- sigma2: The covariance matrix over activations, precalculated on an 153 | representative data set. 154 | 155 | Returns: 156 | -- : The Frechet Distance. 157 | """ 158 | 159 | mu1 = np.atleast_1d(mu1) 160 | mu2 = np.atleast_1d(mu2) 161 | 162 | sigma1 = np.atleast_2d(sigma1) 163 | sigma2 = np.atleast_2d(sigma2) 164 | 165 | assert mu1.shape == mu2.shape, \ 166 | 'Training and test mean vectors have different lengths' 167 | assert sigma1.shape == sigma2.shape, \ 168 | 'Training and test covariances have different dimensions' 169 | 170 | diff = mu1 - mu2 171 | 172 | # Product might be almost singular 173 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 174 | if not np.isfinite(covmean).all(): 175 | msg = ('fid calculation produces singular product; ' 176 | 'adding %s to diagonal of cov estimates') % eps 177 | print(msg) 178 | offset = np.eye(sigma1.shape[0]) * eps 179 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 180 | 181 | # Numerical error might give slight imaginary component 182 | if np.iscomplexobj(covmean): 183 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 184 | m = np.max(np.abs(covmean.imag)) 185 | raise ValueError('Imaginary component {}'.format(m)) 186 | covmean = covmean.real 187 | 188 | tr_covmean = np.trace(covmean) 189 | 190 | return (diff.dot(diff) + np.trace(sigma1) 191 | + np.trace(sigma2) - 2 * tr_covmean) 192 | 193 | 194 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 195 | device='cpu', num_workers=8): 196 | """Calculation of the statistics used by the FID. 197 | Params: 198 | -- files : List of image files paths 199 | -- model : Instance of inception model 200 | -- batch_size : The images numpy array is split into batches with 201 | batch size batch_size. A reasonable batch size 202 | depends on the hardware. 203 | -- dims : Dimensionality of features returned by Inception 204 | -- device : Device to run calculations 205 | -- num_workers : Number of parallel dataloader workers 206 | 207 | Returns: 208 | -- mu : The mean over samples of the activations of the pool_3 layer of 209 | the inception model. 210 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 211 | the inception model. 212 | """ 213 | act = get_activations(files, model, batch_size, dims, device, num_workers) 214 | mu = np.mean(act, axis=0) 215 | sigma = np.cov(act, rowvar=False) 216 | return mu, sigma 217 | 218 | def compute_statistics_of_images_in_npz(path, model, batch_size, dims, device, num_workers=8): 219 | model.eval() 220 | pred_arr = np.empty((50000, dims)) 221 | start_idx = 0 222 | 223 | with open_npz_array(path, "arr_0") as reader: 224 | for samples in tqdm(reader.read_batches(batch_size), total=50000 // batch_size): 225 | samples = np.array(samples) 226 | samples = torch.from_numpy(samples.transpose(0, 3, 1, 2)).contiguous().to(device) 227 | samples = samples.div(255).float() 228 | 229 | #batch = torch.tensor(samples).to(device) 230 | with torch.no_grad(): 231 | pred = model(samples)[0] 232 | 233 | # If model output is not scalar, apply global spatial average pooling. 234 | # This happens if you choose a dimensionality not equal 2048. 235 | if pred.size(2) != 1 or pred.size(3) != 1: 236 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 237 | 238 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 239 | 240 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 241 | 242 | start_idx = start_idx + pred.shape[0] 243 | 244 | act = pred_arr 245 | mu = np.mean(act, axis=0) 246 | sigma = np.cov(act, rowvar=False) 247 | return mu, sigma 248 | 249 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): 250 | if path.endswith('.npz'): 251 | with np.load(path) as f: 252 | m, s = f['mu'][:], f['sigma'][:] 253 | else: 254 | path = pathlib.Path(path) 255 | files = sorted([file for ext in IMAGE_EXTENSIONS 256 | for file in path.glob('*.{}'.format(ext))]) 257 | m, s = calculate_activation_statistics(files, model, batch_size, 258 | dims, device, num_workers) 259 | 260 | return m, s 261 | 262 | 263 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8): 264 | if device is None: 265 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 266 | else: 267 | device = torch.device(device) 268 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 269 | model = InceptionV3([block_idx]).to(device) 270 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers) 271 | np.savez(out_path, mu=m1, sigma=s1) 272 | 273 | 274 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8): 275 | """Calculates the FID of two paths""" 276 | if device is None: 277 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 278 | else: 279 | device = torch.device(device) 280 | 281 | for p in paths: 282 | if not os.path.exists(p): 283 | raise RuntimeError('Invalid path: %s' % p) 284 | 285 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 286 | 287 | model = InceptionV3([block_idx]).to(device) 288 | 289 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 290 | dims, device, num_workers) 291 | if paths[1].endswith('.npz'): 292 | m2, s2 = compute_statistics_of_images_in_npz(paths[1], model, batch_size, 293 | dims, device, num_workers) 294 | else: 295 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 296 | dims, device, num_workers) 297 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 298 | 299 | return fid_value 300 | -------------------------------------------------------------------------------- /DiT/sample_ddp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Samples a large number of images from a pre-trained DiT model using DDP. 9 | Subsequently saves a .npz file that can be used to compute FID and other 10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations 11 | 12 | For a simple single-GPU/CPU sampling script, see sample.py. 13 | """ 14 | import torch 15 | import torch.distributed as dist 16 | from download import find_model 17 | from diffusion import create_diffusion 18 | from diffusers.models import AutoencoderKL 19 | from tqdm import tqdm 20 | import os 21 | from PIL import Image 22 | import numpy as np 23 | import math 24 | import argparse 25 | 26 | 27 | def create_npz_from_sample_folder(sample_dir, num=50_000): 28 | """ 29 | Builds a single .npz file from a folder of .png samples. 30 | """ 31 | samples = [] 32 | for i in tqdm(range(num), desc="Building .npz file from samples"): 33 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 34 | sample_np = np.asarray(sample_pil).astype(np.uint8) 35 | samples.append(sample_np) 36 | samples = np.stack(samples) 37 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 38 | npz_path = f"{sample_dir}.npz" 39 | np.savez(npz_path, arr_0=samples) 40 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 41 | return npz_path 42 | 43 | 44 | def main(args): 45 | """ 46 | Run sampling. 47 | """ 48 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 49 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 50 | torch.set_grad_enabled(False) 51 | 52 | # Setup DDP: 53 | dist.init_process_group("nccl") 54 | rank = dist.get_rank() 55 | device = rank % torch.cuda.device_count() 56 | seed = args.global_seed * dist.get_world_size() + rank 57 | torch.manual_seed(seed) 58 | torch.cuda.set_device(device) 59 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 60 | 61 | if args.ckpt is None: 62 | assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." 63 | assert args.image_size in [256, 512] 64 | assert args.num_classes == 1000 65 | 66 | diffusion = create_diffusion(str(args.num_sampling_steps)) 67 | 68 | # Load model: 69 | latent_size = args.image_size // 8 70 | if args.accelerate_method == "cache": 71 | from models.cache_models import DiT_models 72 | elif args.accelerate_method == "iterate": 73 | from models.iterate_models import DiT_models 74 | elif args.accelerate_method == "nolastlayer": 75 | from models.nolastlayer_models import DiT_models 76 | elif args.accelerate_method is not None and "ranklayer" in args.accelerate_method: 77 | from models.rankdrop_models import DiT_models 78 | elif args.accelerate_method is not None and "bottomlayer" in args.accelerate_method: 79 | from models.bottom_models import DiT_models 80 | elif args.accelerate_method is not None and "randomlayer" in args.accelerate_method: 81 | from models.randomlayer_models import DiT_models 82 | elif args.accelerate_method is not None and "fixlayer" in args.accelerate_method: 83 | from models.fixlayer_models import DiT_models 84 | elif args.accelerate_method is not None and args.accelerate_method == "dynamiclayer": 85 | from models.dynamic_models import DiT_models 86 | elif args.accelerate_method is not None and args.accelerate_method == "layerdropout": 87 | from models.layerdropout_models import DiT_models 88 | elif args.accelerate_method is not None and args.accelerate_method == "dynamiclayer_soft": 89 | from models.router_models_inference import DiT_models 90 | else: 91 | from models.models import DiT_models 92 | 93 | model = DiT_models[args.model]( 94 | input_size=latent_size, 95 | num_classes=args.num_classes 96 | ).to(device) 97 | 98 | if args.accelerate_method is not None: 99 | if 'ranklayer' in args.accelerate_method: 100 | model.load_ranking(args.num_sampling_steps, args.accelerate_method) 101 | elif 'randomlayer' in args.accelerate_method: 102 | model.load_ranking(args.accelerate_method) 103 | elif 'bottomlayer' in args.accelerate_method or 'fixlayer' in args.accelerate_method: 104 | model.load_ranking(args.accelerate_method) 105 | elif 'dynamiclayer' in args.accelerate_method or 'layerdropout' in args.accelerate_method or 'dynamiclayer_soft' in args.accelerate_method: 106 | model.load_ranking(args.path, args.num_sampling_steps, diffusion.timestep_map, args.thres) 107 | 108 | 109 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: 110 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" 111 | state_dict = find_model(ckpt_path) 112 | model.load_state_dict(state_dict) 113 | model.eval() # important! 114 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 115 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 116 | using_cfg = args.cfg_scale > 1.0 117 | 118 | # Create folder to save samples: 119 | model_string_name = args.model.replace("/", "-") 120 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" 121 | if args.accelerate_method is not None and 'dynamiclayer' in args.accelerate_method: 122 | router_name = args.path.split('/')[1].split('.')[0] 123 | folder_name = f"router-{router_name}-thres-{args.thres}-accelerate-{args.accelerate_method}-size-{args.image_size}-vae-{args.vae}-ddim-{args.ddim_sample}-" \ 124 | f"steps-{args.num_sampling_steps}-cfg-{args.cfg_scale}-seed-{args.global_seed}" 125 | else: 126 | folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-psampler-{args.p_sample}-ddim-{args.ddim_sample}-" \ 127 | f"steps-{args.num_sampling_steps}-accelerate-{args.accelerate_method}-cfg-{args.cfg_scale}-seed-{args.global_seed}" 128 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 129 | 130 | os.makedirs(f"{args.sample_dir}", exist_ok=True) 131 | if rank == 0 and args.save_to_disk: 132 | os.makedirs(sample_folder_dir, exist_ok=True) 133 | print(f"Saving .png samples at {sample_folder_dir}") 134 | dist.barrier() 135 | 136 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 137 | n = args.per_proc_batch_size 138 | global_batch_size = n * dist.get_world_size() 139 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 140 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 141 | if rank == 0: 142 | print(f"Total number of images that will be sampled: {total_samples}") 143 | all_images = [] 144 | 145 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 146 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 147 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 148 | iterations = int(samples_needed_this_gpu // n) 149 | pbar = range(iterations) 150 | pbar = tqdm(pbar) if rank == 0 else pbar 151 | total = 0 152 | 153 | for _ in pbar: 154 | model.reset(args.num_sampling_steps) 155 | 156 | # Sample inputs: 157 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) 158 | y = torch.randint(0, args.num_classes, (n,), device=device) 159 | 160 | 161 | # Setup classifier-free guidance: 162 | if using_cfg: 163 | z = torch.cat([z, z], 0) 164 | y_null = torch.tensor([1000] * n, device=device) 165 | y = torch.cat([y, y_null], 0) 166 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 167 | sample_fn = model.forward_with_cfg 168 | else: 169 | model_kwargs = dict(y=y) 170 | sample_fn = model.forward 171 | 172 | # Sample images: 173 | if args.p_sample: 174 | samples = diffusion.p_sample_loop( 175 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 176 | ) 177 | elif args.ddim_sample: 178 | samples = diffusion.ddim_sample_loop( 179 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 180 | ) 181 | else: 182 | raise NotImplementedError 183 | 184 | if using_cfg: 185 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 186 | 187 | samples = vae.decode(samples / 0.18215).sample 188 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(dtype=torch.uint8) 189 | 190 | # Save samples to disk as individual .png files 191 | if args.save_to_disk: 192 | for i, sample in enumerate(samples): 193 | index = i * dist.get_world_size() + rank + total 194 | sample = sample.cpu().numpy() 195 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 196 | else: 197 | samples = samples.contiguous() 198 | gathered_samples = [torch.zeros_like(samples) for _ in range(dist.get_world_size())] 199 | dist.all_gather(gathered_samples, samples) 200 | 201 | if rank == 0: 202 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 203 | total += global_batch_size 204 | 205 | dist.barrier() 206 | 207 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 208 | dist.barrier() 209 | if rank == 0: 210 | if args.save_to_disk: 211 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) 212 | print("Done.") 213 | else: 214 | if rank == 0: 215 | arr = np.concatenate(all_images, axis=0) 216 | arr = arr[: args.num_fid_samples] 217 | 218 | out_path = f"{sample_folder_dir}.npz" 219 | 220 | print(f"saving to {out_path}") 221 | np.savez(out_path, arr_0=arr) 222 | dist.barrier() 223 | dist.destroy_process_group() 224 | 225 | 226 | if __name__ == "__main__": 227 | parser = argparse.ArgumentParser() 228 | parser.add_argument("--model", type=str, default="DiT-XL/2") 229 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") 230 | parser.add_argument("--sample-dir", type=str, default="samples") 231 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 232 | parser.add_argument("--num-fid-samples", type=int, default=50_000) 233 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 234 | parser.add_argument("--num-classes", type=int, default=1000) 235 | parser.add_argument("--cfg-scale", type=float, default=1.5) 236 | parser.add_argument("--num-sampling-steps", type=int, default=250) 237 | parser.add_argument("--global-seed", type=int, default=0) 238 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, 239 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") 240 | parser.add_argument("--ckpt", type=str, default=None, 241 | help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") 242 | 243 | parser.add_argument("--ddim-sample", action="store_true", default=False,) 244 | parser.add_argument("--p-sample", action="store_true", default=False,) 245 | 246 | parser.add_argument("--accelerate-method", type=str, default=None, 247 | help="Use the accelerated version of the model.") 248 | parser.add_argument("--thres", type=float, default=0.5) 249 | 250 | parser.add_argument("--name", type=str, default="None") 251 | parser.add_argument("--path", type=str, default=None,) 252 | 253 | parser.add_argument("--save-to-disk", action="store_true", default=False,) 254 | 255 | 256 | args = parser.parse_args() 257 | main(args) 258 | -------------------------------------------------------------------------------- /U-ViT/tools/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | -------------------------------------------------------------------------------- /U-ViT/libs/uvit_dynamic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | from .timm import trunc_normal_, Mlp 6 | import einops 7 | import torch.utils.checkpoint 8 | 9 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 10 | ATTENTION_MODE = 'flash' 11 | else: 12 | try: 13 | import xformers 14 | import xformers.ops 15 | ATTENTION_MODE = 'xformers' 16 | except: 17 | ATTENTION_MODE = 'math' 18 | print(f'attention mode is {ATTENTION_MODE}') 19 | 20 | 21 | def timestep_embedding(timesteps, dim, max_period=10000): 22 | """ 23 | Create sinusoidal timestep embeddings. 24 | 25 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 26 | These may be fractional. 27 | :param dim: the dimension of the output. 28 | :param max_period: controls the minimum frequency of the embeddings. 29 | :return: an [N x dim] Tensor of positional embeddings. 30 | """ 31 | half = dim // 2 32 | freqs = torch.exp( 33 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 34 | ).to(device=timesteps.device) 35 | args = timesteps[:, None].float() * freqs[None] 36 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 37 | if dim % 2: 38 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 39 | return embedding 40 | 41 | 42 | def patchify(imgs, patch_size): 43 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 44 | return x 45 | 46 | 47 | def unpatchify(x, channels=3): 48 | patch_size = int((x.shape[2] // channels) ** 0.5) 49 | h = w = int(x.shape[1] ** .5) 50 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 51 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 52 | return x 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 57 | super().__init__() 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 63 | self.attn_drop = nn.Dropout(attn_drop) 64 | self.proj = nn.Linear(dim, dim) 65 | self.proj_drop = nn.Dropout(proj_drop) 66 | 67 | def forward(self, x): 68 | B, L, C = x.shape 69 | 70 | qkv = self.qkv(x) 71 | if ATTENTION_MODE == 'flash': 72 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 73 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 74 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 75 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 76 | elif ATTENTION_MODE == 'xformers': 77 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 78 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 79 | x = xformers.ops.memory_efficient_attention(q, k, v) 80 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 81 | elif ATTENTION_MODE == 'math': 82 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 83 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 84 | attn = (q @ k.transpose(-2, -1)) * self.scale 85 | attn = attn.softmax(dim=-1) 86 | attn = self.attn_drop(attn) 87 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 88 | else: 89 | raise NotImplemented 90 | 91 | x = self.proj(x) 92 | x = self.proj_drop(x) 93 | return x 94 | 95 | 96 | class Block(nn.Module): 97 | 98 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 99 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 100 | super().__init__() 101 | self.norm1 = norm_layer(dim) 102 | self.attn = Attention( 103 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 104 | self.norm2 = norm_layer(dim) 105 | mlp_hidden_dim = int(dim * mlp_ratio) 106 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 107 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 108 | self.use_checkpoint = use_checkpoint 109 | 110 | def forward(self, x, skip=None, reuse_att=None, reuse_mlp=None): 111 | if self.use_checkpoint: 112 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip, reuse_att, reuse_mlp) 113 | else: 114 | return self._forward(x, skip, reuse_att, reuse_mlp) 115 | 116 | def _forward(self, x, skip=None, reuse_att=None, reuse_mlp=None): 117 | if self.skip_linear is not None: 118 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 119 | 120 | if reuse_att is not None: 121 | x = x + reuse_att 122 | else: 123 | reuse_att = self.attn(self.norm1(x)) 124 | x = x + reuse_att 125 | 126 | if reuse_mlp is not None: 127 | x = x + reuse_mlp 128 | else: 129 | reuse_mlp = self.mlp(self.norm2(x)) 130 | x = x + reuse_mlp 131 | return x, (reuse_att, reuse_mlp) 132 | 133 | 134 | class PatchEmbed(nn.Module): 135 | """ Image to Patch Embedding 136 | """ 137 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 138 | super().__init__() 139 | self.patch_size = patch_size 140 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 141 | 142 | def forward(self, x): 143 | B, C, H, W = x.shape 144 | assert H % self.patch_size == 0 and W % self.patch_size == 0 145 | x = self.proj(x).flatten(2).transpose(1, 2) 146 | return x 147 | 148 | 149 | class UViT(nn.Module): 150 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 151 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, 152 | use_checkpoint=False, conv=True, skip=True): 153 | super().__init__() 154 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 155 | self.num_classes = num_classes 156 | self.in_chans = in_chans 157 | 158 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 159 | num_patches = (img_size // patch_size) ** 2 160 | 161 | self.time_embed = nn.Sequential( 162 | nn.Linear(embed_dim, 4 * embed_dim), 163 | nn.SiLU(), 164 | nn.Linear(4 * embed_dim, embed_dim), 165 | ) if mlp_time_embed else nn.Identity() 166 | 167 | if self.num_classes > 0: 168 | self.label_emb = nn.Embedding(self.num_classes, embed_dim) 169 | self.extras = 2 170 | else: 171 | self.extras = 1 172 | 173 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 174 | 175 | self.in_blocks = nn.ModuleList([ 176 | Block( 177 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 178 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 179 | for _ in range(depth // 2)]) 180 | 181 | self.mid_block = Block( 182 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 183 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 184 | 185 | self.out_blocks = nn.ModuleList([ 186 | Block( 187 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 188 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 189 | for _ in range(depth // 2)]) 190 | 191 | self.depth = depth + 1 # depth//2 for in/out, and 1 for mid 192 | 193 | self.norm = norm_layer(embed_dim) 194 | self.patch_dim = patch_size ** 2 * in_chans 195 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 196 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 197 | 198 | trunc_normal_(self.pos_embed, std=.02) 199 | self.apply(self._init_weights) 200 | 201 | self.reset() 202 | 203 | def reset_cache_features(self): 204 | self.cond_cache_features = [None] * self.depth 205 | self.uncond_cache_features = [None] * self.depth 206 | 207 | def reset(self): 208 | self.cur_step_idx = 0 209 | self.reset_cache_features() 210 | 211 | def _init_weights(self, m): 212 | if isinstance(m, nn.Linear): 213 | trunc_normal_(m.weight, std=.02) 214 | if isinstance(m, nn.Linear) and m.bias is not None: 215 | nn.init.constant_(m.bias, 0) 216 | elif isinstance(m, nn.LayerNorm): 217 | nn.init.constant_(m.bias, 0) 218 | nn.init.constant_(m.weight, 1.0) 219 | 220 | @torch.jit.ignore 221 | def no_weight_decay(self): 222 | return {'pos_embed'} 223 | 224 | def load_ranking(self, path, num_steps, timestep_map, thres): 225 | self.rank = [None] * num_steps 226 | from .uvit_router import Router 227 | 228 | act_layer, total_layer = 0, 0 229 | ckpt = torch.load(path, map_location='cpu') 230 | routers = torch.nn.ModuleList([ 231 | Router(2*self.depth) for _ in range(num_steps) 232 | ]) 233 | routers.load_state_dict(ckpt) 234 | self.timestep_map = {timestep: i for i, timestep in enumerate(timestep_map)} 235 | print(self.timestep_map) 236 | 237 | act_att, act_mlp = 0, 0 238 | for idx, router in enumerate(routers[:num_steps//2]): 239 | if idx != 0: 240 | self.rank[idx] = (router() > thres).float().nonzero().squeeze(0) 241 | total_layer += 2 * self.depth 242 | act_layer += len(self.rank[idx]) 243 | print(f"TImestep {idx}: Not Reuse: {self.rank[idx].squeeze()}") 244 | 245 | if len(self.rank[idx]) > 0: 246 | act_att += sum(1 - torch.remainder(self.rank[idx], 2)).item() 247 | act_mlp += sum(torch.remainder(self.rank[idx], 2)).item() 248 | 249 | print(f"Total Activate Layer: {act_layer}/{total_layer}, Remove Ratio = {1 - act_layer/total_layer}") 250 | print(f"Total Activate Attention: {act_att}/{total_layer//2}, Remove Ratio = {1 - act_att/(total_layer//2)}") 251 | print(f"Total Activate MLP: {act_mlp}/{total_layer//2}, Remove Ratio = {1 - act_mlp/(total_layer//2)}") 252 | 253 | def forward(self, x, timesteps, y=None): 254 | x = self.patch_embed(x) 255 | B, L, D = x.shape 256 | 257 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 258 | time_token = time_token.unsqueeze(dim=1) 259 | x = torch.cat((time_token, x), dim=1) 260 | if y is not None: 261 | label_emb = self.label_emb(y) 262 | label_emb = label_emb.unsqueeze(dim=1) 263 | x = torch.cat((label_emb, x), dim=1) 264 | x = x + self.pos_embed 265 | 266 | skips = [] 267 | 268 | if self.cur_step_idx % 4 == 2: 269 | self.reset_cache_features() 270 | 271 | if self.cur_step_idx % 2 == 0: 272 | cache_features = self.cond_cache_features 273 | else: 274 | cache_features = self.uncond_cache_features 275 | 276 | round_timestep = round(timesteps[0].item()) 277 | router_idx = self.timestep_map[round_timestep] if round_timestep in self.timestep_map else None 278 | #print(f"Round Timestep: {round_timestep}, Router Index: {router_idx}") 279 | 280 | layer_idx = 0 281 | for blk in self.in_blocks: 282 | reuse_att, reuse_mlp = None, None 283 | if cache_features[layer_idx] is not None: 284 | if layer_idx * 2 not in self.rank[router_idx]: 285 | reuse_att, _ = cache_features[layer_idx] 286 | if layer_idx * 2 + 1 not in self.rank[router_idx]: 287 | _, reuse_mlp = cache_features[layer_idx] 288 | 289 | x, cache_feature = blk(x, reuse_att=reuse_att, reuse_mlp=reuse_mlp) 290 | skips.append(x) 291 | cache_features[layer_idx] = cache_feature 292 | layer_idx += 1 293 | 294 | reuse_att, reuse_mlp = None, None 295 | if cache_features[layer_idx] is not None: 296 | if layer_idx * 2 not in self.rank[router_idx]: 297 | reuse_att, _ = cache_features[layer_idx] 298 | if layer_idx * 2 + 1 not in self.rank[router_idx]: 299 | _, reuse_mlp = cache_features[layer_idx] 300 | 301 | x, cache_feature = self.mid_block(x, reuse_att=reuse_att, reuse_mlp=reuse_mlp) 302 | cache_features[layer_idx] = cache_feature 303 | layer_idx += 1 304 | 305 | for blk in self.out_blocks: 306 | reuse_att, reuse_mlp = None, None 307 | if cache_features[layer_idx] is not None: 308 | if layer_idx * 2 not in self.rank[router_idx]: 309 | reuse_att, _ = cache_features[layer_idx] 310 | if layer_idx * 2 + 1 not in self.rank[router_idx]: 311 | _, reuse_mlp = cache_features[layer_idx] 312 | x , cache_feature = blk(x, skips.pop(), reuse_att=reuse_att, reuse_mlp=reuse_mlp) 313 | cache_features[layer_idx] = cache_feature 314 | layer_idx += 1 315 | 316 | x = self.norm(x) 317 | x = self.decoder_pred(x) 318 | assert x.size(1) == self.extras + L 319 | x = x[:, self.extras:, :] 320 | x = unpatchify(x, self.in_chans) 321 | x = self.final_layer(x) 322 | 323 | self.cur_step_idx += 1 324 | return x 325 | -------------------------------------------------------------------------------- /U-ViT/libs/uvit_router.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | import numpy as np 8 | 9 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 10 | ATTENTION_MODE = 'flash' 11 | else: 12 | try: 13 | import xformers 14 | import xformers.ops 15 | ATTENTION_MODE = 'xformers' 16 | except: 17 | ATTENTION_MODE = 'math' 18 | print(f'attention mode is {ATTENTION_MODE}') 19 | 20 | 21 | def timestep_embedding(timesteps, dim, max_period=10000): 22 | """ 23 | Create sinusoidal timestep embeddings. 24 | 25 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 26 | These may be fractional. 27 | :param dim: the dimension of the output. 28 | :param max_period: controls the minimum frequency of the embeddings. 29 | :return: an [N x dim] Tensor of positional embeddings. 30 | """ 31 | half = dim // 2 32 | freqs = torch.exp( 33 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 34 | ).to(device=timesteps.device) 35 | args = timesteps[:, None].float() * freqs[None] 36 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 37 | if dim % 2: 38 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 39 | return embedding 40 | 41 | 42 | def patchify(imgs, patch_size): 43 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 44 | return x 45 | 46 | 47 | def unpatchify(x, channels=3): 48 | patch_size = int((x.shape[2] // channels) ** 0.5) 49 | h = w = int(x.shape[1] ** .5) 50 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 51 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 52 | return x 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 57 | super().__init__() 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 63 | self.attn_drop = nn.Dropout(attn_drop) 64 | self.proj = nn.Linear(dim, dim) 65 | self.proj_drop = nn.Dropout(proj_drop) 66 | 67 | def forward(self, x): 68 | B, L, C = x.shape 69 | 70 | qkv = self.qkv(x) 71 | if ATTENTION_MODE == 'flash': 72 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 73 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 74 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 75 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 76 | elif ATTENTION_MODE == 'xformers': 77 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 78 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 79 | x = xformers.ops.memory_efficient_attention(q, k, v) 80 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 81 | elif ATTENTION_MODE == 'math': 82 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 83 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 84 | attn = (q @ k.transpose(-2, -1)) * self.scale 85 | attn = attn.softmax(dim=-1) 86 | attn = self.attn_drop(attn) 87 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 88 | else: 89 | raise NotImplemented 90 | 91 | x = self.proj(x) 92 | x = self.proj_drop(x) 93 | return x 94 | 95 | 96 | class Block(nn.Module): 97 | 98 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 99 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 100 | super().__init__() 101 | self.norm1 = norm_layer(dim) 102 | self.attn = Attention( 103 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 104 | self.norm2 = norm_layer(dim) 105 | mlp_hidden_dim = int(dim * mlp_ratio) 106 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 107 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 108 | self.use_checkpoint = False #use_checkpoint 109 | 110 | def forward(self, x, skip=None, reuse_att=None, reuse_mlp=None, 111 | reuse_att_weight=0, reuse_mlp_weight=0): 112 | if self.use_checkpoint: 113 | return torch.utils.checkpoint.checkpoint( 114 | self._forward, x, skip, reuse_att, reuse_mlp, 115 | reuse_att_weight, reuse_mlp_weight 116 | ) 117 | else: 118 | return self._forward( 119 | x, skip, reuse_att, reuse_mlp, 120 | reuse_att_weight, reuse_mlp_weight 121 | ) 122 | 123 | def _forward(self, x, skip=None, reuse_att=None, reuse_mlp=None, reuse_att_weight=None, reuse_mlp_weight=None): 124 | if self.skip_linear is not None: 125 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 126 | 127 | att_out = self.attn(self.norm1(x)) 128 | if reuse_att is not None: 129 | att_out = att_out * (1 - reuse_att_weight) + reuse_att * reuse_att_weight 130 | x = x + att_out 131 | 132 | mlp_out = self.mlp(self.norm2(x)) 133 | if reuse_mlp is not None: 134 | mlp_out = mlp_out * (1 - reuse_mlp_weight) + reuse_mlp * reuse_mlp_weight 135 | x = x + mlp_out 136 | return x, (att_out, mlp_out) 137 | 138 | 139 | class PatchEmbed(nn.Module): 140 | """ Image to Patch Embedding 141 | """ 142 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 143 | super().__init__() 144 | self.patch_size = patch_size 145 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 146 | 147 | def forward(self, x): 148 | B, C, H, W = x.shape 149 | assert H % self.patch_size == 0 and W % self.patch_size == 0 150 | x = self.proj(x).flatten(2).transpose(1, 2) 151 | return x 152 | 153 | class Router(nn.Module): 154 | def __init__(self, num_choises): 155 | super().__init__() 156 | self.num_choises = num_choises 157 | self.prob = torch.nn.Parameter(torch.randn(num_choises), requires_grad=True) 158 | 159 | self.activation = torch.nn.Sigmoid() 160 | 161 | def forward(self, x=None): # Any input will be ignored, only for solving the issue of https://github.com/pytorch/pytorch/issues/37814 162 | return self.activation(self.prob) 163 | 164 | class UViT(nn.Module): 165 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 166 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, 167 | use_checkpoint=False, conv=True, skip=True): 168 | super().__init__() 169 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 170 | self.num_classes = num_classes 171 | self.in_chans = in_chans 172 | 173 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 174 | num_patches = (img_size // patch_size) ** 2 175 | 176 | self.time_embed = nn.Sequential( 177 | nn.Linear(embed_dim, 4 * embed_dim), 178 | nn.SiLU(), 179 | nn.Linear(4 * embed_dim, embed_dim), 180 | ) if mlp_time_embed else nn.Identity() 181 | 182 | if self.num_classes > 0: 183 | self.label_emb = nn.Embedding(self.num_classes, embed_dim) 184 | self.extras = 2 185 | else: 186 | self.extras = 1 187 | 188 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 189 | 190 | self.in_blocks = nn.ModuleList([ 191 | Block( 192 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 193 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 194 | for _ in range(depth // 2)]) 195 | 196 | self.mid_block = Block( 197 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 198 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 199 | 200 | self.out_blocks = nn.ModuleList([ 201 | Block( 202 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 203 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 204 | for _ in range(depth // 2)]) 205 | 206 | self.depth = depth + 1 # depth//2 for in/out, and 1 for mid 207 | 208 | self.norm = norm_layer(embed_dim) 209 | self.patch_dim = patch_size ** 2 * in_chans 210 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 211 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 212 | 213 | trunc_normal_(self.pos_embed, std=.02) 214 | self.apply(self._init_weights) 215 | 216 | self.reset() 217 | 218 | def reset_cache_features(self): 219 | self.cache_features = [None] * self.depth 220 | self.activate_cache = False 221 | self.record_cache = True 222 | 223 | def reset(self): 224 | self.cur_step_idx = 0 225 | self.reset_cache_features() 226 | 227 | def add_router(self, num_nfes): 228 | self.routers = torch.nn.ModuleList([ 229 | Router(2*self.depth) for _ in range(num_nfes) 230 | ]) 231 | 232 | def set_activate_cache(self, activate_cache): 233 | self.activate_cache = activate_cache 234 | 235 | def set_record_cache(self, record_cache): 236 | self.record_cache = record_cache 237 | 238 | def set_timestep_map(self, timestep_map): 239 | self.timestep_map = {timestep: i for i, timestep in enumerate(timestep_map)} 240 | print("Timestep -> Router IDX Map:", self.timestep_map) 241 | 242 | def _init_weights(self, m): 243 | if isinstance(m, nn.Linear): 244 | trunc_normal_(m.weight, std=.02) 245 | if isinstance(m, nn.Linear) and m.bias is not None: 246 | nn.init.constant_(m.bias, 0) 247 | elif isinstance(m, nn.LayerNorm): 248 | nn.init.constant_(m.bias, 0) 249 | nn.init.constant_(m.weight, 1.0) 250 | 251 | @torch.jit.ignore 252 | def no_weight_decay(self): 253 | return {'pos_embed'} 254 | 255 | def forward(self, x, timesteps, y=None): 256 | #print("In Model: Get y: ", y, ". Get Timesteps: ", timesteps) 257 | x = self.patch_embed(x) 258 | B, L, D = x.shape 259 | 260 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 261 | 262 | time_token = time_token.unsqueeze(dim=1) 263 | x = torch.cat((time_token, x), dim=1) 264 | if y is not None: 265 | label_emb = self.label_emb(y) 266 | label_emb = label_emb.unsqueeze(dim=1) 267 | x = torch.cat((label_emb, x), dim=1) 268 | x = x + self.pos_embed 269 | 270 | skips = [] 271 | cache_features = self.cache_features 272 | if self.activate_cache : 273 | router_idx = self.timestep_map[np.round(timesteps[0].item())] 274 | scores = self.routers[router_idx]() 275 | router_l1_loss = scores.sum() 276 | else: 277 | router_l1_loss = None 278 | 279 | layer_idx = 0 280 | for blk in self.in_blocks: 281 | if cache_features[layer_idx] is not None and self.activate_cache: 282 | reuse_att, reuse_mlp = cache_features[layer_idx] 283 | reuse_att_weight = 1 - scores[layer_idx*2] 284 | reuse_mlp_weight = 1 - scores[layer_idx*2+1] 285 | else: 286 | reuse_att, reuse_mlp = None, None 287 | reuse_att_weight, reuse_mlp_weight = 0, 0 288 | 289 | x, cache_feature = blk( 290 | x, reuse_att=reuse_att, reuse_mlp=reuse_mlp, 291 | reuse_att_weight=reuse_att_weight, 292 | reuse_mlp_weight=reuse_mlp_weight, 293 | ) 294 | skips.append(x) 295 | if self.record_cache: 296 | cache_features[layer_idx] = cache_feature 297 | layer_idx += 1 298 | 299 | if cache_features[layer_idx] is not None and self.activate_cache: 300 | reuse_att, reuse_mlp = cache_features[layer_idx] 301 | reuse_att_weight = 1 - scores[layer_idx*2] 302 | reuse_mlp_weight = 1 - scores[layer_idx*2+1] 303 | else: 304 | reuse_att, reuse_mlp = None, None 305 | reuse_att_weight, reuse_mlp_weight = 0, 0 306 | 307 | x, cache_feature = self.mid_block( 308 | x, reuse_att=reuse_att, reuse_mlp=reuse_mlp, 309 | reuse_att_weight=reuse_att_weight, 310 | reuse_mlp_weight=reuse_mlp_weight, 311 | ) 312 | if self.record_cache: 313 | cache_features[layer_idx] = cache_feature 314 | layer_idx += 1 315 | 316 | for blk in self.out_blocks: 317 | if cache_features[layer_idx] is not None and self.activate_cache: 318 | reuse_att, reuse_mlp = cache_features[layer_idx] 319 | reuse_att_weight = 1 - scores[layer_idx*2] 320 | reuse_mlp_weight = 1 - scores[layer_idx*2+1] 321 | else: 322 | reuse_att, reuse_mlp = None, None 323 | reuse_att_weight, reuse_mlp_weight = 0, 0 324 | 325 | x , cache_feature = blk( 326 | x, skips.pop(), reuse_att=reuse_att, reuse_mlp=reuse_mlp, 327 | reuse_att_weight=reuse_att_weight, 328 | reuse_mlp_weight=reuse_mlp_weight, 329 | ) 330 | if self.record_cache: 331 | cache_features[layer_idx] = cache_feature 332 | layer_idx += 1 333 | 334 | x = self.norm(x) 335 | x = self.decoder_pred(x) 336 | assert x.size(1) == self.extras + L 337 | x = x[:, self.extras:, :] 338 | x = unpatchify(x, self.in_chans) 339 | x = self.final_layer(x) 340 | 341 | self.cur_step_idx += 1 342 | 343 | if self.activate_cache: 344 | return x, router_l1_loss 345 | else: 346 | return x 347 | -------------------------------------------------------------------------------- /DiT/models/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, num_classes, hidden_size, dropout_prob): 72 | super().__init__() 73 | use_cfg_embedding = dropout_prob > 0 74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 75 | self.num_classes = num_classes 76 | self.dropout_prob = dropout_prob 77 | 78 | def token_drop(self, labels, force_drop_ids=None): 79 | """ 80 | Drops labels to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | labels = torch.where(drop_ids, self.num_classes, labels) 87 | return labels 88 | 89 | def forward(self, labels, train, force_drop_ids=None): 90 | use_dropout = self.dropout_prob > 0 91 | if (train and use_dropout) or (force_drop_ids is not None): 92 | labels = self.token_drop(labels, force_drop_ids) 93 | embeddings = self.embedding_table(labels) 94 | return embeddings 95 | 96 | 97 | ################################################################################# 98 | # Core DiT Model # 99 | ################################################################################# 100 | 101 | class DiTBlock(nn.Module): 102 | """ 103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 104 | """ 105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 106 | super().__init__() 107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 111 | approx_gelu = lambda: nn.GELU(approximate="tanh") 112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 113 | self.adaLN_modulation = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 116 | ) 117 | 118 | def forward(self, x, c): 119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 122 | return x 123 | 124 | 125 | class FinalLayer(nn.Module): 126 | """ 127 | The final layer of DiT. 128 | """ 129 | def __init__(self, hidden_size, patch_size, out_channels): 130 | super().__init__() 131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 133 | self.adaLN_modulation = nn.Sequential( 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 136 | ) 137 | 138 | def forward(self, x, c): 139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 140 | x = modulate(self.norm_final(x), shift, scale) 141 | x = self.linear(x) 142 | return x 143 | 144 | 145 | class DiT(nn.Module): 146 | """ 147 | Diffusion model with a Transformer backbone. 148 | """ 149 | def __init__( 150 | self, 151 | input_size=32, 152 | patch_size=2, 153 | in_channels=4, 154 | hidden_size=1152, 155 | depth=28, 156 | num_heads=16, 157 | mlp_ratio=4.0, 158 | class_dropout_prob=0.1, 159 | num_classes=1000, 160 | learn_sigma=True, 161 | ): 162 | super().__init__() 163 | self.learn_sigma = learn_sigma 164 | self.in_channels = in_channels 165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 166 | self.patch_size = patch_size 167 | self.num_heads = num_heads 168 | 169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 170 | self.t_embedder = TimestepEmbedder(hidden_size) 171 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 172 | num_patches = self.x_embedder.num_patches 173 | # Will use fixed sin-cos embedding: 174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 175 | 176 | self.blocks = nn.ModuleList([ 177 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 178 | ]) 179 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 180 | self.initialize_weights() 181 | 182 | def reset(self): 183 | pass 184 | 185 | def initialize_weights(self): 186 | # Initialize transformer layers: 187 | def _basic_init(module): 188 | if isinstance(module, nn.Linear): 189 | torch.nn.init.xavier_uniform_(module.weight) 190 | if module.bias is not None: 191 | nn.init.constant_(module.bias, 0) 192 | self.apply(_basic_init) 193 | 194 | # Initialize (and freeze) pos_embed by sin-cos embedding: 195 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 196 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 197 | 198 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 199 | w = self.x_embedder.proj.weight.data 200 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 201 | nn.init.constant_(self.x_embedder.proj.bias, 0) 202 | 203 | # Initialize label embedding table: 204 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 205 | 206 | # Initialize timestep embedding MLP: 207 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 208 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 209 | 210 | # Zero-out adaLN modulation layers in DiT blocks: 211 | for block in self.blocks: 212 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 213 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 214 | 215 | # Zero-out output layers: 216 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 217 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 218 | nn.init.constant_(self.final_layer.linear.weight, 0) 219 | nn.init.constant_(self.final_layer.linear.bias, 0) 220 | 221 | def unpatchify(self, x): 222 | """ 223 | x: (N, T, patch_size**2 * C) 224 | imgs: (N, H, W, C) 225 | """ 226 | c = self.out_channels 227 | p = self.x_embedder.patch_size[0] 228 | h = w = int(x.shape[1] ** 0.5) 229 | assert h * w == x.shape[1] 230 | 231 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 232 | x = torch.einsum('nhwpqc->nchpwq', x) 233 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 234 | return imgs 235 | 236 | def forward(self, x, t, y): 237 | """ 238 | Forward pass of DiT. 239 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 240 | t: (N,) tensor of diffusion timesteps 241 | y: (N,) tensor of class labels 242 | """ 243 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 244 | t = self.t_embedder(t) # (N, D) 245 | y = self.y_embedder(y, self.training) # (N, D) 246 | c = t + y # (N, D) 247 | for block in self.blocks: 248 | x = block(x, c) # (N, T, D) 249 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 250 | x = self.unpatchify(x) # (N, out_channels, H, W) 251 | return x 252 | 253 | def forward_with_cfg(self, x, t, y, cfg_scale): 254 | """ 255 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 256 | """ 257 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 258 | half = x[: len(x) // 2] 259 | combined = torch.cat([half, half], dim=0) 260 | model_out = self.forward(combined, t, y) 261 | # For exact reproducibility reasons, we apply classifier-free guidance on only 262 | # three channels by default. The standard approach to cfg applies it to all channels. 263 | # This can be done by uncommenting the following line and commenting-out the line following that. 264 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 265 | eps, rest = model_out[:, :3], model_out[:, 3:] 266 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 267 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 268 | eps = torch.cat([half_eps, half_eps], dim=0) 269 | return torch.cat([eps, rest], dim=1) 270 | 271 | 272 | ################################################################################# 273 | # Sine/Cosine Positional Embedding Functions # 274 | ################################################################################# 275 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 276 | 277 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 278 | """ 279 | grid_size: int of the grid height and width 280 | return: 281 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 282 | """ 283 | grid_h = np.arange(grid_size, dtype=np.float32) 284 | grid_w = np.arange(grid_size, dtype=np.float32) 285 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 286 | grid = np.stack(grid, axis=0) 287 | 288 | grid = grid.reshape([2, 1, grid_size, grid_size]) 289 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 290 | if cls_token and extra_tokens > 0: 291 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 292 | return pos_embed 293 | 294 | 295 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 296 | assert embed_dim % 2 == 0 297 | 298 | # use half of dimensions to encode grid_h 299 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 300 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 301 | 302 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 303 | return emb 304 | 305 | 306 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 307 | """ 308 | embed_dim: output dimension for each position 309 | pos: a list of positions to be encoded: size (M,) 310 | out: (M, D) 311 | """ 312 | assert embed_dim % 2 == 0 313 | omega = np.arange(embed_dim // 2, dtype=np.float64) 314 | omega /= embed_dim / 2. 315 | omega = 1. / 10000**omega # (D/2,) 316 | 317 | pos = pos.reshape(-1) # (M,) 318 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 319 | 320 | emb_sin = np.sin(out) # (M, D/2) 321 | emb_cos = np.cos(out) # (M, D/2) 322 | 323 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 324 | return emb 325 | 326 | 327 | ################################################################################# 328 | # DiT Configs # 329 | ################################################################################# 330 | 331 | def DiT_XL_2(**kwargs): 332 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 333 | 334 | def DiT_XL_4(**kwargs): 335 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 336 | 337 | def DiT_XL_8(**kwargs): 338 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 339 | 340 | def DiT_L_2(**kwargs): 341 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 342 | 343 | def DiT_L_4(**kwargs): 344 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 345 | 346 | def DiT_L_8(**kwargs): 347 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 348 | 349 | def DiT_B_2(**kwargs): 350 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 351 | 352 | def DiT_B_4(**kwargs): 353 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 354 | 355 | def DiT_B_8(**kwargs): 356 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 357 | 358 | def DiT_S_2(**kwargs): 359 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 360 | 361 | def DiT_S_4(**kwargs): 362 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 363 | 364 | def DiT_S_8(**kwargs): 365 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 366 | 367 | 368 | DiT_models = { 369 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 370 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 371 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 372 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 373 | } 374 | -------------------------------------------------------------------------------- /U-ViT/train_router_discrete.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | from torch import multiprocessing as mp 4 | from datasets import get_dataset 5 | from torchvision.utils import make_grid, save_image 6 | import utils 7 | import einops 8 | from torch.utils._pytree import tree_map 9 | import accelerate 10 | from accelerate import DistributedDataParallelKwargs 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 14 | import tempfile 15 | from tools.fid_score import calculate_fid_given_paths 16 | from absl import logging 17 | import builtins 18 | import os 19 | import wandb 20 | import libs.autoencoder 21 | import numpy as np 22 | 23 | 24 | def format_image_to_wandb(num_router, router_size, router_scores): 25 | image = np.zeros((num_router, router_size, 3), dtype=np.float32) 26 | ones = np.ones((3), dtype=np.float32) 27 | for idx, score in enumerate(router_scores): 28 | mask = score.cpu().detach() 29 | for pos in range(router_size): 30 | image[idx, pos] = ones * mask[pos].item() 31 | return image 32 | 33 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 34 | _betas = ( 35 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 36 | ) 37 | return _betas.numpy() 38 | 39 | 40 | def get_skip(alphas, betas): 41 | N = len(betas) - 1 42 | skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype) 43 | for s in range(N + 1): 44 | skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod() 45 | skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype) 46 | for t in range(N + 1): 47 | prod = betas[1: t + 1] * skip_alphas[1: t + 1, t] 48 | skip_betas[:t, t] = (prod[::-1].cumsum())[::-1] 49 | return skip_alphas, skip_betas 50 | 51 | 52 | def stp(s, ts: torch.Tensor): # scalar tensor product 53 | if isinstance(s, np.ndarray): 54 | s = torch.from_numpy(s).type_as(ts) 55 | extra_dims = (1,) * (ts.dim() - 1) 56 | return s.view(-1, *extra_dims) * ts 57 | 58 | 59 | def mos(a, start_dim=1): # mean of square 60 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 61 | 62 | def sos(a, start_dim=1): # sum of square 63 | e = a.pow(2).flatten(start_dim=start_dim) 64 | return e.sum(dim=-1) 65 | 66 | 67 | class Schedule(object): # discrete time 68 | def __init__(self, _betas): 69 | r""" _betas[0...999] = betas[1...1000] 70 | for n>=1, betas[n] is the variance of q(xn|xn-1) 71 | for n=0, betas[0]=0 72 | """ 73 | 74 | self._betas = _betas 75 | self.betas = np.append(0., _betas) 76 | self.alphas = 1. - self.betas 77 | self.N = len(_betas) 78 | 79 | assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0 80 | assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1 81 | assert len(self.betas) == len(self.alphas) 82 | 83 | # skip_alphas[s, t] = alphas[s + 1: t + 1].prod() 84 | self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas) 85 | self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod() 86 | self.cum_betas = self.skip_betas[0] 87 | self.snr = self.cum_alphas / self.cum_betas 88 | 89 | def tilde_beta(self, s, t): 90 | return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t] 91 | 92 | def sample(self, x0): # sample from q(xn|x0), where n is uniform 93 | n = np.random.choice(list(range(1, self.N + 1)), (len(x0),)) 94 | eps = torch.randn_like(x0) 95 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps) 96 | return torch.tensor(n, device=x0.device), eps, xn 97 | 98 | def get_xn(self, x0, n): 99 | eps = torch.randn_like(x0) 100 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps) 101 | return torch.tensor(n, device=x0.device), eps, xn 102 | 103 | def __repr__(self): 104 | return f'Schedule({self.betas[:10]}..., {self.N})' 105 | 106 | 107 | def LSimple(x0, nnet, schedule, **kwargs): 108 | 109 | n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000} 110 | eps_pred = nnet(xn, n, **kwargs) 111 | return mos(eps - eps_pred) 112 | 113 | 114 | def LRouter(x0, nnet, schedule, order=None, timesteps=None, dpm_solver=None, **kwargs): 115 | #print(x0.shape) 116 | #print(order, timesteps) 117 | 118 | def model_fn(x, t_continuous): 119 | t = t_continuous * 1000 120 | eps_pre = nnet(x, t, **kwargs) 121 | return eps_pre 122 | dpm_solver.model = model_fn 123 | nnet.module.reset_cache_features() 124 | random_step = np.random.randint(0, len(order)-1) 125 | random_t = np.round(timesteps[random_step] * 1000).astype(int).repeat(x0.shape[0]) 126 | 127 | #print(random_t) 128 | _, _, xn = schedule.get_xn(x0, random_t) 129 | vec_s = torch.ones((xn.shape[0],)).to(xn.device) * timesteps[random_step] 130 | vec_t = torch.ones((xn.shape[0],)).to(xn.device) * timesteps[random_step + 1] 131 | with torch.no_grad(): 132 | xn_minus_1 = dpm_solver.dpm_solver_second_update(xn, vec_s, vec_t, return_noise=False, solver_type='dpm_solver') 133 | 134 | random_t_minus_1 = np.round(timesteps[random_step + 1] * 1000).astype(int).repeat(x0.shape[0]) 135 | random_t_minus_1 = torch.tensor(random_t_minus_1).to(xn_minus_1.device) 136 | 137 | # Teacher 138 | nnet.module.set_activate_cache(False) 139 | nnet.module.set_record_cache(False) 140 | t_pred = nnet(xn_minus_1, random_t_minus_1, **kwargs) 141 | 142 | # Student 143 | nnet.module.set_activate_cache(True) 144 | 145 | s_pred, l1_loss = nnet(xn_minus_1, random_t_minus_1, **kwargs) 146 | 147 | nnet.module.set_activate_cache(False) 148 | nnet.module.set_record_cache(True) 149 | 150 | return sos(t_pred - s_pred), l1_loss 151 | 152 | 153 | def train(config): 154 | if config.get('benchmark', False): 155 | torch.backends.cudnn.benchmark = True 156 | torch.backends.cudnn.deterministic = False 157 | 158 | mp.set_start_method('spawn') 159 | 160 | 161 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 162 | accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs]) 163 | #accelerator = accelerate.Accelerator() 164 | device = accelerator.device 165 | accelerate.utils.set_seed(config.seed, device_specific=True) 166 | logging.info(f'Process {accelerator.process_index} using device: {device}') 167 | 168 | config.mixed_precision = accelerator.mixed_precision 169 | config = ml_collections.FrozenConfigDict(config) 170 | 171 | assert config.train.batch_size % accelerator.num_processes == 0 172 | mini_batch_size = config.train.batch_size // accelerator.num_processes 173 | 174 | if accelerator.is_main_process: 175 | os.makedirs(config.ckpt_root, exist_ok=True) 176 | os.makedirs(config.sample_dir, exist_ok=True) 177 | accelerator.wait_for_everyone() 178 | if accelerator.is_main_process: 179 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(), 180 | name=config.hparams, job_type='train')#, mode='offline') 181 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) 182 | logging.info(config) 183 | else: 184 | utils.set_logger(log_level='error') 185 | builtins.print = lambda *args: None 186 | logging.info(f'Run on {accelerator.num_processes} devices') 187 | 188 | # Load Dataset 189 | dataset = get_dataset(**config.dataset) 190 | assert os.path.exists(dataset.fid_stat) 191 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond') 192 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True, 193 | num_workers=8, pin_memory=True, persistent_workers=True) 194 | 195 | # Load Model and Optimizer 196 | train_state = utils.initialize_train_state(config, device) 197 | train_state.nnet.add_router(config.nfe) 198 | router_optim = torch.optim.AdamW( 199 | [param for name, param in train_state.nnet.named_parameters() if "routers" in name], 200 | lr=config.router_lr, weight_decay=0 201 | ) 202 | train_state.update_optimizer(router_optim) 203 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare( 204 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader) 205 | logging.info(f'load nnet from {config.nnet_path}') 206 | msg = accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'), strict=False) 207 | logging.info(f'load nnet messgae = {config.nnet_path}') 208 | 209 | 210 | lr_scheduler = train_state.lr_scheduler 211 | train_state.resume(config.ckpt_root) 212 | 213 | # Load Autoencoder 214 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 215 | autoencoder.to(device) 216 | 217 | # Setup DPM Solver 218 | _betas = stable_diffusion_beta_schedule() 219 | _schedule = Schedule(_betas) 220 | logging.info(f'use {_schedule}') 221 | 222 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 223 | dpm_solver = DPM_Solver(None, noise_schedule, predict_x0=True, thresholding=False) 224 | t_0 = 1. / _schedule.N 225 | t_T = 1.0 226 | order_value = 2 227 | N_steps = config.nfe // order_value 228 | order = [order_value,] * N_steps 229 | timesteps = dpm_solver.get_time_steps( 230 | skip_type='time_uniform', t_T=t_T, t_0=t_0, N=N_steps, device=device 231 | ) 232 | timesteps = timesteps.cpu().numpy() 233 | timestep_mapping = np.round(timesteps * 1000) 234 | accelerator.unwrap_model(nnet).set_timestep_map(timestep_mapping) 235 | 236 | @ torch.cuda.amp.autocast() 237 | def encode(_batch): 238 | return autoencoder.encode(_batch) 239 | 240 | @ torch.cuda.amp.autocast() 241 | def decode(_batch): 242 | return autoencoder.decode(_batch) 243 | 244 | def get_data_generator(): 245 | while True: 246 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'): 247 | yield data 248 | 249 | data_generator = get_data_generator() 250 | 251 | 252 | def train_step(_batch): 253 | _metrics = dict() 254 | optimizer.zero_grad() 255 | if config.train.mode == 'uncond': 256 | _z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch) 257 | data_loss, l1_loss = LRouter(_z, nnet, _schedule, order=order, timesteps=timesteps, dpm_solver=dpm_solver, l1_weight=config.l1_weight) 258 | elif config.train.mode == 'cond': 259 | #print("Label = ", _batch[1]) 260 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0]) 261 | data_loss, l1_loss = LRouter(_z, nnet, _schedule, y=_batch[1], order=order, timesteps=timesteps, dpm_solver=dpm_solver) 262 | loss = data_loss + config.l1_weight * l1_loss 263 | else: 264 | raise NotImplementedError(config.train.mode) 265 | _metrics['loss'] = accelerator.gather(loss.detach()).mean() 266 | _metrics['data_loss'] = accelerator.gather(data_loss.detach()).mean() 267 | _metrics['l1_loss'] = accelerator.gather(l1_loss.detach()).mean() 268 | 269 | 270 | accelerator.backward(loss.mean()) 271 | optimizer.step() 272 | lr_scheduler.step() 273 | train_state.step += 1 274 | 275 | #print("Router 0:", nnet.module.routers[0].prob.data) 276 | #print("Router 1:", nnet.module.routers[1].prob.data) 277 | #print() 278 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) 279 | 280 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') 281 | 282 | loss_metrics = 0 283 | data_loss_metrics = 0 284 | l1_loss_metrics = 0 285 | while train_state.step < config.train.n_steps: 286 | nnet.train() 287 | batch = tree_map(lambda x: x.to(device), next(data_generator)) 288 | metrics = train_step(batch) 289 | 290 | if accelerator.is_main_process: 291 | loss_metrics += metrics['loss'] 292 | data_loss_metrics += metrics['data_loss'] 293 | l1_loss_metrics += metrics['l1_loss'] 294 | 295 | nnet.eval() 296 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: 297 | scores = [nnet.module.routers[idx]() for idx in range(1, config.nfe//2)] 298 | mask = format_image_to_wandb(config.nfe//2-1, nnet.module.depth*2, scores) 299 | mask = wandb.Image( 300 | mask, 301 | ) 302 | metrics['loss'] = loss_metrics / config.train.log_interval 303 | metrics['data_loss'] = data_loss_metrics / config.train.log_interval 304 | metrics['l1_loss'] = l1_loss_metrics / config.train.log_interval 305 | final_score = [sum(score) for score in scores] 306 | metrics['non_zero'] = sum(final_score) / (len(final_score) * len(scores[0])) 307 | 308 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) 309 | metrics['router'] = mask 310 | #logging.info(config.workdir) 311 | wandb.log(metrics, step=train_state.step) 312 | loss_metrics, data_loss_metrics, l1_loss_metrics = 0, 0, 0 313 | 314 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: 315 | torch.cuda.empty_cache() 316 | logging.info(f'Save and eval checkpoint {train_state.step}...') 317 | if accelerator.local_process_index == 0: 318 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt')) 319 | accelerator.wait_for_everyone() 320 | #fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint 321 | #step_fid.append((train_state.step, fid)) 322 | torch.cuda.empty_cache() 323 | accelerator.wait_for_everyone() 324 | 325 | logging.info(f'Finish fitting, step={train_state.step}') 326 | 327 | 328 | 329 | from absl import flags 330 | from absl import app 331 | from ml_collections import config_flags 332 | import sys 333 | from pathlib import Path 334 | 335 | 336 | FLAGS = flags.FLAGS 337 | config_flags.DEFINE_config_file( 338 | "config", None, "Training configuration.", lock_config=False) 339 | flags.mark_flags_as_required(["config"]) 340 | flags.DEFINE_string("workdir", None, "Work unit directory.") 341 | flags.DEFINE_string("nfe", None, "NFE") 342 | flags.DEFINE_string("router_lr", None, "learning rate for router") 343 | flags.DEFINE_string("l1_weight", None, "l1 weight for router loss") 344 | flags.DEFINE_string("nnet_path", None, "l1 weight for router loss") 345 | 346 | 347 | 348 | def get_config_name(): 349 | argv = sys.argv 350 | for i in range(1, len(argv)): 351 | if argv[i].startswith('--config='): 352 | return Path(argv[i].split('=')[-1]).stem 353 | 354 | 355 | def get_hparams(): 356 | argv = sys.argv 357 | lst = [] 358 | for i in range(1, len(argv)): 359 | assert '=' in argv[i] 360 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): 361 | hparam, val = argv[i].split('=') 362 | hparam = hparam.split('.')[-1] 363 | if hparam.endswith('path'): 364 | val = Path(val).stem 365 | lst.append(f'{hparam}={val}') 366 | hparams = '-'.join(lst) 367 | if hparams == '': 368 | hparams = 'default' 369 | return hparams 370 | 371 | 372 | def main(argv): 373 | config = FLAGS.config 374 | config.nfe = int(FLAGS.nfe) 375 | config.router_lr = float(FLAGS.router_lr) 376 | config.l1_weight = float(FLAGS.l1_weight) 377 | config.nnet_path = FLAGS.nnet_path 378 | config.config_name = get_config_name() 379 | config.hparams = get_hparams() 380 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams) 381 | config.ckpt_root = os.path.join(config.workdir, 'ckpts') 382 | config.sample_dir = os.path.join(config.workdir, 'samples') 383 | train(config) 384 | 385 | 386 | if __name__ == "__main__": 387 | app.run(main) 388 | -------------------------------------------------------------------------------- /DiT/train_router.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | A minimal training script for DiT using PyTorch DDP. 9 | """ 10 | import torch 11 | # the first flag below was False when we tested this script but True makes A100 training a lot faster: 12 | torch.backends.cuda.matmul.allow_tf32 = True 13 | torch.backends.cudnn.allow_tf32 = True 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data.distributed import DistributedSampler 18 | from torchvision.datasets import ImageFolder 19 | from torchvision import transforms 20 | import numpy as np 21 | from collections import OrderedDict 22 | from PIL import Image 23 | from copy import deepcopy 24 | from glob import glob 25 | from time import time 26 | import argparse 27 | import logging 28 | import os 29 | 30 | from models.router_models import DiT_models, STE 31 | from diffusion import create_diffusion 32 | from diffusers.models import AutoencoderKL 33 | from download import find_model 34 | 35 | 36 | ################################################################################# 37 | # Training Helper Functions # 38 | ################################################################################# 39 | 40 | @torch.no_grad() 41 | def update_ema(ema_model, model, decay=0.9999): 42 | """ 43 | Step the EMA model towards the current model. 44 | """ 45 | ema_params = OrderedDict(ema_model.named_parameters()) 46 | model_params = OrderedDict(model.named_parameters()) 47 | 48 | for name, param in model_params.items(): 49 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 50 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 51 | 52 | 53 | def requires_grad(model, flag=True): 54 | """ 55 | Set requires_grad flag for all parameters in a model. 56 | """ 57 | for p in model.parameters(): 58 | p.requires_grad = flag 59 | 60 | 61 | def cleanup(): 62 | """ 63 | End DDP training. 64 | """ 65 | dist.destroy_process_group() 66 | 67 | 68 | def create_logger(logging_dir): 69 | """ 70 | Create a logger that writes to a log file and stdout. 71 | """ 72 | if dist.get_rank() == 0: # real logger 73 | logging.basicConfig( 74 | level=logging.INFO, 75 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 76 | datefmt='%Y-%m-%d %H:%M:%S', 77 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 78 | ) 79 | logger = logging.getLogger(__name__) 80 | else: # dummy logger (does nothing) 81 | logger = logging.getLogger(__name__) 82 | logger.addHandler(logging.NullHandler()) 83 | return logger 84 | 85 | def format_image_to_wandb(num_router, router_size, router_scores): 86 | image = np.zeros((num_router, router_size, 3), dtype=np.float32) 87 | ones = np.ones((3), dtype=np.float32) 88 | for idx, score in enumerate(router_scores): 89 | mask = score.cpu().detach() 90 | for pos in range(router_size): 91 | image[idx, pos] = ones * mask[pos].item() 92 | return image 93 | 94 | 95 | def center_crop_arr(pil_image, image_size): 96 | """ 97 | Center cropping implementation from ADM. 98 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 99 | """ 100 | while min(*pil_image.size) >= 2 * image_size: 101 | pil_image = pil_image.resize( 102 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 103 | ) 104 | 105 | scale = image_size / min(*pil_image.size) 106 | pil_image = pil_image.resize( 107 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 108 | ) 109 | 110 | arr = np.array(pil_image) 111 | crop_y = (arr.shape[0] - image_size) // 2 112 | crop_x = (arr.shape[1] - image_size) // 2 113 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 114 | 115 | 116 | ################################################################################# 117 | # Training Loop # 118 | ################################################################################# 119 | 120 | def main(args): 121 | """ 122 | Trains a new DiT model. 123 | """ 124 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 125 | 126 | # Setup DDP: 127 | dist.init_process_group("nccl") 128 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." 129 | rank = dist.get_rank() 130 | device = rank % torch.cuda.device_count() 131 | seed = args.global_seed * dist.get_world_size() + rank 132 | torch.manual_seed(seed) 133 | torch.cuda.set_device(device) 134 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 135 | 136 | # Setup an experiment folder: 137 | if rank == 0: 138 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) 139 | experiment_index = len(glob(f"{args.results_dir}/*")) 140 | model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders) 141 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder 142 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints 143 | os.makedirs(checkpoint_dir, exist_ok=True) 144 | logger = create_logger(experiment_dir) 145 | logger.info(f"Experiment directory created at {experiment_dir}") 146 | else: 147 | logger = create_logger(None) 148 | 149 | # Create model: 150 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 151 | latent_size = args.image_size // 8 152 | 153 | 154 | model = DiT_models[args.model]( 155 | input_size=latent_size, 156 | num_classes=args.num_classes 157 | ).to(device) 158 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: 159 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" 160 | state_dict = find_model(ckpt_path) 161 | msg = model.load_state_dict(state_dict, strict=False) 162 | if rank == 0: 163 | logger.info(f"Loaded model from {ckpt_path} with msg: {msg}") 164 | model.eval() # important! 165 | 166 | diffusion = create_diffusion(str(args.num_sampling_steps)) 167 | model.add_router(args.num_sampling_steps, diffusion.timestep_map) 168 | model = DDP(model.to(device), device_ids=[rank], find_unused_parameters=True) 169 | 170 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 171 | logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 172 | 173 | #routers = [Router(len(model.module.blocks)*2) for _ in range(args.num_sampling_steps//2)] 174 | #routers = [DDP(r.to(device), device_ids=[rank]) for r in routers] 175 | opts = torch.optim.AdamW( 176 | [param for name, param in model.named_parameters() if "routers" in name], 177 | lr=args.lr, weight_decay=0 178 | ) 179 | 180 | # Setup data: 181 | transform = transforms.Compose([ 182 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 183 | transforms.RandomHorizontalFlip(), 184 | transforms.ToTensor(), 185 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 186 | ]) 187 | dataset = ImageFolder(args.data_path, transform=transform) 188 | sampler = DistributedSampler( 189 | dataset, 190 | num_replicas=dist.get_world_size(), 191 | rank=rank, 192 | shuffle=True, 193 | seed=args.global_seed 194 | ) 195 | loader = DataLoader( 196 | dataset, 197 | batch_size=int(args.global_batch_size // dist.get_world_size()), 198 | shuffle=False, 199 | sampler=sampler, 200 | num_workers=args.num_workers, 201 | pin_memory=True, 202 | drop_last=True 203 | ) 204 | logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") 205 | 206 | if args.wandb and rank == 0: 207 | import wandb 208 | wandb.init( 209 | # Set the project where this run will be logged 210 | project="DiT-Router", 211 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10) 212 | name=f"{experiment_index:03d}-{model_string_name}", 213 | # Track hyperparameters and run metadata 214 | config=args.__dict__ 215 | ) 216 | wandb.define_metric("step") 217 | wandb.define_metric("loss", step_metric="step") 218 | 219 | # Prepare models for training: 220 | #update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights 221 | model.train() # important! We need to use embedding dropout for classifier-free guidance here. 222 | #ema.eval() # EMA model should always be in eval mode 223 | 224 | # Variables for monitoring/logging purposes: 225 | train_steps = 0 226 | log_steps = 0 227 | running_loss = 0 228 | running_data_loss, running_l1_loss = 0, 0 229 | start_time = time() 230 | 231 | logger.info(f"Training for {args.epochs} epochs...") 232 | for epoch in range(args.epochs): 233 | sampler.set_epoch(epoch) 234 | logger.info(f"Beginning epoch {epoch}...") 235 | for x, y in loader: 236 | x = x.to(device) 237 | y = y.to(device) 238 | 239 | with torch.no_grad(): 240 | # Map input images to latent space + normalize latents: 241 | x = vae.encode(x).latent_dist.sample().mul_(0.18215) 242 | model_kwargs = dict(y=y, thres=args.ste_threshold) 243 | 244 | #t = 1+2*torch.randint(0, diffusion.num_timesteps//2, (x.shape[0],), device=device) 245 | t = torch.randint(0, diffusion.num_timesteps//2, (1,), device=device) 246 | #t = torch.tensor(2, device=device) 247 | ts = t.repeat(x.shape[0])*2 + 1 248 | 249 | loss_dict = diffusion.router_training_losses(model, x, ts, model_kwargs) 250 | data_loss = loss_dict["mse"].mean() 251 | l1_loss = loss_dict["l1_loss"].mean() 252 | 253 | #print(f"Rank: {rank}, t: {t}, data loss: {data_loss}. L1 loss: {l1_loss}") 254 | loss = data_loss + args.l1 * l1_loss 255 | opts.zero_grad() 256 | model.zero_grad() 257 | 258 | loss.backward() 259 | #for idx, router in enumerate(model.module.routers): 260 | # print(f"Rank: {rank}, idx: {idx}, ", router.prob.grad) 261 | opts.step() 262 | 263 | with torch.no_grad(): 264 | for name, param in model.named_parameters(): 265 | if "routers" in name: 266 | param.clamp_(-5, 5) 267 | 268 | # Log loss values: 269 | running_loss += loss.item() 270 | running_data_loss += data_loss.item() 271 | running_l1_loss += args.l1 * l1_loss.item() 272 | 273 | log_steps += 1 274 | train_steps += 1 275 | 276 | model.module.reset() 277 | 278 | if train_steps % args.log_every == 0: 279 | # Measure training speed: 280 | torch.cuda.synchronize() 281 | end_time = time() 282 | steps_per_sec = log_steps / (end_time - start_time) 283 | 284 | # Reduce loss history over all processes: 285 | for name, loss in [("loss", running_loss), ("data_loss", running_data_loss), ("l1_loss", running_l1_loss)]: 286 | loss = torch.tensor(loss / log_steps, device=device) 287 | dist.all_reduce(loss, op=dist.ReduceOp.SUM) 288 | loss = loss.item() / dist.get_world_size() 289 | logger.info(f"(step={train_steps:07d}) Train {name} Loss: {loss:.7f}, Train Steps/Sec: {steps_per_sec:.2f}") 290 | 291 | scores = [model.module.routers[idx]() for idx in range(0, args.num_sampling_steps, 2)] 292 | 293 | if args.wandb and rank == 0: 294 | #print(scores) 295 | mask = format_image_to_wandb(args.num_sampling_steps//2 , model.module.depth*2, scores) 296 | mask = wandb.Image( 297 | mask, 298 | ) 299 | if args.ste_threshold is not None: 300 | final_score = [sum(STE.apply(score, args.ste_threshold)) for score in scores] 301 | else: 302 | final_score = [sum(score) for score in scores] 303 | wandb.log({ 304 | "step": train_steps, 305 | "loss": loss, 306 | "data_loss": running_data_loss / log_steps, 307 | "l1_loss": running_l1_loss / log_steps, 308 | "non_zero": sum(final_score), 309 | "router": mask 310 | }) 311 | 312 | # Reset monitoring variables: 313 | running_loss = 0 314 | running_data_loss, running_l1_loss = 0, 0 315 | log_steps = 0 316 | start_time = time() 317 | 318 | # Save DiT checkpoint: 319 | if train_steps % args.ckpt_every == 0 and train_steps > 0: 320 | if rank == 0: 321 | checkpoint = { 322 | #"model": model.module.state_dict(), 323 | #"ema": ema.state_dict(), 324 | "routers": model.module.routers.state_dict(), 325 | "opt": opts.state_dict(), 326 | "args": args 327 | } 328 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" 329 | torch.save(checkpoint, checkpoint_path) 330 | logger.info(f"Saved checkpoint to {checkpoint_path}") 331 | dist.barrier() 332 | 333 | if train_steps > args.max_steps: 334 | print("Reach Maximum Step") 335 | break 336 | 337 | model.eval() # important! This disables randomized embedding dropout 338 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... 339 | 340 | logger.info("Done!") 341 | cleanup() 342 | 343 | 344 | if __name__ == "__main__": 345 | # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters). 346 | parser = argparse.ArgumentParser() 347 | parser.add_argument("--data-path", type=str, required=True) 348 | parser.add_argument("--results-dir", type=str, default="results") 349 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") 350 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 351 | parser.add_argument("--num-classes", type=int, default=1000) 352 | parser.add_argument("--epochs", type=int, default=1) 353 | parser.add_argument("--global-batch-size", type=int, default=256) 354 | parser.add_argument("--global-seed", type=int, default=0) 355 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training 356 | parser.add_argument("--num-workers", type=int, default=4) 357 | parser.add_argument("--log-every", type=int, default=100) 358 | parser.add_argument("--ckpt-every", type=int, default=50_000) 359 | parser.add_argument("--wandb", action="store_true") 360 | 361 | parser.add_argument("--ckpt", type=str, default=None) 362 | #parser.add_argument("--cfg-scale", type=float, required=True) 363 | parser.add_argument("--num-sampling-steps", type=int, default=20) 364 | parser.add_argument("--l1", type=float, default=1.0) 365 | 366 | parser.add_argument("--lr", type=float, default=1.0) 367 | parser.add_argument("--max-steps", type=int, default=50000) 368 | 369 | parser.add_argument("--ste-threshold", type=float, default=None) 370 | 371 | args = parser.parse_args() 372 | main(args) 373 | --------------------------------------------------------------------------------