├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── .DS_Store └── figure.png ├── configs ├── fit │ └── config_fit_xl.yaml └── fitv2 │ ├── config_fitv2_3B.yaml │ ├── config_fitv2_hr_3B.yaml │ ├── config_fitv2_hr_xl.yaml │ └── config_fitv2_xl.yaml ├── fit ├── data │ ├── in1k_dataset.py │ └── in1k_latent_dataset.py ├── model │ ├── fit_model.py │ ├── modules.py │ ├── norms.py │ ├── rope.py │ └── utils.py ├── scheduler │ ├── improved_diffusion │ │ ├── __init__.py │ │ ├── diffusion_utils.py │ │ ├── gaussian_diffusion.py │ │ ├── respace.py │ │ └── timestep_sampler.py │ └── transport │ │ ├── __init__.py │ │ ├── integrators.py │ │ ├── path.py │ │ ├── transport.py │ │ └── utils.py └── utils │ ├── eval_utils.py │ ├── lr_scheduler.py │ ├── sit_eval_utils.py │ └── utils.py ├── requirements.txt ├── sample_fit_ddp.py ├── sample_fitv2_ddp.py ├── setup.py ├── tools ├── download_in1k_latents_1024.sh ├── download_in1k_latents_256.sh ├── train_fit_xl.sh ├── train_fitv2_3B.sh ├── train_fitv2_hr_3B.sh ├── train_fitv2_hr_xl.sh └── train_fitv2_xl.sh ├── train_fit.py └── train_fitv2.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whlzy/FiT/b22b7b0ac4bf841242711f5c20703d072708270e/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.ckpt 4 | *.pth 5 | *.egg-info 6 | *.mp4 7 | *.png 8 | *.jpg 9 | *.npz 10 | *.safetensors 11 | /outputs 12 | /checkpoints 13 | /tmp 14 | /logs 15 | /dataset 16 | /workdir 17 | /debug 18 | *__pycache__* 19 | batch* 20 | /.vscode 21 | /.ipynb_checkpoints 22 | *.ipynb_checkpoints 23 | /test -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Figure](assets/figure.png) 2 | 3 | # FiT: Flexible Vision Transformer for Diffusion Model 4 | 5 |

6 | 📃 FiT Paper • 7 | 📦 FiT Checkpoint
• 8 | 📃 FiTv2 Paper • 9 | 📦 FiTv2 Checkpoint
10 |

