├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── .DS_Store
└── figure.png
├── configs
├── fit
│ └── config_fit_xl.yaml
└── fitv2
│ ├── config_fitv2_3B.yaml
│ ├── config_fitv2_hr_3B.yaml
│ ├── config_fitv2_hr_xl.yaml
│ └── config_fitv2_xl.yaml
├── fit
├── data
│ ├── in1k_dataset.py
│ └── in1k_latent_dataset.py
├── model
│ ├── fit_model.py
│ ├── modules.py
│ ├── norms.py
│ ├── rope.py
│ └── utils.py
├── scheduler
│ ├── improved_diffusion
│ │ ├── __init__.py
│ │ ├── diffusion_utils.py
│ │ ├── gaussian_diffusion.py
│ │ ├── respace.py
│ │ └── timestep_sampler.py
│ └── transport
│ │ ├── __init__.py
│ │ ├── integrators.py
│ │ ├── path.py
│ │ ├── transport.py
│ │ └── utils.py
└── utils
│ ├── eval_utils.py
│ ├── lr_scheduler.py
│ ├── sit_eval_utils.py
│ └── utils.py
├── requirements.txt
├── sample_fit_ddp.py
├── sample_fitv2_ddp.py
├── setup.py
├── tools
├── download_in1k_latents_1024.sh
├── download_in1k_latents_256.sh
├── train_fit_xl.sh
├── train_fitv2_3B.sh
├── train_fitv2_hr_3B.sh
├── train_fitv2_hr_xl.sh
└── train_fitv2_xl.sh
├── train_fit.py
└── train_fitv2.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whlzy/FiT/b22b7b0ac4bf841242711f5c20703d072708270e/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.pyc
3 | *.ckpt
4 | *.pth
5 | *.egg-info
6 | *.mp4
7 | *.png
8 | *.jpg
9 | *.npz
10 | *.safetensors
11 | /outputs
12 | /checkpoints
13 | /tmp
14 | /logs
15 | /dataset
16 | /workdir
17 | /debug
18 | *__pycache__*
19 | batch*
20 | /.vscode
21 | /.ipynb_checkpoints
22 | *.ipynb_checkpoints
23 | /test
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # FiT: Flexible Vision Transformer for Diffusion Model
4 |
5 |
6 | 📃 FiT Paper •
7 | 📦 FiT Checkpoint •
8 | 📃 FiTv2 Paper •
9 | 📦 FiTv2 Checkpoint
10 |
11 |
12 | This is the official repo which contains PyTorch model definitions, pre-trained weights and sampling code for our flexible vision transformer (FiT).
13 | FiT is a diffusion transformer based model which can generate images at unrestricted resolutions and aspect ratios.
14 |
15 | The core features will include:
16 | * Pre-trained class-conditional FiT-XL-2-16 (2000K) model weight trained on ImageNet ($H\times W \le 256\times256$).
17 | * Pre-trained class-conditional FiTv2-XL-2-16 (2000K) and FiTv2-3B-2-16 (1000K) model weight trained on ImageNet ($H\times W \le 256\times256$).
18 | * High-resolution Fine-tuned FiTv2-XL-2-32 (400K) and FiTv2-3B-2-32 (200K) model weight trained on ImageNet ($H\times W \le 512\times512$).
19 | * A pytorch sample code for running pre-trained FiT and FiTv2 models to generate images at unrestricted resolutions and aspect ratios.
20 |
21 | Why we need FiT?
22 | * 🧐 Nature is infinitely resolution-free. FiT, like Sora , was trained on the unrestricted resolution or aspect ratio. FiT is capable of generating images at unrestricted resolutions and aspect ratios.
23 | * 🤗 FiT exhibits remarkable flexibility in resolution extrapolation generation.
24 |
25 | Stay tuned for this project! 😆
26 |
27 |
28 | ## Setup
29 | First, download and setup the repo:
30 | ```
31 | git clone https://github.com/whlzy/FiT.git
32 | cd FiT
33 | ```
34 | ## Installation
35 | ```
36 | conda create -n fit_env python=3.10
37 | pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118
38 | pip install xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118
39 | pip install -r requirements.txt
40 | pip install -e .
41 | ```
42 |
43 | ## Sample
44 | ### Basic Sampling
45 |
46 | Basically, the model is trained with images whose $H\times W\leqslant 256\times256$.
47 | Our FiTv1-XL/2 model and FiTv2-XL/2 model are trained with batch size of 256 for 2000K steps.
48 | Our FiTv2-3B/2 model is trained with batch size of 256 for 1000K steps.
49 |
50 | The pre-trained FiT models can be downloaded directyl from huggingface:
51 | | FiT Model | Checkpoint | FID-256x256 | FID-320x320 | Model Size | GFlOPS |
52 | |---------------|------------|---------|-----------------|------------| ------ |
53 | | [FiTv1-XL/2](https://huggingface.co/InfImagine/FiT/tree/main/FiTv1_xl) | [CKPT](https://huggingface.co/InfImagine/FiT/blob/main/FiTv1_xl/model_ema.bin) | 4.21 | 5.11 | 824M | 153 |
54 | | [FiTv2-XL/2](https://huggingface.co/InfImagine/FiTv2/tree/main/FiTv2_XL) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL/model_ema.safetensors?download=true) | 2.26 | 3.55 | 671M | 147 |
55 | | [FiTv2-3B/2](https://huggingface.co/InfImagine/FiT/tree/main/FiTv1_xl) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B/model_ema.bin?download=true) | 2.15 | 3.22 | 3B | 653 |
56 |
57 |
58 | #### Downloading the checkpoints
59 | Downloading via wget:
60 | ```
61 | mkdir checkpoints
62 |
63 | wget -c "https://huggingface.co/InfImagine/FiT/blob/main/FiTv1_xl/model_ema.bin" -O checkpoints/fitv1_xl.bin
64 |
65 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL/model_ema.safetensors?download=true" -O checkpoints/fitv2_xl.safetensors
66 |
67 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B/model_ema.bin?download=true" -O checkpoints/fitv2_3B.bin
68 | ```
69 |
70 | #### Sampling 256x256 Images
71 | Sampling with FiTv1-XL/2 for $256\times 256$ Images:
72 | ```
73 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fit_ddp.py --num-fid-samples 50000 --cfgdir configs/fit/config_fit_xl.yaml --ckpt checkpoints/fitv1_xl.bin --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32
74 | ```
75 | Sampling with FiTv2-XL/2 for $256\times 256$ Images:
76 | ```
77 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE
78 | ```
79 | Sampling with FiTv2-3B/2 for $256\times 256$ Images:
80 | ```
81 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_3B.yaml --ckpt checkpoints/fitv2_3B.bin --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE
82 | ```
83 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified.
84 |
85 |
86 | #### Sampling Images with arbitrary resolutions
87 | We can assign the *image-height* and *image-width* with any value we want. And we need to specify the original maximum positional embedding length (*ori-max-pe-len*) and the interpolation method.
88 | We show some examples as follows.
89 |
90 | Sampling with FiTv2-XL/2 for $160\times 320$ images:
91 | ```
92 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 160 --image-width 320 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple
93 | ```
94 | Sampling with FiTv2-XL/2 for $320\times 320$ images:
95 | ```
96 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 320 --image-width 320 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation ntkpro2 --decouple
97 | ```
98 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified.
99 |
100 |
101 | ### High-resolution Sampling
102 |
103 | For high-resolution image generation, we use images whose $H\times W \leqslant 512\times512$.
104 | Our FiTv2-XL/2 is finetuned with batch size of 256 for 400K steps, while
105 | FiTv2-3B/2 is finetuned with batch size of 256 for 200K steps.
106 |
107 | The high-resolution fine-tuned FiT models can be downloaded directyl from huggingface:
108 | | FiT Model | Checkpoint | FID-512x512 | FID-320x640 | Model Size | GFlOPS |
109 | |---------------|------------|---------|-----------------|------------| ------ |
110 | | [FiTv2-HR-XL/2](https://huggingface.co/InfImagine/FiTv2/tree/main/FiTv2_XL_HRFT) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL_HRFT/model_ema.safetensors?download=true) | 2.90 | 4.87 | 671M | 147 |
111 | | [FiTv2-HR-3B/2](https://huggingface.co/InfImagine/FiTv2/tree/main/FiTv2_3B_HRFT) | [CKPT](https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B_HRFT/model_ema.safetensors?download=true) | 2.41 | 4.54 | 3B | 653 |
112 |
113 |
114 | #### Downloading
115 | Downloading via wget:
116 | ```
117 | mkdir checkpoints
118 |
119 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL_HRFT/model_ema.safetensors?download=true" -O checkpoints/fitv2_hr_xl.safetensors
120 |
121 | wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B_HRFT/model_ema.safetensors?download=true" -O checkpoints/fitv2_hr_3B.safetensors
122 | ```
123 |
124 | #### Sampling 512x512 Images
125 | Sampling with FiTv2-HR-XL/2 for $512\times 512$ Images:
126 | ```
127 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_xl.yaml --ckpt checkpoints/fitv2_hr_xl.safetensors --image-height 512 --image-width 512 --num-sampling-steps 250 --cfg-scale 1.65 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple
128 | ```
129 | Sampling with FiTv2-HR-3B/2 for $512\times 512$ Images:
130 | ```
131 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_3B.yaml --ckpt checkpoints/fitv2_hr_3B.safetensors --image-height 512 --image-width 512 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple
132 | ```
133 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified.
134 |
135 | #### Sampling Images with arbitrary resolutions
136 | Sampling with FiTv2-HR-XL/2 for $320\times 640$ images:
137 | ```
138 | python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_xl.yaml --ckpt checkpoints/fitv2_hr_xl.safetensors --image-height 320 --image-width 640 --num-sampling-steps 250 --cfg-scale 1.65 --global-seed 0 --per-proc-batch-size 32 --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple
139 | ```
140 | Note that *NUM_NODE*, *NUM_GPU* and *MASTER_PORT* need to be specified.
141 |
142 | ## Evaluations
143 | The sampling generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow
144 | evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and
145 | other metrics.
146 |
147 |
148 |
149 | ## Flexible Imagenet Latent Datasets
150 |
151 | We use [SD-VAE-FT-EMA](https://huggingface.co/stabilityai/sd-vae-ft-ema) to encode an image into the latent codes.
152 |
153 | Accordingly to our flexible training pipeline, we can deal with images with arbitrary resolutions and aspect ratios.
154 | So we preprocess the ImageNet1k dataset according to the original height and width of an image.
155 | Conventionally, we set patch size $p$ to $2$ and the downsampling scale $s$ of the VAE encoder is $8$,
156 | so an image with $H=256$ and $W=256$ will lead to $\frac{H\times W}{p^2\times s^2} = 256$ tokens.
157 |
158 | For our pre-training, we set the maximum token length to $256$, which corresponds image resolution size $S=H=W=256$.
159 | For the high-resolution fine-tuning, the token length is $1024$, which corresponds image resolution size $S=H=W=512$.
160 | Given an input image $I\in \mathbb{R}^{3\times H \times W}$, and target resolution size $S=256/512$, the preprocessing is:
161 | ```
162 | If H > S and W > S:
163 | img_resize = Resize(I)
164 | latent_resize = VAE_Encode(img_resize)
165 | save(latent_resize)
166 | img_crop = CenterCrop(Resize(I))
167 | latent_crop = VAE_Encode(img_crop)
168 | save(latent_resize)
169 | else:
170 | img_resize = Resize(I)
171 | latent_resize = VAE_Encode(img_resize)
172 | save(latent_resize)
173 | ```
174 |
175 | ### Dataset for Pretraining
176 |
177 | All the image latent codes with maximum token length $256$ can be downloaded from [here](https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/tree/main).
178 |
179 | ```
180 | bash tools/download_in1k_latents_256.sh
181 | ```
182 |
183 |
184 |
185 | ### Dataset for High-resolution Fine-tuning
186 | All the image latent codes with maximum token length $1024$ can be downloaded from [here](https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema).
187 |
188 |
189 | ```
190 | bash tools/download_in1k_latents_1024.sh
191 | ```
192 |
193 | ### Dataset Architecture
194 |
195 | - imagenet1k_latents_256_sd_vae_ft_ema
196 | - less_than_16
197 | - xxxxxxx.safetensors
198 | - xxxxxxx.safetensors
199 | - from_16_to_256
200 | - xxxxxxx.safetensors
201 | - xxxxxxx.safetensors
202 | - greater_than_256_crop
203 | - xxxxxxx.safetensors
204 | - xxxxxxx.safetensors
205 | - greater_than_256_resize
206 | - xxxxxxx.safetensors
207 | - xxxxxxx.safetensors
208 | - imagenet1k_latents_1024_sd_vae_ft_ema
209 | - less_than_16
210 | - xxxxxxx.safetensors
211 | - xxxxxxx.safetensors
212 | - from_16_to_1024
213 | - xxxxxxx.safetensors
214 | - xxxxxxx.safetensors
215 | - greater_than_1024_crop
216 | - xxxxxxx.safetensors
217 | - xxxxxxx.safetensors
218 | - greater_than_1024_resize
219 | - xxxxxxx.safetensors
220 | - xxxxxxx.safetensors
221 |
222 | ## Train
223 | You need to determine the number of node and GPU for your training.
224 |
225 | Train FiT and FiTv2 models:
226 | ```
227 | bash tools/train_fit_xl.sh
228 |
229 | bash tools/train_fitv2_xl.sh
230 |
231 | bash tools/train_fitv2_3B.sh
232 | ```
233 |
234 | High-resolution Fine-tuning:
235 | ```
236 | bash tools/train_fitv2_hr_xl.sh
237 |
238 | bash tools/train_fitv2_hr_3B.sh
239 | ```
240 |
241 |
242 | ## BibTeX
243 | ```bibtex
244 | @article{Lu2024FiT,
245 | title={FiT: Flexible Vision Transformer for Diffusion Model},
246 | author={Zeyu Lu and Zidong Wang and Di Huang and Chengyue Wu and Xihui Liu and Wanli Ouyang and Lei Bai},
247 | year={2024},
248 | journal={arXiv preprint arXiv:2402.12376},
249 | }
250 | ```
251 | ```bibtex
252 | @article{wang2024fitv2,
253 | title={Fitv2: Scalable and improved flexible vision transformer for diffusion model},
254 | author={Wang, ZiDong and Lu, Zeyu and Huang, Di and Zhou, Cai and Ouyang, Wanli and others},
255 | journal={arXiv preprint arXiv:2410.13925},
256 | year={2024}
257 | }
258 | ```
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whlzy/FiT/b22b7b0ac4bf841242711f5c20703d072708270e/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whlzy/FiT/b22b7b0ac4bf841242711f5c20703d072708270e/assets/figure.png
--------------------------------------------------------------------------------
/configs/fit/config_fit_xl.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema
3 | improved_diffusion:
4 | timestep_respacing: ''
5 | noise_schedule: linear
6 | use_kl: false
7 | sigma_small: false
8 | predict_xstart: false
9 | learn_sigma: true
10 | rescale_learned_sigmas: false
11 | diffusion_steps: 1000
12 | noise_scheduler:
13 | num_train_timesteps: 1000
14 | beta_start: 0.0001
15 | beta_end: 0.02
16 | beta_schedule: linear
17 | prediction_type: epsilon
18 | steps_offset: 0
19 | clip_sample: false
20 | network_config:
21 | target: fit.model.fit_model.FiT
22 | params:
23 | context_size: 256
24 | patch_size: 2
25 | in_channels: 4
26 | hidden_size: 1152
27 | depth: 28
28 | num_heads: 16
29 | mlp_ratio: 4.0
30 | class_dropout_prob: 0.1
31 | num_classes: 1000
32 | learn_sigma: true
33 | use_swiglu: true
34 | use_swiglu_large: true
35 | rel_pos_embed: rope
36 |
37 |
38 | data:
39 | target: fit.data.in1k_latent_dataset.INLatentLoader
40 | params:
41 | train:
42 | data_path: datasets/imagenet1k_latents_256_sd_vae_ft_ema
43 | target_len: 256
44 | random: 'resize'
45 | loader:
46 | batch_size: 32
47 | num_workers: 2
48 | shuffle: True
49 |
50 | accelerate:
51 | # others
52 | gradient_accumulation_steps: 1
53 | mixed_precision: 'bf16'
54 | # training step config
55 | num_train_epochs:
56 | max_train_steps: 2000000
57 | # optimizer config
58 | learning_rate: 1.0e-4
59 | learning_rate_base_batch_size: 256
60 | max_grad_norm: 1.0
61 | optimizer:
62 | target: torch.optim.AdamW
63 | params:
64 | betas: ${tuple:0.9, 0.999}
65 | weight_decay: 0 #1.0e-2
66 | eps: 1.0e-8
67 | lr_scheduler: constant
68 | lr_warmup_steps: 500
69 | # checkpoint config
70 | logger: wandb
71 | checkpointing_epochs: False
72 | checkpointing_steps: 100000
73 | checkpointing_steps_list: [400000, 1000000, 2000000]
74 | checkpoints_total_limit: 2
75 | logging_steps: 10000
76 |
--------------------------------------------------------------------------------
/configs/fitv2/config_fitv2_3B.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema
3 | transport:
4 | path_type: Linear
5 | prediction: velocity
6 | loss_weight: null
7 | sample_eps: null
8 | train_eps: null
9 | snr_type: lognorm
10 | sampler:
11 | mode: ODE
12 | sde:
13 | sampling_method: Euler
14 | diffusion_form: sigma
15 | diffusion_norm: 1.0
16 | last_step: Mean
17 | last_step_size: 0.04
18 | ode:
19 | sampling_method: dopri5
20 | atol: 1.0e-06
21 | rtol: 0.001
22 | reverse: false
23 | likelihood: false
24 | network_config:
25 | target: fit.model.fit_model.FiT
26 | params:
27 | context_size: 256
28 | patch_size: 2
29 | in_channels: 4
30 | hidden_size: 2304
31 | depth: 40
32 | num_heads: 24
33 | mlp_ratio: 4.0
34 | class_dropout_prob: 0.1
35 | num_classes: 1000
36 | learn_sigma: false
37 | use_sit: true
38 | use_swiglu: true
39 | use_swiglu_large: false
40 | q_norm: layernorm
41 | k_norm: layernorm
42 | qk_norm_weight: false
43 | rel_pos_embed: rope
44 | abs_pos_embed: null
45 | adaln_type: lora
46 | adaln_lora_dim: 576
47 |
48 |
49 | data:
50 | target: fit.data.in1k_latent_dataset.INLatentLoader
51 | params:
52 | train:
53 | data_path: datasets/imagenet1k_latents_256_sd_vae_ft_ema
54 | target_len: 256
55 | random: 'random'
56 | loader:
57 | batch_size: 16
58 | num_workers: 2
59 | shuffle: True
60 |
61 |
62 | accelerate:
63 | # others
64 | gradient_accumulation_steps: 1
65 | mixed_precision: 'bf16'
66 | # training step config
67 | num_train_epochs:
68 | max_train_steps: 2000000
69 | # optimizer config
70 | learning_rate: 1.0e-4
71 | learning_rate_base_batch_size: 256
72 | max_grad_norm: 1.0
73 | optimizer:
74 | target: torch.optim.AdamW
75 | params:
76 | betas: ${tuple:0.9, 0.999}
77 | weight_decay: 0 #1.0e-2
78 | eps: 1.0e-8
79 | lr_scheduler: constant_with_warmup
80 | lr_warmup_steps: 50000
81 | # checkpoint config
82 | logger: wandb
83 | checkpointing_epochs: False
84 | checkpointing_steps: 4000
85 | checkpointing_steps_list: [200000, 400000, 1000000, 1400000, 1500000, 1800000]
86 | checkpoints_total_limit: 2
87 | logging_steps: 1000
88 |
--------------------------------------------------------------------------------
/configs/fitv2/config_fitv2_hr_3B.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema
3 | transport:
4 | path_type: Linear
5 | prediction: velocity
6 | loss_weight: null
7 | sample_eps: null
8 | train_eps: null
9 | snr_type: lognorm
10 | sampler:
11 | mode: ODE
12 | sde:
13 | sampling_method: Euler
14 | diffusion_form: sigma
15 | diffusion_norm: 1.0
16 | last_step: Mean
17 | last_step_size: 0.04
18 | ode:
19 | sampling_method: dopri5
20 | atol: 1.0e-06
21 | rtol: 0.001
22 | reverse: false
23 | likelihood: false
24 | network_config:
25 | target: fit.model.fit_model.FiT
26 | params:
27 | context_size: 1024
28 | patch_size: 2
29 | in_channels: 4
30 | hidden_size: 2304
31 | depth: 40
32 | num_heads: 24
33 | mlp_ratio: 4.0
34 | class_dropout_prob: 0.1
35 | num_classes: 1000
36 | learn_sigma: false
37 | use_sit: true
38 | use_checkpoint: true
39 | use_swiglu: true
40 | use_swiglu_large: false
41 | q_norm: layernorm
42 | k_norm: layernorm
43 | qk_norm_weight: false
44 | rel_pos_embed: rope
45 | custom_freqs: ntk-aware
46 | decouple: true
47 | ori_max_pe_len: 16
48 | online_rope: true
49 | abs_pos_embed: null
50 | adaln_type: lora
51 | adaln_lora_dim: 576
52 | pretrain_ckpt: checkpoints/fitv2_3B.bin
53 | ignore_keys: ['x_embedder', 'bias', 'LN', 'final_layer']
54 | finetune: partial
55 |
56 | data:
57 | target: fit.data.in1k_latent_dataset.INLatentLoader
58 | params:
59 | train:
60 | data_path: datasets/imagenet1k_latents_1024_sd_vae_ft_ema
61 | target_len: 1024
62 | random: 'random'
63 | loader:
64 | batch_size: 16
65 | num_workers: 2
66 | shuffle: True
67 |
68 |
69 | accelerate:
70 | # others
71 | gradient_accumulation_steps: 1
72 | mixed_precision: 'bf16'
73 | # training step config
74 | num_train_epochs:
75 | max_train_steps: 200000
76 | # optimizer config
77 | learning_rate: 1.0e-4
78 | learning_rate_base_batch_size: 256
79 | max_grad_norm: 1.0
80 | optimizer:
81 | target: torch.optim.AdamW
82 | params:
83 | betas: ${tuple:0.9, 0.999}
84 | weight_decay: 0 #1.0e-2
85 | eps: 1.0e-8
86 | lr_scheduler: constant
87 | lr_warmup_steps: 0
88 | # checkpoint config
89 | logger: wandb
90 | checkpointing_epochs: False
91 | checkpointing_steps: 4000
92 | checkpointing_steps_list: [40000, 80000, 100000]
93 | checkpoints_total_limit: 2
94 | logging_steps: 1000
95 |
--------------------------------------------------------------------------------
/configs/fitv2/config_fitv2_hr_xl.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema
3 | transport:
4 | path_type: Linear
5 | prediction: velocity
6 | loss_weight: null
7 | sample_eps: null
8 | train_eps: null
9 | snr_type: lognorm
10 | sampler:
11 | mode: ODE
12 | sde:
13 | sampling_method: Euler
14 | diffusion_form: sigma
15 | diffusion_norm: 1.0
16 | last_step: Mean
17 | last_step_size: 0.04
18 | ode:
19 | sampling_method: dopri5
20 | atol: 1.0e-06
21 | rtol: 0.001
22 | reverse: false
23 | likelihood: false
24 | network_config:
25 | target: fit.model.fit_model.FiT
26 | params:
27 | context_size: 1024
28 | patch_size: 2
29 | in_channels: 4
30 | hidden_size: 1152
31 | depth: 36
32 | num_heads: 16
33 | mlp_ratio: 4.0
34 | class_dropout_prob: 0.1
35 | num_classes: 1000
36 | learn_sigma: false
37 | use_sit: true
38 | use_swiglu: true
39 | use_swiglu_large: false
40 | q_norm: layernorm
41 | k_norm: layernorm
42 | qk_norm_weight: false
43 | rel_pos_embed: rope
44 | custom_freqs: ntk-aware
45 | decouple: true
46 | ori_max_pe_len: 16
47 | online_rope: true
48 | abs_pos_embed: null
49 | adaln_type: lora
50 | adaln_lora_dim: 288
51 | pretrain_ckpt: checkpoints/fitv2_xl.safetensors
52 | ignore_keys: ['x_embedder', 'bias', 'LN', 'final_layer']
53 | finetune: partial
54 |
55 | data:
56 | target: fit.data.in1k_latent_dataset.INLatentLoader
57 | params:
58 | train:
59 | data_path: datasets/imagenet1k_latents_1024_sd_vae_ft_ema
60 | target_len: 1024
61 | random: 'random'
62 | loader:
63 | batch_size: 16
64 | num_workers: 2
65 | shuffle: True
66 |
67 |
68 | accelerate:
69 | # others
70 | gradient_accumulation_steps: 1
71 | mixed_precision: 'bf16'
72 | # training step config
73 | num_train_epochs:
74 | max_train_steps: 400000
75 | # optimizer config
76 | learning_rate: 1.0e-4
77 | learning_rate_base_batch_size: 256
78 | max_grad_norm: 1.0
79 | optimizer:
80 | target: torch.optim.AdamW
81 | params:
82 | betas: ${tuple:0.9, 0.999}
83 | weight_decay: 0 #1.0e-2
84 | eps: 1.0e-8
85 | lr_scheduler: constant
86 | lr_warmup_steps: 0
87 | # checkpoint config
88 | logger: wandb
89 | checkpointing_epochs: False
90 | checkpointing_steps: 4000
91 | checkpointing_steps_list: [40000, 80000, 100000, 200000]
92 | checkpoints_total_limit: 2
93 | logging_steps: 1000
94 |
--------------------------------------------------------------------------------
/configs/fitv2/config_fitv2_xl.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | pretrained_first_stage_model_path: stabilityai/sd-vae-ft-ema
3 | transport:
4 | path_type: Linear
5 | prediction: velocity
6 | loss_weight: null
7 | sample_eps: null
8 | train_eps: null
9 | snr_type: lognorm
10 | sampler:
11 | mode: ODE
12 | sde:
13 | sampling_method: Euler
14 | diffusion_form: sigma
15 | diffusion_norm: 1.0
16 | last_step: Mean
17 | last_step_size: 0.04
18 | ode:
19 | sampling_method: dopri5
20 | atol: 1.0e-06
21 | rtol: 0.001
22 | reverse: false
23 | likelihood: false
24 | network_config:
25 | target: fit.model.fit_model.FiT
26 | params:
27 | context_size: 256
28 | patch_size: 2
29 | in_channels: 4
30 | hidden_size: 1152
31 | depth: 36
32 | num_heads: 16
33 | mlp_ratio: 4.0
34 | class_dropout_prob: 0.1
35 | num_classes: 1000
36 | learn_sigma: false
37 | use_sit: true
38 | use_swiglu: true
39 | use_swiglu_large: false
40 | q_norm: layernorm
41 | k_norm: layernorm
42 | qk_norm_weight: false
43 | rel_pos_embed: rope
44 | abs_pos_embed: null
45 | adaln_type: lora
46 | adaln_lora_dim: 288
47 |
48 | data:
49 | target: fit.data.in1k_latent_dataset.INLatentLoader
50 | params:
51 | train:
52 | data_path: datasets/imagenet1k_latents_256_sd_vae_ft_ema
53 | target_len: 256
54 | random: 'random'
55 | loader:
56 | batch_size: 16
57 | num_workers: 2
58 | shuffle: True
59 |
60 |
61 | accelerate:
62 | # others
63 | gradient_accumulation_steps: 1
64 | mixed_precision: 'bf16'
65 | # training step config
66 | num_train_epochs:
67 | max_train_steps: 2000000
68 | # optimizer config
69 | learning_rate: 1.0e-4
70 | learning_rate_base_batch_size: 256
71 | max_grad_norm: 1.0
72 | optimizer:
73 | target: torch.optim.AdamW
74 | params:
75 | betas: ${tuple:0.9, 0.999}
76 | weight_decay: 0 #1.0e-2
77 | eps: 1.0e-8
78 | lr_scheduler: constant_with_warmup
79 | lr_warmup_steps: 50000
80 | # checkpoint config
81 | logger: wandb
82 | checkpointing_epochs: False
83 | checkpointing_steps: 4000
84 | checkpointing_steps_list: [200000, 400000, 1000000, 1400000, 1500000, 1800000]
85 | checkpoints_total_limit: 2
86 | logging_steps: 1000
87 |
--------------------------------------------------------------------------------
/fit/data/in1k_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import torchvision
4 | import numpy as np
5 |
6 | from PIL import Image
7 | from torch.utils.data import DataLoader, Dataset
8 | from torchvision.datasets import ImageFolder
9 | from torchvision import transforms
10 | from accelerate.logging import get_logger
11 | logger = get_logger(__name__, log_level="INFO")
12 |
13 |
14 | def center_crop_arr(pil_image, image_size):
15 | """
16 | Center cropping implementation from ADM.
17 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
18 | """
19 | while min(*pil_image.size) >= 2 * image_size:
20 | pil_image = pil_image.resize(
21 | tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX
22 | )
23 |
24 | scale = image_size / min(*pil_image.size)
25 | pil_image = pil_image.resize(
26 | tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC
27 | )
28 |
29 | arr = np.array(pil_image)
30 | crop_y = (arr.shape[0] - image_size) // 2
31 | crop_x = (arr.shape[1] - image_size) // 2
32 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
33 |
34 | def resize_arr(pil_image, image_size, vq_vae_down=8, patch_size=2):
35 | w, h = pil_image.size
36 | min_len = int(vq_vae_down*patch_size) # 8*2=16 -> 256/16=16, 512/16=32
37 | if w * h >= image_size ** 2:
38 | new_w = np.sqrt(w/h) * image_size
39 | new_h = new_w * h / w
40 | elif w < min_len: # upsample, this case only happens twice in ImageNet1k_256
41 | new_w = min_len
42 | new_h = min_len * h / w
43 | elif h < min_len: # upsample, this case only happens once in ImageNet1k_256
44 | new_h = min_len
45 | new_w = min_len * w / h
46 | else:
47 | new_w, new_h = w, h
48 | new_w, new_h = int(new_w/min_len)*min_len, int(new_h/min_len)*min_len
49 |
50 | if new_w == w and new_h == h:
51 | return pil_image
52 | else:
53 | return pil_image.resize((new_w, new_h), resample=Image.Resampling.BICUBIC)
54 |
55 | class ImagenetDataDictWrapper(Dataset):
56 | def __init__(self, dataset):
57 | super().__init__()
58 | self.dataset = dataset
59 |
60 | def __getitem__(self, i):
61 | x, y = self.dataset[i]
62 | return {"jpg": x, "cls": y}
63 |
64 | def __len__(self):
65 | return len(self.dataset)
66 |
67 | class ImagenetLoader():
68 | def __init__(self, train, rescale='crop'):
69 | super().__init__()
70 |
71 | self.train_config = train
72 |
73 | self.batch_size = self.train_config.loader.batch_size
74 | self.num_workers = self.train_config.loader.num_workers
75 | self.shuffle = self.train_config.loader.shuffle
76 |
77 | if rescale == 'crop':
78 | transform = transforms.Compose([
79 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, self.train_config.resize)),
80 | transforms.RandomHorizontalFlip(),
81 | transforms.ToTensor(),
82 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
83 | ])
84 | elif rescale == 'resize':
85 | transform = transforms.Compose([
86 | transforms.Lambda(lambda pil_image: resize_arr(pil_image, self.train_config.resize)),
87 | transforms.RandomHorizontalFlip(),
88 | transforms.ToTensor(),
89 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
90 | ])
91 | elif rescale == 'keep':
92 | transform = transforms.Compose([
93 | transforms.Lambda(lambda pil_image: resize_arr(pil_image, self.train_config.resize)),
94 | transforms.RandomHorizontalFlip(),
95 | transforms.ToTensor(),
96 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
97 | ])
98 | else:
99 | raise NotImplementedError
100 |
101 | self.train_dataset = ImagenetDataDictWrapper(ImageFolder(self.train_config.data_path, transform=transform))
102 |
103 | self.test_dataset = None
104 | self.val_dataset = None
105 |
106 | def train_len(self):
107 | return len(self.train_dataset)
108 |
109 | def train_dataloader(self):
110 | return DataLoader(
111 | self.train_dataset,
112 | batch_size=self.batch_size,
113 | shuffle=self.shuffle,
114 | num_workers=self.num_workers,
115 | pin_memory=True,
116 | drop_last=True
117 | )
118 |
119 | def test_dataloader(self):
120 | return None
121 |
122 | def val_dataloader(self):
123 | return DataLoader(
124 | self.train_dataset,
125 | batch_size=self.batch_size,
126 | shuffle=self.shuffle,
127 | num_workers=self.num_workers,
128 | pin_memory=True,
129 | drop_last=True
130 | )
131 |
132 | if __name__ == "__main__":
133 | from omegaconf import OmegaConf
134 | conf = OmegaConf.load('/home/luzeyu/projects/workspace/generative-models/configs/example_training/dataset/imagenet-256-streaming.yaml')
135 | indataloader=ImagenetLoader(train=conf.data.params.train).train_dataloader()
136 | from tqdm import tqdm
137 | for i in tqdm(indataloader):
138 | # print(i)
139 | # print("*"*20)
140 | pass
--------------------------------------------------------------------------------
/fit/data/in1k_latent_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import random
6 | from tqdm import tqdm
7 | from torch.utils.data import DataLoader, Dataset
8 | from typing import Optional
9 | from safetensors.torch import load_file, save_file
10 | from concurrent.futures import ThreadPoolExecutor, as_completed
11 | from einops import rearrange
12 |
13 |
14 |
15 | class IN1kLatentDataset(Dataset):
16 | def __init__(self, root_dir, target_len=256, random='random'):
17 | super().__init__()
18 | self.RandomHorizontalFlipProb = 0.5
19 | self.root_dir = root_dir
20 | self.target_len = target_len
21 | self.random = random
22 | self.files = []
23 | files_1 = os.listdir(osp.join(root_dir, f'from_16_to_{target_len}'))
24 | files_2 = os.listdir(osp.join(root_dir, f'greater_than_{target_len}_resize'))
25 | files_3 = os.listdir(osp.join(root_dir, f'greater_than_{target_len}_crop'))
26 | files_23 = list(set(files_2) - set(files_3)) # files_3 in files_2
27 | self.files.extend([
28 | [osp.join(root_dir, f'from_16_to_{target_len}', file)] for file in files_1
29 | ])
30 | self.files.extend([
31 | [osp.join(root_dir, f'greater_than_{target_len}_resize', file)] for file in files_23
32 | ])
33 | self.files.extend([
34 | [
35 | osp.join(root_dir, f'greater_than_{target_len}_resize', file),
36 | osp.join(root_dir, f'greater_than_{target_len}_crop', file)
37 | ] for file in files_3
38 | ])
39 |
40 |
41 | def __len__(self):
42 | return len(self.files)
43 |
44 | def __getitem__(self, idx):
45 | if self.random == 'random':
46 | path = random.choice(self.files[idx])
47 | elif self.random == 'resize':
48 | path = self.files[idx][0] # only resize
49 | elif self.random == 'crop':
50 | path = self.files[idx][-1] # only crop
51 | data = load_file(path)
52 | dtype = data['feature'].dtype
53 |
54 | feature = torch.zeros((self.target_len, 16), dtype=dtype)
55 | grid = torch.zeros((2, self.target_len), dtype=dtype)
56 | mask = torch.zeros((self.target_len), dtype=torch.uint8)
57 | size = torch.zeros(2, dtype=torch.int32)
58 |
59 |
60 | seq_len = data['grid'].shape[-1]
61 | if torch.rand(1) < self.RandomHorizontalFlipProb:
62 | feature[0: seq_len] = rearrange(data['feature'][0], 'h w c -> (h w) c')
63 | else:
64 | feature[0: seq_len] = rearrange(data['feature'][1], 'h w c -> (h w) c')
65 | grid[:, 0: seq_len] = data['grid']
66 | mask[0: seq_len] = 1
67 | size = data['size'][None, :]
68 | label = data['label']
69 | return dict(feature=feature, grid=grid, mask=mask, label=label, size=size)
70 |
71 |
72 |
73 | # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60
74 |
75 | # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60
76 | def get_train_sampler(dataset, global_batch_size, max_steps, resume_steps, seed):
77 | sample_indices = torch.empty([max_steps * global_batch_size], dtype=torch.long)
78 | epoch_id, fill_ptr, offs = 0, 0, 0
79 | while fill_ptr < sample_indices.size(0):
80 | g = torch.Generator()
81 | g.manual_seed(seed + epoch_id)
82 | epoch_sample_indices = torch.randperm(len(dataset), generator=g)
83 | epoch_id += 1
84 | epoch_sample_indices = epoch_sample_indices[
85 | :sample_indices.size(0) - fill_ptr
86 | ]
87 | sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \
88 | epoch_sample_indices
89 | fill_ptr += epoch_sample_indices.size(0)
90 | return sample_indices[resume_steps * global_batch_size : ].tolist()
91 |
92 |
93 |
94 | class INLatentLoader():
95 | def __init__(self, train):
96 | super().__init__()
97 |
98 | self.train_config = train
99 |
100 | self.batch_size = self.train_config.loader.batch_size
101 | self.num_workers = self.train_config.loader.num_workers
102 | self.shuffle = self.train_config.loader.shuffle
103 |
104 | self.train_dataset = IN1kLatentDataset(
105 | self.train_config.data_path, self.train_config.target_len, self.train_config.random
106 | )
107 |
108 |
109 | self.test_dataset = None
110 | self.val_dataset = None
111 |
112 | def train_len(self):
113 | return len(self.train_dataset)
114 |
115 | def train_dataloader(self, global_batch_size, max_steps, resume_step, seed=42):
116 | sampler = get_train_sampler(
117 | self.train_dataset, global_batch_size, max_steps, resume_steps, seed
118 | )
119 | return DataLoader(
120 | self.train_dataset,
121 | batch_size=self.batch_size,
122 | sampler=sampler,
123 | num_workers=self.num_workers,
124 | pin_memory=True,
125 | drop_last=True,
126 | )
127 |
128 | def test_dataloader(self):
129 | return None
130 |
131 | def val_dataloader(self):
132 | return DataLoader(
133 | self.train_dataset,
134 | batch_size=self.batch_size,
135 | shuffle=self.shuffle,
136 | num_workers=self.num_workers,
137 | pin_memory=True,
138 | drop_last=True
139 | )
140 |
--------------------------------------------------------------------------------
/fit/model/fit_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | from typing import Optional
5 | from einops import rearrange, repeat
6 | from fit.model.modules import (
7 | PatchEmbedder, TimestepEmbedder, LabelEmbedder,
8 | FiTBlock, FinalLayer
9 | )
10 | from fit.model.utils import get_parameter_dtype
11 | from fit.utils.eval_utils import init_from_ckpt
12 | from fit.model.sincos import get_2d_sincos_pos_embed_from_grid
13 | from fit.model.rope import VisionRotaryEmbedding
14 |
15 | #################################################################################
16 | # Core FiT Model #
17 | #################################################################################
18 |
19 |
20 |
21 | class FiT(nn.Module):
22 | """
23 | Flexible Diffusion model with a Transformer backbone.
24 | """
25 | def __init__(
26 | self,
27 | context_size: int = 256,
28 | patch_size: int = 2,
29 | in_channels: int = 4,
30 | hidden_size: int = 1152,
31 | depth: int = 28,
32 | num_heads: int = 16,
33 | mlp_ratio: float = 4.0,
34 | class_dropout_prob: float = 0.1,
35 | num_classes: int = 1000,
36 | learn_sigma: bool = True,
37 | use_sit: bool = False,
38 | use_checkpoint: bool=False,
39 | use_swiglu: bool = False,
40 | use_swiglu_large: bool = False,
41 | rel_pos_embed: Optional[str] = 'rope',
42 | norm_type: str = "layernorm",
43 | q_norm: Optional[str] = None,
44 | k_norm: Optional[str] = None,
45 | qk_norm_weight: bool = False,
46 | qkv_bias: bool = True,
47 | ffn_bias: bool = True,
48 | adaln_bias: bool = True,
49 | adaln_type: str = "normal",
50 | adaln_lora_dim: int = None,
51 | rope_theta: float = 10000.0,
52 | custom_freqs: str = 'normal',
53 | max_pe_len_h: Optional[int] = None,
54 | max_pe_len_w: Optional[int] = None,
55 | decouple: bool = False,
56 | ori_max_pe_len: Optional[int] = None,
57 | online_rope: bool = False,
58 | add_rel_pe_to_v: bool = False,
59 | pretrain_ckpt: str = None,
60 | ignore_keys: list = None,
61 | finetune: str = None,
62 | time_shifting: int = 1,
63 | **kwargs,
64 | ):
65 | super().__init__()
66 | self.context_size = context_size
67 | self.hidden_size = hidden_size
68 | assert not (learn_sigma and use_sit)
69 | self.learn_sigma = learn_sigma
70 | self.use_sit = use_sit
71 | self.use_checkpoint = use_checkpoint
72 | self.depth = depth
73 | self.mlp_ratio = mlp_ratio
74 | self.class_dropout_prob = class_dropout_prob
75 | self.num_classes = num_classes
76 | self.in_channels = in_channels
77 | self.out_channels = self.in_channels * 2 if learn_sigma else in_channels
78 | self.patch_size = patch_size
79 | self.num_heads = num_heads
80 | self.adaln_type = adaln_type
81 | self.online_rope = online_rope
82 | self.time_shifting = time_shifting
83 |
84 | self.x_embedder = PatchEmbedder(in_channels * patch_size**2, hidden_size, bias=True)
85 | self.t_embedder = TimestepEmbedder(hidden_size)
86 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
87 |
88 |
89 |
90 | self.rel_pos_embed = VisionRotaryEmbedding(
91 | head_dim=hidden_size//num_heads, theta=rope_theta, custom_freqs=custom_freqs, online_rope=online_rope,
92 | max_pe_len_h=max_pe_len_h, max_pe_len_w=max_pe_len_w, decouple=decouple, ori_max_pe_len=ori_max_pe_len,
93 | )
94 |
95 | if adaln_type == 'lora':
96 | self.global_adaLN_modulation = nn.Sequential(
97 | nn.SiLU(),
98 | nn.Linear(hidden_size, 6 * hidden_size, bias=adaln_bias)
99 | )
100 | else:
101 | self.global_adaLN_modulation = None
102 |
103 | self.blocks = nn.ModuleList([FiTBlock(
104 | hidden_size, num_heads, mlp_ratio=mlp_ratio, swiglu=use_swiglu, swiglu_large=use_swiglu_large,
105 | rel_pos_embed=rel_pos_embed, add_rel_pe_to_v=add_rel_pe_to_v, norm_layer=norm_type,
106 | q_norm=q_norm, k_norm=k_norm, qk_norm_weight=qk_norm_weight, qkv_bias=qkv_bias, ffn_bias=ffn_bias,
107 | adaln_bias=adaln_bias, adaln_type=adaln_type, adaln_lora_dim=adaln_lora_dim
108 | ) for _ in range(depth)])
109 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, norm_layer=norm_type, adaln_bias=adaln_bias, adaln_type=adaln_type)
110 | self.initialize_weights(pretrain_ckpt=pretrain_ckpt, ignore=ignore_keys)
111 | if finetune != None:
112 | self.finetune(type=finetune, unfreeze=ignore_keys)
113 |
114 |
115 | def initialize_weights(self, pretrain_ckpt=None, ignore=None):
116 | # Initialize transformer layers:
117 | def _basic_init(module):
118 | if isinstance(module, nn.Linear):
119 | torch.nn.init.xavier_uniform_(module.weight)
120 | if module.bias is not None:
121 | nn.init.constant_(module.bias, 0)
122 | self.apply(_basic_init)
123 |
124 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
125 | w = self.x_embedder.proj.weight.data
126 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
127 | nn.init.constant_(self.x_embedder.proj.bias, 0)
128 |
129 | # Initialize label embedding table:
130 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
131 |
132 | # Initialize timestep embedding MLP:
133 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
134 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
135 |
136 | # Zero-out adaLN modulation layers in DiT blocks:
137 | for block in self.blocks:
138 | if self.adaln_type in ['normal', 'lora']:
139 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
140 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
141 | elif self.adaln_type == 'swiglu':
142 | nn.init.constant_(block.adaLN_modulation.fc2.weight, 0)
143 | nn.init.constant_(block.adaLN_modulation.fc2.bias, 0)
144 | if self.adaln_type == 'lora':
145 | nn.init.constant_(self.global_adaLN_modulation[-1].weight, 0)
146 | nn.init.constant_(self.global_adaLN_modulation[-1].bias, 0)
147 | # Zero-out output layers:
148 | if self.adaln_type == 'swiglu':
149 | nn.init.constant_(self.final_layer.adaLN_modulation.fc2.weight, 0)
150 | nn.init.constant_(self.final_layer.adaLN_modulation.fc2.bias, 0)
151 | else: # adaln_type in ['normal', 'lora']
152 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
153 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
154 | nn.init.constant_(self.final_layer.linear.weight, 0)
155 | nn.init.constant_(self.final_layer.linear.bias, 0)
156 |
157 | keys = list(self.state_dict().keys())
158 | ignore_keys = []
159 | if ignore != None:
160 | for ign in ignore:
161 | for key in keys:
162 | if ign in key:
163 | ignore_keys.append(key)
164 | ignore_keys = list(set(ignore_keys))
165 | if pretrain_ckpt != None:
166 | init_from_ckpt(self, pretrain_ckpt, ignore_keys, verbose=True)
167 |
168 |
169 | def unpatchify(self, x, hw):
170 | """
171 | args:
172 | x: (B, p**2 * C_out, N)
173 | N = h//p * w//p
174 | return:
175 | imgs: (B, C_out, H, W)
176 | """
177 | h, w = hw
178 | p = self.patch_size
179 | if self.use_sit:
180 | x = rearrange(x, "b (h w) c -> b h w c", h=h//p, w=w//p) # (B, h//2 * w//2, 16) -> (B, h//2, w//2, 16)
181 | x = rearrange(x, "b h w (c p1 p2) -> b c (h p1) (w p2)", p1=p, p2=p) # (B, h//2, w//2, 16) -> (B, h, w, 4)
182 | else:
183 | x = rearrange(x, "b c (h w) -> b c h w", h=h//p, w=w//p) # (B, 16, h//2 * w//2) -> (B, 16, h//2, w//2)
184 | x = rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=p, p2=p) # (B, 16, h//2, w//2) -> (B, h, w, 4)
185 | return x
186 |
187 | def forward(self, x, t, y, grid, mask, size=None):
188 | """
189 | Forward pass of FiT.
190 | x: (B, p**2 * C_in, N), tensor of sequential inputs (flattened latent features of images, N=H*W/(p**2))
191 | t: (B,), tensor of diffusion timesteps
192 | y: (B,), tensor of class labels
193 | grid: (B, 2, N), tensor of height and weight indices that spans a grid
194 | mask: (B, N), tensor of the mask for the sequence
195 | size: (B, n, 2), tensor of the height and width, n is the number of the packed iamges
196 | --------------------------------------------------------------------------------------------
197 | return: (B, p**2 * C_out, N), where C_out=2*C_in if leran_sigma, C_out=C_in otherwise.
198 | """
199 |
200 | t = torch.clamp(self.time_shifting * t / (1 + (self.time_shifting - 1) * t), max=1.0)
201 | t = t.float().to(x.dtype)
202 | if not self.use_sit:
203 | x = rearrange(x, 'B C N -> B N C') # (B, C, N) -> (B, N, C), where C = p**2 * C_in
204 | x = self.x_embedder(x) # (B, N, C) -> (B, N, D)
205 | t = self.t_embedder(t) # (B, D)
206 | y = self.y_embedder(y, self.training) # (B, D)
207 | c = t + y # (B, D)
208 |
209 | # get RoPE frequences in advance, then calculate attention.
210 | if self.online_rope:
211 | freqs_cos, freqs_sin = self.rel_pos_embed.online_get_2d_rope_from_grid(grid, size)
212 | freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
213 | else:
214 | freqs_cos, freqs_sin = self.rel_pos_embed.get_cached_2d_rope_from_grid(grid)
215 | freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
216 | if self.global_adaLN_modulation != None:
217 | global_adaln = self.global_adaLN_modulation(c)
218 | else:
219 | global_adaln = 0.0
220 |
221 | if not self.use_checkpoint:
222 | for block in self.blocks: # (B, N, D)
223 | x = block(x, c, mask, freqs_cos, freqs_sin, global_adaln)
224 | else:
225 | for block in self.blocks:
226 | x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, c, mask, freqs_cos, freqs_sin, global_adaln)
227 | x = self.final_layer(x, c) # (B, N, p ** 2 * C_out), where C_out=2*C_in if leran_sigma, C_out=C_in otherwise.
228 | x = x * mask[..., None] # mask the padding tokens
229 | if not self.use_sit:
230 | x = rearrange(x, 'B N C -> B C N') # (B, N, C) -> (B, C, N), where C = p**2 * C_out
231 | return x
232 |
233 |
234 |
235 | def forward_with_cfg(self, x, t, y, grid, mask, size, cfg_scale, scale_pow=0.0):
236 | """
237 | Forward pass with classifier free guidance of FiT.
238 | x: (2B, N, p**2 * C_in) if use_sit else (2B, p**2 * C_in, N) tensor of sequential inputs (flattened latent features of images, N=H*W/(p**2))
239 | t: (2B,) tensor of diffusion timesteps
240 | y: (2B,) tensor of class labels
241 | grid: (2B, 2, N): tensor of height and weight indices that spans a grid
242 | mask: (2B, N): tensor of the mask for the sequence
243 | cfg_scale: float > 1.0
244 | return: (B, p**2 * C_out, N), where C_out=2*C_in if leran_sigma, C_out=C_in otherwise.
245 | """
246 | half = x[: len(x) // 2] # (2B, ...) -> (B, ...)
247 | combined = torch.cat([half, half], dim=0) # (2B, ...)
248 | model_out = self.forward(combined, t, y, grid, mask, size) # (2B, N, C) is use_sit else (2B, C, N) , where C = p**2 * C_out
249 | # For exact reproducibility reasons, we apply classifier-free guidance on only
250 | # three channels by default. The standard approach to cfg applies it to all channels.
251 | # This can be done by uncommenting the following line and commenting-out the line following that.
252 | # C_cfg = self.in_channels * self.patch_size * self.patch_size
253 | C_cfg = 3 * self.patch_size * self.patch_size
254 | if self.use_sit:
255 | eps, rest = model_out[:, :, :C_cfg], model_out[:, :, C_cfg:] # eps: (2B, N, C_cfg), where C_cfg = p**2 * 3
256 | else:
257 | eps, rest = model_out[:, :C_cfg], model_out[:, C_cfg:] # eps: (2B, C_cfg, N), where C_cfg = p**2 * 3
258 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) # (2B, C_cfg, H, W) -> (B, C_cfg, H, W)
259 | # from https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py#L506
260 | # improved classifier-free guidance
261 | if scale_pow == 0.0:
262 | real_cfg_scale = cfg_scale
263 | else:
264 | scale_step = (
265 | 1-torch.cos(((1-torch.clamp_max(t, 1.0))**scale_pow)*torch.pi)
266 | )*1/2 # power-cos scaling
267 | real_cfg_scale = (cfg_scale-1)*scale_step + 1
268 | real_cfg_scale = real_cfg_scale[: len(x) // 2].view(-1, 1, 1)
269 | t = t / (self.time_shifting + (1 - self.time_shifting) * t)
270 | half_eps = uncond_eps + real_cfg_scale * (cond_eps - uncond_eps) # (B, ...)
271 | eps = torch.cat([half_eps, half_eps], dim=0) # (B, ...) -> (2B, ...)
272 | if self.use_sit:
273 | return torch.cat([eps, rest], dim=2) # (2B, N, C), where C = p**2 * C_out
274 | else:
275 | return torch.cat([eps, rest], dim=1) # (2B, C, N), where C = p**2 * C_out
276 |
277 | def ckpt_wrapper(self, module):
278 | def ckpt_forward(*inputs):
279 | outputs = module(*inputs)
280 | return outputs
281 | return ckpt_forward
282 |
283 |
284 | @property
285 | def dtype(self) -> torch.dtype:
286 | """
287 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
288 | """
289 | return get_parameter_dtype(self)
290 |
291 |
292 | def finetune(self, type, unfreeze):
293 | if type == 'full':
294 | return
295 | for name, param in self.named_parameters():
296 | param.requires_grad = False
297 | for unf in unfreeze:
298 | for name, param in self.named_parameters():
299 | if unf in name: # LN means Layer Norm
300 | param.requires_grad = True
301 |
--------------------------------------------------------------------------------
/fit/model/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from torch import nn, Tensor
6 | from torch.jit import Final
7 | from timm.layers.mlp import SwiGLU, Mlp
8 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
9 | from fit.model.rope import rotate_half
10 | from fit.model.utils import modulate
11 | from fit.model.norms import create_norm
12 | from functools import partial
13 | from einops import rearrange, repeat
14 |
15 | #################################################################################
16 | # Embedding Layers for Patches, Timesteps and Class Labels #
17 | #################################################################################
18 |
19 | class PatchEmbedder(nn.Module):
20 | """
21 | Embeds latent features into vector representations
22 | """
23 | def __init__(self,
24 | input_dim,
25 | embed_dim,
26 | bias: bool = True,
27 | norm_layer: Optional[Callable] = None,
28 | ):
29 | super().__init__()
30 |
31 | self.proj = nn.Linear(input_dim, embed_dim, bias=bias)
32 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
33 |
34 | def forward(self, x):
35 | x = self.proj(x) # (B, L, patch_size ** 2 * C) -> (B, L, D)
36 | x = self.norm(x)
37 | return x
38 |
39 | class TimestepEmbedder(nn.Module):
40 | """
41 | Embeds scalar timesteps into vector representations.
42 | """
43 | def __init__(self, hidden_size, frequency_embedding_size=256):
44 | super().__init__()
45 | self.mlp = nn.Sequential(
46 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
47 | nn.SiLU(),
48 | nn.Linear(hidden_size, hidden_size, bias=True),
49 | )
50 | self.frequency_embedding_size = frequency_embedding_size
51 |
52 | @staticmethod
53 | def timestep_embedding(t, dim, max_period=10000):
54 | """
55 | Create sinusoidal timestep embeddings.
56 | :param t: a 1-D Tensor of N indices, one per batch element.
57 | These may be fractional.
58 | :param dim: the dimension of the output.
59 | :param max_period: controls the minimum frequency of the embeddings.
60 | :return: an (N, D) Tensor of positional embeddings.
61 | """
62 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
63 | half = dim // 2
64 | freqs = torch.exp(
65 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
66 | ).to(device=t.device)
67 | args = t[:, None] * freqs[None]
68 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
69 | if dim % 2:
70 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1]).to(device=t.device)], dim=-1)
71 | return embedding.to(dtype=t.dtype)
72 |
73 | def forward(self, t):
74 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
75 | t_emb = self.mlp(t_freq)
76 | return t_emb
77 |
78 |
79 | class LabelEmbedder(nn.Module):
80 | """
81 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
82 | """
83 | def __init__(self, num_classes, hidden_size, dropout_prob):
84 | super().__init__()
85 | use_cfg_embedding = dropout_prob > 0
86 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
87 | self.num_classes = num_classes
88 | self.dropout_prob = dropout_prob
89 |
90 | def token_drop(self, labels, force_drop_ids=None):
91 | """
92 | Drops labels to enable classifier-free guidance.
93 | """
94 | if force_drop_ids is None:
95 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
96 | else:
97 | drop_ids = force_drop_ids == 1
98 | labels = torch.where(drop_ids, self.num_classes, labels)
99 | return labels
100 |
101 | def forward(self, labels, train, force_drop_ids=None):
102 | use_dropout = self.dropout_prob > 0
103 | if (train and use_dropout) or (force_drop_ids is not None):
104 | labels = self.token_drop(labels, force_drop_ids)
105 | embeddings = self.embedding_table(labels)
106 | return embeddings
107 |
108 |
109 |
110 |
111 | #################################################################################
112 | # Attention #
113 | #################################################################################
114 |
115 | # modified from timm and eva-02
116 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
117 | # https://github.com/baaivision/EVA/blob/master/EVA-02/asuka/modeling_finetune.py
118 |
119 |
120 | class Attention(nn.Module):
121 |
122 | def __init__(self,
123 | dim: int,
124 | num_heads: int = 8,
125 | qkv_bias: bool = False,
126 | q_norm: Optional[str] = None,
127 | k_norm: Optional[str] = None,
128 | qk_norm_weight: bool = False,
129 | attn_drop: float = 0.,
130 | proj_drop: float = 0.,
131 | rel_pos_embed: Optional[str] = None,
132 | add_rel_pe_to_v: bool = False,
133 | ) -> None:
134 | super().__init__()
135 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
136 | self.num_heads = num_heads
137 | self.head_dim = dim // num_heads
138 | self.scale = self.head_dim ** -0.5
139 |
140 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
141 | if q_norm == 'layernorm' and qk_norm_weight == True:
142 | q_norm = 'w_layernorm'
143 | if k_norm == 'layernorm' and qk_norm_weight == True:
144 | k_norm = 'w_layernorm'
145 |
146 | self.q_norm = create_norm(q_norm, self.head_dim)
147 | self.k_norm = create_norm(k_norm, self.head_dim)
148 |
149 |
150 | self.attn_drop = nn.Dropout(attn_drop)
151 | self.proj = nn.Linear(dim, dim)
152 | self.proj_drop = nn.Dropout(proj_drop)
153 |
154 | self.rel_pos_embed = None if rel_pos_embed==None else rel_pos_embed.lower()
155 | self.add_rel_pe_to_v = add_rel_pe_to_v
156 |
157 |
158 |
159 | def forward(self,
160 | x: torch.Tensor,
161 | mask: Optional[torch.Tensor] = None,
162 | freqs_cos: Optional[torch.Tensor] = None,
163 | freqs_sin: Optional[torch.Tensor] = None,
164 | ) -> torch.Tensor:
165 | B, N, C = x.shape
166 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
167 | q, k, v = qkv.unbind(0) # (B, n_h, N, D_h)
168 | q, k = self.q_norm(q), self.k_norm(k)
169 |
170 | if self.rel_pos_embed in ['rope', 'xpos']: # multiplicative rel_pos_embed
171 | if self.add_rel_pe_to_v:
172 | v = v * freqs_cos + rotate_half(v) * freqs_sin
173 | q = q * freqs_cos + rotate_half(q) * freqs_sin
174 | k = k * freqs_cos + rotate_half(k) * freqs_sin
175 |
176 | attn_mask = mask[:, None, None, :] # (B, N) -> (B, 1, 1, N)
177 | attn_mask = (attn_mask == attn_mask.transpose(-2, -1)) # (B, 1, 1, N) x (B, 1, N, 1) -> (B, 1, N, N)
178 | mask = torch.not_equal(mask, torch.zeros_like(mask)).to(mask) # (B, N) -> (B, N)
179 |
180 |
181 | if x.device.type == "cpu":
182 | x = F.scaled_dot_product_attention(
183 | q, k, v, attn_mask=attn_mask,
184 | dropout_p=self.attn_drop.p if self.training else 0.,
185 | )
186 | else:
187 | with torch.backends.cuda.sdp_kernel(enable_flash=True):
188 | '''
189 | F.scaled_dot_product_attention is the efficient implementation equivalent to the following:
190 | attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
191 | attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
192 | attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1)
193 | attn_weight = torch.dropout(attn_weight, dropout_p)
194 | return attn_weight @ V
195 | In conclusion:
196 | boolean attn_mask will mask the attention matrix where attn_mask is False
197 | non-boolean attn_mask will be directly added to Q@K.T
198 | '''
199 | x = F.scaled_dot_product_attention(
200 | q, k, v, attn_mask=attn_mask,
201 | dropout_p=self.attn_drop.p if self.training else 0.,
202 | )
203 | x = x.transpose(1, 2).reshape(B, N, C)
204 | x = x * mask[..., None] # mask: (B, N) -> (B, N, 1)
205 | x = self.proj(x)
206 | x = self.proj_drop(x)
207 | return x
208 |
209 | #################################################################################
210 | # Basic FiT Module #
211 | #################################################################################
212 |
213 | class FiTBlock(nn.Module):
214 | """
215 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
216 | """
217 | def __init__(self,
218 | hidden_size,
219 | num_heads,
220 | mlp_ratio=4.0,
221 | swiglu=True,
222 | swiglu_large=False,
223 | rel_pos_embed=None,
224 | add_rel_pe_to_v=False,
225 | norm_layer: str = 'layernorm',
226 | q_norm: Optional[str] = None,
227 | k_norm: Optional[str] = None,
228 | qk_norm_weight: bool = False,
229 | qkv_bias=True,
230 | ffn_bias=True,
231 | adaln_bias=True,
232 | adaln_type='normal',
233 | adaln_lora_dim: int = None,
234 | **block_kwargs
235 | ):
236 | super().__init__()
237 | self.norm1 = create_norm(norm_layer, hidden_size)
238 | self.norm2 = create_norm(norm_layer, hidden_size)
239 |
240 | self.attn = Attention(
241 | hidden_size, num_heads=num_heads, rel_pos_embed=rel_pos_embed,
242 | q_norm=q_norm, k_norm=k_norm, qk_norm_weight=qk_norm_weight,
243 | qkv_bias=qkv_bias, add_rel_pe_to_v=add_rel_pe_to_v,
244 | **block_kwargs
245 | )
246 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
247 | if swiglu:
248 | if swiglu_large:
249 | self.mlp = SwiGLU(in_features=hidden_size, hidden_features=mlp_hidden_dim, bias=ffn_bias)
250 | else:
251 | self.mlp = SwiGLU(in_features=hidden_size, hidden_features=(mlp_hidden_dim*2)//3, bias=ffn_bias)
252 | else:
253 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=ffn_bias)
254 | if adaln_type == 'normal':
255 | self.adaLN_modulation = nn.Sequential(
256 | nn.SiLU(),
257 | nn.Linear(hidden_size, 6 * hidden_size, bias=adaln_bias)
258 | )
259 | elif adaln_type == 'lora':
260 | self.adaLN_modulation = nn.Sequential(
261 | nn.SiLU(),
262 | nn.Linear(hidden_size, adaln_lora_dim, bias=adaln_bias),
263 | nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=adaln_bias)
264 | )
265 | elif adaln_type == 'swiglu':
266 | self.adaLN_modulation = SwiGLU(
267 | in_features=hidden_size, hidden_features=(hidden_size//4)*3, out_features=6*hidden_size, bias=adaln_bias
268 | )
269 |
270 | def forward(self, x, c, mask, freqs_cos, freqs_sin, global_adaln=0.0):
271 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.adaLN_modulation(c) + global_adaln).chunk(6, dim=1)
272 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask, freqs_cos, freqs_sin)
273 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
274 | return x
275 |
276 | class FinalLayer(nn.Module):
277 | """
278 | The final layer of DiT.
279 | """
280 | def __init__(self, hidden_size, patch_size, out_channels, norm_layer: str = 'layernorm', adaln_bias=True, adaln_type='normal'):
281 | super().__init__()
282 | self.norm_final = create_norm(norm_type=norm_layer, dim=hidden_size)
283 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
284 | if adaln_type == 'swiglu':
285 | self.adaLN_modulation = SwiGLU(in_features=hidden_size, hidden_features=hidden_size//2, out_features=2*hidden_size, bias=adaln_bias)
286 | else: # adaln_type in ['normal', 'lora']
287 | self.adaLN_modulation = nn.Sequential(
288 | nn.SiLU(),
289 | nn.Linear(hidden_size, 2 * hidden_size, bias=adaln_bias)
290 | )
291 |
292 | def forward(self, x, c):
293 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
294 | x = modulate(self.norm_final(x), shift, scale)
295 | x = self.linear(x)
296 | return x
297 |
--------------------------------------------------------------------------------
/fit/model/norms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | from functools import partial
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | import triton
15 | import triton.language as tl
16 |
17 |
18 |
19 | def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
20 | """
21 | Creates the specified normalization layer based on the norm_type.
22 |
23 | Args:
24 | norm_type (str): The type of normalization layer to create.
25 | Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
26 | dim (int): The dimension of the normalization layer.
27 | eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
28 |
29 | Returns:
30 | The created normalization layer.
31 |
32 | Raises:
33 | NotImplementedError: If an unknown norm_type is provided.
34 | """
35 | if norm_type == None or norm_type == "":
36 | return nn.Identity()
37 | norm_type = norm_type.lower() # Normalize to lowercase
38 |
39 | if norm_type == "w_layernorm":
40 | return nn.LayerNorm(dim, eps=eps, bias=False)
41 | elif norm_type == "layernorm":
42 | return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
43 | elif norm_type == "w_rmsnorm":
44 | return RMSNorm(dim, eps=eps)
45 | elif norm_type == "rmsnorm":
46 | return RMSNorm(dim, include_weight=False, eps=eps)
47 | elif norm_type == 'none':
48 | return nn.Identity()
49 | else:
50 | raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
51 |
52 |
53 | class RMSNorm(nn.Module):
54 | """
55 | Initialize the RMSNorm normalization layer.
56 |
57 | Args:
58 | dim (int): The dimension of the input tensor.
59 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
60 |
61 | Attributes:
62 | eps (float): A small value added to the denominator for numerical stability.
63 | weight (nn.Parameter): Learnable scaling parameter.
64 |
65 | """
66 |
67 | def __init__(self, dim: int, eps: float = 1e-6):
68 | super().__init__()
69 | self.eps = eps
70 | self.weight = nn.Parameter(torch.ones(dim))
71 |
72 | def _norm(self, x: torch.Tensor):
73 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
74 |
75 | def forward(self, x: torch.Tensor):
76 | output = self._norm(x.float()).type_as(x)
77 | return output * self.weight
78 |
79 | def reset_parameters(self):
80 | torch.nn.init.ones_(self.weight) # type: ignore
81 |
82 |
--------------------------------------------------------------------------------
/fit/model/rope.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # FiT: A Flexible Vision Transformer for Image Generation
3 | #
4 | # Based on the following repository
5 | # https://github.com/lucidrains/rotary-embedding-torch
6 | # https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope
7 | # https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37
8 | # --------------------------------------------------------
9 |
10 | import math
11 | from math import pi
12 | from typing import Optional, Any, Union, Tuple
13 | import torch
14 | from torch import nn
15 |
16 | from einops import rearrange, repeat
17 | from functools import lru_cache
18 |
19 | #################################################################################
20 | # NTK Operations #
21 | #################################################################################
22 |
23 | def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
24 | return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations
25 |
26 | def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
27 | low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
28 | high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
29 | return max(low, 0), min(high, dim-1) #Clamp values just in case
30 |
31 | def linear_ramp_mask(min, max, dim):
32 | if min == max:
33 | max += 0.001 #Prevent singularity
34 |
35 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
36 | ramp_func = torch.clamp(linear_func, 0, 1)
37 | return ramp_func
38 |
39 | def find_newbase_ntk(dim, base=10000, scale=1):
40 | # Base change formula
41 | return base * scale ** (dim / (dim-2))
42 |
43 | def get_mscale(scale=torch.Tensor):
44 | # if scale <= 1:
45 | # return 1.0
46 | # return 0.1 * math.log(scale) + 1.0
47 | return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
48 |
49 | def get_proportion(L_test, L_train):
50 | L_test = L_test * 2
51 | return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))))
52 | # return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))
53 |
54 |
55 |
56 | #################################################################################
57 | # Rotate Q or K #
58 | #################################################################################
59 |
60 | def rotate_half(x):
61 | x = rearrange(x, '... (d r) -> ... d r', r = 2)
62 | x1, x2 = x.unbind(dim = -1)
63 | x = torch.stack((-x2, x1), dim = -1)
64 | return rearrange(x, '... d r -> ... (d r)')
65 |
66 |
67 |
68 | #################################################################################
69 | # Core Vision RoPE #
70 | #################################################################################
71 |
72 | class VisionRotaryEmbedding(nn.Module):
73 | def __init__(
74 | self,
75 | head_dim: int, # embed dimension for each head
76 | custom_freqs: str = 'normal',
77 | theta: int = 10000,
78 | online_rope: bool = False,
79 | max_cached_len: int = 256,
80 | max_pe_len_h: Optional[int] = None,
81 | max_pe_len_w: Optional[int] = None,
82 | decouple: bool = False,
83 | ori_max_pe_len: Optional[int] = None,
84 | ):
85 | super().__init__()
86 |
87 | dim = head_dim // 2
88 | assert dim % 2 == 0 # accually, this is important
89 | self.dim = dim
90 | self.custom_freqs = custom_freqs.lower()
91 | self.theta = theta
92 | self.decouple = decouple
93 | self.ori_max_pe_len = ori_max_pe_len
94 |
95 | self.custom_freqs = custom_freqs.lower()
96 | if not online_rope:
97 | if self.custom_freqs == 'normal':
98 | freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
99 | freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
100 | else:
101 | if decouple:
102 | freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len)
103 | freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len)
104 | else:
105 | max_pe_len = max(max_pe_len_h, max_pe_len_w)
106 | freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
107 | freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
108 |
109 | attn_factor = 1.0
110 | scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0) # dynamic scale
111 | self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation
112 | self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
113 | self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2)
114 |
115 | self.register_buffer('freqs_h', freqs_h, persistent=False)
116 | self.register_buffer('freqs_w', freqs_w, persistent=False)
117 |
118 | freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h)
119 | freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2)
120 | self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False)
121 | freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w)
122 | freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2)
123 | self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False)
124 |
125 |
126 | def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len):
127 | # scaling operations for extrapolation
128 | assert isinstance(ori_max_pe_len, int)
129 | # scale = max_pe_len / ori_max_pe_len
130 | if not isinstance(max_pe_len, torch.Tensor):
131 | max_pe_len = torch.tensor(max_pe_len)
132 | scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) # dynamic scale
133 |
134 | if self.custom_freqs == 'linear': # equal to position interpolation
135 | freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim))
136 | elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2':
137 | freqs = 1. / torch.pow(
138 | find_newbase_ntk(dim, theta, scale).view(-1, 1),
139 | (torch.arange(0, dim, 2).to(scale).float() / dim)
140 | ).squeeze()
141 | elif self.custom_freqs == 'ntk-by-parts':
142 | #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
143 | #Do not change unless there is a good reason for doing so!
144 | beta_0 = 1.25
145 | beta_1 = 0.75
146 | gamma_0 = 16
147 | gamma_1 = 2
148 | ntk_factor = 1
149 | extrapolation_factor = 1
150 |
151 | #Three RoPE extrapolation/interpolation methods
152 | freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
153 | freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
154 | freqs_ntk = 1. / torch.pow(
155 | find_newbase_ntk(dim, theta, scale).view(-1, 1),
156 | (torch.arange(0, dim, 2).to(scale).float() / dim)
157 | ).squeeze()
158 |
159 | #Combine NTK and Linear
160 | low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
161 | freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor
162 | freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
163 |
164 | #Combine Extrapolation and NTK and Linear
165 | low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
166 | freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor
167 | freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
168 |
169 | elif self.custom_freqs == 'yarn':
170 | #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
171 | #Do not change unless there is a good reason for doing so!
172 | beta_fast = 32
173 | beta_slow = 1
174 | extrapolation_factor = 1
175 |
176 | freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))
177 | freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
178 |
179 | low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
180 | freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
181 | freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
182 | else:
183 | raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!')
184 | return freqs
185 |
186 |
187 | def online_get_2d_rope_from_grid(self, grid, size):
188 | '''
189 | grid: (B, 2, N)
190 | N = H * W
191 | the first dimension represents width, and the second reprensents height
192 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
193 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
194 | size: (B, 1, 2), h goes first and w goes last
195 | '''
196 | size = size.squeeze() # (B, 1, 2) -> (B, 2)
197 | if self.decouple:
198 | size_h = size[:, 0]
199 | size_w = size[:, 1]
200 | freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len)
201 | freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len)
202 | else:
203 | size_max = torch.max(size[:, 0], size[:, 1])
204 | freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
205 | freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
206 | freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :]
207 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
208 |
209 | freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :]
210 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
211 |
212 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
213 |
214 | if self.custom_freqs == 'yarn':
215 | freqs_cos = freqs.cos() * self.mscale[:, None, None]
216 | freqs_sin = freqs.sin() * self.mscale[:, None, None]
217 | elif self.custom_freqs == 'ntk-aware-pro1':
218 | freqs_cos = freqs.cos() * self.proportion1[:, None, None]
219 | freqs_sin = freqs.sin() * self.proportion1[:, None, None]
220 | elif self.custom_freqs == 'ntk-aware-pro2':
221 | freqs_cos = freqs.cos() * self.proportion2[:, None, None]
222 | freqs_sin = freqs.sin() * self.proportion2[:, None, None]
223 | else:
224 | freqs_cos = freqs.cos()
225 | freqs_sin = freqs.sin()
226 |
227 | return freqs_cos, freqs_sin
228 |
229 | @lru_cache()
230 | def get_2d_rope_from_grid(self, grid):
231 | '''
232 | grid: (B, 2, N)
233 | N = H * W
234 | the first dimension represents width, and the second reprensents height
235 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
236 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
237 | '''
238 | freqs_w = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_w)
239 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
240 |
241 | freqs_h = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_h)
242 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
243 |
244 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
245 |
246 | if self.custom_freqs == 'yarn':
247 | freqs_cos = freqs.cos() * self.mscale
248 | freqs_sin = freqs.sin() * self.mscale
249 | elif self.custom_freqs == 'ntk-aware-pro1':
250 | freqs_cos = freqs.cos() * self.proportion1
251 | freqs_sin = freqs.sin() * self.proportion1
252 | elif self.custom_freqs == 'ntk-aware-pro2':
253 | freqs_cos = freqs.cos() * self.proportion2
254 | freqs_sin = freqs.sin() * self.proportion2
255 | else:
256 | freqs_cos = freqs.cos()
257 | freqs_sin = freqs.sin()
258 |
259 | return freqs_cos, freqs_sin
260 |
261 | @lru_cache()
262 | def get_cached_2d_rope_from_grid(self, grid: torch.Tensor):
263 | '''
264 | grid: (B, 2, N)
265 | N = H * W
266 | the first dimension represents width, and the second reprensents height
267 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
268 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
269 | '''
270 | freqs_w, freqs_h = self.freqs_w_cached[grid[:, 0]], self.freqs_h_cached[grid[:, 1]]
271 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
272 |
273 | if self.custom_freqs == 'yarn':
274 | freqs_cos = freqs.cos() * self.mscale
275 | freqs_sin = freqs.sin() * self.mscale
276 | elif self.custom_freqs == 'ntk-aware-pro1':
277 | freqs_cos = freqs.cos() * self.proportion1
278 | freqs_sin = freqs.sin() * self.proportion1
279 | elif self.custom_freqs == 'ntk-aware-pro2':
280 | freqs_cos = freqs.cos() * self.proportion2
281 | freqs_sin = freqs.sin() * self.proportion2
282 | else:
283 | freqs_cos = freqs.cos()
284 | freqs_sin = freqs.sin()
285 |
286 | return freqs_cos, freqs_sin
287 |
288 | @lru_cache()
289 | def get_cached_21d_rope_from_grid(self, grid: torch.Tensor): # for 3d rope formulation 2 !
290 | '''
291 | grid: (B, 3, N)
292 | N = H * W * T
293 | the first dimension represents width, and the second reprensents height, and the third reprensents time
294 | e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
295 | [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
296 | [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
297 | '''
298 | freqs_w, freqs_h = self.freqs_w_cached[grid[:, 0]+grid[:, 2]], self.freqs_h_cached[grid[:, 1]+grid[:, 2]]
299 | freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
300 |
301 | if self.custom_freqs == 'yarn':
302 | freqs_cos = freqs.cos() * self.mscale
303 | freqs_sin = freqs.sin() * self.mscale
304 | elif self.custom_freqs == 'ntk-aware-pro1':
305 | freqs_cos = freqs.cos() * self.proportion1
306 | freqs_sin = freqs.sin() * self.proportion1
307 | elif self.custom_freqs == 'ntk-aware-pro2':
308 | freqs_cos = freqs.cos() * self.proportion2
309 | freqs_sin = freqs.sin() * self.proportion2
310 | else:
311 | freqs_cos = freqs.cos()
312 | freqs_sin = freqs.sin()
313 |
314 | return freqs_cos, freqs_sin
315 |
316 | def forward(self, x, grid):
317 | '''
318 | x: (B, n_head, N, D)
319 | grid: (B, 2, N)
320 | '''
321 | # freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid)
322 | # freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
323 | # using cache to accelerate, this is the same with the above codes:
324 | freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid)
325 | freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
326 | return x * freqs_cos + rotate_half(x) * freqs_sin
327 |
328 |
--------------------------------------------------------------------------------
/fit/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from typing import List, Tuple
4 |
5 |
6 | def modulate(x, shift, scale):
7 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
8 |
9 |
10 | def get_parameter_dtype(parameter: torch.nn.Module):
11 | try:
12 | params = tuple(parameter.parameters())
13 | if len(params) > 0:
14 | return params[0].dtype
15 |
16 | buffers = tuple(parameter.buffers())
17 | if len(buffers) > 0:
18 | return buffers[0].dtype
19 |
20 | except StopIteration:
21 | # For torch.nn.DataParallel compatibility in PyTorch 1.5
22 |
23 | def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
24 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
25 | return tuples
26 |
27 | gen = parameter._named_members(get_members_fn=find_tensor_attributes)
28 | first_tuple = next(gen)
29 | return first_tuple[1].dtype
30 |
31 |
--------------------------------------------------------------------------------
/fit/scheduler/improved_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | # rescale_timesteps=rescale_timesteps,
46 | )
47 |
--------------------------------------------------------------------------------
/fit/scheduler/improved_diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
90 |
91 | def get_flexible_mask_and_ratio(model_kwargs: dict, x: th.Tensor):
92 | '''
93 | sequential case (fit):
94 | x: (B, C, N)
95 | model_kwargs: {y: (B,), mask: (B, N), grid: (B, 2, N)}
96 | mask: (B, N) -> (B, 1, N)
97 | spatial case (dit):
98 | x: (B, C, H, W)
99 | model_kwargs: {y: (B,)}
100 | mask: (B, C) -> (B, C, 1, 1)
101 | '''
102 | mask = model_kwargs.get('mask', th.ones(x.shape[:2])) # (B, N) or (B, C)
103 | ratio = float(mask.shape[-1]) / th.count_nonzero(mask, dim=-1) # (B,)
104 | if len(x.shape) == 3: # sequential x: (B, C, N)
105 | mask = mask[:, None, :] # (B, N) -> (B, 1, N)
106 | elif len(x.shape) == 4: # spatial x: (B, C, H, W)
107 | mask = mask[..., None, None] # (B, C) -> (B, C, 1, 1)
108 | else:
109 | raise NotImplementedError
110 | return mask.to(x), ratio.to(x)
--------------------------------------------------------------------------------
/fit/scheduler/improved_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/fit/scheduler/improved_diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/fit/scheduler/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import Transport, ModelType, WeightType, PathType, SNRType, Sampler
2 |
3 | def create_transport(
4 | path_type='Linear',
5 | prediction="velocity",
6 | loss_weight=None,
7 | train_eps=None,
8 | sample_eps=None,
9 | snr_type='uniform',
10 | ):
11 | """function for creating Transport object
12 | **Note**: model prediction defaults to velocity
13 | Args:
14 | - path_type: type of path to use; default to linear
15 | - learn_score: set model prediction to score
16 | - learn_noise: set model prediction to noise
17 | - velocity_weighted: weight loss by velocity weight
18 | - likelihood_weighted: weight loss by likelihood weight
19 | - train_eps: small epsilon for avoiding instability during training
20 | - sample_eps: small epsilon for avoiding instability during sampling
21 | """
22 |
23 | if prediction == "noise":
24 | model_type = ModelType.NOISE
25 | elif prediction == "score":
26 | model_type = ModelType.SCORE
27 | else:
28 | model_type = ModelType.VELOCITY
29 |
30 | if loss_weight == "velocity":
31 | loss_type = WeightType.VELOCITY
32 | elif loss_weight == "likelihood":
33 | loss_type = WeightType.LIKELIHOOD
34 | else:
35 | loss_type = WeightType.NONE
36 |
37 | if snr_type == "lognorm":
38 | snr_type = SNRType.LOGNORM
39 | elif snr_type == "uniform":
40 | snr_type = SNRType.UNIFORM
41 | else:
42 | raise ValueError(f"Invalid snr type {snr_type}")
43 |
44 | path_choice = {
45 | "Linear": PathType.LINEAR,
46 | "GVP": PathType.GVP,
47 | "VP": PathType.VP,
48 | }
49 |
50 | path_type = path_choice[path_type]
51 |
52 | if (path_type in [PathType.VP]):
53 | train_eps = 1e-5 if train_eps is None else train_eps
54 | sample_eps = 1e-3 if train_eps is None else sample_eps
55 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
56 | train_eps = 1e-3 if train_eps is None else train_eps
57 | sample_eps = 1e-3 if train_eps is None else sample_eps
58 | else: # velocity & [GVP, LINEAR] is stable everywhere
59 | train_eps = 0
60 | sample_eps = 0
61 |
62 | # create flow state
63 | state = Transport(
64 | model_type=model_type,
65 | path_type=path_type,
66 | loss_type=loss_type,
67 | train_eps=train_eps,
68 | sample_eps=sample_eps,
69 | snr_type=snr_type,
70 | )
71 |
72 | return state
--------------------------------------------------------------------------------
/fit/scheduler/transport/integrators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | import torch.nn as nn
4 | from torchdiffeq import odeint
5 | from functools import partial
6 | from tqdm import tqdm
7 |
8 | class sde:
9 | """SDE solver class"""
10 | def __init__(
11 | self,
12 | drift,
13 | diffusion,
14 | *,
15 | t0,
16 | t1,
17 | num_steps,
18 | sampler_type,
19 | ):
20 | assert t0 < t1, "SDE sampler has to be in forward time"
21 |
22 | self.num_timesteps = num_steps
23 | self.t = th.linspace(t0, t1, num_steps)
24 | self.dt = self.t[1] - self.t[0]
25 | self.drift = drift
26 | self.diffusion = diffusion
27 | self.sampler_type = sampler_type
28 |
29 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
30 | w_cur = th.randn(x.size()).to(x)
31 | t = th.ones(x.size(0)).to(x) * t
32 | dw = w_cur * th.sqrt(self.dt)
33 | drift = self.drift(x, t, model, **model_kwargs)
34 | diffusion = self.diffusion(x, t)
35 | mean_x = x + drift * self.dt
36 | x = mean_x + th.sqrt(2 * diffusion) * dw
37 | return x, mean_x
38 |
39 | def __Heun_step(self, x, _, t, model, **model_kwargs):
40 | w_cur = th.randn(x.size()).to(x)
41 | dw = w_cur * th.sqrt(self.dt)
42 | t_cur = th.ones(x.size(0)).to(x) * t
43 | diffusion = self.diffusion(x, t_cur)
44 | xhat = x + th.sqrt(2 * diffusion) * dw
45 | K1 = self.drift(xhat, t_cur, model, **model_kwargs)
46 | xp = xhat + self.dt * K1
47 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
48 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
49 |
50 | def __forward_fn(self):
51 | """TODO: generalize here by adding all private functions ending with steps to it"""
52 | sampler_dict = {
53 | "Euler": self.__Euler_Maruyama_step,
54 | "Heun": self.__Heun_step,
55 | }
56 |
57 | try:
58 | sampler = sampler_dict[self.sampler_type]
59 | except:
60 | raise NotImplementedError("Smapler type not implemented.")
61 |
62 | return sampler
63 |
64 | def sample(self, init, model, **model_kwargs):
65 | """forward loop of sde"""
66 | x = init
67 | mean_x = init
68 | samples = []
69 | sampler = self.__forward_fn()
70 | for ti in self.t[:-1]:
71 | with th.no_grad():
72 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
73 | samples.append(x)
74 |
75 | return samples
76 |
77 | class ode:
78 | """ODE solver class"""
79 | def __init__(
80 | self,
81 | drift,
82 | *,
83 | t0,
84 | t1,
85 | sampler_type,
86 | num_steps,
87 | atol,
88 | rtol,
89 | ):
90 | assert t0 < t1, "ODE sampler has to be in forward time"
91 |
92 | self.drift = drift
93 | self.t = th.linspace(t0, t1, num_steps)
94 | self.atol = atol
95 | self.rtol = rtol
96 | self.sampler_type = sampler_type
97 |
98 | def sample(self, x, model, **model_kwargs):
99 |
100 | device = x[0].device if isinstance(x, tuple) else x.device
101 | def _fn(t, x):
102 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
103 | model_output = self.drift(x, t, model, **model_kwargs)
104 | return model_output
105 |
106 | t = self.t.to(device)
107 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
108 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
109 | samples = odeint(
110 | _fn,
111 | x,
112 | t,
113 | method=self.sampler_type,
114 | atol=atol,
115 | rtol=rtol
116 | )
117 | return samples
--------------------------------------------------------------------------------
/fit/scheduler/transport/path.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | from functools import partial
4 |
5 | def expand_t_like_x(t, x):
6 | """Function to reshape time t to broadcastable dimension of x
7 | Args:
8 | t: [batch_dim,], time vector
9 | x: [batch_dim,...], data point
10 | """
11 | dims = [1] * (len(x.size()) - 1)
12 | t = t.view(t.size(0), *dims)
13 | return t
14 |
15 |
16 | #################### Coupling Plans ####################
17 |
18 | class ICPlan:
19 | """Linear Coupling Plan"""
20 | def __init__(self, sigma=0.0):
21 | self.sigma = sigma
22 |
23 | def compute_alpha_t(self, t):
24 | """Compute the data coefficient along the path"""
25 | return t, 1
26 |
27 | def compute_sigma_t(self, t):
28 | """Compute the noise coefficient along the path"""
29 | return 1 - t, -1
30 |
31 | def compute_d_alpha_alpha_ratio_t(self, t):
32 | """Compute the ratio between d_alpha and alpha"""
33 | return 1 / t
34 |
35 | def compute_drift(self, x, t):
36 | """We always output sde according to score parametrization; """
37 | t = expand_t_like_x(t, x)
38 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
39 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
40 | drift = alpha_ratio * x
41 | diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
42 |
43 | return -drift, diffusion
44 |
45 | def compute_diffusion(self, x, t, form="constant", norm=1.0):
46 | """Compute the diffusion term of the SDE
47 | Args:
48 | x: [batch_dim, ...], data point
49 | t: [batch_dim,], time vector
50 | form: str, form of the diffusion term
51 | norm: float, norm of the diffusion term
52 | """
53 | t = expand_t_like_x(t, x)
54 | choices = {
55 | "constant": th.tensor(norm).to(x),
56 | "SBDM": norm * self.compute_drift(x, t)[1],
57 | "sigma": norm * self.compute_sigma_t(t)[0],
58 | "linear": norm * (1 - t),
59 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
60 | "increasing-decreasing": norm * th.sin(np.pi * t) ** 2,
61 | }
62 |
63 | try:
64 | diffusion = choices[form]
65 | except KeyError:
66 | raise NotImplementedError(f"Diffusion form {form} not implemented")
67 |
68 | return diffusion
69 |
70 | def get_score_from_velocity(self, velocity, x, t):
71 | """Wrapper function: transfrom velocity prediction model to score
72 | Args:
73 | velocity: [batch_dim, ...] shaped tensor; velocity model output
74 | x: [batch_dim, ...] shaped tensor; x_t data point
75 | t: [batch_dim,] time tensor
76 | """
77 | t = expand_t_like_x(t, x)
78 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
79 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
80 | mean = x
81 | reverse_alpha_ratio = alpha_t / d_alpha_t
82 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
83 | score = (reverse_alpha_ratio * velocity - mean) / var
84 | return score
85 |
86 | def get_noise_from_velocity(self, velocity, x, t):
87 | """Wrapper function: transfrom velocity prediction model to denoiser
88 | Args:
89 | velocity: [batch_dim, ...] shaped tensor; velocity model output
90 | x: [batch_dim, ...] shaped tensor; x_t data point
91 | t: [batch_dim,] time tensor
92 | """
93 | t = expand_t_like_x(t, x)
94 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
95 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
96 | mean = x
97 | reverse_alpha_ratio = alpha_t / d_alpha_t
98 | var = reverse_alpha_ratio * d_sigma_t - sigma_t
99 | noise = (reverse_alpha_ratio * velocity - mean) / var
100 | return noise
101 |
102 | def get_velocity_from_score(self, score, x, t):
103 | """Wrapper function: transfrom score prediction model to velocity
104 | Args:
105 | score: [batch_dim, ...] shaped tensor; score model output
106 | x: [batch_dim, ...] shaped tensor; x_t data point
107 | t: [batch_dim,] time tensor
108 | """
109 | t = expand_t_like_x(t, x)
110 | drift, var = self.compute_drift(x, t)
111 | velocity = var * score - drift
112 | return velocity
113 |
114 | def compute_mu_t(self, t, x0, x1):
115 | """Compute the mean of time-dependent density p_t"""
116 | t = expand_t_like_x(t, x1)
117 | alpha_t, _ = self.compute_alpha_t(t)
118 | sigma_t, _ = self.compute_sigma_t(t)
119 | return alpha_t * x1 + sigma_t * x0
120 |
121 | def compute_xt(self, t, x0, x1):
122 | """Sample xt from time-dependent density p_t; rng is required"""
123 | xt = self.compute_mu_t(t, x0, x1)
124 | return xt
125 |
126 | def compute_ut(self, t, x0, x1, xt):
127 | """Compute the vector field corresponding to p_t"""
128 | t = expand_t_like_x(t, x1)
129 | _, d_alpha_t = self.compute_alpha_t(t)
130 | _, d_sigma_t = self.compute_sigma_t(t)
131 | return d_alpha_t * x1 + d_sigma_t * x0
132 |
133 | def plan(self, t, x0, x1):
134 | xt = self.compute_xt(t, x0, x1)
135 | ut = self.compute_ut(t, x0, x1, xt)
136 | return t, xt, ut
137 |
138 |
139 | class VPCPlan(ICPlan):
140 | """class for VP path flow matching"""
141 |
142 | def __init__(self, sigma_min=0.1, sigma_max=20.0):
143 | self.sigma_min = sigma_min
144 | self.sigma_max = sigma_max
145 | self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
146 | self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
147 |
148 |
149 | def compute_alpha_t(self, t):
150 | """Compute coefficient of x1"""
151 | alpha_t = self.log_mean_coeff(t)
152 | alpha_t = th.exp(alpha_t)
153 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
154 | return alpha_t, d_alpha_t
155 |
156 | def compute_sigma_t(self, t):
157 | """Compute coefficient of x0"""
158 | p_sigma_t = 2 * self.log_mean_coeff(t)
159 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
160 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
161 | return sigma_t, d_sigma_t
162 |
163 | def compute_d_alpha_alpha_ratio_t(self, t):
164 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
165 | return self.d_log_mean_coeff(t)
166 |
167 | def compute_drift(self, x, t):
168 | """Compute the drift term of the SDE"""
169 | t = expand_t_like_x(t, x)
170 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
171 | return -0.5 * beta_t * x, beta_t / 2
172 |
173 |
174 | class GVPCPlan(ICPlan):
175 | def __init__(self, sigma=0.0):
176 | super().__init__(sigma)
177 |
178 | def compute_alpha_t(self, t):
179 | """Compute coefficient of x1"""
180 | alpha_t = th.sin(t * np.pi / 2)
181 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
182 | return alpha_t, d_alpha_t
183 |
184 | def compute_sigma_t(self, t):
185 | """Compute coefficient of x0"""
186 | sigma_t = th.cos(t * np.pi / 2)
187 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
188 | return sigma_t, d_sigma_t
189 |
190 | def compute_d_alpha_alpha_ratio_t(self, t):
191 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
192 | return np.pi / (2 * th.tan(t * np.pi / 2))
--------------------------------------------------------------------------------
/fit/scheduler/transport/transport.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | import logging
4 |
5 | import enum
6 |
7 | from . import path
8 | from .utils import EasyDict, log_state, mean_flat, get_flexible_mask_and_ratio
9 | from .integrators import ode, sde
10 |
11 | class ModelType(enum.Enum):
12 | """
13 | Which type of output the model predicts.
14 | """
15 |
16 | NOISE = enum.auto() # the model predicts epsilon
17 | SCORE = enum.auto() # the model predicts \nabla \log p(x)
18 | VELOCITY = enum.auto() # the model predicts v(x)
19 |
20 | class PathType(enum.Enum):
21 | """
22 | Which type of path to use.
23 | """
24 |
25 | LINEAR = enum.auto()
26 | GVP = enum.auto()
27 | VP = enum.auto()
28 |
29 | class WeightType(enum.Enum):
30 | """
31 | Which type of weighting to use.
32 | """
33 |
34 | NONE = enum.auto()
35 | VELOCITY = enum.auto()
36 | LIKELIHOOD = enum.auto()
37 |
38 |
39 | class SNRType(enum.Enum):
40 | UNIFORM = enum.auto()
41 | LOGNORM = enum.auto()
42 |
43 |
44 | class Transport:
45 |
46 | def __init__(
47 | self,
48 | *,
49 | model_type,
50 | path_type,
51 | loss_type,
52 | train_eps,
53 | sample_eps,
54 | snr_type
55 | ):
56 | path_options = {
57 | PathType.LINEAR: path.ICPlan,
58 | PathType.GVP: path.GVPCPlan,
59 | PathType.VP: path.VPCPlan,
60 | }
61 |
62 | self.loss_type = loss_type
63 | self.model_type = model_type
64 | self.path_sampler = path_options[path_type]()
65 | self.train_eps = train_eps
66 | self.sample_eps = sample_eps
67 |
68 | self.snr_type = snr_type
69 |
70 | def prior_logp(self, z):
71 | '''
72 | Standard multivariate normal prior
73 | Assume z is batched
74 | '''
75 | shape = th.tensor(z.size())
76 | N = th.prod(shape[1:])
77 | _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
78 | return th.vmap(_fn)(z)
79 |
80 |
81 | def check_interval(
82 | self,
83 | train_eps,
84 | sample_eps,
85 | *,
86 | diffusion_form="SBDM",
87 | sde=False,
88 | reverse=False,
89 | eval=False,
90 | last_step_size=0.0,
91 | ):
92 | t0 = 0
93 | t1 = 1
94 | eps = train_eps if not eval else sample_eps
95 | if (type(self.path_sampler) in [path.VPCPlan]):
96 |
97 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
98 |
99 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
100 | and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
101 |
102 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
103 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
104 |
105 | if reverse:
106 | t0, t1 = 1 - t0, 1 - t1
107 |
108 | return t0, t1
109 |
110 |
111 | def sample(self, x1):
112 | """Sampling x0 & t based on shape of x1 (if needed)
113 | Args:
114 | x1 - data point; [batch, *dim]
115 | """
116 |
117 | x0 = th.randn_like(x1)
118 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
119 |
120 | if self.snr_type == SNRType.UNIFORM:
121 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
122 | elif self.snr_type == SNRType.LOGNORM:
123 | u = th.normal(mean=0.0, std=1.0, size=(x1.shape[0],))
124 | t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
125 | else:
126 | raise ValueError(f"Unknown snr type: {self.snr_type}")
127 |
128 | t = t.to(x1)
129 | return t, x0, x1
130 |
131 |
132 | def training_losses(
133 | self,
134 | model,
135 | x1,
136 | model_kwargs=None
137 | ):
138 | """Loss for training the score model
139 | Args:
140 | - model: backbone model; could be score, noise, or velocity
141 | - x1: datapoint
142 | - model_kwargs: additional arguments for the model
143 | """
144 | if model_kwargs == None:
145 | model_kwargs = {}
146 |
147 | t, x0, x1 = self.sample(x1)
148 | t, xt, ut = self.path_sampler.plan(t, x0, x1)
149 | model_output = model(xt, t, **model_kwargs)
150 | B, *_, C = xt.shape
151 | assert model_output.size() == (B, *xt.size()[1:-1], C)
152 | mask, ratio = get_flexible_mask_and_ratio(model_kwargs, x1)
153 |
154 | terms = {}
155 | terms['pred'] = model_output
156 | if self.model_type == ModelType.VELOCITY:
157 | terms['loss'] = mean_flat((((model_output - ut)*mask) ** 2)) * ratio
158 | else:
159 | _, drift_var = self.path_sampler.compute_drift(xt, t)
160 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
161 | if self.loss_type in [WeightType.VELOCITY]:
162 | weight = (drift_var / sigma_t) ** 2
163 | elif self.loss_type in [WeightType.LIKELIHOOD]:
164 | weight = drift_var / (sigma_t ** 2)
165 | elif self.loss_type in [WeightType.NONE]:
166 | weight = 1
167 | else:
168 | raise NotImplementedError()
169 |
170 | if self.model_type == ModelType.NOISE:
171 | terms['loss'] = mean_flat(weight * (((model_output - x0)*mask) ** 2)) * ratio
172 | else:
173 | terms['loss'] = mean_flat(weight * (((model_output * sigma_t + x0)*mask) ** 2)) * ratio
174 |
175 | return terms
176 |
177 |
178 | def get_drift(
179 | self
180 | ):
181 | """member function for obtaining the drift of the probability flow ODE"""
182 | def score_ode(x, t, model, **model_kwargs):
183 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
184 | model_output = model(x, t, **model_kwargs)
185 | return (-drift_mean + drift_var * model_output) # by change of variable
186 |
187 | def noise_ode(x, t, model, **model_kwargs):
188 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
189 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
190 | model_output = model(x, t, **model_kwargs)
191 | score = model_output / -sigma_t
192 | return (-drift_mean + drift_var * score)
193 |
194 | def velocity_ode(x, t, model, **model_kwargs):
195 | model_output = model(x, t, **model_kwargs)
196 | return model_output
197 |
198 | if self.model_type == ModelType.NOISE:
199 | drift_fn = noise_ode
200 | elif self.model_type == ModelType.SCORE:
201 | drift_fn = score_ode
202 | else:
203 | drift_fn = velocity_ode
204 |
205 | def body_fn(x, t, model, **model_kwargs):
206 | model_output = drift_fn(x, t, model, **model_kwargs)
207 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
208 | return model_output
209 |
210 | return body_fn
211 |
212 |
213 | def get_score(
214 | self,
215 | ):
216 | """member function for obtaining score of
217 | x_t = alpha_t * x + sigma_t * eps"""
218 | if self.model_type == ModelType.NOISE:
219 | score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
220 | elif self.model_type == ModelType.SCORE:
221 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
222 | elif self.model_type == ModelType.VELOCITY:
223 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
224 | else:
225 | raise NotImplementedError()
226 |
227 | return score_fn
228 |
229 |
230 | class Sampler:
231 | """Sampler class for the transport model"""
232 | def __init__(
233 | self,
234 | transport,
235 | ):
236 | """Constructor for a general sampler; supporting different sampling methods
237 | Args:
238 | - transport: an tranport object specify model prediction & interpolant type
239 | """
240 |
241 | self.transport = transport
242 | self.drift = self.transport.get_drift()
243 | self.score = self.transport.get_score()
244 |
245 | def __get_sde_diffusion_and_drift(
246 | self,
247 | *,
248 | diffusion_form="SBDM",
249 | diffusion_norm=1.0,
250 | ):
251 |
252 | def diffusion_fn(x, t):
253 | diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
254 | return diffusion
255 |
256 | sde_drift = \
257 | lambda x, t, model, **kwargs: \
258 | self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
259 |
260 | sde_diffusion = diffusion_fn
261 |
262 | return sde_drift, sde_diffusion
263 |
264 | def __get_last_step(
265 | self,
266 | sde_drift,
267 | *,
268 | last_step,
269 | last_step_size,
270 | ):
271 | """Get the last step function of the SDE solver"""
272 |
273 | if last_step is None:
274 | last_step_fn = \
275 | lambda x, t, model, **model_kwargs: \
276 | x
277 | elif last_step == "Mean":
278 | last_step_fn = \
279 | lambda x, t, model, **model_kwargs: \
280 | x + sde_drift(x, t, model, **model_kwargs) * last_step_size
281 | elif last_step == "Tweedie":
282 | alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
283 | sigma = self.transport.path_sampler.compute_sigma_t
284 | last_step_fn = \
285 | lambda x, t, model, **model_kwargs: \
286 | x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
287 | elif last_step == "Euler":
288 | last_step_fn = \
289 | lambda x, t, model, **model_kwargs: \
290 | x + self.drift(x, t, model, **model_kwargs) * last_step_size
291 | else:
292 | raise NotImplementedError()
293 |
294 | return last_step_fn
295 |
296 | def sample_sde(
297 | self,
298 | *,
299 | sampling_method="Euler",
300 | diffusion_form="SBDM",
301 | diffusion_norm=1.0,
302 | last_step="Mean",
303 | last_step_size=0.04,
304 | num_steps=250,
305 | ):
306 | """returns a sampling function with given SDE settings
307 | Args:
308 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
309 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
310 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1
311 | - last_step: type of the last step; default to identity
312 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
313 | - num_steps: total integration step of SDE
314 | """
315 |
316 | if last_step is None:
317 | last_step_size = 0.0
318 |
319 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
320 | diffusion_form=diffusion_form,
321 | diffusion_norm=diffusion_norm,
322 | )
323 |
324 | t0, t1 = self.transport.check_interval(
325 | self.transport.train_eps,
326 | self.transport.sample_eps,
327 | diffusion_form=diffusion_form,
328 | sde=True,
329 | eval=True,
330 | reverse=False,
331 | last_step_size=last_step_size,
332 | )
333 |
334 | _sde = sde(
335 | sde_drift,
336 | sde_diffusion,
337 | t0=t0,
338 | t1=t1,
339 | num_steps=num_steps,
340 | sampler_type=sampling_method
341 | )
342 |
343 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
344 |
345 |
346 | def _sample(init, model, **model_kwargs):
347 | xs = _sde.sample(init, model, **model_kwargs)
348 | ts = th.ones(init.size(0), device=init.device) * t1
349 | x = last_step_fn(xs[-1], ts, model, **model_kwargs)
350 | xs.append(x)
351 |
352 | assert len(xs) == num_steps, "Samples does not match the number of steps"
353 |
354 | return xs
355 |
356 | return _sample
357 |
358 | def sample_ode(
359 | self,
360 | *,
361 | sampling_method="dopri5",
362 | num_steps=50,
363 | atol=1e-6,
364 | rtol=1e-3,
365 | reverse=False,
366 | ):
367 | """returns a sampling function with given ODE settings
368 | Args:
369 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
370 | - num_steps:
371 | - fixed solver (Euler, Heun): the actual number of integration steps performed
372 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
373 | - atol: absolute error tolerance for the solver
374 | - rtol: relative error tolerance for the solver
375 | - reverse: whether solving the ODE in reverse (data to noise); default to False
376 | """
377 | if reverse:
378 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
379 | else:
380 | drift = self.drift
381 |
382 | t0, t1 = self.transport.check_interval(
383 | self.transport.train_eps,
384 | self.transport.sample_eps,
385 | sde=False,
386 | eval=True,
387 | reverse=reverse,
388 | last_step_size=0.0,
389 | )
390 |
391 | _ode = ode(
392 | drift=drift,
393 | t0=t0,
394 | t1=t1,
395 | sampler_type=sampling_method,
396 | num_steps=num_steps,
397 | atol=atol,
398 | rtol=rtol,
399 | )
400 |
401 | return _ode.sample
402 |
403 | def sample_ode_likelihood(
404 | self,
405 | *,
406 | sampling_method="dopri5",
407 | num_steps=50,
408 | atol=1e-6,
409 | rtol=1e-3,
410 | ):
411 |
412 | """returns a sampling function for calculating likelihood with given ODE settings
413 | Args:
414 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
415 | - num_steps:
416 | - fixed solver (Euler, Heun): the actual number of integration steps performed
417 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
418 | - atol: absolute error tolerance for the solver
419 | - rtol: relative error tolerance for the solver
420 | """
421 | def _likelihood_drift(x, t, model, **model_kwargs):
422 | x, _ = x
423 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
424 | t = th.ones_like(t) * (1 - t)
425 | with th.enable_grad():
426 | x.requires_grad = True
427 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
428 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
429 | drift = self.drift(x, t, model, **model_kwargs)
430 | return (-drift, logp_grad)
431 |
432 | t0, t1 = self.transport.check_interval(
433 | self.transport.train_eps,
434 | self.transport.sample_eps,
435 | sde=False,
436 | eval=True,
437 | reverse=False,
438 | last_step_size=0.0,
439 | )
440 |
441 | _ode = ode(
442 | drift=_likelihood_drift,
443 | t0=t0,
444 | t1=t1,
445 | sampler_type=sampling_method,
446 | num_steps=num_steps,
447 | atol=atol,
448 | rtol=rtol,
449 | )
450 |
451 | def _sample_fn(x, model, **model_kwargs):
452 | init_logp = th.zeros(x.size(0)).to(x)
453 | input = (x, init_logp)
454 | drift, delta_logp = _ode.sample(input, model, **model_kwargs)
455 | drift, delta_logp = drift[-1], delta_logp[-1]
456 | prior_logp = self.transport.prior_logp(drift)
457 | logp = prior_logp - delta_logp
458 | return logp, drift
459 |
460 | return _sample_fn
--------------------------------------------------------------------------------
/fit/scheduler/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | class EasyDict:
4 |
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 | def mean_flat(x):
13 | """
14 | Take the mean over all non-batch dimensions.
15 | """
16 | return th.mean(x, dim=list(range(1, len(x.size()))))
17 |
18 | def log_state(state):
19 | result = []
20 |
21 | sorted_state = dict(sorted(state.items()))
22 | for key, value in sorted_state.items():
23 | # Check if the value is an instance of a class
24 | if " (B, 1, N)
42 | spatial case (dit):
43 | x: (B, C, H, W)
44 | model_kwargs: {y: (B,)}
45 | mask: (B, C) -> (B, C, 1, 1)
46 | '''
47 | mask = model_kwargs.get('mask', th.ones(x.shape[:2])) # (B, N) or (B, C)
48 | ratio = float(mask.shape[-1]) / th.count_nonzero(mask, dim=-1) # (B,)
49 | if len(x.shape) == 3: # sequential x: (B, N, C)
50 | mask = mask[..., None] # (B, N) -> (B, N, 1)
51 | elif len(x.shape) == 4: # spatial x: (B, C, H, W)
52 | mask = mask[..., None, None] # (B, C) -> (B, C, 1, 1)
53 | else:
54 | raise NotImplementedError
55 | return mask.to(x), ratio.to(x)
56 |
--------------------------------------------------------------------------------
/fit/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | from tqdm import tqdm
4 | import torch
5 | import re
6 | import os
7 |
8 | from safetensors.torch import load_file
9 |
10 |
11 | def create_npz_from_sample_folder(sample_dir, num=50_000):
12 | """
13 | Builds a single .npz file from a folder of .png samples.
14 | """
15 | samples = []
16 | imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split('.')[0]))
17 | print(len(imgs))
18 | assert len(imgs) >= num
19 | for i in tqdm(range(num), desc="Building .npz file from samples"):
20 | sample_pil = Image.open(f"{sample_dir}/{imgs[i]}")
21 | sample_np = np.asarray(sample_pil).astype(np.uint8)
22 | samples.append(sample_np)
23 | samples = np.stack(samples)
24 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
25 | npz_path = f"{sample_dir}.npz"
26 | np.savez(npz_path, arr_0=samples)
27 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
28 | return npz_path
29 |
30 | def init_from_ckpt(
31 | model, checkpoint_dir, ignore_keys=None, verbose=False
32 | ) -> None:
33 | if checkpoint_dir.endswith(".safetensors"):
34 | try:
35 | model_state_dict=load_file(checkpoint_dir)
36 | except: # 历史遗留问题,千万别删
37 | model_state_dict=torch.load(checkpoint_dir, map_location="cpu")
38 | else:
39 | model_state_dict=torch.load(checkpoint_dir, map_location="cpu")
40 | model_new_ckpt=dict()
41 | for i in model_state_dict.keys():
42 | model_new_ckpt[i] = model_state_dict[i]
43 | keys = list(model_new_ckpt.keys())
44 | for k in keys:
45 | if ignore_keys:
46 | for ik in ignore_keys:
47 | if re.match(ik, k):
48 | print("Deleting key {} from state_dict.".format(k))
49 | del model_new_ckpt[k]
50 | missing, unexpected = model.load_state_dict(model_new_ckpt, strict=False)
51 | if verbose:
52 | print(
53 | f"Restored with {len(missing)} missing and {len(unexpected)} unexpected keys"
54 | )
55 | if len(missing) > 0:
56 | print(f"Missing Keys: {missing}")
57 | if len(unexpected) > 0:
58 | print(f"Unexpected Keys: {unexpected}")
59 | if verbose:
60 | print("")
61 |
--------------------------------------------------------------------------------
/fit/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim import Optimizer
2 | from torch.optim.lr_scheduler import LambdaLR
3 |
4 |
5 | # coding=utf-8
6 | # Copyright 2023 The HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | """PyTorch optimization for diffusion models."""
20 |
21 | import math
22 | from enum import Enum
23 | from typing import Optional, Union
24 |
25 | from torch.optim import Optimizer
26 | from torch.optim.lr_scheduler import LambdaLR
27 |
28 |
29 | class SchedulerType(Enum):
30 | LINEAR = "linear"
31 | COSINE = "cosine"
32 | COSINE_WITH_RESTARTS = "cosine_with_restarts"
33 | POLYNOMIAL = "polynomial"
34 | CONSTANT = "constant"
35 | CONSTANT_WITH_WARMUP = "constant_with_warmup"
36 | PIECEWISE_CONSTANT = "piecewise_constant"
37 | WARMDUP_STABLE_DECAY = "warmup_stable_decay"
38 |
39 |
40 | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
41 | """
42 | Create a schedule with a constant learning rate, using the learning rate set in optimizer.
43 |
44 | Args:
45 | optimizer ([`~torch.optim.Optimizer`]):
46 | The optimizer for which to schedule the learning rate.
47 | last_epoch (`int`, *optional*, defaults to -1):
48 | The index of the last epoch when resuming training.
49 |
50 | Return:
51 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
52 | """
53 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
54 |
55 | def get_constant_schedule_with_warmup(
56 | optimizer: Optimizer, num_warmup_steps: int, div_factor: int = 1e-4, last_epoch: int = -1
57 | ):
58 | def lr_lambda(current_step):
59 | # 0,y0 step,y1
60 | #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1
61 | if current_step < num_warmup_steps:
62 | return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor
63 | return 1.0
64 |
65 | return LambdaLR(optimizer, lr_lambda, last_epoch)
66 |
67 | def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
68 | """
69 | Create a schedule with a constant learning rate, using the learning rate set in optimizer.
70 |
71 | Args:
72 | optimizer ([`~torch.optim.Optimizer`]):
73 | The optimizer for which to schedule the learning rate.
74 | step_rules (`string`):
75 | The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
76 | if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
77 | steps and multiple 0.005 for the other steps.
78 | last_epoch (`int`, *optional*, defaults to -1):
79 | The index of the last epoch when resuming training.
80 |
81 | Return:
82 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
83 | """
84 |
85 | rules_dict = {}
86 | rule_list = step_rules.split(",")
87 | for rule_str in rule_list[:-1]:
88 | value_str, steps_str = rule_str.split(":")
89 | steps = int(steps_str)
90 | value = float(value_str)
91 | rules_dict[steps] = value
92 | last_lr_multiple = float(rule_list[-1])
93 |
94 | def create_rules_function(rules_dict, last_lr_multiple):
95 | def rule_func(steps: int) -> float:
96 | sorted_steps = sorted(rules_dict.keys())
97 | for i, sorted_step in enumerate(sorted_steps):
98 | if steps < sorted_step:
99 | return rules_dict[sorted_steps[i]]
100 | return last_lr_multiple
101 |
102 | return rule_func
103 |
104 | rules_func = create_rules_function(rules_dict, last_lr_multiple)
105 |
106 | return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
107 |
108 |
109 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
110 | """
111 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
112 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
113 |
114 | Args:
115 | optimizer ([`~torch.optim.Optimizer`]):
116 | The optimizer for which to schedule the learning rate.
117 | num_warmup_steps (`int`):
118 | The number of steps for the warmup phase.
119 | num_training_steps (`int`):
120 | The total number of training steps.
121 | last_epoch (`int`, *optional*, defaults to -1):
122 | The index of the last epoch when resuming training.
123 |
124 | Return:
125 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
126 | """
127 |
128 | def lr_lambda(current_step: int):
129 | if current_step < num_warmup_steps:
130 | return float(current_step) / float(max(1, num_warmup_steps))
131 | return max(
132 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
133 | )
134 |
135 | return LambdaLR(optimizer, lr_lambda, last_epoch)
136 |
137 |
138 | def get_cosine_schedule_with_warmup(
139 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
140 | ):
141 | """
142 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
143 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
144 | initial lr set in the optimizer.
145 |
146 | Args:
147 | optimizer ([`~torch.optim.Optimizer`]):
148 | The optimizer for which to schedule the learning rate.
149 | num_warmup_steps (`int`):
150 | The number of steps for the warmup phase.
151 | num_training_steps (`int`):
152 | The total number of training steps.
153 | num_periods (`float`, *optional*, defaults to 0.5):
154 | The number of periods of the cosine function in a schedule (the default is to just decrease from the max
155 | value to 0 following a half-cosine).
156 | last_epoch (`int`, *optional*, defaults to -1):
157 | The index of the last epoch when resuming training.
158 |
159 | Return:
160 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
161 | """
162 |
163 | def lr_lambda(current_step):
164 | if current_step < num_warmup_steps:
165 | return float(current_step) / float(max(1, num_warmup_steps))
166 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
167 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
168 |
169 | return LambdaLR(optimizer, lr_lambda, last_epoch)
170 |
171 |
172 | def get_cosine_with_hard_restarts_schedule_with_warmup(
173 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
174 | ):
175 | """
176 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
177 | initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
178 | linearly between 0 and the initial lr set in the optimizer.
179 |
180 | Args:
181 | optimizer ([`~torch.optim.Optimizer`]):
182 | The optimizer for which to schedule the learning rate.
183 | num_warmup_steps (`int`):
184 | The number of steps for the warmup phase.
185 | num_training_steps (`int`):
186 | The total number of training steps.
187 | num_cycles (`int`, *optional*, defaults to 1):
188 | The number of hard restarts to use.
189 | last_epoch (`int`, *optional*, defaults to -1):
190 | The index of the last epoch when resuming training.
191 |
192 | Return:
193 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
194 | """
195 |
196 | def lr_lambda(current_step):
197 | if current_step < num_warmup_steps:
198 | return float(current_step) / float(max(1, num_warmup_steps))
199 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
200 | if progress >= 1.0:
201 | return 0.0
202 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
203 |
204 | return LambdaLR(optimizer, lr_lambda, last_epoch)
205 |
206 |
207 | def get_polynomial_decay_schedule_with_warmup(
208 | optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
209 | ):
210 | """
211 | Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
212 | optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
213 | initial lr set in the optimizer.
214 |
215 | Args:
216 | optimizer ([`~torch.optim.Optimizer`]):
217 | The optimizer for which to schedule the learning rate.
218 | num_warmup_steps (`int`):
219 | The number of steps for the warmup phase.
220 | num_training_steps (`int`):
221 | The total number of training steps.
222 | lr_end (`float`, *optional*, defaults to 1e-7):
223 | The end LR.
224 | power (`float`, *optional*, defaults to 1.0):
225 | Power factor.
226 | last_epoch (`int`, *optional*, defaults to -1):
227 | The index of the last epoch when resuming training.
228 |
229 | Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
230 | implementation at
231 | https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
232 |
233 | Return:
234 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
235 |
236 | """
237 |
238 | lr_init = optimizer.defaults["lr"]
239 | if not (lr_init > lr_end):
240 | raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
241 |
242 | def lr_lambda(current_step: int):
243 | if current_step < num_warmup_steps:
244 | return float(current_step) / float(max(1, num_warmup_steps))
245 | elif current_step > num_training_steps:
246 | return lr_end / lr_init # as LambdaLR multiplies by lr_init
247 | else:
248 | lr_range = lr_init - lr_end
249 | decay_steps = num_training_steps - num_warmup_steps
250 | pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
251 | decay = lr_range * pct_remaining**power + lr_end
252 | return decay / lr_init # as LambdaLR multiplies by lr_init
253 |
254 | return LambdaLR(optimizer, lr_lambda, last_epoch)
255 |
256 |
257 | def get_constant_schedule_with_warmup_and_decay(
258 | optimizer: Optimizer, num_warmup_steps: int, num_decay_steps: int, decay_T: int = 50000, div_factor: int = 1e-4, last_epoch: int = -1
259 | ):
260 | def lr_lambda(current_step):
261 | # 0,y0 step,y1
262 | #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1
263 | if current_step < num_warmup_steps:
264 | return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor
265 | if current_step > num_decay_steps:
266 | return 0.5 ** ((current_step - num_decay_steps) / decay_T)
267 | return 1.0
268 |
269 | return LambdaLR(optimizer, lr_lambda, last_epoch)
270 |
271 | TYPE_TO_SCHEDULER_FUNCTION = {
272 | SchedulerType.LINEAR: get_linear_schedule_with_warmup,
273 | SchedulerType.COSINE: get_cosine_schedule_with_warmup,
274 | SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
275 | SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
276 | SchedulerType.CONSTANT: get_constant_schedule,
277 | SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
278 | SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
279 | SchedulerType.WARMDUP_STABLE_DECAY: get_constant_schedule_with_warmup_and_decay
280 | }
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 | def get_scheduler(
289 | name: Union[str, SchedulerType],
290 | optimizer: Optimizer,
291 | step_rules: Optional[str] = None,
292 | num_warmup_steps: Optional[int] = None,
293 | num_decay_steps: Optional[int] = None,
294 | num_training_steps: Optional[int] = None,
295 | num_cycles: int = 1,
296 | decay_T: Optional[int] = 50000,
297 | power: float = 1.0,
298 | last_epoch: int = -1,
299 | ):
300 | """
301 | Unified API to get any scheduler from its name.
302 |
303 | Args:
304 | name (`str` or `SchedulerType`):
305 | The name of the scheduler to use.
306 | optimizer (`torch.optim.Optimizer`):
307 | The optimizer that will be used during training.
308 | step_rules (`str`, *optional*):
309 | A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
310 | num_warmup_steps (`int`, *optional*):
311 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being
312 | optional), the function will raise an error if it's unset and the scheduler type requires it.
313 | num_decay_steps (`int`, *optional*):
314 | The number of decay steps to do. This is not required by all schedulers (hence the argument being
315 | optional), the function will raise an error if it's unset and the scheduler type requires it.
316 | num_training_steps (`int``, *optional*):
317 | The number of training steps to do. This is not required by all schedulers (hence the argument being
318 | optional), the function will raise an error if it's unset and the scheduler type requires it.
319 | num_cycles (`int`, *optional*):
320 | The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
321 | power (`float`, *optional*, defaults to 1.0):
322 | Power factor. See `POLYNOMIAL` scheduler
323 | decay_T (`int`, *optional*, defaults to 50000):
324 | Power factor. See `POLYNOMIAL` scheduler
325 | last_epoch (`int`, *optional*, defaults to -1):
326 | The index of the last epoch when resuming training.
327 | """
328 | name = SchedulerType(name)
329 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
330 | if name == SchedulerType.CONSTANT:
331 | return schedule_func(optimizer, last_epoch=last_epoch)
332 |
333 | if name == SchedulerType.PIECEWISE_CONSTANT:
334 | return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
335 |
336 | # All other schedulers require `num_warmup_steps`
337 | if num_warmup_steps is None:
338 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
339 |
340 | if name == SchedulerType.CONSTANT_WITH_WARMUP:
341 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
342 |
343 | if name == SchedulerType.WARMDUP_STABLE_DECAY:
344 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_decay_steps=num_decay_steps, decay_T=decay_T, last_epoch=last_epoch)
345 |
346 | # All other schedulers require `num_training_steps`
347 | if num_training_steps is None:
348 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
349 |
350 | if name == SchedulerType.COSINE_WITH_RESTARTS:
351 | return schedule_func(
352 | optimizer,
353 | num_warmup_steps=num_warmup_steps,
354 | num_training_steps=num_training_steps,
355 | num_cycles=num_cycles,
356 | last_epoch=last_epoch,
357 | )
358 |
359 | if name == SchedulerType.POLYNOMIAL:
360 | return schedule_func(
361 | optimizer,
362 | num_warmup_steps=num_warmup_steps,
363 | num_training_steps=num_training_steps,
364 | power=power,
365 | last_epoch=last_epoch,
366 | )
367 |
368 | return schedule_func(
369 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
370 | )
371 |
--------------------------------------------------------------------------------
/fit/utils/sit_eval_utils.py:
--------------------------------------------------------------------------------
1 | def none_or_str(value):
2 | if value == 'None':
3 | return None
4 | return value
5 |
6 | def parse_sde_args(parser):
7 | group = parser.add_argument_group("SDE arguments")
8 | group.add_argument("--sde-sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
9 | group.add_argument("--diffusion-form", type=str, default="sigma", \
10 | choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
11 | help="form of diffusion coefficient in the SDE")
12 | group.add_argument("--diffusion-norm", type=float, default=1.0)
13 | group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
14 | help="form of last step taken in the SDE")
15 | group.add_argument("--last-step-size", type=float, default=0.04, \
16 | help="size of the last step taken")
17 |
18 | def parse_ode_args(parser):
19 | group = parser.add_argument_group("ODE arguments")
20 | group.add_argument("--ode-sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
21 | group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
22 | group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
23 | group.add_argument("--reverse", action="store_true")
24 | group.add_argument("--likelihood", action="store_true")
25 |
26 | # ode solvers:
27 | # - Adaptive-step:
28 | # - dopri8 Runge-Kutta 7(8) of Dormand-Prince-Shampine
29 | # - dopri5 Runge-Kutta 4(5) of Dormand-Prince [default].
30 | # - bosh3 Runge-Kutta 2(3) of Bogacki-Shampine
31 | # - adaptive_heun Runge-Kutta 1(2)
32 | # - Fixed-step:
33 | # - euler Euler method.
34 | # - midpoint Midpoint method.
35 | # - rk4 Fourth-order Runge-Kutta with 3/8 rule.
36 | # - explicit_adams Explicit Adams.
37 | # - implicit_adams Implicit Adams.
--------------------------------------------------------------------------------
/fit/utils/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 |
5 | def get_obj_from_str(string, reload=False, invalidate_cache=True):
6 | module, cls = string.rsplit(".", 1)
7 | if invalidate_cache:
8 | importlib.invalidate_caches()
9 | if reload:
10 | module_imp = importlib.import_module(module)
11 | importlib.reload(module_imp)
12 | return getattr(importlib.import_module(module, package=None), cls)
13 |
14 |
15 | def instantiate_from_config(config):
16 | if not "target" in config:
17 | if config == "__is_first_stage__":
18 | return None
19 | elif config == "__is_unconditional__":
20 | return None
21 | raise KeyError("Expected key `target` to instantiate.")
22 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
23 |
24 |
25 | @torch.no_grad()
26 | def update_ema(ema_model, model, decay=0.9999):
27 | """
28 | Step the EMA model towards the current model.
29 | """
30 | if hasattr(model, 'module'):
31 | model = model.module
32 | if hasattr(ema_model, 'module'):
33 | ema_model = ema_model.module
34 | ema_params = OrderedDict(ema_model.named_parameters())
35 | model_params = OrderedDict(model.named_parameters())
36 |
37 | for name, param in model_params.items():
38 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
39 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
40 |
41 |
42 |
43 | def default(val, d):
44 | if exists(val):
45 | return val
46 | return d() if isfunction(d) else d
47 |
48 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested
2 | transformers>=4.44.2 # The development team is working on version 4.44.2
3 | accelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested
4 | sentencepiece>=0.2.0 # T5 used
5 | numpy==1.26.0
6 | streamlit>=1.38.0 # For streamlit web demo
7 | imageio==2.34.2 # For diffusers inference export video
8 | imageio-ffmpeg==0.5.1 # For diffusers inference export video
9 | openai>=1.42.0 # For prompt refiner
10 | moviepy==1.0.3 # For export video
11 | pillow==9.5.0
12 | timm==0.6.13
13 | safetensors==0.4.5
14 | einops
--------------------------------------------------------------------------------
/sample_fit_ddp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Samples a large number of images from a pre-trained DiT model using DDP.
9 | Subsequently saves a .npz file that can be used to compute FID and other
10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
11 |
12 | For a simple single-GPU/CPU sampling script, see sample.py.
13 | """
14 | import os
15 | import math
16 | import torch
17 | import argparse
18 | import numpy as np
19 | import torch.distributed as dist
20 | import re
21 |
22 | from omegaconf import OmegaConf
23 | from tqdm import tqdm
24 | from PIL import Image
25 | from diffusers.models import AutoencoderKL
26 | from fit.scheduler.improved_diffusion import create_diffusion
27 | rom fit.utils.eval_utils import create_npz_from_sample_folder, init_from_ckpt
28 | from fit.utils.utils import instantiate_from_config
29 | f
30 |
31 | def ntk_scaled_init(head_dim, base=10000, alpha=8):
32 | #The method is just these two lines
33 | dim_h = head_dim // 2 # for x and y
34 | base = base * alpha ** (dim_h / (dim_h-2)) #Base change formula
35 | return base
36 |
37 | def main(args):
38 | """
39 | Run sampling.
40 | """
41 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
42 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
43 | torch.set_grad_enabled(False)
44 |
45 | # Setup DDP:
46 | dist.init_process_group("nccl")
47 | rank = dist.get_rank()
48 | device = rank % torch.cuda.device_count()
49 | seed = args.global_seed * dist.get_world_size() + rank
50 | torch.manual_seed(seed)
51 | torch.cuda.set_device(device)
52 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
53 |
54 | if args.mixed == "fp32":
55 | weight_dtype = torch.float32
56 | elif args.mixed == "bf16":
57 | weight_dtype = torch.bfloat16
58 |
59 | if args.cfgdir == "":
60 | args.cfgdir = os.path.join(args.ckpt.split("/")[0], args.ckpt.split("/")[1], "configs/config.yaml")
61 | print("config dir: ",args.cfgdir)
62 | config = OmegaConf.load(args.cfgdir)
63 | config_diffusion = config.diffusion
64 |
65 |
66 | H, W = args.image_height // 8, args.image_width // 8
67 | patch_size = config_diffusion.network_config.params.patch_size
68 | n_patch_h, n_patch_w = H // patch_size, W // patch_size
69 |
70 | if args.interpolation != 'no':
71 | # sqrt(256) or sqrt(512), we set max PE length for inference, in fact some PE has been seen in the training stage.
72 | ori_max_pe_len = int(config_diffusion.network_config.params.context_size ** 0.5)
73 | if args.interpolation == 'linear': # 这个就是positional index interpolation,原来叫normal,现在叫linear
74 | config_diffusion.network_config.params['custom_freqs'] = 'linear'
75 | elif args.interpolation == 'dynntk': # 这个就是ntk-aware
76 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware'
77 | elif args.interpolation == 'partntk': # 这个就是ntk-by-parts
78 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-by-parts'
79 | elif args.interpolation == 'yarn':
80 | config_diffusion.network_config.params['custom_freqs'] = 'yarn'
81 | else:
82 | raise NotImplementedError
83 | config_diffusion.network_config.params['max_pe_len_h'] = n_patch_h
84 | config_diffusion.network_config.params['max_pe_len_w'] = n_patch_w
85 | config_diffusion.network_config.params['decouple'] = args.decouple
86 | config_diffusion.network_config.params['ori_max_pe_len'] = int(ori_max_pe_len)
87 |
88 | else: # there is no need to do interpolation!
89 | pass
90 |
91 | model = instantiate_from_config(config_diffusion.network_config).to(device, dtype=weight_dtype)
92 | init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True)
93 | model.eval() # important
94 |
95 | # prepare first stage model
96 | if args.vae_decoder == 'sd-ft-mse':
97 | vae_model = 'stabilityai/sd-vae-ft-mse'
98 | elif args.vae_decoder == 'sd-ft-ema':
99 | vae_model = 'stabilityai/sd-vae-ft-ema'
100 | vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(device, dtype=weight_dtype)
101 | vae.eval() # important
102 |
103 |
104 | config_diffusion.improved_diffusion.timestep_respacing = str(args.num_sampling_steps)
105 | diffusion = create_diffusion(**OmegaConf.to_container(config_diffusion.improved_diffusion))
106 |
107 |
108 | workdir_name = 'official_fit'
109 | folder_name = f'{args.ckpt.split("/")[-1].split(".")[0]}'
110 |
111 | sample_folder_dir = f"{args.sample_dir}/{workdir_name}/{folder_name}"
112 | if rank == 0:
113 | os.makedirs(os.path.join(args.sample_dir, workdir_name), exist_ok=True)
114 | os.makedirs(sample_folder_dir, exist_ok=True)
115 | print(f"Saving .png samples at {sample_folder_dir}")
116 | dist.barrier()
117 | args.cfg_scale = float(args.cfg_scale)
118 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
119 | using_cfg = args.cfg_scale > 1.0
120 |
121 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
122 | n = args.per_proc_batch_size
123 | global_batch_size = n * dist.get_world_size()
124 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
125 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
126 | if rank == 0:
127 | print(f"Total number of images that will be sampled: {total_samples}")
128 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
129 | samples_needed_this_gpu = int(total_samples // dist.get_world_size())
130 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
131 | iterations = int(samples_needed_this_gpu // n)
132 | pbar = range(iterations)
133 | pbar = tqdm(pbar) if rank == 0 else pbar
134 | total = 0
135 | index = 0
136 | all_images = []
137 | while len(all_images) * n < int(args.num_fid_samples):
138 | print(device, "device: ", index, flush=True)
139 | index+=1
140 | # Sample inputs:
141 | z = torch.randn(
142 | (n, (patch_size**2)*model.in_channels, n_patch_h*n_patch_w)
143 | ).to(device=device, dtype=weight_dtype)
144 | y = torch.randint(0, args.num_classes, (n,), device=device)
145 |
146 | # prepare for x
147 | grid_h = torch.arange(n_patch_h, dtype=torch.float32)
148 | grid_w = torch.arange(n_patch_w, dtype=torch.float32)
149 | grid = torch.meshgrid(grid_w, grid_h, indexing='xy')
150 | grid = torch.cat(
151 | [grid[0].reshape(1,-1), grid[1].reshape(1,-1)], dim=0
152 | ).repeat(n,1,1).to(device=device, dtype=weight_dtype)
153 | mask = torch.ones(n, n_patch_h*n_patch_w).to(device=device, dtype=weight_dtype)
154 | size = torch.tensor((n_patch_h, n_patch_w)).repeat(n,1).to(device=device, dtype=torch.long)
155 | size = size[:, None, :]
156 |
157 |
158 |
159 | # Setup classifier-free guidance:
160 | if using_cfg:
161 | z = torch.cat([z, z], 0) # (B, patch_size**2 * C, N) -> (2B, patch_size**2 * C, N)
162 | y_null = torch.tensor([1000] * n, device=device)
163 | y = torch.cat([y, y_null], 0) # (B,) -> (2B, )
164 | grid = torch.cat([grid, grid], 0) # (B, 2, N) -> (2B, 2, N)
165 | mask = torch.cat([mask, mask], 0) # (B, N) -> (2B, N)
166 | model_kwargs = dict(y=y, grid=grid.long(), mask=mask, size=size, cfg_scale=args.cfg_scale)
167 | sample_fn = model.forward_with_cfg
168 | else:
169 | model_kwargs = dict(y=y, grid=grid.long(), mask=mask, size=size)
170 | sample_fn = model.forward
171 |
172 | # Sample images:
173 | samples = diffusion.p_sample_loop(
174 | sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
175 | )
176 | if using_cfg:
177 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
178 |
179 | samples = samples[..., : n_patch_h*n_patch_w]
180 | samples = model.unpatchify(samples, (H, W))
181 | samples = vae.decode(samples / vae.config.scaling_factor).sample
182 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(torch.uint8).contiguous()
183 |
184 | # gather samples
185 | gathered_samples = [torch.zeros_like(samples) for _ in range(dist.get_world_size())]
186 | dist.all_gather(gathered_samples, samples) # gather not supported with NCCL
187 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
188 | torch.cuda.empty_cache()
189 | # Save samples to disk as individual .png files
190 | for i, sample in enumerate(samples.cpu().numpy()):
191 | index = i * dist.get_world_size() + rank + total
192 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
193 | total += global_batch_size
194 | if rank == 0:
195 | pbar.update()
196 |
197 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
198 | dist.barrier()
199 | if rank == 0:
200 | arr = np.concatenate(all_images, axis=0)
201 | arr = arr[: int(args.num_fid_samples)]
202 | npz_path = f"{sample_folder_dir}.npz"
203 | np.savez(npz_path, arr_0=arr)
204 | print(f"Saved .npz file to {npz_path} [shape={arr.shape}].")
205 | print("Done.")
206 | dist.barrier()
207 | dist.destroy_process_group()
208 |
209 | if __name__ == "__main__":
210 | parser = argparse.ArgumentParser()
211 | parser.add_argument("--cfgdir", type=str, default="")
212 | parser.add_argument("--ckpt", type=str, default="")
213 | parser.add_argument("--sample-dir", type=str, default="workdir/eval")
214 | parser.add_argument("--per-proc-batch-size", type=int, default=32)
215 | parser.add_argument("--num-fid-samples", type=int, default=50_000)
216 | parser.add_argument("--image-height", type=int, default=256)
217 | parser.add_argument("--image-width", type=int, default=256)
218 | parser.add_argument("--num-classes", type=int, default=1000)
219 | parser.add_argument("--vae-decoder", type=str, choices=['sd-ft-mse', 'sd-ft-ema'], default='sd-ft-ema')
220 | parser.add_argument("--cfg-scale", type=str, default='1.5')
221 | parser.add_argument("--num-sampling-steps", type=int, default=250)
222 | parser.add_argument("--global-seed", type=int, default=0)
223 | parser.add_argument("--interpolation", type=str, choices=['no', 'linear', 'yarn', 'dynntk', 'partntk'], default='no') # interpolation
224 | parser.add_argument("--decouple", default=False, action="store_true") # interpolation
225 | parser.add_argument("--tf32", action='store_true', default=True)
226 | parser.add_argument("--mixed", type=str, default="fp32")
227 | args = parser.parse_args()
228 | main(args)
229 |
--------------------------------------------------------------------------------
/sample_fitv2_ddp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Samples a large number of images from a pre-trained DiT model using DDP.
9 | Subsequently saves a .npz file that can be used to compute FID and other
10 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
11 |
12 | For a simple single-GPU/CPU sampling script, see sample.py.
13 | """
14 | import os
15 | import sys
16 | import math
17 | import torch
18 | import argparse
19 | import numpy as np
20 | import torch.distributed as dist
21 | import re
22 |
23 | from omegaconf import OmegaConf
24 | from tqdm import tqdm
25 | from PIL import Image
26 | from diffusers.models import AutoencoderKL
27 | from fit.scheduler.transport import create_transport, Sampler
28 | from fit.utils.eval_utils import create_npz_from_sample_folder, init_from_ckpt
29 | from fit.utils.utils import instantiate_from_config
30 | from fit.utils.sit_eval_utils import parse_sde_args, parse_ode_args
31 |
32 | def ntk_scaled_init(head_dim, base=10000, alpha=8):
33 | #The method is just these two lines
34 | dim_h = head_dim // 2 # for x and y
35 | base = base * alpha ** (dim_h / (dim_h-2)) #Base change formula
36 | return base
37 |
38 | def main(args):
39 | """
40 | Run sampling.
41 | """
42 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
43 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
44 | torch.set_grad_enabled(False)
45 |
46 | # Setup DDP:
47 | dist.init_process_group("nccl")
48 | rank = dist.get_rank()
49 | device = rank % torch.cuda.device_count()
50 | seed = args.global_seed * dist.get_world_size() + rank
51 | torch.manual_seed(seed)
52 | torch.cuda.set_device(device)
53 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
54 |
55 | if args.mixed == "fp32":
56 | weight_dtype = torch.float32
57 | elif args.mixed == "bf16":
58 | weight_dtype = torch.bfloat16
59 |
60 | if args.cfgdir == "":
61 | args.cfgdir = os.path.join(args.ckpt.split("/")[0], args.ckpt.split("/")[1], "configs/config.yaml")
62 | print("config dir: ",args.cfgdir)
63 | config = OmegaConf.load(args.cfgdir)
64 | config_diffusion = config.diffusion
65 |
66 |
67 | H, W = args.image_height // 8, args.image_width // 8
68 | patch_size = config_diffusion.network_config.params.patch_size
69 | n_patch_h, n_patch_w = H // patch_size, W // patch_size
70 |
71 | if args.interpolation != 'no':
72 | if args.interpolation == 'linear': # 这个就是positional index interpolation,原来叫normal,现在叫linear
73 | config_diffusion.network_config.params['custom_freqs'] = 'linear'
74 | elif args.interpolation == 'dynntk': # 这个就是ntk-aware
75 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware'
76 | elif args.interpolation == 'ntkpro1':
77 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware-pro1'
78 | elif args.interpolation == 'ntkpro2':
79 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware-pro2'
80 | elif args.interpolation == 'partntk': # 这个就是ntk-by-parts
81 | config_diffusion.network_config.params['custom_freqs'] = 'ntk-by-parts'
82 | elif args.interpolation == 'yarn':
83 | config_diffusion.network_config.params['custom_freqs'] = 'yarn'
84 | else:
85 | raise NotImplementedError
86 | config_diffusion.network_config.params['max_pe_len_h'] = n_patch_h
87 | config_diffusion.network_config.params['max_pe_len_w'] = n_patch_w
88 | config_diffusion.network_config.params['decouple'] = args.decouple
89 | config_diffusion.network_config.params['ori_max_pe_len'] = int(args.ori_max_pe_len)
90 |
91 | config_diffusion.network_config.params['online_rope'] = False
92 |
93 | else: # there is no need to do interpolation!
94 | config_diffusion.network_config.params['custom_freqs'] = 'normal'
95 | config_diffusion.network_config.params['online_rope'] = False
96 |
97 |
98 |
99 | model = instantiate_from_config(config_diffusion.network_config).to(device, dtype=weight_dtype)
100 | init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True)
101 | model.eval() # important
102 |
103 | # prepare first stage model
104 | if args.vae_decoder == 'sd-ft-mse':
105 | vae_model = 'stabilityai/sd-vae-ft-mse'
106 | elif args.vae_decoder == 'sd-ft-ema':
107 | vae_model = 'stabilityai/sd-vae-ft-ema'
108 | vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(device, dtype=weight_dtype)
109 | vae.eval() # important
110 |
111 |
112 | # prepare transport
113 | transport = create_transport(**OmegaConf.to_container(config_diffusion.transport)) # default: velocity;
114 | sampler = Sampler(transport)
115 | sampler_mode = args.sampler_mode
116 | if sampler_mode == "ODE":
117 | if args.likelihood:
118 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
119 | sample_fn = sampler.sample_ode_likelihood(
120 | sampling_method=args.ode_sampling_method,
121 | num_steps=args.num_sampling_steps,
122 | atol=args.atol,
123 | rtol=args.rtol,
124 | )
125 | else:
126 | sample_fn = sampler.sample_ode(
127 | sampling_method=args.ode_sampling_method,
128 | num_steps=args.num_sampling_steps,
129 | atol=args.atol,
130 | rtol=args.rtol,
131 | reverse=args.reverse
132 | )
133 | elif sampler_mode == "SDE":
134 | sample_fn = sampler.sample_sde(
135 | sampling_method=args.sde_sampling_method,
136 | diffusion_form=args.diffusion_form,
137 | diffusion_norm=args.diffusion_norm,
138 | last_step=args.last_step,
139 | last_step_size=args.last_step_size,
140 | num_steps=args.num_sampling_steps,
141 | )
142 | else:
143 | raise NotImplementedError
144 |
145 |
146 |
147 | workdir_name = 'official_fit'
148 | folder_name = f'{args.ckpt.split("/")[-1].split(".")[0]}'
149 |
150 |
151 | sample_folder_dir = f"{args.sample_dir}/{workdir_name}/{folder_name}"
152 | if rank == 0:
153 | os.makedirs(os.path.join(args.sample_dir, workdir_name), exist_ok=True)
154 | os.makedirs(sample_folder_dir, exist_ok=True)
155 | print(f"Saving .png samples at {sample_folder_dir}")
156 | dist.barrier()
157 | args.cfg_scale = float(args.cfg_scale)
158 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
159 | using_cfg = args.cfg_scale > 1.0
160 |
161 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
162 | n = args.per_proc_batch_size
163 | global_batch_size = n * dist.get_world_size()
164 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
165 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
166 | if rank == 0:
167 | print(f"Total number of images that will be sampled: {total_samples}")
168 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
169 | samples_needed_this_gpu = int(total_samples // dist.get_world_size())
170 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
171 | iterations = int(samples_needed_this_gpu // n)
172 | pbar = range(iterations)
173 | pbar = tqdm(pbar) if rank == 0 else pbar
174 | total = 0
175 | index = 0
176 | all_images = []
177 | while len(all_images) * n < int(args.num_fid_samples):
178 | print(device, "device: ", index, flush=True)
179 | index+=1
180 | # Sample inputs:
181 | z = torch.randn(
182 | (n, n_patch_h*n_patch_w, (patch_size**2)*model.in_channels)
183 | ).to(device=device, dtype=weight_dtype)
184 | y = torch.randint(0, args.num_classes, (n,), device=device)
185 |
186 | # prepare for x
187 | grid_h = torch.arange(n_patch_h, dtype=torch.long)
188 | grid_w = torch.arange(n_patch_w, dtype=torch.long)
189 | grid = torch.meshgrid(grid_w, grid_h, indexing='xy')
190 | grid = torch.cat(
191 | [grid[0].reshape(1,-1), grid[1].reshape(1,-1)], dim=0
192 | ).repeat(n,1,1).to(device=device, dtype=torch.long)
193 | mask = torch.ones(n, n_patch_h*n_patch_w).to(device=device, dtype=weight_dtype)
194 | size = torch.tensor((n_patch_h, n_patch_w)).repeat(n,1).to(device=device, dtype=torch.long)
195 | size = size[:, None, :]
196 | # Setup classifier-free guidance:
197 | if using_cfg:
198 | z = torch.cat([z, z], 0) # (B, N, patch_size**2 * C) -> (2B, N, patch_size**2 * C)
199 | y_null = torch.tensor([1000] * n, device=device)
200 | y = torch.cat([y, y_null], 0) # (B,) -> (2B, )
201 | grid = torch.cat([grid, grid], 0) # (B, 2, N) -> (2B, 2, N)
202 | mask = torch.cat([mask, mask], 0) # (B, N) -> (2B, N)
203 | size = torch.cat([size, size], 0)
204 | model_kwargs = dict(y=y, grid=grid, mask=mask, size=size, cfg_scale=args.cfg_scale, scale_pow=args.scale_pow)
205 | model_fn = model.forward_with_cfg
206 | else:
207 | model_kwargs = dict(y=y, grid=grid, mask=mask, size=size)
208 | model_fn = model.forward
209 |
210 |
211 | # Sample images:
212 | samples = sample_fn(z, model_fn, **model_kwargs)[-1]
213 | if using_cfg:
214 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
215 |
216 | samples = samples[..., : n_patch_h*n_patch_w]
217 | samples = model.unpatchify(samples, (H, W))
218 | samples = vae.decode(samples / vae.config.scaling_factor).sample
219 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(torch.uint8).contiguous()
220 |
221 | # gather samples
222 | gathered_samples = [torch.zeros_like(samples) for _ in range(dist.get_world_size())]
223 | dist.all_gather(gathered_samples, samples) # gather not supported with NCCL
224 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
225 | # Save samples to disk as individual .png files
226 | for i, sample in enumerate(samples.cpu().numpy()):
227 | index = i * dist.get_world_size() + rank + total
228 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
229 | total += global_batch_size
230 | if rank == 0:
231 | pbar.update()
232 |
233 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
234 | dist.barrier()
235 | if rank == 0:
236 | import time
237 | time.sleep(20)
238 | arr = np.concatenate(all_images, axis=0)
239 | arr = arr[: int(args.num_fid_samples)]
240 | npz_path = f"{sample_folder_dir}.npz"
241 | np.savez(npz_path, arr_0=arr)
242 | print(f"Saved .npz file to {npz_path} [shape={arr.shape}].")
243 | print("Done.")
244 | dist.barrier()
245 | dist.destroy_process_group()
246 |
247 | if __name__ == "__main__":
248 | parser = argparse.ArgumentParser()
249 | parser.add_argument("--cfgdir", type=str, default="")
250 | parser.add_argument("--ckpt", type=str, default="")
251 | parser.add_argument("--sample-dir", type=str, default="workdir/eval")
252 | parser.add_argument("--per-proc-batch-size", type=int, default=32)
253 | parser.add_argument("--num-fid-samples", type=int, default=50_000)
254 | parser.add_argument("--image-height", type=int, default=256)
255 | parser.add_argument("--image-width", type=int, default=256)
256 | parser.add_argument("--num-classes", type=int, default=1000)
257 | parser.add_argument("--vae-decoder", type=str, choices=['sd-ft-mse', 'sd-ft-ema'], default='sd-ft-ema')
258 | parser.add_argument("--cfg-scale", type=str, default='1.0')
259 | parser.add_argument("--scale-pow", type=float, default=0.0)
260 | parser.add_argument("--num-sampling-steps", type=int, default=250)
261 | parser.add_argument("--global-seed", type=int, default=0)
262 | parser.add_argument("--interpolation", type=str, choices=['no', 'linear', 'yarn', 'dynntk', 'partntk', 'ntkpro1', 'ntkpro2'], default='no') # interpolation
263 | parser.add_argument("--ori-max-pe-len", default=None, type=int)
264 | parser.add_argument("--decouple", default=False, action="store_true") # interpolation
265 | parser.add_argument("--sampler-mode", default='SDE', choices=['SDE', 'ODE'])
266 | parser.add_argument("--tf32", action='store_true', default=True)
267 | parser.add_argument("--mixed", type=str, default="fp32")
268 | parser.add_argument("--save-images", action='store_true', default=False)
269 | parse_ode_args(parser)
270 | parse_sde_args(parser)
271 | args = parser.parse_args()
272 | main(args)
273 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | setup(
5 | name='fit',
6 | version='0.0.1',
7 | description='',
8 | packages=find_packages(),
9 | install_requires=[
10 | 'torch',
11 | 'numpy',
12 | 'tqdm',
13 | ],
14 | )
--------------------------------------------------------------------------------
/tools/download_in1k_latents_1024.sh:
--------------------------------------------------------------------------------
1 | mkdir datasets
2 | cd datasets
3 | mkdir imagenet1k_latents_1024_sd_vae_ft_ema
4 | cd imagenet1k_latents_1024_sd_vae_ft_ema
5 |
6 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/from_16_to_1024.tar.gz.part_aa?download=true" -O from_16_to_1024.tar.gz.part_aa
7 |
8 | tar -xzvf from_16_to_1024.tar.gz.part_aa
9 |
10 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/from_16_to_1024.tar.gz.part_ab?download=true" -O from_16_to_1024.tar.gz.part_ab
11 |
12 | tar -xzvf from_16_to_1024.tar.gz.part_aa
13 |
14 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/from_16_to_1024.tar.gz.part_ac?download=true" -O from_16_to_1024.tar.gz.part_ac
15 |
16 | tar -xzvf from_16_to_1024.tar.gz.part_aa
17 |
18 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/greater_than_1024_crop.tar.gz?download=true" -O greater_than_1024_crop.tar.gz
19 |
20 | tar -xzvf greater_than_1024_crop.tar.gz
21 |
22 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet_features_1024_sd_vae_ft_ema/resolve/main/greater_than_1024_resize.tar.gz?download=true" -O greater_than_1024_resize.tar.gz
23 |
24 | tar -xzvf greater_than_1024_resize.tar.gz
25 |
26 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/less_than_16.tar.gz?download=true" -O less_than_16.tar.gz
27 |
28 | tar -xzvf less_than_16.tar.gz
29 |
--------------------------------------------------------------------------------
/tools/download_in1k_latents_256.sh:
--------------------------------------------------------------------------------
1 | mkdir datasets
2 | cd datasets
3 | mkdir imagenet1k_latents_256_sd_vae_ft_ema
4 | cd imagenet1k_latents_256_sd_vae_ft_ema
5 |
6 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/from_16_to_256.tar.gz?download=true" -O from_16_to_256.tar.gz
7 |
8 | tar -xzvf from_16_to_256.tar.gz
9 |
10 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/greater_than_256_crop.tar.gz?download=true" -O greater_than_256_crop.tar.gz
11 |
12 | tar -xzvf greater_than_256_crop.tar.gz
13 |
14 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/greater_than_256_resize.tar.gz?download=true" -O greater_than_256_resize.tar.gz
15 |
16 | tar -xzvf greater_than_256_resize.tar.gz
17 |
18 | wget -c "https://huggingface.co/datasets/InfImagine/imagenet1k_features_256_sd_vae_ft_ema/resolve/main/less_than_16.tar.gz?download=true" -O less_than_16.tar.gz
19 |
20 | tar -xzvf less_than_16.tar.gz
21 |
--------------------------------------------------------------------------------
/tools/train_fit_xl.sh:
--------------------------------------------------------------------------------
1 | JOB_NAME = "train_fit_xl"
2 | NNODES = 1
3 | GPUS_PER_NODE = 8
4 | MASTER_ADDR = "localhost"
5 | export MASTER_PORT=60563
6 |
7 | CMD=" \
8 | projects/FiT/FiT/train_fit.py \
9 | --project_name ${JOB_NAME} \
10 | --main_project_name image_generation \
11 | --seed 0 \
12 | --scale_lr \
13 | --allow_tf32 \
14 | --resume_from_checkpoint latest \
15 | --workdir workdir/fit_xl \
16 | --cfgdir "projects/FiT/FiT/configs/fit/config_fit_xl.yaml" \
17 | --use_ema
18 | "
19 | TORCHLAUNCHER="torchrun \
20 | --nnodes $NNODES \
21 | --nproc_per_node $GPUS_PER_NODE \
22 | --rdzv_id $RANDOM \
23 | --rdzv_backend c10d \
24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
25 | "
26 | bash -c "$TORCHLAUNCHER $CMD"
--------------------------------------------------------------------------------
/tools/train_fitv2_3B.sh:
--------------------------------------------------------------------------------
1 | JOB_NAME = "train_fitv2_3B"
2 | NNODES = 1
3 | GPUS_PER_NODE = 8
4 | MASTER_ADDR = "localhost"
5 | export MASTER_PORT=60563
6 |
7 | CMD=" \
8 | projects/FiT/FiT/train_fitv2.py \
9 | --project_name ${JOB_NAME} \
10 | --main_project_name image_generation \
11 | --seed 0 \
12 | --scale_lr \
13 | --allow_tf32 \
14 | --resume_from_checkpoint latest \
15 | --workdir workdir/fitv2_3B \
16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_3B.yaml" \
17 | --use_ema
18 | "
19 | TORCHLAUNCHER="torchrun \
20 | --nnodes $NNODES \
21 | --nproc_per_node $GPUS_PER_NODE \
22 | --rdzv_id $RANDOM \
23 | --rdzv_backend c10d \
24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
25 | "
26 | bash -c "$TORCHLAUNCHER $CMD"
--------------------------------------------------------------------------------
/tools/train_fitv2_hr_3B.sh:
--------------------------------------------------------------------------------
1 | JOB_NAME = "train_fitv2_hr_3B"
2 | NNODES = 1
3 | GPUS_PER_NODE = 8
4 | MASTER_ADDR = "localhost"
5 | export MASTER_PORT=60563
6 |
7 | CMD=" \
8 | projects/FiT/FiT/train_fitv2.py \
9 | --project_name ${JOB_NAME} \
10 | --main_project_name image_generation \
11 | --seed 0 \
12 | --scale_lr \
13 | --allow_tf32 \
14 | --resume_from_checkpoint latest \
15 | --workdir workdir/fitv2_hr_3B \
16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_hr_3B.yaml" \
17 | --use_ema
18 | "
19 | TORCHLAUNCHER="torchrun \
20 | --nnodes $NNODES \
21 | --nproc_per_node $GPUS_PER_NODE \
22 | --rdzv_id $RANDOM \
23 | --rdzv_backend c10d \
24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
25 | "
26 | bash -c "$TORCHLAUNCHER $CMD"
--------------------------------------------------------------------------------
/tools/train_fitv2_hr_xl.sh:
--------------------------------------------------------------------------------
1 | JOB_NAME = "train_fitv2_hr_xl"
2 | NNODES = 1
3 | GPUS_PER_NODE = 8
4 | MASTER_ADDR = "localhost"
5 | export MASTER_PORT=60563
6 |
7 | CMD=" \
8 | projects/FiT/FiT/train_fitv2.py \
9 | --project_name ${JOB_NAME} \
10 | --main_project_name image_generation \
11 | --seed 0 \
12 | --scale_lr \
13 | --allow_tf32 \
14 | --resume_from_checkpoint latest \
15 | --workdir workdir/fitv2_hr_xl \
16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_hr_xl.yaml" \
17 | --use_ema
18 | "
19 | TORCHLAUNCHER="torchrun \
20 | --nnodes $NNODES \
21 | --nproc_per_node $GPUS_PER_NODE \
22 | --rdzv_id $RANDOM \
23 | --rdzv_backend c10d \
24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
25 | "
26 | bash -c "$TORCHLAUNCHER $CMD"
--------------------------------------------------------------------------------
/tools/train_fitv2_xl.sh:
--------------------------------------------------------------------------------
1 | JOB_NAME = "train_fitv2_xl"
2 | NNODES = 1
3 | GPUS_PER_NODE = 8
4 | MASTER_ADDR = "localhost"
5 | export MASTER_PORT=60563
6 |
7 | CMD=" \
8 | projects/FiT/FiT/train_fitv2.py \
9 | --project_name ${JOB_NAME} \
10 | --main_project_name image_generation \
11 | --seed 0 \
12 | --scale_lr \
13 | --allow_tf32 \
14 | --resume_from_checkpoint latest \
15 | --workdir workdir/fitv2_xl \
16 | --cfgdir "projects/FiT/FiT/configs/fitv2/config_fitv2_xl.yaml" \
17 | --use_ema
18 | "
19 | TORCHLAUNCHER="torchrun \
20 | --nnodes $NNODES \
21 | --nproc_per_node $GPUS_PER_NODE \
22 | --rdzv_id $RANDOM \
23 | --rdzv_backend c10d \
24 | --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
25 | "
26 | bash -c "$TORCHLAUNCHER $CMD"
--------------------------------------------------------------------------------
/train_fitv2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import datetime
5 | import time
6 | import torchvision
7 | import wandb
8 | import logging
9 | import math
10 | import shutil
11 | import accelerate
12 | import torch
13 | import torch.utils.checkpoint
14 | import diffusers
15 | import numpy as np
16 | import torch.nn.functional as F
17 |
18 | from functools import partial
19 | from torch.cuda import amp
20 | from omegaconf import OmegaConf
21 | from accelerate import Accelerator, skip_first_batches
22 | from accelerate.logging import get_logger
23 | from accelerate.state import AcceleratorState
24 | from accelerate.utils import ProjectConfiguration, set_seed, save, FullyShardedDataParallelPlugin
25 | from tqdm.auto import tqdm
26 | from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler
27 | from safetensors import safe_open
28 | from safetensors.torch import load_file
29 | from copy import deepcopy
30 | from einops import rearrange
31 | from fit.schedulers.transport import create_transport
32 | from fit.utils.utils import (
33 | instantiate_from_config,
34 | default,
35 | get_obj_from_str,
36 | update_ema,
37 |
38 | )
39 | from fit.utils.eval_utils import init_from_ckpt
40 | from fit.utils.lr_scheduler import get_scheduler
41 |
42 | logger = get_logger(__name__, log_level="INFO")
43 |
44 | # For Omegaconf Tuple
45 | def resolve_tuple(*args):
46 | return tuple(args)
47 | OmegaConf.register_new_resolver("tuple", resolve_tuple)
48 |
49 | def parse_args():
50 | parser = argparse.ArgumentParser(description="Argument.")
51 | parser.add_argument(
52 | "--project_name",
53 | type=str,
54 | const=True,
55 | default="",
56 | nargs="?",
57 | help="if setting, the logdir will be like: project_name",
58 | )
59 | parser.add_argument(
60 | "--main_project_name",
61 | type=str,
62 | default="image_generation",
63 | )
64 | parser.add_argument(
65 | "--workdir",
66 | type=str,
67 | default="workdir",
68 | help="workdir",
69 | )
70 | parser.add_argument( # if resume, you change it none. i will load from the resumedir
71 | "--cfgdir",
72 | nargs="*",
73 | help="paths to base configs. Loaded from left-to-right. "
74 | "Parameters can be overwritten or added with command-line options of the form `--key value`.",
75 | default=list(),
76 | )
77 | parser.add_argument(
78 | "-s",
79 | "--seed",
80 | type=int,
81 | default=0,
82 | help="seed for seed_everything",
83 | )
84 | parser.add_argument(
85 | "--resume_from_checkpoint",
86 | type=str,
87 | default='latest',
88 | help=(
89 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
90 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
91 | ),
92 | )
93 | parser.add_argument(
94 | "--load_model_from_checkpoint",
95 | type=str,
96 | default=None,
97 | help=(
98 | "Whether training should be loaded from a pretrained model checkpoint."
99 | "Or you can set diffusion.pretrained_model_path in Config for loading!!!"
100 | ),
101 | )
102 | parser.add_argument(
103 | "--scale_lr",
104 | action="store_true",
105 | default=False,
106 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
107 | )
108 | parser.add_argument(
109 | "--allow_tf32",
110 | action="store_true",
111 | help=(
112 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
113 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
114 | ),
115 | )
116 | parser.add_argument(
117 | "--use_ema",
118 | action="store_true",
119 | default=True,
120 | help="Whether to use EMA model."
121 | )
122 | parser.add_argument(
123 | "--ema_decay",
124 | type=float,
125 | default=0.9999,
126 | help="The decay rate for ema."
127 | )
128 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
129 | args = parser.parse_args()
130 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
131 | if env_local_rank != -1 and env_local_rank != args.local_rank:
132 | args.local_rank = env_local_rank
133 | return args
134 |
135 |
136 | def main():
137 | args = parse_args()
138 |
139 | datenow = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
140 | project_name = None
141 | workdir = None
142 | workdirnow = None
143 | cfgdir = None
144 | ckptdir = None
145 | logging_dir = None
146 | imagedir = None
147 |
148 | if args.project_name:
149 | project_name = args.project_name
150 | if os.path.exists(os.path.join(args.workdir, project_name)): #open resume
151 | workdir=os.path.join(args.workdir, project_name)
152 | else: # new a workdir
153 | workdir = os.path.join(args.workdir, project_name)
154 | # if accelerator.is_main_process:
155 | os.makedirs(workdir, exist_ok=True)
156 | workdirnow = workdir
157 |
158 | cfgdir = os.path.join(workdirnow, "configs")
159 | ckptdir = os.path.join(workdirnow, "checkpoints")
160 | logging_dir = os.path.join(workdirnow, "logs")
161 | imagedir = os.path.join(workdirnow, "images")
162 |
163 | # if accelerator.is_main_process:
164 | os.makedirs(cfgdir, exist_ok=True)
165 | os.makedirs(ckptdir, exist_ok=True)
166 | os.makedirs(logging_dir, exist_ok=True)
167 | os.makedirs(imagedir, exist_ok=True)
168 | if args.cfgdir:
169 | load_cfgdir = args.cfgdir
170 |
171 | # setup config
172 | configs_list = load_cfgdir # read config from a config dir
173 | configs = [OmegaConf.load(cfg) for cfg in configs_list]
174 | config = OmegaConf.merge(*configs)
175 | accelerate_cfg = config.accelerate
176 | diffusion_cfg = config.diffusion
177 | data_cfg = config.data
178 | grad_accu_steps = accelerate_cfg.gradient_accumulation_steps
179 |
180 | train_strtg_cfg = getattr(config, 'training_strategy', None)
181 | if train_strtg_cfg != None:
182 | warp_pos_idx = hasattr(train_strtg_cfg, 'warp_pos_idx')
183 | if warp_pos_idx:
184 | warp_pos_idx_fn = partial(warp_pos_idx_from_grid,
185 | shift=train_strtg_cfg.warp_pos_idx.shift,
186 | scale=train_strtg_cfg.warp_pos_idx.scale,
187 | max_len=train_strtg_cfg.warp_pos_idx.max_len
188 | )
189 |
190 | accelerator_project_cfg = ProjectConfiguration(project_dir=workdirnow, logging_dir=logging_dir)
191 |
192 | if getattr(accelerate_cfg, 'fsdp_config', None) != None:
193 | import functools
194 | from torch.distributed.fsdp.fully_sharded_data_parallel import (
195 | BackwardPrefetch, CPUOffload, ShardingStrategy, MixedPrecision, StateDictType, FullStateDictConfig, FullOptimStateDictConfig,
196 | )
197 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
198 | fsdp_cfg = accelerate_cfg.fsdp_config
199 | if accelerate_cfg.mixed_precision == "fp16":
200 | dtype = torch.float16
201 | elif accelerate_cfg.mixed_precision == "bf16":
202 | dtype = torch.bfloat16
203 | else:
204 | dtype = torch.float32
205 | fsdp_plugin = FullyShardedDataParallelPlugin(
206 | sharding_strategy = {
207 | 'FULL_SHARD': ShardingStrategy.FULL_SHARD,
208 | 'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
209 | 'NO_SHARD': ShardingStrategy.NO_SHARD,
210 | 'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD,
211 | 'HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2,
212 | }[fsdp_cfg.sharding_strategy],
213 | backward_prefetch = {
214 | 'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE,
215 | 'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST,
216 | }[fsdp_cfg.backward_prefetch],
217 | mixed_precision_policy = MixedPrecision(
218 | param_dtype=dtype,
219 | reduce_dtype=dtype,
220 | ),
221 | auto_wrap_policy = functools.partial(
222 | size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params
223 | ),
224 | cpu_offload = CPUOffload(offload_params=fsdp_cfg.cpu_offload),
225 | state_dict_type = {
226 | 'FULL_STATE_DICT': StateDictType.FULL_STATE_DICT,
227 | 'LOCAL_STATE_DICT': StateDictType.LOCAL_STATE_DICT,
228 | 'SHARDED_STATE_DICT': StateDictType.SHARDED_STATE_DICT
229 | }[fsdp_cfg.state_dict_type],
230 | state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
231 | optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
232 | limit_all_gathers = fsdp_cfg.limit_all_gathers, # False
233 | use_orig_params = fsdp_cfg.use_orig_params, # True
234 | sync_module_states = fsdp_cfg.sync_module_states, #True
235 | forward_prefetch = fsdp_cfg.forward_prefetch, # False
236 | activation_checkpointing = fsdp_cfg.activation_checkpointing, # False
237 | )
238 | else:
239 | fsdp_plugin = None
240 | accelerator = Accelerator(
241 | gradient_accumulation_steps=grad_accu_steps,
242 | mixed_precision=accelerate_cfg.mixed_precision,
243 | fsdp_plugin=fsdp_plugin,
244 | log_with=getattr(accelerate_cfg, 'logger', 'wandb'),
245 | project_config=accelerator_project_cfg,
246 | )
247 | device = accelerator.device
248 |
249 | # Make one log on every process with the configuration for debugging.
250 | logging.basicConfig(
251 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
252 | datefmt="%m/%d/%Y %H:%M:%S",
253 | level=logging.INFO,
254 | )
255 | logger.info(accelerator.state, main_process_only=False)
256 |
257 | if accelerator.is_local_main_process:
258 | File_handler = logging.FileHandler(os.path.join(logging_dir, project_name+"_"+datenow+".log"), encoding="utf-8")
259 | File_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
260 | File_handler.setLevel(logging.INFO)
261 | logger.logger.addHandler(File_handler)
262 |
263 | diffusers.utils.logging.set_verbosity_warning()
264 | diffusers.utils.logging.set_verbosity_info()
265 | else:
266 | diffusers.utils.logging.set_verbosity_error()
267 | diffusers.utils.logging.set_verbosity_error()
268 |
269 | if args.seed is not None:
270 | set_seed(args.seed)
271 |
272 | if args.allow_tf32: # for A100
273 | torch.backends.cuda.matmul.allow_tf32 = True
274 | torch.backends.cudnn.allow_tf32 = True
275 |
276 | if args.scale_lr:
277 | learning_rate = (
278 | accelerate_cfg.learning_rate *
279 | grad_accu_steps *
280 | data_cfg.params.train.loader.batch_size * # local batch size per device
281 | accelerator.num_processes / accelerate_cfg.learning_rate_base_batch_size # global batch size
282 | )
283 | else:
284 | learning_rate = accelerate_cfg.learning_rate
285 |
286 |
287 | model = instantiate_from_config(diffusion_cfg.network_config).to(device=device)
288 | # update ema
289 | if args.use_ema:
290 | # ema_dtype = torch.float32
291 | if hasattr(model, 'module'):
292 | ema_model = deepcopy(model.module).to(device=device)
293 | else:
294 | ema_model = deepcopy(model).to(device=device)
295 | if getattr(diffusion_cfg, 'pretrain_config', None) != None: # transfer to larger reolution
296 | if getattr(diffusion_cfg.pretrain_config, 'ema_ckpt', None) != None:
297 | init_from_ckpt(
298 | ema_model, checkpoint_dir=diffusion_cfg.pretrain_config.ema_ckpt,
299 | ignore_keys=diffusion_cfg.pretrain_config.ignore_keys, verbose=True
300 | )
301 | for p in ema_model.parameters():
302 | p.requires_grad = False
303 |
304 | if args.use_ema:
305 | model = accelerator.prepare_model(model, device_placement=False)
306 | ema_model = accelerator.prepare_model(ema_model, device_placement=False)
307 | else:
308 | model = accelerator.prepare_model(model, device_placement=False)
309 |
310 | # In SiT, we use transport instead of diffusion
311 | transport = create_transport(**OmegaConf.to_container(diffusion_cfg.transport)) # default: velocity;
312 | # schedule_sampler = create_named_schedule_sampler()
313 |
314 | # Setup Dataloader
315 | total_batch_size = data_cfg.params.train.loader.batch_size * accelerator.num_processes * grad_accu_steps
316 | global_steps = 0
317 | if args.resume_from_checkpoint:
318 | # normal read with safety check
319 | if args.resume_from_checkpoint != "latest":
320 | resume_from_path = os.path.basename(args.resume_from_checkpoint)
321 | else: # Get the most recent checkpoint
322 | dirs = os.listdir(ckptdir)
323 | dirs = [d for d in dirs if d.startswith("checkpoint")]
324 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
325 | resume_from_path = dirs[-1] if len(dirs) > 0 else None
326 |
327 | if resume_from_path is None:
328 | logger.info(
329 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
330 | )
331 | args.resume_from_checkpoint = None
332 | else:
333 | global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps
334 | logger.info(f"Resuming from steps: {global_steps}")
335 |
336 | get_train_dataloader = instantiate_from_config(data_cfg)
337 | train_len = get_train_dataloader.train_len()
338 | train_dataloader = get_train_dataloader.train_dataloader(
339 | global_batch_size=total_batch_size, max_steps=accelerate_cfg.max_train_steps,
340 | resume_step=global_steps, seed=args.seed
341 | )
342 |
343 | # Setup optimizer and lr_scheduler
344 | if accelerator.is_main_process:
345 | for name, param in model.named_parameters():
346 | print(name, param.requires_grad)
347 | if getattr(diffusion_cfg, 'pretrain_config', None) != None: # transfer to larger reolution
348 | params = filter(lambda p: p.requires_grad, model.parameters())
349 | else:
350 | params = list(model.parameters())
351 | optimizer_cfg = default(
352 | accelerate_cfg.optimizer, {"target": "torch.optim.AdamW"}
353 | )
354 | optimizer = get_obj_from_str(optimizer_cfg["target"])(
355 | params, lr=learning_rate, **optimizer_cfg.get("params", dict())
356 | )
357 | lr_scheduler = get_scheduler(
358 | accelerate_cfg.lr_scheduler,
359 | optimizer=optimizer,
360 | num_warmup_steps=accelerate_cfg.lr_warmup_steps,
361 | num_training_steps=accelerate_cfg.max_train_steps,
362 | )
363 |
364 | # Prepare Accelerate
365 | optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
366 | optimizer, train_dataloader, lr_scheduler
367 | )
368 |
369 | # We need to initialize the trackers we use, and also store our configuration.
370 | # The trackers initializes automatically on the main process.
371 | if accelerator.is_main_process and getattr(accelerate_cfg, 'logger', 'wandb') != None:
372 | os.environ["WANDB_DIR"] = os.path.join(os.getcwd(), workdirnow)
373 | accelerator.init_trackers(
374 | project_name=args.main_project_name,
375 | config=config,
376 | init_kwargs={"wandb": {"group": args.project_name}}
377 | )
378 |
379 | # Train!
380 | logger.info("***** Running training *****")
381 | logger.info(f" Num examples = {train_len}")
382 | logger.info(f" Instantaneous batch size per device = {data_cfg.params.train.loader.batch_size}")
383 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
384 | logger.info(f" Learning rate = {learning_rate}")
385 | logger.info(f" Gradient Accumulation steps = {grad_accu_steps}")
386 | logger.info(f" Total optimization steps = {accelerate_cfg.max_train_steps}")
387 | logger.info(f" Current optimization steps = {global_steps}")
388 | logger.info(f" Train dataloader length = {len(train_dataloader)} ")
389 | logger.info(f" Training Mixed-Precision = {accelerate_cfg.mixed_precision}")
390 |
391 | # Potentially load in the weights and states from a previous save
392 | if args.resume_from_checkpoint:
393 | # normal read with safety check
394 | error_times=0
395 | while(True):
396 | if error_times >= 100:
397 | raise
398 | try:
399 | logger.info(f"Resuming from checkpoint {resume_from_path}")
400 | accelerator.load_state(os.path.join(ckptdir, resume_from_path))
401 | break
402 | except (RuntimeError, Exception) as err:
403 | error_times+=1
404 | if accelerator.is_local_main_process:
405 | logger.warning(err)
406 | logger.warning(f"Failed to resume from checkpoint {resume_from_path}")
407 | shutil.rmtree(os.path.join(ckptdir, resume_from_path))
408 | else:
409 | time.sleep(2)
410 |
411 | # save config
412 | OmegaConf.save(config=config, f=os.path.join(cfgdir, "config.yaml"))
413 |
414 | # Only show the progress bar once on each machine.
415 | progress_bar = tqdm(
416 | range(0, accelerate_cfg.max_train_steps),
417 | disable = not accelerator.is_main_process
418 | )
419 | progress_bar.set_description("Optim Steps")
420 | progress_bar.update(global_steps)
421 |
422 | if args.use_ema:
423 | # ema_model = ema_model.to(ema_dtype)
424 | ema_model.eval()
425 | # Training Loop
426 | model.train()
427 | train_loss = 0.0
428 | for step, batch in enumerate(train_dataloader, start=global_steps):
429 | for batch_key in batch.keys():
430 | if not isinstance(batch[batch_key], list):
431 | batch[batch_key] = batch[batch_key].to(device=device)
432 | x = batch['feature'] # (B, N, C)
433 | grid = batch['grid'] # (B, 2, N)
434 | mask = batch['mask'] # (B, N)
435 | y = batch['label'] # (B, 1)
436 | size = batch['size'] # (B, N_pack, 2), order: h, w. When pack is not used, N_pack=1.
437 | with accelerator.accumulate(model):
438 | # save memory for x, grid, mask
439 | N_batch = int(torch.max(torch.sum(size[..., 0] * size[..., 1], dim=-1)))
440 | x, grid, mask = x[:, : N_batch], grid[..., : N_batch], mask[:, : N_batch]
441 |
442 | # prepare other parameters
443 | y = y.squeeze(dim=-1).to(torch.int)
444 | model_kwargs = dict(y=y, grid=grid.long(), mask=mask, size=size)
445 | # forward model and compute loss
446 | with accelerator.autocast():
447 | loss_dict = transport.training_losses(model, x, model_kwargs)
448 | loss = loss_dict["loss"].mean()
449 | # Backpropagate
450 | optimizer.zero_grad()
451 | accelerator.backward(loss)
452 | if accelerator.sync_gradients and accelerate_cfg.max_grad_norm > 0.:
453 | all_norm = accelerator.clip_grad_norm_(
454 | model.parameters(), accelerate_cfg.max_grad_norm
455 | )
456 | optimizer.step()
457 | lr_scheduler.step()
458 | # Gather the losses across all processes for logging (if we use distributed training).
459 | avg_loss = accelerator.gather(loss.repeat(data_cfg.params.train.loader.batch_size)).mean()
460 | train_loss += avg_loss.item() / grad_accu_steps
461 |
462 | # Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation
463 | if accelerator.sync_gradients:
464 | if args.use_ema:
465 | # update_ema(ema_model, deepcopy(model).type(ema_dtype), args.ema_decay)
466 | update_ema(ema_model, model, args.ema_decay)
467 |
468 | progress_bar.update(1)
469 | global_steps += 1
470 | if getattr(accelerate_cfg, 'logger', 'wandb') != None:
471 | accelerator.log({"train_loss": train_loss}, step=global_steps)
472 | accelerator.log({"lr": lr_scheduler.get_last_lr()[0]}, step=global_steps)
473 | if accelerate_cfg.max_grad_norm != 0.0:
474 | accelerator.log({"grad_norm": all_norm.item()}, step=global_steps)
475 | train_loss = 0.0
476 | if global_steps % accelerate_cfg.checkpointing_steps == 0:
477 | if accelerate_cfg.checkpoints_total_limit is not None:
478 | checkpoints = os.listdir(ckptdir)
479 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
480 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
481 |
482 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
483 | if accelerator.is_main_process and len(checkpoints) >= accelerate_cfg.checkpoints_total_limit:
484 | num_to_remove = len(checkpoints) - accelerate_cfg.checkpoints_total_limit + 1
485 | removing_checkpoints = checkpoints[0:num_to_remove]
486 |
487 | logger.info(
488 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
489 | )
490 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
491 |
492 | for removing_checkpoint in removing_checkpoints:
493 | removing_checkpoint = os.path.join(ckptdir, removing_checkpoint)
494 | shutil.rmtree(removing_checkpoint)
495 |
496 | save_path = os.path.join(ckptdir, f"checkpoint-{global_steps}")
497 | if accelerator.is_main_process:
498 | os.makedirs(save_path)
499 | accelerator.wait_for_everyone()
500 | accelerator.save_state(save_path)
501 | logger.info(f"Saved state to {save_path}")
502 | accelerator.wait_for_everyone()
503 |
504 | if global_steps in accelerate_cfg.checkpointing_steps_list:
505 | save_path = os.path.join(ckptdir, f"save-checkpoint-{global_steps}")
506 | accelerator.wait_for_everyone()
507 | accelerator.save_state(save_path)
508 | logger.info(f"Saved state to {save_path}")
509 | accelerator.wait_for_everyone()
510 |
511 | logs = {"step_loss": loss.detach().item(),
512 | "lr": lr_scheduler.get_last_lr()[0]}
513 | progress_bar.set_postfix(**logs)
514 | if global_steps % accelerate_cfg.logging_steps == 0:
515 | if accelerator.is_main_process:
516 | logger.info("step="+str(global_steps)+" / total_step="+str(accelerate_cfg.max_train_steps)+", step_loss="+str(logs["step_loss"])+', lr='+str(logs["lr"]))
517 |
518 | if global_steps >= accelerate_cfg.max_train_steps:
519 | logger.info(f'global step ({global_steps}) >= max_train_steps ({accelerate_cfg.max_train_steps}), stop training!!!')
520 | break
521 | accelerator.wait_for_everyone()
522 | accelerator.end_training()
523 |
524 | if __name__ == "__main__":
525 | main()
--------------------------------------------------------------------------------