├── LICENSE ├── README.md ├── UViT_ImageNet_demo.ipynb ├── configs ├── celeba64_uvit_small.py ├── cifar10_uvit_small.py ├── imagenet256_uvit_huge.py ├── imagenet256_uvit_large.py ├── imagenet512_uvit_huge.py ├── imagenet512_uvit_large.py ├── imagenet64_uvit_large.py ├── imagenet64_uvit_mid.py └── mscoco_uvit_small.py ├── datasets.py ├── dpm_solver_pp.py ├── dpm_solver_pytorch.py ├── eval.py ├── eval_ldm.py ├── eval_ldm_discrete.py ├── eval_t2i_discrete.py ├── libs ├── __init__.py ├── autoencoder.py ├── clip.py ├── timm.py ├── uvit.py └── uvit_t2i.py ├── sample.png ├── sample_t2i_discrete.py ├── scripts ├── extract_empty_feature.py ├── extract_imagenet_feature.py ├── extract_mscoco_feature.py └── extract_test_prompt_feature.py ├── sde.py ├── skip_im.png ├── tools ├── fid_score.py └── inception.py ├── train.py ├── train_ldm.py ├── train_ldm_discrete.py ├── train_t2i_discrete.py ├── utils.py └── uvit.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Fan Bao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## U-ViT
Official PyTorch implementation of [All are Worth Words: A ViT Backbone for Diffusion Models](https://arxiv.org/abs/2209.12152) (CVPR 2023) 2 | 3 | 4 | 💡Projects with U-ViT: 5 | * [UniDiffuser](https://github.com/thu-ml/unidiffuser), a multi-modal large-scale diffusion model based on a 1B U-ViT, is open-sourced 6 | * [DPT](https://arxiv.org/abs/2302.10586), [code](https://github.com/ML-GSAI/DPT), [demo](https://ml-gsai.github.io/DPT-demo) a conditional diffusion model trained with 1 label/class with SOTA SSL generation and classification results on ImageNet 7 | 8 | drawing 9 | 10 | Vision transformers (ViT) have shown promise in various vision tasks while the U-Net based on a convolutional neural network (CNN) remains dominant in diffusion models. 11 | We design a simple and general ViT-based architecture (named U-ViT) for image generation with diffusion models. 12 | U-ViT is characterized by treating all inputs including the time, condition and noisy image patches as tokens 13 | and employing long skip connections between shallow and deep layers. 14 | We evaluate U-ViT in unconditional and class-conditional image generation, 15 | as well as text-to-image generation tasks, where U-ViT is comparable if not superior to a CNN-based U-Net of a similar size. 16 | In particular, latent diffusion models with U-ViT achieve record-breaking FID scores of 2.29 in class-conditional image generation 17 | on ImageNet 256x256, and 5.48 in text-to-image generation on MS-COCO, among methods without accessing 18 | large external datasets during the training of generative models. 19 | 20 | Our results suggest that, for diffusion-based image modeling, the long skip connection is crucial while the down-sampling and up-sampling operators in CNN-based U-Net are not always necessary. We believe that U-ViT can provide insights for future research on backbones in diffusion models and benefit generative modeling on large scale cross-modality datasets. 21 | 22 | -------------------- 23 | 24 | 25 | 26 | This codebase implements the transformer-based backbone 📌*U-ViT*📌 for diffusion models, as introduced in the [paper](https://arxiv.org/abs/2209.12152). 27 | U-ViT treats all inputs as tokens and employs long skip connections. *The long skip connections grealy promote the performance and the convergence speed*. 28 | 29 | 30 | 31 | drawing 32 | 33 | 34 | 💡This codebase contains: 35 | * An implementation of [U-ViT](libs/uvit.py) with optimized attention computation 36 | * Pretrained U-ViT models on common image generation benchmarks (CIFAR10, CelebA 64x64, ImageNet 64x64, ImageNet 256x256, ImageNet 512x512) 37 | * Efficient training scripts for [pixel-space diffusion models](train.py), [latent space diffusion models](train_ldm_discrete.py) and [text-to-image diffusion models](train_t2i_discrete.py) 38 | * Efficient evaluation scripts for [pixel-space diffusion models](eval.py) and [latent space diffusion models](eval_ldm_discrete.py) and [text-to-image diffusion models](eval_t2i_discrete.py) 39 | * A Colab notebook demo for sampling from U-ViT on ImageNet (FID=2.29) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/baofff/U-ViT/blob/main/UViT_ImageNet_demo.ipynb) 40 | 41 | 42 | drawing 43 | 44 | 45 | 💡This codebase supports useful techniques for efficient training and sampling of diffusion models: 46 | * Mixed precision training with the [huggingface accelerate](https://github.com/huggingface/accelerate) library (🥰automatically turned on) 47 | * Efficient attention computation with the [facebook xformers](https://github.com/facebookresearch/xformers) library (needs additional installation) 48 | * Gradient checkpointing trick, which reduces ~65% memory (🥰automatically turned on) 49 | * With these techniques, we are able to train our largest U-ViT-H on ImageNet at high resolutions such as 256x256 and 512x512 using a large batch size of 1024 with *only 2 A100*❗ 50 | 51 | 52 | Training speed and memory of U-ViT-H/2 on ImageNet 256x256 using a batch size of 128 with a A100: 53 | 54 | | mixed precision training | xformers | gradient checkpointing | training speed | memory | 55 | |:------------------------:|:--------:|:----------------------:|:-----------------:|:-------------:| 56 | | ❌ | ❌ | ❌ | - | out of memory | 57 | | ✔ | ❌ | ❌ | 0.97 steps/second | 78852 MB | 58 | | ✔ | ✔ | ❌ | 1.14 steps/second | 54324 MB | 59 | | ✔ | ✔ | ✔ | 0.87 steps/second | 18858 MB | 60 | 61 | 62 | 63 | ## Dependency 64 | 65 | ```sh 66 | pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 # install torch-1.13.1 67 | pip install accelerate==0.12.0 absl-py ml_collections einops wandb ftfy==6.1.1 transformers==4.23.1 68 | 69 | # xformers is optional, but it would greatly speed up the attention computation. 70 | pip install -U xformers 71 | pip install -U --pre triton 72 | ``` 73 | 74 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. (Perhaps other versions also work, but I haven't tested it.) 75 | * We highly suggest install [xformers](https://github.com/facebookresearch/xformers), which would greatly speed up the attention computation for *both training and inference*. 76 | 77 | 78 | 79 | ## Pretrained Models 80 | 81 | 82 | | Model | FID | training iterations | batch size | 83 | |:----------------------------------------------------------------------------------------------------------------------:|:-----:|:-------------------:|:----------:| 84 | | [CIFAR10 (U-ViT-S/2)](https://drive.google.com/file/d/1yoYyuzR_hQYWU0mkTj659tMTnoCWCMv-/view?usp=share_link) | 3.11 | 500K | 128 | 85 | | [CelebA 64x64 (U-ViT-S/4)](https://drive.google.com/file/d/13YpbRtlqF1HDBNLNRlKxLTbKbKeLE06C/view?usp=share_link) | 2.87 | 500K | 128 | 86 | | [ImageNet 64x64 (U-ViT-M/4)](https://drive.google.com/file/d/1igVgRY7-A0ZV3XqdNcMGOnIGOxKr9azv/view?usp=share_link) | 5.85 | 300K | 1024 | 87 | | [ImageNet 64x64 (U-ViT-L/4)](https://drive.google.com/file/d/19rmun-T7RwkNC1feEPWinIo-1JynpW7J/view?usp=share_link) | 4.26 | 300K | 1024 | 88 | | [ImageNet 256x256 (U-ViT-L/2)](https://drive.google.com/file/d/1w7T1hiwKODgkYyMH9Nc9JNUThbxFZgs3/view?usp=share_link) | 3.40 | 300K | 1024 | 89 | | [ImageNet 256x256 (U-ViT-H/2)](https://drive.google.com/file/d/13StUdrjaaSXjfqqF7M47BzPyhMAArQ4u/view?usp=share_link) | 2.29 | 500K | 1024 | 90 | | [ImageNet 512x512 (U-ViT-L/4)](https://drive.google.com/file/d/1mkj4aN2utHMBTWQX9l1nYue9vleL7ZSB/view?usp=share_link) | 4.67 | 500K | 1024 | 91 | | [ImageNet 512x512 (U-ViT-H/4)](https://drive.google.com/file/d/1uegr2o7cuKXtf2akWGAN2Vnlrtw5YKQq/view?usp=share_link) | 4.05 | 500K | 1024 | 92 | | [MS-COCO (U-ViT-S/2)](https://drive.google.com/file/d/15JsZWRz2byYNU6K093et5e5Xqd4uwA8S/view?usp=share_link) | 5.95 | 1M | 256 | 93 | | [MS-COCO (U-ViT-S/2, Deep)](https://drive.google.com/file/d/1gHRy8sn039Wy-iFL21wH8TiheHK8Ky71/view?usp=share_link) | 5.48 | 1M | 256 | 94 | 95 | 96 | 97 | ## Preparation Before Training and Evaluation 98 | 99 | #### Autoencoder 100 | Download `stable-diffusion` directory from this [link](https://drive.google.com/drive/folders/1yo-XhqbPue3rp5P57j6QbA5QZx6KybvP?usp=sharing) (which contains image autoencoders converted from [Stable Diffusion](https://github.com/CompVis/stable-diffusion)). 101 | Put the downloaded directory as `assets/stable-diffusion` in this codebase. 102 | The autoencoders are used in latent diffusion models. 103 | 104 | #### Data 105 | * ImageNet 64x64: Put the standard ImageNet dataset (which contains the `train` and `val` directory) to `assets/datasets/ImageNet`. 106 | * ImageNet 256x256 and ImageNet 512x512: Extract ImageNet features according to `scripts/extract_imagenet_feature.py`. 107 | * MS-COCO: Download COCO 2014 [training](http://images.cocodataset.org/zips/train2014.zip), [validation](http://images.cocodataset.org/zips/val2014.zip) data and [annotations](http://images.cocodataset.org/annotations/annotations_trainval2014.zip). Then extract their features according to `scripts/extract_mscoco_feature.py` `scripts/extract_test_prompt_feature.py` `scripts/extract_empty_feature.py`. 108 | 109 | #### Reference statistics for FID 110 | Download `fid_stats` directory from this [link](https://drive.google.com/drive/folders/1yo-XhqbPue3rp5P57j6QbA5QZx6KybvP?usp=sharing) (which contains reference statistics for FID). 111 | Put the downloaded directory as `assets/fid_stats` in this codebase. 112 | In addition to evaluation, these reference statistics are used to monitor FID during the training process. 113 | 114 | ## Training 115 | 116 | 117 | 118 | We use the [huggingface accelerate](https://github.com/huggingface/accelerate) library to help train with distributed data parallel and mixed precision. The following is the training command: 119 | ```sh 120 | # the training setting 121 | num_processes=2 # the number of gpus you have, e.g., 2 122 | train_script=train.py # the train script, one of 123 | # train.py: training on pixel space 124 | # train_ldm.py: training on latent space with continuous timesteps 125 | # train_ldm_discrete.py: training on latent space with discrete timesteps 126 | # train_t2i_discrete.py: text-to-image training on latent space 127 | config=configs/cifar10_uvit_small.py # the training configuration 128 | # you can change other hyperparameters by modifying the configuration file 129 | 130 | # launch training 131 | accelerate launch --multi_gpu --num_processes $num_processes --mixed_precision fp16 $train_script --config=$config 132 | ``` 133 | 134 | 135 | We provide all commands to reproduce U-ViT training in the paper: 136 | ```sh 137 | # CIFAR10 (U-ViT-S/2) 138 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/cifar10_uvit_small.py 139 | 140 | # CelebA 64x64 (U-ViT-S/4) 141 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/celeba64_uvit_small.py 142 | 143 | # ImageNet 64x64 (U-ViT-M/4) 144 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train.py --config=configs/imagenet64_uvit_mid.py 145 | 146 | # ImageNet 64x64 (U-ViT-L/4) 147 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train.py --config=configs/imagenet64_uvit_large.py 148 | 149 | # ImageNet 256x256 (U-ViT-L/2) 150 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm.py --config=configs/imagenet256_uvit_large.py 151 | 152 | # ImageNet 256x256 (U-ViT-H/2) 153 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm_discrete.py --config=configs/imagenet256_uvit_huge.py 154 | 155 | # ImageNet 512x512 (U-ViT-L/4) 156 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm.py --config=configs/imagenet512_uvit_large.py 157 | 158 | # ImageNet 512x512 (U-ViT-H/4) 159 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm_discrete.py --config=configs/imagenet512_uvit_huge.py 160 | 161 | # MS-COCO (U-ViT-S/2) 162 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train_t2i_discrete.py --config=configs/mscoco_uvit_small.py 163 | 164 | # MS-COCO (U-ViT-S/2, Deep) 165 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train_t2i_discrete.py --config=configs/mscoco_uvit_small.py --config.nnet.depth=16 166 | ``` 167 | 168 | 169 | 170 | ## Evaluation (Compute FID) 171 | 172 | We use the [huggingface accelerate](https://github.com/huggingface/accelerate) library for efficient inference with mixed precision and multiple gpus. The following is the evaluation command: 173 | ```sh 174 | # the evaluation setting 175 | num_processes=2 # the number of gpus you have, e.g., 2 176 | eval_script=eval.py # the evaluation script, one of 177 | # eval.py: for models trained with train.py (i.e., pixel space models) 178 | # eval_ldm.py: for models trained with train_ldm.py (i.e., latent space models with continuous timesteps) 179 | # eval_ldm_discrete.py: for models trained with train_ldm_discrete.py (i.e., latent space models with discrete timesteps) 180 | # eval_t2i_discrete.py: for models trained with train_t2i_discrete.py (i.e., text-to-image models on latent space) 181 | config=configs/cifar10_uvit_small.py # the training configuration 182 | 183 | # launch evaluation 184 | accelerate launch --multi_gpu --num_processes $num_processes --mixed_precision fp16 eval_script --config=$config 185 | ``` 186 | The generated images are stored in a temperary directory, and will be deleted after evaluation. If you want to keep these images, set `--config.sample.path=/save/dir`. 187 | 188 | 189 | We provide all commands to reproduce FID results in the paper: 190 | ```sh 191 | # CIFAR10 (U-ViT-S/2) 192 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/cifar10_uvit_small.py --nnet_path=cifar10_uvit_small.pth 193 | 194 | # CelebA 64x64 (U-ViT-S/4) 195 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/celeba64_uvit_small.py --nnet_path=celeba64_uvit_small.pth 196 | 197 | # ImageNet 64x64 (U-ViT-M/4) 198 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval.py --config=configs/imagenet64_uvit_mid.py --nnet_path=imagenet64_uvit_mid.pth 199 | 200 | # ImageNet 64x64 (U-ViT-L/4) 201 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval.py --config=configs/imagenet64_uvit_large.py --nnet_path=imagenet64_uvit_large.pth 202 | 203 | # ImageNet 256x256 (U-ViT-L/2) 204 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm.py --config=configs/imagenet256_uvit_large.py --nnet_path=imagenet256_uvit_large.pth 205 | 206 | # ImageNet 256x256 (U-ViT-H/2) 207 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm_discrete.py --config=configs/imagenet256_uvit_huge.py --nnet_path=imagenet256_uvit_huge.pth 208 | 209 | # ImageNet 512x512 (U-ViT-L/4) 210 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm.py --config=configs/imagenet512_uvit_large.py --nnet_path=imagenet512_uvit_large.pth 211 | 212 | # ImageNet 512x512 (U-ViT-H/4) 213 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm_discrete.py --config=configs/imagenet512_uvit_huge.py --nnet_path=imagenet512_uvit_huge.pth 214 | 215 | # MS-COCO (U-ViT-S/2) 216 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval_t2i_discrete.py --config=configs/mscoco_uvit_small.py --nnet_path=mscoco_uvit_small.pth 217 | 218 | # MS-COCO (U-ViT-S/2, Deep) 219 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval_t2i_discrete.py --config=configs/mscoco_uvit_small.py --config.nnet.depth=16 --nnet_path=mscoco_uvit_small_deep.pth 220 | ``` 221 | 222 | 223 | 224 | 225 | ## References 226 | If you find the code useful for your research, please consider citing 227 | ```bib 228 | @inproceedings{bao2022all, 229 | title={All are Worth Words: A ViT Backbone for Diffusion Models}, 230 | author={Bao, Fan and Nie, Shen and Xue, Kaiwen and Cao, Yue and Li, Chongxuan and Su, Hang and Zhu, Jun}, 231 | booktitle = {CVPR}, 232 | year={2023} 233 | } 234 | ``` 235 | 236 | This implementation is based on 237 | * [Extended Analytic-DPM](https://github.com/baofff/Extended-Analytic-DPM) (provide the FID reference statistics on CIFAR10 and CelebA 64x64) 238 | * [guided-diffusion](https://github.com/openai/guided-diffusion) (provide the FID reference statistics on ImageNet) 239 | * [pytorch-fid](https://github.com/mseitzer/pytorch-fid) (provide the official implementation of FID to PyTorch) 240 | * [dpm-solver](https://github.com/LuChengTHU/dpm-solver) (provide the sampler) 241 | -------------------------------------------------------------------------------- /configs/celeba64_uvit_small.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=500000, 17 | batch_size=128, 18 | mode='uncond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0002, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.99), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=5000 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=64, 39 | patch_size=4, 40 | embed_dim=512, 41 | depth=12, 42 | num_heads=8, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=-1, 47 | ) 48 | 49 | config.dataset = d( 50 | name='celeba', 51 | path='assets/datasets/celeba', 52 | resolution=64, 53 | ) 54 | 55 | config.sample = d( 56 | sample_steps=1000, 57 | n_samples=50000, 58 | mini_batch_size=500, 59 | algorithm='euler_maruyama_sde', 60 | path='' 61 | ) 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /configs/cifar10_uvit_small.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=500000, 17 | batch_size=128, 18 | mode='uncond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0002, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.999), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=2500 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=32, 39 | patch_size=2, 40 | embed_dim=512, 41 | depth=12, 42 | num_heads=8, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=-1, 47 | ) 48 | 49 | config.dataset = d( 50 | name='cifar10', 51 | path='assets/datasets/cifar10', 52 | random_flip=True, 53 | ) 54 | 55 | config.sample = d( 56 | sample_steps=1000, 57 | n_samples=50000, 58 | mini_batch_size=500, 59 | algorithm='euler_maruyama_sde', 60 | path='' 61 | ) 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /configs/imagenet256_uvit_huge.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=32, 44 | patch_size=2, 45 | in_chans=4, 46 | embed_dim=1152, 47 | depth=28, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True, 54 | conv=False 55 | ) 56 | 57 | config.dataset = d( 58 | name='imagenet256_features', 59 | path='assets/datasets/imagenet256_features', 60 | cfg=True, 61 | p_uncond=0.1 62 | ) 63 | 64 | config.sample = d( 65 | sample_steps=50, 66 | n_samples=50000, 67 | mini_batch_size=50, # the decoder is large 68 | algorithm='dpm_solver', 69 | cfg=True, 70 | scale=0.4, 71 | path='' 72 | ) 73 | 74 | return config 75 | -------------------------------------------------------------------------------- /configs/imagenet256_uvit_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=300000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=32, 44 | patch_size=2, 45 | in_chans=4, 46 | embed_dim=1024, 47 | depth=20, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True 54 | ) 55 | 56 | config.dataset = d( 57 | name='imagenet256_features', 58 | path='assets/datasets/imagenet256_features', 59 | cfg=True, 60 | p_uncond=0.15 61 | ) 62 | 63 | config.sample = d( 64 | sample_steps=50, 65 | n_samples=50000, 66 | mini_batch_size=50, # the decoder is large 67 | algorithm='dpm_solver', 68 | cfg=True, 69 | scale=0.4, 70 | path='' 71 | ) 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /configs/imagenet512_uvit_huge.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 64, 64) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=64, 44 | patch_size=4, 45 | in_chans=4, 46 | embed_dim=1152, 47 | depth=28, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True, 54 | conv=False 55 | ) 56 | 57 | config.dataset = d( 58 | name='imagenet512_features', 59 | path='assets/datasets/imagenet512_features', 60 | cfg=True, 61 | p_uncond=0.1 62 | ) 63 | 64 | config.sample = d( 65 | sample_steps=50, 66 | n_samples=50000, 67 | mini_batch_size=50, # the decoder is large 68 | algorithm='dpm_solver', 69 | cfg=True, 70 | scale=0.7, 71 | path='' 72 | ) 73 | 74 | return config 75 | -------------------------------------------------------------------------------- /configs/imagenet512_uvit_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 64, 64) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth' 18 | ) 19 | 20 | config.train = d( 21 | n_steps=500000, 22 | batch_size=1024, 23 | mode='cond', 24 | log_interval=10, 25 | eval_interval=5000, 26 | save_interval=50000, 27 | ) 28 | 29 | config.optimizer = d( 30 | name='adamw', 31 | lr=0.0002, 32 | weight_decay=0.03, 33 | betas=(0.99, 0.99), 34 | ) 35 | 36 | config.lr_scheduler = d( 37 | name='customized', 38 | warmup_steps=5000 39 | ) 40 | 41 | config.nnet = d( 42 | name='uvit', 43 | img_size=64, 44 | patch_size=4, 45 | in_chans=4, 46 | embed_dim=1024, 47 | depth=20, 48 | num_heads=16, 49 | mlp_ratio=4, 50 | qkv_bias=False, 51 | mlp_time_embed=False, 52 | num_classes=1001, 53 | use_checkpoint=True 54 | ) 55 | 56 | config.dataset = d( 57 | name='imagenet512_features', 58 | path='assets/datasets/imagenet512_features', 59 | cfg=True, 60 | p_uncond=0.15 61 | ) 62 | 63 | config.sample = d( 64 | sample_steps=50, 65 | n_samples=50000, 66 | mini_batch_size=50, # the decoder is large 67 | algorithm='dpm_solver', 68 | cfg=True, 69 | scale=0.7, 70 | path='' 71 | ) 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /configs/imagenet64_uvit_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=300000, 17 | batch_size=1024, 18 | mode='cond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0003, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.99), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=5000 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=64, 39 | patch_size=4, 40 | embed_dim=1024, 41 | depth=20, 42 | num_heads=16, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=1000, 47 | use_checkpoint=True 48 | ) 49 | 50 | config.dataset = d( 51 | name='imagenet', 52 | path='assets/datasets/ImageNet', 53 | resolution=64, 54 | ) 55 | 56 | config.sample = d( 57 | sample_steps=50, 58 | n_samples=50000, 59 | mini_batch_size=200, 60 | algorithm='dpm_solver', 61 | path='' 62 | ) 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /configs/imagenet64_uvit_mid.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | config.train = d( 16 | n_steps=300000, 17 | batch_size=1024, 18 | mode='cond', 19 | log_interval=10, 20 | eval_interval=5000, 21 | save_interval=50000, 22 | ) 23 | 24 | config.optimizer = d( 25 | name='adamw', 26 | lr=0.0003, 27 | weight_decay=0.03, 28 | betas=(0.99, 0.99), 29 | ) 30 | 31 | config.lr_scheduler = d( 32 | name='customized', 33 | warmup_steps=5000 34 | ) 35 | 36 | config.nnet = d( 37 | name='uvit', 38 | img_size=64, 39 | patch_size=4, 40 | embed_dim=768, 41 | depth=16, 42 | num_heads=12, 43 | mlp_ratio=4, 44 | qkv_bias=False, 45 | mlp_time_embed=False, 46 | num_classes=1000, 47 | use_checkpoint=True 48 | ) 49 | 50 | config.dataset = d( 51 | name='imagenet', 52 | path='assets/datasets/ImageNet', 53 | resolution=64, 54 | ) 55 | 56 | config.sample = d( 57 | sample_steps=50, 58 | n_samples=50000, 59 | mini_batch_size=200, 60 | algorithm='dpm_solver', 61 | path='' 62 | ) 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /configs/mscoco_uvit_small.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.z_shape = (4, 32, 32) 14 | 15 | config.autoencoder = d( 16 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth', 17 | scale_factor=0.23010 18 | ) 19 | 20 | config.train = d( 21 | n_steps=1000000, 22 | batch_size=256, 23 | log_interval=10, 24 | eval_interval=5000, 25 | save_interval=50000, 26 | ) 27 | 28 | config.optimizer = d( 29 | name='adamw', 30 | lr=0.0002, 31 | weight_decay=0.03, 32 | betas=(0.9, 0.9), 33 | ) 34 | 35 | config.lr_scheduler = d( 36 | name='customized', 37 | warmup_steps=5000 38 | ) 39 | 40 | config.nnet = d( 41 | name='uvit_t2i', 42 | img_size=32, 43 | in_chans=4, 44 | patch_size=2, 45 | embed_dim=512, 46 | depth=12, 47 | num_heads=8, 48 | mlp_ratio=4, 49 | qkv_bias=False, 50 | mlp_time_embed=False, 51 | clip_dim=768, 52 | num_clip_token=77 53 | ) 54 | 55 | config.dataset = d( 56 | name='mscoco256_features', 57 | path='assets/datasets/coco256_features', 58 | cfg=True, 59 | p_uncond=0.1 60 | ) 61 | 62 | config.sample = d( 63 | sample_steps=50, 64 | n_samples=30000, 65 | mini_batch_size=50, 66 | cfg=True, 67 | scale=1., 68 | path='' 69 | ) 70 | 71 | return config 72 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import datasets 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import torch 6 | import math 7 | import random 8 | from PIL import Image 9 | import os 10 | import glob 11 | import einops 12 | import torchvision.transforms.functional as F 13 | 14 | 15 | class UnlabeledDataset(Dataset): 16 | def __init__(self, dataset): 17 | self.dataset = dataset 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | def __getitem__(self, item): 23 | data = tuple(self.dataset[item][:-1]) # remove label 24 | if len(data) == 1: 25 | data = data[0] 26 | return data 27 | 28 | 29 | class LabeledDataset(Dataset): 30 | def __init__(self, dataset, labels): 31 | self.dataset = dataset 32 | self.labels = labels 33 | 34 | def __len__(self): 35 | return len(self.dataset) 36 | 37 | def __getitem__(self, item): 38 | return self.dataset[item], self.labels[item] 39 | 40 | 41 | class CFGDataset(Dataset): # for classifier free guidance 42 | def __init__(self, dataset, p_uncond, empty_token): 43 | self.dataset = dataset 44 | self.p_uncond = p_uncond 45 | self.empty_token = empty_token 46 | 47 | def __len__(self): 48 | return len(self.dataset) 49 | 50 | def __getitem__(self, item): 51 | x, y = self.dataset[item] 52 | if random.random() < self.p_uncond: 53 | y = self.empty_token 54 | return x, y 55 | 56 | 57 | class DatasetFactory(object): 58 | 59 | def __init__(self): 60 | self.train = None 61 | self.test = None 62 | 63 | def get_split(self, split, labeled=False): 64 | if split == "train": 65 | dataset = self.train 66 | elif split == "test": 67 | dataset = self.test 68 | else: 69 | raise ValueError 70 | 71 | if self.has_label: 72 | return dataset if labeled else UnlabeledDataset(dataset) 73 | else: 74 | assert not labeled 75 | return dataset 76 | 77 | def unpreprocess(self, v): # to B C H W and [0, 1] 78 | v = 0.5 * (v + 1.) 79 | v.clamp_(0., 1.) 80 | return v 81 | 82 | @property 83 | def has_label(self): 84 | return True 85 | 86 | @property 87 | def data_shape(self): 88 | raise NotImplementedError 89 | 90 | @property 91 | def data_dim(self): 92 | return int(np.prod(self.data_shape)) 93 | 94 | @property 95 | def fid_stat(self): 96 | return None 97 | 98 | def sample_label(self, n_samples, device): 99 | raise NotImplementedError 100 | 101 | def label_prob(self, k): 102 | raise NotImplementedError 103 | 104 | 105 | # CIFAR10 106 | 107 | class CIFAR10(DatasetFactory): 108 | r""" CIFAR10 dataset 109 | 110 | Information of the raw dataset: 111 | train: 50,000 112 | test: 10,000 113 | shape: 3 * 32 * 32 114 | """ 115 | 116 | def __init__(self, path, random_flip=False, cfg=False, p_uncond=None): 117 | super().__init__() 118 | 119 | transform_train = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)] 120 | transform_test = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)] 121 | if random_flip: # only for train 122 | transform_train.append(transforms.RandomHorizontalFlip()) 123 | transform_train = transforms.Compose(transform_train) 124 | transform_test = transforms.Compose(transform_test) 125 | self.train = datasets.CIFAR10(path, train=True, transform=transform_train, download=True) 126 | self.test = datasets.CIFAR10(path, train=False, transform=transform_test, download=True) 127 | 128 | assert len(self.train.targets) == 50000 129 | self.K = max(self.train.targets) + 1 130 | self.cnt = torch.tensor([len(np.where(np.array(self.train.targets) == k)[0]) for k in range(self.K)]).float() 131 | self.frac = [self.cnt[k] / 50000 for k in range(self.K)] 132 | print(f'{self.K} classes') 133 | print(f'cnt: {self.cnt}') 134 | print(f'frac: {self.frac}') 135 | 136 | if cfg: # classifier free guidance 137 | assert p_uncond is not None 138 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}') 139 | self.train = CFGDataset(self.train, p_uncond, self.K) 140 | 141 | @property 142 | def data_shape(self): 143 | return 3, 32, 32 144 | 145 | @property 146 | def fid_stat(self): 147 | return 'assets/fid_stats/fid_stats_cifar10_train_pytorch.npz' 148 | 149 | def sample_label(self, n_samples, device): 150 | return torch.multinomial(self.cnt, n_samples, replacement=True).to(device) 151 | 152 | def label_prob(self, k): 153 | return self.frac[k] 154 | 155 | 156 | # ImageNet 157 | 158 | 159 | class FeatureDataset(Dataset): 160 | def __init__(self, path): 161 | super().__init__() 162 | self.path = path 163 | # names = sorted(os.listdir(path)) 164 | # self.files = [os.path.join(path, name) for name in names] 165 | 166 | def __len__(self): 167 | return 1_281_167 * 2 # consider the random flip 168 | 169 | def __getitem__(self, idx): 170 | path = os.path.join(self.path, f'{idx}.npy') 171 | z, label = np.load(path, allow_pickle=True) 172 | return z, label 173 | 174 | 175 | class ImageNet256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder 176 | def __init__(self, path, cfg=False, p_uncond=None): 177 | super().__init__() 178 | print('Prepare dataset...') 179 | self.train = FeatureDataset(path) 180 | print('Prepare dataset ok') 181 | self.K = 1000 182 | 183 | if cfg: # classifier free guidance 184 | assert p_uncond is not None 185 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}') 186 | self.train = CFGDataset(self.train, p_uncond, self.K) 187 | 188 | @property 189 | def data_shape(self): 190 | return 4, 32, 32 191 | 192 | @property 193 | def fid_stat(self): 194 | return f'assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz' 195 | 196 | def sample_label(self, n_samples, device): 197 | return torch.randint(0, 1000, (n_samples,), device=device) 198 | 199 | 200 | class ImageNet512Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder 201 | def __init__(self, path, cfg=False, p_uncond=None): 202 | super().__init__() 203 | print('Prepare dataset...') 204 | self.train = FeatureDataset(path) 205 | print('Prepare dataset ok') 206 | self.K = 1000 207 | 208 | if cfg: # classifier free guidance 209 | assert p_uncond is not None 210 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}') 211 | self.train = CFGDataset(self.train, p_uncond, self.K) 212 | 213 | @property 214 | def data_shape(self): 215 | return 4, 64, 64 216 | 217 | @property 218 | def fid_stat(self): 219 | return f'assets/fid_stats/fid_stats_imagenet512_guided_diffusion.npz' 220 | 221 | def sample_label(self, n_samples, device): 222 | return torch.randint(0, 1000, (n_samples,), device=device) 223 | 224 | 225 | class ImageNet(DatasetFactory): 226 | def __init__(self, path, resolution, random_crop=False, random_flip=True): 227 | super().__init__() 228 | 229 | print(f'Counting ImageNet files from {path}') 230 | train_files = _list_image_files_recursively(os.path.join(path, 'train')) 231 | class_names = [os.path.basename(path).split("_")[0] for path in train_files] 232 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 233 | train_labels = [sorted_classes[x] for x in class_names] 234 | print('Finish counting ImageNet files') 235 | 236 | self.train = ImageDataset(resolution, train_files, labels=train_labels, random_crop=random_crop, random_flip=random_flip) 237 | self.resolution = resolution 238 | if len(self.train) != 1_281_167: 239 | print(f'Missing train samples: {len(self.train)} < 1281167') 240 | 241 | self.K = max(self.train.labels) + 1 242 | cnt = dict(zip(*np.unique(self.train.labels, return_counts=True))) 243 | self.cnt = torch.tensor([cnt[k] for k in range(self.K)]).float() 244 | self.frac = [self.cnt[k] / len(self.train.labels) for k in range(self.K)] 245 | print(f'{self.K} classes') 246 | print(f'cnt[:10]: {self.cnt[:10]}') 247 | print(f'frac[:10]: {self.frac[:10]}') 248 | 249 | @property 250 | def data_shape(self): 251 | return 3, self.resolution, self.resolution 252 | 253 | @property 254 | def fid_stat(self): 255 | return f'assets/fid_stats/fid_stats_imagenet{self.resolution}_guided_diffusion.npz' 256 | 257 | def sample_label(self, n_samples, device): 258 | return torch.multinomial(self.cnt, n_samples, replacement=True).to(device) 259 | 260 | def label_prob(self, k): 261 | return self.frac[k] 262 | 263 | 264 | def _list_image_files_recursively(data_dir): 265 | results = [] 266 | for entry in sorted(os.listdir(data_dir)): 267 | full_path = os.path.join(data_dir, entry) 268 | ext = entry.split(".")[-1] 269 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 270 | results.append(full_path) 271 | elif os.listdir(full_path): 272 | results.extend(_list_image_files_recursively(full_path)) 273 | return results 274 | 275 | 276 | class ImageDataset(Dataset): 277 | def __init__( 278 | self, 279 | resolution, 280 | image_paths, 281 | labels, 282 | random_crop=False, 283 | random_flip=True, 284 | ): 285 | super().__init__() 286 | self.resolution = resolution 287 | self.image_paths = image_paths 288 | self.labels = labels 289 | self.random_crop = random_crop 290 | self.random_flip = random_flip 291 | 292 | def __len__(self): 293 | return len(self.image_paths) 294 | 295 | def __getitem__(self, idx): 296 | path = self.image_paths[idx] 297 | pil_image = Image.open(path) 298 | pil_image.load() 299 | pil_image = pil_image.convert("RGB") 300 | 301 | if self.random_crop: 302 | arr = random_crop_arr(pil_image, self.resolution) 303 | else: 304 | arr = center_crop_arr(pil_image, self.resolution) 305 | 306 | if self.random_flip and random.random() < 0.5: 307 | arr = arr[:, ::-1] 308 | 309 | arr = arr.astype(np.float32) / 127.5 - 1 310 | 311 | label = np.array(self.labels[idx], dtype=np.int64) 312 | return np.transpose(arr, [2, 0, 1]), label 313 | 314 | 315 | def center_crop_arr(pil_image, image_size): 316 | # We are not on a new enough PIL to support the `reducing_gap` 317 | # argument, which uses BOX downsampling at powers of two first. 318 | # Thus, we do it by hand to improve downsample quality. 319 | while min(*pil_image.size) >= 2 * image_size: 320 | pil_image = pil_image.resize( 321 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 322 | ) 323 | 324 | scale = image_size / min(*pil_image.size) 325 | pil_image = pil_image.resize( 326 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 327 | ) 328 | 329 | arr = np.array(pil_image) 330 | crop_y = (arr.shape[0] - image_size) // 2 331 | crop_x = (arr.shape[1] - image_size) // 2 332 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 333 | 334 | 335 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 336 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 337 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 338 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 339 | 340 | # We are not on a new enough PIL to support the `reducing_gap` 341 | # argument, which uses BOX downsampling at powers of two first. 342 | # Thus, we do it by hand to improve downsample quality. 343 | while min(*pil_image.size) >= 2 * smaller_dim_size: 344 | pil_image = pil_image.resize( 345 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 346 | ) 347 | 348 | scale = smaller_dim_size / min(*pil_image.size) 349 | pil_image = pil_image.resize( 350 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 351 | ) 352 | 353 | arr = np.array(pil_image) 354 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 355 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 356 | return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size] 357 | 358 | 359 | # CelebA 360 | 361 | 362 | class Crop(object): 363 | def __init__(self, x1, x2, y1, y2): 364 | self.x1 = x1 365 | self.x2 = x2 366 | self.y1 = y1 367 | self.y2 = y2 368 | 369 | def __call__(self, img): 370 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 371 | 372 | def __repr__(self): 373 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 374 | self.x1, self.x2, self.y1, self.y2 375 | ) 376 | 377 | 378 | class CelebA(DatasetFactory): 379 | r""" train: 162,770 380 | val: 19,867 381 | test: 19,962 382 | shape: 3 * width * width 383 | """ 384 | 385 | def __init__(self, path, resolution=64): 386 | super().__init__() 387 | 388 | self.resolution = resolution 389 | 390 | cx = 89 391 | cy = 121 392 | x1 = cy - 64 393 | x2 = cy + 64 394 | y1 = cx - 64 395 | y2 = cx + 64 396 | 397 | transform = transforms.Compose([Crop(x1, x2, y1, y2), transforms.Resize(self.resolution), 398 | transforms.RandomHorizontalFlip(), transforms.ToTensor(), 399 | transforms.Normalize(0.5, 0.5)]) 400 | self.train = datasets.CelebA(root=path, split="train", target_type=[], transform=transform, download=True) 401 | self.train = UnlabeledDataset(self.train) 402 | 403 | @property 404 | def data_shape(self): 405 | return 3, self.resolution, self.resolution 406 | 407 | @property 408 | def fid_stat(self): 409 | return 'assets/fid_stats/fid_stats_celeba64_train_50000_ddim.npz' 410 | 411 | @property 412 | def has_label(self): 413 | return False 414 | 415 | 416 | # MS COCO 417 | 418 | 419 | def center_crop(width, height, img): 420 | resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos'] 421 | crop = np.min(img.shape[:2]) 422 | img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2, 423 | (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] 424 | try: 425 | img = Image.fromarray(img, 'RGB') 426 | except: 427 | img = Image.fromarray(img) 428 | img = img.resize((width, height), resample) 429 | 430 | return np.array(img).astype(np.uint8) 431 | 432 | 433 | class MSCOCODatabase(Dataset): 434 | def __init__(self, root, annFile, size=None): 435 | from pycocotools.coco import COCO 436 | self.root = root 437 | self.height = self.width = size 438 | 439 | self.coco = COCO(annFile) 440 | self.keys = list(sorted(self.coco.imgs.keys())) 441 | 442 | def _load_image(self, key: int): 443 | path = self.coco.loadImgs(key)[0]["file_name"] 444 | return Image.open(os.path.join(self.root, path)).convert("RGB") 445 | 446 | def _load_target(self, key: int): 447 | return self.coco.loadAnns(self.coco.getAnnIds(key)) 448 | 449 | def __len__(self): 450 | return len(self.keys) 451 | 452 | def __getitem__(self, index): 453 | key = self.keys[index] 454 | image = self._load_image(key) 455 | image = np.array(image).astype(np.uint8) 456 | image = center_crop(self.width, self.height, image).astype(np.float32) 457 | image = (image / 127.5 - 1.0).astype(np.float32) 458 | image = einops.rearrange(image, 'h w c -> c h w') 459 | 460 | anns = self._load_target(key) 461 | target = [] 462 | for ann in anns: 463 | target.append(ann['caption']) 464 | 465 | return image, target 466 | 467 | 468 | def get_feature_dir_info(root): 469 | files = glob.glob(os.path.join(root, '*.npy')) 470 | files_caption = glob.glob(os.path.join(root, '*_*.npy')) 471 | num_data = len(files) - len(files_caption) 472 | n_captions = {k: 0 for k in range(num_data)} 473 | for f in files_caption: 474 | name = os.path.split(f)[-1] 475 | k1, k2 = os.path.splitext(name)[0].split('_') 476 | n_captions[int(k1)] += 1 477 | return num_data, n_captions 478 | 479 | 480 | class MSCOCOFeatureDataset(Dataset): 481 | # the image features are got through sample 482 | def __init__(self, root): 483 | self.root = root 484 | self.num_data, self.n_captions = get_feature_dir_info(root) 485 | 486 | def __len__(self): 487 | return self.num_data 488 | 489 | def __getitem__(self, index): 490 | z = np.load(os.path.join(self.root, f'{index}.npy')) 491 | k = random.randint(0, self.n_captions[index] - 1) 492 | c = np.load(os.path.join(self.root, f'{index}_{k}.npy')) 493 | return z, c 494 | 495 | 496 | class MSCOCO256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip 497 | def __init__(self, path, cfg=False, p_uncond=None): 498 | super().__init__() 499 | print('Prepare dataset...') 500 | self.train = MSCOCOFeatureDataset(os.path.join(path, 'train')) 501 | self.test = MSCOCOFeatureDataset(os.path.join(path, 'val')) 502 | assert len(self.train) == 82783 503 | assert len(self.test) == 40504 504 | print('Prepare dataset ok') 505 | 506 | self.empty_context = np.load(os.path.join(path, 'empty_context.npy')) 507 | 508 | if cfg: # classifier free guidance 509 | assert p_uncond is not None 510 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}') 511 | self.train = CFGDataset(self.train, p_uncond, self.empty_context) 512 | 513 | # text embedding extracted by clip 514 | # for visulization in t2i 515 | self.prompts, self.contexts = [], [] 516 | for f in sorted(os.listdir(os.path.join(path, 'run_vis')), key=lambda x: int(x.split('.')[0])): 517 | prompt, context = np.load(os.path.join(path, 'run_vis', f), allow_pickle=True) 518 | self.prompts.append(prompt) 519 | self.contexts.append(context) 520 | self.contexts = np.array(self.contexts) 521 | 522 | @property 523 | def data_shape(self): 524 | return 4, 32, 32 525 | 526 | @property 527 | def fid_stat(self): 528 | return f'assets/fid_stats/fid_stats_mscoco256_val.npz' 529 | 530 | 531 | def get_dataset(name, **kwargs): 532 | if name == 'cifar10': 533 | return CIFAR10(**kwargs) 534 | elif name == 'imagenet': 535 | return ImageNet(**kwargs) 536 | elif name == 'imagenet256_features': 537 | return ImageNet256Features(**kwargs) 538 | elif name == 'imagenet512_features': 539 | return ImageNet512Features(**kwargs) 540 | elif name == 'celeba': 541 | return CelebA(**kwargs) 542 | elif name == 'mscoco256_features': 543 | return MSCOCO256Features(**kwargs) 544 | else: 545 | raise NotImplementedError(name) 546 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | 14 | 15 | def evaluate(config): 16 | if config.get('benchmark', False): 17 | torch.backends.cudnn.benchmark = True 18 | torch.backends.cudnn.deterministic = False 19 | 20 | mp.set_start_method('spawn') 21 | accelerator = accelerate.Accelerator() 22 | device = accelerator.device 23 | accelerate.utils.set_seed(config.seed, device_specific=True) 24 | logging.info(f'Process {accelerator.process_index} using device: {device}') 25 | 26 | config.mixed_precision = accelerator.mixed_precision 27 | config = ml_collections.FrozenConfigDict(config) 28 | if accelerator.is_main_process: 29 | utils.set_logger(log_level='info', fname=config.output_path) 30 | else: 31 | utils.set_logger(log_level='error') 32 | builtins.print = lambda *args: None 33 | 34 | dataset = get_dataset(**config.dataset) 35 | 36 | nnet = utils.get_nnet(**config.nnet) 37 | nnet = accelerator.prepare(nnet) 38 | logging.info(f'load nnet from {config.nnet_path}') 39 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 40 | nnet.eval() 41 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 42 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 43 | def cfg_nnet(x, timesteps, y): 44 | _cond = nnet(x, timesteps, y=y) 45 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 46 | return _cond + config.sample.scale * (_cond - _uncond) 47 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 48 | else: 49 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 50 | 51 | 52 | logging.info(config.sample) 53 | assert os.path.exists(dataset.fid_stat) 54 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 55 | 56 | def sample_fn(_n_samples): 57 | x_init = torch.randn(_n_samples, *dataset.data_shape, device=device) 58 | if config.train.mode == 'uncond': 59 | kwargs = dict() 60 | elif config.train.mode == 'cond': 61 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 62 | else: 63 | raise NotImplementedError 64 | 65 | if config.sample.algorithm == 'euler_maruyama_sde': 66 | rsde = sde.ReverseSDE(score_model) 67 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 68 | elif config.sample.algorithm == 'euler_maruyama_ode': 69 | rsde = sde.ODE(score_model) 70 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 71 | elif config.sample.algorithm == 'dpm_solver': 72 | noise_schedule = NoiseScheduleVP(schedule='linear') 73 | model_fn = model_wrapper( 74 | score_model.noise_pred, 75 | noise_schedule, 76 | time_input_type='0', 77 | model_kwargs=kwargs 78 | ) 79 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 80 | return dpm_solver.sample( 81 | x_init, 82 | steps=config.sample.sample_steps, 83 | eps=1e-4, 84 | adaptive_step_size=False, 85 | fast_version=True, 86 | ) 87 | else: 88 | raise NotImplementedError 89 | 90 | with tempfile.TemporaryDirectory() as temp_path: 91 | path = config.sample.path or temp_path 92 | if accelerator.is_main_process: 93 | os.makedirs(path, exist_ok=True) 94 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 95 | if accelerator.is_main_process: 96 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 97 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 98 | 99 | 100 | from absl import flags 101 | from absl import app 102 | from ml_collections import config_flags 103 | import os 104 | 105 | 106 | FLAGS = flags.FLAGS 107 | config_flags.DEFINE_config_file( 108 | "config", None, "Training configuration.", lock_config=False) 109 | flags.mark_flags_as_required(["config"]) 110 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 111 | flags.DEFINE_string("output_path", None, "The path to output log.") 112 | 113 | 114 | def main(argv): 115 | config = FLAGS.config 116 | config.nnet_path = FLAGS.nnet_path 117 | config.output_path = FLAGS.output_path 118 | evaluate(config) 119 | 120 | 121 | if __name__ == "__main__": 122 | app.run(main) 123 | -------------------------------------------------------------------------------- /eval_ldm.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | import libs.autoencoder 14 | 15 | 16 | def evaluate(config): 17 | if config.get('benchmark', False): 18 | torch.backends.cudnn.benchmark = True 19 | torch.backends.cudnn.deterministic = False 20 | 21 | mp.set_start_method('spawn') 22 | accelerator = accelerate.Accelerator() 23 | device = accelerator.device 24 | accelerate.utils.set_seed(config.seed, device_specific=True) 25 | logging.info(f'Process {accelerator.process_index} using device: {device}') 26 | 27 | config.mixed_precision = accelerator.mixed_precision 28 | config = ml_collections.FrozenConfigDict(config) 29 | if accelerator.is_main_process: 30 | utils.set_logger(log_level='info', fname=config.output_path) 31 | else: 32 | utils.set_logger(log_level='error') 33 | builtins.print = lambda *args: None 34 | 35 | dataset = get_dataset(**config.dataset) 36 | 37 | nnet = utils.get_nnet(**config.nnet) 38 | nnet = accelerator.prepare(nnet) 39 | logging.info(f'load nnet from {config.nnet_path}') 40 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 41 | nnet.eval() 42 | 43 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 44 | autoencoder.to(device) 45 | 46 | @torch.cuda.amp.autocast() 47 | def encode(_batch): 48 | return autoencoder.encode(_batch) 49 | 50 | @torch.cuda.amp.autocast() 51 | def decode(_batch): 52 | return autoencoder.decode(_batch) 53 | 54 | def decode_large_batch(_batch): 55 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 56 | xs = [] 57 | pt = 0 58 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 59 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 60 | pt += _decode_mini_batch_size 61 | xs.append(x) 62 | xs = torch.concat(xs, dim=0) 63 | assert xs.size(0) == _batch.size(0) 64 | return xs 65 | 66 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 67 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 68 | def cfg_nnet(x, timesteps, y): 69 | _cond = nnet(x, timesteps, y=y) 70 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 71 | return _cond + config.sample.scale * (_cond - _uncond) 72 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 73 | else: 74 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 75 | 76 | logging.info(config.sample) 77 | assert os.path.exists(dataset.fid_stat) 78 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 79 | 80 | def sample_fn(_n_samples): 81 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 82 | if config.train.mode == 'uncond': 83 | kwargs = dict() 84 | elif config.train.mode == 'cond': 85 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 86 | else: 87 | raise NotImplementedError 88 | 89 | if config.sample.algorithm == 'euler_maruyama_sde': 90 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 91 | elif config.sample.algorithm == 'euler_maruyama_ode': 92 | _z = sde.euler_maruyama(sde.ODE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 93 | elif config.sample.algorithm == 'dpm_solver': 94 | noise_schedule = NoiseScheduleVP(schedule='linear') 95 | model_fn = model_wrapper( 96 | score_model.noise_pred, 97 | noise_schedule, 98 | time_input_type='0', 99 | model_kwargs=kwargs 100 | ) 101 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 102 | _z = dpm_solver.sample( 103 | _z_init, 104 | steps=config.sample.sample_steps, 105 | eps=1e-4, 106 | adaptive_step_size=False, 107 | fast_version=True, 108 | ) 109 | else: 110 | raise NotImplementedError 111 | return decode_large_batch(_z) 112 | 113 | with tempfile.TemporaryDirectory() as temp_path: 114 | path = config.sample.path or temp_path 115 | if accelerator.is_main_process: 116 | os.makedirs(path, exist_ok=True) 117 | logging.info(f'Samples are saved in {path}') 118 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 119 | if accelerator.is_main_process: 120 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 121 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 122 | 123 | 124 | from absl import flags 125 | from absl import app 126 | from ml_collections import config_flags 127 | import os 128 | 129 | 130 | FLAGS = flags.FLAGS 131 | config_flags.DEFINE_config_file( 132 | "config", None, "Training configuration.", lock_config=False) 133 | flags.mark_flags_as_required(["config"]) 134 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 135 | flags.DEFINE_string("output_path", None, "The path to output log.") 136 | 137 | 138 | def main(argv): 139 | config = FLAGS.config 140 | config.nnet_path = FLAGS.nnet_path 141 | config.output_path = FLAGS.output_path 142 | evaluate(config) 143 | 144 | 145 | if __name__ == "__main__": 146 | app.run(main) 147 | -------------------------------------------------------------------------------- /eval_ldm_discrete.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | from datasets import get_dataset 8 | import tempfile 9 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 10 | from absl import logging 11 | import builtins 12 | import libs.autoencoder 13 | 14 | 15 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 16 | _betas = ( 17 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 18 | ) 19 | return _betas.numpy() 20 | 21 | 22 | def evaluate(config): 23 | if config.get('benchmark', False): 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.deterministic = False 26 | 27 | mp.set_start_method('spawn') 28 | accelerator = accelerate.Accelerator() 29 | device = accelerator.device 30 | accelerate.utils.set_seed(config.seed, device_specific=True) 31 | logging.info(f'Process {accelerator.process_index} using device: {device}') 32 | 33 | config.mixed_precision = accelerator.mixed_precision 34 | config = ml_collections.FrozenConfigDict(config) 35 | if accelerator.is_main_process: 36 | utils.set_logger(log_level='info', fname=config.output_path) 37 | else: 38 | utils.set_logger(log_level='error') 39 | builtins.print = lambda *args: None 40 | 41 | dataset = get_dataset(**config.dataset) 42 | 43 | nnet = utils.get_nnet(**config.nnet) 44 | nnet = accelerator.prepare(nnet) 45 | logging.info(f'load nnet from {config.nnet_path}') 46 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 47 | nnet.eval() 48 | 49 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 50 | autoencoder.to(device) 51 | 52 | @torch.cuda.amp.autocast() 53 | def encode(_batch): 54 | return autoencoder.encode(_batch) 55 | 56 | @torch.cuda.amp.autocast() 57 | def decode(_batch): 58 | return autoencoder.decode(_batch) 59 | 60 | def decode_large_batch(_batch): 61 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 62 | xs = [] 63 | pt = 0 64 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 65 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 66 | pt += _decode_mini_batch_size 67 | xs.append(x) 68 | xs = torch.concat(xs, dim=0) 69 | assert xs.size(0) == _batch.size(0) 70 | return xs 71 | 72 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 73 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 74 | def cfg_nnet(x, timesteps, y): 75 | _cond = nnet(x, timesteps, y=y) 76 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 77 | return _cond + config.sample.scale * (_cond - _uncond) 78 | else: 79 | def cfg_nnet(x, timesteps, y): 80 | _cond = nnet(x, timesteps, y=y) 81 | return _cond 82 | 83 | logging.info(config.sample) 84 | assert os.path.exists(dataset.fid_stat) 85 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 86 | 87 | _betas = stable_diffusion_beta_schedule() 88 | N = len(_betas) 89 | 90 | def sample_z(_n_samples, _sample_steps, **kwargs): 91 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 92 | 93 | if config.sample.algorithm == 'dpm_solver': 94 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 95 | 96 | def model_fn(x, t_continuous): 97 | t = t_continuous * N 98 | eps_pre = cfg_nnet(x, t, **kwargs) 99 | return eps_pre 100 | 101 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 102 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.) 103 | 104 | else: 105 | raise NotImplementedError 106 | 107 | return _z 108 | 109 | def sample_fn(_n_samples): 110 | if config.train.mode == 'uncond': 111 | kwargs = dict() 112 | elif config.train.mode == 'cond': 113 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 114 | else: 115 | raise NotImplementedError 116 | _z = sample_z(_n_samples, _sample_steps=config.sample.sample_steps, **kwargs) 117 | return decode_large_batch(_z) 118 | 119 | with tempfile.TemporaryDirectory() as temp_path: 120 | path = config.sample.path or temp_path 121 | if accelerator.is_main_process: 122 | os.makedirs(path, exist_ok=True) 123 | logging.info(f'Samples are saved in {path}') 124 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 125 | if accelerator.is_main_process: 126 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 127 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 128 | 129 | 130 | from absl import flags 131 | from absl import app 132 | from ml_collections import config_flags 133 | import os 134 | 135 | 136 | FLAGS = flags.FLAGS 137 | config_flags.DEFINE_config_file( 138 | "config", None, "Training configuration.", lock_config=False) 139 | flags.mark_flags_as_required(["config"]) 140 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 141 | flags.DEFINE_string("output_path", None, "The path to output log.") 142 | 143 | 144 | def main(argv): 145 | config = FLAGS.config 146 | config.nnet_path = FLAGS.nnet_path 147 | config.output_path = FLAGS.output_path 148 | evaluate(config) 149 | 150 | 151 | if __name__ == "__main__": 152 | app.run(main) 153 | -------------------------------------------------------------------------------- /eval_t2i_discrete.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | from torch.utils.data import DataLoader 7 | import utils 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | import einops 14 | import libs.autoencoder 15 | 16 | 17 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 18 | _betas = ( 19 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 20 | ) 21 | return _betas.numpy() 22 | 23 | 24 | def evaluate(config): 25 | if config.get('benchmark', False): 26 | torch.backends.cudnn.benchmark = True 27 | torch.backends.cudnn.deterministic = False 28 | 29 | mp.set_start_method('spawn') 30 | accelerator = accelerate.Accelerator() 31 | device = accelerator.device 32 | accelerate.utils.set_seed(config.seed, device_specific=True) 33 | logging.info(f'Process {accelerator.process_index} using device: {device}') 34 | 35 | config.mixed_precision = accelerator.mixed_precision 36 | config = ml_collections.FrozenConfigDict(config) 37 | if accelerator.is_main_process: 38 | utils.set_logger(log_level='info', fname=config.output_path) 39 | else: 40 | utils.set_logger(log_level='error') 41 | builtins.print = lambda *args: None 42 | 43 | dataset = get_dataset(**config.dataset) 44 | test_dataset = dataset.get_split(split='test', labeled=True) # for sampling 45 | test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, 46 | drop_last=True, num_workers=8, pin_memory=True, persistent_workers=True) 47 | 48 | nnet = utils.get_nnet(**config.nnet) 49 | nnet, test_dataset_loader = accelerator.prepare(nnet, test_dataset_loader) 50 | logging.info(f'load nnet from {config.nnet_path}') 51 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 52 | nnet.eval() 53 | 54 | def cfg_nnet(x, timesteps, context): 55 | _cond = nnet(x, timesteps, context=context) 56 | if config.sample.scale == 0: 57 | return _cond 58 | _empty_context = torch.tensor(dataset.empty_context, device=device) 59 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) 60 | _uncond = nnet(x, timesteps, context=_empty_context) 61 | return _cond + config.sample.scale * (_cond - _uncond) 62 | 63 | autoencoder = libs.autoencoder.get_model(**config.autoencoder) 64 | autoencoder.to(device) 65 | 66 | @torch.cuda.amp.autocast() 67 | def encode(_batch): 68 | return autoencoder.encode(_batch) 69 | 70 | @torch.cuda.amp.autocast() 71 | def decode(_batch): 72 | return autoencoder.decode(_batch) 73 | 74 | def decode_large_batch(_batch): 75 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 76 | xs = [] 77 | pt = 0 78 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 79 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 80 | pt += _decode_mini_batch_size 81 | xs.append(x) 82 | xs = torch.concat(xs, dim=0) 83 | assert xs.size(0) == _batch.size(0) 84 | return xs 85 | 86 | def get_context_generator(): 87 | while True: 88 | for data in test_dataset_loader: 89 | _, _context = data 90 | yield _context 91 | 92 | context_generator = get_context_generator() 93 | 94 | _betas = stable_diffusion_beta_schedule() 95 | N = len(_betas) 96 | 97 | logging.info(config.sample) 98 | assert os.path.exists(dataset.fid_stat) 99 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}') 100 | 101 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs): 102 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 103 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 104 | 105 | def model_fn(x, t_continuous): 106 | t = t_continuous * N 107 | return cfg_nnet(x, t, **kwargs) 108 | 109 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 110 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.) 111 | return decode_large_batch(_z) 112 | 113 | def sample_fn(_n_samples): 114 | _context = next(context_generator) 115 | assert _context.size(0) == _n_samples 116 | return dpm_solver_sample(_n_samples, config.sample.sample_steps, context=_context) 117 | 118 | with tempfile.TemporaryDirectory() as temp_path: 119 | path = config.sample.path or temp_path 120 | if accelerator.is_main_process: 121 | os.makedirs(path, exist_ok=True) 122 | logging.info(f'Samples are saved in {path}') 123 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 124 | if accelerator.is_main_process: 125 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 126 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 127 | 128 | 129 | from absl import flags 130 | from absl import app 131 | from ml_collections import config_flags 132 | import os 133 | 134 | 135 | FLAGS = flags.FLAGS 136 | config_flags.DEFINE_config_file( 137 | "config", None, "Training configuration.", lock_config=False) 138 | flags.mark_flags_as_required(["config"]) 139 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 140 | flags.DEFINE_string("output_path", None, "The path to output log.") 141 | 142 | 143 | def main(argv): 144 | config = FLAGS.config 145 | config.nnet_path = FLAGS.nnet_path 146 | config.output_path = FLAGS.output_path 147 | evaluate(config) 148 | 149 | 150 | if __name__ == "__main__": 151 | app.run(main) 152 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # codes from third party 2 | -------------------------------------------------------------------------------- /libs/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from einops import rearrange 5 | 6 | 7 | class LinearAttention(nn.Module): 8 | def __init__(self, dim, heads=4, dim_head=32): 9 | super().__init__() 10 | self.heads = heads 11 | hidden_dim = dim_head * heads 12 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 13 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 14 | 15 | def forward(self, x): 16 | b, c, h, w = x.shape 17 | qkv = self.to_qkv(x) 18 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 19 | k = k.softmax(dim=-1) 20 | context = torch.einsum('bhdn,bhen->bhde', k, v) 21 | out = torch.einsum('bhde,bhdn->bhen', context, q) 22 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 23 | return self.to_out(out) 24 | 25 | 26 | def nonlinearity(x): 27 | # swish 28 | return x*torch.sigmoid(x) 29 | 30 | 31 | def Normalize(in_channels, num_groups=32): 32 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) 33 | 34 | 35 | class Upsample(nn.Module): 36 | def __init__(self, in_channels, with_conv): 37 | super().__init__() 38 | self.with_conv = with_conv 39 | if self.with_conv: 40 | self.conv = torch.nn.Conv2d(in_channels, 41 | in_channels, 42 | kernel_size=3, 43 | stride=1, 44 | padding=1) 45 | 46 | def forward(self, x): 47 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 48 | if self.with_conv: 49 | x = self.conv(x) 50 | return x 51 | 52 | 53 | class Downsample(nn.Module): 54 | def __init__(self, in_channels, with_conv): 55 | super().__init__() 56 | self.with_conv = with_conv 57 | if self.with_conv: 58 | # no asymmetric padding in torch conv, must do it ourselves 59 | self.conv = torch.nn.Conv2d(in_channels, 60 | in_channels, 61 | kernel_size=3, 62 | stride=2, 63 | padding=0) 64 | 65 | def forward(self, x): 66 | if self.with_conv: 67 | pad = (0,1,0,1) 68 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 69 | x = self.conv(x) 70 | else: 71 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 72 | return x 73 | 74 | 75 | class ResnetBlock(nn.Module): 76 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 77 | dropout, temb_channels=512): 78 | super().__init__() 79 | self.in_channels = in_channels 80 | out_channels = in_channels if out_channels is None else out_channels 81 | self.out_channels = out_channels 82 | self.use_conv_shortcut = conv_shortcut 83 | 84 | self.norm1 = Normalize(in_channels) 85 | self.conv1 = torch.nn.Conv2d(in_channels, 86 | out_channels, 87 | kernel_size=3, 88 | stride=1, 89 | padding=1) 90 | if temb_channels > 0: 91 | self.temb_proj = torch.nn.Linear(temb_channels, 92 | out_channels) 93 | self.norm2 = Normalize(out_channels) 94 | self.dropout = torch.nn.Dropout(dropout) 95 | self.conv2 = torch.nn.Conv2d(out_channels, 96 | out_channels, 97 | kernel_size=3, 98 | stride=1, 99 | padding=1) 100 | if self.in_channels != self.out_channels: 101 | if self.use_conv_shortcut: 102 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 103 | out_channels, 104 | kernel_size=3, 105 | stride=1, 106 | padding=1) 107 | else: 108 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 109 | out_channels, 110 | kernel_size=1, 111 | stride=1, 112 | padding=0) 113 | 114 | def forward(self, x, temb): 115 | h = x 116 | h = self.norm1(h) 117 | h = nonlinearity(h) 118 | h = self.conv1(h) 119 | 120 | if temb is not None: 121 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 122 | 123 | h = self.norm2(h) 124 | h = nonlinearity(h) 125 | h = self.dropout(h) 126 | h = self.conv2(h) 127 | 128 | if self.in_channels != self.out_channels: 129 | if self.use_conv_shortcut: 130 | x = self.conv_shortcut(x) 131 | else: 132 | x = self.nin_shortcut(x) 133 | 134 | return x+h 135 | 136 | 137 | class LinAttnBlock(LinearAttention): 138 | """to match AttnBlock usage""" 139 | def __init__(self, in_channels): 140 | super().__init__(dim=in_channels, heads=1, dim_head=in_channels) 141 | 142 | 143 | class AttnBlock(nn.Module): 144 | def __init__(self, in_channels): 145 | super().__init__() 146 | self.in_channels = in_channels 147 | 148 | self.norm = Normalize(in_channels) 149 | self.q = torch.nn.Conv2d(in_channels, 150 | in_channels, 151 | kernel_size=1, 152 | stride=1, 153 | padding=0) 154 | self.k = torch.nn.Conv2d(in_channels, 155 | in_channels, 156 | kernel_size=1, 157 | stride=1, 158 | padding=0) 159 | self.v = torch.nn.Conv2d(in_channels, 160 | in_channels, 161 | kernel_size=1, 162 | stride=1, 163 | padding=0) 164 | self.proj_out = torch.nn.Conv2d(in_channels, 165 | in_channels, 166 | kernel_size=1, 167 | stride=1, 168 | padding=0) 169 | 170 | 171 | def forward(self, x): 172 | h_ = x 173 | h_ = self.norm(h_) 174 | q = self.q(h_) 175 | k = self.k(h_) 176 | v = self.v(h_) 177 | 178 | # compute attention 179 | b,c,h,w = q.shape 180 | q = q.reshape(b,c,h*w) 181 | q = q.permute(0,2,1) # b,hw,c 182 | k = k.reshape(b,c,h*w) # b,c,hw 183 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 184 | w_ = w_ * (int(c)**(-0.5)) 185 | w_ = torch.nn.functional.softmax(w_, dim=2) 186 | 187 | # attend to values 188 | v = v.reshape(b,c,h*w) 189 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 190 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 191 | h_ = h_.reshape(b,c,h,w) 192 | 193 | h_ = self.proj_out(h_) 194 | 195 | return x+h_ 196 | 197 | 198 | def make_attn(in_channels, attn_type="vanilla"): 199 | assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' 200 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 201 | if attn_type == "vanilla": 202 | return AttnBlock(in_channels) 203 | elif attn_type == "none": 204 | return nn.Identity(in_channels) 205 | else: 206 | return LinAttnBlock(in_channels) 207 | 208 | 209 | class Encoder(nn.Module): 210 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 211 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 212 | resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", 213 | **ignore_kwargs): 214 | super().__init__() 215 | if use_linear_attn: attn_type = "linear" 216 | self.ch = ch 217 | self.temb_ch = 0 218 | self.num_resolutions = len(ch_mult) 219 | self.num_res_blocks = num_res_blocks 220 | self.resolution = resolution 221 | self.in_channels = in_channels 222 | 223 | # downsampling 224 | self.conv_in = torch.nn.Conv2d(in_channels, 225 | self.ch, 226 | kernel_size=3, 227 | stride=1, 228 | padding=1) 229 | 230 | curr_res = resolution 231 | in_ch_mult = (1,)+tuple(ch_mult) 232 | self.in_ch_mult = in_ch_mult 233 | self.down = nn.ModuleList() 234 | for i_level in range(self.num_resolutions): 235 | block = nn.ModuleList() 236 | attn = nn.ModuleList() 237 | block_in = ch*in_ch_mult[i_level] 238 | block_out = ch*ch_mult[i_level] 239 | for i_block in range(self.num_res_blocks): 240 | block.append(ResnetBlock(in_channels=block_in, 241 | out_channels=block_out, 242 | temb_channels=self.temb_ch, 243 | dropout=dropout)) 244 | block_in = block_out 245 | if curr_res in attn_resolutions: 246 | attn.append(make_attn(block_in, attn_type=attn_type)) 247 | down = nn.Module() 248 | down.block = block 249 | down.attn = attn 250 | if i_level != self.num_resolutions-1: 251 | down.downsample = Downsample(block_in, resamp_with_conv) 252 | curr_res = curr_res // 2 253 | self.down.append(down) 254 | 255 | # middle 256 | self.mid = nn.Module() 257 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 258 | out_channels=block_in, 259 | temb_channels=self.temb_ch, 260 | dropout=dropout) 261 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 262 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 263 | out_channels=block_in, 264 | temb_channels=self.temb_ch, 265 | dropout=dropout) 266 | 267 | # end 268 | self.norm_out = Normalize(block_in) 269 | self.conv_out = torch.nn.Conv2d(block_in, 270 | 2*z_channels if double_z else z_channels, 271 | kernel_size=3, 272 | stride=1, 273 | padding=1) 274 | 275 | def forward(self, x): 276 | # timestep embedding 277 | temb = None 278 | 279 | # downsampling 280 | hs = [self.conv_in(x)] 281 | for i_level in range(self.num_resolutions): 282 | for i_block in range(self.num_res_blocks): 283 | h = self.down[i_level].block[i_block](hs[-1], temb) 284 | if len(self.down[i_level].attn) > 0: 285 | h = self.down[i_level].attn[i_block](h) 286 | hs.append(h) 287 | if i_level != self.num_resolutions-1: 288 | hs.append(self.down[i_level].downsample(hs[-1])) 289 | 290 | # middle 291 | h = hs[-1] 292 | h = self.mid.block_1(h, temb) 293 | h = self.mid.attn_1(h) 294 | h = self.mid.block_2(h, temb) 295 | 296 | # end 297 | h = self.norm_out(h) 298 | h = nonlinearity(h) 299 | h = self.conv_out(h) 300 | return h 301 | 302 | 303 | class Decoder(nn.Module): 304 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 305 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 306 | resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, 307 | attn_type="vanilla", **ignorekwargs): 308 | super().__init__() 309 | if use_linear_attn: attn_type = "linear" 310 | self.ch = ch 311 | self.temb_ch = 0 312 | self.num_resolutions = len(ch_mult) 313 | self.num_res_blocks = num_res_blocks 314 | self.resolution = resolution 315 | self.in_channels = in_channels 316 | self.give_pre_end = give_pre_end 317 | self.tanh_out = tanh_out 318 | 319 | # compute in_ch_mult, block_in and curr_res at lowest res 320 | in_ch_mult = (1,)+tuple(ch_mult) 321 | block_in = ch*ch_mult[self.num_resolutions-1] 322 | curr_res = resolution // 2**(self.num_resolutions-1) 323 | self.z_shape = (1,z_channels,curr_res,curr_res) 324 | print("Working with z of shape {} = {} dimensions.".format( 325 | self.z_shape, np.prod(self.z_shape))) 326 | 327 | # z to block_in 328 | self.conv_in = torch.nn.Conv2d(z_channels, 329 | block_in, 330 | kernel_size=3, 331 | stride=1, 332 | padding=1) 333 | 334 | # middle 335 | self.mid = nn.Module() 336 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 337 | out_channels=block_in, 338 | temb_channels=self.temb_ch, 339 | dropout=dropout) 340 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 341 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 342 | out_channels=block_in, 343 | temb_channels=self.temb_ch, 344 | dropout=dropout) 345 | 346 | # upsampling 347 | self.up = nn.ModuleList() 348 | for i_level in reversed(range(self.num_resolutions)): 349 | block = nn.ModuleList() 350 | attn = nn.ModuleList() 351 | block_out = ch*ch_mult[i_level] 352 | for i_block in range(self.num_res_blocks+1): 353 | block.append(ResnetBlock(in_channels=block_in, 354 | out_channels=block_out, 355 | temb_channels=self.temb_ch, 356 | dropout=dropout)) 357 | block_in = block_out 358 | if curr_res in attn_resolutions: 359 | attn.append(make_attn(block_in, attn_type=attn_type)) 360 | up = nn.Module() 361 | up.block = block 362 | up.attn = attn 363 | if i_level != 0: 364 | up.upsample = Upsample(block_in, resamp_with_conv) 365 | curr_res = curr_res * 2 366 | self.up.insert(0, up) # prepend to get consistent order 367 | 368 | # end 369 | self.norm_out = Normalize(block_in) 370 | self.conv_out = torch.nn.Conv2d(block_in, 371 | out_ch, 372 | kernel_size=3, 373 | stride=1, 374 | padding=1) 375 | 376 | def forward(self, z): 377 | #assert z.shape[1:] == self.z_shape[1:] 378 | self.last_z_shape = z.shape 379 | 380 | # timestep embedding 381 | temb = None 382 | 383 | # z to block_in 384 | h = self.conv_in(z) 385 | 386 | # middle 387 | h = self.mid.block_1(h, temb) 388 | h = self.mid.attn_1(h) 389 | h = self.mid.block_2(h, temb) 390 | 391 | # upsampling 392 | for i_level in reversed(range(self.num_resolutions)): 393 | for i_block in range(self.num_res_blocks+1): 394 | h = self.up[i_level].block[i_block](h, temb) 395 | if len(self.up[i_level].attn) > 0: 396 | h = self.up[i_level].attn[i_block](h) 397 | if i_level != 0: 398 | h = self.up[i_level].upsample(h) 399 | 400 | # end 401 | if self.give_pre_end: 402 | return h 403 | 404 | h = self.norm_out(h) 405 | h = nonlinearity(h) 406 | h = self.conv_out(h) 407 | if self.tanh_out: 408 | h = torch.tanh(h) 409 | return h 410 | 411 | 412 | class FrozenAutoencoderKL(nn.Module): 413 | def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215): 414 | super().__init__() 415 | print(f'Create autoencoder with scale_factor={scale_factor}') 416 | self.encoder = Encoder(**ddconfig) 417 | self.decoder = Decoder(**ddconfig) 418 | assert ddconfig["double_z"] 419 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 420 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 421 | self.embed_dim = embed_dim 422 | self.scale_factor = scale_factor 423 | m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu')) 424 | assert len(m) == 0 and len(u) == 0 425 | self.eval() 426 | self.requires_grad_(False) 427 | 428 | def encode_moments(self, x): 429 | h = self.encoder(x) 430 | moments = self.quant_conv(h) 431 | return moments 432 | 433 | def sample(self, moments): 434 | mean, logvar = torch.chunk(moments, 2, dim=1) 435 | logvar = torch.clamp(logvar, -30.0, 20.0) 436 | std = torch.exp(0.5 * logvar) 437 | z = mean + std * torch.randn_like(mean) 438 | z = self.scale_factor * z 439 | return z 440 | 441 | def encode(self, x): 442 | moments = self.encode_moments(x) 443 | z = self.sample(moments) 444 | return z 445 | 446 | def decode(self, z): 447 | z = (1. / self.scale_factor) * z 448 | z = self.post_quant_conv(z) 449 | dec = self.decoder(z) 450 | return dec 451 | 452 | def forward(self, inputs, fn): 453 | if fn == 'encode_moments': 454 | return self.encode_moments(inputs) 455 | elif fn == 'encode': 456 | return self.encode(inputs) 457 | elif fn == 'decode': 458 | return self.decode(inputs) 459 | else: 460 | raise NotImplementedError 461 | 462 | 463 | def get_model(pretrained_path, scale_factor=0.18215): 464 | ddconfig = dict( 465 | double_z=True, 466 | z_channels=4, 467 | resolution=256, 468 | in_channels=3, 469 | out_ch=3, 470 | ch=128, 471 | ch_mult=[1, 2, 4, 4], 472 | num_res_blocks=2, 473 | attn_resolutions=[], 474 | dropout=0.0 475 | ) 476 | return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor) 477 | 478 | 479 | def main(): 480 | import torchvision.transforms as transforms 481 | from torchvision.utils import save_image 482 | import os 483 | from PIL import Image 484 | 485 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth') 486 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 487 | model = model.to(device) 488 | 489 | scale_factor = 0.18215 490 | T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()]) 491 | path = 'imgs' 492 | fnames = os.listdir(path) 493 | for fname in fnames: 494 | p = os.path.join(path, fname) 495 | img = Image.open(p) 496 | img = T(img) 497 | img = img * 2. - 1 498 | img = img[None, ...] 499 | img = img.to(device) 500 | 501 | # with torch.cuda.amp.autocast(): 502 | # moments = model.encode_moments(img) 503 | # mean, logvar = torch.chunk(moments, 2, dim=1) 504 | # logvar = torch.clamp(logvar, -30.0, 20.0) 505 | # std = torch.exp(0.5 * logvar) 506 | # zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)] 507 | # recons = [model.decode(z) for z in zs] 508 | 509 | with torch.cuda.amp.autocast(): 510 | print('test encode & decode') 511 | recons = [model.decode(model.encode(img)) for _ in range(4)] 512 | 513 | out = torch.cat([img, *recons], dim=0) 514 | out = (out + 1) * 0.5 515 | save_image(out, f'recons_{fname}') 516 | 517 | 518 | if __name__ == "__main__": 519 | main() 520 | -------------------------------------------------------------------------------- /libs/clip.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import CLIPTokenizer, CLIPTextModel 3 | 4 | 5 | class AbstractEncoder(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def encode(self, *args, **kwargs): 10 | raise NotImplementedError 11 | 12 | 13 | class FrozenCLIPEmbedder(AbstractEncoder): 14 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 16 | super().__init__() 17 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 18 | self.transformer = CLIPTextModel.from_pretrained(version) 19 | self.device = device 20 | self.max_length = max_length 21 | self.freeze() 22 | 23 | def freeze(self): 24 | self.transformer = self.transformer.eval() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def forward(self, text): 29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 31 | tokens = batch_encoding["input_ids"].to(self.device) 32 | outputs = self.transformer(input_ids=tokens) 33 | 34 | z = outputs.last_hidden_state 35 | return z 36 | 37 | def encode(self, text): 38 | return self(text) 39 | -------------------------------------------------------------------------------- /libs/timm.py: -------------------------------------------------------------------------------- 1 | # code from timm 0.3.2 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import warnings 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def drop_path(x, drop_prob: float = 0., training: bool = False): 66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 67 | 68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 72 | 'survival rate' as the argument. 73 | 74 | """ 75 | if drop_prob == 0. or not training: 76 | return x 77 | keep_prob = 1 - drop_prob 78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 80 | random_tensor.floor_() # binarize 81 | output = x.div(keep_prob) * random_tensor 82 | return output 83 | 84 | 85 | class DropPath(nn.Module): 86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 87 | """ 88 | def __init__(self, drop_prob=None): 89 | super(DropPath, self).__init__() 90 | self.drop_prob = drop_prob 91 | 92 | def forward(self, x): 93 | return drop_path(x, self.drop_prob, self.training) 94 | 95 | 96 | class Mlp(nn.Module): 97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 98 | super().__init__() 99 | out_features = out_features or in_features 100 | hidden_features = hidden_features or in_features 101 | self.fc1 = nn.Linear(in_features, hidden_features) 102 | self.act = act_layer() 103 | self.fc2 = nn.Linear(hidden_features, out_features) 104 | self.drop = nn.Dropout(drop) 105 | 106 | def forward(self, x): 107 | x = self.fc1(x) 108 | x = self.act(x) 109 | x = self.drop(x) 110 | x = self.fc2(x) 111 | x = self.drop(x) 112 | return x 113 | -------------------------------------------------------------------------------- /libs/uvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, 141 | use_checkpoint=False, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.num_classes = num_classes 145 | self.in_chans = in_chans 146 | 147 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 148 | num_patches = (img_size // patch_size) ** 2 149 | 150 | self.time_embed = nn.Sequential( 151 | nn.Linear(embed_dim, 4 * embed_dim), 152 | nn.SiLU(), 153 | nn.Linear(4 * embed_dim, embed_dim), 154 | ) if mlp_time_embed else nn.Identity() 155 | 156 | if self.num_classes > 0: 157 | self.label_emb = nn.Embedding(self.num_classes, embed_dim) 158 | self.extras = 2 159 | else: 160 | self.extras = 1 161 | 162 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 163 | 164 | self.in_blocks = nn.ModuleList([ 165 | Block( 166 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 167 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 168 | for _ in range(depth // 2)]) 169 | 170 | self.mid_block = Block( 171 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 172 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 173 | 174 | self.out_blocks = nn.ModuleList([ 175 | Block( 176 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 177 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 178 | for _ in range(depth // 2)]) 179 | 180 | self.norm = norm_layer(embed_dim) 181 | self.patch_dim = patch_size ** 2 * in_chans 182 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 183 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 184 | 185 | trunc_normal_(self.pos_embed, std=.02) 186 | self.apply(self._init_weights) 187 | 188 | def _init_weights(self, m): 189 | if isinstance(m, nn.Linear): 190 | trunc_normal_(m.weight, std=.02) 191 | if isinstance(m, nn.Linear) and m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | elif isinstance(m, nn.LayerNorm): 194 | nn.init.constant_(m.bias, 0) 195 | nn.init.constant_(m.weight, 1.0) 196 | 197 | @torch.jit.ignore 198 | def no_weight_decay(self): 199 | return {'pos_embed'} 200 | 201 | def forward(self, x, timesteps, y=None): 202 | x = self.patch_embed(x) 203 | B, L, D = x.shape 204 | 205 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 206 | time_token = time_token.unsqueeze(dim=1) 207 | x = torch.cat((time_token, x), dim=1) 208 | if y is not None: 209 | label_emb = self.label_emb(y) 210 | label_emb = label_emb.unsqueeze(dim=1) 211 | x = torch.cat((label_emb, x), dim=1) 212 | x = x + self.pos_embed 213 | 214 | skips = [] 215 | for blk in self.in_blocks: 216 | x = blk(x) 217 | skips.append(x) 218 | 219 | x = self.mid_block(x) 220 | 221 | for blk in self.out_blocks: 222 | x = blk(x, skips.pop()) 223 | 224 | x = self.norm(x) 225 | x = self.decoder_pred(x) 226 | assert x.size(1) == self.extras + L 227 | x = x[:, self.extras:, :] 228 | x = unpatchify(x, self.in_chans) 229 | x = self.final_layer(x) 230 | return x 231 | -------------------------------------------------------------------------------- /libs/uvit_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False, 141 | clip_dim=768, num_clip_token=77, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.in_chans = in_chans 145 | 146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | num_patches = (img_size // patch_size) ** 2 148 | 149 | self.time_embed = nn.Sequential( 150 | nn.Linear(embed_dim, 4 * embed_dim), 151 | nn.SiLU(), 152 | nn.Linear(4 * embed_dim, embed_dim), 153 | ) if mlp_time_embed else nn.Identity() 154 | 155 | self.context_embed = nn.Linear(clip_dim, embed_dim) 156 | 157 | self.extras = 1 + num_clip_token 158 | 159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 160 | 161 | self.in_blocks = nn.ModuleList([ 162 | Block( 163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 165 | for _ in range(depth // 2)]) 166 | 167 | self.mid_block = Block( 168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 170 | 171 | self.out_blocks = nn.ModuleList([ 172 | Block( 173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 175 | for _ in range(depth // 2)]) 176 | 177 | self.norm = norm_layer(embed_dim) 178 | self.patch_dim = patch_size ** 2 * in_chans 179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 181 | 182 | trunc_normal_(self.pos_embed, std=.02) 183 | self.apply(self._init_weights) 184 | 185 | def _init_weights(self, m): 186 | if isinstance(m, nn.Linear): 187 | trunc_normal_(m.weight, std=.02) 188 | if isinstance(m, nn.Linear) and m.bias is not None: 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.LayerNorm): 191 | nn.init.constant_(m.bias, 0) 192 | nn.init.constant_(m.weight, 1.0) 193 | 194 | @torch.jit.ignore 195 | def no_weight_decay(self): 196 | return {'pos_embed'} 197 | 198 | def forward(self, x, timesteps, context): 199 | x = self.patch_embed(x) 200 | B, L, D = x.shape 201 | 202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 203 | time_token = time_token.unsqueeze(dim=1) 204 | context_token = self.context_embed(context) 205 | x = torch.cat((time_token, context_token, x), dim=1) 206 | x = x + self.pos_embed 207 | 208 | skips = [] 209 | for blk in self.in_blocks: 210 | x = blk(x) 211 | skips.append(x) 212 | 213 | x = self.mid_block(x) 214 | 215 | for blk in self.out_blocks: 216 | x = blk(x, skips.pop()) 217 | 218 | x = self.norm(x) 219 | x = self.decoder_pred(x) 220 | assert x.size(1) == self.extras + L 221 | x = x[:, self.extras:, :] 222 | x = unpatchify(x, self.in_chans) 223 | x = self.final_layer(x) 224 | return x 225 | -------------------------------------------------------------------------------- /sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/U-ViT/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/sample.png -------------------------------------------------------------------------------- /sample_t2i_discrete.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | from torch import multiprocessing as mp 4 | import accelerate 5 | import utils 6 | from datasets import get_dataset 7 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 8 | from absl import logging 9 | import builtins 10 | import einops 11 | import libs.autoencoder 12 | import libs.clip 13 | from torchvision.utils import save_image 14 | 15 | 16 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 17 | _betas = ( 18 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 19 | ) 20 | return _betas.numpy() 21 | 22 | 23 | def evaluate(config): 24 | if config.get('benchmark', False): 25 | torch.backends.cudnn.benchmark = True 26 | torch.backends.cudnn.deterministic = False 27 | 28 | mp.set_start_method('spawn') 29 | accelerator = accelerate.Accelerator() 30 | device = accelerator.device 31 | accelerate.utils.set_seed(config.seed, device_specific=True) 32 | logging.info(f'Process {accelerator.process_index} using device: {device}') 33 | 34 | config.mixed_precision = accelerator.mixed_precision 35 | config = ml_collections.FrozenConfigDict(config) 36 | if accelerator.is_main_process: 37 | utils.set_logger(log_level='info') 38 | else: 39 | utils.set_logger(log_level='error') 40 | builtins.print = lambda *args: None 41 | 42 | dataset = get_dataset(**config.dataset) 43 | 44 | with open(config.input_path, 'r') as f: 45 | prompts = f.read().strip().split('\n') 46 | 47 | print(prompts) 48 | 49 | clip = libs.clip.FrozenCLIPEmbedder() 50 | clip.eval() 51 | clip.to(device) 52 | 53 | contexts = clip.encode(prompts) 54 | 55 | nnet = utils.get_nnet(**config.nnet) 56 | nnet = accelerator.prepare(nnet) 57 | logging.info(f'load nnet from {config.nnet_path}') 58 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 59 | nnet.eval() 60 | 61 | def cfg_nnet(x, timesteps, context): 62 | _cond = nnet(x, timesteps, context=context) 63 | if config.sample.scale == 0: 64 | return _cond 65 | _empty_context = torch.tensor(dataset.empty_context, device=device) 66 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) 67 | _uncond = nnet(x, timesteps, context=_empty_context) 68 | return _cond + config.sample.scale * (_cond - _uncond) 69 | 70 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 71 | autoencoder.to(device) 72 | 73 | @torch.cuda.amp.autocast() 74 | def encode(_batch): 75 | return autoencoder.encode(_batch) 76 | 77 | @torch.cuda.amp.autocast() 78 | def decode(_batch): 79 | return autoencoder.decode(_batch) 80 | 81 | _betas = stable_diffusion_beta_schedule() 82 | N = len(_betas) 83 | 84 | logging.info(config.sample) 85 | logging.info(f'mixed_precision={config.mixed_precision}') 86 | logging.info(f'N={N}') 87 | 88 | z_init = torch.randn(contexts.size(0), *config.z_shape, device=device) 89 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 90 | 91 | def model_fn(x, t_continuous): 92 | t = t_continuous * N 93 | return cfg_nnet(x, t, context=contexts) 94 | 95 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 96 | z = dpm_solver.sample(z_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) 97 | samples = dataset.unpreprocess(decode(z)) 98 | 99 | os.makedirs(config.output_path, exist_ok=True) 100 | for sample, prompt in zip(samples, prompts): 101 | save_image(sample, os.path.join(config.output_path, f"{prompt}.png")) 102 | 103 | 104 | 105 | from absl import flags 106 | from absl import app 107 | from ml_collections import config_flags 108 | import os 109 | 110 | 111 | FLAGS = flags.FLAGS 112 | config_flags.DEFINE_config_file( 113 | "config", None, "Training configuration.", lock_config=False) 114 | flags.mark_flags_as_required(["config"]) 115 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 116 | flags.DEFINE_string("output_path", None, "The path to output images.") 117 | flags.DEFINE_string("input_path", None, "The path to input texts.") 118 | 119 | 120 | def main(argv): 121 | config = FLAGS.config 122 | config.nnet_path = FLAGS.nnet_path 123 | config.output_path = FLAGS.output_path 124 | config.input_path = FLAGS.input_path 125 | evaluate(config) 126 | 127 | 128 | if __name__ == "__main__": 129 | app.run(main) 130 | -------------------------------------------------------------------------------- /scripts/extract_empty_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | '', 14 | ] 15 | 16 | device = 'cuda' 17 | clip = libs.clip.FrozenCLIPEmbedder() 18 | clip.eval() 19 | clip.to(device) 20 | 21 | save_dir = f'assets/datasets/coco256_features' 22 | latent = clip.encode(prompts) 23 | print(latent.shape) 24 | c = latent[0].detach().cpu().numpy() 25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c) 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /scripts/extract_imagenet_feature.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | from datasets import ImageNet 5 | from torch.utils.data import DataLoader 6 | from libs.autoencoder import get_model 7 | import argparse 8 | from tqdm import tqdm 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | 12 | 13 | def main(resolution=256): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('path') 16 | args = parser.parse_args() 17 | 18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False) 19 | train_dataset = dataset.get_split(split='train', labeled=True) 20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False, 21 | num_workers=8, pin_memory=True, persistent_workers=True) 22 | 23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth') 24 | model = nn.DataParallel(model) 25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | model.to(device) 27 | 28 | # features = [] 29 | # labels = [] 30 | 31 | idx = 0 32 | for batch in tqdm(train_dataset_loader): 33 | img, label = batch 34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0) 35 | img = img.to(device) 36 | moments = model(img, fn='encode_moments') 37 | moments = moments.detach().cpu().numpy() 38 | 39 | label = torch.cat([label, label], dim=0) 40 | label = label.detach().cpu().numpy() 41 | 42 | for moment, lb in zip(moments, label): 43 | np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb)) 44 | idx += 1 45 | 46 | print(f'save {idx} files') 47 | 48 | # features = np.concatenate(features, axis=0) 49 | # labels = np.concatenate(labels, axis=0) 50 | # print(f'features.shape={features.shape}') 51 | # print(f'labels.shape={labels.shape}') 52 | # np.save(f'imagenet{resolution}_features.npy', features) 53 | # np.save(f'imagenet{resolution}_labels.npy', labels) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /scripts/extract_mscoco_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(resolution=256): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--split', default='train') 14 | args = parser.parse_args() 15 | print(args) 16 | 17 | 18 | if args.split == "train": 19 | datas = MSCOCODatabase(root='assets/datasets/coco/train2014', 20 | annFile='assets/datasets/coco/annotations/captions_train2014.json', 21 | size=resolution) 22 | save_dir = f'assets/datasets/coco{resolution}_features/train' 23 | elif args.split == "val": 24 | datas = MSCOCODatabase(root='assets/datasets/coco/val2014', 25 | annFile='assets/datasets/coco/annotations/captions_val2014.json', 26 | size=resolution) 27 | save_dir = f'assets/datasets/coco{resolution}_features/val' 28 | else: 29 | raise NotImplementedError("ERROR!") 30 | 31 | device = "cuda" 32 | os.makedirs(save_dir) 33 | 34 | autoencoder = libs.autoencoder.get_model('assets/stable-diffusion/autoencoder_kl.pth') 35 | autoencoder.to(device) 36 | clip = libs.clip.FrozenCLIPEmbedder() 37 | clip.eval() 38 | clip.to(device) 39 | 40 | with torch.no_grad(): 41 | for idx, data in tqdm(enumerate(datas)): 42 | x, captions = data 43 | 44 | if len(x.shape) == 3: 45 | x = x[None, ...] 46 | x = torch.tensor(x, device=device) 47 | moments = autoencoder(x, fn='encode_moments').squeeze(0) 48 | moments = moments.detach().cpu().numpy() 49 | np.save(os.path.join(save_dir, f'{idx}.npy'), moments) 50 | 51 | latent = clip.encode(captions) 52 | for i in range(len(latent)): 53 | c = latent[i].detach().cpu().numpy() 54 | np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/extract_test_prompt_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | 'A green train is coming down the tracks.', 14 | 'A group of skiers are preparing to ski down a mountain.', 15 | 'A small kitchen with a low ceiling.', 16 | 'A group of elephants walking in muddy water.', 17 | 'A living area with a television and a table.', 18 | 'A road with traffic lights, street lights and cars.', 19 | 'A bus driving in a city area with traffic signs.', 20 | 'A bus pulls over to the curb close to an intersection.', 21 | 'A group of people are walking and one is holding an umbrella.', 22 | 'A baseball player taking a swing at an incoming ball.', 23 | 'A city street line with brick buildings and trees.', 24 | 'A close up of a plate of broccoli and sauce.', 25 | ] 26 | 27 | device = 'cuda' 28 | clip = libs.clip.FrozenCLIPEmbedder() 29 | clip.eval() 30 | clip.to(device) 31 | 32 | save_dir = f'assets/datasets/coco256_features/run_vis' 33 | latent = clip.encode(prompts) 34 | for i in range(len(latent)): 35 | c = latent[i].detach().cpu().numpy() 36 | np.save(os.path.join(save_dir, f'{i}.npy'), (prompts[i], c)) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /sde.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from absl import logging 4 | import numpy as np 5 | import math 6 | from tqdm import tqdm 7 | 8 | 9 | def get_sde(name, **kwargs): 10 | if name == 'vpsde': 11 | return VPSDE(**kwargs) 12 | elif name == 'vpsde_cosine': 13 | return VPSDECosine(**kwargs) 14 | else: 15 | raise NotImplementedError 16 | 17 | 18 | def stp(s, ts: torch.Tensor): # scalar tensor product 19 | if isinstance(s, np.ndarray): 20 | s = torch.from_numpy(s).type_as(ts) 21 | extra_dims = (1,) * (ts.dim() - 1) 22 | return s.view(-1, *extra_dims) * ts 23 | 24 | 25 | def mos(a, start_dim=1): # mean of square 26 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 27 | 28 | 29 | def duplicate(tensor, *size): 30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape) 31 | 32 | 33 | class SDE(object): 34 | r""" 35 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1 36 | f(x, t) is the drift 37 | g(t) is the diffusion 38 | """ 39 | def drift(self, x, t): 40 | raise NotImplementedError 41 | 42 | def diffusion(self, t): 43 | raise NotImplementedError 44 | 45 | def cum_beta(self, t): # the variance of xt|x0 46 | raise NotImplementedError 47 | 48 | def cum_alpha(self, t): 49 | raise NotImplementedError 50 | 51 | def snr(self, t): # signal noise ratio 52 | raise NotImplementedError 53 | 54 | def nsr(self, t): # noise signal ratio 55 | raise NotImplementedError 56 | 57 | def marginal_prob(self, x0, t): # the mean and std of q(xt|x0) 58 | alpha = self.cum_alpha(t) 59 | beta = self.cum_beta(t) 60 | mean = stp(alpha ** 0.5, x0) # E[xt|x0] 61 | std = beta ** 0.5 # Cov[xt|x0] ** 0.5 62 | return mean, std 63 | 64 | def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform 65 | t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init 66 | mean, std = self.marginal_prob(x0, t) 67 | eps = torch.randn_like(x0) 68 | xt = mean + stp(std, eps) 69 | return t, eps, xt 70 | 71 | 72 | class VPSDE(SDE): 73 | def __init__(self, beta_min=0.1, beta_max=20): 74 | # 0 <= t <= 1 75 | self.beta_0 = beta_min 76 | self.beta_1 = beta_max 77 | 78 | def drift(self, x, t): 79 | return -0.5 * stp(self.squared_diffusion(t), x) 80 | 81 | def diffusion(self, t): 82 | return self.squared_diffusion(t) ** 0.5 83 | 84 | def squared_diffusion(self, t): # beta(t) 85 | return self.beta_0 + t * (self.beta_1 - self.beta_0) 86 | 87 | def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau 88 | return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5 89 | 90 | def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I 91 | return 1. - self.skip_alpha(s, t) 92 | 93 | def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs 94 | x = -self.squared_diffusion_integral(s, t) 95 | return x.exp() 96 | 97 | def cum_beta(self, t): 98 | return self.skip_beta(0, t) 99 | 100 | def cum_alpha(self, t): 101 | return self.skip_alpha(0, t) 102 | 103 | def nsr(self, t): 104 | return self.squared_diffusion_integral(0, t).expm1() 105 | 106 | def snr(self, t): 107 | return 1. / self.nsr(t) 108 | 109 | def __str__(self): 110 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}' 111 | 112 | def __repr__(self): 113 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}' 114 | 115 | 116 | class VPSDECosine(SDE): 117 | r""" 118 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1 119 | f(x, t) is the drift 120 | g(t) is the diffusion 121 | """ 122 | def __init__(self, s=0.008): 123 | self.s = s 124 | self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2 125 | self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2 126 | 127 | def drift(self, x, t): 128 | ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2 129 | return stp(ft, x) 130 | 131 | def diffusion(self, t): 132 | return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5 133 | 134 | def cum_beta(self, t): # the variance of xt|x0 135 | return 1 - self.cum_alpha(t) 136 | 137 | def cum_alpha(self, t): 138 | return self.F(t) / self.F0 139 | 140 | def snr(self, t): # signal noise ratio 141 | Ft = self.F(t) 142 | return Ft / (self.F0 - Ft) 143 | 144 | def nsr(self, t): # noise signal ratio 145 | Ft = self.F(t) 146 | return self.F0 / Ft - 1 147 | 148 | def __str__(self): 149 | return 'vpsde_cosine' 150 | 151 | def __repr__(self): 152 | return 'vpsde_cosine' 153 | 154 | 155 | class ScoreModel(object): 156 | r""" 157 | The forward process is q(x_[0,T]) 158 | """ 159 | 160 | def __init__(self, nnet: nn.Module, pred: str, sde: SDE, T=1): 161 | assert T == 1 162 | self.nnet = nnet 163 | self.pred = pred 164 | self.sde = sde 165 | self.T = T 166 | print(f'ScoreModel with pred={pred}, sde={sde}, T={T}') 167 | 168 | def predict(self, xt, t, **kwargs): 169 | if not isinstance(t, torch.Tensor): 170 | t = torch.tensor(t) 171 | t = t.to(xt.device) 172 | if t.dim() == 0: 173 | t = duplicate(t, xt.size(0)) 174 | return self.nnet(xt, t * 999, **kwargs) # follow SDE 175 | 176 | def noise_pred(self, xt, t, **kwargs): 177 | pred = self.predict(xt, t, **kwargs) 178 | if self.pred == 'noise_pred': 179 | noise_pred = pred 180 | elif self.pred == 'x0_pred': 181 | noise_pred = - stp(self.sde.snr(t).sqrt(), pred) + stp(self.sde.cum_beta(t).rsqrt(), xt) 182 | else: 183 | raise NotImplementedError 184 | return noise_pred 185 | 186 | def x0_pred(self, xt, t, **kwargs): 187 | pred = self.predict(xt, t, **kwargs) 188 | if self.pred == 'noise_pred': 189 | x0_pred = stp(self.sde.cum_alpha(t).rsqrt(), xt) - stp(self.sde.nsr(t).sqrt(), pred) 190 | elif self.pred == 'x0_pred': 191 | x0_pred = pred 192 | else: 193 | raise NotImplementedError 194 | return x0_pred 195 | 196 | def score(self, xt, t, **kwargs): 197 | cum_beta = self.sde.cum_beta(t) 198 | noise_pred = self.noise_pred(xt, t, **kwargs) 199 | return stp(-cum_beta.rsqrt(), noise_pred) 200 | 201 | 202 | class ReverseSDE(object): 203 | r""" 204 | dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw 205 | """ 206 | def __init__(self, score_model): 207 | self.sde = score_model.sde # the forward sde 208 | self.score_model = score_model 209 | 210 | def drift(self, x, t, **kwargs): 211 | drift = self.sde.drift(x, t) # f(x, t) 212 | diffusion = self.sde.diffusion(t) # g(t) 213 | score = self.score_model.score(x, t, **kwargs) 214 | return drift - stp(diffusion ** 2, score) 215 | 216 | def diffusion(self, t): 217 | return self.sde.diffusion(t) 218 | 219 | 220 | class ODE(object): 221 | r""" 222 | dx = [f(x, t) - g(t)^2 s(x, t)] dt 223 | """ 224 | 225 | def __init__(self, score_model): 226 | self.sde = score_model.sde # the forward sde 227 | self.score_model = score_model 228 | 229 | def drift(self, x, t, **kwargs): 230 | drift = self.sde.drift(x, t) # f(x, t) 231 | diffusion = self.sde.diffusion(t) # g(t) 232 | score = self.score_model.score(x, t, **kwargs) 233 | return drift - 0.5 * stp(diffusion ** 2, score) 234 | 235 | def diffusion(self, t): 236 | return 0 237 | 238 | 239 | def dct2str(dct): 240 | return str({k: f'{v:.6g}' for k, v in dct.items()}) 241 | 242 | 243 | @ torch.no_grad() 244 | def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs): 245 | r""" 246 | The Euler Maruyama sampler for reverse SDE / ODE 247 | See `Score-Based Generative Modeling through Stochastic Differential Equations` 248 | """ 249 | assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE) 250 | print(f"euler_maruyama with sample_steps={sample_steps}") 251 | timesteps = np.append(0., np.linspace(eps, T, sample_steps)) 252 | timesteps = torch.tensor(timesteps).to(x_init) 253 | x = x_init 254 | if trace is not None: 255 | trace.append(x) 256 | for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'): 257 | drift = rsde.drift(x, t, **kwargs) 258 | diffusion = rsde.diffusion(t) 259 | dt = s - t 260 | mean = x + drift * dt 261 | sigma = diffusion * (-dt).sqrt() 262 | x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean 263 | if trace is not None: 264 | trace.append(x) 265 | statistics = dict(s=s, t=t, sigma=sigma.item()) 266 | logging.debug(dct2str(statistics)) 267 | return x 268 | 269 | 270 | def LSimple(score_model: ScoreModel, x0, pred='noise_pred', **kwargs): 271 | t, noise, xt = score_model.sde.sample(x0) 272 | if pred == 'noise_pred': 273 | noise_pred = score_model.noise_pred(xt, t, **kwargs) 274 | return mos(noise - noise_pred) 275 | elif pred == 'x0_pred': 276 | x0_pred = score_model.x0_pred(xt, t, **kwargs) 277 | return mos(x0 - x0_pred) 278 | else: 279 | raise NotImplementedError(pred) 280 | -------------------------------------------------------------------------------- /skip_im.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/U-ViT/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/skip_im.png -------------------------------------------------------------------------------- /tools/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | 37 | import numpy as np 38 | import torch 39 | import torchvision.transforms as TF 40 | from PIL import Image 41 | from scipy import linalg 42 | from torch.nn.functional import adaptive_avg_pool2d 43 | 44 | try: 45 | from tqdm import tqdm 46 | except ImportError: 47 | # If tqdm is not available, provide a mock version of it 48 | def tqdm(x): 49 | return x 50 | 51 | from .inception import InceptionV3 52 | 53 | 54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 55 | 'tif', 'tiff', 'webp'} 56 | 57 | 58 | class ImagePathDataset(torch.utils.data.Dataset): 59 | def __init__(self, files, transforms=None): 60 | self.files = files 61 | self.transforms = transforms 62 | 63 | def __len__(self): 64 | return len(self.files) 65 | 66 | def __getitem__(self, i): 67 | path = self.files[i] 68 | img = Image.open(path).convert('RGB') 69 | if self.transforms is not None: 70 | img = self.transforms(img) 71 | return img 72 | 73 | 74 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): 75 | """Calculates the activations of the pool_3 layer for all images. 76 | 77 | Params: 78 | -- files : List of image files paths 79 | -- model : Instance of inception model 80 | -- batch_size : Batch size of images for the model to process at once. 81 | Make sure that the number of samples is a multiple of 82 | the batch size, otherwise some samples are ignored. This 83 | behavior is retained to match the original FID score 84 | implementation. 85 | -- dims : Dimensionality of features returned by Inception 86 | -- device : Device to run calculations 87 | -- num_workers : Number of parallel dataloader workers 88 | 89 | Returns: 90 | -- A numpy array of dimension (num images, dims) that contains the 91 | activations of the given tensor when feeding inception with the 92 | query tensor. 93 | """ 94 | model.eval() 95 | 96 | if batch_size > len(files): 97 | print(('Warning: batch size is bigger than the data size. ' 98 | 'Setting batch size to data size')) 99 | batch_size = len(files) 100 | 101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 102 | dataloader = torch.utils.data.DataLoader(dataset, 103 | batch_size=batch_size, 104 | shuffle=False, 105 | drop_last=False, 106 | num_workers=num_workers) 107 | 108 | pred_arr = np.empty((len(files), dims)) 109 | 110 | start_idx = 0 111 | 112 | for batch in tqdm(dataloader): 113 | batch = batch.to(device) 114 | 115 | with torch.no_grad(): 116 | pred = model(batch)[0] 117 | 118 | # If model output is not scalar, apply global spatial average pooling. 119 | # This happens if you choose a dimensionality not equal 2048. 120 | if pred.size(2) != 1 or pred.size(3) != 1: 121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 122 | 123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 124 | 125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 126 | 127 | start_idx = start_idx + pred.shape[0] 128 | 129 | return pred_arr 130 | 131 | 132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 133 | """Numpy implementation of the Frechet Distance. 134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 135 | and X_2 ~ N(mu_2, C_2) is 136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 137 | 138 | Stable version by Dougal J. Sutherland. 139 | 140 | Params: 141 | -- mu1 : Numpy array containing the activations of a layer of the 142 | inception net (like returned by the function 'get_predictions') 143 | for generated samples. 144 | -- mu2 : The sample mean over activations, precalculated on an 145 | representative data set. 146 | -- sigma1: The covariance matrix over activations for generated samples. 147 | -- sigma2: The covariance matrix over activations, precalculated on an 148 | representative data set. 149 | 150 | Returns: 151 | -- : The Frechet Distance. 152 | """ 153 | 154 | mu1 = np.atleast_1d(mu1) 155 | mu2 = np.atleast_1d(mu2) 156 | 157 | sigma1 = np.atleast_2d(sigma1) 158 | sigma2 = np.atleast_2d(sigma2) 159 | 160 | assert mu1.shape == mu2.shape, \ 161 | 'Training and test mean vectors have different lengths' 162 | assert sigma1.shape == sigma2.shape, \ 163 | 'Training and test covariances have different dimensions' 164 | 165 | diff = mu1 - mu2 166 | 167 | # Product might be almost singular 168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 169 | if not np.isfinite(covmean).all(): 170 | msg = ('fid calculation produces singular product; ' 171 | 'adding %s to diagonal of cov estimates') % eps 172 | print(msg) 173 | offset = np.eye(sigma1.shape[0]) * eps 174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 175 | 176 | # Numerical error might give slight imaginary component 177 | if np.iscomplexobj(covmean): 178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 179 | m = np.max(np.abs(covmean.imag)) 180 | raise ValueError('Imaginary component {}'.format(m)) 181 | covmean = covmean.real 182 | 183 | tr_covmean = np.trace(covmean) 184 | 185 | return (diff.dot(diff) + np.trace(sigma1) 186 | + np.trace(sigma2) - 2 * tr_covmean) 187 | 188 | 189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 190 | device='cpu', num_workers=8): 191 | """Calculation of the statistics used by the FID. 192 | Params: 193 | -- files : List of image files paths 194 | -- model : Instance of inception model 195 | -- batch_size : The images numpy array is split into batches with 196 | batch size batch_size. A reasonable batch size 197 | depends on the hardware. 198 | -- dims : Dimensionality of features returned by Inception 199 | -- device : Device to run calculations 200 | -- num_workers : Number of parallel dataloader workers 201 | 202 | Returns: 203 | -- mu : The mean over samples of the activations of the pool_3 layer of 204 | the inception model. 205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 206 | the inception model. 207 | """ 208 | act = get_activations(files, model, batch_size, dims, device, num_workers) 209 | mu = np.mean(act, axis=0) 210 | sigma = np.cov(act, rowvar=False) 211 | return mu, sigma 212 | 213 | 214 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): 215 | if path.endswith('.npz'): 216 | with np.load(path) as f: 217 | m, s = f['mu'][:], f['sigma'][:] 218 | else: 219 | path = pathlib.Path(path) 220 | files = sorted([file for ext in IMAGE_EXTENSIONS 221 | for file in path.glob('*.{}'.format(ext))]) 222 | m, s = calculate_activation_statistics(files, model, batch_size, 223 | dims, device, num_workers) 224 | 225 | return m, s 226 | 227 | 228 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8): 229 | if device is None: 230 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 231 | else: 232 | device = torch.device(device) 233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 234 | model = InceptionV3([block_idx]).to(device) 235 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers) 236 | np.savez(out_path, mu=m1, sigma=s1) 237 | 238 | 239 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8): 240 | """Calculates the FID of two paths""" 241 | if device is None: 242 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 243 | else: 244 | device = torch.device(device) 245 | 246 | for p in paths: 247 | if not os.path.exists(p): 248 | raise RuntimeError('Invalid path: %s' % p) 249 | 250 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 251 | 252 | model = InceptionV3([block_idx]).to(device) 253 | 254 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 255 | dims, device, num_workers) 256 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 257 | dims, device, num_workers) 258 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 259 | 260 | return fid_value 261 | -------------------------------------------------------------------------------- /tools/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sde 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | from datasets import get_dataset 6 | from torchvision.utils import make_grid, save_image 7 | import utils 8 | import einops 9 | from torch.utils._pytree import tree_map 10 | import accelerate 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 14 | import tempfile 15 | from tools.fid_score import calculate_fid_given_paths 16 | from absl import logging 17 | import builtins 18 | import os 19 | import wandb 20 | 21 | 22 | def train(config): 23 | if config.get('benchmark', False): 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.deterministic = False 26 | 27 | mp.set_start_method('spawn') 28 | accelerator = accelerate.Accelerator() 29 | device = accelerator.device 30 | accelerate.utils.set_seed(config.seed, device_specific=True) 31 | logging.info(f'Process {accelerator.process_index} using device: {device}') 32 | 33 | config.mixed_precision = accelerator.mixed_precision 34 | config = ml_collections.FrozenConfigDict(config) 35 | 36 | assert config.train.batch_size % accelerator.num_processes == 0 37 | mini_batch_size = config.train.batch_size // accelerator.num_processes 38 | 39 | if accelerator.is_main_process: 40 | os.makedirs(config.ckpt_root, exist_ok=True) 41 | os.makedirs(config.sample_dir, exist_ok=True) 42 | accelerator.wait_for_everyone() 43 | if accelerator.is_main_process: 44 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(), 45 | name=config.hparams, job_type='train', mode='offline') 46 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) 47 | logging.info(config) 48 | else: 49 | utils.set_logger(log_level='error') 50 | builtins.print = lambda *args: None 51 | 52 | dataset = get_dataset(**config.dataset) 53 | assert os.path.exists(dataset.fid_stat) 54 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond') 55 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True, 56 | num_workers=8, pin_memory=True, persistent_workers=True) 57 | 58 | train_state = utils.initialize_train_state(config, device) 59 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare( 60 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader) 61 | lr_scheduler = train_state.lr_scheduler 62 | train_state.resume(config.ckpt_root) 63 | 64 | def get_data_generator(): 65 | while True: 66 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'): 67 | yield data 68 | 69 | data_generator = get_data_generator() 70 | 71 | 72 | # set the score_model to train 73 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 74 | score_model_ema = sde.ScoreModel(nnet_ema, pred=config.pred, sde=sde.VPSDE()) 75 | 76 | 77 | def train_step(_batch): 78 | _metrics = dict() 79 | optimizer.zero_grad() 80 | if config.train.mode == 'uncond': 81 | loss = sde.LSimple(score_model, _batch, pred=config.pred) 82 | elif config.train.mode == 'cond': 83 | loss = sde.LSimple(score_model, _batch[0], pred=config.pred, y=_batch[1]) 84 | else: 85 | raise NotImplementedError(config.train.mode) 86 | _metrics['loss'] = accelerator.gather(loss.detach()).mean() 87 | accelerator.backward(loss.mean()) 88 | if 'grad_clip' in config and config.grad_clip > 0: 89 | accelerator.clip_grad_norm_(nnet.parameters(), max_norm=config.grad_clip) 90 | optimizer.step() 91 | lr_scheduler.step() 92 | train_state.ema_update(config.get('ema_rate', 0.9999)) 93 | train_state.step += 1 94 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) 95 | 96 | 97 | def eval_step(n_samples, sample_steps, algorithm): 98 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm={algorithm}, ' 99 | f'mini_batch_size={config.sample.mini_batch_size}') 100 | 101 | def sample_fn(_n_samples): 102 | _x_init = torch.randn(_n_samples, *dataset.data_shape, device=device) 103 | if config.train.mode == 'uncond': 104 | kwargs = dict() 105 | elif config.train.mode == 'cond': 106 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 107 | else: 108 | raise NotImplementedError 109 | 110 | if algorithm == 'euler_maruyama_sde': 111 | return sde.euler_maruyama(sde.ReverseSDE(score_model_ema), _x_init, sample_steps, **kwargs) 112 | elif algorithm == 'euler_maruyama_ode': 113 | return sde.euler_maruyama(sde.ODE(score_model_ema), _x_init, sample_steps, **kwargs) 114 | elif algorithm == 'dpm_solver': 115 | noise_schedule = NoiseScheduleVP(schedule='linear') 116 | model_fn = model_wrapper( 117 | score_model_ema.noise_pred, 118 | noise_schedule, 119 | time_input_type='0', 120 | model_kwargs=kwargs 121 | ) 122 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 123 | return dpm_solver.sample( 124 | _x_init, 125 | steps=sample_steps, 126 | eps=1e-4, 127 | adaptive_step_size=False, 128 | fast_version=True, 129 | ) 130 | else: 131 | raise NotImplementedError 132 | 133 | with tempfile.TemporaryDirectory() as temp_path: 134 | path = config.sample.path or temp_path 135 | if accelerator.is_main_process: 136 | os.makedirs(path, exist_ok=True) 137 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 138 | 139 | _fid = 0 140 | if accelerator.is_main_process: 141 | _fid = calculate_fid_given_paths((dataset.fid_stat, path)) 142 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}') 143 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f: 144 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f) 145 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step) 146 | _fid = torch.tensor(_fid, device=device) 147 | _fid = accelerator.reduce(_fid, reduction='sum') 148 | 149 | return _fid.item() 150 | 151 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') 152 | 153 | step_fid = [] 154 | while train_state.step < config.train.n_steps: 155 | nnet.train() 156 | batch = tree_map(lambda x: x.to(device), next(data_generator)) 157 | metrics = train_step(batch) 158 | 159 | nnet.eval() 160 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: 161 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) 162 | logging.info(config.workdir) 163 | wandb.log(metrics, step=train_state.step) 164 | 165 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0: 166 | logging.info('Save a grid of images...') 167 | x_init = torch.randn(100, *dataset.data_shape, device=device) 168 | if config.train.mode == 'uncond': 169 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50) 170 | elif config.train.mode == 'cond': 171 | y = einops.repeat(torch.arange(10, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10) 172 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50, y=y) 173 | else: 174 | raise NotImplementedError 175 | samples = make_grid(dataset.unpreprocess(samples), 10) 176 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png')) 177 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step) 178 | torch.cuda.empty_cache() 179 | accelerator.wait_for_everyone() 180 | 181 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: 182 | logging.info(f'Save and eval checkpoint {train_state.step}...') 183 | if accelerator.local_process_index == 0: 184 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt')) 185 | accelerator.wait_for_everyone() 186 | fid = eval_step(n_samples=10000, sample_steps=50, algorithm='dpm_solver') # calculate fid of the saved checkpoint 187 | step_fid.append((train_state.step, fid)) 188 | torch.cuda.empty_cache() 189 | accelerator.wait_for_everyone() 190 | 191 | logging.info(f'Finish fitting, step={train_state.step}') 192 | logging.info(f'step_fid: {step_fid}') 193 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0] 194 | logging.info(f'step_best: {step_best}') 195 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt')) 196 | del metrics 197 | accelerator.wait_for_everyone() 198 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps, algorithm=config.sample.algorithm) 199 | 200 | 201 | 202 | from absl import flags 203 | from absl import app 204 | from ml_collections import config_flags 205 | import sys 206 | from pathlib import Path 207 | 208 | 209 | FLAGS = flags.FLAGS 210 | config_flags.DEFINE_config_file( 211 | "config", None, "Training configuration.", lock_config=False) 212 | flags.mark_flags_as_required(["config"]) 213 | flags.DEFINE_string("workdir", None, "Work unit directory.") 214 | 215 | 216 | def get_config_name(): 217 | argv = sys.argv 218 | for i in range(1, len(argv)): 219 | if argv[i].startswith('--config='): 220 | return Path(argv[i].split('=')[-1]).stem 221 | 222 | 223 | def get_hparams(): 224 | argv = sys.argv 225 | lst = [] 226 | for i in range(1, len(argv)): 227 | assert '=' in argv[i] 228 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): 229 | hparam, val = argv[i].split('=') 230 | hparam = hparam.split('.')[-1] 231 | if hparam.endswith('path'): 232 | val = Path(val).stem 233 | lst.append(f'{hparam}={val}') 234 | hparams = '-'.join(lst) 235 | if hparams == '': 236 | hparams = 'default' 237 | return hparams 238 | 239 | 240 | def main(argv): 241 | config = FLAGS.config 242 | config.config_name = get_config_name() 243 | config.hparams = get_hparams() 244 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams) 245 | config.ckpt_root = os.path.join(config.workdir, 'ckpts') 246 | config.sample_dir = os.path.join(config.workdir, 'samples') 247 | train(config) 248 | 249 | 250 | if __name__ == "__main__": 251 | app.run(main) 252 | -------------------------------------------------------------------------------- /train_ldm.py: -------------------------------------------------------------------------------- 1 | import sde 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | from datasets import get_dataset 6 | from torchvision.utils import make_grid, save_image 7 | import utils 8 | import einops 9 | from torch.utils._pytree import tree_map 10 | import accelerate 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 14 | import tempfile 15 | from tools.fid_score import calculate_fid_given_paths 16 | from absl import logging 17 | import builtins 18 | import os 19 | import wandb 20 | import libs.autoencoder 21 | 22 | 23 | def train(config): 24 | if config.get('benchmark', False): 25 | torch.backends.cudnn.benchmark = True 26 | torch.backends.cudnn.deterministic = False 27 | 28 | mp.set_start_method('spawn') 29 | accelerator = accelerate.Accelerator() 30 | device = accelerator.device 31 | accelerate.utils.set_seed(config.seed, device_specific=True) 32 | logging.info(f'Process {accelerator.process_index} using device: {device}') 33 | 34 | config.mixed_precision = accelerator.mixed_precision 35 | config = ml_collections.FrozenConfigDict(config) 36 | 37 | assert config.train.batch_size % accelerator.num_processes == 0 38 | mini_batch_size = config.train.batch_size // accelerator.num_processes 39 | 40 | if accelerator.is_main_process: 41 | os.makedirs(config.ckpt_root, exist_ok=True) 42 | os.makedirs(config.sample_dir, exist_ok=True) 43 | accelerator.wait_for_everyone() 44 | if accelerator.is_main_process: 45 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(), 46 | name=config.hparams, job_type='train', mode='offline') 47 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) 48 | logging.info(config) 49 | else: 50 | utils.set_logger(log_level='error') 51 | builtins.print = lambda *args: None 52 | logging.info(f'Run on {accelerator.num_processes} devices') 53 | 54 | dataset = get_dataset(**config.dataset) 55 | assert os.path.exists(dataset.fid_stat) 56 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond') 57 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True, 58 | num_workers=8, pin_memory=True, persistent_workers=True) 59 | 60 | train_state = utils.initialize_train_state(config, device) 61 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare( 62 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader) 63 | lr_scheduler = train_state.lr_scheduler 64 | train_state.resume(config.ckpt_root) 65 | 66 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 67 | autoencoder.to(device) 68 | 69 | @ torch.cuda.amp.autocast() 70 | def encode(_batch): 71 | return autoencoder.encode(_batch) 72 | 73 | @ torch.cuda.amp.autocast() 74 | def decode(_batch): 75 | return autoencoder.decode(_batch) 76 | 77 | def get_data_generator(): 78 | while True: 79 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'): 80 | yield data 81 | 82 | data_generator = get_data_generator() 83 | 84 | 85 | # set the score_model to train 86 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 87 | score_model_ema = sde.ScoreModel(nnet_ema, pred=config.pred, sde=sde.VPSDE()) 88 | 89 | 90 | def train_step(_batch): 91 | _metrics = dict() 92 | optimizer.zero_grad() 93 | if config.train.mode == 'uncond': 94 | _z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch) 95 | loss = sde.LSimple(score_model, _z, pred=config.pred) 96 | elif config.train.mode == 'cond': 97 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0]) 98 | loss = sde.LSimple(score_model, _z, pred=config.pred, y=_batch[1]) 99 | else: 100 | raise NotImplementedError(config.train.mode) 101 | _metrics['loss'] = accelerator.gather(loss.detach()).mean() 102 | accelerator.backward(loss.mean()) 103 | optimizer.step() 104 | lr_scheduler.step() 105 | train_state.ema_update(config.get('ema_rate', 0.9999)) 106 | train_state.step += 1 107 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) 108 | 109 | 110 | def eval_step(n_samples, sample_steps, algorithm): 111 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm={algorithm}, ' 112 | f'mini_batch_size={config.sample.mini_batch_size}') 113 | 114 | def sample_fn(_n_samples): 115 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 116 | if config.train.mode == 'uncond': 117 | kwargs = dict() 118 | elif config.train.mode == 'cond': 119 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 120 | else: 121 | raise NotImplementedError 122 | 123 | if algorithm == 'euler_maruyama_sde': 124 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model_ema), _z_init, sample_steps, **kwargs) 125 | elif algorithm == 'euler_maruyama_ode': 126 | _z = sde.euler_maruyama(sde.ODE(score_model_ema), _z_init, sample_steps, **kwargs) 127 | elif algorithm == 'dpm_solver': 128 | noise_schedule = NoiseScheduleVP(schedule='linear') 129 | model_fn = model_wrapper( 130 | score_model_ema.noise_pred, 131 | noise_schedule, 132 | time_input_type='0', 133 | model_kwargs=kwargs 134 | ) 135 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 136 | _z = dpm_solver.sample( 137 | _z_init, 138 | steps=sample_steps, 139 | eps=1e-4, 140 | adaptive_step_size=False, 141 | fast_version=True, 142 | ) 143 | else: 144 | raise NotImplementedError 145 | return decode(_z) 146 | 147 | with tempfile.TemporaryDirectory() as temp_path: 148 | path = config.sample.path or temp_path 149 | if accelerator.is_main_process: 150 | os.makedirs(path, exist_ok=True) 151 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 152 | 153 | _fid = 0 154 | if accelerator.is_main_process: 155 | _fid = calculate_fid_given_paths((dataset.fid_stat, path)) 156 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}') 157 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f: 158 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f) 159 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step) 160 | _fid = torch.tensor(_fid, device=device) 161 | _fid = accelerator.reduce(_fid, reduction='sum') 162 | 163 | return _fid.item() 164 | 165 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') 166 | 167 | step_fid = [] 168 | while train_state.step < config.train.n_steps: 169 | nnet.train() 170 | batch = tree_map(lambda x: x.to(device), next(data_generator)) 171 | metrics = train_step(batch) 172 | 173 | nnet.eval() 174 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: 175 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) 176 | logging.info(config.workdir) 177 | wandb.log(metrics, step=train_state.step) 178 | 179 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0: 180 | torch.cuda.empty_cache() 181 | logging.info('Save a grid of images...') 182 | z_init = torch.randn(5 * 10, *config.z_shape, device=device) 183 | if config.train.mode == 'uncond': 184 | z = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=z_init, sample_steps=50) 185 | elif config.train.mode == 'cond': 186 | y = einops.repeat(torch.arange(5, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10) 187 | z = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=z_init, sample_steps=50, y=y) 188 | else: 189 | raise NotImplementedError 190 | samples = decode(z) 191 | samples = make_grid(dataset.unpreprocess(samples), 10) 192 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png')) 193 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step) 194 | torch.cuda.empty_cache() 195 | accelerator.wait_for_everyone() 196 | 197 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: 198 | torch.cuda.empty_cache() 199 | logging.info(f'Save and eval checkpoint {train_state.step}...') 200 | if accelerator.local_process_index == 0: 201 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt')) 202 | accelerator.wait_for_everyone() 203 | fid = eval_step(n_samples=10000, sample_steps=50, algorithm='dpm_solver') # calculate fid of the saved checkpoint 204 | step_fid.append((train_state.step, fid)) 205 | torch.cuda.empty_cache() 206 | accelerator.wait_for_everyone() 207 | 208 | logging.info(f'Finish fitting, step={train_state.step}') 209 | logging.info(f'step_fid: {step_fid}') 210 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0] 211 | logging.info(f'step_best: {step_best}') 212 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt')) 213 | del metrics 214 | accelerator.wait_for_everyone() 215 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps, algorithm=config.sample.algorithm) 216 | 217 | 218 | 219 | from absl import flags 220 | from absl import app 221 | from ml_collections import config_flags 222 | import sys 223 | from pathlib import Path 224 | 225 | 226 | FLAGS = flags.FLAGS 227 | config_flags.DEFINE_config_file( 228 | "config", None, "Training configuration.", lock_config=False) 229 | flags.mark_flags_as_required(["config"]) 230 | flags.DEFINE_string("workdir", None, "Work unit directory.") 231 | 232 | 233 | def get_config_name(): 234 | argv = sys.argv 235 | for i in range(1, len(argv)): 236 | if argv[i].startswith('--config='): 237 | return Path(argv[i].split('=')[-1]).stem 238 | 239 | 240 | def get_hparams(): 241 | argv = sys.argv 242 | lst = [] 243 | for i in range(1, len(argv)): 244 | assert '=' in argv[i] 245 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): 246 | hparam, val = argv[i].split('=') 247 | hparam = hparam.split('.')[-1] 248 | if hparam.endswith('path'): 249 | val = Path(val).stem 250 | lst.append(f'{hparam}={val}') 251 | hparams = '-'.join(lst) 252 | if hparams == '': 253 | hparams = 'default' 254 | return hparams 255 | 256 | 257 | def main(argv): 258 | config = FLAGS.config 259 | config.config_name = get_config_name() 260 | config.hparams = get_hparams() 261 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams) 262 | config.ckpt_root = os.path.join(config.workdir, 'ckpts') 263 | config.sample_dir = os.path.join(config.workdir, 'samples') 264 | train(config) 265 | 266 | 267 | if __name__ == "__main__": 268 | app.run(main) 269 | -------------------------------------------------------------------------------- /train_ldm_discrete.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | from torch import multiprocessing as mp 4 | from datasets import get_dataset 5 | from torchvision.utils import make_grid, save_image 6 | import utils 7 | import einops 8 | from torch.utils._pytree import tree_map 9 | import accelerate 10 | from torch.utils.data import DataLoader 11 | from tqdm.auto import tqdm 12 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 13 | import tempfile 14 | from tools.fid_score import calculate_fid_given_paths 15 | from absl import logging 16 | import builtins 17 | import os 18 | import wandb 19 | import libs.autoencoder 20 | import numpy as np 21 | 22 | 23 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 24 | _betas = ( 25 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 26 | ) 27 | return _betas.numpy() 28 | 29 | 30 | def get_skip(alphas, betas): 31 | N = len(betas) - 1 32 | skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype) 33 | for s in range(N + 1): 34 | skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod() 35 | skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype) 36 | for t in range(N + 1): 37 | prod = betas[1: t + 1] * skip_alphas[1: t + 1, t] 38 | skip_betas[:t, t] = (prod[::-1].cumsum())[::-1] 39 | return skip_alphas, skip_betas 40 | 41 | 42 | def stp(s, ts: torch.Tensor): # scalar tensor product 43 | if isinstance(s, np.ndarray): 44 | s = torch.from_numpy(s).type_as(ts) 45 | extra_dims = (1,) * (ts.dim() - 1) 46 | return s.view(-1, *extra_dims) * ts 47 | 48 | 49 | def mos(a, start_dim=1): # mean of square 50 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 51 | 52 | 53 | class Schedule(object): # discrete time 54 | def __init__(self, _betas): 55 | r""" _betas[0...999] = betas[1...1000] 56 | for n>=1, betas[n] is the variance of q(xn|xn-1) 57 | for n=0, betas[0]=0 58 | """ 59 | 60 | self._betas = _betas 61 | self.betas = np.append(0., _betas) 62 | self.alphas = 1. - self.betas 63 | self.N = len(_betas) 64 | 65 | assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0 66 | assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1 67 | assert len(self.betas) == len(self.alphas) 68 | 69 | # skip_alphas[s, t] = alphas[s + 1: t + 1].prod() 70 | self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas) 71 | self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod() 72 | self.cum_betas = self.skip_betas[0] 73 | self.snr = self.cum_alphas / self.cum_betas 74 | 75 | def tilde_beta(self, s, t): 76 | return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t] 77 | 78 | def sample(self, x0): # sample from q(xn|x0), where n is uniform 79 | n = np.random.choice(list(range(1, self.N + 1)), (len(x0),)) 80 | eps = torch.randn_like(x0) 81 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps) 82 | return torch.tensor(n, device=x0.device), eps, xn 83 | 84 | def __repr__(self): 85 | return f'Schedule({self.betas[:10]}..., {self.N})' 86 | 87 | 88 | def LSimple(x0, nnet, schedule, **kwargs): 89 | n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000} 90 | eps_pred = nnet(xn, n, **kwargs) 91 | return mos(eps - eps_pred) 92 | 93 | 94 | def train(config): 95 | if config.get('benchmark', False): 96 | torch.backends.cudnn.benchmark = True 97 | torch.backends.cudnn.deterministic = False 98 | 99 | mp.set_start_method('spawn') 100 | accelerator = accelerate.Accelerator() 101 | device = accelerator.device 102 | accelerate.utils.set_seed(config.seed, device_specific=True) 103 | logging.info(f'Process {accelerator.process_index} using device: {device}') 104 | 105 | config.mixed_precision = accelerator.mixed_precision 106 | config = ml_collections.FrozenConfigDict(config) 107 | 108 | assert config.train.batch_size % accelerator.num_processes == 0 109 | mini_batch_size = config.train.batch_size // accelerator.num_processes 110 | 111 | if accelerator.is_main_process: 112 | os.makedirs(config.ckpt_root, exist_ok=True) 113 | os.makedirs(config.sample_dir, exist_ok=True) 114 | accelerator.wait_for_everyone() 115 | if accelerator.is_main_process: 116 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(), 117 | name=config.hparams, job_type='train', mode='offline') 118 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) 119 | logging.info(config) 120 | else: 121 | utils.set_logger(log_level='error') 122 | builtins.print = lambda *args: None 123 | logging.info(f'Run on {accelerator.num_processes} devices') 124 | 125 | dataset = get_dataset(**config.dataset) 126 | assert os.path.exists(dataset.fid_stat) 127 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond') 128 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True, 129 | num_workers=8, pin_memory=True, persistent_workers=True) 130 | 131 | train_state = utils.initialize_train_state(config, device) 132 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare( 133 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader) 134 | lr_scheduler = train_state.lr_scheduler 135 | train_state.resume(config.ckpt_root) 136 | 137 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 138 | autoencoder.to(device) 139 | 140 | @ torch.cuda.amp.autocast() 141 | def encode(_batch): 142 | return autoencoder.encode(_batch) 143 | 144 | @ torch.cuda.amp.autocast() 145 | def decode(_batch): 146 | return autoencoder.decode(_batch) 147 | 148 | def get_data_generator(): 149 | while True: 150 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'): 151 | yield data 152 | 153 | data_generator = get_data_generator() 154 | 155 | _betas = stable_diffusion_beta_schedule() 156 | _schedule = Schedule(_betas) 157 | logging.info(f'use {_schedule}') 158 | 159 | 160 | def train_step(_batch): 161 | _metrics = dict() 162 | optimizer.zero_grad() 163 | if config.train.mode == 'uncond': 164 | _z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch) 165 | loss = LSimple(_z, nnet, _schedule) 166 | elif config.train.mode == 'cond': 167 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0]) 168 | loss = LSimple(_z, nnet, _schedule, y=_batch[1]) 169 | else: 170 | raise NotImplementedError(config.train.mode) 171 | _metrics['loss'] = accelerator.gather(loss.detach()).mean() 172 | accelerator.backward(loss.mean()) 173 | optimizer.step() 174 | lr_scheduler.step() 175 | train_state.ema_update(config.get('ema_rate', 0.9999)) 176 | train_state.step += 1 177 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) 178 | 179 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs): 180 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 181 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 182 | 183 | def model_fn(x, t_continuous): 184 | t = t_continuous * _schedule.N 185 | eps_pre = nnet_ema(x, t, **kwargs) 186 | return eps_pre 187 | 188 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 189 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.) 190 | return decode(_z) 191 | 192 | def eval_step(n_samples, sample_steps): 193 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}' 194 | f'mini_batch_size={config.sample.mini_batch_size}') 195 | 196 | def sample_fn(_n_samples): 197 | if config.train.mode == 'uncond': 198 | kwargs = dict() 199 | elif config.train.mode == 'cond': 200 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 201 | else: 202 | raise NotImplementedError 203 | return dpm_solver_sample(_n_samples, sample_steps, **kwargs) 204 | 205 | 206 | with tempfile.TemporaryDirectory() as temp_path: 207 | path = config.sample.path or temp_path 208 | if accelerator.is_main_process: 209 | os.makedirs(path, exist_ok=True) 210 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 211 | 212 | _fid = 0 213 | if accelerator.is_main_process: 214 | _fid = calculate_fid_given_paths((dataset.fid_stat, path)) 215 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}') 216 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f: 217 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f) 218 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step) 219 | _fid = torch.tensor(_fid, device=device) 220 | _fid = accelerator.reduce(_fid, reduction='sum') 221 | 222 | return _fid.item() 223 | 224 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') 225 | 226 | step_fid = [] 227 | while train_state.step < config.train.n_steps: 228 | nnet.train() 229 | batch = tree_map(lambda x: x.to(device), next(data_generator)) 230 | metrics = train_step(batch) 231 | 232 | nnet.eval() 233 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: 234 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) 235 | logging.info(config.workdir) 236 | wandb.log(metrics, step=train_state.step) 237 | 238 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0: 239 | torch.cuda.empty_cache() 240 | logging.info('Save a grid of images...') 241 | if config.train.mode == 'uncond': 242 | samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50) 243 | elif config.train.mode == 'cond': 244 | y = einops.repeat(torch.arange(5, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10) 245 | samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50, y=y) 246 | else: 247 | raise NotImplementedError 248 | samples = make_grid(dataset.unpreprocess(samples), 10) 249 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png')) 250 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step) 251 | torch.cuda.empty_cache() 252 | accelerator.wait_for_everyone() 253 | 254 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: 255 | torch.cuda.empty_cache() 256 | logging.info(f'Save and eval checkpoint {train_state.step}...') 257 | if accelerator.local_process_index == 0: 258 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt')) 259 | accelerator.wait_for_everyone() 260 | fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint 261 | step_fid.append((train_state.step, fid)) 262 | torch.cuda.empty_cache() 263 | accelerator.wait_for_everyone() 264 | 265 | logging.info(f'Finish fitting, step={train_state.step}') 266 | logging.info(f'step_fid: {step_fid}') 267 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0] 268 | logging.info(f'step_best: {step_best}') 269 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt')) 270 | del metrics 271 | accelerator.wait_for_everyone() 272 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps) 273 | 274 | 275 | 276 | from absl import flags 277 | from absl import app 278 | from ml_collections import config_flags 279 | import sys 280 | from pathlib import Path 281 | 282 | 283 | FLAGS = flags.FLAGS 284 | config_flags.DEFINE_config_file( 285 | "config", None, "Training configuration.", lock_config=False) 286 | flags.mark_flags_as_required(["config"]) 287 | flags.DEFINE_string("workdir", None, "Work unit directory.") 288 | 289 | 290 | def get_config_name(): 291 | argv = sys.argv 292 | for i in range(1, len(argv)): 293 | if argv[i].startswith('--config='): 294 | return Path(argv[i].split('=')[-1]).stem 295 | 296 | 297 | def get_hparams(): 298 | argv = sys.argv 299 | lst = [] 300 | for i in range(1, len(argv)): 301 | assert '=' in argv[i] 302 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): 303 | hparam, val = argv[i].split('=') 304 | hparam = hparam.split('.')[-1] 305 | if hparam.endswith('path'): 306 | val = Path(val).stem 307 | lst.append(f'{hparam}={val}') 308 | hparams = '-'.join(lst) 309 | if hparams == '': 310 | hparams = 'default' 311 | return hparams 312 | 313 | 314 | def main(argv): 315 | config = FLAGS.config 316 | config.config_name = get_config_name() 317 | config.hparams = get_hparams() 318 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams) 319 | config.ckpt_root = os.path.join(config.workdir, 'ckpts') 320 | config.sample_dir = os.path.join(config.workdir, 'samples') 321 | train(config) 322 | 323 | 324 | if __name__ == "__main__": 325 | app.run(main) 326 | -------------------------------------------------------------------------------- /train_t2i_discrete.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | from torch import multiprocessing as mp 4 | from datasets import get_dataset 5 | from torchvision.utils import make_grid, save_image 6 | import utils 7 | import einops 8 | from torch.utils._pytree import tree_map 9 | import accelerate 10 | from torch.utils.data import DataLoader 11 | from tqdm.auto import tqdm 12 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 13 | import tempfile 14 | from tools.fid_score import calculate_fid_given_paths 15 | from absl import logging 16 | import builtins 17 | import os 18 | import wandb 19 | import libs.autoencoder 20 | import numpy as np 21 | 22 | 23 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 24 | _betas = ( 25 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 26 | ) 27 | return _betas.numpy() 28 | 29 | 30 | def get_skip(alphas, betas): 31 | N = len(betas) - 1 32 | skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype) 33 | for s in range(N + 1): 34 | skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod() 35 | skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype) 36 | for t in range(N + 1): 37 | prod = betas[1: t + 1] * skip_alphas[1: t + 1, t] 38 | skip_betas[:t, t] = (prod[::-1].cumsum())[::-1] 39 | return skip_alphas, skip_betas 40 | 41 | 42 | def stp(s, ts: torch.Tensor): # scalar tensor product 43 | if isinstance(s, np.ndarray): 44 | s = torch.from_numpy(s).type_as(ts) 45 | extra_dims = (1,) * (ts.dim() - 1) 46 | return s.view(-1, *extra_dims) * ts 47 | 48 | 49 | def mos(a, start_dim=1): # mean of square 50 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 51 | 52 | 53 | class Schedule(object): # discrete time 54 | def __init__(self, _betas): 55 | r""" _betas[0...999] = betas[1...1000] 56 | for n>=1, betas[n] is the variance of q(xn|xn-1) 57 | for n=0, betas[0]=0 58 | """ 59 | 60 | self._betas = _betas 61 | self.betas = np.append(0., _betas) 62 | self.alphas = 1. - self.betas 63 | self.N = len(_betas) 64 | 65 | assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0 66 | assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1 67 | assert len(self.betas) == len(self.alphas) 68 | 69 | # skip_alphas[s, t] = alphas[s + 1: t + 1].prod() 70 | self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas) 71 | self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod() 72 | self.cum_betas = self.skip_betas[0] 73 | self.snr = self.cum_alphas / self.cum_betas 74 | 75 | def tilde_beta(self, s, t): 76 | return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t] 77 | 78 | def sample(self, x0): # sample from q(xn|x0), where n is uniform 79 | n = np.random.choice(list(range(1, self.N + 1)), (len(x0),)) 80 | eps = torch.randn_like(x0) 81 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps) 82 | return torch.tensor(n, device=x0.device), eps, xn 83 | 84 | def __repr__(self): 85 | return f'Schedule({self.betas[:10]}..., {self.N})' 86 | 87 | 88 | def LSimple(x0, nnet, schedule, **kwargs): 89 | n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000} 90 | eps_pred = nnet(xn, n, **kwargs) 91 | return mos(eps - eps_pred) 92 | 93 | 94 | def train(config): 95 | if config.get('benchmark', False): 96 | torch.backends.cudnn.benchmark = True 97 | torch.backends.cudnn.deterministic = False 98 | 99 | mp.set_start_method('spawn') 100 | accelerator = accelerate.Accelerator() 101 | device = accelerator.device 102 | accelerate.utils.set_seed(config.seed, device_specific=True) 103 | logging.info(f'Process {accelerator.process_index} using device: {device}') 104 | 105 | config.mixed_precision = accelerator.mixed_precision 106 | config = ml_collections.FrozenConfigDict(config) 107 | 108 | assert config.train.batch_size % accelerator.num_processes == 0 109 | mini_batch_size = config.train.batch_size // accelerator.num_processes 110 | 111 | if accelerator.is_main_process: 112 | os.makedirs(config.ckpt_root, exist_ok=True) 113 | os.makedirs(config.sample_dir, exist_ok=True) 114 | accelerator.wait_for_everyone() 115 | if accelerator.is_main_process: 116 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(), 117 | name=config.hparams, job_type='train', mode='offline') 118 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) 119 | logging.info(config) 120 | else: 121 | utils.set_logger(log_level='error') 122 | builtins.print = lambda *args: None 123 | logging.info(f'Run on {accelerator.num_processes} devices') 124 | 125 | dataset = get_dataset(**config.dataset) 126 | assert os.path.exists(dataset.fid_stat) 127 | train_dataset = dataset.get_split(split='train', labeled=True) 128 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True, 129 | num_workers=8, pin_memory=True, persistent_workers=True) 130 | test_dataset = dataset.get_split(split='test', labeled=True) # for sampling 131 | test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True, 132 | num_workers=8, pin_memory=True, persistent_workers=True) 133 | 134 | train_state = utils.initialize_train_state(config, device) 135 | nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare( 136 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader) 137 | lr_scheduler = train_state.lr_scheduler 138 | train_state.resume(config.ckpt_root) 139 | 140 | autoencoder = libs.autoencoder.get_model(**config.autoencoder) 141 | autoencoder.to(device) 142 | 143 | @ torch.cuda.amp.autocast() 144 | def encode(_batch): 145 | return autoencoder.encode(_batch) 146 | 147 | @ torch.cuda.amp.autocast() 148 | def decode(_batch): 149 | return autoencoder.decode(_batch) 150 | 151 | def get_data_generator(): 152 | while True: 153 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'): 154 | yield data 155 | 156 | data_generator = get_data_generator() 157 | 158 | def get_context_generator(): 159 | while True: 160 | for data in test_dataset_loader: 161 | _, _context = data 162 | yield _context 163 | 164 | context_generator = get_context_generator() 165 | 166 | _betas = stable_diffusion_beta_schedule() 167 | _schedule = Schedule(_betas) 168 | logging.info(f'use {_schedule}') 169 | 170 | def cfg_nnet(x, timesteps, context): 171 | _cond = nnet_ema(x, timesteps, context=context) 172 | _empty_context = torch.tensor(dataset.empty_context, device=device) 173 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) 174 | _uncond = nnet_ema(x, timesteps, context=_empty_context) 175 | return _cond + config.sample.scale * (_cond - _uncond) 176 | 177 | def train_step(_batch): 178 | _metrics = dict() 179 | optimizer.zero_grad() 180 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0]) 181 | loss = LSimple(_z, nnet, _schedule, context=_batch[1]) # currently only support the extracted feature version 182 | _metrics['loss'] = accelerator.gather(loss.detach()).mean() 183 | accelerator.backward(loss.mean()) 184 | optimizer.step() 185 | lr_scheduler.step() 186 | train_state.ema_update(config.get('ema_rate', 0.9999)) 187 | train_state.step += 1 188 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) 189 | 190 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs): 191 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 192 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 193 | 194 | def model_fn(x, t_continuous): 195 | t = t_continuous * _schedule.N 196 | return cfg_nnet(x, t, **kwargs) 197 | 198 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 199 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.) 200 | return decode(_z) 201 | 202 | def eval_step(n_samples, sample_steps): 203 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=dpm_solver, ' 204 | f'mini_batch_size={config.sample.mini_batch_size}') 205 | 206 | def sample_fn(_n_samples): 207 | _context = next(context_generator) 208 | assert _context.size(0) == _n_samples 209 | return dpm_solver_sample(_n_samples, sample_steps, context=_context) 210 | 211 | with tempfile.TemporaryDirectory() as temp_path: 212 | path = config.sample.path or temp_path 213 | if accelerator.is_main_process: 214 | os.makedirs(path, exist_ok=True) 215 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 216 | 217 | _fid = 0 218 | if accelerator.is_main_process: 219 | _fid = calculate_fid_given_paths((dataset.fid_stat, path)) 220 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}') 221 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f: 222 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f) 223 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step) 224 | _fid = torch.tensor(_fid, device=device) 225 | _fid = accelerator.reduce(_fid, reduction='sum') 226 | 227 | return _fid.item() 228 | 229 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') 230 | 231 | step_fid = [] 232 | while train_state.step < config.train.n_steps: 233 | nnet.train() 234 | batch = tree_map(lambda x: x.to(device), next(data_generator)) 235 | metrics = train_step(batch) 236 | 237 | nnet.eval() 238 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: 239 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) 240 | logging.info(config.workdir) 241 | wandb.log(metrics, step=train_state.step) 242 | 243 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0: 244 | torch.cuda.empty_cache() 245 | logging.info('Save a grid of images...') 246 | contexts = torch.tensor(dataset.contexts, device=device)[: 2 * 5] 247 | samples = dpm_solver_sample(_n_samples=2 * 5, _sample_steps=50, context=contexts) 248 | samples = make_grid(dataset.unpreprocess(samples), 5) 249 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png')) 250 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step) 251 | torch.cuda.empty_cache() 252 | accelerator.wait_for_everyone() 253 | 254 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: 255 | torch.cuda.empty_cache() 256 | logging.info(f'Save and eval checkpoint {train_state.step}...') 257 | if accelerator.local_process_index == 0: 258 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt')) 259 | accelerator.wait_for_everyone() 260 | fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint 261 | step_fid.append((train_state.step, fid)) 262 | torch.cuda.empty_cache() 263 | accelerator.wait_for_everyone() 264 | 265 | logging.info(f'Finish fitting, step={train_state.step}') 266 | logging.info(f'step_fid: {step_fid}') 267 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0] 268 | logging.info(f'step_best: {step_best}') 269 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt')) 270 | del metrics 271 | accelerator.wait_for_everyone() 272 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps) 273 | 274 | 275 | 276 | from absl import flags 277 | from absl import app 278 | from ml_collections import config_flags 279 | import sys 280 | from pathlib import Path 281 | 282 | 283 | FLAGS = flags.FLAGS 284 | config_flags.DEFINE_config_file( 285 | "config", None, "Training configuration.", lock_config=False) 286 | flags.mark_flags_as_required(["config"]) 287 | flags.DEFINE_string("workdir", None, "Work unit directory.") 288 | 289 | 290 | def get_config_name(): 291 | argv = sys.argv 292 | for i in range(1, len(argv)): 293 | if argv[i].startswith('--config='): 294 | return Path(argv[i].split('=')[-1]).stem 295 | 296 | 297 | def get_hparams(): 298 | argv = sys.argv 299 | lst = [] 300 | for i in range(1, len(argv)): 301 | assert '=' in argv[i] 302 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): 303 | hparam, val = argv[i].split('=') 304 | hparam = hparam.split('.')[-1] 305 | if hparam.endswith('path'): 306 | val = Path(val).stem 307 | lst.append(f'{hparam}={val}') 308 | hparams = '-'.join(lst) 309 | if hparams == '': 310 | hparams = 'default' 311 | return hparams 312 | 313 | 314 | def main(argv): 315 | config = FLAGS.config 316 | config.config_name = get_config_name() 317 | config.hparams = get_hparams() 318 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams) 319 | config.ckpt_root = os.path.join(config.workdir, 'ckpts') 320 | config.sample_dir = os.path.join(config.workdir, 'samples') 321 | train(config) 322 | 323 | 324 | if __name__ == "__main__": 325 | app.run(main) 326 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | from torchvision.utils import save_image 7 | from absl import logging 8 | 9 | 10 | def set_logger(log_level='info', fname=None): 11 | import logging as _logging 12 | handler = logging.get_absl_handler() 13 | formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') 14 | handler.setFormatter(formatter) 15 | logging.set_verbosity(log_level) 16 | if fname is not None: 17 | handler = _logging.FileHandler(fname) 18 | handler.setFormatter(formatter) 19 | logging.get_absl_logger().addHandler(handler) 20 | 21 | 22 | def dct2str(dct): 23 | return str({k: f'{v:.6g}' for k, v in dct.items()}) 24 | 25 | 26 | def get_nnet(name, **kwargs): 27 | if name == 'uvit': 28 | from libs.uvit import UViT 29 | return UViT(**kwargs) 30 | elif name == 'uvit_t2i': 31 | from libs.uvit_t2i import UViT 32 | return UViT(**kwargs) 33 | else: 34 | raise NotImplementedError(name) 35 | 36 | 37 | def set_seed(seed: int): 38 | if seed is not None: 39 | torch.manual_seed(seed) 40 | np.random.seed(seed) 41 | 42 | 43 | def get_optimizer(params, name, **kwargs): 44 | if name == 'adam': 45 | from torch.optim import Adam 46 | return Adam(params, **kwargs) 47 | elif name == 'adamw': 48 | from torch.optim import AdamW 49 | return AdamW(params, **kwargs) 50 | else: 51 | raise NotImplementedError(name) 52 | 53 | 54 | def customized_lr_scheduler(optimizer, warmup_steps=-1): 55 | from torch.optim.lr_scheduler import LambdaLR 56 | def fn(step): 57 | if warmup_steps > 0: 58 | return min(step / warmup_steps, 1) 59 | else: 60 | return 1 61 | return LambdaLR(optimizer, fn) 62 | 63 | 64 | def get_lr_scheduler(optimizer, name, **kwargs): 65 | if name == 'customized': 66 | return customized_lr_scheduler(optimizer, **kwargs) 67 | elif name == 'cosine': 68 | from torch.optim.lr_scheduler import CosineAnnealingLR 69 | return CosineAnnealingLR(optimizer, **kwargs) 70 | else: 71 | raise NotImplementedError(name) 72 | 73 | 74 | def ema(model_dest: nn.Module, model_src: nn.Module, rate): 75 | param_dict_src = dict(model_src.named_parameters()) 76 | for p_name, p_dest in model_dest.named_parameters(): 77 | p_src = param_dict_src[p_name] 78 | assert p_src is not p_dest 79 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) 80 | 81 | 82 | class TrainState(object): 83 | def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): 84 | self.optimizer = optimizer 85 | self.lr_scheduler = lr_scheduler 86 | self.step = step 87 | self.nnet = nnet 88 | self.nnet_ema = nnet_ema 89 | 90 | def ema_update(self, rate=0.9999): 91 | if self.nnet_ema is not None: 92 | ema(self.nnet_ema, self.nnet, rate) 93 | 94 | def save(self, path): 95 | os.makedirs(path, exist_ok=True) 96 | torch.save(self.step, os.path.join(path, 'step.pth')) 97 | for key, val in self.__dict__.items(): 98 | if key != 'step' and val is not None: 99 | torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) 100 | 101 | def load(self, path): 102 | logging.info(f'load from {path}') 103 | self.step = torch.load(os.path.join(path, 'step.pth')) 104 | for key, val in self.__dict__.items(): 105 | if key != 'step' and val is not None: 106 | val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) 107 | 108 | def resume(self, ckpt_root, step=None): 109 | if not os.path.exists(ckpt_root): 110 | return 111 | if step is None: 112 | ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) 113 | if not ckpts: 114 | return 115 | steps = map(lambda x: int(x.split(".")[0]), ckpts) 116 | step = max(steps) 117 | ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') 118 | logging.info(f'resume from {ckpt_path}') 119 | self.load(ckpt_path) 120 | 121 | def to(self, device): 122 | for key, val in self.__dict__.items(): 123 | if isinstance(val, nn.Module): 124 | val.to(device) 125 | 126 | 127 | def cnt_params(model): 128 | return sum(param.numel() for param in model.parameters()) 129 | 130 | 131 | def initialize_train_state(config, device): 132 | params = [] 133 | 134 | nnet = get_nnet(**config.nnet) 135 | params += nnet.parameters() 136 | nnet_ema = get_nnet(**config.nnet) 137 | nnet_ema.eval() 138 | logging.info(f'nnet has {cnt_params(nnet)} parameters') 139 | 140 | optimizer = get_optimizer(params, **config.optimizer) 141 | lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) 142 | 143 | train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, 144 | nnet=nnet, nnet_ema=nnet_ema) 145 | train_state.ema_update(0) 146 | train_state.to(device) 147 | return train_state 148 | 149 | 150 | def amortize(n_samples, batch_size): 151 | k = n_samples // batch_size 152 | r = n_samples % batch_size 153 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 154 | 155 | 156 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None): 157 | os.makedirs(path, exist_ok=True) 158 | idx = 0 159 | batch_size = mini_batch_size * accelerator.num_processes 160 | 161 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): 162 | samples = unpreprocess_fn(sample_fn(mini_batch_size)) 163 | samples = accelerator.gather(samples.contiguous())[:_batch_size] 164 | if accelerator.is_main_process: 165 | for sample in samples: 166 | save_image(sample, os.path.join(path, f"{idx}.png")) 167 | idx += 1 168 | 169 | 170 | def grad_norm(model): 171 | total_norm = 0. 172 | for p in model.parameters(): 173 | param_norm = p.grad.data.norm(2) 174 | total_norm += param_norm.item() ** 2 175 | total_norm = total_norm ** (1. / 2) 176 | return total_norm 177 | -------------------------------------------------------------------------------- /uvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/U-ViT/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/uvit.png --------------------------------------------------------------------------------