11 | 12 | This is the official repo which contains PyTorch model definitions, pre-trained weights and sampling code for our flexible vision transformer (FiT). 13 | FiT is a diffusion transformer based model which can generate images at unrestricted resolutions and aspect ratios. 14 | 15 | The core features will include: 16 | * Pre-trained class-conditional FiT-XL-2-16 (2000K) model weight trained on ImageNet ($H\times W \le 256\times256$). 17 | * Pre-trained class-conditional FiTv2-XL-2-16 (2000K) and FiTv2-3B-2-16 (1000K) model weight trained on ImageNet ($H\times W \le 256\times256$). 18 | * High-resolution Fine-tuned FiTv2-XL-2-32 (400K) and FiTv2-3B-2-32 (200K) model weight trained on ImageNet ($H\times W \le 512\times512$). 19 | * A pytorch sample code for running pre-trained FiT and FiTv2 models to generate images at unrestricted resolutions and aspect ratios. 20 | 21 | Why we need FiT? 22 | * 🧐 Nature is infinitely resolution-free. FiT, like Sora, was trained on the unrestricted resolution or aspect ratio. FiT is capable of generating images at unrestricted resolutions and aspect ratios. 23 | * 🤗 FiT exhibits remarkable flexibility in resolution extrapolation generation. 24 | 25 | Stay tuned for this project! 😆 26 | 27 | 28 | ## Setup 29 | First, download and setup the repo: 30 | ``` 31 | git clone https://github.com/whlzy/FiT.git 32 | cd FiT 33 | ``` 34 | ## Installation 35 | ``` 36 | conda create -n fit_env python=3.10 37 | pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118 38 | pip install xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118 39 | pip install -r requirements.txt 40 | pip install -e . 41 | ``` 42 | 43 | ## Sample 44 | ### Basic Sampling 45 | 46 | Basically, the model is trained with images whose $H\times W\leqslant 256\times256$. 47 | Our FiTv1-XL/2 model and FiTv2-XL/2 model are trained with batch size of 256 for 2000K steps. 48 | Our FiTv2-3B/2 model is trained with batch size of 256 for 1000K steps. 49 | 50 | The pre-trained FiT models can be downloaded directyl from huggingface: 51 | | FiT Model | Checkpoint | FID-256x256 | FID-320x320 | Model Size | GFlOPS | 52 | |---------------|------------|---------|-----------------|------------| ------ | 53 | | [FiTv1-XL/2](https://huggingface.co/InfImagine/FiT/tree/main/FiTv1_xl) | [CKPT](https://huggingface.co/InfImagine/FiT/blob/main/FiTv1_xl/model_ema.bin) | 4.21 | 5.11 | 824M | 153 | 54 | | [FiTv2-XL/2](https://huggingface.co/InfImagine/FiTv2/tree/main/FiTv2_XL) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL/model_ema.safetensors?download=true) | 2.26 | 3.55 | 671M | 147 | 55 | | [FiTv2-3B/2](https://huggingface.co/InfImagine/FiT/tree/main/FiTv1_xl) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B/model_ema.bin?download=true) | 2.15 | 3.22 | 3B | 653 | 56 | 57 | 58 | #### Downloading the checkpoints 59 | Downloading via wget: 60 | ``` 61 | mkdir checkpoints 62 | 63 | wget -c "https://huggingface.co/InfImagine/FiT/blob/main/FiTv1_xl/model_ema.bin" -O checkpoints/fitv1_xl.bin 64 | 65 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL/model_ema.safetensors?download=true" -O checkpoints/fitv2_xl.safetensors 66 | 67 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B/model_ema.bin?download=true" -O checkpoints/fitv2_3B.bin 68 | ``` 69 | 70 | #### Sampling 256x256 Images 71 | Sampling with FiTv1-XL/2 for $256\times 256$ Images: 72 | ``` 73 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fit_ddp.py --num-fid-samples 50000 --cfgdir configs/fit/config_fit_xl.yaml --ckpt checkpoints/fitv1_xl.bin --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 74 | ``` 75 | Sampling with FiTv2-XL/2 for $256\times 256$ Images: 76 | ``` 77 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE 78 | ``` 79 | Sampling with FiTv2-3B/2 for $256\times 256$ Images: 80 | ``` 81 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_3B.yaml --ckpt checkpoints/fitv2_3B.bin --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE 82 | ``` 83 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified. 84 | 85 | 86 | #### Sampling Images with arbitrary resolutions 87 | We can assign the *image-height* and *image-width* with any value we want. And we need to specify the original maximum positional embedding length (*ori-max-pe-len*) and the interpolation method. 88 | We show some examples as follows. 89 | 90 | Sampling with FiTv2-XL/2 for $160\times 320$ images: 91 | ``` 92 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 160 --image-width 320 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple 93 | ``` 94 | Sampling with FiTv2-XL/2 for $320\times 320$ images: 95 | ``` 96 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 320 --image-width 320 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation ntkpro2 --decouple 97 | ``` 98 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified. 99 | 100 | 101 | ### High-resolution Sampling 102 | 103 | For high-resolution image generation, we use images whose $H\times W \leqslant 512\times512$. 104 | Our FiTv2-XL/2 is finetuned with batch size of 256 for 400K steps, while 105 | FiTv2-3B/2 is finetuned with batch size of 256 for 200K steps. 106 | 107 | The high-resolution fine-tuned FiT models can be downloaded directyl from huggingface: 108 | | FiT Model | Checkpoint | FID-512x512 | FID-320x640 | Model Size | GFlOPS | 109 | |---------------|------------|---------|-----------------|------------| ------ | 110 | | [FiTv2-HR-XL/2](https://huggingface.co/InfImagine/FiTv2/tree/main/FiTv2_XL_HRFT) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL_HRFT/model_ema.safetensors?download=true) | 2.90 | 4.87 | 671M | 147 | 111 | | [FiTv2-HR-3B/2](https://huggingface.co/InfImagine/FiTv2/tree/main/FiTv2_3B_HRFT) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B_HRFT/model_ema.safetensors?download=true) | 2.41 | 4.54 | 3B | 653 | 112 | 113 | 114 | #### Downloading 115 | Downloading via wget: 116 | ``` 117 | mkdir checkpoints 118 | 119 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL_HRFT/model_ema.safetensors?download=true" -O checkpoints/fitv2_hr_xl.safetensors 120 | 121 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B_HRFT/model_ema.safetensors?download=true" -O checkpoints/fitv2_hr_3B.safetensors 122 | ``` 123 | 124 | #### Sampling 512x512 Images 125 | Sampling with FiTv2-HR-XL/2 for $512\times 512$ Images: 126 | ``` 127 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_xl.yaml --ckpt checkpoints/fitv2_hr_xl.safetensors --image-height 512 --image-width 512 --num-sampling-steps 250 --cfg-scale 1.65 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple 128 | ``` 129 | Sampling with FiTv2-HR-3B/2 for $512\times 512$ Images: 130 | ``` 131 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_3B.yaml --ckpt checkpoints/fitv2_hr_3B.safetensors --image-height 512 --image-width 512 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple 132 | ``` 133 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified. 134 | 135 | #### Sampling Images with arbitrary resolutions 136 | Sampling with FiTv2-HR-XL/2 for $320\times 640$ images: 137 | ``` 138 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_xl.yaml --ckpt checkpoints/fitv2_hr_xl.safetensors --image-height 320 --image-width 640 --num-sampling-steps 250 --cfg-scale 1.65 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple 139 | ``` 140 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified. 141 | 142 | ## Evaluations 143 | The sampling generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow 144 | evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and 145 | other metrics. 146 | 147 | 148 | 149 | ## Flexible Imagenet Latent Datasets 150 | 151 | We use [SD-VAE-FT-EMA](https://huggingface.co/stabilityai/sd-vae-ft-ema) to encode an image into the latent codes. 152 | 153 | Accordingly to our flexible training pipeline, we can deal with images with arbitrary resolutions and aspect ratios. 154 | So we preprocess the ImageNet1k dataset according to the original height and width of an image. 155 | Conventionally, we set patch size $p$ to $2$ and the downsampling scale $s$ of the VAE encoder is $8$, 156 | so an image with $H=256$ and $W=256$ will lead to $\frac{H\times W}{p^2\times s^2} = 256$ tokens. 157 | 158 | For our pre-training, we set the maximum token length to $256$, which corresponds image resolution size $S=H=W=256$. 159 | For the high-resolution fine-tuning, the token length is $1024$, which corresponds image resolution size $S=H=W=512$. 160 | Given an input image $I\in \mathbb{R}^{3\times H \times W}$, and target resolution size $S=256/512$, the preprocessing is: 161 | ``` 162 | If H > S and W > S: 163 | img_resize = Resize(I) 164 | latent_resize = VAE_Encode(img_resize) 165 | save(latent_resize) 166 | img_crop = CenterCrop(Resize(I)) 167 | latent_crop = VAE_Encode(img_crop) 168 | save(latent_resize) 169 | else: 170 | img_resize = Resize(I) 171 | latent_resize = VAE_Encode(img_resize) 172 | save(latent_resize) 173 | ``` 174 | 175 | ### Dataset for Pretraining 176 | 177 | All the image latent codes with maximum token length $256$ can be downloaded from [here](https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/tree/main). 178 | 179 | ``` 180 | bash tools/download_in1k_latents_256.sh 181 | ``` 182 | 183 | 184 | 185 | ### Dataset for High-resolution Fine-tuning 186 | All the image latent codes with maximum token length $1024$ can be downloaded from [here](https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema). 187 | 188 | 189 | ``` 190 | bash tools/download_in1k_latents_1024.sh 191 | ``` 192 | 193 | ### Dataset Architecture 194 | 195 | - imagenet1k_latents_256_sd_vae_ft_ema 196 | - less_than_16 197 | - xxxxxxx.safetensors 198 | - xxxxxxx.safetensors 199 | - from_16_to_256 200 | - xxxxxxx.safetensors 201 | - xxxxxxx.safetensors 202 | - greater_than_256_crop 203 | - xxxxxxx.safetensors 204 | - xxxxxxx.safetensors 205 | - greater_than_256_resize 206 | - xxxxxxx.safetensors 207 | - xxxxxxx.safetensors 208 | - imagenet1k_latents_1024_sd_vae_ft_ema 209 | - less_than_16 210 | - xxxxxxx.safetensors 211 | - xxxxxxx.safetensors 212 | - from_16_to_1024 213 | - xxxxxxx.safetensors 214 | - xxxxxxx.safetensors 215 | - greater_than_1024_crop 216 | - xxxxxxx.safetensors 217 | - xxxxxxx.safetensors 218 | - greater_than_1024_resize 219 | - xxxxxxx.safetensors 220 | - xxxxxxx.safetensors 221 | 222 | ## Train 223 | You need to determine the number of node and GPU for your training. 224 | 225 | Train FiT and FiTv2 models: 226 | ``` 227 | bash tools/train_fit_xl.sh 228 | 229 | bash tools/train_fitv2_xl.sh 230 | 231 | bash tools/train_fitv2_3B.sh 232 | ``` 233 | 234 | High-resolution Fine-tuning: 235 | ``` 236 | bash tools/train_fitv2_hr_xl.sh 237 | 238 | bash tools/train_fitv2_hr_3B.sh 239 | ``` 240 | 241 | 242 | ## BibTeX 243 | ```bibtex 244 | @article{Lu2024FiT, 245 | title={FiT: Flexible Vision Transformer for Diffusion Model}, 246 | author={Zeyu Lu and Zidong Wang and Di Huang and Chengyue Wu and Xihui Liu and Wanli Ouyang and Lei Bai}, 247 | year={2024}, 248 | journal={arXiv preprint arXiv:2402.12376}, 249 | } 250 | ``` 251 | ```bibtex 252 | @article{wang2024fitv2, 253 | title={Fitv2: Scalable and improved flexible vision transformer for diffusion model}, 254 | author={Wang, ZiDong and Lu, Zeyu and Huang, Di and Zhou, Cai and Ouyang, Wanli and others}, 255 | journal={arXiv preprint arXiv:2410.13925}, 256 | year={2024} 257 | } 258 | ``` -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whlzy/FiT/b22b7b0ac4bf841242711f5c20703d072708270e/assets/.DS_Store -------------------------------------------------------------------------------- /assets/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whlzy/FiT/b22b7b0ac4bf841242711f5c20703d072708270e/assets/figure.png -------------------------------------------------------------------------------- /configs/fit/config_fit_xl.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema 3 | improved_diffusion: 4 | timestep_respacing: '' 5 | noise_schedule: linear 6 | use_kl: false 7 | sigma_small: false 8 | predict_xstart: false 9 | learn_sigma: true 10 | rescale_learned_sigmas: false 11 | diffusion_steps: 1000 12 | noise_scheduler: 13 | num_train_timesteps: 1000 14 | beta_start: 0.0001 15 | beta_end: 0.02 16 | beta_schedule: linear 17 | prediction_type: epsilon 18 | steps_offset: 0 19 | clip_sample: false 20 | network_config: 21 | target: fit.model.fit_model.FiT 22 | params: 23 | context_size: 256 24 | patch_size: 2 25 | in_channels: 4 26 | hidden_size: 1152 27 | depth: 28 28 | num_heads: 16 29 | mlp_ratio: 4.0 30 | class_dropout_prob: 0.1 31 | num_classes: 1000 32 | learn_sigma: true 33 | use_swiglu: true 34 | use_swiglu_large: true 35 | rel_pos_embed: rope 36 | 37 | 38 | data: 39 | target: fit.data.in1k_latent_dataset.INLatentLoader 40 | params: 41 | train: 42 | data_path: datasets/imagenet1k_latents_256_sd_vae_ft_ema 43 | target_len: 256 44 | random: 'resize' 45 | loader: 46 | batch_size: 32 47 | num_workers: 2 48 | shuffle: True 49 | 50 | accelerate: 51 | # others 52 | gradient_accumulation_steps: 1 53 | mixed_precision: 'bf16' 54 | # training step config 55 | num_train_epochs: 56 | max_train_steps: 2000000 57 | # optimizer config 58 | learning_rate: 1.0e-4 59 | learning_rate_base_batch_size: 256 60 | max_grad_norm: 1.0 61 | optimizer: 62 | target: torch.optim.AdamW 63 | params: 64 | betas: ${tuple:0.9, 0.999} 65 | weight_decay: 0 #1.0e-2 66 | eps: 1.0e-8 67 | lr_scheduler: constant 68 | lr_warmup_steps: 500 69 | # checkpoint config 70 | logger: wandb 71 | checkpointing_epochs: False 72 | checkpointing_steps: 100000 73 | checkpointing_steps_list: [400000, 1000000, 2000000] 74 | checkpoints_total_limit: 2 75 | logging_steps: 10000 76 | -------------------------------------------------------------------------------- /configs/fitv2/config_fitv2_3B.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema 3 | transport: 4 | path_type: Linear 5 | prediction: velocity 6 | loss_weight: null 7 | sample_eps: null 8 | train_eps: null 9 | snr_type: lognorm 10 | sampler: 11 | mode: ODE 12 | sde: 13 | sampling_method: Euler 14 | diffusion_form: sigma 15 | diffusion_norm: 1.0 16 | last_step: Mean 17 | last_step_size: 0.04 18 | ode: 19 | sampling_method: dopri5 20 | atol: 1.0e-06 21 | rtol: 0.001 22 | reverse: false 23 | likelihood: false 24 | network_config: 25 | target: fit.model.fit_model.FiT 26 | params: 27 | context_size: 256 28 | patch_size: 2 29 | in_channels: 4 30 | hidden_size: 2304 31 | depth: 40 32 | num_heads: 24 33 | mlp_ratio: 4.0 34 | class_dropout_prob: 0.1 35 | num_classes: 1000 36 | learn_sigma: false 37 | use_sit: true 38 | use_swiglu: true 39 | use_swiglu_large: false 40 | q_norm: layernorm 41 | k_norm: layernorm 42 | qk_norm_weight: false 43 | rel_pos_embed: rope 44 | abs_pos_embed: null 45 | adaln_type: lora 46 | adaln_lora_dim: 576 47 | 48 | 49 | data: 50 | target: fit.data.in1k_latent_dataset.INLatentLoader 51 | params: 52 | train: 53 | data_path: datasets/imagenet1k_latents_256_sd_vae_ft_ema 54 | target_len: 256 55 | random: 'random' 56 | loader: 57 | batch_size: 16 58 | num_workers: 2 59 | shuffle: True 60 | 61 | 62 | accelerate: 63 | # others 64 | gradient_accumulation_steps: 1 65 | mixed_precision: 'bf16' 66 | # training step config 67 | num_train_epochs: 68 | max_train_steps: 2000000 69 | # optimizer config 70 | learning_rate: 1.0e-4 71 | learning_rate_base_batch_size: 256 72 | max_grad_norm: 1.0 73 | optimizer: 74 | target: torch.optim.AdamW 75 | params: 76 | betas: ${tuple:0.9, 0.999} 77 | weight_decay: 0 #1.0e-2 78 | eps: 1.0e-8 79 | lr_scheduler: constant_with_warmup 80 | lr_warmup_steps: 50000 81 | # checkpoint config 82 | logger: wandb 83 | checkpointing_epochs: False 84 | checkpointing_steps: 4000 85 | checkpointing_steps_list: [200000, 400000, 1000000, 1400000, 1500000, 1800000] 86 | checkpoints_total_limit: 2 87 | logging_steps: 1000 88 | -------------------------------------------------------------------------------- /configs/fitv2/config_fitv2_hr_3B.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema 3 | transport: 4 | path_type: Linear 5 | prediction: velocity 6 | loss_weight: null 7 | sample_eps: null 8 | train_eps: null 9 | snr_type: lognorm 10 | sampler: 11 | mode: ODE 12 | sde: 13 | sampling_method: Euler 14 | diffusion_form: sigma 15 | diffusion_norm: 1.0 16 | last_step: Mean 17 | last_step_size: 0.04 18 | ode: 19 | sampling_method: dopri5 20 | atol: 1.0e-06 21 | rtol: 0.001 22 | reverse: false 23 | likelihood: false 24 | network_config: 25 | target: fit.model.fit_model.FiT 26 | params: 27 | context_size: 1024 28 | patch_size: 2 29 | in_channels: 4 30 | hidden_size: 2304 31 | depth: 40 32 | num_heads: 24 33 | mlp_ratio: 4.0 34 | class_dropout_prob: 0.1 35 | num_classes: 1000 36 | learn_sigma: false 37 | use_sit: true 38 | use_checkpoint: true 39 | use_swiglu: true 40 | use_swiglu_large: false 41 | q_norm: layernorm 42 | k_norm: layernorm 43 | qk_norm_weight: false 44 | rel_pos_embed: rope 45 | custom_freqs: ntk-aware 46 | decouple: true 47 | ori_max_pe_len: 16 48 | online_rope: true 49 | abs_pos_embed: null 50 | adaln_type: lora 51 | adaln_lora_dim: 576 52 | pretrain_ckpt: checkpoints/fitv2_3B.bin 53 | ignore_keys: ['x_embedder', 'bias', 'LN', 'final_layer'] 54 | finetune: partial 55 | 56 | data: 57 | target: fit.data.in1k_latent_dataset.INLatentLoader 58 | params: 59 | train: 60 | data_path: datasets/imagenet1k_latents_1024_sd_vae_ft_ema 61 | target_len: 1024 62 | random: 'random' 63 | loader: 64 | batch_size: 16 65 | num_workers: 2 66 | shuffle: True 67 | 68 | 69 | accelerate: 70 | # others 71 | gradient_accumulation_steps: 1 72 | mixed_precision: 'bf16' 73 | # training step config 74 | num_train_epochs: 75 | max_train_steps: 200000 76 | # optimizer config 77 | learning_rate: 1.0e-4 78 | learning_rate_base_batch_size: 256 79 | max_grad_norm: 1.0 80 | optimizer: 81 | target: torch.optim.AdamW 82 | params: 83 | betas: ${tuple:0.9, 0.999} 84 | weight_decay: 0 #1.0e-2 85 | eps: 1.0e-8 86 | lr_scheduler: constant 87 | lr_warmup_steps: 0 88 | # checkpoint config 89 | logger: wandb 90 | checkpointing_epochs: False 91 | checkpointing_steps: 4000 92 | checkpointing_steps_list: [40000, 80000, 100000] 93 | checkpoints_total_limit: 2 94 | logging_steps: 1000 95 | -------------------------------------------------------------------------------- /configs/fitv2/config_fitv2_hr_xl.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema 3 | transport: 4 | path_type: Linear 5 | prediction: velocity 6 | loss_weight: null 7 | sample_eps: null 8 | train_eps: null 9 | snr_type: lognorm 10 | sampler: 11 | mode: ODE 12 | sde: 13 | sampling_method: Euler 14 | diffusion_form: sigma 15 | diffusion_norm: 1.0 16 | last_step: Mean 17 | last_step_size: 0.04 18 | ode: 19 | sampling_method: dopri5 20 | atol: 1.0e-06 21 | rtol: 0.001 22 | reverse: false 23 | likelihood: false 24 | network_config: 25 | target: fit.model.fit_model.FiT 26 | params: 27 | context_size: 1024 28 | patch_size: 2 29 | in_channels: 4 30 | hidden_size: 1152 31 | depth: 36 32 | num_heads: 16 33 | mlp_ratio: 4.0 34 | class_dropout_prob: 0.1 35 | num_classes: 1000 36 | learn_sigma: false 37 | use_sit: true 38 | use_swiglu: true 39 | use_swiglu_large: false 40 | q_norm: layernorm 41 | k_norm: layernorm 42 | qk_norm_weight: false 43 | rel_pos_embed: rope 44 | custom_freqs: ntk-aware 45 | decouple: true 46 | ori_max_pe_len: 16 47 | online_rope: true 48 | abs_pos_embed: null 49 | adaln_type: lora 50 | adaln_lora_dim: 288 51 | pretrain_ckpt: checkpoints/fitv2_xl.safetensors 52 | ignore_keys: ['x_embedder', 'bias', 'LN', 'final_layer'] 53 | finetune: partial 54 | 55 | data: 56 | target: fit.data.in1k_latent_dataset.INLatentLoader 57 | params: 58 | train: 59 | data_path: datasets/imagenet1k_latents_1024_sd_vae_ft_ema 60 | target_len: 1024 61 | random: 'random' 62 | loader: 63 | batch_size: 16 64 | num_workers: 2 65 | shuffle: True 66 | 67 | 68 | accelerate: 69 | # others 70 | gradient_accumulation_steps: 1 71 | mixed_precision: 'bf16' 72 | # training step config 73 | num_train_epochs: 74 | max_train_steps: 400000 75 | # optimizer config 76 | learning_rate: 1.0e-4 77 | learning_rate_base_batch_size: 256 78 | max_grad_norm: 1.0 79 | optimizer: 80 | target: torch.optim.AdamW 81 | params: 82 | betas: ${tuple:0.9, 0.999} 83 | weight_decay: 0 #1.0e-2 84 | eps: 1.0e-8 85 | lr_scheduler: constant 86 | lr_warmup_steps: 0 87 | # checkpoint config 88 | logger: wandb 89 | checkpointing_epochs: False 90 | checkpointing_steps: 4000 91 | checkpointing_steps_list: [40000, 80000, 100000, 200000] 92 | checkpoints_total_limit: 2 93 | logging_steps: 1000 94 | -------------------------------------------------------------------------------- /configs/fitv2/config_fitv2_xl.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema 3 | transport: 4 | path_type: Linear 5 | prediction: velocity 6 | loss_weight: null 7 | sample_eps: null 8 | train_eps: null 9 | snr_type: lognorm 10 | sampler: 11 | mode: ODE 12 | sde: 13 | sampling_method: Euler 14 | diffusion_form: sigma 15 | diffusion_norm: 1.0 16 | last_step: Mean 17 | last_step_size: 0.04 18 | ode: 19 | sampling_method: dopri5 20 | atol: 1.0e-06 21 | rtol: 0.001 22 | reverse: false 23 | likelihood: false 24 | network_config: 25 | target: fit.model.fit_model.FiT 26 | params: 27 | context_size: 256 28 | patch_size: 2 29 | in_channels: 4 30 | hidden_size: 1152 31 | depth: 36 32 | num_heads: 16 33 | mlp_ratio: 4.0 34 | class_dropout_prob: 0.1 35 | num_classes: 1000 36 | learn_sigma: false 37 | use_sit: true 38 | use_swiglu: true 39 | use_swiglu_large: false 40 | q_norm: layernorm 41 | k_norm: layernorm 42 | qk_norm_weight: false 43 | rel_pos_embed: rope 44 | abs_pos_embed: null 45 | adaln_type: lora 46 | adaln_lora_dim: 288 47 | 48 | data: 49 | target: fit.data.in1k_latent_dataset.INLatentLoader 50 | params: 51 | train: 52 | data_path: datasets/imagenet1k_latents_256_sd_vae_ft_ema 53 | target_len: 256 54 | random: 'random' 55 | loader: 56 | batch_size: 16 57 | num_workers: 2 58 | shuffle: True 59 | 60 | 61 | accelerate: 62 | # others 63 | gradient_accumulation_steps: 1 64 | mixed_precision: 'bf16' 65 | # training step config 66 | num_train_epochs: 67 | max_train_steps: 2000000 68 | # optimizer config 69 | learning_rate: 1.0e-4 70 | learning_rate_base_batch_size: 256 71 | max_grad_norm: 1.0 72 | optimizer: 73 | target: torch.optim.AdamW 74 | params: 75 | betas: ${tuple:0.9, 0.999} 76 | weight_decay: 0 #1.0e-2 77 | eps: 1.0e-8 78 | lr_scheduler: constant_with_warmup 79 | lr_warmup_steps: 50000 80 | # checkpoint config 81 | logger: wandb 82 | checkpointing_epochs: False 83 | checkpointing_steps: 4000 84 | checkpointing_steps_list: [200000, 400000, 1000000, 1400000, 1500000, 1800000] 85 | checkpoints_total_limit: 2 86 | logging_steps: 1000 87 | -------------------------------------------------------------------------------- /fit/data/in1k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import torchvision 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchvision.datasets import ImageFolder 9 | from torchvision import transforms 10 | from accelerate.logging import get_logger 11 | logger = get_logger(__name__, log_level="INFO") 12 | 13 | 14 | def center_crop_arr(pil_image, image_size): 15 | """ 16 | Center cropping implementation from ADM. 17 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 18 | """ 19 | while min(*pil_image.size) >= 2 * image_size: 20 | pil_image = pil_image.resize( 21 | tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX 22 | ) 23 | 24 | scale = image_size / min(*pil_image.size) 25 | pil_image = pil_image.resize( 26 | tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC 27 | ) 28 | 29 | arr = np.array(pil_image) 30 | crop_y = (arr.shape[0] - image_size) // 2 31 | crop_x = (arr.shape[1] - image_size) // 2 32 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 33 | 34 | def resize_arr(pil_image, image_size, vq_vae_down=8, patch_size=2): 35 | w, h = pil_image.size 36 | min_len = int(vq_vae_down*patch_size) # 8*2=16 -> 256/16=16, 512/16=32 37 | if w * h >= image_size ** 2: 38 | new_w = np.sqrt(w/h) * image_size 39 | new_h = new_w * h / w 40 | elif w < min_len: # upsample, this case only happens twice in ImageNet1k_256 41 | new_w = min_len 42 | new_h = min_len * h / w 43 | elif h < min_len: # upsample, this case only happens once in ImageNet1k_256 44 | new_h = min_len 45 | new_w = min_len * w / h 46 | else: 47 | new_w, new_h = w, h 48 | new_w, new_h = int(new_w/min_len)*min_len, int(new_h/min_len)*min_len 49 | 50 | if new_w == w and new_h == h: 51 | return pil_image 52 | else: 53 | return pil_image.resize((new_w, new_h), resample=Image.Resampling.BICUBIC) 54 | 55 | class ImagenetDataDictWrapper(Dataset): 56 | def __init__(self, dataset): 57 | super().__init__() 58 | self.dataset = dataset 59 | 60 | def __getitem__(self, i): 61 | x, y = self.dataset[i] 62 | return {"jpg": x, "cls": y} 63 | 64 | def __len__(self): 65 | return len(self.dataset) 66 | 67 | class ImagenetLoader(): 68 | def __init__(self, train, rescale='crop'): 69 | super().__init__() 70 | 71 | self.train_config = train 72 | 73 | self.batch_size = self.train_config.loader.batch_size 74 | self.num_workers = self.train_config.loader.num_workers 75 | self.shuffle = self.train_config.loader.shuffle 76 | 77 | if rescale == 'crop': 78 | transform = transforms.Compose([ 79 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, self.train_config.resize)), 80 | transforms.RandomHorizontalFlip(), 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 83 | ]) 84 | elif rescale == 'resize': 85 | transform = transforms.Compose([ 86 | transforms.Lambda(lambda pil_image: resize_arr(pil_image, self.train_config.resize)), 87 | transforms.RandomHorizontalFlip(), 88 | transforms.ToTensor(), 89 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 90 | ]) 91 | elif rescale == 'keep': 92 | transform = transforms.Compose([ 93 | transforms.Lambda(lambda pil_image: resize_arr(pil_image, self.train_config.resize)), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 97 | ]) 98 | else: 99 | raise NotImplementedError 100 | 101 | self.train_dataset = ImagenetDataDictWrapper(ImageFolder(self.train_config.data_path, transform=transform)) 102 | 103 | self.test_dataset = None 104 | self.val_dataset = None 105 | 106 | def train_len(self): 107 | return len(self.train_dataset) 108 | 109 | def train_dataloader(self): 110 | return DataLoader( 111 | self.train_dataset, 112 | batch_size=self.batch_size, 113 | shuffle=self.shuffle, 114 | num_workers=self.num_workers, 115 | pin_memory=True, 116 | drop_last=True 117 | ) 118 | 119 | def test_dataloader(self): 120 | return None 121 | 122 | def val_dataloader(self): 123 | return DataLoader( 124 | self.train_dataset, 125 | batch_size=self.batch_size, 126 | shuffle=self.shuffle, 127 | num_workers=self.num_workers, 128 | pin_memory=True, 129 | drop_last=True 130 | ) 131 | 132 | if __name__ == "__main__": 133 | from omegaconf import OmegaConf 134 | conf = OmegaConf.load('/home/luzeyu/projects/workspace/generative-models/configs/example_training/dataset/imagenet-256-streaming.yaml') 135 | indataloader=ImagenetLoader(train=conf.data.params.train).train_dataloader() 136 | from tqdm import tqdm 137 | for i in tqdm(indataloader): 138 | # print(i) 139 | # print("*"*20) 140 | pass -------------------------------------------------------------------------------- /fit/data/in1k_latent_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import random 6 | from tqdm import tqdm 7 | from torch.utils.data import DataLoader, Dataset 8 | from typing import Optional 9 | from safetensors.torch import load_file, save_file 10 | from concurrent.futures import ThreadPoolExecutor, as_completed 11 | from einops import rearrange 12 | 13 | 14 | 15 | class IN1kLatentDataset(Dataset): 16 | def __init__(self, root_dir, target_len=256, random='random'): 17 | super().__init__() 18 | self.RandomHorizontalFlipProb = 0.5 19 | self.root_dir = root_dir 20 | self.target_len = target_len 21 | self.random = random 22 | self.files = [] 23 | files_1 = os.listdir(osp.join(root_dir, f'from_16_to_{target_len}')) 24 | files_2 = os.listdir(osp.join(root_dir, f'greater_than_{target_len}_resize')) 25 | files_3 = os.listdir(osp.join(root_dir, f'greater_than_{target_len}_crop')) 26 | files_23 = list(set(files_2) - set(files_3)) # files_3 in files_2 27 | self.files.extend([ 28 | [osp.join(root_dir, f'from_16_to_{target_len}', file)] for file in files_1 29 | ]) 30 | self.files.extend([ 31 | [osp.join(root_dir, f'greater_than_{target_len}_resize', file)] for file in files_23 32 | ]) 33 | self.files.extend([ 34 | [ 35 | osp.join(root_dir, f'greater_than_{target_len}_resize', file), 36 | osp.join(root_dir, f'greater_than_{target_len}_crop', file) 37 | ] for file in files_3 38 | ]) 39 | 40 | 41 | def __len__(self): 42 | return len(self.files) 43 | 44 | def __getitem__(self, idx): 45 | if self.random == 'random': 46 | path = random.choice(self.files[idx]) 47 | elif self.random == 'resize': 48 | path = self.files[idx][0] # only resize 49 | elif self.random == 'crop': 50 | path = self.files[idx][-1] # only crop 51 | data = load_file(path) 52 | dtype = data['feature'].dtype 53 | 54 | feature = torch.zeros((self.target_len, 16), dtype=dtype) 55 | grid = torch.zeros((2, self.target_len), dtype=dtype) 56 | mask = torch.zeros((self.target_len), dtype=torch.uint8) 57 | size = torch.zeros(2, dtype=torch.int32) 58 | 59 | 60 | seq_len = data['grid'].shape[-1] 61 | if torch.rand(1) < self.RandomHorizontalFlipProb: 62 | feature[0: seq_len] = rearrange(data['feature'][0], 'h w c -> (h w) c') 63 | else: 64 | feature[0: seq_len] = rearrange(data['feature'][1], 'h w c -> (h w) c') 65 | grid[:, 0: seq_len] = data['grid'] 66 | mask[0: seq_len] = 1 67 | size = data['size'][None, :] 68 | label = data['label'] 69 | return dict(feature=feature, grid=grid, mask=mask, label=label, size=size) 70 | 71 | 72 | 73 | # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60 74 | 75 | # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60 76 | def get_train_sampler(dataset, global_batch_size, max_steps, resume_steps, seed): 77 | sample_indices = torch.empty([max_steps * global_batch_size], dtype=torch.long) 78 | epoch_id, fill_ptr, offs = 0, 0, 0 79 | while fill_ptr < sample_indices.size(0): 80 | g = torch.Generator() 81 | g.manual_seed(seed + epoch_id) 82 | epoch_sample_indices = torch.randperm(len(dataset), generator=g) 83 | epoch_id += 1 84 | epoch_sample_indices = epoch_sample_indices[ 85 | :sample_indices.size(0) - fill_ptr 86 | ] 87 | sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \ 88 | epoch_sample_indices 89 | fill_ptr += epoch_sample_indices.size(0) 90 | return sample_indices[resume_steps * global_batch_size : ].tolist() 91 | 92 | 93 | 94 | class INLatentLoader(): 95 | def __init__(self, train): 96 | super().__init__() 97 | 98 | self.train_config = train 99 | 100 | self.batch_size = self.train_config.loader.batch_size 101 | self.num_workers = self.train_config.loader.num_workers 102 | self.shuffle = self.train_config.loader.shuffle 103 | 104 | self.train_dataset = IN1kLatentDataset( 105 | self.train_config.data_path, self.train_config.target_len, self.train_config.random 106 | ) 107 | 108 | 109 | self.test_dataset = None 110 | self.val_dataset = None 111 | 112 | def train_len(self): 113 | return len(self.train_dataset) 114 | 115 | def train_dataloader(self, global_batch_size, max_steps, resume_step, seed=42): 116 | sampler = get_train_sampler( 117 | self.train_dataset, global_batch_size, max_steps, resume_steps, seed 118 | ) 119 | return DataLoader( 120 | self.train_dataset, 121 | batch_size=self.batch_size, 122 | sampler=sampler, 123 | num_workers=self.num_workers, 124 | pin_memory=True, 125 | drop_last=True, 126 | ) 127 | 128 | def test_dataloader(self): 129 | return None 130 | 131 | def val_dataloader(self): 132 | return DataLoader( 133 | self.train_dataset, 134 | batch_size=self.batch_size, 135 | shuffle=self.shuffle, 136 | num_workers=self.num_workers, 137 | pin_memory=True, 138 | drop_last=True 139 | ) 140 | -------------------------------------------------------------------------------- /fit/model/fit_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | from typing import Optional 5 | from einops import rearrange, repeat 6 | from fit.model.modules import ( 7 | PatchEmbedder, TimestepEmbedder, LabelEmbedder, 8 | FiTBlock, FinalLayer 9 | ) 10 | from fit.model.utils import get_parameter_dtype 11 | from fit.utils.eval_utils import init_from_ckpt 12 | from fit.model.sincos import get_2d_sincos_pos_embed_from_grid 13 | from fit.model.rope import VisionRotaryEmbedding 14 | 15 | ################################################################################# 16 | # Core FiT Model # 17 | ################################################################################# 18 | 19 | 20 | 21 | class FiT(nn.Module): 22 | """ 23 | Flexible Diffusion model with a Transformer backbone. 24 | """ 25 | def __init__( 26 | self, 27 | context_size: int = 256, 28 | patch_size: int = 2, 29 | in_channels: int = 4, 30 | hidden_size: int = 1152, 31 | depth: int = 28, 32 | num_heads: int = 16, 33 | mlp_ratio: float = 4.0, 34 | class_dropout_prob: float = 0.1, 35 | num_classes: int = 1000, 36 | learn_sigma: bool = True, 37 | use_sit: bool = False, 38 | use_checkpoint: bool=False, 39 | use_swiglu: bool = False, 40 | use_swiglu_large: bool = False, 41 | rel_pos_embed: Optional[str] = 'rope', 42 | norm_type: str = "layernorm", 43 | q_norm: Optional[str] = None, 44 | k_norm: Optional[str] = None, 45 | qk_norm_weight: bool = False, 46 | qkv_bias: bool = True, 47 | ffn_bias: bool = True, 48 | adaln_bias: bool = True, 49 | adaln_type: str = "normal", 50 | adaln_lora_dim: int = None, 51 | rope_theta: float = 10000.0, 52 | custom_freqs: str = 'normal', 53 | max_pe_len_h: Optional[int] = None, 54 | max_pe_len_w: Optional[int] = None, 55 | decouple: bool = False, 56 | ori_max_pe_len: Optional[int] = None, 57 | online_rope: bool = False, 58 | add_rel_pe_to_v: bool = False, 59 | pretrain_ckpt: str = None, 60 | ignore_keys: list = None, 61 | finetune: str = None, 62 | time_shifting: int = 1, 63 | **kwargs, 64 | ): 65 | super().__init__() 66 | self.context_size = context_size 67 | self.hidden_size = hidden_size 68 | assert not (learn_sigma and use_sit) 69 | self.learn_sigma = learn_sigma 70 | self.use_sit = use_sit 71 | self.use_checkpoint = use_checkpoint 72 | self.depth = depth 73 | self.mlp_ratio = mlp_ratio 74 | self.class_dropout_prob = class_dropout_prob 75 | self.num_classes = num_classes 76 | self.in_channels = in_channels 77 | self.out_channels = self.in_channels * 2 if learn_sigma else in_channels 78 | self.patch_size = patch_size 79 | self.num_heads = num_heads 80 | self.adaln_type = adaln_type 81 | self.online_rope = online_rope 82 | self.time_shifting = time_shifting 83 | 84 | self.x_embedder = PatchEmbedder(in_channels * patch_size**2, hidden_size, bias=True) 85 | self.t_embedder = TimestepEmbedder(hidden_size) 86 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 87 | 88 | 89 | 90 | self.rel_pos_embed = VisionRotaryEmbedding( 91 | head_dim=hidden_size//num_heads, theta=rope_theta, custom_freqs=custom_freqs, online_rope=online_rope, 92 | max_pe_len_h=max_pe_len_h, max_pe_len_w=max_pe_len_w, decouple=decouple, ori_max_pe_len=ori_max_pe_len, 93 | ) 94 | 95 | if adaln_type == 'lora': 96 | self.global_adaLN_modulation = nn.Sequential( 97 | nn.SiLU(), 98 | nn.Linear(hidden_size, 6 * hidden_size, bias=adaln_bias) 99 | ) 100 | else: 101 | self.global_adaLN_modulation = None 102 | 103 | self.blocks = nn.ModuleList([FiTBlock( 104 | hidden_size, num_heads, mlp_ratio=mlp_ratio, swiglu=use_swiglu, swiglu_large=use_swiglu_large, 105 | rel_pos_embed=rel_pos_embed, add_rel_pe_to_v=add_rel_pe_to_v, norm_layer=norm_type, 106 | q_norm=q_norm, k_norm=k_norm, qk_norm_weight=qk_norm_weight, qkv_bias=qkv_bias, ffn_bias=ffn_bias, 107 | adaln_bias=adaln_bias, adaln_type=adaln_type, adaln_lora_dim=adaln_lora_dim 108 | ) for _ in range(depth)]) 109 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, norm_layer=norm_type, adaln_bias=adaln_bias, adaln_type=adaln_type) 110 | self.initialize_weights(pretrain_ckpt=pretrain_ckpt, ignore=ignore_keys) 111 | if finetune != None: 112 | self.finetune(type=finetune, unfreeze=ignore_keys) 113 | 114 | 115 | def initialize_weights(self, pretrain_ckpt=None, ignore=None): 116 | # Initialize transformer layers: 117 | def _basic_init(module): 118 | if isinstance(module, nn.Linear): 119 | torch.nn.init.xavier_uniform_(module.weight) 120 | if module.bias is not None: 121 | nn.init.constant_(module.bias, 0) 122 | self.apply(_basic_init) 123 | 124 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 125 | w = self.x_embedder.proj.weight.data 126 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 127 | nn.init.constant_(self.x_embedder.proj.bias, 0) 128 | 129 | # Initialize label embedding table: 130 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 131 | 132 | # Initialize timestep embedding MLP: 133 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 134 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 135 | 136 | # Zero-out adaLN modulation layers in DiT blocks: 137 | for block in self.blocks: 138 | if self.adaln_type in ['normal', 'lora']: 139 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 140 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 141 | elif self.adaln_type == 'swiglu': 142 | nn.init.constant_(block.adaLN_modulation.fc2.weight, 0) 143 | nn.init.constant_(block.adaLN_modulation.fc2.bias, 0) 144 | if self.adaln_type == 'lora': 145 | nn.init.constant_(self.global_adaLN_modulation[-1].weight, 0) 146 | nn.init.constant_(self.global_adaLN_modulation[-1].bias, 0) 147 | # Zero-out output layers: 148 | if self.adaln_type == 'swiglu': 149 | nn.init.constant_(self.final_layer.adaLN_modulation.fc2.weight, 0) 150 | nn.init.constant_(self.final_layer.adaLN_modulation.fc2.bias, 0) 151 | else: # adaln_type in ['normal', 'lora'] 152 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 153 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 154 | nn.init.constant_(self.final_layer.linear.weight, 0) 155 | nn.init.constant_(self.final_layer.linear.bias, 0) 156 | 157 | keys = list(self.state_dict().keys()) 158 | ignore_keys = [] 159 | if ignore != None: 160 | for ign in ignore: 161 | for key in keys: 162 | if ign in key: 163 | ignore_keys.append(key) 164 | ignore_keys = list(set(ignore_keys)) 165 | if pretrain_ckpt != None: 166 | init_from_ckpt(self, pretrain_ckpt, ignore_keys, verbose=True) 167 | 168 | 169 | def unpatchify(self, x, hw): 170 | """ 171 | args: 172 | x: (B, p**2 * C_out, N) 173 | N = h//p * w//p 174 | return: 175 | imgs: (B, C_out, H, W) 176 | """ 177 | h, w = hw 178 | p = self.patch_size 179 | if self.use_sit: 180 | x = rearrange(x, "b (h w) c -> b h w c", h=h//p, w=w//p) # (B, h//2 * w//2, 16) -> (B, h//2, w//2, 16) 181 | x = rearrange(x, "b h w (c p1 p2) -> b c (h p1) (w p2)", p1=p, p2=p) # (B, h//2, w//2, 16) -> (B, h, w, 4) 182 | else: 183 | x = rearrange(x, "b c (h w) -> b c h w", h=h//p, w=w//p) # (B, 16, h//2 * w//2) -> (B, 16, h//2, w//2) 184 | x = rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=p, p2=p) # (B, 16, h//2, w//2) -> (B, h, w, 4) 185 | return x 186 | 187 | def forward(self, x, t, y, grid, mask, size=None): 188 | """ 189 | Forward pass of FiT. 190 | x: (B, p**2 * C_in, N), tensor of sequential inputs (flattened latent features of images, N=H*W/(p**2)) 191 | t: (B,), tensor of diffusion timesteps 192 | y: (B,), tensor of class labels 193 | grid: (B, 2, N), tensor of height and weight indices that spans a grid 194 | mask: (B, N), tensor of the mask for the sequence 195 | size: (B, n, 2), tensor of the height and width, n is the number of the packed iamges 196 | -------------------------------------------------------------------------------------------- 197 | return: (B, p**2 * C_out, N), where C_out=2*C_in if leran_sigma, C_out=C_in otherwise. 198 | """ 199 | 200 | t = torch.clamp(self.time_shifting * t / (1 + (self.time_shifting - 1) * t), max=1.0) 201 | t = t.float().to(x.dtype) 202 | if not self.use_sit: 203 | x = rearrange(x, 'B C N -> B N C') # (B, C, N) -> (B, N, C), where C = p**2 * C_in 204 | x = self.x_embedder(x) # (B, N, C) -> (B, N, D) 205 | t = self.t_embedder(t) # (B, D) 206 | y = self.y_embedder(y, self.training) # (B, D) 207 | c = t + y # (B, D) 208 | 209 | # get RoPE frequences in advance, then calculate attention. 210 | if self.online_rope: 211 | freqs_cos, freqs_sin = self.rel_pos_embed.online_get_2d_rope_from_grid(grid, size) 212 | freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) 213 | else: 214 | freqs_cos, freqs_sin = self.rel_pos_embed.get_cached_2d_rope_from_grid(grid) 215 | freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) 216 | if self.global_adaLN_modulation != None: 217 | global_adaln = self.global_adaLN_modulation(c) 218 | else: 219 | global_adaln = 0.0 220 | 221 | if not self.use_checkpoint: 222 | for block in self.blocks: # (B, N, D) 223 | x = block(x, c, mask, freqs_cos, freqs_sin, global_adaln) 224 | else: 225 | for block in self.blocks: 226 | x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, c, mask, freqs_cos, freqs_sin, global_adaln) 227 | x = self.final_layer(x, c) # (B, N, p ** 2 * C_out), where C_out=2*C_in if leran_sigma, C_out=C_in otherwise. 228 | x = x * mask[..., None] # mask the padding tokens 229 | if not self.use_sit: 230 | x = rearrange(x, 'B N C -> B C N') # (B, N, C) -> (B, C, N), where C = p**2 * C_out 231 | return x 232 | 233 | 234 | 235 | def forward_with_cfg(self, x, t, y, grid, mask, size, cfg_scale, scale_pow=0.0): 236 | """ 237 | Forward pass with classifier free guidance of FiT. 238 | x: (2B, N, p**2 * C_in) if use_sit else (2B, p**2 * C_in, N) tensor of sequential inputs (flattened latent features of images, N=H*W/(p**2)) 239 | t: (2B,) tensor of diffusion timesteps 240 | y: (2B,) tensor of class labels 241 | grid: (2B, 2, N): tensor of height and weight indices that spans a grid 242 | mask: (2B, N): tensor of the mask for the sequence 243 | cfg_scale: float > 1.0 244 | return: (B, p**2 * C_out, N), where C_out=2*C_in if leran_sigma, C_out=C_in otherwise. 245 | """ 246 | half = x[: len(x) // 2] # (2B, ...) -> (B, ...) 247 | combined = torch.cat([half, half], dim=0) # (2B, ...) 248 | model_out = self.forward(combined, t, y, grid, mask, size) # (2B, N, C) is use_sit else (2B, C, N) , where C = p**2 * C_out 249 | # For exact reproducibility reasons, we apply classifier-free guidance on only 250 | # three channels by default. The standard approach to cfg applies it to all channels. 251 | # This can be done by uncommenting the following line and commenting-out the line following that. 252 | # C_cfg = self.in_channels * self.patch_size * self.patch_size 253 | C_cfg = 3 * self.patch_size * self.patch_size 254 | if self.use_sit: 255 | eps, rest = model_out[:, :, :C_cfg], model_out[:, :, C_cfg:] # eps: (2B, N, C_cfg), where C_cfg = p**2 * 3 256 | else: 257 | eps, rest = model_out[:, :C_cfg], model_out[:, C_cfg:] # eps: (2B, C_cfg, N), where C_cfg = p**2 * 3 258 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) # (2B, C_cfg, H, W) -> (B, C_cfg, H, W) 259 | # from https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py#L506 260 | # improved classifier-free guidance 261 | if scale_pow == 0.0: 262 | real_cfg_scale = cfg_scale 263 | else: 264 | scale_step = ( 265 | 1-torch.cos(((1-torch.clamp_max(t, 1.0))**scale_pow)*torch.pi) 266 | )*1/2 # power-cos scaling 267 | real_cfg_scale = (cfg_scale-1)*scale_step + 1 268 | real_cfg_scale = real_cfg_scale[: len(x) // 2].view(-1, 1, 1) 269 | t = t / (self.time_shifting + (1 - self.time_shifting) * t) 270 | half_eps = uncond_eps + real_cfg_scale * (cond_eps - uncond_eps) # (B, ...) 271 | eps = torch.cat([half_eps, half_eps], dim=0) # (B, ...) -> (2B, ...) 272 | if self.use_sit: 273 | return torch.cat([eps, rest], dim=2) # (2B, N, C), where C = p**2 * C_out 274 | else: 275 | return torch.cat([eps, rest], dim=1) # (2B, C, N), where C = p**2 * C_out 276 | 277 | def ckpt_wrapper(self, module): 278 | def ckpt_forward(*inputs): 279 | outputs = module(*inputs) 280 | return outputs 281 | return ckpt_forward 282 | 283 | 284 | @property 285 | def dtype(self) -> torch.dtype: 286 | """ 287 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 288 | """ 289 | return get_parameter_dtype(self) 290 | 291 | 292 | def finetune(self, type, unfreeze): 293 | if type == 'full': 294 | return 295 | for name, param in self.named_parameters(): 296 | param.requires_grad = False 297 | for unf in unfreeze: 298 | for name, param in self.named_parameters(): 299 | if unf in name: # LN means Layer Norm 300 | param.requires_grad = True 301 | -------------------------------------------------------------------------------- /fit/model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | from torch.jit import Final 7 | from timm.layers.mlp import SwiGLU, Mlp 8 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 9 | from fit.model.rope import rotate_half 10 | from fit.model.utils import modulate 11 | from fit.model.norms import create_norm 12 | from functools import partial 13 | from einops import rearrange, repeat 14 | 15 | ################################################################################# 16 | # Embedding Layers for Patches, Timesteps and Class Labels # 17 | ################################################################################# 18 | 19 | class PatchEmbedder(nn.Module): 20 | """ 21 | Embeds latent features into vector representations 22 | """ 23 | def __init__(self, 24 | input_dim, 25 | embed_dim, 26 | bias: bool = True, 27 | norm_layer: Optional[Callable] = None, 28 | ): 29 | super().__init__() 30 | 31 | self.proj = nn.Linear(input_dim, embed_dim, bias=bias) 32 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 33 | 34 | def forward(self, x): 35 | x = self.proj(x) # (B, L, patch_size ** 2 * C) -> (B, L, D) 36 | x = self.norm(x) 37 | return x 38 | 39 | class TimestepEmbedder(nn.Module): 40 | """ 41 | Embeds scalar timesteps into vector representations. 42 | """ 43 | def __init__(self, hidden_size, frequency_embedding_size=256): 44 | super().__init__() 45 | self.mlp = nn.Sequential( 46 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 47 | nn.SiLU(), 48 | nn.Linear(hidden_size, hidden_size, bias=True), 49 | ) 50 | self.frequency_embedding_size = frequency_embedding_size 51 | 52 | @staticmethod 53 | def timestep_embedding(t, dim, max_period=10000): 54 | """ 55 | Create sinusoidal timestep embeddings. 56 | :param t: a 1-D Tensor of N indices, one per batch element. 57 | These may be fractional. 58 | :param dim: the dimension of the output. 59 | :param max_period: controls the minimum frequency of the embeddings. 60 | :return: an (N, D) Tensor of positional embeddings. 61 | """ 62 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 63 | half = dim // 2 64 | freqs = torch.exp( 65 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 66 | ).to(device=t.device) 67 | args = t[:, None] * freqs[None] 68 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 69 | if dim % 2: 70 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1]).to(device=t.device)], dim=-1) 71 | return embedding.to(dtype=t.dtype) 72 | 73 | def forward(self, t): 74 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 75 | t_emb = self.mlp(t_freq) 76 | return t_emb 77 | 78 | 79 | class LabelEmbedder(nn.Module): 80 | """ 81 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 82 | """ 83 | def __init__(self, num_classes, hidden_size, dropout_prob): 84 | super().__init__() 85 | use_cfg_embedding = dropout_prob > 0 86 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 87 | self.num_classes = num_classes 88 | self.dropout_prob = dropout_prob 89 | 90 | def token_drop(self, labels, force_drop_ids=None): 91 | """ 92 | Drops labels to enable classifier-free guidance. 93 | """ 94 | if force_drop_ids is None: 95 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 96 | else: 97 | drop_ids = force_drop_ids == 1 98 | labels = torch.where(drop_ids, self.num_classes, labels) 99 | return labels 100 | 101 | def forward(self, labels, train, force_drop_ids=None): 102 | use_dropout = self.dropout_prob > 0 103 | if (train and use_dropout) or (force_drop_ids is not None): 104 | labels = self.token_drop(labels, force_drop_ids) 105 | embeddings = self.embedding_table(labels) 106 | return embeddings 107 | 108 | 109 | 110 | 111 | ################################################################################# 112 | # Attention # 113 | ################################################################################# 114 | 115 | # modified from timm and eva-02 116 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 117 | # https://github.com/baaivision/EVA/blob/master/EVA-02/asuka/modeling_finetune.py 118 | 119 | 120 | class Attention(nn.Module): 121 | 122 | def __init__(self, 123 | dim: int, 124 | num_heads: int = 8, 125 | qkv_bias: bool = False, 126 | q_norm: Optional[str] = None, 127 | k_norm: Optional[str] = None, 128 | qk_norm_weight: bool = False, 129 | attn_drop: float = 0., 130 | proj_drop: float = 0., 131 | rel_pos_embed: Optional[str] = None, 132 | add_rel_pe_to_v: bool = False, 133 | ) -> None: 134 | super().__init__() 135 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 136 | self.num_heads = num_heads 137 | self.head_dim = dim // num_heads 138 | self.scale = self.head_dim ** -0.5 139 | 140 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 141 | if q_norm == 'layernorm' and qk_norm_weight == True: 142 | q_norm = 'w_layernorm' 143 | if k_norm == 'layernorm' and qk_norm_weight == True: 144 | k_norm = 'w_layernorm' 145 | 146 | self.q_norm = create_norm(q_norm, self.head_dim) 147 | self.k_norm = create_norm(k_norm, self.head_dim) 148 | 149 | 150 | self.attn_drop = nn.Dropout(attn_drop) 151 | self.proj = nn.Linear(dim, dim) 152 | self.proj_drop = nn.Dropout(proj_drop) 153 | 154 | self.rel_pos_embed = None if rel_pos_embed==None else rel_pos_embed.lower() 155 | self.add_rel_pe_to_v = add_rel_pe_to_v 156 | 157 | 158 | 159 | def forward(self, 160 | x: torch.Tensor, 161 | mask: Optional[torch.Tensor] = None, 162 | freqs_cos: Optional[torch.Tensor] = None, 163 | freqs_sin: Optional[torch.Tensor] = None, 164 | ) -> torch.Tensor: 165 | B, N, C = x.shape 166 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 167 | q, k, v = qkv.unbind(0) # (B, n_h, N, D_h) 168 | q, k = self.q_norm(q), self.k_norm(k) 169 | 170 | if self.rel_pos_embed in ['rope', 'xpos']: # multiplicative rel_pos_embed 171 | if self.add_rel_pe_to_v: 172 | v = v * freqs_cos + rotate_half(v) * freqs_sin 173 | q = q * freqs_cos + rotate_half(q) * freqs_sin 174 | k = k * freqs_cos + rotate_half(k) * freqs_sin 175 | 176 | attn_mask = mask[:, None, None, :] # (B, N) -> (B, 1, 1, N) 177 | attn_mask = (attn_mask == attn_mask.transpose(-2, -1)) # (B, 1, 1, N) x (B, 1, N, 1) -> (B, 1, N, N) 178 | mask = torch.not_equal(mask, torch.zeros_like(mask)).to(mask) # (B, N) -> (B, N) 179 | 180 | 181 | if x.device.type == "cpu": 182 | x = F.scaled_dot_product_attention( 183 | q, k, v, attn_mask=attn_mask, 184 | dropout_p=self.attn_drop.p if self.training else 0., 185 | ) 186 | else: 187 | with torch.backends.cuda.sdp_kernel(enable_flash=True): 188 | ''' 189 | F.scaled_dot_product_attention is the efficient implementation equivalent to the following: 190 | attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask 191 | attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask 192 | attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1) 193 | attn_weight = torch.dropout(attn_weight, dropout_p) 194 | return attn_weight @ V 195 | In conclusion: 196 | boolean attn_mask will mask the attention matrix where attn_mask is False 197 | non-boolean attn_mask will be directly added to Q@K.T 198 | ''' 199 | x = F.scaled_dot_product_attention( 200 | q, k, v, attn_mask=attn_mask, 201 | dropout_p=self.attn_drop.p if self.training else 0., 202 | ) 203 | x = x.transpose(1, 2).reshape(B, N, C) 204 | x = x * mask[..., None] # mask: (B, N) -> (B, N, 1) 205 | x = self.proj(x) 206 | x = self.proj_drop(x) 207 | return x 208 | 209 | ################################################################################# 210 | # Basic FiT Module # 211 | ################################################################################# 212 | 213 | class FiTBlock(nn.Module): 214 | """ 215 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 216 | """ 217 | def __init__(self, 218 | hidden_size, 219 | num_heads, 220 | mlp_ratio=4.0, 221 | swiglu=True, 222 | swiglu_large=False, 223 | rel_pos_embed=None, 224 | add_rel_pe_to_v=False, 225 | norm_layer: str = 'layernorm', 226 | q_norm: Optional[str] = None, 227 | k_norm: Optional[str] = None, 228 | qk_norm_weight: bool = False, 229 | qkv_bias=True, 230 | ffn_bias=True, 231 | adaln_bias=True, 232 | adaln_type='normal', 233 | adaln_lora_dim: int = None, 234 | **block_kwargs 235 | ): 236 | super().__init__() 237 | self.norm1 = create_norm(norm_layer, hidden_size) 238 | self.norm2 = create_norm(norm_layer, hidden_size) 239 | 240 | self.attn = Attention( 241 | hidden_size, num_heads=num_heads, rel_pos_embed=rel_pos_embed, 242 | q_norm=q_norm, k_norm=k_norm, qk_norm_weight=qk_norm_weight, 243 | qkv_bias=qkv_bias, add_rel_pe_to_v=add_rel_pe_to_v, 244 | **block_kwargs 245 | ) 246 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 247 | if swiglu: 248 | if swiglu_large: 249 | self.mlp = SwiGLU(in_features=hidden_size, hidden_features=mlp_hidden_dim, bias=ffn_bias) 250 | else: 251 | self.mlp = SwiGLU(in_features=hidden_size, hidden_features=(mlp_hidden_dim*2)//3, bias=ffn_bias) 252 | else: 253 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=ffn_bias) 254 | if adaln_type == 'normal': 255 | self.adaLN_modulation = nn.Sequential( 256 | nn.SiLU(), 257 | nn.Linear(hidden_size, 6 * hidden_size, bias=adaln_bias) 258 | ) 259 | elif adaln_type == 'lora': 260 | self.adaLN_modulation = nn.Sequential( 261 | nn.SiLU(), 262 | nn.Linear(hidden_size, adaln_lora_dim, bias=adaln_bias), 263 | nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=adaln_bias) 264 | ) 265 | elif adaln_type == 'swiglu': 266 | self.adaLN_modulation = SwiGLU( 267 | in_features=hidden_size, hidden_features=(hidden_size//4)*3, out_features=6*hidden_size, bias=adaln_bias 268 | ) 269 | 270 | def forward(self, x, c, mask, freqs_cos, freqs_sin, global_adaln=0.0): 271 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.adaLN_modulation(c) + global_adaln).chunk(6, dim=1) 272 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask, freqs_cos, freqs_sin) 273 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 274 | return x 275 | 276 | class FinalLayer(nn.Module): 277 | """ 278 | The final layer of DiT. 279 | """ 280 | def __init__(self, hidden_size, patch_size, out_channels, norm_layer: str = 'layernorm', adaln_bias=True, adaln_type='normal'): 281 | super().__init__() 282 | self.norm_final = create_norm(norm_type=norm_layer, dim=hidden_size) 283 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 284 | if adaln_type == 'swiglu': 285 | self.adaLN_modulation = SwiGLU(in_features=hidden_size, hidden_features=hidden_size//2, out_features=2*hidden_size, bias=adaln_bias) 286 | else: # adaln_type in ['normal', 'lora'] 287 | self.adaLN_modulation = nn.Sequential( 288 | nn.SiLU(), 289 | nn.Linear(hidden_size, 2 * hidden_size, bias=adaln_bias) 290 | ) 291 | 292 | def forward(self, x, c): 293 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 294 | x = modulate(self.norm_final(x), shift, scale) 295 | x = self.linear(x) 296 | return x 297 | -------------------------------------------------------------------------------- /fit/model/norms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from functools import partial 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | import triton 15 | import triton.language as tl 16 | 17 | 18 | 19 | def create_norm(norm_type: str, dim: int, eps: float = 1e-6): 20 | """ 21 | Creates the specified normalization layer based on the norm_type. 22 | 23 | Args: 24 | norm_type (str): The type of normalization layer to create. 25 | Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm 26 | dim (int): The dimension of the normalization layer. 27 | eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. 28 | 29 | Returns: 30 | The created normalization layer. 31 | 32 | Raises: 33 | NotImplementedError: If an unknown norm_type is provided. 34 | """ 35 | if norm_type == None or norm_type == "": 36 | return nn.Identity() 37 | norm_type = norm_type.lower() # Normalize to lowercase 38 | 39 | if norm_type == "w_layernorm": 40 | return nn.LayerNorm(dim, eps=eps, bias=False) 41 | elif norm_type == "layernorm": 42 | return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) 43 | elif norm_type == "w_rmsnorm": 44 | return RMSNorm(dim, eps=eps) 45 | elif norm_type == "rmsnorm": 46 | return RMSNorm(dim, include_weight=False, eps=eps) 47 | elif norm_type == 'none': 48 | return nn.Identity() 49 | else: 50 | raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") 51 | 52 | 53 | class RMSNorm(nn.Module): 54 | """ 55 | Initialize the RMSNorm normalization layer. 56 | 57 | Args: 58 | dim (int): The dimension of the input tensor. 59 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 60 | 61 | Attributes: 62 | eps (float): A small value added to the denominator for numerical stability. 63 | weight (nn.Parameter): Learnable scaling parameter. 64 | 65 | """ 66 | 67 | def __init__(self, dim: int, eps: float = 1e-6): 68 | super().__init__() 69 | self.eps = eps 70 | self.weight = nn.Parameter(torch.ones(dim)) 71 | 72 | def _norm(self, x: torch.Tensor): 73 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 74 | 75 | def forward(self, x: torch.Tensor): 76 | output = self._norm(x.float()).type_as(x) 77 | return output * self.weight 78 | 79 | def reset_parameters(self): 80 | torch.nn.init.ones_(self.weight) # type: ignore 81 | 82 | -------------------------------------------------------------------------------- /fit/model/rope.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # FiT: A Flexible Vision Transformer for Image Generation 3 | # 4 | # Based on the following repository 5 | # https://github.com/lucidrains/rotary-embedding-torch 6 | # https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope 7 | # https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | from math import pi 12 | from typing import Optional, Any, Union, Tuple 13 | import torch 14 | from torch import nn 15 | 16 | from einops import rearrange, repeat 17 | from functools import lru_cache 18 | 19 | ################################################################################# 20 | # NTK Operations # 21 | ################################################################################# 22 | 23 | def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048): 24 | return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations 25 | 26 | def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): 27 | low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings)) 28 | high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings)) 29 | return max(low, 0), min(high, dim-1) #Clamp values just in case 30 | 31 | def linear_ramp_mask(min, max, dim): 32 | if min == max: 33 | max += 0.001 #Prevent singularity 34 | 35 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 36 | ramp_func = torch.clamp(linear_func, 0, 1) 37 | return ramp_func 38 | 39 | def find_newbase_ntk(dim, base=10000, scale=1): 40 | # Base change formula 41 | return base * scale ** (dim / (dim-2)) 42 | 43 | def get_mscale(scale=torch.Tensor): 44 | # if scale <= 1: 45 | # return 1.0 46 | # return 0.1 * math.log(scale) + 1.0 47 | return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0) 48 | 49 | def get_proportion(L_test, L_train): 50 | L_test = L_test * 2 51 | return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))) 52 | # return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))) 53 | 54 | 55 | 56 | ################################################################################# 57 | # Rotate Q or K # 58 | ################################################################################# 59 | 60 | def rotate_half(x): 61 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 62 | x1, x2 = x.unbind(dim = -1) 63 | x = torch.stack((-x2, x1), dim = -1) 64 | return rearrange(x, '... d r -> ... (d r)') 65 | 66 | 67 | 68 | ################################################################################# 69 | # Core Vision RoPE # 70 | ################################################################################# 71 | 72 | class VisionRotaryEmbedding(nn.Module): 73 | def __init__( 74 | self, 75 | head_dim: int, # embed dimension for each head 76 | custom_freqs: str = 'normal', 77 | theta: int = 10000, 78 | online_rope: bool = False, 79 | max_cached_len: int = 256, 80 | max_pe_len_h: Optional[int] = None, 81 | max_pe_len_w: Optional[int] = None, 82 | decouple: bool = False, 83 | ori_max_pe_len: Optional[int] = None, 84 | ): 85 | super().__init__() 86 | 87 | dim = head_dim // 2 88 | assert dim % 2 == 0 # accually, this is important 89 | self.dim = dim 90 | self.custom_freqs = custom_freqs.lower() 91 | self.theta = theta 92 | self.decouple = decouple 93 | self.ori_max_pe_len = ori_max_pe_len 94 | 95 | self.custom_freqs = custom_freqs.lower() 96 | if not online_rope: 97 | if self.custom_freqs == 'normal': 98 | freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) 99 | freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) 100 | else: 101 | if decouple: 102 | freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len) 103 | freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len) 104 | else: 105 | max_pe_len = max(max_pe_len_h, max_pe_len_w) 106 | freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len) 107 | freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len) 108 | 109 | attn_factor = 1.0 110 | scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0) # dynamic scale 111 | self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation 112 | self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len) 113 | self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2) 114 | 115 | self.register_buffer('freqs_h', freqs_h, persistent=False) 116 | self.register_buffer('freqs_w', freqs_w, persistent=False) 117 | 118 | freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h) 119 | freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2) 120 | self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False) 121 | freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w) 122 | freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2) 123 | self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False) 124 | 125 | 126 | def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len): 127 | # scaling operations for extrapolation 128 | assert isinstance(ori_max_pe_len, int) 129 | # scale = max_pe_len / ori_max_pe_len 130 | if not isinstance(max_pe_len, torch.Tensor): 131 | max_pe_len = torch.tensor(max_pe_len) 132 | scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) # dynamic scale 133 | 134 | if self.custom_freqs == 'linear': # equal to position interpolation 135 | freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim)) 136 | elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2': 137 | freqs = 1. / torch.pow( 138 | find_newbase_ntk(dim, theta, scale).view(-1, 1), 139 | (torch.arange(0, dim, 2).to(scale).float() / dim) 140 | ).squeeze() 141 | elif self.custom_freqs == 'ntk-by-parts': 142 | #Interpolation constants found experimentally for LLaMA (might not be totally optimal though) 143 | #Do not change unless there is a good reason for doing so! 144 | beta_0 = 1.25 145 | beta_1 = 0.75 146 | gamma_0 = 16 147 | gamma_1 = 2 148 | ntk_factor = 1 149 | extrapolation_factor = 1 150 | 151 | #Three RoPE extrapolation/interpolation methods 152 | freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) 153 | freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))) 154 | freqs_ntk = 1. / torch.pow( 155 | find_newbase_ntk(dim, theta, scale).view(-1, 1), 156 | (torch.arange(0, dim, 2).to(scale).float() / dim) 157 | ).squeeze() 158 | 159 | #Combine NTK and Linear 160 | low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) 161 | freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor 162 | freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask 163 | 164 | #Combine Extrapolation and NTK and Linear 165 | low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) 166 | freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor 167 | freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask 168 | 169 | elif self.custom_freqs == 'yarn': 170 | #Interpolation constants found experimentally for LLaMA (might not be totally optimal though) 171 | #Do not change unless there is a good reason for doing so! 172 | beta_fast = 32 173 | beta_slow = 1 174 | extrapolation_factor = 1 175 | 176 | freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)) 177 | freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))) 178 | 179 | low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len) 180 | freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation 181 | freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask 182 | else: 183 | raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!') 184 | return freqs 185 | 186 | 187 | def online_get_2d_rope_from_grid(self, grid, size): 188 | ''' 189 | grid: (B, 2, N) 190 | N = H * W 191 | the first dimension represents width, and the second reprensents height 192 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] 193 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] 194 | size: (B, 1, 2), h goes first and w goes last 195 | ''' 196 | size = size.squeeze() # (B, 1, 2) -> (B, 2) 197 | if self.decouple: 198 | size_h = size[:, 0] 199 | size_w = size[:, 1] 200 | freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len) 201 | freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len) 202 | else: 203 | size_max = torch.max(size[:, 0], size[:, 1]) 204 | freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len) 205 | freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len) 206 | freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :] 207 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 208 | 209 | freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :] 210 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 211 | 212 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) 213 | 214 | if self.custom_freqs == 'yarn': 215 | freqs_cos = freqs.cos() * self.mscale[:, None, None] 216 | freqs_sin = freqs.sin() * self.mscale[:, None, None] 217 | elif self.custom_freqs == 'ntk-aware-pro1': 218 | freqs_cos = freqs.cos() * self.proportion1[:, None, None] 219 | freqs_sin = freqs.sin() * self.proportion1[:, None, None] 220 | elif self.custom_freqs == 'ntk-aware-pro2': 221 | freqs_cos = freqs.cos() * self.proportion2[:, None, None] 222 | freqs_sin = freqs.sin() * self.proportion2[:, None, None] 223 | else: 224 | freqs_cos = freqs.cos() 225 | freqs_sin = freqs.sin() 226 | 227 | return freqs_cos, freqs_sin 228 | 229 | @lru_cache() 230 | def get_2d_rope_from_grid(self, grid): 231 | ''' 232 | grid: (B, 2, N) 233 | N = H * W 234 | the first dimension represents width, and the second reprensents height 235 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] 236 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] 237 | ''' 238 | freqs_w = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_w) 239 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 240 | 241 | freqs_h = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_h) 242 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 243 | 244 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) 245 | 246 | if self.custom_freqs == 'yarn': 247 | freqs_cos = freqs.cos() * self.mscale 248 | freqs_sin = freqs.sin() * self.mscale 249 | elif self.custom_freqs == 'ntk-aware-pro1': 250 | freqs_cos = freqs.cos() * self.proportion1 251 | freqs_sin = freqs.sin() * self.proportion1 252 | elif self.custom_freqs == 'ntk-aware-pro2': 253 | freqs_cos = freqs.cos() * self.proportion2 254 | freqs_sin = freqs.sin() * self.proportion2 255 | else: 256 | freqs_cos = freqs.cos() 257 | freqs_sin = freqs.sin() 258 | 259 | return freqs_cos, freqs_sin 260 | 261 | @lru_cache() 262 | def get_cached_2d_rope_from_grid(self, grid: torch.Tensor): 263 | ''' 264 | grid: (B, 2, N) 265 | N = H * W 266 | the first dimension represents width, and the second reprensents height 267 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] 268 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] 269 | ''' 270 | freqs_w, freqs_h = self.freqs_w_cached[grid[:, 0]], self.freqs_h_cached[grid[:, 1]] 271 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) 272 | 273 | if self.custom_freqs == 'yarn': 274 | freqs_cos = freqs.cos() * self.mscale 275 | freqs_sin = freqs.sin() * self.mscale 276 | elif self.custom_freqs == 'ntk-aware-pro1': 277 | freqs_cos = freqs.cos() * self.proportion1 278 | freqs_sin = freqs.sin() * self.proportion1 279 | elif self.custom_freqs == 'ntk-aware-pro2': 280 | freqs_cos = freqs.cos() * self.proportion2 281 | freqs_sin = freqs.sin() * self.proportion2 282 | else: 283 | freqs_cos = freqs.cos() 284 | freqs_sin = freqs.sin() 285 | 286 | return freqs_cos, freqs_sin 287 | 288 | @lru_cache() 289 | def get_cached_21d_rope_from_grid(self, grid: torch.Tensor): # for 3d rope formulation 2 ! 290 | ''' 291 | grid: (B, 3, N) 292 | N = H * W * T 293 | the first dimension represents width, and the second reprensents height, and the third reprensents time 294 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] 295 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] 296 | [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 297 | ''' 298 | freqs_w, freqs_h = self.freqs_w_cached[grid[:, 0]+grid[:, 2]], self.freqs_h_cached[grid[:, 1]+grid[:, 2]] 299 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) 300 | 301 | if self.custom_freqs == 'yarn': 302 | freqs_cos = freqs.cos() * self.mscale 303 | freqs_sin = freqs.sin() * self.mscale 304 | elif self.custom_freqs == 'ntk-aware-pro1': 305 | freqs_cos = freqs.cos() * self.proportion1 306 | freqs_sin = freqs.sin() * self.proportion1 307 | elif self.custom_freqs == 'ntk-aware-pro2': 308 | freqs_cos = freqs.cos() * self.proportion2 309 | freqs_sin = freqs.sin() * self.proportion2 310 | else: 311 | freqs_cos = freqs.cos() 312 | freqs_sin = freqs.sin() 313 | 314 | return freqs_cos, freqs_sin 315 | 316 | def forward(self, x, grid): 317 | ''' 318 | x: (B, n_head, N, D) 319 | grid: (B, 2, N) 320 | ''' 321 | # freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid) 322 | # freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) 323 | # using cache to accelerate, this is the same with the above codes: 324 | freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid) 325 | freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) 326 | return x * freqs_cos + rotate_half(x) * freqs_sin 327 | 328 | -------------------------------------------------------------------------------- /fit/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import List, Tuple 4 | 5 | 6 | def modulate(x, shift, scale): 7 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 8 | 9 | 10 | def get_parameter_dtype(parameter: torch.nn.Module): 11 | try: 12 | params = tuple(parameter.parameters()) 13 | if len(params) > 0: 14 | return params[0].dtype 15 | 16 | buffers = tuple(parameter.buffers()) 17 | if len(buffers) > 0: 18 | return buffers[0].dtype 19 | 20 | except StopIteration: 21 | # For torch.nn.DataParallel compatibility in PyTorch 1.5 22 | 23 | def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: 24 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 25 | return tuples 26 | 27 | gen = parameter._named_members(get_members_fn=find_tensor_attributes) 28 | first_tuple = next(gen) 29 | return first_tuple[1].dtype 30 | 31 | -------------------------------------------------------------------------------- /fit/scheduler/improved_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /fit/scheduler/improved_diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | 90 | 91 | def get_flexible_mask_and_ratio(model_kwargs: dict, x: th.Tensor): 92 | ''' 93 | sequential case (fit): 94 | x: (B, C, N) 95 | model_kwargs: {y: (B,), mask: (B, N), grid: (B, 2, N)} 96 | mask: (B, N) -> (B, 1, N) 97 | spatial case (dit): 98 | x: (B, C, H, W) 99 | model_kwargs: {y: (B,)} 100 | mask: (B, C) -> (B, C, 1, 1) 101 | ''' 102 | mask = model_kwargs.get('mask', th.ones(x.shape[:2])) # (B, N) or (B, C) 103 | ratio = float(mask.shape[-1]) / th.count_nonzero(mask, dim=-1) # (B,) 104 | if len(x.shape) == 3: # sequential x: (B, C, N) 105 | mask = mask[:, None, :] # (B, N) -> (B, 1, N) 106 | elif len(x.shape) == 4: # spatial x: (B, C, H, W) 107 | mask = mask[..., None, None] # (B, C) -> (B, C, 1, 1) 108 | else: 109 | raise NotImplementedError 110 | return mask.to(x), ratio.to(x) -------------------------------------------------------------------------------- /fit/scheduler/improved_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /fit/scheduler/improved_diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /fit/scheduler/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import Transport, ModelType, WeightType, PathType, SNRType, Sampler 2 | 3 | def create_transport( 4 | path_type='Linear', 5 | prediction="velocity", 6 | loss_weight=None, 7 | train_eps=None, 8 | sample_eps=None, 9 | snr_type='uniform', 10 | ): 11 | """function for creating Transport object 12 | **Note**: model prediction defaults to velocity 13 | Args: 14 | - path_type: type of path to use; default to linear 15 | - learn_score: set model prediction to score 16 | - learn_noise: set model prediction to noise 17 | - velocity_weighted: weight loss by velocity weight 18 | - likelihood_weighted: weight loss by likelihood weight 19 | - train_eps: small epsilon for avoiding instability during training 20 | - sample_eps: small epsilon for avoiding instability during sampling 21 | """ 22 | 23 | if prediction == "noise": 24 | model_type = ModelType.NOISE 25 | elif prediction == "score": 26 | model_type = ModelType.SCORE 27 | else: 28 | model_type = ModelType.VELOCITY 29 | 30 | if loss_weight == "velocity": 31 | loss_type = WeightType.VELOCITY 32 | elif loss_weight == "likelihood": 33 | loss_type = WeightType.LIKELIHOOD 34 | else: 35 | loss_type = WeightType.NONE 36 | 37 | if snr_type == "lognorm": 38 | snr_type = SNRType.LOGNORM 39 | elif snr_type == "uniform": 40 | snr_type = SNRType.UNIFORM 41 | else: 42 | raise ValueError(f"Invalid snr type {snr_type}") 43 | 44 | path_choice = { 45 | "Linear": PathType.LINEAR, 46 | "GVP": PathType.GVP, 47 | "VP": PathType.VP, 48 | } 49 | 50 | path_type = path_choice[path_type] 51 | 52 | if (path_type in [PathType.VP]): 53 | train_eps = 1e-5 if train_eps is None else train_eps 54 | sample_eps = 1e-3 if train_eps is None else sample_eps 55 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): 56 | train_eps = 1e-3 if train_eps is None else train_eps 57 | sample_eps = 1e-3 if train_eps is None else sample_eps 58 | else: # velocity & [GVP, LINEAR] is stable everywhere 59 | train_eps = 0 60 | sample_eps = 0 61 | 62 | # create flow state 63 | state = Transport( 64 | model_type=model_type, 65 | path_type=path_type, 66 | loss_type=loss_type, 67 | train_eps=train_eps, 68 | sample_eps=sample_eps, 69 | snr_type=snr_type, 70 | ) 71 | 72 | return state -------------------------------------------------------------------------------- /fit/scheduler/transport/integrators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | from torchdiffeq import odeint 5 | from functools import partial 6 | from tqdm import tqdm 7 | 8 | class sde: 9 | """SDE solver class""" 10 | def __init__( 11 | self, 12 | drift, 13 | diffusion, 14 | *, 15 | t0, 16 | t1, 17 | num_steps, 18 | sampler_type, 19 | ): 20 | assert t0 < t1, "SDE sampler has to be in forward time" 21 | 22 | self.num_timesteps = num_steps 23 | self.t = th.linspace(t0, t1, num_steps) 24 | self.dt = self.t[1] - self.t[0] 25 | self.drift = drift 26 | self.diffusion = diffusion 27 | self.sampler_type = sampler_type 28 | 29 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): 30 | w_cur = th.randn(x.size()).to(x) 31 | t = th.ones(x.size(0)).to(x) * t 32 | dw = w_cur * th.sqrt(self.dt) 33 | drift = self.drift(x, t, model, **model_kwargs) 34 | diffusion = self.diffusion(x, t) 35 | mean_x = x + drift * self.dt 36 | x = mean_x + th.sqrt(2 * diffusion) * dw 37 | return x, mean_x 38 | 39 | def __Heun_step(self, x, _, t, model, **model_kwargs): 40 | w_cur = th.randn(x.size()).to(x) 41 | dw = w_cur * th.sqrt(self.dt) 42 | t_cur = th.ones(x.size(0)).to(x) * t 43 | diffusion = self.diffusion(x, t_cur) 44 | xhat = x + th.sqrt(2 * diffusion) * dw 45 | K1 = self.drift(xhat, t_cur, model, **model_kwargs) 46 | xp = xhat + self.dt * K1 47 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) 48 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step 49 | 50 | def __forward_fn(self): 51 | """TODO: generalize here by adding all private functions ending with steps to it""" 52 | sampler_dict = { 53 | "Euler": self.__Euler_Maruyama_step, 54 | "Heun": self.__Heun_step, 55 | } 56 | 57 | try: 58 | sampler = sampler_dict[self.sampler_type] 59 | except: 60 | raise NotImplementedError("Smapler type not implemented.") 61 | 62 | return sampler 63 | 64 | def sample(self, init, model, **model_kwargs): 65 | """forward loop of sde""" 66 | x = init 67 | mean_x = init 68 | samples = [] 69 | sampler = self.__forward_fn() 70 | for ti in self.t[:-1]: 71 | with th.no_grad(): 72 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) 73 | samples.append(x) 74 | 75 | return samples 76 | 77 | class ode: 78 | """ODE solver class""" 79 | def __init__( 80 | self, 81 | drift, 82 | *, 83 | t0, 84 | t1, 85 | sampler_type, 86 | num_steps, 87 | atol, 88 | rtol, 89 | ): 90 | assert t0 < t1, "ODE sampler has to be in forward time" 91 | 92 | self.drift = drift 93 | self.t = th.linspace(t0, t1, num_steps) 94 | self.atol = atol 95 | self.rtol = rtol 96 | self.sampler_type = sampler_type 97 | 98 | def sample(self, x, model, **model_kwargs): 99 | 100 | device = x[0].device if isinstance(x, tuple) else x.device 101 | def _fn(t, x): 102 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t 103 | model_output = self.drift(x, t, model, **model_kwargs) 104 | return model_output 105 | 106 | t = self.t.to(device) 107 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] 108 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] 109 | samples = odeint( 110 | _fn, 111 | x, 112 | t, 113 | method=self.sampler_type, 114 | atol=atol, 115 | rtol=rtol 116 | ) 117 | return samples -------------------------------------------------------------------------------- /fit/scheduler/transport/path.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from functools import partial 4 | 5 | def expand_t_like_x(t, x): 6 | """Function to reshape time t to broadcastable dimension of x 7 | Args: 8 | t: [batch_dim,], time vector 9 | x: [batch_dim,...], data point 10 | """ 11 | dims = [1] * (len(x.size()) - 1) 12 | t = t.view(t.size(0), *dims) 13 | return t 14 | 15 | 16 | #################### Coupling Plans #################### 17 | 18 | class ICPlan: 19 | """Linear Coupling Plan""" 20 | def __init__(self, sigma=0.0): 21 | self.sigma = sigma 22 | 23 | def compute_alpha_t(self, t): 24 | """Compute the data coefficient along the path""" 25 | return t, 1 26 | 27 | def compute_sigma_t(self, t): 28 | """Compute the noise coefficient along the path""" 29 | return 1 - t, -1 30 | 31 | def compute_d_alpha_alpha_ratio_t(self, t): 32 | """Compute the ratio between d_alpha and alpha""" 33 | return 1 / t 34 | 35 | def compute_drift(self, x, t): 36 | """We always output sde according to score parametrization; """ 37 | t = expand_t_like_x(t, x) 38 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) 39 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 40 | drift = alpha_ratio * x 41 | diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t 42 | 43 | return -drift, diffusion 44 | 45 | def compute_diffusion(self, x, t, form="constant", norm=1.0): 46 | """Compute the diffusion term of the SDE 47 | Args: 48 | x: [batch_dim, ...], data point 49 | t: [batch_dim,], time vector 50 | form: str, form of the diffusion term 51 | norm: float, norm of the diffusion term 52 | """ 53 | t = expand_t_like_x(t, x) 54 | choices = { 55 | "constant": th.tensor(norm).to(x), 56 | "SBDM": norm * self.compute_drift(x, t)[1], 57 | "sigma": norm * self.compute_sigma_t(t)[0], 58 | "linear": norm * (1 - t), 59 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, 60 | "increasing-decreasing": norm * th.sin(np.pi * t) ** 2, 61 | } 62 | 63 | try: 64 | diffusion = choices[form] 65 | except KeyError: 66 | raise NotImplementedError(f"Diffusion form {form} not implemented") 67 | 68 | return diffusion 69 | 70 | def get_score_from_velocity(self, velocity, x, t): 71 | """Wrapper function: transfrom velocity prediction model to score 72 | Args: 73 | velocity: [batch_dim, ...] shaped tensor; velocity model output 74 | x: [batch_dim, ...] shaped tensor; x_t data point 75 | t: [batch_dim,] time tensor 76 | """ 77 | t = expand_t_like_x(t, x) 78 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 79 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 80 | mean = x 81 | reverse_alpha_ratio = alpha_t / d_alpha_t 82 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 83 | score = (reverse_alpha_ratio * velocity - mean) / var 84 | return score 85 | 86 | def get_noise_from_velocity(self, velocity, x, t): 87 | """Wrapper function: transfrom velocity prediction model to denoiser 88 | Args: 89 | velocity: [batch_dim, ...] shaped tensor; velocity model output 90 | x: [batch_dim, ...] shaped tensor; x_t data point 91 | t: [batch_dim,] time tensor 92 | """ 93 | t = expand_t_like_x(t, x) 94 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 95 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 96 | mean = x 97 | reverse_alpha_ratio = alpha_t / d_alpha_t 98 | var = reverse_alpha_ratio * d_sigma_t - sigma_t 99 | noise = (reverse_alpha_ratio * velocity - mean) / var 100 | return noise 101 | 102 | def get_velocity_from_score(self, score, x, t): 103 | """Wrapper function: transfrom score prediction model to velocity 104 | Args: 105 | score: [batch_dim, ...] shaped tensor; score model output 106 | x: [batch_dim, ...] shaped tensor; x_t data point 107 | t: [batch_dim,] time tensor 108 | """ 109 | t = expand_t_like_x(t, x) 110 | drift, var = self.compute_drift(x, t) 111 | velocity = var * score - drift 112 | return velocity 113 | 114 | def compute_mu_t(self, t, x0, x1): 115 | """Compute the mean of time-dependent density p_t""" 116 | t = expand_t_like_x(t, x1) 117 | alpha_t, _ = self.compute_alpha_t(t) 118 | sigma_t, _ = self.compute_sigma_t(t) 119 | return alpha_t * x1 + sigma_t * x0 120 | 121 | def compute_xt(self, t, x0, x1): 122 | """Sample xt from time-dependent density p_t; rng is required""" 123 | xt = self.compute_mu_t(t, x0, x1) 124 | return xt 125 | 126 | def compute_ut(self, t, x0, x1, xt): 127 | """Compute the vector field corresponding to p_t""" 128 | t = expand_t_like_x(t, x1) 129 | _, d_alpha_t = self.compute_alpha_t(t) 130 | _, d_sigma_t = self.compute_sigma_t(t) 131 | return d_alpha_t * x1 + d_sigma_t * x0 132 | 133 | def plan(self, t, x0, x1): 134 | xt = self.compute_xt(t, x0, x1) 135 | ut = self.compute_ut(t, x0, x1, xt) 136 | return t, xt, ut 137 | 138 | 139 | class VPCPlan(ICPlan): 140 | """class for VP path flow matching""" 141 | 142 | def __init__(self, sigma_min=0.1, sigma_max=20.0): 143 | self.sigma_min = sigma_min 144 | self.sigma_max = sigma_max 145 | self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min 146 | self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min 147 | 148 | 149 | def compute_alpha_t(self, t): 150 | """Compute coefficient of x1""" 151 | alpha_t = self.log_mean_coeff(t) 152 | alpha_t = th.exp(alpha_t) 153 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t) 154 | return alpha_t, d_alpha_t 155 | 156 | def compute_sigma_t(self, t): 157 | """Compute coefficient of x0""" 158 | p_sigma_t = 2 * self.log_mean_coeff(t) 159 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) 160 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) 161 | return sigma_t, d_sigma_t 162 | 163 | def compute_d_alpha_alpha_ratio_t(self, t): 164 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 165 | return self.d_log_mean_coeff(t) 166 | 167 | def compute_drift(self, x, t): 168 | """Compute the drift term of the SDE""" 169 | t = expand_t_like_x(t, x) 170 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) 171 | return -0.5 * beta_t * x, beta_t / 2 172 | 173 | 174 | class GVPCPlan(ICPlan): 175 | def __init__(self, sigma=0.0): 176 | super().__init__(sigma) 177 | 178 | def compute_alpha_t(self, t): 179 | """Compute coefficient of x1""" 180 | alpha_t = th.sin(t * np.pi / 2) 181 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) 182 | return alpha_t, d_alpha_t 183 | 184 | def compute_sigma_t(self, t): 185 | """Compute coefficient of x0""" 186 | sigma_t = th.cos(t * np.pi / 2) 187 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) 188 | return sigma_t, d_sigma_t 189 | 190 | def compute_d_alpha_alpha_ratio_t(self, t): 191 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 192 | return np.pi / (2 * th.tan(t * np.pi / 2)) -------------------------------------------------------------------------------- /fit/scheduler/transport/transport.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | import logging 4 | 5 | import enum 6 | 7 | from . import path 8 | from .utils import EasyDict, log_state, mean_flat, get_flexible_mask_and_ratio 9 | from .integrators import ode, sde 10 | 11 | class ModelType(enum.Enum): 12 | """ 13 | Which type of output the model predicts. 14 | """ 15 | 16 | NOISE = enum.auto() # the model predicts epsilon 17 | SCORE = enum.auto() # the model predicts \nabla \log p(x) 18 | VELOCITY = enum.auto() # the model predicts v(x) 19 | 20 | class PathType(enum.Enum): 21 | """ 22 | Which type of path to use. 23 | """ 24 | 25 | LINEAR = enum.auto() 26 | GVP = enum.auto() 27 | VP = enum.auto() 28 | 29 | class WeightType(enum.Enum): 30 | """ 31 | Which type of weighting to use. 32 | """ 33 | 34 | NONE = enum.auto() 35 | VELOCITY = enum.auto() 36 | LIKELIHOOD = enum.auto() 37 | 38 | 39 | class SNRType(enum.Enum): 40 | UNIFORM = enum.auto() 41 | LOGNORM = enum.auto() 42 | 43 | 44 | class Transport: 45 | 46 | def __init__( 47 | self, 48 | *, 49 | model_type, 50 | path_type, 51 | loss_type, 52 | train_eps, 53 | sample_eps, 54 | snr_type 55 | ): 56 | path_options = { 57 | PathType.LINEAR: path.ICPlan, 58 | PathType.GVP: path.GVPCPlan, 59 | PathType.VP: path.VPCPlan, 60 | } 61 | 62 | self.loss_type = loss_type 63 | self.model_type = model_type 64 | self.path_sampler = path_options[path_type]() 65 | self.train_eps = train_eps 66 | self.sample_eps = sample_eps 67 | 68 | self.snr_type = snr_type 69 | 70 | def prior_logp(self, z): 71 | ''' 72 | Standard multivariate normal prior 73 | Assume z is batched 74 | ''' 75 | shape = th.tensor(z.size()) 76 | N = th.prod(shape[1:]) 77 | _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. 78 | return th.vmap(_fn)(z) 79 | 80 | 81 | def check_interval( 82 | self, 83 | train_eps, 84 | sample_eps, 85 | *, 86 | diffusion_form="SBDM", 87 | sde=False, 88 | reverse=False, 89 | eval=False, 90 | last_step_size=0.0, 91 | ): 92 | t0 = 0 93 | t1 = 1 94 | eps = train_eps if not eval else sample_eps 95 | if (type(self.path_sampler) in [path.VPCPlan]): 96 | 97 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 98 | 99 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ 100 | and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step 101 | 102 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 103 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 104 | 105 | if reverse: 106 | t0, t1 = 1 - t0, 1 - t1 107 | 108 | return t0, t1 109 | 110 | 111 | def sample(self, x1): 112 | """Sampling x0 & t based on shape of x1 (if needed) 113 | Args: 114 | x1 - data point; [batch, *dim] 115 | """ 116 | 117 | x0 = th.randn_like(x1) 118 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps) 119 | 120 | if self.snr_type == SNRType.UNIFORM: 121 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 122 | elif self.snr_type == SNRType.LOGNORM: 123 | u = th.normal(mean=0.0, std=1.0, size=(x1.shape[0],)) 124 | t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0 125 | else: 126 | raise ValueError(f"Unknown snr type: {self.snr_type}") 127 | 128 | t = t.to(x1) 129 | return t, x0, x1 130 | 131 | 132 | def training_losses( 133 | self, 134 | model, 135 | x1, 136 | model_kwargs=None 137 | ): 138 | """Loss for training the score model 139 | Args: 140 | - model: backbone model; could be score, noise, or velocity 141 | - x1: datapoint 142 | - model_kwargs: additional arguments for the model 143 | """ 144 | if model_kwargs == None: 145 | model_kwargs = {} 146 | 147 | t, x0, x1 = self.sample(x1) 148 | t, xt, ut = self.path_sampler.plan(t, x0, x1) 149 | model_output = model(xt, t, **model_kwargs) 150 | B, *_, C = xt.shape 151 | assert model_output.size() == (B, *xt.size()[1:-1], C) 152 | mask, ratio = get_flexible_mask_and_ratio(model_kwargs, x1) 153 | 154 | terms = {} 155 | terms['pred'] = model_output 156 | if self.model_type == ModelType.VELOCITY: 157 | terms['loss'] = mean_flat((((model_output - ut)*mask) ** 2)) * ratio 158 | else: 159 | _, drift_var = self.path_sampler.compute_drift(xt, t) 160 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) 161 | if self.loss_type in [WeightType.VELOCITY]: 162 | weight = (drift_var / sigma_t) ** 2 163 | elif self.loss_type in [WeightType.LIKELIHOOD]: 164 | weight = drift_var / (sigma_t ** 2) 165 | elif self.loss_type in [WeightType.NONE]: 166 | weight = 1 167 | else: 168 | raise NotImplementedError() 169 | 170 | if self.model_type == ModelType.NOISE: 171 | terms['loss'] = mean_flat(weight * (((model_output - x0)*mask) ** 2)) * ratio 172 | else: 173 | terms['loss'] = mean_flat(weight * (((model_output * sigma_t + x0)*mask) ** 2)) * ratio 174 | 175 | return terms 176 | 177 | 178 | def get_drift( 179 | self 180 | ): 181 | """member function for obtaining the drift of the probability flow ODE""" 182 | def score_ode(x, t, model, **model_kwargs): 183 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 184 | model_output = model(x, t, **model_kwargs) 185 | return (-drift_mean + drift_var * model_output) # by change of variable 186 | 187 | def noise_ode(x, t, model, **model_kwargs): 188 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 189 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) 190 | model_output = model(x, t, **model_kwargs) 191 | score = model_output / -sigma_t 192 | return (-drift_mean + drift_var * score) 193 | 194 | def velocity_ode(x, t, model, **model_kwargs): 195 | model_output = model(x, t, **model_kwargs) 196 | return model_output 197 | 198 | if self.model_type == ModelType.NOISE: 199 | drift_fn = noise_ode 200 | elif self.model_type == ModelType.SCORE: 201 | drift_fn = score_ode 202 | else: 203 | drift_fn = velocity_ode 204 | 205 | def body_fn(x, t, model, **model_kwargs): 206 | model_output = drift_fn(x, t, model, **model_kwargs) 207 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" 208 | return model_output 209 | 210 | return body_fn 211 | 212 | 213 | def get_score( 214 | self, 215 | ): 216 | """member function for obtaining score of 217 | x_t = alpha_t * x + sigma_t * eps""" 218 | if self.model_type == ModelType.NOISE: 219 | score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] 220 | elif self.model_type == ModelType.SCORE: 221 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) 222 | elif self.model_type == ModelType.VELOCITY: 223 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) 224 | else: 225 | raise NotImplementedError() 226 | 227 | return score_fn 228 | 229 | 230 | class Sampler: 231 | """Sampler class for the transport model""" 232 | def __init__( 233 | self, 234 | transport, 235 | ): 236 | """Constructor for a general sampler; supporting different sampling methods 237 | Args: 238 | - transport: an tranport object specify model prediction & interpolant type 239 | """ 240 | 241 | self.transport = transport 242 | self.drift = self.transport.get_drift() 243 | self.score = self.transport.get_score() 244 | 245 | def __get_sde_diffusion_and_drift( 246 | self, 247 | *, 248 | diffusion_form="SBDM", 249 | diffusion_norm=1.0, 250 | ): 251 | 252 | def diffusion_fn(x, t): 253 | diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) 254 | return diffusion 255 | 256 | sde_drift = \ 257 | lambda x, t, model, **kwargs: \ 258 | self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) 259 | 260 | sde_diffusion = diffusion_fn 261 | 262 | return sde_drift, sde_diffusion 263 | 264 | def __get_last_step( 265 | self, 266 | sde_drift, 267 | *, 268 | last_step, 269 | last_step_size, 270 | ): 271 | """Get the last step function of the SDE solver""" 272 | 273 | if last_step is None: 274 | last_step_fn = \ 275 | lambda x, t, model, **model_kwargs: \ 276 | x 277 | elif last_step == "Mean": 278 | last_step_fn = \ 279 | lambda x, t, model, **model_kwargs: \ 280 | x + sde_drift(x, t, model, **model_kwargs) * last_step_size 281 | elif last_step == "Tweedie": 282 | alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long 283 | sigma = self.transport.path_sampler.compute_sigma_t 284 | last_step_fn = \ 285 | lambda x, t, model, **model_kwargs: \ 286 | x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) 287 | elif last_step == "Euler": 288 | last_step_fn = \ 289 | lambda x, t, model, **model_kwargs: \ 290 | x + self.drift(x, t, model, **model_kwargs) * last_step_size 291 | else: 292 | raise NotImplementedError() 293 | 294 | return last_step_fn 295 | 296 | def sample_sde( 297 | self, 298 | *, 299 | sampling_method="Euler", 300 | diffusion_form="SBDM", 301 | diffusion_norm=1.0, 302 | last_step="Mean", 303 | last_step_size=0.04, 304 | num_steps=250, 305 | ): 306 | """returns a sampling function with given SDE settings 307 | Args: 308 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama 309 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM 310 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1 311 | - last_step: type of the last step; default to identity 312 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] 313 | - num_steps: total integration step of SDE 314 | """ 315 | 316 | if last_step is None: 317 | last_step_size = 0.0 318 | 319 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( 320 | diffusion_form=diffusion_form, 321 | diffusion_norm=diffusion_norm, 322 | ) 323 | 324 | t0, t1 = self.transport.check_interval( 325 | self.transport.train_eps, 326 | self.transport.sample_eps, 327 | diffusion_form=diffusion_form, 328 | sde=True, 329 | eval=True, 330 | reverse=False, 331 | last_step_size=last_step_size, 332 | ) 333 | 334 | _sde = sde( 335 | sde_drift, 336 | sde_diffusion, 337 | t0=t0, 338 | t1=t1, 339 | num_steps=num_steps, 340 | sampler_type=sampling_method 341 | ) 342 | 343 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) 344 | 345 | 346 | def _sample(init, model, **model_kwargs): 347 | xs = _sde.sample(init, model, **model_kwargs) 348 | ts = th.ones(init.size(0), device=init.device) * t1 349 | x = last_step_fn(xs[-1], ts, model, **model_kwargs) 350 | xs.append(x) 351 | 352 | assert len(xs) == num_steps, "Samples does not match the number of steps" 353 | 354 | return xs 355 | 356 | return _sample 357 | 358 | def sample_ode( 359 | self, 360 | *, 361 | sampling_method="dopri5", 362 | num_steps=50, 363 | atol=1e-6, 364 | rtol=1e-3, 365 | reverse=False, 366 | ): 367 | """returns a sampling function with given ODE settings 368 | Args: 369 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 370 | - num_steps: 371 | - fixed solver (Euler, Heun): the actual number of integration steps performed 372 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 373 | - atol: absolute error tolerance for the solver 374 | - rtol: relative error tolerance for the solver 375 | - reverse: whether solving the ODE in reverse (data to noise); default to False 376 | """ 377 | if reverse: 378 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) 379 | else: 380 | drift = self.drift 381 | 382 | t0, t1 = self.transport.check_interval( 383 | self.transport.train_eps, 384 | self.transport.sample_eps, 385 | sde=False, 386 | eval=True, 387 | reverse=reverse, 388 | last_step_size=0.0, 389 | ) 390 | 391 | _ode = ode( 392 | drift=drift, 393 | t0=t0, 394 | t1=t1, 395 | sampler_type=sampling_method, 396 | num_steps=num_steps, 397 | atol=atol, 398 | rtol=rtol, 399 | ) 400 | 401 | return _ode.sample 402 | 403 | def sample_ode_likelihood( 404 | self, 405 | *, 406 | sampling_method="dopri5", 407 | num_steps=50, 408 | atol=1e-6, 409 | rtol=1e-3, 410 | ): 411 | 412 | """returns a sampling function for calculating likelihood with given ODE settings 413 | Args: 414 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 415 | - num_steps: 416 | - fixed solver (Euler, Heun): the actual number of integration steps performed 417 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 418 | - atol: absolute error tolerance for the solver 419 | - rtol: relative error tolerance for the solver 420 | """ 421 | def _likelihood_drift(x, t, model, **model_kwargs): 422 | x, _ = x 423 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 424 | t = th.ones_like(t) * (1 - t) 425 | with th.enable_grad(): 426 | x.requires_grad = True 427 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] 428 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) 429 | drift = self.drift(x, t, model, **model_kwargs) 430 | return (-drift, logp_grad) 431 | 432 | t0, t1 = self.transport.check_interval( 433 | self.transport.train_eps, 434 | self.transport.sample_eps, 435 | sde=False, 436 | eval=True, 437 | reverse=False, 438 | last_step_size=0.0, 439 | ) 440 | 441 | _ode = ode( 442 | drift=_likelihood_drift, 443 | t0=t0, 444 | t1=t1, 445 | sampler_type=sampling_method, 446 | num_steps=num_steps, 447 | atol=atol, 448 | rtol=rtol, 449 | ) 450 | 451 | def _sample_fn(x, model, **model_kwargs): 452 | init_logp = th.zeros(x.size(0)).to(x) 453 | input = (x, init_logp) 454 | drift, delta_logp = _ode.sample(input, model, **model_kwargs) 455 | drift, delta_logp = drift[-1], delta_logp[-1] 456 | prior_logp = self.transport.prior_logp(drift) 457 | logp = prior_logp - delta_logp 458 | return logp, drift 459 | 460 | return _sample_fn -------------------------------------------------------------------------------- /fit/scheduler/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class EasyDict: 4 | 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | def mean_flat(x): 13 | """ 14 | Take the mean over all non-batch dimensions. 15 | """ 16 | return th.mean(x, dim=list(range(1, len(x.size())))) 17 | 18 | def log_state(state): 19 | result = [] 20 | 21 | sorted_state = dict(sorted(state.items())) 22 | for key, value in sorted_state.items(): 23 | # Check if the value is an instance of a class 24 | if " (B, 1, N) 42 | spatial case (dit): 43 | x: (B, C, H, W) 44 | model_kwargs: {y: (B,)} 45 | mask: (B, C) -> (B, C, 1, 1) 46 | ''' 47 | mask = model_kwargs.get('mask', th.ones(x.shape[:2])) # (B, N) or (B, C) 48 | ratio = float(mask.shape[-1]) / th.count_nonzero(mask, dim=-1) # (B,) 49 | if len(x.shape) == 3: # sequential x: (B, N, C) 50 | mask = mask[..., None] # (B, N) -> (B, N, 1) 51 | elif len(x.shape) == 4: # spatial x: (B, C, H, W) 52 | mask = mask[..., None, None] # (B, C) -> (B, C, 1, 1) 53 | else: 54 | raise NotImplementedError 55 | return mask.to(x), ratio.to(x) 56 | -------------------------------------------------------------------------------- /fit/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | import re 6 | import os 7 | 8 | from safetensors.torch import load_file 9 | 10 | 11 | def create_npz_from_sample_folder(sample_dir, num=50_000): 12 | """ 13 | Builds a single .npz file from a folder of .png samples. 14 | """ 15 | samples = [] 16 | imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split('.')[0])) 17 | print(len(imgs)) 18 | assert len(imgs) >= num 19 | for i in tqdm(range(num), desc="Building .npz file from samples"): 20 | sample_pil = Image.open(f"{sample_dir}/{imgs[i]}") 21 | sample_np = np.asarray(sample_pil).astype(np.uint8) 22 | samples.append(sample_np) 23 | samples = np.stack(samples) 24 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 25 | npz_path = f"{sample_dir}.npz" 26 | np.savez(npz_path, arr_0=samples) 27 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 28 | return npz_path 29 | 30 | def init_from_ckpt( 31 | model, checkpoint_dir, ignore_keys=None, verbose=False 32 | ) -> None: 33 | if checkpoint_dir.endswith(".safetensors"): 34 | try: 35 | model_state_dict=load_file(checkpoint_dir) 36 | except: # 历史遗留问题,千万别删 37 | model_state_dict=torch.load(checkpoint_dir, map_location="cpu") 38 | else: 39 | model_state_dict=torch.load(checkpoint_dir, map_location="cpu") 40 | model_new_ckpt=dict() 41 | for i in model_state_dict.keys(): 42 | model_new_ckpt[i] = model_state_dict[i] 43 | keys = list(model_new_ckpt.keys()) 44 | for k in keys: 45 | if ignore_keys: 46 | for ik in ignore_keys: 47 | if re.match(ik, k): 48 | print("Deleting key {} from state_dict.".format(k)) 49 | del model_new_ckpt[k] 50 | missing, unexpected = model.load_state_dict(model_new_ckpt, strict=False) 51 | if verbose: 52 | print( 53 | f"Restored with {len(missing)} missing and {len(unexpected)} unexpected keys" 54 | ) 55 | if len(missing) > 0: 56 | print(f"Missing Keys: {missing}") 57 | if len(unexpected) > 0: 58 | print(f"Unexpected Keys: {unexpected}") 59 | if verbose: 60 | print("") 61 | -------------------------------------------------------------------------------- /fit/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.optim.lr_scheduler import LambdaLR 3 | 4 | 5 | # coding=utf-8 6 | # Copyright 2023 The HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | """PyTorch optimization for diffusion models.""" 20 | 21 | import math 22 | from enum import Enum 23 | from typing import Optional, Union 24 | 25 | from torch.optim import Optimizer 26 | from torch.optim.lr_scheduler import LambdaLR 27 | 28 | 29 | class SchedulerType(Enum): 30 | LINEAR = "linear" 31 | COSINE = "cosine" 32 | COSINE_WITH_RESTARTS = "cosine_with_restarts" 33 | POLYNOMIAL = "polynomial" 34 | CONSTANT = "constant" 35 | CONSTANT_WITH_WARMUP = "constant_with_warmup" 36 | PIECEWISE_CONSTANT = "piecewise_constant" 37 | WARMDUP_STABLE_DECAY = "warmup_stable_decay" 38 | 39 | 40 | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): 41 | """ 42 | Create a schedule with a constant learning rate, using the learning rate set in optimizer. 43 | 44 | Args: 45 | optimizer ([`~torch.optim.Optimizer`]): 46 | The optimizer for which to schedule the learning rate. 47 | last_epoch (`int`, *optional*, defaults to -1): 48 | The index of the last epoch when resuming training. 49 | 50 | Return: 51 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 52 | """ 53 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 54 | 55 | def get_constant_schedule_with_warmup( 56 | optimizer: Optimizer, num_warmup_steps: int, div_factor: int = 1e-4, last_epoch: int = -1 57 | ): 58 | def lr_lambda(current_step): 59 | # 0,y0 step,y1 60 | #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1 61 | if current_step < num_warmup_steps: 62 | return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor 63 | return 1.0 64 | 65 | return LambdaLR(optimizer, lr_lambda, last_epoch) 66 | 67 | def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): 68 | """ 69 | Create a schedule with a constant learning rate, using the learning rate set in optimizer. 70 | 71 | Args: 72 | optimizer ([`~torch.optim.Optimizer`]): 73 | The optimizer for which to schedule the learning rate. 74 | step_rules (`string`): 75 | The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate 76 | if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 77 | steps and multiple 0.005 for the other steps. 78 | last_epoch (`int`, *optional*, defaults to -1): 79 | The index of the last epoch when resuming training. 80 | 81 | Return: 82 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 83 | """ 84 | 85 | rules_dict = {} 86 | rule_list = step_rules.split(",") 87 | for rule_str in rule_list[:-1]: 88 | value_str, steps_str = rule_str.split(":") 89 | steps = int(steps_str) 90 | value = float(value_str) 91 | rules_dict[steps] = value 92 | last_lr_multiple = float(rule_list[-1]) 93 | 94 | def create_rules_function(rules_dict, last_lr_multiple): 95 | def rule_func(steps: int) -> float: 96 | sorted_steps = sorted(rules_dict.keys()) 97 | for i, sorted_step in enumerate(sorted_steps): 98 | if steps < sorted_step: 99 | return rules_dict[sorted_steps[i]] 100 | return last_lr_multiple 101 | 102 | return rule_func 103 | 104 | rules_func = create_rules_function(rules_dict, last_lr_multiple) 105 | 106 | return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) 107 | 108 | 109 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 110 | """ 111 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 112 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 113 | 114 | Args: 115 | optimizer ([`~torch.optim.Optimizer`]): 116 | The optimizer for which to schedule the learning rate. 117 | num_warmup_steps (`int`): 118 | The number of steps for the warmup phase. 119 | num_training_steps (`int`): 120 | The total number of training steps. 121 | last_epoch (`int`, *optional*, defaults to -1): 122 | The index of the last epoch when resuming training. 123 | 124 | Return: 125 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 126 | """ 127 | 128 | def lr_lambda(current_step: int): 129 | if current_step < num_warmup_steps: 130 | return float(current_step) / float(max(1, num_warmup_steps)) 131 | return max( 132 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 133 | ) 134 | 135 | return LambdaLR(optimizer, lr_lambda, last_epoch) 136 | 137 | 138 | def get_cosine_schedule_with_warmup( 139 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 140 | ): 141 | """ 142 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 143 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 144 | initial lr set in the optimizer. 145 | 146 | Args: 147 | optimizer ([`~torch.optim.Optimizer`]): 148 | The optimizer for which to schedule the learning rate. 149 | num_warmup_steps (`int`): 150 | The number of steps for the warmup phase. 151 | num_training_steps (`int`): 152 | The total number of training steps. 153 | num_periods (`float`, *optional*, defaults to 0.5): 154 | The number of periods of the cosine function in a schedule (the default is to just decrease from the max 155 | value to 0 following a half-cosine). 156 | last_epoch (`int`, *optional*, defaults to -1): 157 | The index of the last epoch when resuming training. 158 | 159 | Return: 160 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 161 | """ 162 | 163 | def lr_lambda(current_step): 164 | if current_step < num_warmup_steps: 165 | return float(current_step) / float(max(1, num_warmup_steps)) 166 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 167 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 168 | 169 | return LambdaLR(optimizer, lr_lambda, last_epoch) 170 | 171 | 172 | def get_cosine_with_hard_restarts_schedule_with_warmup( 173 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 174 | ): 175 | """ 176 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 177 | initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases 178 | linearly between 0 and the initial lr set in the optimizer. 179 | 180 | Args: 181 | optimizer ([`~torch.optim.Optimizer`]): 182 | The optimizer for which to schedule the learning rate. 183 | num_warmup_steps (`int`): 184 | The number of steps for the warmup phase. 185 | num_training_steps (`int`): 186 | The total number of training steps. 187 | num_cycles (`int`, *optional*, defaults to 1): 188 | The number of hard restarts to use. 189 | last_epoch (`int`, *optional*, defaults to -1): 190 | The index of the last epoch when resuming training. 191 | 192 | Return: 193 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 194 | """ 195 | 196 | def lr_lambda(current_step): 197 | if current_step < num_warmup_steps: 198 | return float(current_step) / float(max(1, num_warmup_steps)) 199 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 200 | if progress >= 1.0: 201 | return 0.0 202 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 203 | 204 | return LambdaLR(optimizer, lr_lambda, last_epoch) 205 | 206 | 207 | def get_polynomial_decay_schedule_with_warmup( 208 | optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 209 | ): 210 | """ 211 | Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the 212 | optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the 213 | initial lr set in the optimizer. 214 | 215 | Args: 216 | optimizer ([`~torch.optim.Optimizer`]): 217 | The optimizer for which to schedule the learning rate. 218 | num_warmup_steps (`int`): 219 | The number of steps for the warmup phase. 220 | num_training_steps (`int`): 221 | The total number of training steps. 222 | lr_end (`float`, *optional*, defaults to 1e-7): 223 | The end LR. 224 | power (`float`, *optional*, defaults to 1.0): 225 | Power factor. 226 | last_epoch (`int`, *optional*, defaults to -1): 227 | The index of the last epoch when resuming training. 228 | 229 | Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT 230 | implementation at 231 | https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 232 | 233 | Return: 234 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 235 | 236 | """ 237 | 238 | lr_init = optimizer.defaults["lr"] 239 | if not (lr_init > lr_end): 240 | raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") 241 | 242 | def lr_lambda(current_step: int): 243 | if current_step < num_warmup_steps: 244 | return float(current_step) / float(max(1, num_warmup_steps)) 245 | elif current_step > num_training_steps: 246 | return lr_end / lr_init # as LambdaLR multiplies by lr_init 247 | else: 248 | lr_range = lr_init - lr_end 249 | decay_steps = num_training_steps - num_warmup_steps 250 | pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps 251 | decay = lr_range * pct_remaining**power + lr_end 252 | return decay / lr_init # as LambdaLR multiplies by lr_init 253 | 254 | return LambdaLR(optimizer, lr_lambda, last_epoch) 255 | 256 | 257 | def get_constant_schedule_with_warmup_and_decay( 258 | optimizer: Optimizer, num_warmup_steps: int, num_decay_steps: int, decay_T: int = 50000, div_factor: int = 1e-4, last_epoch: int = -1 259 | ): 260 | def lr_lambda(current_step): 261 | # 0,y0 step,y1 262 | #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1 263 | if current_step < num_warmup_steps: 264 | return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor 265 | if current_step > num_decay_steps: 266 | return 0.5 ** ((current_step - num_decay_steps) / decay_T) 267 | return 1.0 268 | 269 | return LambdaLR(optimizer, lr_lambda, last_epoch) 270 | 271 | TYPE_TO_SCHEDULER_FUNCTION = { 272 | SchedulerType.LINEAR: get_linear_schedule_with_warmup, 273 | SchedulerType.COSINE: get_cosine_schedule_with_warmup, 274 | SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, 275 | SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, 276 | SchedulerType.CONSTANT: get_constant_schedule, 277 | SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, 278 | SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, 279 | SchedulerType.WARMDUP_STABLE_DECAY: get_constant_schedule_with_warmup_and_decay 280 | } 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | def get_scheduler( 289 | name: Union[str, SchedulerType], 290 | optimizer: Optimizer, 291 | step_rules: Optional[str] = None, 292 | num_warmup_steps: Optional[int] = None, 293 | num_decay_steps: Optional[int] = None, 294 | num_training_steps: Optional[int] = None, 295 | num_cycles: int = 1, 296 | decay_T: Optional[int] = 50000, 297 | power: float = 1.0, 298 | last_epoch: int = -1, 299 | ): 300 | """ 301 | Unified API to get any scheduler from its name. 302 | 303 | Args: 304 | name (`str` or `SchedulerType`): 305 | The name of the scheduler to use. 306 | optimizer (`torch.optim.Optimizer`): 307 | The optimizer that will be used during training. 308 | step_rules (`str`, *optional*): 309 | A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. 310 | num_warmup_steps (`int`, *optional*): 311 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 312 | optional), the function will raise an error if it's unset and the scheduler type requires it. 313 | num_decay_steps (`int`, *optional*): 314 | The number of decay steps to do. This is not required by all schedulers (hence the argument being 315 | optional), the function will raise an error if it's unset and the scheduler type requires it. 316 | num_training_steps (`int``, *optional*): 317 | The number of training steps to do. This is not required by all schedulers (hence the argument being 318 | optional), the function will raise an error if it's unset and the scheduler type requires it. 319 | num_cycles (`int`, *optional*): 320 | The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. 321 | power (`float`, *optional*, defaults to 1.0): 322 | Power factor. See `POLYNOMIAL` scheduler 323 | decay_T (`int`, *optional*, defaults to 50000): 324 | Power factor. See `POLYNOMIAL` scheduler 325 | last_epoch (`int`, *optional*, defaults to -1): 326 | The index of the last epoch when resuming training. 327 | """ 328 | name = SchedulerType(name) 329 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 330 | if name == SchedulerType.CONSTANT: 331 | return schedule_func(optimizer, last_epoch=last_epoch) 332 | 333 | if name == SchedulerType.PIECEWISE_CONSTANT: 334 | return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) 335 | 336 | # All other schedulers require `num_warmup_steps` 337 | if num_warmup_steps is None: 338 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 339 | 340 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 341 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) 342 | 343 | if name == SchedulerType.WARMDUP_STABLE_DECAY: 344 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_decay_steps=num_decay_steps, decay_T=decay_T, last_epoch=last_epoch) 345 | 346 | # All other schedulers require `num_training_steps` 347 | if num_training_steps is None: 348 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 349 | 350 | if name == SchedulerType.COSINE_WITH_RESTARTS: 351 | return schedule_func( 352 | optimizer, 353 | num_warmup_steps=num_warmup_steps, 354 | num_training_steps=num_training_steps, 355 | num_cycles=num_cycles, 356 | last_epoch=last_epoch, 357 | ) 358 | 359 | if name == SchedulerType.POLYNOMIAL: 360 | return schedule_func( 361 | optimizer, 362 | num_warmup_steps=num_warmup_steps, 363 | num_training_steps=num_training_steps, 364 | power=power, 365 | last_epoch=last_epoch, 366 | ) 367 | 368 | return schedule_func( 369 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch 370 | ) 371 | -------------------------------------------------------------------------------- /fit/utils/sit_eval_utils.py: -------------------------------------------------------------------------------- 1 | def none_or_str(value): 2 | if value == 'None': 3 | return None 4 | return value 5 | 6 | def parse_sde_args(parser): 7 | group = parser.add_argument_group("SDE arguments") 8 | group.add_argument("--sde-sampling-method", type=str, default="Euler", choices=["Euler", "Heun"]) 9 | group.add_argument("--diffusion-form", type=str, default="sigma", \ 10 | choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\ 11 | help="form of diffusion coefficient in the SDE") 12 | group.add_argument("--diffusion-norm", type=float, default=1.0) 13 | group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\ 14 | help="form of last step taken in the SDE") 15 | group.add_argument("--last-step-size", type=float, default=0.04, \ 16 | help="size of the last step taken") 17 | 18 | def parse_ode_args(parser): 19 | group = parser.add_argument_group("ODE arguments") 20 | group.add_argument("--ode-sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq") 21 | group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance") 22 | group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") 23 | group.add_argument("--reverse", action="store_true") 24 | group.add_argument("--likelihood", action="store_true") 25 | 26 | # ode solvers: 27 | # - Adaptive-step: 28 | # - dopri8 Runge-Kutta 7(8) of Dormand-Prince-Shampine 29 | # - dopri5 Runge-Kutta 4(5) of Dormand-Prince [default]. 30 | # - bosh3 Runge-Kutta 2(3) of Bogacki-Shampine 31 | # - adaptive_heun Runge-Kutta 1(2) 32 | # - Fixed-step: 33 | # - euler Euler method. 34 | # - midpoint Midpoint method. 35 | # - rk4 Fourth-order Runge-Kutta with 3/8 rule. 36 | # - explicit_adams Explicit Adams. 37 | # - implicit_adams Implicit Adams. -------------------------------------------------------------------------------- /fit/utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | 5 | def get_obj_from_str(string, reload=False, invalidate_cache=True): 6 | module, cls = string.rsplit(".", 1) 7 | if invalidate_cache: 8 | importlib.invalidate_caches() 9 | if reload: 10 | module_imp = importlib.import_module(module) 11 | importlib.reload(module_imp) 12 | return getattr(importlib.import_module(module, package=None), cls) 13 | 14 | 15 | def instantiate_from_config(config): 16 | if not "target" in config: 17 | if config == "__is_first_stage__": 18 | return None 19 | elif config == "__is_unconditional__": 20 | return None 21 | raise KeyError("Expected key `target` to instantiate.") 22 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 23 | 24 | 25 | @torch.no_grad() 26 | def update_ema(ema_model, model, decay=0.9999): 27 | """ 28 | Step the EMA model towards the current model. 29 | """ 30 | if hasattr(model, 'module'): 31 | model = model.module 32 | if hasattr(ema_model, 'module'): 33 | ema_model = ema_model.module 34 | ema_params = OrderedDict(ema_model.named_parameters()) 35 | model_params = OrderedDict(model.named_parameters()) 36 | 37 | for name, param in model_params.items(): 38 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 39 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 40 | 41 | 42 | 43 | def default(val, d): 44 | if exists(val): 45 | return val 46 | return d() if isfunction(d) else d 47 | 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested 2 | transformers>=4.44.2 # The development team is working on version 4.44.2 3 | accelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested 4 | sentencepiece>=0.2.0 # T5 used 5 | numpy==1.26.0 6 | streamlit>=1.38.0 # For streamlit web demo 7 | imageio==2.34.2 # For diffusers inference export video 8 | imageio-ffmpeg==0.5.1 # For diffusers inference export video 9 | openai>=1.42.0 # For prompt refiner 10 | moviepy==1.0.3 # For export video 11 | pillow==9.5.0 12 | timm==0.6.13 13 | safetensors==0.4.5 14 | einops -------------------------------------------------------------------------------- /sample_fit_ddp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Samples a large number of images from a pre-trained DiT model using DDP. 9 | Subsequently saves a .npz file that can be used to compute FID and other 10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations 11 | 12 | For a simple single-GPU/CPU sampling script, see sample.py. 13 | """ 14 | import os 15 | import math 16 | import torch 17 | import argparse 18 | import numpy as np 19 | import torch.distributed as dist 20 | import re 21 | 22 | from omegaconf import OmegaConf 23 | from tqdm import tqdm 24 | from PIL import Image 25 | from diffusers.models import AutoencoderKL 26 | from fit.scheduler.improved_diffusion import create_diffusion 27 | rom fit.utils.eval_utils import create_npz_from_sample_folder, init_from_ckpt 28 | from fit.utils.utils import instantiate_from_config 29 | f 30 | 31 | def ntk_scaled_init(head_dim, base=10000, alpha=8): 32 | #The method is just these two lines 33 | dim_h = head_dim // 2 # for x and y 34 | base = base * alpha ** (dim_h / (dim_h-2)) #Base change formula 35 | return base 36 | 37 | def main(args): 38 | """ 39 | Run sampling. 40 | """ 41 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 42 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 43 | torch.set_grad_enabled(False) 44 | 45 | # Setup DDP: 46 | dist.init_process_group("nccl") 47 | rank = dist.get_rank() 48 | device = rank % torch.cuda.device_count() 49 | seed = args.global_seed * dist.get_world_size() + rank 50 | torch.manual_seed(seed) 51 | torch.cuda.set_device(device) 52 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 53 | 54 | if args.mixed == "fp32": 55 | weight_dtype = torch.float32 56 | elif args.mixed == "bf16": 57 | weight_dtype = torch.bfloat16 58 | 59 | if args.cfgdir == "": 60 | args.cfgdir = os.path.join(args.ckpt.split("/")[0], args.ckpt.split("/")[1], "configs/config.yaml") 61 | print("config dir: ",args.cfgdir) 62 | config = OmegaConf.load(args.cfgdir) 63 | config_diffusion = config.diffusion 64 | 65 | 66 | H, W = args.image_height // 8, args.image_width // 8 67 | patch_size = config_diffusion.network_config.params.patch_size 68 | n_patch_h, n_patch_w = H // patch_size, W // patch_size 69 | 70 | if args.interpolation != 'no': 71 | # sqrt(256) or sqrt(512), we set max PE length for inference, in fact some PE has been seen in the training stage. 72 | ori_max_pe_len = int(config_diffusion.network_config.params.context_size ** 0.5) 73 | if args.interpolation == 'linear': # 这个就是positional index interpolation,原来叫normal,现在叫linear 74 | config_diffusion.network_config.params['custom_freqs'] = 'linear' 75 | elif args.interpolation == 'dynntk': # 这个就是ntk-aware 76 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware' 77 | elif args.interpolation == 'partntk': # 这个就是ntk-by-parts 78 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-by-parts' 79 | elif args.interpolation == 'yarn': 80 | config_diffusion.network_config.params['custom_freqs'] = 'yarn' 81 | else: 82 | raise NotImplementedError 83 | config_diffusion.network_config.params['max_pe_len_h'] = n_patch_h 84 | config_diffusion.network_config.params['max_pe_len_w'] = n_patch_w 85 | config_diffusion.network_config.params['decouple'] = args.decouple 86 | config_diffusion.network_config.params['ori_max_pe_len'] = int(ori_max_pe_len) 87 | 88 | else: # there is no need to do interpolation! 89 | pass 90 | 91 | model = instantiate_from_config(config_diffusion.network_config).to(device, dtype=weight_dtype) 92 | init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True) 93 | model.eval() # important 94 | 95 | # prepare first stage model 96 | if args.vae_decoder == 'sd-ft-mse': 97 | vae_model = 'stabilityai/sd-vae-ft-mse' 98 | elif args.vae_decoder == 'sd-ft-ema': 99 | vae_model = 'stabilityai/sd-vae-ft-ema' 100 | vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(device, dtype=weight_dtype) 101 | vae.eval() # important 102 | 103 | 104 | config_diffusion.improved_diffusion.timestep_respacing = str(args.num_sampling_steps) 105 | diffusion = create_diffusion(**OmegaConf.to_container(config_diffusion.improved_diffusion)) 106 | 107 | 108 | workdir_name = 'official_fit' 109 | folder_name = f'{args.ckpt.split("/")[-1].split(".")[0]}' 110 | 111 | sample_folder_dir = f"{args.sample_dir}/{workdir_name}/{folder_name}" 112 | if rank == 0: 113 | os.makedirs(os.path.join(args.sample_dir, workdir_name), exist_ok=True) 114 | os.makedirs(sample_folder_dir, exist_ok=True) 115 | print(f"Saving .png samples at {sample_folder_dir}") 116 | dist.barrier() 117 | args.cfg_scale = float(args.cfg_scale) 118 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 119 | using_cfg = args.cfg_scale > 1.0 120 | 121 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 122 | n = args.per_proc_batch_size 123 | global_batch_size = n * dist.get_world_size() 124 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 125 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 126 | if rank == 0: 127 | print(f"Total number of images that will be sampled: {total_samples}") 128 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 129 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 130 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 131 | iterations = int(samples_needed_this_gpu // n) 132 | pbar = range(iterations) 133 | pbar = tqdm(pbar) if rank == 0 else pbar 134 | total = 0 135 | index = 0 136 | all_images = [] 137 | while len(all_images) * n < int(args.num_fid_samples): 138 | print(device, "device: ", index, flush=True) 139 | index+=1 140 | # Sample inputs: 141 | z = torch.randn( 142 | (n, (patch_size**2)*model.in_channels, n_patch_h*n_patch_w) 143 | ).to(device=device, dtype=weight_dtype) 144 | y = torch.randint(0, args.num_classes, (n,), device=device) 145 | 146 | # prepare for x 147 | grid_h = torch.arange(n_patch_h, dtype=torch.float32) 148 | grid_w = torch.arange(n_patch_w, dtype=torch.float32) 149 | grid = torch.meshgrid(grid_w, grid_h, indexing='xy') 150 | grid = torch.cat( 151 | [grid[0].reshape(1,-1), grid[1].reshape(1,-1)], dim=0 152 | ).repeat(n,1,1).to(device=device, dtype=weight_dtype) 153 | mask = torch.ones(n, n_patch_h*n_patch_w).to(device=device, dtype=weight_dtype) 154 | size = torch.tensor((n_patch_h, n_patch_w)).repeat(n,1).to(device=device, dtype=torch.long) 155 | size = size[:, None, :] 156 | 157 | 158 | 159 | # Setup classifier-free guidance: 160 | if using_cfg: 161 | z = torch.cat([z, z], 0) # (B, patch_size**2 * C, N) -> (2B, patch_size**2 * C, N) 162 | y_null = torch.tensor([1000] * n, device=device) 163 | y = torch.cat([y, y_null], 0) # (B,) -> (2B, ) 164 | grid = torch.cat([grid, grid], 0) # (B, 2, N) -> (2B, 2, N) 165 | mask = torch.cat([mask, mask], 0) # (B, N) -> (2B, N) 166 | model_kwargs = dict(y=y, grid=grid.long(), mask=mask, size=size, cfg_scale=args.cfg_scale) 167 | sample_fn = model.forward_with_cfg 168 | else: 169 | model_kwargs = dict(y=y, grid=grid.long(), mask=mask, size=size) 170 | sample_fn = model.forward 171 | 172 | # Sample images: 173 | samples = diffusion.p_sample_loop( 174 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 175 | ) 176 | if using_cfg: 177 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 178 | 179 | samples = samples[..., : n_patch_h*n_patch_w] 180 | samples = model.unpatchify(samples, (H, W)) 181 | samples = vae.decode(samples / vae.config.scaling_factor).sample 182 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(torch.uint8).contiguous() 183 | 184 | # gather samples 185 | gathered_samples = [torch.zeros_like(samples) for _ in range(dist.get_world_size())] 186 | dist.all_gather(gathered_samples, samples) # gather not supported with NCCL 187 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 188 | torch.cuda.empty_cache() 189 | # Save samples to disk as individual .png files 190 | for i, sample in enumerate(samples.cpu().numpy()): 191 | index = i * dist.get_world_size() + rank + total 192 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 193 | total += global_batch_size 194 | if rank == 0: 195 | pbar.update() 196 | 197 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 198 | dist.barrier() 199 | if rank == 0: 200 | arr = np.concatenate(all_images, axis=0) 201 | arr = arr[: int(args.num_fid_samples)] 202 | npz_path = f"{sample_folder_dir}.npz" 203 | np.savez(npz_path, arr_0=arr) 204 | print(f"Saved .npz file to {npz_path} [shape={arr.shape}].") 205 | print("Done.") 206 | dist.barrier() 207 | dist.destroy_process_group() 208 | 209 | if __name__ == "__main__": 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument("--cfgdir", type=str, default="") 212 | parser.add_argument("--ckpt", type=str, default="") 213 | parser.add_argument("--sample-dir", type=str, default="workdir/eval") 214 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 215 | parser.add_argument("--num-fid-samples", type=int, default=50_000) 216 | parser.add_argument("--image-height", type=int, default=256) 217 | parser.add_argument("--image-width", type=int, default=256) 218 | parser.add_argument("--num-classes", type=int, default=1000) 219 | parser.add_argument("--vae-decoder", type=str, choices=['sd-ft-mse', 'sd-ft-ema'], default='sd-ft-ema') 220 | parser.add_argument("--cfg-scale", type=str, default='1.5') 221 | parser.add_argument("--num-sampling-steps", type=int, default=250) 222 | parser.add_argument("--global-seed", type=int, default=0) 223 | parser.add_argument("--interpolation", type=str, choices=['no', 'linear', 'yarn', 'dynntk', 'partntk'], default='no') # interpolation 224 | parser.add_argument("--decouple", default=False, action="store_true") # interpolation 225 | parser.add_argument("--tf32", action='store_true', default=True) 226 | parser.add_argument("--mixed", type=str, default="fp32") 227 | args = parser.parse_args() 228 | main(args) 229 | -------------------------------------------------------------------------------- /sample_fitv2_ddp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Samples a large number of images from a pre-trained DiT model using DDP. 9 | Subsequently saves a .npz file that can be used to compute FID and other 10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations 11 | 12 | For a simple single-GPU/CPU sampling script, see sample.py. 13 | """ 14 | import os 15 | import sys 16 | import math 17 | import torch 18 | import argparse 19 | import numpy as np 20 | import torch.distributed as dist 21 | import re 22 | 23 | from omegaconf import OmegaConf 24 | from tqdm import tqdm 25 | from PIL import Image 26 | from diffusers.models import AutoencoderKL 27 | from fit.scheduler.transport import create_transport, Sampler 28 | from fit.utils.eval_utils import create_npz_from_sample_folder, init_from_ckpt 29 | from fit.utils.utils import instantiate_from_config 30 | from fit.utils.sit_eval_utils import parse_sde_args, parse_ode_args 31 | 32 | def ntk_scaled_init(head_dim, base=10000, alpha=8): 33 | #The method is just these two lines 34 | dim_h = head_dim // 2 # for x and y 35 | base = base * alpha ** (dim_h / (dim_h-2)) #Base change formula 36 | return base 37 | 38 | def main(args): 39 | """ 40 | Run sampling. 41 | """ 42 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 43 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 44 | torch.set_grad_enabled(False) 45 | 46 | # Setup DDP: 47 | dist.init_process_group("nccl") 48 | rank = dist.get_rank() 49 | device = rank % torch.cuda.device_count() 50 | seed = args.global_seed * dist.get_world_size() + rank 51 | torch.manual_seed(seed) 52 | torch.cuda.set_device(device) 53 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 54 | 55 | if args.mixed == "fp32": 56 | weight_dtype = torch.float32 57 | elif args.mixed == "bf16": 58 | weight_dtype = torch.bfloat16 59 | 60 | if args.cfgdir == "": 61 | args.cfgdir = os.path.join(args.ckpt.split("/")[0], args.ckpt.split("/")[1], "configs/config.yaml") 62 | print("config dir: ",args.cfgdir) 63 | config = OmegaConf.load(args.cfgdir) 64 | config_diffusion = config.diffusion 65 | 66 | 67 | H, W = args.image_height // 8, args.image_width // 8 68 | patch_size = config_diffusion.network_config.params.patch_size 69 | n_patch_h, n_patch_w = H // patch_size, W // patch_size 70 | 71 | if args.interpolation != 'no': 72 | if args.interpolation == 'linear': # 这个就是positional index interpolation,原来叫normal,现在叫linear 73 | config_diffusion.network_config.params['custom_freqs'] = 'linear' 74 | elif args.interpolation == 'dynntk': # 这个就是ntk-aware 75 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware' 76 | elif args.interpolation == 'ntkpro1': 77 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware-pro1' 78 | elif args.interpolation == 'ntkpro2': 79 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware-pro2' 80 | elif args.interpolation == 'partntk': # 这个就是ntk-by-parts 81 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-by-parts' 82 | elif args.interpolation == 'yarn': 83 | config_diffusion.network_config.params['custom_freqs'] = 'yarn' 84 | else: 85 | raise NotImplementedError 86 | config_diffusion.network_config.params['max_pe_len_h'] = n_patch_h 87 | config_diffusion.network_config.params['max_pe_len_w'] = n_patch_w 88 | config_diffusion.network_config.params['decouple'] = args.decouple 89 | config_diffusion.network_config.params['ori_max_pe_len'] = int(args.ori_max_pe_len) 90 | 91 | config_diffusion.network_config.params['online_rope'] = False 92 | 93 | else: # there is no need to do interpolation! 94 | config_diffusion.network_config.params['custom_freqs'] = 'normal' 95 | config_diffusion.network_config.params['online_rope'] = False 96 | 97 | 98 | 99 | model = instantiate_from_config(config_diffusion.network_config).to(device, dtype=weight_dtype) 100 | init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True) 101 | model.eval() # important 102 | 103 | # prepare first stage model 104 | if args.vae_decoder == 'sd-ft-mse': 105 | vae_model = 'stabilityai/sd-vae-ft-mse' 106 | elif args.vae_decoder == 'sd-ft-ema': 107 | vae_model = 'stabilityai/sd-vae-ft-ema' 108 | vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(device, dtype=weight_dtype) 109 | vae.eval() # important 110 | 111 | 112 | # prepare transport 113 | transport = create_transport(**OmegaConf.to_container(config_diffusion.transport)) # default: velocity; 114 | sampler = Sampler(transport) 115 | sampler_mode = args.sampler_mode 116 | if sampler_mode == "ODE": 117 | if args.likelihood: 118 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" 119 | sample_fn = sampler.sample_ode_likelihood( 120 | sampling_method=args.ode_sampling_method, 121 | num_steps=args.num_sampling_steps, 122 | atol=args.atol, 123 | rtol=args.rtol, 124 | ) 125 | else: 126 | sample_fn = sampler.sample_ode( 127 | sampling_method=args.ode_sampling_method, 128 | num_steps=args.num_sampling_steps, 129 | atol=args.atol, 130 | rtol=args.rtol, 131 | reverse=args.reverse 132 | ) 133 | elif sampler_mode == "SDE": 134 | sample_fn = sampler.sample_sde( 135 | sampling_method=args.sde_sampling_method, 136 | diffusion_form=args.diffusion_form, 137 | diffusion_norm=args.diffusion_norm, 138 | last_step=args.last_step, 139 | last_step_size=args.last_step_size, 140 | num_steps=args.num_sampling_steps, 141 | ) 142 | else: 143 | raise NotImplementedError 144 | 145 | 146 | 147 | workdir_name = 'official_fit' 148 | folder_name = f'{args.ckpt.split("/")[-1].split(".")[0]}' 149 | 150 | 151 | sample_folder_dir = f"{args.sample_dir}/{workdir_name}/{folder_name}" 152 | if rank == 0: 153 | os.makedirs(os.path.join(args.sample_dir, workdir_name), exist_ok=True) 154 | os.makedirs(sample_folder_dir, exist_ok=True) 155 | print(f"Saving .png samples at {sample_folder_dir}") 156 | dist.barrier() 157 | args.cfg_scale = float(args.cfg_scale) 158 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 159 | using_cfg = args.cfg_scale > 1.0 160 | 161 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 162 | n = args.per_proc_batch_size 163 | global_batch_size = n * dist.get_world_size() 164 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 165 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 166 | if rank == 0: 167 | print(f"Total number of images that will be sampled: {total_samples}") 168 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 169 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 170 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 171 | iterations = int(samples_needed_this_gpu // n) 172 | pbar = range(iterations) 173 | pbar = tqdm(pbar) if rank == 0 else pbar 174 | total = 0 175 | index = 0 176 | all_images = [] 177 | while len(all_images) * n < int(args.num_fid_samples): 178 | print(device, "device: ", index, flush=True) 179 | index+=1 180 | # Sample inputs: 181 | z = torch.randn( 182 | (n, n_patch_h*n_patch_w, (patch_size**2)*model.in_channels) 183 | ).to(device=device, dtype=weight_dtype) 184 | y = torch.randint(0, args.num_classes, (n,), device=device) 185 | 186 | # prepare for x 187 | grid_h = torch.arange(n_patch_h, dtype=torch.long) 188 | grid_w = torch.arange(n_patch_w, dtype=torch.long) 189 | grid = torch.meshgrid(grid_w, grid_h, indexing='xy') 190 | grid = torch.cat( 191 | [grid[0].reshape(1,-1), grid[1].reshape(1,-1)], dim=0 192 | ).repeat(n,1,1).to(device=device, dtype=torch.long) 193 | mask = torch.ones(n, n_patch_h*n_patch_w).to(device=device, dtype=weight_dtype) 194 | size = torch.tensor((n_patch_h, n_patch_w)).repeat(n,1).to(device=device, dtype=torch.long) 195 | size = size[:, None, :] 196 | # Setup classifier-free guidance: 197 | if using_cfg: 198 | z = torch.cat([z, z], 0) # (B, N, patch_size**2 * C) -> (2B, N, patch_size**2 * C) 199 | y_null = torch.tensor([1000] * n, device=device) 200 | y = torch.cat([y, y_null], 0) # (B,) -> (2B, ) 201 | grid = torch.cat([grid, grid], 0) # (B, 2, N) -> (2B, 2, N) 202 | mask = torch.cat([mask, mask], 0) # (B, N) -> (2B, N) 203 | size = torch.cat([size, size], 0) 204 | model_kwargs = dict(y=y, grid=grid, mask=mask, size=size, cfg_scale=args.cfg_scale, scale_pow=args.scale_pow) 205 | model_fn = model.forward_with_cfg 206 | else: 207 | model_kwargs = dict(y=y, grid=grid, mask=mask, size=size) 208 | model_fn = model.forward 209 | 210 | 211 | # Sample images: 212 | samples = sample_fn(z, model_fn, **model_kwargs)[-1] 213 | if using_cfg: 214 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 215 | 216 | samples = samples[..., : n_patch_h*n_patch_w] 217 | samples = model.unpatchify(samples, (H, W)) 218 | samples = vae.decode(samples / vae.config.scaling_factor).sample 219 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(torch.uint8).contiguous() 220 | 221 | # gather samples 222 | gathered_samples = [torch.zeros_like(samples) for _ in range(dist.get_world_size())] 223 | dist.all_gather(gathered_samples, samples) # gather not supported with NCCL 224 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 225 | # Save samples to disk as individual .png files 226 | for i, sample in enumerate(samples.cpu().numpy()): 227 | index = i * dist.get_world_size() + rank + total 228 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 229 | total += global_batch_size 230 | if rank == 0: 231 | pbar.update() 232 | 233 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 234 | dist.barrier() 235 | if rank == 0: 236 | import time 237 | time.sleep(20) 238 | arr = np.concatenate(all_images, axis=0) 239 | arr = arr[: int(args.num_fid_samples)] 240 | npz_path = f"{sample_folder_dir}.npz" 241 | np.savez(npz_path, arr_0=arr) 242 | print(f"Saved .npz file to {npz_path} [shape={arr.shape}].") 243 | print("Done.") 244 | dist.barrier() 245 | dist.destroy_process_group() 246 | 247 | if __name__ == "__main__": 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument("--cfgdir", type=str, default="") 250 | parser.add_argument("--ckpt", type=str, default="") 251 | parser.add_argument("--sample-dir", type=str, default="workdir/eval") 252 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 253 | parser.add_argument("--num-fid-samples", type=int, default=50_000) 254 | parser.add_argument("--image-height", type=int, default=256) 255 | parser.add_argument("--image-width", type=int, default=256) 256 | parser.add_argument("--num-classes", type=int, default=1000) 257 | parser.add_argument("--vae-decoder", type=str, choices=['sd-ft-mse', 'sd-ft-ema'], default='sd-ft-ema') 258 | parser.add_argument("--cfg-scale", type=str, default='1.0') 259 | parser.add_argument("--scale-pow", type=float, default=0.0) 260 | parser.add_argument("--num-sampling-steps", type=int, default=250) 261 | parser.add_argument("--global-seed", type=int, default=0) 262 | parser.add_argument("--interpolation", type=str, choices=['no', 'linear', 'yarn', 'dynntk', 'partntk', 'ntkpro1', 'ntkpro2'], default='no') # interpolation 263 | parser.add_argument("--ori-max-pe-len", default=None, type=int) 264 | parser.add_argument("--decouple", default=False, action="store_true") # interpolation 265 | parser.add_argument("--sampler-mode", default='SDE', choices=['SDE', 'ODE']) 266 | parser.add_argument("--tf32", action='store_true', default=True) 267 | parser.add_argument("--mixed", type=str, default="fp32") 268 | parser.add_argument("--save-images", action='store_true', default=False) 269 | parse_ode_args(parser) 270 | parse_sde_args(parser) 271 | args = parser.parse_args() 272 | main(args) 273 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | setup( 5 | name='fit', 6 | version='0.0.1', 7 | description='', 8 | packages=find_packages(), 9 | install_requires=[ 10 | 'torch', 11 | 'numpy', 12 | 'tqdm', 13 | ], 14 | ) -------------------------------------------------------------------------------- /tools/download_in1k_latents_1024.sh: -------------------------------------------------------------------------------- 1 | mkdir datasets 2 | cd datasets 3 | mkdir imagenet1k_latents_1024_sd_vae_ft_ema 4 | cd imagenet1k_latents_1024_sd_vae_ft_ema 5 | 6 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/from_16_to_1024.tar.gz.part_aa?download=true" -O from_16_to_1024.tar.gz.part_aa 7 | 8 | tar -xzvf from_16_to_1024.tar.gz.part_aa 9 | 10 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/from_16_to_1024.tar.gz.part_ab?download=true" -O from_16_to_1024.tar.gz.part_ab 11 | 12 | tar -xzvf from_16_to_1024.tar.gz.part_aa 13 | 14 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/from_16_to_1024.tar.gz.part_ac?download=true" -O from_16_to_1024.tar.gz.part_ac 15 | 16 | tar -xzvf from_16_to_1024.tar.gz.part_aa 17 | 18 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/greater_than_1024_crop.tar.gz?download=true" -O greater_than_1024_crop.tar.gz 19 | 20 | tar -xzvf greater_than_1024_crop.tar.gz 21 | 22 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/greater_than_1024_resize.tar.gz?download=true" -O greater_than_1024_resize.tar.gz 23 | 24 | tar -xzvf greater_than_1024_resize.tar.gz 25 | 26 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/less_than_16.tar.gz?download=true" -O less_than_16.tar.gz 27 | 28 | tar -xzvf less_than_16.tar.gz 29 | -------------------------------------------------------------------------------- /tools/download_in1k_latents_256.sh: -------------------------------------------------------------------------------- 1 | mkdir datasets 2 | cd datasets 3 | mkdir imagenet1k_latents_256_sd_vae_ft_ema 4 | cd imagenet1k_latents_256_sd_vae_ft_ema 5 | 6 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/from_16_to_256.tar.gz?download=true" -O from_16_to_256.tar.gz 7 | 8 | tar -xzvf from_16_to_256.tar.gz 9 | 10 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/greater_than_256_crop.tar.gz?download=true" -O greater_than_256_crop.tar.gz 11 | 12 | tar -xzvf greater_than_256_crop.tar.gz 13 | 14 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/greater_than_256_resize.tar.gz?download=true" -O greater_than_256_resize.tar.gz 15 | 16 | tar -xzvf greater_than_256_resize.tar.gz 17 | 18 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/less_than_16.tar.gz?download=true" -O less_than_16.tar.gz 19 | 20 | tar -xzvf less_than_16.tar.gz 21 | -------------------------------------------------------------------------------- /tools/train_fit_xl.sh: -------------------------------------------------------------------------------- 1 | JOB_NAME = "train_fit_xl" 2 | NNODES = 1 3 | GPUS_PER_NODE = 8 4 | MASTER_ADDR = "localhost" 5 | export MASTER_PORT=60563 6 | 7 | CMD=" \ 8 | projects/FiT/FiT/train_fit.py \ 9 | --project_name ${JOB_NAME} \ 10 | --main_project_name image_generation \ 11 | --seed 0 \ 12 | --scale_lr \ 13 | --allow_tf32 \ 14 | --resume_from_checkpoint latest \ 15 | --workdir workdir/fit_xl \ 16 | --cfgdir "projects/FiT/FiT/configs/fit/config_fit_xl.yaml" \ 17 | --use_ema 18 | " 19 | TORCHLAUNCHER="torchrun \ 20 | --nnodes $NNODES \ 21 | --nproc_per_node $GPUS_PER_NODE \ 22 | --rdzv_id $RANDOM \ 23 | --rdzv_backend c10d \ 24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ 25 | " 26 | bash -c "$TORCHLAUNCHER $CMD" -------------------------------------------------------------------------------- /tools/train_fitv2_3B.sh: -------------------------------------------------------------------------------- 1 | JOB_NAME = "train_fitv2_3B" 2 | NNODES = 1 3 | GPUS_PER_NODE = 8 4 | MASTER_ADDR = "localhost" 5 | export MASTER_PORT=60563 6 | 7 | CMD=" \ 8 | projects/FiT/FiT/train_fitv2.py \ 9 | --project_name ${JOB_NAME} \ 10 | --main_project_name image_generation \ 11 | --seed 0 \ 12 | --scale_lr \ 13 | --allow_tf32 \ 14 | --resume_from_checkpoint latest \ 15 | --workdir workdir/fitv2_3B \ 16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_3B.yaml" \ 17 | --use_ema 18 | " 19 | TORCHLAUNCHER="torchrun \ 20 | --nnodes $NNODES \ 21 | --nproc_per_node $GPUS_PER_NODE \ 22 | --rdzv_id $RANDOM \ 23 | --rdzv_backend c10d \ 24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ 25 | " 26 | bash -c "$TORCHLAUNCHER $CMD" -------------------------------------------------------------------------------- /tools/train_fitv2_hr_3B.sh: -------------------------------------------------------------------------------- 1 | JOB_NAME = "train_fitv2_hr_3B" 2 | NNODES = 1 3 | GPUS_PER_NODE = 8 4 | MASTER_ADDR = "localhost" 5 | export MASTER_PORT=60563 6 | 7 | CMD=" \ 8 | projects/FiT/FiT/train_fitv2.py \ 9 | --project_name ${JOB_NAME} \ 10 | --main_project_name image_generation \ 11 | --seed 0 \ 12 | --scale_lr \ 13 | --allow_tf32 \ 14 | --resume_from_checkpoint latest \ 15 | --workdir workdir/fitv2_hr_3B \ 16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_hr_3B.yaml" \ 17 | --use_ema 18 | " 19 | TORCHLAUNCHER="torchrun \ 20 | --nnodes $NNODES \ 21 | --nproc_per_node $GPUS_PER_NODE \ 22 | --rdzv_id $RANDOM \ 23 | --rdzv_backend c10d \ 24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ 25 | " 26 | bash -c "$TORCHLAUNCHER $CMD" -------------------------------------------------------------------------------- /tools/train_fitv2_hr_xl.sh: -------------------------------------------------------------------------------- 1 | JOB_NAME = "train_fitv2_hr_xl" 2 | NNODES = 1 3 | GPUS_PER_NODE = 8 4 | MASTER_ADDR = "localhost" 5 | export MASTER_PORT=60563 6 | 7 | CMD=" \ 8 | projects/FiT/FiT/train_fitv2.py \ 9 | --project_name ${JOB_NAME} \ 10 | --main_project_name image_generation \ 11 | --seed 0 \ 12 | --scale_lr \ 13 | --allow_tf32 \ 14 | --resume_from_checkpoint latest \ 15 | --workdir workdir/fitv2_hr_xl \ 16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_hr_xl.yaml" \ 17 | --use_ema 18 | " 19 | TORCHLAUNCHER="torchrun \ 20 | --nnodes $NNODES \ 21 | --nproc_per_node $GPUS_PER_NODE \ 22 | --rdzv_id $RANDOM \ 23 | --rdzv_backend c10d \ 24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ 25 | " 26 | bash -c "$TORCHLAUNCHER $CMD" -------------------------------------------------------------------------------- /tools/train_fitv2_xl.sh: -------------------------------------------------------------------------------- 1 | JOB_NAME = "train_fitv2_xl" 2 | NNODES = 1 3 | GPUS_PER_NODE = 8 4 | MASTER_ADDR = "localhost" 5 | export MASTER_PORT=60563 6 | 7 | CMD=" \ 8 | projects/FiT/FiT/train_fitv2.py \ 9 | --project_name ${JOB_NAME} \ 10 | --main_project_name image_generation \ 11 | --seed 0 \ 12 | --scale_lr \ 13 | --allow_tf32 \ 14 | --resume_from_checkpoint latest \ 15 | --workdir workdir/fitv2_xl \ 16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_xl.yaml" \ 17 | --use_ema 18 | " 19 | TORCHLAUNCHER="torchrun \ 20 | --nnodes $NNODES \ 21 | --nproc_per_node $GPUS_PER_NODE \ 22 | --rdzv_id $RANDOM \ 23 | --rdzv_backend c10d \ 24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ 25 | " 26 | bash -c "$TORCHLAUNCHER $CMD" -------------------------------------------------------------------------------- /train_fitv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import datetime 5 | import time 6 | import torchvision 7 | import wandb 8 | import logging 9 | import math 10 | import shutil 11 | import accelerate 12 | import torch 13 | import torch.utils.checkpoint 14 | import diffusers 15 | import numpy as np 16 | import torch.nn.functional as F 17 | 18 | from functools import partial 19 | from torch.cuda import amp 20 | from omegaconf import OmegaConf 21 | from accelerate import Accelerator, skip_first_batches 22 | from accelerate.logging import get_logger 23 | from accelerate.state import AcceleratorState 24 | from accelerate.utils import ProjectConfiguration, set_seed, save, FullyShardedDataParallelPlugin 25 | from tqdm.auto import tqdm 26 | from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler 27 | from safetensors import safe_open 28 | from safetensors.torch import load_file 29 | from copy import deepcopy 30 | from einops import rearrange 31 | from fit.schedulers.transport import create_transport 32 | from fit.utils.utils import ( 33 | instantiate_from_config, 34 | default, 35 | get_obj_from_str, 36 | update_ema, 37 | 38 | ) 39 | from fit.utils.eval_utils import init_from_ckpt 40 | from fit.utils.lr_scheduler import get_scheduler 41 | 42 | logger = get_logger(__name__, log_level="INFO") 43 | 44 | # For Omegaconf Tuple 45 | def resolve_tuple(*args): 46 | return tuple(args) 47 | OmegaConf.register_new_resolver("tuple", resolve_tuple) 48 | 49 | def parse_args(): 50 | parser = argparse.ArgumentParser(description="Argument.") 51 | parser.add_argument( 52 | "--project_name", 53 | type=str, 54 | const=True, 55 | default="", 56 | nargs="?", 57 | help="if setting, the logdir will be like: project_name", 58 | ) 59 | parser.add_argument( 60 | "--main_project_name", 61 | type=str, 62 | default="image_generation", 63 | ) 64 | parser.add_argument( 65 | "--workdir", 66 | type=str, 67 | default="workdir", 68 | help="workdir", 69 | ) 70 | parser.add_argument( # if resume, you change it none. i will load from the resumedir 71 | "--cfgdir", 72 | nargs="*", 73 | help="paths to base configs. Loaded from left-to-right. " 74 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 75 | default=list(), 76 | ) 77 | parser.add_argument( 78 | "-s", 79 | "--seed", 80 | type=int, 81 | default=0, 82 | help="seed for seed_everything", 83 | ) 84 | parser.add_argument( 85 | "--resume_from_checkpoint", 86 | type=str, 87 | default='latest', 88 | help=( 89 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 90 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 91 | ), 92 | ) 93 | parser.add_argument( 94 | "--load_model_from_checkpoint", 95 | type=str, 96 | default=None, 97 | help=( 98 | "Whether training should be loaded from a pretrained model checkpoint." 99 | "Or you can set diffusion.pretrained_model_path in Config for loading!!!" 100 | ), 101 | ) 102 | parser.add_argument( 103 | "--scale_lr", 104 | action="store_true", 105 | default=False, 106 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 107 | ) 108 | parser.add_argument( 109 | "--allow_tf32", 110 | action="store_true", 111 | help=( 112 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 113 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 114 | ), 115 | ) 116 | parser.add_argument( 117 | "--use_ema", 118 | action="store_true", 119 | default=True, 120 | help="Whether to use EMA model." 121 | ) 122 | parser.add_argument( 123 | "--ema_decay", 124 | type=float, 125 | default=0.9999, 126 | help="The decay rate for ema." 127 | ) 128 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 129 | args = parser.parse_args() 130 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 131 | if env_local_rank != -1 and env_local_rank != args.local_rank: 132 | args.local_rank = env_local_rank 133 | return args 134 | 135 | 136 | def main(): 137 | args = parse_args() 138 | 139 | datenow = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 140 | project_name = None 141 | workdir = None 142 | workdirnow = None 143 | cfgdir = None 144 | ckptdir = None 145 | logging_dir = None 146 | imagedir = None 147 | 148 | if args.project_name: 149 | project_name = args.project_name 150 | if os.path.exists(os.path.join(args.workdir, project_name)): #open resume 151 | workdir=os.path.join(args.workdir, project_name) 152 | else: # new a workdir 153 | workdir = os.path.join(args.workdir, project_name) 154 | # if accelerator.is_main_process: 155 | os.makedirs(workdir, exist_ok=True) 156 | workdirnow = workdir 157 | 158 | cfgdir = os.path.join(workdirnow, "configs") 159 | ckptdir = os.path.join(workdirnow, "checkpoints") 160 | logging_dir = os.path.join(workdirnow, "logs") 161 | imagedir = os.path.join(workdirnow, "images") 162 | 163 | # if accelerator.is_main_process: 164 | os.makedirs(cfgdir, exist_ok=True) 165 | os.makedirs(ckptdir, exist_ok=True) 166 | os.makedirs(logging_dir, exist_ok=True) 167 | os.makedirs(imagedir, exist_ok=True) 168 | if args.cfgdir: 169 | load_cfgdir = args.cfgdir 170 | 171 | # setup config 172 | configs_list = load_cfgdir # read config from a config dir 173 | configs = [OmegaConf.load(cfg) for cfg in configs_list] 174 | config = OmegaConf.merge(*configs) 175 | accelerate_cfg = config.accelerate 176 | diffusion_cfg = config.diffusion 177 | data_cfg = config.data 178 | grad_accu_steps = accelerate_cfg.gradient_accumulation_steps 179 | 180 | train_strtg_cfg = getattr(config, 'training_strategy', None) 181 | if train_strtg_cfg != None: 182 | warp_pos_idx = hasattr(train_strtg_cfg, 'warp_pos_idx') 183 | if warp_pos_idx: 184 | warp_pos_idx_fn = partial(warp_pos_idx_from_grid, 185 | shift=train_strtg_cfg.warp_pos_idx.shift, 186 | scale=train_strtg_cfg.warp_pos_idx.scale, 187 | max_len=train_strtg_cfg.warp_pos_idx.max_len 188 | ) 189 | 190 | accelerator_project_cfg = ProjectConfiguration(project_dir=workdirnow, logging_dir=logging_dir) 191 | 192 | if getattr(accelerate_cfg, 'fsdp_config', None) != None: 193 | import functools 194 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 195 | BackwardPrefetch, CPUOffload, ShardingStrategy, MixedPrecision, StateDictType, FullStateDictConfig, FullOptimStateDictConfig, 196 | ) 197 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy 198 | fsdp_cfg = accelerate_cfg.fsdp_config 199 | if accelerate_cfg.mixed_precision == "fp16": 200 | dtype = torch.float16 201 | elif accelerate_cfg.mixed_precision == "bf16": 202 | dtype = torch.bfloat16 203 | else: 204 | dtype = torch.float32 205 | fsdp_plugin = FullyShardedDataParallelPlugin( 206 | sharding_strategy = { 207 | 'FULL_SHARD': ShardingStrategy.FULL_SHARD, 208 | 'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP, 209 | 'NO_SHARD': ShardingStrategy.NO_SHARD, 210 | 'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD, 211 | 'HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2, 212 | }[fsdp_cfg.sharding_strategy], 213 | backward_prefetch = { 214 | 'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE, 215 | 'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST, 216 | }[fsdp_cfg.backward_prefetch], 217 | mixed_precision_policy = MixedPrecision( 218 | param_dtype=dtype, 219 | reduce_dtype=dtype, 220 | ), 221 | auto_wrap_policy = functools.partial( 222 | size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params 223 | ), 224 | cpu_offload = CPUOffload(offload_params=fsdp_cfg.cpu_offload), 225 | state_dict_type = { 226 | 'FULL_STATE_DICT': StateDictType.FULL_STATE_DICT, 227 | 'LOCAL_STATE_DICT': StateDictType.LOCAL_STATE_DICT, 228 | 'SHARDED_STATE_DICT': StateDictType.SHARDED_STATE_DICT 229 | }[fsdp_cfg.state_dict_type], 230 | state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True), 231 | optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), 232 | limit_all_gathers = fsdp_cfg.limit_all_gathers, # False 233 | use_orig_params = fsdp_cfg.use_orig_params, # True 234 | sync_module_states = fsdp_cfg.sync_module_states, #True 235 | forward_prefetch = fsdp_cfg.forward_prefetch, # False 236 | activation_checkpointing = fsdp_cfg.activation_checkpointing, # False 237 | ) 238 | else: 239 | fsdp_plugin = None 240 | accelerator = Accelerator( 241 | gradient_accumulation_steps=grad_accu_steps, 242 | mixed_precision=accelerate_cfg.mixed_precision, 243 | fsdp_plugin=fsdp_plugin, 244 | log_with=getattr(accelerate_cfg, 'logger', 'wandb'), 245 | project_config=accelerator_project_cfg, 246 | ) 247 | device = accelerator.device 248 | 249 | # Make one log on every process with the configuration for debugging. 250 | logging.basicConfig( 251 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 252 | datefmt="%m/%d/%Y %H:%M:%S", 253 | level=logging.INFO, 254 | ) 255 | logger.info(accelerator.state, main_process_only=False) 256 | 257 | if accelerator.is_local_main_process: 258 | File_handler = logging.FileHandler(os.path.join(logging_dir, project_name+"_"+datenow+".log"), encoding="utf-8") 259 | File_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) 260 | File_handler.setLevel(logging.INFO) 261 | logger.logger.addHandler(File_handler) 262 | 263 | diffusers.utils.logging.set_verbosity_warning() 264 | diffusers.utils.logging.set_verbosity_info() 265 | else: 266 | diffusers.utils.logging.set_verbosity_error() 267 | diffusers.utils.logging.set_verbosity_error() 268 | 269 | if args.seed is not None: 270 | set_seed(args.seed) 271 | 272 | if args.allow_tf32: # for A100 273 | torch.backends.cuda.matmul.allow_tf32 = True 274 | torch.backends.cudnn.allow_tf32 = True 275 | 276 | if args.scale_lr: 277 | learning_rate = ( 278 | accelerate_cfg.learning_rate * 279 | grad_accu_steps * 280 | data_cfg.params.train.loader.batch_size * # local batch size per device 281 | accelerator.num_processes / accelerate_cfg.learning_rate_base_batch_size # global batch size 282 | ) 283 | else: 284 | learning_rate = accelerate_cfg.learning_rate 285 | 286 | 287 | model = instantiate_from_config(diffusion_cfg.network_config).to(device=device) 288 | # update ema 289 | if args.use_ema: 290 | # ema_dtype = torch.float32 291 | if hasattr(model, 'module'): 292 | ema_model = deepcopy(model.module).to(device=device) 293 | else: 294 | ema_model = deepcopy(model).to(device=device) 295 | if getattr(diffusion_cfg, 'pretrain_config', None) != None: # transfer to larger reolution 296 | if getattr(diffusion_cfg.pretrain_config, 'ema_ckpt', None) != None: 297 | init_from_ckpt( 298 | ema_model, checkpoint_dir=diffusion_cfg.pretrain_config.ema_ckpt, 299 | ignore_keys=diffusion_cfg.pretrain_config.ignore_keys, verbose=True 300 | ) 301 | for p in ema_model.parameters(): 302 | p.requires_grad = False 303 | 304 | if args.use_ema: 305 | model = accelerator.prepare_model(model, device_placement=False) 306 | ema_model = accelerator.prepare_model(ema_model, device_placement=False) 307 | else: 308 | model = accelerator.prepare_model(model, device_placement=False) 309 | 310 | # In SiT, we use transport instead of diffusion 311 | transport = create_transport(**OmegaConf.to_container(diffusion_cfg.transport)) # default: velocity; 312 | # schedule_sampler = create_named_schedule_sampler() 313 | 314 | # Setup Dataloader 315 | total_batch_size = data_cfg.params.train.loader.batch_size * accelerator.num_processes * grad_accu_steps 316 | global_steps = 0 317 | if args.resume_from_checkpoint: 318 | # normal read with safety check 319 | if args.resume_from_checkpoint != "latest": 320 | resume_from_path = os.path.basename(args.resume_from_checkpoint) 321 | else: # Get the most recent checkpoint 322 | dirs = os.listdir(ckptdir) 323 | dirs = [d for d in dirs if d.startswith("checkpoint")] 324 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 325 | resume_from_path = dirs[-1] if len(dirs) > 0 else None 326 | 327 | if resume_from_path is None: 328 | logger.info( 329 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 330 | ) 331 | args.resume_from_checkpoint = None 332 | else: 333 | global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps 334 | logger.info(f"Resuming from steps: {global_steps}") 335 | 336 | get_train_dataloader = instantiate_from_config(data_cfg) 337 | train_len = get_train_dataloader.train_len() 338 | train_dataloader = get_train_dataloader.train_dataloader( 339 | global_batch_size=total_batch_size, max_steps=accelerate_cfg.max_train_steps, 340 | resume_step=global_steps, seed=args.seed 341 | ) 342 | 343 | # Setup optimizer and lr_scheduler 344 | if accelerator.is_main_process: 345 | for name, param in model.named_parameters(): 346 | print(name, param.requires_grad) 347 | if getattr(diffusion_cfg, 'pretrain_config', None) != None: # transfer to larger reolution 348 | params = filter(lambda p: p.requires_grad, model.parameters()) 349 | else: 350 | params = list(model.parameters()) 351 | optimizer_cfg = default( 352 | accelerate_cfg.optimizer, {"target": "torch.optim.AdamW"} 353 | ) 354 | optimizer = get_obj_from_str(optimizer_cfg["target"])( 355 | params, lr=learning_rate, **optimizer_cfg.get("params", dict()) 356 | ) 357 | lr_scheduler = get_scheduler( 358 | accelerate_cfg.lr_scheduler, 359 | optimizer=optimizer, 360 | num_warmup_steps=accelerate_cfg.lr_warmup_steps, 361 | num_training_steps=accelerate_cfg.max_train_steps, 362 | ) 363 | 364 | # Prepare Accelerate 365 | optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 366 | optimizer, train_dataloader, lr_scheduler 367 | ) 368 | 369 | # We need to initialize the trackers we use, and also store our configuration. 370 | # The trackers initializes automatically on the main process. 371 | if accelerator.is_main_process and getattr(accelerate_cfg, 'logger', 'wandb') != None: 372 | os.environ["WANDB_DIR"] = os.path.join(os.getcwd(), workdirnow) 373 | accelerator.init_trackers( 374 | project_name=args.main_project_name, 375 | config=config, 376 | init_kwargs={"wandb": {"group": args.project_name}} 377 | ) 378 | 379 | # Train! 380 | logger.info("***** Running training *****") 381 | logger.info(f" Num examples = {train_len}") 382 | logger.info(f" Instantaneous batch size per device = {data_cfg.params.train.loader.batch_size}") 383 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 384 | logger.info(f" Learning rate = {learning_rate}") 385 | logger.info(f" Gradient Accumulation steps = {grad_accu_steps}") 386 | logger.info(f" Total optimization steps = {accelerate_cfg.max_train_steps}") 387 | logger.info(f" Current optimization steps = {global_steps}") 388 | logger.info(f" Train dataloader length = {len(train_dataloader)} ") 389 | logger.info(f" Training Mixed-Precision = {accelerate_cfg.mixed_precision}") 390 | 391 | # Potentially load in the weights and states from a previous save 392 | if args.resume_from_checkpoint: 393 | # normal read with safety check 394 | error_times=0 395 | while(True): 396 | if error_times >= 100: 397 | raise 398 | try: 399 | logger.info(f"Resuming from checkpoint {resume_from_path}") 400 | accelerator.load_state(os.path.join(ckptdir, resume_from_path)) 401 | break 402 | except (RuntimeError, Exception) as err: 403 | error_times+=1 404 | if accelerator.is_local_main_process: 405 | logger.warning(err) 406 | logger.warning(f"Failed to resume from checkpoint {resume_from_path}") 407 | shutil.rmtree(os.path.join(ckptdir, resume_from_path)) 408 | else: 409 | time.sleep(2) 410 | 411 | # save config 412 | OmegaConf.save(config=config, f=os.path.join(cfgdir, "config.yaml")) 413 | 414 | # Only show the progress bar once on each machine. 415 | progress_bar = tqdm( 416 | range(0, accelerate_cfg.max_train_steps), 417 | disable = not accelerator.is_main_process 418 | ) 419 | progress_bar.set_description("Optim Steps") 420 | progress_bar.update(global_steps) 421 | 422 | if args.use_ema: 423 | # ema_model = ema_model.to(ema_dtype) 424 | ema_model.eval() 425 | # Training Loop 426 | model.train() 427 | train_loss = 0.0 428 | for step, batch in enumerate(train_dataloader, start=global_steps): 429 | for batch_key in batch.keys(): 430 | if not isinstance(batch[batch_key], list): 431 | batch[batch_key] = batch[batch_key].to(device=device) 432 | x = batch['feature'] # (B, N, C) 433 | grid = batch['grid'] # (B, 2, N) 434 | mask = batch['mask'] # (B, N) 435 | y = batch['label'] # (B, 1) 436 | size = batch['size'] # (B, N_pack, 2), order: h, w. When pack is not used, N_pack=1. 437 | with accelerator.accumulate(model): 438 | # save memory for x, grid, mask 439 | N_batch = int(torch.max(torch.sum(size[..., 0] * size[..., 1], dim=-1))) 440 | x, grid, mask = x[:, : N_batch], grid[..., : N_batch], mask[:, : N_batch] 441 | 442 | # prepare other parameters 443 | y = y.squeeze(dim=-1).to(torch.int) 444 | model_kwargs = dict(y=y, grid=grid.long(), mask=mask, size=size) 445 | # forward model and compute loss 446 | with accelerator.autocast(): 447 | loss_dict = transport.training_losses(model, x, model_kwargs) 448 | loss = loss_dict["loss"].mean() 449 | # Backpropagate 450 | optimizer.zero_grad() 451 | accelerator.backward(loss) 452 | if accelerator.sync_gradients and accelerate_cfg.max_grad_norm > 0.: 453 | all_norm = accelerator.clip_grad_norm_( 454 | model.parameters(), accelerate_cfg.max_grad_norm 455 | ) 456 | optimizer.step() 457 | lr_scheduler.step() 458 | # Gather the losses across all processes for logging (if we use distributed training). 459 | avg_loss = accelerator.gather(loss.repeat(data_cfg.params.train.loader.batch_size)).mean() 460 | train_loss += avg_loss.item() / grad_accu_steps 461 | 462 | # Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation 463 | if accelerator.sync_gradients: 464 | if args.use_ema: 465 | # update_ema(ema_model, deepcopy(model).type(ema_dtype), args.ema_decay) 466 | update_ema(ema_model, model, args.ema_decay) 467 | 468 | progress_bar.update(1) 469 | global_steps += 1 470 | if getattr(accelerate_cfg, 'logger', 'wandb') != None: 471 | accelerator.log({"train_loss": train_loss}, step=global_steps) 472 | accelerator.log({"lr": lr_scheduler.get_last_lr()[0]}, step=global_steps) 473 | if accelerate_cfg.max_grad_norm != 0.0: 474 | accelerator.log({"grad_norm": all_norm.item()}, step=global_steps) 475 | train_loss = 0.0 476 | if global_steps % accelerate_cfg.checkpointing_steps == 0: 477 | if accelerate_cfg.checkpoints_total_limit is not None: 478 | checkpoints = os.listdir(ckptdir) 479 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 480 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 481 | 482 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 483 | if accelerator.is_main_process and len(checkpoints) >= accelerate_cfg.checkpoints_total_limit: 484 | num_to_remove = len(checkpoints) - accelerate_cfg.checkpoints_total_limit + 1 485 | removing_checkpoints = checkpoints[0:num_to_remove] 486 | 487 | logger.info( 488 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 489 | ) 490 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 491 | 492 | for removing_checkpoint in removing_checkpoints: 493 | removing_checkpoint = os.path.join(ckptdir, removing_checkpoint) 494 | shutil.rmtree(removing_checkpoint) 495 | 496 | save_path = os.path.join(ckptdir, f"checkpoint-{global_steps}") 497 | if accelerator.is_main_process: 498 | os.makedirs(save_path) 499 | accelerator.wait_for_everyone() 500 | accelerator.save_state(save_path) 501 | logger.info(f"Saved state to {save_path}") 502 | accelerator.wait_for_everyone() 503 | 504 | if global_steps in accelerate_cfg.checkpointing_steps_list: 505 | save_path = os.path.join(ckptdir, f"save-checkpoint-{global_steps}") 506 | accelerator.wait_for_everyone() 507 | accelerator.save_state(save_path) 508 | logger.info(f"Saved state to {save_path}") 509 | accelerator.wait_for_everyone() 510 | 511 | logs = {"step_loss": loss.detach().item(), 512 | "lr": lr_scheduler.get_last_lr()[0]} 513 | progress_bar.set_postfix(**logs) 514 | if global_steps % accelerate_cfg.logging_steps == 0: 515 | if accelerator.is_main_process: 516 | logger.info("step="+str(global_steps)+" / total_step="+str(accelerate_cfg.max_train_steps)+", step_loss="+str(logs["step_loss"])+', lr='+str(logs["lr"])) 517 | 518 | if global_steps >= accelerate_cfg.max_train_steps: 519 | logger.info(f'global step ({global_steps}) >= max_train_steps ({accelerate_cfg.max_train_steps}), stop training!!!') 520 | break 521 | accelerator.wait_for_everyone() 522 | accelerator.end_training() 523 | 524 | if __name__ == "__main__": 525 | main() --------------------------------------------------------------------------------