├── LICENSE
├── README.md
├── UViT_ImageNet_demo.ipynb
├── configs
├── celeba64_uvit_small.py
├── cifar10_uvit_small.py
├── imagenet256_uvit_huge.py
├── imagenet256_uvit_large.py
├── imagenet512_uvit_huge.py
├── imagenet512_uvit_large.py
├── imagenet64_uvit_large.py
├── imagenet64_uvit_mid.py
└── mscoco_uvit_small.py
├── datasets.py
├── dpm_solver_pp.py
├── dpm_solver_pytorch.py
├── eval.py
├── eval_ldm.py
├── eval_ldm_discrete.py
├── eval_t2i_discrete.py
├── libs
├── __init__.py
├── autoencoder.py
├── clip.py
├── timm.py
├── uvit.py
└── uvit_t2i.py
├── sample.png
├── sample_t2i_discrete.py
├── scripts
├── extract_empty_feature.py
├── extract_imagenet_feature.py
├── extract_mscoco_feature.py
└── extract_test_prompt_feature.py
├── sde.py
├── skip_im.png
├── tools
├── fid_score.py
└── inception.py
├── train.py
├── train_ldm.py
├── train_ldm_discrete.py
├── train_t2i_discrete.py
├── utils.py
└── uvit.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Fan Bao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## U-ViT
Official PyTorch implementation of [All are Worth Words: A ViT Backbone for Diffusion Models](https://arxiv.org/abs/2209.12152) (CVPR 2023)
2 |
3 |
4 | 💡Projects with U-ViT:
5 | * [UniDiffuser](https://github.com/thu-ml/unidiffuser), a multi-modal large-scale diffusion model based on a 1B U-ViT, is open-sourced
6 | * [DPT](https://arxiv.org/abs/2302.10586), [code](https://github.com/ML-GSAI/DPT), [demo](https://ml-gsai.github.io/DPT-demo) a conditional diffusion model trained with 1 label/class with SOTA SSL generation and classification results on ImageNet
7 |
8 |
9 |
10 | Vision transformers (ViT) have shown promise in various vision tasks while the U-Net based on a convolutional neural network (CNN) remains dominant in diffusion models.
11 | We design a simple and general ViT-based architecture (named U-ViT) for image generation with diffusion models.
12 | U-ViT is characterized by treating all inputs including the time, condition and noisy image patches as tokens
13 | and employing long skip connections between shallow and deep layers.
14 | We evaluate U-ViT in unconditional and class-conditional image generation,
15 | as well as text-to-image generation tasks, where U-ViT is comparable if not superior to a CNN-based U-Net of a similar size.
16 | In particular, latent diffusion models with U-ViT achieve record-breaking FID scores of 2.29 in class-conditional image generation
17 | on ImageNet 256x256, and 5.48 in text-to-image generation on MS-COCO, among methods without accessing
18 | large external datasets during the training of generative models.
19 |
20 | Our results suggest that, for diffusion-based image modeling, the long skip connection is crucial while the down-sampling and up-sampling operators in CNN-based U-Net are not always necessary. We believe that U-ViT can provide insights for future research on backbones in diffusion models and benefit generative modeling on large scale cross-modality datasets.
21 |
22 | --------------------
23 |
24 |
25 |
26 | This codebase implements the transformer-based backbone 📌*U-ViT*📌 for diffusion models, as introduced in the [paper](https://arxiv.org/abs/2209.12152).
27 | U-ViT treats all inputs as tokens and employs long skip connections. *The long skip connections grealy promote the performance and the convergence speed*.
28 |
29 |
30 |
31 |
32 |
33 |
34 | 💡This codebase contains:
35 | * An implementation of [U-ViT](libs/uvit.py) with optimized attention computation
36 | * Pretrained U-ViT models on common image generation benchmarks (CIFAR10, CelebA 64x64, ImageNet 64x64, ImageNet 256x256, ImageNet 512x512)
37 | * Efficient training scripts for [pixel-space diffusion models](train.py), [latent space diffusion models](train_ldm_discrete.py) and [text-to-image diffusion models](train_t2i_discrete.py)
38 | * Efficient evaluation scripts for [pixel-space diffusion models](eval.py) and [latent space diffusion models](eval_ldm_discrete.py) and [text-to-image diffusion models](eval_t2i_discrete.py)
39 | * A Colab notebook demo for sampling from U-ViT on ImageNet (FID=2.29) [](https://colab.research.google.com/github/baofff/U-ViT/blob/main/UViT_ImageNet_demo.ipynb)
40 |
41 |
42 |
43 |
44 |
45 | 💡This codebase supports useful techniques for efficient training and sampling of diffusion models:
46 | * Mixed precision training with the [huggingface accelerate](https://github.com/huggingface/accelerate) library (🥰automatically turned on)
47 | * Efficient attention computation with the [facebook xformers](https://github.com/facebookresearch/xformers) library (needs additional installation)
48 | * Gradient checkpointing trick, which reduces ~65% memory (🥰automatically turned on)
49 | * With these techniques, we are able to train our largest U-ViT-H on ImageNet at high resolutions such as 256x256 and 512x512 using a large batch size of 1024 with *only 2 A100*❗
50 |
51 |
52 | Training speed and memory of U-ViT-H/2 on ImageNet 256x256 using a batch size of 128 with a A100:
53 |
54 | | mixed precision training | xformers | gradient checkpointing | training speed | memory |
55 | |:------------------------:|:--------:|:----------------------:|:-----------------:|:-------------:|
56 | | ❌ | ❌ | ❌ | - | out of memory |
57 | | ✔ | ❌ | ❌ | 0.97 steps/second | 78852 MB |
58 | | ✔ | ✔ | ❌ | 1.14 steps/second | 54324 MB |
59 | | ✔ | ✔ | ✔ | 0.87 steps/second | 18858 MB |
60 |
61 |
62 |
63 | ## Dependency
64 |
65 | ```sh
66 | pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 # install torch-1.13.1
67 | pip install accelerate==0.12.0 absl-py ml_collections einops wandb ftfy==6.1.1 transformers==4.23.1
68 |
69 | # xformers is optional, but it would greatly speed up the attention computation.
70 | pip install -U xformers
71 | pip install -U --pre triton
72 | ```
73 |
74 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. (Perhaps other versions also work, but I haven't tested it.)
75 | * We highly suggest install [xformers](https://github.com/facebookresearch/xformers), which would greatly speed up the attention computation for *both training and inference*.
76 |
77 |
78 |
79 | ## Pretrained Models
80 |
81 |
82 | | Model | FID | training iterations | batch size |
83 | |:----------------------------------------------------------------------------------------------------------------------:|:-----:|:-------------------:|:----------:|
84 | | [CIFAR10 (U-ViT-S/2)](https://drive.google.com/file/d/1yoYyuzR_hQYWU0mkTj659tMTnoCWCMv-/view?usp=share_link) | 3.11 | 500K | 128 |
85 | | [CelebA 64x64 (U-ViT-S/4)](https://drive.google.com/file/d/13YpbRtlqF1HDBNLNRlKxLTbKbKeLE06C/view?usp=share_link) | 2.87 | 500K | 128 |
86 | | [ImageNet 64x64 (U-ViT-M/4)](https://drive.google.com/file/d/1igVgRY7-A0ZV3XqdNcMGOnIGOxKr9azv/view?usp=share_link) | 5.85 | 300K | 1024 |
87 | | [ImageNet 64x64 (U-ViT-L/4)](https://drive.google.com/file/d/19rmun-T7RwkNC1feEPWinIo-1JynpW7J/view?usp=share_link) | 4.26 | 300K | 1024 |
88 | | [ImageNet 256x256 (U-ViT-L/2)](https://drive.google.com/file/d/1w7T1hiwKODgkYyMH9Nc9JNUThbxFZgs3/view?usp=share_link) | 3.40 | 300K | 1024 |
89 | | [ImageNet 256x256 (U-ViT-H/2)](https://drive.google.com/file/d/13StUdrjaaSXjfqqF7M47BzPyhMAArQ4u/view?usp=share_link) | 2.29 | 500K | 1024 |
90 | | [ImageNet 512x512 (U-ViT-L/4)](https://drive.google.com/file/d/1mkj4aN2utHMBTWQX9l1nYue9vleL7ZSB/view?usp=share_link) | 4.67 | 500K | 1024 |
91 | | [ImageNet 512x512 (U-ViT-H/4)](https://drive.google.com/file/d/1uegr2o7cuKXtf2akWGAN2Vnlrtw5YKQq/view?usp=share_link) | 4.05 | 500K | 1024 |
92 | | [MS-COCO (U-ViT-S/2)](https://drive.google.com/file/d/15JsZWRz2byYNU6K093et5e5Xqd4uwA8S/view?usp=share_link) | 5.95 | 1M | 256 |
93 | | [MS-COCO (U-ViT-S/2, Deep)](https://drive.google.com/file/d/1gHRy8sn039Wy-iFL21wH8TiheHK8Ky71/view?usp=share_link) | 5.48 | 1M | 256 |
94 |
95 |
96 |
97 | ## Preparation Before Training and Evaluation
98 |
99 | #### Autoencoder
100 | Download `stable-diffusion` directory from this [link](https://drive.google.com/drive/folders/1yo-XhqbPue3rp5P57j6QbA5QZx6KybvP?usp=sharing) (which contains image autoencoders converted from [Stable Diffusion](https://github.com/CompVis/stable-diffusion)).
101 | Put the downloaded directory as `assets/stable-diffusion` in this codebase.
102 | The autoencoders are used in latent diffusion models.
103 |
104 | #### Data
105 | * ImageNet 64x64: Put the standard ImageNet dataset (which contains the `train` and `val` directory) to `assets/datasets/ImageNet`.
106 | * ImageNet 256x256 and ImageNet 512x512: Extract ImageNet features according to `scripts/extract_imagenet_feature.py`.
107 | * MS-COCO: Download COCO 2014 [training](http://images.cocodataset.org/zips/train2014.zip), [validation](http://images.cocodataset.org/zips/val2014.zip) data and [annotations](http://images.cocodataset.org/annotations/annotations_trainval2014.zip). Then extract their features according to `scripts/extract_mscoco_feature.py` `scripts/extract_test_prompt_feature.py` `scripts/extract_empty_feature.py`.
108 |
109 | #### Reference statistics for FID
110 | Download `fid_stats` directory from this [link](https://drive.google.com/drive/folders/1yo-XhqbPue3rp5P57j6QbA5QZx6KybvP?usp=sharing) (which contains reference statistics for FID).
111 | Put the downloaded directory as `assets/fid_stats` in this codebase.
112 | In addition to evaluation, these reference statistics are used to monitor FID during the training process.
113 |
114 | ## Training
115 |
116 |
117 |
118 | We use the [huggingface accelerate](https://github.com/huggingface/accelerate) library to help train with distributed data parallel and mixed precision. The following is the training command:
119 | ```sh
120 | # the training setting
121 | num_processes=2 # the number of gpus you have, e.g., 2
122 | train_script=train.py # the train script, one of
123 | # train.py: training on pixel space
124 | # train_ldm.py: training on latent space with continuous timesteps
125 | # train_ldm_discrete.py: training on latent space with discrete timesteps
126 | # train_t2i_discrete.py: text-to-image training on latent space
127 | config=configs/cifar10_uvit_small.py # the training configuration
128 | # you can change other hyperparameters by modifying the configuration file
129 |
130 | # launch training
131 | accelerate launch --multi_gpu --num_processes $num_processes --mixed_precision fp16 $train_script --config=$config
132 | ```
133 |
134 |
135 | We provide all commands to reproduce U-ViT training in the paper:
136 | ```sh
137 | # CIFAR10 (U-ViT-S/2)
138 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/cifar10_uvit_small.py
139 |
140 | # CelebA 64x64 (U-ViT-S/4)
141 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/celeba64_uvit_small.py
142 |
143 | # ImageNet 64x64 (U-ViT-M/4)
144 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train.py --config=configs/imagenet64_uvit_mid.py
145 |
146 | # ImageNet 64x64 (U-ViT-L/4)
147 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train.py --config=configs/imagenet64_uvit_large.py
148 |
149 | # ImageNet 256x256 (U-ViT-L/2)
150 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm.py --config=configs/imagenet256_uvit_large.py
151 |
152 | # ImageNet 256x256 (U-ViT-H/2)
153 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm_discrete.py --config=configs/imagenet256_uvit_huge.py
154 |
155 | # ImageNet 512x512 (U-ViT-L/4)
156 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm.py --config=configs/imagenet512_uvit_large.py
157 |
158 | # ImageNet 512x512 (U-ViT-H/4)
159 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 train_ldm_discrete.py --config=configs/imagenet512_uvit_huge.py
160 |
161 | # MS-COCO (U-ViT-S/2)
162 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train_t2i_discrete.py --config=configs/mscoco_uvit_small.py
163 |
164 | # MS-COCO (U-ViT-S/2, Deep)
165 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train_t2i_discrete.py --config=configs/mscoco_uvit_small.py --config.nnet.depth=16
166 | ```
167 |
168 |
169 |
170 | ## Evaluation (Compute FID)
171 |
172 | We use the [huggingface accelerate](https://github.com/huggingface/accelerate) library for efficient inference with mixed precision and multiple gpus. The following is the evaluation command:
173 | ```sh
174 | # the evaluation setting
175 | num_processes=2 # the number of gpus you have, e.g., 2
176 | eval_script=eval.py # the evaluation script, one of
177 | # eval.py: for models trained with train.py (i.e., pixel space models)
178 | # eval_ldm.py: for models trained with train_ldm.py (i.e., latent space models with continuous timesteps)
179 | # eval_ldm_discrete.py: for models trained with train_ldm_discrete.py (i.e., latent space models with discrete timesteps)
180 | # eval_t2i_discrete.py: for models trained with train_t2i_discrete.py (i.e., text-to-image models on latent space)
181 | config=configs/cifar10_uvit_small.py # the training configuration
182 |
183 | # launch evaluation
184 | accelerate launch --multi_gpu --num_processes $num_processes --mixed_precision fp16 eval_script --config=$config
185 | ```
186 | The generated images are stored in a temperary directory, and will be deleted after evaluation. If you want to keep these images, set `--config.sample.path=/save/dir`.
187 |
188 |
189 | We provide all commands to reproduce FID results in the paper:
190 | ```sh
191 | # CIFAR10 (U-ViT-S/2)
192 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/cifar10_uvit_small.py --nnet_path=cifar10_uvit_small.pth
193 |
194 | # CelebA 64x64 (U-ViT-S/4)
195 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/celeba64_uvit_small.py --nnet_path=celeba64_uvit_small.pth
196 |
197 | # ImageNet 64x64 (U-ViT-M/4)
198 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval.py --config=configs/imagenet64_uvit_mid.py --nnet_path=imagenet64_uvit_mid.pth
199 |
200 | # ImageNet 64x64 (U-ViT-L/4)
201 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval.py --config=configs/imagenet64_uvit_large.py --nnet_path=imagenet64_uvit_large.pth
202 |
203 | # ImageNet 256x256 (U-ViT-L/2)
204 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm.py --config=configs/imagenet256_uvit_large.py --nnet_path=imagenet256_uvit_large.pth
205 |
206 | # ImageNet 256x256 (U-ViT-H/2)
207 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm_discrete.py --config=configs/imagenet256_uvit_huge.py --nnet_path=imagenet256_uvit_huge.pth
208 |
209 | # ImageNet 512x512 (U-ViT-L/4)
210 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm.py --config=configs/imagenet512_uvit_large.py --nnet_path=imagenet512_uvit_large.pth
211 |
212 | # ImageNet 512x512 (U-ViT-H/4)
213 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 eval_ldm_discrete.py --config=configs/imagenet512_uvit_huge.py --nnet_path=imagenet512_uvit_huge.pth
214 |
215 | # MS-COCO (U-ViT-S/2)
216 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval_t2i_discrete.py --config=configs/mscoco_uvit_small.py --nnet_path=mscoco_uvit_small.pth
217 |
218 | # MS-COCO (U-ViT-S/2, Deep)
219 | accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval_t2i_discrete.py --config=configs/mscoco_uvit_small.py --config.nnet.depth=16 --nnet_path=mscoco_uvit_small_deep.pth
220 | ```
221 |
222 |
223 |
224 |
225 | ## References
226 | If you find the code useful for your research, please consider citing
227 | ```bib
228 | @inproceedings{bao2022all,
229 | title={All are Worth Words: A ViT Backbone for Diffusion Models},
230 | author={Bao, Fan and Nie, Shen and Xue, Kaiwen and Cao, Yue and Li, Chongxuan and Su, Hang and Zhu, Jun},
231 | booktitle = {CVPR},
232 | year={2023}
233 | }
234 | ```
235 |
236 | This implementation is based on
237 | * [Extended Analytic-DPM](https://github.com/baofff/Extended-Analytic-DPM) (provide the FID reference statistics on CIFAR10 and CelebA 64x64)
238 | * [guided-diffusion](https://github.com/openai/guided-diffusion) (provide the FID reference statistics on ImageNet)
239 | * [pytorch-fid](https://github.com/mseitzer/pytorch-fid) (provide the official implementation of FID to PyTorch)
240 | * [dpm-solver](https://github.com/LuChengTHU/dpm-solver) (provide the sampler)
241 |
--------------------------------------------------------------------------------
/configs/celeba64_uvit_small.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 |
15 | config.train = d(
16 | n_steps=500000,
17 | batch_size=128,
18 | mode='uncond',
19 | log_interval=10,
20 | eval_interval=5000,
21 | save_interval=50000,
22 | )
23 |
24 | config.optimizer = d(
25 | name='adamw',
26 | lr=0.0002,
27 | weight_decay=0.03,
28 | betas=(0.99, 0.99),
29 | )
30 |
31 | config.lr_scheduler = d(
32 | name='customized',
33 | warmup_steps=5000
34 | )
35 |
36 | config.nnet = d(
37 | name='uvit',
38 | img_size=64,
39 | patch_size=4,
40 | embed_dim=512,
41 | depth=12,
42 | num_heads=8,
43 | mlp_ratio=4,
44 | qkv_bias=False,
45 | mlp_time_embed=False,
46 | num_classes=-1,
47 | )
48 |
49 | config.dataset = d(
50 | name='celeba',
51 | path='assets/datasets/celeba',
52 | resolution=64,
53 | )
54 |
55 | config.sample = d(
56 | sample_steps=1000,
57 | n_samples=50000,
58 | mini_batch_size=500,
59 | algorithm='euler_maruyama_sde',
60 | path=''
61 | )
62 |
63 | return config
64 |
--------------------------------------------------------------------------------
/configs/cifar10_uvit_small.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 |
15 | config.train = d(
16 | n_steps=500000,
17 | batch_size=128,
18 | mode='uncond',
19 | log_interval=10,
20 | eval_interval=5000,
21 | save_interval=50000,
22 | )
23 |
24 | config.optimizer = d(
25 | name='adamw',
26 | lr=0.0002,
27 | weight_decay=0.03,
28 | betas=(0.99, 0.999),
29 | )
30 |
31 | config.lr_scheduler = d(
32 | name='customized',
33 | warmup_steps=2500
34 | )
35 |
36 | config.nnet = d(
37 | name='uvit',
38 | img_size=32,
39 | patch_size=2,
40 | embed_dim=512,
41 | depth=12,
42 | num_heads=8,
43 | mlp_ratio=4,
44 | qkv_bias=False,
45 | mlp_time_embed=False,
46 | num_classes=-1,
47 | )
48 |
49 | config.dataset = d(
50 | name='cifar10',
51 | path='assets/datasets/cifar10',
52 | random_flip=True,
53 | )
54 |
55 | config.sample = d(
56 | sample_steps=1000,
57 | n_samples=50000,
58 | mini_batch_size=500,
59 | algorithm='euler_maruyama_sde',
60 | path=''
61 | )
62 |
63 | return config
64 |
--------------------------------------------------------------------------------
/configs/imagenet256_uvit_huge.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 32, 32)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | config.train = d(
21 | n_steps=500000,
22 | batch_size=1024,
23 | mode='cond',
24 | log_interval=10,
25 | eval_interval=5000,
26 | save_interval=50000,
27 | )
28 |
29 | config.optimizer = d(
30 | name='adamw',
31 | lr=0.0002,
32 | weight_decay=0.03,
33 | betas=(0.99, 0.99),
34 | )
35 |
36 | config.lr_scheduler = d(
37 | name='customized',
38 | warmup_steps=5000
39 | )
40 |
41 | config.nnet = d(
42 | name='uvit',
43 | img_size=32,
44 | patch_size=2,
45 | in_chans=4,
46 | embed_dim=1152,
47 | depth=28,
48 | num_heads=16,
49 | mlp_ratio=4,
50 | qkv_bias=False,
51 | mlp_time_embed=False,
52 | num_classes=1001,
53 | use_checkpoint=True,
54 | conv=False
55 | )
56 |
57 | config.dataset = d(
58 | name='imagenet256_features',
59 | path='assets/datasets/imagenet256_features',
60 | cfg=True,
61 | p_uncond=0.1
62 | )
63 |
64 | config.sample = d(
65 | sample_steps=50,
66 | n_samples=50000,
67 | mini_batch_size=50, # the decoder is large
68 | algorithm='dpm_solver',
69 | cfg=True,
70 | scale=0.4,
71 | path=''
72 | )
73 |
74 | return config
75 |
--------------------------------------------------------------------------------
/configs/imagenet256_uvit_large.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 32, 32)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth'
18 | )
19 |
20 | config.train = d(
21 | n_steps=300000,
22 | batch_size=1024,
23 | mode='cond',
24 | log_interval=10,
25 | eval_interval=5000,
26 | save_interval=50000,
27 | )
28 |
29 | config.optimizer = d(
30 | name='adamw',
31 | lr=0.0002,
32 | weight_decay=0.03,
33 | betas=(0.99, 0.99),
34 | )
35 |
36 | config.lr_scheduler = d(
37 | name='customized',
38 | warmup_steps=5000
39 | )
40 |
41 | config.nnet = d(
42 | name='uvit',
43 | img_size=32,
44 | patch_size=2,
45 | in_chans=4,
46 | embed_dim=1024,
47 | depth=20,
48 | num_heads=16,
49 | mlp_ratio=4,
50 | qkv_bias=False,
51 | mlp_time_embed=False,
52 | num_classes=1001,
53 | use_checkpoint=True
54 | )
55 |
56 | config.dataset = d(
57 | name='imagenet256_features',
58 | path='assets/datasets/imagenet256_features',
59 | cfg=True,
60 | p_uncond=0.15
61 | )
62 |
63 | config.sample = d(
64 | sample_steps=50,
65 | n_samples=50000,
66 | mini_batch_size=50, # the decoder is large
67 | algorithm='dpm_solver',
68 | cfg=True,
69 | scale=0.4,
70 | path=''
71 | )
72 |
73 | return config
74 |
--------------------------------------------------------------------------------
/configs/imagenet512_uvit_huge.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 64, 64)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | config.train = d(
21 | n_steps=500000,
22 | batch_size=1024,
23 | mode='cond',
24 | log_interval=10,
25 | eval_interval=5000,
26 | save_interval=50000,
27 | )
28 |
29 | config.optimizer = d(
30 | name='adamw',
31 | lr=0.0002,
32 | weight_decay=0.03,
33 | betas=(0.99, 0.99),
34 | )
35 |
36 | config.lr_scheduler = d(
37 | name='customized',
38 | warmup_steps=5000
39 | )
40 |
41 | config.nnet = d(
42 | name='uvit',
43 | img_size=64,
44 | patch_size=4,
45 | in_chans=4,
46 | embed_dim=1152,
47 | depth=28,
48 | num_heads=16,
49 | mlp_ratio=4,
50 | qkv_bias=False,
51 | mlp_time_embed=False,
52 | num_classes=1001,
53 | use_checkpoint=True,
54 | conv=False
55 | )
56 |
57 | config.dataset = d(
58 | name='imagenet512_features',
59 | path='assets/datasets/imagenet512_features',
60 | cfg=True,
61 | p_uncond=0.1
62 | )
63 |
64 | config.sample = d(
65 | sample_steps=50,
66 | n_samples=50000,
67 | mini_batch_size=50, # the decoder is large
68 | algorithm='dpm_solver',
69 | cfg=True,
70 | scale=0.7,
71 | path=''
72 | )
73 |
74 | return config
75 |
--------------------------------------------------------------------------------
/configs/imagenet512_uvit_large.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 64, 64)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth'
18 | )
19 |
20 | config.train = d(
21 | n_steps=500000,
22 | batch_size=1024,
23 | mode='cond',
24 | log_interval=10,
25 | eval_interval=5000,
26 | save_interval=50000,
27 | )
28 |
29 | config.optimizer = d(
30 | name='adamw',
31 | lr=0.0002,
32 | weight_decay=0.03,
33 | betas=(0.99, 0.99),
34 | )
35 |
36 | config.lr_scheduler = d(
37 | name='customized',
38 | warmup_steps=5000
39 | )
40 |
41 | config.nnet = d(
42 | name='uvit',
43 | img_size=64,
44 | patch_size=4,
45 | in_chans=4,
46 | embed_dim=1024,
47 | depth=20,
48 | num_heads=16,
49 | mlp_ratio=4,
50 | qkv_bias=False,
51 | mlp_time_embed=False,
52 | num_classes=1001,
53 | use_checkpoint=True
54 | )
55 |
56 | config.dataset = d(
57 | name='imagenet512_features',
58 | path='assets/datasets/imagenet512_features',
59 | cfg=True,
60 | p_uncond=0.15
61 | )
62 |
63 | config.sample = d(
64 | sample_steps=50,
65 | n_samples=50000,
66 | mini_batch_size=50, # the decoder is large
67 | algorithm='dpm_solver',
68 | cfg=True,
69 | scale=0.7,
70 | path=''
71 | )
72 |
73 | return config
74 |
--------------------------------------------------------------------------------
/configs/imagenet64_uvit_large.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 |
15 | config.train = d(
16 | n_steps=300000,
17 | batch_size=1024,
18 | mode='cond',
19 | log_interval=10,
20 | eval_interval=5000,
21 | save_interval=50000,
22 | )
23 |
24 | config.optimizer = d(
25 | name='adamw',
26 | lr=0.0003,
27 | weight_decay=0.03,
28 | betas=(0.99, 0.99),
29 | )
30 |
31 | config.lr_scheduler = d(
32 | name='customized',
33 | warmup_steps=5000
34 | )
35 |
36 | config.nnet = d(
37 | name='uvit',
38 | img_size=64,
39 | patch_size=4,
40 | embed_dim=1024,
41 | depth=20,
42 | num_heads=16,
43 | mlp_ratio=4,
44 | qkv_bias=False,
45 | mlp_time_embed=False,
46 | num_classes=1000,
47 | use_checkpoint=True
48 | )
49 |
50 | config.dataset = d(
51 | name='imagenet',
52 | path='assets/datasets/ImageNet',
53 | resolution=64,
54 | )
55 |
56 | config.sample = d(
57 | sample_steps=50,
58 | n_samples=50000,
59 | mini_batch_size=200,
60 | algorithm='dpm_solver',
61 | path=''
62 | )
63 |
64 | return config
65 |
--------------------------------------------------------------------------------
/configs/imagenet64_uvit_mid.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 |
15 | config.train = d(
16 | n_steps=300000,
17 | batch_size=1024,
18 | mode='cond',
19 | log_interval=10,
20 | eval_interval=5000,
21 | save_interval=50000,
22 | )
23 |
24 | config.optimizer = d(
25 | name='adamw',
26 | lr=0.0003,
27 | weight_decay=0.03,
28 | betas=(0.99, 0.99),
29 | )
30 |
31 | config.lr_scheduler = d(
32 | name='customized',
33 | warmup_steps=5000
34 | )
35 |
36 | config.nnet = d(
37 | name='uvit',
38 | img_size=64,
39 | patch_size=4,
40 | embed_dim=768,
41 | depth=16,
42 | num_heads=12,
43 | mlp_ratio=4,
44 | qkv_bias=False,
45 | mlp_time_embed=False,
46 | num_classes=1000,
47 | use_checkpoint=True
48 | )
49 |
50 | config.dataset = d(
51 | name='imagenet',
52 | path='assets/datasets/ImageNet',
53 | resolution=64,
54 | )
55 |
56 | config.sample = d(
57 | sample_steps=50,
58 | n_samples=50000,
59 | mini_batch_size=200,
60 | algorithm='dpm_solver',
61 | path=''
62 | )
63 |
64 | return config
65 |
--------------------------------------------------------------------------------
/configs/mscoco_uvit_small.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.z_shape = (4, 32, 32)
14 |
15 | config.autoencoder = d(
16 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
17 | scale_factor=0.23010
18 | )
19 |
20 | config.train = d(
21 | n_steps=1000000,
22 | batch_size=256,
23 | log_interval=10,
24 | eval_interval=5000,
25 | save_interval=50000,
26 | )
27 |
28 | config.optimizer = d(
29 | name='adamw',
30 | lr=0.0002,
31 | weight_decay=0.03,
32 | betas=(0.9, 0.9),
33 | )
34 |
35 | config.lr_scheduler = d(
36 | name='customized',
37 | warmup_steps=5000
38 | )
39 |
40 | config.nnet = d(
41 | name='uvit_t2i',
42 | img_size=32,
43 | in_chans=4,
44 | patch_size=2,
45 | embed_dim=512,
46 | depth=12,
47 | num_heads=8,
48 | mlp_ratio=4,
49 | qkv_bias=False,
50 | mlp_time_embed=False,
51 | clip_dim=768,
52 | num_clip_token=77
53 | )
54 |
55 | config.dataset = d(
56 | name='mscoco256_features',
57 | path='assets/datasets/coco256_features',
58 | cfg=True,
59 | p_uncond=0.1
60 | )
61 |
62 | config.sample = d(
63 | sample_steps=50,
64 | n_samples=30000,
65 | mini_batch_size=50,
66 | cfg=True,
67 | scale=1.,
68 | path=''
69 | )
70 |
71 | return config
72 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from torchvision import datasets
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import torch
6 | import math
7 | import random
8 | from PIL import Image
9 | import os
10 | import glob
11 | import einops
12 | import torchvision.transforms.functional as F
13 |
14 |
15 | class UnlabeledDataset(Dataset):
16 | def __init__(self, dataset):
17 | self.dataset = dataset
18 |
19 | def __len__(self):
20 | return len(self.dataset)
21 |
22 | def __getitem__(self, item):
23 | data = tuple(self.dataset[item][:-1]) # remove label
24 | if len(data) == 1:
25 | data = data[0]
26 | return data
27 |
28 |
29 | class LabeledDataset(Dataset):
30 | def __init__(self, dataset, labels):
31 | self.dataset = dataset
32 | self.labels = labels
33 |
34 | def __len__(self):
35 | return len(self.dataset)
36 |
37 | def __getitem__(self, item):
38 | return self.dataset[item], self.labels[item]
39 |
40 |
41 | class CFGDataset(Dataset): # for classifier free guidance
42 | def __init__(self, dataset, p_uncond, empty_token):
43 | self.dataset = dataset
44 | self.p_uncond = p_uncond
45 | self.empty_token = empty_token
46 |
47 | def __len__(self):
48 | return len(self.dataset)
49 |
50 | def __getitem__(self, item):
51 | x, y = self.dataset[item]
52 | if random.random() < self.p_uncond:
53 | y = self.empty_token
54 | return x, y
55 |
56 |
57 | class DatasetFactory(object):
58 |
59 | def __init__(self):
60 | self.train = None
61 | self.test = None
62 |
63 | def get_split(self, split, labeled=False):
64 | if split == "train":
65 | dataset = self.train
66 | elif split == "test":
67 | dataset = self.test
68 | else:
69 | raise ValueError
70 |
71 | if self.has_label:
72 | return dataset if labeled else UnlabeledDataset(dataset)
73 | else:
74 | assert not labeled
75 | return dataset
76 |
77 | def unpreprocess(self, v): # to B C H W and [0, 1]
78 | v = 0.5 * (v + 1.)
79 | v.clamp_(0., 1.)
80 | return v
81 |
82 | @property
83 | def has_label(self):
84 | return True
85 |
86 | @property
87 | def data_shape(self):
88 | raise NotImplementedError
89 |
90 | @property
91 | def data_dim(self):
92 | return int(np.prod(self.data_shape))
93 |
94 | @property
95 | def fid_stat(self):
96 | return None
97 |
98 | def sample_label(self, n_samples, device):
99 | raise NotImplementedError
100 |
101 | def label_prob(self, k):
102 | raise NotImplementedError
103 |
104 |
105 | # CIFAR10
106 |
107 | class CIFAR10(DatasetFactory):
108 | r""" CIFAR10 dataset
109 |
110 | Information of the raw dataset:
111 | train: 50,000
112 | test: 10,000
113 | shape: 3 * 32 * 32
114 | """
115 |
116 | def __init__(self, path, random_flip=False, cfg=False, p_uncond=None):
117 | super().__init__()
118 |
119 | transform_train = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
120 | transform_test = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
121 | if random_flip: # only for train
122 | transform_train.append(transforms.RandomHorizontalFlip())
123 | transform_train = transforms.Compose(transform_train)
124 | transform_test = transforms.Compose(transform_test)
125 | self.train = datasets.CIFAR10(path, train=True, transform=transform_train, download=True)
126 | self.test = datasets.CIFAR10(path, train=False, transform=transform_test, download=True)
127 |
128 | assert len(self.train.targets) == 50000
129 | self.K = max(self.train.targets) + 1
130 | self.cnt = torch.tensor([len(np.where(np.array(self.train.targets) == k)[0]) for k in range(self.K)]).float()
131 | self.frac = [self.cnt[k] / 50000 for k in range(self.K)]
132 | print(f'{self.K} classes')
133 | print(f'cnt: {self.cnt}')
134 | print(f'frac: {self.frac}')
135 |
136 | if cfg: # classifier free guidance
137 | assert p_uncond is not None
138 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
139 | self.train = CFGDataset(self.train, p_uncond, self.K)
140 |
141 | @property
142 | def data_shape(self):
143 | return 3, 32, 32
144 |
145 | @property
146 | def fid_stat(self):
147 | return 'assets/fid_stats/fid_stats_cifar10_train_pytorch.npz'
148 |
149 | def sample_label(self, n_samples, device):
150 | return torch.multinomial(self.cnt, n_samples, replacement=True).to(device)
151 |
152 | def label_prob(self, k):
153 | return self.frac[k]
154 |
155 |
156 | # ImageNet
157 |
158 |
159 | class FeatureDataset(Dataset):
160 | def __init__(self, path):
161 | super().__init__()
162 | self.path = path
163 | # names = sorted(os.listdir(path))
164 | # self.files = [os.path.join(path, name) for name in names]
165 |
166 | def __len__(self):
167 | return 1_281_167 * 2 # consider the random flip
168 |
169 | def __getitem__(self, idx):
170 | path = os.path.join(self.path, f'{idx}.npy')
171 | z, label = np.load(path, allow_pickle=True)
172 | return z, label
173 |
174 |
175 | class ImageNet256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
176 | def __init__(self, path, cfg=False, p_uncond=None):
177 | super().__init__()
178 | print('Prepare dataset...')
179 | self.train = FeatureDataset(path)
180 | print('Prepare dataset ok')
181 | self.K = 1000
182 |
183 | if cfg: # classifier free guidance
184 | assert p_uncond is not None
185 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
186 | self.train = CFGDataset(self.train, p_uncond, self.K)
187 |
188 | @property
189 | def data_shape(self):
190 | return 4, 32, 32
191 |
192 | @property
193 | def fid_stat(self):
194 | return f'assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz'
195 |
196 | def sample_label(self, n_samples, device):
197 | return torch.randint(0, 1000, (n_samples,), device=device)
198 |
199 |
200 | class ImageNet512Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
201 | def __init__(self, path, cfg=False, p_uncond=None):
202 | super().__init__()
203 | print('Prepare dataset...')
204 | self.train = FeatureDataset(path)
205 | print('Prepare dataset ok')
206 | self.K = 1000
207 |
208 | if cfg: # classifier free guidance
209 | assert p_uncond is not None
210 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
211 | self.train = CFGDataset(self.train, p_uncond, self.K)
212 |
213 | @property
214 | def data_shape(self):
215 | return 4, 64, 64
216 |
217 | @property
218 | def fid_stat(self):
219 | return f'assets/fid_stats/fid_stats_imagenet512_guided_diffusion.npz'
220 |
221 | def sample_label(self, n_samples, device):
222 | return torch.randint(0, 1000, (n_samples,), device=device)
223 |
224 |
225 | class ImageNet(DatasetFactory):
226 | def __init__(self, path, resolution, random_crop=False, random_flip=True):
227 | super().__init__()
228 |
229 | print(f'Counting ImageNet files from {path}')
230 | train_files = _list_image_files_recursively(os.path.join(path, 'train'))
231 | class_names = [os.path.basename(path).split("_")[0] for path in train_files]
232 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
233 | train_labels = [sorted_classes[x] for x in class_names]
234 | print('Finish counting ImageNet files')
235 |
236 | self.train = ImageDataset(resolution, train_files, labels=train_labels, random_crop=random_crop, random_flip=random_flip)
237 | self.resolution = resolution
238 | if len(self.train) != 1_281_167:
239 | print(f'Missing train samples: {len(self.train)} < 1281167')
240 |
241 | self.K = max(self.train.labels) + 1
242 | cnt = dict(zip(*np.unique(self.train.labels, return_counts=True)))
243 | self.cnt = torch.tensor([cnt[k] for k in range(self.K)]).float()
244 | self.frac = [self.cnt[k] / len(self.train.labels) for k in range(self.K)]
245 | print(f'{self.K} classes')
246 | print(f'cnt[:10]: {self.cnt[:10]}')
247 | print(f'frac[:10]: {self.frac[:10]}')
248 |
249 | @property
250 | def data_shape(self):
251 | return 3, self.resolution, self.resolution
252 |
253 | @property
254 | def fid_stat(self):
255 | return f'assets/fid_stats/fid_stats_imagenet{self.resolution}_guided_diffusion.npz'
256 |
257 | def sample_label(self, n_samples, device):
258 | return torch.multinomial(self.cnt, n_samples, replacement=True).to(device)
259 |
260 | def label_prob(self, k):
261 | return self.frac[k]
262 |
263 |
264 | def _list_image_files_recursively(data_dir):
265 | results = []
266 | for entry in sorted(os.listdir(data_dir)):
267 | full_path = os.path.join(data_dir, entry)
268 | ext = entry.split(".")[-1]
269 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
270 | results.append(full_path)
271 | elif os.listdir(full_path):
272 | results.extend(_list_image_files_recursively(full_path))
273 | return results
274 |
275 |
276 | class ImageDataset(Dataset):
277 | def __init__(
278 | self,
279 | resolution,
280 | image_paths,
281 | labels,
282 | random_crop=False,
283 | random_flip=True,
284 | ):
285 | super().__init__()
286 | self.resolution = resolution
287 | self.image_paths = image_paths
288 | self.labels = labels
289 | self.random_crop = random_crop
290 | self.random_flip = random_flip
291 |
292 | def __len__(self):
293 | return len(self.image_paths)
294 |
295 | def __getitem__(self, idx):
296 | path = self.image_paths[idx]
297 | pil_image = Image.open(path)
298 | pil_image.load()
299 | pil_image = pil_image.convert("RGB")
300 |
301 | if self.random_crop:
302 | arr = random_crop_arr(pil_image, self.resolution)
303 | else:
304 | arr = center_crop_arr(pil_image, self.resolution)
305 |
306 | if self.random_flip and random.random() < 0.5:
307 | arr = arr[:, ::-1]
308 |
309 | arr = arr.astype(np.float32) / 127.5 - 1
310 |
311 | label = np.array(self.labels[idx], dtype=np.int64)
312 | return np.transpose(arr, [2, 0, 1]), label
313 |
314 |
315 | def center_crop_arr(pil_image, image_size):
316 | # We are not on a new enough PIL to support the `reducing_gap`
317 | # argument, which uses BOX downsampling at powers of two first.
318 | # Thus, we do it by hand to improve downsample quality.
319 | while min(*pil_image.size) >= 2 * image_size:
320 | pil_image = pil_image.resize(
321 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
322 | )
323 |
324 | scale = image_size / min(*pil_image.size)
325 | pil_image = pil_image.resize(
326 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
327 | )
328 |
329 | arr = np.array(pil_image)
330 | crop_y = (arr.shape[0] - image_size) // 2
331 | crop_x = (arr.shape[1] - image_size) // 2
332 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
333 |
334 |
335 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
336 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
337 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
338 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
339 |
340 | # We are not on a new enough PIL to support the `reducing_gap`
341 | # argument, which uses BOX downsampling at powers of two first.
342 | # Thus, we do it by hand to improve downsample quality.
343 | while min(*pil_image.size) >= 2 * smaller_dim_size:
344 | pil_image = pil_image.resize(
345 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
346 | )
347 |
348 | scale = smaller_dim_size / min(*pil_image.size)
349 | pil_image = pil_image.resize(
350 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
351 | )
352 |
353 | arr = np.array(pil_image)
354 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
355 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
356 | return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
357 |
358 |
359 | # CelebA
360 |
361 |
362 | class Crop(object):
363 | def __init__(self, x1, x2, y1, y2):
364 | self.x1 = x1
365 | self.x2 = x2
366 | self.y1 = y1
367 | self.y2 = y2
368 |
369 | def __call__(self, img):
370 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)
371 |
372 | def __repr__(self):
373 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
374 | self.x1, self.x2, self.y1, self.y2
375 | )
376 |
377 |
378 | class CelebA(DatasetFactory):
379 | r""" train: 162,770
380 | val: 19,867
381 | test: 19,962
382 | shape: 3 * width * width
383 | """
384 |
385 | def __init__(self, path, resolution=64):
386 | super().__init__()
387 |
388 | self.resolution = resolution
389 |
390 | cx = 89
391 | cy = 121
392 | x1 = cy - 64
393 | x2 = cy + 64
394 | y1 = cx - 64
395 | y2 = cx + 64
396 |
397 | transform = transforms.Compose([Crop(x1, x2, y1, y2), transforms.Resize(self.resolution),
398 | transforms.RandomHorizontalFlip(), transforms.ToTensor(),
399 | transforms.Normalize(0.5, 0.5)])
400 | self.train = datasets.CelebA(root=path, split="train", target_type=[], transform=transform, download=True)
401 | self.train = UnlabeledDataset(self.train)
402 |
403 | @property
404 | def data_shape(self):
405 | return 3, self.resolution, self.resolution
406 |
407 | @property
408 | def fid_stat(self):
409 | return 'assets/fid_stats/fid_stats_celeba64_train_50000_ddim.npz'
410 |
411 | @property
412 | def has_label(self):
413 | return False
414 |
415 |
416 | # MS COCO
417 |
418 |
419 | def center_crop(width, height, img):
420 | resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
421 | crop = np.min(img.shape[:2])
422 | img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
423 | (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
424 | try:
425 | img = Image.fromarray(img, 'RGB')
426 | except:
427 | img = Image.fromarray(img)
428 | img = img.resize((width, height), resample)
429 |
430 | return np.array(img).astype(np.uint8)
431 |
432 |
433 | class MSCOCODatabase(Dataset):
434 | def __init__(self, root, annFile, size=None):
435 | from pycocotools.coco import COCO
436 | self.root = root
437 | self.height = self.width = size
438 |
439 | self.coco = COCO(annFile)
440 | self.keys = list(sorted(self.coco.imgs.keys()))
441 |
442 | def _load_image(self, key: int):
443 | path = self.coco.loadImgs(key)[0]["file_name"]
444 | return Image.open(os.path.join(self.root, path)).convert("RGB")
445 |
446 | def _load_target(self, key: int):
447 | return self.coco.loadAnns(self.coco.getAnnIds(key))
448 |
449 | def __len__(self):
450 | return len(self.keys)
451 |
452 | def __getitem__(self, index):
453 | key = self.keys[index]
454 | image = self._load_image(key)
455 | image = np.array(image).astype(np.uint8)
456 | image = center_crop(self.width, self.height, image).astype(np.float32)
457 | image = (image / 127.5 - 1.0).astype(np.float32)
458 | image = einops.rearrange(image, 'h w c -> c h w')
459 |
460 | anns = self._load_target(key)
461 | target = []
462 | for ann in anns:
463 | target.append(ann['caption'])
464 |
465 | return image, target
466 |
467 |
468 | def get_feature_dir_info(root):
469 | files = glob.glob(os.path.join(root, '*.npy'))
470 | files_caption = glob.glob(os.path.join(root, '*_*.npy'))
471 | num_data = len(files) - len(files_caption)
472 | n_captions = {k: 0 for k in range(num_data)}
473 | for f in files_caption:
474 | name = os.path.split(f)[-1]
475 | k1, k2 = os.path.splitext(name)[0].split('_')
476 | n_captions[int(k1)] += 1
477 | return num_data, n_captions
478 |
479 |
480 | class MSCOCOFeatureDataset(Dataset):
481 | # the image features are got through sample
482 | def __init__(self, root):
483 | self.root = root
484 | self.num_data, self.n_captions = get_feature_dir_info(root)
485 |
486 | def __len__(self):
487 | return self.num_data
488 |
489 | def __getitem__(self, index):
490 | z = np.load(os.path.join(self.root, f'{index}.npy'))
491 | k = random.randint(0, self.n_captions[index] - 1)
492 | c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
493 | return z, c
494 |
495 |
496 | class MSCOCO256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
497 | def __init__(self, path, cfg=False, p_uncond=None):
498 | super().__init__()
499 | print('Prepare dataset...')
500 | self.train = MSCOCOFeatureDataset(os.path.join(path, 'train'))
501 | self.test = MSCOCOFeatureDataset(os.path.join(path, 'val'))
502 | assert len(self.train) == 82783
503 | assert len(self.test) == 40504
504 | print('Prepare dataset ok')
505 |
506 | self.empty_context = np.load(os.path.join(path, 'empty_context.npy'))
507 |
508 | if cfg: # classifier free guidance
509 | assert p_uncond is not None
510 | print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
511 | self.train = CFGDataset(self.train, p_uncond, self.empty_context)
512 |
513 | # text embedding extracted by clip
514 | # for visulization in t2i
515 | self.prompts, self.contexts = [], []
516 | for f in sorted(os.listdir(os.path.join(path, 'run_vis')), key=lambda x: int(x.split('.')[0])):
517 | prompt, context = np.load(os.path.join(path, 'run_vis', f), allow_pickle=True)
518 | self.prompts.append(prompt)
519 | self.contexts.append(context)
520 | self.contexts = np.array(self.contexts)
521 |
522 | @property
523 | def data_shape(self):
524 | return 4, 32, 32
525 |
526 | @property
527 | def fid_stat(self):
528 | return f'assets/fid_stats/fid_stats_mscoco256_val.npz'
529 |
530 |
531 | def get_dataset(name, **kwargs):
532 | if name == 'cifar10':
533 | return CIFAR10(**kwargs)
534 | elif name == 'imagenet':
535 | return ImageNet(**kwargs)
536 | elif name == 'imagenet256_features':
537 | return ImageNet256Features(**kwargs)
538 | elif name == 'imagenet512_features':
539 | return ImageNet512Features(**kwargs)
540 | elif name == 'celeba':
541 | return CelebA(**kwargs)
542 | elif name == 'mscoco256_features':
543 | return MSCOCO256Features(**kwargs)
544 | else:
545 | raise NotImplementedError(name)
546 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | import sde
8 | from datasets import get_dataset
9 | import tempfile
10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
11 | from absl import logging
12 | import builtins
13 |
14 |
15 | def evaluate(config):
16 | if config.get('benchmark', False):
17 | torch.backends.cudnn.benchmark = True
18 | torch.backends.cudnn.deterministic = False
19 |
20 | mp.set_start_method('spawn')
21 | accelerator = accelerate.Accelerator()
22 | device = accelerator.device
23 | accelerate.utils.set_seed(config.seed, device_specific=True)
24 | logging.info(f'Process {accelerator.process_index} using device: {device}')
25 |
26 | config.mixed_precision = accelerator.mixed_precision
27 | config = ml_collections.FrozenConfigDict(config)
28 | if accelerator.is_main_process:
29 | utils.set_logger(log_level='info', fname=config.output_path)
30 | else:
31 | utils.set_logger(log_level='error')
32 | builtins.print = lambda *args: None
33 |
34 | dataset = get_dataset(**config.dataset)
35 |
36 | nnet = utils.get_nnet(**config.nnet)
37 | nnet = accelerator.prepare(nnet)
38 | logging.info(f'load nnet from {config.nnet_path}')
39 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
40 | nnet.eval()
41 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
42 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
43 | def cfg_nnet(x, timesteps, y):
44 | _cond = nnet(x, timesteps, y=y)
45 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
46 | return _cond + config.sample.scale * (_cond - _uncond)
47 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE())
48 | else:
49 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
50 |
51 |
52 | logging.info(config.sample)
53 | assert os.path.exists(dataset.fid_stat)
54 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
55 |
56 | def sample_fn(_n_samples):
57 | x_init = torch.randn(_n_samples, *dataset.data_shape, device=device)
58 | if config.train.mode == 'uncond':
59 | kwargs = dict()
60 | elif config.train.mode == 'cond':
61 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
62 | else:
63 | raise NotImplementedError
64 |
65 | if config.sample.algorithm == 'euler_maruyama_sde':
66 | rsde = sde.ReverseSDE(score_model)
67 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
68 | elif config.sample.algorithm == 'euler_maruyama_ode':
69 | rsde = sde.ODE(score_model)
70 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
71 | elif config.sample.algorithm == 'dpm_solver':
72 | noise_schedule = NoiseScheduleVP(schedule='linear')
73 | model_fn = model_wrapper(
74 | score_model.noise_pred,
75 | noise_schedule,
76 | time_input_type='0',
77 | model_kwargs=kwargs
78 | )
79 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
80 | return dpm_solver.sample(
81 | x_init,
82 | steps=config.sample.sample_steps,
83 | eps=1e-4,
84 | adaptive_step_size=False,
85 | fast_version=True,
86 | )
87 | else:
88 | raise NotImplementedError
89 |
90 | with tempfile.TemporaryDirectory() as temp_path:
91 | path = config.sample.path or temp_path
92 | if accelerator.is_main_process:
93 | os.makedirs(path, exist_ok=True)
94 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
95 | if accelerator.is_main_process:
96 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
97 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
98 |
99 |
100 | from absl import flags
101 | from absl import app
102 | from ml_collections import config_flags
103 | import os
104 |
105 |
106 | FLAGS = flags.FLAGS
107 | config_flags.DEFINE_config_file(
108 | "config", None, "Training configuration.", lock_config=False)
109 | flags.mark_flags_as_required(["config"])
110 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
111 | flags.DEFINE_string("output_path", None, "The path to output log.")
112 |
113 |
114 | def main(argv):
115 | config = FLAGS.config
116 | config.nnet_path = FLAGS.nnet_path
117 | config.output_path = FLAGS.output_path
118 | evaluate(config)
119 |
120 |
121 | if __name__ == "__main__":
122 | app.run(main)
123 |
--------------------------------------------------------------------------------
/eval_ldm.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | import sde
8 | from datasets import get_dataset
9 | import tempfile
10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
11 | from absl import logging
12 | import builtins
13 | import libs.autoencoder
14 |
15 |
16 | def evaluate(config):
17 | if config.get('benchmark', False):
18 | torch.backends.cudnn.benchmark = True
19 | torch.backends.cudnn.deterministic = False
20 |
21 | mp.set_start_method('spawn')
22 | accelerator = accelerate.Accelerator()
23 | device = accelerator.device
24 | accelerate.utils.set_seed(config.seed, device_specific=True)
25 | logging.info(f'Process {accelerator.process_index} using device: {device}')
26 |
27 | config.mixed_precision = accelerator.mixed_precision
28 | config = ml_collections.FrozenConfigDict(config)
29 | if accelerator.is_main_process:
30 | utils.set_logger(log_level='info', fname=config.output_path)
31 | else:
32 | utils.set_logger(log_level='error')
33 | builtins.print = lambda *args: None
34 |
35 | dataset = get_dataset(**config.dataset)
36 |
37 | nnet = utils.get_nnet(**config.nnet)
38 | nnet = accelerator.prepare(nnet)
39 | logging.info(f'load nnet from {config.nnet_path}')
40 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
41 | nnet.eval()
42 |
43 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
44 | autoencoder.to(device)
45 |
46 | @torch.cuda.amp.autocast()
47 | def encode(_batch):
48 | return autoencoder.encode(_batch)
49 |
50 | @torch.cuda.amp.autocast()
51 | def decode(_batch):
52 | return autoencoder.decode(_batch)
53 |
54 | def decode_large_batch(_batch):
55 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large
56 | xs = []
57 | pt = 0
58 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
59 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
60 | pt += _decode_mini_batch_size
61 | xs.append(x)
62 | xs = torch.concat(xs, dim=0)
63 | assert xs.size(0) == _batch.size(0)
64 | return xs
65 |
66 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
67 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
68 | def cfg_nnet(x, timesteps, y):
69 | _cond = nnet(x, timesteps, y=y)
70 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
71 | return _cond + config.sample.scale * (_cond - _uncond)
72 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE())
73 | else:
74 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
75 |
76 | logging.info(config.sample)
77 | assert os.path.exists(dataset.fid_stat)
78 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
79 |
80 | def sample_fn(_n_samples):
81 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
82 | if config.train.mode == 'uncond':
83 | kwargs = dict()
84 | elif config.train.mode == 'cond':
85 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
86 | else:
87 | raise NotImplementedError
88 |
89 | if config.sample.algorithm == 'euler_maruyama_sde':
90 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
91 | elif config.sample.algorithm == 'euler_maruyama_ode':
92 | _z = sde.euler_maruyama(sde.ODE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
93 | elif config.sample.algorithm == 'dpm_solver':
94 | noise_schedule = NoiseScheduleVP(schedule='linear')
95 | model_fn = model_wrapper(
96 | score_model.noise_pred,
97 | noise_schedule,
98 | time_input_type='0',
99 | model_kwargs=kwargs
100 | )
101 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
102 | _z = dpm_solver.sample(
103 | _z_init,
104 | steps=config.sample.sample_steps,
105 | eps=1e-4,
106 | adaptive_step_size=False,
107 | fast_version=True,
108 | )
109 | else:
110 | raise NotImplementedError
111 | return decode_large_batch(_z)
112 |
113 | with tempfile.TemporaryDirectory() as temp_path:
114 | path = config.sample.path or temp_path
115 | if accelerator.is_main_process:
116 | os.makedirs(path, exist_ok=True)
117 | logging.info(f'Samples are saved in {path}')
118 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
119 | if accelerator.is_main_process:
120 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
121 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
122 |
123 |
124 | from absl import flags
125 | from absl import app
126 | from ml_collections import config_flags
127 | import os
128 |
129 |
130 | FLAGS = flags.FLAGS
131 | config_flags.DEFINE_config_file(
132 | "config", None, "Training configuration.", lock_config=False)
133 | flags.mark_flags_as_required(["config"])
134 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
135 | flags.DEFINE_string("output_path", None, "The path to output log.")
136 |
137 |
138 | def main(argv):
139 | config = FLAGS.config
140 | config.nnet_path = FLAGS.nnet_path
141 | config.output_path = FLAGS.output_path
142 | evaluate(config)
143 |
144 |
145 | if __name__ == "__main__":
146 | app.run(main)
147 |
--------------------------------------------------------------------------------
/eval_ldm_discrete.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | from datasets import get_dataset
8 | import tempfile
9 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
10 | from absl import logging
11 | import builtins
12 | import libs.autoencoder
13 |
14 |
15 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
16 | _betas = (
17 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
18 | )
19 | return _betas.numpy()
20 |
21 |
22 | def evaluate(config):
23 | if config.get('benchmark', False):
24 | torch.backends.cudnn.benchmark = True
25 | torch.backends.cudnn.deterministic = False
26 |
27 | mp.set_start_method('spawn')
28 | accelerator = accelerate.Accelerator()
29 | device = accelerator.device
30 | accelerate.utils.set_seed(config.seed, device_specific=True)
31 | logging.info(f'Process {accelerator.process_index} using device: {device}')
32 |
33 | config.mixed_precision = accelerator.mixed_precision
34 | config = ml_collections.FrozenConfigDict(config)
35 | if accelerator.is_main_process:
36 | utils.set_logger(log_level='info', fname=config.output_path)
37 | else:
38 | utils.set_logger(log_level='error')
39 | builtins.print = lambda *args: None
40 |
41 | dataset = get_dataset(**config.dataset)
42 |
43 | nnet = utils.get_nnet(**config.nnet)
44 | nnet = accelerator.prepare(nnet)
45 | logging.info(f'load nnet from {config.nnet_path}')
46 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
47 | nnet.eval()
48 |
49 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
50 | autoencoder.to(device)
51 |
52 | @torch.cuda.amp.autocast()
53 | def encode(_batch):
54 | return autoencoder.encode(_batch)
55 |
56 | @torch.cuda.amp.autocast()
57 | def decode(_batch):
58 | return autoencoder.decode(_batch)
59 |
60 | def decode_large_batch(_batch):
61 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large
62 | xs = []
63 | pt = 0
64 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
65 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
66 | pt += _decode_mini_batch_size
67 | xs.append(x)
68 | xs = torch.concat(xs, dim=0)
69 | assert xs.size(0) == _batch.size(0)
70 | return xs
71 |
72 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
73 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
74 | def cfg_nnet(x, timesteps, y):
75 | _cond = nnet(x, timesteps, y=y)
76 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
77 | return _cond + config.sample.scale * (_cond - _uncond)
78 | else:
79 | def cfg_nnet(x, timesteps, y):
80 | _cond = nnet(x, timesteps, y=y)
81 | return _cond
82 |
83 | logging.info(config.sample)
84 | assert os.path.exists(dataset.fid_stat)
85 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
86 |
87 | _betas = stable_diffusion_beta_schedule()
88 | N = len(_betas)
89 |
90 | def sample_z(_n_samples, _sample_steps, **kwargs):
91 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
92 |
93 | if config.sample.algorithm == 'dpm_solver':
94 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
95 |
96 | def model_fn(x, t_continuous):
97 | t = t_continuous * N
98 | eps_pre = cfg_nnet(x, t, **kwargs)
99 | return eps_pre
100 |
101 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
102 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.)
103 |
104 | else:
105 | raise NotImplementedError
106 |
107 | return _z
108 |
109 | def sample_fn(_n_samples):
110 | if config.train.mode == 'uncond':
111 | kwargs = dict()
112 | elif config.train.mode == 'cond':
113 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
114 | else:
115 | raise NotImplementedError
116 | _z = sample_z(_n_samples, _sample_steps=config.sample.sample_steps, **kwargs)
117 | return decode_large_batch(_z)
118 |
119 | with tempfile.TemporaryDirectory() as temp_path:
120 | path = config.sample.path or temp_path
121 | if accelerator.is_main_process:
122 | os.makedirs(path, exist_ok=True)
123 | logging.info(f'Samples are saved in {path}')
124 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
125 | if accelerator.is_main_process:
126 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
127 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
128 |
129 |
130 | from absl import flags
131 | from absl import app
132 | from ml_collections import config_flags
133 | import os
134 |
135 |
136 | FLAGS = flags.FLAGS
137 | config_flags.DEFINE_config_file(
138 | "config", None, "Training configuration.", lock_config=False)
139 | flags.mark_flags_as_required(["config"])
140 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
141 | flags.DEFINE_string("output_path", None, "The path to output log.")
142 |
143 |
144 | def main(argv):
145 | config = FLAGS.config
146 | config.nnet_path = FLAGS.nnet_path
147 | config.output_path = FLAGS.output_path
148 | evaluate(config)
149 |
150 |
151 | if __name__ == "__main__":
152 | app.run(main)
153 |
--------------------------------------------------------------------------------
/eval_t2i_discrete.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | from torch.utils.data import DataLoader
7 | import utils
8 | from datasets import get_dataset
9 | import tempfile
10 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
11 | from absl import logging
12 | import builtins
13 | import einops
14 | import libs.autoencoder
15 |
16 |
17 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
18 | _betas = (
19 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
20 | )
21 | return _betas.numpy()
22 |
23 |
24 | def evaluate(config):
25 | if config.get('benchmark', False):
26 | torch.backends.cudnn.benchmark = True
27 | torch.backends.cudnn.deterministic = False
28 |
29 | mp.set_start_method('spawn')
30 | accelerator = accelerate.Accelerator()
31 | device = accelerator.device
32 | accelerate.utils.set_seed(config.seed, device_specific=True)
33 | logging.info(f'Process {accelerator.process_index} using device: {device}')
34 |
35 | config.mixed_precision = accelerator.mixed_precision
36 | config = ml_collections.FrozenConfigDict(config)
37 | if accelerator.is_main_process:
38 | utils.set_logger(log_level='info', fname=config.output_path)
39 | else:
40 | utils.set_logger(log_level='error')
41 | builtins.print = lambda *args: None
42 |
43 | dataset = get_dataset(**config.dataset)
44 | test_dataset = dataset.get_split(split='test', labeled=True) # for sampling
45 | test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True,
46 | drop_last=True, num_workers=8, pin_memory=True, persistent_workers=True)
47 |
48 | nnet = utils.get_nnet(**config.nnet)
49 | nnet, test_dataset_loader = accelerator.prepare(nnet, test_dataset_loader)
50 | logging.info(f'load nnet from {config.nnet_path}')
51 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
52 | nnet.eval()
53 |
54 | def cfg_nnet(x, timesteps, context):
55 | _cond = nnet(x, timesteps, context=context)
56 | if config.sample.scale == 0:
57 | return _cond
58 | _empty_context = torch.tensor(dataset.empty_context, device=device)
59 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
60 | _uncond = nnet(x, timesteps, context=_empty_context)
61 | return _cond + config.sample.scale * (_cond - _uncond)
62 |
63 | autoencoder = libs.autoencoder.get_model(**config.autoencoder)
64 | autoencoder.to(device)
65 |
66 | @torch.cuda.amp.autocast()
67 | def encode(_batch):
68 | return autoencoder.encode(_batch)
69 |
70 | @torch.cuda.amp.autocast()
71 | def decode(_batch):
72 | return autoencoder.decode(_batch)
73 |
74 | def decode_large_batch(_batch):
75 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large
76 | xs = []
77 | pt = 0
78 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
79 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
80 | pt += _decode_mini_batch_size
81 | xs.append(x)
82 | xs = torch.concat(xs, dim=0)
83 | assert xs.size(0) == _batch.size(0)
84 | return xs
85 |
86 | def get_context_generator():
87 | while True:
88 | for data in test_dataset_loader:
89 | _, _context = data
90 | yield _context
91 |
92 | context_generator = get_context_generator()
93 |
94 | _betas = stable_diffusion_beta_schedule()
95 | N = len(_betas)
96 |
97 | logging.info(config.sample)
98 | assert os.path.exists(dataset.fid_stat)
99 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}')
100 |
101 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
102 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
103 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
104 |
105 | def model_fn(x, t_continuous):
106 | t = t_continuous * N
107 | return cfg_nnet(x, t, **kwargs)
108 |
109 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
110 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.)
111 | return decode_large_batch(_z)
112 |
113 | def sample_fn(_n_samples):
114 | _context = next(context_generator)
115 | assert _context.size(0) == _n_samples
116 | return dpm_solver_sample(_n_samples, config.sample.sample_steps, context=_context)
117 |
118 | with tempfile.TemporaryDirectory() as temp_path:
119 | path = config.sample.path or temp_path
120 | if accelerator.is_main_process:
121 | os.makedirs(path, exist_ok=True)
122 | logging.info(f'Samples are saved in {path}')
123 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
124 | if accelerator.is_main_process:
125 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
126 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
127 |
128 |
129 | from absl import flags
130 | from absl import app
131 | from ml_collections import config_flags
132 | import os
133 |
134 |
135 | FLAGS = flags.FLAGS
136 | config_flags.DEFINE_config_file(
137 | "config", None, "Training configuration.", lock_config=False)
138 | flags.mark_flags_as_required(["config"])
139 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
140 | flags.DEFINE_string("output_path", None, "The path to output log.")
141 |
142 |
143 | def main(argv):
144 | config = FLAGS.config
145 | config.nnet_path = FLAGS.nnet_path
146 | config.output_path = FLAGS.output_path
147 | evaluate(config)
148 |
149 |
150 | if __name__ == "__main__":
151 | app.run(main)
152 |
--------------------------------------------------------------------------------
/libs/__init__.py:
--------------------------------------------------------------------------------
1 | # codes from third party
2 |
--------------------------------------------------------------------------------
/libs/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from einops import rearrange
5 |
6 |
7 | class LinearAttention(nn.Module):
8 | def __init__(self, dim, heads=4, dim_head=32):
9 | super().__init__()
10 | self.heads = heads
11 | hidden_dim = dim_head * heads
12 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
13 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
14 |
15 | def forward(self, x):
16 | b, c, h, w = x.shape
17 | qkv = self.to_qkv(x)
18 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
19 | k = k.softmax(dim=-1)
20 | context = torch.einsum('bhdn,bhen->bhde', k, v)
21 | out = torch.einsum('bhde,bhdn->bhen', context, q)
22 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
23 | return self.to_out(out)
24 |
25 |
26 | def nonlinearity(x):
27 | # swish
28 | return x*torch.sigmoid(x)
29 |
30 |
31 | def Normalize(in_channels, num_groups=32):
32 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
33 |
34 |
35 | class Upsample(nn.Module):
36 | def __init__(self, in_channels, with_conv):
37 | super().__init__()
38 | self.with_conv = with_conv
39 | if self.with_conv:
40 | self.conv = torch.nn.Conv2d(in_channels,
41 | in_channels,
42 | kernel_size=3,
43 | stride=1,
44 | padding=1)
45 |
46 | def forward(self, x):
47 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
48 | if self.with_conv:
49 | x = self.conv(x)
50 | return x
51 |
52 |
53 | class Downsample(nn.Module):
54 | def __init__(self, in_channels, with_conv):
55 | super().__init__()
56 | self.with_conv = with_conv
57 | if self.with_conv:
58 | # no asymmetric padding in torch conv, must do it ourselves
59 | self.conv = torch.nn.Conv2d(in_channels,
60 | in_channels,
61 | kernel_size=3,
62 | stride=2,
63 | padding=0)
64 |
65 | def forward(self, x):
66 | if self.with_conv:
67 | pad = (0,1,0,1)
68 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
69 | x = self.conv(x)
70 | else:
71 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
72 | return x
73 |
74 |
75 | class ResnetBlock(nn.Module):
76 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
77 | dropout, temb_channels=512):
78 | super().__init__()
79 | self.in_channels = in_channels
80 | out_channels = in_channels if out_channels is None else out_channels
81 | self.out_channels = out_channels
82 | self.use_conv_shortcut = conv_shortcut
83 |
84 | self.norm1 = Normalize(in_channels)
85 | self.conv1 = torch.nn.Conv2d(in_channels,
86 | out_channels,
87 | kernel_size=3,
88 | stride=1,
89 | padding=1)
90 | if temb_channels > 0:
91 | self.temb_proj = torch.nn.Linear(temb_channels,
92 | out_channels)
93 | self.norm2 = Normalize(out_channels)
94 | self.dropout = torch.nn.Dropout(dropout)
95 | self.conv2 = torch.nn.Conv2d(out_channels,
96 | out_channels,
97 | kernel_size=3,
98 | stride=1,
99 | padding=1)
100 | if self.in_channels != self.out_channels:
101 | if self.use_conv_shortcut:
102 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
103 | out_channels,
104 | kernel_size=3,
105 | stride=1,
106 | padding=1)
107 | else:
108 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
109 | out_channels,
110 | kernel_size=1,
111 | stride=1,
112 | padding=0)
113 |
114 | def forward(self, x, temb):
115 | h = x
116 | h = self.norm1(h)
117 | h = nonlinearity(h)
118 | h = self.conv1(h)
119 |
120 | if temb is not None:
121 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
122 |
123 | h = self.norm2(h)
124 | h = nonlinearity(h)
125 | h = self.dropout(h)
126 | h = self.conv2(h)
127 |
128 | if self.in_channels != self.out_channels:
129 | if self.use_conv_shortcut:
130 | x = self.conv_shortcut(x)
131 | else:
132 | x = self.nin_shortcut(x)
133 |
134 | return x+h
135 |
136 |
137 | class LinAttnBlock(LinearAttention):
138 | """to match AttnBlock usage"""
139 | def __init__(self, in_channels):
140 | super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
141 |
142 |
143 | class AttnBlock(nn.Module):
144 | def __init__(self, in_channels):
145 | super().__init__()
146 | self.in_channels = in_channels
147 |
148 | self.norm = Normalize(in_channels)
149 | self.q = torch.nn.Conv2d(in_channels,
150 | in_channels,
151 | kernel_size=1,
152 | stride=1,
153 | padding=0)
154 | self.k = torch.nn.Conv2d(in_channels,
155 | in_channels,
156 | kernel_size=1,
157 | stride=1,
158 | padding=0)
159 | self.v = torch.nn.Conv2d(in_channels,
160 | in_channels,
161 | kernel_size=1,
162 | stride=1,
163 | padding=0)
164 | self.proj_out = torch.nn.Conv2d(in_channels,
165 | in_channels,
166 | kernel_size=1,
167 | stride=1,
168 | padding=0)
169 |
170 |
171 | def forward(self, x):
172 | h_ = x
173 | h_ = self.norm(h_)
174 | q = self.q(h_)
175 | k = self.k(h_)
176 | v = self.v(h_)
177 |
178 | # compute attention
179 | b,c,h,w = q.shape
180 | q = q.reshape(b,c,h*w)
181 | q = q.permute(0,2,1) # b,hw,c
182 | k = k.reshape(b,c,h*w) # b,c,hw
183 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
184 | w_ = w_ * (int(c)**(-0.5))
185 | w_ = torch.nn.functional.softmax(w_, dim=2)
186 |
187 | # attend to values
188 | v = v.reshape(b,c,h*w)
189 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
190 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
191 | h_ = h_.reshape(b,c,h,w)
192 |
193 | h_ = self.proj_out(h_)
194 |
195 | return x+h_
196 |
197 |
198 | def make_attn(in_channels, attn_type="vanilla"):
199 | assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
200 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
201 | if attn_type == "vanilla":
202 | return AttnBlock(in_channels)
203 | elif attn_type == "none":
204 | return nn.Identity(in_channels)
205 | else:
206 | return LinAttnBlock(in_channels)
207 |
208 |
209 | class Encoder(nn.Module):
210 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
211 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
212 | resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
213 | **ignore_kwargs):
214 | super().__init__()
215 | if use_linear_attn: attn_type = "linear"
216 | self.ch = ch
217 | self.temb_ch = 0
218 | self.num_resolutions = len(ch_mult)
219 | self.num_res_blocks = num_res_blocks
220 | self.resolution = resolution
221 | self.in_channels = in_channels
222 |
223 | # downsampling
224 | self.conv_in = torch.nn.Conv2d(in_channels,
225 | self.ch,
226 | kernel_size=3,
227 | stride=1,
228 | padding=1)
229 |
230 | curr_res = resolution
231 | in_ch_mult = (1,)+tuple(ch_mult)
232 | self.in_ch_mult = in_ch_mult
233 | self.down = nn.ModuleList()
234 | for i_level in range(self.num_resolutions):
235 | block = nn.ModuleList()
236 | attn = nn.ModuleList()
237 | block_in = ch*in_ch_mult[i_level]
238 | block_out = ch*ch_mult[i_level]
239 | for i_block in range(self.num_res_blocks):
240 | block.append(ResnetBlock(in_channels=block_in,
241 | out_channels=block_out,
242 | temb_channels=self.temb_ch,
243 | dropout=dropout))
244 | block_in = block_out
245 | if curr_res in attn_resolutions:
246 | attn.append(make_attn(block_in, attn_type=attn_type))
247 | down = nn.Module()
248 | down.block = block
249 | down.attn = attn
250 | if i_level != self.num_resolutions-1:
251 | down.downsample = Downsample(block_in, resamp_with_conv)
252 | curr_res = curr_res // 2
253 | self.down.append(down)
254 |
255 | # middle
256 | self.mid = nn.Module()
257 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
258 | out_channels=block_in,
259 | temb_channels=self.temb_ch,
260 | dropout=dropout)
261 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
262 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
263 | out_channels=block_in,
264 | temb_channels=self.temb_ch,
265 | dropout=dropout)
266 |
267 | # end
268 | self.norm_out = Normalize(block_in)
269 | self.conv_out = torch.nn.Conv2d(block_in,
270 | 2*z_channels if double_z else z_channels,
271 | kernel_size=3,
272 | stride=1,
273 | padding=1)
274 |
275 | def forward(self, x):
276 | # timestep embedding
277 | temb = None
278 |
279 | # downsampling
280 | hs = [self.conv_in(x)]
281 | for i_level in range(self.num_resolutions):
282 | for i_block in range(self.num_res_blocks):
283 | h = self.down[i_level].block[i_block](hs[-1], temb)
284 | if len(self.down[i_level].attn) > 0:
285 | h = self.down[i_level].attn[i_block](h)
286 | hs.append(h)
287 | if i_level != self.num_resolutions-1:
288 | hs.append(self.down[i_level].downsample(hs[-1]))
289 |
290 | # middle
291 | h = hs[-1]
292 | h = self.mid.block_1(h, temb)
293 | h = self.mid.attn_1(h)
294 | h = self.mid.block_2(h, temb)
295 |
296 | # end
297 | h = self.norm_out(h)
298 | h = nonlinearity(h)
299 | h = self.conv_out(h)
300 | return h
301 |
302 |
303 | class Decoder(nn.Module):
304 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
305 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
306 | resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
307 | attn_type="vanilla", **ignorekwargs):
308 | super().__init__()
309 | if use_linear_attn: attn_type = "linear"
310 | self.ch = ch
311 | self.temb_ch = 0
312 | self.num_resolutions = len(ch_mult)
313 | self.num_res_blocks = num_res_blocks
314 | self.resolution = resolution
315 | self.in_channels = in_channels
316 | self.give_pre_end = give_pre_end
317 | self.tanh_out = tanh_out
318 |
319 | # compute in_ch_mult, block_in and curr_res at lowest res
320 | in_ch_mult = (1,)+tuple(ch_mult)
321 | block_in = ch*ch_mult[self.num_resolutions-1]
322 | curr_res = resolution // 2**(self.num_resolutions-1)
323 | self.z_shape = (1,z_channels,curr_res,curr_res)
324 | print("Working with z of shape {} = {} dimensions.".format(
325 | self.z_shape, np.prod(self.z_shape)))
326 |
327 | # z to block_in
328 | self.conv_in = torch.nn.Conv2d(z_channels,
329 | block_in,
330 | kernel_size=3,
331 | stride=1,
332 | padding=1)
333 |
334 | # middle
335 | self.mid = nn.Module()
336 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
337 | out_channels=block_in,
338 | temb_channels=self.temb_ch,
339 | dropout=dropout)
340 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
341 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
342 | out_channels=block_in,
343 | temb_channels=self.temb_ch,
344 | dropout=dropout)
345 |
346 | # upsampling
347 | self.up = nn.ModuleList()
348 | for i_level in reversed(range(self.num_resolutions)):
349 | block = nn.ModuleList()
350 | attn = nn.ModuleList()
351 | block_out = ch*ch_mult[i_level]
352 | for i_block in range(self.num_res_blocks+1):
353 | block.append(ResnetBlock(in_channels=block_in,
354 | out_channels=block_out,
355 | temb_channels=self.temb_ch,
356 | dropout=dropout))
357 | block_in = block_out
358 | if curr_res in attn_resolutions:
359 | attn.append(make_attn(block_in, attn_type=attn_type))
360 | up = nn.Module()
361 | up.block = block
362 | up.attn = attn
363 | if i_level != 0:
364 | up.upsample = Upsample(block_in, resamp_with_conv)
365 | curr_res = curr_res * 2
366 | self.up.insert(0, up) # prepend to get consistent order
367 |
368 | # end
369 | self.norm_out = Normalize(block_in)
370 | self.conv_out = torch.nn.Conv2d(block_in,
371 | out_ch,
372 | kernel_size=3,
373 | stride=1,
374 | padding=1)
375 |
376 | def forward(self, z):
377 | #assert z.shape[1:] == self.z_shape[1:]
378 | self.last_z_shape = z.shape
379 |
380 | # timestep embedding
381 | temb = None
382 |
383 | # z to block_in
384 | h = self.conv_in(z)
385 |
386 | # middle
387 | h = self.mid.block_1(h, temb)
388 | h = self.mid.attn_1(h)
389 | h = self.mid.block_2(h, temb)
390 |
391 | # upsampling
392 | for i_level in reversed(range(self.num_resolutions)):
393 | for i_block in range(self.num_res_blocks+1):
394 | h = self.up[i_level].block[i_block](h, temb)
395 | if len(self.up[i_level].attn) > 0:
396 | h = self.up[i_level].attn[i_block](h)
397 | if i_level != 0:
398 | h = self.up[i_level].upsample(h)
399 |
400 | # end
401 | if self.give_pre_end:
402 | return h
403 |
404 | h = self.norm_out(h)
405 | h = nonlinearity(h)
406 | h = self.conv_out(h)
407 | if self.tanh_out:
408 | h = torch.tanh(h)
409 | return h
410 |
411 |
412 | class FrozenAutoencoderKL(nn.Module):
413 | def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
414 | super().__init__()
415 | print(f'Create autoencoder with scale_factor={scale_factor}')
416 | self.encoder = Encoder(**ddconfig)
417 | self.decoder = Decoder(**ddconfig)
418 | assert ddconfig["double_z"]
419 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
420 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
421 | self.embed_dim = embed_dim
422 | self.scale_factor = scale_factor
423 | m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
424 | assert len(m) == 0 and len(u) == 0
425 | self.eval()
426 | self.requires_grad_(False)
427 |
428 | def encode_moments(self, x):
429 | h = self.encoder(x)
430 | moments = self.quant_conv(h)
431 | return moments
432 |
433 | def sample(self, moments):
434 | mean, logvar = torch.chunk(moments, 2, dim=1)
435 | logvar = torch.clamp(logvar, -30.0, 20.0)
436 | std = torch.exp(0.5 * logvar)
437 | z = mean + std * torch.randn_like(mean)
438 | z = self.scale_factor * z
439 | return z
440 |
441 | def encode(self, x):
442 | moments = self.encode_moments(x)
443 | z = self.sample(moments)
444 | return z
445 |
446 | def decode(self, z):
447 | z = (1. / self.scale_factor) * z
448 | z = self.post_quant_conv(z)
449 | dec = self.decoder(z)
450 | return dec
451 |
452 | def forward(self, inputs, fn):
453 | if fn == 'encode_moments':
454 | return self.encode_moments(inputs)
455 | elif fn == 'encode':
456 | return self.encode(inputs)
457 | elif fn == 'decode':
458 | return self.decode(inputs)
459 | else:
460 | raise NotImplementedError
461 |
462 |
463 | def get_model(pretrained_path, scale_factor=0.18215):
464 | ddconfig = dict(
465 | double_z=True,
466 | z_channels=4,
467 | resolution=256,
468 | in_channels=3,
469 | out_ch=3,
470 | ch=128,
471 | ch_mult=[1, 2, 4, 4],
472 | num_res_blocks=2,
473 | attn_resolutions=[],
474 | dropout=0.0
475 | )
476 | return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
477 |
478 |
479 | def main():
480 | import torchvision.transforms as transforms
481 | from torchvision.utils import save_image
482 | import os
483 | from PIL import Image
484 |
485 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
486 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
487 | model = model.to(device)
488 |
489 | scale_factor = 0.18215
490 | T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
491 | path = 'imgs'
492 | fnames = os.listdir(path)
493 | for fname in fnames:
494 | p = os.path.join(path, fname)
495 | img = Image.open(p)
496 | img = T(img)
497 | img = img * 2. - 1
498 | img = img[None, ...]
499 | img = img.to(device)
500 |
501 | # with torch.cuda.amp.autocast():
502 | # moments = model.encode_moments(img)
503 | # mean, logvar = torch.chunk(moments, 2, dim=1)
504 | # logvar = torch.clamp(logvar, -30.0, 20.0)
505 | # std = torch.exp(0.5 * logvar)
506 | # zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
507 | # recons = [model.decode(z) for z in zs]
508 |
509 | with torch.cuda.amp.autocast():
510 | print('test encode & decode')
511 | recons = [model.decode(model.encode(img)) for _ in range(4)]
512 |
513 | out = torch.cat([img, *recons], dim=0)
514 | out = (out + 1) * 0.5
515 | save_image(out, f'recons_{fname}')
516 |
517 |
518 | if __name__ == "__main__":
519 | main()
520 |
--------------------------------------------------------------------------------
/libs/clip.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import CLIPTokenizer, CLIPTextModel
3 |
4 |
5 | class AbstractEncoder(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def encode(self, *args, **kwargs):
10 | raise NotImplementedError
11 |
12 |
13 | class FrozenCLIPEmbedder(AbstractEncoder):
14 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
16 | super().__init__()
17 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
18 | self.transformer = CLIPTextModel.from_pretrained(version)
19 | self.device = device
20 | self.max_length = max_length
21 | self.freeze()
22 |
23 | def freeze(self):
24 | self.transformer = self.transformer.eval()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def forward(self, text):
29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
31 | tokens = batch_encoding["input_ids"].to(self.device)
32 | outputs = self.transformer(input_ids=tokens)
33 |
34 | z = outputs.last_hidden_state
35 | return z
36 |
37 | def encode(self, text):
38 | return self(text)
39 |
--------------------------------------------------------------------------------
/libs/timm.py:
--------------------------------------------------------------------------------
1 | # code from timm 0.3.2
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | import warnings
6 |
7 |
8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | with torch.no_grad():
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | Args:
53 | tensor: an n-dimensional `torch.Tensor`
54 | mean: the mean of the normal distribution
55 | std: the standard deviation of the normal distribution
56 | a: the minimum cutoff value
57 | b: the maximum cutoff value
58 | Examples:
59 | >>> w = torch.empty(3, 5)
60 | >>> nn.init.trunc_normal_(w)
61 | """
62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63 |
64 |
65 | def drop_path(x, drop_prob: float = 0., training: bool = False):
66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67 |
68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72 | 'survival rate' as the argument.
73 |
74 | """
75 | if drop_prob == 0. or not training:
76 | return x
77 | keep_prob = 1 - drop_prob
78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80 | random_tensor.floor_() # binarize
81 | output = x.div(keep_prob) * random_tensor
82 | return output
83 |
84 |
85 | class DropPath(nn.Module):
86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87 | """
88 | def __init__(self, drop_prob=None):
89 | super(DropPath, self).__init__()
90 | self.drop_prob = drop_prob
91 |
92 | def forward(self, x):
93 | return drop_path(x, self.drop_prob, self.training)
94 |
95 |
96 | class Mlp(nn.Module):
97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
98 | super().__init__()
99 | out_features = out_features or in_features
100 | hidden_features = hidden_features or in_features
101 | self.fc1 = nn.Linear(in_features, hidden_features)
102 | self.act = act_layer()
103 | self.fc2 = nn.Linear(hidden_features, out_features)
104 | self.drop = nn.Dropout(drop)
105 |
106 | def forward(self, x):
107 | x = self.fc1(x)
108 | x = self.act(x)
109 | x = self.drop(x)
110 | x = self.fc2(x)
111 | x = self.drop(x)
112 | return x
113 |
--------------------------------------------------------------------------------
/libs/uvit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
9 | ATTENTION_MODE = 'flash'
10 | else:
11 | try:
12 | import xformers
13 | import xformers.ops
14 | ATTENTION_MODE = 'xformers'
15 | except:
16 | ATTENTION_MODE = 'math'
17 | print(f'attention mode is {ATTENTION_MODE}')
18 |
19 |
20 | def timestep_embedding(timesteps, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 |
24 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
25 | These may be fractional.
26 | :param dim: the dimension of the output.
27 | :param max_period: controls the minimum frequency of the embeddings.
28 | :return: an [N x dim] Tensor of positional embeddings.
29 | """
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33 | ).to(device=timesteps.device)
34 | args = timesteps[:, None].float() * freqs[None]
35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36 | if dim % 2:
37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38 | return embedding
39 |
40 |
41 | def patchify(imgs, patch_size):
42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
43 | return x
44 |
45 |
46 | def unpatchify(x, channels=3):
47 | patch_size = int((x.shape[2] // channels) ** 0.5)
48 | h = w = int(x.shape[1] ** .5)
49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
56 | super().__init__()
57 | self.num_heads = num_heads
58 | head_dim = dim // num_heads
59 | self.scale = qk_scale or head_dim ** -0.5
60 |
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.attn_drop = nn.Dropout(attn_drop)
63 | self.proj = nn.Linear(dim, dim)
64 | self.proj_drop = nn.Dropout(proj_drop)
65 |
66 | def forward(self, x):
67 | B, L, C = x.shape
68 |
69 | qkv = self.qkv(x)
70 | if ATTENTION_MODE == 'flash':
71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
74 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
75 | elif ATTENTION_MODE == 'xformers':
76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
78 | x = xformers.ops.memory_efficient_attention(q, k, v)
79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
80 | elif ATTENTION_MODE == 'math':
81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
87 | else:
88 | raise NotImplemented
89 |
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
107 | self.use_checkpoint = use_checkpoint
108 |
109 | def forward(self, x, skip=None):
110 | if self.use_checkpoint:
111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
112 | else:
113 | return self._forward(x, skip)
114 |
115 | def _forward(self, x, skip=None):
116 | if self.skip_linear is not None:
117 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
118 | x = x + self.attn(self.norm1(x))
119 | x = x + self.mlp(self.norm2(x))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | self.patch_size = patch_size
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | assert H % self.patch_size == 0 and W % self.patch_size == 0
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class UViT(nn.Module):
139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
141 | use_checkpoint=False, conv=True, skip=True):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
144 | self.num_classes = num_classes
145 | self.in_chans = in_chans
146 |
147 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
148 | num_patches = (img_size // patch_size) ** 2
149 |
150 | self.time_embed = nn.Sequential(
151 | nn.Linear(embed_dim, 4 * embed_dim),
152 | nn.SiLU(),
153 | nn.Linear(4 * embed_dim, embed_dim),
154 | ) if mlp_time_embed else nn.Identity()
155 |
156 | if self.num_classes > 0:
157 | self.label_emb = nn.Embedding(self.num_classes, embed_dim)
158 | self.extras = 2
159 | else:
160 | self.extras = 1
161 |
162 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
163 |
164 | self.in_blocks = nn.ModuleList([
165 | Block(
166 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
167 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
168 | for _ in range(depth // 2)])
169 |
170 | self.mid_block = Block(
171 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
172 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
173 |
174 | self.out_blocks = nn.ModuleList([
175 | Block(
176 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
177 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
178 | for _ in range(depth // 2)])
179 |
180 | self.norm = norm_layer(embed_dim)
181 | self.patch_dim = patch_size ** 2 * in_chans
182 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
183 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
184 |
185 | trunc_normal_(self.pos_embed, std=.02)
186 | self.apply(self._init_weights)
187 |
188 | def _init_weights(self, m):
189 | if isinstance(m, nn.Linear):
190 | trunc_normal_(m.weight, std=.02)
191 | if isinstance(m, nn.Linear) and m.bias is not None:
192 | nn.init.constant_(m.bias, 0)
193 | elif isinstance(m, nn.LayerNorm):
194 | nn.init.constant_(m.bias, 0)
195 | nn.init.constant_(m.weight, 1.0)
196 |
197 | @torch.jit.ignore
198 | def no_weight_decay(self):
199 | return {'pos_embed'}
200 |
201 | def forward(self, x, timesteps, y=None):
202 | x = self.patch_embed(x)
203 | B, L, D = x.shape
204 |
205 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
206 | time_token = time_token.unsqueeze(dim=1)
207 | x = torch.cat((time_token, x), dim=1)
208 | if y is not None:
209 | label_emb = self.label_emb(y)
210 | label_emb = label_emb.unsqueeze(dim=1)
211 | x = torch.cat((label_emb, x), dim=1)
212 | x = x + self.pos_embed
213 |
214 | skips = []
215 | for blk in self.in_blocks:
216 | x = blk(x)
217 | skips.append(x)
218 |
219 | x = self.mid_block(x)
220 |
221 | for blk in self.out_blocks:
222 | x = blk(x, skips.pop())
223 |
224 | x = self.norm(x)
225 | x = self.decoder_pred(x)
226 | assert x.size(1) == self.extras + L
227 | x = x[:, self.extras:, :]
228 | x = unpatchify(x, self.in_chans)
229 | x = self.final_layer(x)
230 | return x
231 |
--------------------------------------------------------------------------------
/libs/uvit_t2i.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
9 | ATTENTION_MODE = 'flash'
10 | else:
11 | try:
12 | import xformers
13 | import xformers.ops
14 | ATTENTION_MODE = 'xformers'
15 | except:
16 | ATTENTION_MODE = 'math'
17 | print(f'attention mode is {ATTENTION_MODE}')
18 |
19 |
20 | def timestep_embedding(timesteps, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 |
24 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
25 | These may be fractional.
26 | :param dim: the dimension of the output.
27 | :param max_period: controls the minimum frequency of the embeddings.
28 | :return: an [N x dim] Tensor of positional embeddings.
29 | """
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33 | ).to(device=timesteps.device)
34 | args = timesteps[:, None].float() * freqs[None]
35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36 | if dim % 2:
37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38 | return embedding
39 |
40 |
41 | def patchify(imgs, patch_size):
42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
43 | return x
44 |
45 |
46 | def unpatchify(x, channels=3):
47 | patch_size = int((x.shape[2] // channels) ** 0.5)
48 | h = w = int(x.shape[1] ** .5)
49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
56 | super().__init__()
57 | self.num_heads = num_heads
58 | head_dim = dim // num_heads
59 | self.scale = qk_scale or head_dim ** -0.5
60 |
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.attn_drop = nn.Dropout(attn_drop)
63 | self.proj = nn.Linear(dim, dim)
64 | self.proj_drop = nn.Dropout(proj_drop)
65 |
66 | def forward(self, x):
67 | B, L, C = x.shape
68 |
69 | qkv = self.qkv(x)
70 | if ATTENTION_MODE == 'flash':
71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
74 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
75 | elif ATTENTION_MODE == 'xformers':
76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
78 | x = xformers.ops.memory_efficient_attention(q, k, v)
79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
80 | elif ATTENTION_MODE == 'math':
81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
87 | else:
88 | raise NotImplemented
89 |
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
107 | self.use_checkpoint = use_checkpoint
108 |
109 | def forward(self, x, skip=None):
110 | if self.use_checkpoint:
111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
112 | else:
113 | return self._forward(x, skip)
114 |
115 | def _forward(self, x, skip=None):
116 | if self.skip_linear is not None:
117 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
118 | x = x + self.attn(self.norm1(x))
119 | x = x + self.mlp(self.norm2(x))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | self.patch_size = patch_size
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | assert H % self.patch_size == 0 and W % self.patch_size == 0
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class UViT(nn.Module):
139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
141 | clip_dim=768, num_clip_token=77, conv=True, skip=True):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
144 | self.in_chans = in_chans
145 |
146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
147 | num_patches = (img_size // patch_size) ** 2
148 |
149 | self.time_embed = nn.Sequential(
150 | nn.Linear(embed_dim, 4 * embed_dim),
151 | nn.SiLU(),
152 | nn.Linear(4 * embed_dim, embed_dim),
153 | ) if mlp_time_embed else nn.Identity()
154 |
155 | self.context_embed = nn.Linear(clip_dim, embed_dim)
156 |
157 | self.extras = 1 + num_clip_token
158 |
159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
160 |
161 | self.in_blocks = nn.ModuleList([
162 | Block(
163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
165 | for _ in range(depth // 2)])
166 |
167 | self.mid_block = Block(
168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
170 |
171 | self.out_blocks = nn.ModuleList([
172 | Block(
173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
175 | for _ in range(depth // 2)])
176 |
177 | self.norm = norm_layer(embed_dim)
178 | self.patch_dim = patch_size ** 2 * in_chans
179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
181 |
182 | trunc_normal_(self.pos_embed, std=.02)
183 | self.apply(self._init_weights)
184 |
185 | def _init_weights(self, m):
186 | if isinstance(m, nn.Linear):
187 | trunc_normal_(m.weight, std=.02)
188 | if isinstance(m, nn.Linear) and m.bias is not None:
189 | nn.init.constant_(m.bias, 0)
190 | elif isinstance(m, nn.LayerNorm):
191 | nn.init.constant_(m.bias, 0)
192 | nn.init.constant_(m.weight, 1.0)
193 |
194 | @torch.jit.ignore
195 | def no_weight_decay(self):
196 | return {'pos_embed'}
197 |
198 | def forward(self, x, timesteps, context):
199 | x = self.patch_embed(x)
200 | B, L, D = x.shape
201 |
202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
203 | time_token = time_token.unsqueeze(dim=1)
204 | context_token = self.context_embed(context)
205 | x = torch.cat((time_token, context_token, x), dim=1)
206 | x = x + self.pos_embed
207 |
208 | skips = []
209 | for blk in self.in_blocks:
210 | x = blk(x)
211 | skips.append(x)
212 |
213 | x = self.mid_block(x)
214 |
215 | for blk in self.out_blocks:
216 | x = blk(x, skips.pop())
217 |
218 | x = self.norm(x)
219 | x = self.decoder_pred(x)
220 | assert x.size(1) == self.extras + L
221 | x = x[:, self.extras:, :]
222 | x = unpatchify(x, self.in_chans)
223 | x = self.final_layer(x)
224 | return x
225 |
--------------------------------------------------------------------------------
/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baofff/U-ViT/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/sample.png
--------------------------------------------------------------------------------
/sample_t2i_discrete.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import torch
3 | from torch import multiprocessing as mp
4 | import accelerate
5 | import utils
6 | from datasets import get_dataset
7 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
8 | from absl import logging
9 | import builtins
10 | import einops
11 | import libs.autoencoder
12 | import libs.clip
13 | from torchvision.utils import save_image
14 |
15 |
16 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
17 | _betas = (
18 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
19 | )
20 | return _betas.numpy()
21 |
22 |
23 | def evaluate(config):
24 | if config.get('benchmark', False):
25 | torch.backends.cudnn.benchmark = True
26 | torch.backends.cudnn.deterministic = False
27 |
28 | mp.set_start_method('spawn')
29 | accelerator = accelerate.Accelerator()
30 | device = accelerator.device
31 | accelerate.utils.set_seed(config.seed, device_specific=True)
32 | logging.info(f'Process {accelerator.process_index} using device: {device}')
33 |
34 | config.mixed_precision = accelerator.mixed_precision
35 | config = ml_collections.FrozenConfigDict(config)
36 | if accelerator.is_main_process:
37 | utils.set_logger(log_level='info')
38 | else:
39 | utils.set_logger(log_level='error')
40 | builtins.print = lambda *args: None
41 |
42 | dataset = get_dataset(**config.dataset)
43 |
44 | with open(config.input_path, 'r') as f:
45 | prompts = f.read().strip().split('\n')
46 |
47 | print(prompts)
48 |
49 | clip = libs.clip.FrozenCLIPEmbedder()
50 | clip.eval()
51 | clip.to(device)
52 |
53 | contexts = clip.encode(prompts)
54 |
55 | nnet = utils.get_nnet(**config.nnet)
56 | nnet = accelerator.prepare(nnet)
57 | logging.info(f'load nnet from {config.nnet_path}')
58 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
59 | nnet.eval()
60 |
61 | def cfg_nnet(x, timesteps, context):
62 | _cond = nnet(x, timesteps, context=context)
63 | if config.sample.scale == 0:
64 | return _cond
65 | _empty_context = torch.tensor(dataset.empty_context, device=device)
66 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
67 | _uncond = nnet(x, timesteps, context=_empty_context)
68 | return _cond + config.sample.scale * (_cond - _uncond)
69 |
70 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
71 | autoencoder.to(device)
72 |
73 | @torch.cuda.amp.autocast()
74 | def encode(_batch):
75 | return autoencoder.encode(_batch)
76 |
77 | @torch.cuda.amp.autocast()
78 | def decode(_batch):
79 | return autoencoder.decode(_batch)
80 |
81 | _betas = stable_diffusion_beta_schedule()
82 | N = len(_betas)
83 |
84 | logging.info(config.sample)
85 | logging.info(f'mixed_precision={config.mixed_precision}')
86 | logging.info(f'N={N}')
87 |
88 | z_init = torch.randn(contexts.size(0), *config.z_shape, device=device)
89 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
90 |
91 | def model_fn(x, t_continuous):
92 | t = t_continuous * N
93 | return cfg_nnet(x, t, context=contexts)
94 |
95 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
96 | z = dpm_solver.sample(z_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
97 | samples = dataset.unpreprocess(decode(z))
98 |
99 | os.makedirs(config.output_path, exist_ok=True)
100 | for sample, prompt in zip(samples, prompts):
101 | save_image(sample, os.path.join(config.output_path, f"{prompt}.png"))
102 |
103 |
104 |
105 | from absl import flags
106 | from absl import app
107 | from ml_collections import config_flags
108 | import os
109 |
110 |
111 | FLAGS = flags.FLAGS
112 | config_flags.DEFINE_config_file(
113 | "config", None, "Training configuration.", lock_config=False)
114 | flags.mark_flags_as_required(["config"])
115 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
116 | flags.DEFINE_string("output_path", None, "The path to output images.")
117 | flags.DEFINE_string("input_path", None, "The path to input texts.")
118 |
119 |
120 | def main(argv):
121 | config = FLAGS.config
122 | config.nnet_path = FLAGS.nnet_path
123 | config.output_path = FLAGS.output_path
124 | config.input_path = FLAGS.input_path
125 | evaluate(config)
126 |
127 |
128 | if __name__ == "__main__":
129 | app.run(main)
130 |
--------------------------------------------------------------------------------
/scripts/extract_empty_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main():
12 | prompts = [
13 | '',
14 | ]
15 |
16 | device = 'cuda'
17 | clip = libs.clip.FrozenCLIPEmbedder()
18 | clip.eval()
19 | clip.to(device)
20 |
21 | save_dir = f'assets/datasets/coco256_features'
22 | latent = clip.encode(prompts)
23 | print(latent.shape)
24 | c = latent[0].detach().cpu().numpy()
25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c)
26 |
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/scripts/extract_imagenet_feature.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | from datasets import ImageNet
5 | from torch.utils.data import DataLoader
6 | from libs.autoencoder import get_model
7 | import argparse
8 | from tqdm import tqdm
9 | torch.manual_seed(0)
10 | np.random.seed(0)
11 |
12 |
13 | def main(resolution=256):
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('path')
16 | args = parser.parse_args()
17 |
18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False)
19 | train_dataset = dataset.get_split(split='train', labeled=True)
20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False,
21 | num_workers=8, pin_memory=True, persistent_workers=True)
22 |
23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
24 | model = nn.DataParallel(model)
25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26 | model.to(device)
27 |
28 | # features = []
29 | # labels = []
30 |
31 | idx = 0
32 | for batch in tqdm(train_dataset_loader):
33 | img, label = batch
34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0)
35 | img = img.to(device)
36 | moments = model(img, fn='encode_moments')
37 | moments = moments.detach().cpu().numpy()
38 |
39 | label = torch.cat([label, label], dim=0)
40 | label = label.detach().cpu().numpy()
41 |
42 | for moment, lb in zip(moments, label):
43 | np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb))
44 | idx += 1
45 |
46 | print(f'save {idx} files')
47 |
48 | # features = np.concatenate(features, axis=0)
49 | # labels = np.concatenate(labels, axis=0)
50 | # print(f'features.shape={features.shape}')
51 | # print(f'labels.shape={labels.shape}')
52 | # np.save(f'imagenet{resolution}_features.npy', features)
53 | # np.save(f'imagenet{resolution}_labels.npy', labels)
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/scripts/extract_mscoco_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main(resolution=256):
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--split', default='train')
14 | args = parser.parse_args()
15 | print(args)
16 |
17 |
18 | if args.split == "train":
19 | datas = MSCOCODatabase(root='assets/datasets/coco/train2014',
20 | annFile='assets/datasets/coco/annotations/captions_train2014.json',
21 | size=resolution)
22 | save_dir = f'assets/datasets/coco{resolution}_features/train'
23 | elif args.split == "val":
24 | datas = MSCOCODatabase(root='assets/datasets/coco/val2014',
25 | annFile='assets/datasets/coco/annotations/captions_val2014.json',
26 | size=resolution)
27 | save_dir = f'assets/datasets/coco{resolution}_features/val'
28 | else:
29 | raise NotImplementedError("ERROR!")
30 |
31 | device = "cuda"
32 | os.makedirs(save_dir)
33 |
34 | autoencoder = libs.autoencoder.get_model('assets/stable-diffusion/autoencoder_kl.pth')
35 | autoencoder.to(device)
36 | clip = libs.clip.FrozenCLIPEmbedder()
37 | clip.eval()
38 | clip.to(device)
39 |
40 | with torch.no_grad():
41 | for idx, data in tqdm(enumerate(datas)):
42 | x, captions = data
43 |
44 | if len(x.shape) == 3:
45 | x = x[None, ...]
46 | x = torch.tensor(x, device=device)
47 | moments = autoencoder(x, fn='encode_moments').squeeze(0)
48 | moments = moments.detach().cpu().numpy()
49 | np.save(os.path.join(save_dir, f'{idx}.npy'), moments)
50 |
51 | latent = clip.encode(captions)
52 | for i in range(len(latent)):
53 | c = latent[i].detach().cpu().numpy()
54 | np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/scripts/extract_test_prompt_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main():
12 | prompts = [
13 | 'A green train is coming down the tracks.',
14 | 'A group of skiers are preparing to ski down a mountain.',
15 | 'A small kitchen with a low ceiling.',
16 | 'A group of elephants walking in muddy water.',
17 | 'A living area with a television and a table.',
18 | 'A road with traffic lights, street lights and cars.',
19 | 'A bus driving in a city area with traffic signs.',
20 | 'A bus pulls over to the curb close to an intersection.',
21 | 'A group of people are walking and one is holding an umbrella.',
22 | 'A baseball player taking a swing at an incoming ball.',
23 | 'A city street line with brick buildings and trees.',
24 | 'A close up of a plate of broccoli and sauce.',
25 | ]
26 |
27 | device = 'cuda'
28 | clip = libs.clip.FrozenCLIPEmbedder()
29 | clip.eval()
30 | clip.to(device)
31 |
32 | save_dir = f'assets/datasets/coco256_features/run_vis'
33 | latent = clip.encode(prompts)
34 | for i in range(len(latent)):
35 | c = latent[i].detach().cpu().numpy()
36 | np.save(os.path.join(save_dir, f'{i}.npy'), (prompts[i], c))
37 |
38 |
39 | if __name__ == '__main__':
40 | main()
41 |
--------------------------------------------------------------------------------
/sde.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from absl import logging
4 | import numpy as np
5 | import math
6 | from tqdm import tqdm
7 |
8 |
9 | def get_sde(name, **kwargs):
10 | if name == 'vpsde':
11 | return VPSDE(**kwargs)
12 | elif name == 'vpsde_cosine':
13 | return VPSDECosine(**kwargs)
14 | else:
15 | raise NotImplementedError
16 |
17 |
18 | def stp(s, ts: torch.Tensor): # scalar tensor product
19 | if isinstance(s, np.ndarray):
20 | s = torch.from_numpy(s).type_as(ts)
21 | extra_dims = (1,) * (ts.dim() - 1)
22 | return s.view(-1, *extra_dims) * ts
23 |
24 |
25 | def mos(a, start_dim=1): # mean of square
26 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
27 |
28 |
29 | def duplicate(tensor, *size):
30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape)
31 |
32 |
33 | class SDE(object):
34 | r"""
35 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
36 | f(x, t) is the drift
37 | g(t) is the diffusion
38 | """
39 | def drift(self, x, t):
40 | raise NotImplementedError
41 |
42 | def diffusion(self, t):
43 | raise NotImplementedError
44 |
45 | def cum_beta(self, t): # the variance of xt|x0
46 | raise NotImplementedError
47 |
48 | def cum_alpha(self, t):
49 | raise NotImplementedError
50 |
51 | def snr(self, t): # signal noise ratio
52 | raise NotImplementedError
53 |
54 | def nsr(self, t): # noise signal ratio
55 | raise NotImplementedError
56 |
57 | def marginal_prob(self, x0, t): # the mean and std of q(xt|x0)
58 | alpha = self.cum_alpha(t)
59 | beta = self.cum_beta(t)
60 | mean = stp(alpha ** 0.5, x0) # E[xt|x0]
61 | std = beta ** 0.5 # Cov[xt|x0] ** 0.5
62 | return mean, std
63 |
64 | def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform
65 | t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init
66 | mean, std = self.marginal_prob(x0, t)
67 | eps = torch.randn_like(x0)
68 | xt = mean + stp(std, eps)
69 | return t, eps, xt
70 |
71 |
72 | class VPSDE(SDE):
73 | def __init__(self, beta_min=0.1, beta_max=20):
74 | # 0 <= t <= 1
75 | self.beta_0 = beta_min
76 | self.beta_1 = beta_max
77 |
78 | def drift(self, x, t):
79 | return -0.5 * stp(self.squared_diffusion(t), x)
80 |
81 | def diffusion(self, t):
82 | return self.squared_diffusion(t) ** 0.5
83 |
84 | def squared_diffusion(self, t): # beta(t)
85 | return self.beta_0 + t * (self.beta_1 - self.beta_0)
86 |
87 | def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau
88 | return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5
89 |
90 | def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I
91 | return 1. - self.skip_alpha(s, t)
92 |
93 | def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs
94 | x = -self.squared_diffusion_integral(s, t)
95 | return x.exp()
96 |
97 | def cum_beta(self, t):
98 | return self.skip_beta(0, t)
99 |
100 | def cum_alpha(self, t):
101 | return self.skip_alpha(0, t)
102 |
103 | def nsr(self, t):
104 | return self.squared_diffusion_integral(0, t).expm1()
105 |
106 | def snr(self, t):
107 | return 1. / self.nsr(t)
108 |
109 | def __str__(self):
110 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
111 |
112 | def __repr__(self):
113 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
114 |
115 |
116 | class VPSDECosine(SDE):
117 | r"""
118 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
119 | f(x, t) is the drift
120 | g(t) is the diffusion
121 | """
122 | def __init__(self, s=0.008):
123 | self.s = s
124 | self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2
125 | self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2
126 |
127 | def drift(self, x, t):
128 | ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2
129 | return stp(ft, x)
130 |
131 | def diffusion(self, t):
132 | return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5
133 |
134 | def cum_beta(self, t): # the variance of xt|x0
135 | return 1 - self.cum_alpha(t)
136 |
137 | def cum_alpha(self, t):
138 | return self.F(t) / self.F0
139 |
140 | def snr(self, t): # signal noise ratio
141 | Ft = self.F(t)
142 | return Ft / (self.F0 - Ft)
143 |
144 | def nsr(self, t): # noise signal ratio
145 | Ft = self.F(t)
146 | return self.F0 / Ft - 1
147 |
148 | def __str__(self):
149 | return 'vpsde_cosine'
150 |
151 | def __repr__(self):
152 | return 'vpsde_cosine'
153 |
154 |
155 | class ScoreModel(object):
156 | r"""
157 | The forward process is q(x_[0,T])
158 | """
159 |
160 | def __init__(self, nnet: nn.Module, pred: str, sde: SDE, T=1):
161 | assert T == 1
162 | self.nnet = nnet
163 | self.pred = pred
164 | self.sde = sde
165 | self.T = T
166 | print(f'ScoreModel with pred={pred}, sde={sde}, T={T}')
167 |
168 | def predict(self, xt, t, **kwargs):
169 | if not isinstance(t, torch.Tensor):
170 | t = torch.tensor(t)
171 | t = t.to(xt.device)
172 | if t.dim() == 0:
173 | t = duplicate(t, xt.size(0))
174 | return self.nnet(xt, t * 999, **kwargs) # follow SDE
175 |
176 | def noise_pred(self, xt, t, **kwargs):
177 | pred = self.predict(xt, t, **kwargs)
178 | if self.pred == 'noise_pred':
179 | noise_pred = pred
180 | elif self.pred == 'x0_pred':
181 | noise_pred = - stp(self.sde.snr(t).sqrt(), pred) + stp(self.sde.cum_beta(t).rsqrt(), xt)
182 | else:
183 | raise NotImplementedError
184 | return noise_pred
185 |
186 | def x0_pred(self, xt, t, **kwargs):
187 | pred = self.predict(xt, t, **kwargs)
188 | if self.pred == 'noise_pred':
189 | x0_pred = stp(self.sde.cum_alpha(t).rsqrt(), xt) - stp(self.sde.nsr(t).sqrt(), pred)
190 | elif self.pred == 'x0_pred':
191 | x0_pred = pred
192 | else:
193 | raise NotImplementedError
194 | return x0_pred
195 |
196 | def score(self, xt, t, **kwargs):
197 | cum_beta = self.sde.cum_beta(t)
198 | noise_pred = self.noise_pred(xt, t, **kwargs)
199 | return stp(-cum_beta.rsqrt(), noise_pred)
200 |
201 |
202 | class ReverseSDE(object):
203 | r"""
204 | dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw
205 | """
206 | def __init__(self, score_model):
207 | self.sde = score_model.sde # the forward sde
208 | self.score_model = score_model
209 |
210 | def drift(self, x, t, **kwargs):
211 | drift = self.sde.drift(x, t) # f(x, t)
212 | diffusion = self.sde.diffusion(t) # g(t)
213 | score = self.score_model.score(x, t, **kwargs)
214 | return drift - stp(diffusion ** 2, score)
215 |
216 | def diffusion(self, t):
217 | return self.sde.diffusion(t)
218 |
219 |
220 | class ODE(object):
221 | r"""
222 | dx = [f(x, t) - g(t)^2 s(x, t)] dt
223 | """
224 |
225 | def __init__(self, score_model):
226 | self.sde = score_model.sde # the forward sde
227 | self.score_model = score_model
228 |
229 | def drift(self, x, t, **kwargs):
230 | drift = self.sde.drift(x, t) # f(x, t)
231 | diffusion = self.sde.diffusion(t) # g(t)
232 | score = self.score_model.score(x, t, **kwargs)
233 | return drift - 0.5 * stp(diffusion ** 2, score)
234 |
235 | def diffusion(self, t):
236 | return 0
237 |
238 |
239 | def dct2str(dct):
240 | return str({k: f'{v:.6g}' for k, v in dct.items()})
241 |
242 |
243 | @ torch.no_grad()
244 | def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs):
245 | r"""
246 | The Euler Maruyama sampler for reverse SDE / ODE
247 | See `Score-Based Generative Modeling through Stochastic Differential Equations`
248 | """
249 | assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE)
250 | print(f"euler_maruyama with sample_steps={sample_steps}")
251 | timesteps = np.append(0., np.linspace(eps, T, sample_steps))
252 | timesteps = torch.tensor(timesteps).to(x_init)
253 | x = x_init
254 | if trace is not None:
255 | trace.append(x)
256 | for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'):
257 | drift = rsde.drift(x, t, **kwargs)
258 | diffusion = rsde.diffusion(t)
259 | dt = s - t
260 | mean = x + drift * dt
261 | sigma = diffusion * (-dt).sqrt()
262 | x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean
263 | if trace is not None:
264 | trace.append(x)
265 | statistics = dict(s=s, t=t, sigma=sigma.item())
266 | logging.debug(dct2str(statistics))
267 | return x
268 |
269 |
270 | def LSimple(score_model: ScoreModel, x0, pred='noise_pred', **kwargs):
271 | t, noise, xt = score_model.sde.sample(x0)
272 | if pred == 'noise_pred':
273 | noise_pred = score_model.noise_pred(xt, t, **kwargs)
274 | return mos(noise - noise_pred)
275 | elif pred == 'x0_pred':
276 | x0_pred = score_model.x0_pred(xt, t, **kwargs)
277 | return mos(x0 - x0_pred)
278 | else:
279 | raise NotImplementedError(pred)
280 |
--------------------------------------------------------------------------------
/skip_im.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baofff/U-ViT/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/skip_im.png
--------------------------------------------------------------------------------
/tools/fid_score.py:
--------------------------------------------------------------------------------
1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2 |
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 |
7 | When run as a stand-alone program, it compares the distribution of
8 | images that are stored as PNG/JPEG at a specified location with a
9 | distribution given by summary statistics (in pickle format).
10 |
11 | The FID is calculated by assuming that X_1 and X_2 are the activations of
12 | the pool_3 layer of the inception net for generated samples and real world
13 | samples respectively.
14 |
15 | See --help to see further details.
16 |
17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18 | of Tensorflow
19 |
20 | Copyright 2018 Institute of Bioinformatics, JKU Linz
21 |
22 | Licensed under the Apache License, Version 2.0 (the "License");
23 | you may not use this file except in compliance with the License.
24 | You may obtain a copy of the License at
25 |
26 | http://www.apache.org/licenses/LICENSE-2.0
27 |
28 | Unless required by applicable law or agreed to in writing, software
29 | distributed under the License is distributed on an "AS IS" BASIS,
30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31 | See the License for the specific language governing permissions and
32 | limitations under the License.
33 | """
34 | import os
35 | import pathlib
36 |
37 | import numpy as np
38 | import torch
39 | import torchvision.transforms as TF
40 | from PIL import Image
41 | from scipy import linalg
42 | from torch.nn.functional import adaptive_avg_pool2d
43 |
44 | try:
45 | from tqdm import tqdm
46 | except ImportError:
47 | # If tqdm is not available, provide a mock version of it
48 | def tqdm(x):
49 | return x
50 |
51 | from .inception import InceptionV3
52 |
53 |
54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
55 | 'tif', 'tiff', 'webp'}
56 |
57 |
58 | class ImagePathDataset(torch.utils.data.Dataset):
59 | def __init__(self, files, transforms=None):
60 | self.files = files
61 | self.transforms = transforms
62 |
63 | def __len__(self):
64 | return len(self.files)
65 |
66 | def __getitem__(self, i):
67 | path = self.files[i]
68 | img = Image.open(path).convert('RGB')
69 | if self.transforms is not None:
70 | img = self.transforms(img)
71 | return img
72 |
73 |
74 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
75 | """Calculates the activations of the pool_3 layer for all images.
76 |
77 | Params:
78 | -- files : List of image files paths
79 | -- model : Instance of inception model
80 | -- batch_size : Batch size of images for the model to process at once.
81 | Make sure that the number of samples is a multiple of
82 | the batch size, otherwise some samples are ignored. This
83 | behavior is retained to match the original FID score
84 | implementation.
85 | -- dims : Dimensionality of features returned by Inception
86 | -- device : Device to run calculations
87 | -- num_workers : Number of parallel dataloader workers
88 |
89 | Returns:
90 | -- A numpy array of dimension (num images, dims) that contains the
91 | activations of the given tensor when feeding inception with the
92 | query tensor.
93 | """
94 | model.eval()
95 |
96 | if batch_size > len(files):
97 | print(('Warning: batch size is bigger than the data size. '
98 | 'Setting batch size to data size'))
99 | batch_size = len(files)
100 |
101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor())
102 | dataloader = torch.utils.data.DataLoader(dataset,
103 | batch_size=batch_size,
104 | shuffle=False,
105 | drop_last=False,
106 | num_workers=num_workers)
107 |
108 | pred_arr = np.empty((len(files), dims))
109 |
110 | start_idx = 0
111 |
112 | for batch in tqdm(dataloader):
113 | batch = batch.to(device)
114 |
115 | with torch.no_grad():
116 | pred = model(batch)[0]
117 |
118 | # If model output is not scalar, apply global spatial average pooling.
119 | # This happens if you choose a dimensionality not equal 2048.
120 | if pred.size(2) != 1 or pred.size(3) != 1:
121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
122 |
123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
124 |
125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred
126 |
127 | start_idx = start_idx + pred.shape[0]
128 |
129 | return pred_arr
130 |
131 |
132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
133 | """Numpy implementation of the Frechet Distance.
134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
135 | and X_2 ~ N(mu_2, C_2) is
136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
137 |
138 | Stable version by Dougal J. Sutherland.
139 |
140 | Params:
141 | -- mu1 : Numpy array containing the activations of a layer of the
142 | inception net (like returned by the function 'get_predictions')
143 | for generated samples.
144 | -- mu2 : The sample mean over activations, precalculated on an
145 | representative data set.
146 | -- sigma1: The covariance matrix over activations for generated samples.
147 | -- sigma2: The covariance matrix over activations, precalculated on an
148 | representative data set.
149 |
150 | Returns:
151 | -- : The Frechet Distance.
152 | """
153 |
154 | mu1 = np.atleast_1d(mu1)
155 | mu2 = np.atleast_1d(mu2)
156 |
157 | sigma1 = np.atleast_2d(sigma1)
158 | sigma2 = np.atleast_2d(sigma2)
159 |
160 | assert mu1.shape == mu2.shape, \
161 | 'Training and test mean vectors have different lengths'
162 | assert sigma1.shape == sigma2.shape, \
163 | 'Training and test covariances have different dimensions'
164 |
165 | diff = mu1 - mu2
166 |
167 | # Product might be almost singular
168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
169 | if not np.isfinite(covmean).all():
170 | msg = ('fid calculation produces singular product; '
171 | 'adding %s to diagonal of cov estimates') % eps
172 | print(msg)
173 | offset = np.eye(sigma1.shape[0]) * eps
174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
175 |
176 | # Numerical error might give slight imaginary component
177 | if np.iscomplexobj(covmean):
178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
179 | m = np.max(np.abs(covmean.imag))
180 | raise ValueError('Imaginary component {}'.format(m))
181 | covmean = covmean.real
182 |
183 | tr_covmean = np.trace(covmean)
184 |
185 | return (diff.dot(diff) + np.trace(sigma1)
186 | + np.trace(sigma2) - 2 * tr_covmean)
187 |
188 |
189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
190 | device='cpu', num_workers=8):
191 | """Calculation of the statistics used by the FID.
192 | Params:
193 | -- files : List of image files paths
194 | -- model : Instance of inception model
195 | -- batch_size : The images numpy array is split into batches with
196 | batch size batch_size. A reasonable batch size
197 | depends on the hardware.
198 | -- dims : Dimensionality of features returned by Inception
199 | -- device : Device to run calculations
200 | -- num_workers : Number of parallel dataloader workers
201 |
202 | Returns:
203 | -- mu : The mean over samples of the activations of the pool_3 layer of
204 | the inception model.
205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
206 | the inception model.
207 | """
208 | act = get_activations(files, model, batch_size, dims, device, num_workers)
209 | mu = np.mean(act, axis=0)
210 | sigma = np.cov(act, rowvar=False)
211 | return mu, sigma
212 |
213 |
214 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
215 | if path.endswith('.npz'):
216 | with np.load(path) as f:
217 | m, s = f['mu'][:], f['sigma'][:]
218 | else:
219 | path = pathlib.Path(path)
220 | files = sorted([file for ext in IMAGE_EXTENSIONS
221 | for file in path.glob('*.{}'.format(ext))])
222 | m, s = calculate_activation_statistics(files, model, batch_size,
223 | dims, device, num_workers)
224 |
225 | return m, s
226 |
227 |
228 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
229 | if device is None:
230 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
231 | else:
232 | device = torch.device(device)
233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
234 | model = InceptionV3([block_idx]).to(device)
235 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
236 | np.savez(out_path, mu=m1, sigma=s1)
237 |
238 |
239 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
240 | """Calculates the FID of two paths"""
241 | if device is None:
242 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
243 | else:
244 | device = torch.device(device)
245 |
246 | for p in paths:
247 | if not os.path.exists(p):
248 | raise RuntimeError('Invalid path: %s' % p)
249 |
250 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
251 |
252 | model = InceptionV3([block_idx]).to(device)
253 |
254 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
255 | dims, device, num_workers)
256 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
257 | dims, device, num_workers)
258 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
259 |
260 | return fid_value
261 |
--------------------------------------------------------------------------------
/tools/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=(DEFAULT_BLOCK_INDEX,),
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 |
39 | Parameters
40 | ----------
41 | output_blocks : list of int
42 | Indices of blocks to return features of. Possible values are:
43 | - 0: corresponds to output of first max pooling
44 | - 1: corresponds to output of second max pooling
45 | - 2: corresponds to output which is fed to aux classifier
46 | - 3: corresponds to output of final average pooling
47 | resize_input : bool
48 | If true, bilinearly resizes input to width and height 299 before
49 | feeding input to model. As the network without fully connected
50 | layers is fully convolutional, it should be able to handle inputs
51 | of arbitrary size, so resizing might not be strictly needed
52 | normalize_input : bool
53 | If true, scales the input from range (0, 1) to the range the
54 | pretrained Inception network expects, namely (-1, 1)
55 | requires_grad : bool
56 | If true, parameters of the model require gradients. Possibly useful
57 | for finetuning the network
58 | use_fid_inception : bool
59 | If true, uses the pretrained Inception model used in Tensorflow's
60 | FID implementation. If false, uses the pretrained Inception model
61 | available in torchvision. The FID Inception model has different
62 | weights and a slightly different structure from torchvision's
63 | Inception model. If you want to compute FID scores, you are
64 | strongly advised to set this parameter to true to get comparable
65 | results.
66 | """
67 | super(InceptionV3, self).__init__()
68 |
69 | self.resize_input = resize_input
70 | self.normalize_input = normalize_input
71 | self.output_blocks = sorted(output_blocks)
72 | self.last_needed_block = max(output_blocks)
73 |
74 | assert self.last_needed_block <= 3, \
75 | 'Last possible output block index is 3'
76 |
77 | self.blocks = nn.ModuleList()
78 |
79 | if use_fid_inception:
80 | inception = fid_inception_v3()
81 | else:
82 | inception = _inception_v3(pretrained=True)
83 |
84 | # Block 0: input to maxpool1
85 | block0 = [
86 | inception.Conv2d_1a_3x3,
87 | inception.Conv2d_2a_3x3,
88 | inception.Conv2d_2b_3x3,
89 | nn.MaxPool2d(kernel_size=3, stride=2)
90 | ]
91 | self.blocks.append(nn.Sequential(*block0))
92 |
93 | # Block 1: maxpool1 to maxpool2
94 | if self.last_needed_block >= 1:
95 | block1 = [
96 | inception.Conv2d_3b_1x1,
97 | inception.Conv2d_4a_3x3,
98 | nn.MaxPool2d(kernel_size=3, stride=2)
99 | ]
100 | self.blocks.append(nn.Sequential(*block1))
101 |
102 | # Block 2: maxpool2 to aux classifier
103 | if self.last_needed_block >= 2:
104 | block2 = [
105 | inception.Mixed_5b,
106 | inception.Mixed_5c,
107 | inception.Mixed_5d,
108 | inception.Mixed_6a,
109 | inception.Mixed_6b,
110 | inception.Mixed_6c,
111 | inception.Mixed_6d,
112 | inception.Mixed_6e,
113 | ]
114 | self.blocks.append(nn.Sequential(*block2))
115 |
116 | # Block 3: aux classifier to final avgpool
117 | if self.last_needed_block >= 3:
118 | block3 = [
119 | inception.Mixed_7a,
120 | inception.Mixed_7b,
121 | inception.Mixed_7c,
122 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
123 | ]
124 | self.blocks.append(nn.Sequential(*block3))
125 |
126 | for param in self.parameters():
127 | param.requires_grad = requires_grad
128 |
129 | def forward(self, inp):
130 | """Get Inception feature maps
131 |
132 | Parameters
133 | ----------
134 | inp : torch.autograd.Variable
135 | Input tensor of shape Bx3xHxW. Values are expected to be in
136 | range (0, 1)
137 |
138 | Returns
139 | -------
140 | List of torch.autograd.Variable, corresponding to the selected output
141 | block, sorted ascending by index
142 | """
143 | outp = []
144 | x = inp
145 |
146 | if self.resize_input:
147 | x = F.interpolate(x,
148 | size=(299, 299),
149 | mode='bilinear',
150 | align_corners=False)
151 |
152 | if self.normalize_input:
153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154 |
155 | for idx, block in enumerate(self.blocks):
156 | x = block(x)
157 | if idx in self.output_blocks:
158 | outp.append(x)
159 |
160 | if idx == self.last_needed_block:
161 | break
162 |
163 | return outp
164 |
165 |
166 | def _inception_v3(*args, **kwargs):
167 | """Wraps `torchvision.models.inception_v3`
168 |
169 | Skips default weight inititialization if supported by torchvision version.
170 | See https://github.com/mseitzer/pytorch-fid/issues/28.
171 | """
172 | try:
173 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
174 | except ValueError:
175 | # Just a caution against weird version strings
176 | version = (0,)
177 |
178 | if version >= (0, 6):
179 | kwargs['init_weights'] = False
180 |
181 | return torchvision.models.inception_v3(*args, **kwargs)
182 |
183 |
184 | def fid_inception_v3():
185 | """Build pretrained Inception model for FID computation
186 |
187 | The Inception model for FID computation uses a different set of weights
188 | and has a slightly different structure than torchvision's Inception.
189 |
190 | This method first constructs torchvision's Inception and then patches the
191 | necessary parts that are different in the FID Inception model.
192 | """
193 | inception = _inception_v3(num_classes=1008,
194 | aux_logits=False,
195 | pretrained=False)
196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203 | inception.Mixed_7b = FIDInceptionE_1(1280)
204 | inception.Mixed_7c = FIDInceptionE_2(2048)
205 |
206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
207 | inception.load_state_dict(state_dict)
208 | return inception
209 |
210 |
211 | class FIDInceptionA(torchvision.models.inception.InceptionA):
212 | """InceptionA block patched for FID computation"""
213 | def __init__(self, in_channels, pool_features):
214 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
215 |
216 | def forward(self, x):
217 | branch1x1 = self.branch1x1(x)
218 |
219 | branch5x5 = self.branch5x5_1(x)
220 | branch5x5 = self.branch5x5_2(branch5x5)
221 |
222 | branch3x3dbl = self.branch3x3dbl_1(x)
223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225 |
226 | # Patch: Tensorflow's average pool does not use the padded zero's in
227 | # its average calculation
228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229 | count_include_pad=False)
230 | branch_pool = self.branch_pool(branch_pool)
231 |
232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233 | return torch.cat(outputs, 1)
234 |
235 |
236 | class FIDInceptionC(torchvision.models.inception.InceptionC):
237 | """InceptionC block patched for FID computation"""
238 | def __init__(self, in_channels, channels_7x7):
239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240 |
241 | def forward(self, x):
242 | branch1x1 = self.branch1x1(x)
243 |
244 | branch7x7 = self.branch7x7_1(x)
245 | branch7x7 = self.branch7x7_2(branch7x7)
246 | branch7x7 = self.branch7x7_3(branch7x7)
247 |
248 | branch7x7dbl = self.branch7x7dbl_1(x)
249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253 |
254 | # Patch: Tensorflow's average pool does not use the padded zero's in
255 | # its average calculation
256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257 | count_include_pad=False)
258 | branch_pool = self.branch_pool(branch_pool)
259 |
260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261 | return torch.cat(outputs, 1)
262 |
263 |
264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265 | """First InceptionE block patched for FID computation"""
266 | def __init__(self, in_channels):
267 | super(FIDInceptionE_1, self).__init__(in_channels)
268 |
269 | def forward(self, x):
270 | branch1x1 = self.branch1x1(x)
271 |
272 | branch3x3 = self.branch3x3_1(x)
273 | branch3x3 = [
274 | self.branch3x3_2a(branch3x3),
275 | self.branch3x3_2b(branch3x3),
276 | ]
277 | branch3x3 = torch.cat(branch3x3, 1)
278 |
279 | branch3x3dbl = self.branch3x3dbl_1(x)
280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281 | branch3x3dbl = [
282 | self.branch3x3dbl_3a(branch3x3dbl),
283 | self.branch3x3dbl_3b(branch3x3dbl),
284 | ]
285 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
286 |
287 | # Patch: Tensorflow's average pool does not use the padded zero's in
288 | # its average calculation
289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290 | count_include_pad=False)
291 | branch_pool = self.branch_pool(branch_pool)
292 |
293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294 | return torch.cat(outputs, 1)
295 |
296 |
297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298 | """Second InceptionE block patched for FID computation"""
299 | def __init__(self, in_channels):
300 | super(FIDInceptionE_2, self).__init__(in_channels)
301 |
302 | def forward(self, x):
303 | branch1x1 = self.branch1x1(x)
304 |
305 | branch3x3 = self.branch3x3_1(x)
306 | branch3x3 = [
307 | self.branch3x3_2a(branch3x3),
308 | self.branch3x3_2b(branch3x3),
309 | ]
310 | branch3x3 = torch.cat(branch3x3, 1)
311 |
312 | branch3x3dbl = self.branch3x3dbl_1(x)
313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314 | branch3x3dbl = [
315 | self.branch3x3dbl_3a(branch3x3dbl),
316 | self.branch3x3dbl_3b(branch3x3dbl),
317 | ]
318 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
319 |
320 | # Patch: The FID Inception model uses max pooling instead of average
321 | # pooling. This is likely an error in this specific Inception
322 | # implementation, as other Inception models use average pooling here
323 | # (which matches the description in the paper).
324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325 | branch_pool = self.branch_pool(branch_pool)
326 |
327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328 | return torch.cat(outputs, 1)
329 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sde
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | from datasets import get_dataset
6 | from torchvision.utils import make_grid, save_image
7 | import utils
8 | import einops
9 | from torch.utils._pytree import tree_map
10 | import accelerate
11 | from torch.utils.data import DataLoader
12 | from tqdm.auto import tqdm
13 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
14 | import tempfile
15 | from tools.fid_score import calculate_fid_given_paths
16 | from absl import logging
17 | import builtins
18 | import os
19 | import wandb
20 |
21 |
22 | def train(config):
23 | if config.get('benchmark', False):
24 | torch.backends.cudnn.benchmark = True
25 | torch.backends.cudnn.deterministic = False
26 |
27 | mp.set_start_method('spawn')
28 | accelerator = accelerate.Accelerator()
29 | device = accelerator.device
30 | accelerate.utils.set_seed(config.seed, device_specific=True)
31 | logging.info(f'Process {accelerator.process_index} using device: {device}')
32 |
33 | config.mixed_precision = accelerator.mixed_precision
34 | config = ml_collections.FrozenConfigDict(config)
35 |
36 | assert config.train.batch_size % accelerator.num_processes == 0
37 | mini_batch_size = config.train.batch_size // accelerator.num_processes
38 |
39 | if accelerator.is_main_process:
40 | os.makedirs(config.ckpt_root, exist_ok=True)
41 | os.makedirs(config.sample_dir, exist_ok=True)
42 | accelerator.wait_for_everyone()
43 | if accelerator.is_main_process:
44 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
45 | name=config.hparams, job_type='train', mode='offline')
46 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
47 | logging.info(config)
48 | else:
49 | utils.set_logger(log_level='error')
50 | builtins.print = lambda *args: None
51 |
52 | dataset = get_dataset(**config.dataset)
53 | assert os.path.exists(dataset.fid_stat)
54 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
55 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
56 | num_workers=8, pin_memory=True, persistent_workers=True)
57 |
58 | train_state = utils.initialize_train_state(config, device)
59 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
60 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
61 | lr_scheduler = train_state.lr_scheduler
62 | train_state.resume(config.ckpt_root)
63 |
64 | def get_data_generator():
65 | while True:
66 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
67 | yield data
68 |
69 | data_generator = get_data_generator()
70 |
71 |
72 | # set the score_model to train
73 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
74 | score_model_ema = sde.ScoreModel(nnet_ema, pred=config.pred, sde=sde.VPSDE())
75 |
76 |
77 | def train_step(_batch):
78 | _metrics = dict()
79 | optimizer.zero_grad()
80 | if config.train.mode == 'uncond':
81 | loss = sde.LSimple(score_model, _batch, pred=config.pred)
82 | elif config.train.mode == 'cond':
83 | loss = sde.LSimple(score_model, _batch[0], pred=config.pred, y=_batch[1])
84 | else:
85 | raise NotImplementedError(config.train.mode)
86 | _metrics['loss'] = accelerator.gather(loss.detach()).mean()
87 | accelerator.backward(loss.mean())
88 | if 'grad_clip' in config and config.grad_clip > 0:
89 | accelerator.clip_grad_norm_(nnet.parameters(), max_norm=config.grad_clip)
90 | optimizer.step()
91 | lr_scheduler.step()
92 | train_state.ema_update(config.get('ema_rate', 0.9999))
93 | train_state.step += 1
94 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
95 |
96 |
97 | def eval_step(n_samples, sample_steps, algorithm):
98 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm={algorithm}, '
99 | f'mini_batch_size={config.sample.mini_batch_size}')
100 |
101 | def sample_fn(_n_samples):
102 | _x_init = torch.randn(_n_samples, *dataset.data_shape, device=device)
103 | if config.train.mode == 'uncond':
104 | kwargs = dict()
105 | elif config.train.mode == 'cond':
106 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
107 | else:
108 | raise NotImplementedError
109 |
110 | if algorithm == 'euler_maruyama_sde':
111 | return sde.euler_maruyama(sde.ReverseSDE(score_model_ema), _x_init, sample_steps, **kwargs)
112 | elif algorithm == 'euler_maruyama_ode':
113 | return sde.euler_maruyama(sde.ODE(score_model_ema), _x_init, sample_steps, **kwargs)
114 | elif algorithm == 'dpm_solver':
115 | noise_schedule = NoiseScheduleVP(schedule='linear')
116 | model_fn = model_wrapper(
117 | score_model_ema.noise_pred,
118 | noise_schedule,
119 | time_input_type='0',
120 | model_kwargs=kwargs
121 | )
122 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
123 | return dpm_solver.sample(
124 | _x_init,
125 | steps=sample_steps,
126 | eps=1e-4,
127 | adaptive_step_size=False,
128 | fast_version=True,
129 | )
130 | else:
131 | raise NotImplementedError
132 |
133 | with tempfile.TemporaryDirectory() as temp_path:
134 | path = config.sample.path or temp_path
135 | if accelerator.is_main_process:
136 | os.makedirs(path, exist_ok=True)
137 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
138 |
139 | _fid = 0
140 | if accelerator.is_main_process:
141 | _fid = calculate_fid_given_paths((dataset.fid_stat, path))
142 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
143 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
144 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
145 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
146 | _fid = torch.tensor(_fid, device=device)
147 | _fid = accelerator.reduce(_fid, reduction='sum')
148 |
149 | return _fid.item()
150 |
151 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
152 |
153 | step_fid = []
154 | while train_state.step < config.train.n_steps:
155 | nnet.train()
156 | batch = tree_map(lambda x: x.to(device), next(data_generator))
157 | metrics = train_step(batch)
158 |
159 | nnet.eval()
160 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
161 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
162 | logging.info(config.workdir)
163 | wandb.log(metrics, step=train_state.step)
164 |
165 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
166 | logging.info('Save a grid of images...')
167 | x_init = torch.randn(100, *dataset.data_shape, device=device)
168 | if config.train.mode == 'uncond':
169 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50)
170 | elif config.train.mode == 'cond':
171 | y = einops.repeat(torch.arange(10, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
172 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50, y=y)
173 | else:
174 | raise NotImplementedError
175 | samples = make_grid(dataset.unpreprocess(samples), 10)
176 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
177 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
178 | torch.cuda.empty_cache()
179 | accelerator.wait_for_everyone()
180 |
181 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
182 | logging.info(f'Save and eval checkpoint {train_state.step}...')
183 | if accelerator.local_process_index == 0:
184 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
185 | accelerator.wait_for_everyone()
186 | fid = eval_step(n_samples=10000, sample_steps=50, algorithm='dpm_solver') # calculate fid of the saved checkpoint
187 | step_fid.append((train_state.step, fid))
188 | torch.cuda.empty_cache()
189 | accelerator.wait_for_everyone()
190 |
191 | logging.info(f'Finish fitting, step={train_state.step}')
192 | logging.info(f'step_fid: {step_fid}')
193 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
194 | logging.info(f'step_best: {step_best}')
195 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
196 | del metrics
197 | accelerator.wait_for_everyone()
198 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps, algorithm=config.sample.algorithm)
199 |
200 |
201 |
202 | from absl import flags
203 | from absl import app
204 | from ml_collections import config_flags
205 | import sys
206 | from pathlib import Path
207 |
208 |
209 | FLAGS = flags.FLAGS
210 | config_flags.DEFINE_config_file(
211 | "config", None, "Training configuration.", lock_config=False)
212 | flags.mark_flags_as_required(["config"])
213 | flags.DEFINE_string("workdir", None, "Work unit directory.")
214 |
215 |
216 | def get_config_name():
217 | argv = sys.argv
218 | for i in range(1, len(argv)):
219 | if argv[i].startswith('--config='):
220 | return Path(argv[i].split('=')[-1]).stem
221 |
222 |
223 | def get_hparams():
224 | argv = sys.argv
225 | lst = []
226 | for i in range(1, len(argv)):
227 | assert '=' in argv[i]
228 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
229 | hparam, val = argv[i].split('=')
230 | hparam = hparam.split('.')[-1]
231 | if hparam.endswith('path'):
232 | val = Path(val).stem
233 | lst.append(f'{hparam}={val}')
234 | hparams = '-'.join(lst)
235 | if hparams == '':
236 | hparams = 'default'
237 | return hparams
238 |
239 |
240 | def main(argv):
241 | config = FLAGS.config
242 | config.config_name = get_config_name()
243 | config.hparams = get_hparams()
244 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
245 | config.ckpt_root = os.path.join(config.workdir, 'ckpts')
246 | config.sample_dir = os.path.join(config.workdir, 'samples')
247 | train(config)
248 |
249 |
250 | if __name__ == "__main__":
251 | app.run(main)
252 |
--------------------------------------------------------------------------------
/train_ldm.py:
--------------------------------------------------------------------------------
1 | import sde
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | from datasets import get_dataset
6 | from torchvision.utils import make_grid, save_image
7 | import utils
8 | import einops
9 | from torch.utils._pytree import tree_map
10 | import accelerate
11 | from torch.utils.data import DataLoader
12 | from tqdm.auto import tqdm
13 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
14 | import tempfile
15 | from tools.fid_score import calculate_fid_given_paths
16 | from absl import logging
17 | import builtins
18 | import os
19 | import wandb
20 | import libs.autoencoder
21 |
22 |
23 | def train(config):
24 | if config.get('benchmark', False):
25 | torch.backends.cudnn.benchmark = True
26 | torch.backends.cudnn.deterministic = False
27 |
28 | mp.set_start_method('spawn')
29 | accelerator = accelerate.Accelerator()
30 | device = accelerator.device
31 | accelerate.utils.set_seed(config.seed, device_specific=True)
32 | logging.info(f'Process {accelerator.process_index} using device: {device}')
33 |
34 | config.mixed_precision = accelerator.mixed_precision
35 | config = ml_collections.FrozenConfigDict(config)
36 |
37 | assert config.train.batch_size % accelerator.num_processes == 0
38 | mini_batch_size = config.train.batch_size // accelerator.num_processes
39 |
40 | if accelerator.is_main_process:
41 | os.makedirs(config.ckpt_root, exist_ok=True)
42 | os.makedirs(config.sample_dir, exist_ok=True)
43 | accelerator.wait_for_everyone()
44 | if accelerator.is_main_process:
45 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
46 | name=config.hparams, job_type='train', mode='offline')
47 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
48 | logging.info(config)
49 | else:
50 | utils.set_logger(log_level='error')
51 | builtins.print = lambda *args: None
52 | logging.info(f'Run on {accelerator.num_processes} devices')
53 |
54 | dataset = get_dataset(**config.dataset)
55 | assert os.path.exists(dataset.fid_stat)
56 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
57 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
58 | num_workers=8, pin_memory=True, persistent_workers=True)
59 |
60 | train_state = utils.initialize_train_state(config, device)
61 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
62 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
63 | lr_scheduler = train_state.lr_scheduler
64 | train_state.resume(config.ckpt_root)
65 |
66 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
67 | autoencoder.to(device)
68 |
69 | @ torch.cuda.amp.autocast()
70 | def encode(_batch):
71 | return autoencoder.encode(_batch)
72 |
73 | @ torch.cuda.amp.autocast()
74 | def decode(_batch):
75 | return autoencoder.decode(_batch)
76 |
77 | def get_data_generator():
78 | while True:
79 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
80 | yield data
81 |
82 | data_generator = get_data_generator()
83 |
84 |
85 | # set the score_model to train
86 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
87 | score_model_ema = sde.ScoreModel(nnet_ema, pred=config.pred, sde=sde.VPSDE())
88 |
89 |
90 | def train_step(_batch):
91 | _metrics = dict()
92 | optimizer.zero_grad()
93 | if config.train.mode == 'uncond':
94 | _z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch)
95 | loss = sde.LSimple(score_model, _z, pred=config.pred)
96 | elif config.train.mode == 'cond':
97 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
98 | loss = sde.LSimple(score_model, _z, pred=config.pred, y=_batch[1])
99 | else:
100 | raise NotImplementedError(config.train.mode)
101 | _metrics['loss'] = accelerator.gather(loss.detach()).mean()
102 | accelerator.backward(loss.mean())
103 | optimizer.step()
104 | lr_scheduler.step()
105 | train_state.ema_update(config.get('ema_rate', 0.9999))
106 | train_state.step += 1
107 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
108 |
109 |
110 | def eval_step(n_samples, sample_steps, algorithm):
111 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm={algorithm}, '
112 | f'mini_batch_size={config.sample.mini_batch_size}')
113 |
114 | def sample_fn(_n_samples):
115 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
116 | if config.train.mode == 'uncond':
117 | kwargs = dict()
118 | elif config.train.mode == 'cond':
119 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
120 | else:
121 | raise NotImplementedError
122 |
123 | if algorithm == 'euler_maruyama_sde':
124 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model_ema), _z_init, sample_steps, **kwargs)
125 | elif algorithm == 'euler_maruyama_ode':
126 | _z = sde.euler_maruyama(sde.ODE(score_model_ema), _z_init, sample_steps, **kwargs)
127 | elif algorithm == 'dpm_solver':
128 | noise_schedule = NoiseScheduleVP(schedule='linear')
129 | model_fn = model_wrapper(
130 | score_model_ema.noise_pred,
131 | noise_schedule,
132 | time_input_type='0',
133 | model_kwargs=kwargs
134 | )
135 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
136 | _z = dpm_solver.sample(
137 | _z_init,
138 | steps=sample_steps,
139 | eps=1e-4,
140 | adaptive_step_size=False,
141 | fast_version=True,
142 | )
143 | else:
144 | raise NotImplementedError
145 | return decode(_z)
146 |
147 | with tempfile.TemporaryDirectory() as temp_path:
148 | path = config.sample.path or temp_path
149 | if accelerator.is_main_process:
150 | os.makedirs(path, exist_ok=True)
151 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
152 |
153 | _fid = 0
154 | if accelerator.is_main_process:
155 | _fid = calculate_fid_given_paths((dataset.fid_stat, path))
156 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
157 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
158 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
159 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
160 | _fid = torch.tensor(_fid, device=device)
161 | _fid = accelerator.reduce(_fid, reduction='sum')
162 |
163 | return _fid.item()
164 |
165 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
166 |
167 | step_fid = []
168 | while train_state.step < config.train.n_steps:
169 | nnet.train()
170 | batch = tree_map(lambda x: x.to(device), next(data_generator))
171 | metrics = train_step(batch)
172 |
173 | nnet.eval()
174 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
175 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
176 | logging.info(config.workdir)
177 | wandb.log(metrics, step=train_state.step)
178 |
179 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
180 | torch.cuda.empty_cache()
181 | logging.info('Save a grid of images...')
182 | z_init = torch.randn(5 * 10, *config.z_shape, device=device)
183 | if config.train.mode == 'uncond':
184 | z = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=z_init, sample_steps=50)
185 | elif config.train.mode == 'cond':
186 | y = einops.repeat(torch.arange(5, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
187 | z = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=z_init, sample_steps=50, y=y)
188 | else:
189 | raise NotImplementedError
190 | samples = decode(z)
191 | samples = make_grid(dataset.unpreprocess(samples), 10)
192 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
193 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
194 | torch.cuda.empty_cache()
195 | accelerator.wait_for_everyone()
196 |
197 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
198 | torch.cuda.empty_cache()
199 | logging.info(f'Save and eval checkpoint {train_state.step}...')
200 | if accelerator.local_process_index == 0:
201 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
202 | accelerator.wait_for_everyone()
203 | fid = eval_step(n_samples=10000, sample_steps=50, algorithm='dpm_solver') # calculate fid of the saved checkpoint
204 | step_fid.append((train_state.step, fid))
205 | torch.cuda.empty_cache()
206 | accelerator.wait_for_everyone()
207 |
208 | logging.info(f'Finish fitting, step={train_state.step}')
209 | logging.info(f'step_fid: {step_fid}')
210 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
211 | logging.info(f'step_best: {step_best}')
212 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
213 | del metrics
214 | accelerator.wait_for_everyone()
215 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps, algorithm=config.sample.algorithm)
216 |
217 |
218 |
219 | from absl import flags
220 | from absl import app
221 | from ml_collections import config_flags
222 | import sys
223 | from pathlib import Path
224 |
225 |
226 | FLAGS = flags.FLAGS
227 | config_flags.DEFINE_config_file(
228 | "config", None, "Training configuration.", lock_config=False)
229 | flags.mark_flags_as_required(["config"])
230 | flags.DEFINE_string("workdir", None, "Work unit directory.")
231 |
232 |
233 | def get_config_name():
234 | argv = sys.argv
235 | for i in range(1, len(argv)):
236 | if argv[i].startswith('--config='):
237 | return Path(argv[i].split('=')[-1]).stem
238 |
239 |
240 | def get_hparams():
241 | argv = sys.argv
242 | lst = []
243 | for i in range(1, len(argv)):
244 | assert '=' in argv[i]
245 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
246 | hparam, val = argv[i].split('=')
247 | hparam = hparam.split('.')[-1]
248 | if hparam.endswith('path'):
249 | val = Path(val).stem
250 | lst.append(f'{hparam}={val}')
251 | hparams = '-'.join(lst)
252 | if hparams == '':
253 | hparams = 'default'
254 | return hparams
255 |
256 |
257 | def main(argv):
258 | config = FLAGS.config
259 | config.config_name = get_config_name()
260 | config.hparams = get_hparams()
261 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
262 | config.ckpt_root = os.path.join(config.workdir, 'ckpts')
263 | config.sample_dir = os.path.join(config.workdir, 'samples')
264 | train(config)
265 |
266 |
267 | if __name__ == "__main__":
268 | app.run(main)
269 |
--------------------------------------------------------------------------------
/train_ldm_discrete.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import torch
3 | from torch import multiprocessing as mp
4 | from datasets import get_dataset
5 | from torchvision.utils import make_grid, save_image
6 | import utils
7 | import einops
8 | from torch.utils._pytree import tree_map
9 | import accelerate
10 | from torch.utils.data import DataLoader
11 | from tqdm.auto import tqdm
12 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
13 | import tempfile
14 | from tools.fid_score import calculate_fid_given_paths
15 | from absl import logging
16 | import builtins
17 | import os
18 | import wandb
19 | import libs.autoencoder
20 | import numpy as np
21 |
22 |
23 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
24 | _betas = (
25 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
26 | )
27 | return _betas.numpy()
28 |
29 |
30 | def get_skip(alphas, betas):
31 | N = len(betas) - 1
32 | skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
33 | for s in range(N + 1):
34 | skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
35 | skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
36 | for t in range(N + 1):
37 | prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
38 | skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
39 | return skip_alphas, skip_betas
40 |
41 |
42 | def stp(s, ts: torch.Tensor): # scalar tensor product
43 | if isinstance(s, np.ndarray):
44 | s = torch.from_numpy(s).type_as(ts)
45 | extra_dims = (1,) * (ts.dim() - 1)
46 | return s.view(-1, *extra_dims) * ts
47 |
48 |
49 | def mos(a, start_dim=1): # mean of square
50 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
51 |
52 |
53 | class Schedule(object): # discrete time
54 | def __init__(self, _betas):
55 | r""" _betas[0...999] = betas[1...1000]
56 | for n>=1, betas[n] is the variance of q(xn|xn-1)
57 | for n=0, betas[0]=0
58 | """
59 |
60 | self._betas = _betas
61 | self.betas = np.append(0., _betas)
62 | self.alphas = 1. - self.betas
63 | self.N = len(_betas)
64 |
65 | assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
66 | assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
67 | assert len(self.betas) == len(self.alphas)
68 |
69 | # skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
70 | self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
71 | self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod()
72 | self.cum_betas = self.skip_betas[0]
73 | self.snr = self.cum_alphas / self.cum_betas
74 |
75 | def tilde_beta(self, s, t):
76 | return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]
77 |
78 | def sample(self, x0): # sample from q(xn|x0), where n is uniform
79 | n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
80 | eps = torch.randn_like(x0)
81 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
82 | return torch.tensor(n, device=x0.device), eps, xn
83 |
84 | def __repr__(self):
85 | return f'Schedule({self.betas[:10]}..., {self.N})'
86 |
87 |
88 | def LSimple(x0, nnet, schedule, **kwargs):
89 | n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000}
90 | eps_pred = nnet(xn, n, **kwargs)
91 | return mos(eps - eps_pred)
92 |
93 |
94 | def train(config):
95 | if config.get('benchmark', False):
96 | torch.backends.cudnn.benchmark = True
97 | torch.backends.cudnn.deterministic = False
98 |
99 | mp.set_start_method('spawn')
100 | accelerator = accelerate.Accelerator()
101 | device = accelerator.device
102 | accelerate.utils.set_seed(config.seed, device_specific=True)
103 | logging.info(f'Process {accelerator.process_index} using device: {device}')
104 |
105 | config.mixed_precision = accelerator.mixed_precision
106 | config = ml_collections.FrozenConfigDict(config)
107 |
108 | assert config.train.batch_size % accelerator.num_processes == 0
109 | mini_batch_size = config.train.batch_size // accelerator.num_processes
110 |
111 | if accelerator.is_main_process:
112 | os.makedirs(config.ckpt_root, exist_ok=True)
113 | os.makedirs(config.sample_dir, exist_ok=True)
114 | accelerator.wait_for_everyone()
115 | if accelerator.is_main_process:
116 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
117 | name=config.hparams, job_type='train', mode='offline')
118 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
119 | logging.info(config)
120 | else:
121 | utils.set_logger(log_level='error')
122 | builtins.print = lambda *args: None
123 | logging.info(f'Run on {accelerator.num_processes} devices')
124 |
125 | dataset = get_dataset(**config.dataset)
126 | assert os.path.exists(dataset.fid_stat)
127 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
128 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
129 | num_workers=8, pin_memory=True, persistent_workers=True)
130 |
131 | train_state = utils.initialize_train_state(config, device)
132 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
133 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
134 | lr_scheduler = train_state.lr_scheduler
135 | train_state.resume(config.ckpt_root)
136 |
137 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
138 | autoencoder.to(device)
139 |
140 | @ torch.cuda.amp.autocast()
141 | def encode(_batch):
142 | return autoencoder.encode(_batch)
143 |
144 | @ torch.cuda.amp.autocast()
145 | def decode(_batch):
146 | return autoencoder.decode(_batch)
147 |
148 | def get_data_generator():
149 | while True:
150 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
151 | yield data
152 |
153 | data_generator = get_data_generator()
154 |
155 | _betas = stable_diffusion_beta_schedule()
156 | _schedule = Schedule(_betas)
157 | logging.info(f'use {_schedule}')
158 |
159 |
160 | def train_step(_batch):
161 | _metrics = dict()
162 | optimizer.zero_grad()
163 | if config.train.mode == 'uncond':
164 | _z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch)
165 | loss = LSimple(_z, nnet, _schedule)
166 | elif config.train.mode == 'cond':
167 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
168 | loss = LSimple(_z, nnet, _schedule, y=_batch[1])
169 | else:
170 | raise NotImplementedError(config.train.mode)
171 | _metrics['loss'] = accelerator.gather(loss.detach()).mean()
172 | accelerator.backward(loss.mean())
173 | optimizer.step()
174 | lr_scheduler.step()
175 | train_state.ema_update(config.get('ema_rate', 0.9999))
176 | train_state.step += 1
177 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
178 |
179 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
180 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
181 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
182 |
183 | def model_fn(x, t_continuous):
184 | t = t_continuous * _schedule.N
185 | eps_pre = nnet_ema(x, t, **kwargs)
186 | return eps_pre
187 |
188 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
189 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
190 | return decode(_z)
191 |
192 | def eval_step(n_samples, sample_steps):
193 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}'
194 | f'mini_batch_size={config.sample.mini_batch_size}')
195 |
196 | def sample_fn(_n_samples):
197 | if config.train.mode == 'uncond':
198 | kwargs = dict()
199 | elif config.train.mode == 'cond':
200 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
201 | else:
202 | raise NotImplementedError
203 | return dpm_solver_sample(_n_samples, sample_steps, **kwargs)
204 |
205 |
206 | with tempfile.TemporaryDirectory() as temp_path:
207 | path = config.sample.path or temp_path
208 | if accelerator.is_main_process:
209 | os.makedirs(path, exist_ok=True)
210 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
211 |
212 | _fid = 0
213 | if accelerator.is_main_process:
214 | _fid = calculate_fid_given_paths((dataset.fid_stat, path))
215 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
216 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
217 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
218 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
219 | _fid = torch.tensor(_fid, device=device)
220 | _fid = accelerator.reduce(_fid, reduction='sum')
221 |
222 | return _fid.item()
223 |
224 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
225 |
226 | step_fid = []
227 | while train_state.step < config.train.n_steps:
228 | nnet.train()
229 | batch = tree_map(lambda x: x.to(device), next(data_generator))
230 | metrics = train_step(batch)
231 |
232 | nnet.eval()
233 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
234 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
235 | logging.info(config.workdir)
236 | wandb.log(metrics, step=train_state.step)
237 |
238 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
239 | torch.cuda.empty_cache()
240 | logging.info('Save a grid of images...')
241 | if config.train.mode == 'uncond':
242 | samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50)
243 | elif config.train.mode == 'cond':
244 | y = einops.repeat(torch.arange(5, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
245 | samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50, y=y)
246 | else:
247 | raise NotImplementedError
248 | samples = make_grid(dataset.unpreprocess(samples), 10)
249 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
250 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
251 | torch.cuda.empty_cache()
252 | accelerator.wait_for_everyone()
253 |
254 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
255 | torch.cuda.empty_cache()
256 | logging.info(f'Save and eval checkpoint {train_state.step}...')
257 | if accelerator.local_process_index == 0:
258 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
259 | accelerator.wait_for_everyone()
260 | fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
261 | step_fid.append((train_state.step, fid))
262 | torch.cuda.empty_cache()
263 | accelerator.wait_for_everyone()
264 |
265 | logging.info(f'Finish fitting, step={train_state.step}')
266 | logging.info(f'step_fid: {step_fid}')
267 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
268 | logging.info(f'step_best: {step_best}')
269 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
270 | del metrics
271 | accelerator.wait_for_everyone()
272 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
273 |
274 |
275 |
276 | from absl import flags
277 | from absl import app
278 | from ml_collections import config_flags
279 | import sys
280 | from pathlib import Path
281 |
282 |
283 | FLAGS = flags.FLAGS
284 | config_flags.DEFINE_config_file(
285 | "config", None, "Training configuration.", lock_config=False)
286 | flags.mark_flags_as_required(["config"])
287 | flags.DEFINE_string("workdir", None, "Work unit directory.")
288 |
289 |
290 | def get_config_name():
291 | argv = sys.argv
292 | for i in range(1, len(argv)):
293 | if argv[i].startswith('--config='):
294 | return Path(argv[i].split('=')[-1]).stem
295 |
296 |
297 | def get_hparams():
298 | argv = sys.argv
299 | lst = []
300 | for i in range(1, len(argv)):
301 | assert '=' in argv[i]
302 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
303 | hparam, val = argv[i].split('=')
304 | hparam = hparam.split('.')[-1]
305 | if hparam.endswith('path'):
306 | val = Path(val).stem
307 | lst.append(f'{hparam}={val}')
308 | hparams = '-'.join(lst)
309 | if hparams == '':
310 | hparams = 'default'
311 | return hparams
312 |
313 |
314 | def main(argv):
315 | config = FLAGS.config
316 | config.config_name = get_config_name()
317 | config.hparams = get_hparams()
318 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
319 | config.ckpt_root = os.path.join(config.workdir, 'ckpts')
320 | config.sample_dir = os.path.join(config.workdir, 'samples')
321 | train(config)
322 |
323 |
324 | if __name__ == "__main__":
325 | app.run(main)
326 |
--------------------------------------------------------------------------------
/train_t2i_discrete.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import torch
3 | from torch import multiprocessing as mp
4 | from datasets import get_dataset
5 | from torchvision.utils import make_grid, save_image
6 | import utils
7 | import einops
8 | from torch.utils._pytree import tree_map
9 | import accelerate
10 | from torch.utils.data import DataLoader
11 | from tqdm.auto import tqdm
12 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
13 | import tempfile
14 | from tools.fid_score import calculate_fid_given_paths
15 | from absl import logging
16 | import builtins
17 | import os
18 | import wandb
19 | import libs.autoencoder
20 | import numpy as np
21 |
22 |
23 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
24 | _betas = (
25 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
26 | )
27 | return _betas.numpy()
28 |
29 |
30 | def get_skip(alphas, betas):
31 | N = len(betas) - 1
32 | skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
33 | for s in range(N + 1):
34 | skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
35 | skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
36 | for t in range(N + 1):
37 | prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
38 | skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
39 | return skip_alphas, skip_betas
40 |
41 |
42 | def stp(s, ts: torch.Tensor): # scalar tensor product
43 | if isinstance(s, np.ndarray):
44 | s = torch.from_numpy(s).type_as(ts)
45 | extra_dims = (1,) * (ts.dim() - 1)
46 | return s.view(-1, *extra_dims) * ts
47 |
48 |
49 | def mos(a, start_dim=1): # mean of square
50 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
51 |
52 |
53 | class Schedule(object): # discrete time
54 | def __init__(self, _betas):
55 | r""" _betas[0...999] = betas[1...1000]
56 | for n>=1, betas[n] is the variance of q(xn|xn-1)
57 | for n=0, betas[0]=0
58 | """
59 |
60 | self._betas = _betas
61 | self.betas = np.append(0., _betas)
62 | self.alphas = 1. - self.betas
63 | self.N = len(_betas)
64 |
65 | assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
66 | assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
67 | assert len(self.betas) == len(self.alphas)
68 |
69 | # skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
70 | self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
71 | self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod()
72 | self.cum_betas = self.skip_betas[0]
73 | self.snr = self.cum_alphas / self.cum_betas
74 |
75 | def tilde_beta(self, s, t):
76 | return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]
77 |
78 | def sample(self, x0): # sample from q(xn|x0), where n is uniform
79 | n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
80 | eps = torch.randn_like(x0)
81 | xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
82 | return torch.tensor(n, device=x0.device), eps, xn
83 |
84 | def __repr__(self):
85 | return f'Schedule({self.betas[:10]}..., {self.N})'
86 |
87 |
88 | def LSimple(x0, nnet, schedule, **kwargs):
89 | n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000}
90 | eps_pred = nnet(xn, n, **kwargs)
91 | return mos(eps - eps_pred)
92 |
93 |
94 | def train(config):
95 | if config.get('benchmark', False):
96 | torch.backends.cudnn.benchmark = True
97 | torch.backends.cudnn.deterministic = False
98 |
99 | mp.set_start_method('spawn')
100 | accelerator = accelerate.Accelerator()
101 | device = accelerator.device
102 | accelerate.utils.set_seed(config.seed, device_specific=True)
103 | logging.info(f'Process {accelerator.process_index} using device: {device}')
104 |
105 | config.mixed_precision = accelerator.mixed_precision
106 | config = ml_collections.FrozenConfigDict(config)
107 |
108 | assert config.train.batch_size % accelerator.num_processes == 0
109 | mini_batch_size = config.train.batch_size // accelerator.num_processes
110 |
111 | if accelerator.is_main_process:
112 | os.makedirs(config.ckpt_root, exist_ok=True)
113 | os.makedirs(config.sample_dir, exist_ok=True)
114 | accelerator.wait_for_everyone()
115 | if accelerator.is_main_process:
116 | wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
117 | name=config.hparams, job_type='train', mode='offline')
118 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
119 | logging.info(config)
120 | else:
121 | utils.set_logger(log_level='error')
122 | builtins.print = lambda *args: None
123 | logging.info(f'Run on {accelerator.num_processes} devices')
124 |
125 | dataset = get_dataset(**config.dataset)
126 | assert os.path.exists(dataset.fid_stat)
127 | train_dataset = dataset.get_split(split='train', labeled=True)
128 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
129 | num_workers=8, pin_memory=True, persistent_workers=True)
130 | test_dataset = dataset.get_split(split='test', labeled=True) # for sampling
131 | test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True,
132 | num_workers=8, pin_memory=True, persistent_workers=True)
133 |
134 | train_state = utils.initialize_train_state(config, device)
135 | nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare(
136 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader)
137 | lr_scheduler = train_state.lr_scheduler
138 | train_state.resume(config.ckpt_root)
139 |
140 | autoencoder = libs.autoencoder.get_model(**config.autoencoder)
141 | autoencoder.to(device)
142 |
143 | @ torch.cuda.amp.autocast()
144 | def encode(_batch):
145 | return autoencoder.encode(_batch)
146 |
147 | @ torch.cuda.amp.autocast()
148 | def decode(_batch):
149 | return autoencoder.decode(_batch)
150 |
151 | def get_data_generator():
152 | while True:
153 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
154 | yield data
155 |
156 | data_generator = get_data_generator()
157 |
158 | def get_context_generator():
159 | while True:
160 | for data in test_dataset_loader:
161 | _, _context = data
162 | yield _context
163 |
164 | context_generator = get_context_generator()
165 |
166 | _betas = stable_diffusion_beta_schedule()
167 | _schedule = Schedule(_betas)
168 | logging.info(f'use {_schedule}')
169 |
170 | def cfg_nnet(x, timesteps, context):
171 | _cond = nnet_ema(x, timesteps, context=context)
172 | _empty_context = torch.tensor(dataset.empty_context, device=device)
173 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
174 | _uncond = nnet_ema(x, timesteps, context=_empty_context)
175 | return _cond + config.sample.scale * (_cond - _uncond)
176 |
177 | def train_step(_batch):
178 | _metrics = dict()
179 | optimizer.zero_grad()
180 | _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
181 | loss = LSimple(_z, nnet, _schedule, context=_batch[1]) # currently only support the extracted feature version
182 | _metrics['loss'] = accelerator.gather(loss.detach()).mean()
183 | accelerator.backward(loss.mean())
184 | optimizer.step()
185 | lr_scheduler.step()
186 | train_state.ema_update(config.get('ema_rate', 0.9999))
187 | train_state.step += 1
188 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
189 |
190 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
191 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
192 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
193 |
194 | def model_fn(x, t_continuous):
195 | t = t_continuous * _schedule.N
196 | return cfg_nnet(x, t, **kwargs)
197 |
198 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
199 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
200 | return decode(_z)
201 |
202 | def eval_step(n_samples, sample_steps):
203 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=dpm_solver, '
204 | f'mini_batch_size={config.sample.mini_batch_size}')
205 |
206 | def sample_fn(_n_samples):
207 | _context = next(context_generator)
208 | assert _context.size(0) == _n_samples
209 | return dpm_solver_sample(_n_samples, sample_steps, context=_context)
210 |
211 | with tempfile.TemporaryDirectory() as temp_path:
212 | path = config.sample.path or temp_path
213 | if accelerator.is_main_process:
214 | os.makedirs(path, exist_ok=True)
215 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
216 |
217 | _fid = 0
218 | if accelerator.is_main_process:
219 | _fid = calculate_fid_given_paths((dataset.fid_stat, path))
220 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
221 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
222 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
223 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
224 | _fid = torch.tensor(_fid, device=device)
225 | _fid = accelerator.reduce(_fid, reduction='sum')
226 |
227 | return _fid.item()
228 |
229 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
230 |
231 | step_fid = []
232 | while train_state.step < config.train.n_steps:
233 | nnet.train()
234 | batch = tree_map(lambda x: x.to(device), next(data_generator))
235 | metrics = train_step(batch)
236 |
237 | nnet.eval()
238 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
239 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
240 | logging.info(config.workdir)
241 | wandb.log(metrics, step=train_state.step)
242 |
243 | if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
244 | torch.cuda.empty_cache()
245 | logging.info('Save a grid of images...')
246 | contexts = torch.tensor(dataset.contexts, device=device)[: 2 * 5]
247 | samples = dpm_solver_sample(_n_samples=2 * 5, _sample_steps=50, context=contexts)
248 | samples = make_grid(dataset.unpreprocess(samples), 5)
249 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
250 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
251 | torch.cuda.empty_cache()
252 | accelerator.wait_for_everyone()
253 |
254 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
255 | torch.cuda.empty_cache()
256 | logging.info(f'Save and eval checkpoint {train_state.step}...')
257 | if accelerator.local_process_index == 0:
258 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
259 | accelerator.wait_for_everyone()
260 | fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
261 | step_fid.append((train_state.step, fid))
262 | torch.cuda.empty_cache()
263 | accelerator.wait_for_everyone()
264 |
265 | logging.info(f'Finish fitting, step={train_state.step}')
266 | logging.info(f'step_fid: {step_fid}')
267 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
268 | logging.info(f'step_best: {step_best}')
269 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
270 | del metrics
271 | accelerator.wait_for_everyone()
272 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
273 |
274 |
275 |
276 | from absl import flags
277 | from absl import app
278 | from ml_collections import config_flags
279 | import sys
280 | from pathlib import Path
281 |
282 |
283 | FLAGS = flags.FLAGS
284 | config_flags.DEFINE_config_file(
285 | "config", None, "Training configuration.", lock_config=False)
286 | flags.mark_flags_as_required(["config"])
287 | flags.DEFINE_string("workdir", None, "Work unit directory.")
288 |
289 |
290 | def get_config_name():
291 | argv = sys.argv
292 | for i in range(1, len(argv)):
293 | if argv[i].startswith('--config='):
294 | return Path(argv[i].split('=')[-1]).stem
295 |
296 |
297 | def get_hparams():
298 | argv = sys.argv
299 | lst = []
300 | for i in range(1, len(argv)):
301 | assert '=' in argv[i]
302 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
303 | hparam, val = argv[i].split('=')
304 | hparam = hparam.split('.')[-1]
305 | if hparam.endswith('path'):
306 | val = Path(val).stem
307 | lst.append(f'{hparam}={val}')
308 | hparams = '-'.join(lst)
309 | if hparams == '':
310 | hparams = 'default'
311 | return hparams
312 |
313 |
314 | def main(argv):
315 | config = FLAGS.config
316 | config.config_name = get_config_name()
317 | config.hparams = get_hparams()
318 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
319 | config.ckpt_root = os.path.join(config.workdir, 'ckpts')
320 | config.sample_dir = os.path.join(config.workdir, 'samples')
321 | train(config)
322 |
323 |
324 | if __name__ == "__main__":
325 | app.run(main)
326 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import os
5 | from tqdm import tqdm
6 | from torchvision.utils import save_image
7 | from absl import logging
8 |
9 |
10 | def set_logger(log_level='info', fname=None):
11 | import logging as _logging
12 | handler = logging.get_absl_handler()
13 | formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
14 | handler.setFormatter(formatter)
15 | logging.set_verbosity(log_level)
16 | if fname is not None:
17 | handler = _logging.FileHandler(fname)
18 | handler.setFormatter(formatter)
19 | logging.get_absl_logger().addHandler(handler)
20 |
21 |
22 | def dct2str(dct):
23 | return str({k: f'{v:.6g}' for k, v in dct.items()})
24 |
25 |
26 | def get_nnet(name, **kwargs):
27 | if name == 'uvit':
28 | from libs.uvit import UViT
29 | return UViT(**kwargs)
30 | elif name == 'uvit_t2i':
31 | from libs.uvit_t2i import UViT
32 | return UViT(**kwargs)
33 | else:
34 | raise NotImplementedError(name)
35 |
36 |
37 | def set_seed(seed: int):
38 | if seed is not None:
39 | torch.manual_seed(seed)
40 | np.random.seed(seed)
41 |
42 |
43 | def get_optimizer(params, name, **kwargs):
44 | if name == 'adam':
45 | from torch.optim import Adam
46 | return Adam(params, **kwargs)
47 | elif name == 'adamw':
48 | from torch.optim import AdamW
49 | return AdamW(params, **kwargs)
50 | else:
51 | raise NotImplementedError(name)
52 |
53 |
54 | def customized_lr_scheduler(optimizer, warmup_steps=-1):
55 | from torch.optim.lr_scheduler import LambdaLR
56 | def fn(step):
57 | if warmup_steps > 0:
58 | return min(step / warmup_steps, 1)
59 | else:
60 | return 1
61 | return LambdaLR(optimizer, fn)
62 |
63 |
64 | def get_lr_scheduler(optimizer, name, **kwargs):
65 | if name == 'customized':
66 | return customized_lr_scheduler(optimizer, **kwargs)
67 | elif name == 'cosine':
68 | from torch.optim.lr_scheduler import CosineAnnealingLR
69 | return CosineAnnealingLR(optimizer, **kwargs)
70 | else:
71 | raise NotImplementedError(name)
72 |
73 |
74 | def ema(model_dest: nn.Module, model_src: nn.Module, rate):
75 | param_dict_src = dict(model_src.named_parameters())
76 | for p_name, p_dest in model_dest.named_parameters():
77 | p_src = param_dict_src[p_name]
78 | assert p_src is not p_dest
79 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
80 |
81 |
82 | class TrainState(object):
83 | def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
84 | self.optimizer = optimizer
85 | self.lr_scheduler = lr_scheduler
86 | self.step = step
87 | self.nnet = nnet
88 | self.nnet_ema = nnet_ema
89 |
90 | def ema_update(self, rate=0.9999):
91 | if self.nnet_ema is not None:
92 | ema(self.nnet_ema, self.nnet, rate)
93 |
94 | def save(self, path):
95 | os.makedirs(path, exist_ok=True)
96 | torch.save(self.step, os.path.join(path, 'step.pth'))
97 | for key, val in self.__dict__.items():
98 | if key != 'step' and val is not None:
99 | torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
100 |
101 | def load(self, path):
102 | logging.info(f'load from {path}')
103 | self.step = torch.load(os.path.join(path, 'step.pth'))
104 | for key, val in self.__dict__.items():
105 | if key != 'step' and val is not None:
106 | val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
107 |
108 | def resume(self, ckpt_root, step=None):
109 | if not os.path.exists(ckpt_root):
110 | return
111 | if step is None:
112 | ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
113 | if not ckpts:
114 | return
115 | steps = map(lambda x: int(x.split(".")[0]), ckpts)
116 | step = max(steps)
117 | ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
118 | logging.info(f'resume from {ckpt_path}')
119 | self.load(ckpt_path)
120 |
121 | def to(self, device):
122 | for key, val in self.__dict__.items():
123 | if isinstance(val, nn.Module):
124 | val.to(device)
125 |
126 |
127 | def cnt_params(model):
128 | return sum(param.numel() for param in model.parameters())
129 |
130 |
131 | def initialize_train_state(config, device):
132 | params = []
133 |
134 | nnet = get_nnet(**config.nnet)
135 | params += nnet.parameters()
136 | nnet_ema = get_nnet(**config.nnet)
137 | nnet_ema.eval()
138 | logging.info(f'nnet has {cnt_params(nnet)} parameters')
139 |
140 | optimizer = get_optimizer(params, **config.optimizer)
141 | lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
142 |
143 | train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
144 | nnet=nnet, nnet_ema=nnet_ema)
145 | train_state.ema_update(0)
146 | train_state.to(device)
147 | return train_state
148 |
149 |
150 | def amortize(n_samples, batch_size):
151 | k = n_samples // batch_size
152 | r = n_samples % batch_size
153 | return k * [batch_size] if r == 0 else k * [batch_size] + [r]
154 |
155 |
156 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None):
157 | os.makedirs(path, exist_ok=True)
158 | idx = 0
159 | batch_size = mini_batch_size * accelerator.num_processes
160 |
161 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
162 | samples = unpreprocess_fn(sample_fn(mini_batch_size))
163 | samples = accelerator.gather(samples.contiguous())[:_batch_size]
164 | if accelerator.is_main_process:
165 | for sample in samples:
166 | save_image(sample, os.path.join(path, f"{idx}.png"))
167 | idx += 1
168 |
169 |
170 | def grad_norm(model):
171 | total_norm = 0.
172 | for p in model.parameters():
173 | param_norm = p.grad.data.norm(2)
174 | total_norm += param_norm.item() ** 2
175 | total_norm = total_norm ** (1. / 2)
176 | return total_norm
177 |
--------------------------------------------------------------------------------
/uvit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baofff/U-ViT/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/uvit.png
--------------------------------------------------------------------------------